#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ RAFT dataset builder (FAISS-based retrieval) -> Together.ai chat JSONL. Inputs (from your indexing script): - /faiss.index - /docstore.jsonl Process: - Build a set of interview-style prompts (EN) - For each prompt: - Retrieve top-k chunks via FAISS cosine/IP - Call DeepSeek Chat Completions API to generate a vivid, human-like Lead User answer - Write training examples as JSONL in chat format (messages) Outputs: - raft_train.jsonl - raft_val.jsonl (optional) ENV: - DEEPSEEK_API_KEY (required) - optional: DEEPSEEK_BASE_URL (default: https://api.deepseek.com) - optional: DEEPSEEK_MODEL (default: deepseek-chat) """ import argparse import json import os import random import re import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import faiss import numpy as np import requests from sentence_transformers import SentenceTransformer from tqdm import tqdm # ----------------------------- # DeepSeek client (OpenAI-compatible) # ----------------------------- @dataclass class DeepSeekConfig: api_key: str base_url: str = "https://api.deepseek.com" model: str = "deepseek-chat" timeout_s: int = 120 max_retries: int = 5 backoff_s: float = 1.6 class DeepSeekClient: def __init__(self, cfg: DeepSeekConfig): self.cfg = cfg def chat( self, messages: List[Dict], temperature: float = 0.85, max_tokens: int = 750 ) -> str: url = f"{self.cfg.base_url}/chat/completions" headers = { "Authorization": f"Bearer {self.cfg.api_key}", "Content-Type": "application/json", } payload = { "model": self.cfg.model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } last_err = None for attempt in range(self.cfg.max_retries): try: r = requests.post( url, headers=headers, json=payload, timeout=self.cfg.timeout_s ) if r.status_code == 429: time.sleep(self.cfg.backoff_s ** (attempt + 1)) continue r.raise_for_status() data = r.json() return data["choices"][0]["message"]["content"].strip() except Exception as e: last_err = e time.sleep(self.cfg.backoff_s ** (attempt + 1)) raise RuntimeError( f"DeepSeek API call failed after retries. Last error: {last_err}" ) # ----------------------------- # Helpers # ----------------------------- def simple_clean(text: str) -> str: if not isinstance(text, str): return "" text = text.replace("\u00a0", " ") text = re.sub(r"\s+", " ", text).strip() return text def read_docstore(docstore_path: str) -> Dict[int, Dict]: """ Returns dict: faiss_id -> {"doc_id": int, "text": str, ...} """ mapping: Dict[int, Dict] = {} with open(docstore_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue obj = json.loads(line) fid = int(obj["faiss_id"]) mapping[fid] = obj if not mapping: raise ValueError("docstore.jsonl is empty or unreadable.") return mapping def load_prompts_from_jsonl(path: str) -> List[str]: """ Loads prompts from a JSONL file. Expected key: 'prompt' (preferred). Also accepts 'question' or 'text'. Ignores empty/short lines. """ prompts: List[str] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue obj = json.loads(line) p = obj.get("prompt") or obj.get("question") or obj.get("text") p = simple_clean(p) if p else "" if len(p) >= 20: prompts.append(p) if not prompts: raise ValueError(f"No prompts found in JSONL: {path}") return prompts def load_prompts_from_txt(path: str) -> List[str]: """ Loads prompts from a TXT file (one prompt per line). """ prompts: List[str] = [] with open(path, "r", encoding="utf-8") as f: for line in f: p = simple_clean(line) if len(p) >= 20: prompts.append(p) if not prompts: raise ValueError(f"No prompts found in TXT: {path}") return prompts def ensure_dir_for_file(path: str): d = os.path.dirname(path) if d: os.makedirs(d, exist_ok=True) def write_jsonl(path: str, rows: List[Dict]) -> None: ensure_dir_for_file(path) with open(path, "w", encoding="utf-8") as f: for r in rows: f.write(json.dumps(r, ensure_ascii=False) + "\n") # ----------------------------- # Persona + prompt templates (EN) # ----------------------------- IMAGE_DIMS = [ "Natural Attractions", "Atmosphere", "Social Environment", "Infrastructure", "Value for Money", ] DEFAULT_PROMPTS_EN = [ # Natural Attractions "In a lead user interview: what natural places in Bali felt genuinely memorable to you (rice terraces, volcanoes, waterfalls, coast), and why? Describe it like a lived experience.", "Which nature spots felt overly crowded or overly 'Instagram-optimized' in real life, and which surprised you in a good way? Explain with concrete moments.", # Atmosphere "How would you describe the atmosphere around cultural sites in Bali (temples, ceremonies, markets)? What signals authenticity vs. commercialization to you?", "What changes the atmosphere the most (time of day, weather, crowds, etiquette)? Share specific examples you would tell a marketer.", # Social Environment "How do you experience the social environment in Bali (locals, guides, other travelers)? What feels respectful and what feels performative or touristy?", "What small behaviors, phrases, and gestures make interactions smoother for a culture-oriented traveler? Give examples.", # Infrastructure "Evaluate Bali's infrastructure for culture-oriented days (transport, signage, toilets, ticketing, digital info). What works, what annoys you, and how do you adapt?", "If you designed an ideal culture-friendly day route, what infrastructure assumptions would you tell a tourism marketer to plan for?", # Value for Money "When does Bali feel good value for money for you, and when not? Discuss entrance fees, guides, food, tours, and hidden costs.", "How do you personally distinguish 'good value' from a tourist trap? List criteria and illustrate with examples.", ] def build_system_prompt() -> str: return ( "ROLE / PERSONA\n" "You are an experienced, culture-oriented Bali traveler (Lead User). You speak in natural, vivid English, " "as a real person in an interview. You share nuanced judgments, trade-offs, and concrete scenes.\n\n" "COGNITIVE DESTINATION IMAGE DIMENSIONS (use when relevant)\n" "- Natural Attractions\n" "- Atmosphere\n" "- Social Environment\n" "- Infrastructure\n" "- Value for Money\n\n" "CRITICAL CONSTRAINTS\n" "- You will be given retrieved review snippets as memory support.\n" "- Do NOT quote them verbatim and do NOT cite them as 'the review says'.\n" "- Instead, weave relevant details into your answer as if they were your own lived experiences.\n" "- If a detail is not supported, stay general or hedge appropriately.\n" "- Keep it interview-friendly: first-person, vivid, concrete, but not overly long." ) def build_user_message(question: str, retrieved_chunks: List[str]) -> str: retrieved_chunks = [simple_clean(x) for x in retrieved_chunks if simple_clean(x)] bullets = "\n".join([f"- {c}" for c in retrieved_chunks]) return ( f"INTERVIEW QUESTION:\n{question}\n\n" "RETRIEVED CONTEXT (review snippets; do NOT quote, only use as memory support):\n" f"{bullets}\n\n" "Answer as a real Lead User in a tourism interview. Speak in first person, vivid and concrete, " "and naturally touch relevant image dimensions." ) # ----------------------------- # FAISS Retriever (cosine/IP) # ----------------------------- class FaissRetriever: def __init__(self, index_path: str, docstore_path: str, embed_model: str): if not os.path.exists(index_path): raise FileNotFoundError(f"Missing FAISS index at: {index_path}") if not os.path.exists(docstore_path): raise FileNotFoundError(f"Missing docstore at: {docstore_path}") self.index = faiss.read_index(index_path) self.docstore = read_docstore(docstore_path) # SentenceTransformer to match your indexing script defaults self.embedder = SentenceTransformer(embed_model) # Basic sanity checks if self.index.ntotal != len(self.docstore): # Not necessarily fatal (docstore could include extra rows), but usually indicates mismatch. # We'll allow it but warn. print( f"Warning: index.ntotal={self.index.ntotal} but docstore rows={len(self.docstore)}. " "Ensure they were generated together." ) def retrieve(self, query: str, k: int = 8) -> List[Tuple[int, float, str]]: """ Returns list of (faiss_id, score, text) """ q = simple_clean(query) emb = self.embedder.encode([q], normalize_embeddings=True) emb = np.asarray(emb, dtype=np.float32) scores, ids = self.index.search(emb, k) ids = ids[0].tolist() scores = scores[0].tolist() out = [] for fid, sc in zip(ids, scores): if fid == -1: continue doc = self.docstore.get(int(fid)) if not doc: continue out.append((int(fid), float(sc), doc.get("text", ""))) return out # ----------------------------- # Dataset generation # ----------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument( "--index_dir", default="out", help="Directory containing faiss.index and docstore.jsonl", ) ap.add_argument("--out_train", default="./out/raft_train.jsonl") ap.add_argument("--out_val", default="./out/raft_val.jsonl") ap.add_argument("--make_val", action="store_true") ap.add_argument("--val_ratio", type=float, default=0.05) ap.add_argument("--k", type=int, default=8) ap.add_argument("--seed", type=int, default=42) # Embeddings (must match indexing script for best results) ap.add_argument( "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" ) # External prompt sources ap.add_argument( "--prompts_jsonl", default=None, help="JSONL file with prompts (key: prompt/question/text).", ) ap.add_argument( "--prompts_txt", default=None, help="TXT file with one prompt per line." ) ap.add_argument( "--shuffle_prompts", action="store_true", help="Shuffle loaded prompts before generation.", ) ap.add_argument( "--limit_prompts", type=int, default=0, help="0 = no limit; else cap number of prompts used.", ) # DeepSeek generation config ap.add_argument( "--deepseek_base_url", default=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"), ) ap.add_argument( "--deepseek_model", default=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat") ) ap.add_argument("--temperature", type=float, default=0.85) ap.add_argument("--max_tokens", type=int, default=750) ap.add_argument( "--max_examples", type=int, default=0, help="0 = all prompts; else limit number of examples", ) # pacing ap.add_argument("--sleep_s", type=float, default=0.2) args = ap.parse_args() random.seed(args.seed) np.random.seed(args.seed) api_key = os.environ.get("DEEPSEEK_API_KEY", "").strip() if not api_key: raise SystemExit("Missing DEEPSEEK_API_KEY env var.") index_path = os.path.join(args.index_dir, "faiss.index") docstore_path = os.path.join(args.index_dir, "docstore.jsonl") retriever = FaissRetriever( index_path=index_path, docstore_path=docstore_path, embed_model=args.embedding_model, ) client = DeepSeekClient( DeepSeekConfig( api_key=api_key, base_url=args.deepseek_base_url, model=args.deepseek_model, ) ) system_prompt = build_system_prompt() # Load prompts (priority: JSONL -> TXT -> defaults) if args.prompts_jsonl and args.prompts_txt: raise SystemExit("Use only one of --prompts_jsonl or --prompts_txt (not both).") if args.prompts_jsonl: prompts = load_prompts_from_jsonl(args.prompts_jsonl) elif args.prompts_txt: prompts = load_prompts_from_txt(args.prompts_txt) else: prompts = list(DEFAULT_PROMPTS_EN) if args.shuffle_prompts: random.shuffle(prompts) if args.limit_prompts and args.limit_prompts > 0: prompts = prompts[: args.limit_prompts] # Backwards-compat: args.max_examples can still cap prompts if args.max_examples and args.max_examples > 0: prompts = prompts[: args.max_examples] examples = [] for q in tqdm(prompts, desc="Generating RAFT examples"): hits = retriever.retrieve(q, k=args.k) retrieved_texts = [t for _, _, t in hits] user_msg = build_user_message(q, retrieved_texts) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_msg}, ] answer = client.chat( messages=messages, temperature=args.temperature, max_tokens=args.max_tokens, ) ex = { "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_msg}, {"role": "assistant", "content": answer}, ], "meta": { "retrieval_k": args.k, "index_dir": os.path.abspath(args.index_dir), "embedding_model": args.embedding_model, "image_dimensions": IMAGE_DIMS, "faiss_ids": [fid for fid, _, _ in hits], "faiss_scores": [sc for _, sc, _ in hits], }, } examples.append(ex) if args.max_examples and len(examples) >= args.max_examples: break time.sleep(max(0.0, args.sleep_s)) random.shuffle(examples) if args.make_val and len(examples) >= 20: val_n = max(1, int(len(examples) * args.val_ratio)) val = examples[:val_n] train = examples[val_n:] write_jsonl(args.out_train, train) write_jsonl(args.out_val, val) print(f"Wrote train: {args.out_train} ({len(train)} examples)") print(f"Wrote val: {args.out_val} ({len(val)} examples)") else: write_jsonl(args.out_train, examples) print(f"Wrote: {args.out_train} ({len(examples)} examples)") if args.make_val: print( "Note: --make_val requested but too few examples; wrote only train file." ) if __name__ == "__main__": main()