RAFT updates, BERTopic config, cleanup

This commit is contained in:
2026-02-21 01:57:14 +01:00
parent 8cadcb1f69
commit 1a99b53d44
12 changed files with 10750 additions and 9778 deletions

View File

@@ -1,14 +1,15 @@
import argparse
import json
import os
from threading import Thread
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
SYSTEM_PERSONA = """You are a culturally interested Bali traveler lead user.
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
@@ -22,7 +23,9 @@ When answering:
Respond as if you are describing your genuine experience and judgment as this type of traveler.
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
Use the provided CONTEXT to inform your answer, but do not feel obligated to use all of it. If the CONTEXT is not relevant to the question, you can ignore it.
NEVER directly quote the CONTEXT verbatim.
NEVER mention DOC or any context sources you are referring to. Instead, use it to synthesize your own understanding and response.
"""
@@ -86,14 +89,14 @@ def main():
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
)
print("\nRetrieved Context:")
print(context_blob)
messages = [
{"role": "system", "content": SYSTEM_PERSONA},
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
]
# Use chat template from your folder (you have chat_template.jinja)
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
enc = tok.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
@@ -108,7 +111,11 @@ def main():
attention_mask = torch.ones_like(input_ids)
attention_mask = attention_mask.to(model.device)
out = model.generate(
streamer = TextIteratorStreamer(
tok, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
generation_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=args.max_new_tokens,
@@ -117,10 +124,17 @@ def main():
top_p=0.9,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
streamer=streamer,
)
ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
print(f"\nBaliTwin: {ans}")
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
print("\nBaliTwin: ", end="", flush=True)
for token in streamer:
print(token, end="", flush=True)
print("")
thread.join()
if __name__ == "__main__":