mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-05-13 15:55:46 +02:00
Compare commits
5 Commits
71886c9091
..
legacy
| Author | SHA1 | Date | |
|---|---|---|---|
|
c98a1d0c6e
|
|||
|
b2da597b18
|
|||
|
e3c9b7286f
|
|||
|
ef99f152ac
|
|||
|
edafc06cab
|
@@ -0,0 +1,14 @@
|
||||
# Masterthesis, praktischer Anteil
|
||||
|
||||
## Jupyter Notebooks "rehydrieren"
|
||||
|
||||
Damit keine unnötigen Jupyter Outputs etc. im Versionsmanagement landen, gibt es das Skript `convert_jupytext.sh`, welches nur den notwendigen Quelltext in ein `.py` File schreibt. Mit demselben Skript kann dieser Schritt wieder umgekehrt werden, also ein Jupyter Notebook aus dem Python-File geschrieben werden.
|
||||
|
||||
Das Skript sollte also immer vor dem Committen von Änderungen mit `py` als erstes Argument ausgeführt werden.
|
||||
|
||||
Verwendung:
|
||||
|
||||
```bash
|
||||
./convert_jupytext.sh py # Jupyter Notebook -> Python
|
||||
./convert_jupytext.sh nb # Python -> Jupyter Notebook
|
||||
```
|
||||
@@ -3,6 +3,8 @@ import traceback
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from bertopic.representation import KeyBERTInspired
|
||||
from bertopic.vectorizers import ClassTfidfTransformer
|
||||
from hdbscan import HDBSCAN
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
@@ -12,55 +14,50 @@ from sklearn.model_selection import ParameterGrid
|
||||
from umap import UMAP
|
||||
|
||||
from bertopic import BERTopic
|
||||
from bertopic.representation import KeyBERTInspired
|
||||
from bertopic.vectorizers import ClassTfidfTransformer
|
||||
|
||||
param_grid = {
|
||||
"nr_topics": [45, 50, 55],
|
||||
"min_topic_size": [30, 40, 50],
|
||||
"n_gram_max": [3],
|
||||
"min_document_frequency": [1, 2],
|
||||
"n_neighbors": [15],
|
||||
"n_components": [2],
|
||||
"min_dist": [0.1],
|
||||
"top_n_words": [10],
|
||||
"n_gram_max": [2, 3], # Vectorization
|
||||
"min_document_frequency": [1], # Vectorization
|
||||
"min_samples": [10, 25], # HDBSCAN
|
||||
"min_topic_size": [10, 20, 30, 40, 50], # HDBSCAN
|
||||
"n_neighbors": [15], # UMAP
|
||||
"n_components": [2, 5], # UMAP
|
||||
"min_dist": [0.01, 0.1], # UMAP
|
||||
"nr_topics": ["auto"], # Topic Modeling
|
||||
"top_n_words": [10, 13, 15, 17, 20], # Topic Modeling
|
||||
}
|
||||
|
||||
|
||||
def calculate_metrics(topic_model, embedder, top_n_words=5):
|
||||
def calculate_metrics(topic_model, embedder, top_n_words=10):
|
||||
# Get topic words
|
||||
topic_words = []
|
||||
for topic_id in range(len(topic_model.get_topic_info()) - 1):
|
||||
words = [word for word, _ in topic_model.get_topic(topic_id)]
|
||||
topic_words.append(words[:top_n_words])
|
||||
|
||||
# Pre-compute embeddings for all unique words
|
||||
all_words = list(set(word for words in topic_words for word in words))
|
||||
word_embeddings = embedder.encode(all_words)
|
||||
embedding_map = {word: emb for word, emb in zip(all_words, word_embeddings)}
|
||||
|
||||
# Coherence
|
||||
coherence_scores = []
|
||||
for words in topic_words:
|
||||
embeddings = embedder.encode(words)
|
||||
embeddings = np.array([embedding_map[word] for word in words])
|
||||
sim_matrix = cosine_similarity(embeddings)
|
||||
np.fill_diagonal(sim_matrix, 0)
|
||||
coherence_scores.append(np.mean(sim_matrix))
|
||||
mean_sim = np.mean(sim_matrix[np.triu_indices(sim_matrix.shape[0], k=1)])
|
||||
coherence_scores.append(mean_sim)
|
||||
overall_coherence = np.mean(coherence_scores)
|
||||
|
||||
# Diversity
|
||||
all_topic_words = [word for topic in topic_words for word in topic]
|
||||
diversity = len(set(all_topic_words)) / len(all_topic_words)
|
||||
|
||||
# Inter-topic distance
|
||||
topic_embeddings = [
|
||||
np.mean(embedder.encode(words), axis=0) for words in topic_words
|
||||
]
|
||||
topic_distance = pairwise_distances(topic_embeddings, metric="cosine")
|
||||
avg_distance = np.mean(topic_distance[np.triu_indices_from(topic_distance, k=1)])
|
||||
|
||||
res = {
|
||||
"coherence": float(str(overall_coherence)[:6]),
|
||||
"diversity": float(str(diversity)[:6]),
|
||||
"inter_topic_distance": float(str(avg_distance)[:6]),
|
||||
"combined_score": float(
|
||||
str(0.6 * overall_coherence + 0.2 * diversity + 0.2 * avg_distance)[:6]
|
||||
),
|
||||
"combined_score": float(str(0.7 * overall_coherence + 0.3 * diversity)[:6]),
|
||||
}
|
||||
print(res)
|
||||
return res
|
||||
@@ -85,6 +82,7 @@ def auto_tune_bertopic(texts, embedding_model, param_grid):
|
||||
|
||||
print(f"Total parameter combinations: {len(param_list)}")
|
||||
for params in param_list:
|
||||
print(f"Testing param combination no. {len(history) + 1}/{len(param_list)}...")
|
||||
try:
|
||||
print(f"Testing params: {params}")
|
||||
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)
|
||||
@@ -143,18 +141,27 @@ def auto_tune_bertopic(texts, embedding_model, param_grid):
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
return best_model, best_params, best_score, history
|
||||
with open("output/autotune.json", "w") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
return best_model, best_params, best_score
|
||||
|
||||
|
||||
SPECIAL_CHARS = ["\n", "\\n"]
|
||||
MIN_REVIEW_WORDS = 5
|
||||
|
||||
reviews = pd.read_csv("data.tab", sep="\t").review.to_list()
|
||||
print("Loading reviews...")
|
||||
reviews = pd.read_csv("../data/original/reviews.tab", sep="\t").review.to_list()
|
||||
|
||||
print("Running light preprocessing...")
|
||||
for schar in SPECIAL_CHARS:
|
||||
reviews = [
|
||||
review.replace(schar, " ") if isinstance(review, str) else review
|
||||
for review in reviews
|
||||
]
|
||||
|
||||
print("Filtering short reviews...")
|
||||
reviews = [review for review in reviews if len(str(review).split()) >= MIN_REVIEW_WORDS]
|
||||
|
||||
print("Staring auto-tuning...")
|
||||
print(auto_tune_bertopic(reviews, "all-MiniLM-L6-v2", param_grid))
|
||||
|
||||
@@ -2,12 +2,12 @@ import json
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
with open("history.json", "r") as f:
|
||||
with open("output/autotune.json", "r") as f:
|
||||
history = json.load(f)
|
||||
|
||||
history = sorted(history, key=lambda x: x["metrics"]["combined_score"], reverse=True)
|
||||
history = sorted(history, key=lambda x: x["metrics"]["combined_score"], reverse=False)
|
||||
|
||||
with open("history_sorted.json", "w") as f:
|
||||
with open("output/autotune_sorted.json", "w") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 16 KiB |
+21
-13
@@ -23,7 +23,15 @@
|
||||
#
|
||||
|
||||
# %%
|
||||
from bertopic import BERTopic
|
||||
import json
|
||||
import pickle
|
||||
import re
|
||||
|
||||
import gensim.corpora as corpora
|
||||
import nltk
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import spacy
|
||||
from bertopic.representation import KeyBERTInspired
|
||||
from bertopic.vectorizers import ClassTfidfTransformer
|
||||
from gensim.models.coherencemodel import CoherenceModel
|
||||
@@ -34,14 +42,8 @@ from sentence_transformers import SentenceTransformer
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from umap import UMAP
|
||||
import gensim.corpora as corpora
|
||||
import json
|
||||
import nltk
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import re
|
||||
import spacy
|
||||
import pickle
|
||||
|
||||
from bertopic import BERTopic
|
||||
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
@@ -323,8 +325,8 @@ if REDUCE_OUTLIERS:
|
||||
#
|
||||
|
||||
# %%
|
||||
from pathlib import Path
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
# --- config ---
|
||||
topics_to_keep = {2, 4, 6, 8, 10, 5, 7}
|
||||
@@ -468,7 +470,11 @@ topic_model.get_topic_info()
|
||||
|
||||
# %%
|
||||
topic_words = []
|
||||
for topic_id in range(len(topic_model.get_topic_info()) - 1):
|
||||
for topic_id in topic_model.get_topic_info()["Topic"]:
|
||||
# Skip outlier topic
|
||||
if topic_id < 0:
|
||||
continue
|
||||
|
||||
words = [word for word, _ in topic_model.get_topic(topic_id)]
|
||||
topic_words.append(words)
|
||||
|
||||
@@ -477,8 +483,10 @@ coherence_scores = []
|
||||
for words in topic_words:
|
||||
coherence_embeddings = embedding_model.encode(words)
|
||||
sim_matrix = cosine_similarity(coherence_embeddings)
|
||||
np.fill_diagonal(sim_matrix, 0) # Ignore self-similarity
|
||||
mean_sim = np.mean(sim_matrix)
|
||||
|
||||
# Ignore self-similarity
|
||||
np.fill_diagonal(sim_matrix, 0)
|
||||
mean_sim = np.mean(sim_matrix[np.triu_indices(sim_matrix.shape[0], k=1)])
|
||||
coherence_scores.append(mean_sim)
|
||||
|
||||
overall_coherence = np.mean(coherence_scores)
|
||||
|
||||
@@ -23,7 +23,14 @@
|
||||
#
|
||||
|
||||
# %%
|
||||
from bertopic import BERTopic
|
||||
import pickle
|
||||
import re
|
||||
|
||||
import gensim.corpora as corpora
|
||||
import nltk
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import spacy
|
||||
from bertopic.representation import KeyBERTInspired
|
||||
from bertopic.vectorizers import ClassTfidfTransformer
|
||||
from gensim.models.coherencemodel import CoherenceModel
|
||||
@@ -33,13 +40,8 @@ from sentence_transformers import SentenceTransformer
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from umap import UMAP
|
||||
import gensim.corpora as corpora
|
||||
import nltk
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import re
|
||||
import spacy
|
||||
import pickle
|
||||
|
||||
from bertopic import BERTopic
|
||||
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
@@ -300,8 +302,8 @@ if REDUCE_OUTLIERS:
|
||||
#
|
||||
|
||||
# %%
|
||||
from pathlib import Path
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
# --- config ---
|
||||
topics_to_keep = {2, 4, 5, 9, 22, 26}
|
||||
@@ -445,7 +447,11 @@ topic_model.get_topic_info()
|
||||
|
||||
# %%
|
||||
topic_words = []
|
||||
for topic_id in range(len(topic_model.get_topic_info()) - 1):
|
||||
for topic_id in topic_model.get_topic_info()["Topic"]:
|
||||
# Skip outlier topic
|
||||
if topic_id < 0:
|
||||
continue
|
||||
|
||||
words = [word for word, _ in topic_model.get_topic(topic_id)]
|
||||
topic_words.append(words)
|
||||
|
||||
@@ -454,8 +460,10 @@ coherence_scores = []
|
||||
for words in topic_words:
|
||||
coherence_embeddings = embedding_model.encode(words)
|
||||
sim_matrix = cosine_similarity(coherence_embeddings)
|
||||
np.fill_diagonal(sim_matrix, 0) # Ignore self-similarity
|
||||
mean_sim = np.mean(sim_matrix)
|
||||
|
||||
# Ignore self-similarity
|
||||
np.fill_diagonal(sim_matrix, 0)
|
||||
mean_sim = np.mean(sim_matrix[np.triu_indices(sim_matrix.shape[0], k=1)])
|
||||
coherence_scores.append(mean_sim)
|
||||
|
||||
overall_coherence = np.mean(coherence_scores)
|
||||
@@ -492,10 +500,14 @@ if this_will_crash_your_pc_are_you_sure:
|
||||
tokens = [analyzer(doc) for doc in cleaned_docs]
|
||||
dictionary = corpora.Dictionary(tokens)
|
||||
corpus = [dictionary.doc2bow(token) for token in tokens]
|
||||
topic_words = [
|
||||
[words for words, _ in topic_model.get_topic(topic)]
|
||||
for topic in range(len(set(topics)) - 1)
|
||||
]
|
||||
|
||||
for topic_id in topic_model.get_topic_info()["Topic"]:
|
||||
# Skip outlier topic
|
||||
if topic_id < 0:
|
||||
continue
|
||||
|
||||
words = [word for word, _ in topic_model.get_topic(topic_id)]
|
||||
topic_words.append(words)
|
||||
|
||||
# %env TOKENIZERS_PARALLELISM=false
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
QLoRA SFT fine-tune for Mistral-7B on chat-style JSONL:
|
||||
Each line: {"messages": [{"role":"system","content":...}, {"role":"user","content":...}, {"role":"assistant","content":...}, ...]}
|
||||
|
||||
Produces a LoRA adapter you can merge or load at inference time.
|
||||
|
||||
Example:
|
||||
python finetune_mistral_bali_qlora.py \
|
||||
--model_id mistralai/Mistral-7B-Instruct-v0.2 \
|
||||
--train_jsonl /path/to/bali_train.jsonl \
|
||||
--output_dir ./mistral-bali-lora \
|
||||
--max_seq_len 2048 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--learning_rate 2e-4 \
|
||||
--num_train_epochs 2 \
|
||||
--streaming true
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
TrainingArguments,
|
||||
)
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Data formatting
|
||||
# -----------------------------
|
||||
def normalize_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Ensures message roles/content are well-formed and in allowed roles.
|
||||
"""
|
||||
allowed = {"system", "user", "assistant"}
|
||||
out = []
|
||||
for m in messages:
|
||||
role = (m.get("role") or "").strip().lower()
|
||||
content = m.get("content")
|
||||
if role not in allowed or content is None:
|
||||
continue
|
||||
content = str(content)
|
||||
out.append({"role": role, "content": content})
|
||||
return out
|
||||
|
||||
|
||||
def messages_to_text(tokenizer: AutoTokenizer, example: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Converts {"messages":[...]} to a single training text using the model's chat template if available.
|
||||
For Mistral Instruct models, tokenizer.apply_chat_template is typically present.
|
||||
"""
|
||||
messages = normalize_messages(example.get("messages", []))
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Prefer tokenizer chat template when available.
|
||||
if (
|
||||
hasattr(tokenizer, "apply_chat_template")
|
||||
and tokenizer.chat_template is not None
|
||||
):
|
||||
# add_generation_prompt=False -> include the assistant content in the formatted text
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
|
||||
# Fallback formatting (less ideal than the native template):
|
||||
# Keep it deterministic and simple.
|
||||
parts = []
|
||||
for m in messages:
|
||||
r = m["role"]
|
||||
c = m["content"].strip()
|
||||
if r == "system":
|
||||
# Override system message
|
||||
system_message = """
|
||||
You are a specialized Balinese cultural travel expert. Your role is to provide accurate, culturally grounded, and practical guidance for travelers engaging with Balinese culture, including temples, ceremonies, etiquette, ritual calendars, dance, crafts, village life, sacred landscapes, and historical–spiritual context.
|
||||
|
||||
Prioritize cultural meaning and lived practice over sightseeing. Explain why places, rituals, and customs matter, and how visitors should behave respectfully. Emphasize dress codes, offerings, bodily conduct, photography rules, gender and purity considerations, and community norms.
|
||||
|
||||
Integrate timing and context where relevant, including ceremonial cycles (Pawukon/Wuku, full and new moons), festival periods, tides, agricultural rhythms, and temple schedules. Promote responsible tourism, community benefit, and environmental care, and discourage entry into restricted or sacred spaces.
|
||||
|
||||
Go beyond generic tips by naming specific temples, villages, regions, ceremonies, deities, and regional variations. Include practical logistics (access, hours, customary donations, crowd patterns) when helpful, without speculation. If uncertain, state this briefly and suggest local confirmation.
|
||||
|
||||
Structure responses clearly: a brief contextual introduction, followed by well-labeled sections or bullet points, and a short “Essentials” or “Respect Checklist” summary.
|
||||
|
||||
Do not include chain-of-thought, hidden reasoning, or meta commentary. Provide only polished, user-facing guidance in a calm, authoritative, and respectful tone.
|
||||
"""
|
||||
parts.append(f"<<SYS>>\n{system_message}\n<</SYS>>\n")
|
||||
elif r == "user":
|
||||
parts.append(f"[USER]\n{c}\n")
|
||||
else:
|
||||
parts.append(f"[ASSISTANT]\n{c}\n")
|
||||
return "\n".join(parts).strip() + "\n"
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# Main
|
||||
# -----------------------------
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="e.g. mistralai/Mistral-7B-Instruct-v0.2 (recommended) or base model",
|
||||
)
|
||||
p.add_argument(
|
||||
"--train_jsonl",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to JSONL training file; each line has a 'messages' list.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--eval_jsonl",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional eval JSONL with same format.",
|
||||
)
|
||||
p.add_argument("--output_dir", type=str, required=True)
|
||||
|
||||
# Training hyperparameters
|
||||
p.add_argument("--max_seq_len", type=int, default=2048)
|
||||
p.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||||
p.add_argument("--per_device_eval_batch_size", type=int, default=1)
|
||||
p.add_argument("--gradient_accumulation_steps", type=int, default=16)
|
||||
p.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
p.add_argument("--weight_decay", type=float, default=0.0)
|
||||
p.add_argument("--num_train_epochs", type=float, default=1.0)
|
||||
p.add_argument("--warmup_ratio", type=float, default=0.03)
|
||||
p.add_argument("--logging_steps", type=int, default=10)
|
||||
p.add_argument("--save_steps", type=int, default=200)
|
||||
p.add_argument("--eval_steps", type=int, default=200)
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
|
||||
# Performance / memory
|
||||
p.add_argument(
|
||||
"--streaming",
|
||||
type=str,
|
||||
default="true",
|
||||
help="true/false. Use streaming for very large JSONL.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--bf16",
|
||||
type=str,
|
||||
default="true",
|
||||
help="true/false. Prefer bf16 if your GPU supports it.",
|
||||
)
|
||||
p.add_argument("--gradient_checkpointing", type=str, default="true")
|
||||
|
||||
# LoRA config
|
||||
p.add_argument("--lora_r", type=int, default=16)
|
||||
p.add_argument("--lora_alpha", type=int, default=32)
|
||||
p.add_argument("--lora_dropout", type=float, default=0.05)
|
||||
p.add_argument(
|
||||
"--target_modules",
|
||||
type=str,
|
||||
default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
|
||||
help="Comma-separated module names for Mistral-style architectures.",
|
||||
)
|
||||
|
||||
# Optional: limit samples for quick smoke tests
|
||||
p.add_argument("--max_train_samples", type=int, default=None)
|
||||
p.add_argument("--max_eval_samples", type=int, default=None)
|
||||
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def str2bool(x: str) -> bool:
|
||||
return str(x).strip().lower() in {"1", "true", "yes", "y", "t"}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
streaming = str2bool(args.streaming)
|
||||
use_bf16 = str2bool(args.bf16)
|
||||
use_gc = str2bool(args.gradient_checkpointing)
|
||||
|
||||
# -----------------------------
|
||||
# Tokenizer
|
||||
# -----------------------------
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True)
|
||||
if tokenizer.pad_token is None:
|
||||
# Common for causal LMs
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# -----------------------------
|
||||
# Model (4-bit QLoRA)
|
||||
# -----------------------------
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16 if use_bf16 else torch.float16,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
|
||||
)
|
||||
|
||||
model.config.use_cache = False # important for training
|
||||
if use_gc:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare for k-bit + LoRA
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
target_modules = [m.strip() for m in args.target_modules.split(",") if m.strip()]
|
||||
lora_config = LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# -----------------------------
|
||||
# Dataset
|
||||
# -----------------------------
|
||||
data_files = {"train": args.train_jsonl}
|
||||
if args.eval_jsonl:
|
||||
data_files["eval"] = args.eval_jsonl
|
||||
|
||||
ds = load_dataset("json", data_files=data_files, streaming=streaming)
|
||||
|
||||
def format_fn(example: Dict[str, Any]) -> Dict[str, str]:
|
||||
text = messages_to_text(tokenizer, example)
|
||||
return {"text": text}
|
||||
|
||||
train_ds = ds["train"].map(format_fn)
|
||||
eval_ds = ds["eval"].map(format_fn) if args.eval_jsonl else None
|
||||
|
||||
# Optional sample limits (works differently for streaming vs non-streaming)
|
||||
if args.max_train_samples is not None:
|
||||
if streaming:
|
||||
train_ds = train_ds.take(args.max_train_samples)
|
||||
else:
|
||||
train_ds = train_ds.select(
|
||||
range(min(args.max_train_samples, len(train_ds)))
|
||||
)
|
||||
|
||||
if eval_ds is not None and args.max_eval_samples is not None:
|
||||
if streaming:
|
||||
eval_ds = eval_ds.take(args.max_eval_samples)
|
||||
else:
|
||||
eval_ds = eval_ds.select(range(min(args.max_eval_samples, len(eval_ds))))
|
||||
|
||||
# -----------------------------
|
||||
# Training
|
||||
# -----------------------------
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
num_train_epochs=args.num_train_epochs,
|
||||
warmup_ratio=args.warmup_ratio,
|
||||
logging_steps=args.logging_steps,
|
||||
save_steps=args.save_steps,
|
||||
eval_strategy="steps" if eval_ds is not None else "no",
|
||||
eval_steps=args.eval_steps if eval_ds is not None else None,
|
||||
save_total_limit=3,
|
||||
bf16=use_bf16,
|
||||
fp16=not use_bf16,
|
||||
optim="paged_adamw_8bit", # good default for QLoRA
|
||||
lr_scheduler_type="cosine",
|
||||
seed=args.seed,
|
||||
report_to="none",
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=eval_ds,
|
||||
packing=True, # packs multiple conversations per sequence for higher throughput
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save LoRA adapter + tokenizer
|
||||
trainer.model.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
|
||||
print(f"\nDone. Saved LoRA adapter to: {args.output_dir}")
|
||||
print("Inference: load base model + peft adapter from this directory.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,10 @@
|
||||
python finetune_mistral_bali_qlora.py \
|
||||
--model_id mistralai/Mistral-7B-Instruct-v0.2 \
|
||||
--train_jsonl ../raft/bali_culture_raft_dataset.jsonl \
|
||||
--output_dir ./mistral-bali-lora \
|
||||
--max_seq_len 2048 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--learning_rate 2e-4 \
|
||||
--num_train_epochs 2 \
|
||||
--streaming true
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Rewrite chat-style JSONL into {"input": ..., "output": ...} JSONL for LLM tuning.
|
||||
|
||||
Expected input line shape (example):
|
||||
{
|
||||
"messages": [
|
||||
{"role":"system","content":"..."},
|
||||
{"role":"user","content":"..."},
|
||||
{"role":"assistant","content":"..."}
|
||||
],
|
||||
"meta": {...} # optional
|
||||
}
|
||||
|
||||
Output line shape:
|
||||
{"input": "<user text>", "output": "<assistant text>"}
|
||||
|
||||
By default:
|
||||
- Ignores all non-user/assistant roles (e.g., system).
|
||||
- Emits one record per (user -> next assistant) pair in the conversation.
|
||||
- Drops all other fields (including meta) unless --keep-meta is set.
|
||||
|
||||
Usage:
|
||||
python rewrite_jsonl.py in.jsonl out.jsonl
|
||||
cat in.jsonl | python rewrite_jsonl.py - - > out.jsonl
|
||||
python rewrite_jsonl.py in.jsonl out.jsonl --only-last
|
||||
python rewrite_jsonl.py in.jsonl out.jsonl --keep-meta
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def iter_user_assistant_pairs(messages: List[Dict[str, Any]]) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Return list of (user_content, assistant_content) pairs.
|
||||
Pairing rule: whenever a 'user' message is followed later by the next 'assistant'
|
||||
message, emit a pair. Intermediate system/tool messages are ignored.
|
||||
"""
|
||||
pairs: List[Tuple[str, str]] = []
|
||||
pending_user: Optional[str] = None
|
||||
|
||||
for m in messages:
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
|
||||
if role == "user":
|
||||
# Start (or restart) a pending user turn
|
||||
if isinstance(content, str) and content.strip():
|
||||
pending_user = content
|
||||
else:
|
||||
pending_user = ""
|
||||
elif role == "assistant":
|
||||
if pending_user is not None:
|
||||
assistant_text = content if isinstance(content, str) else ""
|
||||
pairs.append((pending_user, assistant_text))
|
||||
pending_user = None
|
||||
else:
|
||||
# ignore system/tool/developer/etc.
|
||||
continue
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def read_lines(path: str) -> List[str]:
|
||||
if path == "-":
|
||||
return sys.stdin.read().splitlines()
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return f.read().splitlines()
|
||||
|
||||
|
||||
def write_lines(path: str, lines: List[str]) -> None:
|
||||
if path == "-":
|
||||
sys.stdout.write("\n".join(lines) + ("\n" if lines else ""))
|
||||
return
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(lines) + ("\n" if lines else ""))
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("infile", help="Input JSONL path, or '-' for stdin")
|
||||
ap.add_argument("outfile", help="Output JSONL path, or '-' for stdout")
|
||||
ap.add_argument(
|
||||
"--only-last",
|
||||
action="store_true",
|
||||
help="Emit only the last (user -> assistant) pair per input line.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--keep-meta",
|
||||
action="store_true",
|
||||
help="If input line has 'meta', copy it through to output records.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
in_lines = read_lines(args.infile)
|
||||
out_lines: List[str] = []
|
||||
|
||||
for idx, line in enumerate(in_lines, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
sys.stderr.write(f"[line {idx}] JSON decode error: {e}\n")
|
||||
continue
|
||||
|
||||
messages = obj.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
# Not in expected format; skip silently (or log if desired)
|
||||
continue
|
||||
|
||||
pairs = iter_user_assistant_pairs(messages)
|
||||
if not pairs:
|
||||
continue
|
||||
|
||||
if args.only_last:
|
||||
pairs = [pairs[-1]]
|
||||
|
||||
for user_text, assistant_text in pairs:
|
||||
out_obj: Dict[str, Any] = {
|
||||
"input": user_text,
|
||||
"output": assistant_text,
|
||||
}
|
||||
if args.keep_meta and isinstance(obj.get("meta"), dict):
|
||||
out_obj["meta"] = obj["meta"]
|
||||
out_lines.append(json.dumps(out_obj, ensure_ascii=False))
|
||||
|
||||
write_lines(args.outfile, out_lines)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -0,0 +1,36 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def rewrite_jsonl(input_path, output_path):
|
||||
with open(input_path, "r", encoding="utf-8") as infile, open(
|
||||
output_path, "w", encoding="utf-8"
|
||||
) as outfile:
|
||||
|
||||
for line_num, line in enumerate(infile, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
record = json.loads(line)
|
||||
user_text = record.get("input", "")
|
||||
bot_text = record.get("output", "")
|
||||
|
||||
new_record = {"text": f"<user>: {user_text} <bot>: {bot_text}"}
|
||||
|
||||
outfile.write(json.dumps(new_record, ensure_ascii=False) + "\n")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON on line {line_num}") from e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Rewrite JSONL from {input, output} to {text: '<user>: ... <bot>: ...'} format"
|
||||
)
|
||||
parser.add_argument("--input", required=True, help="Path to input JSONL file")
|
||||
parser.add_argument("--output", required=True, help="Path to output JSONL file")
|
||||
|
||||
args = parser.parse_args()
|
||||
rewrite_jsonl(args.input, args.output)
|
||||
@@ -0,0 +1,54 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def rewrite_jsonl(input_path, output_path):
|
||||
with open(input_path, "r", encoding="utf-8") as infile, open(
|
||||
output_path, "w", encoding="utf-8"
|
||||
) as outfile:
|
||||
|
||||
for line_num, line in enumerate(infile, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
record = json.loads(line)
|
||||
messages = record.get("messages", [])
|
||||
|
||||
user_parts = []
|
||||
bot_parts = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "user":
|
||||
user_parts.append(content)
|
||||
elif role == "assistant":
|
||||
bot_parts.append(content)
|
||||
|
||||
# Skip entries without both sides
|
||||
if not user_parts or not bot_parts:
|
||||
continue
|
||||
|
||||
user_text = " ".join(user_parts)
|
||||
bot_text = " ".join(bot_parts)
|
||||
|
||||
new_record = {"text": f"<user>: {user_text} <bot>: {bot_text}"}
|
||||
|
||||
outfile.write(json.dumps(new_record, ensure_ascii=False) + "\n")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON on line {line_num}") from e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Rewrite messages-based JSONL to {text: '<user>: ... <bot>: ...'} format"
|
||||
)
|
||||
parser.add_argument("--input", required=True, help="Path to input JSONL file")
|
||||
parser.add_argument("--output", required=True, help="Path to output JSONL file")
|
||||
|
||||
args = parser.parse_args()
|
||||
rewrite_jsonl(args.input, args.output)
|
||||
@@ -6,6 +6,10 @@
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.18.0
|
||||
# kernelspec:
|
||||
# display_name: .venv
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# %% [markdown]
|
||||
@@ -60,14 +64,13 @@ FAILED_LOG = Path("./raft_failures.log")
|
||||
TARGET_MIN_SAMPLES = 5000
|
||||
TARGET_MAX_SAMPLES = 10000
|
||||
|
||||
# How many Q&A pairs to request per API call.
|
||||
GEN_PAIRS_PER_BATCH = (3, 6) # (min, max)
|
||||
|
||||
# Number of review snippets to include in one request (to anchor the generations).
|
||||
SNIPPETS_PER_BATCH = 6
|
||||
|
||||
# Model + API
|
||||
DEEPSEEK_MODEL = "deepseek-reasoner" # reasoning model with CoT (we will discard CoT in dataset)
|
||||
DEEPSEEK_MODEL = (
|
||||
"deepseek-chat" # reasoning model with CoT (we will discard CoT in dataset)
|
||||
)
|
||||
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
|
||||
API_KEY = os.environ.get("DEEPSEEK_API_KEY", "PUT_YOUR_KEY_HERE")
|
||||
|
||||
@@ -81,12 +84,15 @@ SEED = 42
|
||||
|
||||
# --------------------------------
|
||||
os.makedirs(CORPUS_DIR, exist_ok=True)
|
||||
print(f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}")
|
||||
print(
|
||||
f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}"
|
||||
)
|
||||
|
||||
# %%
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def parse_corpus_text(text: str) -> List[str]:
|
||||
"""
|
||||
Parse a file that contains lines like:
|
||||
@@ -115,6 +121,7 @@ def parse_corpus_text(text: str) -> List[str]:
|
||||
reviews.append(p)
|
||||
return reviews
|
||||
|
||||
|
||||
def load_corpus_snippets(corpus_dir: Path) -> List[str]:
|
||||
snippets = []
|
||||
for p in corpus_dir.glob("**/*.txt"):
|
||||
@@ -126,6 +133,7 @@ def load_corpus_snippets(corpus_dir: Path) -> List[str]:
|
||||
print(f"Failed to parse {p}: {e}")
|
||||
return snippets
|
||||
|
||||
|
||||
snippets = load_corpus_snippets(CORPUS_DIR)
|
||||
print(f"Loaded {len(snippets)} review snippets.")
|
||||
print("Example:", snippets[0][:200] if snippets else "(no snippets)")
|
||||
@@ -146,33 +154,39 @@ You are a meticulous, culture-focused Bali travel expert. Your mission is to cra
|
||||
"""
|
||||
|
||||
GEN_INSTRUCTION = """
|
||||
From the provided review snippets, generate {k} distinct **Q&A pairs** valuable for travelers focused on Balinese culture. Each Q&A should:
|
||||
From the provided review snippets, generate a distinct **Q&A pair** valuable for travelers focused on Balinese culture. Each Q&A should:
|
||||
- Ask a question a culture-curious traveler would search for (concise).
|
||||
- Provide a **thorough, actionable, expert** answer (400–900 words when needed).
|
||||
- Incorporate and reconcile the snippets where helpful, but **freely add authoritative, accurate context** to reach high quality.
|
||||
- Emphasize respect, safety, logistics, cultural sensitivity, and practical steps.
|
||||
- Do **NOT** output chain-of-thought. Do **NOT** include references to “snippets” or meta-instructions in the final answer.
|
||||
- Return ONLY valid JSON with this shape:
|
||||
{
|
||||
"pairs": [ {"question": "...","answer": "..."} , ... ]
|
||||
}
|
||||
{{
|
||||
"pairs": [ {{"question": "...","answer": "..."}} ]
|
||||
}}
|
||||
"""
|
||||
|
||||
def make_user_prompt(batch_snippets, k):
|
||||
|
||||
def make_user_prompt(batch_snippets):
|
||||
joined = "\n\n---\n\n".join(batch_snippets)
|
||||
return f"""You are given **Bali travel review snippets** (may be noisy/partial).
|
||||
Generate {k} culture-focused Q&A pairs in JSON using the spec provided.
|
||||
return f"""You are given a **Bali travel review snippet** (may be noisy/partial).
|
||||
Generate a culture-focused Q&A pair in JSON using the spec provided.
|
||||
|
||||
Snippets:
|
||||
{joined}
|
||||
|
||||
{GEN_INSTRUCTION.format(k=k)}
|
||||
{GEN_INSTRUCTION}
|
||||
"""
|
||||
|
||||
|
||||
# %%
|
||||
import time
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from typing import Tuple
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
@@ -185,8 +199,15 @@ except Exception:
|
||||
|
||||
client = OpenAI(api_key=API_KEY, base_url=DEEPSEEK_BASE_URL)
|
||||
|
||||
def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MODEL,
|
||||
temperature: float = TEMPERATURE, max_tokens: int = MAX_TOKENS, timeout: int = TIMEOUT) -> dict:
|
||||
|
||||
def ask_deepseek(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
model: str = DEEPSEEK_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
timeout: int = TIMEOUT,
|
||||
) -> dict:
|
||||
"""
|
||||
Calls DeepSeek's /chat/completions. Returns parsed JSON (dict) from assistant content.
|
||||
Any 'reasoning_content' produced by deepseek-reasoner is ignored.
|
||||
@@ -203,8 +224,10 @@ def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MOD
|
||||
timeout=timeout,
|
||||
)
|
||||
content = resp.choices[0].message.content
|
||||
print(resp)
|
||||
return json.loads(content)
|
||||
|
||||
|
||||
def as_messages_entry(question: str, answer: str) -> dict:
|
||||
return {
|
||||
"messages": [
|
||||
@@ -214,14 +237,16 @@ def as_messages_entry(question: str, answer: str) -> dict:
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def chunk(lst, n):
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30), reraise=True)
|
||||
def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
|
||||
user_prompt = make_user_prompt(batch_snips, k)
|
||||
|
||||
def generate_pairs_for_batch(batch_snips: list) -> list:
|
||||
user_prompt = make_user_prompt(batch_snips)
|
||||
data = ask_deepseek(SYSTEM_PROMPT, user_prompt)
|
||||
print(data)
|
||||
pairs = data.get("pairs", [])
|
||||
out = []
|
||||
for p in pairs:
|
||||
@@ -232,6 +257,52 @@ def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
|
||||
return out
|
||||
|
||||
|
||||
# %%
|
||||
import json, random
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
MAX_WORKERS = 32 # tune: start at 8/16/32, adjust for rate limits & CPU
|
||||
|
||||
|
||||
def safe_generate(batch):
|
||||
# Keep this small: just generate; let caller handle logging/writing
|
||||
return generate_pairs_for_batch(batch)
|
||||
|
||||
|
||||
random.shuffle(snippets)
|
||||
|
||||
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(
|
||||
FAILED_LOG, "a", encoding="utf-8"
|
||||
) as flog:
|
||||
total_written = 0
|
||||
failed_batches = 0
|
||||
|
||||
# Submit all jobs up front (or see “windowed submission” below)
|
||||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
|
||||
futures = {
|
||||
ex.submit(safe_generate, batch): (i, batch)
|
||||
for i, batch in enumerate(snippets)
|
||||
}
|
||||
|
||||
for fut in as_completed(futures):
|
||||
i, batch = futures[fut]
|
||||
|
||||
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
|
||||
if remaining_max <= 0:
|
||||
# Optional: you can stop early; outstanding futures still run.
|
||||
break
|
||||
|
||||
try:
|
||||
entries = fut.result()
|
||||
# serialize writes in this single consumer loop
|
||||
for e in entries[:remaining_max]:
|
||||
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
|
||||
total_written += min(len(entries), remaining_max)
|
||||
|
||||
except Exception as e:
|
||||
failed_batches += 1
|
||||
flog.write(f"BATCH_FAIL\tidx={i}\t{repr(e)}\n")
|
||||
|
||||
# %%
|
||||
import random, json
|
||||
|
||||
@@ -243,19 +314,20 @@ if OUTPUT_JSONL.exists():
|
||||
total_written = 0
|
||||
failed_batches = 0
|
||||
|
||||
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(FAILED_LOG, "a", encoding="utf-8") as flog:
|
||||
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(
|
||||
FAILED_LOG, "a", encoding="utf-8"
|
||||
) as flog:
|
||||
random.shuffle(snippets)
|
||||
for i in range(0, len(snippets), SNIPPETS_PER_BATCH):
|
||||
batch = snippets[i:i+SNIPPETS_PER_BATCH]
|
||||
for i in range(0, len(snippets)):
|
||||
batch = snippets[i]
|
||||
print(i, batch)
|
||||
remaining_min = max(TARGET_MIN_SAMPLES - total_written, 0)
|
||||
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
|
||||
if remaining_max <= 0:
|
||||
break
|
||||
k_low, k_high = GEN_PAIRS_PER_BATCH
|
||||
k = min(k_high, max(k_low, remaining_min // 2 if remaining_min else k_high))
|
||||
|
||||
try:
|
||||
entries = generate_pairs_for_batch(batch, k=k)
|
||||
entries = generate_pairs_for_batch(batch)
|
||||
for e in entries:
|
||||
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
|
||||
total_written += len(entries)
|
||||
@@ -279,7 +351,9 @@ if OUTPUT_JSONL.exists():
|
||||
msgs = obj.get("messages", [])
|
||||
print(f"Sample {i+1}:")
|
||||
for m in msgs:
|
||||
print(f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}")
|
||||
print(
|
||||
f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}"
|
||||
)
|
||||
print("-" * 80)
|
||||
except Exception as e:
|
||||
print("Failed to parse a line:", e)
|
||||
@@ -290,6 +364,7 @@ else:
|
||||
# Optional utility: shard the JSONL for training convenience
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
|
||||
shard_idx = 0
|
||||
count = 0
|
||||
@@ -301,7 +376,9 @@ def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
|
||||
out.close()
|
||||
shard_idx += 1
|
||||
count = 0
|
||||
shard_path = input_path.with_name(input_path.stem + f".part{shard_idx:03d}.jsonl")
|
||||
shard_path = input_path.with_name(
|
||||
input_path.stem + f".part{shard_idx:03d}.jsonl"
|
||||
)
|
||||
out = open(shard_path, "w", encoding="utf-8")
|
||||
print("Opened", shard_path)
|
||||
out.write(line)
|
||||
@@ -310,6 +387,7 @@ def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
|
||||
out.close()
|
||||
print("Sharding complete.")
|
||||
|
||||
|
||||
# Example:
|
||||
# shard_jsonl(OUTPUT_JSONL, lines_per_shard=4000)
|
||||
|
||||
|
||||
@@ -29,8 +29,9 @@ from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Paths
|
||||
DATA_JSONL = Path("./outputs/raft_dataset.jsonl") # change if different
|
||||
RUN_NAME = "raft_qlora_tourist_0.2"
|
||||
# DATA_JSONL = Path("./outputs/raft_dataset.jsonl") # change if different
|
||||
DATA_JSONL = Path("../raft/bali_culture_raft_dataset.jsonl")
|
||||
RUN_NAME = "raft_qlora_tourist"
|
||||
OUTPUT_DIR = Path(f"./finetuned/{RUN_NAME}")
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ADAPTER_DIR = OUTPUT_DIR / "lora_adapter"
|
||||
|
||||
@@ -0,0 +1,677 @@
|
||||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.18.0
|
||||
# kernelspec:
|
||||
# display_name: .venv
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# %% [markdown]
|
||||
# # QLoRA/RAFT Fine-Tuning
|
||||
#
|
||||
|
||||
# %% [markdown]
|
||||
# ## Configuration
|
||||
#
|
||||
|
||||
# %%
|
||||
from termcolor import colored
|
||||
from pathlib import Path
|
||||
from transformers import BitsAndBytesConfig
|
||||
from torch import torch
|
||||
from peft import PeftModel
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Paths
|
||||
DATA_JSONL = Path("../raft/remap_bali_raft_dataset.jsonl") # change if different
|
||||
RUN_NAME = "raft_qlora_tourist"
|
||||
OUTPUT_DIR = Path(f"./finetuned/{RUN_NAME}")
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ADAPTER_DIR = OUTPUT_DIR / "checkpoint-1550"
|
||||
|
||||
# Base model — examples: "meta-llama/Llama-3.1-8B", "Qwen/Qwen2-7B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
# Prefer an instruction-tuned base for better stability on SFT.
|
||||
BASE_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
|
||||
# Tokenization/prompt formatting
|
||||
SYSTEM_PREFIX = "You are a helpful assistant. Answer concisely and truthfully based ONLY on the user's request."
|
||||
USE_CHAT_TEMPLATE = True # if the tokenizer has a chat template, we'll leverage it
|
||||
|
||||
# BitsAndBytes config
|
||||
BNB_CONFIG = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## 2) Load dataset (JSONL)
|
||||
#
|
||||
|
||||
# %%
|
||||
import json
|
||||
import random
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
def read_jsonl(p: Path):
|
||||
rows = []
|
||||
with p.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
if "input" in obj and "output" in obj:
|
||||
rows.append(obj)
|
||||
except Exception:
|
||||
pass
|
||||
return rows
|
||||
|
||||
|
||||
rows = read_jsonl(DATA_JSONL)
|
||||
print(f"Loaded {len(rows)} rows from {DATA_JSONL}")
|
||||
print(rows[0])
|
||||
|
||||
random.Random(42).shuffle(rows)
|
||||
split = int(len(rows) * 0.85)
|
||||
train_rows = rows[:split]
|
||||
val_rows = rows[split:] if split < len(rows) else rows[-max(1, len(rows) // 50) :]
|
||||
|
||||
train_rows = [{"input": r["input"], "output": r["output"]} for r in train_rows]
|
||||
val_rows = [{"input": r["input"], "output": r["output"]} for r in val_rows]
|
||||
|
||||
train_ds = Dataset.from_list(train_rows)
|
||||
eval_ds = Dataset.from_list(val_rows) if val_rows else None
|
||||
train_ds, eval_ds
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## 3) Prompt formatting
|
||||
#
|
||||
|
||||
# %%
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
print(colored("Verifying eos and pad tokens...", "yellow"))
|
||||
if tokenizer.pad_token_id != 2:
|
||||
print(colored(f"Expected pad token to be 2, but got {tokenizer.pad_token}", "red"))
|
||||
else:
|
||||
print(colored("Pad token is ok", "green"))
|
||||
|
||||
if tokenizer.eos_token_id != 2:
|
||||
print(colored(f"Expected eos token to be 2, but got {tokenizer.eos_token}", "red"))
|
||||
else:
|
||||
print(colored("Eos token is ok", "green"))
|
||||
|
||||
|
||||
def format_example(ex):
|
||||
user = ex["input"]
|
||||
assistant = ex["output"]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PREFIX},
|
||||
{"role": "user", "content": user},
|
||||
{"role": "assistant", "content": assistant},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
return {"text": text}
|
||||
|
||||
|
||||
train_ds_fmt = train_ds.map(format_example, remove_columns=train_ds.column_names)
|
||||
eval_ds_fmt = (
|
||||
eval_ds.map(format_example, remove_columns=eval_ds.column_names)
|
||||
if eval_ds
|
||||
else None
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
print("👉 " + train_ds_fmt[i]["text"])
|
||||
if train_ds_fmt[i]["text"][-4:] == tokenizer.eos_token:
|
||||
print(f"✅ {colored('EOS is fine.', 'green')}")
|
||||
else:
|
||||
print(f"❌ {colored('EOS is missing.', 'red')}")
|
||||
|
||||
# %% [markdown]
|
||||
# ## 4) Tokenize
|
||||
#
|
||||
|
||||
# %%
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
def make_supervised_tensors(batch):
|
||||
enc = tokenizer(
|
||||
batch["text"],
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
padding="max_length",
|
||||
return_tensors=None,
|
||||
)
|
||||
input_ids = enc["input_ids"]
|
||||
attn_mask = enc["attention_mask"]
|
||||
|
||||
# Mask pads
|
||||
labels = [ids[:] for ids in input_ids]
|
||||
for i in range(len(labels)):
|
||||
for j, m in enumerate(attn_mask[i]):
|
||||
if m == 0:
|
||||
labels[i][j] = IGNORE_INDEX
|
||||
|
||||
return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
|
||||
|
||||
|
||||
train_tok = train_ds_fmt.map(
|
||||
make_supervised_tensors, batched=True, remove_columns=train_ds_fmt.column_names
|
||||
)
|
||||
eval_tok = (
|
||||
eval_ds_fmt.map(
|
||||
make_supervised_tensors, batched=True, remove_columns=eval_ds_fmt.column_names
|
||||
)
|
||||
if eval_ds_fmt
|
||||
else None
|
||||
)
|
||||
|
||||
train_tok, eval_tok
|
||||
|
||||
train_ds_fmt["text"][0]
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## Setup sanity check
|
||||
#
|
||||
|
||||
# %%
|
||||
import transformers
|
||||
import peft
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.nn import modules as bnb_modules
|
||||
|
||||
print(colored("Sanity check...", "yellow"))
|
||||
print("CUDA available:", torch.cuda.is_available())
|
||||
print("Torch version:", torch.__version__)
|
||||
print("Transformers version:", transformers.__version__)
|
||||
print(
|
||||
"Compute capability:",
|
||||
torch.cuda.get_device_capability(0) if torch.cuda.is_available() else "no cuda",
|
||||
)
|
||||
print("BitsAndbytes:", bnb.__version__)
|
||||
print("PEFT:", peft.__version__)
|
||||
|
||||
|
||||
print("Embedding4bit available:", hasattr(bnb_modules, "Embedding4bit"))
|
||||
|
||||
# %% [markdown]
|
||||
# ## 5) Load base model with 4-bit quantization and prepare QLoRA
|
||||
#
|
||||
|
||||
# %%
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
quantization_config=BNB_CONFIG,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## 6) Train
|
||||
#
|
||||
|
||||
# %%
|
||||
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
||||
import math
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=str(OUTPUT_DIR),
|
||||
run_name=RUN_NAME,
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=1,
|
||||
per_device_eval_batch_size=1,
|
||||
gradient_accumulation_steps=8,
|
||||
learning_rate=2e-4,
|
||||
warmup_ratio=0.05,
|
||||
weight_decay=0.01,
|
||||
logging_steps=25,
|
||||
eval_steps=50,
|
||||
save_steps=50,
|
||||
save_total_limit=2,
|
||||
bf16=True,
|
||||
fp16=False,
|
||||
gradient_checkpointing=True,
|
||||
report_to=["none"],
|
||||
seed=42,
|
||||
eval_strategy="steps",
|
||||
load_best_model_at_end=True,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=train_tok,
|
||||
eval_dataset=eval_tok,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
train_result = trainer.train()
|
||||
metrics = trainer.evaluate() if eval_tok else {}
|
||||
perplexity = (
|
||||
math.exp(metrics["eval_loss"]) if metrics and "eval_loss" in metrics else None
|
||||
)
|
||||
metrics, perplexity
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# | epochs | train_loss | eval_loss |
|
||||
# | ------ | ---------- | --------- |
|
||||
# | 50 | 4.377000 | 3.628506 |
|
||||
# | 100 | 2.636800 | 2.558457 |
|
||||
# | 150 | 2.428800 | 2.427239 |
|
||||
# | 200 | 2.334800 | 2.193493 |
|
||||
# | 250 | 2.188500 | 2.186310 |
|
||||
# | 300 | 2.112400 | 2.173394 |
|
||||
# | 350 | 2.122900 | 2.163947 |
|
||||
# | 400 | 2.155400 | 2.162106 |
|
||||
# | 450 | 2.072100 | 2.154830 |
|
||||
# | 500 | 1.979900 | 2.165512 |
|
||||
# | 550 | 1.935800 | 2.176313 |
|
||||
# | 600 | 1.942800 | 2.170668 |
|
||||
# | 650 | 1.968000 | 2.162810 |
|
||||
# | 700 | 1.974100 | 2.167501 |
|
||||
# | 750 | 1.801900 | 2.235841 |
|
||||
# | 800 | 1.768000 | 2.233753 |
|
||||
# | 850 | 1.779100 | 2.218278 |
|
||||
# | 900 | 1.828900 | 2.220891 |
|
||||
# | 950 | 1.854900 | 2.208387 |
|
||||
# | 1000 | 1.653600 | 2.302763 |
|
||||
# | 1050 | 1.663500 | 2.307982 |
|
||||
# | 1100 | 1.673400 | 2.301423 |
|
||||
# | 1150 | 1.608400 | 2.320958 |
|
||||
# | 1200 | 1.683500 | 2.303580 |
|
||||
# | 1250 | 1.532100 | 2.434277 |
|
||||
# | 1300 | 1.558900 | 2.418276 |
|
||||
# | 1350 | 1.508900 | 2.422347 |
|
||||
# | 1400 | 1.535100 | 2.416650 |
|
||||
# | 1450 | 1.529900 | 2.415497 |
|
||||
#
|
||||
# | Step | Training Loss | Evaluation Loss |
|
||||
# | ---- | ------------- | --------------- |
|
||||
# | 50 | 1.173100 | 1.040235 |
|
||||
# | 100 | 0.882900 | 0.875235 |
|
||||
# | 150 | 0.806600 | 0.820686 |
|
||||
# | 200 | 0.785700 | 0.792914 |
|
||||
# | 250 | 0.764300 | 0.761308 |
|
||||
# | 300 | 0.733900 | 0.745976 |
|
||||
# | 350 | 0.744000 | 0.732220 |
|
||||
# | 400 | 0.712000 | 0.719414 |
|
||||
# | 450 | 0.703800 | 0.709955 |
|
||||
# | 500 | 0.684100 | 0.699460 |
|
||||
# | 550 | 0.705900 | 0.691758 |
|
||||
# | 600 | 0.683200 | 0.688031 |
|
||||
# | 650 | 0.670100 | 0.680539 |
|
||||
# | 700 | 0.681600 | 0.674205 |
|
||||
# | 750 | 0.681500 | 0.671295 |
|
||||
# | 800 | 0.651700 | 0.666133 |
|
||||
# | 850 | 0.662900 | 0.660661 |
|
||||
# | 900 | 0.651400 | 0.656359 |
|
||||
# | 950 | 0.648100 | 0.653309 |
|
||||
# | 1000 | 0.631500 | 0.648716 |
|
||||
# | 1050 | 0.654200 | 0.643737 |
|
||||
# | 1100 | 0.571100 | 0.648199 |
|
||||
# | 1150 | 0.573500 | 0.648405 |
|
||||
# | 1200 | 0.556000 | 0.644185 |
|
||||
# | 1250 | 0.568100 | 0.642854 |
|
||||
# | 1300 | 0.570200 | 0.640425 |
|
||||
# | 1350 | 0.551100 | 0.636319 |
|
||||
# | 1400 | 0.551400 | 0.634054 |
|
||||
# | 1450 | 0.550100 | 0.631558 |
|
||||
# | 1500 | 0.559800 | 0.630046 |
|
||||
# | 1550 | 0.556600 | 0.626972 |
|
||||
#
|
||||
|
||||
# %% [markdown]
|
||||
# ## 7) Save LoRA adapters
|
||||
#
|
||||
|
||||
# %%
|
||||
ADAPTER_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model.save_pretrained(str(ADAPTER_DIR))
|
||||
tokenizer.save_pretrained(str(ADAPTER_DIR))
|
||||
|
||||
print(f"Saved LoRA adapter to: {ADAPTER_DIR}")
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## 8) Save merged model
|
||||
#
|
||||
|
||||
# %%
|
||||
# this does not work on my system since I don't have enough VRAM.
|
||||
# it should work though provided you have sufficient resources.
|
||||
# my next step would have been to convert the merged model to llama.cpp GGUF format so I can run it in Ollama/OpenWebUI.
|
||||
DO_MERGE = False
|
||||
|
||||
base_model = None
|
||||
if DO_MERGE:
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
merged = PeftModel.from_pretrained(
|
||||
base_model, str(ADAPTER_DIR), offload_folder="offload/", is_trainable=False
|
||||
).merge_and_unload()
|
||||
merged_dir = OUTPUT_DIR / "merged_model"
|
||||
merged.save_pretrained(str(merged_dir))
|
||||
tokenizer.save_pretrained(str(merged_dir))
|
||||
print(f"Merged full model saved to: {merged_dir}")
|
||||
else:
|
||||
print("Skipping merge (set DO_MERGE=True to enable).")
|
||||
|
||||
# %% [markdown]
|
||||
# ## 9) Quick inference with the trained adapter
|
||||
#
|
||||
|
||||
# %%
|
||||
test_model = None
|
||||
|
||||
print(colored("Loading the base model + trained adapter.", "green"))
|
||||
test_model = AutoModelForCausalLM.from_pretrained(
|
||||
BASE_MODEL,
|
||||
quantization_config=BNB_CONFIG,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
test_model = PeftModel.from_pretrained(
|
||||
test_model, str(ADAPTER_DIR), offload_folder="offload/", is_trainable=False
|
||||
)
|
||||
test_model.eval()
|
||||
|
||||
|
||||
def generate_answer(prompt, max_new_tokens=256, temperature=0.2, top_p=0.9):
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PREFIX},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
model_inputs = tokenizer.apply_chat_template(
|
||||
messages, return_tensors="pt", add_generation_prompt=True
|
||||
).to(test_model.device)
|
||||
|
||||
gen_kwargs = {"input_ids": model_inputs}
|
||||
|
||||
with torch.no_grad():
|
||||
out = test_model.generate(
|
||||
**gen_kwargs,
|
||||
do_sample=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
)
|
||||
return tokenizer.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
sample_prompt = (
|
||||
train_rows[0]["input"]
|
||||
if len(train_rows) > 0
|
||||
else "What are the visitor crowd levels like?"
|
||||
)
|
||||
|
||||
for i in range(10):
|
||||
print(generate_answer(train_rows[i]["input"])[:800])
|
||||
print("---")
|
||||
|
||||
|
||||
# %%
|
||||
generate_answer("What are the visitor crowd levels like?")
|
||||
|
||||
|
||||
# %%
|
||||
def chat(
|
||||
user, system="You are a precise assistant.", temperature=0.0, max_new_tokens=256
|
||||
):
|
||||
msgs = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
]
|
||||
model_inputs = tokenizer.apply_chat_template(
|
||||
msgs, return_tensors="pt", add_generation_prompt=True
|
||||
).to(test_model.device)
|
||||
gen_kwargs = {"input_ids": model_inputs}
|
||||
with torch.no_grad():
|
||||
out = test_model.generate(
|
||||
**gen_kwargs,
|
||||
# **tokenizer(user, return_tensors="pt").to(test_model.device),
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=(temperature > 0),
|
||||
temperature=temperature,
|
||||
top_p=1.0,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
return tokenizer.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
for i in range(10):
|
||||
prompt = train_rows[i]["input"]
|
||||
out = chat(prompt, max_new_tokens=2000, temperature=0.2)
|
||||
|
||||
print("\n\n💬\n" + out)
|
||||
|
||||
# %% [markdown]
|
||||
# ## PoS Gradio setup
|
||||
#
|
||||
|
||||
# %%
|
||||
# === Gradio chat for Mistral-Instruct (no self-replies) ===
|
||||
# Assumes: `test_model` (HF AutoModelForCausalLM + PEFT adapter) and `BASE_MODEL` are defined.
|
||||
|
||||
import torch, threading
|
||||
import gradio as gr
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
TextIteratorStreamer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
|
||||
# -- Tokenizer (use BASE model tokenizer) --
|
||||
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
|
||||
|
||||
# Ensure pad/eos exist and are consistent
|
||||
if tokenizer.pad_token is None and tokenizer.eos_token is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif tokenizer.eos_token is None and tokenizer.pad_token is not None:
|
||||
tokenizer.eos_token = tokenizer.pad_token
|
||||
elif tokenizer.pad_token is None and tokenizer.eos_token is None:
|
||||
tokenizer.add_special_tokens({"eos_token": "</s>"})
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
try:
|
||||
test_model.resize_token_embeddings(len(tokenizer))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
DEVICE = getattr(test_model, "device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
SYSTEM_PROMPT = "You are a helpful assistant."
|
||||
|
||||
|
||||
# --- Custom stop: if the model starts a new user turn ([INST]) stop generation immediately.
|
||||
# This prevents the model from “answering its own replies”.
|
||||
class StopOnInst(StoppingCriteria):
|
||||
def __init__(self, tokenizer, trigger_text="[INST]"):
|
||||
self.trigger_ids = tokenizer.encode(trigger_text, add_special_tokens=False)
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
if not self.trigger_ids:
|
||||
return False
|
||||
seq = input_ids[0].tolist()
|
||||
tlen = len(self.trigger_ids)
|
||||
if len(seq) < tlen:
|
||||
return False
|
||||
return seq[-tlen:] == self.trigger_ids
|
||||
|
||||
|
||||
STOPPING = StoppingCriteriaList([StopOnInst(tokenizer)])
|
||||
|
||||
|
||||
def _build_inputs(pairs):
|
||||
"""
|
||||
pairs: list of (user, assistant) tuples.
|
||||
We include prior completed assistant replies and the latest user with empty assistant,
|
||||
then ask the model to continue as assistant.
|
||||
"""
|
||||
msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
for u, a in pairs:
|
||||
u = (u or "").strip()
|
||||
a = (a or "").strip()
|
||||
if not u and not a:
|
||||
continue
|
||||
if u:
|
||||
msgs.append({"role": "user", "content": u})
|
||||
if a:
|
||||
msgs.append({"role": "assistant", "content": a})
|
||||
|
||||
# Use chat template; many Mistral tokenizers return a single Tensor (input_ids)
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
msgs, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
if isinstance(input_ids, torch.Tensor):
|
||||
inputs = {"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids)}
|
||||
else:
|
||||
inputs = input_ids
|
||||
return {k: v.to(DEVICE) for k, v in inputs.items()}
|
||||
|
||||
|
||||
def stream_reply(history_pairs, max_new_tokens=512, temperature=0.7, top_p=0.9):
|
||||
inputs = _build_inputs(history_pairs)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
gen_kwargs = dict(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id, # Mistral uses </s> as EOS
|
||||
streamer=streamer,
|
||||
stopping_criteria=STOPPING, # <- key fix
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
t = threading.Thread(target=test_model.generate, kwargs=gen_kwargs)
|
||||
t.start()
|
||||
partial = ""
|
||||
for piece in streamer:
|
||||
partial += piece
|
||||
yield partial
|
||||
t.join()
|
||||
|
||||
|
||||
# --- Gradio handlers ---
|
||||
|
||||
|
||||
def gr_respond(message, chat_history):
|
||||
message = (message or "").strip()
|
||||
chat_history = chat_history or []
|
||||
# Append new user turn with empty assistant; we stream into that slot.
|
||||
chat_history = chat_history + [(message, "")]
|
||||
pairs = [(u or "", a or "") for (u, a) in chat_history]
|
||||
|
||||
for partial in stream_reply(pairs):
|
||||
chat_history[-1] = (message, partial)
|
||||
yield "", chat_history # clears textbox, updates chat
|
||||
|
||||
|
||||
def gr_clear():
|
||||
return None
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("## 💬 Chat with Touristral")
|
||||
chat = gr.Chatbot(height=200, layout="bubble")
|
||||
with gr.Row():
|
||||
msg = gr.Textbox(placeholder="Type a message and press Enter…", scale=9)
|
||||
send = gr.Button("Send", scale=1)
|
||||
with gr.Row():
|
||||
clear = gr.Button("Clear chat")
|
||||
|
||||
msg.submit(gr_respond, [msg, chat], [msg, chat])
|
||||
send.click(gr_respond, [msg, chat], [msg, chat])
|
||||
clear.click(gr_clear, None, chat, queue=False)
|
||||
|
||||
demo.queue().launch(share=False)
|
||||
|
||||
# %% [markdown]
|
||||
# ## 10) Light evaluation on the validation set
|
||||
#
|
||||
|
||||
# %%
|
||||
import evaluate
|
||||
|
||||
if eval_ds:
|
||||
rouge = evaluate.load("rouge")
|
||||
preds, refs = [], []
|
||||
for ex in val_rows[:50]:
|
||||
preds.append(generate_answer(ex["input"], max_new_tokens=192, temperature=0.2))
|
||||
refs.append(ex["output"])
|
||||
results = rouge.compute(predictions=preds, references=refs)
|
||||
print(results)
|
||||
else:
|
||||
print("No eval split available; skipped.")
|
||||
|
||||
|
||||
# %% [markdown]
|
||||
# ## 11) (Optional) Use with other runtimes
|
||||
#
|
||||
# - **Python Inference (PEFT)**: Load base model + adapter as shown in Section 9.
|
||||
# - **Merged model**: Set `DO_MERGE=True` to create a standalone model directory; you can then convert to other runtimes (e.g., llama.cpp GGUF) using their conversion tools.
|
||||
# - **Ollama**: If your runtime supports adapters or merged weights for the chosen base model, create a `Modelfile` pointing to them. Need a concrete path? Tell me your base and target runtime and I’ll add exact steps.
|
||||
#
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user