mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2025-12-06 18:20:53 +01:00
324 lines
12 KiB
Python
324 lines
12 KiB
Python
# ---
|
||
# jupyter:
|
||
# jupytext:
|
||
# text_representation:
|
||
# extension: .py
|
||
# format_name: percent
|
||
# format_version: '1.3'
|
||
# jupytext_version: 1.18.0
|
||
# ---
|
||
|
||
# %% [markdown]
|
||
# # RAFT Data Builder for Bali Culture Tourism (DeepSeek)
|
||
#
|
||
# This notebook builds a **.jsonl chat dataset** from a text corpus in `./corpus` for **PEFT fine-tuning** a 7B model.
|
||
# It performs a RAFT-style generation loop with the **DeepSeek API** (using `deepseek-reasoner`) to synthesize high-quality, domain-specific Q&A pairs with detailed, insightful answers for **culture-focused tourism in Bali**.
|
||
#
|
||
# > Created: **2025-10-20 21:22:57 UTC**
|
||
#
|
||
# ## What this notebook does
|
||
#
|
||
# - Parses `.txt` files in `./corpus` that look like the sample you provided (lines with `(1)`, `(2)`, ... reviews).
|
||
# - Batches snippets and prompts DeepSeek to generate question-answer pairs (3–6 per batch) with **rich cultural context**.
|
||
# - Ensures **assistant answers are long-form, practical, and expert-level**, while keeping the final dataset **free of chain-of-thought**.
|
||
# - Outputs a **chat-format JSONL** with entries like:
|
||
# ```json
|
||
# {
|
||
# "messages": [
|
||
# { "role": "system", "content": "..." },
|
||
# { "role": "user", "content": "Question..." },
|
||
# { "role": "assistant", "content": "Answer..." }
|
||
# ]
|
||
# }
|
||
# ```
|
||
# - Targets **5k–10k** samples (configurable) so it’s substantial enough for **PEFT on a 7B** model.
|
||
#
|
||
|
||
# %% [markdown]
|
||
# ## 1) Setup
|
||
#
|
||
# - Put your `.txt` files under `./corpus/`.
|
||
# - Set your DeepSeek API key in the environment as `DEEPSEEK_API_KEY` (or edit the cell below).
|
||
# - (Optional) Adjust batching, limits, and output paths.
|
||
#
|
||
|
||
# %%
|
||
# If needed, install dependencies
|
||
# You can comment out packages you already have.
|
||
# %pip install -q --upgrade openai tenacity tqdm tiktoken nbformat
|
||
|
||
# %%
|
||
import os
|
||
from pathlib import Path
|
||
|
||
# -------- Configuration --------
|
||
CORPUS_DIR = Path("./corpus")
|
||
OUTPUT_JSONL = Path("./bali_culture_raft_dataset.jsonl")
|
||
FAILED_LOG = Path("./raft_failures.log")
|
||
|
||
# Aim high for a 7B PEFT: 5k–10k samples is a solid start.
|
||
TARGET_MIN_SAMPLES = 5000
|
||
TARGET_MAX_SAMPLES = 10000
|
||
|
||
# How many Q&A pairs to request per API call.
|
||
GEN_PAIRS_PER_BATCH = (3, 6) # (min, max)
|
||
|
||
# Number of review snippets to include in one request (to anchor the generations).
|
||
SNIPPETS_PER_BATCH = 6
|
||
|
||
# Model + API
|
||
DEEPSEEK_MODEL = "deepseek-reasoner" # reasoning model with CoT (we will discard CoT in dataset)
|
||
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
|
||
API_KEY = os.environ.get("DEEPSEEK_API_KEY", "PUT_YOUR_KEY_HERE")
|
||
|
||
# Rate & safety
|
||
TEMPERATURE = 0.8
|
||
MAX_TOKENS = 1200 # allow long, detailed answers
|
||
TIMEOUT = 60
|
||
|
||
# Reproducibility
|
||
SEED = 42
|
||
|
||
# --------------------------------
|
||
os.makedirs(CORPUS_DIR, exist_ok=True)
|
||
print(f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}")
|
||
|
||
# %%
|
||
import re
|
||
from typing import List, Dict
|
||
|
||
def parse_corpus_text(text: str) -> List[str]:
|
||
"""
|
||
Parse a file that contains lines like:
|
||
(1) review...
|
||
(2) review...
|
||
Returns list of review strings.
|
||
"""
|
||
# Remove header-like lines e.g. [TOPIC], [Stats], blank runs
|
||
lines = []
|
||
for line in text.splitlines():
|
||
if line.strip().startswith("[TOPIC]") or line.strip().startswith("[Stats]"):
|
||
continue
|
||
if line.strip() == "":
|
||
continue
|
||
lines.append(line)
|
||
cleaned = "\n".join(lines)
|
||
|
||
# Extract "(n) ..." reviews; allow multi-line with \n in between until next numbered item
|
||
# Strategy: split on occurrences of lines starting with (\d+)
|
||
pieces = re.split(r"\n?\(\d+\)\s*", cleaned)
|
||
reviews = []
|
||
for piece in pieces:
|
||
p = piece.strip()
|
||
if not p:
|
||
continue
|
||
reviews.append(p)
|
||
return reviews
|
||
|
||
def load_corpus_snippets(corpus_dir: Path) -> List[str]:
|
||
snippets = []
|
||
for p in corpus_dir.glob("**/*.txt"):
|
||
try:
|
||
txt = p.read_text(encoding="utf-8", errors="ignore")
|
||
reviews = parse_corpus_text(txt)
|
||
snippets.extend(reviews)
|
||
except Exception as e:
|
||
print(f"Failed to parse {p}: {e}")
|
||
return snippets
|
||
|
||
snippets = load_corpus_snippets(CORPUS_DIR)
|
||
print(f"Loaded {len(snippets)} review snippets.")
|
||
print("Example:", snippets[0][:200] if snippets else "(no snippets)")
|
||
|
||
# %%
|
||
import json
|
||
import random
|
||
|
||
random.seed(SEED)
|
||
|
||
SYSTEM_PROMPT = """
|
||
You are a meticulous, culture-focused Bali travel expert. Your mission is to craft **deeply insightful, practical, and accurate** answers for travelers seeking **Balinese cultural experiences**—temples, ceremonies, etiquette, dance, crafts, village life, historical and spiritual context, sacred geography, and respectful participation.
|
||
- Prioritize **cultural significance**, local customs (dress, offerings, behavior), do/don’t etiquette, timing considerations (ceremonies, tides, festivals), and **responsible tourism**.
|
||
- Go **beyond surface tips**: weave in specific names, regional differences, logistics (hours, access, fees, crowd patterns), and **context that helps travelers act respectfully**.
|
||
- If the provided snippets are thin or conflicting, **draw on your broader knowledge** to fill gaps and **add relevant context**—but keep the advice truthful and practical. Avoid speculation presented as fact; if uncertain, say so briefly.
|
||
- Use clear structure: short intro, numbered/bulleted steps or sections, then a brief “Essentials” recap.
|
||
- **Do NOT include chain-of-thought or hidden reasoning** in final answers; respond with the polished final guidance only.
|
||
"""
|
||
|
||
GEN_INSTRUCTION = """
|
||
From the provided review snippets, generate {k} distinct **Q&A pairs** valuable for travelers focused on Balinese culture. Each Q&A should:
|
||
- Ask a question a culture-curious traveler would search for (concise).
|
||
- Provide a **thorough, actionable, expert** answer (400–900 words when needed).
|
||
- Incorporate and reconcile the snippets where helpful, but **freely add authoritative, accurate context** to reach high quality.
|
||
- Emphasize respect, safety, logistics, cultural sensitivity, and practical steps.
|
||
- Do **NOT** output chain-of-thought. Do **NOT** include references to “snippets” or meta-instructions in the final answer.
|
||
- Return ONLY valid JSON with this shape:
|
||
{
|
||
"pairs": [ {"question": "...","answer": "..."} , ... ]
|
||
}
|
||
"""
|
||
|
||
def make_user_prompt(batch_snippets, k):
|
||
joined = "\n\n---\n\n".join(batch_snippets)
|
||
return f"""You are given **Bali travel review snippets** (may be noisy/partial).
|
||
Generate {k} culture-focused Q&A pairs in JSON using the spec provided.
|
||
|
||
Snippets:
|
||
{joined}
|
||
|
||
{GEN_INSTRUCTION.format(k=k)}
|
||
"""
|
||
|
||
|
||
# %%
|
||
import time
|
||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||
from typing import Tuple
|
||
from tqdm import tqdm
|
||
import math
|
||
|
||
try:
|
||
from openai import OpenAI
|
||
except Exception:
|
||
# Fallback if import name differs; user can pip install openai
|
||
raise
|
||
|
||
client = OpenAI(api_key=API_KEY, base_url=DEEPSEEK_BASE_URL)
|
||
|
||
def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MODEL,
|
||
temperature: float = TEMPERATURE, max_tokens: int = MAX_TOKENS, timeout: int = TIMEOUT) -> dict:
|
||
"""
|
||
Calls DeepSeek's /chat/completions. Returns parsed JSON (dict) from assistant content.
|
||
Any 'reasoning_content' produced by deepseek-reasoner is ignored.
|
||
"""
|
||
resp = client.chat.completions.create(
|
||
model=model,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
response_format={"type":"json_object"},
|
||
timeout=timeout,
|
||
)
|
||
content = resp.choices[0].message.content
|
||
return json.loads(content)
|
||
|
||
def as_messages_entry(question: str, answer: str) -> dict:
|
||
return {
|
||
"messages": [
|
||
{"role": "system", "content": SYSTEM_PROMPT.strip()},
|
||
{"role": "user", "content": question.strip()},
|
||
{"role": "assistant", "content": answer.strip()},
|
||
]
|
||
}
|
||
|
||
def chunk(lst, n):
|
||
for i in range(0, len(lst), n):
|
||
yield lst[i:i+n]
|
||
|
||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30), reraise=True)
|
||
def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
|
||
user_prompt = make_user_prompt(batch_snips, k)
|
||
data = ask_deepseek(SYSTEM_PROMPT, user_prompt)
|
||
pairs = data.get("pairs", [])
|
||
out = []
|
||
for p in pairs:
|
||
q = p.get("question", "").strip()
|
||
a = p.get("answer", "").strip()
|
||
if q and a:
|
||
out.append(as_messages_entry(q, a))
|
||
return out
|
||
|
||
|
||
# %%
|
||
import random, json
|
||
|
||
random.seed(SEED)
|
||
|
||
if OUTPUT_JSONL.exists():
|
||
print(f"WARNING: {OUTPUT_JSONL} already exists. New generations will be appended.")
|
||
|
||
total_written = 0
|
||
failed_batches = 0
|
||
|
||
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(FAILED_LOG, "a", encoding="utf-8") as flog:
|
||
random.shuffle(snippets)
|
||
for i in range(0, len(snippets), SNIPPETS_PER_BATCH):
|
||
batch = snippets[i:i+SNIPPETS_PER_BATCH]
|
||
remaining_min = max(TARGET_MIN_SAMPLES - total_written, 0)
|
||
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
|
||
if remaining_max <= 0:
|
||
break
|
||
k_low, k_high = GEN_PAIRS_PER_BATCH
|
||
k = min(k_high, max(k_low, remaining_min // 2 if remaining_min else k_high))
|
||
|
||
try:
|
||
entries = generate_pairs_for_batch(batch, k=k)
|
||
for e in entries:
|
||
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
|
||
total_written += len(entries)
|
||
except Exception as e:
|
||
failed_batches += 1
|
||
flog.write(f"BATCH_FAIL\t{repr(e)}\n")
|
||
continue
|
||
|
||
print(f"Done. Wrote {total_written} samples. Failed batches: {failed_batches}.")
|
||
|
||
# %%
|
||
# Quick sanity check: read a few lines back
|
||
import json, itertools
|
||
|
||
n_show = 3
|
||
if OUTPUT_JSONL.exists():
|
||
with open(OUTPUT_JSONL, "r", encoding="utf-8") as f:
|
||
for i, line in zip(range(n_show), f):
|
||
try:
|
||
obj = json.loads(line)
|
||
msgs = obj.get("messages", [])
|
||
print(f"Sample {i+1}:")
|
||
for m in msgs:
|
||
print(f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}")
|
||
print("-"*80)
|
||
except Exception as e:
|
||
print("Failed to parse a line:", e)
|
||
else:
|
||
print("No dataset file yet. Run the generation cells above.")
|
||
|
||
# %%
|
||
# Optional utility: shard the JSONL for training convenience
|
||
from pathlib import Path
|
||
|
||
def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
|
||
shard_idx = 0
|
||
count = 0
|
||
out = None
|
||
with open(input_path, "r", encoding="utf-8") as fin:
|
||
for line in fin:
|
||
if out is None or count >= lines_per_shard:
|
||
if out:
|
||
out.close()
|
||
shard_idx += 1
|
||
count = 0
|
||
shard_path = input_path.with_name(input_path.stem + f".part{shard_idx:03d}.jsonl")
|
||
out = open(shard_path, "w", encoding="utf-8")
|
||
print("Opened", shard_path)
|
||
out.write(line)
|
||
count += 1
|
||
if out:
|
||
out.close()
|
||
print("Sharding complete.")
|
||
|
||
# Example:
|
||
# shard_jsonl(OUTPUT_JSONL, lines_per_shard=4000)
|
||
|
||
# %% [markdown]
|
||
# ## 2) Training hint (PEFT)
|
||
#
|
||
# The dataset is compatible with the standard **messages** chat format used in many PEFT training scripts
|
||
# (e.g., Hugging Face transformers + PEFT/LoRA). Your trainer should read each JSONL line and feed the
|
||
# `messages` list to your chat template (or convert to a plain instruction format). For 7B models,
|
||
# start with ~5k–10k high-quality samples and tune learning rate, epochs, and LoRA rank to your budget.
|
||
#
|