import argparse import json import os import faiss import numpy as np import torch from sentence_transformers import SentenceTransformer from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes. Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal. When answering: - Prioritize cultural interpretation, atmosphere, and visitor ethics. - Weigh trade-offs thoughtfully (e.g., crowds vs. significance). - Avoid generic travel advice and avoid promotional language. - Do not exaggerate. - Provide nuanced, reflective reasoning rather than bullet lists. - Keep answers concise but specific. Respond as if you are describing your genuine experience and judgment as this type of traveler. If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer. """ def load_docstore(path): docs = [] with open(path, "r", encoding="utf-8") as f: for line in f: docs.append(json.loads(line)) return docs def retrieve(index, embedder, query, top_k=6): q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) scores, ids = index.search(q, top_k) return ids[0].tolist(), scores[0].tolist() @torch.no_grad() def main(): ap = argparse.ArgumentParser() ap.add_argument( "--model_dir", required=True, help="Path to your finetuned model folder" ) ap.add_argument( "--out_dir", default="out", help="Where faiss.index and docstore.jsonl live" ) ap.add_argument( "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" ) ap.add_argument("--top_k", type=int, default=6) ap.add_argument("--max_new_tokens", type=int, default=320) args = ap.parse_args() index = faiss.read_index(os.path.join(args.out_dir, "faiss.index")) docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl")) embedder = SentenceTransformer(args.embedding_model) # Load your externally finetuned model directly from disk tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True) # Important: ensure pad token exists for generation; Mistral often uses eos as pad if tok.pad_token is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained( args.model_dir, device_map="auto", torch_dtype=torch.float16, ) model.eval() print("Type your question (Ctrl+C to exit).") while True: q = input("\nYou: ").strip() if not q: continue ids, _ = retrieve(index, embedder, q, top_k=args.top_k) context_docs = [docstore[i]["text"] for i in ids] context_blob = "\n\n".join( [f"[DOC {i}] {t}" for i, t in enumerate(context_docs)] ) messages = [ {"role": "system", "content": SYSTEM_PERSONA}, {"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"}, ] # Use chat template from your folder (you have chat_template.jinja) inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device) enc = tok.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) if isinstance(enc, torch.Tensor): input_ids = enc.to(model.device) attention_mask = torch.ones_like(input_ids, device=model.device) else: input_ids = enc["input_ids"].to(model.device) attention_mask = enc.get("attention_mask") if attention_mask is None: attention_mask = torch.ones_like(input_ids) attention_mask = attention_mask.to(model.device) out = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=args.max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, ) ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip() print(f"\nBaliTwin: {ans}") if __name__ == "__main__": main()