# --- # jupyter: # jupytext: # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.18.0 # kernelspec: # display_name: .venv # language: python # name: python3 # --- # %% [markdown] # # Retrieval-Augmented **Fine-Tuning** (RAFT) Dataset Generation with **Local Ollama** # # This notebook builds a **supervised fine-tuning dataset** (JSONL) for _retrieval-augmented_ tasks, by: # # 1. **Ingesting** your local corpus (Markdown, text, HTML; PDFs optional with extra deps). # 2. **Chunking** and **embedding** documents using Ollama's local **embedding model** (e.g., `nomic-embed-text`, `mxbai-embed-large`). # 3. Building a **lightweight vector index** (FAISS). # 4. **Sampling contexts** and using a local **generation model** via Ollama (e.g., `llama3.1`, `qwen2`, `phi3`) to synthesize **grounded Q&A** or instruction–response pairs. # 5. Emitting a **RAFT-style JSONL** for supervised training (e.g., `input`, `output`, `meta` with source citations). # # > **Requirements** # > # > - Local [Ollama](https://ollama.com/) running at `http://localhost:11434` # > - At least one **embedding** model pulled (e.g., `ollama pull nomic-embed-text`) # > - At least one **generation** model pulled (e.g., `ollama pull llama3.1`) # > # > You can adapt the prompts and schema for your specific downstream trainer (Llama.cpp, vLLM, Axolotl, mlx, etc.). # # %% [markdown] # ## 0) Setup # # Install Python dependencies. If you're offline, pre-install or remove what you don't need. # # %% # If needed, uncomment: # # %pip install --quiet requests faiss-cpu rich markdownify python-frontmatter pypdf regex # Optional extras: # # %pip install --quiet tiktoken beautifulsoup4 lxml # %% [markdown] # ## 1) Configuration # # Set paths, models, and chunking/index params. # # %% from dataclasses import dataclass from pathlib import Path from typing import List, Dict, Any, Tuple import os, json, uuid, random import hashlib import requests from rich import print import regex import numpy as np # ---- Core config ---- DATA_DIR = Path("./corpus") # Put your source docs here OUTPUT_DIR = Path("./outputs") # Where artifacts are saved OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # Ollama endpoints & models OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://localhost:11434") EMBED_MODEL = os.environ.get("EMBED_MODEL", "nomic-embed-text") GEN_MODEL = os.environ.get("GEN_MODEL", "llama3.1:8b") # Chunking CHUNK_SIZE = 1200 # characters CHUNK_OVERLAP = 200 # characters MIN_CHARS = 200 # minimum viable chunk length # Index USE_FAISS = True TOP_K = 4 # RAFT generation SEED = 7 SAMPLES_PER_DOC = 4 MAX_TOKENS_GEN = 512 # Generation max tokens (approx; Ollama supports 'num_predict') TEMPERATURE = 0.6 random.seed(SEED) np.random.seed(SEED) print( { "DATA_DIR": str(DATA_DIR.resolve()), "OUTPUT_DIR": str(OUTPUT_DIR.resolve()), "OLLAMA_URL": OLLAMA_URL, "EMBED_MODEL": EMBED_MODEL, "GEN_MODEL": GEN_MODEL, "CHUNK_SIZE": CHUNK_SIZE, "CHUNK_OVERLAP": CHUNK_OVERLAP, "TOP_K": TOP_K, "SAMPLES_PER_DOC": SAMPLES_PER_DOC, } ) # %% [markdown] # ## 2) Load & Normalize Documents # # Basic loaders for `.md`, `.txt`, `.html`. PDF support is optional (requires `pypdf`). You can extend as needed. # # %% from bs4 import BeautifulSoup # if you didn't install bs4, comment HTML support below try: import frontmatter except Exception: frontmatter = None def read_text_file(p: Path) -> str: return p.read_text(encoding="utf-8", errors="ignore") def read_markdown(p: Path) -> str: text = p.read_text(encoding="utf-8", errors="ignore") # Optional: strip YAML frontmatter if frontmatter: try: fm = frontmatter.loads(text) return fm.content except Exception: return text return text def read_html(p: Path) -> str: html = p.read_text(encoding="utf-8", errors="ignore") soup = BeautifulSoup(html, "lxml") # Remove script/style for tag in soup(["script", "style", "noscript"]): tag.decompose() text = soup.get_text(" ", strip=True) return text def read_pdf(p: Path) -> str: try: from pypdf import PdfReader except Exception as e: print( "[yellow]Install pypdf to enable PDF parsing: %pip install pypdf[/yellow]" ) raise e reader = PdfReader(str(p)) parts = [] for page in reader.pages: try: parts.append(page.extract_text() or "") except Exception: parts.append("") return "\n".join(parts) SUPPORTED_EXTS = { ".txt": read_text_file, ".md": read_markdown, ".markdown": read_markdown, ".html": read_html, ".htm": read_html, ".pdf": read_pdf, } def load_corpus(data_dir: Path) -> Dict[str, str]: docs = {} for p in data_dir.rglob("*"): if not p.is_file(): continue fn = p.suffix.lower() if fn in SUPPORTED_EXTS: try: docs[str(p)] = SUPPORTED_EXTS[fn](p) except Exception as e: print(f"[red]Failed to read {p}: {e}[/red]") print(f"[green]Loaded {len(docs)} documents[/green]") return docs docs = load_corpus(DATA_DIR) len(docs) # %% [markdown] # ## 3) Chunking # # Simple character-based chunker with overlap. Swap in a token-based chunker if you prefer. # # %% @dataclass class Chunk: id: str doc_path: str start: int end: int text: str sha1: str def chunk_text( text: str, doc_path: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP ) -> List[Chunk]: chunks: List[Chunk] = [] i = 0 n = len(text) while i < n: j = min(i + chunk_size, n) piece = text[i:j].strip() if len(piece) >= MIN_CHARS: sha1 = hashlib.sha1(piece.encode("utf-8")).hexdigest() chunks.append( Chunk( id=str(uuid.uuid4()), doc_path=doc_path, start=i, end=j, text=piece, sha1=sha1, ) ) if j == n: break i = j - overlap if i < 0: i = 0 if i >= n: break return chunks all_chunks: List[Chunk] = [] for path, text in docs.items(): all_chunks.extend(chunk_text(text, path)) print(f"[green]Total chunks: {len(all_chunks)}[/green]") # %% [markdown] # ## 4) Embeddings via Ollama # # Uses Ollama's `POST /api/embeddings` endpoint with your selected embedding model. # Make sure you've pulled it locally: `ollama pull nomic-embed-text` (or your chosen model). # # %% EMBED_ENDPOINT = f"{OLLAMA_URL}/api/embeddings" def embed_texts( texts: List[str], model: str = EMBED_MODEL, batch_size: int = 32 ) -> np.ndarray: vectors = [] for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] # Ollama supports a single prompt or list? We'll call one by one to be safe with large content. for t in batch: r = requests.post(EMBED_ENDPOINT, json={"model": model, "prompt": t}) r.raise_for_status() data = r.json() vec = np.array(data["embedding"], dtype=np.float32) vectors.append(vec) return np.vstack(vectors) if vectors else np.zeros((0, 768), dtype=np.float32) chunk_texts = [c.text for c in all_chunks] emb_matrix = embed_texts(chunk_texts, model=EMBED_MODEL, batch_size=8) emb_matrix.shape # %% [markdown] # ## 5) Build Vector Index (FAISS) # # We normalize vectors and use inner product (equivalent to cosine on normalized vectors). # # %% def normalize_rows(x: np.ndarray) -> np.ndarray: norms = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12 return x / norms if USE_FAISS: import faiss xb = normalize_rows(emb_matrix).astype(np.float32) d = xb.shape[1] index = faiss.IndexFlatIP(d) index.add(xb) print("[green]FAISS index built:[/green]", index.ntotal, "vectors") else: index = None xb = normalize_rows(emb_matrix).astype(np.float32) # %% [markdown] # ## 6) Retrieval Helper # # %% def search(query: str, top_k: int = TOP_K) -> List[Tuple[int, float]]: # Embed the query qv = embed_texts([query], model=EMBED_MODEL, batch_size=1) qv = normalize_rows(qv).astype(np.float32) if USE_FAISS and index is not None: D, I = index.search(qv, top_k) hits = list(zip(I[0].tolist(), D[0].tolist())) else: sims = (xb @ qv.T).ravel() I = np.argsort(-sims)[:top_k] hits = [(int(i), float(sims[i])) for i in I] return hits # quick smoke test (no error means it's wired up) print(search("What does this corpus talk about?", 3)) # %% [markdown] # ## 7) Synthesize Grounded Q&A / Instructions with Ollama # # We sample chunks, retrieve neighbors for richer context, and prompt a local LLM to create **high-quality** pairs. # # %% GEN_ENDPOINT = f"{OLLAMA_URL}/api/generate" SYSTEM_PROMPT = ( "You are a careful dataset writer. Given only the provided CONTEXT, craft high-quality, factual " "question–answer pairs for supervised fine-tuning. Answers must be grounded strictly in the context. " "If the context lacks the answer, say 'INSUFFICIENT_CONTEXT'. Focus on clarity, specificity, and avoid hallucinations." ) USER_PROMPT_TEMPLATE = ( "CONTEXT:\n\n{context}\n\n" "Task: Produce {n} diverse Q&A pairs about the content above. " "Use JSON lines (one JSON object per line) with keys: 'input' (question/instruction), 'output' (concise grounded answer), " "'meta' (object with 'source_path', 'chunk_ids', and optional 'citations': list of quotes). " "Do NOT include markdown; output JSON objects only." ) def ollama_generate( prompt: str, model: str = GEN_MODEL, temperature: float = TEMPERATURE, num_predict: int = MAX_TOKENS_GEN, ) -> str: payload = { "model": model, "prompt": prompt, "system": SYSTEM_PROMPT, "options": {"temperature": temperature, "num_predict": num_predict}, "stream": False, } r = requests.post(GEN_ENDPOINT, json=payload) r.raise_for_status() data = r.json() return data.get("response", "") def build_context(primary_idx: int, k: int = TOP_K) -> Tuple[str, List[str]]: primary_chunk = all_chunks[primary_idx] query = primary_chunk.text[:400] # use the start of the chunk as a pseudo-query hits = search(query, k) pieces, ids = [], [] for i, score in hits: ch = all_chunks[i] ids.append(ch.id) pieces.append(f"[{Path(ch.doc_path).name}::{ch.start}-{ch.end}]\n{ch.text}") return "\n\n---\n\n".join(pieces), ids def parse_llm_jsonl(text: str) -> List[Dict[str, Any]]: rows = [] for line in text.splitlines(): line = line.strip() if not line: continue # be forgiving for trailing commas etc. try: obj = json.loads(line) if isinstance(obj, dict): rows.append(obj) except Exception: # try to salvage with regex for JSON-ish try: fixed = regex.sub(r",\s*}", "}", line) fixed = regex.sub(r",\s*]", "]", fixed) obj = json.loads(fixed) if isinstance(obj, dict): rows.append(obj) except Exception: pass return rows # %% [markdown] # ## 8) Generate the RAFT Dataset # # This step iterates over documents, samples chunks, retrieves neighbors, and asks the model to produce JSONL rows. # # %% import datetime import time def synthesize_dataset( samples_per_doc: int = SAMPLES_PER_DOC, out_path: Path = OUTPUT_DIR / "raft_dataset.jsonl", ) -> Path: rng = random.Random(SEED) doc_to_chunk_idx = {} for i, ch in enumerate(all_chunks): doc_to_chunk_idx.setdefault(ch.doc_path, []).append(i) total_target = 0 with out_path.open("w", encoding="utf-8") as f: for doc_path, idxs in doc_to_chunk_idx.items(): print(f"[blue]Synthesizing for: {doc_path} ({len(idxs)} chunks)[/blue]") doc_idx = list(doc_to_chunk_idx.keys()).index(doc_path) total_docs = len(doc_to_chunk_idx) percent = (doc_idx + 1) / total_docs * 100 print( f"[cyan]Progress: {doc_idx + 1}/{total_docs} ({percent:.1f}%) completed[/cyan]" ) if not idxs: continue chosen = rng.sample(idxs, min(samples_per_doc, len(idxs))) for pi in chosen: ctx, ids = build_context(pi, k=TOP_K) user = USER_PROMPT_TEMPLATE.format(context=ctx, n=3) raw = ollama_generate( user, model=GEN_MODEL, temperature=TEMPERATURE, num_predict=MAX_TOKENS_GEN, ) rows = parse_llm_jsonl(raw) for r in rows: # enforce schema & enrich meta inp = r.get("input") or r.get("question") or r.get("query") out = r.get("output") or r.get("answer") or r.get("response") meta = r.get("meta") or {} if not isinstance(meta, dict): meta = {} meta.update( { "source_path": str(doc_path), "chunk_ids": ids, "generated_at": datetime.datetime.fromtimestamp( time.time() ).strftime("%Y-%m-%d %H:%M:%S"), "model": GEN_MODEL, "embed_model": EMBED_MODEL, } ) if inp and out: obj = {"input": inp, "output": out, "meta": meta} f.write(json.dumps(obj, ensure_ascii=False) + "\n") total_target += 1 print(f"[green]Wrote {total_target} rows -> {out_path}[/green]") return out_path OUT_JSONL = synthesize_dataset(samples_per_doc=SAMPLES_PER_DOC) OUT_JSONL # %% [markdown] # ## 9) Preview Samples # # %% from itertools import islice def head_jsonl(p: Path, n: int = 5): with p.open("r", encoding="utf-8") as f: for line in islice(f, n): print(line.rstrip()) head_jsonl(OUT_JSONL, 5) # %% [markdown] # ## 10) Optional: Spot-Check Generation Quality # # Run a tiny evaluation by asking the model with and without retrieval and compare answers. # # %% EVAL_QUESTIONS = [] # Collect inputs from the dataset (first N) with (OUTPUT_DIR / "raft_dataset.jsonl").open("r", encoding="utf-8") as f: for i, line in enumerate(f): try: obj = json.loads(line) EVAL_QUESTIONS.append(obj["input"]) except Exception: pass if len(EVAL_QUESTIONS) >= 5: break def rag_answer(q: str, k: int = TOP_K) -> str: hits = search(q, k) ctx = "\n\n".join([all_chunks[i].text for i, _ in hits]) user = f"Answer the question using ONLY this context. If missing, say INSUFFICIENT_CONTEXT.\n\nCONTEXT:\n{ctx}\n\nQUESTION: {q}" return ollama_generate(user, model=GEN_MODEL, temperature=0.2, num_predict=256) for q in EVAL_QUESTIONS: print("\n[bold]Q:[/bold]", q) ans = rag_answer(q) print("[bold]A:[/bold]", ans.strip()[:500], "...") # %% [markdown] # ## 11) Artifacts # # - `outputs/raft_dataset.jsonl` — your RAFT dataset (input/output/meta per line) # - `corpus/` — your source documents (you provide) # - You can also persist `emb_matrix.npy` and a FAISS index for reuse. # # %% # Optionally persist embeddings and index for later reuse np.save(OUTPUT_DIR / "emb_matrix.npy", emb_matrix) if USE_FAISS: import faiss faiss.write_index(index, str(OUTPUT_DIR / "faiss.index")) print("[green]Saved FAISS index and embeddings.[/green]") else: print("[yellow]FAISS disabled; only saved embeddings.[/yellow]") # %% [markdown] # ## 12) Troubleshooting # # - **Connection error to Ollama**: ensure `ollama serve` is running and models are pulled (`ollama pull nomic-embed-text`, `ollama pull llama3.1`). # - **Empty dataset**: your corpus may be too small or the parser skipped files. Check `corpus/` content and chunk parameters. # - **Hallucinations**: tighten the system prompt, lower temperature, or increase `TOP_K` and chunk size. # - **JSON parsing issues**: the notebook tries to be forgiving; you can harden `parse_llm_jsonl` per your needs. # - **PDFs**: `pip install pypdf` and try again. #