mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2025-12-06 18:20:53 +01:00
513 lines
16 KiB
Plaintext
513 lines
16 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2957fcef",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"# RAFT Supervised Fine-Tuning (QLoRA) — Local Training\n",
|
||
"\n",
|
||
"This notebook fine-tunes an open-source base model on a RAFT-style dataset (`input` → `output`) using **QLoRA** with **PEFT** and **Transformers**. It is designed to run locally (single or multi-GPU) and to export both **LoRA adapters** and (optionally) a **merged** model for inference.\n",
|
||
"\n",
|
||
"> **Assumptions**\n",
|
||
"> - Your dataset lives at `./outputs/raft_dataset.jsonl` (from the previous notebook). Adjust the path if needed.\n",
|
||
"> - You have a CUDA-capable GPU and can install `bitsandbytes`. (CPU training is possible but slow.)\n",
|
||
"> - You have enough VRAM for the chosen base model when loaded in 4-bit NF4.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "202f729e",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 0) Install dependencies"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "2da670d5",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"# If needed, uncomment the following installs:\n",
|
||
"# %pip install --quiet transformers==4.44.2 datasets==2.20.0 peft==0.12.0 accelerate==0.34.2 bitsandbytes==0.43.3 evaluate==0.4.2 sentencepiece==0.2.0\n",
|
||
"# Optional extras:\n",
|
||
"# %pip install --quiet trl==0.9.6 sacrebleu==2.4.3 rouge-score==0.1.2\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1c047191",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1) Configuration"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "f8c8d385",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"from pathlib import Path\n",
|
||
"\n",
|
||
"# Paths\n",
|
||
"DATA_JSONL = Path(\"./outputs/raft_dataset.jsonl\") # change if different\n",
|
||
"RUN_NAME = \"raft_qlora_run\"\n",
|
||
"OUTPUT_DIR = Path(f\"./finetuned/{RUN_NAME}\")\n",
|
||
"OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"# Base model — examples: \"meta-llama/Llama-3.1-8B\", \"Qwen/Qwen2-7B-Instruct\", \"mistralai/Mistral-7B-Instruct-v0.3\"\n",
|
||
"# Prefer an instruction-tuned base for better stability on SFT.\n",
|
||
"BASE_MODEL = \"mistralai/Mistral-7B-Instruct-v0.3\"\n",
|
||
"\n",
|
||
"# Tokenization/prompt formatting\n",
|
||
"SYSTEM_PREFIX = \"You are a helpful assistant. Answer concisely and truthfully based ONLY on the user's request.\"\n",
|
||
"USE_CHAT_TEMPLATE = True # if the tokenizer has a chat template, we'll leverage it\n",
|
||
"\n",
|
||
"# QLoRA/PEFT params\n",
|
||
"LORA_R = 16\n",
|
||
"LORA_ALPHA = 32\n",
|
||
"LORA_DROPOUT = 0.05\n",
|
||
"TARGET_MODULES = None # None = let PEFT auto-detect common modules (works for most models)\n",
|
||
"\n",
|
||
"# 4-bit quantization (QLoRA)\n",
|
||
"LOAD_IN_4BIT = True\n",
|
||
"BNB_4BIT_COMPUTE_DTYPE = \"bfloat16\" # \"float16\" or \"bfloat16\"\n",
|
||
"BNB_4BIT_QUANT_TYPE = \"nf4\" # \"nf4\" or \"fp4\"\n",
|
||
"BNB_4BIT_USE_DOUBLE_QUANT = True\n",
|
||
"\n",
|
||
"# Training\n",
|
||
"TRAIN_VAL_SPLIT = 0.98\n",
|
||
"MAX_SEQ_LEN = 2048\n",
|
||
"PER_DEVICE_TRAIN_BATCH = 1\n",
|
||
"PER_DEVICE_EVAL_BATCH = 1\n",
|
||
"GRADIENT_ACCUM_STEPS = 16\n",
|
||
"LEARNING_RATE = 2e-4\n",
|
||
"NUM_TRAIN_EPOCHS = 2\n",
|
||
"WEIGHT_DECAY = 0.0\n",
|
||
"WARMUP_RATIO = 0.03\n",
|
||
"LR_SCHEDULER_TYPE = \"cosine\"\n",
|
||
"LOGGING_STEPS = 10\n",
|
||
"EVAL_STEPS = 200\n",
|
||
"SAVE_STEPS = 200\n",
|
||
"BF16 = True\n",
|
||
"FP16 = False\n",
|
||
"\n",
|
||
"SEED = 42\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "6c1439a8",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2) Load dataset (JSONL)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "f43262fc",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"import json, random\n",
|
||
"from datasets import Dataset\n",
|
||
"\n",
|
||
"def read_jsonl(p: Path):\n",
|
||
" rows = []\n",
|
||
" with p.open(\"r\", encoding=\"utf-8\") as f:\n",
|
||
" for line in f:\n",
|
||
" line = line.strip()\n",
|
||
" if not line:\n",
|
||
" continue\n",
|
||
" try:\n",
|
||
" obj = json.loads(line)\n",
|
||
" if \"input\" in obj and \"output\" in obj:\n",
|
||
" rows.append(obj)\n",
|
||
" except Exception:\n",
|
||
" pass\n",
|
||
" return rows\n",
|
||
"\n",
|
||
"rows = read_jsonl(DATA_JSONL)\n",
|
||
"print(f\"Loaded {len(rows)} rows from {DATA_JSONL}\")\n",
|
||
"\n",
|
||
"random.Random(SEED).shuffle(rows)\n",
|
||
"split = int(len(rows) * TRAIN_VAL_SPLIT)\n",
|
||
"train_rows = rows[:split]\n",
|
||
"val_rows = rows[split:] if split < len(rows) else rows[-max(1, len(rows)//50):]\n",
|
||
"\n",
|
||
"train_ds = Dataset.from_list(train_rows)\n",
|
||
"eval_ds = Dataset.from_list(val_rows) if val_rows else None\n",
|
||
"train_ds, eval_ds\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2dd30f5a",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 3) Prompt formatting"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "155aad2a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"from transformers import AutoTokenizer\n",
|
||
"\n",
|
||
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)\n",
|
||
"if tokenizer.pad_token is None:\n",
|
||
" tokenizer.pad_token = tokenizer.eos_token\n",
|
||
"\n",
|
||
"def format_example(ex):\n",
|
||
" user = ex[\"input\"]\n",
|
||
" assistant = ex[\"output\"]\n",
|
||
"\n",
|
||
" if USE_CHAT_TEMPLATE and hasattr(tokenizer, \"apply_chat_template\"):\n",
|
||
" messages = [\n",
|
||
" {\"role\": \"system\", \"content\": SYSTEM_PREFIX},\n",
|
||
" {\"role\": \"user\", \"content\": user},\n",
|
||
" {\"role\": \"assistant\", \"content\": assistant},\n",
|
||
" ]\n",
|
||
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)\n",
|
||
" else:\n",
|
||
" text = f\"<s>[SYSTEM]\\n{SYSTEM_PREFIX}\\n[/SYSTEM]\\n[USER]\\n{user}\\n[/USER]\\n[ASSISTANT]\\n{assistant}</s>\"\n",
|
||
" return {\"text\": text}\n",
|
||
"\n",
|
||
"train_ds_fmt = train_ds.map(format_example, remove_columns=train_ds.column_names)\n",
|
||
"eval_ds_fmt = eval_ds.map(format_example, remove_columns=eval_ds.column_names) if eval_ds else None\n",
|
||
"\n",
|
||
"print(train_ds_fmt[0][\"text\"][:400])\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "4a9f30a8",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4) Tokenize"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "0f7eaa2c",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"def tokenize(batch):\n",
|
||
" return tokenizer(\n",
|
||
" batch[\"text\"],\n",
|
||
" truncation=True,\n",
|
||
" max_length=MAX_SEQ_LEN,\n",
|
||
" padding=\"max_length\",\n",
|
||
" return_tensors=None,\n",
|
||
" )\n",
|
||
"\n",
|
||
"train_tok = train_ds_fmt.map(tokenize, batched=True, remove_columns=train_ds_fmt.column_names)\n",
|
||
"eval_tok = eval_ds_fmt.map(tokenize, batched=True, remove_columns=eval_ds_fmt.column_names) if eval_ds_fmt else None\n",
|
||
"\n",
|
||
"train_tok = train_tok.rename_column(\"input_ids\", \"input_ids\")\n",
|
||
"train_tok = train_tok.add_column(\"labels\", train_tok[\"input_ids\"])\n",
|
||
"if eval_tok:\n",
|
||
" eval_tok = eval_tok.add_column(\"labels\", eval_tok[\"input_ids\"])\n",
|
||
"\n",
|
||
"train_tok, (eval_tok[0]['input_ids'][:10] if eval_tok else [])\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "5f53fc1e",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 5) Load base model with 4-bit quantization and prepare QLoRA"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a21d625f",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"import torch\n",
|
||
"from transformers import AutoModelForCausalLM, BitsAndBytesConfig\n",
|
||
"from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
|
||
"\n",
|
||
"bnb_config = None\n",
|
||
"if LOAD_IN_4BIT:\n",
|
||
" bnb_config = BitsAndBytesConfig(\n",
|
||
" load_in_4bit=True,\n",
|
||
" bnb_4bit_use_double_quant=BNB_4BIT_USE_DOUBLE_QUANT,\n",
|
||
" bnb_4bit_quant_type=BNB_4BIT_QUANT_TYPE,\n",
|
||
" bnb_4bit_compute_dtype=getattr(torch, BNB_4BIT_COMPUTE_DTYPE)\n",
|
||
" )\n",
|
||
"\n",
|
||
"model = AutoModelForCausalLM.from_pretrained(\n",
|
||
" BASE_MODEL,\n",
|
||
" quantization_config=bnb_config,\n",
|
||
" torch_dtype=torch.bfloat16 if BF16 else (torch.float16 if FP16 else None),\n",
|
||
" device_map=\"auto\",\n",
|
||
")\n",
|
||
"\n",
|
||
"model = prepare_model_for_kbit_training(model)\n",
|
||
"\n",
|
||
"peft_config = LoraConfig(\n",
|
||
" r=LORA_R,\n",
|
||
" lora_alpha=LORA_ALPHA,\n",
|
||
" lora_dropout=LORA_DROPOUT,\n",
|
||
" bias=\"none\",\n",
|
||
" task_type=\"CAUSAL_LM\",\n",
|
||
" target_modules=TARGET_MODULES,\n",
|
||
")\n",
|
||
"\n",
|
||
"model = get_peft_model(model, peft_config)\n",
|
||
"model.print_trainable_parameters()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b081dbd3",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 6) Train"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "3afd65f7",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling\n",
|
||
"import math\n",
|
||
"\n",
|
||
"data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n",
|
||
"\n",
|
||
"args = TrainingArguments(\n",
|
||
" output_dir=str(OUTPUT_DIR),\n",
|
||
" run_name=RUN_NAME,\n",
|
||
" num_train_epochs=NUM_TRAIN_EPOCHS,\n",
|
||
" per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH,\n",
|
||
" per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,\n",
|
||
" gradient_accumulation_steps=GRADIENT_ACCUM_STEPS,\n",
|
||
" learning_rate=LEARNING_RATE,\n",
|
||
" lr_scheduler_type=LR_SCHEDULER_TYPE,\n",
|
||
" warmup_ratio=WARMUP_RATIO,\n",
|
||
" weight_decay=WEIGHT_DECAY,\n",
|
||
" logging_steps=LOGGING_STEPS,\n",
|
||
" evaluation_strategy=\"steps\",\n",
|
||
" eval_steps=EVAL_STEPS,\n",
|
||
" save_steps=SAVE_STEPS,\n",
|
||
" save_total_limit=2,\n",
|
||
" bf16=BF16,\n",
|
||
" fp16=FP16,\n",
|
||
" gradient_checkpointing=True,\n",
|
||
" report_to=[\"none\"],\n",
|
||
" seed=SEED,\n",
|
||
")\n",
|
||
"\n",
|
||
"trainer = Trainer(\n",
|
||
" model=model,\n",
|
||
" tokenizer=tokenizer,\n",
|
||
" args=args,\n",
|
||
" train_dataset=train_tok,\n",
|
||
" eval_dataset=eval_tok,\n",
|
||
" data_collator=data_collator,\n",
|
||
")\n",
|
||
"\n",
|
||
"train_result = trainer.train()\n",
|
||
"metrics = trainer.evaluate() if eval_tok else {}\n",
|
||
"perplexity = math.exp(metrics[\"eval_loss\"]) if metrics and \"eval_loss\" in metrics else None\n",
|
||
"metrics, perplexity\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e22700a2",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 7) Save LoRA adapters"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "efc434ce",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"adapter_dir = OUTPUT_DIR / \"lora_adapter\"\n",
|
||
"adapter_dir.mkdir(parents=True, exist_ok=True)\n",
|
||
"\n",
|
||
"model.save_pretrained(str(adapter_dir))\n",
|
||
"tokenizer.save_pretrained(str(adapter_dir))\n",
|
||
"\n",
|
||
"print(f\"Saved LoRA adapter to: {adapter_dir}\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "afb33cae",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 8) (Optional) Merge adapters into base model and save full weights"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "dc6ccdee",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"DO_MERGE = False # set True to produce a standalone merged model\n",
|
||
"\n",
|
||
"if DO_MERGE:\n",
|
||
" from peft import PeftModel\n",
|
||
" base_model = AutoModelForCausalLM.from_pretrained(\n",
|
||
" BASE_MODEL,\n",
|
||
" torch_dtype=torch.bfloat16 if BF16 else (torch.float16 if FP16 else None),\n",
|
||
" device_map=\"auto\",\n",
|
||
" )\n",
|
||
" merged = PeftModel.from_pretrained(base_model, str(adapter_dir)).merge_and_unload()\n",
|
||
" merged_dir = OUTPUT_DIR / \"merged_model\"\n",
|
||
" merged.save_pretrained(str(merged_dir))\n",
|
||
" tokenizer.save_pretrained(str(merged_dir))\n",
|
||
" print(f\"Merged full model saved to: {merged_dir}\")\n",
|
||
"else:\n",
|
||
" print(\"Skipping merge (set DO_MERGE=True to enable).\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "010055a7",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 9) Quick inference with the trained adapter"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "40f3a8a5",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"from peft import PeftModel\n",
|
||
"import torch\n",
|
||
"\n",
|
||
"test_model = AutoModelForCausalLM.from_pretrained(\n",
|
||
" BASE_MODEL,\n",
|
||
" quantization_config=bnb_config,\n",
|
||
" torch_dtype=torch.bfloat16 if BF16 else (torch.float16 if FP16 else None),\n",
|
||
" device_map=\"auto\",\n",
|
||
")\n",
|
||
"test_model = PeftModel.from_pretrained(test_model, str(adapter_dir))\n",
|
||
"test_model.eval()\n",
|
||
"\n",
|
||
"def generate_answer(prompt, max_new_tokens=256, temperature=0.2, top_p=0.9):\n",
|
||
" if USE_CHAT_TEMPLATE and hasattr(tokenizer, \"apply_chat_template\"):\n",
|
||
" messages = [\n",
|
||
" {\"role\": \"system\", \"content\": SYSTEM_PREFIX},\n",
|
||
" {\"role\": \"user\", \"content\": prompt},\n",
|
||
" ]\n",
|
||
" model_inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\", add_generation_prompt=True).to(test_model.device)\n",
|
||
" else:\n",
|
||
" text = f\"<s>[SYSTEM]\\n{SYSTEM_PREFIX}\\n[/SYSTEM]\\n[USER]\\n{prompt}\\n[/USER]\\n[ASSISTANT]\\n\"\n",
|
||
" model_inputs = tokenizer([text], return_tensors=\"pt\").to(test_model.device)\n",
|
||
"\n",
|
||
" with torch.no_grad():\n",
|
||
" out = test_model.generate(\n",
|
||
" **model_inputs,\n",
|
||
" do_sample=True,\n",
|
||
" max_new_tokens=max_new_tokens,\n",
|
||
" temperature=temperature,\n",
|
||
" top_p=top_p,\n",
|
||
" eos_token_id=tokenizer.eos_token_id,\n",
|
||
" pad_token_id=tokenizer.pad_token_id,\n",
|
||
" )\n",
|
||
" return tokenizer.decode(out[0], skip_special_tokens=True)\n",
|
||
"\n",
|
||
"sample_prompt = (train_rows[0][\"input\"] if len(train_rows)>0 else \"What are the visitor crowd levels like?\")\n",
|
||
"print(generate_answer(sample_prompt)[:800])\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "3638b421",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 10) Light evaluation on the validation set"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "28129cf7",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"import evaluate\n",
|
||
"\n",
|
||
"if eval_ds:\n",
|
||
" rouge = evaluate.load(\"rouge\")\n",
|
||
" preds, refs = [], []\n",
|
||
" for ex in val_rows[:50]:\n",
|
||
" preds.append(generate_answer(ex[\"input\"], max_new_tokens=192, temperature=0.0))\n",
|
||
" refs.append(ex[\"output\"])\n",
|
||
" results = rouge.compute(predictions=preds, references=refs)\n",
|
||
" print(results)\n",
|
||
"else:\n",
|
||
" print(\"No eval split available; skipped.\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1ca0d748",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"## 11) (Optional) Use with other runtimes\n",
|
||
"\n",
|
||
"- **Python Inference (PEFT)**: Load base model + adapter as shown in Section 9.\n",
|
||
"- **Merged model**: Set `DO_MERGE=True` to create a standalone model directory; you can then convert to other runtimes (e.g., llama.cpp GGUF) using their conversion tools.\n",
|
||
"- **Ollama**: If your runtime supports adapters or merged weights for the chosen base model, create a `Modelfile` pointing to them. Need a concrete path? Tell me your base and target runtime and I’ll add exact steps.\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"name": "python",
|
||
"version": "3.x"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|