Shivani Sharma — September 10, 2021
Advanced Classification Deep Learning Libraries Project Python PyTorch

This article was published as a part of the Data Science Blogathon


With ignite, you can write loops to train the network in just a few lines, add standard metrics calculation out of the box, save the model, etc. Well, for those who have moved from TF to PyTorch, we can say that the ignite – Keras library for PyTorch.
This article will detail an example of training a neural network for a classification problem using ignite



Add even more fire to PyTorch With Ignite

I will not spend time talking about how cool the framework PyTorch is. Anyone who has already used it understands what I am writing about. But, with all its advantages, it is still low-level in terms of writing loops for training, checking, testing neural networks.

If we look at the official examples of using the PyTorch framework, we will see in the grid training code at least two iteration cycles by epochs and by batches of the training sample:

for epoch in range(1, epochs + 1):
for batch_idx, (data, target) in enumerate(train_loader):

# …

The main idea of ​​the ignite library is to factor these loops into a single class, while still allowing the user to interact with these loops using event handlers.

As a result, in the case of standard deep learning tasks, we can save a lot on the number of lines of code. Fewer lines – fewer errors!

For example, for comparison, on the left is the code for training and validating the model using ignite, and on the right is pure PyTorch:

So, again, what is ignite good for?

  • you no longer need to write loops for epoch in range(n_epochs)and for each task for a batch in data_loader.

  • allows you to factorize your code better

  • allows you to calculate basic metrics out of the box

  • provides “goodies” like

    • keeping the latest and best models (also optimizer and learning rate scheduler) during training,

    • stop learning early

    • etc

  • Easily integrates with visualization tools: tensorboard X, wisdom.

In a sense, as already mentioned, the ignite library can be compared to the well-known Keras and its API for training and testing networks. Also, at first glance, the ignite library is very similar to the TNT library, since initially both libraries pursued the same goals and had similar ideas for their implementation.

So, we light up:

pip install PyTorch-ignite
conda install ignite -c PyTorch

Further, with a specific example, we will familiarize ourselves with the ignite library API.

Classification problem with Ignite

In this part of the article, we will consider a school example of training a neural network for a classification problem using the ignite library.

So, let’s take a simple kaggle fruit image dataset. The task is to assign a corresponding class to each picture with a fruit.

Before using ignite, let’s define the main components:


  • training sample batch loader, train_loader

  • test sample batch loader, val_loader


  • take a small SqueezeNet grid from torch vision

Optimization algorithm:

  • let’s take SGD

Loss function:

  • Cross-entropy


So now it’s time to run ignite :

from ignite.engine import Engine, _prepare_batch
def process_function(engine, batch):
x, y = _prepare_batch(batch, device=device)
y_pred = model(x)
loss = criterion(y_pred, y)
return loss.item()
trainer = Engine(process_function)

Let’s see what this code means.

Engine in Ignite

The class ignite. engine. Engine Is the framework of the library, and the object of this class is trainer:

trainer = Engine(process_function)

is defined with an input function process_functionfor processing one batch and is used to implement passes through the training set. Inside the class ignite. engine. Engine, the following happens:

while epoch < max_epochs:
# run once on data
for batch in data:
output = process_function(batch)
Let's go back to the function process_function:
def process_function(engine, batch):
x, y = _prepare_batch(batch, device=device)
y_pred = model(x)
loss = criterion(y_pred, y)
return loss.item()

We see that inside the function, as usual in the case of training the model, we calculate predictions y_pred, calculate the loss function loss and gradients. The latter allows us to update the model of the weight: optimizer.step().

In general, there are no restrictions on the function code process_function. We only note that it takes two arguments as input: an object Engine(in our case trainer) and a batch from the data loader. Therefore, for example, to test a neural network, we can define another object of the class ignite.engine.Engine, in which the input function simply calculates predictions, and implement a pass through the test sample once. Read on for more details.

So, the above code only defines the necessary objects without starting the training. Basically, in a minimal example, a method can be called:, max_epochs=10)

and this code is enough to “quietly” (without any output of intermediate results) train the model.

The note

Of course, in practice, the above example is of little interest, so let’s add the following options for the “trainer”:

  • displaying the value of the loss function every 50 iterations

  • starting the calculation of metrics on a training set with a fixed model

  • starting the calculation of metrics on the test sample after each epoch

  • saving model parameters after each epoch

  • keeping the top three models

  • learning rate scheduling

  • early stopping

Events and event handlers in Python

To add the above options for the “trainer”, the ignite library provides an event system and launches custom event handlers. Thus, the user can manipulate the class object Engine at each stage:

  • the engine has started/finished launching

  • the era has begun/ended

  • batch iteration started / ended

and run your code on every event.

Displaying the value of the loss function

To do this, you just need to define the function in which the display will take place, and add it to the “trainer”:

from ignite.engine import Events
log_interval = 50
def log_training_loss(engine):
it = ( - 1) % len(train_loader) + 1
if it % log_interval == 0:
print("Epoch[{}] Iteration[{}/{}] Loss: {:.4f}"

There are actually two ways to add an event handler: through add_event_handleror through a decorator on. We can also do the same as above:

from ignite.engine import Events
log_interval = 50
def log_training_loss(engine):

# …

trainer.add_event_handler(Events.ITERATION_COMPLETED, log_training_loss)

Note that any arguments can be passed to the event handling function. In general, such a function will look like this:

def custom_handler(engine, *args, **kwargs):
trainer.add_event_handler(Events.ITERATION_COMPLETED, custom_handler, *args, **kwargs)

# or

@trainer.on(Events.ITERATION_COMPLETED, *args, **kwargs)
def custom_handler(engine, *args, **kwargs):
So, let's start training on one epoch and see what happens:
output =, max_epochs=1)
Epoch[1] Iteration[50/322] Loss: 4.3459
Epoch[1] Iteration[100/322] Loss: 4.2801
Epoch[1] Iteration[150/322] Loss: 4.2294
Epoch[1] Iteration[200/322] Loss: 4.1467
Epoch[1] Iteration[250/322] Loss: 3.8607
Epoch[1] Iteration[300/322] Loss: 3.6688

Not bad! Let’s go further.

Starting the calculation of metrics on training and test samples

Let’s calculate the following metrics: average accuracy, average completeness after each epoch on a part of the training, and the entire test sample. Note that we will compute metrics on a portion of the training sample after each training epoch, and not during training. Thus, the measurement of efficiency will be more accurate, since the model does not change during the calculation.

So, let’s define the metrics:

from ignite. metrics import Loss, CategoricalAccuracy, Precision, Recall
metrics = {
'avg_loss': Loss(criterion),
'avg_accuracy': CategoricalAccuracy(),
'avg_precision': Precision(average=True),
'avg_recall': Recall(average=True)

Next, we will create two engines for evaluating the model using ignite.engine.create_supervised_evaluator:

from ignite.engine import create_supervised_evaluator
# Recall that device = “cuda” was defined above
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

We create two engines in order to val_evaluatorattach additional event handlers to one of them ( ) to save the model and stop training early (more on this later).

Let’s also take a closer look at how the engine for evaluating the model is defined, namely, how the input function is defined process_functionto process a single batch:

def create_supervised_evaluator(model, metrics={}, device=None):
if device:
def _inference(engine, batch):
with torch.no_grad():
x, y = _prepare_batch(batch, device=device)
y_pred = model(x)
return y_pred, y
engine = Engine(_inference)
for name, metric in metrics.items():
metric.attach(engine, name)
return engine

We continue further. Let us choose at random a part of the training sample, on which we will calculate the metrics:

import numpy as np
from import Subset
indices = np.arange(len(train_dataset))
random_indices = np.random.permutation(indices)[:len(val_dataset)]
train_subset = Subset(train_dataset, indices=random_indices)
train_eval_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,
drop_last=True, pin_memory="cuda" in device)

Next, let’s determine at what point in the training we will start the calculation of metrics and produce output to the screen:

def compute_and_display_offline_train_metrics(engine):
epoch = engine.state.epoch
print("Compute train metrics...")
metrics =
print("Trn Re - Epoch: {} Avg Loss: {:.4f} | Accu: {:.4f} | Precise: {:.4f} | Recall: {:.4f}"
def compute_and_display_val_metrics(engine):
epoch = engine.state.epoch
print("Compute validation metrics...")
metrics =
print("Valid Res - Epoch: {} Avg Loss: {:.4f} | Accuracy: {:.4f} | Precision: {:.4f} | Recall: {:.4f}"

You can run!

output =, max_epochs=1)

We get on the screen

Epoch[1] Iteration[50/322] Loss: 3.5112
Epoch[1] Iteration[100/322] Loss: 2.9840
Epoch[1] Iteration[150/322] Loss: 2.8807
Epoch[1] Iteration[200/322] Loss: 2.9285
Epoch[1] Iteration[250/322] Loss: 2.5026
Epoch[1] Iteration[300/322] Loss: 2.1944

Compute train metrics…

Train Rest - Epoch: 1 Average Loss: 2.1018 | Accuracy: 0.3699 | Precision: 0.3981 | Recall: 0.3686

Compute validation metrics…

The Results for Validation is – Epoch: 1 Avg Loss: 2.0519 | Accuracy: 0.3850 | Precision: 0.3578 | Recall: 0.3845

Better already!

A Few Details

Let’s take a deeper view of the previous code. The reader may have noticed the following line of code:

metrics =

and there was probably a question about the type of an object derived from a function has attribute metrics.

In fact, the class Engine contains a structure called state(type State) in order to be able to pass data between event handlers. This attribute state contains basic information about the current epoch, iteration, the number of epochs, etc. It can also be used to transfer any user data, including the results of calculating metrics.

state =
metrics = state.metrics

# or simply
metrics = train_evaluator.state.metrics

Calculating metrics during training

If the task has a huge training sample and the calculation of metrics after each training epoch is expensive, but at the same time I would like to see some metrics change during training, then you can use the following event handler out of the box RunningAverage. For example, we want to calculate and display the precision of a classifier:

acc_metric = RunningAverage(CategoryAccuracy(...), alpha=0.98)
acc_metric.attach(trainer, 'running_avg_accuracy')
def log_running_avg_metrics(engine):
print("running avg accuracy:", engine.state.metrics['running_avg_accuracy'])
To use the functionality RunningAverage, you need to install ignite from sources:
pip install git+

Learning rate scheduling with Ignite

There are several ways to change the learning rate with ignite. Next, let’s look at the simplest way, calling the function lr_scheduler.step()at the beginning of each epoch.

from torch.optim.lr_scheduler import ExponentialLR
lr_scheduler = ExponentialLR(optimizer, gamma=0.8)
def update_lr_scheduler(engine):
# Output of learning rate values:
if len(optimizer.param_groups) == 1:
lr = float(optimizer.param_groups[0]['lr'])
print("Learning rate: {}".format(lr))
for i, param_group in enumerate(optimizer.param_groups):
lr = float(param_group['lr'])
print("Learning rate (group {}): {}".format(i, lr))

Keeping the best models and other parameters during training

During training, it would be great to write the weights of the best model to disk, and also periodically store the model weights, optimizer parameters, and parameters for changing the learning rate. The latter can be useful in order to resume learning from the last saved state.

In ignite, there is a special class ModelCheckpoint. So, let’s create an event handler ModelCheckpointand keep the best model in terms of precision on the test set. In this situation, we need to declare a score_function that results in the precision value to the event handler and it decides whether to save the model or not:

from ignite. handlers import ModelCheckpoint
def score_function(engine):
val_avg_accuracy = engine.state.metrics['avg_accuracy']
return val_avg_accuracy
best_model_saver = ModelCheckpoint("best_models",

# “best_models” –

Folder where to save 1 or more of the best models

# File name ->


# save_as_state_dict=True, # Save as `state_dict`

{"best_model": model})

Now let’s create another event handler ModelCheckpointin order to save the training state every 1000 iterations:

training_saver = ModelCheckpoint("checkpoint",
to_save = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
trainer.add_event_handler(Events.ITERATION_COMPLETED, training_saver, to_save)

So, almost everything is ready, let’s add the last element:

Early stopping in Ignite

Let’s add another event handler that will stop training if there is no improvement in the quality of the model for 10 epochs. We will again evaluate the quality of the model using the function score_function.

from ignite. handlers import EarlyStopping
early_stopping = EarlyStopping(patience=10,
val_evaluator.add_event_handler(Events.EPOCH_COMPLETED, early_stopping)

Start training your Model

In order to start training, we just need to call the method run(). We’ll train the model for 10 epochs:

max_epochs = 10
output =, max_epochs=max_epochs)
Output on display
Now let's check the models and parameters saved to disk:
ls best_models/
ls checkpoint/

Trained model predictions using Ignite

First, let’s create a test data loader (for example, let’s take a validation sample) so that the batch of data consists of images and their indices:

class TestDataset(Dataset):
def __init__(self, ds):
self.ds = ds
def __len__(self):
return len(self.ds)
def __getitem__(self, index):
return self.ds[index][0], index
test_dataset = TestDataset(val_dataset)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
drop_last=False, pin_memory="cuda" in device)

Using ignite, we will create a new engine for predicting test data. To do this, we define a function inference_updatethat outputs the prediction result and the image index. To improve accuracy, we will also use the well-known “test time augmentation” (TTA) trick.

import torch.nn.functional as F
from ignite._utils import convert_tensor
def _prepare_batch(batch):
x, index = batch
x = convert_tensor(x, device=device)
return x, index
def inference_update(engine, batch):
x, indices = _prepare_batch(batch)
y_pred = model(x)
y_pred = F.softmax(y_pred, dim=1)
return {"y_pred": convert_tensor(y_pred, device='cpu'), "indices": indices}
inferencer = Engine(inference_update)

Next, we will create event handlers that will notify about the prediction stage and save the predictions to a dedicated array:

def log_tta(engine):
print("TTA {} / {}".format(engine.state.epoch, n_tta))
n_tta = 3
num_classes = 81
n_samples = len(val_dataset)

# Array for storing predictions

y_probas_tta = np.zeros((n_samples, num_classes, n_tta), dtype=np.float32)
def save_results(engine):
output = engine.state.output
tta_index = engine.state.epoch - 1
start_ind = ((engine.state.iteration - 1) % len(test_loader)) * batch_size
end_ind = min(start_ind + batch_size, n_samples)
batch_y_probas = output['y_pred'].detach().numpy()
y_probas_tta[start_ind:end_ind, :, tta_index] = batch_y_probas
Before starting the process, let's download the best model:
model = squeezenet1_1(pretrained=False, num_classes=64)
model.classifier[-1] = nn.AdaptiveAvgPool2d(1)
model =
model_state_dict = torch.load("best_models/model_best_model_10_val_accuracy=0.8730994.pth")
Launch:, max_epochs=n_tta)
> TTA 1 / 3
> TTA 2 / 3
> TTA 3 / 3

Next, in the standard way, we take the average of the TTA predictions and calculate the class index with the highest probability:

y_probas = np.mean(y_probas_tta, axis=-1)
y_preds = np.argmax(y_probas, axis=-1)
And now we can calculate again the accuracy of the model based on the obtained predictions:
from sklearn.metrics import accuracy_score
y_tes_tr = [y for _, y in val_data]
accuracy_score(y_test_true, y_preds)
> 0.9310369676443035

So, in this part, we showed how to calculate predictions using a trained model on a validation set. In fact, the example is very simple, but it should be clear from it how to use ignite for other and more complex situations.


In conclusion, I want to say that the ignite library is not an official product from Facebook and programmers take part in its development on a voluntary basis (for example, the author of this article). At the moment it is in version 0.1.0, but the main API (Engine, State, Events, Metric.) will, as far as possible, remain unchanged in future versions. Since the library is under active development, including additional modules, the developers will be glad to receive feedback, and error messages.

The media shown in this article are not owned by Analytics Vidhya and are used at the Author’s discretion.

About the Author

Our Top Authors

  • Analytics Vidhya
  • Guest Blog
  • Tavish Srivastava
  • Aishwarya Singh
  • Aniruddha Bhandari
  • Abhishek Sharma
  • Aarshay Jain

Download Analytics Vidhya App for the Latest blog/Article

Leave a Reply Your email address will not be published. Required fields are marked *