Búsqueda aleatoria de hiperparámetros usando RandomizedSearchCV#

Parametrización de la búsqueda#

[1]:
#
# Aca se usara una SVM. Dependiendo del tipo de kernel cambian los parámetros
# que pueden ajustarse.
#
# La variable tuned_parameters es una lista de diccionarios que contiene los
# valores que pueden ajustarse.
#
import scipy
from sklearn.model_selection import RandomizedSearchCV
from sklearn.svm import SVC

param_distributions = [
    # -------------------------------------------------------------------------
    # Selección aleatoria de valores para el primer modelo
    {
        "kernel": ["rbf"],
        "gamma": scipy.stats.expon(scale=0.1),
        "C": scipy.stats.expon(scale=100),
    },
    # -------------------------------------------------------------------------
    # Selección aleatoria de valores para el segundo modelo
    {
        "kernel": ["linear"],
        "C": [1, 10, 100],
    },
]

randomizedSearchCV = RandomizedSearchCV(
    # --------------------------------------------------------------------------
    # This is assumed to implement the scikit-learn estimator interface.
    estimator=SVC(),
    # --------------------------------------------------------------------------
    # Dictionary with parameters names (str) as keys and distributions or lists
    # of parameters to try.
    param_distributions=param_distributions,
    # --------------------------------------------------------------------------
    # Number of parameter settings that are sampled.
    n_iter=10,
    # --------------------------------------------------------------------------
    # Determines the cross-validation splitting strategy.
    cv=5,
    # --------------------------------------------------------------------------
    # Strategy to evaluate the performance of the cross-validated model on the
    # test set.
    scoring="accuracy",
    # --------------------------------------------------------------------------
    # Refit an estimator using the best found parameters on the whole dataset.
    refit=True,
    # --------------------------------------------------------------------------
    random_state=12345,
)

Preparación de los datos#

[2]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

digits = load_digits()

n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.5,
    random_state=0,
)

Realización de la búsqueda#

[3]:
randomizedSearchCV.fit(X_train, y_train)
[3]:
RandomizedSearchCV(cv=5, estimator=SVC(),
                   param_distributions=[{'C': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x7f93f8500190>,
                                         'gamma': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x7f93f8500040>,
                                         'kernel': ['rbf']},
                                        {'C': [1, 10, 100],
                                         'kernel': ['linear']}],
                   random_state=12345, scoring='accuracy')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Resultados obtenidos#

[4]:
randomizedSearchCV.cv_results_
[4]:
{'mean_fit_time': array([0.04400897, 0.00801969, 0.04970441, 0.05731525, 0.00747628,
        0.0398592 , 0.04776983, 0.00783315, 0.05106816, 0.00722532]),
 'std_fit_time': array([0.00382408, 0.000706  , 0.00289493, 0.00141426, 0.0003996 ,
        0.00302346, 0.00195418, 0.00033007, 0.00175536, 0.00010216]),
 'mean_score_time': array([0.0082077 , 0.00192423, 0.00866947, 0.01225367, 0.00183663,
        0.00636201, 0.00870934, 0.00202675, 0.00914736, 0.00188689]),
 'std_score_time': array([0.0006078 , 0.00037137, 0.00070998, 0.00037714, 0.0002058 ,
        0.00035371, 0.00070078, 0.00051744, 0.00042682, 0.00018982]),
 'param_C': masked_array(data=[220.8682396496381, 100, 83.86933864671792,
                    61.97597554687909, 100, 222.88118074128667,
                    11.254669304411362, 10, 81.28259452454746, 1],
              mask=[False, False, False, False, False, False, False, False,
                    False, False],
        fill_value='?',
             dtype=object),
 'param_gamma': masked_array(data=[0.014007538087890578, --, 0.09052140627545172,
                    0.2573582562354472, --, 0.0027155110061623483,
                    0.03548248190035313, --, 0.18320769168351672, --],
              mask=[False,  True, False, False,  True, False, False,  True,
                    False,  True],
        fill_value='?',
             dtype=object),
 'param_kernel': masked_array(data=['rbf', 'linear', 'rbf', 'rbf', 'linear', 'rbf', 'rbf',
                    'linear', 'rbf', 'linear'],
              mask=[False, False, False, False, False, False, False, False,
                    False, False],
        fill_value='?',
             dtype=object),
 'params': [{'C': 220.8682396496381,
   'gamma': 0.014007538087890578,
   'kernel': 'rbf'},
  {'C': 100, 'kernel': 'linear'},
  {'C': 83.86933864671792, 'gamma': 0.09052140627545172, 'kernel': 'rbf'},
  {'C': 61.97597554687909, 'gamma': 0.2573582562354472, 'kernel': 'rbf'},
  {'C': 100, 'kernel': 'linear'},
  {'C': 222.88118074128667, 'gamma': 0.0027155110061623483, 'kernel': 'rbf'},
  {'C': 11.254669304411362, 'gamma': 0.03548248190035313, 'kernel': 'rbf'},
  {'C': 10, 'kernel': 'linear'},
  {'C': 81.28259452454746, 'gamma': 0.18320769168351672, 'kernel': 'rbf'},
  {'C': 1, 'kernel': 'linear'}],
 'split0_test_score': array([0.30555556, 0.97222222, 0.11666667, 0.11666667, 0.97222222,
        0.99444444, 0.11666667, 0.97222222, 0.11666667, 0.97222222]),
 'split1_test_score': array([0.37777778, 0.97777778, 0.11666667, 0.11666667, 0.97777778,
        0.98888889, 0.11666667, 0.97777778, 0.11666667, 0.97777778]),
 'split2_test_score': array([0.34444444, 0.96111111, 0.11666667, 0.11666667, 0.96111111,
        0.96111111, 0.13333333, 0.96111111, 0.11666667, 0.96111111]),
 'split3_test_score': array([0.36871508, 0.97206704, 0.11731844, 0.11731844, 0.97206704,
        0.97765363, 0.12290503, 0.97206704, 0.11731844, 0.97206704]),
 'split4_test_score': array([0.33519553, 0.97765363, 0.11731844, 0.11731844, 0.97765363,
        0.99441341, 0.12290503, 0.97765363, 0.11731844, 0.97765363]),
 'mean_test_score': array([0.34633768, 0.97216636, 0.11692737, 0.11692737, 0.97216636,
        0.9833023 , 0.12249534, 0.97216636, 0.11692737, 0.97216636]),
 'std_test_score': array([0.02561305, 0.00606349, 0.0003193 , 0.0003193 , 0.00606349,
        0.01267415, 0.00609499, 0.00606349, 0.0003193 , 0.00606349]),
 'rank_test_score': array([6, 2, 8, 8, 2, 1, 7, 2, 8, 2], dtype=int32)}
[5]:
#
# Estimator that was chosen by the search, i.e. estimator which gave highest
# score (or smallest loss if specified) on the left out data.
#
randomizedSearchCV.best_estimator_
[5]:
SVC(C=222.88118074128667, gamma=0.0027155110061623483)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
[6]:
randomizedSearchCV.best_score_
[6]:
0.9833022967101179
[7]:
randomizedSearchCV.best_params_
[7]:
{'C': 222.88118074128667, 'gamma': 0.0027155110061623483, 'kernel': 'rbf'}

Pronóstico con el mejor modelo#

[8]:
randomizedSearchCV.predict(X_train)
[8]:
array([1, 4, 9, 0, 4, 1, 1, 5, 9, 1, 4, 2, 6, 3, 9, 7, 6, 4, 8, 6, 8, 7,
       6, 0, 5, 9, 4, 7, 3, 4, 9, 4, 9, 7, 9, 1, 5, 6, 0, 0, 4, 3, 6, 1,
       0, 9, 4, 8, 7, 5, 9, 8, 4, 5, 0, 1, 6, 0, 5, 5, 0, 4, 3, 2, 8, 7,
       6, 3, 4, 2, 5, 8, 0, 6, 9, 4, 5, 4, 9, 7, 3, 3, 1, 4, 4, 2, 6, 8,
       1, 1, 0, 3, 7, 4, 6, 7, 4, 0, 5, 2, 9, 2, 1, 9, 2, 3, 1, 7, 7, 4,
       5, 6, 5, 6, 7, 8, 1, 4, 3, 4, 4, 3, 5, 3, 3, 4, 7, 9, 8, 0, 6, 1,
       9, 0, 8, 4, 1, 2, 3, 9, 7, 8, 8, 8, 3, 7, 5, 7, 0, 1, 7, 8, 3, 8,
       0, 4, 8, 6, 2, 3, 6, 7, 3, 7, 7, 1, 3, 5, 0, 9, 8, 5, 3, 1, 2, 0,
       3, 6, 0, 3, 4, 1, 2, 3, 1, 0, 5, 8, 9, 3, 9, 6, 6, 8, 9, 0, 7, 8,
       2, 0, 0, 7, 7, 4, 5, 3, 1, 8, 5, 9, 6, 2, 9, 7, 7, 9, 5, 4, 2, 6,
       6, 1, 3, 4, 7, 2, 8, 0, 6, 1, 6, 6, 5, 8, 4, 3, 0, 5, 2, 9, 9, 7,
       8, 0, 5, 0, 6, 3, 3, 5, 1, 5, 1, 7, 9, 6, 4, 5, 0, 1, 8, 7, 8, 8,
       8, 9, 8, 7, 7, 2, 2, 2, 8, 0, 7, 8, 6, 8, 0, 4, 2, 2, 3, 7, 9, 0,
       2, 0, 0, 2, 7, 1, 5, 6, 4, 0, 0, 5, 5, 3, 9, 6, 1, 6, 0, 6, 4, 0,
       1, 8, 2, 2, 3, 7, 6, 1, 1, 2, 4, 7, 4, 9, 4, 3, 0, 4, 3, 1, 3, 0,
       9, 4, 6, 0, 3, 2, 6, 2, 5, 6, 7, 8, 8, 4, 4, 6, 9, 4, 5, 4, 5, 7,
       1, 9, 6, 8, 0, 4, 1, 9, 9, 7, 1, 8, 5, 0, 8, 7, 7, 2, 1, 3, 7, 4,
       0, 6, 3, 1, 2, 9, 9, 2, 5, 7, 3, 0, 6, 1, 6, 1, 1, 2, 5, 5, 3, 2,
       8, 5, 0, 9, 6, 9, 8, 4, 5, 8, 1, 6, 3, 0, 4, 6, 1, 8, 3, 4, 7, 1,
       0, 7, 9, 2, 7, 2, 1, 6, 9, 3, 1, 3, 2, 4, 3, 4, 3, 3, 5, 4, 7, 3,
       6, 7, 0, 0, 1, 1, 0, 2, 0, 7, 7, 4, 7, 2, 0, 1, 2, 4, 8, 1, 6, 0,
       3, 4, 0, 6, 8, 4, 4, 9, 0, 8, 4, 6, 8, 7, 8, 2, 8, 1, 6, 6, 9, 5,
       3, 8, 5, 1, 3, 3, 1, 8, 8, 3, 0, 4, 1, 7, 2, 7, 4, 0, 4, 2, 7, 7,
       9, 1, 9, 0, 9, 3, 8, 6, 2, 5, 3, 3, 7, 2, 1, 0, 8, 7, 7, 3, 1, 2,
       4, 5, 7, 7, 9, 1, 5, 5, 2, 8, 7, 9, 4, 7, 0, 2, 6, 1, 3, 1, 3, 7,
       3, 6, 7, 1, 6, 6, 1, 0, 6, 9, 7, 7, 4, 4, 9, 1, 5, 1, 1, 7, 2, 6,
       6, 4, 3, 1, 0, 5, 3, 9, 5, 8, 1, 7, 9, 9, 8, 2, 1, 0, 6, 6, 4, 4,
       7, 8, 6, 5, 8, 8, 2, 2, 2, 9, 8, 8, 3, 6, 0, 4, 4, 7, 6, 6, 9, 0,
       4, 6, 8, 5, 1, 9, 9, 3, 1, 6, 5, 9, 7, 3, 4, 4, 2, 4, 4, 9, 2, 9,
       9, 7, 2, 3, 3, 3, 7, 2, 7, 8, 1, 0, 5, 6, 6, 8, 0, 7, 0, 4, 2, 6,
       6, 8, 6, 4, 7, 7, 0, 3, 0, 7, 4, 0, 0, 2, 1, 8, 4, 2, 2, 9, 9, 3,
       3, 4, 4, 2, 6, 3, 7, 2, 8, 4, 2, 9, 5, 1, 9, 0, 9, 7, 2, 6, 2, 1,
       6, 9, 9, 3, 8, 3, 6, 2, 2, 4, 9, 3, 4, 6, 8, 6, 1, 7, 4, 1, 4, 7,
       0, 1, 5, 6, 2, 7, 8, 4, 9, 0, 9, 0, 5, 2, 2, 4, 1, 8, 8, 7, 2, 9,
       7, 0, 0, 6, 0, 5, 0, 5, 1, 0, 8, 6, 6, 0, 3, 4, 0, 3, 5, 6, 9, 8,
       4, 8, 5, 2, 7, 5, 5, 1, 1, 8, 9, 0, 3, 4, 9, 2, 9, 3, 1, 7, 5, 4,
       9, 5, 7, 7, 7, 0, 1, 9, 1, 9, 7, 1, 3, 9, 4, 9, 2, 5, 3, 5, 6, 2,
       3, 0, 7, 3, 2, 5, 2, 6, 1, 7, 0, 2, 9, 4, 2, 0, 7, 4, 4, 9, 0, 4,
       9, 6, 8, 2, 6, 4, 4, 9, 1, 2, 4, 7, 8, 8, 2, 2, 5, 9, 5, 4, 3, 1,
       4, 7, 1, 5, 3, 8, 6, 5, 5, 2, 1, 0, 9, 1, 6, 9, 3, 7, 4, 5, 6, 3,
       6, 5, 2, 7, 6, 3, 4, 1, 1, 4, 8, 4, 5, 3, 3, 7, 7, 8])