mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 16:32:42 +01:00
139 lines
4.0 KiB
Python
139 lines
4.0 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import pandas as pd
|
|
from sentence_transformers import SentenceTransformer
|
|
from tqdm import tqdm
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
def simple_clean(text: str) -> str:
|
|
if not isinstance(text, str):
|
|
return ""
|
|
text = text.replace("\u00a0", " ")
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|
return text
|
|
|
|
|
|
def chunk_text(text: str, tokenizer, chunk_tokens: int = 250, overlap_tokens: int = 50):
|
|
"""
|
|
Token-based chunking using a transformer tokenizer.
|
|
"""
|
|
text = simple_clean(text)
|
|
if not text:
|
|
return []
|
|
|
|
# Tokenize the entire text
|
|
tokens = tokenizer.encode(text, add_special_tokens=False)
|
|
|
|
if len(tokens) <= chunk_tokens:
|
|
return [text] if text else []
|
|
|
|
chunks = []
|
|
i = 0
|
|
while i < len(tokens):
|
|
# Get chunk of tokens
|
|
chunk_token_ids = tokens[i : i + chunk_tokens]
|
|
|
|
if chunk_token_ids:
|
|
# Decode tokens back to text
|
|
chunk_text_str = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
|
|
if chunk_text_str:
|
|
chunks.append(chunk_text_str)
|
|
|
|
i += max(1, chunk_tokens - overlap_tokens)
|
|
|
|
return chunks
|
|
|
|
|
|
def detect_text_col(df: pd.DataFrame) -> str:
|
|
# Heuristic: pick the longest average string column
|
|
best_col, best_score = None, -1
|
|
for col in df.columns:
|
|
sample = df[col].dropna().astype(str).head(200)
|
|
if len(sample) == 0:
|
|
continue
|
|
avg_len = sample.map(len).mean()
|
|
if avg_len > best_score:
|
|
best_score = avg_len
|
|
best_col = col
|
|
if best_col is None:
|
|
raise ValueError("Could not detect a text column in the .csv file.")
|
|
return best_col
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--input_csv", required=True, help="Tripadvisor reviews .csv file")
|
|
ap.add_argument("--out_dir", default="out")
|
|
ap.add_argument(
|
|
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
ap.add_argument("--chunk_tokens", type=int, default=250)
|
|
ap.add_argument("--overlap_tokens", type=int, default=50)
|
|
args = ap.parse_args()
|
|
|
|
os.makedirs(args.out_dir, exist_ok=True)
|
|
|
|
# Initialize tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
|
|
|
df = pd.read_csv(args.input_csv, sep=",", dtype=str, on_bad_lines="skip")
|
|
|
|
rows = df["Original"].fillna("").astype(str).tolist()
|
|
|
|
corpus_path = os.path.join(args.out_dir, "corpus.jsonl")
|
|
corpus = []
|
|
doc_id = 0
|
|
|
|
for r in tqdm(rows, desc="Chunking"):
|
|
r = simple_clean(r)
|
|
if len(r) < 30:
|
|
continue
|
|
chunks = chunk_text(
|
|
r,
|
|
tokenizer,
|
|
chunk_tokens=args.chunk_tokens,
|
|
overlap_tokens=args.overlap_tokens,
|
|
)
|
|
for ch in chunks:
|
|
if len(ch) < 30:
|
|
continue
|
|
corpus.append({"doc_id": doc_id, "text": ch})
|
|
doc_id += 1
|
|
|
|
with open(corpus_path, "w", encoding="utf-8") as f:
|
|
for ex in corpus:
|
|
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
|
|
|
# Build FAISS index
|
|
embedder = SentenceTransformer(args.embedding_model)
|
|
texts = [c["text"] for c in corpus]
|
|
embs = embedder.encode(
|
|
texts, batch_size=64, show_progress_bar=True, normalize_embeddings=True
|
|
)
|
|
embs = np.asarray(embs, dtype=np.float32)
|
|
|
|
dim = embs.shape[1]
|
|
index = faiss.IndexFlatIP(dim) # cosine if normalized
|
|
index.add(embs)
|
|
|
|
faiss_path = os.path.join(args.out_dir, "faiss.index")
|
|
faiss.write_index(index, faiss_path)
|
|
|
|
# Store mapping doc row -> text
|
|
mapping_path = os.path.join(args.out_dir, "docstore.jsonl")
|
|
with open(mapping_path, "w", encoding="utf-8") as f:
|
|
for i, c in enumerate(corpus):
|
|
f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n")
|
|
|
|
print(f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|