From 28823dc0b5dea5723bbcac982593cd98b130f1df Mon Sep 17 00:00:00 2001 From: Marvin Scham Date: Thu, 19 Feb 2026 14:24:42 +0100 Subject: [PATCH] New RAFT approach --- raft/make_raft_data.py | 162 +++++++++++++++++++++++++++++++++++++ raft/prepare_corpus.py | 122 ++++++++++++++++++++++++++++ raft/rag_chat.py | 89 ++++++++++++++++++++ raft/rag_chat_merged.py | 118 +++++++++++++++++++++++++++ raft/train_mistral_raft.py | 100 +++++++++++++++++++++++ 5 files changed, 591 insertions(+) create mode 100644 raft/make_raft_data.py create mode 100644 raft/prepare_corpus.py create mode 100644 raft/rag_chat.py create mode 100644 raft/rag_chat_merged.py create mode 100644 raft/train_mistral_raft.py diff --git a/raft/make_raft_data.py b/raft/make_raft_data.py new file mode 100644 index 0000000..5ebc616 --- /dev/null +++ b/raft/make_raft_data.py @@ -0,0 +1,162 @@ +import argparse +import json +import os +import random + +import faiss +import numpy as np +import torch +from sentence_transformers import SentenceTransformer +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +## Usage: python make_raft_data.py --out_dir out --n_examples 5000 + + +SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler. +You give your opinions nand guidance with local etiquette and context. +You avoid stereotypes. You explain local etiquette, customs, and context. +When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence. +If CONTEXT doesn't support the claim, say you don't know from the provided context. +""" + + +def load_docstore(path): + docs = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + docs.append(json.loads(line)) + return docs + + +def retrieve(index, embedder, query, top_k=6): + q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) + scores, ids = index.search(q, top_k) + return ids[0].tolist(), scores[0].tolist() + + +@torch.no_grad() +def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7): + # Using tokenizer chat template where available + input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device) + out = model.generate( + input_ids=input_ids, + max_new_tokens=max_new_tokens, + do_sample=True, + temperature=temperature, + top_p=0.9, + eos_token_id=tok.eos_token_id, + ) + return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--out_dir", default="out") + ap.add_argument( + "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" + ) + ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2") + ap.add_argument("--n_examples", type=int, default=5000) + ap.add_argument("--top_k", type=int, default=6) + ap.add_argument("--n_distractors", type=int, default=3) + ap.add_argument("--seed", type=int, default=7) + args = ap.parse_args() + + random.seed(args.seed) + + faiss_path = os.path.join(args.out_dir, "faiss.index") + docstore_path = os.path.join(args.out_dir, "docstore.jsonl") + + index = faiss.read_index(faiss_path) + docstore = load_docstore(docstore_path) + + embedder = SentenceTransformer(args.embedding_model) + + # Teacher model to synthesize questions & answers from review chunks + tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True) + model = AutoModelForCausalLM.from_pretrained( + args.teacher_model, torch_dtype=torch.float16, device_map="auto" + ) + model.eval() + + out_path = os.path.join(args.out_dir, "raft_train.jsonl") + with open(out_path, "w", encoding="utf-8") as f: + for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"): + # pick a "gold" chunk + gold = random.choice(docstore) + gold_text = gold["text"] + + # 1) generate a question answerable from gold_text + q_prompt = [ + {"role": "system", "content": SYSTEM_PERSONA}, + { + "role": "user", + "content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n" + f"CONTEXT:\n{gold_text}\n\n" + "Return only the question.", + }, + ] + question = generate_text( + model, tok, q_prompt, max_new_tokens=60, temperature=0.8 + ) + question = question.split("\n")[0].strip() + + # 2) retrieve top-k for that question + ids, _ = retrieve(index, embedder, question, top_k=args.top_k) + retrieved = [docstore[i] for i in ids] + + # 3) add distractors (random docs not in retrieved) + retrieved_ids = set(ids) + distractors = [] + attempts = 0 + while len(distractors) < args.n_distractors and attempts < 50: + cand_idx = random.randrange(len(docstore)) + attempts += 1 + if cand_idx in retrieved_ids: + continue + distractors.append(docstore[cand_idx]) + + # Mix: retrieved + distractors + context_docs = retrieved + distractors + random.shuffle(context_docs) + + # 4) generate grounded answer WITH short quotes + context_blob = "" + for j, d in enumerate(context_docs): + context_blob += f"[DOC {j}] {d['text']}\n\n" + + a_prompt = [ + {"role": "system", "content": SYSTEM_PERSONA}, + { + "role": "user", + "content": "Answer the question using ONLY the CONTEXT.\n" + "Rules:\n" + "- Include 1–2 short direct quotes from CONTEXT as evidence.\n" + "- If the answer isn't supported, say you can't tell from the context.\n\n" + f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}", + }, + ] + answer = generate_text( + model, tok, a_prompt, max_new_tokens=260, temperature=0.6 + ) + + # Final training example (conversational dataset format for TRL) + train_ex = { + "messages": [ + {"role": "system", "content": SYSTEM_PERSONA}, + { + "role": "user", + "content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n" + "Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.", + }, + {"role": "assistant", "content": answer}, + ] + } + f.write(json.dumps(train_ex, ensure_ascii=False) + "\n") + + print(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/raft/prepare_corpus.py b/raft/prepare_corpus.py new file mode 100644 index 0000000..ed29d1b --- /dev/null +++ b/raft/prepare_corpus.py @@ -0,0 +1,122 @@ +import argparse +import json +import os +import re + +import faiss +import numpy as np +import pandas as pd +from sentence_transformers import SentenceTransformer +from tqdm import tqdm + +## Usage: python prepare_corpus.py --input_tab your_reviews.tab --out_dir out + + +def simple_clean(text: str) -> str: + if not isinstance(text, str): + return "" + text = text.replace("\u00a0", " ") + text = re.sub(r"\s+", " ", text).strip() + return text + + +def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150): + """ + Simple char-based chunking (good enough for reviews). + For better chunking, split by sentences and cap token length. + """ + text = simple_clean(text) + if len(text) <= chunk_chars: + return [text] if text else [] + chunks = [] + i = 0 + while i < len(text): + chunk = text[i : i + chunk_chars] + if chunk: + chunks.append(chunk) + i += max(1, chunk_chars - overlap) + return chunks + + +def detect_text_col(df: pd.DataFrame) -> str: + # Heuristic: pick the longest average string column + best_col, best_score = None, -1 + for col in df.columns: + sample = df[col].dropna().astype(str).head(200) + if len(sample) == 0: + continue + avg_len = sample.map(len).mean() + if avg_len > best_score: + best_score = avg_len + best_col = col + if best_col is None: + raise ValueError("Could not detect a text column in the .tab file.") + return best_col + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--input_tab", required=True, help="Tripadvisor reviews .tab file") + ap.add_argument("--out_dir", default="out") + ap.add_argument( + "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" + ) + ap.add_argument("--chunk_chars", type=int, default=900) + ap.add_argument("--overlap", type=int, default=150) + args = ap.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + + # Many .tab files are TSV + df = pd.read_csv(args.input_tab, sep="\t", dtype=str, on_bad_lines="skip") + text_col = detect_text_col(df) + + rows = df[text_col].fillna("").astype(str).tolist() + + corpus_path = os.path.join(args.out_dir, "corpus.jsonl") + corpus = [] + doc_id = 0 + + for r in tqdm(rows, desc="Chunking"): + r = simple_clean(r) + if len(r) < 30: + continue + chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap) + for ch in chunks: + if len(ch) < 30: + continue + corpus.append({"doc_id": doc_id, "text": ch}) + doc_id += 1 + + with open(corpus_path, "w", encoding="utf-8") as f: + for ex in corpus: + f.write(json.dumps(ex, ensure_ascii=False) + "\n") + + # Build FAISS index + embedder = SentenceTransformer(args.embedding_model) + texts = [c["text"] for c in corpus] + embs = embedder.encode( + texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True + ) + embs = np.asarray(embs, dtype=np.float32) + + dim = embs.shape[1] + index = faiss.IndexFlatIP(dim) # cosine if normalized + index.add(embs) + + faiss_path = os.path.join(args.out_dir, "faiss.index") + faiss.write_index(index, faiss_path) + + # Store mapping doc row -> text + mapping_path = os.path.join(args.out_dir, "docstore.jsonl") + with open(mapping_path, "w", encoding="utf-8") as f: + for i, c in enumerate(corpus): + f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n") + + print( + f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}\nText column detected: {text_col}" + ) + + +if __name__ == "__main__": + main() diff --git a/raft/rag_chat.py b/raft/rag_chat.py new file mode 100644 index 0000000..35b1b22 --- /dev/null +++ b/raft/rag_chat.py @@ -0,0 +1,89 @@ +import argparse +import json +import os + +import faiss +import numpy as np +import torch +from peft import PeftModel +from sentence_transformers import SentenceTransformer +from transformers import AutoModelForCausalLM, AutoTokenizer + +## Usage: python rag_chat.py --lora_dir out/mistral_balitwin_lora + +SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler. +You give your opinions nand guidance with local etiquette and context. +Use the provided CONTEXT; include 1-2 short quotes as evidence. +If the context does not support the claim, say so. +""" + + +def load_docstore(path): + docs = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + docs.append(json.loads(line)) + return docs + + +def retrieve(index, embedder, query, top_k=6): + q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) + scores, ids = index.search(q, top_k) + return ids[0].tolist(), scores[0].tolist() + + +@torch.no_grad() +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2") + ap.add_argument("--lora_dir", default="out/mistral_balitwin_lora") + ap.add_argument("--out_dir", default="out") + ap.add_argument( + "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" + ) + ap.add_argument("--top_k", type=int, default=6) + args = ap.parse_args() + + index = faiss.read_index(os.path.join(args.out_dir, "faiss.index")) + docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl")) + embedder = SentenceTransformer(args.embedding_model) + + tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) + base = AutoModelForCausalLM.from_pretrained( + args.base_model, device_map="auto", torch_dtype=torch.float16 + ) + model = PeftModel.from_pretrained(base, args.lora_dir) + model.eval() + + print("Type your question (Ctrl+C to exit).") + while True: + q = input("\nYou: ").strip() + if not q: + continue + + ids, _ = retrieve(index, embedder, q, top_k=args.top_k) + context_docs = [docstore[i]["text"] for i in ids] + context_blob = "\n\n".join( + [f"[DOC {i}] {t}" for i, t in enumerate(context_docs)] + ) + + messages = [ + {"role": "system", "content": SYSTEM_PERSONA}, + {"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"}, + ] + inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device) + + out = model.generate( + inp, + max_new_tokens=320, + do_sample=True, + temperature=0.7, + top_p=0.9, + eos_token_id=tok.eos_token_id, + ) + ans = tok.decode(out[0][inp.shape[1] :], skip_special_tokens=True).strip() + print(f"\nBaliTwin: {ans}") + + +if __name__ == "__main__": + main() diff --git a/raft/rag_chat_merged.py b/raft/rag_chat_merged.py new file mode 100644 index 0000000..3e9b04f --- /dev/null +++ b/raft/rag_chat_merged.py @@ -0,0 +1,118 @@ +import argparse +import json +import os + +import faiss +import numpy as np +import torch +from sentence_transformers import SentenceTransformer +from transformers import AutoModelForCausalLM, AutoTokenizer + +## Usage: python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out + +SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler. +You give your opinions nand guidance with local etiquette and context. +Use the provided CONTEXT; include 1-2 short quotes as evidence. +If the context does not support the claim, say so. +""" + + +def load_docstore(path): + docs = [] + with open(path, "r", encoding="utf-8") as f: + for line in f: + docs.append(json.loads(line)) + return docs + + +def retrieve(index, embedder, query, top_k=6): + q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) + scores, ids = index.search(q, top_k) + return ids[0].tolist(), scores[0].tolist() + + +@torch.no_grad() +def main(): + ap = argparse.ArgumentParser() + ap.add_argument( + "--model_dir", required=True, help="Path to your finetuned model folder" + ) + ap.add_argument( + "--out_dir", default="out", help="Where faiss.index and docstore.jsonl live" + ) + ap.add_argument( + "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" + ) + ap.add_argument("--top_k", type=int, default=6) + ap.add_argument("--max_new_tokens", type=int, default=320) + args = ap.parse_args() + + index = faiss.read_index(os.path.join(args.out_dir, "faiss.index")) + docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl")) + embedder = SentenceTransformer(args.embedding_model) + + # Load your externally finetuned model directly from disk + tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True) + + # Important: ensure pad token exists for generation; Mistral often uses eos as pad + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_dir, + device_map="auto", + torch_dtype=torch.float16, + ) + model.eval() + + print("Type your question (Ctrl+C to exit).") + while True: + q = input("\nYou: ").strip() + if not q: + continue + + ids, _ = retrieve(index, embedder, q, top_k=args.top_k) + context_docs = [docstore[i]["text"] for i in ids] + context_blob = "\n\n".join( + [f"[DOC {i}] {t}" for i, t in enumerate(context_docs)] + ) + + messages = [ + {"role": "system", "content": SYSTEM_PERSONA}, + {"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"}, + ] + + # Use chat template from your folder (you have chat_template.jinja) + inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device) + + enc = tok.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" + ) + + if isinstance(enc, torch.Tensor): + input_ids = enc.to(model.device) + attention_mask = torch.ones_like(input_ids, device=model.device) + else: + input_ids = enc["input_ids"].to(model.device) + attention_mask = enc.get("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + attention_mask = attention_mask.to(model.device) + + out = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=args.max_new_tokens, + do_sample=True, + temperature=0.7, + top_p=0.9, + eos_token_id=tok.eos_token_id, + pad_token_id=tok.pad_token_id, + ) + + ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip() + print(f"\nBaliTwin: {ans}") + + +if __name__ == "__main__": + main() diff --git a/raft/train_mistral_raft.py b/raft/train_mistral_raft.py new file mode 100644 index 0000000..c833613 --- /dev/null +++ b/raft/train_mistral_raft.py @@ -0,0 +1,100 @@ +import argparse +import json +import os + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from trl import SFTConfig, SFTTrainer + +## Usage: python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mistral_balitwin_lora + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--train_jsonl", default="out/raft_train.jsonl") + ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3") + ap.add_argument("--out_dir", default="out/mistral_balitwin_lora") + ap.add_argument("--max_seq_len", type=int, default=2048) + ap.add_argument("--batch_size", type=int, default=1) + ap.add_argument("--grad_accum", type=int, default=16) + ap.add_argument("--lr", type=float, default=2e-4) + ap.add_argument("--epochs", type=int, default=1) + args = ap.parse_args() + + os.makedirs(args.out_dir, exist_ok=True) + + # QLoRA (4-bit) config (good default for 7B on limited VRAM) + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=( + torch.bfloat16 if torch.cuda.is_available() else torch.float16 + ), + bnb_4bit_use_double_quant=True, + ) + + tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True) + # Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9} + + model = AutoModelForCausalLM.from_pretrained( + args.base_model, + device_map="auto", + quantization_config=bnb_config, + torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16, + ) + + # LoRA adapter config (tweak r/alpha if needed) + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + ) + + dataset = load_dataset("json", data_files=args.train_jsonl, split="train") + + training_args = SFTConfig( + output_dir=args.out_dir, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch_size, + gradient_accumulation_steps=args.grad_accum, + learning_rate=args.lr, + logging_steps=10, + save_steps=200, + save_total_limit=2, + max_length=args.max_seq_len, + bf16=torch.cuda.is_available(), + fp16=not torch.cuda.is_available(), + assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10} + report_to=[], + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + peft_config=peft_config, + ) + + trainer.train() + trainer.save_model(args.out_dir) + tokenizer.save_pretrained(args.out_dir) + + print(f"Saved LoRA adapter to: {args.out_dir}") + + +if __name__ == "__main__": + main()