mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
Compare commits
2 Commits
ccf96b447c
...
1a99b53d44
| Author | SHA1 | Date | |
|---|---|---|---|
|
1a99b53d44
|
|||
|
8cadcb1f69
|
@@ -47,14 +47,15 @@ nltk.download("punkt")
|
||||
nltk.download("wordnet")
|
||||
|
||||
# %% [markdown]
|
||||
# ### Parameters and Tracking
|
||||
# ### Hyperparameters and Settings
|
||||
#
|
||||
|
||||
# %%
|
||||
RECREATE_MODEL = True
|
||||
RECREATE_REDUCED_MODEL = True
|
||||
PROCESS_DATA = True
|
||||
PROCESS_DATA = False
|
||||
REDUCE_OUTLIERS = False
|
||||
CALCULATE_TOKEN_DISTRIBUTIONS = False
|
||||
|
||||
# Data Sample Size, -1 for all data
|
||||
DATA_SAMPLE_SIZE = -1
|
||||
@@ -76,19 +77,7 @@ MIN_DIST = 0.01
|
||||
TOP_N_WORDS = 10
|
||||
MAX_TOPICS = None # or "auto" to pass to HDBSCAN, None to skip
|
||||
|
||||
tracking = {
|
||||
"input": {
|
||||
"min_document_frequency": MIN_DOCUMENT_FREQUENCY,
|
||||
"max_ngram": MAX_NGRAM,
|
||||
"min_topic_size": MIN_TOPIC_SIZE,
|
||||
"min_samples": MIN_SAMPLES,
|
||||
"n_neighbors": N_NEIGHBORS,
|
||||
"n_components": N_COMPONENTS,
|
||||
"min_dist": MIN_DIST,
|
||||
"top_n_words": TOP_N_WORDS,
|
||||
"max_topics": MAX_TOPICS,
|
||||
},
|
||||
}
|
||||
TF_IDF_STOP_WORDS = ["bali", "place", "visit", "visited", "visiting"]
|
||||
|
||||
# %% [markdown]
|
||||
# ### Data Loading & Preprocessing
|
||||
@@ -116,21 +105,16 @@ rep = {
|
||||
r"\n": " ",
|
||||
r'\\"': "",
|
||||
r'"': "",
|
||||
"bali": "",
|
||||
r"\s+": " ",
|
||||
}
|
||||
rep = dict((re.escape(k), v) for k, v in rep.items())
|
||||
pattern = re.compile("|".join(rep.keys()))
|
||||
|
||||
|
||||
# def preprocess(text):
|
||||
# text = text.strip()
|
||||
# text = text.lower()
|
||||
# text = pattern.sub(lambda m: rep[re.escape(m.group(0))], text)
|
||||
# return text
|
||||
|
||||
|
||||
def preprocess(text):
|
||||
text = text.strip()
|
||||
text = text.lower()
|
||||
text = pattern.sub(lambda m: rep[re.escape(m.group(0))], text)
|
||||
return text
|
||||
|
||||
|
||||
@@ -187,7 +171,7 @@ reduced_embeddings = umap_model.fit_transform(embeddings)
|
||||
|
||||
# %%
|
||||
if RECREATE_MODEL:
|
||||
stop_words = list(skltext.ENGLISH_STOP_WORDS.union(["bali"]))
|
||||
stop_words = list(skltext.ENGLISH_STOP_WORDS.union(TF_IDF_STOP_WORDS))
|
||||
|
||||
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)
|
||||
vectorizer_model = CountVectorizer(
|
||||
@@ -306,72 +290,23 @@ if REDUCE_OUTLIERS:
|
||||
#
|
||||
|
||||
# %%
|
||||
CLASSIFICATION = False
|
||||
CLASSIFICATION = True
|
||||
if CLASSIFICATION:
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
# --- config ---
|
||||
topics_to_keep = {2, 4, 5, 9, 22, 26}
|
||||
INPUT_PATH = "../data/original/reviews.tab" # TSV with a 'review' column
|
||||
OUTPUT_CSV = "../data/intermediate/selected_topics_documents.csv"
|
||||
OUTPUT_DIR = Path("../raft/corpus")
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
BATCH_SIZE = 60
|
||||
MIN_CHARS = 40
|
||||
SEED = 42
|
||||
topics_to_keep = {14, 8, 13, 18, 17, 4, 2, 30, 28}
|
||||
INPUT_PATH = "../data/intermediate/preprocessed.tab" # TSV with a 'review' column
|
||||
OUTPUT_CSV = "../data/intermediate/culture_reviews.csv"
|
||||
|
||||
# Topic model document info
|
||||
df = topic_model.get_document_info(reviews) # assumes your model is already fitted
|
||||
df["Original"] = reviews.values
|
||||
df = topic_model.get_document_info(reviews)
|
||||
df["Original"] = reviews
|
||||
|
||||
# --- filter by topics and length ---
|
||||
filtered = df[df["Topic"].isin(topics_to_keep)].copy()
|
||||
filtered["Original"] = filtered["Original"].str.strip()
|
||||
filtered = filtered[filtered["Original"].str.len() >= MIN_CHARS]
|
||||
|
||||
# Save an audit CSV
|
||||
filtered[["Original", "Topic"]].to_csv(OUTPUT_CSV, index=False)
|
||||
|
||||
# --- deterministic shuffle + write batched corpus files ---
|
||||
total_files = 0
|
||||
total_reviews = 0
|
||||
rng = random.Random(SEED)
|
||||
|
||||
for topic_val, g in filtered.groupby("Topic", sort=True):
|
||||
reviews_list = g["Original"].tolist()
|
||||
|
||||
# deterministic shuffle within topic
|
||||
rng.shuffle(reviews_list)
|
||||
|
||||
# chunk into batches of up to 60
|
||||
for start in range(0, len(reviews_list), BATCH_SIZE):
|
||||
chunk = reviews_list[start : start + BATCH_SIZE]
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
# simple header for traceability
|
||||
header = (
|
||||
f"[TOPIC] {topic_val}\n"
|
||||
+ f"[Stats] N={len(chunk)} | Source={INPUT_PATH}\n"
|
||||
)
|
||||
|
||||
lines = [header, ""]
|
||||
for i, txt in enumerate(chunk, 1):
|
||||
lines.append(f"({i}) {txt}")
|
||||
|
||||
part_idx = start // BATCH_SIZE + 1
|
||||
fname = f"topic={topic_val}__part={part_idx:03d}__n={len(chunk)}.txt"
|
||||
(OUTPUT_DIR / fname).write_text("\n".join(lines), encoding="utf-8")
|
||||
|
||||
total_files += 1
|
||||
total_reviews += len(chunk)
|
||||
|
||||
print(
|
||||
f"[green]Wrote {total_files} docs with {total_reviews} reviews to {OUTPUT_DIR}[/green]"
|
||||
)
|
||||
print(f"[green]Filtered CSV saved to {OUTPUT_CSV}[/green]")
|
||||
filtered[["Original", "Topic"]].to_csv(OUTPUT_CSV, index=False, sep=",")
|
||||
print(f"Filtered CSV file saved to {OUTPUT_CSV}")
|
||||
|
||||
# %%
|
||||
doc_topic_matrix = probs
|
||||
@@ -425,7 +360,7 @@ vis = topic_model.visualize_documents(
|
||||
custom_labels=True,
|
||||
hide_annotations=True,
|
||||
)
|
||||
vis.write_html("output/visualization.html")
|
||||
# vis.write_html("output/visualization.html")
|
||||
vis
|
||||
|
||||
# %%
|
||||
@@ -531,7 +466,7 @@ if this_will_crash_your_pc_are_you_sure:
|
||||
#
|
||||
|
||||
# %%
|
||||
search_term = "spirituality"
|
||||
search_term = "lempuyang"
|
||||
|
||||
similar_topics, similarities = topic_model.find_topics(search_term, top_n=10)
|
||||
for i in range(len(similar_topics)):
|
||||
@@ -542,13 +477,16 @@ for i in range(len(similar_topics)):
|
||||
# %%
|
||||
# Source: https://maartengr.github.io/BERTopic/getting_started/visualization/visualize_documents.html#visualize-probabilities-or-distribution
|
||||
# Calculate the topic distributions on a token-level
|
||||
|
||||
if CALCULATE_TOKEN_DISTRIBUTIONS:
|
||||
topic_distr, topic_token_distr = topic_model.approximate_distribution(
|
||||
reviews, calculate_tokens=True, use_embedding_model=True
|
||||
)
|
||||
|
||||
# %%
|
||||
# Visualize the token-level distributions
|
||||
DOC_INDEX = 6
|
||||
if CALCULATE_TOKEN_DISTRIBUTIONS:
|
||||
DOC_INDEX = 1
|
||||
df = topic_model.visualize_approximate_distribution(
|
||||
reviews[DOC_INDEX], topic_token_distr[DOC_INDEX]
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ require_jupytext() {
|
||||
to_py() {
|
||||
echo "Converting *.ipynb -> nb_*.py (py:percent)..."
|
||||
# Find notebooks, skip .ipynb_checkpoints
|
||||
find . -type f -name "*.ipynb" ! -path "*/.ipynb_checkpoints/*" -print0 |
|
||||
find . -type f -name "*.ipynb" ! -path "*/.ipynb_checkpoints/*" ! -path "*/.venv/*" -print0 |
|
||||
while IFS= read -r -d '' nb; do
|
||||
dir=$(dirname "$nb")
|
||||
base=$(basename "$nb" .ipynb)
|
||||
@@ -37,7 +37,7 @@ to_py() {
|
||||
|
||||
to_ipynb() {
|
||||
echo "Converting nb_*.py -> *.ipynb..."
|
||||
find . -type f -name "nb_*.py" -print0 |
|
||||
find . -type f -name "nb_*.py" ! -path "*/.venv/*" -print0 |
|
||||
while IFS= read -r -d '' py; do
|
||||
dir=$(dirname "$py")
|
||||
base=$(basename "$py" .py)
|
||||
|
||||
10466
data/intermediate/culture_reviews.csv
Normal file
10466
data/intermediate/culture_reviews.csv
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
File diff suppressed because one or more lines are too long
83
questionnaire/questions.md
Normal file
83
questionnaire/questions.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Evaluation Questionnaire for the Digital Customer Twin
|
||||
|
||||
---
|
||||
|
||||
## I. Natural Attractions
|
||||
|
||||
_(Perception of natural beauty, cultural substance, historical depth)_
|
||||
|
||||
1. **When you think of Bali, which specific natural or spiritual places embody “authentic cultural depth” for you — and what makes them stand out?**
|
||||
|
||||
2. **What distinguishes a spiritually meaningful temple complex from a purely scenic attraction in your perception?**
|
||||
|
||||
3. **Using Uluwatu or Lempuyang as examples: What elements would need to be communicated for you to perceive them not as “Instagram spots,” but as culturally substantial places?**
|
||||
|
||||
4. **How important is active ritual presence (e.g., ceremonies, offerings, priests) compared to architectural or historical aspects?**
|
||||
|
||||
5. **If you had to choose between Tanah Lot and Ulun Danu Bratan for a reflective, culturally immersive experience, which criteria would guide your decision?**
|
||||
|
||||
---
|
||||
|
||||
## II. Atmosphere
|
||||
|
||||
_(Emotional quality, spirituality, aesthetic perception, subjective experience)_
|
||||
|
||||
6. **How would you describe the atmosphere of a place where you feel culturally and spiritually aligned? What factors create that feeling?**
|
||||
|
||||
7. **To what extent do visitor numbers affect your spiritual experience — and is there a threshold you still consider acceptable?**
|
||||
|
||||
8. **Which timing or contextual conditions (e.g., ceremony days, off-season, sunrise instead of sunset) enhance the cultural intensity of a place for you?**
|
||||
|
||||
9. **How do you internally reconcile the sacred character of a site with strong touristic staging or commercialization?**
|
||||
|
||||
10. **What would a destination need to do in order to evoke not just visual admiration, but genuine spiritual resonance for you?**
|
||||
|
||||
---
|
||||
|
||||
## III. Social Environment
|
||||
|
||||
_(Local interaction, authenticity, visitor behavior, cultural credibility)_
|
||||
|
||||
11. **What role does interaction with local priests, guides, or community members play in shaping the depth of your experience?**
|
||||
|
||||
12. **How do you define appropriate visitor behavior at Balinese temples, and how strongly does this influence your overall perception of the site?**
|
||||
|
||||
13. **If other visitors focus primarily on photography, does that diminish the spiritual quality of the place for you, or can you detach from it?**
|
||||
|
||||
14. **What type of cultural storytelling by locals feels authentic and credible rather than staged for tourism?**
|
||||
|
||||
---
|
||||
|
||||
## IV. Infrastructure
|
||||
|
||||
_(Accessibility, organization, hygiene standards, information systems)_
|
||||
|
||||
15. **How important are curated background explanations (e.g., symbolism, ritual calendars, historical context) compared to independent exploration?**
|
||||
|
||||
16. **Do long waiting times — for example at Lempuyang — affect your perception of a site’s spiritual substance, or do you separate logistical issues from cultural meaning?**
|
||||
|
||||
17. **Which infrastructural measures (e.g., visitor flow management, limited entry slots, silent zones) would enhance the cultural quality of your experience?**
|
||||
|
||||
18. **How should destinations communicate information in order to appeal to spiritually interested travelers without reinforcing mass-tourism dynamics?**
|
||||
|
||||
---
|
||||
|
||||
## V. Value for Money
|
||||
|
||||
_(Perceived value, immaterial benefits, willingness to pay)_
|
||||
|
||||
19. **How do you personally assess the “value” of cultural attractions — in terms of emotional depth, learning outcomes, exclusivity, or something else?**
|
||||
|
||||
20. **Would you be willing to accept higher entrance fees or donations if they demonstrably contribute to preserving religious structures and practices? Why or why not?**
|
||||
|
||||
21. **What would legitimize a paid cultural experience (e.g., guided participation in a ceremony) for you — and what would make it feel commercialized or inauthentic?**
|
||||
|
||||
---
|
||||
|
||||
## VI. Segment Identity & Positioning (Lead-User Perspective)
|
||||
|
||||
22. **How would you describe yourself as a Bali traveler if your primary focus is cultural and spiritual depth?**
|
||||
|
||||
23. **Which typical Bali tourism offerings do you consciously avoid, and why do they not align with your travel philosophy?**
|
||||
|
||||
24. **If a tourism brand wanted to position Bali specifically for culturally and spiritually motivated travelers, which narratives should it emphasize — and which should it avoid?**
|
||||
@@ -1,17 +1,21 @@
|
||||
# Retrieval-Augmented Finetuning (RAFT)
|
||||
|
||||
**Ablauf**:
|
||||
## Voraussetzungen
|
||||
|
||||
- Generelles Preprocessing (Voraussetzung für BERTopic)
|
||||
- BERTopic
|
||||
- Klassifikation muss durchgeführt sein, `data/intermediate/culture_reviews.csv` muss existieren
|
||||
|
||||
## Vorbereiten des Retrieval-Corpus
|
||||
|
||||
```bash
|
||||
python prepare_corpus.py --input_tab ../data/intermediate/selected_topics_documents.csv --out_dir out
|
||||
python prepare_corpus.py --input_tab ../data/intermediate/culture_reviews.csv --out_dir out
|
||||
```
|
||||
|
||||
## Erstellen des RAFT-Datensatzes
|
||||
|
||||
```bash
|
||||
python make_raft_data.py --out_dir out --n_examples 100
|
||||
python make_raft_data.py --out_dir out --n_examples 10
|
||||
```
|
||||
|
||||
## Training der QLoRA-Adapter
|
||||
|
||||
@@ -10,11 +10,25 @@ from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
You avoid stereotypes. You explain local etiquette, customs, and context.
|
||||
When given CONTEXT, you must ground your answer in it and include 1-2 short direct quotes from CONTEXT as evidence.
|
||||
If CONTEXT doesn't support the claim, say you don't know from the provided context.
|
||||
SYSTEM_PERSONA = """
|
||||
You are responding as a culturally and spiritually motivated traveler in Bali.
|
||||
|
||||
Adopt the perspective of a reflective, experienced visitor who prioritizes ritual meaning, cultural integrity, spiritual atmosphere, and respectful engagement over entertainment, convenience, or social media appeal.
|
||||
|
||||
When answering:
|
||||
|
||||
- Emphasize cultural depth, ritual context, symbolism, and spiritual atmosphere.
|
||||
- Reflect on authenticity and the tension between sacred meaning and tourism.
|
||||
- Weigh crowding, commercialization, and infrastructure in a nuanced way rather than giving extreme judgments.
|
||||
- Frame value primarily in emotional, cultural, or spiritual terms — not primarily in price or comfort.
|
||||
- Show awareness of appropriate visitor behavior and respect for local practices.
|
||||
- Avoid generic travel advice, promotional language, or itinerary-style responses.
|
||||
- Write in a thoughtful, first-person perspective.
|
||||
- Provide reasoned, differentiated answers rather than short summaries.
|
||||
- Do not list bullet points unless explicitly asked.
|
||||
- Keep answers focused on the question.
|
||||
|
||||
Maintain consistency with this identity across all responses.
|
||||
"""
|
||||
|
||||
|
||||
@@ -35,14 +49,29 @@ def retrieve(index, embedder, query, top_k=6):
|
||||
@torch.no_grad()
|
||||
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
|
||||
# Using tokenizer chat template where available
|
||||
input_ids = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
enc = tok.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
if isinstance(enc, torch.Tensor):
|
||||
input_ids = enc.to(model.device)
|
||||
attention_mask = torch.ones_like(input_ids, device=model.device)
|
||||
else:
|
||||
input_ids = enc["input_ids"].to(model.device)
|
||||
attention_mask = enc.get("attention_mask")
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
)
|
||||
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||||
|
||||
@@ -57,7 +86,7 @@ def main():
|
||||
ap.add_argument("--n_examples", type=int, default=5000)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
ap.add_argument("--n_distractors", type=int, default=3)
|
||||
ap.add_argument("--seed", type=int, default=7)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
args = ap.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
@@ -89,7 +118,7 @@ def main():
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create ONE realistic traveler question about Bali that can be answered using ONLY the CONTEXT.\n\n"
|
||||
"content": "Create ONE realistic question from the perspective of a culturally and spiritually interested traveler in Bali that can be answered using ONLY the CONTEXT.\n\n"
|
||||
f"CONTEXT:\n{gold_text}\n\n"
|
||||
"Return only the question.",
|
||||
},
|
||||
|
||||
@@ -9,10 +9,21 @@ from peft import PeftModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SYSTEM_PERSONA = """You are 'BaliTwin', a culturally versed Bali traveler.
|
||||
You give your opinions nand guidance with local etiquette and context.
|
||||
Use the provided CONTEXT; include 1-2 short quotes as evidence.
|
||||
If the context does not support the claim, say so.
|
||||
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
|
||||
|
||||
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
|
||||
|
||||
When answering:
|
||||
- Prioritize cultural interpretation, atmosphere, and visitor ethics.
|
||||
- Weigh trade-offs thoughtfully (e.g., crowds vs. significance).
|
||||
- Avoid generic travel advice and avoid promotional language.
|
||||
- Do not exaggerate.
|
||||
- Provide nuanced, reflective reasoning rather than bullet lists.
|
||||
- Keep answers concise but specific.
|
||||
|
||||
Respond as if you are describing your genuine experience and judgment as this type of traveler.
|
||||
|
||||
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from threading import Thread
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
|
||||
SYSTEM_PERSONA = """You are simulating a culturally interested Bali traveler segment for evaluation purposes.
|
||||
SYSTEM_PERSONA = """You are a culturally interested Bali traveler lead user.
|
||||
|
||||
Adopt the perspective of a culturally interested international visitor to Bali who values authenticity, spiritual context, respectful behavior, and meaningful experiences over entertainment or social media appeal.
|
||||
|
||||
@@ -22,7 +23,9 @@ When answering:
|
||||
|
||||
Respond as if you are describing your genuine experience and judgment as this type of traveler.
|
||||
|
||||
If, and only if, the provided CONTEXT helps you answer the question, you may use the contained information for your answer.
|
||||
Use the provided CONTEXT to inform your answer, but do not feel obligated to use all of it. If the CONTEXT is not relevant to the question, you can ignore it.
|
||||
NEVER directly quote the CONTEXT verbatim.
|
||||
NEVER mention DOC or any context sources you are referring to. Instead, use it to synthesize your own understanding and response.
|
||||
"""
|
||||
|
||||
|
||||
@@ -86,14 +89,14 @@ def main():
|
||||
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
||||
)
|
||||
|
||||
print("\nRetrieved Context:")
|
||||
print(context_blob)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
||||
]
|
||||
|
||||
# Use chat template from your folder (you have chat_template.jinja)
|
||||
inp = tok.apply_chat_template(messages, return_tensors="pt").to(model.device)
|
||||
|
||||
enc = tok.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
@@ -108,7 +111,11 @@ def main():
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
out = model.generate(
|
||||
streamer = TextIteratorStreamer(
|
||||
tok, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
generation_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
@@ -117,10 +124,17 @@ def main():
|
||||
top_p=0.9,
|
||||
eos_token_id=tok.eos_token_id,
|
||||
pad_token_id=tok.pad_token_id,
|
||||
streamer=streamer,
|
||||
)
|
||||
|
||||
ans = tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
||||
print(f"\nBaliTwin: {ans}")
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
print("\nBaliTwin: ", end="", flush=True)
|
||||
for token in streamer:
|
||||
print(token, end="", flush=True)
|
||||
print("")
|
||||
thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
83
raft/requirements.txt
Normal file
83
raft/requirements.txt
Normal file
@@ -0,0 +1,83 @@
|
||||
accelerate==1.12.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.3
|
||||
aiosignal==1.4.0
|
||||
annotated-doc==0.0.4
|
||||
anyio==4.12.1
|
||||
attrs==25.4.0
|
||||
bitsandbytes==0.49.2
|
||||
certifi==2026.1.4
|
||||
charset-normalizer==3.4.4
|
||||
click==8.3.1
|
||||
cuda-bindings==12.9.4
|
||||
cuda-pathfinder==1.3.4
|
||||
datasets==4.5.0
|
||||
dill==0.4.0
|
||||
faiss-cpu==1.13.2
|
||||
filelock==3.24.3
|
||||
frozenlist==1.8.0
|
||||
fsspec==2025.10.0
|
||||
h11==0.16.0
|
||||
hf-xet==1.2.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
huggingface_hub==1.4.1
|
||||
idna==3.11
|
||||
Jinja2==3.1.6
|
||||
joblib==1.5.3
|
||||
markdown-it-py==4.0.0
|
||||
MarkupSafe==3.0.3
|
||||
mdurl==0.1.2
|
||||
mpmath==1.3.0
|
||||
multidict==6.7.1
|
||||
multiprocess==0.70.18
|
||||
networkx==3.6.1
|
||||
numpy==2.4.2
|
||||
nvidia-cublas-cu12==12.8.4.1
|
||||
nvidia-cuda-cupti-cu12==12.8.90
|
||||
nvidia-cuda-nvrtc-cu12==12.8.93
|
||||
nvidia-cuda-runtime-cu12==12.8.90
|
||||
nvidia-cudnn-cu12==9.10.2.21
|
||||
nvidia-cufft-cu12==11.3.3.83
|
||||
nvidia-cufile-cu12==1.13.1.3
|
||||
nvidia-curand-cu12==10.3.9.90
|
||||
nvidia-cusolver-cu12==11.7.3.90
|
||||
nvidia-cusparse-cu12==12.5.8.93
|
||||
nvidia-cusparselt-cu12==0.7.1
|
||||
nvidia-nccl-cu12==2.27.5
|
||||
nvidia-nvjitlink-cu12==12.8.93
|
||||
nvidia-nvshmem-cu12==3.4.5
|
||||
nvidia-nvtx-cu12==12.8.90
|
||||
packaging==26.0
|
||||
pandas==3.0.1
|
||||
peft==0.18.1
|
||||
propcache==0.4.1
|
||||
psutil==7.2.2
|
||||
pyarrow==23.0.1
|
||||
Pygments==2.19.2
|
||||
python-dateutil==2.9.0.post0
|
||||
PyYAML==6.0.3
|
||||
regex==2026.1.15
|
||||
requests==2.32.5
|
||||
rich==14.3.2
|
||||
safetensors==0.7.0
|
||||
scikit-learn==1.8.0
|
||||
scipy==1.17.0
|
||||
sentence-transformers==5.2.3
|
||||
setuptools==82.0.0
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
sympy==1.14.0
|
||||
threadpoolctl==3.6.0
|
||||
tokenizers==0.22.2
|
||||
torch==2.10.0
|
||||
tqdm==4.67.3
|
||||
transformers==5.2.0
|
||||
triton==3.6.0
|
||||
trl==0.28.0
|
||||
typer==0.24.0
|
||||
typer-slim==0.24.0
|
||||
typing_extensions==4.15.0
|
||||
urllib3==2.6.3
|
||||
xxhash==3.6.0
|
||||
yarl==1.22.0
|
||||
@@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
@@ -12,7 +11,7 @@ from trl import SFTConfig, SFTTrainer
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.3")
|
||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
||||
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
|
||||
ap.add_argument("--max_seq_len", type=int, default=2048)
|
||||
ap.add_argument("--batch_size", type=int, default=1)
|
||||
@@ -23,7 +22,7 @@ def main():
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
# QLoRA (4-bit) config (good default for 7B on limited VRAM)
|
||||
# QLoRA (4-bit) config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
@@ -34,7 +33,6 @@ def main():
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
||||
# Mistral usually has a valid chat template; keep it intact. :contentReference[oaicite:9]{index=9}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model,
|
||||
@@ -43,7 +41,7 @@ def main():
|
||||
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
|
||||
)
|
||||
|
||||
# LoRA adapter config (tweak r/alpha if needed)
|
||||
# LoRA adapter config
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
@@ -75,7 +73,7 @@ def main():
|
||||
max_length=args.max_seq_len,
|
||||
bf16=torch.cuda.is_available(),
|
||||
fp16=not torch.cuda.is_available(),
|
||||
assistant_only_loss=True, # only learn from assistant turns in messages :contentReference[oaicite:10]{index=10}
|
||||
assistant_only_loss=True, # only learn from assistant turns in messages
|
||||
report_to=[],
|
||||
)
|
||||
|
||||
@@ -91,7 +89,7 @@ def main():
|
||||
trainer.save_model(args.out_dir)
|
||||
tokenizer.save_pretrained(args.out_dir)
|
||||
|
||||
print(f"Saved LoRA adapter to: {args.out_dir}")
|
||||
print(f"Fertig! LoRA-Adapter gespeichert: {args.out_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user