Adapter Based Fine Tuning BART And T5-Flan-XXL For Single Word Spell Correction

In this post I share results of a weekend project around fine tuning BART and T5 Flan models for sequence to sequence generation. I have used common misspellings in English language (single words) for training and evaluating the models. As a benchmark I have first trained and evaluated a pre-trained checkpoint of BART and then followed with LoRA based fine-tuning for BART and Flan T5 XXL.

Data set used: Peter Norvig’s Common English Misspellings Data Set

Model Type : AutoModelForSeq2SeqLM

PEFT Task Type: SEQ_2_SEQ_LM

ModelColab Code PointersCheckpointPEFT Type#Params#Trainable-Params
BARTbart-seq-2-seq“facebook/bart-base”NA
(full model fine tuning)
255 million255 million
BARTbart-lora“facebook/bart-base”LoRA255 million884.7k
(0.41%)
BARTbart-8bit-lora“facebook/bart-base” – 8 bitLoRA255 million884.7k
(0.41%)
T5-Flan-XXLt5-flan-xxl-lora“philschmid/flan-t5-xxl-sharded-fp16” – 8 bitLoRA~11.15 billion18.87 Milion
(0.17%)
Fine Tuned Models

BART VS T5

Bart[1] and T5[2] are both have Seq2Seq[3] model architecture. They both uses Encoder-Decoder style architecture, where Encoder is like BERT’s encoder and Decoder is autoregressive. Essential it is a de-noising auto-encoder that maps a corrupted document to the original document (using an auto-regressive decoder).

BART uses a variety of input text corruption mechanism and a Seq2Seq model to reconstruct output. Bart’s authors evaluated a number of noising approaches for pre-training e.g.

  • Token Masking: Random tokens are sampled and replaced with [MASK]
  • Token Deletion – Random tokens are deleted from the input. This helps model lean which positions are missing input.
  • Text Infilling: A number of text spans are sampled with span lengths drawn from a Poisson distribution. Each span is replaced with a single [MASK] token. Text Infilling is inspired from SpanBERT. It teaches the model to predict how many tokens are missing from a span.
  • Sentence Permutation: A document is divided into sentences based on full stops and these sentences are shuffled at random.
  • Document Rotation : A token is chosen uniformly at random and the document is rotated so that it begins with that token.

T5-Flan has same number of parameters as T5(Text-to-Text Transfer Transformer) but it is fine-tuned on more than 1000 additional tasks covering also more languages.

#Layers#ParametersVocab SizeEmbedding Size
BART-Base6 Encoder
6 Decoder
255 million ~50k1024
T5-Large14 Encoder
14 Decoder
780 million ~32k4096
T5-Flan-XXl23 Encoder (T5 Blocks)
23 Decoder (T5 Blocks)
11.15 Billion ~32k4096

T5 Model has following Modules (Huggingface implementation 8-bit checkpoint)

a. Shared Embedding Layer (Encoder and Decoder modules share this layer) of size (32128, 4096)

b. Encoder T5 Block – Attention Layer and Feed Forward Layer

T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (k): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (v): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (o): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (relative_attention_bias): Embedding(32, 64)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
              (wo): Linear(in_features=10240, out_features=4096, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
              (act): GELUActivation()
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )

c. Decoder T5 Blocks –

i. Self Attention Layer

ii. Cross Attention Layer

iii. Feed Forward Layer

T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (k): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (v): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (o): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (relative_attention_bias): Embedding(32, 64)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerCrossAttention(
            (EncDecAttention): T5Attention(
              (q): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (k): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (v): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (o): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (2): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
              (wo): Linear(in_features=10240, out_features=4096, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
              (act): GELUActivation()
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )

Tokenization

For BART we follow standard sequence to sequence tokenization scheme. In the example below we tokenize input text and label text.


def bart_preprocess_function(sample,padding="max_length"):

    # tokenize inputs
    model_inputs = tokenizer(sample["misspelled_queries"], max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(text_target=sample["correct_spelling"], max_length=max_target_length, padding=padding, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

Where as for T5-Flan-XXL we would have to append the input text with a context prompt

def t5_preprocess_function(sample, max_source_length = 12, max_target_length= 12, padding="max_length"):
    # add prefix to the input for t5
    inputs = ["correct spelling of following word : " + item for item in sample["misspelled_queries"]]

    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(text_target=sample["correct_spelling"], max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length":
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

Adapter Based Fine-tuning

Once we have downloaded an 8 bit model checkpoint we use LoRA adapter from Huggingface PEFT module and use it as a wrapper

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType

# Define LoRA Config
lora_config = LoraConfig(
 r=16,
 lora_alpha=32,
 target_modules=["q", "v"], # for BART use ["q_proj", "v_proj"]
 lora_dropout=0.05,
 bias="none",
 task_type=TaskType.SEQ_2_SEQ_LM
)
# prepare int-8 model for training
model = prepare_model_for_int8_training(model) #skip this for BART 

# add LoRA adaptor
model = get_peft_model(model, lora_config) 
model.print_trainable_parameters()

32 bit checkpoint vs 8 bit checkpoint

I also tried comparing performance of PEFT Bart with 32 bit checkpoint as well as 8 bit checkpoint. The results were almost comparable, although inferencing on 8 bit checkpoint was reasonably faster.

Additional improvements for 8 bit check point involved using FP16 for layernorm layers and FP32 for the LM head for model stability during training.

import torch
for param in model.parameters():
  # param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float16)

# model.gradient_checkpointing_enable()  # reduce number of stored activations
# model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
  def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

A recalcitrant issue was with using gradient_checkpointing with the above approach. If I remember it correctly somehow Huggingface Trainer API throw following error.

Performance Results (Test Data)

The table below shows performance of different fine-tuned models w.r.t. misspelled query and corrected spelling edit distance.

Edit-DistanceAccuracy BartAccuracy
Bart – LoRA
Accuracy
Bart 8 bit – LoRA
Accuracy
T5-Flan-XXL – LoRA
149.3432.1231.7953.79
252.5518.8218.8242.75
348.6014.0414.0235.51
448.7611.5712.4033.88
527.275.453.6412.73
639.399.099.0918.18

References

  1. BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
  2. Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
  3. Sequence to Sequence Learning with Neural Networks
  4. T5-Flan : Scaling Instruction-Finetuned Language Models

About Siddharth Sharma

Interested in NLP, Retrieval & Ranking Models, Content Understanding and Predictive Analytics.
This entry was posted in Uncategorized and tagged , , , , , . Bookmark the permalink.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s