import argparse import json import os import torch from datasets import load_dataset from peft import LoraConfig from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from trl import SFTConfig, SFTTrainer def main(): ap = argparse.ArgumentParser() ap.add_argument("--train_jsonl", default="out/raft_train.jsonl") ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3") ap.add_argument("--out_dir", default="out/mistral_balitwin_lora") ap.add_argument("--max_seq_len", type=int, default=2048) ap.add_argument("--batch_size", type=int, default=1) ap.add_argument("--grad_accum", type=int, default=16) ap.add_argument("--lr", type=float, default=2e-4) ap.add_argument("--epochs", type=int, default=1) args = ap.parse_args() os.makedirs(args.out_dir, exist_ok=True) # QLoRA (4-bit) config (good default for 7B on limited VRAM) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=( torch.bfloat16 if torch.cuda.is_available() else torch.float16 ), bnb_4bit_use_double_quant=True, ) tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) # Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9} model = AutoModelForCausalLM.from_pretrained( args.base_model, device_map="auto", quantization_config=bnb_config, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16, ) # LoRA adapter config (tweak r/alpha if needed) peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], ) dataset = load_dataset("json", data_files=args.train_jsonl, split="train") training_args = SFTConfig( output_dir=args.out_dir, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, logging_steps=10, save_steps=200, save_total_limit=2, max_length=args.max_seq_len, bf16=torch.cuda.is_available(), fp16=not torch.cuda.is_available(), assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10} report_to=[], ) trainer = SFTTrainer( model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer, peft_config=peft_config, ) trainer.train() trainer.save_model(args.out_dir) tokenizer.save_pretrained(args.out_dir) print(f"Saved LoRA adapter to: {args.out_dir}") if __name__ == "__main__": main()