mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 08:22:43 +01:00
New RAFT approach
This commit is contained in:
162
raft/make_raft_data.py
Normal file
162
raft/make_raft_data.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
## Usage: python make_raft_data.py --out_dir out --n_examples 5000
|
||||
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
You avoid stereotypes. You explain local etiquette, customs, and context.
|
||||
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
|
||||
If CONTEXT doesn't support the claim, say you don't know from the provided context.
|
||||
"""
|
||||
|
||||
|
||||
def load_docstore(path):
|
||||
docs = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
docs.append(json.loads(line))
|
||||
return docs
|
||||
|
||||
|
||||
def retrieve(index, embedder, query, top_k=6):
|
||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||||
scores, ids = index.search(q, top_k)
|
||||
return ids[0].tolist(), scores[0].tolist()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
|
||||
# Using tokenizer chat template where available
|
||||
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
)
|
||||
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--out_dir", default="out")
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||||
ap.add_argument("--n_examples", type=int, default=5000)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
ap.add_argument("--n_distractors", type=int, default=3)
|
||||
ap.add_argument("--seed", type=int, default=7)
|
||||
args = ap.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
|
||||
faiss_path = os.path.join(args.out_dir, "faiss.index")
|
||||
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
|
||||
|
||||
index = faiss.read_index(faiss_path)
|
||||
docstore = load_docstore(docstore_path)
|
||||
|
||||
embedder = SentenceTransformer(args.embedding_model)
|
||||
|
||||
# Teacher model to synthesize questions & answers from review chunks
|
||||
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
model.eval()
|
||||
|
||||
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
|
||||
# pick a "gold" chunk
|
||||
gold = random.choice(docstore)
|
||||
gold_text = gold["text"]
|
||||
|
||||
# 1) generate a question answerable from gold_text
|
||||
q_prompt = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
|
||||
f"CONTEXT:\n{gold_text}\n\n"
|
||||
"Return only the question.",
|
||||
},
|
||||
]
|
||||
question = generate_text(
|
||||
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
|
||||
)
|
||||
question = question.split("\n")[0].strip()
|
||||
|
||||
# 2) retrieve top-k for that question
|
||||
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
|
||||
retrieved = [docstore[i] for i in ids]
|
||||
|
||||
# 3) add distractors (random docs not in retrieved)
|
||||
retrieved_ids = set(ids)
|
||||
distractors = []
|
||||
attempts = 0
|
||||
while len(distractors) < args.n_distractors and attempts < 50:
|
||||
cand_idx = random.randrange(len(docstore))
|
||||
attempts += 1
|
||||
if cand_idx in retrieved_ids:
|
||||
continue
|
||||
distractors.append(docstore[cand_idx])
|
||||
|
||||
# Mix: retrieved + distractors
|
||||
context_docs = retrieved + distractors
|
||||
random.shuffle(context_docs)
|
||||
|
||||
# 4) generate grounded answer WITH short quotes
|
||||
context_blob = ""
|
||||
for j, d in enumerate(context_docs):
|
||||
context_blob += f"[DOC {j}] {d['text']}\n\n"
|
||||
|
||||
a_prompt = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Answer the question using ONLY the CONTEXT.\n"
|
||||
"Rules:\n"
|
||||
"- Include 1–2 short direct quotes from CONTEXT as evidence.\n"
|
||||
"- If the answer isn't supported, say you can't tell from the context.\n\n"
|
||||
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
|
||||
},
|
||||
]
|
||||
answer = generate_text(
|
||||
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
|
||||
)
|
||||
|
||||
# Final training example (conversational dataset format for TRL)
|
||||
train_ex = {
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n"
|
||||
"Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.",
|
||||
},
|
||||
{"role": "assistant", "content": answer},
|
||||
]
|
||||
}
|
||||
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Wrote {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user