mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 08:22:43 +01:00
New RAFT approach
This commit is contained in:
100
raft/train_mistral_raft.py
Normal file
100
raft/train_mistral_raft.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
## Usage: python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mistral_balitwin_lora
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3")
|
||||
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
|
||||
ap.add_argument("--max_seq_len", type=int, default=2048)
|
||||
ap.add_argument("--batch_size", type=int, default=1)
|
||||
ap.add_argument("--grad_accum", type=int, default=16)
|
||||
ap.add_argument("--lr", type=float, default=2e-4)
|
||||
ap.add_argument("--epochs", type=int, default=1)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
# QLoRA (4-bit) config (good default for 7B on limited VRAM)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=(
|
||||
torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
||||
),
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
||||
# Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config,
|
||||
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
|
||||
)
|
||||
|
||||
# LoRA adapter config (tweak r/alpha if needed)
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
)
|
||||
|
||||
dataset = load_dataset("json", data_files=args.train_jsonl, split="train")
|
||||
|
||||
training_args = SFTConfig(
|
||||
output_dir=args.out_dir,
|
||||
num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
gradient_accumulation_steps=args.grad_accum,
|
||||
learning_rate=args.lr,
|
||||
logging_steps=10,
|
||||
save_steps=200,
|
||||
save_total_limit=2,
|
||||
max_length=args.max_seq_len,
|
||||
bf16=torch.cuda.is_available(),
|
||||
fp16=not torch.cuda.is_available(),
|
||||
assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10}
|
||||
report_to=[],
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(args.out_dir)
|
||||
tokenizer.save_pretrained(args.out_dir)
|
||||
|
||||
print(f"Saved LoRA adapter to: {args.out_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user