Files
masterthesis-playground/raft/deprecated_rag_chat.py
2026-02-21 15:24:21 +01:00

99 lines
3.4 KiB
Python

import argparse
import json
import os
import faiss
import numpy as np
import torch
from peft import PeftModel
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
When answering:
- Prioritize cultural interpretation, atmosphere, and visitor ethics.
- Weigh trade-offs thoughtfully (e.g., crowds vs. significance).
- Avoid generic travel advice and avoid promotional language.
- Do not exaggerate.
- Provide nuanced, reflective reasoning rather than bullet lists.
- Keep answers concise but specific.
Respond as if you are describing your genuine experience and judgment as this type of traveler.
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
"""
def load_docstore(path):
docs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
docs.append(json.loads(line))
return docs
def retrieve(index, embedder, query, top_k=6):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@torch.no_grad()
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--lora_dir", default="out/mistral_balitwin_lora")
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--top_k", type=int, default=6)
args = ap.parse_args()
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
embedder = SentenceTransformer(args.embedding_model)
tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
base = AutoModelForCausalLM.from_pretrained(
args.base_model, device_map="auto", torch_dtype=torch.float16
)
model = PeftModel.from_pretrained(base, args.lora_dir)
model.eval()
print("Type your question (Ctrl+C to exit).")
while True:
q = input("\nYou: ").strip()
if not q:
continue
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
context_docs = [docstore[i]["text"] for i in ids]
context_blob = "\n\n".join(
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
)
messages = [
{"role": "system", "content": SYSTEM_PERSONA},
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
]
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
out = model.generate(
inp,
max_new_tokens=320,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tok.eos_token_id,
)
ans = tok.decode(out[0][inp.shape[1] :], skip_special_tokens=True).strip()
print(f"\nBaliTwin: {ans}")
if __name__ == "__main__":
main()