mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-02-04 05:03:11 +01:00
QLoRA stuff + datasets
This commit is contained in:
311
qlora/finetune_mistral_bali_qlora.py
Normal file
311
qlora/finetune_mistral_bali_qlora.py
Normal file
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
QLoRA SFT fine-tune for Mistral-7B on chat-style JSONL:
|
||||
Each line: {"messages": [{"role":"system","content":...}, {"role":"user","content":...}, {"role":"assistant","content":...}, ...]}
|
||||
|
||||
Produces a LoRA adapter you can merge or load at inference time.
|
||||
|
||||
Example:
|
||||
python finetune_mistral_bali_qlora.py \
|
||||
--model_id mistralai/Mistral-7B-Instruct-v0.2 \
|
||||
--train_jsonl /path/to/bali_train.jsonl \
|
||||
--output_dir ./mistral-bali-lora \
|
||||
--max_seq_len 2048 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--learning_rate 2e-4 \
|
||||
--num_train_epochs 2 \
|
||||
--streaming true
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
TrainingArguments,
|
||||
)
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Data formatting
|
||||
# -----------------------------
|
||||
def normalize_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Ensures message roles/content are well-formed and in allowed roles.
|
||||
"""
|
||||
allowed = {"system", "user", "assistant"}
|
||||
out = []
|
||||
for m in messages:
|
||||
role = (m.get("role") or "").strip().lower()
|
||||
content = m.get("content")
|
||||
if role not in allowed or content is None:
|
||||
continue
|
||||
content = str(content)
|
||||
out.append({"role": role, "content": content})
|
||||
return out
|
||||
|
||||
|
||||
def messages_to_text(tokenizer: AutoTokenizer, example: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Converts {"messages":[...]} to a single training text using the model's chat template if available.
|
||||
For Mistral Instruct models, tokenizer.apply_chat_template is typically present.
|
||||
"""
|
||||
messages = normalize_messages(example.get("messages", []))
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Prefer tokenizer chat template when available.
|
||||
if (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
# add_generation_prompt=False -> include the assistant content in the formatted text
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
|
||||
# Fallback formatting (less ideal than the native template):
|
||||
# Keep it deterministic and simple.
|
||||
parts = []
|
||||
for m in messages:
|
||||
r = m["role"]
|
||||
c = m["content"].strip()
|
||||
if r == "system":
|
||||
# Override system message
|
||||
system_message = """
|
||||
You are a specialized Balinese cultural travel expert. Your role is to provide accurate, culturally grounded, and practical guidance for travelers engaging with Balinese culture, including temples, ceremonies, etiquette, ritual calendars, dance, crafts, village life, sacred landscapes, and historical–spiritual context.
|
||||
|
||||
Prioritize cultural meaning and lived practice over sightseeing. Explain why places, rituals, and customs matter, and how visitors should behave respectfully. Emphasize dress codes, offerings, bodily conduct, photography rules, gender and purity considerations, and community norms.
|
||||
|
||||
Integrate timing and context where relevant, including ceremonial cycles (Pawukon/Wuku, full and new moons), festival periods, tides, agricultural rhythms, and temple schedules. Promote responsible tourism, community benefit, and environmental care, and discourage entry into restricted or sacred spaces.
|
||||
|
||||
Go beyond generic tips by naming specific temples, villages, regions, ceremonies, deities, and regional variations. Include practical logistics (access, hours, customary donations, crowd patterns) when helpful, without speculation. If uncertain, state this briefly and suggest local confirmation.
|
||||
|
||||
Structure responses clearly: a brief contextual introduction, followed by well-labeled sections or bullet points, and a short “Essentials” or “Respect Checklist” summary.
|
||||
|
||||
Do not include chain-of-thought, hidden reasoning, or meta commentary. Provide only polished, user-facing guidance in a calm, authoritative, and respectful tone.
|
||||
"""
|
||||
parts.append(f"<<SYS>>\n{system_message}\n<</SYS>>\n")
|
||||
elif r == "user":
|
||||
parts.append(f"[USER]\n{c}\n")
|
||||
else:
|
||||
parts.append(f"[ASSISTANT]\n{c}\n")
|
||||
return "\n".join(parts).strip() + "\n"
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Main
|
||||
# -----------------------------
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="e.g. mistralai/Mistral-7B-Instruct-v0.2 (recommended) or base model",
|
||||
)
|
||||
p.add_argument(
|
||||
"--train_jsonl",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to JSONL training file; each line has a 'messages' list.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--eval_jsonl",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional eval JSONL with same format.",
|
||||
)
|
||||
p.add_argument("--output_dir", type=str, required=True)
|
||||
|
||||
# Training hyperparameters
|
||||
p.add_argument("--max_seq_len", type=int, default=2048)
|
||||
p.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||||
p.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||||
p.add_argument("--gradient_accumulation_steps", type=int, default=16)
|
||||
p.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
p.add_argument("--weight_decay", type=float, default=0.0)
|
||||
p.add_argument("--num_train_epochs", type=float, default=1.0)
|
||||
p.add_argument("--warmup_ratio", type=float, default=0.03)
|
||||
p.add_argument("--logging_steps", type=int, default=10)
|
||||
p.add_argument("--save_steps", type=int, default=200)
|
||||
p.add_argument("--eval_steps", type=int, default=200)
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
|
||||
# Performance / memory
|
||||
p.add_argument(
|
||||
"--streaming",
|
||||
type=str,
|
||||
default="true",
|
||||
help="true/false. Use streaming for very large JSONL.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--bf16",
|
||||
type=str,
|
||||
default="true",
|
||||
help="true/false. Prefer bf16 if your GPU supports it.",
|
||||
)
|
||||
p.add_argument("--gradient_checkpointing", type=str, default="true")
|
||||
|
||||
# LoRA config
|
||||
p.add_argument("--lora_r", type=int, default=16)
|
||||
p.add_argument("--lora_alpha", type=int, default=32)
|
||||
p.add_argument("--lora_dropout", type=float, default=0.05)
|
||||
p.add_argument(
|
||||
"--target_modules",
|
||||
type=str,
|
||||
default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
|
||||
help="Comma-separated module names for Mistral-style architectures.",
|
||||
)
|
||||
|
||||
# Optional: limit samples for quick smoke tests
|
||||
p.add_argument("--max_train_samples", type=int, default=None)
|
||||
p.add_argument("--max_eval_samples", type=int, default=None)
|
||||
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def str2bool(x: str) -> bool:
|
||||
return str(x).strip().lower() in {"1", "true", "yes", "y", "t"}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
streaming = str2bool(args.streaming)
|
||||
use_bf16 = str2bool(args.bf16)
|
||||
use_gc = str2bool(args.gradient_checkpointing)
|
||||
|
||||
# -----------------------------
|
||||
# Tokenizer
|
||||
# -----------------------------
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True)
|
||||
if tokenizer.pad_token is None:
|
||||
# Common for causal LMs
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# -----------------------------
|
||||
# Model (4-bit QLoRA)
|
||||
# -----------------------------
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16 if use_bf16 else torch.float16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
|
||||
)
|
||||
|
||||
model.config.use_cache = False # important for training
|
||||
if use_gc:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare for k-bit + LoRA
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
target_modules = [m.strip() for m in args.target_modules.split(",") if m.strip()]
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# -----------------------------
|
||||
# Dataset
|
||||
# -----------------------------
|
||||
data_files = {"train": args.train_jsonl}
|
||||
if args.eval_jsonl:
|
||||
data_files["eval"] = args.eval_jsonl
|
||||
|
||||
ds = load_dataset("json", data_files=data_files, streaming=streaming)
|
||||
|
||||
def format_fn(example: Dict[str, Any]) -> Dict[str, str]:
|
||||
text = messages_to_text(tokenizer, example)
|
||||
return {"text": text}
|
||||
|
||||
train_ds = ds["train"].map(format_fn)
|
||||
eval_ds = ds["eval"].map(format_fn) if args.eval_jsonl else None
|
||||
|
||||
# Optional sample limits (works differently for streaming vs non-streaming)
|
||||
if args.max_train_samples is not None:
|
||||
if streaming:
|
||||
train_ds = train_ds.take(args.max_train_samples)
|
||||
else:
|
||||
train_ds = train_ds.select(
|
||||
range(min(args.max_train_samples, len(train_ds)))
|
||||
)
|
||||
|
||||
if eval_ds is not None and args.max_eval_samples is not None:
|
||||
if streaming:
|
||||
eval_ds = eval_ds.take(args.max_eval_samples)
|
||||
else:
|
||||
eval_ds = eval_ds.select(range(min(args.max_eval_samples, len(eval_ds))))
|
||||
|
||||
# -----------------------------
|
||||
# Training
|
||||
# -----------------------------
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
logging_steps=args.logging_steps,
|
||||
save_steps=args.save_steps,
|
||||
eval_strategy="steps" if eval_ds is not None else "no",
|
||||
eval_steps=args.eval_steps if eval_ds is not None else None,
|
||||
save_total_limit=3,
|
||||
bf16=use_bf16,
|
||||
fp16=not use_bf16,
|
||||
optim="paged_adamw_8bit", # good default for QLoRA
|
||||
lr_scheduler_type="cosine",
|
||||
seed=args.seed,
|
||||
report_to="none",
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=eval_ds,
|
||||
packing=True, # packs multiple conversations per sequence for higher throughput
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save LoRA adapter + tokenizer
|
||||
trainer.model.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
print(f"\nDone. Saved LoRA adapter to: {args.output_dir}")
|
||||
print("Inference: load base model + peft adapter from this directory.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
10
qlora/run.sh
Normal file
10
qlora/run.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
python finetune_mistral_bali_qlora.py \
|
||||
--model_id mistralai/Mistral-7B-Instruct-v0.2 \
|
||||
--train_jsonl ../raft/bali_culture_raft_dataset.jsonl \
|
||||
--output_dir ./mistral-bali-lora \
|
||||
--max_seq_len 2048 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--learning_rate 2e-4 \
|
||||
--num_train_epochs 2 \
|
||||
--streaming true
|
||||
Reference in New Issue
Block a user