Files
masterthesis-playground/raft/nb_build_raft_bali_culture_dataset.py

324 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.18.0
# ---
# %% [markdown]
# # RAFT Data Builder for Bali Culture Tourism (DeepSeek)
#
# This notebook builds a **.jsonl chat dataset** from a text corpus in `./corpus` for **PEFT fine-tuning** a 7B model.
# It performs a RAFT-style generation loop with the **DeepSeek API** (using `deepseek-reasoner`) to synthesize high-quality, domain-specific Q&A pairs with detailed, insightful answers for **culture-focused tourism in Bali**.
#
# > Created: **2025-10-20 21:22:57 UTC**
#
# ## What this notebook does
#
# - Parses `.txt` files in `./corpus` that look like the sample you provided (lines with `(1)`, `(2)`, ... reviews).
# - Batches snippets and prompts DeepSeek to generate question-answer pairs (36 per batch) with **rich cultural context**.
# - Ensures **assistant answers are long-form, practical, and expert-level**, while keeping the final dataset **free of chain-of-thought**.
# - Outputs a **chat-format JSONL** with entries like:
# ```json
# {
# "messages": [
# { "role": "system", "content": "..." },
# { "role": "user", "content": "Question..." },
# { "role": "assistant", "content": "Answer..." }
# ]
# }
# ```
# - Targets **5k10k** samples (configurable) so its substantial enough for **PEFT on a 7B** model.
#
# %% [markdown]
# ## 1) Setup
#
# - Put your `.txt` files under `./corpus/`.
# - Set your DeepSeek API key in the environment as `DEEPSEEK_API_KEY` (or edit the cell below).
# - (Optional) Adjust batching, limits, and output paths.
#
# %%
# If needed, install dependencies
# You can comment out packages you already have.
# %pip install -q --upgrade openai tenacity tqdm tiktoken nbformat
# %%
import os
from pathlib import Path
# -------- Configuration --------
CORPUS_DIR = Path("./corpus")
OUTPUT_JSONL = Path("./bali_culture_raft_dataset.jsonl")
FAILED_LOG = Path("./raft_failures.log")
# Aim high for a 7B PEFT: 5k10k samples is a solid start.
TARGET_MIN_SAMPLES = 5000
TARGET_MAX_SAMPLES = 10000
# How many Q&A pairs to request per API call.
GEN_PAIRS_PER_BATCH = (3, 6) # (min, max)
# Number of review snippets to include in one request (to anchor the generations).
SNIPPETS_PER_BATCH = 6
# Model + API
DEEPSEEK_MODEL = "deepseek-reasoner" # reasoning model with CoT (we will discard CoT in dataset)
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
API_KEY = os.environ.get("DEEPSEEK_API_KEY", "PUT_YOUR_KEY_HERE")
# Rate & safety
TEMPERATURE = 0.8
MAX_TOKENS = 1200 # allow long, detailed answers
TIMEOUT = 60
# Reproducibility
SEED = 42
# --------------------------------
os.makedirs(CORPUS_DIR, exist_ok=True)
print(f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}")
# %%
import re
from typing import List, Dict
def parse_corpus_text(text: str) -> List[str]:
"""
Parse a file that contains lines like:
(1) review...
(2) review...
Returns list of review strings.
"""
# Remove header-like lines e.g. [TOPIC], [Stats], blank runs
lines = []
for line in text.splitlines():
if line.strip().startswith("[TOPIC]") or line.strip().startswith("[Stats]"):
continue
if line.strip() == "":
continue
lines.append(line)
cleaned = "\n".join(lines)
# Extract "(n) ..." reviews; allow multi-line with \n in between until next numbered item
# Strategy: split on occurrences of lines starting with (\d+)
pieces = re.split(r"\n?\(\d+\)\s*", cleaned)
reviews = []
for piece in pieces:
p = piece.strip()
if not p:
continue
reviews.append(p)
return reviews
def load_corpus_snippets(corpus_dir: Path) -> List[str]:
snippets = []
for p in corpus_dir.glob("**/*.txt"):
try:
txt = p.read_text(encoding="utf-8", errors="ignore")
reviews = parse_corpus_text(txt)
snippets.extend(reviews)
except Exception as e:
print(f"Failed to parse {p}: {e}")
return snippets
snippets = load_corpus_snippets(CORPUS_DIR)
print(f"Loaded {len(snippets)} review snippets.")
print("Example:", snippets[0][:200] if snippets else "(no snippets)")
# %%
import json
import random
random.seed(SEED)
SYSTEM_PROMPT = """
You are a meticulous, culture-focused Bali travel expert. Your mission is to craft **deeply insightful, practical, and accurate** answers for travelers seeking **Balinese cultural experiences**—temples, ceremonies, etiquette, dance, crafts, village life, historical and spiritual context, sacred geography, and respectful participation.
- Prioritize **cultural significance**, local customs (dress, offerings, behavior), do/dont etiquette, timing considerations (ceremonies, tides, festivals), and **responsible tourism**.
- Go **beyond surface tips**: weave in specific names, regional differences, logistics (hours, access, fees, crowd patterns), and **context that helps travelers act respectfully**.
- If the provided snippets are thin or conflicting, **draw on your broader knowledge** to fill gaps and **add relevant context**—but keep the advice truthful and practical. Avoid speculation presented as fact; if uncertain, say so briefly.
- Use clear structure: short intro, numbered/bulleted steps or sections, then a brief “Essentials” recap.
- **Do NOT include chain-of-thought or hidden reasoning** in final answers; respond with the polished final guidance only.
"""
GEN_INSTRUCTION = """
From the provided review snippets, generate {k} distinct **Q&A pairs** valuable for travelers focused on Balinese culture. Each Q&A should:
- Ask a question a culture-curious traveler would search for (concise).
- Provide a **thorough, actionable, expert** answer (400900 words when needed).
- Incorporate and reconcile the snippets where helpful, but **freely add authoritative, accurate context** to reach high quality.
- Emphasize respect, safety, logistics, cultural sensitivity, and practical steps.
- Do **NOT** output chain-of-thought. Do **NOT** include references to “snippets” or meta-instructions in the final answer.
- Return ONLY valid JSON with this shape:
{
"pairs": [ {"question": "...","answer": "..."} , ... ]
}
"""
def make_user_prompt(batch_snippets, k):
joined = "\n\n---\n\n".join(batch_snippets)
return f"""You are given **Bali travel review snippets** (may be noisy/partial).
Generate {k} culture-focused Q&A pairs in JSON using the spec provided.
Snippets:
{joined}
{GEN_INSTRUCTION.format(k=k)}
"""
# %%
import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from typing import Tuple
from tqdm import tqdm
import math
try:
from openai import OpenAI
except Exception:
# Fallback if import name differs; user can pip install openai
raise
client = OpenAI(api_key=API_KEY, base_url=DEEPSEEK_BASE_URL)
def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MODEL,
temperature: float = TEMPERATURE, max_tokens: int = MAX_TOKENS, timeout: int = TIMEOUT) -> dict:
"""
Calls DeepSeek's /chat/completions. Returns parsed JSON (dict) from assistant content.
Any 'reasoning_content' produced by deepseek-reasoner is ignored.
"""
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
max_tokens=max_tokens,
response_format={"type":"json_object"},
timeout=timeout,
)
content = resp.choices[0].message.content
return json.loads(content)
def as_messages_entry(question: str, answer: str) -> dict:
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT.strip()},
{"role": "user", "content": question.strip()},
{"role": "assistant", "content": answer.strip()},
]
}
def chunk(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i+n]
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30), reraise=True)
def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
user_prompt = make_user_prompt(batch_snips, k)
data = ask_deepseek(SYSTEM_PROMPT, user_prompt)
pairs = data.get("pairs", [])
out = []
for p in pairs:
q = p.get("question", "").strip()
a = p.get("answer", "").strip()
if q and a:
out.append(as_messages_entry(q, a))
return out
# %%
import random, json
random.seed(SEED)
if OUTPUT_JSONL.exists():
print(f"WARNING: {OUTPUT_JSONL} already exists. New generations will be appended.")
total_written = 0
failed_batches = 0
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(FAILED_LOG, "a", encoding="utf-8") as flog:
random.shuffle(snippets)
for i in range(0, len(snippets), SNIPPETS_PER_BATCH):
batch = snippets[i:i+SNIPPETS_PER_BATCH]
remaining_min = max(TARGET_MIN_SAMPLES - total_written, 0)
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
if remaining_max <= 0:
break
k_low, k_high = GEN_PAIRS_PER_BATCH
k = min(k_high, max(k_low, remaining_min // 2 if remaining_min else k_high))
try:
entries = generate_pairs_for_batch(batch, k=k)
for e in entries:
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
total_written += len(entries)
except Exception as e:
failed_batches += 1
flog.write(f"BATCH_FAIL\t{repr(e)}\n")
continue
print(f"Done. Wrote {total_written} samples. Failed batches: {failed_batches}.")
# %%
# Quick sanity check: read a few lines back
import json, itertools
n_show = 3
if OUTPUT_JSONL.exists():
with open(OUTPUT_JSONL, "r", encoding="utf-8") as f:
for i, line in zip(range(n_show), f):
try:
obj = json.loads(line)
msgs = obj.get("messages", [])
print(f"Sample {i+1}:")
for m in msgs:
print(f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}")
print("-"*80)
except Exception as e:
print("Failed to parse a line:", e)
else:
print("No dataset file yet. Run the generation cells above.")
# %%
# Optional utility: shard the JSONL for training convenience
from pathlib import Path
def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
shard_idx = 0
count = 0
out = None
with open(input_path, "r", encoding="utf-8") as fin:
for line in fin:
if out is None or count >= lines_per_shard:
if out:
out.close()
shard_idx += 1
count = 0
shard_path = input_path.with_name(input_path.stem + f".part{shard_idx:03d}.jsonl")
out = open(shard_path, "w", encoding="utf-8")
print("Opened", shard_path)
out.write(line)
count += 1
if out:
out.close()
print("Sharding complete.")
# Example:
# shard_jsonl(OUTPUT_JSONL, lines_per_shard=4000)
# %% [markdown]
# ## 2) Training hint (PEFT)
#
# The dataset is compatible with the standard **messages** chat format used in many PEFT training scripts
# (e.g., Hugging Face transformers + PEFT/LoRA). Your trainer should read each JSONL line and feed the
# `messages` list to your chat template (or convert to a plain instruction format). For 7B models,
# start with ~5k10k high-quality samples and tune learning rate, epochs, and LoRA rank to your budget.
#