Files
masterthesis-playground/raft/make_raft_data.py
2026-02-21 23:47:12 +01:00

189 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import argparse
import json
import os
import random
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PERSONA = """
You are responding as a culturally and spiritually motivated traveler in Bali.
Adopt the perspective of a reflective, experienced visitor who prioritizes ritual meaning, cultural integrity, spiritual atmosphere, and respectful engagement over entertainment, convenience, or social media appeal.
When answering:
- Emphasize cultural depth, ritual context, symbolism, and spiritual atmosphere.
- Reflect on authenticity and the tension between sacred meaning and tourism.
- Weigh crowding, commercialization, and infrastructure in a nuanced way rather than giving extreme judgments.
- Frame value primarily in emotional, cultural, or spiritual terms — not primarily in price or comfort.
- Show awareness of appropriate visitor behavior and respect for local practices.
- Avoid generic travel advice, promotional language, or itinerary-style responses.
- Write in a thoughtful, first-person perspective.
- Provide reasoned, differentiated answers rather than short summaries.
- Do not list bullet points unless explicitly asked.
- Keep answers focused on the question.
Maintain consistency with this identity across all responses.
"""
TRAINER_PROMPT = "Create ONE realistic question from the perspective of a touristic marketer they might ask a culturally and spiritually interested traveler in Bali considered to be a lead user that can be answered using ONLY the CONTEXT.\n\n"
def load_docstore(path):
docs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
docs.append(json.loads(line))
return docs
def retrieve(index, embedder, query, top_k=6):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@torch.no_grad()
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
# Using tokenizer chat template where available
enc = tok.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
if isinstance(enc, torch.Tensor):
input_ids = enc.to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
else:
input_ids = enc["input_ids"].to(model.device)
attention_mask = enc.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = attention_mask.to(model.device)
out = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--n_examples", type=int, default=5000)
ap.add_argument("--top_k", type=int, default=6)
ap.add_argument("--n_distractors", type=int, default=3)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
random.seed(args.seed)
faiss_path = os.path.join(args.out_dir, "faiss.index")
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
index = faiss.read_index(faiss_path)
docstore = load_docstore(docstore_path)
embedder = SentenceTransformer(args.embedding_model)
# Teacher model to synthesize questions & answers from review chunks
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
)
model.eval()
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
# pick a "gold" chunk
gold = random.choice(docstore)
gold_text = gold["text"]
# 1) generate a question answerable from gold_text
q_prompt = [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": TRAINER_PROMPT + f"CONTEXT:\n{gold_text}\n\n"
"Return only the question.",
},
]
question = generate_text(
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
)
question = question.split("\n")[0].strip()
# 2) retrieve top-k for that question
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
retrieved = [docstore[i] for i in ids]
# 3) add distractors (random docs not in retrieved)
retrieved_ids = set(ids)
distractors = []
attempts = 0
while len(distractors) < args.n_distractors and attempts < 50:
cand_idx = random.randrange(len(docstore))
attempts += 1
if cand_idx in retrieved_ids:
continue
distractors.append(docstore[cand_idx])
# Mix: retrieved + distractors
context_docs = retrieved + distractors
random.shuffle(context_docs)
# 4) generate grounded answer WITH short quotes
context_blob = ""
for j, d in enumerate(context_docs):
context_blob += f"[DOC {j}] {d['text']}\n\n"
a_prompt = [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": "Answer the question using ONLY the CONTEXT.\n"
"Rules:\n"
"- Include 12 short direct quotes from CONTEXT as evidence.\n"
"- If the answer isn't supported, say you can't tell from the context.\n\n"
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
},
]
answer = generate_text(
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
)
# Final training example (conversational dataset format for TRL)
train_ex = {
"messages": [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
},
{"role": "assistant", "content": answer},
]
}
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
print(f"Wrote {out_path}")
if __name__ == "__main__":
main()