Tensorflow XLA: The Fusion Compiler for Tensorflow

Abhilash Majumder 09 May, 2023 • 17 min read


Tensorflow XLA (Accelerated Linear Algebra) is a compiler that can boost the execution speed of tensorflow kernels. This compiler optimizes kernels related to GEMM (generic matrix multiplications), activations, and other linear algebra computations. While executing models built with a Tensorflow core (or Keras), tensorflow computes the kernels individually through the executor on a specific device (GPU/XPU driver). However, the smaller kernels allocated to GPU registers take some time to get processed individually. XLA offers an innovative memory bandwidth-based optimization by combining different linear algebra/computation kernels together for processing on dedicated registers. XLA provides semantics for different backends as the compiled code is segregated into either x64 or ARM64 LLVM architectures. From the Python perspective, the code segment wrapped inside jit(just in time) compilation will trigger specific kernels for specific hardware devices. There are different backends associated with the XLA compiler, and a broad overview of the architecture is provided below (from Tensorflow documentation):

Table of Contents

Tensorflow Compiler Backends – XLA

Compilation of Tensorflow for XLA has several features depending on the modes of compilation as well as the underlying device. Although most of TF compilation takes its source from LLVM, it keeps most of the compiler optimizations including, but not limited to, assembling /disassembling compiler codes, generating objects, and linking them as analyzing bitcodes. TF Compiler uses different approaches that broadly influence the graph creation mechanism, node updates, and low-level optimization patterns. A broad depiction of the same is provided in the snapshot below:

"xla | tensorflow | xla compiler | GPU

Based on the runtime flexibility of the backend device, the default compiler pathway is either through MLIR or JIT(just in time compilation). Most of the programs which involve TF Variables (Placeholders) flow through the JIT compiler pathway to LLVM, where intermediate representation and object mapping at the graph level is performed.  The  TF compiler has different aspects, and XLA is one aspect of it. TF should be built from the source for using TF compiler semantics, as mentioned in the documentation. Particularly the most important commands for building (using bazel) would be adding GPU support (for cudnn/cublas).

bazel build –config=cuda [–config=option] //tensorflow/tools/pip_package:build_pip_package

For most of the tutorial, we will use Tensorflow C backend for XLA to understand the internal working of different operations and how the compute graph is created in Tensorflow across devices (cpu, GPU, and tpus). For the C API, the reader is assumed to have little faimilarity with C++ programming and tensorflow source setup. TF has provided a sample startup documentation on the basics of C backend here.

XLA Program to Compile MLIR HLO and Create Compute Graph

In the case of TF compilers, we will primarily focus on XLA. XLA  has different features to invoke different device runtimes – from CPU to TPU devices. Although MLIR has been denoted as a separate part of the TF compiler, XLA uses MLIR dialects to create compute graphs. For instance, simple multiplication and addition of 2 tensors require MLIR HLO input data types to be created and fed to the XLA compiler. This compiler is entrusted with different dialects to optimize the symbolic compute graph produced. Nearly all compute-bound operations of deep learning (all compute kernels) running on the XLA compiler create this MLIR HLO data to compile and create the TF graph (the backend network graph) that we see in the tensorboard.  As per Tensorflow, examples of a simple computation of ax+y  through the XLA compiler looks like so:

"xla | tensorflow | xla compiler | GPU

In this case, alpha, x, and y are tensors (FP 32) that are created as nodes in the compute graph. The MLIR representation of the above compute graph, which gets fed into the XLA compiler at runtime, is as follows:

func.func @main(
  %alpha: tensor<f32>, %x: tensor<4xf32>, %y: tensor<4xf32>
) -> tensor<4xf32> {
  %0 = stablehlo.broadcast_in_dim %alpha, dims = []
    : (tensor<f32>) -> tensor<4xf32>
  %1 = stablehlo.multiply %0, %x : tensor<4xf32>
  %2 = stablehlo.add %1, %y : tensor<4xf32>
  func.return %2: tensor<4xf32>

We create the tensor memory allocation (including size and datatype) in MLIR and use broadcast dialect to broadcast the input tensors to the respective XLA devices – cpu or CUDA etc. It is important to note that the broadcast is generally on the master rank of the “X”PU card (in the case of a card consisting of a single tile). Now for CPU devices, there is only a single master rank, whereas if we have multiple GPUs (each having a different tile configuration), the master rank is from rank 0.  All compute operations, when applied to the scale of multi-cluster GPUs /TPUs, follow collective communication principles and algorithms.

Primarily these collective algorithms govern data transfer across GPUs, and CPU-GPU and mediate all computing across devices; XLA uses these collective algorithms for any kind of computation arising from tensorflow (training or inference). We will be covering that in a while, but first, let us view how to write an XLA script to use the MLIR dialects for performing the ax+y computation. We will be using tensorflow c api for fine-grained analysis, and like most C++ files, we need to have a set of header files. These are the important includes that are needed:

#include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
#include "mlir/IR/DialectRegistry.h"  // from @llvm-project
#include "mlir/Parser/Parser.h"  // from @llvm-project
#include "stablehlo/dialect/Register.h"  // from @stablehlo
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/stream_executor/stream_executor.h" //plugin gpu streamer

As this suggests, we will be using the tensorflow XLA framework for our headers. The XLA directory has different parts, and we will be using the client API. Each of these headers is required for different purposes such as “pjrt” is for plugging XLA code into third-party devices such as CUDA or Intel/AMD GPUs or TPUs (“pjrt” stands for “Pretty much Just another RunTime”); “stream_executor” is used for data transfer through streams to the respective devices. Mostly for CUDA it becomes SIMT (single instruction multiple threads).  The next part of the program is split into 13 broad steps.

Broad Steps

1. The first step is to create an instance of XLA client which will interface with the MLIR

//1. Setup client
LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie();

2.  The second step is to instantiate a device: this device is the offload device on which the computing will happen. In this case, we create a class for CUDA if Google cuda is supported and instantiate an instance of the “Platform*” datatype. Now “Platform” is the datatype to use in TF when specifying which device to use:


class GetCUDAPlatform{
    Platform* getPlatform(){
        return *MultiPlatformManager::PlatformWithName("CUDA");


string platform_name;
//cpu or CUDA
TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform,
se::StreamExecutorConfig config;
config.ordinal = 0;
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor,

3. Set up “LocalDeviceState” and “PjRtStreamExecutorDevice.” The latter describes the state of a device that can do computation or transfer buffers. This could represent a CPU/GPU or accelerator. While using different GPUs, callback_stream might be needed.

auto device_state = make_unique<LocalDeviceState>(
      executor, local_client, LocalDeviceState::kSynchronous,
      allow_event_reuse=false, use_callback_stream=false);

auto device = make_unique<PjRtStreamExecutorDevice>(
      0, move(device_state), platform_name);

vector<unique_ptr<PjRtStreamExecutorDevice>> devices;

4. The next step involves setting up “PjRtStreamExecutorClient”  which allows us to compile and execute compute operations on the device. The following lines have some arguments to specify the host to device transfers (cpu-gpu)  or an accelerator memory allocator and how to run the GPU instructions.

auto pjrt_se_client = PjRtStreamExecutorClient(
    platform_name, local_client, move(devices), process_index=0,
    allocator=nullptr, host_memory_allocator=nullptr,
    should_stage_host_to_device_transfers=false //(use this true for h2d memcpy),
    gpu_run_options=nullptr//use this  when cuda activated

5 & 6. The next step is where we plugin the MLIR dialect that we highlighted for the operation ax+y. We will plug the float 32 (fp32) dialect datatypes (tensors) to be used by XLA. In this case, we specify the MLIR file where the XLA compiler will read at runtime for the dialects. The stablehlo directive inside MLIR will then register those dialects (operations, operators, and operands for computing).

string program_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "samples", 
                      "axpy", "stablehlo_axpy.mlir");
string program_string;

TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), program_path, &program_string));

cerr << "Loaded StableHLO program from " << program_path << ":\n"<< program_string << endl;

//6. Register MLIR dialects for TD parsing
// Register MLIR dialects necessary to parse our program. In our case this is
// just the Func dialect and StableHLO.
mlir::DialectRegistry dialects;

7-9. The next steps involve parsing the HLO program (MLIR) using MLIR contexts and compiling the HLO program into an executable. We also create containers (variables for input and output)  for our computation.

auto ctx = make_unique<mlir::MLIRContext>(dialects);
mlir::OwningOpRef<mlir::ModuleOp> program =mlir::parseSourceString<mlir::ModuleOp>
                                           (program_string, ctx.get());

//8. Use our client to compile our StableHLO program to an executable.
                          pjrt_se_client.Compile(*program, CompileOptions{}));

//9. Create inputs/containers to our computation.
auto alpha_literal = xla::LiteralUtil::CreateR0<float>(3.14f);
auto x_literal = xla::LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto y_literal =  xla::LiteralUtil::CreateR1<float>({10.5f, 20.5f, 30.5f, 40.5f});

cerr << "Computation inputs:" << endl;
cerr << "\talpha:" << alpha_literal << endl;
cerr << "\tx:" << x_literal << endl;
cerr << "\ty:" << y_literal << endl;

10 – 12. This contains the compute part, which is done after creating buffers and transferring the variables created in the previous section. First, we get the host device, transfer our literals /variables, including compiled HLO/MLIR to buffers to the corresponding device (CPU/CUDA), and use the “PjRt” streamer to execute the computation. Any kind of compute operation requires this particular code block to ensure computation is done through the buffers by PjRt.

//10. Get the host device.
PjRtDevice* device = pjrt_se_client.devices()[0];

//11. Transfer our literals to buffers. 
//If we were using a GPU, these buffers would correspond to device memory.
TF_ASSERT_OK_AND_ASSIGN(unique_ptr<PjRtBuffer> alpha,
      pjrt_se_client.BufferFromHostLiteral(alpha_literal, device));
TF_ASSERT_OK_AND_ASSIGN(unique_ptr<PjRtBuffer> x,
                          pjrt_se_client.BufferFromHostLiteral(x_literal, device));
TF_ASSERT_OK_AND_ASSIGN(unique_ptr<PjRtBuffer> y,
                          pjrt_se_client.BufferFromHostLiteral(y_literal, device));

//12. Do our computation.
    vector<vector<unique_ptr<PjRtBuffer>>> axpy_result,
    executable->Execute({{alpha.get(), x.get(), y.get()}}, /*options=*/{}));

13. The nominally last step involves transferring back the buffers to variables /literals as outputs of the computation which is of datatype : “xla:LiteralTestUtil” (if we would like to test the output result). This is our output of the computation.

TF_ASSERT_OK_AND_ASSIGN(shared_ptr<Literal> axpy_result_literal,

//14. Testing to  check to make sure that our results match what we expect.
xla::LiteralTestUtil::ExpectR1Near<float>({13.64f, 26.78f, 39.92f, 53.06f},
cerr << "Computation output: " << *axpy_result_literal << endl;

Creating Different Activation Graphs in MLIR for XLA

Now that we have seen the steps to create MLIR dialects for compiling XLA, we can create different activation functions. So when we write “tf.nn.softmax” or “tf.nn.relu”, it is client-side HLO (MLIR) that invokes and creates the dialects to be compiled by XLA, which inturn prepares the compute graph.  Keeping by the standard snippet which we created above, we can create an HLO MLIR snippet for tanh activation (“tf.nn.tanh”) as follows:

func.func @main(
  %x: tensor<f32>
) -> tensor<f32> {
  %tanh = stablehlo.tanh %x: tensor<f32>
  func.return %tanh: tensor<f32>

Make Changes in Standard XLA Code in Parts 5-12

  • Allocate the new tanh MLIR code which requires only a single tensor, and create PjRt stream executor.
  • Create literals for the parsed MLIR dialect of tanh and assign them to the runtime device.
  • Compute tanh in buffers after converting the literal.
  • Transfer the buffers back to the literals as output.
string program_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), 
                      "samples", "tanhx", "stablehlo_tanhx.mlir");
string program_string;

             program_path, &program_string));

cerr << "Loaded StableHLO program from " << program_path << ":\n"
<< program_string << endl;

//6. Register MLIR dialects for TD parsing
// Register MLIR dialects necessary to parse our program. 
//In our case this is
// just the Func dialect and StableHLO.
mlir::DialectRegistry dialects;

//7. Parse StableHLO program.
auto ctx = make_unique<mlir::MLIRContext>(dialects);
mlir::OwningOpRef<mlir::ModuleOp> program = mlir::parseSourceString<mlir::ModuleOp>
                                            (program_string, ctx.get());

//8. Use our client to compile our StableHLO program to an executable.
TF_ASSERT_OK_AND_ASSIGN(unique_ptr<PjRtLoadedExecutable> executable,
                          pjrt_se_client.Compile(*program, CompileOptions{}));

//9. Create inputs/containers to our computation.
auto x_literal = xla::LiteralUtil::CreateR0<float>(3.14f);

cerr << "Computation inputs:" << endl;
cerr << "\tx:" << x_literal << endl;

//10. Get the host device.
PjRtDevice* device = pjrt_se_client.devices()[0];

//11. Transfer our literals to buffers. 
//If we were using a GPU, these buffers would correspond to device memory.
TF_ASSERT_OK_AND_ASSIGN(unique_ptr<PjRtBuffer> alpha,
      pjrt_se_client.BufferFromHostLiteral(x_literal, device));

//12. Do our computation.
    vector<vector<unique_ptr<PjRtBuffer>>> tanh_result,
    executable->Execute({{x.get()}}, /*options=*/{}));

//13. Convert result buffer back to literal.
TF_ASSERT_OK_AND_ASSIGN(shared_ptr<Literal> tanh_result_literal,

cerr << "Computation output: " << *tanh_result_literal << endl;

Similarly, we can manually create different functions for different XLA devices in low-level MLIR code (which tensorflow does for us) .

XLA Collective Calls GPU in Tensorflow

We have seen the C++/MLIR dialect infusion with the XLA compiler for generating compute graph of tensorflow. Now we will be focusing on collective operations, which are a set of algorithms that govern the data flow pipeline across hosts and devices (cpus to “x”pus ). There is a concept of ranks in the case of a multi tile /multi-card GPU or tpu system. Ranks indicate the allocation of different processes spawn by the CPU. For instance, while training BERT Large on GPUs, or TPUs, we only delegate the resources in terms of data and model to the TPUs, and all collective calls get invoked by them. Let’s view this from a Pythonic environment syntax, where we load a “Transformer” model into a Google TPU.

A Minor Case Analysis on  TPUs

At approximately 20 inches (50 cm), a TPU v3-8 board is a fairly sizeable piece of hardware. It sports four dual-core TPU chips for a total of 8 TPU cores. Each TPU core has a traditional vector processing part (VPU) as well as dedicated matrix multiplication hardware capable of processing 128×128 matrices. This is the part that specifically accelerates machine learning workloads.TPUs are equipped with 128GB of high-speed memory allowing larger batches, larger models, and also larger training inputs.

"xla | tensorflow | xla compiler | GPU

Now assuming this hardware architecture from TPUs, we have 4 TPU chips in a socket and 2 cores per chip . This implies 8 cores are present in the TPU chipset. This should imply a rank ordering of 0-7. Ranks in a device are specified by the number of tiles available on which a process can run. So if there are n such tiles /cores, it implies a maximum of 0- (n-1) ranks can be instantiated. Generally, frameworks like TF use these ranks for different collective calls, and the distributed training of large models happens in this way.

The rank 0 is the master who is responsible for synchronizing all processes performing their own computing. This implies that rank 7 may be doing a “multiplication” OP in MLIR/XLA graph, whereas rank 6 can be waiting on the result of rank 7 and perform an “addition” OP with input from rank 5. This is mediated by rank 0 or the master. Below is the standard Python code to instantiate TPUs and run a BERT/variant model:

import tensorflow as tf

    # TPU detection.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy()

print("REPLICAS: ", strategy.num_replicas_in_sync)
#allow experimental tf
AUTO = tf.data.experimental.AUTOTUNE

# Configuration of hyperparameters
#batch size denotes the partitioning amongst the cluster replicas.
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
MAX_LEN = 192

The strategy scope (“get_strategy”) defined here contains the collective algorithms call, which we will be looking at in depth. XLA relies on these collective calls for data transfers and computing. Now for the BERT training code, it is simply loading the model in the scope of the XLA-based collective algorithm on TPU (the way how data will be circulated across TPUs from the CPU) and instantiating a run:

#get tokenized text from inputs and feed  tokens/cls inside transformer
def build_model(transformer, max_len=512):
    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    sequence_output = transformer(input_word_ids)[0]
    cls_token = sequence_output[:, 0, :]
    out = Dense(1, activation='sigmoid')(cls_token)
    model = Model(inputs=input_word_ids, outputs=out)
    model.compile(Adam(lr=1e-5), loss='binary_crossentropy', metrics=['accuracy'])
    return model
#define strategy for TPU can be cyclic reduce or hierarchical
with strategy.scope():
    transformer_layer = (
    model = build_model(transformer_layer, max_len=MAX_LEN)

#Standard training loop in Tensorflow
n_steps = train_x.shape[0] // BATCH_SIZE
train_history = model.fit(

The entire code and detailed explanation are in my Kaggle notebook.

Collective Calls in Tensorflow XLA

The standard collective calls for data transfer in case of large-scale model training (such as BERT large or sharded GPT on TPUs/GPUs) include :

Broadcast: The Broadcast operation copies an N-element buffer on the root rank to all ranks.  This is generally done at the start of model creation and data loading, which initializes CPU, CPU-GPU buffers.


The syntax for initiating “broadcast” calls in XLA is like so:

XlaBuilder b("broadcast");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
Broadcast(XlaOp x,absl::Span<const int64> 32);

AllReduce: The AllReduce operation performs reductions on data (for example, sum, max) across devices and writes the result in the receive buffers of every rank. The AllReduce operation is rank-agnostic. Any reordering of the ranks will not affect the outcome of the operations. This is explicitly used when we have a multi-card system and want to perform faster computation operations such as intermediate matrix multiplications (GEMM) for softmax (fused) activations with some adaptive optimizers like ADAM.



XlaBuilder c("allreduce")
auto b = CreateSubBuilder("sum");
auto x = b->Parameter(/*parameter_number=*/0, scalar_shape, "x");
auto y = b->Parameter(/*parameter_number=*/1, scalar_shape, "y");
if (scalar_shape.element_type() == PRED) {
      Or(x, y);
} else {
      Add(x, y);
TF_ASSIGN_OR_RETURN(auto computation, b->Build());
AllReduce(XlaOp operand,const XlaComputation& computation,
         absl::Span<const ReplicaGroup> replica_groups,const absl::optional<ChannelHandle>& absl::nullopt);

In this case, we are performing a “special” all reduce, which performs a cross-replica summation of the data across shards/replicas. But the XLA syntax is similar to “broadcast.”

Reduce: The Reduce operation is performing the same operation as AllReduce, but writes the result only in the receive buffers of a specified root rank.

"Reduce |xla | tensorflow | xla compiler | GPU

The syntax for reduce is similar to allreduce XLA collective:

XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
             absl::Span<const int64> dimensions_to_reduce);

There are different variations of reduction algorithms, such as reduce-scatter, reduce a window, etc which are beyond the scope of this tutorial.

AllGather: In the AllGather operation, each of the K processors aggregates N values from every processor into an output of dimension K*N. The output is ordered by rank index.

"xla | tensorflow | xla compiler | GPU

When combined with all gather, the variation of reduce, called reduce scatter, is equivalent to all reduce XLA operations that we saw earlier. Similar syntax as before for XLA:

XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, 
               int64_t shard_count,absl::Span<const ReplicaGroup> replica_groups = {},
               const std::optional<ChannelHandle>& channel_id = std::nullopt, 
               const std::optional<Layout>& layout = std::nullopt,
               const std::optional<bool> use_global_device_ids = std::nullopt);

These are the different major algorithms that govern data transfer across hosts and devices. Most of the details can be found in Tensorflow documentations. The documentation provides an outline of almost all algorithms that TF uses for its compute and collective operations through XLA.

Creating a Custom XLA Call to GPU (Softmax Activation)

Most of the code written for GPU devices involving cuda kernels is often segregated into device and host scopes, ensuring which part of the program gets segmented and called on which device. In Cuda, the general norm is to specify a device (GPU) runnable segment of the code (can be a function) using the “__global__” declaration. For instance, the same “ax+y” code on arrays using a cuda device:

void example_func(int n, float a, float *x, float *y)
  int i = blockIdx.x*blockDim.x + threadIdx.x;
  if (i < n) y[i] = a*x[i] + y[i];
  std::copy(std::begin(y), std::end(y), std::ostream_iterator<int>(std::cout, " "));

Now let us create a similar function in XLA and create softmax activation of inputs. We will not need MLIR in this case, as we will be using the XLA Builder class directly; the MLIR backend will, in turn, be created by Tensorflow.

The first step is to design the softmax activation (iteratively) using C++ and a float * array. To recollect, the softmax activation appears like so:


After some mathematical simplification by dividing the numerator and denominator by e^m (m being the max of the distribution z), and taking absolute logarithm on both sides, the softmax activation can be re-written as :


Where the sum is denoted as :


Illustration of the Concept


void softmax(const float* input) {

    size_t size = (sizeof(input) / sizeof(double));
    int i;
    float m, sum, constant;
    m = -INFINITY;
    for (i = 0; i < size; ++i) {
        if (m < input[i]) {
            m = input[i];
    sum = 0.0;
    for (i = 0; i < size; ++i) {
        sum += exp(input[i] - m);
    constant = m + log(sum);
    for (i = 0; i < size; ++i) {
        input[i] = exp(input[i] - constant);


The next step is to write a CPU code for XLA using the “XLABuilder.h” class. The XLA code will require knowing the dimension of the input vector or array as well as the buffer size of the output. To initiate this, we first create the XLA Builder object as follows:

#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"

void compute_softmax() {
  xla::XlaBuilder b("compute_softmax");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {128}));

So in the first case, we include the XLA compiler client API, which invokes the XLA Builder class. The “compute_softmax” method is a CPU-only method, and while initializing an XLABuilder instance, we have to write a name – “compute_softmax” (for example). The next step is to declare the variables/literals since XLA needs to know the input and output buffer size requirements. The “param0” indicates the input array of size 128, and “custom_call” calls the “do_custom_call” function (defined below). This later function has the flexibility to run both in CPU or GPU based on the absence or presence of the “__global__” keyword, respectively (as mentioned in the previous section).

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);


  for (int i = 0; i < 128; ++i) {
    out_buf[i] = in0[i] ;
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

Here we pass in the input array (const float* in0) to the “softmax” C++ function created before. The output buffer is then assigned from the results of the function and assigned to the “out_buf” buffer.  The next line specifies and binds the XLA code to the specific device – “Host” or the CPU.

Now let us write the GPU part of this code, and here we need to access the Block Dimensions and Thread Ids. For a general introduction to CUDA memory placement of blocks/dims, the introduction to CUDA manual can be referred to. In this case, we will parallelize the compute to the GPU Device by signalling calls from the CUDA kernel to the same softmax function. We just have to make a change in the softmax function definition by adding the “__device__” keyword as follows:

__device__ void softmax(const float* input) {

    size_t size = (sizeof(input) / sizeof(float));
    int i;
    float m, sum, constant;
    m = -INFINITY;
    for (i = 0; i < size; ++i) {
        if (m < input[i]) {
            m = input[i];
    sum = 0.0;
    for (i = 0; i < size; ++i) {
        sum += exp(input[i] - m);
    constant = m + log(sum);
    for (i = 0; i < size; ++i) {
        input[i] = exp(input[i] - constant);


__global__ custom_call_kernel(const float* in0, float* out) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx]= in0[idx];


We see the same function with the “__device__” scope applied to ensure that the CPU is in control of synchronization whilst the GPU signals the different threads to run the softmax computation. This is one way of creating the kernel; there is another way how we can entirely put the softmax computation code inside the custom_call_kernel (CUDA) as so:

__global__ custom_call_kernel(const float* in0, float* out) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int N = (sizeof(in0) / sizeof(float));
    int i;
    float m, sum, constant;
    m = -INFINITY;
    if(idx<N) {
        if (m < in0[idx]) {
            m = in0[idx];
    sum = 0.0;
    if(idx<N) {
        sum += exp(in0[idx] - m);
    constant = m + log(sum);
    if(idx<N) {
        out[idx] = exp(in0[i] - constant);


Although this is an unoptimized way to write kernels (we should use reductions), this gives a brief overview of writing custom CUDA kernels for XLA. The next part is to allocate and create the Blocks and Grids for CUDA:

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  float* out = reinterpret_cast<float*>(buffers[1]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 128 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, out);

We allocate the “CUstream “ cuda streams, provide the input buffers, and assign the block dimensions. Since we have taken 128 as the buffer size, we have grid dimensions as 128/64, which implies two grids – grid 0 and grid 1.  This completes the tutorial regarding creating custom XLA kernels for CPU and GPU  for Tensorflow to use CUDA/CuBLAS. The important thing to is that TF XLA relies on preallocating memory buffers for computation. Create a clone with any of the collective operations (Reduce/AllGather).

XLA in TF Training

In this terminal section, we will see briefly how on an application level, we can use TF XLA using Python (just like we create tensorflow models in Python). This relies on AOT (ahead of time compilation), which we will not go into depth about. TF Compilers follow JIT and AOT for different XLA compute devices at runtime. When using the XLA compiler with JIT, the general syntax is as follows:

def train_code(data,labels):
    #load data
    #Create Gradient tape
    with tf.GradientTape() as tape:
      #Pass through layer:
      #Perform cross entropy loss
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
    #Perform Backprop through gradients
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    #Update Optimizer to apply gradients
    optimizer.apply_gradients(zip(grads, layer_variables))

This is a generic training loop in tensorflow code, where we specify the inputs and create the scope of “GradientTape”. The next steps involve passing the input through the model or layer (tf.nn.layers) and then performing backprop and optimizer updates. The important thing to note is the “@tf.function(jit_compile=True).” TF uses this declaration of “jit_compile” to ensure that the XLA compute backend is getting triggered. The XLA then creates the backend C++ code and MLIR dialects that we studied in the previous section. Depending on the device of allocation, the XLA will create separate kernels and trigger separate collective calls through its “XLABuilder” API. Another part of XLA relies heavily on AOT while dealing with large-scale auto clustering in TPUs, which is beyond the scope of this tutorial.


The XLA backend is thus a powerful tool in TF for optimizing different computations across devices. It deals with memory allocation, data transfers, sharding, fusing computations, and optimizing node-level graph dependencies.

A brief overview of its abilities in Tensorflow C API, through which most of the crude optimizations can happen manually. TF XLA is in development and has several additions lined up for better performance.

Frequently Asked Questions

Lorem ipsum dolor sit amet, consectetur adipiscing elit,

Responses From Readers

  • [tta_listen_btn class="listen"]