mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 08:22:43 +01:00
160 lines
5.9 KiB
Python
160 lines
5.9 KiB
Python
import argparse
|
||
import json
|
||
import os
|
||
import random
|
||
|
||
import faiss
|
||
import numpy as np
|
||
import torch
|
||
from sentence_transformers import SentenceTransformer
|
||
from tqdm import tqdm
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||
You give your opinions nand guidance with local etiquette and context.
|
||
You avoid stereotypes. You explain local etiquette, customs, and context.
|
||
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
|
||
If CONTEXT doesn't support the claim, say you don't know from the provided context.
|
||
"""
|
||
|
||
|
||
def load_docstore(path):
|
||
docs = []
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
docs.append(json.loads(line))
|
||
return docs
|
||
|
||
|
||
def retrieve(index, embedder, query, top_k=6):
|
||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||
scores, ids = index.search(q, top_k)
|
||
return ids[0].tolist(), scores[0].tolist()
|
||
|
||
|
||
@torch.no_grad()
|
||
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
|
||
# Using tokenizer chat template where available
|
||
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||
out = model.generate(
|
||
input_ids=input_ids,
|
||
max_new_tokens=max_new_tokens,
|
||
do_sample=True,
|
||
temperature=temperature,
|
||
top_p=0.9,
|
||
eos_token_id=tok.eos_token_id,
|
||
)
|
||
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--out_dir", default="out")
|
||
ap.add_argument(
|
||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||
)
|
||
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||
ap.add_argument("--n_examples", type=int, default=5000)
|
||
ap.add_argument("--top_k", type=int, default=6)
|
||
ap.add_argument("--n_distractors", type=int, default=3)
|
||
ap.add_argument("--seed", type=int, default=7)
|
||
args = ap.parse_args()
|
||
|
||
random.seed(args.seed)
|
||
|
||
faiss_path = os.path.join(args.out_dir, "faiss.index")
|
||
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
|
||
|
||
index = faiss.read_index(faiss_path)
|
||
docstore = load_docstore(docstore_path)
|
||
|
||
embedder = SentenceTransformer(args.embedding_model)
|
||
|
||
# Teacher model to synthesize questions & answers from review chunks
|
||
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
|
||
)
|
||
model.eval()
|
||
|
||
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
|
||
with open(out_path, "w", encoding="utf-8") as f:
|
||
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
|
||
# pick a "gold" chunk
|
||
gold = random.choice(docstore)
|
||
gold_text = gold["text"]
|
||
|
||
# 1) generate a question answerable from gold_text
|
||
q_prompt = [
|
||
{"role": "system", "content": SYSTEM_PERSONA},
|
||
{
|
||
"role": "user",
|
||
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
|
||
f"CONTEXT:\n{gold_text}\n\n"
|
||
"Return only the question.",
|
||
},
|
||
]
|
||
question = generate_text(
|
||
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
|
||
)
|
||
question = question.split("\n")[0].strip()
|
||
|
||
# 2) retrieve top-k for that question
|
||
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
|
||
retrieved = [docstore[i] for i in ids]
|
||
|
||
# 3) add distractors (random docs not in retrieved)
|
||
retrieved_ids = set(ids)
|
||
distractors = []
|
||
attempts = 0
|
||
while len(distractors) < args.n_distractors and attempts < 50:
|
||
cand_idx = random.randrange(len(docstore))
|
||
attempts += 1
|
||
if cand_idx in retrieved_ids:
|
||
continue
|
||
distractors.append(docstore[cand_idx])
|
||
|
||
# Mix: retrieved + distractors
|
||
context_docs = retrieved + distractors
|
||
random.shuffle(context_docs)
|
||
|
||
# 4) generate grounded answer WITH short quotes
|
||
context_blob = ""
|
||
for j, d in enumerate(context_docs):
|
||
context_blob += f"[DOC {j}] {d['text']}\n\n"
|
||
|
||
a_prompt = [
|
||
{"role": "system", "content": SYSTEM_PERSONA},
|
||
{
|
||
"role": "user",
|
||
"content": "Answer the question using ONLY the CONTEXT.\n"
|
||
"Rules:\n"
|
||
"- Include 1–2 short direct quotes from CONTEXT as evidence.\n"
|
||
"- If the answer isn't supported, say you can't tell from the context.\n\n"
|
||
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
|
||
},
|
||
]
|
||
answer = generate_text(
|
||
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
|
||
)
|
||
|
||
# Final training example (conversational dataset format for TRL)
|
||
train_ex = {
|
||
"messages": [
|
||
{"role": "system", "content": SYSTEM_PERSONA},
|
||
{
|
||
"role": "user",
|
||
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n"
|
||
"Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.",
|
||
},
|
||
{"role": "assistant", "content": answer},
|
||
]
|
||
}
|
||
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
|
||
|
||
print(f"Wrote {out_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|