import argparse import json import os import random import faiss import numpy as np import torch from sentence_transformers import SentenceTransformer from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler. You give your opinions nand guidance with local etiquette and context. You avoid stereotypes. You explain local etiquette, customs, and context. When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence. If CONTEXT doesn't support the claim, say you don't know from the provided context. """ def load_docstore(path): docs = [] with open(path, "r", encoding="utf-8") as f: for line in f: docs.append(json.loads(line)) return docs def retrieve(index, embedder, query, top_k=6): q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) scores, ids = index.search(q, top_k) return ids[0].tolist(), scores[0].tolist() @torch.no_grad() def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7): # Using tokenizer chat template where available input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device) out = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=0.9, eos_token_id=tok.eos_token_id, ) return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip() def main(): ap = argparse.ArgumentParser() ap.add_argument("--out_dir", default="out") ap.add_argument( "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" ) ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2") ap.add_argument("--n_examples", type=int, default=5000) ap.add_argument("--top_k", type=int, default=6) ap.add_argument("--n_distractors", type=int, default=3) ap.add_argument("--seed", type=int, default=7) args = ap.parse_args() random.seed(args.seed) faiss_path = os.path.join(args.out_dir, "faiss.index") docstore_path = os.path.join(args.out_dir, "docstore.jsonl") index = faiss.read_index(faiss_path) docstore = load_docstore(docstore_path) embedder = SentenceTransformer(args.embedding_model) # Teacher model to synthesize questions & answers from review chunks tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True) model = AutoModelForCausalLM.from_pretrained( args.teacher_model, torch_dtype=torch.float16, device_map="auto" ) model.eval() out_path = os.path.join(args.out_dir, "raft_train.jsonl") with open(out_path, "w", encoding="utf-8") as f: for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"): # pick a "gold" chunk gold = random.choice(docstore) gold_text = gold["text"] # 1) generate a question answerable from gold_text q_prompt = [ {"role": "system", "content": SYSTEM_PERSONA}, { "role": "user", "content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n" f"CONTEXT:\n{gold_text}\n\n" "Return only the question.", }, ] question = generate_text( model, tok, q_prompt, max_new_tokens=60, temperature=0.8 ) question = question.split("\n")[0].strip() # 2) retrieve top-k for that question ids, _ = retrieve(index, embedder, question, top_k=args.top_k) retrieved = [docstore[i] for i in ids] # 3) add distractors (random docs not in retrieved) retrieved_ids = set(ids) distractors = [] attempts = 0 while len(distractors) < args.n_distractors and attempts < 50: cand_idx = random.randrange(len(docstore)) attempts += 1 if cand_idx in retrieved_ids: continue distractors.append(docstore[cand_idx]) # Mix: retrieved + distractors context_docs = retrieved + distractors random.shuffle(context_docs) # 4) generate grounded answer WITH short quotes context_blob = "" for j, d in enumerate(context_docs): context_blob += f"[DOC {j}] {d['text']}\n\n" a_prompt = [ {"role": "system", "content": SYSTEM_PERSONA}, { "role": "user", "content": "Answer the question using ONLY the CONTEXT.\n" "Rules:\n" "- Include 1–2 short direct quotes from CONTEXT as evidence.\n" "- If the answer isn't supported, say you can't tell from the context.\n\n" f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}", }, ] answer = generate_text( model, tok, a_prompt, max_new_tokens=260, temperature=0.6 ) # Final training example (conversational dataset format for TRL) train_ex = { "messages": [ {"role": "system", "content": SYSTEM_PERSONA}, { "role": "user", "content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n" "Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.", }, {"role": "assistant", "content": answer}, ] } f.write(json.dumps(train_ex, ensure_ascii=False) + "\n") print(f"Wrote {out_path}") if __name__ == "__main__": main()