Finding the Optimal k

tags: #ML/supervised/classification/knn

How do we find the optimal k?

To find the optimal k in k-Nearest Neighbors, we can use cross-validation.

The basic idea is to try different values of k and evaluate the model performance using cross-validation. We can then select the value of k that gives the best performance.

Approach 1: Cross-Validation

  1. Perform k-fold cross-validation for different values of k and compute the average performance metric (e.g., accuracy) for each value of k.

  2. Plot the average performance metric against the value of k and select the value of k that gives the highest performance.

import numpy as np
from sklearn.model_selection import KFold, cross_val_score
from sklearn.neighbors import KNeighborsClassifier


# partition dataset
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# k values to test
k_values = np.arange(1, 10, 1)

for k in k_values:
	kfold = KFold(n_splits=5, random_state=11, shuffle=True)
	knn = KNeighborsClassifier(n_neighbours=k)
	scores = cross_val_score(estimator = knn, X_train, y_train, cv=kfold)
    print(f'k={k:<2}; mean accuracy={scores.mean():.2%}; sd={scores.std():.2%}')

This code will perform k-fold cross-validation with k values ranging from 1 to 9, and output the mean accuracy and standard deviation of the accuracy for each value of k.

Approach 2: GridSearchCV

Another approach is to use grid search, where we specify a range of values for k and the cross-validation method, and the algorithm searches for the best combination of hyperparameters.

import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier

# create param grid
param_grid = {
	n_neighbours = np.arange(1, 10, 1)
}

#instanitate estimator
knn_model = KNeighborsClassifier()

# Perform grid search with cross-validation 
grid_search = GridSearchCV(estimator=knn_model, param_grid=param_grid, cv=5) 

# Fit the grid search to the training data 
grid_search.fit(X_train, y_train)

# Print the best hyperparameters 
print("Best hyperparameters: ", grid_search.best_params_) 

# Print best score 
print("Best Score: {}".format(grid_search.best_score_)
Powered by Forestry.md