mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
422 lines
13 KiB
Python
422 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
RAFT dataset builder with FAISS-based retrieval.
|
|
|
|
Inputs:
|
|
- faiss.index
|
|
- docstore.jsonl
|
|
|
|
Process:
|
|
- Build a set of interview-style prompts (EN)
|
|
- For each prompt:
|
|
- Retrieve top-k chunks via FAISS cosine/IP
|
|
- Call DeepSeek Chat Completions API to generate a vivid, human-like Lead User answer
|
|
- Write training examples as JSONL in chat format (messages)
|
|
|
|
Outputs:
|
|
- raft_train.jsonl
|
|
- raft_val.jsonl (optional)
|
|
|
|
ENV:
|
|
- DEEPSEEK_API_KEY
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import random
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List, Tuple
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import requests
|
|
from sentence_transformers import SentenceTransformer
|
|
from tqdm import tqdm
|
|
|
|
|
|
@dataclass
|
|
class DeepSeekConfig:
|
|
api_key: str
|
|
base_url: str = "https://api.deepseek.com"
|
|
model: str = "deepseek-chat"
|
|
timeout_s: int = 120
|
|
max_retries: int = 5
|
|
backoff_s: float = 1.6
|
|
|
|
|
|
class DeepSeekClient:
|
|
def __init__(self, cfg: DeepSeekConfig):
|
|
self.cfg = cfg
|
|
|
|
def chat(
|
|
self, messages: List[Dict], temperature: float = 0.85, max_tokens: int = 750
|
|
) -> str:
|
|
url = f"{self.cfg.base_url}/chat/completions"
|
|
headers = {
|
|
"Authorization": f"Bearer {self.cfg.api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
payload = {
|
|
"model": self.cfg.model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
|
|
last_err = None
|
|
for attempt in range(self.cfg.max_retries):
|
|
try:
|
|
r = requests.post(
|
|
url, headers=headers, json=payload, timeout=self.cfg.timeout_s
|
|
)
|
|
if r.status_code == 429:
|
|
time.sleep(self.cfg.backoff_s ** (attempt + 1))
|
|
continue
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
return data["choices"][0]["message"]["content"].strip()
|
|
except Exception as e:
|
|
last_err = e
|
|
time.sleep(self.cfg.backoff_s ** (attempt + 1))
|
|
|
|
raise RuntimeError(f"DeepSeek API call failed. Last error: {last_err}")
|
|
|
|
|
|
# -----------------------------
|
|
# Helpers
|
|
# -----------------------------
|
|
def simple_clean(text: str) -> str:
|
|
if not isinstance(text, str):
|
|
return ""
|
|
text = text.replace("\u00a0", " ")
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|
return text
|
|
|
|
|
|
def read_docstore(docstore_path: str) -> Dict[int, Dict]:
|
|
"""
|
|
Returns dict: faiss_id -> {"doc_id": int, "text": str, ...}
|
|
"""
|
|
mapping: Dict[int, Dict] = {}
|
|
with open(docstore_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
obj = json.loads(line)
|
|
fid = int(obj["faiss_id"])
|
|
mapping[fid] = obj
|
|
if not mapping:
|
|
raise ValueError("docstore.jsonl is broken.")
|
|
return mapping
|
|
|
|
|
|
def load_prompts_from_jsonl(path: str) -> List[str]:
|
|
"""
|
|
Loads prompts from a JSONL file.
|
|
"""
|
|
prompts: List[str] = []
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
obj = json.loads(line)
|
|
p = obj.get("prompt") or obj.get("question") or obj.get("text")
|
|
p = simple_clean(p) if p else ""
|
|
if len(p) >= 20:
|
|
prompts.append(p)
|
|
if not prompts:
|
|
raise ValueError(f"No prompts in JSONL: {path}")
|
|
return prompts
|
|
|
|
|
|
def load_prompts_from_txt(path: str) -> List[str]:
|
|
"""
|
|
Loads prompts from a TXT file (each line is a prompt).
|
|
"""
|
|
prompts: List[str] = []
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
p = simple_clean(line)
|
|
if len(p) >= 20:
|
|
prompts.append(p)
|
|
if not prompts:
|
|
raise ValueError(f"No prompts in TXT: {path}")
|
|
return prompts
|
|
|
|
|
|
def ensure_dir_for_file(path: str):
|
|
d = os.path.dirname(path)
|
|
if d:
|
|
os.makedirs(d, exist_ok=True)
|
|
|
|
|
|
def write_jsonl(path: str, rows: List[Dict]) -> None:
|
|
ensure_dir_for_file(path)
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
for r in rows:
|
|
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
|
|
|
|
|
IMAGE_DIMS = [
|
|
"Natural Attractions",
|
|
"Atmosphere",
|
|
"Social Environment",
|
|
"Infrastructure",
|
|
"Value for Money",
|
|
]
|
|
|
|
|
|
def build_system_prompt() -> str:
|
|
return (
|
|
"ROLE / PERSONA\n"
|
|
"You are an experienced, culture-oriented Bali traveler (Lead User). You speak in natural, vivid English, "
|
|
"as a real person in an interview. You share nuanced judgments, trade-offs, and concrete scenes.\n\n"
|
|
"CRITICAL CONSTRAINTS\n"
|
|
"- You will be given retrieved review snippets as memory support.\n"
|
|
"- Do NOT quote them verbatim and do NOT cite them as 'the review says'.\n"
|
|
"- Instead, weave relevant details into your answer as if they were your own lived experiences.\n"
|
|
"- If a detail is not supported, stay general or hedge appropriately.\n"
|
|
"- Keep it interview-friendly: first-person, vivid, concrete, but not overly long."
|
|
)
|
|
|
|
|
|
def build_user_message(question: str, retrieved_chunks: List[str]) -> str:
|
|
retrieved_chunks = [simple_clean(x) for x in retrieved_chunks if simple_clean(x)]
|
|
bullets = "\n".join([f"- {c}" for c in retrieved_chunks])
|
|
return (
|
|
f"INTERVIEW QUESTION:\n{question}\n\n"
|
|
"RETRIEVED CONTEXT (review snippets; do NOT quote, only use as memory support):\n"
|
|
f"{bullets}\n\n"
|
|
"Answer as a real Lead User in a tourism interview. Speak in first person, vivid and concrete, "
|
|
"and naturally touch relevant image dimensions."
|
|
)
|
|
|
|
|
|
# -----------------------------
|
|
# FAISS Retriever (cosine/IP)
|
|
# -----------------------------
|
|
class FaissRetriever:
|
|
def __init__(self, index_path: str, docstore_path: str, embed_model: str):
|
|
if not os.path.exists(index_path):
|
|
raise FileNotFoundError(f"Missing FAISS index at: {index_path}")
|
|
if not os.path.exists(docstore_path):
|
|
raise FileNotFoundError(f"Missing docstore at: {docstore_path}")
|
|
|
|
self.index = faiss.read_index(index_path)
|
|
self.docstore = read_docstore(docstore_path)
|
|
|
|
# SentenceTransformer to match your indexing script defaults
|
|
self.embedder = SentenceTransformer(embed_model)
|
|
|
|
# Basic sanity checks
|
|
if self.index.ntotal != len(self.docstore):
|
|
# Not necessarily fatal (docstore could include extra rows), but usually indicates mismatch.
|
|
# We'll allow it but warn.
|
|
print(
|
|
f"Warning: index.ntotal={self.index.ntotal} but docstore rows={len(self.docstore)}. "
|
|
"Ensure they were generated together."
|
|
)
|
|
|
|
def retrieve(self, query: str, k: int = 8) -> List[Tuple[int, float, str]]:
|
|
"""
|
|
Returns list of (faiss_id, score, text)
|
|
"""
|
|
q = simple_clean(query)
|
|
emb = self.embedder.encode([q], normalize_embeddings=True)
|
|
emb = np.asarray(emb, dtype=np.float32)
|
|
|
|
scores, ids = self.index.search(emb, k)
|
|
ids = ids[0].tolist()
|
|
scores = scores[0].tolist()
|
|
|
|
out = []
|
|
for fid, sc in zip(ids, scores):
|
|
if fid == -1:
|
|
continue
|
|
doc = self.docstore.get(int(fid))
|
|
if not doc:
|
|
continue
|
|
out.append((int(fid), float(sc), doc.get("text", "")))
|
|
return out
|
|
|
|
|
|
# -----------------------------
|
|
# Dataset generation
|
|
# -----------------------------
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument(
|
|
"--index_dir",
|
|
default="out",
|
|
help="Directory containing faiss.index and docstore.jsonl",
|
|
)
|
|
ap.add_argument("--out_train", default="./out/raft_train.jsonl")
|
|
ap.add_argument("--out_val", default="./out/raft_val.jsonl")
|
|
ap.add_argument("--make_val", action="store_true")
|
|
ap.add_argument("--val_ratio", type=float, default=0.05)
|
|
ap.add_argument("--k", type=int, default=8)
|
|
ap.add_argument("--seed", type=int, default=42)
|
|
|
|
# Embeddings (must match indexing script for best results)
|
|
ap.add_argument(
|
|
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
|
|
# External prompt sources
|
|
ap.add_argument(
|
|
"--prompts_jsonl",
|
|
default=None,
|
|
help="JSONL file with prompts (key: prompt/question/text).",
|
|
)
|
|
ap.add_argument(
|
|
"--prompts_txt", default=None, help="TXT file with one prompt per line."
|
|
)
|
|
ap.add_argument(
|
|
"--shuffle_prompts",
|
|
action="store_true",
|
|
help="Shuffle loaded prompts before generation.",
|
|
)
|
|
ap.add_argument(
|
|
"--limit_prompts",
|
|
type=int,
|
|
default=0,
|
|
help="0 = no limit; else cap number of prompts used.",
|
|
)
|
|
|
|
# DeepSeek generation config
|
|
ap.add_argument(
|
|
"--deepseek_base_url",
|
|
default=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
|
|
)
|
|
ap.add_argument(
|
|
"--deepseek_model", default=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
|
|
)
|
|
ap.add_argument("--temperature", type=float, default=0.85)
|
|
ap.add_argument("--max_tokens", type=int, default=750)
|
|
ap.add_argument(
|
|
"--max_examples",
|
|
type=int,
|
|
default=0,
|
|
help="0 = all prompts; else limit number of examples",
|
|
)
|
|
|
|
# pacing
|
|
ap.add_argument("--sleep_s", type=float, default=0.2)
|
|
|
|
args = ap.parse_args()
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
|
|
api_key = os.environ.get("DEEPSEEK_API_KEY", "").strip()
|
|
if not api_key:
|
|
raise SystemExit("Missing DEEPSEEK_API_KEY env var.")
|
|
|
|
index_path = os.path.join(args.index_dir, "faiss.index")
|
|
docstore_path = os.path.join(args.index_dir, "docstore.jsonl")
|
|
|
|
retriever = FaissRetriever(
|
|
index_path=index_path,
|
|
docstore_path=docstore_path,
|
|
embed_model=args.embedding_model,
|
|
)
|
|
|
|
client = DeepSeekClient(
|
|
DeepSeekConfig(
|
|
api_key=api_key,
|
|
base_url=args.deepseek_base_url,
|
|
model=args.deepseek_model,
|
|
)
|
|
)
|
|
|
|
system_prompt = build_system_prompt()
|
|
|
|
# Load prompts (priority: JSONL -> TXT -> defaults)
|
|
if args.prompts_jsonl and args.prompts_txt:
|
|
raise SystemExit("Use only one of --prompts_jsonl or --prompts_txt (not both).")
|
|
|
|
if args.prompts_jsonl:
|
|
prompts = load_prompts_from_jsonl(args.prompts_jsonl)
|
|
elif args.prompts_txt:
|
|
prompts = load_prompts_from_txt(args.prompts_txt)
|
|
else:
|
|
print("Provide a prompt source with --prompts_jsonl or --prompts_txt.")
|
|
exit(1)
|
|
|
|
if args.shuffle_prompts:
|
|
random.shuffle(prompts)
|
|
|
|
if args.limit_prompts and args.limit_prompts > 0:
|
|
prompts = prompts[: args.limit_prompts]
|
|
|
|
# Backwards-compat: args.max_examples can still cap prompts
|
|
if args.max_examples and args.max_examples > 0:
|
|
prompts = prompts[: args.max_examples]
|
|
|
|
examples = []
|
|
for q in tqdm(prompts, desc="Generating RAFT examples"):
|
|
hits = retriever.retrieve(q, k=args.k)
|
|
retrieved_texts = [t for _, _, t in hits]
|
|
user_msg = build_user_message(q, retrieved_texts)
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_msg},
|
|
]
|
|
|
|
answer = client.chat(
|
|
messages=messages,
|
|
temperature=args.temperature,
|
|
max_tokens=args.max_tokens,
|
|
)
|
|
|
|
ex = {
|
|
"messages": [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_msg},
|
|
{"role": "assistant", "content": answer},
|
|
],
|
|
"meta": {
|
|
"retrieval_k": args.k,
|
|
"index_dir": os.path.abspath(args.index_dir),
|
|
"embedding_model": args.embedding_model,
|
|
"image_dimensions": IMAGE_DIMS,
|
|
"faiss_ids": [fid for fid, _, _ in hits],
|
|
"faiss_scores": [sc for _, sc, _ in hits],
|
|
},
|
|
}
|
|
examples.append(ex)
|
|
|
|
if args.max_examples and len(examples) >= args.max_examples:
|
|
break
|
|
|
|
time.sleep(max(0.0, args.sleep_s))
|
|
|
|
random.shuffle(examples)
|
|
|
|
if args.make_val and len(examples) >= 20:
|
|
val_n = max(1, int(len(examples) * args.val_ratio))
|
|
val = examples[:val_n]
|
|
train = examples[val_n:]
|
|
write_jsonl(args.out_train, train)
|
|
write_jsonl(args.out_val, val)
|
|
print(f"Wrote train: {args.out_train} ({len(train)} examples)")
|
|
print(f"Wrote val: {args.out_val} ({len(val)} examples)")
|
|
else:
|
|
write_jsonl(args.out_train, examples)
|
|
print(f"Wrote: {args.out_train} ({len(examples)} examples)")
|
|
if args.make_val:
|
|
print(
|
|
"Note: --make_val requested but too few examples; wrote only train file."
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|