Programar algoritmo kNN desde cero en R

El algoritmo kNN es uno de los algoritmos más conocidos dentro del mundo del machine learning, muy utilizado, entre otras cosas, en la imputación de valores perdidos. Hoy vamos a programar un algoritmo kNN desde 0 en R para que entiendas perfectamente cómo funciona en detalle este algoritmo y lo puedas usar a la perfección. ¡Vamos a ello!

Introducción al k Nearest Neighbors (kNN) en R

La lógica detrás del algoritmo kNN es muy sencilla: me guardo la tabla de datos de entrenamiento y cuando me llegue un nuevo dato, encuentro los k observaciones (vecinos) más cercanos y hago la clasificación en base a esas observaciones. Al fin y al cabo, es de esperar que observaciones cercanas sean similares a la nueva observación.

Como ves, aquí vemos una gran diferencia respecto a la mayoría de algoritmos supervisados, y es que se trata de un algoritmo no paramétrico. Es decir, que el algoritmo no debe aprender el valor de ningún parámetro, por lo que no hay un proceso de entrenamiento como tal.

Así pues, la clave del algoritmo kNN que programaremos en R se basa en tres aspectos clave que debemos conocer:

  1. Conocer las distintas medidas de distancia que existen, cómo funcionan y cuándo usar cada una de las medidas.
  2. Entender cómo elegir la cantidad de k vecinos a los que se debe observar.
  3. Conocer cómo hace el algortimo kNN las predicciones.

Así pues, si te parece, vamos a ir viendo cada uno de estos aspectos.

Medidas de distancia que puede usar el algoritmo kNN

Dentro del algoritmo kNN las medidas de distancia más utilizadas son: distancia Euclídea, distancia de Minkowski, distancia Manhattan, distancia de Coseno y distancia Jaccard. Estas no son las únicas, el algoritmo kNN puede usar cualquier otra medida de distancia, aunque con estas cubriremos la gran mayoría de casos.

Distancia Euclídea

La distancia Euclídea es algo que ya hemos visto en este blog al programar el algoritmo K-means tanto en R como en Python. La distancia Euclídea se basa en el teorema de Pitágoras, según el cual, la hipotenusa al cuadrado es igual a la suma de catetos al cuadrado.

Esta fórmula funcionará independientemente del número de variables que haya. Partiendo del teorema de Pitágoras, podremos encontrar la distancia en línea recta entre dos puntos, es decir, la distancia Euclídea. En la imagen siguiente podemos ver cómo se calcularía la distancia entre los puntos p y q.

Fórmula Distancia Euclidea

Así pues, como la distancia Euclídea es una de las posibles medidas de distancia que puede usar el algoritmo kNN, vamos a programar la distancia Euclídea desde 0 en R:

euclidean_distance = function(a, b){
  # Comprobamos que tienen la misma cantidad de observaciones
  if(length(a) == length(b)){
    sqrt(sum((a-b)^2))  
  } else{
    stop('Vectors must be of the same length')
  }
}

euclidean_distance(1:10, 11:20)
[1] 31.62278

Distancia de Manhattan

La distancia de Manhattan no mide la distancia en línea recta, sino que considera la distancia como la suma de los catetos. Esto, que así dicho parece que no tiene mucho sentido, se entiende mejor con una imagen:

Distancia de Manhattan

Como ves, en este caso, se intenta medir la distancia entre los puntos (x1, y1) e (x2, y2). La línea azul representa la distancia Euclídea, mientras que la línea roja representa la distancia Manhattan. Obviamente, cuando calculas la distancia en una ciudad es imposible ir por línea recta, por lo que la distancia Manhattan tiene mucho más sentido.

Como decía, la distancia Mnahattan se calcula como la suma de los catetos. Por tanto, la fórmula para calcular la distancia Manhattan (con dos dimensiones) es la siguiente:

\[ Mdis = |x_2 – x_1| + |y_2 – y_1| \]

Así pues, vamos a programar la distancia Manhattan desde 0 en R para poder usarlo en nuestro algoritmo kNN:

manhattan_distance = function(a, b){
  # Comprobamos que tienen la misma cantidad de observaciones
  if(length(a) == length(b)){
    sum(abs(a-b))
  } else{
    stop('Vectors must be of the same length')
  }
}

manhattan_distance(1:10, 11:20)
[1] 100

Similitud Coseno

Como ya vimos el post sobre cómo programar un sistema de recomendación en R desde 0, la similitud de coseno mide el ángulo entre dos vectores, de tal forma que sepamos si apuntan, o no, en la misma dirección.

La fórmula de la similitud del coseno se calcula con la siguiente fórmula:

Similitud de coseno

Sin embargo, en nuestro caso no queremos medir la similitud, sino la distancia. La similitud de coseno será 1 cuando los dos vectores tengan ángulo 0 (apunten en la misma dirección), mientras que en ese caso la distancia debería ser 0 . Por tanto, para obtener la distancia en base a la similitud del coseno, simplemente deberemos restar la similitud a 1.

Así pues, podemos programar la distancia en base a similitud del coseno en R desde 0:

cos_similarity = function(a,b){
  if(length(a) == length(b)){
    num = sum(a *b, na.rm = T)
    den = sqrt(sum(a^2, na.rm = T)) * sqrt(sum(b^2, na.rm = T)) 
    result = num/den

    return(1-result)
  } else{
    stop('Vectors must be of the same length')
  }
}


cos_similarity(1:10, 11:20)
[1] 0.0440877

Coeficiente de Jaccard

El coeficiente de Jaccard mide el grado de similitud entre dos vectores, dando un valor de 1 cuando todos los valores son iguales y 0 cuando los valores son diferentes. El coeficiente de Jaccard se piede calcular la siguiente manera:

\[ J(A, B) = \frac{|A∩B|}{|A∪B|} = \frac{|A∩B|}{|A|+|B|-|A∩B|} \]

Así pues, vamos a programar el coeficiente de Jaccard en R para poder usarlo en nuestro algoritmo kNN programado desde 0:

jaccard = function(a, b){
  if(length(a) == length(b)){
    intersection = length(intersect(a,b))
    union = length(a) +  length(b) - intersection
    return(intersection/union)
  } else{
      stop('Vectors must be of the same length')
    }
}

jaccard(1:10, 11:20)
[1] 0

Distancia Minkowski

La distancia de Minkowski es un tipo de distancia que generaliza las distancias Euclídea y de Manhattan. Básicamente, la distancia Minkowski es una distancia que requiere de un parámetro p. Cuando p=2 obtenemos la distancia Euclídea y si p=1 obtenemos la distancia Manhattan.

De todos modos, aunque los valores más típicos de la distancia Minkowski suelen ser 1 y 2, también se le podría asignar otros valores.

La fórmula de la distancia de Minkowski es la siguiente:

\[ Minkowski = \Bigg(\sum^d_{l=1}|x_{il}-x_{jl}|^{1/p}\Bigg)^p \]

Así pues, podemos programar la distancia Minkowski desde 0 en R:

minkowski_distance = function(a,b,p){
  if(p<=0){
   stop('p must be higher than 0') 
  }

  if(length(a)== length(b)){
    sum(abs(a-b)^p)^(1/p)
  }else{
     stop('Vectors must be of the same length')

  }
}

minkowski_distance(1:10, 11:20, 1)
[1] 100
minkowski_distance(1:10, 11:20, 2)
[1] 31.62278

Como vemos, los resultados para p=1 coinciden con la distancia de Manhattan y con p=2 coinciden con la distancia Euclídea.

Cuándo usar cada tipo de distancia

Con esto ya conoces las principales medidas de distancia que se suelen utilizar en el algoritmo kNN. Sin embargo, ¿cuándo deberíamos usar cada una de ellas?

Como todo, la métrica de distancia que usemos dependerá del tipo de dato que tengamos, las dimensiones que tengamos y el objetivo de negocio. Por ejemplo, si queremos encontrar la ruta más cercana que debe realizar un taxi, o las distancias en un tablero de ajedrez, parece claro que los propios datos nos llevan a a usar la función Manhattan, ya que es la única distancia que tiene sentido.

Asimismo, cuando tenemos un nivel de dimensionalidad alto, es decir, cuando hay muchas variables, según este paper, la distancia de Manhattan funciona mejor que la distancia Euclídea. Y es que, con alta dimensionalidad todo está lejos de todo, por lo que otra opción suele ser mirar la dirección de los vectores, es decir, usar la distancia del coseno.

Por otro lado, entre usar la distancia Jaccard o la distancia del coseno, dependerá de la duplicidad de los datos. Si la duplicidad de datos no importa, entonces se usará la distancia Jaccard, y, en caso contrario, se usará la distancia del coseno. Estas distancias suelen ser típicas de datos que involucran palabras (NLP) y sistemas de recomendación.

Ahora que ya conocemos las diferentes medidas de distancia y cuándo usar cada una de ellas, vamos a usar estas distancias para encontrar a los k vecinos más cercanos y, a partir de ahí, ver cómo se hace la predicción.

Encontrar a los k vecinos más cercanos

Vamos a programar una función que, dada una medida de distancia, una serie de observaciones, nos devuelva los k vecinos de una observación que le pasemos.

Esto es muy interesante, ya que nos va a permitir usar la función kNN simplemente para encontrar los vecinos más cercanos, sin tener que realizar ninguna predicción, al estilo de lo que permite la librería scikit learn en Python y que resulta muy útil de cara a crear sistemas de recomendación.

En cualquier caso, la función para encontrar los vecinos más cercanos funcionará de la siguiente manera:

  • Calcular la distancia de la observación respecto a todas las observaciones.
  • Filtrar y devolver las k observaciones con menor distancia

Por tanto, vamos a programar esta parte tan importante de nuestro algoritmo kNN en R. Para que fácilmente encaje con el resto del código, haremos que en vez de devolver un dataset, el algoritmo devuelva los índices de las observaciones.

nearest_neighbors = function(x,obs, k, FUN, p = NULL){

  # Check the number of observations is the same
  if(ncol(x) != ncol(obs)){
    stop('Data must have the same number of variables')
  }

  # Calculate distance, considering p for Minkowski
  if(is.null(p)){
    dist = apply(x,1, FUN,obs)  
  }else{
    dist = apply(x,1, FUN,obs,p)
  }

  # Find closest neighbours
  distances = sort(dist)[1:k]
  neighbor_ind = which(dist %in% sort(dist)[1:k])

  if(length(neighbor_ind)!= k){
    warning(
      paste('Several variables with equal distance. Used k:',length(neighbor_ind))
    )
  }

  ret = list(neighbor_ind, distances)
  return(ret)
}

Ahora, podemos comprobar que nuestra función nearest_neighbors funciona. Para ello, vamos a usar el dataset iris, del cual usaremos la última observación con la que encontrar 3 vecinos más cercanos:

x = iris[1:(nrow(iris)-1),]
obs = iris[nrow(iris),]

ind = nearest_neighbors(x[,1:4], obs[,1:4],4, euclidean_distance)[[1]]
as.matrix(x[ind,1:4])
    Sepal.Length Sepal.Width Petal.Length Petal.Width
102          5.8         2.7          5.1         1.9
128          6.1         3.0          4.9         1.8
139          6.0         3.0          4.8         1.8
143          5.8         2.7          5.1         1.9

Como vemos, los 3 vecinos son bastante parecidos. De hecho, si lo comparamos con la observación utilizada vemos que son muy similares:

obs[,1:4]
    Sepal.Length Sepal.Width Petal.Length Petal.Width
150          5.9           3          5.1         1.8

Como vemos, ya sabemos cómo encontrar los k vecinos más cercanos. Sin embargo, ¿cómo hace el algoritmo la predicción? ¡Veámoslo!

Programar la predicción del algoritmo kNN

El algoritmo kNN sirve tanto para problemas de clasificación como de regresión. Obviamente, el tipo de predicción que haga dependerá del tipo de problema que se le pase.

Predicción del algoritmo kNN en problemas de clasificación

En el caso de los problemas de clasificación, el algoritmo kNN se basa en hallár la moda, es decir, el valor que más se repite. Si de los 10 vecinos, 8 son Setosa, entonces nuestra predicción será Setosa. Es como si de una votación se tratara, siendo la predicción la opción más votada.

Sin embargo, esto tiene un problema, y es que, ¿qué pasa si hay dos (o más) clases con el mismo número de votos? En ese caso, no habría un valor predominante, por lo que este sistema no serviría.

En esos casos, el algoritmo aumenta en 1 la k, es decir, añade un nuevo vecino que será el que (probablemente) desempate. Si añadiendo la k en 1 no se desempata, se vuelve incrementar en 1 hasta que se de ese «desempate».

Así pues, vamos a realizar la función de predicción en base a los datos.

knn_prediction = function(x,y){

  groups = table(x[,y])
  pred = groups[groups == max(groups)]
  return(pred)

}

knn_prediction(x[ind,], 'Species')
virginica 
        4 

Como vemos, el algortimo kNN nos devuelve la predicción de que la planta que teníamos es de la especie virginica. Comprobamos a ver si es cierto:

obs[,'Species']
[1] virginica
Levels: setosa versicolor virginica

Efectivamente, vemos que el algoritmo ha acertado en la predicción. De todos modos, aún no hemos arreglado el problema de que haya dos clases con la misma cantidad de vecinos. Esto lo arreglaremos más adelante al montar el algoritmo mediante recursividad.

Vista la predicción en el caso de la clasificación, vemos cómo hace el algoritmo kNN la predicción en caso de que se trate de una regresión:

Predicción del algoritmo kNN en problemas de regresión

En el caso del algoritmo kNN para regresiones, podemos optar pos dos enfoques:

  1. Realizar la predicción en base a la media, tal como se hace en el caso de los árboles de decisión, como vimos en este post.
  2. Realizar la predicción en base a una media ponderada que tenga en cuenta la distancia del resto de observaciones respecto al target, de tal forma que aquellas observaciones que sean más cercanas tengan más peso que aquellas que están más lejos.

Existen distintas formas de calcular la media ponderada (enlace), aunque lo más habitual suele ser ponderar en base a la inversa de las distancias.

Así pues, vamos a modificar la función de predicción anterior para que:

  1. Tenga en cuenta el tipo de variable sobre la que se hace la predicción.
  2. En caso de tratarse de una regresión, acepte que la predicción se haga tanto mediante una media simple como con una media ponderada.
knn_prediction = function(x,y, weights = NULL){

  x = as.matrix(x)

  if(is.factor(x[,y]) | is.character(x[,y])){
    groups = table(x[,y])
    pred = names(groups[groups == max(groups)])
  } 

  if(is.numeric(x[,y])){

    # Calculate weighted prediction
    if(!is.null(weights)){
      w = 1/weights/ sum(weights)
      pred = weighted.mean(x[,y], w)

    # Calculate standard prediction  
    }else{
      pred = mean(x[,y])
    }

  }

  # If no pred, then class is not correct
  if(try(class(x[,y])) == 'try-error'){
    stop('Y should be factor or numeric.')
  }

  return(pred)

}

Ahora, vamos a probar a hacer la predicción. Para ello, vamos a usar el dataset gapminder, con el cual intentaremos predecir la esperanza de vida de un país para el 2007, usando su población y su PIB per Cápita:l

cat('Prediccion:', pred,'\n',
    'Valor real: ', obs$lifeExp,
    sep ='')
Prediccion: 54.5806
Valor real: 43.487

Como vemos, hemos obtenido una predicción que es bastante cercana a lo que podríamos haber esperado. Comprobamos lo mismo, pero en el caso de que hubiéramos hecho un weighted kNN:

pred = knn_prediction(x[neighbors[[1]],],'lifeExp',weights = neighbors[[2]])

cat('Prediccion:', pred,'\n',
    'Valor real: ', obs$lifeExp,
    sep ='')
Prediccion:51.67521
Valor real: 43.487

Como vemos, la predicción ha mejorado un poco. Esto es debido a que las observaciones más cercanas son las que más peso han tenido y, por lo que parece, son las que mejor predicen (lo cual tiene sentido y, por eso precisamente, existe se usa el algoritmo kNN que acabamos de programar en R).

Con esto, ya sabríamos cómo hace la predicción el algoritmo kNN. Ahora, vamos a poner todo junto para que terminemos de programar nuestro algoritmo y que, usar kNN simplemente sea llamar a una función. ¡Vamos a ello!

Terminando de programar el algoritmo kNN desde 0 en R

De cara a terminar de programar nuestro algoritmo kNN en R, hay que tener en cuenta cómo se usa este algoritmo. En general, la predicción no se suele hacer sobre una observación, sino sobre varias a la vez. Por tanto, tendremos que permitir que nuestro algoritmo reciba varias observaciones sobre las que predecir y que devuelva un vector de predicciones.

Con lo que hemos hecho hasta ahora esto es muy sencillo, ya que simplemente tendremos que iterar sobre los datos de entrada para y hacer un append de los resultados.

Además, vamos a programar otra cuestión que anteriormente hemos dejado en el tintero: los empates en las votaciones al hacer clasificación.

Tal como tenemos el algoritmo programado solucionar eso es muy sencillo si aplicamos la recursividad. Así pues, cuando obtienes una predicción simplemente debes comprobar el número de clases de respuesta. Si es superior a uno, para esa misma observación vuelves a llamar al algoritmo, pero con un k superior.

Veámos cómo hacerlo:

knn = function(x_fit, x_pred, y, k, 
               func = euclidean_distance,weighted_pred = F, p = NULL){

  # Inicilizamos las predicciones
  predictions = c()

  y_ind = which(colnames(x_pred) == y)

  # Para cada observaciones, obtenemos la prediccion
  for(i in 1:nrow(x_pred)){

    neighbors = nearest_neighbors(x_fit[,-y_ind], 
                                  x_pred[i,-y_ind],k,FUN = func)

    if(weighted_pred){
      pred = knn_prediction(x_fit[neighbors[[1]], ],y, neighbors[[2]])
    } else{
      pred = knn_prediction(x_fit[neighbors[[1]], ],y)
    }

    # If more than 1 predictions, make prediction with 1 more k
    if(length(pred)>1){
      pred = knn(x_fit, x_pred[i,],y, k = k+1, 
                 func = func, weighted_pred = weighted_pred, p == p)
    }

    predictions[i] = pred

  }
  return(predictions)

}

Como ves, montar nuestro algoritmo ha sido muy sencillo. Veamos a ver cómo funciona. De cara a medir el rendimiento en la clasificación, usaré los datos de iris, pero en este caso, muestreados.

set.seed(1234)

n_fit = 20
train_ind = sample(1:nrow(iris),n_fit)

x_fit = iris[-train_ind,]
x_pred = iris[train_ind,]

predictions = knn(x_fit, x_pred, 'Species', k = 5)
predictions
 [1] 'setosa'     'versicolor' 'virginica'  'virginica'  'virginica'  'virginica'  'virginica'  'virginica' 
 [9] 'versicolor' 'virginica'  'versicolor' 'versicolor' 'versicolor' 'virginica'  'setosa'     'virginica' 
[17] 'versicolor' 'setosa'     'virginica'  'setosa'    

Ahora, podemos comparar las predicciones con los datos reales, a ver cómo se ha comportado el algoritmo. Al tratarse de un problema de clasificación, realizaremos una matriz de confusión, para lo cual usaremos la función confusionMatrix de caret.

library(caret)

predictions = factor(predictions, levels = levels(x_pred$Species))

confusionMatrix(as.factor(predictions), x_pred$Species)
Confusion Matrix and Statistics

            Reference
Prediction   setosa versicolor virginica
  setosa          4          0         0
  versicolor      0          6         0
  virginica       0          0        10

Overall Statistics

               Accuracy : 1          
                 95% CI : (0.8316, 1)
    No Information Rate : 0.5        
    P-Value [Acc > NIR] : 9.537e-07  

                  Kappa : 1          

 Mcnemar's Test P-Value : NA         

Statistics by Class:

                     Class: setosa Class: versicolor Class: virginica
Sensitivity                    1.0               1.0              1.0
Specificity                    1.0               1.0              1.0
Pos Pred Value                 1.0               1.0              1.0
Neg Pred Value                 1.0               1.0              1.0
Prevalence                     0.2               0.3              0.5
Detection Rate                 0.2               0.3              0.5
Detection Prevalence           0.2               0.3              0.5
Balanced Accuracy              1.0               1.0              1.0

Como vemos, nuestro algoritmo a acertado el 100% de los casos. Aunque si bien es cierto que se trata de un caso sencillo que se le da especialmente bien a este algoritmo, pero bueno, al menos nos sirve para comprobar que el algoritmo funciona.

Además, podemos comprobar también cómo funciona nuestro algoritmo a la hora de resolver problemas de regresión. Para ello, usaremos el dataset gapminder que hemos usado anteriormente. Para evaluarlo, usaré el RMSE y, además haremos una visualización de las desviaciones con el paquete ggplot.

library(ggplot2)
library(dplyr)
library(tidyr)

set.seed(12345)

n_fit = 20
train_ind = sample(1:nrow(gapminder_2007),n_fit)

x_fit = gapminder_2007[-train_ind, 3:6]
x_pred = gapminder_2007[train_ind, 3:6]

predictions = knn(x_fit, x_pred, 'lifeExp', k = 5)

results = data.frame(real = x_pred$lifeExp, 
                     prediction = predictions)
results %>%
  mutate(id = row_number()) %>%
  pivot_longer(cols = -c('id'), names_to = 'variable', values_to = 'valor') %>%
  ggplot(aes(id, valor, col = variable)) + 
  geom_point(size = 2) +  geom_line(aes(group = id),color='grey') +
  theme_minimal() + theme(legend.position = 'bottom') +
  labs(x = '', y = 'Predicted Value', col = '', 
       title = 'kNN model prediction fit', subtitle = 'k = 5 & Non-weighted')
Predicción del modelo kNN con k5 vs datos reales

Como vemos, la predicción parece ser bastante acertada en algunos casos, aunque muy alejada en otros. Veámos ahora cómo se comporta el weighted kNN, de tal forma que podamos ver si hay alguna variación significativa:

predictions_w = knn(x_fit, x_pred, 'lifeExp', k = 5, weighted_pred = T)

results_w = data.frame(real = x_pred$lifeExp, 
                     prediction = predictions_w)
results_w %>%
  mutate(id = row_number()) %>%
  pivot_longer(cols = -c('id'), 
               names_to = 'variable',
               values_to = 'valor') %>%
  ggplot(aes(id, valor, col = variable)) + 
  geom_point(size = 2) +  geom_line(aes(group = id),color='grey') +
  theme_minimal() + theme(legend.position = 'bottom') +
  labs(x = '', y = 'Predicted Value', col = '', 
       title = 'Weighted kNN model prediction fit', subtitle = 'k = 5')
Predicción del modelo kNN con k5 vs datos reales

Como podemos ver, las diferncias respecto a las predicciones parecen algo menores, es decir, que el algoritmo algo ha mejorado. Vamos a medir el RMSE para comparar si es así o no:

rmse = function(y_pred, y_real){
  sqrt(mean((y_pred-y_real)^2))
}

rmse_model1 = rmse(results$prediction, results$real)
rmse_model2 = rmse(results_w$prediction, results_w$real)

cat(
  'RMSE kNN: ', rmse_model1,'\n',
  'RMSE weighted kNN: ', rmse_model2, sep=''
)
RMSE kNN: 13.47053
RMSE weighted kNN: 13.54268

Como podemos ver, en este caso el weighted kNN ha tenido una capacidad predictiva un poco mejor que el kNN normal, aunque tampoco excesivamente grande. Pero, ¿puede ser porque no estemos aplicando bien el algoritmo? Y es que, hasta ahora hemos visto cómo se programa desde 0 en R. Y sí, ya sabemos cómo funciona. Pero.. ¿cómo se elige el número de k, por ejemplo? ¿En qué casos debemos usar el algoritmo kNN? ¡Veámoslo!

Cómo y cuándo usar el algoritmo kNN

Elección del número de vecinos a revisar

Una de las preguntas más importantes sobre el algoritmo kNN es cuántos vecinos deberíamos revisar. Como norma general, existen dos problemas principales a la hora de elegir la k:

  • Si la k es demasiado baja, el algoritmo utilizará muy pocos vecinos, lo que tenderá al overfitting.
  • Si la k es demasiado alta, la predicción tenderá cada vez más a parecerse a la media, por lo que tendrás un problema de underfitting.

Como norma general, la k se suele elegir mediante la siguiente fórmula:

\[ k = \sqrt{n} \]

Por curiosidad, vamos a comprobar cuál sería el RMSE de nuestro modelo al ajustar la k de esta manera en el caso de la regresión:

k = round(sqrt(nrow(x_fit)))

# Make predictions
predictions = knn(x_fit, x_pred, 'lifeExp', k )
predictions_w = knn(x_fit, x_pred, 'lifeExp', k , weighted_pred = T)

# Get RMSE for both cases
rmse_model1 = rmse(predictions, x_pred$lifeExp)
rmse_model2 = rmse(predictions_w, x_pred$lifeExp)


cat(
  '-- k correctly selected --','\n',
  'selected value of k: ', k,'\n',
  'RMSE kNN: ', rmse_model1,'\n',
  'RMSE weighted kNN: ', rmse_model2, sep=''
)
-- k correctly selected --
selected value of k: 11
RMSE kNN: 12.15191
RMSE weighted kNN: 12.69471

Como vemos, un mejor número de k, la predicción mejora su RMSE al menos 1 punto, lo cual no está nada mal.

Aunque esta regla sea bastante sencilla y rápida, otra forma de elegir el número de k es iterando el modelo con diferentes ks. De esta forma, nos podemos asegurar que la k que elegimos es la que maximiza los resultados (al menos para ese dataset).

Veámos cómo se hace:

n_iterations =  50

errors = c()

for(i in 1:n_iterations){
  prediction = knn(x_fit, x_pred, 'lifeExp', k = i)
  errors[i] = prediction
  if(i%%10==0){print(i)}
}

error = data.frame(k = c(1:n_iterations), error = errors)

error %>%
  ggplot(aes(k,error)) + geom_line() +
  geom_vline(xintercept = 4, linetype = 'dashed') +
  theme_minimal() +
  scale_x_continuous(breaks = seq(0,n_iterations,2)) +
  labs(title = 'RMSE evolution for different values of k',
       subtitle = 'Non-weighted kNN'
       )
Cómo elegir el número de K en el algoritmo kNN

Como podemos ver, con k = 4 obtenemos la menor cantidad de RMSE. Antes de eso, la predicción está sufriendo de overfitting y con k>4, cada vez predecimos peor hasta que con k >8 el modelo deja de generalizar y empieza a sufrir de underfitting.

Sin embargo, lo malo de obtener el número de k de esta forma es que es computacionalmente muy costoso, lo que hace que pierda uno de los valores del algoritmo kNN.

Normalización de los datos

Al igual que vimos con el algoritmo k-Means, para que el algoritmo kNN funcione adecuadamente debemos normalizar los datos sobre los que lo vamos a aplicar.

El motivo es que si tenemos dos variables con escalas muy diferentes (véase la población y el PIB per Cápita), aquella variable con valores más altos será la que determine la distancia. Esto lleva a que, en la práctica, en vez de que el algoritmo tenga en cuenta ambas variables, se está centrando únicamente en una variable para hacer la predicción.

gapminder_2007 %>%
  ggplot(aes(pop, gdpPercap)) + geom_point() +
  theme_minimal() +
  labs(title = 'Distribución Pop vs Gdp per Cap', 
       subtitle = 'Datos sin escalar') 
Distribución de la población vs el PIB per Cap sin escalar

Si escalamos los datos, veremos cómo encontrar a los vecinos más cercanos resultará mas sencillo:

normalization = function(x){
  (x - min(x))/(max(x) - min(x))
}

gapminder_2007$pop_norm = normalization(gapminder_2007$pop)
gapminder_2007$gdpPercap_norm = normalization(gapminder_2007$gdpPercap)

gapminder_2007 %>%
  ggplot(aes(pop_norm, gdpPercap_norm)) + geom_point() +
  theme_minimal() +
  labs(title = 'Distribución Pop vs Gdp per Cap', 
       subtitle = 'Datos sin escalar')
Distribución población vs Pib per Cápita datos escalados

Los datos parece que no hayan cambiado (debido a los outliers), pero si comprobamos el RMSE veremos como se ha reducido significativamente:

x_fit = gapminder_2007[-train_ind, c(4,7,8)]
x_pred = gapminder_2007[train_ind, c(4,7,8)]

prediction = knn(x_fit, x_pred, 'lifeExp', k = 11)

cat(
  'RMSE kNN sin normalizar: ', rmse_model1,'\n',
  'RMSE kNN datos normalizados: ', rmse(prediction, x_pred$lifeExp), sep=''
)
RMSE kNN sin normalizar: 12.15191
RMSE kNN datos normalizados: 7.510772

Como vemos, la predicción ha mejorado mucho y esto simplemente se debe al hecho de haber normalizado los datos.

Usar variables categóricas en kNN

Si te has fijado, en los ejemplos utilizados hasta ahora únicamente he utilizado variables numéricas. Y es que, para poder utilizar variables categóricas en el algoritmo kNN que acabamos de programar en R, hay dos opciones:

  • Convertir las categorías en variables numéricas, aplicando técnicas como el one-hot encoding (típico de redes neuronales).
  • Utilizar una medida de distancia que sí permita trabajar con variables categóricas. Y es que, mientras que algunas medidas de distancia (como la distancia Euclídea o la distancia Manhattan) únicamente admiten variables numéricas, otras medidas de distancia de Hamming o la distancia de Gower.

Conclusión

Ahora que ya sabemos todo lo que hay que saber sobre el algoritmo kNN: lo hemos programado desde 0 en R, las distintas medidas de distancia que puede usar y cuándo usar cada una de ellas, cómo usarlo correctamente, cómo elegir la k y sabemos cómo hacer para que use variables categóricas.

Espero que este post te haya resultado útil. Si te gustaría seguir aprendiendo sobre diferentes algoritmos en R, te recomendaría te suscribas a la newsletter para estar atento de los posts que voy subiendo. En cualquier caso, ¡nos vemos en el siguiente!