Fine-tune Llama 3 using Direct Preference Optimization

Sunil Kumar Dash 02 May, 2024 • 10 min read

Introduction

Large Language Models have revolutionized productivity by enabling tasks like Q&A, dynamic code generation, and agentic systems. However, pre-trained vanilla models are often biased and can produce harmful content. To improve performance, algorithms like Reinforcement Learning with Human Feedback and Direct Preference Optimization (DPO) can be used. This article focuses on RLHF methods and the implementation of DPO using Unsloth, highlighting the importance of these methods in improving the quality and effectiveness of models in various tasks.

Learning Objectives

  • Understand the significance of fine-tuning large language models (LLMs) for specific tasks and applications.
  • Differentiate between RLHF (Reinforcement Learning with Human Feedback) and DPO (Direct Preference Optimization) as fine-tuning approaches for LLMs.
  • Identify the pros and cons of both RLHF and DPO methods in the context of LLM fine-tuning.
  • Explore open-source tools available for implementing DPO fine-tuning, such as TRL (Transformers for Reinforcement Learning) library, Axolotl, and Unsloth.
  • Learn the steps involved in fine-tuning a Llama 3 8B model using DPO with Unsloth, including data preparation, model installation, training, and inference.
  • Understand the key parameters and hyperparameters involved in training a model using DPO.
  • Gain insights into the benefits of using Llama 3 models, such as reduced memory footprint and improved inference speed.

This article was published as a part of the Data Science Blogathon.

What is Llama 3?

The Llama 3 is a family of open-source models recently released by Meta. The model family consists of pre-trained and instruction-tuned chat models with 8B and 70B parameters. Since its release, the model has been well-received by the OSS community. The models have performed well in various benchmarks like MMLU, HUMANEVAL, MATH, etc. The small 8B model especially has outperformed many bigger models. This makes the model ideal for personal uses and edge deployment.  However, many use cases require the models to be fine-tuned on a custom dataset to perform well. So, let’s understand what is RLHF and DPO then implement it.

What is RLHF?

RLHF is an alignment technique usually applied after the Supervised Fine-tuning process to drill down certain types of behavior into a base model. For example, the model can be trained to refuse to respond to harmful texts or to avoid hate speeches. This is an important step before releasing the models to the public. Big companies like Google, Meta, and OpenAI spend enormous resources to align the models before releasing them in the wild. 

How Does RLHF Work?

The RLHF technique is a two-step process that involves training a reward model on preference data and fine-tuning the base model with reinforcement learning. The preference dataset is a highly curated dataset of accepted and rejected responses from foundational language models. A human data annotator ranks each response to add variability. The reward model is then trained or fine-tuned on the preference data, which can be the same model or different language model. It is possible to use a traditional classification model as well.

RHLF work

The next step is to fine-tune the base model using RL. Traditionally in RLHF, the PPO (Proximal Policy Optimization) algorithm is used to update the parameters of the base model based on a reward function. In PPO, we have an initial language model, a policy model that will be fine-tuned, and the reward model from the previous step.

The preference dataset prompts the RL policy model to generate responses, which are then fed to the initial base model to calculate the relative KL penalty. The KL penalty measures the difference between one probability distribution and another, ensuring the policy model doesn’t drift far from the base model. The formula for calculating the KL penalty is given.

"

In the next step, the reward model assigns preference scores to the responses from the RL policy model. After this, the parameters of the RL policy model are updated by maximizing the reward function. The reward function is the sum of the preference score and KL penalty.

"

From here onwards the policy model can be updated iteratively.

Direct Preference Optimization

While RLHF using PPO has upsides like greater flexibility to incorporate various types of feedback, the implementation can be unstable. Here are some of the pros and cons of the RLHF fine-tuning method.

Pros of RLHF

  • RLHF provides greater control over fine-tuning as it allows for designing nuanced reward models.
  • RLHF can also accommodate diverse reward formats such as numbers, implicit feedback, and textual corrections.
  • RLHF can be beneficial when the model needs to be trained over massive data.

Cons of RLHF

  • Training and fitting a reward model can be challenging both technically and computationally.
  • While it allows diverse feedback to guide LLMs, it is often unstable and less reliable than DPO.

What is Direct Preference Optimization?

Direct Preference Optimization is a fine-tuning technique that aims to improve on the shortcomings of PPO. DPO simplifies the RLHF by eliminating the need for reward modeling and training the model via RL-based optimization. Instead, it directly optimizes the language model based on human preference data. This is achieved by using pairwise comparisons of model outputs, where human evaluators choose which of two responses they prefer for a given prompt. The feedback from these comparisons is used directly to guide the training of the language model. We can also use responses from better models as preferred and weaker models as rejected to fine-tune base models.

Direct Preference Optimization uses a reference model instead of a reward model, aiming to output a higher probability for preferred responses and a lower probability for rejected responses. This approach is more stable and efficient than PPO-based RLHF, as it bypasses extensive reward model training and fitting processes.

DPO

Pros of Direct Preference Optimization

  • DPO is straightforward when it comes to implementation. There is no need to train a separate reward model.
  • Besides being easy to implement, DPO is also more stable and predictable. The models can be reliably guided toward a particular goal.
  • DPO is more resource-efficient as it directly operates on the LLM.

Cons of Direct Preference Optimization

  • DPO does not provide the flexibility of a complex reward mechanism design.
  • It can not work with diverse feedback like RLHF. DPO relies on binary feedback format.

Open-source Tools for DPO Training

Now, let’s explore the open-source tools for implementing DPO. There are many ways you can implement DPO using different open-source tools.

  • TRL: The most popular is through Huggingface’s TRL library. It has all the bells and whistles for efficient DPO fine-tuning. As it is from Huggingface, you can integrate other libraries from Huggingface seamlessly. 
  • Axolotl: If you do not want to bother yourself with Python codes, there is another open-source tool called Axolotl. Instead of writing the codes in Python, it lets us define all the parameters and hyper-parameters in a YAML file. This makes it much easier to manage the fine-tuning process. It wraps the TRL library underneath, hence we can use all of its functionality but in a cleaner way. 
  • Unsloth: Another open-source tool that lets you fine-tune LLMs optimally. The library implements CUDA-optimized custom triton kernels for faster training and inferencing. It also leaves a lesser memory footprint during model training. We will use Unsloth for the DPO fine-tuning the Llama 3 8B model.

So, let’s implement Direct Preference Optimization fine-tuning on the Llama 3 model using Unsloth.

DPO Fine-tuning with Unsloth

Let us now explore DPO fine-tuning with unsloth. We need to go through certain steps.

Step1: Install Dependencies

Before moving ahead, install the dependencies. We will install Unsloth from their git repository, flash-attention, trl, and Wandb for logging. Optionally you can install deep speed for distributed training across GPUs.

import torch
major_version, minor_version = torch.cuda.get_device_capability()
# Must install separately since Colab has torch 2.2.1, which breaks packages
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
if major_version >= 8:
    # Use this for new GPUs like Ampere, Hopper GPUs (RTX 30xx, RTX 40xx, A100, H100, L40)
    !pip install --no-deps packaging ninja einops flash-attn xformers trl peft \
    accelerate bitsandbytes
else:
    # Use this for older GPUs (V100, Tesla T4, RTX 20xx)
    !pip install --no-deps xformers trl peft accelerate bitsandbytes
pass

Step2: Set Key in Local Environment

Now, set WANDB_API_KEY in your local environment.

import os
os.environ['WANDB_API_KEY'] = "your_api_key"

Step4: Data Preparation

We will use the Orca DPO dataset from Intel for alignment through DPO. As we learned before, a DPO dataset has a prompt column, a column for selected answers, and a prompt for rejected answers.

Direct Preference Optimization

This is a small dataset, you can use other DPO datasets like Argilla’s ultra-feedback preference data

The data is perfect for DPO tuning. We can load the data using Huggingface’s dataset library. Change the column name question to prompt as TRL’s DPOTrainer requires it. We will also need to split Train and Test data.

from datasets import load_dataset
dataset = load_dataset("Intel/orca_dpo_pairs", split = "train")

dataset = dataset.rename_column('question','prompt')

dataset_dict = dataset.train_test_split(test_size=0.04)

Step5: Install Llama 3

We will now install Llama 3 instruct quantized model from Unsloth. This will take a few moments. The 4-bit quantized model is around 5.76 GB. The script below will install and load the model on the GPU.

from unsloth import FastLanguageModel
import torch
max_seq_length = 4096
dtype = None 
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8B-instruct-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

Step6: Load LoRA Adapters

We can now load all the required LoRA adapters to the Llama model. We will only update some 1-10% of the total parameters. Setting gradient checkpointing to “unsloth” allows 30% less memory use and accommodates 2x larger batch sizes.

model = FastLanguageModel.get_peft_model(
    model,
    r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 64,
    lora_dropout = 0, 
    bias = "none",    
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  
    loftq_config = None, 
)
  • r: The “r” stands for the rank of the low-rank adapters. A higher rank will increase the number of trainable parameters; which improves the model’s adaptability to data. At the same time, this increases computing requirements.
  • lora_alpha: This is similar to “learning rate”. This modulates the effect of training update matrices on the original weight of the models.
  •  target_modules: The target modules here represent the layers of the model architecture to which the updates will be applied. 

Step7: Define LoRA Hyper-parameters

Now, define all the training arguments and hyperparams for the model training. But before that patch the DPOTrainer. This is only needed if you are doing it in a Notebook. This enhances the model logging in a Jupyter Notebook. Ignore the step if you are not on an IPython Notebook.

from unsloth import PatchDPOTrainer
PatchDPOTrainer()

Log in to your Weights and Biases profile.

import wandb
wandb.login()

Now define the LoRA hyper-parameters.

from transformers import TrainingArguments
from trl import DPOTrainer
import wandb

project_name = "llama3" 
entity = "wandb"
# os.environ["WANDB_LOG_MODEL"] = "checkpoint"

wandb.init(project=project_name, name = "mistral-7b-instruct-DPO-1")
dpo_trainer = DPOTrainer(
    model = model,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 3,
        warmup_ratio = 0.1,
        num_train_epochs = 1,
        learning_rate = 5e-6,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        #max_steps=20,
        optim = "adamw_8bit",
        weight_decay = 0.0,
        lr_scheduler_type = "linear",
        seed = 42,
        report_to="wandb",  # enable logging to W&B
        output_dir = "outputs",
    ),
    beta = 0.1,
    train_dataset = dataset_dict["train"],
    eval_dataset = dataset_dict["test"],
    tokenizer = tokenizer,
    max_length = 1024,
    max_prompt_length = 512,
)

Here’s a quick breakdown of all the key training arguments used above.

  • per_device_train_batch_size: The batch size per device during training.
  • gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update pass.
  • warmup_ratio: The ratio of the total training steps where learning rate linearly ramps up to its maximum value.
  • num_train_epochs: The number of complete passes through the training dataset.
  • optim: The type of optimizer used. Here, it is an adamw 8-bit.
  • lr_scheduler: This parameter adjusts the learning rate during training. A linear scheduler linearly adjusts the value of the learning rate.

Now, start training.

dpo_trainer.train()

This will kick-start model fine-tuning. If you encounter an out-of-memory (OOM) error try reducing training batch size and accumulation steps. You can visualize the training run in the Notebook or observe it from your Wandb profile. 

Step8: Inferencing

Once the training is finished save the LoRA model.

model.save_pretrained("lora_model")

You can now load the LoRA model and start asking questions.

from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "lora_model",
        max_seq_length = 512,
        # dtype = dtype,
        load_in_4bit = True,
    )
FastLanguageModel.for_inference(model) 

We can define a transformer pipeline for inferencing.

import transformers
message = [
    {"role": "system", "content": "You are a helpful assistant chatbot."},
    {"role": "user", "content": "What is a Large Language Model?"}
]

prompt = tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)

# Create pipeline
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Generate text
sequences = pipeline(
    prompt,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    eos_token_id=terminators,
    num_return_sequences=1,
    max_length=200,
)

print(sequences[0]['generated_text'][len(prompt):])

You may also wrap it in a Gradio chat interface using the below script.

import gradio as gr

messages = []

def add_text(history, text):
    global messages  #message[list] is defined globally
    history = history + [(text,'')]
    messages = messages + [{"role":'user', 'content': text}]
    return history, ""

def generate(history):
  global messages
  prompt = pipeline.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
)

  terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

  outputs = pipeline(
    prompt,
    max_new_tokens=256,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9,
)
  response_msg = outputs[0]["generated_text"][len(prompt):]
  for char in response_msg:
      history[-1][1] += char
      yield history
  pass

with gr.Blocks() as demo:

    chatbot = gr.Chatbot(value=[], elem_id="chatbot")
    with gr.Row():
            txt = gr.Textbox(
                show_label=False,
                placeholder="Enter text and press enter",
            )

    txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
            generate, inputs =[chatbot,],outputs = chatbot,)

demo.queue()
demo.launch(debug=True)
Direct Preference Optimization

Conclusions

Llama 3 from Meta has proven to be very capable, especially the small 8B model. It can be run on cheaper hardware and fine-tuned to adhere to particular use cases. But to make them commercially viable, we may need to fine-tune them for custom use cases. This article discussed fine-tuning techniques like RLHF, DPO, and implementation of DPO using Unsloth. Here are the key takeaways from the article.

Key Takeaways

  • RLHF is an alignment technique usually applied to models post-training. PPO is the most widely used RLHF method that allows for model alignment with greater control.
  • DPO has emerged as an effective alternative to PPO training. It bypasses PPO’s complicated and unreliable workflow by using a replica of the model itself as the reward model.
  • DPO can be implemented using open-source tools like Unsloth, HuggingFace TRL and Transformer, Axolotl, etc.
  • Unsloth has custom-optimized implementations of popular language models which help reduce training time, and memory footprints and improve model inferencing.

Frequently Asked Questions

Q1. What is LLM fine-tuning?

A. Fine-tuning, in the context of machine learning, especially deep learning, is a technique where you take a pre-trained model and adapt it to a new, specific task.

Q2. Can I fine-tune LLMs for free?

A. Yes, it is possible to fine-tune small LLMs using Colab’s T4 GPU and QLoRA.

Q3. What are the benefits of LLM fine-tuning?

A.  Fine-tuning greatly improves the capabilities of pre-trained base models on multiple tasks like coding, reasoning, math, etc.

Q4. What is DPO?

A. DPO or Direct Preference Optimization is an alignment technique to adapt models to new data. It uses a reference model and a preference dataset to fine-tune the base model,

Q5. What is Llama 3?

A. Llama 3 is a family of models from Meta. There are four models in total, 8B and 70B models with pre-trained and instructor-tuned models.

The media shown in this article is not owned by Analytics Vidhya and is used at the Author’s discretion.

Sunil Kumar Dash 02 May 2024

Frequently Asked Questions

Lorem ipsum dolor sit amet, consectetur adipiscing elit,

Responses From Readers

Clear