Finding the Optimal k
tags: #ML/supervised/classification/knn
How do we find the optimal k?
To find the optimal
The basic idea is to try different values of k and evaluate the model performance using cross-validation. We can then select the value of k that gives the best performance.
Approach 1: Cross-Validation
-
Perform k-fold cross-validation for different values of k and compute the average performance metric (e.g., accuracy) for each value of k.
-
Plot the average performance metric against the value of k and select the value of k that gives the highest performance.
import numpy as np
from sklearn.model_selection import KFold, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
# partition dataset
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# k values to test
k_values = np.arange(1, 10, 1)
for k in k_values:
kfold = KFold(n_splits=5, random_state=11, shuffle=True)
knn = KNeighborsClassifier(n_neighbours=k)
scores = cross_val_score(estimator = knn, X_train, y_train, cv=kfold)
print(f'k={k:<2}; mean accuracy={scores.mean():.2%}; sd={scores.std():.2%}')
This code will perform k-fold cross-validation with k values ranging from 1 to 9, and output the mean accuracy and standard deviation of the accuracy for each value of k.
Approach 2: GridSearchCV
Another approach is to use grid search, where we specify a range of values for k and the cross-validation method, and the algorithm searches for the best combination of hyperparameters.
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
# create param grid
param_grid = {
n_neighbours = np.arange(1, 10, 1)
}
#instanitate estimator
knn_model = KNeighborsClassifier()
# Perform grid search with cross-validation
grid_search = GridSearchCV(estimator=knn_model, param_grid=param_grid, cv=5)
# Fit the grid search to the training data
grid_search.fit(X_train, y_train)
# Print the best hyperparameters
print("Best hyperparameters: ", grid_search.best_params_)
# Print best score
print("Best Score: {}".format(grid_search.best_score_)