New RAFT approach

This commit is contained in:
2026-02-19 14:24:42 +01:00
parent d0d3edae14
commit 28823dc0b5
5 changed files with 591 additions and 0 deletions

122
raft/prepare_corpus.py Normal file
View File

@@ -0,0 +1,122 @@
import argparse
import json
import os
import re
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
## Usage: python prepare_corpus.py --input_tab your_reviews.tab --out_dir out
def simple_clean(text: str) -> str:
if not isinstance(text, str):
return ""
text = text.replace("\u00a0", " ")
text = re.sub(r"\s+", " ", text).strip()
return text
def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150):
"""
Simple char-based chunking (good enough for reviews).
For better chunking, split by sentences and cap token length.
"""
text = simple_clean(text)
if len(text) <= chunk_chars:
return [text] if text else []
chunks = []
i = 0
while i < len(text):
chunk = text[i : i + chunk_chars]
if chunk:
chunks.append(chunk)
i += max(1, chunk_chars - overlap)
return chunks
def detect_text_col(df: pd.DataFrame) -> str:
# Heuristic: pick the longest average string column
best_col, best_score = None, -1
for col in df.columns:
sample = df[col].dropna().astype(str).head(200)
if len(sample) == 0:
continue
avg_len = sample.map(len).mean()
if avg_len > best_score:
best_score = avg_len
best_col = col
if best_col is None:
raise ValueError("Could not detect a text column in the .tab file.")
return best_col
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input_tab", required=True, help="Tripadvisor reviews .tab file")
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--chunk_chars", type=int, default=900)
ap.add_argument("--overlap", type=int, default=150)
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# Many .tab files are TSV
df = pd.read_csv(args.input_tab, sep="\t", dtype=str, on_bad_lines="skip")
text_col = detect_text_col(df)
rows = df[text_col].fillna("").astype(str).tolist()
corpus_path = os.path.join(args.out_dir, "corpus.jsonl")
corpus = []
doc_id = 0
for r in tqdm(rows, desc="Chunking"):
r = simple_clean(r)
if len(r) < 30:
continue
chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap)
for ch in chunks:
if len(ch) < 30:
continue
corpus.append({"doc_id": doc_id, "text": ch})
doc_id += 1
with open(corpus_path, "w", encoding="utf-8") as f:
for ex in corpus:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Build FAISS index
embedder = SentenceTransformer(args.embedding_model)
texts = [c["text"] for c in corpus]
embs = embedder.encode(
texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True
)
embs = np.asarray(embs, dtype=np.float32)
dim = embs.shape[1]
index = faiss.IndexFlatIP(dim) # cosine if normalized
index.add(embs)
faiss_path = os.path.join(args.out_dir, "faiss.index")
faiss.write_index(index, faiss_path)
# Store mapping doc row -> text
mapping_path = os.path.join(args.out_dir, "docstore.jsonl")
with open(mapping_path, "w", encoding="utf-8") as f:
for i, c in enumerate(corpus):
f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n")
print(
f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}\nText column detected: {text_col}"
)
if __name__ == "__main__":
main()