Generación de estimados por cada entrada con cross_val_predict — 2:57#
Ultima modificación: 2023-02-27 | YouTube
Genera estimados usando validación cruzada para cada punto del dataset.
Cada muestra pertenece únicamente a un test set, y es pronosticada con estimador entrenado sobre el correspondiente fit test.
Los resultados no se deben usar con una métrica de evaluación para medir la generalización.
[1]:
from sklearn import datasets, linear_model
from sklearn.model_selection import cross_val_predict
diabetes = datasets.load_diabetes()
X = diabetes.data[:150]
y = diabetes.target[:150]
lasso = linear_model.Lasso()
y_pred = cross_val_predict(
# -------------------------------------------------------------------------
# The object to use to fit the data. Must implement fit()
estimator=lasso,
# -------------------------------------------------------------------------
# The data to fit. Can be for example a list, or an array.
X=X,
# -------------------------------------------------------------------------
# The target variable to try to predict in the case of supervised learning.
y=y,
# -------------------------------------------------------------------------
# Group labels for the samples used while splitting the dataset into
# train/test set. Only used in conjunction with a “Group” cv instance
# (e.g., GroupKFold).
groups=None,
# -------------------------------------------------------------------------
# Determines the cross-validation splitting strategy.
cv=3,
# -------------------------------------------------------------------------
# The verbosity level.
verbose=0,
# -------------------------------------------------------------------------
# Parameters to pass to the fit method of the estimator.
fit_params=None,
# -------------------------------------------------------------------------
# The method to be invoked by estimator.
# * 'predict'
# * 'predict_proba'
# * 'predict_log_proba'
# * 'decision_function'
method='predict',
)
y_pred
[1]:
array([174.26880804, 117.65351812, 164.60139517, 155.64896614,
132.68760433, 128.49628881, 120.76054246, 141.06973577,
164.18815157, 182.37359235, 111.04089197, 127.94301803,
135.08790036, 162.83048118, 135.35840097, 157.646062 ,
178.95718913, 163.39063561, 143.8537171 , 144.29750064,
133.58233051, 124.77837092, 132.90970661, 208.52815059,
153.61883822, 154.16510156, 118.95406457, 163.50463573,
145.89310221, 168.33041101, 155.87407921, 123.45950035,
185.70488949, 133.38607278, 117.27952346, 150.27889968,
174.15239599, 160.03182658, 192.31482952, 161.58472303,
154.22198208, 119.3542199 , 146.15701123, 133.82133127,
179.68058182, 137.96512787, 146.07828391, 126.77695128,
123.32171096, 166.26579186, 146.41482006, 161.67179165,
147.47956035, 138.44754455, 144.85639621, 113.77894392,
185.54822207, 115.3184268 , 142.23537263, 171.07810595,
132.53792741, 177.804149 , 116.56316924, 134.25390247,
142.88636755, 173.2810216 , 154.3131706 , 149.16866669,
144.88198422, 121.97620921, 110.38264649, 180.256848 ,
199.05936395, 151.12103871, 161.14084613, 153.96825789,
150.77351299, 113.30728681, 165.15799964, 115.85653979,
174.19357338, 150.11981272, 115.47896755, 153.38941372,
115.31491321, 156.50136756, 92.62132743, 178.15703742,
131.59414072, 134.4595406 , 116.97477767, 190.00997209,
166.01073896, 126.26094663, 134.29467801, 144.72144347,
190.97839825, 182.39093099, 154.45282532, 148.30386965,
151.72181945, 124.12933168, 138.60217036, 137.75932827,
123.09198296, 131.74775137, 112.07413774, 124.56829499,
156.78556054, 128.63113028, 93.68261606, 130.54248822,
131.86857136, 154.57121467, 179.8136611 , 165.78203933,
150.04784865, 162.3802658 , 143.92902775, 143.15527739,
125.20186268, 145.99563268, 155.35031858, 145.97631553,
134.66102139, 163.92609605, 101.92269793, 139.32878073,
122.714843 , 152.20573137, 153.36795743, 116.76441525,
131.96912343, 109.7467123 , 132.57442844, 159.38054395,
109.31271822, 147.69895686, 156.36777035, 161.12384083,
128.16629451, 156.784652 , 154.04498872, 124.83843865,
143.85525003, 143.23635903, 147.76264284, 154.21725374,
129.08004207, 157.79641814])