mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
128 lines
4.4 KiB
Python
128 lines
4.4 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
|
|
|
|
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
|
|
|
|
When answering:
|
|
- Prioritize cultural interpretation, atmosphere, and visitor ethics.
|
|
- Weigh trade-offs thoughtfully (e.g., crowds vs. significance).
|
|
- Avoid generic travel advice and avoid promotional language.
|
|
- Do not exaggerate.
|
|
- Provide nuanced, reflective reasoning rather than bullet lists.
|
|
- Keep answers concise but specific.
|
|
|
|
Respond as if you are describing your genuine experience and judgment as this type of traveler.
|
|
|
|
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
|
|
"""
|
|
|
|
|
|
def load_docstore(path):
|
|
docs = []
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
docs.append(json.loads(line))
|
|
return docs
|
|
|
|
|
|
def retrieve(index, embedder, query, top_k=6):
|
|
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
|
scores, ids = index.search(q, top_k)
|
|
return ids[0].tolist(), scores[0].tolist()
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument(
|
|
"--model_dir", required=True, help="Path to your finetuned model folder"
|
|
)
|
|
ap.add_argument(
|
|
"--out_dir", default="out", help="Where faiss.index and docstore.jsonl live"
|
|
)
|
|
ap.add_argument(
|
|
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
ap.add_argument("--top_k", type=int, default=6)
|
|
ap.add_argument("--max_new_tokens", type=int, default=320)
|
|
args = ap.parse_args()
|
|
|
|
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
|
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
|
|
embedder = SentenceTransformer(args.embedding_model)
|
|
|
|
# Load your externally finetuned model directly from disk
|
|
tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
|
|
|
|
# Important: ensure pad token exists for generation; Mistral often uses eos as pad
|
|
if tok.pad_token is None:
|
|
tok.pad_token = tok.eos_token
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_dir,
|
|
device_map="auto",
|
|
torch_dtype=torch.float16,
|
|
)
|
|
model.eval()
|
|
|
|
print("Type your question (Ctrl+C to exit).")
|
|
while True:
|
|
q = input("\nYou: ").strip()
|
|
if not q:
|
|
continue
|
|
|
|
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
|
|
context_docs = [docstore[i]["text"] for i in ids]
|
|
context_blob = "\n\n".join(
|
|
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PERSONA},
|
|
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
|
]
|
|
|
|
# Use chat template from your folder (you have chat_template.jinja)
|
|
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
|
|
|
enc = tok.apply_chat_template(
|
|
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
|
)
|
|
|
|
if isinstance(enc, torch.Tensor):
|
|
input_ids = enc.to(model.device)
|
|
attention_mask = torch.ones_like(input_ids, device=model.device)
|
|
else:
|
|
input_ids = enc["input_ids"].to(model.device)
|
|
attention_mask = enc.get("attention_mask")
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
attention_mask = attention_mask.to(model.device)
|
|
|
|
out = model.generate(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
max_new_tokens=args.max_new_tokens,
|
|
do_sample=True,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
eos_token_id=tok.eos_token_id,
|
|
pad_token_id=tok.pad_token_id,
|
|
)
|
|
|
|
ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
|
print(f"\nBaliTwin: {ans}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|