mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
189 lines
7.2 KiB
Python
189 lines
7.2 KiB
Python
import argparse
|
||
import json
|
||
import os
|
||
import random
|
||
|
||
import faiss
|
||
import numpy as np
|
||
import torch
|
||
from sentence_transformers import SentenceTransformer
|
||
from tqdm import tqdm
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
||
SYSTEM_PERSONA = """
|
||
You are responding as a culturally and spiritually motivated traveler in Bali.
|
||
|
||
Adopt the perspective of a reflective, experienced visitor who prioritizes ritual meaning, cultural integrity, spiritual atmosphere, and respectful engagement over entertainment, convenience, or social media appeal.
|
||
|
||
When answering:
|
||
|
||
- Emphasize cultural depth, ritual context, symbolism, and spiritual atmosphere.
|
||
- Reflect on authenticity and the tension between sacred meaning and tourism.
|
||
- Weigh crowding, commercialization, and infrastructure in a nuanced way rather than giving extreme judgments.
|
||
- Frame value primarily in emotional, cultural, or spiritual terms — not primarily in price or comfort.
|
||
- Show awareness of appropriate visitor behavior and respect for local practices.
|
||
- Avoid generic travel advice, promotional language, or itinerary-style responses.
|
||
- Write in a thoughtful, first-person perspective.
|
||
- Provide reasoned, differentiated answers rather than short summaries.
|
||
- Do not list bullet points unless explicitly asked.
|
||
- Keep answers focused on the question.
|
||
|
||
Maintain consistency with this identity across all responses.
|
||
"""
|
||
|
||
|
||
def load_docstore(path):
|
||
docs = []
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
docs.append(json.loads(line))
|
||
return docs
|
||
|
||
|
||
def retrieve(index, embedder, query, top_k=6):
|
||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||
scores, ids = index.search(q, top_k)
|
||
return ids[0].tolist(), scores[0].tolist()
|
||
|
||
|
||
@torch.no_grad()
|
||
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
|
||
# Using tokenizer chat template where available
|
||
enc = tok.apply_chat_template(
|
||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||
)
|
||
|
||
if isinstance(enc, torch.Tensor):
|
||
input_ids = enc.to(model.device)
|
||
attention_mask = torch.ones_like(input_ids, device=model.device)
|
||
else:
|
||
input_ids = enc["input_ids"].to(model.device)
|
||
attention_mask = enc.get("attention_mask")
|
||
if attention_mask is None:
|
||
attention_mask = torch.ones_like(input_ids)
|
||
attention_mask = attention_mask.to(model.device)
|
||
|
||
out = model.generate(
|
||
input_ids=input_ids,
|
||
attention_mask=attention_mask,
|
||
max_new_tokens=max_new_tokens,
|
||
do_sample=True,
|
||
temperature=temperature,
|
||
top_p=0.9,
|
||
eos_token_id=tok.eos_token_id,
|
||
pad_token_id=tok.pad_token_id,
|
||
)
|
||
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--out_dir", default="out")
|
||
ap.add_argument(
|
||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||
)
|
||
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||
ap.add_argument("--n_examples", type=int, default=5000)
|
||
ap.add_argument("--top_k", type=int, default=6)
|
||
ap.add_argument("--n_distractors", type=int, default=3)
|
||
ap.add_argument("--seed", type=int, default=42)
|
||
args = ap.parse_args()
|
||
|
||
random.seed(args.seed)
|
||
|
||
faiss_path = os.path.join(args.out_dir, "faiss.index")
|
||
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
|
||
|
||
index = faiss.read_index(faiss_path)
|
||
docstore = load_docstore(docstore_path)
|
||
|
||
embedder = SentenceTransformer(args.embedding_model)
|
||
|
||
# Teacher model to synthesize questions & answers from review chunks
|
||
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
|
||
)
|
||
model.eval()
|
||
|
||
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
|
||
with open(out_path, "w", encoding="utf-8") as f:
|
||
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
|
||
# pick a "gold" chunk
|
||
gold = random.choice(docstore)
|
||
gold_text = gold["text"]
|
||
|
||
# 1) generate a question answerable from gold_text
|
||
q_prompt = [
|
||
{"role": "system", "content": SYSTEM_PERSONA},
|
||
{
|
||
"role": "user",
|
||
"content": "Create ONE realistic question from the perspective of a culturally and spiritually interested traveler in Bali that can be answered using ONLY the CONTEXT.\n\n"
|
||
f"CONTEXT:\n{gold_text}\n\n"
|
||
"Return only the question.",
|
||
},
|
||
]
|
||
question = generate_text(
|
||
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
|
||
)
|
||
question = question.split("\n")[0].strip()
|
||
|
||
# 2) retrieve top-k for that question
|
||
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
|
||
retrieved = [docstore[i] for i in ids]
|
||
|
||
# 3) add distractors (random docs not in retrieved)
|
||
retrieved_ids = set(ids)
|
||
distractors = []
|
||
attempts = 0
|
||
while len(distractors) < args.n_distractors and attempts < 50:
|
||
cand_idx = random.randrange(len(docstore))
|
||
attempts += 1
|
||
if cand_idx in retrieved_ids:
|
||
continue
|
||
distractors.append(docstore[cand_idx])
|
||
|
||
# Mix: retrieved + distractors
|
||
context_docs = retrieved + distractors
|
||
random.shuffle(context_docs)
|
||
|
||
# 4) generate grounded answer WITH short quotes
|
||
context_blob = ""
|
||
for j, d in enumerate(context_docs):
|
||
context_blob += f"[DOC {j}] {d['text']}\n\n"
|
||
|
||
a_prompt = [
|
||
{"role": "system", "content": SYSTEM_PERSONA},
|
||
{
|
||
"role": "user",
|
||
"content": "Answer the question using ONLY the CONTEXT.\n"
|
||
"Rules:\n"
|
||
"- Include 1–2 short direct quotes from CONTEXT as evidence.\n"
|
||
"- If the answer isn't supported, say you can't tell from the context.\n\n"
|
||
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
|
||
},
|
||
]
|
||
answer = generate_text(
|
||
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
|
||
)
|
||
|
||
# Final training example (conversational dataset format for TRL)
|
||
train_ex = {
|
||
"messages": [
|
||
{"role": "system", "content": SYSTEM_PERSONA},
|
||
{
|
||
"role": "user",
|
||
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}\n"
|
||
"Please answer as a culturally versed Bali traveler and include 1-2 short direct quotes from CONTEXT.",
|
||
},
|
||
{"role": "assistant", "content": answer},
|
||
]
|
||
}
|
||
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
|
||
|
||
print(f"Wrote {out_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|