mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
RAFT updates, BERTopic config, cleanup
This commit is contained in:
@@ -1,17 +1,21 @@
|
||||
# Retrieval-Augmented Finetuning (RAFT)
|
||||
|
||||
**Ablauf**:
|
||||
## Voraussetzungen
|
||||
|
||||
- Generelles Preprocessing (Voraussetzung für BERTopic)
|
||||
- BERTopic
|
||||
- Klassifikation muss durchgeführt sein, `data/intermediate/culture_reviews.csv` muss existieren
|
||||
|
||||
## Vorbereiten des Retrieval-Corpus
|
||||
|
||||
```bash
|
||||
python prepare_corpus.py --input_tab ../data/intermediate/selected_topics_documents.csv --out_dir out
|
||||
python prepare_corpus.py --input_tab ../data/intermediate/culture_reviews.csv --out_dir out
|
||||
```
|
||||
|
||||
## Erstellen des RAFT-Datensatzes
|
||||
|
||||
```bash
|
||||
python make_raft_data.py --out_dir out --n_examples 100
|
||||
python make_raft_data.py --out_dir out --n_examples 10
|
||||
```
|
||||
|
||||
## Training der QLoRA-Adapter
|
||||
|
||||
@@ -10,11 +10,25 @@ from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
You avoid stereotypes. You explain local etiquette, customs, and context.
|
||||
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
|
||||
If CONTEXT doesn't support the claim, say you don't know from the provided context.
|
||||
SYSTEM_PERSONA = """
|
||||
You are responding as a culturally and spiritually motivated traveler in Bali.
|
||||
|
||||
Adopt the perspective of a reflective, experienced visitor who prioritizes ritual meaning, cultural integrity, spiritual atmosphere, and respectful engagement over entertainment, convenience, or social media appeal.
|
||||
|
||||
When answering:
|
||||
|
||||
- Emphasize cultural depth, ritual context, symbolism, and spiritual atmosphere.
|
||||
- Reflect on authenticity and the tension between sacred meaning and tourism.
|
||||
- Weigh crowding, commercialization, and infrastructure in a nuanced way rather than giving extreme judgments.
|
||||
- Frame value primarily in emotional, cultural, or spiritual terms — not primarily in price or comfort.
|
||||
- Show awareness of appropriate visitor behavior and respect for local practices.
|
||||
- Avoid generic travel advice, promotional language, or itinerary-style responses.
|
||||
- Write in a thoughtful, first-person perspective.
|
||||
- Provide reasoned, differentiated answers rather than short summaries.
|
||||
- Do not list bullet points unless explicitly asked.
|
||||
- Keep answers focused on the question.
|
||||
|
||||
Maintain consistency with this identity across all responses.
|
||||
"""
|
||||
|
||||
|
||||
@@ -35,14 +49,29 @@ def retrieve(index, embedder, query, top_k=6):
|
||||
@torch.no_grad()
|
||||
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
|
||||
# Using tokenizer chat template where available
|
||||
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
enc = tok.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
if isinstance(enc, torch.Tensor):
|
||||
input_ids = enc.to(model.device)
|
||||
attention_mask = torch.ones_like(input_ids, device=model.device)
|
||||
else:
|
||||
input_ids = enc["input_ids"].to(model.device)
|
||||
attention_mask = enc.get("attention_mask")
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||||
|
||||
@@ -57,7 +86,7 @@ def main():
|
||||
ap.add_argument("--n_examples", type=int, default=5000)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
ap.add_argument("--n_distractors", type=int, default=3)
|
||||
ap.add_argument("--seed", type=int, default=7)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
args = ap.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
@@ -89,7 +118,7 @@ def main():
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
|
||||
"content": "Create ONE realistic question from the perspective of a culturally and spiritually interested traveler in Bali that can be answered using ONLY the CONTEXT.\n\n"
|
||||
f"CONTEXT:\n{gold_text}\n\n"
|
||||
"Return only the question.",
|
||||
},
|
||||
|
||||
@@ -9,10 +9,21 @@ from peft import PeftModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
Use the provided CONTEXT; include 1-2 short quotes as evidence.
|
||||
If the context does not support the claim, say so.
|
||||
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
|
||||
|
||||
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
|
||||
|
||||
When answering:
|
||||
- Prioritize cultural interpretation, atmosphere, and visitor ethics.
|
||||
- Weigh trade-offs thoughtfully (e.g., crowds vs. significance).
|
||||
- Avoid generic travel advice and avoid promotional language.
|
||||
- Do not exaggerate.
|
||||
- Provide nuanced, reflective reasoning rather than bullet lists.
|
||||
- Keep answers concise but specific.
|
||||
|
||||
Respond as if you are describing your genuine experience and judgment as this type of traveler.
|
||||
|
||||
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from threading import Thread
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
|
||||
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
|
||||
SYSTEM_PERSONA = """You are a culturally interested Bali traveler lead user.
|
||||
|
||||
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
|
||||
|
||||
@@ -22,7 +23,9 @@ When answering:
|
||||
|
||||
Respond as if you are describing your genuine experience and judgment as this type of traveler.
|
||||
|
||||
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
|
||||
Use the provided CONTEXT to inform your answer, but do not feel obligated to use all of it. If the CONTEXT is not relevant to the question, you can ignore it.
|
||||
NEVER directly quote the CONTEXT verbatim.
|
||||
NEVER mention DOC or any context sources you are referring to. Instead, use it to synthesize your own understanding and response.
|
||||
"""
|
||||
|
||||
|
||||
@@ -86,14 +89,14 @@ def main():
|
||||
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
||||
)
|
||||
|
||||
print("\nRetrieved Context:")
|
||||
print(context_blob)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
||||
]
|
||||
|
||||
# Use chat template from your folder (you have chat_template.jinja)
|
||||
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
|
||||
enc = tok.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
@@ -108,7 +111,11 @@ def main():
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
out = model.generate(
|
||||
streamer = TextIteratorStreamer(
|
||||
tok, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
generation_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
@@ -117,10 +124,17 @@ def main():
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
streamer=streamer,
|
||||
)
|
||||
|
||||
ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||||
print(f"\nBaliTwin: {ans}")
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
print("\nBaliTwin: ", end="", flush=True)
|
||||
for token in streamer:
|
||||
print(token, end="", flush=True)
|
||||
print("")
|
||||
thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
83
raft/requirements.txt
Normal file
83
raft/requirements.txt
Normal file
@@ -0,0 +1,83 @@
|
||||
accelerate==1.12.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.3
|
||||
aiosignal==1.4.0
|
||||
annotated-doc==0.0.4
|
||||
anyio==4.12.1
|
||||
attrs==25.4.0
|
||||
bitsandbytes==0.49.2
|
||||
certifi==2026.1.4
|
||||
charset-normalizer==3.4.4
|
||||
click==8.3.1
|
||||
cuda-bindings==12.9.4
|
||||
cuda-pathfinder==1.3.4
|
||||
datasets==4.5.0
|
||||
dill==0.4.0
|
||||
faiss-cpu==1.13.2
|
||||
filelock==3.24.3
|
||||
frozenlist==1.8.0
|
||||
fsspec==2025.10.0
|
||||
h11==0.16.0
|
||||
hf-xet==1.2.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
huggingface_hub==1.4.1
|
||||
idna==3.11
|
||||
Jinja2==3.1.6
|
||||
joblib==1.5.3
|
||||
markdown-it-py==4.0.0
|
||||
MarkupSafe==3.0.3
|
||||
mdurl==0.1.2
|
||||
mpmath==1.3.0
|
||||
multidict==6.7.1
|
||||
multiprocess==0.70.18
|
||||
networkx==3.6.1
|
||||
numpy==2.4.2
|
||||
nvidia-cublas-cu12==12.8.4.1
|
||||
nvidia-cuda-cupti-cu12==12.8.90
|
||||
nvidia-cuda-nvrtc-cu12==12.8.93
|
||||
nvidia-cuda-runtime-cu12==12.8.90
|
||||
nvidia-cudnn-cu12==9.10.2.21
|
||||
nvidia-cufft-cu12==11.3.3.83
|
||||
nvidia-cufile-cu12==1.13.1.3
|
||||
nvidia-curand-cu12==10.3.9.90
|
||||
nvidia-cusolver-cu12==11.7.3.90
|
||||
nvidia-cusparse-cu12==12.5.8.93
|
||||
nvidia-cusparselt-cu12==0.7.1
|
||||
nvidia-nccl-cu12==2.27.5
|
||||
nvidia-nvjitlink-cu12==12.8.93
|
||||
nvidia-nvshmem-cu12==3.4.5
|
||||
nvidia-nvtx-cu12==12.8.90
|
||||
packaging==26.0
|
||||
pandas==3.0.1
|
||||
peft==0.18.1
|
||||
propcache==0.4.1
|
||||
psutil==7.2.2
|
||||
pyarrow==23.0.1
|
||||
Pygments==2.19.2
|
||||
python-dateutil==2.9.0.post0
|
||||
PyYAML==6.0.3
|
||||
regex==2026.1.15
|
||||
requests==2.32.5
|
||||
rich==14.3.2
|
||||
safetensors==0.7.0
|
||||
scikit-learn==1.8.0
|
||||
scipy==1.17.0
|
||||
sentence-transformers==5.2.3
|
||||
setuptools==82.0.0
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
sympy==1.14.0
|
||||
threadpoolctl==3.6.0
|
||||
tokenizers==0.22.2
|
||||
torch==2.10.0
|
||||
tqdm==4.67.3
|
||||
transformers==5.2.0
|
||||
triton==3.6.0
|
||||
trl==0.28.0
|
||||
typer==0.24.0
|
||||
typer-slim==0.24.0
|
||||
typing_extensions==4.15.0
|
||||
urllib3==2.6.3
|
||||
xxhash==3.6.0
|
||||
yarl==1.22.0
|
||||
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
@@ -12,7 +11,7 @@ from trl import SFTConfig, SFTTrainer
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3")
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||||
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
|
||||
ap.add_argument("--max_seq_len", type=int, default=2048)
|
||||
ap.add_argument("--batch_size", type=int, default=1)
|
||||
@@ -23,7 +22,7 @@ def main():
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
# QLoRA (4-bit) config (good default for 7B on limited VRAM)
|
||||
# QLoRA (4-bit) config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
@@ -34,7 +33,6 @@ def main():
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
||||
# Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model,
|
||||
@@ -43,7 +41,7 @@ def main():
|
||||
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
|
||||
)
|
||||
|
||||
# LoRA adapter config (tweak r/alpha if needed)
|
||||
# LoRA adapter config
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
@@ -75,7 +73,7 @@ def main():
|
||||
max_length=args.max_seq_len,
|
||||
bf16=torch.cuda.is_available(),
|
||||
fp16=not torch.cuda.is_available(),
|
||||
assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10}
|
||||
assistant_only_loss=True, # only learn from assistant turns in messages
|
||||
report_to=[],
|
||||
)
|
||||
|
||||
@@ -91,7 +89,7 @@ def main():
|
||||
trainer.save_model(args.out_dir)
|
||||
tokenizer.save_pretrained(args.out_dir)
|
||||
|
||||
print(f"Saved LoRA adapter to: {args.out_dir}")
|
||||
print(f"Fertig! LoRA-Adapter gespeichert: {args.out_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user