Búsqueda Exhaustiva de hiperparámetros usando GridSearchCV#

En muchos casos, los modelos contienen diferentes hiperparámetros que controlan su configuración y la estimación de los parámetros. Por ejemplo, en el ejemplo del ajuste del polinomio, el grado n es un hiperparámetro. En este tutorial, se presenta como abordar el problema cuando hay más de un hiperparámetro que debe ser ajustado.

Parametrización de la búsqueda#

[1]:
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

#
# Aca se usara una SVM. Dependiendo del tipo de kernel cambian los parámetros
# que pueden ajustarse.
#
# La variable tuned_parameters es una lista de diccionarios que contiene los
# valores que pueden ajustarse.
#
param_grid = [
    # -------------------------------------------------------------------------
    # Primera malla de parámetros
    {
        "kernel": ["rbf"],
        "gamma": [1e-3, 1e-4],
        "C": [1, 10, 100, 1000],
    },
    # -------------------------------------------------------------------------
    # Segunda malla de parámetros
    {
        "kernel": ["linear"],
        "C": [1, 10, 100, 1000],
    },
]

gridSearchCV = GridSearchCV(
    # --------------------------------------------------------------------------
    # This is assumed to implement the scikit-learn estimator interface.
    estimator=SVC(),
    # --------------------------------------------------------------------------
    # Dictionary with parameters names (str) as keys and lists of parameter
    # settings to try as values, or a list of such dictionaries
    param_grid=param_grid,
    # --------------------------------------------------------------------------
    # Determines the cross-validation splitting strategy.
    cv=5,
    # --------------------------------------------------------------------------
    # Strategy to evaluate the performance of the cross-validated model on the
    # test set.
    scoring="accuracy",
    # --------------------------------------------------------------------------
    # Refit an estimator using the best found parameters on the whole dataset.
    refit=True,
    # --------------------------------------------------------------------------
    # If False, the cv_results_ attribute will not include training scores.
    return_train_score=False,
)

Principales métricas disponibles para el scoring#

Se encuentran disponibles en:

https://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter

  • Clasificación:

    • ‘accuracy’

    • ‘balanced_accuracy’

  • Regresión:

    • ‘neg_mean_absolute_error’

    • ‘neg_mean_squared_error’

    • ‘neg_root_mean_squared_error’

    • ‘r2’

Preparación de los datos#

[2]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

digits = load_digits()

n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.5,
    random_state=0,
)

Realización de la búsqueda#

[3]:
gridSearchCV.fit(X_train, y_train)
[3]:
GridSearchCV(cv=5, estimator=SVC(),
             param_grid=[{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
                          'kernel': ['rbf']},
                         {'C': [1, 10, 100, 1000], 'kernel': ['linear']}],
             scoring='accuracy')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Valores retornados#

[4]:
gridSearchCV.cv_results_
[4]:
{'mean_fit_time': array([0.01856437, 0.01342072, 0.01756868, 0.00850973, 0.01629653,
        0.00868969, 0.01653719, 0.00836463, 0.00683012, 0.00677333,
        0.00679779, 0.00681438]),
 'std_fit_time': array([2.89455612e-03, 3.50194241e-04, 1.72907452e-03, 6.64449268e-05,
        1.29181919e-04, 6.61261387e-04, 8.92717381e-04, 1.37359679e-04,
        1.62231088e-04, 1.20423768e-04, 1.60550795e-04, 1.62279714e-04]),
 'mean_score_time': array([0.00502486, 0.00571933, 0.00491943, 0.00344534, 0.00470152,
        0.00343714, 0.00461111, 0.00336585, 0.00163321, 0.00157847,
        0.0015861 , 0.00159302]),
 'std_score_time': array([5.80807515e-04, 1.87975631e-04, 3.26420638e-04, 2.64588341e-05,
        2.11185229e-04, 1.61336117e-04, 7.61566715e-05, 6.71735789e-05,
        5.87190450e-05, 2.21226748e-05, 4.56135557e-05, 6.10881441e-05]),
 'param_C': masked_array(data=[1, 1, 10, 10, 100, 100, 1000, 1000, 1, 10, 100, 1000],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False],
        fill_value='?',
             dtype=object),
 'param_gamma': masked_array(data=[0.001, 0.0001, 0.001, 0.0001, 0.001, 0.0001, 0.001,
                    0.0001, --, --, --, --],
              mask=[False, False, False, False, False, False, False, False,
                     True,  True,  True,  True],
        fill_value='?',
             dtype=object),
 'param_kernel': masked_array(data=['rbf', 'rbf', 'rbf', 'rbf', 'rbf', 'rbf', 'rbf', 'rbf',
                    'linear', 'linear', 'linear', 'linear'],
              mask=[False, False, False, False, False, False, False, False,
                    False, False, False, False],
        fill_value='?',
             dtype=object),
 'params': [{'C': 1, 'gamma': 0.001, 'kernel': 'rbf'},
  {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'},
  {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'},
  {'C': 10, 'gamma': 0.0001, 'kernel': 'rbf'},
  {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'},
  {'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'},
  {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'},
  {'C': 1000, 'gamma': 0.0001, 'kernel': 'rbf'},
  {'C': 1, 'kernel': 'linear'},
  {'C': 10, 'kernel': 'linear'},
  {'C': 100, 'kernel': 'linear'},
  {'C': 1000, 'kernel': 'linear'}],
 'split0_test_score': array([0.99444444, 0.97777778, 0.99444444, 0.99444444, 0.99444444,
        0.98888889, 0.99444444, 0.98888889, 0.97222222, 0.97222222,
        0.97222222, 0.97222222]),
 'split1_test_score': array([0.98888889, 0.96111111, 0.98888889, 0.97777778, 0.98888889,
        0.98333333, 0.98888889, 0.98333333, 0.97777778, 0.97777778,
        0.97777778, 0.97777778]),
 'split2_test_score': array([0.96666667, 0.94444444, 0.96666667, 0.96111111, 0.96666667,
        0.95555556, 0.96666667, 0.95555556, 0.96111111, 0.96111111,
        0.96111111, 0.96111111]),
 'split3_test_score': array([0.98324022, 0.93854749, 0.98882682, 0.97206704, 0.98882682,
        0.98882682, 0.98882682, 0.98882682, 0.97206704, 0.97206704,
        0.97206704, 0.97206704]),
 'split4_test_score': array([0.99441341, 0.96648045, 0.99441341, 1.        , 0.99441341,
        0.99441341, 0.99441341, 0.99441341, 0.97765363, 0.97765363,
        0.97765363, 0.97765363]),
 'mean_test_score': array([0.98553073, 0.95767225, 0.98664804, 0.98108007, 0.98664804,
        0.9822036 , 0.98664804, 0.9822036 , 0.97216636, 0.97216636,
        0.97216636, 0.97216636]),
 'std_test_score': array([0.01030275, 0.01438314, 0.01029668, 0.01433506, 0.01029668,
        0.01377704, 0.01029668, 0.01377704, 0.00606349, 0.00606349,
        0.00606349, 0.00606349]),
 'rank_test_score': array([ 4, 12,  1,  7,  1,  5,  1,  5,  8,  8,  8,  8], dtype=int32)}
[5]:
#
# Estimator that was chosen by the search, i.e. estimator which gave highest
# score (or smallest loss if specified) on the left out data.
#
gridSearchCV.best_estimator_
[5]:
SVC(C=10, gamma=0.001)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
[6]:
gridSearchCV.best_score_
[6]:
0.9866480446927375
[7]:
gridSearchCV.best_params_
[7]:
{'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}

Pronóstico con el mejor modelo#

[8]:
gridSearchCV.predict(X_train)
[8]:
array([1, 4, 9, 0, 4, 1, 1, 5, 9, 1, 4, 2, 6, 3, 9, 7, 6, 4, 8, 6, 8, 7,
       6, 0, 5, 9, 4, 7, 3, 4, 9, 4, 9, 7, 9, 1, 5, 6, 0, 0, 4, 3, 6, 1,
       0, 9, 4, 8, 7, 5, 9, 8, 4, 5, 0, 1, 6, 0, 5, 5, 0, 4, 3, 2, 8, 7,
       6, 3, 4, 2, 5, 8, 0, 6, 9, 4, 5, 4, 9, 7, 3, 3, 1, 4, 4, 2, 6, 8,
       1, 1, 0, 3, 7, 4, 6, 7, 4, 0, 5, 2, 9, 2, 1, 9, 2, 3, 1, 7, 7, 4,
       5, 6, 5, 6, 7, 8, 1, 4, 3, 4, 4, 3, 5, 3, 3, 4, 7, 9, 8, 0, 6, 1,
       9, 0, 8, 4, 1, 2, 3, 9, 7, 8, 8, 8, 3, 7, 5, 7, 0, 1, 7, 8, 3, 8,
       0, 4, 8, 6, 2, 3, 6, 7, 3, 7, 7, 1, 3, 5, 0, 9, 8, 5, 3, 1, 2, 0,
       3, 6, 0, 3, 4, 1, 2, 3, 1, 0, 5, 8, 9, 3, 9, 6, 6, 8, 9, 0, 7, 8,
       2, 0, 0, 7, 7, 4, 5, 3, 1, 8, 5, 9, 6, 2, 9, 7, 7, 9, 5, 4, 2, 6,
       6, 1, 3, 4, 7, 2, 8, 0, 6, 1, 6, 6, 5, 8, 4, 3, 0, 5, 2, 9, 9, 7,
       8, 0, 5, 0, 6, 3, 3, 5, 1, 5, 1, 7, 9, 6, 4, 5, 0, 1, 8, 7, 8, 8,
       8, 9, 8, 7, 7, 2, 2, 2, 8, 0, 7, 8, 6, 8, 0, 4, 2, 2, 3, 7, 9, 0,
       2, 0, 0, 2, 7, 1, 5, 6, 4, 0, 0, 5, 5, 3, 9, 6, 1, 6, 0, 6, 4, 0,
       1, 8, 2, 2, 3, 7, 6, 1, 1, 2, 4, 7, 4, 9, 4, 3, 0, 4, 3, 1, 3, 0,
       9, 4, 6, 0, 3, 2, 6, 2, 5, 6, 7, 8, 8, 4, 4, 6, 9, 4, 5, 4, 5, 7,
       1, 9, 6, 8, 0, 4, 1, 9, 9, 7, 1, 8, 5, 0, 8, 7, 7, 2, 1, 3, 7, 4,
       0, 6, 3, 1, 2, 9, 9, 2, 5, 7, 3, 0, 6, 1, 6, 1, 1, 2, 5, 5, 3, 2,
       8, 5, 0, 9, 6, 9, 8, 4, 5, 8, 1, 6, 3, 0, 4, 6, 1, 8, 3, 4, 7, 1,
       0, 7, 9, 2, 7, 2, 1, 6, 9, 3, 1, 3, 2, 4, 3, 4, 3, 3, 5, 4, 7, 3,
       6, 7, 0, 0, 1, 1, 0, 2, 0, 7, 7, 4, 7, 2, 0, 1, 2, 4, 8, 1, 6, 0,
       3, 4, 0, 6, 8, 4, 4, 9, 0, 8, 4, 6, 8, 7, 8, 2, 8, 1, 6, 6, 9, 5,
       3, 8, 5, 1, 3, 3, 1, 8, 8, 3, 0, 4, 1, 7, 2, 7, 4, 0, 4, 2, 7, 7,
       9, 1, 9, 0, 9, 3, 8, 6, 2, 5, 3, 3, 7, 2, 1, 0, 8, 7, 7, 3, 1, 2,
       4, 5, 7, 7, 9, 1, 5, 5, 2, 8, 7, 9, 4, 7, 0, 2, 6, 1, 3, 1, 3, 7,
       3, 6, 7, 1, 6, 6, 1, 0, 6, 9, 7, 7, 4, 4, 9, 1, 5, 1, 1, 7, 2, 6,
       6, 4, 3, 1, 0, 5, 3, 9, 5, 8, 1, 7, 9, 9, 8, 2, 1, 0, 6, 6, 4, 4,
       7, 8, 6, 5, 8, 8, 2, 2, 2, 9, 8, 8, 3, 6, 0, 4, 4, 7, 6, 6, 9, 0,
       4, 6, 8, 5, 1, 9, 9, 3, 1, 6, 5, 9, 7, 3, 4, 4, 2, 4, 4, 9, 2, 9,
       9, 7, 2, 3, 3, 3, 7, 2, 7, 8, 1, 0, 5, 6, 6, 8, 0, 7, 0, 4, 2, 6,
       6, 8, 6, 4, 7, 7, 0, 3, 0, 7, 4, 0, 0, 2, 1, 8, 4, 2, 2, 9, 9, 3,
       3, 4, 4, 2, 6, 3, 7, 2, 8, 4, 2, 9, 5, 1, 9, 0, 9, 7, 2, 6, 2, 1,
       6, 9, 9, 3, 8, 3, 6, 2, 2, 4, 9, 3, 4, 6, 8, 6, 1, 7, 4, 1, 4, 7,
       0, 1, 5, 6, 2, 7, 8, 4, 9, 0, 9, 0, 5, 2, 2, 4, 1, 8, 8, 7, 2, 9,
       7, 0, 0, 6, 0, 5, 0, 5, 1, 0, 8, 6, 6, 0, 3, 4, 0, 3, 5, 6, 9, 8,
       4, 8, 5, 2, 7, 5, 5, 1, 1, 8, 9, 0, 3, 4, 9, 2, 9, 3, 1, 7, 5, 4,
       9, 5, 7, 7, 7, 0, 1, 9, 1, 9, 7, 1, 3, 9, 4, 9, 2, 5, 3, 5, 6, 2,
       3, 0, 7, 3, 2, 5, 2, 6, 1, 7, 0, 2, 9, 4, 2, 0, 7, 4, 4, 9, 0, 4,
       9, 6, 8, 2, 6, 4, 4, 9, 1, 2, 4, 7, 8, 8, 2, 2, 5, 9, 5, 4, 3, 1,
       4, 7, 1, 5, 3, 8, 6, 5, 5, 2, 1, 0, 9, 1, 6, 9, 3, 7, 4, 5, 6, 3,
       6, 5, 2, 7, 6, 3, 4, 1, 1, 4, 8, 4, 5, 3, 3, 7, 7, 8])