RAG cleanup

This commit is contained in:
2026-02-21 15:24:21 +01:00
parent 04cc3f8e77
commit 49c622db08
4 changed files with 71 additions and 31 deletions

View File

@@ -8,6 +8,7 @@ import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoTokenizer
def simple_clean(text: str) -> str:
@@ -18,21 +19,34 @@ def simple_clean(text: str) -> str:
return text
def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150):
def chunk_text(text: str, tokenizer, chunk_tokens: int = 250, overlap_tokens: int = 50):
"""
Simple char-based chunking (good enough for reviews).
For better chunking, split by sentences and cap token length.
Token-based chunking using a transformer tokenizer.
"""
text = simple_clean(text)
if len(text) <= chunk_chars:
if not text:
return []
# Tokenize the entire text
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= chunk_tokens:
return [text] if text else []
chunks = []
i = 0
while i < len(text):
chunk = text[i : i + chunk_chars]
if chunk:
chunks.append(chunk)
i += max(1, chunk_chars - overlap)
while i < len(tokens):
# Get chunk of tokens
chunk_token_ids = tokens[i : i + chunk_tokens]
if chunk_token_ids:
# Decode tokens back to text
chunk_text_str = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
if chunk_text_str:
chunks.append(chunk_text_str)
i += max(1, chunk_tokens - overlap_tokens)
return chunks
@@ -59,12 +73,15 @@ def main():
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--chunk_chars", type=int, default=900)
ap.add_argument("--overlap", type=int, default=150)
ap.add_argument("--chunk_tokens", type=int, default=250)
ap.add_argument("--overlap_tokens", type=int, default=50)
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
df = pd.read_csv(args.input_csv, sep=",", dtype=str, on_bad_lines="skip")
rows = df["Original"].fillna("").astype(str).tolist()
@@ -77,7 +94,12 @@ def main():
r = simple_clean(r)
if len(r) < 30:
continue
chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap)
chunks = chunk_text(
r,
tokenizer,
chunk_tokens=args.chunk_tokens,
overlap_tokens=args.overlap_tokens,
)
for ch in chunks:
if len(ch) < 30:
continue