mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2025-12-06 10:10:50 +01:00
564 lines
16 KiB
Python
564 lines
16 KiB
Python
# ---
|
||
# jupyter:
|
||
# jupytext:
|
||
# text_representation:
|
||
# extension: .py
|
||
# format_name: percent
|
||
# format_version: '1.3'
|
||
# jupytext_version: 1.18.0
|
||
# kernelspec:
|
||
# display_name: .venv
|
||
# language: python
|
||
# name: python3
|
||
# ---
|
||
|
||
# %% [markdown]
|
||
# # Retrieval-Augmented **Fine-Tuning** (RAFT) Dataset Generation with **Local Ollama**
|
||
#
|
||
# This notebook builds a **supervised fine-tuning dataset** (JSONL) for _retrieval-augmented_ tasks, by:
|
||
#
|
||
# 1. **Ingesting** your local corpus (Markdown, text, HTML; PDFs optional with extra deps).
|
||
# 2. **Chunking** and **embedding** documents using Ollama's local **embedding model** (e.g., `nomic-embed-text`, `mxbai-embed-large`).
|
||
# 3. Building a **lightweight vector index** (FAISS).
|
||
# 4. **Sampling contexts** and using a local **generation model** via Ollama (e.g., `llama3.1`, `qwen2`, `phi3`) to synthesize **grounded Q&A** or instruction–response pairs.
|
||
# 5. Emitting a **RAFT-style JSONL** for supervised training (e.g., `input`, `output`, `meta` with source citations).
|
||
#
|
||
# > **Requirements**
|
||
# >
|
||
# > - Local [Ollama](https://ollama.com/) running at `http://localhost:11434`
|
||
# > - At least one **embedding** model pulled (e.g., `ollama pull nomic-embed-text`)
|
||
# > - At least one **generation** model pulled (e.g., `ollama pull llama3.1`)
|
||
# >
|
||
# > You can adapt the prompts and schema for your specific downstream trainer (Llama.cpp, vLLM, Axolotl, mlx, etc.).
|
||
#
|
||
|
||
# %% [markdown]
|
||
# ## 0) Setup
|
||
#
|
||
# Install Python dependencies. If you're offline, pre-install or remove what you don't need.
|
||
#
|
||
|
||
# %%
|
||
# If needed, uncomment:
|
||
# # %pip install --quiet requests faiss-cpu rich markdownify python-frontmatter pypdf regex
|
||
# Optional extras:
|
||
# # %pip install --quiet tiktoken beautifulsoup4 lxml
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 1) Configuration
|
||
#
|
||
# Set paths, models, and chunking/index params.
|
||
#
|
||
|
||
# %%
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, Tuple
|
||
import os, json, uuid, random
|
||
import hashlib
|
||
import requests
|
||
from rich import print
|
||
import regex
|
||
import numpy as np
|
||
|
||
# ---- Core config ----
|
||
DATA_DIR = Path("./corpus") # Put your source docs here
|
||
OUTPUT_DIR = Path("./outputs") # Where artifacts are saved
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Ollama endpoints & models
|
||
OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434")
|
||
EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text")
|
||
GEN_MODEL = os.environ.get("GEN_MODEL", "llama3.1:8b")
|
||
|
||
# Chunking
|
||
CHUNK_SIZE = 1200 # characters
|
||
CHUNK_OVERLAP = 200 # characters
|
||
MIN_CHARS = 200 # minimum viable chunk length
|
||
|
||
# Index
|
||
USE_FAISS = True
|
||
TOP_K = 4
|
||
|
||
# RAFT generation
|
||
SEED = 7
|
||
SAMPLES_PER_DOC = 4
|
||
MAX_TOKENS_GEN = 512 # Generation max tokens (approx; Ollama supports 'num_predict')
|
||
TEMPERATURE = 0.6
|
||
|
||
random.seed(SEED)
|
||
np.random.seed(SEED)
|
||
|
||
print(
|
||
{
|
||
"DATA_DIR": str(DATA_DIR.resolve()),
|
||
"OUTPUT_DIR": str(OUTPUT_DIR.resolve()),
|
||
"OLLAMA_URL": OLLAMA_URL,
|
||
"EMBED_MODEL": EMBED_MODEL,
|
||
"GEN_MODEL": GEN_MODEL,
|
||
"CHUNK_SIZE": CHUNK_SIZE,
|
||
"CHUNK_OVERLAP": CHUNK_OVERLAP,
|
||
"TOP_K": TOP_K,
|
||
"SAMPLES_PER_DOC": SAMPLES_PER_DOC,
|
||
}
|
||
)
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 2) Load & Normalize Documents
|
||
#
|
||
# Basic loaders for `.md`, `.txt`, `.html`. PDF support is optional (requires `pypdf`). You can extend as needed.
|
||
#
|
||
|
||
# %%
|
||
from bs4 import BeautifulSoup # if you didn't install bs4, comment HTML support below
|
||
|
||
try:
|
||
import frontmatter
|
||
except Exception:
|
||
frontmatter = None
|
||
|
||
|
||
def read_text_file(p: Path) -> str:
|
||
return p.read_text(encoding="utf-8", errors="ignore")
|
||
|
||
|
||
def read_markdown(p: Path) -> str:
|
||
text = p.read_text(encoding="utf-8", errors="ignore")
|
||
# Optional: strip YAML frontmatter
|
||
if frontmatter:
|
||
try:
|
||
fm = frontmatter.loads(text)
|
||
return fm.content
|
||
except Exception:
|
||
return text
|
||
return text
|
||
|
||
|
||
def read_html(p: Path) -> str:
|
||
html = p.read_text(encoding="utf-8", errors="ignore")
|
||
soup = BeautifulSoup(html, "lxml")
|
||
# Remove script/style
|
||
for tag in soup(["script", "style", "noscript"]):
|
||
tag.decompose()
|
||
text = soup.get_text(" ", strip=True)
|
||
return text
|
||
|
||
|
||
def read_pdf(p: Path) -> str:
|
||
try:
|
||
from pypdf import PdfReader
|
||
except Exception as e:
|
||
print(
|
||
"[yellow]Install pypdf to enable PDF parsing: %pip install pypdf[/yellow]"
|
||
)
|
||
raise e
|
||
reader = PdfReader(str(p))
|
||
parts = []
|
||
for page in reader.pages:
|
||
try:
|
||
parts.append(page.extract_text() or "")
|
||
except Exception:
|
||
parts.append("")
|
||
return "\n".join(parts)
|
||
|
||
|
||
SUPPORTED_EXTS = {
|
||
".txt": read_text_file,
|
||
".md": read_markdown,
|
||
".markdown": read_markdown,
|
||
".html": read_html,
|
||
".htm": read_html,
|
||
".pdf": read_pdf,
|
||
}
|
||
|
||
|
||
def load_corpus(data_dir: Path) -> Dict[str, str]:
|
||
docs = {}
|
||
for p in data_dir.rglob("*"):
|
||
if not p.is_file():
|
||
continue
|
||
fn = p.suffix.lower()
|
||
if fn in SUPPORTED_EXTS:
|
||
try:
|
||
docs[str(p)] = SUPPORTED_EXTS[fn](p)
|
||
except Exception as e:
|
||
print(f"[red]Failed to read {p}: {e}[/red]")
|
||
print(f"[green]Loaded {len(docs)} documents[/green]")
|
||
return docs
|
||
|
||
|
||
docs = load_corpus(DATA_DIR)
|
||
len(docs)
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 3) Chunking
|
||
#
|
||
# Simple character-based chunker with overlap. Swap in a token-based chunker if you prefer.
|
||
#
|
||
|
||
# %%
|
||
@dataclass
|
||
class Chunk:
|
||
id: str
|
||
doc_path: str
|
||
start: int
|
||
end: int
|
||
text: str
|
||
sha1: str
|
||
|
||
|
||
def chunk_text(
|
||
text: str, doc_path: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP
|
||
) -> List[Chunk]:
|
||
chunks: List[Chunk] = []
|
||
i = 0
|
||
n = len(text)
|
||
while i < n:
|
||
j = min(i + chunk_size, n)
|
||
piece = text[i:j].strip()
|
||
if len(piece) >= MIN_CHARS:
|
||
sha1 = hashlib.sha1(piece.encode("utf-8")).hexdigest()
|
||
chunks.append(
|
||
Chunk(
|
||
id=str(uuid.uuid4()),
|
||
doc_path=doc_path,
|
||
start=i,
|
||
end=j,
|
||
text=piece,
|
||
sha1=sha1,
|
||
)
|
||
)
|
||
if j == n:
|
||
break
|
||
i = j - overlap
|
||
if i < 0:
|
||
i = 0
|
||
if i >= n:
|
||
break
|
||
return chunks
|
||
|
||
|
||
all_chunks: List[Chunk] = []
|
||
for path, text in docs.items():
|
||
all_chunks.extend(chunk_text(text, path))
|
||
|
||
print(f"[green]Total chunks: {len(all_chunks)}[/green]")
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 4) Embeddings via Ollama
|
||
#
|
||
# Uses Ollama's `POST /api/embeddings` endpoint with your selected embedding model.
|
||
# Make sure you've pulled it locally: `ollama pull nomic-embed-text` (or your chosen model).
|
||
#
|
||
|
||
# %%
|
||
EMBED_ENDPOINT = f"{OLLAMA_URL}/api/embeddings"
|
||
|
||
|
||
def embed_texts(
|
||
texts: List[str], model: str = EMBED_MODEL, batch_size: int = 32
|
||
) -> np.ndarray:
|
||
vectors = []
|
||
for i in range(0, len(texts), batch_size):
|
||
batch = texts[i : i + batch_size]
|
||
# Ollama supports a single prompt or list? We'll call one by one to be safe with large content.
|
||
for t in batch:
|
||
r = requests.post(EMBED_ENDPOINT, json={"model": model, "prompt": t})
|
||
r.raise_for_status()
|
||
data = r.json()
|
||
vec = np.array(data["embedding"], dtype=np.float32)
|
||
vectors.append(vec)
|
||
return np.vstack(vectors) if vectors else np.zeros((0, 768), dtype=np.float32)
|
||
|
||
|
||
chunk_texts = [c.text for c in all_chunks]
|
||
emb_matrix = embed_texts(chunk_texts, model=EMBED_MODEL, batch_size=8)
|
||
emb_matrix.shape
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 5) Build Vector Index (FAISS)
|
||
#
|
||
# We normalize vectors and use inner product (equivalent to cosine on normalized vectors).
|
||
#
|
||
|
||
# %%
|
||
def normalize_rows(x: np.ndarray) -> np.ndarray:
|
||
norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
|
||
return x / norms
|
||
|
||
|
||
if USE_FAISS:
|
||
import faiss
|
||
|
||
xb = normalize_rows(emb_matrix).astype(np.float32)
|
||
d = xb.shape[1]
|
||
index = faiss.IndexFlatIP(d)
|
||
index.add(xb)
|
||
print("[green]FAISS index built:[/green]", index.ntotal, "vectors")
|
||
else:
|
||
index = None
|
||
xb = normalize_rows(emb_matrix).astype(np.float32)
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 6) Retrieval Helper
|
||
#
|
||
|
||
# %%
|
||
def search(query: str, top_k: int = TOP_K) -> List[Tuple[int, float]]:
|
||
# Embed the query
|
||
qv = embed_texts([query], model=EMBED_MODEL, batch_size=1)
|
||
qv = normalize_rows(qv).astype(np.float32)
|
||
if USE_FAISS and index is not None:
|
||
D, I = index.search(qv, top_k)
|
||
hits = list(zip(I[0].tolist(), D[0].tolist()))
|
||
else:
|
||
sims = (xb @ qv.T).ravel()
|
||
I = np.argsort(-sims)[:top_k]
|
||
hits = [(int(i), float(sims[i])) for i in I]
|
||
return hits
|
||
|
||
|
||
# quick smoke test (no error means it's wired up)
|
||
print(search("What does this corpus talk about?", 3))
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 7) Synthesize Grounded Q&A / Instructions with Ollama
|
||
#
|
||
# We sample chunks, retrieve neighbors for richer context, and prompt a local LLM to create **high-quality** pairs.
|
||
#
|
||
|
||
# %%
|
||
GEN_ENDPOINT = f"{OLLAMA_URL}/api/generate"
|
||
|
||
SYSTEM_PROMPT = (
|
||
"You are a careful dataset writer. Given only the provided CONTEXT, craft high-quality, factual "
|
||
"question–answer pairs for supervised fine-tuning. Answers must be grounded strictly in the context. "
|
||
"If the context lacks the answer, say 'INSUFFICIENT_CONTEXT'. Focus on clarity, specificity, and avoid hallucinations."
|
||
)
|
||
|
||
USER_PROMPT_TEMPLATE = (
|
||
"CONTEXT:\n\n{context}\n\n"
|
||
"Task: Produce {n} diverse Q&A pairs about the content above. "
|
||
"Use JSON lines (one JSON object per line) with keys: 'input' (question/instruction), 'output' (concise grounded answer), "
|
||
"'meta' (object with 'source_path', 'chunk_ids', and optional 'citations': list of quotes). "
|
||
"Do NOT include markdown; output JSON objects only."
|
||
)
|
||
|
||
|
||
def ollama_generate(
|
||
prompt: str,
|
||
model: str = GEN_MODEL,
|
||
temperature: float = TEMPERATURE,
|
||
num_predict: int = MAX_TOKENS_GEN,
|
||
) -> str:
|
||
payload = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"system": SYSTEM_PROMPT,
|
||
"options": {"temperature": temperature, "num_predict": num_predict},
|
||
"stream": False,
|
||
}
|
||
r = requests.post(GEN_ENDPOINT, json=payload)
|
||
r.raise_for_status()
|
||
data = r.json()
|
||
return data.get("response", "")
|
||
|
||
|
||
def build_context(primary_idx: int, k: int = TOP_K) -> Tuple[str, List[str]]:
|
||
primary_chunk = all_chunks[primary_idx]
|
||
query = primary_chunk.text[:400] # use the start of the chunk as a pseudo-query
|
||
hits = search(query, k)
|
||
pieces, ids = [], []
|
||
for i, score in hits:
|
||
ch = all_chunks[i]
|
||
ids.append(ch.id)
|
||
pieces.append(f"[{Path(ch.doc_path).name}::{ch.start}-{ch.end}]\n{ch.text}")
|
||
return "\n\n---\n\n".join(pieces), ids
|
||
|
||
|
||
def parse_llm_jsonl(text: str) -> List[Dict[str, Any]]:
|
||
rows = []
|
||
for line in text.splitlines():
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
# be forgiving for trailing commas etc.
|
||
try:
|
||
obj = json.loads(line)
|
||
if isinstance(obj, dict):
|
||
rows.append(obj)
|
||
except Exception:
|
||
# try to salvage with regex for JSON-ish
|
||
try:
|
||
fixed = regex.sub(r",\s*}", "}", line)
|
||
fixed = regex.sub(r",\s*]", "]", fixed)
|
||
obj = json.loads(fixed)
|
||
if isinstance(obj, dict):
|
||
rows.append(obj)
|
||
except Exception:
|
||
pass
|
||
return rows
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 8) Generate the RAFT Dataset
|
||
#
|
||
# This step iterates over documents, samples chunks, retrieves neighbors, and asks the model to produce JSONL rows.
|
||
#
|
||
|
||
# %%
|
||
import datetime
|
||
import time
|
||
|
||
|
||
def synthesize_dataset(
|
||
samples_per_doc: int = SAMPLES_PER_DOC,
|
||
out_path: Path = OUTPUT_DIR / "raft_dataset.jsonl",
|
||
) -> Path:
|
||
rng = random.Random(SEED)
|
||
doc_to_chunk_idx = {}
|
||
for i, ch in enumerate(all_chunks):
|
||
doc_to_chunk_idx.setdefault(ch.doc_path, []).append(i)
|
||
|
||
total_target = 0
|
||
with out_path.open("w", encoding="utf-8") as f:
|
||
for doc_path, idxs in doc_to_chunk_idx.items():
|
||
print(f"[blue]Synthesizing for: {doc_path} ({len(idxs)} chunks)[/blue]")
|
||
doc_idx = list(doc_to_chunk_idx.keys()).index(doc_path)
|
||
total_docs = len(doc_to_chunk_idx)
|
||
percent = (doc_idx + 1) / total_docs * 100
|
||
print(
|
||
f"[cyan]Progress: {doc_idx + 1}/{total_docs} ({percent:.1f}%) completed[/cyan]"
|
||
)
|
||
if not idxs:
|
||
continue
|
||
chosen = rng.sample(idxs, min(samples_per_doc, len(idxs)))
|
||
for pi in chosen:
|
||
ctx, ids = build_context(pi, k=TOP_K)
|
||
user = USER_PROMPT_TEMPLATE.format(context=ctx, n=3)
|
||
raw = ollama_generate(
|
||
user,
|
||
model=GEN_MODEL,
|
||
temperature=TEMPERATURE,
|
||
num_predict=MAX_TOKENS_GEN,
|
||
)
|
||
rows = parse_llm_jsonl(raw)
|
||
for r in rows:
|
||
# enforce schema & enrich meta
|
||
inp = r.get("input") or r.get("question") or r.get("query")
|
||
out = r.get("output") or r.get("answer") or r.get("response")
|
||
meta = r.get("meta") or {}
|
||
if not isinstance(meta, dict):
|
||
meta = {}
|
||
meta.update(
|
||
{
|
||
"source_path": str(doc_path),
|
||
"chunk_ids": ids,
|
||
"generated_at": datetime.datetime.fromtimestamp(
|
||
time.time()
|
||
).strftime("%Y-%m-%d %H:%M:%S"),
|
||
"model": GEN_MODEL,
|
||
"embed_model": EMBED_MODEL,
|
||
}
|
||
)
|
||
if inp and out:
|
||
obj = {"input": inp, "output": out, "meta": meta}
|
||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||
total_target += 1
|
||
print(f"[green]Wrote {total_target} rows -> {out_path}[/green]")
|
||
return out_path
|
||
|
||
|
||
OUT_JSONL = synthesize_dataset(samples_per_doc=SAMPLES_PER_DOC)
|
||
OUT_JSONL
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 9) Preview Samples
|
||
#
|
||
|
||
# %%
|
||
from itertools import islice
|
||
|
||
|
||
def head_jsonl(p: Path, n: int = 5):
|
||
with p.open("r", encoding="utf-8") as f:
|
||
for line in islice(f, n):
|
||
print(line.rstrip())
|
||
|
||
|
||
head_jsonl(OUT_JSONL, 5)
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 10) Optional: Spot-Check Generation Quality
|
||
#
|
||
# Run a tiny evaluation by asking the model with and without retrieval and compare answers.
|
||
#
|
||
|
||
# %%
|
||
EVAL_QUESTIONS = []
|
||
|
||
# Collect inputs from the dataset (first N)
|
||
with (OUTPUT_DIR / "raft_dataset.jsonl").open("r", encoding="utf-8") as f:
|
||
for i, line in enumerate(f):
|
||
try:
|
||
obj = json.loads(line)
|
||
EVAL_QUESTIONS.append(obj["input"])
|
||
except Exception:
|
||
pass
|
||
if len(EVAL_QUESTIONS) >= 5:
|
||
break
|
||
|
||
|
||
def rag_answer(q: str, k: int = TOP_K) -> str:
|
||
hits = search(q, k)
|
||
ctx = "\n\n".join([all_chunks[i].text for i, _ in hits])
|
||
user = f"Answer the question using ONLY this context. If missing, say INSUFFICIENT_CONTEXT.\n\nCONTEXT:\n{ctx}\n\nQUESTION: {q}"
|
||
return ollama_generate(user, model=GEN_MODEL, temperature=0.2, num_predict=256)
|
||
|
||
|
||
for q in EVAL_QUESTIONS:
|
||
print("\n[bold]Q:[/bold]", q)
|
||
ans = rag_answer(q)
|
||
print("[bold]A:[/bold]", ans.strip()[:500], "...")
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 11) Artifacts
|
||
#
|
||
# - `outputs/raft_dataset.jsonl` — your RAFT dataset (input/output/meta per line)
|
||
# - `corpus/` — your source documents (you provide)
|
||
# - You can also persist `emb_matrix.npy` and a FAISS index for reuse.
|
||
#
|
||
|
||
# %%
|
||
# Optionally persist embeddings and index for later reuse
|
||
np.save(OUTPUT_DIR / "emb_matrix.npy", emb_matrix)
|
||
|
||
if USE_FAISS:
|
||
import faiss
|
||
|
||
faiss.write_index(index, str(OUTPUT_DIR / "faiss.index"))
|
||
print("[green]Saved FAISS index and embeddings.[/green]")
|
||
else:
|
||
print("[yellow]FAISS disabled; only saved embeddings.[/yellow]")
|
||
|
||
|
||
# %% [markdown]
|
||
# ## 12) Troubleshooting
|
||
#
|
||
# - **Connection error to Ollama**: ensure `ollama serve` is running and models are pulled (`ollama pull nomic-embed-text`, `ollama pull llama3.1`).
|
||
# - **Empty dataset**: your corpus may be too small or the parser skipped files. Check `corpus/` content and chunk parameters.
|
||
# - **Hallucinations**: tighten the system prompt, lower temperature, or increase `TOP_K` and chunk size.
|
||
# - **JSON parsing issues**: the notebook tries to be forgiving; you can harden `parse_llm_jsonl` per your needs.
|
||
# - **PDFs**: `pip install pypdf` and try again.
|
||
#
|