Training Data Preparation¶
The most common cause of poor fine-tuning results isn't the model or the hyperparameters — it's the data. Garbage in, garbage out applies with extreme force to LLMs: the model will learn exactly what your data shows, including its inconsistencies, errors, and biases.
Learning objectives¶
- Format datasets in JSONL for SFT
- Apply label masking so loss only computes on completions
- Assess data quality: diversity, consistency, length distribution
- Handle data imbalance and deduplication
- Know how much data you need for common task types
The JSONL format¶
The standard format for SFT datasets is JSONL (JSON Lines) — one training example per line:
import json
# Format 1: Alpaca-style (instruction, input, output)
alpaca_examples = [
{
"instruction": "Classify the sentiment of this review.",
"input": "The product broke after two days. Terrible quality.",
"output": "negative"
},
{
"instruction": "Classify the sentiment of this review.",
"input": "Amazing! Exactly what I needed.",
"output": "positive"
},
]
# Format 2: ShareGPT-style (messages list — preferred for chat models)
sharegpt_examples = [
{
"messages": [
{"role": "system", "content": "You are a sentiment classifier. Respond with: positive, negative, or neutral."},
{"role": "user", "content": "The product broke after two days. Terrible quality."},
{"role": "assistant", "content": "negative"}
]
},
]
# Write to JSONL
with open("train.jsonl", "w") as f:
for ex in sharegpt_examples:
f.write(json.dumps(ex) + "\n")
Label masking: only train on the completion¶
In instruction fine-tuning, you don't want the model to learn to predict the prompt — only the response. Label masking sets prompt tokens to -100 (ignored in loss computation).
from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
def format_and_mask(example: dict) -> dict:
"""
Format a messages-style example and mask prompt tokens in labels.
Only the assistant turn contributes to the loss.
"""
# Format full prompt including response
full_text = tokenizer.apply_chat_template(
example["messages"], tokenize=False, add_generation_prompt=False
)
# Format prompt only (no response)
prompt_messages = [m for m in example["messages"] if m["role"] != "assistant"]
prompt_text = tokenizer.apply_chat_template(
prompt_messages, tokenize=False, add_generation_prompt=True
)
# Tokenize both
full_tokens = tokenizer(full_text, return_tensors="pt")["input_ids"][0]
prompt_tokens = tokenizer(prompt_text, return_tensors="pt")["input_ids"][0]
# Build labels: -100 for prompt, actual IDs for completion
labels = full_tokens.clone()
labels[:len(prompt_tokens)] = -100
return {
"input_ids": full_tokens,
"labels": labels,
"attention_mask": torch.ones_like(full_tokens),
}
# Test
example = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"}
]
}
result = format_and_mask(example)
print("input_ids length:", len(result["input_ids"]))
print("labels (non-masked portion):", result["labels"][result["labels"] != -100])
SFTTrainer handles masking automatically
If you use SFTTrainer with DataCollatorForCompletionOnlyLM, masking is handled automatically. Only implement manual masking if you're writing a custom training loop.
Data quality checklist¶
from datasets import Dataset
from collections import Counter
import re
def audit_dataset(examples: list[dict]) -> dict:
"""Quick quality audit for SFT datasets."""
lengths = []
seen = set()
duplicates = 0
empty_outputs = 0
for ex in examples:
# Get output text
if "messages" in ex:
output = next((m["content"] for m in ex["messages"] if m["role"] == "assistant"), "")
else:
output = ex.get("output", "")
if not output.strip():
empty_outputs += 1
continue
# Check for duplicates
key = output.strip().lower()
if key in seen:
duplicates += 1
seen.add(key)
lengths.append(len(output.split()))
avg_len = sum(lengths) / len(lengths) if lengths else 0
return {
"total": len(examples),
"empty_outputs": empty_outputs,
"duplicates": duplicates,
"avg_output_words": round(avg_len, 1),
"min_output_words": min(lengths) if lengths else 0,
"max_output_words": max(lengths) if lengths else 0,
}
# Sample audit
examples = [
{"messages": [{"role": "user", "content": "Q1"}, {"role": "assistant", "content": "Answer one"}]},
{"messages": [{"role": "user", "content": "Q2"}, {"role": "assistant", "content": "Answer one"}]}, # duplicate output
{"messages": [{"role": "user", "content": "Q3"}, {"role": "assistant", "content": ""}]}, # empty
]
report = audit_dataset(examples)
print(report)
# {'total': 3, 'empty_outputs': 1, 'duplicates': 1, 'avg_output_words': 2.0, 'min': 2, 'max': 2}
How much data do you need?¶
DATA_REQUIREMENTS = {
"Classification (2–5 classes)": {
"minimum": "50–100 examples per class",
"good": "200–500 per class",
"notes": "Balance classes within 2:1 ratio"
},
"Named entity extraction": {
"minimum": "200–500 diverse sentences",
"good": "1,000–5,000",
"notes": "Entity diversity matters more than raw count"
},
"Style / tone transfer": {
"minimum": "100–300 before/after pairs",
"good": "500–1,000",
"notes": "Quality >> quantity; inconsistent pairs hurt badly"
},
"Domain-specific Q&A": {
"minimum": "500–1,000 Q&A pairs",
"good": "2,000–10,000",
"notes": "Cover edge cases, not just common queries"
},
"Code generation": {
"minimum": "1,000 function-level examples",
"good": "10,000+",
"notes": "Include tests; code must be executable"
},
}
for task, reqs in DATA_REQUIREMENTS.items():
print(f"\n{task}")
print(f" Minimum: {reqs['minimum']}")
print(f" Good: {reqs['good']}")
print(f" Note: {reqs['notes']}")
Handling class imbalance¶
from datasets import Dataset, concatenate_datasets
from collections import Counter
def balance_dataset(examples: list[dict], label_field: str = "output", max_ratio: float = 3.0) -> list[dict]:
"""Downsample majority class to max_ratio × minority class size."""
by_label: dict[str, list] = {}
for ex in examples:
label = ex[label_field]
by_label.setdefault(label, []).append(ex)
counts = {k: len(v) for k, v in by_label.items()}
min_count = min(counts.values())
max_allowed = int(min_count * max_ratio)
balanced = []
for label, items in by_label.items():
balanced.extend(items[:max_allowed])
print(f"Original: {Counter(ex[label_field] for ex in examples)}")
print(f"Balanced: {Counter(ex[label_field] for ex in balanced)}")
return balanced
# Example
raw = (
[{"output": "positive"}] * 1000 +
[{"output": "negative"}] * 200 +
[{"output": "neutral"}] * 50
)
balanced = balance_dataset(raw, max_ratio=2.0)
Generating synthetic training data¶
When labeled data is scarce, use a strong model to generate synthetic examples:
import os
import json
from openai import OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def generate_synthetic_examples(task_description: str, n: int = 20) -> list[dict]:
resp = client.chat.completions.create(
model="gpt-4o",
messages=[{
"role": "user",
"content": f"""Generate {n} diverse training examples for: {task_description}
Return JSON array: [{{"input": "...", "output": "..."}}]
Ensure diverse inputs covering edge cases."""
}],
temperature=0.9,
response_format={"type": "json_object"},
)
data = json.loads(resp.choices[0].message.content)
return data.get("examples", data.get("data", []))
examples = generate_synthetic_examples(
"Classify customer support tickets as: billing, technical, account, or general",
n=10
)
for ex in examples[:3]:
print(f"Input: {ex['input'][:60]}")
print(f"Output: {ex['output']}\n")
Verify synthetic data quality
Synthetic data from GPT-4o is high-quality but can be homogeneous — the model generates the "typical" example, not the edge case. Always review a sample manually and supplement with real edge cases.