Diagnóstico del cáncer de seno usando redes neuronales#

  • 60 min | Ultima modificación: Noviembre 6, 2020.

Bibliografía. * Machine Learning with R. Brett Lantz, Packt Publishing, Second Edition, 2015.

Descripción del problema#

Se desea determinar si una masa mamaria es un tumor benigno o maligno, a partir de las medidas obtenidas de imágenes digitalizadas de la aspiración con una aguja fina. Los valores representan las características de los núcleos celulares presentes en la imagen digital. La muestra de 569 ejemplos de resultados de las biopsias. Cada registro contiene 32 variables, las cuales corresponden a tres medidas (media, desviación estándar, peor caso) de diez características diferentes (radius, texture, …).

  • Identification number

  • Cancer diagnosis (“M” para maligno y “B” para benigno)

  • Radius

  • Texture

  • Perimeter

  • Area

  • Smoothness

  • Compactness

  • Concavity

  • Concave points

  • Symmetry

  • Fractal dimension

En términos de los datos, se desea pronosticar si una masa es benigna o maligna (clase B o M) a partir de las 30 variables.

Fuente de los datos: https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)

El problema en términos matemáticos se define de la siguiente forma.

  • Se tienen M ejemplos (las 569 observaciones del problema analizado).

  • Cada ejemplo esta definido por un conjunto de variables (x_1, x_2, …, x_N); es decir, las 30 columnas de datos.

  • Cada ejemplo pertenece a una clase y hay P clases diferentes; en el caso analizado sólo hay dos clases: benigno o maligno.

  • Para un nuevo caso (tumor) y con base en las 30 mediciones realizadas (variables), se desea pronosticar a que clase pertenece (maligno o benigno).

Preparación y carga de datos#

[1]:
!wget https://raw.githubusercontent.com/jdvelasq/datalabs/master/datasets/wisc_bc_data.csv
--2020-11-02 02:01:13--  https://raw.githubusercontent.com/jdvelasq/datalabs/master/datasets/wisc_bc_data.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 199.232.48.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|199.232.48.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 125773 (123K) [text/plain]
Saving to: ‘wisc_bc_data.csv’

wisc_bc_data.csv    100%[===================>] 122.83K  --.-KB/s    in 0.1s

2020-11-02 02:01:14 (1.07 MB/s) - ‘wisc_bc_data.csv’ saved [125773/125773]

[2]:
##
## Mueve el archivo de datos al hdfs
##
!hdfs dfs -copyFromLocal wisc_bc_data.csv /tmp/wisc_bc_data.csv

Spark#

[3]:
##
## Carga de las librerías de Spark
##
import findspark
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

findspark.init()

APP_NAME = "spark-app"

conf = SparkConf().setAppName(APP_NAME)
sc = SparkContext(conf=conf)
spark = SparkSession(sc)

Carga de datos#

[4]:
##
## Lectura del archivo.
##
spark_df = spark.read.load(
    "/tmp/wisc_bc_data.csv", format="csv", sep=",", inferSchema="true", header="true"
)


##
## Tipos de datos de los campos del DataFrame
##
spark_df.printSchema()
root
 |-- id: integer (nullable = true)
 |-- diagnosis: string (nullable = true)
 |-- radius_mean: double (nullable = true)
 |-- texture_mean: double (nullable = true)
 |-- perimeter_mean: double (nullable = true)
 |-- area_mean: double (nullable = true)
 |-- smoothness_mean: double (nullable = true)
 |-- compactness_mean: double (nullable = true)
 |-- concavity_mean: double (nullable = true)
 |-- concave_points_mean: double (nullable = true)
 |-- symmetry_mean: double (nullable = true)
 |-- fractal_dimension_mean: double (nullable = true)
 |-- radius_se: double (nullable = true)
 |-- texture_se: double (nullable = true)
 |-- perimeter_se: double (nullable = true)
 |-- area_se: double (nullable = true)
 |-- smoothness_se: double (nullable = true)
 |-- compactness_se: double (nullable = true)
 |-- concavity_se: double (nullable = true)
 |-- concave_points_se: double (nullable = true)
 |-- symmetry_se: double (nullable = true)
 |-- fractal_dimension_se: double (nullable = true)
 |-- radius_worst: double (nullable = true)
 |-- texture_worst: double (nullable = true)
 |-- perimeter_worst: double (nullable = true)
 |-- area_worst: double (nullable = true)
 |-- smoothness_worst: double (nullable = true)
 |-- compactness_worst: double (nullable = true)
 |-- concavity_worst: double (nullable = true)
 |-- concave_points_worst: double (nullable = true)
 |-- symmetry_worst: double (nullable = true)
 |-- fractal_dimension_worst: double (nullable = true)

[5]:
##
## Contenido del archivo
##
spark_df.show()
+--------+---------+-----------+------------+--------------+---------+---------------+----------------+--------------+-------------------+-------------+----------------------+---------+----------+------------+-------+-------------+--------------+------------+-----------------+-----------+--------------------+------------+-------------+---------------+----------+----------------+-----------------+---------------+--------------------+--------------+-----------------------+
|      id|diagnosis|radius_mean|texture_mean|perimeter_mean|area_mean|smoothness_mean|compactness_mean|concavity_mean|concave_points_mean|symmetry_mean|fractal_dimension_mean|radius_se|texture_se|perimeter_se|area_se|smoothness_se|compactness_se|concavity_se|concave_points_se|symmetry_se|fractal_dimension_se|radius_worst|texture_worst|perimeter_worst|area_worst|smoothness_worst|compactness_worst|concavity_worst|concave_points_worst|symmetry_worst|fractal_dimension_worst|
+--------+---------+-----------+------------+--------------+---------+---------------+----------------+--------------+-------------------+-------------+----------------------+---------+----------+------------+-------+-------------+--------------+------------+-----------------+-----------+--------------------+------------+-------------+---------------+----------+----------------+-----------------+---------------+--------------------+--------------+-----------------------+
|  842302|        M|      17.99|       10.38|         122.8|   1001.0|         0.1184|          0.2776|        0.3001|             0.1471|       0.2419|               0.07871|    1.095|    0.9053|       8.589|  153.4|     0.006399|       0.04904|     0.05373|          0.01587|    0.03003|            0.006193|       25.38|        17.33|          184.6|    2019.0|          0.1622|           0.6656|         0.7119|              0.2654|        0.4601|                 0.1189|
|  842517|        M|      20.57|       17.77|         132.9|   1326.0|        0.08474|         0.07864|        0.0869|            0.07017|       0.1812|               0.05667|   0.5435|    0.7339|       3.398|  74.08|     0.005225|       0.01308|      0.0186|           0.0134|    0.01389|            0.003532|       24.99|        23.41|          158.8|    1956.0|          0.1238|           0.1866|         0.2416|               0.186|         0.275|                0.08902|
|84300903|        M|      19.69|       21.25|         130.0|   1203.0|         0.1096|          0.1599|        0.1974|             0.1279|       0.2069|               0.05999|   0.7456|    0.7869|       4.585|  94.03|      0.00615|       0.04006|     0.03832|          0.02058|     0.0225|            0.004571|       23.57|        25.53|          152.5|    1709.0|          0.1444|           0.4245|         0.4504|               0.243|        0.3613|                0.08758|
|84348301|        M|      11.42|       20.38|         77.58|    386.1|         0.1425|          0.2839|        0.2414|             0.1052|       0.2597|               0.09744|   0.4956|     1.156|       3.445|  27.23|      0.00911|       0.07458|     0.05661|          0.01867|    0.05963|            0.009208|       14.91|         26.5|          98.87|     567.7|          0.2098|           0.8663|         0.6869|              0.2575|        0.6638|                  0.173|
|84358402|        M|      20.29|       14.34|         135.1|   1297.0|         0.1003|          0.1328|         0.198|             0.1043|       0.1809|               0.05883|   0.7572|    0.7813|       5.438|  94.44|      0.01149|       0.02461|     0.05688|          0.01885|    0.01756|            0.005115|       22.54|        16.67|          152.2|    1575.0|          0.1374|            0.205|            0.4|              0.1625|        0.2364|                0.07678|
|  843786|        M|      12.45|        15.7|         82.57|    477.1|         0.1278|            0.17|        0.1578|            0.08089|       0.2087|               0.07613|   0.3345|    0.8902|       2.217|  27.19|      0.00751|       0.03345|     0.03672|          0.01137|    0.02165|            0.005082|       15.47|        23.75|          103.4|     741.6|          0.1791|           0.5249|         0.5355|              0.1741|        0.3985|                 0.1244|
|  844359|        M|      18.25|       19.98|         119.6|   1040.0|        0.09463|           0.109|        0.1127|              0.074|       0.1794|               0.05742|   0.4467|    0.7732|        3.18|  53.91|     0.004314|       0.01382|     0.02254|          0.01039|    0.01369|            0.002179|       22.88|        27.66|          153.2|    1606.0|          0.1442|           0.2576|         0.3784|              0.1932|        0.3063|                0.08368|
|84458202|        M|      13.71|       20.83|          90.2|    577.9|         0.1189|          0.1645|       0.09366|            0.05985|       0.2196|               0.07451|   0.5835|     1.377|       3.856|  50.96|     0.008805|       0.03029|     0.02488|          0.01448|    0.01486|            0.005412|       17.06|        28.14|          110.6|     897.0|          0.1654|           0.3682|         0.2678|              0.1556|        0.3196|                 0.1151|
|  844981|        M|       13.0|       21.82|          87.5|    519.8|         0.1273|          0.1932|        0.1859|            0.09353|        0.235|               0.07389|   0.3063|     1.002|       2.406|  24.32|     0.005731|       0.03502|     0.03553|          0.01226|    0.02143|            0.003749|       15.49|        30.73|          106.2|     739.3|          0.1703|           0.5401|          0.539|               0.206|        0.4378|                 0.1072|
|84501001|        M|      12.46|       24.04|         83.97|    475.9|         0.1186|          0.2396|        0.2273|            0.08543|        0.203|               0.08243|   0.2976|     1.599|       2.039|  23.94|     0.007149|       0.07217|     0.07743|          0.01432|    0.01789|             0.01008|       15.09|        40.68|          97.65|     711.4|          0.1853|            1.058|          1.105|               0.221|        0.4366|                 0.2075|
|  845636|        M|      16.02|       23.24|         102.7|    797.8|        0.08206|         0.06669|       0.03299|            0.03323|       0.1528|               0.05697|   0.3795|     1.187|       2.466|  40.51|     0.004029|      0.009269|     0.01101|         0.007591|     0.0146|            0.003042|       19.19|        33.88|          123.8|    1150.0|          0.1181|           0.1551|         0.1459|             0.09975|        0.2948|                0.08452|
|84610002|        M|      15.78|       17.89|         103.6|    781.0|         0.0971|          0.1292|       0.09954|            0.06606|       0.1842|               0.06082|   0.5058|    0.9849|       3.564|  54.16|     0.005771|       0.04061|     0.02791|          0.01282|    0.02008|            0.004144|       20.42|        27.28|          136.5|    1299.0|          0.1396|           0.5609|         0.3965|               0.181|        0.3792|                 0.1048|
|  846226|        M|      19.17|        24.8|         132.4|   1123.0|         0.0974|          0.2458|        0.2065|             0.1118|       0.2397|                 0.078|   0.9555|     3.568|       11.07|  116.2|     0.003139|       0.08297|      0.0889|           0.0409|    0.04484|             0.01284|       20.96|        29.94|          151.7|    1332.0|          0.1037|           0.3903|         0.3639|              0.1767|        0.3176|                 0.1023|
|  846381|        M|      15.85|       23.95|         103.7|    782.7|        0.08401|          0.1002|       0.09938|            0.05364|       0.1847|               0.05338|   0.4033|     1.078|       2.903|  36.58|     0.009769|       0.03126|     0.05051|          0.01992|    0.02981|            0.003002|       16.84|        27.66|          112.0|     876.5|          0.1131|           0.1924|         0.2322|              0.1119|        0.2809|                0.06287|
|84667401|        M|      13.73|       22.61|          93.6|    578.3|         0.1131|          0.2293|        0.2128|            0.08025|       0.2069|               0.07682|   0.2121|     1.169|       2.061|  19.21|     0.006429|       0.05936|     0.05501|          0.01628|    0.01961|            0.008093|       15.03|        32.01|          108.8|     697.7|          0.1651|           0.7725|         0.6943|              0.2208|        0.3596|                 0.1431|
|84799002|        M|      14.54|       27.54|         96.73|    658.8|         0.1139|          0.1595|        0.1639|            0.07364|       0.2303|               0.07077|     0.37|     1.033|       2.879|  32.55|     0.005607|        0.0424|     0.04741|           0.0109|    0.01857|            0.005466|       17.46|        37.13|          124.1|     943.2|          0.1678|           0.6577|         0.7026|              0.1712|        0.4218|                 0.1341|
|  848406|        M|      14.68|       20.13|         94.74|    684.5|        0.09867|           0.072|       0.07395|            0.05259|       0.1586|               0.05922|   0.4727|      1.24|       3.195|   45.4|     0.005718|       0.01162|     0.01998|          0.01109|     0.0141|            0.002085|       19.07|        30.88|          123.4|    1138.0|          0.1464|           0.1871|         0.2914|              0.1609|        0.3029|                0.08216|
|84862001|        M|      16.13|       20.68|         108.1|    798.8|          0.117|          0.2022|        0.1722|             0.1028|       0.2164|               0.07356|   0.5692|     1.073|       3.854|  54.18|     0.007026|       0.02501|     0.03188|          0.01297|    0.01689|            0.004142|       20.96|        31.48|          136.8|    1315.0|          0.1789|           0.4233|         0.4784|              0.2073|        0.3706|                 0.1142|
|  849014|        M|      19.81|       22.15|         130.0|   1260.0|        0.09831|          0.1027|        0.1479|            0.09498|       0.1582|               0.05395|   0.7582|     1.017|       5.865|  112.4|     0.006494|       0.01893|     0.03391|          0.01521|    0.01356|            0.001997|       27.32|        30.88|          186.8|    2398.0|          0.1512|            0.315|         0.5372|              0.2388|        0.2768|                0.07615|
| 8510426|        B|      13.54|       14.36|         87.46|    566.3|        0.09779|         0.08129|       0.06664|            0.04781|       0.1885|               0.05766|   0.2699|    0.7886|       2.058|  23.56|     0.008462|        0.0146|     0.02387|          0.01315|     0.0198|              0.0023|       15.11|        19.26|           99.7|     711.2|           0.144|           0.1773|          0.239|              0.1288|        0.2977|                0.07259|
+--------+---------+-----------+------------+--------------+---------+---------------+----------------+--------------+-------------------+-------------+----------------------+---------+----------+------------+-------+-------------+--------------+------------+-----------------+-----------+--------------------+------------+-------------+---------------+----------+----------------+-----------------+---------------+--------------------+--------------+-----------------------+
only showing top 20 rows

[6]:
##
## Número de registros cargados
##
spark_df.count()
[6]:
569

Análisis exploratorio#

[7]:
##
## Cantidad de casos para cada diagnóstico.
##
spark_df.groupby("diagnosis").count().toPandas().set_index("diagnosis").plot.bar()
[7]:
<AxesSubplot:xlabel='diagnosis'>
../../_images/apache_spark_04_mlib_main_03_MLlib_main_mlp_classifier_15_1.png
[8]:
##
## Cantidad de casos para cada diagnóstico.
##
spark_df.groupby("diagnosis").count().toPandas()
[8]:
diagnosis count
0 B 357
1 M 212
[9]:
##
## Probabilidades
##
pdf = spark_df.groupby("diagnosis").count().toPandas().set_index("diagnosis")
round(100 * pdf["count"] / pdf["count"].sum(), 1)
[9]:
diagnosis
B    62.7
M    37.3
Name: count, dtype: float64

Preparación de los datos#

[10]:
##
## Columnas que contiene el dataframe
##
inputCols = [a for a, _ in spark_df.dtypes]
inputCols.remove("diagnosis")
inputCols.remove("id")
len(inputCols)
[10]:
30
[11]:
from pyspark.ml.feature import VectorAssembler

vectorAssembler = VectorAssembler(
    inputCols=inputCols,
    outputCol="features",
)

spark_df = vectorAssembler.transform(spark_df)

spark_df.select("features").show()
+--------------------+
|            features|
+--------------------+
|[17.99,10.38,122....|
|[20.57,17.77,132....|
|[19.69,21.25,130....|
|[11.42,20.38,77.5...|
|[20.29,14.34,135....|
|[12.45,15.7,82.57...|
|[18.25,19.98,119....|
|[13.71,20.83,90.2...|
|[13.0,21.82,87.5,...|
|[12.46,24.04,83.9...|
|[16.02,23.24,102....|
|[15.78,17.89,103....|
|[19.17,24.8,132.4...|
|[15.85,23.95,103....|
|[13.73,22.61,93.6...|
|[14.54,27.54,96.7...|
|[14.68,20.13,94.7...|
|[16.13,20.68,108....|
|[19.81,22.15,130....|
|[13.54,14.36,87.4...|
+--------------------+
only showing top 20 rows

[12]:
import matplotlib.pyplot as plt
import seaborn as sns

##
## Los rangos de las variables numéricas son
## bastantes diferentes
##
plt.figure(figsize=(10, 6))
pdf = spark_df.toPandas()
pdf.pop('id')
sns.boxplot(data=pdf)
plt.xticks(rotation=90);
../../_images/apache_spark_04_mlib_main_03_MLlib_main_mlp_classifier_21_0.png
[13]:
##
## Escalamiento
##
from pyspark.ml.feature import MinMaxScaler

scaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures")
scalerModel = scaler.fit(spark_df)

spark_df = scalerModel.transform(spark_df)

spark_df.select("scaledFeatures").show()
+--------------------+
|      scaledFeatures|
+--------------------+
|[0.52103743669837...|
|[0.64314449335037...|
|[0.60149557480240...|
|[0.21009039708457...|
|[0.62989256472147...|
|[0.25883856311231...|
|[0.53334279899663...|
|[0.31847224194235...|
|[0.28486913720478...|
|[0.25931184627762...|
|[0.42780065313076...|
|[0.41644185716314...|
|[0.57688485020587...|
|[0.41975483932036...|
|[0.31941880827298...|
|[0.35775474466373...|
|[0.36438070897818...|
|[0.43300676794926...|
|[0.60717497278621...|
|[0.31042642813195...|
+--------------------+
only showing top 20 rows

[14]:
##
## Reemplaza {M,B} por {0,1}
##
from pyspark.sql.functions import when

spark_df = spark_df.withColumn('diagnosis', when(spark_df['diagnosis'] == 'M', 0).otherwise(spark_df['diagnosis']))
spark_df = spark_df.withColumn('diagnosis', when(spark_df['diagnosis'] == 'B', 1).otherwise(spark_df['diagnosis']))

spark_df.groupby("diagnosis").count().toPandas()
[14]:
diagnosis count
0 0 212
1 1 357
[15]:
##
## Transforma la columna diagnosis de string a entero
##
from pyspark.sql.types import IntegerType

spark_df = spark_df.withColumn("diagnosis", spark_df["diagnosis"].cast(IntegerType()))

Conjuntos de datos para entrenamiento y prueba#

[16]:
##
## Se usa el 80% de los datos para entrenamiento
## y el 20% restante para prueba
##
(train_df, test_df) = spark_df.randomSplit([0.8, 0.2])

Entrenamiento del modelo#

[17]:
from pyspark.ml.classification import MultilayerPerceptronClassifier

##
## Creación del modelo
##
trainer = MultilayerPerceptronClassifier(
    featuresCol="scaledFeatures",
    labelCol="diagnosis",
    predictionCol="prediction_MLP",
    probabilityCol="probability_MLP",
    rawPredictionCol="rawPrediction_MLP",
    maxIter=100,
    layers=[30, 1, 2], # 30 inputs, 1 neuron, 2 clases
    seed=1234,
)

##
## Entrenamiento
##
model = trainer.fit(train_df)

##
## Pronóstico
##
train_df = model.transform(train_df)
test_df = model.transform(test_df)

Desempeño del modelo#

[18]:
import pyspark.sql.functions as F
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql.types import FloatType
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


def print_stats(dataframes, prediction_col):

    evaluator = MulticlassClassificationEvaluator(
        labelCol="diagnosis",
        predictionCol=prediction_col,
        metricName="accuracy",
    )

    print("Train Error = %g " % (1.0 - evaluator.evaluate(train_df)))
    print("Test Error = %g " % (1.0 - evaluator.evaluate(test_df)))
    print()

    for df in dataframes:

        predictionAndLabels = df.select([prediction_col, "diagnosis"])
        predictionAndLabels = predictionAndLabels.withColumn(
            prediction_col, F.col(prediction_col).cast(FloatType())
        )
        predictionAndLabels = predictionAndLabels.withColumn(
            "diagnosis", F.col("diagnosis").cast(FloatType())
        )

        metrics = MulticlassMetrics(predictionAndLabels.rdd.map(tuple))

        print(metrics.confusionMatrix().toArray())
        print()

print_stats(dataframes=[train_df, test_df], prediction_col = 'prediction_MLP')
Train Error = 0
Test Error = 0.0645161

[[167.   0.]
 [  0. 278.]]

[[40.  5.]
 [ 3. 76.]]

Linear Support Vector Classifier#

[19]:
from pyspark.ml.classification import LinearSVC

##
## Creación del modelo
##
trainer = LinearSVC(
    featuresCol="scaledFeatures",
    labelCol="diagnosis",
    predictionCol="prediction_SVC",
    rawPredictionCol="rawPrediction_SVC",
    maxIter=100,
    regParam=0.0,
)

##
## Entrenamiento
##
model = trainer.fit(train_df)

##
## Pronóstico
##
train_df = model.transform(train_df)
test_df = model.transform(test_df)

print_stats(dataframes=[train_df, test_df], prediction_col = 'prediction_SVC')
Train Error = 0.0179775
Test Error = 0.0403226

[[161.   6.]
 [  2. 276.]]

[[42.  3.]
 [ 2. 77.]]

Ejercicio.— Cuál es el número óptimo de neuronas para una capa oculta.