Compare commits

...

2 Commits

Author SHA1 Message Date
49c622db08 RAG cleanup 2026-02-21 15:24:21 +01:00
04cc3f8e77 tsv to csv 2026-02-21 14:42:56 +01:00
4 changed files with 77 additions and 41 deletions

View File

@@ -9,7 +9,7 @@
## Vorbereiten des Retrieval-Corpus ## Vorbereiten des Retrieval-Corpus
```bash ```bash
python prepare_corpus.py --input_tab ../data/intermediate/culture_reviews.csv --out_dir out python prepare_corpus.py --input_csv ../data/intermediate/culture_reviews.csv --out_dir out
``` ```
## Erstellen des RAFT-Datensatzes ## Erstellen des RAFT-Datensatzes
@@ -26,14 +26,16 @@ python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mi
## Inferenz ## Inferenz
### Per Baseline Mistral 7B + PEFT-Adapter
```bash
python rag_chat.py --lora_dir out/mistral_balitwin_lora
```
### Pre-Merged Modell + Adapter ### Pre-Merged Modell + Adapter
```bash ```bash
python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out
``` ```
### Per Baseline Mistral 7B + PEFT-Adapter
Hinweis: das Skript wurde nach wenigen oberflächlichen Evaluationsrunden nicht weiter verwendet, da der beste Kandidat durch einen Merge des Basismodells und seiner PEFT-Adapter beschleunigt werden konnte und dieses Skript nicht länger relevant war.
```bash
python deprecated_rag_chat.py --lora_dir out/mistral_balitwin_lora
```

View File

@@ -8,6 +8,7 @@ import numpy as np
import pandas as pd import pandas as pd
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer
def simple_clean(text: str) -> str: def simple_clean(text: str) -> str:
@@ -18,21 +19,34 @@ def simple_clean(text: str) -> str:
return text return text
def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150): def chunk_text(text: str, tokenizer, chunk_tokens: int = 250, overlap_tokens: int = 50):
""" """
Simple char-based chunking (good enough for reviews). Token-based chunking using a transformer tokenizer.
For better chunking, split by sentences and cap token length.
""" """
text = simple_clean(text) text = simple_clean(text)
if len(text) <= chunk_chars: if not text:
return []
# Tokenize the entire text
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= chunk_tokens:
return [text] if text else [] return [text] if text else []
chunks = [] chunks = []
i = 0 i = 0
while i < len(text): while i < len(tokens):
chunk = text[i : i + chunk_chars] # Get chunk of tokens
if chunk: chunk_token_ids = tokens[i : i + chunk_tokens]
chunks.append(chunk)
i += max(1, chunk_chars - overlap) if chunk_token_ids:
# Decode tokens back to text
chunk_text_str = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
if chunk_text_str:
chunks.append(chunk_text_str)
i += max(1, chunk_tokens - overlap_tokens)
return chunks return chunks
@@ -48,28 +62,29 @@ def detect_text_col(df: pd.DataFrame) -> str:
best_score = avg_len best_score = avg_len
best_col = col best_col = col
if best_col is None: if best_col is None:
raise ValueError("Could not detect a text column in the .tab file.") raise ValueError("Could not detect a text column in the .csv file.")
return best_col return best_col
def main(): def main():
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument("--input_tab", required=True, help="Tripadvisor reviews .tab file") ap.add_argument("--input_csv", required=True, help="Tripadvisor reviews .csv file")
ap.add_argument("--out_dir", default="out") ap.add_argument("--out_dir", default="out")
ap.add_argument( ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
) )
ap.add_argument("--chunk_chars", type=int, default=900) ap.add_argument("--chunk_tokens", type=int, default=250)
ap.add_argument("--overlap", type=int, default=150) ap.add_argument("--overlap_tokens", type=int, default=50)
args = ap.parse_args() args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True)
# Many .tab files are TSV # Initialize tokenizer
df = pd.read_csv(args.input_tab, sep="\t", dtype=str, on_bad_lines="skip") tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text_col = detect_text_col(df)
rows = df[text_col].fillna("").astype(str).tolist() df = pd.read_csv(args.input_csv, sep=",", dtype=str, on_bad_lines="skip")
rows = df["Original"].fillna("").astype(str).tolist()
corpus_path = os.path.join(args.out_dir, "corpus.jsonl") corpus_path = os.path.join(args.out_dir, "corpus.jsonl")
corpus = [] corpus = []
@@ -79,7 +94,12 @@ def main():
r = simple_clean(r) r = simple_clean(r)
if len(r) < 30: if len(r) < 30:
continue continue
chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap) chunks = chunk_text(
r,
tokenizer,
chunk_tokens=args.chunk_tokens,
overlap_tokens=args.overlap_tokens,
)
for ch in chunks: for ch in chunks:
if len(ch) < 30: if len(ch) < 30:
continue continue
@@ -111,9 +131,7 @@ def main():
for i, c in enumerate(corpus): for i, c in enumerate(corpus):
f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n") f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n")
print( print(f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}")
f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}\nText column detected: {text_col}"
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -37,7 +37,7 @@ def load_docstore(path):
return docs return docs
def retrieve(index, embedder, query, top_k=6): def retrieve(index, embedder, query, top_k=12):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k) scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist() return ids[0].tolist(), scores[0].tolist()
@@ -55,8 +55,9 @@ def main():
ap.add_argument( ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
) )
ap.add_argument("--top_k", type=int, default=6) ap.add_argument("--top_k", type=int, default=12)
ap.add_argument("--max_new_tokens", type=int, default=320) ap.add_argument("--max_new_tokens", type=int, default=320)
ap.add_argument("--no_model", action=argparse.BooleanOptionalAction)
args = ap.parse_args() args = ap.parse_args()
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index")) index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
@@ -70,12 +71,13 @@ def main():
if tok.pad_token is None: if tok.pad_token is None:
tok.pad_token = tok.eos_token tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained( if not args.no_model:
args.model_dir, model = AutoModelForCausalLM.from_pretrained(
device_map="auto", args.model_dir,
torch_dtype=torch.float16, device_map="auto",
) torch_dtype=torch.float16,
model.eval() )
model.eval()
print("Type your question (Ctrl+C to exit).") print("Type your question (Ctrl+C to exit).")
while True: while True:
@@ -83,20 +85,34 @@ def main():
if not q: if not q:
continue continue
ids, _ = retrieve(index, embedder, q, top_k=args.top_k) ids, scores = retrieve(index, embedder, q, top_k=args.top_k)
# Drop irrelevant context
if scores[0] > 0:
filtered = [(i, s) for i, s in zip(ids, scores) if s / scores[0] >= 0.75]
if not filtered:
print("No relevant context found.")
continue
ids, scores = zip(*filtered)
else:
print("No relevant context found.")
continue
context_docs = [docstore[i]["text"] for i in ids] context_docs = [docstore[i]["text"] for i in ids]
context_blob = "\n\n".join( context_blob = "\n\n".join([t for _, t in enumerate(context_docs)])
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
)
print("\nRetrieved Context:") print("\nRetrieved Context:")
print(context_blob) for i, (doc, score) in enumerate(zip(context_docs, scores)):
print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}")
messages = [ messages = [
{"role": "system", "content": SYSTEM_PERSONA}, {"role": "system", "content": SYSTEM_PERSONA},
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"}, {"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
] ]
if args.no_model:
continue
enc = tok.apply_chat_template( enc = tok.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
) )