Files
masterthesis-playground/raft/nb_raft_ollama_dataset.py
2025-10-20 23:06:52 +02:00

564 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.18.0
# kernelspec:
# display_name: .venv
# language: python
# name: python3
# ---
# %% [markdown]
# # Retrieval-Augmented **Fine-Tuning** (RAFT) Dataset Generation with **Local Ollama**
#
# This notebook builds a **supervised fine-tuning dataset** (JSONL) for _retrieval-augmented_ tasks, by:
#
# 1. **Ingesting** your local corpus (Markdown, text, HTML; PDFs optional with extra deps).
# 2. **Chunking** and **embedding** documents using Ollama's local **embedding model** (e.g., `nomic-embed-text`, `mxbai-embed-large`).
# 3. Building a **lightweight vector index** (FAISS).
# 4. **Sampling contexts** and using a local **generation model** via Ollama (e.g., `llama3.1`, `qwen2`, `phi3`) to synthesize **grounded Q&A** or instructionresponse pairs.
# 5. Emitting a **RAFT-style JSONL** for supervised training (e.g., `input`, `output`, `meta` with source citations).
#
# > **Requirements**
# >
# > - Local [Ollama](https://ollama.com/) running at `http://localhost:11434`
# > - At least one **embedding** model pulled (e.g., `ollama pull nomic-embed-text`)
# > - At least one **generation** model pulled (e.g., `ollama pull llama3.1`)
# >
# > You can adapt the prompts and schema for your specific downstream trainer (Llama.cpp, vLLM, Axolotl, mlx, etc.).
#
# %% [markdown]
# ## 0) Setup
#
# Install Python dependencies. If you're offline, pre-install or remove what you don't need.
#
# %%
# If needed, uncomment:
# # %pip install --quiet requests faiss-cpu rich markdownify python-frontmatter pypdf regex
# Optional extras:
# # %pip install --quiet tiktoken beautifulsoup4 lxml
# %% [markdown]
# ## 1) Configuration
#
# Set paths, models, and chunking/index params.
#
# %%
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Any, Tuple
import os, json, uuid, random
import hashlib
import requests
from rich import print
import regex
import numpy as np
# ---- Core config ----
DATA_DIR = Path("./corpus") # Put your source docs here
OUTPUT_DIR = Path("./outputs") # Where artifacts are saved
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# Ollama endpoints & models
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text")
GEN_MODEL = os.environ.get("GEN_MODEL", "llama3.1:8b")
# Chunking
CHUNK_SIZE = 1200 # characters
CHUNK_OVERLAP = 200 # characters
MIN_CHARS = 200 # minimum viable chunk length
# Index
USE_FAISS = True
TOP_K = 4
# RAFT generation
SEED = 7
SAMPLES_PER_DOC = 4
MAX_TOKENS_GEN = 512 # Generation max tokens (approx; Ollama supports 'num_predict')
TEMPERATURE = 0.6
random.seed(SEED)
np.random.seed(SEED)
print(
{
"DATA_DIR": str(DATA_DIR.resolve()),
"OUTPUT_DIR": str(OUTPUT_DIR.resolve()),
"OLLAMA_URL": OLLAMA_URL,
"EMBED_MODEL": EMBED_MODEL,
"GEN_MODEL": GEN_MODEL,
"CHUNK_SIZE": CHUNK_SIZE,
"CHUNK_OVERLAP": CHUNK_OVERLAP,
"TOP_K": TOP_K,
"SAMPLES_PER_DOC": SAMPLES_PER_DOC,
}
)
# %% [markdown]
# ## 2) Load & Normalize Documents
#
# Basic loaders for `.md`, `.txt`, `.html`. PDF support is optional (requires `pypdf`). You can extend as needed.
#
# %%
from bs4 import BeautifulSoup # if you didn't install bs4, comment HTML support below
try:
import frontmatter
except Exception:
frontmatter = None
def read_text_file(p: Path) -> str:
return p.read_text(encoding="utf-8", errors="ignore")
def read_markdown(p: Path) -> str:
text = p.read_text(encoding="utf-8", errors="ignore")
# Optional: strip YAML frontmatter
if frontmatter:
try:
fm = frontmatter.loads(text)
return fm.content
except Exception:
return text
return text
def read_html(p: Path) -> str:
html = p.read_text(encoding="utf-8", errors="ignore")
soup = BeautifulSoup(html, "lxml")
# Remove script/style
for tag in soup(["script", "style", "noscript"]):
tag.decompose()
text = soup.get_text(" ", strip=True)
return text
def read_pdf(p: Path) -> str:
try:
from pypdf import PdfReader
except Exception as e:
print(
"[yellow]Install pypdf to enable PDF parsing: %pip install pypdf[/yellow]"
)
raise e
reader = PdfReader(str(p))
parts = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
parts.append("")
return "\n".join(parts)
SUPPORTED_EXTS = {
".txt": read_text_file,
".md": read_markdown,
".markdown": read_markdown,
".html": read_html,
".htm": read_html,
".pdf": read_pdf,
}
def load_corpus(data_dir: Path) -> Dict[str, str]:
docs = {}
for p in data_dir.rglob("*"):
if not p.is_file():
continue
fn = p.suffix.lower()
if fn in SUPPORTED_EXTS:
try:
docs[str(p)] = SUPPORTED_EXTS[fn](p)
except Exception as e:
print(f"[red]Failed to read {p}: {e}[/red]")
print(f"[green]Loaded {len(docs)} documents[/green]")
return docs
docs = load_corpus(DATA_DIR)
len(docs)
# %% [markdown]
# ## 3) Chunking
#
# Simple character-based chunker with overlap. Swap in a token-based chunker if you prefer.
#
# %%
@dataclass
class Chunk:
id: str
doc_path: str
start: int
end: int
text: str
sha1: str
def chunk_text(
text: str, doc_path: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP
) -> List[Chunk]:
chunks: List[Chunk] = []
i = 0
n = len(text)
while i < n:
j = min(i + chunk_size, n)
piece = text[i:j].strip()
if len(piece) >= MIN_CHARS:
sha1 = hashlib.sha1(piece.encode("utf-8")).hexdigest()
chunks.append(
Chunk(
id=str(uuid.uuid4()),
doc_path=doc_path,
start=i,
end=j,
text=piece,
sha1=sha1,
)
)
if j == n:
break
i = j - overlap
if i < 0:
i = 0
if i >= n:
break
return chunks
all_chunks: List[Chunk] = []
for path, text in docs.items():
all_chunks.extend(chunk_text(text, path))
print(f"[green]Total chunks: {len(all_chunks)}[/green]")
# %% [markdown]
# ## 4) Embeddings via Ollama
#
# Uses Ollama's `POST /api/embeddings` endpoint with your selected embedding model.
# Make sure you've pulled it locally: `ollama pull nomic-embed-text` (or your chosen model).
#
# %%
EMBED_ENDPOINT = f"{OLLAMA_URL}/api/embeddings"
def embed_texts(
texts: List[str], model: str = EMBED_MODEL, batch_size: int = 32
) -> np.ndarray:
vectors = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
# Ollama supports a single prompt or list? We'll call one by one to be safe with large content.
for t in batch:
r = requests.post(EMBED_ENDPOINT, json={"model": model, "prompt": t})
r.raise_for_status()
data = r.json()
vec = np.array(data["embedding"], dtype=np.float32)
vectors.append(vec)
return np.vstack(vectors) if vectors else np.zeros((0, 768), dtype=np.float32)
chunk_texts = [c.text for c in all_chunks]
emb_matrix = embed_texts(chunk_texts, model=EMBED_MODEL, batch_size=8)
emb_matrix.shape
# %% [markdown]
# ## 5) Build Vector Index (FAISS)
#
# We normalize vectors and use inner product (equivalent to cosine on normalized vectors).
#
# %%
def normalize_rows(x: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
return x / norms
if USE_FAISS:
import faiss
xb = normalize_rows(emb_matrix).astype(np.float32)
d = xb.shape[1]
index = faiss.IndexFlatIP(d)
index.add(xb)
print("[green]FAISS index built:[/green]", index.ntotal, "vectors")
else:
index = None
xb = normalize_rows(emb_matrix).astype(np.float32)
# %% [markdown]
# ## 6) Retrieval Helper
#
# %%
def search(query: str, top_k: int = TOP_K) -> List[Tuple[int, float]]:
# Embed the query
qv = embed_texts([query], model=EMBED_MODEL, batch_size=1)
qv = normalize_rows(qv).astype(np.float32)
if USE_FAISS and index is not None:
D, I = index.search(qv, top_k)
hits = list(zip(I[0].tolist(), D[0].tolist()))
else:
sims = (xb @ qv.T).ravel()
I = np.argsort(-sims)[:top_k]
hits = [(int(i), float(sims[i])) for i in I]
return hits
# quick smoke test (no error means it's wired up)
print(search("What does this corpus talk about?", 3))
# %% [markdown]
# ## 7) Synthesize Grounded Q&A / Instructions with Ollama
#
# We sample chunks, retrieve neighbors for richer context, and prompt a local LLM to create **high-quality** pairs.
#
# %%
GEN_ENDPOINT = f"{OLLAMA_URL}/api/generate"
SYSTEM_PROMPT = (
"You are a careful dataset writer. Given only the provided CONTEXT, craft high-quality, factual "
"questionanswer pairs for supervised fine-tuning. Answers must be grounded strictly in the context. "
"If the context lacks the answer, say 'INSUFFICIENT_CONTEXT'. Focus on clarity, specificity, and avoid hallucinations."
)
USER_PROMPT_TEMPLATE = (
"CONTEXT:\n\n{context}\n\n"
"Task: Produce {n} diverse Q&A pairs about the content above. "
"Use JSON lines (one JSON object per line) with keys: 'input' (question/instruction), 'output' (concise grounded answer), "
"'meta' (object with 'source_path', 'chunk_ids', and optional 'citations': list of quotes). "
"Do NOT include markdown; output JSON objects only."
)
def ollama_generate(
prompt: str,
model: str = GEN_MODEL,
temperature: float = TEMPERATURE,
num_predict: int = MAX_TOKENS_GEN,
) -> str:
payload = {
"model": model,
"prompt": prompt,
"system": SYSTEM_PROMPT,
"options": {"temperature": temperature, "num_predict": num_predict},
"stream": False,
}
r = requests.post(GEN_ENDPOINT, json=payload)
r.raise_for_status()
data = r.json()
return data.get("response", "")
def build_context(primary_idx: int, k: int = TOP_K) -> Tuple[str, List[str]]:
primary_chunk = all_chunks[primary_idx]
query = primary_chunk.text[:400] # use the start of the chunk as a pseudo-query
hits = search(query, k)
pieces, ids = [], []
for i, score in hits:
ch = all_chunks[i]
ids.append(ch.id)
pieces.append(f"[{Path(ch.doc_path).name}::{ch.start}-{ch.end}]\n{ch.text}")
return "\n\n---\n\n".join(pieces), ids
def parse_llm_jsonl(text: str) -> List[Dict[str, Any]]:
rows = []
for line in text.splitlines():
line = line.strip()
if not line:
continue
# be forgiving for trailing commas etc.
try:
obj = json.loads(line)
if isinstance(obj, dict):
rows.append(obj)
except Exception:
# try to salvage with regex for JSON-ish
try:
fixed = regex.sub(r",\s*}", "}", line)
fixed = regex.sub(r",\s*]", "]", fixed)
obj = json.loads(fixed)
if isinstance(obj, dict):
rows.append(obj)
except Exception:
pass
return rows
# %% [markdown]
# ## 8) Generate the RAFT Dataset
#
# This step iterates over documents, samples chunks, retrieves neighbors, and asks the model to produce JSONL rows.
#
# %%
import datetime
import time
def synthesize_dataset(
samples_per_doc: int = SAMPLES_PER_DOC,
out_path: Path = OUTPUT_DIR / "raft_dataset.jsonl",
) -> Path:
rng = random.Random(SEED)
doc_to_chunk_idx = {}
for i, ch in enumerate(all_chunks):
doc_to_chunk_idx.setdefault(ch.doc_path, []).append(i)
total_target = 0
with out_path.open("w", encoding="utf-8") as f:
for doc_path, idxs in doc_to_chunk_idx.items():
print(f"[blue]Synthesizing for: {doc_path} ({len(idxs)} chunks)[/blue]")
doc_idx = list(doc_to_chunk_idx.keys()).index(doc_path)
total_docs = len(doc_to_chunk_idx)
percent = (doc_idx + 1) / total_docs * 100
print(
f"[cyan]Progress: {doc_idx + 1}/{total_docs} ({percent:.1f}%) completed[/cyan]"
)
if not idxs:
continue
chosen = rng.sample(idxs, min(samples_per_doc, len(idxs)))
for pi in chosen:
ctx, ids = build_context(pi, k=TOP_K)
user = USER_PROMPT_TEMPLATE.format(context=ctx, n=3)
raw = ollama_generate(
user,
model=GEN_MODEL,
temperature=TEMPERATURE,
num_predict=MAX_TOKENS_GEN,
)
rows = parse_llm_jsonl(raw)
for r in rows:
# enforce schema & enrich meta
inp = r.get("input") or r.get("question") or r.get("query")
out = r.get("output") or r.get("answer") or r.get("response")
meta = r.get("meta") or {}
if not isinstance(meta, dict):
meta = {}
meta.update(
{
"source_path": str(doc_path),
"chunk_ids": ids,
"generated_at": datetime.datetime.fromtimestamp(
time.time()
).strftime("%Y-%m-%d %H:%M:%S"),
"model": GEN_MODEL,
"embed_model": EMBED_MODEL,
}
)
if inp and out:
obj = {"input": inp, "output": out, "meta": meta}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
total_target += 1
print(f"[green]Wrote {total_target} rows -> {out_path}[/green]")
return out_path
OUT_JSONL = synthesize_dataset(samples_per_doc=SAMPLES_PER_DOC)
OUT_JSONL
# %% [markdown]
# ## 9) Preview Samples
#
# %%
from itertools import islice
def head_jsonl(p: Path, n: int = 5):
with p.open("r", encoding="utf-8") as f:
for line in islice(f, n):
print(line.rstrip())
head_jsonl(OUT_JSONL, 5)
# %% [markdown]
# ## 10) Optional: Spot-Check Generation Quality
#
# Run a tiny evaluation by asking the model with and without retrieval and compare answers.
#
# %%
EVAL_QUESTIONS = []
# Collect inputs from the dataset (first N)
with (OUTPUT_DIR / "raft_dataset.jsonl").open("r", encoding="utf-8") as f:
for i, line in enumerate(f):
try:
obj = json.loads(line)
EVAL_QUESTIONS.append(obj["input"])
except Exception:
pass
if len(EVAL_QUESTIONS) >= 5:
break
def rag_answer(q: str, k: int = TOP_K) -> str:
hits = search(q, k)
ctx = "\n\n".join([all_chunks[i].text for i, _ in hits])
user = f"Answer the question using ONLY this context. If missing, say INSUFFICIENT_CONTEXT.\n\nCONTEXT:\n{ctx}\n\nQUESTION: {q}"
return ollama_generate(user, model=GEN_MODEL, temperature=0.2, num_predict=256)
for q in EVAL_QUESTIONS:
print("\n[bold]Q:[/bold]", q)
ans = rag_answer(q)
print("[bold]A:[/bold]", ans.strip()[:500], "...")
# %% [markdown]
# ## 11) Artifacts
#
# - `outputs/raft_dataset.jsonl` — your RAFT dataset (input/output/meta per line)
# - `corpus/` — your source documents (you provide)
# - You can also persist `emb_matrix.npy` and a FAISS index for reuse.
#
# %%
# Optionally persist embeddings and index for later reuse
np.save(OUTPUT_DIR / "emb_matrix.npy", emb_matrix)
if USE_FAISS:
import faiss
faiss.write_index(index, str(OUTPUT_DIR / "faiss.index"))
print("[green]Saved FAISS index and embeddings.[/green]")
else:
print("[yellow]FAISS disabled; only saved embeddings.[/yellow]")
# %% [markdown]
# ## 12) Troubleshooting
#
# - **Connection error to Ollama**: ensure `ollama serve` is running and models are pulled (`ollama pull nomic-embed-text`, `ollama pull llama3.1`).
# - **Empty dataset**: your corpus may be too small or the parser skipped files. Check `corpus/` content and chunk parameters.
# - **Hallucinations**: tighten the system prompt, lower temperature, or increase `TOP_K` and chunk size.
# - **JSON parsing issues**: the notebook tries to be forgiving; you can harden `parse_llm_jsonl` per your needs.
# - **PDFs**: `pip install pypdf` and try again.
#