Files
masterthesis-playground/raft/raft_finetune_qlora.ipynb
2025-10-13 17:34:49 +02:00

513 lines
16 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"id": "2957fcef",
"metadata": {},
"source": [
"\n",
"# RAFT Supervised Fine-Tuning (QLoRA) — Local Training\n",
"\n",
"This notebook fine-tunes an open-source base model on a RAFT-style dataset (`input` → `output`) using **QLoRA** with **PEFT** and **Transformers**. It is designed to run locally (single or multi-GPU) and to export both **LoRA adapters** and (optionally) a **merged** model for inference.\n",
"\n",
"> **Assumptions**\n",
"> - Your dataset lives at `./outputs/raft_dataset.jsonl` (from the previous notebook). Adjust the path if needed.\n",
"> - You have a CUDA-capable GPU and can install `bitsandbytes`. (CPU training is possible but slow.)\n",
"> - You have enough VRAM for the chosen base model when loaded in 4-bit NF4.\n"
]
},
{
"cell_type": "markdown",
"id": "202f729e",
"metadata": {},
"source": [
"## 0) Install dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2da670d5",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# If needed, uncomment the following installs:\n",
"# %pip install --quiet transformers==4.44.2 datasets==2.20.0 peft==0.12.0 accelerate==0.34.2 bitsandbytes==0.43.3 evaluate==0.4.2 sentencepiece==0.2.0\n",
"# Optional extras:\n",
"# %pip install --quiet trl==0.9.6 sacrebleu==2.4.3 rouge-score==0.1.2\n"
]
},
{
"cell_type": "markdown",
"id": "1c047191",
"metadata": {},
"source": [
"## 1) Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8c8d385",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from pathlib import Path\n",
"\n",
"# Paths\n",
"DATA_JSONL = Path(\"./outputs/raft_dataset.jsonl\") # change if different\n",
"RUN_NAME = \"raft_qlora_run\"\n",
"OUTPUT_DIR = Path(f\"./finetuned/{RUN_NAME}\")\n",
"OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
"\n",
"# Base model — examples: \"meta-llama/Llama-3.1-8B\", \"Qwen/Qwen2-7B-Instruct\", \"mistralai/Mistral-7B-Instruct-v0.3\"\n",
"# Prefer an instruction-tuned base for better stability on SFT.\n",
"BASE_MODEL = \"mistralai/Mistral-7B-Instruct-v0.3\"\n",
"\n",
"# Tokenization/prompt formatting\n",
"SYSTEM_PREFIX = \"You are a helpful assistant. Answer concisely and truthfully based ONLY on the user's request.\"\n",
"USE_CHAT_TEMPLATE = True # if the tokenizer has a chat template, we'll leverage it\n",
"\n",
"# QLoRA/PEFT params\n",
"LORA_R = 16\n",
"LORA_ALPHA = 32\n",
"LORA_DROPOUT = 0.05\n",
"TARGET_MODULES = None # None = let PEFT auto-detect common modules (works for most models)\n",
"\n",
"# 4-bit quantization (QLoRA)\n",
"LOAD_IN_4BIT = True\n",
"BNB_4BIT_COMPUTE_DTYPE = \"bfloat16\" # \"float16\" or \"bfloat16\"\n",
"BNB_4BIT_QUANT_TYPE = \"nf4\" # \"nf4\" or \"fp4\"\n",
"BNB_4BIT_USE_DOUBLE_QUANT = True\n",
"\n",
"# Training\n",
"TRAIN_VAL_SPLIT = 0.98\n",
"MAX_SEQ_LEN = 2048\n",
"PER_DEVICE_TRAIN_BATCH = 1\n",
"PER_DEVICE_EVAL_BATCH = 1\n",
"GRADIENT_ACCUM_STEPS = 16\n",
"LEARNING_RATE = 2e-4\n",
"NUM_TRAIN_EPOCHS = 2\n",
"WEIGHT_DECAY = 0.0\n",
"WARMUP_RATIO = 0.03\n",
"LR_SCHEDULER_TYPE = \"cosine\"\n",
"LOGGING_STEPS = 10\n",
"EVAL_STEPS = 200\n",
"SAVE_STEPS = 200\n",
"BF16 = True\n",
"FP16 = False\n",
"\n",
"SEED = 42\n"
]
},
{
"cell_type": "markdown",
"id": "6c1439a8",
"metadata": {},
"source": [
"## 2) Load dataset (JSONL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f43262fc",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import json, random\n",
"from datasets import Dataset\n",
"\n",
"def read_jsonl(p: Path):\n",
" rows = []\n",
" with p.open(\"r\", encoding=\"utf-8\") as f:\n",
" for line in f:\n",
" line = line.strip()\n",
" if not line:\n",
" continue\n",
" try:\n",
" obj = json.loads(line)\n",
" if \"input\" in obj and \"output\" in obj:\n",
" rows.append(obj)\n",
" except Exception:\n",
" pass\n",
" return rows\n",
"\n",
"rows = read_jsonl(DATA_JSONL)\n",
"print(f\"Loaded {len(rows)} rows from {DATA_JSONL}\")\n",
"\n",
"random.Random(SEED).shuffle(rows)\n",
"split = int(len(rows) * TRAIN_VAL_SPLIT)\n",
"train_rows = rows[:split]\n",
"val_rows = rows[split:] if split < len(rows) else rows[-max(1, len(rows)//50):]\n",
"\n",
"train_ds = Dataset.from_list(train_rows)\n",
"eval_ds = Dataset.from_list(val_rows) if val_rows else None\n",
"train_ds, eval_ds\n"
]
},
{
"cell_type": "markdown",
"id": "2dd30f5a",
"metadata": {},
"source": [
"## 3) Prompt formatting"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "155aad2a",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)\n",
"if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"def format_example(ex):\n",
" user = ex[\"input\"]\n",
" assistant = ex[\"output\"]\n",
"\n",
" if USE_CHAT_TEMPLATE and hasattr(tokenizer, \"apply_chat_template\"):\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PREFIX},\n",
" {\"role\": \"user\", \"content\": user},\n",
" {\"role\": \"assistant\", \"content\": assistant},\n",
" ]\n",
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)\n",
" else:\n",
" text = f\"<s>[SYSTEM]\\n{SYSTEM_PREFIX}\\n[/SYSTEM]\\n[USER]\\n{user}\\n[/USER]\\n[ASSISTANT]\\n{assistant}</s>\"\n",
" return {\"text\": text}\n",
"\n",
"train_ds_fmt = train_ds.map(format_example, remove_columns=train_ds.column_names)\n",
"eval_ds_fmt = eval_ds.map(format_example, remove_columns=eval_ds.column_names) if eval_ds else None\n",
"\n",
"print(train_ds_fmt[0][\"text\"][:400])\n"
]
},
{
"cell_type": "markdown",
"id": "4a9f30a8",
"metadata": {},
"source": [
"## 4) Tokenize"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f7eaa2c",
"metadata": {},
"outputs": [],
"source": [
"\n",
"def tokenize(batch):\n",
" return tokenizer(\n",
" batch[\"text\"],\n",
" truncation=True,\n",
" max_length=MAX_SEQ_LEN,\n",
" padding=\"max_length\",\n",
" return_tensors=None,\n",
" )\n",
"\n",
"train_tok = train_ds_fmt.map(tokenize, batched=True, remove_columns=train_ds_fmt.column_names)\n",
"eval_tok = eval_ds_fmt.map(tokenize, batched=True, remove_columns=eval_ds_fmt.column_names) if eval_ds_fmt else None\n",
"\n",
"train_tok = train_tok.rename_column(\"input_ids\", \"input_ids\")\n",
"train_tok = train_tok.add_column(\"labels\", train_tok[\"input_ids\"])\n",
"if eval_tok:\n",
" eval_tok = eval_tok.add_column(\"labels\", eval_tok[\"input_ids\"])\n",
"\n",
"train_tok, (eval_tok[0]['input_ids'][:10] if eval_tok else [])\n"
]
},
{
"cell_type": "markdown",
"id": "5f53fc1e",
"metadata": {},
"source": [
"## 5) Load base model with 4-bit quantization and prepare QLoRA"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a21d625f",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import torch\n",
"from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n",
"from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
"\n",
"bnb_config = None\n",
"if LOAD_IN_4BIT:\n",
" bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_use_double_quant=BNB_4BIT_USE_DOUBLE_QUANT,\n",
" bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE,\n",
" bnb_4bit_compute_dtype=getattr(torch, BNB_4BIT_COMPUTE_DTYPE)\n",
" )\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" BASE_MODEL,\n",
" quantization_config=bnb_config,\n",
" torch_dtype=torch.bfloat16 if BF16 else (torch.float16 if FP16 else None),\n",
" device_map=\"auto\",\n",
")\n",
"\n",
"model = prepare_model_for_kbit_training(model)\n",
"\n",
"peft_config = LoraConfig(\n",
" r=LORA_R,\n",
" lora_alpha=LORA_ALPHA,\n",
" lora_dropout=LORA_DROPOUT,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
" target_modules=TARGET_MODULES,\n",
")\n",
"\n",
"model = get_peft_model(model, peft_config)\n",
"model.print_trainable_parameters()\n"
]
},
{
"cell_type": "markdown",
"id": "b081dbd3",
"metadata": {},
"source": [
"## 6) Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3afd65f7",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling\n",
"import math\n",
"\n",
"data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n",
"\n",
"args = TrainingArguments(\n",
" output_dir=str(OUTPUT_DIR),\n",
" run_name=RUN_NAME,\n",
" num_train_epochs=NUM_TRAIN_EPOCHS,\n",
" per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH,\n",
" per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,\n",
" gradient_accumulation_steps=GRADIENT_ACCUM_STEPS,\n",
" learning_rate=LEARNING_RATE,\n",
" lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
" warmup_ratio=WARMUP_RATIO,\n",
" weight_decay=WEIGHT_DECAY,\n",
" logging_steps=LOGGING_STEPS,\n",
" evaluation_strategy=\"steps\",\n",
" eval_steps=EVAL_STEPS,\n",
" save_steps=SAVE_STEPS,\n",
" save_total_limit=2,\n",
" bf16=BF16,\n",
" fp16=FP16,\n",
" gradient_checkpointing=True,\n",
" report_to=[\"none\"],\n",
" seed=SEED,\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" args=args,\n",
" train_dataset=train_tok,\n",
" eval_dataset=eval_tok,\n",
" data_collator=data_collator,\n",
")\n",
"\n",
"train_result = trainer.train()\n",
"metrics = trainer.evaluate() if eval_tok else {}\n",
"perplexity = math.exp(metrics[\"eval_loss\"]) if metrics and \"eval_loss\" in metrics else None\n",
"metrics, perplexity\n"
]
},
{
"cell_type": "markdown",
"id": "e22700a2",
"metadata": {},
"source": [
"## 7) Save LoRA adapters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efc434ce",
"metadata": {},
"outputs": [],
"source": [
"\n",
"adapter_dir = OUTPUT_DIR / \"lora_adapter\"\n",
"adapter_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
"model.save_pretrained(str(adapter_dir))\n",
"tokenizer.save_pretrained(str(adapter_dir))\n",
"\n",
"print(f\"Saved LoRA adapter to: {adapter_dir}\")\n"
]
},
{
"cell_type": "markdown",
"id": "afb33cae",
"metadata": {},
"source": [
"## 8) (Optional) Merge adapters into base model and save full weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc6ccdee",
"metadata": {},
"outputs": [],
"source": [
"\n",
"DO_MERGE = False # set True to produce a standalone merged model\n",
"\n",
"if DO_MERGE:\n",
" from peft import PeftModel\n",
" base_model = AutoModelForCausalLM.from_pretrained(\n",
" BASE_MODEL,\n",
" torch_dtype=torch.bfloat16 if BF16 else (torch.float16 if FP16 else None),\n",
" device_map=\"auto\",\n",
" )\n",
" merged = PeftModel.from_pretrained(base_model, str(adapter_dir)).merge_and_unload()\n",
" merged_dir = OUTPUT_DIR / \"merged_model\"\n",
" merged.save_pretrained(str(merged_dir))\n",
" tokenizer.save_pretrained(str(merged_dir))\n",
" print(f\"Merged full model saved to: {merged_dir}\")\n",
"else:\n",
" print(\"Skipping merge (set DO_MERGE=True to enable).\")\n"
]
},
{
"cell_type": "markdown",
"id": "010055a7",
"metadata": {},
"source": [
"## 9) Quick inference with the trained adapter"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40f3a8a5",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from peft import PeftModel\n",
"import torch\n",
"\n",
"test_model = AutoModelForCausalLM.from_pretrained(\n",
" BASE_MODEL,\n",
" quantization_config=bnb_config,\n",
" torch_dtype=torch.bfloat16 if BF16 else (torch.float16 if FP16 else None),\n",
" device_map=\"auto\",\n",
")\n",
"test_model = PeftModel.from_pretrained(test_model, str(adapter_dir))\n",
"test_model.eval()\n",
"\n",
"def generate_answer(prompt, max_new_tokens=256, temperature=0.2, top_p=0.9):\n",
" if USE_CHAT_TEMPLATE and hasattr(tokenizer, \"apply_chat_template\"):\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PREFIX},\n",
" {\"role\": \"user\", \"content\": prompt},\n",
" ]\n",
" model_inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(test_model.device)\n",
" else:\n",
" text = f\"<s>[SYSTEM]\\n{SYSTEM_PREFIX}\\n[/SYSTEM]\\n[USER]\\n{prompt}\\n[/USER]\\n[ASSISTANT]\\n\"\n",
" model_inputs = tokenizer([text], return_tensors=\"pt\").to(test_model.device)\n",
"\n",
" with torch.no_grad():\n",
" out = test_model.generate(\n",
" **model_inputs,\n",
" do_sample=True,\n",
" max_new_tokens=max_new_tokens,\n",
" temperature=temperature,\n",
" top_p=top_p,\n",
" eos_token_id=tokenizer.eos_token_id,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" )\n",
" return tokenizer.decode(out[0], skip_special_tokens=True)\n",
"\n",
"sample_prompt = (train_rows[0][\"input\"] if len(train_rows)>0 else \"What are the visitor crowd levels like?\")\n",
"print(generate_answer(sample_prompt)[:800])\n"
]
},
{
"cell_type": "markdown",
"id": "3638b421",
"metadata": {},
"source": [
"## 10) Light evaluation on the validation set"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "28129cf7",
"metadata": {},
"outputs": [],
"source": [
"\n",
"import evaluate\n",
"\n",
"if eval_ds:\n",
" rouge = evaluate.load(\"rouge\")\n",
" preds, refs = [], []\n",
" for ex in val_rows[:50]:\n",
" preds.append(generate_answer(ex[\"input\"], max_new_tokens=192, temperature=0.0))\n",
" refs.append(ex[\"output\"])\n",
" results = rouge.compute(predictions=preds, references=refs)\n",
" print(results)\n",
"else:\n",
" print(\"No eval split available; skipped.\")\n"
]
},
{
"cell_type": "markdown",
"id": "1ca0d748",
"metadata": {},
"source": [
"\n",
"## 11) (Optional) Use with other runtimes\n",
"\n",
"- **Python Inference (PEFT)**: Load base model + adapter as shown in Section 9.\n",
"- **Merged model**: Set `DO_MERGE=True` to create a standalone model directory; you can then convert to other runtimes (e.g., llama.cpp GGUF) using their conversion tools.\n",
"- **Ollama**: If your runtime supports adapters or merged weights for the chosen base model, create a `Modelfile` pointing to them. Need a concrete path? Tell me your base and target runtime and Ill add exact steps.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.x"
}
},
"nbformat": 4,
"nbformat_minor": 5
}