Oftentimes in practical machine learning problems there will be significant differences in the rarity of different classes of data being predicted. For example, when detecting cancer we can expect to have datasets with large numbers of false outcomes, and a relatively smaller number of true outcomes.
The overall performance of any model trained on such data will be constrained by its ability to predict rare points. In problems where these rare points are only equally important or perhaps less important than non-rare points, this constraint may only become significant in the later “tuning” stages of building the model. But in problems where the rare points are important, or even the point of the classifier (as in a cancer example), dealing with their scarcity is a first-order concern for the mode builder.
Tangentially, note that the relative importance of performance on rare observations should inform your choice of error metric for the problem you are working on; the more important they are, the more your metric should penalize underperformance on them. See my previous Model Fit Metrics and Log Loss notebooks for slightly more detail on which error metrics do and don’t care about this problem.
Several different techniques exist in the practice for dealing with imbalanced dataset. The most naive class of techniques is sampling: changing the data presented to the model by undersampling common classes, oversampling (duplicating) rare classes, or both.
We’ll motivate why under- and over- sampling is useful with an example. The following visualization shows the radical effect that the relative number of points-per-class in a dataset has on classification as performed by a (linear-kernel) Support Vector Machine
As you can see, when a dataset is dominated by one or few classes, to the exclusion of some other classes, the optimal solution can break down to collapse: a model which simply classifies all or most points in the majority class (as in the first and second visualizations in this grid). However, as the number of observations per point approaches an equal split, the classifier becomes less and less biased.
Re-sampling points that are being fed into the model is the simplest way to fix model errors like this one stemming from rare class problems. Not all datasets have this issue, but for those that due, dealing with this issue is an important early step in modeling the data.
Getting some sample data
We’ll use the following data for the sake of illustration (taken from the sklearn documentation):
A group of researchers implemented the full suite of modern data sampling techniques with the imbalance-learn contrib module for sklearn. This submodule is installed as part of the base sklearn install by default, so it should be available to everyone. It comes with its own documentation as well; that is available here.
imblearn implements over-sampling and under-sampling using dedicated classes.
In the first graph we have oversampled the dataset, duplicating points from the rare class 2 and the ultra rare class 3 in order to match the common class 1. This results in many points getting “pasted over” a huge number of times, as there are just 64 distinct points in class 2, but 4700 of them in class 1.
In the second graph we have undersampled the dataset. This goes the other way: classes 1 and 2 are reduced in numeric volume until they reach the same number of observations as class 3, at 64.
Which of these two fields of points are you better off training your classifier on? In extreme cases where the number of observations in the rare class(es) is really small, oversampling is better, as you will not lose important information on the distribution of the other classes in the dataset. For example, if there were just five observations in class 3, we’d have an awful time training a classifier on just 15 (5 times 3) undersampled points!
Outside of this case however, the performance of the one or the other will be most indistinguishable. Remember, sampling doesn’t introduce new information in the dataset, it (hopefully) merely shifts it around so as to increase the “numerical stability” of the resulting models.
By default the number of observations will be balanced, e.g. each class will appear equally many times. There’s also a ratio parameter, which allows you to choose the number of observations per class (in terms of integer absolute numbers, e.g. “60 class 1 observations”) to push into your sample dataset. For example:
The metric trap
One of the major issues that novice users fall into when dealing with unbalanced datasets relates to the metrics used to evaluate their model. Using simpler metrics like accuracy_score can be misleading. In a dataset with highly unbalanced classes, if the classifier always “predicts” the most common class without performing any analysis of the features, it will still have a high accuracy rate, obviously illusory.
Let’s do this experiment, using simple cross-validation and no feature engineering:
As we can see, the high accuracy rate was just an illusion. In this way, the choice of the metric used in unbalanced datasets is extremely important. In this competition, the evaluation metric is the Normalized Gini Coefficient, a more robust metric for imbalanced datasets, that ranges from approximately 0 for random guessing, to approximately 0.5 for a perfect score.
An interesting way to evaluate the results is by means of a confusion matrix, which shows the correct and incorrect predictions for each class. In the first row, the first column indicates how many classes 0 were predicted correctly, and the second column, how many classes 0 were predicted as 1. In the second row, we note that all class 1 entries were erroneously predicted as class 0.
Therefore, the higher the diagonal values of the confusion matrix the better, indicating many correct predictions.
A widely adopted technique for dealing with highly unbalanced datasets is called resampling. It consists of removing samples from the majority class (under-sampling) and / or adding more examples from the minority class (over-sampling).
Despite the advantage of balancing classes, these techniques also have their weaknesses (there is no free lunch). The simplest implementation of over-sampling is to duplicate random records from the minority class, which can cause overfitting. In under-sampling, the simplest technique involves removing random records from the majority class, which can cause loss of information.
Let’s implement a basic example, which uses the DataFrame.sample method to get random samples each class:
Python imbalanced-learn module
A number of more sophisticated resapling techniques have been proposed in the scientific literature.
For example, we can cluster the records of the majority class, and do the under-sampling by removing records from each cluster, thus seeking to preserve information. In over-sampling, instead of creating exact copies of the minority class records, we can introduce small variations into those copies, creating more diverse synthetic samples.
Let’s apply some of these resampling techniques, using the Python library imbalanced-learn. It is compatible with scikit-learn and is part of scikit-learn-contrib projects.
Random under-sampling and over-sampling with imbalanced-learn
Under-sampling: Tomek links
Tomek links are pairs of very close instances, but of opposite classes. Removing the instances of the majority class of each pair increases the space between the two classes, facilitating the classification process.
Under-sampling: Cluster Centroids
This technique performs under-sampling by generating centroids based on clustering methods. The data will be previously grouped by similarity, in order to preserve information.
SMOTE (Synthetic Minority Oversampling TEchnique) consists of synthesizing elements for the minority class, based on those that already exist. It works randomly picingk a point from the minority class and computing the k-nearest neighbors for this point. The synthetic points are added between the chosen point and its neighbors.
Over-sampling followed by under-sampling
Now, we will do a combination of over-sampling and under-sampling, using the SMOTE and Tomek links techniques: