New RAFT approach

This commit is contained in:
2026-02-19 14:24:42 +01:00
parent d0d3edae14
commit 28823dc0b5
5 changed files with 591 additions and 0 deletions

162
raft/make_raft_data.py Normal file
View File

@@ -0,0 +1,162 @@
import argparse
import json
import os
import random
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
## Usage: python make_raft_data.py --out_dir out --n_examples 5000
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
You give your opinions nand guidance with local etiquette and context.
You avoid stereotypes. You explain local etiquette, customs, and context.
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
If CONTEXT doesn't support the claim, say you don't know from the provided context.
"""
def load_docstore(path):
docs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
docs.append(json.loads(line))
return docs
def retrieve(index, embedder, query, top_k=6):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@torch.no_grad()
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
# Using tokenizer chat template where available
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
out = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
eos_token_id=tok.eos_token_id,
)
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--n_examples", type=int, default=5000)
ap.add_argument("--top_k", type=int, default=6)
ap.add_argument("--n_distractors", type=int, default=3)
ap.add_argument("--seed", type=int, default=7)
args = ap.parse_args()
random.seed(args.seed)
faiss_path = os.path.join(args.out_dir, "faiss.index")
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
index = faiss.read_index(faiss_path)
docstore = load_docstore(docstore_path)
embedder = SentenceTransformer(args.embedding_model)
# Teacher model to synthesize questions & answers from review chunks
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
)
model.eval()
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
# pick a "gold" chunk
gold = random.choice(docstore)
gold_text = gold["text"]
# 1) generate a question answerable from gold_text
q_prompt = [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
f"CONTEXT:\n{gold_text}\n\n"
"Return only the question.",
},
]
question = generate_text(
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
)
question = question.split("\n")[0].strip()
# 2) retrieve top-k for that question
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
retrieved = [docstore[i] for i in ids]
# 3) add distractors (random docs not in retrieved)
retrieved_ids = set(ids)
distractors = []
attempts = 0
while len(distractors) < args.n_distractors and attempts < 50:
cand_idx = random.randrange(len(docstore))
attempts += 1
if cand_idx in retrieved_ids:
continue
distractors.append(docstore[cand_idx])
# Mix: retrieved + distractors
context_docs = retrieved + distractors
random.shuffle(context_docs)
# 4) generate grounded answer WITH short quotes
context_blob = ""
for j, d in enumerate(context_docs):
context_blob += f"[DOC {j}] {d['text']}\n\n"
a_prompt = [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": "Answer the question using ONLY the CONTEXT.\n"
"Rules:\n"
"- Include 12 short direct quotes from CONTEXT as evidence.\n"
"- If the answer isn't supported, say you can't tell from the context.\n\n"
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
},
]
answer = generate_text(
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
)
# Final training example (conversational dataset format for TRL)
train_ex = {
"messages": [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n"
"Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.",
},
{"role": "assistant", "content": answer},
]
}
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
print(f"Wrote {out_path}")
if __name__ == "__main__":
main()