RAFT updates, BERTopic config, cleanup

This commit is contained in:
2026-02-21 01:57:14 +01:00
parent 8cadcb1f69
commit 1a99b53d44
12 changed files with 10750 additions and 9778 deletions

View File

@@ -1,17 +1,21 @@
# Retrieval-Augmented Finetuning (RAFT)
**Ablauf**:
## Voraussetzungen
- Generelles Preprocessing (Voraussetzung für BERTopic)
- BERTopic
- Klassifikation muss durchgeführt sein, `data/intermediate/culture_reviews.csv` muss existieren
## Vorbereiten des Retrieval-Corpus
```bash
python prepare_corpus.py --input_tab ../data/intermediate/selected_topics_documents.csv --out_dir out
python prepare_corpus.py --input_tab ../data/intermediate/culture_reviews.csv --out_dir out
```
## Erstellen des RAFT-Datensatzes
```bash
python make_raft_data.py --out_dir out --n_examples 100
python make_raft_data.py --out_dir out --n_examples 10
```
## Training der QLoRA-Adapter

View File

@@ -10,11 +10,25 @@ from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
You give your opinions nand guidance with local etiquette and context.
You avoid stereotypes. You explain local etiquette, customs, and context.
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
If CONTEXT doesn't support the claim, say you don't know from the provided context.
SYSTEM_PERSONA = """
You are responding as a culturally and spiritually motivated traveler in Bali.
Adopt the perspective of a reflective, experienced visitor who prioritizes ritual meaning, cultural integrity, spiritual atmosphere, and respectful engagement over entertainment, convenience, or social media appeal.
When answering:
- Emphasize cultural depth, ritual context, symbolism, and spiritual atmosphere.
- Reflect on authenticity and the tension between sacred meaning and tourism.
- Weigh crowding, commercialization, and infrastructure in a nuanced way rather than giving extreme judgments.
- Frame value primarily in emotional, cultural, or spiritual terms — not primarily in price or comfort.
- Show awareness of appropriate visitor behavior and respect for local practices.
- Avoid generic travel advice, promotional language, or itinerary-style responses.
- Write in a thoughtful, first-person perspective.
- Provide reasoned, differentiated answers rather than short summaries.
- Do not list bullet points unless explicitly asked.
- Keep answers focused on the question.
Maintain consistency with this identity across all responses.
"""
@@ -35,14 +49,29 @@ def retrieve(index, embedder, query, top_k=6):
@torch.no_grad()
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
# Using tokenizer chat template where available
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
enc = tok.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
if isinstance(enc, torch.Tensor):
input_ids = enc.to(model.device)
attention_mask = torch.ones_like(input_ids, device=model.device)
else:
input_ids = enc["input_ids"].to(model.device)
attention_mask = enc.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = attention_mask.to(model.device)
out = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=0.9,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
@@ -57,7 +86,7 @@ def main():
ap.add_argument("--n_examples", type=int, default=5000)
ap.add_argument("--top_k", type=int, default=6)
ap.add_argument("--n_distractors", type=int, default=3)
ap.add_argument("--seed", type=int, default=7)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
random.seed(args.seed)
@@ -89,7 +118,7 @@ def main():
{"role": "system", "content": SYSTEM_PERSONA},
{
"role": "user",
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
"content": "Create ONE realistic question from the perspective of a culturally and spiritually interested traveler in Bali that can be answered using ONLY the CONTEXT.\n\n"
f"CONTEXT:\n{gold_text}\n\n"
"Return only the question.",
},

View File

@@ -9,10 +9,21 @@ from peft import PeftModel
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
You give your opinions nand guidance with local etiquette and context.
Use the provided CONTEXT; include 1-2 short quotes as evidence.
If the context does not support the claim, say so.
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
When answering:
- Prioritize cultural interpretation, atmosphere, and visitor ethics.
- Weigh trade-offs thoughtfully (e.g., crowds vs. significance).
- Avoid generic travel advice and avoid promotional language.
- Do not exaggerate.
- Provide nuanced, reflective reasoning rather than bullet lists.
- Keep answers concise but specific.
Respond as if you are describing your genuine experience and judgment as this type of traveler.
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
"""

View File

@@ -1,14 +1,15 @@
import argparse
import json
import os
from threading import Thread
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
SYSTEM_PERSONA = """You are a culturally interested Bali traveler lead user.
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
@@ -22,7 +23,9 @@ When answering:
Respond as if you are describing your genuine experience and judgment as this type of traveler.
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
Use the provided CONTEXT to inform your answer, but do not feel obligated to use all of it. If the CONTEXT is not relevant to the question, you can ignore it.
NEVER directly quote the CONTEXT verbatim.
NEVER mention DOC or any context sources you are referring to. Instead, use it to synthesize your own understanding and response.
"""
@@ -86,14 +89,14 @@ def main():
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
)
print("\nRetrieved Context:")
print(context_blob)
messages = [
{"role": "system", "content": SYSTEM_PERSONA},
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
]
# Use chat template from your folder (you have chat_template.jinja)
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
enc = tok.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
@@ -108,7 +111,11 @@ def main():
attention_mask = torch.ones_like(input_ids)
attention_mask = attention_mask.to(model.device)
out = model.generate(
streamer = TextIteratorStreamer(
tok, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
generation_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=args.max_new_tokens,
@@ -117,10 +124,17 @@ def main():
top_p=0.9,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
streamer=streamer,
)
ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
print(f"\nBaliTwin: {ans}")
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
print("\nBaliTwin: ", end="", flush=True)
for token in streamer:
print(token, end="", flush=True)
print("")
thread.join()
if __name__ == "__main__":

83
raft/requirements.txt Normal file
View File

@@ -0,0 +1,83 @@
accelerate==1.12.0
aiohappyeyeballs==2.6.1
aiohttp==3.13.3
aiosignal==1.4.0
annotated-doc==0.0.4
anyio==4.12.1
attrs==25.4.0
bitsandbytes==0.49.2
certifi==2026.1.4
charset-normalizer==3.4.4
click==8.3.1
cuda-bindings==12.9.4
cuda-pathfinder==1.3.4
datasets==4.5.0
dill==0.4.0
faiss-cpu==1.13.2
filelock==3.24.3
frozenlist==1.8.0
fsspec==2025.10.0
h11==0.16.0
hf-xet==1.2.0
httpcore==1.0.9
httpx==0.28.1
huggingface_hub==1.4.1
idna==3.11
Jinja2==3.1.6
joblib==1.5.3
markdown-it-py==4.0.0
MarkupSafe==3.0.3
mdurl==0.1.2
mpmath==1.3.0
multidict==6.7.1
multiprocess==0.70.18
networkx==3.6.1
numpy==2.4.2
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.5
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvshmem-cu12==3.4.5
nvidia-nvtx-cu12==12.8.90
packaging==26.0
pandas==3.0.1
peft==0.18.1
propcache==0.4.1
psutil==7.2.2
pyarrow==23.0.1
Pygments==2.19.2
python-dateutil==2.9.0.post0
PyYAML==6.0.3
regex==2026.1.15
requests==2.32.5
rich==14.3.2
safetensors==0.7.0
scikit-learn==1.8.0
scipy==1.17.0
sentence-transformers==5.2.3
setuptools==82.0.0
shellingham==1.5.4
six==1.17.0
sympy==1.14.0
threadpoolctl==3.6.0
tokenizers==0.22.2
torch==2.10.0
tqdm==4.67.3
transformers==5.2.0
triton==3.6.0
trl==0.28.0
typer==0.24.0
typer-slim==0.24.0
typing_extensions==4.15.0
urllib3==2.6.3
xxhash==3.6.0
yarl==1.22.0

View File

@@ -1,5 +1,4 @@
import argparse
import json
import os
import torch
@@ -12,7 +11,7 @@ from trl import SFTConfig, SFTTrainer
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3")
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
ap.add_argument("--max_seq_len", type=int, default=2048)
ap.add_argument("--batch_size", type=int, default=1)
@@ -23,7 +22,7 @@ def main():
os.makedirs(args.out_dir, exist_ok=True)
# QLoRA (4-bit) config (good default for 7B on limited VRAM)
# QLoRA (4-bit) config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
@@ -34,7 +33,6 @@ def main():
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
# Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9}
model = AutoModelForCausalLM.from_pretrained(
args.base_model,
@@ -43,7 +41,7 @@ def main():
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)
# LoRA adapter config (tweak r/alpha if needed)
# LoRA adapter config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
@@ -75,7 +73,7 @@ def main():
max_length=args.max_seq_len,
bf16=torch.cuda.is_available(),
fp16=not torch.cuda.is_available(),
assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10}
assistant_only_loss=True, # only learn from assistant turns in messages
report_to=[],
)
@@ -91,7 +89,7 @@ def main():
trainer.save_model(args.out_dir)
tokenizer.save_pretrained(args.out_dir)
print(f"Saved LoRA adapter to: {args.out_dir}")
print(f"Fertig! LoRA-Adapter gespeichert: {args.out_dir}")
if __name__ == "__main__":