KNNImputer#

  • Completa valores faltantes usando la metodología kNN.

  • Cada valor faltante es imputado usando la media de los k vecinos más cercanos en el dataset de entrenamiento.

[1]:
import numpy as np

X_train = [
    [1, 2, np.nan],
    [3, 4, 3],
    [np.nan, 6, 5],
    [8, 8, 7],
]

#
#       Para la columna 1             Para la columna 2
#  --------------------------   ---------------------------
#     Data  Puntos     kNN=2       Data  Puntos     kNN=2
#                   respecto                     respecto
#                      a [6]                        a [2]
#  --------------------------   ---------------------------
#   [1, 2]     [2]        No     [4, 3]     [4]        Si
#   [3, 4]     [4]        Si     [6, 5]     [6]        Si
#   [8, 8]     [8]        Si     [8, 7]     [8]        No
#
#   (3 + 8) / 2 = 5.5 <---        (3 + 5) / 2 = 4.0 <---
#
[2]:
from sklearn.impute import KNNImputer

knnImputer = KNNImputer(
    # -------------------------------------------------------------------------
    # The placeholder for the missing values.
    missing_values=np.nan,
    # -------------------------------------------------------------------------
    # Number of neighboring samples to use for imputation.
    n_neighbors=2,
    # -------------------------------------------------------------------------
    # Weight function used in prediction.
    # - 'uniform' : uniform weights. All points in each neighborhood are
    #   weighted equally.
    # - 'distance' : weight points by the inverse of their distance. in this
    # case, closer neighbors of a query point will have a greater influence
    # than neighbors which are further away.
    weights="uniform",
    # -------------------------------------------------------------------------
    # Distance metric for searching neighbors.
    # - 'nan_euclidean'
    # - user defined.
    metric="nan_euclidean",
)

knnImputer.fit(X_train)

knnImputer.transform(X_train)
[2]:
array([[1. , 2. , 4. ],
       [3. , 4. , 3. ],
       [5.5, 6. , 5. ],
       [8. , 8. , 7. ]])