mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2025-12-06 02:00:50 +01:00
Add unedited test RAFT build with deepseek
This commit is contained in:
323
raft/nb_build_raft_bali_culture_dataset.py
Normal file
323
raft/nb_build_raft_bali_culture_dataset.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.18.0
|
||||
# ---
|
||||
|
||||
# %% [markdown]
|
||||
# # RAFT Data Builder for Bali Culture Tourism (DeepSeek)
|
||||
#
|
||||
# This notebook builds a **.jsonl chat dataset** from a text corpus in `./corpus` for **PEFT fine-tuning** a 7B model.
|
||||
# It performs a RAFT-style generation loop with the **DeepSeek API** (using `deepseek-reasoner`) to synthesize high-quality, domain-specific Q&A pairs with detailed, insightful answers for **culture-focused tourism in Bali**.
|
||||
#
|
||||
# > Created: **2025-10-20 21:22:57 UTC**
|
||||
#
|
||||
# ## What this notebook does
|
||||
#
|
||||
# - Parses `.txt` files in `./corpus` that look like the sample you provided (lines with `(1)`, `(2)`, ... reviews).
|
||||
# - Batches snippets and prompts DeepSeek to generate question-answer pairs (3–6 per batch) with **rich cultural context**.
|
||||
# - Ensures **assistant answers are long-form, practical, and expert-level**, while keeping the final dataset **free of chain-of-thought**.
|
||||
# - Outputs a **chat-format JSONL** with entries like:
|
||||
# ```json
|
||||
# {
|
||||
# "messages": [
|
||||
# { "role": "system", "content": "..." },
|
||||
# { "role": "user", "content": "Question..." },
|
||||
# { "role": "assistant", "content": "Answer..." }
|
||||
# ]
|
||||
# }
|
||||
# ```
|
||||
# - Targets **5k–10k** samples (configurable) so it’s substantial enough for **PEFT on a 7B** model.
|
||||
#
|
||||
|
||||
# %% [markdown]
|
||||
# ## 1) Setup
|
||||
#
|
||||
# - Put your `.txt` files under `./corpus/`.
|
||||
# - Set your DeepSeek API key in the environment as `DEEPSEEK_API_KEY` (or edit the cell below).
|
||||
# - (Optional) Adjust batching, limits, and output paths.
|
||||
#
|
||||
|
||||
# %%
|
||||
# If needed, install dependencies
|
||||
# You can comment out packages you already have.
|
||||
# %pip install -q --upgrade openai tenacity tqdm tiktoken nbformat
|
||||
|
||||
# %%
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# -------- Configuration --------
|
||||
CORPUS_DIR = Path("./corpus")
|
||||
OUTPUT_JSONL = Path("./bali_culture_raft_dataset.jsonl")
|
||||
FAILED_LOG = Path("./raft_failures.log")
|
||||
|
||||
# Aim high for a 7B PEFT: 5k–10k samples is a solid start.
|
||||
TARGET_MIN_SAMPLES = 5000
|
||||
TARGET_MAX_SAMPLES = 10000
|
||||
|
||||
# How many Q&A pairs to request per API call.
|
||||
GEN_PAIRS_PER_BATCH = (3, 6) # (min, max)
|
||||
|
||||
# Number of review snippets to include in one request (to anchor the generations).
|
||||
SNIPPETS_PER_BATCH = 6
|
||||
|
||||
# Model + API
|
||||
DEEPSEEK_MODEL = "deepseek-reasoner" # reasoning model with CoT (we will discard CoT in dataset)
|
||||
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
|
||||
API_KEY = os.environ.get("DEEPSEEK_API_KEY", "PUT_YOUR_KEY_HERE")
|
||||
|
||||
# Rate & safety
|
||||
TEMPERATURE = 0.8
|
||||
MAX_TOKENS = 1200 # allow long, detailed answers
|
||||
TIMEOUT = 60
|
||||
|
||||
# Reproducibility
|
||||
SEED = 42
|
||||
|
||||
# --------------------------------
|
||||
os.makedirs(CORPUS_DIR, exist_ok=True)
|
||||
print(f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}")
|
||||
|
||||
# %%
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
def parse_corpus_text(text: str) -> List[str]:
|
||||
"""
|
||||
Parse a file that contains lines like:
|
||||
(1) review...
|
||||
(2) review...
|
||||
Returns list of review strings.
|
||||
"""
|
||||
# Remove header-like lines e.g. [TOPIC], [Stats], blank runs
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
if line.strip().startswith("[TOPIC]") or line.strip().startswith("[Stats]"):
|
||||
continue
|
||||
if line.strip() == "":
|
||||
continue
|
||||
lines.append(line)
|
||||
cleaned = "\n".join(lines)
|
||||
|
||||
# Extract "(n) ..." reviews; allow multi-line with \n in between until next numbered item
|
||||
# Strategy: split on occurrences of lines starting with (\d+)
|
||||
pieces = re.split(r"\n?\(\d+\)\s*", cleaned)
|
||||
reviews = []
|
||||
for piece in pieces:
|
||||
p = piece.strip()
|
||||
if not p:
|
||||
continue
|
||||
reviews.append(p)
|
||||
return reviews
|
||||
|
||||
def load_corpus_snippets(corpus_dir: Path) -> List[str]:
|
||||
snippets = []
|
||||
for p in corpus_dir.glob("**/*.txt"):
|
||||
try:
|
||||
txt = p.read_text(encoding="utf-8", errors="ignore")
|
||||
reviews = parse_corpus_text(txt)
|
||||
snippets.extend(reviews)
|
||||
except Exception as e:
|
||||
print(f"Failed to parse {p}: {e}")
|
||||
return snippets
|
||||
|
||||
snippets = load_corpus_snippets(CORPUS_DIR)
|
||||
print(f"Loaded {len(snippets)} review snippets.")
|
||||
print("Example:", snippets[0][:200] if snippets else "(no snippets)")
|
||||
|
||||
# %%
|
||||
import json
|
||||
import random
|
||||
|
||||
random.seed(SEED)
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are a meticulous, culture-focused Bali travel expert. Your mission is to craft **deeply insightful, practical, and accurate** answers for travelers seeking **Balinese cultural experiences**—temples, ceremonies, etiquette, dance, crafts, village life, historical and spiritual context, sacred geography, and respectful participation.
|
||||
- Prioritize **cultural significance**, local customs (dress, offerings, behavior), do/don’t etiquette, timing considerations (ceremonies, tides, festivals), and **responsible tourism**.
|
||||
- Go **beyond surface tips**: weave in specific names, regional differences, logistics (hours, access, fees, crowd patterns), and **context that helps travelers act respectfully**.
|
||||
- If the provided snippets are thin or conflicting, **draw on your broader knowledge** to fill gaps and **add relevant context**—but keep the advice truthful and practical. Avoid speculation presented as fact; if uncertain, say so briefly.
|
||||
- Use clear structure: short intro, numbered/bulleted steps or sections, then a brief “Essentials” recap.
|
||||
- **Do NOT include chain-of-thought or hidden reasoning** in final answers; respond with the polished final guidance only.
|
||||
"""
|
||||
|
||||
GEN_INSTRUCTION = """
|
||||
From the provided review snippets, generate {k} distinct **Q&A pairs** valuable for travelers focused on Balinese culture. Each Q&A should:
|
||||
- Ask a question a culture-curious traveler would search for (concise).
|
||||
- Provide a **thorough, actionable, expert** answer (400–900 words when needed).
|
||||
- Incorporate and reconcile the snippets where helpful, but **freely add authoritative, accurate context** to reach high quality.
|
||||
- Emphasize respect, safety, logistics, cultural sensitivity, and practical steps.
|
||||
- Do **NOT** output chain-of-thought. Do **NOT** include references to “snippets” or meta-instructions in the final answer.
|
||||
- Return ONLY valid JSON with this shape:
|
||||
{
|
||||
"pairs": [ {"question": "...","answer": "..."} , ... ]
|
||||
}
|
||||
"""
|
||||
|
||||
def make_user_prompt(batch_snippets, k):
|
||||
joined = "\n\n---\n\n".join(batch_snippets)
|
||||
return f"""You are given **Bali travel review snippets** (may be noisy/partial).
|
||||
Generate {k} culture-focused Q&A pairs in JSON using the spec provided.
|
||||
|
||||
Snippets:
|
||||
{joined}
|
||||
|
||||
{GEN_INSTRUCTION.format(k=k)}
|
||||
"""
|
||||
|
||||
|
||||
# %%
|
||||
import time
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from typing import Tuple
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except Exception:
|
||||
# Fallback if import name differs; user can pip install openai
|
||||
raise
|
||||
|
||||
client = OpenAI(api_key=API_KEY, base_url=DEEPSEEK_BASE_URL)
|
||||
|
||||
def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MODEL,
|
||||
temperature: float = TEMPERATURE, max_tokens: int = MAX_TOKENS, timeout: int = TIMEOUT) -> dict:
|
||||
"""
|
||||
Calls DeepSeek's /chat/completions. Returns parsed JSON (dict) from assistant content.
|
||||
Any 'reasoning_content' produced by deepseek-reasoner is ignored.
|
||||
"""
|
||||
resp = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type":"json_object"},
|
||||
timeout=timeout,
|
||||
)
|
||||
content = resp.choices[0].message.content
|
||||
return json.loads(content)
|
||||
|
||||
def as_messages_entry(question: str, answer: str) -> dict:
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM_PROMPT.strip()},
|
||||
{"role": "user", "content": question.strip()},
|
||||
{"role": "assistant", "content": answer.strip()},
|
||||
]
|
||||
}
|
||||
|
||||
def chunk(lst, n):
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i+n]
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30), reraise=True)
|
||||
def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
|
||||
user_prompt = make_user_prompt(batch_snips, k)
|
||||
data = ask_deepseek(SYSTEM_PROMPT, user_prompt)
|
||||
pairs = data.get("pairs", [])
|
||||
out = []
|
||||
for p in pairs:
|
||||
q = p.get("question", "").strip()
|
||||
a = p.get("answer", "").strip()
|
||||
if q and a:
|
||||
out.append(as_messages_entry(q, a))
|
||||
return out
|
||||
|
||||
|
||||
# %%
|
||||
import random, json
|
||||
|
||||
random.seed(SEED)
|
||||
|
||||
if OUTPUT_JSONL.exists():
|
||||
print(f"WARNING: {OUTPUT_JSONL} already exists. New generations will be appended.")
|
||||
|
||||
total_written = 0
|
||||
failed_batches = 0
|
||||
|
||||
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(FAILED_LOG, "a", encoding="utf-8") as flog:
|
||||
random.shuffle(snippets)
|
||||
for i in range(0, len(snippets), SNIPPETS_PER_BATCH):
|
||||
batch = snippets[i:i+SNIPPETS_PER_BATCH]
|
||||
remaining_min = max(TARGET_MIN_SAMPLES - total_written, 0)
|
||||
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
|
||||
if remaining_max <= 0:
|
||||
break
|
||||
k_low, k_high = GEN_PAIRS_PER_BATCH
|
||||
k = min(k_high, max(k_low, remaining_min // 2 if remaining_min else k_high))
|
||||
|
||||
try:
|
||||
entries = generate_pairs_for_batch(batch, k=k)
|
||||
for e in entries:
|
||||
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
|
||||
total_written += len(entries)
|
||||
except Exception as e:
|
||||
failed_batches += 1
|
||||
flog.write(f"BATCH_FAIL\t{repr(e)}\n")
|
||||
continue
|
||||
|
||||
print(f"Done. Wrote {total_written} samples. Failed batches: {failed_batches}.")
|
||||
|
||||
# %%
|
||||
# Quick sanity check: read a few lines back
|
||||
import json, itertools
|
||||
|
||||
n_show = 3
|
||||
if OUTPUT_JSONL.exists():
|
||||
with open(OUTPUT_JSONL, "r", encoding="utf-8") as f:
|
||||
for i, line in zip(range(n_show), f):
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
msgs = obj.get("messages", [])
|
||||
print(f"Sample {i+1}:")
|
||||
for m in msgs:
|
||||
print(f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}")
|
||||
print("-"*80)
|
||||
except Exception as e:
|
||||
print("Failed to parse a line:", e)
|
||||
else:
|
||||
print("No dataset file yet. Run the generation cells above.")
|
||||
|
||||
# %%
|
||||
# Optional utility: shard the JSONL for training convenience
|
||||
from pathlib import Path
|
||||
|
||||
def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
|
||||
shard_idx = 0
|
||||
count = 0
|
||||
out = None
|
||||
with open(input_path, "r", encoding="utf-8") as fin:
|
||||
for line in fin:
|
||||
if out is None or count >= lines_per_shard:
|
||||
if out:
|
||||
out.close()
|
||||
shard_idx += 1
|
||||
count = 0
|
||||
shard_path = input_path.with_name(input_path.stem + f".part{shard_idx:03d}.jsonl")
|
||||
out = open(shard_path, "w", encoding="utf-8")
|
||||
print("Opened", shard_path)
|
||||
out.write(line)
|
||||
count += 1
|
||||
if out:
|
||||
out.close()
|
||||
print("Sharding complete.")
|
||||
|
||||
# Example:
|
||||
# shard_jsonl(OUTPUT_JSONL, lines_per_shard=4000)
|
||||
|
||||
# %% [markdown]
|
||||
# ## 2) Training hint (PEFT)
|
||||
#
|
||||
# The dataset is compatible with the standard **messages** chat format used in many PEFT training scripts
|
||||
# (e.g., Hugging Face transformers + PEFT/LoRA). Your trainer should read each JSONL line and feed the
|
||||
# `messages` list to your chat template (or convert to a plain instruction format). For 7B models,
|
||||
# start with ~5k–10k high-quality samples and tune learning rate, epochs, and LoRA rank to your budget.
|
||||
#
|
||||
Reference in New Issue
Block a user