Files
masterthesis-playground/raft/make_raft_data.py

160 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import argparse
import json
import os
import random
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
You give your opinions nand guidance with local etiquette and context.
You avoid stereotypes. You explain local etiquette, customs, and context.
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
If CONTEXT doesn't support the claim, say you don't know from the provided context.
"""
def load_docstore(path):
docs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
docs.append(json.loads(line))
return docs
def retrieve(index, embedder, query, top_k=6):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@torch.no_grad()
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
# Using tokenizer chat template where available
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
out = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
eos_token_id=tok.eos_token_id,
)
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--n_examples", type=int, default=5000)
ap.add_argument("--top_k", type=int, default=6)
ap.add_argument("--n_distractors", type=int, default=3)
ap.add_argument("--seed", type=int, default=7)
args = ap.parse_args()
random.seed(args.seed)
faiss_path = os.path.join(args.out_dir, "faiss.index")
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
index = faiss.read_index(faiss_path)
docstore = load_docstore(docstore_path)
embedder = SentenceTransformer(args.embedding_model)
# Teacher model to synthesize questions & answers from review chunks
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
)
model.eval()
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
# pick a "gold" chunk
gold = random.choice(docstore)
gold_text = gold["text"]
# 1) generate a question answerable from gold_text
q_prompt = [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
f"CONTEXT:\n{gold_text}\n\n"
"Return only the question.",
},
]
question = generate_text(
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
)
question = question.split("\n")[0].strip()
# 2) retrieve top-k for that question
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
retrieved = [docstore[i] for i in ids]
# 3) add distractors (random docs not in retrieved)
retrieved_ids = set(ids)
distractors = []
attempts = 0
while len(distractors) < args.n_distractors and attempts < 50:
cand_idx = random.randrange(len(docstore))
attempts += 1
if cand_idx in retrieved_ids:
continue
distractors.append(docstore[cand_idx])
# Mix: retrieved + distractors
context_docs = retrieved + distractors
random.shuffle(context_docs)
# 4) generate grounded answer WITH short quotes
context_blob = ""
for j, d in enumerate(context_docs):
context_blob += f"[DOC {j}] {d['text']}\n\n"
a_prompt = [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": "Answer the question using ONLY the CONTEXT.\n"
"Rules:\n"
"- Include 12 short direct quotes from CONTEXT as evidence.\n"
"- If the answer isn't supported, say you can't tell from the context.\n\n"
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
},
]
answer = generate_text(
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
)
# Final training example (conversational dataset format for TRL)
train_ex = {
"messages": [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n"
"Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.",
},
{"role": "assistant", "content": answer},
]
}
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
print(f"Wrote {out_path}")
if __name__ == "__main__":
main()