Predicción de la evolución de la diabetes usando PySpark (MLlib: RDD-based)#

  • 30 min | Última modificación: Noviembre 6, 2020

Definición del problema#

Los modelos de regresión son ampliamente usados para la construcción de modelos de pronóstico para toma de decisiones. En este caso particular, el médico desearía tener un pronóstico del progreso de la diabetes con un horizonte de doce meses de sus pacientes con base en variables físicas y pruebas de laboratorio, con el fin de realizar mejores tratamientos. Véase https://www4.stat.ncsu.edu/~boos/var.select/diabetes.html

En este problema se tiene una base de datos de diez variables base (edad, sexo, índice de masa corporal, presión arterial, y seis variables medidas en sangre) para 442 pacientes, y un índice que mide el progreso de la diabetes un año después de la prueba. La información se encuentra en el archivo diabetes.csv. Se desea construir un modelo de regresión que permita pronosticar la evolución de la enfermedad con la información disponible.Se desea construir un modelo de regresión que permita pronosticar la evolución de la enfermedad con la información disponible.

Preparación del archivo de datos#

[1]:
#
# Descarga
#
!wget https://raw.githubusercontent.com/jdvelasq/datalabs/master/datasets/diabetes.csv
--2020-11-05 20:26:21--  https://raw.githubusercontent.com/jdvelasq/datalabs/master/datasets/diabetes.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 199.232.48.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|199.232.48.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 88791 (87K) [text/plain]
Saving to: ‘diabetes.csv’

diabetes.csv        100%[===================>]  86.71K  --.-KB/s    in 0.1s

2020-11-05 20:26:21 (800 KB/s) - ‘diabetes.csv’ saved [88791/88791]

[2]:
#
# Contenido del archivo
#
!head diabetes.csv
age,sex,bmi,bp,s1,s2,s3,s4,s5,s6,Y
0.0380759064334241,0.0506801187398187,0.0616962065186885,0.0218723549949558,-0.0442234984244464,-0.0348207628376986,-0.0434008456520269,-0.00259226199818282,0.0199084208763183,-0.0176461251598052,151.0
-0.00188201652779104,-0.044641636506989,-0.0514740612388061,-0.0263278347173518,-0.00844872411121698,-0.019163339748222,0.0744115640787594,-0.0394933828740919,-0.0683297436244215,-0.09220404962683,75.0
0.0852989062966783,0.0506801187398187,0.0444512133365941,-0.00567061055493425,-0.0455994512826475,-0.0341944659141195,-0.0323559322397657,-0.00259226199818282,0.00286377051894013,-0.0259303389894746,141.0
-0.0890629393522603,-0.044641636506989,-0.0115950145052127,-0.0366564467985606,0.0121905687618,0.0249905933641021,-0.0360375700438527,0.0343088588777263,0.0226920225667445,-0.0093619113301358,206.0
0.00538306037424807,-0.044641636506989,-0.0363846922044735,0.0218723549949558,0.00393485161259318,0.0155961395104161,0.0081420836051921,-0.00259226199818282,-0.0319914449413559,-0.0466408735636482,135.0
-0.0926954778032799,-0.044641636506989,-0.0406959404999971,-0.0194420933298793,-0.0689906498720667,-0.0792878444118122,0.0412768238419757,-0.076394503750001,-0.0411803851880079,-0.0963461565416647,97.0
-0.0454724779400257,0.0506801187398187,-0.0471628129432825,-0.015999222636143,-0.040095639849843,-0.0248000120604336,0.000778807997017968,-0.0394933828740919,-0.0629129499162512,-0.0383566597339788,138.0
0.063503675590561,0.0506801187398187,-0.00189470584028465,0.0666296740135272,0.0906198816792644,0.108914381123697,0.0228686348215404,0.0177033544835672,-0.0358167281015492,0.00306440941436832,63.0
0.0417084448844436,0.0506801187398187,0.0616962065186885,-0.0400993174922969,-0.0139525355440215,0.00620168565673016,-0.0286742944356786,-0.00259226199818282,-0.0149564750249113,0.0113486232440377,110.0
[3]:
#
# Remueve la primera fila del archivo
#
!sed '1d'  diabetes.csv >  diabetes0.csv
[4]:
#
# Mueve el archivo de datos al hdfs
#
!hdfs dfs -copyFromLocal diabetes0.csv /tmp/diabetes.csv

Inicialización de Spark#

[5]:
#
# Carga de las librerías de Spark
#
import findspark
from pyspark.sql import SparkSession

from pyspark import SparkConf, SparkContext

findspark.init()

APP_NAME = "spark-app"

conf = SparkConf().setAppName(APP_NAME)
sc = SparkContext(conf=conf)
spark = SparkSession(sc)

Carga de datos#

[6]:
rdd = sc.textFile("/tmp/diabetes.csv")

rdd = rdd.map(lambda w: w.split(","))

Análisis exploratorio#

[7]:
#
# Estadísitcas básicas
#
import numpy as np
import pandas as pd
from pyspark.mllib.stat import Statistics

x = rdd.map(lambda w: np.array(w))
summary = Statistics.colStats(x)

pd.DataFrame(
    {
        "Mean": summary.mean(),
        "Var": summary.variance(),
        "Max": summary.max(),
        "Min": summary.min(),
        "Count": summary.count(),
        "numNonZero": summary.numNonzeros(),
    },
    index=["age", "sex", "bmi", "bp", "s1", "s2", "s3", "s4", "s5", "s6", "Y"],
).transpose()
[7]:
age sex bmi bp s1 s2 s3 s4 s5 s6 Y
Mean -3.642919e-16 1.288032e-16 -8.023096e-16 1.281527e-16 -8.977194e-17 1.322727e-16 -4.565575e-16 3.872770e-16 -3.885781e-16 -3.404395e-16 152.133484
Var 2.267574e-03 2.267574e-03 2.267574e-03 2.267574e-03 2.267574e-03 2.267574e-03 2.267574e-03 2.267574e-03 2.267574e-03 2.267574e-03 5943.331348
Max 1.107267e-01 5.068012e-02 1.705552e-01 1.320442e-01 1.539137e-01 1.987880e-01 1.811791e-01 1.852344e-01 1.335990e-01 1.356118e-01 346.000000
Min -1.072256e-01 -4.464164e-02 -9.027530e-02 -1.123996e-01 -1.267807e-01 -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01 25.000000
Count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 442.000000
numNonZero 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 442.000000
[8]:
#
# histograma.
# Muestra que la mayor parte de la población tiene
# gastos bajos, mientras que disminuye la frecuencia
# de costos más altos es más baja.
#
# La función histogram de pyspark permite computar
# los datos de graficación del histograma, pero
# no grafica el histograma como tal.
#
h = rdd.map(lambda w: float(w[10])).histogram(11)
h
[8]:
([25.0,
  54.18181818181819,
  83.36363636363637,
  112.54545454545455,
  141.72727272727275,
  170.9090909090909,
  200.0909090909091,
  229.27272727272728,
  258.4545454545455,
  287.6363636363636,
  316.8181818181818,
  346.0],
 [33, 68, 71, 52, 46, 51, 29, 36, 35, 15, 6])
[ ]:
!pip3 install -q  matplotlib
[9]:
#
# Debe recurrise a Pandas para realizar la
# graficación
#
pd.DataFrame(list(zip(*h)), columns=["bin", "frequency"]).set_index("bin").plot(
    kind="bar"
);
../../_images/apache_spark_02_mlib_rdd_02_MLlib_RDD_regression_17_0.png
[10]:
#
# Cómputo de la correlación entre
# las variables numéricas del DataFrame
#
Statistics.corr(rdd, method="pearson")
[10]:
array([[ 1.        ,  0.1737371 ,  0.18508467,  0.33542671,  0.26006082,
         0.21924314, -0.07518097,  0.2038409 ,  0.27077678,  0.30173101,
         0.18788875],
       [ 0.1737371 ,  1.        ,  0.0881614 ,  0.24101317,  0.03527682,
         0.14263726, -0.37908963,  0.33211509,  0.14991756,  0.20813322,
         0.043062  ],
       [ 0.18508467,  0.0881614 ,  1.        ,  0.39541532,  0.24977742,
         0.26116991, -0.36681098,  0.4138066 ,  0.44615865,  0.38867999,
         0.58645013],
       [ 0.33542671,  0.24101317,  0.39541532,  1.        ,  0.24246971,
         0.18555783, -0.17876121,  0.25765337,  0.39347814,  0.39042938,
         0.44148385],
       [ 0.26006082,  0.03527682,  0.24977742,  0.24246971,  1.        ,
         0.89666296,  0.05151936,  0.54220728,  0.51550076,  0.32571675,
         0.21202248],
       [ 0.21924314,  0.14263726,  0.26116991,  0.18555783,  0.89666296,
         1.        , -0.19645512,  0.65981689,  0.3183534 ,  0.29060038,
         0.17405359],
       [-0.07518097, -0.37908963, -0.36681098, -0.17876121,  0.05151936,
        -0.19645512,  1.        , -0.73849273, -0.398577  , -0.2736973 ,
        -0.39478925],
       [ 0.2038409 ,  0.33211509,  0.4138066 ,  0.25765337,  0.54220728,
         0.65981689, -0.73849273,  1.        ,  0.61785739,  0.41721211,
         0.43045288],
       [ 0.27077678,  0.14991756,  0.44615865,  0.39347814,  0.51550076,
         0.3183534 , -0.398577  ,  0.61785739,  1.        ,  0.46467046,
         0.56588343],
       [ 0.30173101,  0.20813322,  0.38867999,  0.39042938,  0.32571675,
         0.29060038, -0.2736973 ,  0.41721211,  0.46467046,  1.        ,
         0.38248348],
       [ 0.18788875,  0.043062  ,  0.58645013,  0.44148385,  0.21202248,
         0.17405359, -0.39478925,  0.43045288,  0.56588343,  0.38248348,
         1.        ]])

Preparación de los datos#

[12]:
from pyspark.mllib.regression import LabeledPoint

x = rdd.map(lambda w: LabeledPoint(w[10], w[:10]))
print(x.collect()[0:5])
[LabeledPoint(151.0, [0.0380759064334241,0.0506801187398187,0.0616962065186885,0.0218723549949558,-0.0442234984244464,-0.0348207628376986,-0.0434008456520269,-0.00259226199818282,0.0199084208763183,-0.0176461251598052]), LabeledPoint(75.0, [-0.00188201652779104,-0.044641636506989,-0.0514740612388061,-0.0263278347173518,-0.00844872411121698,-0.019163339748222,0.0744115640787594,-0.0394933828740919,-0.0683297436244215,-0.09220404962683]), LabeledPoint(141.0, [0.0852989062966783,0.0506801187398187,0.0444512133365941,-0.00567061055493425,-0.0455994512826475,-0.0341944659141195,-0.0323559322397657,-0.00259226199818282,0.00286377051894013,-0.0259303389894746]), LabeledPoint(206.0, [-0.0890629393522603,-0.044641636506989,-0.0115950145052127,-0.0366564467985606,0.0121905687618,0.0249905933641021,-0.0360375700438527,0.0343088588777263,0.0226920225667445,-0.0093619113301358]), LabeledPoint(135.0, [0.00538306037424807,-0.044641636506989,-0.0363846922044735,0.0218723549949558,0.00393485161259318,0.0155961395104161,0.0081420836051921,-0.00259226199818282,-0.0319914449413559,-0.0466408735636482])]

Estimación del modelo#

[13]:
import numpy as np
from pyspark.mllib.regression import LinearRegressionWithSGD

model = LinearRegressionWithSGD().train(
    x,
    iterations=1000,
    intercept=True,
    initialWeights=np.array([1.0] * 10),
    step=0.1,
    regType="l2",
    regParam=0.0,
)

print("Parameters:")
print("---------------------------------")
print("Intercept:", model.intercept)
print("Weights:", model.weights)
print()

x.map(lambda lp: (lp.label, model.predict(lp.features))).collect()[:5]
Parameters:
---------------------------------
Intercept: 136.51915849333383
Weights: [2.509049930309272,1.3287507854099163,5.7661262578376675,4.579821109698119,2.6933893176651207,2.3834423738656616,-2.2049661818403274,4.477153030983179,5.589557700419026,4.091438278742303]
[13]:
[(151.0, 137.05902197948032),
 (75.0, 134.86923255784592),
 (141.0, 136.79619488005662),
 (206.0, 136.41563914422062),
 (135.0, 136.01228532658578)]