RAFT adjustments for deepseek-based approach

This commit is contained in:
2025-12-15 21:00:17 +01:00
parent 71886c9091
commit edafc06cab
2 changed files with 10011 additions and 31 deletions

View File

@@ -6,6 +6,10 @@
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.18.0
# kernelspec:
# display_name: .venv
# language: python
# name: python3
# ---
# %% [markdown]
@@ -60,14 +64,13 @@ FAILED_LOG = Path("./raft_failures.log")
TARGET_MIN_SAMPLES = 5000
TARGET_MAX_SAMPLES = 10000
# How many Q&A pairs to request per API call.
GEN_PAIRS_PER_BATCH = (3, 6) # (min, max)
# Number of review snippets to include in one request (to anchor the generations).
SNIPPETS_PER_BATCH = 6
# Model + API
DEEPSEEK_MODEL = "deepseek-reasoner" # reasoning model with CoT (we will discard CoT in dataset)
DEEPSEEK_MODEL = (
"deepseek-chat" # reasoning model with CoT (we will discard CoT in dataset)
)
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
API_KEY = os.environ.get("DEEPSEEK_API_KEY", "PUT_YOUR_KEY_HERE")
@@ -81,12 +84,15 @@ SEED = 42
# --------------------------------
os.makedirs(CORPUS_DIR, exist_ok=True)
print(f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}")
print(
f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}"
)
# %%
import re
from typing import List, Dict
def parse_corpus_text(text: str) -> List[str]:
"""
Parse a file that contains lines like:
@@ -115,6 +121,7 @@ def parse_corpus_text(text: str) -> List[str]:
reviews.append(p)
return reviews
def load_corpus_snippets(corpus_dir: Path) -> List[str]:
snippets = []
for p in corpus_dir.glob("**/*.txt"):
@@ -126,6 +133,7 @@ def load_corpus_snippets(corpus_dir: Path) -> List[str]:
print(f"Failed to parse {p}: {e}")
return snippets
snippets = load_corpus_snippets(CORPUS_DIR)
print(f"Loaded {len(snippets)} review snippets.")
print("Example:", snippets[0][:200] if snippets else "(no snippets)")
@@ -146,33 +154,39 @@ You are a meticulous, culture-focused Bali travel expert. Your mission is to cra
"""
GEN_INSTRUCTION = """
From the provided review snippets, generate {k} distinct **Q&A pairs** valuable for travelers focused on Balinese culture. Each Q&A should:
From the provided review snippets, generate a distinct **Q&A pair** valuable for travelers focused on Balinese culture. Each Q&A should:
- Ask a question a culture-curious traveler would search for (concise).
- Provide a **thorough, actionable, expert** answer (400900 words when needed).
- Incorporate and reconcile the snippets where helpful, but **freely add authoritative, accurate context** to reach high quality.
- Emphasize respect, safety, logistics, cultural sensitivity, and practical steps.
- Do **NOT** output chain-of-thought. Do **NOT** include references to “snippets” or meta-instructions in the final answer.
- Return ONLY valid JSON with this shape:
{
"pairs": [ {"question": "...","answer": "..."} , ... ]
}
{{
"pairs": [ {{"question": "...","answer": "..."}} ]
}}
"""
def make_user_prompt(batch_snippets, k):
def make_user_prompt(batch_snippets):
joined = "\n\n---\n\n".join(batch_snippets)
return f"""You are given **Bali travel review snippets** (may be noisy/partial).
Generate {k} culture-focused Q&A pairs in JSON using the spec provided.
return f"""You are given a **Bali travel review snippet** (may be noisy/partial).
Generate a culture-focused Q&A pair in JSON using the spec provided.
Snippets:
{joined}
{GEN_INSTRUCTION.format(k=k)}
{GEN_INSTRUCTION}
"""
# %%
import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from typing import Tuple
from tqdm import tqdm
import math
@@ -185,8 +199,15 @@ except Exception:
client = OpenAI(api_key=API_KEY, base_url=DEEPSEEK_BASE_URL)
def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MODEL,
temperature: float = TEMPERATURE, max_tokens: int = MAX_TOKENS, timeout: int = TIMEOUT) -> dict:
def ask_deepseek(
system_prompt: str,
user_prompt: str,
model: str = DEEPSEEK_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = MAX_TOKENS,
timeout: int = TIMEOUT,
) -> dict:
"""
Calls DeepSeek's /chat/completions. Returns parsed JSON (dict) from assistant content.
Any 'reasoning_content' produced by deepseek-reasoner is ignored.
@@ -199,12 +220,14 @@ def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MOD
],
temperature=temperature,
max_tokens=max_tokens,
response_format={"type":"json_object"},
response_format={"type": "json_object"},
timeout=timeout,
)
content = resp.choices[0].message.content
print(resp)
return json.loads(content)
def as_messages_entry(question: str, answer: str) -> dict:
return {
"messages": [
@@ -214,14 +237,16 @@ def as_messages_entry(question: str, answer: str) -> dict:
]
}
def chunk(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i+n]
yield lst[i : i + n]
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30), reraise=True)
def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
user_prompt = make_user_prompt(batch_snips, k)
def generate_pairs_for_batch(batch_snips: list) -> list:
user_prompt = make_user_prompt(batch_snips)
data = ask_deepseek(SYSTEM_PROMPT, user_prompt)
print(data)
pairs = data.get("pairs", [])
out = []
for p in pairs:
@@ -232,6 +257,52 @@ def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
return out
# %%
import json, random
from concurrent.futures import ThreadPoolExecutor, as_completed
MAX_WORKERS = 32 # tune: start at 8/16/32, adjust for rate limits & CPU
def safe_generate(batch):
# Keep this small: just generate; let caller handle logging/writing
return generate_pairs_for_batch(batch)
random.shuffle(snippets)
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(
FAILED_LOG, "a", encoding="utf-8"
) as flog:
total_written = 0
failed_batches = 0
# Submit all jobs up front (or see “windowed submission” below)
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
futures = {
ex.submit(safe_generate, batch): (i, batch)
for i, batch in enumerate(snippets)
}
for fut in as_completed(futures):
i, batch = futures[fut]
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
if remaining_max <= 0:
# Optional: you can stop early; outstanding futures still run.
break
try:
entries = fut.result()
# serialize writes in this single consumer loop
for e in entries[:remaining_max]:
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
total_written += min(len(entries), remaining_max)
except Exception as e:
failed_batches += 1
flog.write(f"BATCH_FAIL\tidx={i}\t{repr(e)}\n")
# %%
import random, json
@@ -239,23 +310,24 @@ random.seed(SEED)
if OUTPUT_JSONL.exists():
print(f"WARNING: {OUTPUT_JSONL} already exists. New generations will be appended.")
total_written = 0
failed_batches = 0
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(FAILED_LOG, "a", encoding="utf-8") as flog:
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(
FAILED_LOG, "a", encoding="utf-8"
) as flog:
random.shuffle(snippets)
for i in range(0, len(snippets), SNIPPETS_PER_BATCH):
batch = snippets[i:i+SNIPPETS_PER_BATCH]
for i in range(0, len(snippets)):
batch = snippets[i]
print(i, batch)
remaining_min = max(TARGET_MIN_SAMPLES - total_written, 0)
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
if remaining_max <= 0:
break
k_low, k_high = GEN_PAIRS_PER_BATCH
k = min(k_high, max(k_low, remaining_min // 2 if remaining_min else k_high))
try:
entries = generate_pairs_for_batch(batch, k=k)
entries = generate_pairs_for_batch(batch)
for e in entries:
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
total_written += len(entries)
@@ -279,8 +351,10 @@ if OUTPUT_JSONL.exists():
msgs = obj.get("messages", [])
print(f"Sample {i+1}:")
for m in msgs:
print(f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}")
print("-"*80)
print(
f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}"
)
print("-" * 80)
except Exception as e:
print("Failed to parse a line:", e)
else:
@@ -290,6 +364,7 @@ else:
# Optional utility: shard the JSONL for training convenience
from pathlib import Path
def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
shard_idx = 0
count = 0
@@ -301,7 +376,9 @@ def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
out.close()
shard_idx += 1
count = 0
shard_path = input_path.with_name(input_path.stem + f".part{shard_idx:03d}.jsonl")
shard_path = input_path.with_name(
input_path.stem + f".part{shard_idx:03d}.jsonl"
)
out = open(shard_path, "w", encoding="utf-8")
print("Opened", shard_path)
out.write(line)
@@ -310,6 +387,7 @@ def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
out.close()
print("Sharding complete.")
# Example:
# shard_jsonl(OUTPUT_JSONL, lines_per_shard=4000)