An end-to-end skill that trains a 9.6M-parameter ESM-2 "mini" architecture on Swiss-Prot — from raw sequences to deployed model weights, zero-shot fitness evaluation, and GitHub release.
A compact Transformer encoder trained with masked language modeling (MLM) on protein sequences. The skill mirrors the ESM-2 "mini" configuration — 12 layers, hidden dimension 256, 8 attention heads — scaled down for single-card training while retaining meaningful zero-shot capabilities.
The skill automates every step: data download from UniProt, tokenizer construction, MLM training with three-tier checkpointing, and model upload to GitHub Release.
python scripts/download_data.py
data/swissprot_train.fasta (433,583 seqs) · data/swissprot_val.fasta (22,821 seqs)python train.py \
--data data/swissprot_train.fasta \
--val_data data/swissprot_val.fasta \
--out_dir output \
--epochs 5 \
--batch_size 32 \
--device cuda
python scripts/evaluate_fitness.py \
--checkpoint output/checkpoint_final_best.pt
python scripts/upload_to_github.py \
--token ghp_xxxx \
--repo junior1p/ESM2-small \
--tag v1.0.0
gh release create v1.0.0 && gh release upload v1.0.0 output/checkpoint_final_best.ptTrained on a single MLU370 accelerator for 67,500 steps (~11 hours). The model achieves competitive zero-shot fitness prediction on GFP mutations despite its compact size.
output/
├── config.json # training hyperparameters
├── checkpoint_epoch1.pt # epoch snapshots
├── checkpoint_epoch1_best.pt
├── checkpoint_epoch2.pt
├── checkpoint_epoch2_best.pt
├── checkpoint_epoch3.pt
├── checkpoint_epoch3_best.pt
├── checkpoint_epoch4.pt
├── checkpoint_epoch4_best.pt
├── checkpoint_epoch5.pt
├── checkpoint_final.pt # final epoch
├── checkpoint_final_best.pt # best val loss ← USE THIS
├── checkpoint_step2000.pt # periodic (every 2000 steps)
├── checkpoint_step4000.pt
└── ...
python train.py --resume output/checkpoint_step20000.pt --data data/swissprot_train.fasta --val_data data/swissprot_val.fasta| Mutation | Type | Expected Delta | Interpretation |
|---|---|---|---|
K7V | Neutral | ~0 | Conservative substitution, minimal structural impact |
K7I | Neutral | Negative | Hydrophobic substitution at surface position |
G66Y | Brighter | Large negative | Unexpected direction — aromatic at small residue site |
G66H | Dimer | Large negative | Histidine at dimer interface, disrupts packing |
--data and --val_data when resuming: python train.py --resume output/checkpoint_*.pt --data data/swissprot_train.fasta --val_data data/swissprot_val.fasta
The complete executable skill file used by AI agents. Reproduces the full training pipeline from raw data to deployed model.
---
name: train-esm2-small
description: End-to-end training of ESM2-small (9.6M-parameter protein language model) on Swiss-Prot — data download, tokenization, training, checkpointing, evaluation, and model upload to GitHub. Works on GPU (CUDA) and Cambricon MLU370.
version: 1.0.0
author: Max
license: MIT
dependencies: [torch>=2.0, tqdm, requests]
metadata:
hermes:
tags: [protein language model, ESM-2, masked language modeling, MLU370, protein engineering, PyTorch, MLM]
repo: https://github.com/junior1p/ESM2-small
---
# Train ESM2-small: Protein Language Model from Scratch
Train a compact 9.6M-parameter ESM-2 architecture on Swiss-Prot protein sequences end-to-end.
## When to Use This Skill
- Training a protein language model from scratch
- Evaluating zero-shot mutation prediction (fitness)
- Adapting the ESM-2 architecture to new protein datasets
- Setting up checkpointing and resume for long training runs
- Uploading trained models to GitHub Release
## Quick Start
# 1. Download data
python scripts/download_data.py
# 2. Train (GPU)
python train.py --data data/swissprot_train.fasta --val_data data/swissprot_val.fasta \
--out_dir output --epochs 5 --batch_size 32 --device cuda
# 3. Evaluate zero-shot fitness
python scripts/evaluate_fitness.py --checkpoint output/checkpoint_final_best.pt
## Training Pipeline
### Data Download
python scripts/download_data.py
Downloads Swiss-Prot curated protein sequences from UniProt, splits into train/val, and saves FASTA files.
- Output: data/swissprot_train.fasta, data/swissprot_val.fasta
- Train: 433,583 sequences | Val: 22,821 sequences
- Truncation: sequences > 512 tokens are truncated at C-terminus
### Tokenizer
31-token amino acid vocabulary:
- 20 standard amino acids (A, R, N, D, C, Q, E, G, H, I, L, K, M, F, P, S, T, W, Y, V)
- Special: [MASK], [PAD], [CLS], [SEP], [UNK]
### Architecture
ESM2-small mirrors ESM-2 "mini" (12-layer Transformer):
| Parameter | Value |
|---|---|
| Layers | 12 |
| Hidden dim | 256 |
| Attention heads | 8 |
| FFN dim | 1024 |
| Vocab size | 31 |
| Max length | 512 |
| Total params | 9,624,607 |
### Training Configuration
| Parameter | Default |
|---|---|
| Optimizer | AdamW (lr=1e-4, beta=(0.9, 0.999), eps=1e-8, weight_decay=0.01) |
| LR schedule | Linear warmup 1000 steps -> cosine decay to ~1e-9 |
| Batch size | 32 sequences |
| Masking | 15% uniform random |
| Mixed precision | FP32 weights, BF16 forward/backward (MLU370) |
| Epochs | 5 (~2h/epoch on single MLU370, ~1h/epoch on A100) |
| Steps per epoch | ~13,500 |
| Throughput | ~30K tokens/sec (single card) |
| Total tokens | ~1.1 billion |
### Checkpointing
Three checkpoint types are saved automatically:
1. Periodic (every 2000 steps): full snapshot for resume
2. Epoch (end of each epoch): history checkpoint
3. Best (when val loss improves): *_best.pt copy
Checkpoint contents:
- model_state_dict: model weights
- optimizer_state_dict: AdamW state
- lr_scheduler_state_dict: cosine schedule position
- rng_state: torch/cuda/numpy RNG for reproducibility
- train_loss, val_loss, config
### Resume from Checkpoint
python train.py --resume output/checkpoint_step20000.pt --data data/swissprot_train.fasta --val_data data/swissprot_val.fasta
Resumes exact training state (model weights + optimizer + LR schedule + RNG seed).
## Evaluation: Zero-Shot Fitness Prediction
python scripts/evaluate_fitness.py --checkpoint output/checkpoint_final_best.pt
Uses masked language modeling logit difference for zero-shot variant effect prediction:
1. Encode wild-type (WT) sequence -> get per-position MLM logits
2. Encode mutant sequence -> get per-position MLM logits
3. Delta = Score(mutant) - Score(WT)
Positive Delta -> potentially beneficial mutation
Negative Delta -> potentially deleterious mutation
Tested on GFP (Green Fluorescent Protein) mutations:
| Mutation | Type | Expected Delta |
|---|---|---|
| K7V | neutral | ~0 |
| K7I | neutral | negative |
| G66Y | brighter | large negative |
| G66H | dimer | large negative |
## Model Upload to GitHub
After training, upload model weights to GitHub Release:
# Using GitHub CLI
gh release create v1.0.0 \
--title "ESM2-small v1.0.0" \
--notes "Trained on Swiss-Prot, val_loss=0.417"
gh release upload v1.0.0 output/checkpoint_final_best.pt
# Or using the upload script
python scripts/upload_to_github.py \
--token hf_xxxx \
--repo junior1p/ESM2-small \
--tag v1.0.0
## Known Limitations
- GFP Spearman rho ~ 0.200 (small model, short training — larger models achieve 0.3-0.5)
- No MSA or structure features — pure MLM only
- Single card training — multi-card scaling not included
- No dropout — model is meant for transfer/fine-tuning
## Pitfalls
- Resume requires --data flag re-passed: data loader not saved in checkpoints
- Truncated sequences: sequences > 512 tokens are cut at C-terminus
- Val loss > Train loss: normal for protein MLM
- MLU370 BF16: fallback to FP32 if CNNL BF16 ops fail
- Checkpoint disk space: ~5GB for full training run
Full reproducibility in three commands. The skill handles everything from data download to model upload.
# Clone the repository
git clone https://github.com/junior1p/ESM2-small.git
cd ESM2-small
# Download Swiss-Prot data (~433K train + 22K val sequences)
python scripts/download_data.py
# Train on GPU — ~11 hours on single MLU370, ~5h on A100
python train.py \
--data data/swissprot_train.fasta \
--val_data data/swissprot_val.fasta \
--out_dir output \
--epochs 5 \
--batch_size 32 \
--device cuda
# Evaluate zero-shot fitness on GFP mutations
python scripts/evaluate_fitness.py \
--checkpoint output/checkpoint_final_best.pt