Add unedited test RAFT build with deepseek

This commit is contained in:
2025-10-20 23:43:50 +02:00
parent 8cad184cb5
commit 71886c9091

View File

@@ -0,0 +1,323 @@
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.18.0
# ---
# %% [markdown]
# # RAFT Data Builder for Bali Culture Tourism (DeepSeek)
#
# This notebook builds a **.jsonl chat dataset** from a text corpus in `./corpus` for **PEFT fine-tuning** a 7B model.
# It performs a RAFT-style generation loop with the **DeepSeek API** (using `deepseek-reasoner`) to synthesize high-quality, domain-specific Q&A pairs with detailed, insightful answers for **culture-focused tourism in Bali**.
#
# > Created: **2025-10-20 21:22:57 UTC**
#
# ## What this notebook does
#
# - Parses `.txt` files in `./corpus` that look like the sample you provided (lines with `(1)`, `(2)`, ... reviews).
# - Batches snippets and prompts DeepSeek to generate question-answer pairs (36 per batch) with **rich cultural context**.
# - Ensures **assistant answers are long-form, practical, and expert-level**, while keeping the final dataset **free of chain-of-thought**.
# - Outputs a **chat-format JSONL** with entries like:
# ```json
# {
# "messages": [
# { "role": "system", "content": "..." },
# { "role": "user", "content": "Question..." },
# { "role": "assistant", "content": "Answer..." }
# ]
# }
# ```
# - Targets **5k10k** samples (configurable) so its substantial enough for **PEFT on a 7B** model.
#
# %% [markdown]
# ## 1) Setup
#
# - Put your `.txt` files under `./corpus/`.
# - Set your DeepSeek API key in the environment as `DEEPSEEK_API_KEY` (or edit the cell below).
# - (Optional) Adjust batching, limits, and output paths.
#
# %%
# If needed, install dependencies
# You can comment out packages you already have.
# %pip install -q --upgrade openai tenacity tqdm tiktoken nbformat
# %%
import os
from pathlib import Path
# -------- Configuration --------
CORPUS_DIR = Path("./corpus")
OUTPUT_JSONL = Path("./bali_culture_raft_dataset.jsonl")
FAILED_LOG = Path("./raft_failures.log")
# Aim high for a 7B PEFT: 5k10k samples is a solid start.
TARGET_MIN_SAMPLES = 5000
TARGET_MAX_SAMPLES = 10000
# How many Q&A pairs to request per API call.
GEN_PAIRS_PER_BATCH = (3, 6) # (min, max)
# Number of review snippets to include in one request (to anchor the generations).
SNIPPETS_PER_BATCH = 6
# Model + API
DEEPSEEK_MODEL = "deepseek-reasoner" # reasoning model with CoT (we will discard CoT in dataset)
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
API_KEY = os.environ.get("DEEPSEEK_API_KEY", "PUT_YOUR_KEY_HERE")
# Rate & safety
TEMPERATURE = 0.8
MAX_TOKENS = 1200 # allow long, detailed answers
TIMEOUT = 60
# Reproducibility
SEED = 42
# --------------------------------
os.makedirs(CORPUS_DIR, exist_ok=True)
print(f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}")
# %%
import re
from typing import List, Dict
def parse_corpus_text(text: str) -> List[str]:
"""
Parse a file that contains lines like:
(1) review...
(2) review...
Returns list of review strings.
"""
# Remove header-like lines e.g. [TOPIC], [Stats], blank runs
lines = []
for line in text.splitlines():
if line.strip().startswith("[TOPIC]") or line.strip().startswith("[Stats]"):
continue
if line.strip() == "":
continue
lines.append(line)
cleaned = "\n".join(lines)
# Extract "(n) ..." reviews; allow multi-line with \n in between until next numbered item
# Strategy: split on occurrences of lines starting with (\d+)
pieces = re.split(r"\n?\(\d+\)\s*", cleaned)
reviews = []
for piece in pieces:
p = piece.strip()
if not p:
continue
reviews.append(p)
return reviews
def load_corpus_snippets(corpus_dir: Path) -> List[str]:
snippets = []
for p in corpus_dir.glob("**/*.txt"):
try:
txt = p.read_text(encoding="utf-8", errors="ignore")
reviews = parse_corpus_text(txt)
snippets.extend(reviews)
except Exception as e:
print(f"Failed to parse {p}: {e}")
return snippets
snippets = load_corpus_snippets(CORPUS_DIR)
print(f"Loaded {len(snippets)} review snippets.")
print("Example:", snippets[0][:200] if snippets else "(no snippets)")
# %%
import json
import random
random.seed(SEED)
SYSTEM_PROMPT = """
You are a meticulous, culture-focused Bali travel expert. Your mission is to craft **deeply insightful, practical, and accurate** answers for travelers seeking **Balinese cultural experiences**—temples, ceremonies, etiquette, dance, crafts, village life, historical and spiritual context, sacred geography, and respectful participation.
- Prioritize **cultural significance**, local customs (dress, offerings, behavior), do/dont etiquette, timing considerations (ceremonies, tides, festivals), and **responsible tourism**.
- Go **beyond surface tips**: weave in specific names, regional differences, logistics (hours, access, fees, crowd patterns), and **context that helps travelers act respectfully**.
- If the provided snippets are thin or conflicting, **draw on your broader knowledge** to fill gaps and **add relevant context**—but keep the advice truthful and practical. Avoid speculation presented as fact; if uncertain, say so briefly.
- Use clear structure: short intro, numbered/bulleted steps or sections, then a brief “Essentials” recap.
- **Do NOT include chain-of-thought or hidden reasoning** in final answers; respond with the polished final guidance only.
"""
GEN_INSTRUCTION = """
From the provided review snippets, generate {k} distinct **Q&A pairs** valuable for travelers focused on Balinese culture. Each Q&A should:
- Ask a question a culture-curious traveler would search for (concise).
- Provide a **thorough, actionable, expert** answer (400900 words when needed).
- Incorporate and reconcile the snippets where helpful, but **freely add authoritative, accurate context** to reach high quality.
- Emphasize respect, safety, logistics, cultural sensitivity, and practical steps.
- Do **NOT** output chain-of-thought. Do **NOT** include references to “snippets” or meta-instructions in the final answer.
- Return ONLY valid JSON with this shape:
{
"pairs": [ {"question": "...","answer": "..."} , ... ]
}
"""
def make_user_prompt(batch_snippets, k):
joined = "\n\n---\n\n".join(batch_snippets)
return f"""You are given **Bali travel review snippets** (may be noisy/partial).
Generate {k} culture-focused Q&A pairs in JSON using the spec provided.
Snippets:
{joined}
{GEN_INSTRUCTION.format(k=k)}
"""
# %%
import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from typing import Tuple
from tqdm import tqdm
import math
try:
from openai import OpenAI
except Exception:
# Fallback if import name differs; user can pip install openai
raise
client = OpenAI(api_key=API_KEY, base_url=DEEPSEEK_BASE_URL)
def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MODEL,
temperature: float = TEMPERATURE, max_tokens: int = MAX_TOKENS, timeout: int = TIMEOUT) -> dict:
"""
Calls DeepSeek's /chat/completions. Returns parsed JSON (dict) from assistant content.
Any 'reasoning_content' produced by deepseek-reasoner is ignored.
"""
resp = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=temperature,
max_tokens=max_tokens,
response_format={"type":"json_object"},
timeout=timeout,
)
content = resp.choices[0].message.content
return json.loads(content)
def as_messages_entry(question: str, answer: str) -> dict:
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT.strip()},
{"role": "user", "content": question.strip()},
{"role": "assistant", "content": answer.strip()},
]
}
def chunk(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i+n]
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30), reraise=True)
def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
user_prompt = make_user_prompt(batch_snips, k)
data = ask_deepseek(SYSTEM_PROMPT, user_prompt)
pairs = data.get("pairs", [])
out = []
for p in pairs:
q = p.get("question", "").strip()
a = p.get("answer", "").strip()
if q and a:
out.append(as_messages_entry(q, a))
return out
# %%
import random, json
random.seed(SEED)
if OUTPUT_JSONL.exists():
print(f"WARNING: {OUTPUT_JSONL} already exists. New generations will be appended.")
total_written = 0
failed_batches = 0
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(FAILED_LOG, "a", encoding="utf-8") as flog:
random.shuffle(snippets)
for i in range(0, len(snippets), SNIPPETS_PER_BATCH):
batch = snippets[i:i+SNIPPETS_PER_BATCH]
remaining_min = max(TARGET_MIN_SAMPLES - total_written, 0)
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
if remaining_max <= 0:
break
k_low, k_high = GEN_PAIRS_PER_BATCH
k = min(k_high, max(k_low, remaining_min // 2 if remaining_min else k_high))
try:
entries = generate_pairs_for_batch(batch, k=k)
for e in entries:
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
total_written += len(entries)
except Exception as e:
failed_batches += 1
flog.write(f"BATCH_FAIL\t{repr(e)}\n")
continue
print(f"Done. Wrote {total_written} samples. Failed batches: {failed_batches}.")
# %%
# Quick sanity check: read a few lines back
import json, itertools
n_show = 3
if OUTPUT_JSONL.exists():
with open(OUTPUT_JSONL, "r", encoding="utf-8") as f:
for i, line in zip(range(n_show), f):
try:
obj = json.loads(line)
msgs = obj.get("messages", [])
print(f"Sample {i+1}:")
for m in msgs:
print(f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}")
print("-"*80)
except Exception as e:
print("Failed to parse a line:", e)
else:
print("No dataset file yet. Run the generation cells above.")
# %%
# Optional utility: shard the JSONL for training convenience
from pathlib import Path
def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
shard_idx = 0
count = 0
out = None
with open(input_path, "r", encoding="utf-8") as fin:
for line in fin:
if out is None or count >= lines_per_shard:
if out:
out.close()
shard_idx += 1
count = 0
shard_path = input_path.with_name(input_path.stem + f".part{shard_idx:03d}.jsonl")
out = open(shard_path, "w", encoding="utf-8")
print("Opened", shard_path)
out.write(line)
count += 1
if out:
out.close()
print("Sharding complete.")
# Example:
# shard_jsonl(OUTPUT_JSONL, lines_per_shard=4000)
# %% [markdown]
# ## 2) Training hint (PEFT)
#
# The dataset is compatible with the standard **messages** chat format used in many PEFT training scripts
# (e.g., Hugging Face transformers + PEFT/LoRA). Your trainer should read each JSONL line and feed the
# `messages` list to your chat template (or convert to a plain instruction format). For 7B models,
# start with ~5k10k high-quality samples and tune learning rate, epochs, and LoRA rank to your budget.
#