Transfer Learning for NLP: Fine-Tuning BERT for Text Classification

Prateek Joshi 20 Jun, 2022 • 10 min read

Introduction

With the advancement in deep learning, neural network architectures like recurrent neural networks (RNN and LSTM) and convolutional neural networks (CNN) have shown a decent improvement in performance in solving several Natural Language Processing (NLP) tasks like text classification, language modeling, machine translation, etc.

However, this performance of deep learning models in NLP pales in comparison to the performance of deep learning in Computer Vision.

text classification using bert

One of the main reasons for this slow progress could be the lack of large labeled text datasets. Most of the labeled text datasets are not big enough to train deep neural networks because these networks have a huge number of parameters and training such networks on small datasets will cause overfitting.

Another quite important reason for NLP lagging behind computer vision was the lack of transfer learning in NLP. Transfer learning has been instrumental in the success of deep learning in computer vision. This happened due to the availability of huge labeled datasets like Imagenet on which deep CNN based models were trained and later they were used as pre-trained models for a wide range of computer vision tasks.

That was not the case with NLP until 2018 when the transformer model was introduced by Google. Ever since the transfer learning in NLP is helping in solving many tasks with state of the art performance.

In this article, I explain how do we fine-tune BERT for text classification.

If you want to learn NLP from scratch, check out our course – Natural Language Processing (NLP) Using Python

 

Table of Contents

  1. Transfer Learning in NLP
  2. What is Model Fine-Tuning?
  3. Overview of BERT
  4. Fine-Tune BERT for Spam Classification

 

Transfer Learning in NLP

Transfer learning is a technique where a deep learning model trained on a large dataset is used to perform similar tasks on another dataset. We call such a deep learning model a pre-trained model. The most renowned examples of pre-trained models are the computer vision deep learning models trained on the ImageNet dataset. So, it is better to use a pre-trained model as a starting point to solve a problem rather than building a model from scratch.

 

transfer learning

 

This breakthrough of transfer learning in computer vision occurred in the year 2012-13. However, with recent advances in NLP, transfer learning has become a viable option in this NLP as well.

Most of the tasks in NLP such as text classification, language modeling, machine translation, etc. are sequence modeling tasks. The traditional machine learning models and neural networks cannot capture the sequential information present in the text. Therefore, people started using recurrent neural networks (RNN and LSTM) because these architectures can model sequential information present in the text.

RNN

A typical RNN

However, these recurrent neural networks have their own set of problems. One major issue is that RNNs can not be parallelized because they take one input at a time. In the case of a text sequence, an RNN or LSTM would take one token at a time as input. So, it will pass through the sequence token by token. Hence, training such a model on a big dataset will take a lot of time.

So, the need for transfer learning in NLP was at an all-time high. In 2018, the transformer was introduced by Google in the paper “Attention is All You Need” which turned out to be a groundbreaking milestone in NLP.

transformer nlp

The Transformer – Model Architecture
(Source: https://arxiv.org/abs/1706.03762)

Soon a wide range of transformer-based models started coming up for different NLP tasks. There are multiple advantages of using transformer-based models, but the most important ones are:

  • First Benefit

    These models do not process an input sequence token by token rather they take the entire sequence as input in one go which is a big improvement over RNN based models because now the model can be accelerated by the GPUs.

  • 2nd Benefit

    We don’t need labeled data to pre-train these models. It means that we have to just provide a huge amount of unlabeled text data to train a transformer-based model. We can use this trained model for other NLP tasks like text classification, named entity recognition, text generation, etc. This is how transfer learning works in NLP.

BERT and GPT-2 are the most popular transformer-based models and in this article, we will focus on BERT and learn how we can use a pre-trained BERT model to perform text classification.

 

What is Model Fine-Tuning?

BERT (Bidirectional Encoder Representations from Transformers) is a big neural network architecture, with a huge number of parameters, that can range from 100 million to over 300 million. So, training a BERT model from scratch on a small dataset would result in overfitting.

So, it is better to use a pre-trained BERT model that was trained on a huge dataset, as a starting point. We can then further train the model on our relatively smaller dataset and this process is known as model fine-tuning.

Different Fine-Tuning Techniques

  • Train the entire architecture – We can further train the entire pre-trained model on our dataset and feed the output to a softmax layer. In this case, the error is back-propagated through the entire architecture and the pre-trained weights of the model are updated based on the new dataset.
  • Train some layers while freezing others – Another way to use a pre-trained model is to train it partially. What we can do is keep the weights of initial layers of the model frozen while we retrain only the higher layers. We can try and test as to how many layers to be frozen and how many to be trained.
  • Freeze the entire architecture – We can even freeze all the layers of the model and attach a few neural network layers of our own and train this new model. Note that the weights of only the attached layers will be updated during model training.

In this tutorial, we will use the third approach. We will freeze all the layers of BERT during fine-tuning and append a dense layer and a softmax layer to the architecture.

 

Overview of BERT

You’ve heard about BERT, you’ve read about how incredible it is, and how it’s potentially changing the NLP landscape. But what is BERT in the first place?

Here’s how the research team behind BERT describes the NLP framework:

“BERT stands for Bidirectional Encoder Representations from Transformers. It is designed to pre-train deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context. As a result, the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of NLP tasks.”

That sounds way too complex as a starting point. But it does summarize what BERT does pretty well so let’s break it down.

Firstly, BERT stands for Bidirectional Encoder Representations from Transformers. Each word here has a meaning to it and we will encounter that one by one in this article. For now, the key takeaway from this line is – BERT is based on the Transformer architecture. Secondly, BERT is pre-trained on a large corpus of unlabelled text including the entire Wikipedia (that’s 2,500 million words!) and Book Corpus (800 million words).

This pre-training step is half the magic behind BERT’s success. This is because as we train a model on a large text corpus, our model starts to pick up the deeper and intimate understandings of how the language works. This knowledge is the swiss army knife that is useful for almost any NLP task.

Third, BERT is a “deep bidirectional” model. Bidirectional means that BERT learns information from both the left and the right side of a token’s context during the training phase.

To learn more about the BERT architecture and its pre-training tasks, then you may like to read the below article:

 

Fine-Tune BERT for Spam Classification

Now we will fine-tune a BERT model to perform text classification with the help of the Transformers library. You should have a basic understanding of defining, training, and evaluating neural network models in PyTorch. If you want a quick refresher on PyTorch then you can go through the article below:

Link to Colab Notebook

Problem Statement

We have a collection of SMS messages. Some of these messages are spam and the rest are genuine. Our task is to build a system that would automatically detect whether a message is spam or not.

The dataset that we will be using for this use case can be downloaded from here (right-click and click on “Save link as…”).

I suggest you use Google Colab to perform this task so that you can use the GPU. Firstly, activate the GPU runtime on Colab by clicking on Runtime -> Change runtime type -> Select GPU.

Install Transformers Library

We will then install Huggingface’s transformers library. This library lets you import a wide range of transformer-based pre-trained models. Just execute the code below to install the library.

!pip install transformers

Import Libraries

Load Dataset

You would have to upload the downloaded spam dataset to your Colab runtime. Then read it into a pandas dataframe.

Output:

spam dataset

 

 

 

 

 

 

The dataset consists of two columns – “label” and “text”. The column “text” contains the message body and the “label” is a binary variable where 1 means spam and 0 means the message is not a spam.

Now we will split this dataset into three sets – train, validation, and test.

We will fine-tune the model using the train set and the validation set, and make predictions for the test set.

 

Import BERT Model and BERT Tokenizer

We will import the BERT-base model that has 110 million parameters. There is an even bigger BERT model called BERT-large that has 345 million parameters.

Python Code:

Let’s see how this BERT tokenizer works. We will try to encode a couple of sentences using the tokenizer.

Output:

{‘input_ids’: [[101, 2023, 2003, 1037, 14324, 2944, 14924, 4818, 102, 0],
[101, 2057, 2097, 2986, 1011, 8694, 1037, 14324, 2944, 102]],

‘attention_mask’: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

 

As you can see the output is a dictionary of two items.

  • ‘input_ids’ contains the integer sequences of the input sentences. The integers 101 and 102 are special tokens. We add them to both the sequences, and 0 represents the padding token.
  • ‘attention_mask’ contains 1’s and 0’s. It tells the model to pay attention to the tokens corresponding to the mask value of 1 and ignore the rest.

 

Tokenize the Sentences

Since the messages (text) in the dataset are of varying length, therefore we will use padding to make all the messages have the same length. We can use the maximum sequence length to pad the messages. However, we can also have a look at the distribution of the sequence lengths in the train set to find the right padding length.

nlp

 

 

 

 

 

 

 

We can clearly see that most of the messages have a length of 25 words or less. Whereas the maximum length is 175. So, if we select 175 as the padding length then all the input sequences will have length 175 and most of the tokens in those sequences will be padding tokens which are not going to help the model learn anything useful and on top of that, it will make the training slower.

Therefore, we will set 25 as the padding length.

So, we have now converted the messages in train, validation, and test set to integer sequences of length 25 tokens each.

Next, we will convert the integer sequences to tensors.

Now we will create dataloaders for both train and validation set. These dataloaders will pass batches of train data and validation data as input to the model during the training phase.

 

Define Model Architecture

If you can recall, earlier I mentioned in this article that I would freeze all the layers of the model before fine-tuning it. So, let’s do it first.

This will prevent updating of model weights during fine-tuning. If you wish to fine-tune even the pre-trained weights of the BERT model then you should not execute the code above.

Moving on we will now let’s define our model architecture.

We will use AdamW as our optimizer. It is an improved version of the Adam optimizer. To learn more about it do check out this paper.

There is a class imbalance in our dataset. The majority of the observations are not spam. So, we will first compute class weights for the labels in the train set and then pass these weights to the loss function so that it takes care of the class imbalance.

Output: [0.57743559 3.72848948]

Fine-Tune BERT

So, till now we have defined the model architecture, we have specified the optimizer and the loss function, and our dataloaders are also ready. Now we have to define a couple of functions to train (fine-tune) and evaluate the model, respectively.

We will use the following function to evaluate the model. It will use the validation set data.

Now we will finally start fine-tuning of the model.

Output:

Training Loss: 0.592
Validation Loss: 0.567

Epoch 5 / 10
Batch 50 of 122.
Batch 100 of 122.

Evaluating...

Training Loss: 0.566
Validation Loss: 0.543

Epoch 6 / 10
Batch 50 of 122.
Batch 100 of 122.

Evaluating...

Training Loss: 0.552
Validation Loss: 0.525

Epoch 7 / 10
Batch 50 of 122.
Batch 100 of 122.

Evaluating...

Training Loss: 0.525
Validation Loss: 0.498

Epoch 8 / 10
Batch 50 of 122.
Batch 100 of 122.

Evaluating...

Training Loss: 0.507
Validation Loss: 0.477

Epoch 9 / 10
Batch 50 of 122.
Batch 100 of 122.

Evaluating...

Training Loss: 0.488
Validation Loss: 0.461

Epoch 10 / 10
Batch 50 of 122.
Batch 100 of 122.

Evaluating...

Training Loss: 0.474
Validation Loss: 0.454

You can see that the validation loss is still decreasing at the end of the 10th epoch. So, you may try a higher number of epochs. Now let’s see how well it performs on the test dataset.

 

Make Predictions

To make predictions, we will first of all load the best model weights which were saved during the training process.

Once the weights are loaded, we can use the fine-tuned model to make predictions on the test set.

Let’s check out the model’s performance.

Output:

text classification using bert

 

 

 

 

 

Both recall and precision for class 1 are quite high which means that the model predicts this class pretty well. However, our objective was to detect spam messages, so misclassifying class 1 (spam) samples is a bigger concern than misclassifying class 0 samples. If you look at the recall for class 1, it is 0.90 which means that the model was able to correctly classify 90% of the spam messages. However, precision is a bit on the lower side for class 1. It means that the model misclassifies some of the class 0 messages (not spam) as spam.

Link to Colab Notebook

 

End Notes

To summarize,  in this article, we fine-tuned a pre-trained BERT model to perform text classification on a very small dataset. I urge you to fine-tune BERT on a different dataset and see how it performs. You can even perform multiclass or multi-label classification with the help of BERT. In addition to that, you can even train the entire BERT architecture as well if you have a bigger dataset.

In case you are looking for a roadmap to becoming an expert in NLP read the following article-

You may use the comment section in case you have any thoughts to share or have any doubts.

Prateek Joshi 20 Jun 2022

Data Scientist at Analytics Vidhya with multidisciplinary academic background. Experienced in machine learning, NLP, graphs & networks. Passionate about learning and applying data science to solve real world problems.

Frequently Asked Questions

Lorem ipsum dolor sit amet, consectetur adipiscing elit,

Responses From Readers

Clear

Dinesh Chauhan
Dinesh Chauhan 21 Jul, 2020

Hi, Thanks a lot for very detailed explanation. I have some doubt around precision/recall inference you made. You mentioned recall for class 1 is high (0.90) so model will correctly identify spam 90% of time , however doesn't that metric should be accuracy ? Also you mentioned precision is on lower side (0.39) which means we would misclassify non-spam as spam , however how do you interpret high precision for class 0 in same context ? if possible please expand little more on precision/recall for both spam (1) & ham(class 0).Thanks & Regards

John ODonovan
John ODonovan 23 Jul, 2020

Hi Dinesh, Nice tutorial, thanks! I got similar results to you. ( I used my own GPU box instead of Colab)precision recall f1-score support0 0.97 0.87 0.91 7241 0.48 0.81 0.61 112accuracy 0.86 836macro avg 0.73 0.84 0.76 836weighted avg 0.90 0.86 0.87 836

Tegene
Tegene 28 Jul, 2020

Hi Prateek ! Thank you very much for your hands on explanation on such comple concept. I am enjoying your tutorials as well. Currently, am working a project on transfer learning for netx word prediction for one of my local languages. the language uses latin letters(english letters) and it uses a very long suffixes when it is inflected, it is also a low resource language. so which method would you recommend?if you can please put the steps for me. thanks in advance

ali
ali 07 Aug, 2020

Hi Can we save the model and then use it for prediction? How do we do?

Jim
Jim 13 Aug, 2020

Hi Dinesh,I have a dataset with 3 classes and it is very imbalanced. However when I run the last piece of code where start fine-tuning of the model I got this error: RuntimeError: weight tensor should be defined either for all 2 classes or no classes but got weight tensor of shape: [3] at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:47Can you help me?

Susanna
Susanna 20 Aug, 2020

Hi Prateek! Thank you so much for the explanation! It's very helpful! However, in the fine-tuning part, I'm getting the following error.RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forwardWould you happen to know how to fix this? I would be very grateful!

Stephen
Stephen 20 Aug, 2020

Hi Prateek,Great work! Can we specifically find out which text in the datasets was miss-classified? And also is there a way to add a few more lines so the code can take in sentences and then predict if it is spam or ham?

aly mostafa
aly mostafa 28 Aug, 2020

why did you choose '768' in self.fc1 = nn.Linear(768,512) in class BERT_Arch(nn.Module): thanks in advance

Aman Kabra
Aman Kabra 06 Sep, 2020

Thanks for the article, Prateek. Much appreciated!The following error is thrown while running the code: .... Batch 3,800 of 3,906. Batch 3,850 of 3,906. Batch 3,900 of 3,906.Evaluating... --------------------------------------------------------------------------- NameError Traceback (most recent call last) in () 15 16 #evaluate model ---> 17 valid_loss, _ = evaluate() 18 19 #save the best modelin evaluate() 19 20 # Calculate elapsed time in minutes. ---> 21 elapsed = format_time(time.time() - t0) 22 23 # Report progress.NameError: name 'format_time' is not definedCould you help out?

yasmyne
yasmyne 28 Sep, 2020

Hi,thank s for this great tutorial I want to apply this for a multi label text classification problem. My labels are of this format tensor([[0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], ..., [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0]]) I changed the softmax function in the bert model by the sigmoid function but when I tried to train the model I got this error multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:18Could u help plz thank u

Murtuza Dahodwala
Murtuza Dahodwala 30 Sep, 2020

Hello Prateek, Very nice article to read!! However i want to use BERT to map correct word in my vocabulary with it's wrong spelling. I have a dataset with 20lakh food item names and another dataset with all the correct words.For eg:IN: "Cheeze Piza 2" OUT: "Cheese Pizza"I am aware that i can use BERT or GPT for this task but I'm not sure where to start. Can you please guide?

Ken
Ken 03 Oct, 2020

Hi Thanks for you great explanations.I just wonder whether we can build knowledge graph with BERT from scratch? Or BERT is only good at the part of text classification?

Fareen
Fareen 09 Oct, 2020

I want to generate sentence embeddings using this fine-tuned pytorch model. How can I do this?

Mehul
Mehul 10 Oct, 2020

Hi Prateek, Thanks a lot for this great blog. I just have one question.How do we came to know that we have to unpack the output of the bert in forward pass? _, cls_hs = self.bert(sent_id, attention_mask=mask) Where can I find this detail in the Transformer documentation?Thanks again. : )

Peter
Peter 24 Oct, 2020

Thank you for the tutorial, I had success training 500K tweets for Sentiment Analysis (positive/negative), when loading the model to evaluate, section "# get predictions for test data" i am getting the following error: Traceback (most recent call last): File "C:/Users/Botros/PycharmProjects/mlwork/test.py", line 201, in preds = model(test_seq.to(device), test_mask.to(device)) File "C:\Users\Botros\PycharmProjects\mlwork\venv\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "C:/Users/Botros/PycharmProjects/mlwork/test.py", line 166, in forward _, cls_hs = self.bert(sent_id, attention_mask=mask) File "C:\Users\Botros\PycharmProjects\mlwork\venv\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "C:\Users\Botros\PycharmProjects\mlwork\venv\lib\site-packages\transformers\modeling_bert.py", line 830, in forward embedding_output = self.embeddings( File "C:\Users\Botros\PycharmProjects\mlwork\venv\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "C:\Users\Botros\PycharmProjects\mlwork\venv\lib\site-packages\transformers\modeling_bert.py", line 198, in forward position_embeddings = self.position_embeddings(position_ids) File "C:\Users\Botros\PycharmProjects\mlwork\venv\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "C:\Users\Botros\PycharmProjects\mlwork\venv\lib\site-packages\torch\nn\modules\sparse.py", line 124, in forward return F.embedding( File "C:\Users\Botros\PycharmProjects\mlwork\venv\lib\site-packages\torch\nn\functional.py", line 1814, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) IndexError: index out of range in selfi was able to create another model with 100k and run it fine, can you help?

Francesco
Francesco 05 Nov, 2020

Hello, very nice tutorial. in my dataset i have 5 class, and i get this error: "RuntimeError: weight tensor should be defined either for all 2 classes or no classes but got weight tensor of shape: [5] at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:44"I have to change the number of output units from 2 to 5 in the model architecture, but exatly where in the code? And how?

Mertay
Mertay 21 Nov, 2020

Hey Prateek,Thanks for the tutorial it was great. I have a question. How do I make predictions using the model. Not necessarily on the test but I want to come up with my input and see the results. Can you help me to get one example?Thank you.

Shreyash
Shreyash 01 Dec, 2020

Hi! Really nice explanation. My code was running perfectly well yesterday and today it seems to give an error out of nowhere. I even tried your exact same code with the same dataset still it is giving the error./usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py in forward(self, input) 91 92 def forward(self, input: Tensor) -> Tensor: ---> 93 return F.linear(input, self.weight, self.bias) 94 95 def extra_repr(self) -> str:/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias) 1686 if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): 1687 return handle_torch_function(linear, tens_ops, input, weight, bias=bias) -> 1688 if input.dim() == 2 and bias is not None: 1689 # fused op is marginally faster 1690 ret = torch.addmm(bias, input, weight.t())AttributeError: 'str' object has no attribute 'dim'Please let me know where is it going wrong?

Kieran
Kieran 10 Dec, 2020

Hello Prateek, thank you very much for this amazing article was really informative and you were able to explain such a complex thing in an easy way. Keep up your good work. cheers!!!

paulo gamalho
paulo gamalho 16 Dec, 2020

Thanks, Prateek! Nice work! I have two questions: How can add my own test text? and how can I print a new csv file provided with the labels predicted by the system? Thanks!

Fatima Zulfiqar
Fatima Zulfiqar 06 Jan, 2021

Hi I have successfully train the model however I am getting following error as:RuntimeError: CUDA out of memory. Tried to allocate 5.04 GiB (GPU 0; 14.73 GiB total capacity; 9.24 GiB already allocated; 1.88 GiB free; 11.80 GiB reserved in total by PyTorch)Can you help me out? I am using Google Colab

Abhishek
Abhishek 08 Feb, 2021

Hi, Prateek. This is a beautiful explaination. Can we use it for multi label classification? If yes, what are the changes that we need to make?

Maram
Maram 11 Feb, 2021

Hi Dinesh,I want to make the BERT model predict the classes for multi-classes classification, for a giving textfor example for "Wish you a great day" the model will give me prediction of the class "Positive"I tried several ways to do it but none of them working "), do I have to save my fine-tuned model before making prediction?Can you kindly give me advice, hint of the way to do it. I would be very grateful!

Simone
Simone 13 Feb, 2021

Hi,Thanks fro sharing, this is a great article!One question, would still be relevant utilising `class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)` in your cross entropy loss function even if you have more than two classes to predict?

Johnny Stone
Johnny Stone 02 Dec, 2021

How would one go about to make a prediction for a custom text using the fine tuned model?

Nemo
Nemo 06 Jan, 2022

RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d" not implemented for 'Long' Hi, I cannot fix this this problem. Can you please help me?

Alice
Alice 23 Feb, 2022

how can we save this model using save.pretrained(dir)? I get BERT_Arch object has no attribute 'save_pretrained'

Rifat
Rifat 16 Nov, 2023

Assalomu alaykum ,Prateek Joshi, Thanks you for all , it is great instruction. I have a question that I can use BERT model for other language such as Uzbek.?

Natural Language Processing
Become a full stack data scientist

  • [tta_listen_btn class="listen"]