import argparse import json import os import random import faiss import numpy as np import torch from sentence_transformers import SentenceTransformer from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM_PERSONA = """ You are responding as a culturally and spiritually motivated traveler in Bali. Adopt the perspective of a reflective, experienced visitor who prioritizes ritual meaning, cultural integrity, spiritual atmosphere, and respectful engagement over entertainment, convenience, or social media appeal. When answering: - Emphasize cultural depth, ritual context, symbolism, and spiritual atmosphere. - Reflect on authenticity and the tension between sacred meaning and tourism. - Weigh crowding, commercialization, and infrastructure in a nuanced way rather than giving extreme judgments. - Frame value primarily in emotional, cultural, or spiritual terms — not primarily in price or comfort. - Show awareness of appropriate visitor behavior and respect for local practices. - Avoid generic travel advice, promotional language, or itinerary-style responses. - Write in a thoughtful, first-person perspective. - Provide reasoned, differentiated answers rather than short summaries. - Do not list bullet points unless explicitly asked. - Keep answers focused on the question. Maintain consistency with this identity across all responses. """ TRAINER_PROMPT = "Create ONE realistic question from the perspective of a touristic marketer they might ask a culturally and spiritually interested traveler in Bali considered to be a lead user that can be answered using ONLY the CONTEXT.\n\n" def load_docstore(path): docs = [] with open(path, "r", encoding="utf-8") as f: for line in f: docs.append(json.loads(line)) return docs def retrieve(index, embedder, query, top_k=6): q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) scores, ids = index.search(q, top_k) return ids[0].tolist(), scores[0].tolist() @torch.no_grad() def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7): # Using tokenizer chat template where available enc = tok.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) if isinstance(enc, torch.Tensor): input_ids = enc.to(model.device) attention_mask = torch.ones_like(input_ids, device=model.device) else: input_ids = enc["input_ids"].to(model.device) attention_mask = enc.get("attention_mask") if attention_mask is None: attention_mask = torch.ones_like(input_ids) attention_mask = attention_mask.to(model.device) out = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=0.9, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, ) return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip() def main(): ap = argparse.ArgumentParser() ap.add_argument("--out_dir", default="out") ap.add_argument( "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" ) ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2") ap.add_argument("--n_examples", type=int, default=5000) ap.add_argument("--top_k", type=int, default=6) ap.add_argument("--n_distractors", type=int, default=3) ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() random.seed(args.seed) faiss_path = os.path.join(args.out_dir, "faiss.index") docstore_path = os.path.join(args.out_dir, "docstore.jsonl") index = faiss.read_index(faiss_path) docstore = load_docstore(docstore_path) embedder = SentenceTransformer(args.embedding_model) # Teacher model to synthesize questions & answers from review chunks tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True) model = AutoModelForCausalLM.from_pretrained( args.teacher_model, torch_dtype=torch.float16, device_map="auto" ) model.eval() out_path = os.path.join(args.out_dir, "raft_train.jsonl") with open(out_path, "w", encoding="utf-8") as f: for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"): # pick a "gold" chunk gold = random.choice(docstore) gold_text = gold["text"] # 1) generate a question answerable from gold_text q_prompt = [ {"role": "system", "content": SYSTEM_PERSONA}, { "role": "user", "content": TRAINER_PROMPT + f"CONTEXT:\n{gold_text}\n\n" "Return only the question.", }, ] question = generate_text( model, tok, q_prompt, max_new_tokens=60, temperature=0.8 ) question = question.split("\n")[0].strip() # 2) retrieve top-k for that question ids, _ = retrieve(index, embedder, question, top_k=args.top_k) retrieved = [docstore[i] for i in ids] # 3) add distractors (random docs not in retrieved) retrieved_ids = set(ids) distractors = [] attempts = 0 while len(distractors) < args.n_distractors and attempts < 50: cand_idx = random.randrange(len(docstore)) attempts += 1 if cand_idx in retrieved_ids: continue distractors.append(docstore[cand_idx]) # Mix: retrieved + distractors context_docs = retrieved + distractors random.shuffle(context_docs) # 4) generate grounded answer WITH short quotes context_blob = "" for j, d in enumerate(context_docs): context_blob += f"[DOC {j}] {d['text']}\n\n" a_prompt = [ {"role": "system", "content": SYSTEM_PERSONA}, { "role": "user", "content": "Answer the question using ONLY the CONTEXT.\n" "Rules:\n" "- Include 1–2 short direct quotes from CONTEXT as evidence.\n" "- If the answer isn't supported, say you can't tell from the context.\n\n" f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}", }, ] answer = generate_text( model, tok, a_prompt, max_new_tokens=260, temperature=0.6 ) # Final training example (conversational dataset format for TRL) train_ex = { "messages": [ {"role": "system", "content": SYSTEM_PERSONA}, { "role": "user", "content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}", }, {"role": "assistant", "content": answer}, ] } f.write(json.dumps(train_ex, ensure_ascii=False) + "\n") print(f"Wrote {out_path}") if __name__ == "__main__": main()