Files
masterthesis-playground/raft/train_mistral_raft.py
2026-02-21 23:47:12 +01:00

96 lines
2.7 KiB
Python

import argparse
import os
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
ap.add_argument("--max_seq_len", type=int, default=2048)
ap.add_argument("--batch_size", type=int, default=1)
ap.add_argument("--grad_accum", type=int, default=16)
ap.add_argument("--lr", type=float, default=2e-4)
ap.add_argument("--epochs", type=int, default=1)
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# QLoRA (4-bit) config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=(
torch.bfloat16 if torch.cuda.is_available() else torch.float16
),
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)
# LoRA adapter config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
dataset = load_dataset("json", data_files=args.train_jsonl, split="train")
training_args = SFTConfig(
output_dir=args.out_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
logging_steps=10,
save_steps=200,
save_total_limit=2,
max_length=args.max_seq_len,
bf16=torch.cuda.is_available(),
fp16=not torch.cuda.is_available(),
report_to=[],
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
peft_config=peft_config,
)
trainer.train()
trainer.save_model(args.out_dir)
tokenizer.save_pretrained(args.out_dir)
print(f"Fertig! LoRA-Adapter gespeichert: {args.out_dir}")
if __name__ == "__main__":
main()