Generación de estimados por cada entrada con cross_val_predict — 2:57#

  • Ultima modificación: 2023-02-27 | YouTube

  • Genera estimados usando validación cruzada para cada punto del dataset.

  • Cada muestra pertenece únicamente a un test set, y es pronosticada con estimador entrenado sobre el correspondiente fit test.

  • Los resultados no se deben usar con una métrica de evaluación para medir la generalización.

[1]:
from sklearn import datasets, linear_model
from sklearn.model_selection import cross_val_predict

diabetes = datasets.load_diabetes()
X = diabetes.data[:150]
y = diabetes.target[:150]

lasso = linear_model.Lasso()

y_pred = cross_val_predict(
    # -------------------------------------------------------------------------
    # The object to use to fit the data. Must implement fit()
    estimator=lasso,
    # -------------------------------------------------------------------------
    # The data to fit. Can be for example a list, or an array.
    X=X,
    # -------------------------------------------------------------------------
    # The target variable to try to predict in the case of supervised learning.
    y=y,
    # -------------------------------------------------------------------------
    # Group labels for the samples used while splitting the dataset into
    # train/test set. Only used in conjunction with a “Group” cv instance
    # (e.g., GroupKFold).
    groups=None,
    # -------------------------------------------------------------------------
    # Determines the cross-validation splitting strategy.
    cv=3,
    # -------------------------------------------------------------------------
    # The verbosity level.
    verbose=0,
    # -------------------------------------------------------------------------
    # Parameters to pass to the fit method of the estimator.
    fit_params=None,
    # -------------------------------------------------------------------------
    # The method to be invoked by estimator.
    # * 'predict'
    # * 'predict_proba'
    # * 'predict_log_proba'
    # * 'decision_function'
    method='predict',
)

y_pred
[1]:
array([174.26880804, 117.65351812, 164.60139517, 155.64896614,
       132.68760433, 128.49628881, 120.76054246, 141.06973577,
       164.18815157, 182.37359235, 111.04089197, 127.94301803,
       135.08790036, 162.83048118, 135.35840097, 157.646062  ,
       178.95718913, 163.39063561, 143.8537171 , 144.29750064,
       133.58233051, 124.77837092, 132.90970661, 208.52815059,
       153.61883822, 154.16510156, 118.95406457, 163.50463573,
       145.89310221, 168.33041101, 155.87407921, 123.45950035,
       185.70488949, 133.38607278, 117.27952346, 150.27889968,
       174.15239599, 160.03182658, 192.31482952, 161.58472303,
       154.22198208, 119.3542199 , 146.15701123, 133.82133127,
       179.68058182, 137.96512787, 146.07828391, 126.77695128,
       123.32171096, 166.26579186, 146.41482006, 161.67179165,
       147.47956035, 138.44754455, 144.85639621, 113.77894392,
       185.54822207, 115.3184268 , 142.23537263, 171.07810595,
       132.53792741, 177.804149  , 116.56316924, 134.25390247,
       142.88636755, 173.2810216 , 154.3131706 , 149.16866669,
       144.88198422, 121.97620921, 110.38264649, 180.256848  ,
       199.05936395, 151.12103871, 161.14084613, 153.96825789,
       150.77351299, 113.30728681, 165.15799964, 115.85653979,
       174.19357338, 150.11981272, 115.47896755, 153.38941372,
       115.31491321, 156.50136756,  92.62132743, 178.15703742,
       131.59414072, 134.4595406 , 116.97477767, 190.00997209,
       166.01073896, 126.26094663, 134.29467801, 144.72144347,
       190.97839825, 182.39093099, 154.45282532, 148.30386965,
       151.72181945, 124.12933168, 138.60217036, 137.75932827,
       123.09198296, 131.74775137, 112.07413774, 124.56829499,
       156.78556054, 128.63113028,  93.68261606, 130.54248822,
       131.86857136, 154.57121467, 179.8136611 , 165.78203933,
       150.04784865, 162.3802658 , 143.92902775, 143.15527739,
       125.20186268, 145.99563268, 155.35031858, 145.97631553,
       134.66102139, 163.92609605, 101.92269793, 139.32878073,
       122.714843  , 152.20573137, 153.36795743, 116.76441525,
       131.96912343, 109.7467123 , 132.57442844, 159.38054395,
       109.31271822, 147.69895686, 156.36777035, 161.12384083,
       128.16629451, 156.784652  , 154.04498872, 124.83843865,
       143.85525003, 143.23635903, 147.76264284, 154.21725374,
       129.08004207, 157.79641814])