mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 08:22:43 +01:00
New RAFT approach
This commit is contained in:
89
raft/rag_chat.py
Normal file
89
raft/rag_chat.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
## Usage: python rag_chat.py --lora_dir out/mistral_balitwin_lora
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
Use the provided CONTEXT; include 1-2 short quotes as evidence.
|
||||
If the context does not support the claim, say so.
|
||||
"""
|
||||
|
||||
|
||||
def load_docstore(path):
|
||||
docs = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
docs.append(json.loads(line))
|
||||
return docs
|
||||
|
||||
|
||||
def retrieve(index, embedder, query, top_k=6):
|
||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||||
scores, ids = index.search(q, top_k)
|
||||
return ids[0].tolist(), scores[0].tolist()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||||
ap.add_argument("--lora_dir", default="out/mistral_balitwin_lora")
|
||||
ap.add_argument("--out_dir", default="out")
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
args = ap.parse_args()
|
||||
|
||||
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
||||
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
|
||||
embedder = SentenceTransformer(args.embedding_model)
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
||||
base = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model, device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
model = PeftModel.from_pretrained(base, args.lora_dir)
|
||||
model.eval()
|
||||
|
||||
print("Type your question (Ctrl+C to exit).")
|
||||
while True:
|
||||
q = input("\nYou: ").strip()
|
||||
if not q:
|
||||
continue
|
||||
|
||||
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
|
||||
context_docs = [docstore[i]["text"] for i in ids]
|
||||
context_blob = "\n\n".join(
|
||||
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
||||
]
|
||||
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
|
||||
out = model.generate(
|
||||
inp,
|
||||
max_new_tokens=320,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
)
|
||||
ans = tok.decode(out[0][inp.shape[1] :], skip_special_tokens=True).strip()
|
||||
print(f"\nBaliTwin: {ans}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user