Files
masterthesis-playground/raft/create_raft_dataset_notebook.py
2025-10-13 17:34:49 +02:00

620 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# This script programmatically creates a Jupyter Notebook tailored for
# "Retrieval-Augmented Fine-Tuning (RAFT) dataset generation using local Ollama".
# It saves the notebook to /mnt/data/ so you can download and run it.
import json
from datetime import datetime
import nbformat as nbf
nb = nbf.v4.new_notebook()
nb["metadata"].update(
{
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3",
},
"language_info": {"name": "python", "version": "3.x"},
}
)
cells = []
# Title / Overview
cells.append(
nbf.v4.new_markdown_cell(
"""
# Retrieval-Augmented **Fine-Tuning** (RAFT) Dataset Generation with **Local Ollama**
This notebook builds a **supervised fine-tuning dataset** (JSONL) for *retrieval-augmented* tasks, by:
1. **Ingesting** your local corpus (Markdown, text, HTML; PDFs optional with extra deps).
2. **Chunking** and **embedding** documents using Ollama's local **embedding model** (e.g., `nomic-embed-text`, `mxbai-embed-large`).
3. Building a **lightweight vector index** (FAISS).
4. **Sampling contexts** and using a local **generation model** via Ollama (e.g., `llama3.1`, `qwen2`, `phi3`) to synthesize **grounded Q&A** or instructionresponse pairs.
5. Emitting a **RAFT-style JSONL** for supervised training (e.g., `input`, `output`, `meta` with source citations).
> **Requirements**
>
> - Local [Ollama](https://ollama.com/) running at `http://localhost:11434`
> - At least one **embedding** model pulled (e.g., `ollama pull nomic-embed-text`)
> - At least one **generation** model pulled (e.g., `ollama pull llama3.1`)
>
> You can adapt the prompts and schema for your specific downstream trainer (Llama.cpp, vLLM, Axolotl, mlx, etc.).
"""
)
)
# Setup
cells.append(
nbf.v4.new_markdown_cell(
"""
## 0) Setup
Install Python dependencies. If you're offline, pre-install or remove what you don't need.
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
# If needed, uncomment:
# %pip install --quiet requests faiss-cpu rich markdownify python-frontmatter pypdf regex
# Optional extras:
# %pip install --quiet tiktoken beautifulsoup4 lxml
"""
)
)
# Config
cells.append(
nbf.v4.new_markdown_cell(
"""
## 1) Configuration
Set paths, models, and chunking/index params.
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import os, re, json, uuid, math, glob, random, time
import hashlib
import requests
from rich import print
import regex
import numpy as np
# ---- Core config ----
DATA_DIR = Path("./corpus") # Put your source docs here
OUTPUT_DIR = Path("./outputs") # Where artifacts are saved
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Ollama endpoints & models
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text")
GEN_MODEL = os.environ.get("GEN_MODEL", "llama3.1")
# Chunking
CHUNK_SIZE = 1200 # characters
CHUNK_OVERLAP = 200 # characters
MIN_CHARS = 200 # minimum viable chunk length
# Index
USE_FAISS = True
TOP_K = 4
# RAFT generation
SEED = 7
SAMPLES_PER_DOC = 4
MAX_TOKENS_GEN = 512 # Generation max tokens (approx; Ollama supports 'num_predict')
TEMPERATURE = 0.6
random.seed(SEED)
np.random.seed(SEED)
print({
"DATA_DIR": str(DATA_DIR.resolve()),
"OUTPUT_DIR": str(OUTPUT_DIR.resolve()),
"OLLAMA_URL": OLLAMA_URL,
"EMBED_MODEL": EMBED_MODEL,
"GEN_MODEL": GEN_MODEL,
"CHUNK_SIZE": CHUNK_SIZE,
"CHUNK_OVERLAP": CHUNK_OVERLAP,
"TOP_K": TOP_K,
"SAMPLES_PER_DOC": SAMPLES_PER_DOC
})
"""
)
)
# Helpers: loaders
cells.append(
nbf.v4.new_markdown_cell(
"""
## 2) Load & Normalize Documents
Basic loaders for `.md`, `.txt`, `.html`. PDF support is optional (requires `pypdf`). You can extend as needed.
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
from bs4 import BeautifulSoup # if you didn't install bs4, comment HTML support below
try:
import frontmatter
except Exception:
frontmatter = None
def read_text_file(p: Path) -> str:
return p.read_text(encoding="utf-8", errors="ignore")
def read_markdown(p: Path) -> str:
text = p.read_text(encoding="utf-8", errors="ignore")
# Optional: strip YAML frontmatter
if frontmatter:
try:
fm = frontmatter.loads(text)
return fm.content
except Exception:
return text
return text
def read_html(p: Path) -> str:
html = p.read_text(encoding="utf-8", errors="ignore")
soup = BeautifulSoup(html, "lxml")
# Remove script/style
for tag in soup(["script", "style", "noscript"]):
tag.decompose()
text = soup.get_text(" ", strip=True)
return text
def read_pdf(p: Path) -> str:
try:
from pypdf import PdfReader
except Exception as e:
print("[yellow]Install pypdf to enable PDF parsing: %pip install pypdf[/yellow]")
raise e
reader = PdfReader(str(p))
parts = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
parts.append("")
return "\\n".join(parts)
SUPPORTED_EXTS = {".txt": read_text_file, ".md": read_markdown, ".markdown": read_markdown,
".html": read_html, ".htm": read_html, ".pdf": read_pdf}
def load_corpus(data_dir: Path) -> Dict[str, str]:
docs = {}
for p in data_dir.rglob("*"):
if not p.is_file():
continue
fn = p.suffix.lower()
if fn in SUPPORTED_EXTS:
try:
docs[str(p)] = SUPPORTED_EXTS[fn](p)
except Exception as e:
print(f"[red]Failed to read {p}: {e}[/red]")
print(f"[green]Loaded {len(docs)} documents[/green]")
return docs
docs = load_corpus(DATA_DIR)
len(docs)
"""
)
)
# Chunking
cells.append(
nbf.v4.new_markdown_cell(
"""
## 3) Chunking
Simple character-based chunker with overlap. Swap in a token-based chunker if you prefer.
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
@dataclass
class Chunk:
id: str
doc_path: str
start: int
end: int
text: str
sha1: str
def chunk_text(text: str, doc_path: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[Chunk]:
chunks: List[Chunk] = []
i = 0
n = len(text)
while i < n:
j = min(i + chunk_size, n)
piece = text[i:j].strip()
if len(piece) >= MIN_CHARS:
sha1 = hashlib.sha1(piece.encode("utf-8")).hexdigest()
chunks.append(Chunk(
id=str(uuid.uuid4()),
doc_path=doc_path,
start=i, end=j,
text=piece,
sha1=sha1
))
if j == n:
break
i = j - overlap
if i < 0:
i = 0
if i >= n:
break
return chunks
all_chunks: List[Chunk] = []
for path, text in docs.items():
all_chunks.extend(chunk_text(text, path))
print(f"[green]Total chunks: {len(all_chunks)}[/green]")
"""
)
)
# Embeddings via Ollama
cells.append(
nbf.v4.new_markdown_cell(
"""
## 4) Embeddings via Ollama
Uses Ollama's `POST /api/embeddings` endpoint with your selected embedding model.
Make sure you've pulled it locally: `ollama pull nomic-embed-text` (or your chosen model).
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
EMBED_ENDPOINT = f"{OLLAMA_URL}/api/embeddings"
def embed_texts(texts: List[str], model: str = EMBED_MODEL, batch_size: int = 32) -> np.ndarray:
vectors = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
# Ollama supports a single prompt or list? We'll call one by one to be safe with large content.
for t in batch:
r = requests.post(EMBED_ENDPOINT, json={"model": model, "prompt": t})
r.raise_for_status()
data = r.json()
vec = np.array(data["embedding"], dtype=np.float32)
vectors.append(vec)
return np.vstack(vectors) if vectors else np.zeros((0, 768), dtype=np.float32)
chunk_texts = [c.text for c in all_chunks]
emb_matrix = embed_texts(chunk_texts, model=EMBED_MODEL, batch_size=8)
emb_matrix.shape
"""
)
)
# Build FAISS
cells.append(
nbf.v4.new_markdown_cell(
"""
## 5) Build Vector Index (FAISS)
We normalize vectors and use inner product (equivalent to cosine on normalized vectors).
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
def normalize_rows(x: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
return x / norms
if USE_FAISS:
import faiss
xb = normalize_rows(emb_matrix).astype(np.float32)
d = xb.shape[1]
index = faiss.IndexFlatIP(d)
index.add(xb)
print("[green]FAISS index built:[/green]", index.ntotal, "vectors")
else:
index = None
xb = normalize_rows(emb_matrix).astype(np.float32)
"""
)
)
# Retrieval helper
cells.append(
nbf.v4.new_markdown_cell(
"""
## 6) Retrieval Helper
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
def search(query: str, top_k: int = TOP_K) -> List[Tuple[int, float]]:
# Embed the query
qv = embed_texts([query], model=EMBED_MODEL, batch_size=1)
qv = normalize_rows(qv).astype(np.float32)
if USE_FAISS and index is not None:
D, I = index.search(qv, top_k)
hits = list(zip(I[0].tolist(), D[0].tolist()))
else:
sims = (xb @ qv.T).ravel()
I = np.argsort(-sims)[:top_k]
hits = [(int(i), float(sims[i])) for i in I]
return hits
# quick smoke test (no error means it's wired up)
# print(search("What does this corpus talk about?", 3))
"""
)
)
# Generation via Ollama
cells.append(
nbf.v4.new_markdown_cell(
"""
## 7) Synthesize Grounded Q&A / Instructions with Ollama
We sample chunks, retrieve neighbors for richer context, and prompt a local LLM to create **high-quality** pairs.
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
GEN_ENDPOINT = f"{OLLAMA_URL}/api/generate"
SYSTEM_PROMPT = (
"You are a careful dataset writer. Given only the provided CONTEXT, craft high-quality, factual "
"questionanswer pairs for supervised fine-tuning. Answers must be grounded strictly in the context. "
"If the context lacks the answer, say 'INSUFFICIENT_CONTEXT'. Focus on clarity, specificity, and avoid hallucinations."
)
USER_PROMPT_TEMPLATE = (
"CONTEXT:\\n\\n{context}\\n\\n"
"Task: Produce {n} diverse Q&A pairs about the content above. "
"Use JSON lines (one JSON object per line) with keys: 'input' (question/instruction), 'output' (concise grounded answer), "
"'meta' (object with 'source_path', 'chunk_ids', and optional 'citations': list of quotes). "
"Do NOT include markdown; output JSON objects only."
)
def ollama_generate(prompt: str, model: str = GEN_MODEL, temperature: float = TEMPERATURE, num_predict: int = MAX_TOKENS_GEN) -> str:
payload = {
"model": model,
"prompt": prompt,
"system": SYSTEM_PROMPT,
"options": {
"temperature": temperature,
"num_predict": num_predict
},
"stream": False
}
r = requests.post(GEN_ENDPOINT, json=payload)
r.raise_for_status()
data = r.json()
return data.get("response", "")
def build_context(primary_idx: int, k: int = TOP_K) -> Tuple[str, List[str]]:
primary_chunk = all_chunks[primary_idx]
query = primary_chunk.text[:400] # use the start of the chunk as a pseudo-query
hits = search(query, k)
pieces, ids = [], []
for i, score in hits:
ch = all_chunks[i]
ids.append(ch.id)
pieces.append(f"[{Path(ch.doc_path).name}::{ch.start}-{ch.end}]\\n{ch.text}")
return "\\n\\n---\\n\\n".join(pieces), ids
def parse_llm_jsonl(text: str) -> List[Dict[str, Any]]:
rows = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
# be forgiving for trailing commas etc.
try:
obj = json.loads(line)
if isinstance(obj, dict):
rows.append(obj)
except Exception:
# try to salvage with regex for JSON-ish
try:
fixed = regex.sub(r",\\s*}", "}", line)
fixed = regex.sub(r",\\s*]", "]", fixed)
obj = json.loads(fixed)
if isinstance(obj, dict):
rows.append(obj)
except Exception:
pass
return rows
"""
)
)
# Sampling and synthesis loop
cells.append(
nbf.v4.new_markdown_cell(
"""
## 8) Generate the RAFT Dataset
This step iterates over documents, samples chunks, retrieves neighbors, and asks the model to produce JSONL rows.
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
def synthesize_dataset(samples_per_doc: int = SAMPLES_PER_DOC, out_path: Path = OUTPUT_DIR / "raft_dataset.jsonl") -> Path:
rng = random.Random(SEED)
doc_to_chunk_idx = {}
for i, ch in enumerate(all_chunks):
doc_to_chunk_idx.setdefault(ch.doc_path, []).append(i)
total_target = 0
with out_path.open("w", encoding="utf-8") as f:
for doc_path, idxs in doc_to_chunk_idx.items():
if not idxs:
continue
chosen = rng.sample(idxs, min(samples_per_doc, len(idxs)))
for pi in chosen:
ctx, ids = build_context(pi, k=TOP_K)
user = USER_PROMPT_TEMPLATE.format(context=ctx, n=3)
raw = ollama_generate(user, model=GEN_MODEL, temperature=TEMPERATURE, num_predict=MAX_TOKENS_GEN)
rows = parse_llm_jsonl(raw)
for r in rows:
# enforce schema & enrich meta
inp = r.get("input") or r.get("question") or r.get("query")
out = r.get("output") or r.get("answer") or r.get("response")
meta = r.get("meta") or {}
if not isinstance(meta, dict):
meta = {}
meta.update({
"source_path": str(doc_path),
"chunk_ids": ids,
"generated_at": datetime.utcnow().isoformat() + "Z",
"model": GEN_MODEL,
"embed_model": EMBED_MODEL
})
if inp and out:
obj = {"input": inp, "output": out, "meta": meta}
f.write(json.dumps(obj, ensure_ascii=False) + "\\n")
total_target += 1
print(f"[green]Wrote {total_target} rows -> {out_path}[/green]")
return out_path
OUT_JSONL = synthesize_dataset(samples_per_doc=SAMPLES_PER_DOC)
OUT_JSONL
"""
)
)
# Sanity check / preview
cells.append(
nbf.v4.new_markdown_cell(
"""
## 9) Preview Samples
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
from itertools import islice
def head_jsonl(p: Path, n: int = 5):
with p.open("r", encoding="utf-8") as f:
for line in islice(f, n):
print(line.rstrip())
head_jsonl(OUT_JSONL, 5)
"""
)
)
# Optional: small eval
cells.append(
nbf.v4.new_markdown_cell(
"""
## 10) Optional: Spot-Check Generation Quality
Run a tiny evaluation by asking the model with and without retrieval and compare answers.
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
EVAL_QUESTIONS = []
# Collect inputs from the dataset (first N)
with (OUTPUT_DIR / "raft_dataset.jsonl").open("r", encoding="utf-8") as f:
for i, line in enumerate(f):
try:
obj = json.loads(line)
EVAL_QUESTIONS.append(obj["input"])
except Exception:
pass
if len(EVAL_QUESTIONS) >= 5:
break
def rag_answer(q: str, k: int = TOP_K) -> str:
hits = search(q, k)
ctx = "\\n\\n".join([all_chunks[i].text for i,_ in hits])
user = f"Answer the question using ONLY this context. If missing, say INSUFFICIENT_CONTEXT.\\n\\nCONTEXT:\\n{ctx}\\n\\nQUESTION: {q}"
return ollama_generate(user, model=GEN_MODEL, temperature=0.2, num_predict=256)
for q in EVAL_QUESTIONS:
print("\\n[bold]Q:[/bold]", q)
ans = rag_answer(q)
print("[bold]A:[/bold]", ans.strip()[:500], "...")
"""
)
)
# Save artifacts list
cells.append(
nbf.v4.new_markdown_cell(
"""
## 11) Artifacts
- `outputs/raft_dataset.jsonl` — your RAFT dataset (input/output/meta per line)
- `corpus/` — your source documents (you provide)
- You can also persist `emb_matrix.npy` and a FAISS index for reuse.
"""
)
)
cells.append(
nbf.v4.new_code_cell(
"""
# Optionally persist embeddings and index for later reuse
np.save(OUTPUT_DIR / "emb_matrix.npy", emb_matrix)
if USE_FAISS:
import faiss
faiss.write_index(index, str(OUTPUT_DIR / "faiss.index"))
print("[green]Saved FAISS index and embeddings.[/green]")
else:
print("[yellow]FAISS disabled; only saved embeddings.[/yellow]")
"""
)
)
# Troubleshooting
cells.append(
nbf.v4.new_markdown_cell(
"""
## 12) Troubleshooting
- **Connection error to Ollama**: ensure `ollama serve` is running and models are pulled (`ollama pull nomic-embed-text`, `ollama pull llama3.1`).
- **Empty dataset**: your corpus may be too small or the parser skipped files. Check `corpus/` content and chunk parameters.
- **Hallucinations**: tighten the system prompt, lower temperature, or increase `TOP_K` and chunk size.
- **JSON parsing issues**: the notebook tries to be forgiving; you can harden `parse_llm_jsonl` per your needs.
- **PDFs**: `pip install pypdf` and try again.
"""
)
)
# Save the notebook
nb["cells"] = cells
out_path = "raft_ollama_dataset.ipynb"
with open(out_path, "w", encoding="utf-8") as f:
nbf.write(nb, f)
out_path