Saumyab271 — Published On December 31, 2021
Artificial Intelligence Beginner NLP

This article was published as a part of the Data Science Blogathon


In 2018, a powerful Transformer-based machine learning model, namely, BERT was developed by Jacob Devlin and his colleagues from Google for NLP applications. BERT  is a very good pre-trained language model which helps machines learn excellent representations of text wrt context in many natural language tasks and thus outperforms the state-of-the-art.

In this article, we will use a pre-trained BERT model for a binary text classification task. In-text classification, the main aim of the model is to categorize a text into one of the predefined categories or labels.

BERT and Tensorflow
Illustration of usage of BERT model

In the above image, the output will be one of the categories i.e. 1 or 0 in the case of binary classification. Soon we are going to use the pre-trained BERT model to classify the email text as ham or spam category.

But before moving to the implementation, let’s discuss the concept of BERT and its usage briefly.

What is BERT?

BERT is an acronym for Bidirectional Encoder Representations from Transformers. The BERT architecture is composed of several Transformer encoders stacked together. Further, each Transformer encoder is composed of two sub-layers: a feed-forward layer and a self-attention layer.

BERT makes use of a Transformer that learns contextual relations between words in a sentence/text. The transformer includes 2 separate mechanisms: an encoder that reads the text input and a decoder that generates a prediction for any given task. BERT makes use of only the encoder as its goal is to generate a language model.

If you are interested in reading more about transformers, please refer to the paper by Google.

In contrast to state-of-the-art models, the Transformer encoder reads the entire sentence at once as it is bidirectional and thus more accurate. The bidirectional characteristic allows the model to learn all surroundings (right and left of the word) of words to better understand the context.

Text Classification with BERT

Now we’re going to jump to the implementation part where we will classify text using BERT. In this post, we’re going to use the SMS Spam Collection dataset. If you want to follow along, you can download the dataset from here.

This dataset is already in CSV format and it has 5169 sms, each labeled under one of 2 categories: ham, spam.

Let’s take a look at the first 5 rows of the dataset to have an idea about the dataset and what it looks like. The name of the dataset is “SMSSpamCollection”.

import pandas as pd
df= pd.read_csv(‘SMSSpamCollection’, sep=’t’, names=[“label”, “message”])

Text classification | BERT and Tensorflow

As can be seen from the above image, the dataframe only has two columns, which is a label that defines whether SMS is ham or spam, and a message that consists of SMS which will be our input data to the BERT model.

Now, just for sake of clarity, we rename the columns as Category and Message.

df.rename(columns = {'label':'Category', 'message':'Message'}, inplace = True)
BERT and Tensorflow

In the above image, the column names have been changed from label to Category and message to Message.

Now, we will define a variable called spam, which is a dictionary that maps the Category in the dataframe to a numeric value that is acceptable by the model and uniquely identifies each Category.

df['spam']=df['Category'].apply(lambda x: 1 if x=='spam' else 0)
text classification

In the above code, ham is mapped to 0, and spam is mapped to 1.

Next, we split the dataset into train and test and use stratified sampling for partitioning.

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(df['Message'],df['spam'], stratify=df['spam'])

The first four entries of X_train are shown in the above image. X_train consists of SMS while the corresponding label is in y_train for the training dataset.

Once, preprocessing is done, the next step is to download the BERT preprocessor and encoder for generating the model. Our model consists of one dense layer with 1 output unit that will give the probability of SMS being spam or ham as the sigmoid function is being used. After running the code above for 2 epochs, an accuracy of 90.07% is achieved from the training dataset. The accuracy that we get can slightly differ due to the randomness of the training process.

import tensorflow as tf
import tensorflow_hub as hub
!pip install tensorflow-text
import tensorflow_text as text

bert_preprocess = hub.KerasLayer(“”)
bert_encoder = hub.KerasLayer(“”)

# Bert layers
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
outputs = bert_encoder(preprocessed_text)
# Neural network layers
l = tf.keras.layers.Dropout(0.1, name="dropout")(outputs['pooled_output'])
l = tf.keras.layers.Dense(1, activation='sigmoid', name="output")(l)
# Use inputs and outputs to construct a final model
model = tf.keras.Model(inputs=[text_input], outputs = [l])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']), y_train, epochs=2, batch_size = 32)
y_predicted = model.predict(X_test)
y_predicted = y_predicted.flatten()

After training the model on the training dataset, we can predict the labels corresponding to the test dataset. It should be noted that in the above image, values are between (0,1). These are due to the usage of the sigmoid function in the last layer. We can further classify these values into 1 or 0 i.e. spam or ham based on some cutoff value say 0.5.


We have tried leveraging the pre-trained BERT model to classify the text in the simplest possible way. I hope this post helps any beginner to get started with the BERT model with the simplest coding.

One thing to note is that BERT has a large number of parameters, which requires high computing resources. The model training takes a lot of time and cost. In order to accelerate the speed of model training, other existing embedding models such as GloVe etc. can be used at the cost of accuracy. Further, usage of BERT is not limited to text or sentence classification but can also be applied to advanced Natural Language Processing applications such as next sentence prediction, question answering, or Named-Entity-Recognition tasks.

End Notes

Thanks for reading!

I hope you enjoyed learning about the BERT model, its usage, and implementation in classifying the text as spam or ham.

Code: Notebook

If you liked this and want to know more, go visit my other articles on Data Science and Machine Learning by clicking on the Link.

Feel free to connect over LinkedIn or mail.

The media shown in this article is not owned by Analytics Vidhya and are used at the Author’s discretion. 

About the Author

Our Top Authors

Download Analytics Vidhya App for the Latest blog/Article

One thought on "Text Classification using BERT and TensorFlow"

Kabu says: September 12, 2022 at 9:00 pm
I would like to use BERT to classify news articles. My questions are: 1) How do I know that my text input is less than 512 tokens? 2) How I can I use BERT to classify the whole article and not just a part of it? Or should I really do an arithmetic mean of the classification result of each part of the whole article? Reply

Leave a Reply Your email address will not be published. Required fields are marked *