New RAFT approach

This commit is contained in:
2026-02-19 14:24:42 +01:00
parent d0d3edae14
commit 28823dc0b5
5 changed files with 591 additions and 0 deletions

162
raft/make_raft_data.py Normal file
View File

@@ -0,0 +1,162 @@
import argparse
import json
import os
import random
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
## Usage: python make_raft_data.py --out_dir out --n_examples 5000
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
You give your opinions nand guidance with local etiquette and context.
You avoid stereotypes. You explain local etiquette, customs, and context.
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
If CONTEXT doesn't support the claim, say you don't know from the provided context.
"""
def load_docstore(path):
docs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
docs.append(json.loads(line))
return docs
def retrieve(index, embedder, query, top_k=6):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@torch.no_grad()
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
# Using tokenizer chat template where available
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
out = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
eos_token_id=tok.eos_token_id,
)
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--n_examples", type=int, default=5000)
ap.add_argument("--top_k", type=int, default=6)
ap.add_argument("--n_distractors", type=int, default=3)
ap.add_argument("--seed", type=int, default=7)
args = ap.parse_args()
random.seed(args.seed)
faiss_path = os.path.join(args.out_dir, "faiss.index")
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
index = faiss.read_index(faiss_path)
docstore = load_docstore(docstore_path)
embedder = SentenceTransformer(args.embedding_model)
# Teacher model to synthesize questions & answers from review chunks
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
)
model.eval()
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
with open(out_path, "w", encoding="utf-8") as f:
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
# pick a "gold" chunk
gold = random.choice(docstore)
gold_text = gold["text"]
# 1) generate a question answerable from gold_text
q_prompt = [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
f"CONTEXT:\n{gold_text}\n\n"
"Return only the question.",
},
]
question = generate_text(
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
)
question = question.split("\n")[0].strip()
# 2) retrieve top-k for that question
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
retrieved = [docstore[i] for i in ids]
# 3) add distractors (random docs not in retrieved)
retrieved_ids = set(ids)
distractors = []
attempts = 0
while len(distractors) < args.n_distractors and attempts < 50:
cand_idx = random.randrange(len(docstore))
attempts += 1
if cand_idx in retrieved_ids:
continue
distractors.append(docstore[cand_idx])
# Mix: retrieved + distractors
context_docs = retrieved + distractors
random.shuffle(context_docs)
# 4) generate grounded answer WITH short quotes
context_blob = ""
for j, d in enumerate(context_docs):
context_blob += f"[DOC {j}] {d['text']}\n\n"
a_prompt = [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": "Answer the question using ONLY the CONTEXT.\n"
"Rules:\n"
"- Include 12 short direct quotes from CONTEXT as evidence.\n"
"- If the answer isn't supported, say you can't tell from the context.\n\n"
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
},
]
answer = generate_text(
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
)
# Final training example (conversational dataset format for TRL)
train_ex = {
"messages": [
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n"
"Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.",
},
{"role": "assistant", "content": answer},
]
}
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
print(f"Wrote {out_path}")
if __name__ == "__main__":
main()

122
raft/prepare_corpus.py Normal file
View File

@@ -0,0 +1,122 @@
import argparse
import json
import os
import re
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
## Usage: python prepare_corpus.py --input_tab your_reviews.tab --out_dir out
def simple_clean(text: str) -> str:
if not isinstance(text, str):
return ""
text = text.replace("\u00a0", " ")
text = re.sub(r"\s+", " ", text).strip()
return text
def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150):
"""
Simple char-based chunking (good enough for reviews).
For better chunking, split by sentences and cap token length.
"""
text = simple_clean(text)
if len(text) <= chunk_chars:
return [text] if text else []
chunks = []
i = 0
while i < len(text):
chunk = text[i : i + chunk_chars]
if chunk:
chunks.append(chunk)
i += max(1, chunk_chars - overlap)
return chunks
def detect_text_col(df: pd.DataFrame) -> str:
# Heuristic: pick the longest average string column
best_col, best_score = None, -1
for col in df.columns:
sample = df[col].dropna().astype(str).head(200)
if len(sample) == 0:
continue
avg_len = sample.map(len).mean()
if avg_len > best_score:
best_score = avg_len
best_col = col
if best_col is None:
raise ValueError("Could not detect a text column in the .tab file.")
return best_col
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input_tab", required=True, help="Tripadvisor reviews .tab file")
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--chunk_chars", type=int, default=900)
ap.add_argument("--overlap", type=int, default=150)
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# Many .tab files are TSV
df = pd.read_csv(args.input_tab, sep="\t", dtype=str, on_bad_lines="skip")
text_col = detect_text_col(df)
rows = df[text_col].fillna("").astype(str).tolist()
corpus_path = os.path.join(args.out_dir, "corpus.jsonl")
corpus = []
doc_id = 0
for r in tqdm(rows, desc="Chunking"):
r = simple_clean(r)
if len(r) < 30:
continue
chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap)
for ch in chunks:
if len(ch) < 30:
continue
corpus.append({"doc_id": doc_id, "text": ch})
doc_id += 1
with open(corpus_path, "w", encoding="utf-8") as f:
for ex in corpus:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Build FAISS index
embedder = SentenceTransformer(args.embedding_model)
texts = [c["text"] for c in corpus]
embs = embedder.encode(
texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True
)
embs = np.asarray(embs, dtype=np.float32)
dim = embs.shape[1]
index = faiss.IndexFlatIP(dim) # cosine if normalized
index.add(embs)
faiss_path = os.path.join(args.out_dir, "faiss.index")
faiss.write_index(index, faiss_path)
# Store mapping doc row -> text
mapping_path = os.path.join(args.out_dir, "docstore.jsonl")
with open(mapping_path, "w", encoding="utf-8") as f:
for i, c in enumerate(corpus):
f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n")
print(
f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}\nText column detected: {text_col}"
)
if __name__ == "__main__":
main()

89
raft/rag_chat.py Normal file
View File

@@ -0,0 +1,89 @@
import argparse
import json
import os
import faiss
import numpy as np
import torch
from peft import PeftModel
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
## Usage: python rag_chat.py --lora_dir out/mistral_balitwin_lora
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
You give your opinions nand guidance with local etiquette and context.
Use the provided CONTEXT; include 1-2 short quotes as evidence.
If the context does not support the claim, say so.
"""
def load_docstore(path):
docs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
docs.append(json.loads(line))
return docs
def retrieve(index, embedder, query, top_k=6):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--lora_dir", default="out/mistral_balitwin_lora")
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--top_k", type=int, default=6)
args = ap.parse_args()
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
embedder = SentenceTransformer(args.embedding_model)
tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(
args.base_model, device_map="auto", torch_dtype=torch.float16
)
model = PeftModel.from_pretrained(base, args.lora_dir)
model.eval()
print("Type your question (Ctrl+C to exit).")
while True:
q = input("\nYou: ").strip()
if not q:
continue
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
context_docs = [docstore[i]["text"] for i in ids]
context_blob = "\n\n".join(
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
)
messages = [
{"role": "system", "content": SYSTEM_PERSONA},
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
]
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
out = model.generate(
inp,
max_new_tokens=320,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tok.eos_token_id,
)
ans = tok.decode(out[0][inp.shape[1] :], skip_special_tokens=True).strip()
print(f"\nBaliTwin: {ans}")
if __name__ == "__main__":
main()

118
raft/rag_chat_merged.py Normal file
View File

@@ -0,0 +1,118 @@
import argparse
import json
import os
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
## Usage: python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
You give your opinions nand guidance with local etiquette and context.
Use the provided CONTEXT; include 1-2 short quotes as evidence.
If the context does not support the claim, say so.
"""
def load_docstore(path):
docs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
docs.append(json.loads(line))
return docs
def retrieve(index, embedder, query, top_k=6):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument(
"--model_dir", required=True, help="Path to your finetuned model folder"
)
ap.add_argument(
"--out_dir", default="out", help="Where faiss.index and docstore.jsonl live"
)
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--top_k", type=int, default=6)
ap.add_argument("--max_new_tokens", type=int, default=320)
args = ap.parse_args()
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
embedder = SentenceTransformer(args.embedding_model)
# Load your externally finetuned model directly from disk
tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
# Important: ensure pad token exists for generation; Mistral often uses eos as pad
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
device_map="auto",
torch_dtype=torch.float16,
)
model.eval()
print("Type your question (Ctrl+C to exit).")
while True:
q = input("\nYou: ").strip()
if not q:
continue
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
context_docs = [docstore[i]["text"] for i in ids]
context_blob = "\n\n".join(
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
)
messages = [
{"role": "system", "content": SYSTEM_PERSONA},
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
]
# Use chat template from your folder (you have chat_template.jinja)
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
enc = tok.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
if isinstance(enc, torch.Tensor):
input_ids = enc.to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
else:
input_ids = enc["input_ids"].to(model.device)
attention_mask = enc.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = attention_mask.to(model.device)
out = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=args.max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
print(f"\nBaliTwin: {ans}")
if __name__ == "__main__":
main()

100
raft/train_mistral_raft.py Normal file
View File

@@ -0,0 +1,100 @@
import argparse
import json
import os
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
## Usage: python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mistral_balitwin_lora
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3")
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
ap.add_argument("--max_seq_len", type=int, default=2048)
ap.add_argument("--batch_size", type=int, default=1)
ap.add_argument("--grad_accum", type=int, default=16)
ap.add_argument("--lr", type=float, default=2e-4)
ap.add_argument("--epochs", type=int, default=1)
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# QLoRA (4-bit) config (good default for 7B on limited VRAM)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=(
torch.bfloat16 if torch.cuda.is_available() else torch.float16
),
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
# Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9}
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)
# LoRA adapter config (tweak r/alpha if needed)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
)
dataset = load_dataset("json", data_files=args.train_jsonl, split="train")
training_args = SFTConfig(
output_dir=args.out_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
logging_steps=10,
save_steps=200,
save_total_limit=2,
max_length=args.max_seq_len,
bf16=torch.cuda.is_available(),
fp16=not torch.cuda.is_available(),
assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10}
report_to=[],
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
peft_config=peft_config,
)
trainer.train()
trainer.save_model(args.out_dir)
tokenizer.save_pretrained(args.out_dir)
print(f"Saved LoRA adapter to: {args.out_dir}")
if __name__ == "__main__":
main()