mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 08:22:43 +01:00
RAG cleanup
This commit is contained in:
@@ -8,6 +8,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def simple_clean(text: str) -> str:
|
||||
@@ -18,21 +19,34 @@ def simple_clean(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150):
|
||||
def chunk_text(text: str, tokenizer, chunk_tokens: int = 250, overlap_tokens: int = 50):
|
||||
"""
|
||||
Simple char-based chunking (good enough for reviews).
|
||||
For better chunking, split by sentences and cap token length.
|
||||
Token-based chunking using a transformer tokenizer.
|
||||
"""
|
||||
text = simple_clean(text)
|
||||
if len(text) <= chunk_chars:
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Tokenize the entire text
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
if len(tokens) <= chunk_tokens:
|
||||
return [text] if text else []
|
||||
|
||||
chunks = []
|
||||
i = 0
|
||||
while i < len(text):
|
||||
chunk = text[i : i + chunk_chars]
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
i += max(1, chunk_chars - overlap)
|
||||
while i < len(tokens):
|
||||
# Get chunk of tokens
|
||||
chunk_token_ids = tokens[i : i + chunk_tokens]
|
||||
|
||||
if chunk_token_ids:
|
||||
# Decode tokens back to text
|
||||
chunk_text_str = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
|
||||
if chunk_text_str:
|
||||
chunks.append(chunk_text_str)
|
||||
|
||||
i += max(1, chunk_tokens - overlap_tokens)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@@ -59,12 +73,15 @@ def main():
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--chunk_chars", type=int, default=900)
|
||||
ap.add_argument("--overlap", type=int, default=150)
|
||||
ap.add_argument("--chunk_tokens", type=int, default=250)
|
||||
ap.add_argument("--overlap_tokens", type=int, default=50)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
df = pd.read_csv(args.input_csv, sep=",", dtype=str, on_bad_lines="skip")
|
||||
|
||||
rows = df["Original"].fillna("").astype(str).tolist()
|
||||
@@ -77,7 +94,12 @@ def main():
|
||||
r = simple_clean(r)
|
||||
if len(r) < 30:
|
||||
continue
|
||||
chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap)
|
||||
chunks = chunk_text(
|
||||
r,
|
||||
tokenizer,
|
||||
chunk_tokens=args.chunk_tokens,
|
||||
overlap_tokens=args.overlap_tokens,
|
||||
)
|
||||
for ch in chunks:
|
||||
if len(ch) < 30:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user