Learn everything about Analytics

Home » Porting a Pytorch Model to C++

Porting a Pytorch Model to C++

This article was published as a part of the Data Science Blogathon.


In this article, we are going to see different ways how we can port a Pytorch Model to C++.  Pytorch is usually used for research and prototyping new models and systems. The framework is flexible and imperative and therefore easy to use. The main thing is how we can port a Pytorch Model into a more suitable format that can be used in production.


We will look into different pipelines how Pytorch Model can be ported in C++ with a more suitable format that can be used in production.

1) TorchScript

2) ONNX (Open Neural Network Exchange)

3) TFLite (Tensorflow Lite)



TorchScript is an intermediate representation of a PyTorch Model (subclass of nn.Module) that can be run in a high-performance environment such as C++.  It helps to create serializable and optimizable models. After training these models in python, they can be independently run in python or in C++.  So, one can easily train a model in PyTorch using Python and then export the model via torchscript to a production environment where Python is not available.  It basically provides a tool to capture the definition of the model.

Tracing a Module :

class DummyCell(torch.nn.Module):
    def __init__(self):
        super(DummyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)
    def forward(self, x):
        out = self.linear(x)
        return out

dummy_cell = DummyCell()
x =  torch.rand(2, 4)
traced_cell = torch.jit.trace(dummy_cell, (x))

# Print Traced Graph

# Print Traced Code
Here, torchscript has invoked the module, recorded the operations that were performed into an intermediate representation known as a graph. traced_cell.graph provides a very low-level representation and most of the information in the graph is not useful for the end-users. tracel_cell.code provides more of a python-syntax interpretation of the code.

Output of above code (traced_cell.graph and traced_cell.code) :

graph(%self.1 : __torch__.DummyCell,
      %input : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %16 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %18 : Tensor = prim::CallMethod[name="forward"](%16, %input)
  return (%18)
def forward(self,
    input: Tensor) -> Tensor:
  return (self.linear).forward(input, )

Advantages of TorchScript

1) TorchScript code can be invoked in its own interpreter. The saved graph can also be loaded in C++ for production.

2) TorchScript provides us a  representation in which we can do compiler optimizations on the code to provide more efficient execution.

ONNX (Open Neural Network Exchange)

ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators, the building block of machine learning and deep learning models and a common file format which enables AI developers to use models with a variety of framework, tools, runtimes, and compilers.  It defines an extensible computation graph model, as well as definitions of built-in operators and standard data types.

One can export the above DummyCell Model into onnx using the following code :

torch.onnx.export(dummy_cell, x, "dummy_model.onnx", export_params=True, verbose=True)

Output :

graph(%input : Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %linear.weight : Float(4, 4, strides=[4, 1], requires_grad=1, device=cpu),
      %linear.bias : Float(4, strides=[1], requires_grad=1, device=cpu)):
  %3 : Float(2, 4, strides=[4, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%input, %linear.weight, %linear.bias)
  return (%3)

It saves a model into a file name “dummy_model.onnx” which can be loaded using python module onnx. For inference in python, one can use the ONNX Runtime. ONNX Runtime is a performance-focused engine for ONNX Models, which inferences efficiently across multiple platforms and hardware. Check here for more details on performance.


Inferencing in C++

To execute the ONNX models from C++, first, we have to write the inference code in Rust, using the tract library for execution. Now, we have the rust library for inferencing ONNX Models. We can use now use cbindgen to export the rust library as a public C Header. Now this header along the generated shared or static library from Rust can be included in C++ to inference the ONNX Models. We while generating a shared library from rust can also provide many optimization flags according to different hardware. Cross-Compilations for different hardware types are also easily possible from Rust.


Tensorflow Lite

Tensorflow Lite is an open-source deep learning framework for on-device inference. It is a set of tools to help developers run Tensorflow models on mobile, embedded, and IoT devices. It enables on-device machine learning inference with low latency and small binary size. It has two main components:

1) Tensorflow Lite Interpreter:  It runs specially optimized models on many different hardware types, including mobile phones, embedded Linux devices, and microcontrollers.

2) Tensorflow Lite Converter: It converts TensorFlow models into an efficient form for use by the interpreter.

The main pipeline to convert a PyTorch model into TensorFlow lite is as follows:

1) Build the PyTorch Model

2) Export the Model in ONNX Format

3) Convert the ONNX Model into Tensorflow (Using onnx-tf )

Here we can convert the ONNX Model to TensorFlow protobuf model using the below command:

!onnx-tf convert -i "dummy_model.onnx" -o  'dummy_model_tensorflow'

4) Convert the Tensorflow Model into Tensorflow Lite (tflite)

The tflite model (Tensorflow Lite Model) now can be used in C++. Please refer here to how to perform inference on tflite model in C++. 

End Notes:

I hope you find the article useful. We tried to explain in brief the different ways we can deploy the models trained in PyTorch to production. 

We will be further creating a detailed article on each of the above steps, Inferencing a TorchScript Model in C++, Inferencing on an ONNX Model in C++ using tract lib, and also on converting a PyTorch model into tflite, and inferencing in C++. 

If you have feedback/questions please do share that with us. You can also connect to me on Linkedin, and we can have a chat on any applications of the above topics. Thanks.

References :

1) Introduction to TorchScript

2) Loading a TorchScript Model in C++ 

3) Exporting a Pytorch Model to ONNX

4) Tract Neural Network Inference ToolKit in Rust

5) Running Inferencing on TfLite Models in C++

6) Colab – Pytorch Trained Model on Android Device

The media shown in this article are not owned by Analytics Vidhya and is used at the Author’s discretion. 

You can also read this article on our Mobile APP Get it on Google Play