Compare commits

..

2 Commits

Author SHA1 Message Date
ef99f152ac QLoRA stuff + datasets 2025-12-27 16:38:45 +01:00
edafc06cab RAFT adjustments for deepseek-based approach 2025-12-15 21:00:17 +01:00
10 changed files with 40264 additions and 31 deletions

View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python3
"""
QLoRA SFT fine-tune for Mistral-7B on chat-style JSONL:
Each line: {"messages": [{"role":"system","content":...}, {"role":"user","content":...}, {"role":"assistant","content":...}, ...]}
Produces a LoRA adapter you can merge or load at inference time.
Example:
python finetune_mistral_bali_qlora.py \
--model_id mistralai/Mistral-7B-Instruct-v0.2 \
--train_jsonl /path/to/bali_train.jsonl \
--output_dir ./mistral-bali-lora \
--max_seq_len 2048 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--learning_rate 2e-4 \
--num_train_epochs 2 \
--streaming true
"""
import argparse
import json
import os
from typing import Any, Dict, List, Optional
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from trl import SFTTrainer
# -----------------------------
# Data formatting
# -----------------------------
def normalize_messages(messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""
Ensures message roles/content are well-formed and in allowed roles.
"""
allowed = {"system", "user", "assistant"}
out = []
for m in messages:
role = (m.get("role") or "").strip().lower()
content = m.get("content")
if role not in allowed or content is None:
continue
content = str(content)
out.append({"role": role, "content": content})
return out
def messages_to_text(tokenizer: AutoTokenizer, example: Dict[str, Any]) -> str:
"""
Converts {"messages":[...]} to a single training text using the model's chat template if available.
For Mistral Instruct models, tokenizer.apply_chat_template is typically present.
"""
messages = normalize_messages(example.get("messages", []))
if not messages:
return ""
# Prefer tokenizer chat template when available.
if (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
# add_generation_prompt=False -> include the assistant content in the formatted text
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
# Fallback formatting (less ideal than the native template):
# Keep it deterministic and simple.
parts = []
for m in messages:
r = m["role"]
c = m["content"].strip()
if r == "system":
# Override system message
system_message = """
You are a specialized Balinese cultural travel expert. Your role is to provide accurate, culturally grounded, and practical guidance for travelers engaging with Balinese culture, including temples, ceremonies, etiquette, ritual calendars, dance, crafts, village life, sacred landscapes, and historicalspiritual context.
Prioritize cultural meaning and lived practice over sightseeing. Explain why places, rituals, and customs matter, and how visitors should behave respectfully. Emphasize dress codes, offerings, bodily conduct, photography rules, gender and purity considerations, and community norms.
Integrate timing and context where relevant, including ceremonial cycles (Pawukon/Wuku, full and new moons), festival periods, tides, agricultural rhythms, and temple schedules. Promote responsible tourism, community benefit, and environmental care, and discourage entry into restricted or sacred spaces.
Go beyond generic tips by naming specific temples, villages, regions, ceremonies, deities, and regional variations. Include practical logistics (access, hours, customary donations, crowd patterns) when helpful, without speculation. If uncertain, state this briefly and suggest local confirmation.
Structure responses clearly: a brief contextual introduction, followed by well-labeled sections or bullet points, and a short “Essentials” or “Respect Checklist” summary.
Do not include chain-of-thought, hidden reasoning, or meta commentary. Provide only polished, user-facing guidance in a calm, authoritative, and respectful tone.
"""
parts.append(f"<<SYS>>\n{system_message}\n<</SYS>>\n")
elif r == "user":
parts.append(f"[USER]\n{c}\n")
else:
parts.append(f"[ASSISTANT]\n{c}\n")
return "\n".join(parts).strip() + "\n"
# -----------------------------
# Main
# -----------------------------
def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
"--model_id",
type=str,
required=True,
help="e.g. mistralai/Mistral-7B-Instruct-v0.2 (recommended) or base model",
)
p.add_argument(
"--train_jsonl",
type=str,
required=True,
help="Path to JSONL training file; each line has a 'messages' list.",
)
p.add_argument(
"--eval_jsonl",
type=str,
default=None,
help="Optional eval JSONL with same format.",
)
p.add_argument("--output_dir", type=str, required=True)
# Training hyperparameters
p.add_argument("--max_seq_len", type=int, default=2048)
p.add_argument("--per_device_train_batch_size", type=int, default=1)
p.add_argument("--per_device_eval_batch_size", type=int, default=1)
p.add_argument("--gradient_accumulation_steps", type=int, default=16)
p.add_argument("--learning_rate", type=float, default=2e-4)
p.add_argument("--weight_decay", type=float, default=0.0)
p.add_argument("--num_train_epochs", type=float, default=1.0)
p.add_argument("--warmup_ratio", type=float, default=0.03)
p.add_argument("--logging_steps", type=int, default=10)
p.add_argument("--save_steps", type=int, default=200)
p.add_argument("--eval_steps", type=int, default=200)
p.add_argument("--seed", type=int, default=42)
# Performance / memory
p.add_argument(
"--streaming",
type=str,
default="true",
help="true/false. Use streaming for very large JSONL.",
)
p.add_argument(
"--bf16",
type=str,
default="true",
help="true/false. Prefer bf16 if your GPU supports it.",
)
p.add_argument("--gradient_checkpointing", type=str, default="true")
# LoRA config
p.add_argument("--lora_r", type=int, default=16)
p.add_argument("--lora_alpha", type=int, default=32)
p.add_argument("--lora_dropout", type=float, default=0.05)
p.add_argument(
"--target_modules",
type=str,
default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
help="Comma-separated module names for Mistral-style architectures.",
)
# Optional: limit samples for quick smoke tests
p.add_argument("--max_train_samples", type=int, default=None)
p.add_argument("--max_eval_samples", type=int, default=None)
return p.parse_args()
def str2bool(x: str) -> bool:
return str(x).strip().lower() in {"1", "true", "yes", "y", "t"}
def main():
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
streaming = str2bool(args.streaming)
use_bf16 = str2bool(args.bf16)
use_gc = str2bool(args.gradient_checkpointing)
# -----------------------------
# Tokenizer
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True)
if tokenizer.pad_token is None:
# Common for causal LMs
tokenizer.pad_token = tokenizer.eos_token
# -----------------------------
# Model (4-bit QLoRA)
# -----------------------------
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16 if use_bf16 else torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
)
model.config.use_cache = False # important for training
if use_gc:
model.gradient_checkpointing_enable()
# Prepare for k-bit + LoRA
model = prepare_model_for_kbit_training(model)
target_modules = [m.strip() for m in args.target_modules.split(",") if m.strip()]
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# -----------------------------
# Dataset
# -----------------------------
data_files = {"train": args.train_jsonl}
if args.eval_jsonl:
data_files["eval"] = args.eval_jsonl
ds = load_dataset("json", data_files=data_files, streaming=streaming)
def format_fn(example: Dict[str, Any]) -> Dict[str, str]:
text = messages_to_text(tokenizer, example)
return {"text": text}
train_ds = ds["train"].map(format_fn)
eval_ds = ds["eval"].map(format_fn) if args.eval_jsonl else None
# Optional sample limits (works differently for streaming vs non-streaming)
if args.max_train_samples is not None:
if streaming:
train_ds = train_ds.take(args.max_train_samples)
else:
train_ds = train_ds.select(
range(min(args.max_train_samples, len(train_ds)))
)
if eval_ds is not None and args.max_eval_samples is not None:
if streaming:
eval_ds = eval_ds.take(args.max_eval_samples)
else:
eval_ds = eval_ds.select(range(min(args.max_eval_samples, len(eval_ds))))
# -----------------------------
# Training
# -----------------------------
training_args = TrainingArguments(
output_dir=args.output_dir,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
num_train_epochs=args.num_train_epochs,
warmup_ratio=args.warmup_ratio,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_strategy="steps" if eval_ds is not None else "no",
eval_steps=args.eval_steps if eval_ds is not None else None,
save_total_limit=3,
bf16=use_bf16,
fp16=not use_bf16,
optim="paged_adamw_8bit", # good default for QLoRA
lr_scheduler_type="cosine",
seed=args.seed,
report_to="none",
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
packing=True, # packs multiple conversations per sequence for higher throughput
)
trainer.train()
# Save LoRA adapter + tokenizer
trainer.model.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
print(f"\nDone. Saved LoRA adapter to: {args.output_dir}")
print("Inference: load base model + peft adapter from this directory.")
if __name__ == "__main__":
main()

10
qlora/run.sh Normal file
View File

@@ -0,0 +1,10 @@
python finetune_mistral_bali_qlora.py \
--model_id mistralai/Mistral-7B-Instruct-v0.2 \
--train_jsonl ../raft/bali_culture_raft_dataset.jsonl \
--output_dir ./mistral-bali-lora \
--max_seq_len 2048 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--learning_rate 2e-4 \
--num_train_epochs 2 \
--streaming true

File diff suppressed because one or more lines are too long

138
raft/jsonl_remapper.py Normal file
View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
"""
Rewrite chat-style JSONL into {"input": ..., "output": ...} JSONL for LLM tuning.
Expected input line shape (example):
{
"messages": [
{"role":"system","content":"..."},
{"role":"user","content":"..."},
{"role":"assistant","content":"..."}
],
"meta": {...} # optional
}
Output line shape:
{"input": "<user text>", "output": "<assistant text>"}
By default:
- Ignores all non-user/assistant roles (e.g., system).
- Emits one record per (user -> next assistant) pair in the conversation.
- Drops all other fields (including meta) unless --keep-meta is set.
Usage:
python rewrite_jsonl.py in.jsonl out.jsonl
cat in.jsonl | python rewrite_jsonl.py - - > out.jsonl
python rewrite_jsonl.py in.jsonl out.jsonl --only-last
python rewrite_jsonl.py in.jsonl out.jsonl --keep-meta
"""
import argparse
import json
import sys
from typing import Any, Dict, List, Optional, Tuple
def iter_user_assistant_pairs(messages: List[Dict[str, Any]]) -> List[Tuple[str, str]]:
"""
Return list of (user_content, assistant_content) pairs.
Pairing rule: whenever a 'user' message is followed later by the next 'assistant'
message, emit a pair. Intermediate system/tool messages are ignored.
"""
pairs: List[Tuple[str, str]] = []
pending_user: Optional[str] = None
for m in messages:
role = m.get("role")
content = m.get("content")
if role == "user":
# Start (or restart) a pending user turn
if isinstance(content, str) and content.strip():
pending_user = content
else:
pending_user = ""
elif role == "assistant":
if pending_user is not None:
assistant_text = content if isinstance(content, str) else ""
pairs.append((pending_user, assistant_text))
pending_user = None
else:
# ignore system/tool/developer/etc.
continue
return pairs
def read_lines(path: str) -> List[str]:
if path == "-":
return sys.stdin.read().splitlines()
with open(path, "r", encoding="utf-8") as f:
return f.read().splitlines()
def write_lines(path: str, lines: List[str]) -> None:
if path == "-":
sys.stdout.write("\n".join(lines) + ("\n" if lines else ""))
return
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + ("\n" if lines else ""))
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("infile", help="Input JSONL path, or '-' for stdin")
ap.add_argument("outfile", help="Output JSONL path, or '-' for stdout")
ap.add_argument(
"--only-last",
action="store_true",
help="Emit only the last (user -> assistant) pair per input line.",
)
ap.add_argument(
"--keep-meta",
action="store_true",
help="If input line has 'meta', copy it through to output records.",
)
args = ap.parse_args()
in_lines = read_lines(args.infile)
out_lines: List[str] = []
for idx, line in enumerate(in_lines, start=1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
sys.stderr.write(f"[line {idx}] JSON decode error: {e}\n")
continue
messages = obj.get("messages")
if not isinstance(messages, list):
# Not in expected format; skip silently (or log if desired)
continue
pairs = iter_user_assistant_pairs(messages)
if not pairs:
continue
if args.only_last:
pairs = [pairs[-1]]
for user_text, assistant_text in pairs:
out_obj: Dict[str, Any] = {
"input": user_text,
"output": assistant_text,
}
if args.keep_meta and isinstance(obj.get("meta"), dict):
out_obj["meta"] = obj["meta"]
out_lines.append(json.dumps(out_obj, ensure_ascii=False))
write_lines(args.outfile, out_lines)
return 0
if __name__ == "__main__":
raise SystemExit(main())

36
raft/jsonl_remapper_2.py Normal file
View File

@@ -0,0 +1,36 @@
import argparse
import json
def rewrite_jsonl(input_path, output_path):
with open(input_path, "r", encoding="utf-8") as infile, open(
output_path, "w", encoding="utf-8"
) as outfile:
for line_num, line in enumerate(infile, start=1):
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
user_text = record.get("input", "")
bot_text = record.get("output", "")
new_record = {"text": f"<user>: {user_text} <bot>: {bot_text}"}
outfile.write(json.dumps(new_record, ensure_ascii=False) + "\n")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON on line {line_num}") from e
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Rewrite JSONL from {input, output} to {text: '<user>: ... <bot>: ...'} format"
)
parser.add_argument("--input", required=True, help="Path to input JSONL file")
parser.add_argument("--output", required=True, help="Path to output JSONL file")
args = parser.parse_args()
rewrite_jsonl(args.input, args.output)

54
raft/jsonl_remapper_3.py Normal file
View File

@@ -0,0 +1,54 @@
import argparse
import json
def rewrite_jsonl(input_path, output_path):
with open(input_path, "r", encoding="utf-8") as infile, open(
output_path, "w", encoding="utf-8"
) as outfile:
for line_num, line in enumerate(infile, start=1):
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
messages = record.get("messages", [])
user_parts = []
bot_parts = []
for msg in messages:
role = msg.get("role")
content = msg.get("content", "")
if role == "user":
user_parts.append(content)
elif role == "assistant":
bot_parts.append(content)
# Skip entries without both sides
if not user_parts or not bot_parts:
continue
user_text = " ".join(user_parts)
bot_text = " ".join(bot_parts)
new_record = {"text": f"<user>: {user_text} <bot>: {bot_text}"}
outfile.write(json.dumps(new_record, ensure_ascii=False) + "\n")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON on line {line_num}") from e
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Rewrite messages-based JSONL to {text: '<user>: ... <bot>: ...'} format"
)
parser.add_argument("--input", required=True, help="Path to input JSONL file")
parser.add_argument("--output", required=True, help="Path to output JSONL file")
args = parser.parse_args()
rewrite_jsonl(args.input, args.output)

View File

@@ -6,6 +6,10 @@
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.18.0
# kernelspec:
# display_name: .venv
# language: python
# name: python3
# ---
# %% [markdown]
@@ -60,14 +64,13 @@ FAILED_LOG = Path("./raft_failures.log")
TARGET_MIN_SAMPLES = 5000
TARGET_MAX_SAMPLES = 10000
# How many Q&A pairs to request per API call.
GEN_PAIRS_PER_BATCH = (3, 6) # (min, max)
# Number of review snippets to include in one request (to anchor the generations).
SNIPPETS_PER_BATCH = 6
# Model + API
DEEPSEEK_MODEL = "deepseek-reasoner" # reasoning model with CoT (we will discard CoT in dataset)
DEEPSEEK_MODEL = (
"deepseek-chat" # reasoning model with CoT (we will discard CoT in dataset)
)
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
API_KEY = os.environ.get("DEEPSEEK_API_KEY", "PUT_YOUR_KEY_HERE")
@@ -81,12 +84,15 @@ SEED = 42
# --------------------------------
os.makedirs(CORPUS_DIR, exist_ok=True)
print(f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}")
print(
f"Corpus dir: {CORPUS_DIR.resolve()}\nOutput: {OUTPUT_JSONL.resolve()}\nModel: {DEEPSEEK_MODEL}"
)
# %%
import re
from typing import List, Dict
def parse_corpus_text(text: str) -> List[str]:
"""
Parse a file that contains lines like:
@@ -115,6 +121,7 @@ def parse_corpus_text(text: str) -> List[str]:
reviews.append(p)
return reviews
def load_corpus_snippets(corpus_dir: Path) -> List[str]:
snippets = []
for p in corpus_dir.glob("**/*.txt"):
@@ -126,6 +133,7 @@ def load_corpus_snippets(corpus_dir: Path) -> List[str]:
print(f"Failed to parse {p}: {e}")
return snippets
snippets = load_corpus_snippets(CORPUS_DIR)
print(f"Loaded {len(snippets)} review snippets.")
print("Example:", snippets[0][:200] if snippets else "(no snippets)")
@@ -146,33 +154,39 @@ You are a meticulous, culture-focused Bali travel expert. Your mission is to cra
"""
GEN_INSTRUCTION = """
From the provided review snippets, generate {k} distinct **Q&A pairs** valuable for travelers focused on Balinese culture. Each Q&A should:
From the provided review snippets, generate a distinct **Q&A pair** valuable for travelers focused on Balinese culture. Each Q&A should:
- Ask a question a culture-curious traveler would search for (concise).
- Provide a **thorough, actionable, expert** answer (400900 words when needed).
- Incorporate and reconcile the snippets where helpful, but **freely add authoritative, accurate context** to reach high quality.
- Emphasize respect, safety, logistics, cultural sensitivity, and practical steps.
- Do **NOT** output chain-of-thought. Do **NOT** include references to “snippets” or meta-instructions in the final answer.
- Return ONLY valid JSON with this shape:
{
"pairs": [ {"question": "...","answer": "..."} , ... ]
}
{{
"pairs": [ {{"question": "...","answer": "..."}} ]
}}
"""
def make_user_prompt(batch_snippets, k):
def make_user_prompt(batch_snippets):
joined = "\n\n---\n\n".join(batch_snippets)
return f"""You are given **Bali travel review snippets** (may be noisy/partial).
Generate {k} culture-focused Q&A pairs in JSON using the spec provided.
return f"""You are given a **Bali travel review snippet** (may be noisy/partial).
Generate a culture-focused Q&A pair in JSON using the spec provided.
Snippets:
{joined}
{GEN_INSTRUCTION.format(k=k)}
{GEN_INSTRUCTION}
"""
# %%
import time
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from typing import Tuple
from tqdm import tqdm
import math
@@ -185,8 +199,15 @@ except Exception:
client = OpenAI(api_key=API_KEY, base_url=DEEPSEEK_BASE_URL)
def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MODEL,
temperature: float = TEMPERATURE, max_tokens: int = MAX_TOKENS, timeout: int = TIMEOUT) -> dict:
def ask_deepseek(
system_prompt: str,
user_prompt: str,
model: str = DEEPSEEK_MODEL,
temperature: float = TEMPERATURE,
max_tokens: int = MAX_TOKENS,
timeout: int = TIMEOUT,
) -> dict:
"""
Calls DeepSeek's /chat/completions. Returns parsed JSON (dict) from assistant content.
Any 'reasoning_content' produced by deepseek-reasoner is ignored.
@@ -199,12 +220,14 @@ def ask_deepseek(system_prompt: str, user_prompt: str, model: str = DEEPSEEK_MOD
],
temperature=temperature,
max_tokens=max_tokens,
response_format={"type":"json_object"},
response_format={"type": "json_object"},
timeout=timeout,
)
content = resp.choices[0].message.content
print(resp)
return json.loads(content)
def as_messages_entry(question: str, answer: str) -> dict:
return {
"messages": [
@@ -214,14 +237,16 @@ def as_messages_entry(question: str, answer: str) -> dict:
]
}
def chunk(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i+n]
yield lst[i : i + n]
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30), reraise=True)
def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
user_prompt = make_user_prompt(batch_snips, k)
def generate_pairs_for_batch(batch_snips: list) -> list:
user_prompt = make_user_prompt(batch_snips)
data = ask_deepseek(SYSTEM_PROMPT, user_prompt)
print(data)
pairs = data.get("pairs", [])
out = []
for p in pairs:
@@ -232,6 +257,52 @@ def generate_pairs_for_batch(batch_snips: list, k: int) -> list:
return out
# %%
import json, random
from concurrent.futures import ThreadPoolExecutor, as_completed
MAX_WORKERS = 32 # tune: start at 8/16/32, adjust for rate limits & CPU
def safe_generate(batch):
# Keep this small: just generate; let caller handle logging/writing
return generate_pairs_for_batch(batch)
random.shuffle(snippets)
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(
FAILED_LOG, "a", encoding="utf-8"
) as flog:
total_written = 0
failed_batches = 0
# Submit all jobs up front (or see “windowed submission” below)
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
futures = {
ex.submit(safe_generate, batch): (i, batch)
for i, batch in enumerate(snippets)
}
for fut in as_completed(futures):
i, batch = futures[fut]
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
if remaining_max <= 0:
# Optional: you can stop early; outstanding futures still run.
break
try:
entries = fut.result()
# serialize writes in this single consumer loop
for e in entries[:remaining_max]:
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
total_written += min(len(entries), remaining_max)
except Exception as e:
failed_batches += 1
flog.write(f"BATCH_FAIL\tidx={i}\t{repr(e)}\n")
# %%
import random, json
@@ -239,23 +310,24 @@ random.seed(SEED)
if OUTPUT_JSONL.exists():
print(f"WARNING: {OUTPUT_JSONL} already exists. New generations will be appended.")
total_written = 0
failed_batches = 0
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(FAILED_LOG, "a", encoding="utf-8") as flog:
with open(OUTPUT_JSONL, "a", encoding="utf-8") as fout, open(
FAILED_LOG, "a", encoding="utf-8"
) as flog:
random.shuffle(snippets)
for i in range(0, len(snippets), SNIPPETS_PER_BATCH):
batch = snippets[i:i+SNIPPETS_PER_BATCH]
for i in range(0, len(snippets)):
batch = snippets[i]
print(i, batch)
remaining_min = max(TARGET_MIN_SAMPLES - total_written, 0)
remaining_max = max(TARGET_MAX_SAMPLES - total_written, 0)
if remaining_max <= 0:
break
k_low, k_high = GEN_PAIRS_PER_BATCH
k = min(k_high, max(k_low, remaining_min // 2 if remaining_min else k_high))
try:
entries = generate_pairs_for_batch(batch, k=k)
entries = generate_pairs_for_batch(batch)
for e in entries:
fout.write(json.dumps(e, ensure_ascii=False) + "\n")
total_written += len(entries)
@@ -279,8 +351,10 @@ if OUTPUT_JSONL.exists():
msgs = obj.get("messages", [])
print(f"Sample {i+1}:")
for m in msgs:
print(f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}")
print("-"*80)
print(
f"[{m['role']}] {m['content'][:120].replace('\n',' ')}{'...' if len(m['content'])>120 else ''}"
)
print("-" * 80)
except Exception as e:
print("Failed to parse a line:", e)
else:
@@ -290,6 +364,7 @@ else:
# Optional utility: shard the JSONL for training convenience
from pathlib import Path
def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
shard_idx = 0
count = 0
@@ -301,7 +376,9 @@ def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
out.close()
shard_idx += 1
count = 0
shard_path = input_path.with_name(input_path.stem + f".part{shard_idx:03d}.jsonl")
shard_path = input_path.with_name(
input_path.stem + f".part{shard_idx:03d}.jsonl"
)
out = open(shard_path, "w", encoding="utf-8")
print("Opened", shard_path)
out.write(line)
@@ -310,6 +387,7 @@ def shard_jsonl(input_path: Path, lines_per_shard: int = 2000):
out.close()
print("Sharding complete.")
# Example:
# shard_jsonl(OUTPUT_JSONL, lines_per_shard=4000)

9901
raft/remap2_bali.jsonl Normal file

File diff suppressed because one or more lines are too long

9902
raft/remap3_bali.jsonl Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long