In this post I share results of a weekend project around fine tuning BART and T5 Flan models for sequence to sequence generation. I have used common misspellings in English language (single words) for training and evaluating the models. As a benchmark I have first trained and evaluated a pre-trained checkpoint of BART and then followed with LoRA based fine-tuning for BART and Flan T5 XXL.
Data set used: Peter Norvig’s Common English Misspellings Data Set
Model Type : AutoModelForSeq2SeqLM
PEFT Task Type: SEQ_2_SEQ_LM
Model | Colab Code Pointers | Checkpoint | PEFT Type | #Params | #Trainable-Params | |
BART | bart-seq-2-seq | “facebook/bart-base” | NA (full model fine tuning) | 255 million | 255 million | |
BART | bart-lora | “facebook/bart-base” | LoRA | 255 million | 884.7k (0.41%) | |
BART | bart-8bit-lora | “facebook/bart-base” – 8 bit | LoRA | 255 million | 884.7k (0.41%) | |
T5-Flan-XXL | t5-flan-xxl-lora | “philschmid/flan-t5-xxl-sharded-fp16” – 8 bit | LoRA | ~11.15 billion | 18.87 Milion (0.17%) | |
BART VS T5
Bart[1] and T5[2] are both have Seq2Seq[3] model architecture. They both uses Encoder-Decoder style architecture, where Encoder is like BERT’s encoder and Decoder is autoregressive. Essential it is a de-noising auto-encoder that maps a corrupted document to the original document (using an auto-regressive decoder).
BART uses a variety of input text corruption mechanism and a Seq2Seq model to reconstruct output. Bart’s authors evaluated a number of noising approaches for pre-training e.g.

- Token Masking: Random tokens are sampled and replaced with [MASK]
- Token Deletion – Random tokens are deleted from the input. This helps model lean which positions are missing input.
- Text Infilling: A number of text spans are sampled with span lengths drawn from a Poisson distribution. Each span is replaced with a single [MASK] token. Text Infilling is inspired from SpanBERT. It teaches the model to predict how many tokens are missing from a span.
- Sentence Permutation: A document is divided into sentences based on full stops and these sentences are shuffled at random.
- Document Rotation : A token is chosen uniformly at random and the document is rotated so that it begins with that token.
T5-Flan has same number of parameters as T5(Text-to-Text Transfer Transformer) but it is fine-tuned on more than 1000 additional tasks covering also more languages.
#Layers | #Parameters | Vocab Size | Embedding Size | ||
BART-Base | 6 Encoder 6 Decoder | 255 million | ~50k | 1024 | |
T5-Large | 14 Encoder 14 Decoder | 780 million | ~32k | 4096 | |
T5-Flan-XXl | 23 Encoder (T5 Blocks) 23 Decoder (T5 Blocks) | 11.15 Billion | ~32k | 4096 |
T5 Model has following Modules (Huggingface implementation 8-bit checkpoint)
a. Shared Embedding Layer (Encoder and Decoder modules share this layer) of size (32128, 4096)
b. Encoder T5 Block – Attention Layer and Feed Forward Layer
T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(k): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(v): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(o): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(relative_attention_bias): Embedding(32, 64)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
(wi_1): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
(wo): Linear(in_features=10240, out_features=4096, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): GELUActivation()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
c. Decoder T5 Blocks –
i. Self Attention Layer
ii. Cross Attention Layer
iii. Feed Forward Layer
T5Block(
(layer): ModuleList(
(0): T5LayerSelfAttention(
(SelfAttention): T5Attention(
(q): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(k): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(v): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(o): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(relative_attention_bias): Embedding(32, 64)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(1): T5LayerCrossAttention(
(EncDecAttention): T5Attention(
(q): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(k): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(v): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
(o): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
(2): T5LayerFF(
(DenseReluDense): T5DenseGatedActDense(
(wi_0): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
(wi_1): Linear8bitLt(in_features=4096, out_features=10240, bias=False)
(wo): Linear(in_features=10240, out_features=4096, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
(act): GELUActivation()
)
(layer_norm): T5LayerNorm()
(dropout): Dropout(p=0.1, inplace=False)
)
)
Tokenization
For BART we follow standard sequence to sequence tokenization scheme. In the example below we tokenize input text and label text.
def bart_preprocess_function(sample,padding="max_length"):
# tokenize inputs
model_inputs = tokenizer(sample["misspelled_queries"], max_length=max_source_length, padding=padding, truncation=True)
# Tokenize targets with the `text_target` keyword argument
labels = tokenizer(text_target=sample["correct_spelling"], max_length=max_target_length, padding=padding, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
Where as for T5-Flan-XXL we would have to append the input text with a context prompt
def t5_preprocess_function(sample, max_source_length = 12, max_target_length= 12, padding="max_length"):
# add prefix to the input for t5
inputs = ["correct spelling of following word : " + item for item in sample["misspelled_queries"]]
# tokenize inputs
model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
# Tokenize targets with the `text_target` keyword argument
labels = tokenizer(text_target=sample["correct_spelling"], max_length=max_target_length, padding=padding, truncation=True)
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
# padding in the loss.
if padding == "max_length":
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["input_ids"]
return model_inputs
Adapter Based Fine-tuning
Once we have downloaded an 8 bit model checkpoint we use LoRA adapter from Huggingface PEFT module and use it as a wrapper
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
# Define LoRA Config
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"], # for BART use ["q_proj", "v_proj"]
lora_dropout=0.05,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM
)
# prepare int-8 model for training
model = prepare_model_for_int8_training(model) #skip this for BART
# add LoRA adaptor
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
32 bit checkpoint vs 8 bit checkpoint
I also tried comparing performance of PEFT Bart with 32 bit checkpoint as well as 8 bit checkpoint. The results were almost comparable, although inferencing on 8 bit checkpoint was reasonably faster.
Additional improvements for 8 bit check point involved using FP16 for layernorm layers and FP32 for the LM head for model stability during training.
import torch
for param in model.parameters():
# param.requires_grad = False # freeze the model - train adapters later
if param.ndim == 1:
# cast the small parameters (e.g. layernorm) to fp32 for stability
param.data = param.data.to(torch.float16)
# model.gradient_checkpointing_enable() # reduce number of stored activations
# model.enable_input_require_grads()
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)
A recalcitrant issue was with using gradient_checkpointing with the above approach. If I remember it correctly somehow Huggingface Trainer API throw following error.
Performance Results (Test Data)
The table below shows performance of different fine-tuned models w.r.t. misspelled query and corrected spelling edit distance.
Edit-Distance | Accuracy Bart | Accuracy Bart – LoRA | Accuracy Bart 8 bit – LoRA | Accuracy T5-Flan-XXL – LoRA | |||
1 | 49.34 | 32.12 | 31.79 | 53.79 | |||
2 | 52.55 | 18.82 | 18.82 | 42.75 | |||
3 | 48.60 | 14.04 | 14.02 | 35.51 | |||
4 | 48.76 | 11.57 | 12.40 | 33.88 | |||
5 | 27.27 | 5.45 | 3.64 | 12.73 | |||
6 | 39.39 | 9.09 | 9.09 | 18.18 |
References
- BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
- Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
- Sequence to Sequence Learning with Neural Networks
- T5-Flan : Scaling Instruction-Finetuned Language Models