New RAFT approach

This commit is contained in:
2026-02-19 14:24:42 +01:00
parent d0d3edae14
commit 28823dc0b5
5 changed files with 591 additions and 0 deletions

118
raft/rag_chat_merged.py Normal file
View File

@@ -0,0 +1,118 @@
import argparse
import json
import os
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
## Usage: python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
You give your opinions nand guidance with local etiquette and context.
Use the provided CONTEXT; include 1-2 short quotes as evidence.
If the context does not support the claim, say so.
"""
def load_docstore(path):
docs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
docs.append(json.loads(line))
return docs
def retrieve(index, embedder, query, top_k=6):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument(
"--model_dir", required=True, help="Path to your finetuned model folder"
)
ap.add_argument(
"--out_dir", default="out", help="Where faiss.index and docstore.jsonl live"
)
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--top_k", type=int, default=6)
ap.add_argument("--max_new_tokens", type=int, default=320)
args = ap.parse_args()
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
embedder = SentenceTransformer(args.embedding_model)
# Load your externally finetuned model directly from disk
tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
# Important: ensure pad token exists for generation; Mistral often uses eos as pad
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
device_map="auto",
torch_dtype=torch.float16,
)
model.eval()
print("Type your question (Ctrl+C to exit).")
while True:
q = input("\nYou: ").strip()
if not q:
continue
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
context_docs = [docstore[i]["text"] for i in ids]
context_blob = "\n\n".join(
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
)
messages = [
{"role": "system", "content": SYSTEM_PERSONA},
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
]
# Use chat template from your folder (you have chat_template.jinja)
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
enc = tok.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
if isinstance(enc, torch.Tensor):
input_ids = enc.to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
else:
input_ids = enc["input_ids"].to(model.device)
attention_mask = enc.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = attention_mask.to(model.device)
out = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=args.max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
print(f"\nBaliTwin: {ans}")
if __name__ == "__main__":
main()