import argparse import json import os import faiss import numpy as np import torch from sentence_transformers import SentenceTransformer from transformers import AutoModelForCausalLM, AutoTokenizer ## Usage: python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler. You give your opinions nand guidance with local etiquette and context. Use the provided CONTEXT; include 1-2 short quotes as evidence. If the context does not support the claim, say so. """ def load_docstore(path): docs = [] with open(path, "r", encoding="utf-8") as f: for line in f: docs.append(json.loads(line)) return docs def retrieve(index, embedder, query, top_k=6): q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) scores, ids = index.search(q, top_k) return ids[0].tolist(), scores[0].tolist() @torch.no_grad() def main(): ap = argparse.ArgumentParser() ap.add_argument( "--model_dir", required=True, help="Path to your finetuned model folder" ) ap.add_argument( "--out_dir", default="out", help="Where faiss.index and docstore.jsonl live" ) ap.add_argument( "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" ) ap.add_argument("--top_k", type=int, default=6) ap.add_argument("--max_new_tokens", type=int, default=320) args = ap.parse_args() index = faiss.read_index(os.path.join(args.out_dir, "faiss.index")) docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl")) embedder = SentenceTransformer(args.embedding_model) # Load your externally finetuned model directly from disk tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True) # Important: ensure pad token exists for generation; Mistral often uses eos as pad if tok.pad_token is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( args.model_dir, device_map="auto", torch_dtype=torch.float16, ) model.eval() print("Type your question (Ctrl+C to exit).") while True: q = input("\nYou: ").strip() if not q: continue ids, _ = retrieve(index, embedder, q, top_k=args.top_k) context_docs = [docstore[i]["text"] for i in ids] context_blob = "\n\n".join( [f"[DOC {i}] {t}" for i, t in enumerate(context_docs)] ) messages = [ {"role": "system", "content": SYSTEM_PERSONA}, {"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"}, ] # Use chat template from your folder (you have chat_template.jinja) inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device) enc = tok.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) if isinstance(enc, torch.Tensor): input_ids = enc.to(model.device) attention_mask = torch.ones_like(input_ids, device=model.device) else: input_ids = enc["input_ids"].to(model.device) attention_mask = enc.get("attention_mask") if attention_mask is None: attention_mask = torch.ones_like(input_ids) attention_mask = attention_mask.to(model.device) out = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=args.max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, ) ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip() print(f"\nBaliTwin: {ans}") if __name__ == "__main__": main()