mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 08:22:43 +01:00
158 lines
5.4 KiB
Python
158 lines
5.4 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
from threading import Thread
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import torch
|
|
from sentence_transformers import SentenceTransformer
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
|
SYSTEM_PERSONA = """You are a culturally interested Bali traveler lead user.
|
|
|
|
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
|
|
|
|
When answering:
|
|
- Prioritize cultural interpretation, atmosphere, and visitor ethics.
|
|
- Weigh trade-offs thoughtfully (e.g., crowds vs. significance).
|
|
- Avoid generic travel advice and avoid promotional language.
|
|
- Do not exaggerate.
|
|
- Provide nuanced, reflective reasoning rather than bullet lists.
|
|
- Keep answers concise but specific.
|
|
|
|
Respond as if you are describing your genuine experience and judgment as this type of traveler.
|
|
|
|
Use the provided CONTEXT to inform your answer, but do not feel obligated to use all of it. If the CONTEXT is not relevant to the question, you can ignore it.
|
|
NEVER directly quote the CONTEXT verbatim.
|
|
NEVER mention DOC or any context sources you are referring to. Instead, use it to synthesize your own understanding and response.
|
|
"""
|
|
|
|
|
|
def load_docstore(path):
|
|
docs = []
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
docs.append(json.loads(line))
|
|
return docs
|
|
|
|
|
|
def retrieve(index, embedder, query, top_k=12):
|
|
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
|
scores, ids = index.search(q, top_k)
|
|
return ids[0].tolist(), scores[0].tolist()
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument(
|
|
"--model_dir", required=True, help="Path to your finetuned model folder"
|
|
)
|
|
ap.add_argument(
|
|
"--out_dir", default="out", help="Where faiss.index and docstore.jsonl live"
|
|
)
|
|
ap.add_argument(
|
|
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
ap.add_argument("--top_k", type=int, default=12)
|
|
ap.add_argument("--max_new_tokens", type=int, default=320)
|
|
ap.add_argument("--no_model", action=argparse.BooleanOptionalAction)
|
|
args = ap.parse_args()
|
|
|
|
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
|
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
|
|
embedder = SentenceTransformer(args.embedding_model)
|
|
|
|
# Load your externally finetuned model directly from disk
|
|
tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
|
|
|
|
# Important: ensure pad token exists for generation; Mistral often uses eos as pad
|
|
if tok.pad_token is None:
|
|
tok.pad_token = tok.eos_token
|
|
|
|
if not args.no_model:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_dir,
|
|
device_map="auto",
|
|
torch_dtype=torch.float16,
|
|
)
|
|
model.eval()
|
|
|
|
print("Type your question (Ctrl+C to exit).")
|
|
while True:
|
|
q = input("\nYou: ").strip()
|
|
if not q:
|
|
continue
|
|
|
|
ids, scores = retrieve(index, embedder, q, top_k=args.top_k)
|
|
|
|
# Drop irrelevant context
|
|
if scores[0] > 0:
|
|
filtered = [(i, s) for i, s in zip(ids, scores) if s / scores[0] >= 0.75]
|
|
if not filtered:
|
|
print("No relevant context found.")
|
|
continue
|
|
ids, scores = zip(*filtered)
|
|
else:
|
|
print("No relevant context found.")
|
|
continue
|
|
|
|
context_docs = [docstore[i]["text"] for i in ids]
|
|
context_blob = "\n\n".join([t for _, t in enumerate(context_docs)])
|
|
|
|
print("\nRetrieved Context:")
|
|
for i, (doc, score) in enumerate(zip(context_docs, scores)):
|
|
print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}")
|
|
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PERSONA},
|
|
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
|
]
|
|
|
|
if args.no_model:
|
|
continue
|
|
|
|
enc = tok.apply_chat_template(
|
|
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
|
)
|
|
|
|
if isinstance(enc, torch.Tensor):
|
|
input_ids = enc.to(model.device)
|
|
attention_mask = torch.ones_like(input_ids, device=model.device)
|
|
else:
|
|
input_ids = enc["input_ids"].to(model.device)
|
|
attention_mask = enc.get("attention_mask")
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones_like(input_ids)
|
|
attention_mask = attention_mask.to(model.device)
|
|
|
|
streamer = TextIteratorStreamer(
|
|
tok, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
|
)
|
|
|
|
generation_kwargs = dict(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
max_new_tokens=args.max_new_tokens,
|
|
do_sample=True,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
eos_token_id=tok.eos_token_id,
|
|
pad_token_id=tok.pad_token_id,
|
|
streamer=streamer,
|
|
)
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
thread.start()
|
|
|
|
print("\nBaliTwin: ", end="", flush=True)
|
|
for token in streamer:
|
|
print(token, end="", flush=True)
|
|
print("")
|
|
thread.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|