While working as a data scientist, some of the most frequently occurring problem statements are related to binary classification. A common problem when solving these problem statements is that of class imbalance. When observation in one class is higher than in other classes, a class imbalance exists. Example: To detect fraudulent credit card transactions. As shown in the graph below, the fraudulent transaction is around 400 compared to the non-fraudulent transaction of around 90000.
Class Imbalance is a common problem in machine learning, especially in classification problems. Imbalance data can hamper our model accuracy big time. It appears in many domains, including fraud detection, spam filtering, disease screening, SaaS subscription churn, advertising click-throughs, etc. Let’s understand how to deal with imbalanced data in machine learning.
Learning Objectives
Most machine learning algorithms work best when the number of samples in each class is about equal. This is because most algorithms are designed to maximize accuracy and reduce errors.
However, if the dataframe has imbalanced classes, then In such cases, you get a pretty high accuracy just by predicting the majority class, but you fail to capture the minority class, which is most often the point of creating the model in the first place. For example, if the class distribution shows that 99% of the data has the majority class, then any basic classification model like the logistic regression or decision tree will not be able to identify the minor class data points.
Let’s say we have a dataset of credit card companies where we have to find out whether the credit card transaction was fraudulent or not.
But here’s the catch… fraud transaction is relatively rare. Only 6% of the transactions are fraudulent.
Now, before you even start, do you see how the problem might break? Imagine if you didn’t bother training a model at all. Instead, what if you just wrote a single line of code that always predicts ‘no fraudulent transaction’
def transaction(transaction_data):
return 'No fradulent transaction'
Well, guess what? Your “solution” would have 94% accuracy!
Unfortunately, that accuracy is misleading.
This is clearly a problem because many machine learning algorithms are designed to maximize overall accuracy. In this article, we will see different techniques to handle imbalanced data.
We will use a credit card fraud detection dataset for this article. You can find the dataset here.
After loading the data display the first five-row of the data set.
Python Code:
You can clearly see that there is a huge difference between the data set. 9000 non-fraudulent transactions and 492 fraudulent.
One of the major issues that new developer users fall into when dealing with unbalanced datasets relates to the evaluation metrics used to evaluate their machine learning model. Using simpler metrics like accuracy score can be misleading. In a dataset with highly unbalanced classes, the classifier will always “predicts” the most common class without performing any analysis of the features, and it will have a high accuracy rate, obviously not the correct one.
Let’s do this experiment using the simple XGBClassifier and no feature engineering:
# import linrary
from xgboost import XGBClassifier
xgb_model = XGBClassifier().fit(x_train, y_train)
# predict
xgb_y_predict = xgb_model.predict(x_test)
# accuracy score
xgb_score = accuracy_score(xgb_y_predict, y_test)
print('Accuracy score is:', xbg_score)OUTPUT
Accuracy score is: 0.992
We can see 99% accuracy, we are getting very high accuracy because it is predicting mostly the majority class that is 0 (Non-fraudulent).
One of the widely adopted class imbalance techniques for dealing with highly unbalanced datasets is called resampling. It consists of removing samples from the majority class (under-sampling) and/or adding more examples from the minority class (over-sampling).
Despite the advantage of balancing classes, these techniques also have their weaknesses (there is no free lunch).
The simplest implementation of over-sampling is to duplicate random records from the minority class, which can cause overfishing.
In under-sampling, the simplest technique involves removing random records from the majority class, which can cause a loss of information.
Let’s implement this with the credit card fraud detection example.
We will start by separating the class that will be 0 and class 1.
# class count
class_count_0, class_count_1 = data['Class'].value_counts()
# Separate class
class_0 = data[data['Class'] == 0]
class_1 = data[data['Class'] == 1]# print the shape of the class
print('class 0:', class_0.shape)
print('class 1:', class_1.shape
Undersampling can be defined as removing some observations of the majority class. This is done until the majority and minority class is balanced out.
Undersampling can be a good choice when you have a ton of data -think millions of rows. But a drawback to undersampling is that we are removing information that may be valuable.
class_0_under = class_0.sample(class_count_1)
test_under = pd.concat([class_0_under, class_1], axis=0)
print("total class of 1 and0:",test_under['Class'].value_counts())# plot the count after under-sampeling
test_under['Class'].value_counts().plot(kind='bar', title='count (target)')
Oversampling can be defined as adding more copies to the minority class. Oversampling can be a good choice when you don’t have a ton of data to work with.
A con to consider when undersampling is that it can cause overfitting and poor generalization to your test set.
class_1_over = class_1.sample(class_count_0, replace=True)
test_over = pd.concat([class_1_over, class_0], axis=0)
print("total class of 1 and 0:",test_under['Class'].value_counts())# plot the count after under-sampeling
test_over['Class'].value_counts().plot(kind='bar', title='count (target)')
A number of more sophisticated resampling techniques have been proposed in the scientific literature.
For example, we can cluster the records of the majority class and do the under-sampling by removing records from each cluster, thus seeking to preserve information. In over-sampling, instead of creating exact copies of the minority class records, we can introduce small variations into those copies, creating more diverse synthetic samples.
Let’s apply some of these resampling techniques using the Python library imbalanced-learn. It is compatible with scikit-learn and is part of scikit-learn-contrib projects.
import imblearn
You may have heard about pandas, numpy, matplotlib, etc. while learning data science. But there is another library: imblearn, which is used to sample imbalanced datasets and improve your model performance.
RandomUnderSampler
is a fast and easy way to balance the data by randomly selecting a subset of data for the targeted classes. Under-sample the majority class(es) by randomly picking samples with or without replacement.
# import library
from imblearn.under_sampling import RandomUnderSampler
rus = RandomUnderSampler(random_state=42, replacement=True)# fit predictor and target variable
x_rus, y_rus = rus.fit_resample(x, y)
print('original dataset shape:', Counter(y))
print('Resample dataset shape', Counter(y_rus))
One way to fight imbalanced data is to generate new samples in the minority classes. The most naive strategy is to generate new samples by random sampling with the replacement of the currently available samples. The RandomOverSampler
offers such a scheme.
# import library
from imblearn.over_sampling import RandomOverSampler
ros = RandomOverSampler(random_state=42)
# fit predictor and target variablex_ros, y_ros = ros.fit_resample(x, y)
print('Original dataset shape', Counter(y))
print('Resample dataset shape', Counter(y_ros))
Tomek links are pairs of very close instances but of opposite classes. Removing the instances of the majority class of each pair increases the space between the two classes, facilitating the classification process.
Tomek’s link exists if the two samples are the nearest neighbors of each other.
In the code below, we’ll use ratio='majority'
to resample the majority class.
# import library
from imblearn.under_sampling import TomekLinks
tl = RandomOverSampler(sampling_strategy='majority')
# fit predictor and target variable
x_tl, y_tl = ros.fit_resample(x, y)
print('Original dataset shape', Counter(y))
print('Resample dataset shape', Counter(y_ros))
This technique generates synthetic data for the minority class.
SMOTE (Synthetic Minority Oversampling Technique) works by randomly picking a point from the minority class and computing the k-nearest neighbors for this point. The synthetic points are added between the chosen point and its neighbors.
SMOTE algorithm works in 4 simple steps:
# import library
from imblearn.over_sampling import SMOTE
smote = SMOTE()
# fit predictor and target variable
x_smote, y_smote = smote.fit_resample(x, y)
print('Original dataset shape', Counter(y))
print('Resample dataset shape', Counter(y_ros))
NearMiss is an under-sampling technique. Instead of resampling the Minority class, using a distance will make the majority class equal to the minority class.
from imblearn.under_sampling import NearMiss
nm = NearMiss()
x_nm, y_nm = nm.fit_resample(x, y)
print('Original dataset shape:', Counter(y))
print('Resample dataset shape:', Counter(y_nm))
Accuracy is not the best metric to use when evaluating imbalanced datasets, as it can be misleading.
Metrics that can provide better insight are:
The next tactic is to use penalized learning algorithms that increase the cost of classification mistakes in the minority class.
A popular algorithm for this technique is Penalized-SVM.
During training, we can use the argument class_weight=’balanced’
to penalize mistakes on the minority class by an amount proportional to how under-represented it is.
We also want to include the argument probability=True
if we want to enable probability estimates for SVM algorithms.
Let’s train a model using Penalized-SVM on the original imbalanced dataset:
# load library
from sklearn.svm import SVC
# we can add class_weight='balanced' to add panalize mistake
svc_model = SVC(class_weight='balanced', probability=True)
svc_model.fit(x_train, y_train)
svc_predict = svc_model.predict(x_test)# check performance
print('ROCAUC score:',roc_auc_score(y_test, svc_predict))
print('Accuracy score:',accuracy_score(y_test, svc_predict))
print('F1 score:',f1_score(y_test, svc_predict))
While in every machine learning problem, it’s a good rule of thumb to try a variety of algorithms, it can be especially beneficial with imbalanced datasets.
Decision trees frequently perform well on imbalanced data. In modern machine learning, tree ensembles (Random Forests, Gradient Boosted Trees, etc.) almost always outperform singular decision trees, so we’ll jump right into those:
Tree base algorithm work by learning a hierarchy of if/else questions. This can force both classes to be addressed.
# load library
from sklearn.ensemble import RandomForestClassifier
rfc = RandomForestClassifier()
# fit the predictor and target
rfc.fit(x_train, y_train)
# predict
rfc_predict = rfc.predict(x_test)# check performance
print('ROCAUC score:',roc_auc_score(y_test, rfc_predict))
print('Accuracy score:',accuracy_score(y_test, rfc_predict))
print('F1 score:',f1_score(y_test, rfc_predict))
Advantage:
Disadvantages:
Advantages:
Disadvantages:
To summarize, in this article, we have seen various techniques to handle the class imbalance in a dataset. There are actually many methods to try when dealing with imbalanced data. You can check the implementation of these codes in my GitHub repository here.
Key Takeaways
A. There are multiple ways to sample imbalanced data, you could apply oversampling methods to the majority class, or you could apply undersampling methods for solving imbalanced classification problems.
A. The ratio of classes in a dataset is the proportion of one class with other(s). The threshold for applying sampling techniques varies according to your problem statement, but generally, a class is considered to be a minority if it constitutes less than 10% of the dataset.
A. All the methods that we have talked about in this article can be used for classification problems with imbalanced datasets. That is why it is advantageous for a data scientist to know about sampling libraries like the imblearn library.
Lorem ipsum dolor sit amet, consectetur adipiscing elit,
What is x and y in the following code: # import library from imblearn.over_sampling import RandomOverSampler ros = RandomOverSampler(random_state=42) # fit predictor and target variablex_ros, y_ros = ros.fit_resample(x, y) print('Original dataset shape', Counter(y)) print('Resample dataset shape', Counter(y_ros))
ig x are the feature values for records and y is the overall labeled data for records.
what does the counter() function do? Cant seem to see where you defined it. Good article though... It was really helpful during my project.