Finetune and Deploy Custom PaLiGemma Model for your Image Tasks

Akash Das 31 May, 2024
12 min read

Introduction

PaLiGemma is an open-source state-of-the-art model released alongside other products at Google I/O 2024 and combines two other models developed by Google. Based on open components like the SigLIP vision model and the Gemma language model, PaliGemma is a flexible and lightweight vision-language model (VLM) that draws inspiration from PaLI-3. It supports several languages and produces text output after accepting images and text as input. It is intended to serve as a model for various vision-language activities, including text reading, object identification and segmentation, visual question answering, and captioning images and short videos. 

In contrast to other VLMs that have had trouble with object detection and segmentation, notably OpenAI’s GPT-4o, Google Gemini, and Anthropic’s Claude 3, PaliGemma offers a wide variety of capabilities and can be finetuned for improved performance on particular tasks.

PaLiGemma Model

In today’s blog, we will learn the pipeline for fine-tuning the PaLiGemma model and deploying it over one of the service providers. Throughout the tutorial, we will use Roboflow for easy dataset access in the desired format, Kaggle for loading the model weights, and finally, Azure Virtual Machines. A Colab instance with an NVIDIA T4 GPU would be sufficient for the task.

Learning Objectives

In this blog, you will learn:

  • About the PaLiGemma model and its components.
  • How to set up the environment for fine-tuning PaLiGemma.
  • Data preparation techniques in JSONL format.
  • The process of downloading and configuring PaLiGemma model weights.
  • Steps for fine-tuning PaLiGemma and saving the fine-tuned model.
  • Deployment strategies for the fine-tuned model using Azure Virtual Machines.

This article was published as a part of the Data Science Blogathon.

Before we Begin

Before reading this blog, you should be familiar with Python programming and the training process for large language models (LLMs). Although not compulsory, having a rudimentary understanding of JAX (or related technologies like Keras) would be beneficial when examining the sample code snippets.

Also, for fine-tuning the PaLiGemma, we will follow the below steps:

  1. Install the required dependencies
  2. Download any image dataset in PaliGemma JSONL format
  3. Download pre-trained PaliGemma weights and tokenizer from Kaggle
  4. Finetune PaLiGemma using JAX
  5. Save our model for later use
  6. Deploy the finetuned model

Step 1: Install and Setup the Model

PaliGemma and Kaggle Setup

For first-time users, we must request PaLiGemma access through Kaggle and configure our API key, the steps of which are mentioned below.

  • Login or Sign Up on Kaggle: Log in to your Kaggle account or create a new one if you don’t have one.
  • Request Access to PaliGemma: Go to the PaLiGemma model card on Kaggle, click “Request Access,” complete the consent form, and accept the terms and conditions.
  • Generate Kaggle API Key: Open your Settings page on Kaggle and click “Create New Token” to download the `kaggle.json` file containing your API credentials.
  • Add Kaggle API Key to Colab: In Colab, select “Secrets” (🔑) in the left pane and add your Kaggle username and API key. Store your username under `KAGGLE_USERNAME` and your API key under `KAGGLE_KEY`.
  • Store Credentials Securely: Ensure your Kaggle API key is stored securely and only used as needed to access Kaggle datasets or models.
PaLiGemma Model

Once all is done, set the environment variables as shown below.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Fetch the big_vision repository and dependencies

To fine-tune the PaLiGemma model, we will use the big_vision project maintained by Google Research. The code below can install the repository and corresponding dependencies in your notebooks.

import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

Code Output

PaLiGemma Model

Import JAX and dependencies

The code below will import the necessary frameworks, like JAX, to complete the model setup.

import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece

from IPython.core.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.lib.xla_bridge.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")

Also read: PaliGemma: Google’s New AI Sees Like You and Writes Like Shakespeare!

Step 2: Chose suitable data for your task and prepare it in the JSONL format

For any finetuning tasks using PaLiGemma, we need that data in the PaLiGemma JSONL format. You might not be familiar with this format, as it is not a common data format (like YOLO) for image tasks, but JSONL (JSON Lines) is often used for training large models because it allows for efficient line-by-line processing. Below is an example of the JSONL format for data storage.

{"name": "John Doe", "age": 30, "city": "New York"}
{"name": "Jane Smith", "age": 25, "city": "Los Angeles"}
{"name": "Sam Brown", "age": 22, "city": "Chicago"}

Creating the data in a JSONL format is easy, and below, I am providing sample code to do the same. 

import json
import os

# Directory containing the images
image_dir = '/path/to/images'

# Dictionary containing the image labels
labels = {
    "image1.jpg": "label1",
    "image2.jpg": "label2",
    "image3.jpg": "label3"
}

# Create a list of dictionaries with image path and label
data = []
for image_name, label in labels.items():
    image_path = os.path.join(image_dir, image_name)
    data.append({"image_path": image_path, "label": label})

# Write the data to a JSONL file
with open('images_labels.jsonl', 'w') as file:
    for entry in data:
        file.write(json.dumps(entry) + '\n')

However, here we will use Roboflow for easy task achievement. Roboflow has already provided full support to the PaLiGemma JSONL format, which can be used to access any datasets from the Roboflow Universe. You can use any of the datasets according to your task requirements by using the Roboflow API key. Below is a code snippet showing how to achieve the same.

#Install the required dependencies to download and parse a dataset
!pip install roboflow supervision

from google.colab import userdata
from roboflow import Roboflow

ROBOFLOW_API_KEY = userdata.get('ROBOFLOW_API_KEY')

rf = Roboflow(api_key=ROBOFLOW_API_KEY)
project = rf.workspace("workspace-user-id").project("sample-project-name")
version = project.version(#enterversionnumber)
dataset = version.download("PaliGemma")

Now that we have successfully completed the model setup and imported the data in the desired format and platform, we can obtain the PaLiGemma weights to finetune the model further.

Step 3: Download and Configure PaLiGemma Model Weights

This step involves downloading the PaLiGemma weights from Kaggle. For easy computation in limited resources, we will use the paligemma-3b-pt-224 version. JAX/FLAX PaliGemma 3B is available in three different versions, differing in input image resolution (224, 448, and 896) and input text sequence length (128, 512, and 512 tokens, respectively).

The float16 version of the model checkpoint can be downloaded from Kaggle by running the following code. This process may be a bit time-consuming.

import os
import kagglehub

MODEL_PATH = "./pt_224_128.params.f16.npz"
if not os.path.exists(MODEL_PATH):
  MODEL_PATH = kagglehub.model_download
  ('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATPaLiGemma modelH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
  print("Downloading the dataset...")
  !gsutil -m -q cp -n -r gs://longcap100/ .
  print(f"Data path: {DATA_DIR}")

Code Output

PaLiGemma Model
PaLiGemma model weights being downloaded

The next step would require configuring and moving the model to fit with the Colab T4 GPU. To set up the model, start by initializing the `model_config` as a `FrozenConfigDict,` which helps freeze certain parameters and reduces memory usage. Then, create an instance of the `PaliGemma Model` class, using `model_config` for its settings. Load the model parameters into RAM and define a decode function to sample outputs from the model. Once done, the model can then be moved to the T4 GPU. The below code will guide both steps.

# Define model
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True,
     "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), 
eos_token=tokenizer.eos_id())

#Move model to T4 GPU
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,
                      params, trainable)

# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default.
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, 
trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)

Code Output

PaLiGemma Model
An overview of the PaLiGemma parameters after model download and configuration

This step has completed all the necessities for our fine-tuning process, so we can proceed to the subsequent step.

Also read: SynthID: Google is Expanding Ways to Protect AI Misinformation

Step 4: Finetuning PaLiGemma

Before proceeding to the fine-tuning step, a few more checks and preprocessing steps must be performed. These are standard procedures, and their codes would be long, so they are not considered in the current scope. Details of these can be found in additional open-source resources mentioned in subsequent sections. Regardless, a broad overview of the steps is mentioned below.

  1. Create Model Inputs
    • Normalize image data by converting images to greyscale, removing the alpha layer, and resizing them to 224×224 pixels.
    • Tokenize text by adding flags to mark whether tokens are prefixes or suffixes for use during training and evaluation.
    • Remove tokens after the end-of-sequence (EOS) token and return the remaining decoded tokens.
  2. Create Training and Validation Iterators
    • Define a training iterator to process data in chunks, shuffle examples, and repeat them for multiple epochs. Preprocess images and tokenize text with appropriate flags.
    •  Define a validation iterator to process validation data in an ordered manner, preprocess images, and tokenize text.
  3. View Training Examples
    • Display a random selection of training images and their descriptions to understand the data on which the model is being trained.
  4. Define Training and Evaluation Loops
    • Implement a stochastic gradient descent (SGD) training loop to optimize the model parameters. Calculate the loss per example, excluding prefixes and padded tokens from the loss calculation.
    • Implement an evaluation loop to make predictions on the validation dataset, handle padding for small datasets, and ensure only actual examples are counted in the output.

With all these steps done, we can now finetune the model. The below code will achieve the same. It runs the training loop for the model over 64 steps, displaying the learning rate (lr) and loss rate at each step. Every 16 steps, it outputs the model’s predictions for the same set of images, allowing you to observe the improvement in the model’s ability to predict descriptions. Early in the training, predictions may contain errors like repeated or incomplete sentences, but as training progresses, the accuracy of the descriptions improves. By step 64, the model’s predictions should closely match the descriptions from the training data. 

BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4

train_data_it = train_data_iterator()

sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)

for step in range(1, TRAIN_STEPS+1):
  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]

  # Convert list of examples into a dict of np.arrays and load onto devices.
  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  batch = big_vision.utils.reshard(batch, data_sharding)

  # Training step and report training loss
  learning_rate = sched_fn(step)
  params, loss = update_fn(params, batch, learning_rate)

  loss = jax.device_get(loss)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")

  if (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    html_out = ""
    for image, caption in make_predictions(
        validation_data_iterator(), num_examples=4, batch_size=4):
      html_out += render_example(image, caption)
    display(HTML(html_out))

You can now test the fine-tuned model using a pre-defined function called `make_predictions`, which processes images iteratively and performs inference on each one. This function can be used to test our fine-tuned object detection model.

print("Model predictions")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))

Below is a sample of the model outputs over each iteration. For the current purpose, the fineunting was done for 30 steps, as it was performed for a demo purpose. The dataset, number of steps, and other hyperparameters will also change based on your usage and requirements.

PaLiGemma Model
A comparison of the evolution in the responses as each finetuning step proceeded. Notice how the responses improved over each step.

Step 5: Saving the Finetuned Model

Once finetuning is completed and the model predictions have been checked, to use the same model further or to be able to deploy it for the later stages, it can be saved using the below code:

flat, _ = big_vision.utils.tree_flatten_with_names(params)
with open("/content/fine-tuned-PaliGemma-3b-pt-224.f16.npz", "wb") as f:
  np.savez(f, **{k: v for k, v in flat})

Step 6: Deploying the Finetuned Model

For deploying, we will rely on the Roboflow Inference server and deploy it on an AWS EC2 instance. The Roboflow Inference Server allows you to deploy computer vision models to various devices, including AWS EC2. The Inference Server relies on Docker to run. If you don’t already have Docker installed on the device(s) on which you want to run inference, install it by following the official Docker installation instructions. Once you have Docker installed, run the following command to download the Roboflow Inference Server on your AWS EC2.

pip install inference supervision

Now, the Roboflow Inference server will be running, and you can use the finetuned model in the EC2 server.

PaLiGemma Model

Conclusion

In this blog, we have walked through the comprehensive process of fine-tuning and deploying the PaLiGemma model, a cutting-edge vision-language model from Google. Starting with installing the necessary dependencies and setting up our environment, we leveraged various tools and platforms, including Kaggle for accessing model weights, Roboflow for dataset preparation, and Azure Virtual Machines for deployment. By following these steps, you can harness the power of PaLiGemma for a range of vision-language tasks such as object detection, image captioning, and visual question answering. I hope this guide provides a clear and practical pathway to enhance your projects with advanced AI capabilities.

References

In addition to this blog, here are a few more interesting reads and inspirations for this blog.

Key Takeaways

  • Integration of Advanced Models: PaLiGemma combines the capabilities of SigLIP and Gemma, providing a versatile and lightweight vision-language model that excels in multiple languages and tasks.
  • Enhanced Vision-Language Capabilities: Unlike many other VLMs, PaLiGemma effectively handles object detection and segmentation, making it a robust choice for various vision-language activities, including text reading, visual question answering, and image/video captioning.
  • Step-by-Step Fine-Tuning Process: The tutorial provides a detailed, step-by-step guide to fine-tuning PaLiGemma, covering essential steps such as setting up dependencies, preparing data, and configuring model weights using JAX.
  • Efficient Use of Resources: The tutorial demonstrates efficient resource management and practical deployment strategies by utilizing tools like Roboflow for dataset preparation, Kaggle for model weights, and Azure Virtual Machines for deployment.
  • Practical Application and Deployment: The guide culminates in deploying the fine-tuned model on an EC2 server, showcasing how to apply theoretical knowledge to practical situations and enabling users to leverage PaLiGemma’s capabilities in real-world scenarios.

The media shown in this article are not owned by Analytics Vidhya and is used at the Author’s discretion.

Frequently Asked Questions

 Q1. What prerequisites do I need to fine-tune and deploy the PaLiGemma model?

A. You must be familiar with Python programming and have experience training large language models (LLMs). Knowledge of JAX or Keras is beneficial for understanding the code snippets. Additionally, you’ll need access to Kaggle to download the model weights and datasets and an Azure account to deploy the model.

Q2. How do I access the PaLiGemma model and its weights?

A. First, log in to your Kaggle account and request access to the PaLiGemma model through its model card on Kaggle. Accept the terms and generate an API key from your Kaggle settings. Download the model weights using this API key and store it securely in your Colab instance to access the model.

Q3. What format should my dataset be in for fine-tuning PaLiGemma?

A. Your dataset should be in JSONL format, where each line in the file represents a JSON object. For example:
{"image_path": "/path/to/image1.jpg", "label": "label1"} {"image_path": "/path/to/image2.jpg", "label": "label2"}
You can use tools like Roboflow to prepare and download datasets in the required JSONL format.

Q4. How do I configure the PaLiGemma model to fit my specific training environment?

A. You need to set the model configuration to be compatible with your environment, such as a Colab T4 GPU. Load the model weights and tokenizer, and appropriately set up the model parameters and data sharding. Use JAX and the necessary libraries to prepare the model for training.

Q5. How can I deploy my fine-tuned PaLiGemma model using Azure?

A. After fine-tuning your model, save the model parameters. Set up an Azure Virtual Machine (VM) to host your model. Transfer the fine-tuned model to the VM and use Azure’s deployment services to make it accessible for inference. The specific deployment steps on Azure will depend on your VM configuration and preferred deployment method.

Akash Das 31 May, 2024

Frequently Asked Questions

Lorem ipsum dolor sit amet, consectetur adipiscing elit,

Responses From Readers

Clear