Files
masterthesis-playground/raft/jsonl_remapper.py
2025-12-27 16:38:45 +01:00

139 lines
4.1 KiB
Python

#!/usr/bin/env python3
"""
Rewrite chat-style JSONL into {"input": ..., "output": ...} JSONL for LLM tuning.
Expected input line shape (example):
{
"messages": [
{"role":"system","content":"..."},
{"role":"user","content":"..."},
{"role":"assistant","content":"..."}
],
"meta": {...} # optional
}
Output line shape:
{"input": "<user text>", "output": "<assistant text>"}
By default:
- Ignores all non-user/assistant roles (e.g., system).
- Emits one record per (user -> next assistant) pair in the conversation.
- Drops all other fields (including meta) unless --keep-meta is set.
Usage:
python rewrite_jsonl.py in.jsonl out.jsonl
cat in.jsonl | python rewrite_jsonl.py - - > out.jsonl
python rewrite_jsonl.py in.jsonl out.jsonl --only-last
python rewrite_jsonl.py in.jsonl out.jsonl --keep-meta
"""
import argparse
import json
import sys
from typing import Any, Dict, List, Optional, Tuple
def iter_user_assistant_pairs(messages: List[Dict[str, Any]]) -> List[Tuple[str, str]]:
"""
Return list of (user_content, assistant_content) pairs.
Pairing rule: whenever a 'user' message is followed later by the next 'assistant'
message, emit a pair. Intermediate system/tool messages are ignored.
"""
pairs: List[Tuple[str, str]] = []
pending_user: Optional[str] = None
for m in messages:
role = m.get("role")
content = m.get("content")
if role == "user":
# Start (or restart) a pending user turn
if isinstance(content, str) and content.strip():
pending_user = content
else:
pending_user = ""
elif role == "assistant":
if pending_user is not None:
assistant_text = content if isinstance(content, str) else ""
pairs.append((pending_user, assistant_text))
pending_user = None
else:
# ignore system/tool/developer/etc.
continue
return pairs
def read_lines(path: str) -> List[str]:
if path == "-":
return sys.stdin.read().splitlines()
with open(path, "r", encoding="utf-8") as f:
return f.read().splitlines()
def write_lines(path: str, lines: List[str]) -> None:
if path == "-":
sys.stdout.write("\n".join(lines) + ("\n" if lines else ""))
return
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + ("\n" if lines else ""))
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("infile", help="Input JSONL path, or '-' for stdin")
ap.add_argument("outfile", help="Output JSONL path, or '-' for stdout")
ap.add_argument(
"--only-last",
action="store_true",
help="Emit only the last (user -> assistant) pair per input line.",
)
ap.add_argument(
"--keep-meta",
action="store_true",
help="If input line has 'meta', copy it through to output records.",
)
args = ap.parse_args()
in_lines = read_lines(args.infile)
out_lines: List[str] = []
for idx, line in enumerate(in_lines, start=1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
sys.stderr.write(f"[line {idx}] JSON decode error: {e}\n")
continue
messages = obj.get("messages")
if not isinstance(messages, list):
# Not in expected format; skip silently (or log if desired)
continue
pairs = iter_user_assistant_pairs(messages)
if not pairs:
continue
if args.only_last:
pairs = [pairs[-1]]
for user_text, assistant_text in pairs:
out_obj: Dict[str, Any] = {
"input": user_text,
"output": assistant_text,
}
if args.keep_meta and isinstance(obj.get("meta"), dict):
out_obj["meta"] = obj["meta"]
out_lines.append(json.dumps(out_obj, ensure_ascii=False))
write_lines(args.outfile, out_lines)
return 0
if __name__ == "__main__":
raise SystemExit(main())