mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-02-04 05:03:11 +01:00
RAFT adjustments for deepseek-based approach
This commit is contained in:
9902
raft/bali_culture_raft_dataset.jsonl
Normal file
9902
raft/bali_culture_raft_dataset.jsonl
Normal file
File diff suppressed because one or more lines are too long
@@ -6,6 +6,10 @@
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.18.0
|
||||
# kernelspec:
|
||||
# display_name: .venv
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# %% [markdown]
|
||||
@@ -60,14 +64,13 @@ FAILED_LOG = Path("./raft_failures.log")
|
||||
TARGET_MIN_SAMPLES = 5000
|
||||
TARGET_MAX_SAMPLES = 10000
|
||||
|
||||
# How many Q&A pairs to request per API call.
|
||||
GEN_PAIRS_PER_BATCH = (3, 6) # (min, max)
|
||||
|
||||
# Number of review snippets to include in one request (to anchor the generations).
|
||||
SNIPPETS_PER_BATCH = 6
|
||||
|
||||
# Model + API
|
||||
DEEPSEEK_MODEL = "deepseek-reasoner" # reasoning model with CoT (we will discard CoT in dataset)
|
||||
DEEPSEEK_MODEL = (
|
||||
"deepseek-chat" # reasoning model with CoT (we will discard CoT in dataset)
|
||||
)
|
||||
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
|
||||
API_KEY = os.environ.get("DEEPSEEK_API_KEY", "PUT_YOUR_KEY_HERE")
|
||||
|
||||
@@ -81,12 +84,15 @@ SEED = 42
|
||||
|
||||
# --------------------------------
|
||||
os.makedirs(CORPUS_DIR, exist_ok=True)
|
||||
print(f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}")
|
||||
print(
|
||||
f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}"
|
||||
)
|
||||
|
||||
# %%
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def parse_corpus_text(text: str) -> List[str]:
|
||||
"""
|
||||
Parse a file that contains lines like:
|
||||
@@ -115,6 +121,7 @@ def parse_corpus_text(text: str) -> List[str]:
|
||||
reviews.append(p)
|
||||
return reviews
|
||||
|
||||
|
||||
def load_corpus_snippets(corpus_dir: Path) -> List[str]:
|
||||
snippets = []
|
||||
for p in corpus_dir.glob("**/*.txt"):
|
||||
@@ -126,6 +133,7 @@ def load_corpus_snippets(corpus_dir: Path) -> List[str]:
|
||||
print(f"Failed to parse {p}: {e}")
|
||||
return snippets
|
||||
|
||||
|
||||
snippets = load_corpus_snippets(CORPUS_DIR)
|
||||
print(f"Loaded {len(snippets)} review snippets.")
|
||||
print("Example:", snippets[0][:200] if snippets else "(no snippets)")
|
||||
@@ -146,33 +154,39 @@ You are a meticulous, culture-focused Bali travel expert. Your mission is to cra
|
||||
"""
|
||||
|
||||
GEN_INSTRUCTION = """
|
||||
From the provided review snippets, generate {k} distinct **Q&A pairs** valuable for travelers focused on Balinese culture. Each Q&A should:
|
||||
From the provided review snippets, generate a distinct **Q&A pair** valuable for travelers focused on Balinese culture. Each Q&A should:
|
||||
- Ask a question a culture-curious traveler would search for (concise).
|
||||
- Provide a **thorough, actionable, expert** answer (400–900 words when needed).
|
||||
- Incorporate and reconcile the snippets where helpful, but **freely add authoritative, accurate context** to reach high quality.
|
||||
- Emphasize respect, safety, logistics, cultural sensitivity, and practical steps.
|
||||
- Do **NOT** output chain-of-thought. Do **NOT** include references to “snippets” or meta-instructions in the final answer.
|
||||
- Return ONLY valid JSON with this shape:
|
||||
{
|
||||
"pairs": [ {"question": "...","answer": "..."} , ... ]
|
||||
}
|
||||
{{
|
||||
"pairs": [ {{"question": "...","answer": "..."}} ]
|
||||
}}
|
||||
"""
|
||||
|
||||
def make_user_prompt(batch_snippets, k):
|
||||
|
||||
def make_user_prompt(batch_snippets):
|
||||
joined = "\n\n---\n\n".join(batch_snippets)
|
||||
return f"""You are given **Bali travel review snippets** (may be noisy/partial).
|
||||
Generate {k} culture-focused Q&A pairs in JSON using the spec provided.
|
||||
return f"""You are given a **Bali travel review snippet** (may be noisy/partial).
|
||||
Generate a culture-focused Q&A pair in JSON using the spec provided.
|
||||
|
||||
Snippets:
|
||||
{joined}
|
||||
|
||||
{GEN_INSTRUCTION.format(k=k)}
|
||||
{GEN_INSTRUCTION}
|
||||
"""
|
||||
|
||||
|
||||
# %%
|
||||
import time
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
retry_if_exception_type,
|
||||
)
|
||||
from typing import Tuple
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
@@ -185,8 +199,15 @@ except Exception:
|
||||
|
||||
client = OpenAI(api_key=API_KEY, base_url=DEEPSEEK_BASE_URL)
|
||||
|
||||
def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MODEL,
|
||||
temperature: float = TEMPERATURE, max_tokens: int = MAX_TOKENS, timeout: int = TIMEOUT) -> dict:
|
||||
|
||||
def ask_deepseek(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
model: str = DEEPSEEK_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
timeout: int = TIMEOUT,
|
||||
) -> dict:
|
||||
"""
|
||||
Calls DeepSeek's /chat/completions. Returns parsed JSON (dict) from assistant content.
|
||||
Any 'reasoning_content' produced by deepseek-reasoner is ignored.
|
||||
@@ -199,12 +220,14 @@ def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MOD
|
||||
],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type":"json_object"},
|
||||
response_format={"type": "json_object"},
|
||||
timeout=timeout,
|
||||
)
|
||||
content = resp.choices[0].message.content
|
||||
print(resp)
|
||||
return json.loads(content)
|
||||
|
||||
|
||||
def as_messages_entry(question: str, answer: str) -> dict:
|
||||
return {
|
||||
"messages": [
|
||||
@@ -214,14 +237,16 @@ def as_messages_entry(question: str, answer: str) -> dict:
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def chunk(lst, n):
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i+n]
|
||||
yield lst[i : i + n]
|
||||
|
||||
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30), reraise=True)
|
||||
def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
|
||||
user_prompt = make_user_prompt(batch_snips, k)
|
||||
|
||||
def generate_pairs_for_batch(batch_snips: list) -> list:
|
||||
user_prompt = make_user_prompt(batch_snips)
|
||||
data = ask_deepseek(SYSTEM_PROMPT, user_prompt)
|
||||
print(data)
|
||||
pairs = data.get("pairs", [])
|
||||
out = []
|
||||
for p in pairs:
|
||||
@@ -232,6 +257,52 @@ def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
|
||||
return out
|
||||
|
||||
|
||||
# %%
|
||||
import json, random
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
MAX_WORKERS = 32 # tune: start at 8/16/32, adjust for rate limits & CPU
|
||||
|
||||
|
||||
def safe_generate(batch):
|
||||
# Keep this small: just generate; let caller handle logging/writing
|
||||
return generate_pairs_for_batch(batch)
|
||||
|
||||
|
||||
random.shuffle(snippets)
|
||||
|
||||
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(
|
||||
FAILED_LOG, "a", encoding="utf-8"
|
||||
) as flog:
|
||||
total_written = 0
|
||||
failed_batches = 0
|
||||
|
||||
# Submit all jobs up front (or see “windowed submission” below)
|
||||
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
|
||||
futures = {
|
||||
ex.submit(safe_generate, batch): (i, batch)
|
||||
for i, batch in enumerate(snippets)
|
||||
}
|
||||
|
||||
for fut in as_completed(futures):
|
||||
i, batch = futures[fut]
|
||||
|
||||
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
|
||||
if remaining_max <= 0:
|
||||
# Optional: you can stop early; outstanding futures still run.
|
||||
break
|
||||
|
||||
try:
|
||||
entries = fut.result()
|
||||
# serialize writes in this single consumer loop
|
||||
for e in entries[:remaining_max]:
|
||||
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
|
||||
total_written += min(len(entries), remaining_max)
|
||||
|
||||
except Exception as e:
|
||||
failed_batches += 1
|
||||
flog.write(f"BATCH_FAIL\tidx={i}\t{repr(e)}\n")
|
||||
|
||||
# %%
|
||||
import random, json
|
||||
|
||||
@@ -243,19 +314,20 @@ if OUTPUT_JSONL.exists():
|
||||
total_written = 0
|
||||
failed_batches = 0
|
||||
|
||||
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(FAILED_LOG, "a", encoding="utf-8") as flog:
|
||||
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(
|
||||
FAILED_LOG, "a", encoding="utf-8"
|
||||
) as flog:
|
||||
random.shuffle(snippets)
|
||||
for i in range(0, len(snippets), SNIPPETS_PER_BATCH):
|
||||
batch = snippets[i:i+SNIPPETS_PER_BATCH]
|
||||
for i in range(0, len(snippets)):
|
||||
batch = snippets[i]
|
||||
print(i, batch)
|
||||
remaining_min = max(TARGET_MIN_SAMPLES - total_written, 0)
|
||||
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
|
||||
if remaining_max <= 0:
|
||||
break
|
||||
k_low, k_high = GEN_PAIRS_PER_BATCH
|
||||
k = min(k_high, max(k_low, remaining_min // 2 if remaining_min else k_high))
|
||||
|
||||
try:
|
||||
entries = generate_pairs_for_batch(batch, k=k)
|
||||
entries = generate_pairs_for_batch(batch)
|
||||
for e in entries:
|
||||
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
|
||||
total_written += len(entries)
|
||||
@@ -279,8 +351,10 @@ if OUTPUT_JSONL.exists():
|
||||
msgs = obj.get("messages", [])
|
||||
print(f"Sample {i+1}:")
|
||||
for m in msgs:
|
||||
print(f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}")
|
||||
print("-"*80)
|
||||
print(
|
||||
f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}"
|
||||
)
|
||||
print("-" * 80)
|
||||
except Exception as e:
|
||||
print("Failed to parse a line:", e)
|
||||
else:
|
||||
@@ -290,6 +364,7 @@ else:
|
||||
# Optional utility: shard the JSONL for training convenience
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
|
||||
shard_idx = 0
|
||||
count = 0
|
||||
@@ -301,7 +376,9 @@ def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
|
||||
out.close()
|
||||
shard_idx += 1
|
||||
count = 0
|
||||
shard_path = input_path.with_name(input_path.stem + f".part{shard_idx:03d}.jsonl")
|
||||
shard_path = input_path.with_name(
|
||||
input_path.stem + f".part{shard_idx:03d}.jsonl"
|
||||
)
|
||||
out = open(shard_path, "w", encoding="utf-8")
|
||||
print("Opened", shard_path)
|
||||
out.write(line)
|
||||
@@ -310,6 +387,7 @@ def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
|
||||
out.close()
|
||||
print("Sharding complete.")
|
||||
|
||||
|
||||
# Example:
|
||||
# shard_jsonl(OUTPUT_JSONL, lines_per_shard=4000)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user