diff --git a/raft/nb_raft_finetune_qlora.py b/raft/nb_raft_finetune_qlora.py index 3587a96..f559913 100644 --- a/raft/nb_raft_finetune_qlora.py +++ b/raft/nb_raft_finetune_qlora.py @@ -29,8 +29,9 @@ from peft import PeftModel from transformers import AutoModelForCausalLM # Paths -DATA_JSONL = Path("./outputs/raft_dataset.jsonl") # change if different -RUN_NAME = "raft_qlora_tourist_0.2" +# DATA_JSONL = Path("./outputs/raft_dataset.jsonl") # change if different +DATA_JSONL = Path("../raft/bali_culture_raft_dataset.jsonl") +RUN_NAME = "raft_qlora_tourist" OUTPUT_DIR = Path(f"./finetuned/{RUN_NAME}") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) ADAPTER_DIR = OUTPUT_DIR / "lora_adapter" diff --git a/raft/nb_raft_finetune_qlora_2.py b/raft/nb_raft_finetune_qlora_2.py new file mode 100644 index 0000000..550814e --- /dev/null +++ b/raft/nb_raft_finetune_qlora_2.py @@ -0,0 +1,677 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.0 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # QLoRA/RAFT Fine-Tuning +# + +# %% [markdown] +# ## Configuration +# + +# %% +from termcolor import colored +from pathlib import Path +from transformers import BitsAndBytesConfig +from torch import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM + +# Paths +DATA_JSONL = Path("../raft/remap_bali_raft_dataset.jsonl") # change if different +RUN_NAME = "raft_qlora_tourist" +OUTPUT_DIR = Path(f"./finetuned/{RUN_NAME}") +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) +ADAPTER_DIR = OUTPUT_DIR / "checkpoint-1550" + +# Base model — examples: "meta-llama/Llama-3.1-8B", "Qwen/Qwen2-7B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3" +# Prefer an instruction-tuned base for better stability on SFT. +BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3" + +# Tokenization/prompt formatting +SYSTEM_PREFIX = "You are a helpful assistant. Answer concisely and truthfully based ONLY on the user's request." +USE_CHAT_TEMPLATE = True # if the tokenizer has a chat template, we'll leverage it + +# BitsAndBytes config +BNB_CONFIG = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, +) + + +# %% [markdown] +# ## 2) Load dataset (JSONL) +# + +# %% +import json +import random +from datasets import Dataset + + +def read_jsonl(p: Path): + rows = [] + with p.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + if "input" in obj and "output" in obj: + rows.append(obj) + except Exception: + pass + return rows + + +rows = read_jsonl(DATA_JSONL) +print(f"Loaded {len(rows)} rows from {DATA_JSONL}") +print(rows[0]) + +random.Random(42).shuffle(rows) +split = int(len(rows) * 0.85) +train_rows = rows[:split] +val_rows = rows[split:] if split < len(rows) else rows[-max(1, len(rows) // 50) :] + +train_rows = [{"input": r["input"], "output": r["output"]} for r in train_rows] +val_rows = [{"input": r["input"], "output": r["output"]} for r in val_rows] + +train_ds = Dataset.from_list(train_rows) +eval_ds = Dataset.from_list(val_rows) if val_rows else None +train_ds, eval_ds + + +# %% [markdown] +# ## 3) Prompt formatting +# + +# %% +from transformers import AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) +tokenizer.pad_token = tokenizer.eos_token + +print(colored("Verifying eos and pad tokens...", "yellow")) +if tokenizer.pad_token_id != 2: + print(colored(f"Expected pad token to be 2, but got {tokenizer.pad_token}", "red")) +else: + print(colored("Pad token is ok", "green")) + +if tokenizer.eos_token_id != 2: + print(colored(f"Expected eos token to be 2, but got {tokenizer.eos_token}", "red")) +else: + print(colored("Eos token is ok", "green")) + + +def format_example(ex): + user = ex["input"] + assistant = ex["output"] + + messages = [ + {"role": "system", "content": SYSTEM_PREFIX}, + {"role": "user", "content": user}, + {"role": "assistant", "content": assistant}, + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=False + ) + return {"text": text} + + +train_ds_fmt = train_ds.map(format_example, remove_columns=train_ds.column_names) +eval_ds_fmt = ( + eval_ds.map(format_example, remove_columns=eval_ds.column_names) + if eval_ds + else None +) + +for i in range(10): + print("👉 " + train_ds_fmt[i]["text"]) + if train_ds_fmt[i]["text"][-4:] == tokenizer.eos_token: + print(f"✅ {colored('EOS is fine.', 'green')}") + else: + print(f"❌ {colored('EOS is missing.', 'red')}") + +# %% [markdown] +# ## 4) Tokenize +# + +# %% +IGNORE_INDEX = -100 + + +def make_supervised_tensors(batch): + enc = tokenizer( + batch["text"], + truncation=True, + max_length=2048, + padding="max_length", + return_tensors=None, + ) + input_ids = enc["input_ids"] + attn_mask = enc["attention_mask"] + + # Mask pads + labels = [ids[:] for ids in input_ids] + for i in range(len(labels)): + for j, m in enumerate(attn_mask[i]): + if m == 0: + labels[i][j] = IGNORE_INDEX + + return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels} + + +train_tok = train_ds_fmt.map( + make_supervised_tensors, batched=True, remove_columns=train_ds_fmt.column_names +) +eval_tok = ( + eval_ds_fmt.map( + make_supervised_tensors, batched=True, remove_columns=eval_ds_fmt.column_names + ) + if eval_ds_fmt + else None +) + +train_tok, eval_tok + +train_ds_fmt["text"][0] + + +# %% [markdown] +# ## Setup sanity check +# + +# %% +import transformers +import peft +import bitsandbytes as bnb +from bitsandbytes.nn import modules as bnb_modules + +print(colored("Sanity check...", "yellow")) +print("CUDA available:", torch.cuda.is_available()) +print("Torch version:", torch.__version__) +print("Transformers version:", transformers.__version__) +print( + "Compute capability:", + torch.cuda.get_device_capability(0) if torch.cuda.is_available() else "no cuda", +) +print("BitsAndbytes:", bnb.__version__) +print("PEFT:", peft.__version__) + + +print("Embedding4bit available:", hasattr(bnb_modules, "Embedding4bit")) + +# %% [markdown] +# ## 5) Load base model with 4-bit quantization and prepare QLoRA +# + +# %% +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training + +model = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, + quantization_config=BNB_CONFIG, + dtype=torch.bfloat16, + device_map="auto", +) + +model = prepare_model_for_kbit_training(model) + +peft_config = LoraConfig( + r=8, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], +) + +model = get_peft_model(model, peft_config) +model.print_trainable_parameters() + + +# %% [markdown] +# ## 6) Train +# + +# %% +from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling +import math + +data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + +args = TrainingArguments( + output_dir=str(OUTPUT_DIR), + run_name=RUN_NAME, + num_train_epochs=3, + per_device_train_batch_size=1, + per_device_eval_batch_size=1, + gradient_accumulation_steps=8, + learning_rate=2e-4, + warmup_ratio=0.05, + weight_decay=0.01, + logging_steps=25, + eval_steps=50, + save_steps=50, + save_total_limit=2, + bf16=True, + fp16=False, + gradient_checkpointing=True, + report_to=["none"], + seed=42, + eval_strategy="steps", + load_best_model_at_end=True, +) + +trainer = Trainer( + model=model, + args=args, + train_dataset=train_tok, + eval_dataset=eval_tok, + data_collator=data_collator, +) + +train_result = trainer.train() +metrics = trainer.evaluate() if eval_tok else {} +perplexity = ( + math.exp(metrics["eval_loss"]) if metrics and "eval_loss" in metrics else None +) +metrics, perplexity + + +# %% [markdown] +# | epochs | train_loss | eval_loss | +# | ------ | ---------- | --------- | +# | 50 | 4.377000 | 3.628506 | +# | 100 | 2.636800 | 2.558457 | +# | 150 | 2.428800 | 2.427239 | +# | 200 | 2.334800 | 2.193493 | +# | 250 | 2.188500 | 2.186310 | +# | 300 | 2.112400 | 2.173394 | +# | 350 | 2.122900 | 2.163947 | +# | 400 | 2.155400 | 2.162106 | +# | 450 | 2.072100 | 2.154830 | +# | 500 | 1.979900 | 2.165512 | +# | 550 | 1.935800 | 2.176313 | +# | 600 | 1.942800 | 2.170668 | +# | 650 | 1.968000 | 2.162810 | +# | 700 | 1.974100 | 2.167501 | +# | 750 | 1.801900 | 2.235841 | +# | 800 | 1.768000 | 2.233753 | +# | 850 | 1.779100 | 2.218278 | +# | 900 | 1.828900 | 2.220891 | +# | 950 | 1.854900 | 2.208387 | +# | 1000 | 1.653600 | 2.302763 | +# | 1050 | 1.663500 | 2.307982 | +# | 1100 | 1.673400 | 2.301423 | +# | 1150 | 1.608400 | 2.320958 | +# | 1200 | 1.683500 | 2.303580 | +# | 1250 | 1.532100 | 2.434277 | +# | 1300 | 1.558900 | 2.418276 | +# | 1350 | 1.508900 | 2.422347 | +# | 1400 | 1.535100 | 2.416650 | +# | 1450 | 1.529900 | 2.415497 | +# +# | Step | Training Loss | Evaluation Loss | +# | ---- | ------------- | --------------- | +# | 50 | 1.173100 | 1.040235 | +# | 100 | 0.882900 | 0.875235 | +# | 150 | 0.806600 | 0.820686 | +# | 200 | 0.785700 | 0.792914 | +# | 250 | 0.764300 | 0.761308 | +# | 300 | 0.733900 | 0.745976 | +# | 350 | 0.744000 | 0.732220 | +# | 400 | 0.712000 | 0.719414 | +# | 450 | 0.703800 | 0.709955 | +# | 500 | 0.684100 | 0.699460 | +# | 550 | 0.705900 | 0.691758 | +# | 600 | 0.683200 | 0.688031 | +# | 650 | 0.670100 | 0.680539 | +# | 700 | 0.681600 | 0.674205 | +# | 750 | 0.681500 | 0.671295 | +# | 800 | 0.651700 | 0.666133 | +# | 850 | 0.662900 | 0.660661 | +# | 900 | 0.651400 | 0.656359 | +# | 950 | 0.648100 | 0.653309 | +# | 1000 | 0.631500 | 0.648716 | +# | 1050 | 0.654200 | 0.643737 | +# | 1100 | 0.571100 | 0.648199 | +# | 1150 | 0.573500 | 0.648405 | +# | 1200 | 0.556000 | 0.644185 | +# | 1250 | 0.568100 | 0.642854 | +# | 1300 | 0.570200 | 0.640425 | +# | 1350 | 0.551100 | 0.636319 | +# | 1400 | 0.551400 | 0.634054 | +# | 1450 | 0.550100 | 0.631558 | +# | 1500 | 0.559800 | 0.630046 | +# | 1550 | 0.556600 | 0.626972 | +# + +# %% [markdown] +# ## 7) Save LoRA adapters +# + +# %% +ADAPTER_DIR.mkdir(parents=True, exist_ok=True) + +model.save_pretrained(str(ADAPTER_DIR)) +tokenizer.save_pretrained(str(ADAPTER_DIR)) + +print(f"Saved LoRA adapter to: {ADAPTER_DIR}") + + +# %% [markdown] +# ## 8) Save merged model +# + +# %% +# this does not work on my system since I don't have enough VRAM. +# it should work though provided you have sufficient resources. +# my next step would have been to convert the merged model to llama.cpp GGUF format so I can run it in Ollama/OpenWebUI. +DO_MERGE = False + +base_model = None +if DO_MERGE: + base_model = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + merged = PeftModel.from_pretrained( + base_model, str(ADAPTER_DIR), offload_folder="offload/", is_trainable=False + ).merge_and_unload() + merged_dir = OUTPUT_DIR / "merged_model" + merged.save_pretrained(str(merged_dir)) + tokenizer.save_pretrained(str(merged_dir)) + print(f"Merged full model saved to: {merged_dir}") +else: + print("Skipping merge (set DO_MERGE=True to enable).") + +# %% [markdown] +# ## 9) Quick inference with the trained adapter +# + +# %% +test_model = None + +print(colored("Loading the base model + trained adapter.", "green")) +test_model = AutoModelForCausalLM.from_pretrained( + BASE_MODEL, + quantization_config=BNB_CONFIG, + dtype=torch.bfloat16, + device_map="auto", +) +test_model = PeftModel.from_pretrained( + test_model, str(ADAPTER_DIR), offload_folder="offload/", is_trainable=False +) +test_model.eval() + + +def generate_answer(prompt, max_new_tokens=256, temperature=0.2, top_p=0.9): + messages = [ + {"role": "system", "content": SYSTEM_PREFIX}, + {"role": "user", "content": prompt}, + ] + model_inputs = tokenizer.apply_chat_template( + messages, return_tensors="pt", add_generation_prompt=True + ).to(test_model.device) + + gen_kwargs = {"input_ids": model_inputs} + + with torch.no_grad(): + out = test_model.generate( + **gen_kwargs, + do_sample=True, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + ) + return tokenizer.decode(out[0], skip_special_tokens=True) + + +sample_prompt = ( + train_rows[0]["input"] + if len(train_rows) > 0 + else "What are the visitor crowd levels like?" +) + +for i in range(10): + print(generate_answer(train_rows[i]["input"])[:800]) + print("---") + + +# %% +generate_answer("What are the visitor crowd levels like?") + + +# %% +def chat( + user, system="You are a precise assistant.", temperature=0.0, max_new_tokens=256 +): + msgs = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + model_inputs = tokenizer.apply_chat_template( + msgs, return_tensors="pt", add_generation_prompt=True + ).to(test_model.device) + gen_kwargs = {"input_ids": model_inputs} + with torch.no_grad(): + out = test_model.generate( + **gen_kwargs, + # **tokenizer(user, return_tensors="pt").to(test_model.device), + max_new_tokens=max_new_tokens, + do_sample=(temperature > 0), + temperature=temperature, + top_p=1.0, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id + ) + return tokenizer.decode(out[0], skip_special_tokens=True) + + +for i in range(10): + prompt = train_rows[i]["input"] + out = chat(prompt, max_new_tokens=2000, temperature=0.2) + + print("\n\n💬\n" + out) + +# %% [markdown] +# ## PoS Gradio setup +# + +# %% +# === Gradio chat for Mistral-Instruct (no self-replies) === +# Assumes: `test_model` (HF AutoModelForCausalLM + PEFT adapter) and `BASE_MODEL` are defined. + +import torch, threading +import gradio as gr +from transformers import ( + AutoTokenizer, + TextIteratorStreamer, + StoppingCriteria, + StoppingCriteriaList, +) + +# -- Tokenizer (use BASE model tokenizer) -- +tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True) + +# Ensure pad/eos exist and are consistent +if tokenizer.pad_token is None and tokenizer.eos_token is not None: + tokenizer.pad_token = tokenizer.eos_token +elif tokenizer.eos_token is None and tokenizer.pad_token is not None: + tokenizer.eos_token = tokenizer.pad_token +elif tokenizer.pad_token is None and tokenizer.eos_token is None: + tokenizer.add_special_tokens({"eos_token": ""}) + tokenizer.pad_token = tokenizer.eos_token + try: + test_model.resize_token_embeddings(len(tokenizer)) + except Exception: + pass + +DEVICE = getattr(test_model, "device", "cuda" if torch.cuda.is_available() else "cpu") +SYSTEM_PROMPT = "You are a helpful assistant." + + +# --- Custom stop: if the model starts a new user turn ([INST]) stop generation immediately. +# This prevents the model from “answering its own replies”. +class StopOnInst(StoppingCriteria): + def __init__(self, tokenizer, trigger_text="[INST]"): + self.trigger_ids = tokenizer.encode(trigger_text, add_special_tokens=False) + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs + ) -> bool: + if not self.trigger_ids: + return False + seq = input_ids[0].tolist() + tlen = len(self.trigger_ids) + if len(seq) < tlen: + return False + return seq[-tlen:] == self.trigger_ids + + +STOPPING = StoppingCriteriaList([StopOnInst(tokenizer)]) + + +def _build_inputs(pairs): + """ + pairs: list of (user, assistant) tuples. + We include prior completed assistant replies and the latest user with empty assistant, + then ask the model to continue as assistant. + """ + msgs = [{"role": "system", "content": SYSTEM_PROMPT}] + for u, a in pairs: + u = (u or "").strip() + a = (a or "").strip() + if not u and not a: + continue + if u: + msgs.append({"role": "user", "content": u}) + if a: + msgs.append({"role": "assistant", "content": a}) + + # Use chat template; many Mistral tokenizers return a single Tensor (input_ids) + input_ids = tokenizer.apply_chat_template( + msgs, add_generation_prompt=True, return_tensors="pt" + ) + if isinstance(input_ids, torch.Tensor): + inputs = {"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)} + else: + inputs = input_ids + return {k: v.to(DEVICE) for k, v in inputs.items()} + + +def stream_reply(history_pairs, max_new_tokens=512, temperature=0.7, top_p=0.9): + inputs = _build_inputs(history_pairs) + + streamer = TextIteratorStreamer( + tokenizer, skip_prompt=True, skip_special_tokens=True + ) + + gen_kwargs = dict( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=temperature, + top_p=top_p, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, # Mistral uses as EOS + streamer=streamer, + stopping_criteria=STOPPING, # <- key fix + ) + + with torch.inference_mode(): + t = threading.Thread(target=test_model.generate, kwargs=gen_kwargs) + t.start() + partial = "" + for piece in streamer: + partial += piece + yield partial + t.join() + + +# --- Gradio handlers --- + + +def gr_respond(message, chat_history): + message = (message or "").strip() + chat_history = chat_history or [] + # Append new user turn with empty assistant; we stream into that slot. + chat_history = chat_history + [(message, "")] + pairs = [(u or "", a or "") for (u, a) in chat_history] + + for partial in stream_reply(pairs): + chat_history[-1] = (message, partial) + yield "", chat_history # clears textbox, updates chat + + +def gr_clear(): + return None + + +with gr.Blocks() as demo: + gr.Markdown("## 💬 Chat with Touristral") + chat = gr.Chatbot(height=200, layout="bubble") + with gr.Row(): + msg = gr.Textbox(placeholder="Type a message and press Enter…", scale=9) + send = gr.Button("Send", scale=1) + with gr.Row(): + clear = gr.Button("Clear chat") + + msg.submit(gr_respond, [msg, chat], [msg, chat]) + send.click(gr_respond, [msg, chat], [msg, chat]) + clear.click(gr_clear, None, chat, queue=False) + +demo.queue().launch(share=False) + +# %% [markdown] +# ## 10) Light evaluation on the validation set +# + +# %% +import evaluate + +if eval_ds: + rouge = evaluate.load("rouge") + preds, refs = [], [] + for ex in val_rows[:50]: + preds.append(generate_answer(ex["input"], max_new_tokens=192, temperature=0.2)) + refs.append(ex["output"]) + results = rouge.compute(predictions=preds, references=refs) + print(results) +else: + print("No eval split available; skipped.") + + +# %% [markdown] +# ## 11) (Optional) Use with other runtimes +# +# - **Python Inference (PEFT)**: Load base model + adapter as shown in Section 9. +# - **Merged model**: Set `DO_MERGE=True` to create a standalone model directory; you can then convert to other runtimes (e.g., llama.cpp GGUF) using their conversion tools. +# - **Ollama**: If your runtime supports adapters or merged weights for the chosen base model, create a `Modelfile` pointing to them. Need a concrete path? Tell me your base and target runtime and I’ll add exact steps. +#