Shivani Sharma — September 1, 2021
Advanced Data Science Deep Learning Libraries NLP Project Python Text Unsupervised

This article was published as a part of the Data Science Blogathon


Trax is a full-featured deep learning library with a focus on clean code and fast computation. In syntax, it is generally similar to Keras, and a Trax model can be converted to a Keras model. The library is actively developed and supported by the Google Brain team. Trax uses Tensorflow and is one of the libraries in its ecosystem. It runs on CPU, GPU, and TPU and uses the same version.

A transformer is designed to work with sequences, including textual ones, but unlike architectures on recurrent networks, it does not require processing the sequence in order. Simplifying greatly, we can say that if we leave only the attention mechanism from the Seq2Seq architecture on an LSTM with an attention mechanism and add a feed-forward neural network, then it will work.

Here is my experiment of creating a summarizer, this construction receives an article as input and generates a short text describing the essence using the world-famous Trax library. A summary can be just a heading. I’ll try to tell you about everything in detail. Let’s begin with data analysis!


As a dataset for the experiment, I decided to use the Lenta.Ru news corpus, the latest version of which I found on Kaggle. The corpus contains over 800 thousand news articles in the format (URL, title, text, topic, tags, date). If the article is text, then the summary for my model is the title. This is a complete sentence containing the main message of the news article.

First, I filtered out abnormally short and abnormally long articles. Then I selected texts and headings from the set, converted everything to lowercase, saved it as a list of tuples and as a full text. I split the list of tuples into two parts – for training (train) and evaluation (eval). Then I wrote an “infinite” generator, which, having reached the end of the list, shuffles it and starts over. It is unpleasant when the generator “ends” somewhere in the middle of an era. This is important primarily for the assessment set, I took only 5% of the total number of articles, about 36 thousand pairs.

Based on the full text, I trained the tokenizer and used parts of words as tokens. The problem of tokenization or segmentation into whole words is that some words in the text are rare, perhaps only once, and there are a lot of such words, and the size of the dictionary is finite and I want to make it not very large to fit into the memory of the virtual machine. You have to replace some words with named templates, often use placeholders for words that are not in the dictionary, and even use special techniques like pointer-generator. And splitting into subwords allows you to make a tokenizer with a small dictionary, which also works practically without loss of information.

For such segmentation, there are several relatively honest ways, you can get acquainted with them. I chose the Byte Pair Encoding (BPE) based model implemented in the sentence piece library. BPE is a method of encoding text with compression. To encode a frequently repeated sequence of characters, a character is used that is not in the original sequence. Everything is the same with segmentation, only a sequence of frequently occurring characters becomes a new token, and so on until the specified size of the dictionary is reached. My dictionary contains 16,000 tokens.

The model has been trained thanks to such a simple design:

import sentencepiece as spm
                                --pad_id=0 --bos_id=-1 --eos_id=1 --unk_id=2 
                                --model_prefix=bpe --vocab_size=16000 --model_type=bpe')

The result is two files: a dictionary for control and a model that can be loaded into the tokenizer wrapper. For the model I have chosen, the article and title must be converted to a sequence of integers and concatenated with the service tokens EOS: 1 and PAD: 0 (end of sequence and placeholder).

After conversion, the sequence is placed in a fixed-length bucket. I have three of them: 256, 512, and 1024. The sequences in the basket are automatically padded with placeholders to a fixed length and collected in batches. The number of sequences in the package depends on the basket, respectively 16, 8, 4.

Reflection on sequences longer than 512 tokens

Segmentation and concatenation is done in the trax pipeline:
input_pipeline ='sentencepiece',
train_stream = input_pipeline(train_data_stream())
eval_stream = input_pipeline(eval_data_stream())
preprocessing is my concatenation function, generator. The sorting into baskets and the formation of packages is carried out thanks to the following design:
boundaries =  [256, 512]
batch_sizes = [16, 8, 4]
train_batch_stream =
    boundaries, batch_sizes)(train_stream)
eval_batch_stream =
    boundaries, batch_sizes)(eval_stream)


A transformer that works with two sequences, for example, for machine translation, includes two blocks – an encoder and a decoder, but only a decoder is sufficient for summarization. Such an architecture generally implements a language model where the probability of the next word is determined from the previous ones. It is also called Decoder-only Transformer and is similar to GPT (Generative Pre-trained Transformer).

For my case, the Trax library has a separate model class Trax.models.transformer.TransformerLM (…), that is, you can create a model with one line of code. In the mentioned specialization, the model is built from scratch. I chose something in between – I built a model from ready-made blocks using code examples.

The diagram of the model is shown in the figure:

model | Summarizer using the Trax

PositionlEncoder () is a block that provides vector space construction and coding of the token position in the input sequence. Code:

from trax import layers as tl
def PositionalEncoder(vocab_size, d_model, dropout, max_len, mode):
return [ 
        tl.Embedding(vocab_size, d_model),  
        tl.Dropout(rate=dropout, mode=mode), 
        tl.PositionalEncoding(max_len=max_len, mode=mode)]


  1. vocab_size (int): vocabulary size
  2. d_model (int): number of vector space features
  3. dropout (float): degree of use dropout
  4. max_len (int): maximum sequence length for positional encoding
  5. mode (str): ‘train’ or ‘eval’ – for dropout and pos. coding.
Feed Forward generates a feed-forward block with the selected activation functions:
def FeedForward(d_model, d_ff, dropout, mode, ff_activation):
    return [ 
        tl.Dropout(rate=dropout, mode=mode), 
        tl.Dropout(rate=dropout, mode=mode) 


  1. d_model (int): the number of vector space
  2. features d_ff (int): the “width” of the block or the number of units in the output dense layer
  3. dropout (float): the degree of use of dropout
  4. mode (str): ‘train’ or ‘eval’ – so as not to use dropout when evaluating the model quality
  5. ff_activation (function): activation function, in my model – ReLU
DecoderBlock (…) are two Residual blocks. It is unlikely that the translation “residual” accurately captures the meaning, but it is a workaround for dealing with vanishing gradients in deep architectures.
If you count from input to output, then the first block contains the attention mechanism, I used a ready-made level from the library. The second is the feed-forward block described above. The attention mechanism here is unusual, it “looks” at the same sequence for which the next token is generated, and so that it does not “look into the future”, a special mask is used when calculating the weights.
def DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation):
    return [
          tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) 
          FeedForward(d_model, d_ff, dropout, mode, ff_activation)
Of the unknown arguments, only n_heads (int) is the number of attention heads, I hope this is a good term for attention heads. Each head learns to pay attention to something different.

Putting all the parts together and setting the parameters of the model. I have six decoders, each with eight attention heads. The total number of parameters to be taught is 37 412 480.
Of the levels unknown to me, perhaps only Shift Right. It shifts the input sequence to the right, filling the vacant space with zeros, by default – one position. This is necessary for teacher force, a special technique that simplifies the teaching of the language model, especially in the early stages. The idea here is this: when the model learns to predict the next word from the previous ones, instead of predicting the model, which may be wrong, the ground truths are used as those previous words. This can be briefly described by the formula:
y (t) = x (t + 1). Here’s a detailed explanation for RNN.
def SumTransformer(vocab_size=vocab_size,
    decoder_blocks = [DecoderBlock(d_model, d_ff, n_heads, dropout, mode, 
                      ff_activation) for _ in range(n_layers)] 
    return tl.Serial(
        PositionalEncoder(vocab_size, d_model, dropout, max_len, mode),


In my experience, Google Colab doesn’t like the long-term use of its GPUs and doesn’t always stand out, especially in the afternoon. Therefore, I trained the model in separate epochs of 20,000 steps, where a step corresponds to one batch. It turned out to make 1-2 epochs a day. 100 steps are about a minute, and an epoch is about three hours.
The first era showed that the model learns only a few thousand steps, no further improvement occurs. It turned out that I chose too large a learning_rate. For my model, it should be 0.0002 for the first few epochs, then 0.0001 and 0.00005 at the end. If I was teaching the model in one pass, then I could use the lr_schedules from Trax.supervised. There are various convenient options with warming up and with a gradual decrease in step.

I used CrossEntropyLoss and Accuracy as metrics. Over 12 epochs on the estimate set, loss fell from 10 to 2, and the share of correct answers increased to almost 60%. This turned out to be enough to generate almost acceptable headlines.
#The learning cycle looks like this:
from trax.supervised import training
def training_loop(SumTransformer, train_gen, eval_gen, output_dir = "~/model"):
    output_dir = os.path.expanduser(output_dir)
    train_task = training.TrainTask( 
    eval_task = training.EvalTask( 
        metrics=[tl.CrossEntropyLoss(), tl.Accuracy()] 
    loop = training.Loop(SumTransformer(),
    return loop


  1. SumTransformer (trax.layers.combinators.Serial): model
  2. train_gen (generator): data flow for training
  3. eval_gen (generator): data flow for quality assessment.
  4. output_dir (str): folder for the model file, from where it can be copied to Google Drive before shutting down the virtual machine.

Then everything is simple:

loop = training_loop(SumTransformer, train_batch_stream, eval_batch_stream)
and three hours of waiting …

Evaluation of results

To evaluate the results, I used a greedy argmax-based decoder, which determines the index of the most probable token in the dictionary from the position of the maximum value in the output tensor. Then the token is added to the input sequence and the operation is repeated until the EOS symbol appears or the specified maximum sentence length is reached.

Examples from the evaluation set:

Test # 1: the Swiss watch company Audemars Piguet has presented a new model from the royal oak collection. as reported by luxury launches, this is a perpetual calendar watch. The official presentation will take place at the sihh international fine watchmaking salon in Geneva …

Sample: Audemars Piguet has equipped a watch with a perpetual calendar

Audemars Piguet has presented a new model from the royal oak collection
Test # 2: At the annual festival in Grahamstown, South Africa, a magician accidentally shot his partner in the head during a performance. the local newspaper the daily dispatch reports. the incident took place on June 30th. Brandon saw (Brendon peel) and his assistant whether lau (li lau) performed a magic trick in front of a large audience when drunk inadvertently let an arrow in the back of his partner …

magician accidentally shot an assistant in front of the audience

at the festival in Pulkovo attacked with a knife
(And not in Pulkovo, and did not attack, and not with a knife, but thanks that it was a melee weapon, not a pistol).


Trax library is easy to use for simple deep learning projects. In this article, you can summarize any text, blog, article within seconds. This is a beginner-friendly project!

The media shown in this article are not owned by Analytics Vidhya and are used at the Author’s discretion.

About the Author

Our Top Authors

  • Analytics Vidhya
  • Guest Blog
  • Tavish Srivastava
  • Aishwarya Singh
  • Aniruddha Bhandari
  • Abhishek Sharma
  • Aarshay Jain

Download Analytics Vidhya App for the Latest blog/Article

Leave a Reply Your email address will not be published. Required fields are marked *