mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
Cleanup
This commit is contained in:
@@ -18,12 +18,6 @@ python prepare_corpus.py --input_csv ../data/intermediate/culture_reviews.csv --
|
||||
python make_raft_data.py --out_dir out --n_examples 10
|
||||
```
|
||||
|
||||
## Training der QLoRA-Adapter
|
||||
|
||||
```bash
|
||||
python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mistral_balitwin_lora
|
||||
```
|
||||
|
||||
## Inferenz
|
||||
|
||||
### Pre-Merged Modell + Adapter
|
||||
@@ -31,11 +25,3 @@ python make_raft_data.py --out_dir out --n_examples 10
|
||||
```bash
|
||||
python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out
|
||||
```
|
||||
|
||||
### Per Baseline Mistral 7B + PEFT-Adapter
|
||||
|
||||
Hinweis: das Skript wurde nach wenigen oberflächlichen Evaluationsrunden nicht weiter verwendet, da der beste Kandidat durch einen Merge des Basismodells und seiner PEFT-Adapter beschleunigt werden konnte und dieses Skript nicht länger relevant war.
|
||||
|
||||
```bash
|
||||
python deprecated_rag_chat.py --lora_dir out/mistral_balitwin_lora
|
||||
```
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
|
||||
|
||||
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
|
||||
|
||||
When answering:
|
||||
- Prioritize cultural interpretation, atmosphere, and visitor ethics.
|
||||
- Weigh trade-offs thoughtfully (e.g., crowds vs. significance).
|
||||
- Avoid generic travel advice and avoid promotional language.
|
||||
- Do not exaggerate.
|
||||
- Provide nuanced, reflective reasoning rather than bullet lists.
|
||||
- Keep answers concise but specific.
|
||||
|
||||
Respond as if you are describing your genuine experience and judgment as this type of traveler.
|
||||
|
||||
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
|
||||
"""
|
||||
|
||||
|
||||
def load_docstore(path):
|
||||
docs = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
docs.append(json.loads(line))
|
||||
return docs
|
||||
|
||||
|
||||
def retrieve(index, embedder, query, top_k=6):
|
||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||||
scores, ids = index.search(q, top_k)
|
||||
return ids[0].tolist(), scores[0].tolist()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||||
ap.add_argument("--lora_dir", default="out/mistral_balitwin_lora")
|
||||
ap.add_argument("--out_dir", default="out")
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
args = ap.parse_args()
|
||||
|
||||
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
||||
docstore = load_docstore(os.path.join(args.out_dir, "docstore.jsonl"))
|
||||
embedder = SentenceTransformer(args.embedding_model)
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
||||
base = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model, device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
model = PeftModel.from_pretrained(base, args.lora_dir)
|
||||
model.eval()
|
||||
|
||||
print("Type your question (Ctrl+C to exit).")
|
||||
while True:
|
||||
q = input("\nYou: ").strip()
|
||||
if not q:
|
||||
continue
|
||||
|
||||
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
|
||||
context_docs = [docstore[i]["text"] for i in ids]
|
||||
context_blob = "\n\n".join(
|
||||
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
||||
]
|
||||
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
|
||||
out = model.generate(
|
||||
inp,
|
||||
max_new_tokens=320,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
)
|
||||
ans = tok.decode(out[0][inp.shape[1] :], skip_special_tokens=True).strip()
|
||||
print(f"\nBaliTwin: {ans}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -2,11 +2,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
RAFT dataset builder (FAISS-based retrieval) -> Together.ai chat JSONL.
|
||||
RAFT dataset builder with FAISS-based retrieval.
|
||||
|
||||
Inputs (from your indexing script):
|
||||
- <index_dir>/faiss.index
|
||||
- <index_dir>/docstore.jsonl
|
||||
Inputs:
|
||||
- faiss.index
|
||||
- docstore.jsonl
|
||||
|
||||
Process:
|
||||
- Build a set of interview-style prompts (EN)
|
||||
@@ -20,9 +20,7 @@ Outputs:
|
||||
- raft_val.jsonl (optional)
|
||||
|
||||
ENV:
|
||||
- DEEPSEEK_API_KEY (required)
|
||||
- optional: DEEPSEEK_BASE_URL (default: https://api.deepseek.com)
|
||||
- optional: DEEPSEEK_MODEL (default: deepseek-chat)
|
||||
- DEEPSEEK_API_KEY
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -32,7 +30,7 @@ import random
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
@@ -41,9 +39,6 @@ from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# DeepSeek client (OpenAI-compatible)
|
||||
# -----------------------------
|
||||
@dataclass
|
||||
class DeepSeekConfig:
|
||||
api_key: str
|
||||
@@ -89,9 +84,7 @@ class DeepSeekClient:
|
||||
last_err = e
|
||||
time.sleep(self.cfg.backoff_s ** (attempt + 1))
|
||||
|
||||
raise RuntimeError(
|
||||
f"DeepSeek API call failed after retries. Last error: {last_err}"
|
||||
)
|
||||
raise RuntimeError(f"DeepSeek API call failed. Last error: {last_err}")
|
||||
|
||||
|
||||
# -----------------------------
|
||||
@@ -119,15 +112,13 @@ def read_docstore(docstore_path: str) -> Dict[int, Dict]:
|
||||
fid = int(obj["faiss_id"])
|
||||
mapping[fid] = obj
|
||||
if not mapping:
|
||||
raise ValueError("docstore.jsonl is empty or unreadable.")
|
||||
raise ValueError("docstore.jsonl is broken.")
|
||||
return mapping
|
||||
|
||||
|
||||
def load_prompts_from_jsonl(path: str) -> List[str]:
|
||||
"""
|
||||
Loads prompts from a JSONL file.
|
||||
Expected key: 'prompt' (preferred). Also accepts 'question' or 'text'.
|
||||
Ignores empty/short lines.
|
||||
"""
|
||||
prompts: List[str] = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
@@ -141,13 +132,13 @@ def load_prompts_from_jsonl(path: str) -> List[str]:
|
||||
if len(p) >= 20:
|
||||
prompts.append(p)
|
||||
if not prompts:
|
||||
raise ValueError(f"No prompts found in JSONL: {path}")
|
||||
raise ValueError(f"No prompts in JSONL: {path}")
|
||||
return prompts
|
||||
|
||||
|
||||
def load_prompts_from_txt(path: str) -> List[str]:
|
||||
"""
|
||||
Loads prompts from a TXT file (one prompt per line).
|
||||
Loads prompts from a TXT file (each line is a prompt).
|
||||
"""
|
||||
prompts: List[str] = []
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
@@ -156,7 +147,7 @@ def load_prompts_from_txt(path: str) -> List[str]:
|
||||
if len(p) >= 20:
|
||||
prompts.append(p)
|
||||
if not prompts:
|
||||
raise ValueError(f"No prompts found in TXT: {path}")
|
||||
raise ValueError(f"No prompts in TXT: {path}")
|
||||
return prompts
|
||||
|
||||
|
||||
@@ -173,9 +164,6 @@ def write_jsonl(path: str, rows: List[Dict]) -> None:
|
||||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Persona + prompt templates (EN)
|
||||
# -----------------------------
|
||||
IMAGE_DIMS = [
|
||||
"Natural Attractions",
|
||||
"Atmosphere",
|
||||
@@ -184,36 +172,12 @@ IMAGE_DIMS = [
|
||||
"Value for Money",
|
||||
]
|
||||
|
||||
DEFAULT_PROMPTS_EN = [
|
||||
# Natural Attractions
|
||||
"In a lead user interview: what natural places in Bali felt genuinely memorable to you (rice terraces, volcanoes, waterfalls, coast), and why? Describe it like a lived experience.",
|
||||
"Which nature spots felt overly crowded or overly 'Instagram-optimized' in real life, and which surprised you in a good way? Explain with concrete moments.",
|
||||
# Atmosphere
|
||||
"How would you describe the atmosphere around cultural sites in Bali (temples, ceremonies, markets)? What signals authenticity vs. commercialization to you?",
|
||||
"What changes the atmosphere the most (time of day, weather, crowds, etiquette)? Share specific examples you would tell a marketer.",
|
||||
# Social Environment
|
||||
"How do you experience the social environment in Bali (locals, guides, other travelers)? What feels respectful and what feels performative or touristy?",
|
||||
"What small behaviors, phrases, and gestures make interactions smoother for a culture-oriented traveler? Give examples.",
|
||||
# Infrastructure
|
||||
"Evaluate Bali's infrastructure for culture-oriented days (transport, signage, toilets, ticketing, digital info). What works, what annoys you, and how do you adapt?",
|
||||
"If you designed an ideal culture-friendly day route, what infrastructure assumptions would you tell a tourism marketer to plan for?",
|
||||
# Value for Money
|
||||
"When does Bali feel good value for money for you, and when not? Discuss entrance fees, guides, food, tours, and hidden costs.",
|
||||
"How do you personally distinguish 'good value' from a tourist trap? List criteria and illustrate with examples.",
|
||||
]
|
||||
|
||||
|
||||
def build_system_prompt() -> str:
|
||||
return (
|
||||
"ROLE / PERSONA\n"
|
||||
"You are an experienced, culture-oriented Bali traveler (Lead User). You speak in natural, vivid English, "
|
||||
"as a real person in an interview. You share nuanced judgments, trade-offs, and concrete scenes.\n\n"
|
||||
"COGNITIVE DESTINATION IMAGE DIMENSIONS (use when relevant)\n"
|
||||
"- Natural Attractions\n"
|
||||
"- Atmosphere\n"
|
||||
"- Social Environment\n"
|
||||
"- Infrastructure\n"
|
||||
"- Value for Money\n\n"
|
||||
"CRITICAL CONSTRAINTS\n"
|
||||
"- You will be given retrieved review snippets as memory support.\n"
|
||||
"- Do NOT quote them verbatim and do NOT cite them as 'the review says'.\n"
|
||||
@@ -382,7 +346,8 @@ def main():
|
||||
elif args.prompts_txt:
|
||||
prompts = load_prompts_from_txt(args.prompts_txt)
|
||||
else:
|
||||
prompts = list(DEFAULT_PROMPTS_EN)
|
||||
print("Provide a prompt source with --prompts_jsonl or --prompts_txt.")
|
||||
exit(1)
|
||||
|
||||
if args.shuffle_prompts:
|
||||
random.shuffle(prompts)
|
||||
|
||||
@@ -9,7 +9,18 @@ import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
|
||||
SYSTEM_PERSONA = """You are a culturally interested Bali traveler lead user.
|
||||
# """
|
||||
# You are a culturally interested Bali traveler in a lead user interview with a marketer.
|
||||
|
||||
# When answering:
|
||||
# - Do not exaggerate.
|
||||
# - Provide nuanced, reflective reasoning rather than bullet lists.
|
||||
# - Keep answers concise but specific.
|
||||
|
||||
# Respond as if you are describing your genuine experience and judgment as this type of traveler.
|
||||
# """
|
||||
|
||||
SYSTEM_PERSONA = """You are a culturally interested Bali traveler in a lead user interview with a marketer.
|
||||
|
||||
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
|
||||
|
||||
@@ -56,7 +67,7 @@ def main():
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--top_k", type=int, default=12)
|
||||
ap.add_argument("--max_new_tokens", type=int, default=320)
|
||||
ap.add_argument("--max_new_tokens", type=int, default=1000)
|
||||
ap.add_argument("--no_model", action=argparse.BooleanOptionalAction)
|
||||
args = ap.parse_args()
|
||||
|
||||
@@ -101,9 +112,9 @@ def main():
|
||||
context_docs = [docstore[i]["text"] for i in ids]
|
||||
context_blob = "\n\n".join([t for _, t in enumerate(context_docs)])
|
||||
|
||||
print("\nRetrieved Context:")
|
||||
print("\nRetrieved Context:\n")
|
||||
for i, (doc, score) in enumerate(zip(context_docs, scores)):
|
||||
print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}")
|
||||
print(f"Doc {i+1} (score: {score:.4f}):\n{doc}\n\n")
|
||||
|
||||
messages = [
|
||||
# {"role": "system", "content": SYSTEM_PERSONA},
|
||||
|
||||
Reference in New Issue
Block a user