Skip to content

A small repo to experiment with Transformer (and more) architectures.

Notifications You must be signed in to change notification settings

Datta0/nanoformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NanoFormer

NanoFormer is a lightweight transformer model implementation designed for efficient training and inference. It is a collection of transformer architectures (and a few variants) that can be easily experimented with

Features

  • Supports transformers variants like:
  • Dynamic batch size handling with efficient padding
  • Mixed precision training (bfloat16)
  • Gradient checkpointing for memory efficiency
  • Gradient accumulation support
  • Wandb integration for experiment tracking
  • Automatic model checkpointing
  • Custom training loop with validation
  • Torch compile support

Installation

git clone https://github.com/yourusername/nanoformer.git
cd nanoformer

Usage

Training

To train the model with default parameters:

python train.py \
    --dataset "imdatta0/wikipedia_en_sample" \
    --batch_size 8 \
    --gradient_accumulation_steps 16 \
    --num_epochs 1 \
    --lr 5e-4 \
    --hidden_dim 256 \
    --num_hidden_layers 8 \
    --attention_type="gqa" \ # or [one of "diff", "ngpt"]
    --compile
    # --logit_cap = None 
    # --logit_scale = 1.0

To estimate the number of tokens in a dataset and the model's param count with given config: (will need to refactor this to not create the model for estimation)

python train.py \
    --dataset "imdatta0/wikipedia_en_sample" \
    --batch_size 8 \
    --gradient_accumulation_steps 16 \
    --num_epochs 1 \
    --lr 5e-4 \
    --hidden_dim 256 \
    --num_hidden_layers 8 \
    --estimate

Note: If you use --compiple, it might take a while to get the initial batch started. So the tqdm ETA estimates might be off. The GPU utilisation will pick up after a few batches and training speeds up. So please be patient.

TODO

  • Implement Differential Transformer
  • Implement nGPT
  • Implement custom optimisers like Shampoo, SOAP and whatnot
  • Add support for Sliding Window Attention
  • Modify configs to be closer to Chinchilla Optimal Ratios

WIP

  • Inference support
  • Loading from checkpoint

About

A small repo to experiment with Transformer (and more) architectures.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published