Confusion Matrix
tags: #ML/supervised/classification/logit
Python Code
from sklearn.metrics import (confusion_matrix, accuracy_score)
# confusion matrix
cm = confusion_matrix(ytest, ypred) # prediction = predicted outcome from the model
print ("Confusion Matrix : \n", cm)
# accuracy score of the model
print('Test accuracy = ', accuracy_score(ytest, prediction))
Visualizing the Confusion Matrix
We can use correlation heatmap function from seaborn to print a heatmap of the confusion matrix:
sns.heatmap(confusion_matrix(y_test, y_predict), annot = True)
What is a confusion matrix?
The confusion matrix is an N x N table (where N is the number of classes) that contains the number of correct and incorrect predictions of the classification model.
We can evaluate the performance of a classification model using a confusion matrix based on 4 outcomes: TP, TN, FP, FN.
Actual
0 1
Predicted
0 TN FP
1 FN TP
Performance Metrics
Accuracy
This is a measure of the number of correct predictions out of all predictions:
Precision
Which is the measure of the proportion of CORRECT POSITIVES out of all POSITIVE predictions made by the model:
Recall
Which is the measure of the proportion of CORRECT POSITIVES out of all ACTUAL POSITIVE CASES made by the model:
Error Rate
This is a measure of the number of false predictions out of all predictions:
F1-Score
Ideally, want both high precision (ratio of correct positive cases to all positive prediction), and high recall (ratio of correct positive cases to all actual cases).
However, this is difficult to achieve – instead, we can use F1-score which is a measure of both recall and precision, by finding the harmonic mean of the two metrics, giving weight to both measures:
Weighted F1-Score (Multi-classification)
To compute the weighted F1-score for a multi-classification tasks using the confusion matrix, where the weights are determined by the proportion of samples in each class:
- Calculate the F1-score for each class (identify the recall and precision of each class).
- Compute the weighted average mean for each class (sum of rows/summation of all rows) – Proportion of class.
Such that: