Compare commits

...

2 Commits

Author SHA1 Message Date
49c622db08 RAG cleanup 2026-02-21 15:24:21 +01:00
04cc3f8e77 tsv to csv 2026-02-21 14:42:56 +01:00
4 changed files with 77 additions and 41 deletions

View File

@@ -9,7 +9,7 @@
## Vorbereiten des Retrieval-Corpus
```bash
python prepare_corpus.py --input_tab ../data/intermediate/culture_reviews.csv --out_dir out
python prepare_corpus.py --input_csv ../data/intermediate/culture_reviews.csv --out_dir out
```
## Erstellen des RAFT-Datensatzes
@@ -26,14 +26,16 @@ python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mi
## Inferenz
### Per Baseline Mistral 7B + PEFT-Adapter
```bash
python rag_chat.py --lora_dir out/mistral_balitwin_lora
```
### Pre-Merged Modell + Adapter
```bash
python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out
```
### Per Baseline Mistral 7B + PEFT-Adapter
Hinweis: das Skript wurde nach wenigen oberflächlichen Evaluationsrunden nicht weiter verwendet, da der beste Kandidat durch einen Merge des Basismodells und seiner PEFT-Adapter beschleunigt werden konnte und dieses Skript nicht länger relevant war.
```bash
python deprecated_rag_chat.py --lora_dir out/mistral_balitwin_lora
```

View File

@@ -8,6 +8,7 @@ import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoTokenizer
def simple_clean(text: str) -> str:
@@ -18,21 +19,34 @@ def simple_clean(text: str) -> str:
return text
def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150):
def chunk_text(text: str, tokenizer, chunk_tokens: int = 250, overlap_tokens: int = 50):
"""
Simple char-based chunking (good enough for reviews).
For better chunking, split by sentences and cap token length.
Token-based chunking using a transformer tokenizer.
"""
text = simple_clean(text)
if len(text) <= chunk_chars:
if not text:
return []
# Tokenize the entire text
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= chunk_tokens:
return [text] if text else []
chunks = []
i = 0
while i < len(text):
chunk = text[i : i + chunk_chars]
if chunk:
chunks.append(chunk)
i += max(1, chunk_chars - overlap)
while i < len(tokens):
# Get chunk of tokens
chunk_token_ids = tokens[i : i + chunk_tokens]
if chunk_token_ids:
# Decode tokens back to text
chunk_text_str = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
if chunk_text_str:
chunks.append(chunk_text_str)
i += max(1, chunk_tokens - overlap_tokens)
return chunks
@@ -48,28 +62,29 @@ def detect_text_col(df: pd.DataFrame) -> str:
best_score = avg_len
best_col = col
if best_col is None:
raise ValueError("Could not detect a text column in the .tab file.")
raise ValueError("Could not detect a text column in the .csv file.")
return best_col
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--input_tab", required=True, help="Tripadvisor reviews .tab file")
ap.add_argument("--input_csv", required=True, help="Tripadvisor reviews .csv file")
ap.add_argument("--out_dir", default="out")
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--chunk_chars", type=int, default=900)
ap.add_argument("--overlap", type=int, default=150)
ap.add_argument("--chunk_tokens", type=int, default=250)
ap.add_argument("--overlap_tokens", type=int, default=50)
args = ap.parse_args()
os.makedirs(args.out_dir, exist_ok=True)
# Many .tab files are TSV
df = pd.read_csv(args.input_tab, sep="\t", dtype=str, on_bad_lines="skip")
text_col = detect_text_col(df)
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
rows = df[text_col].fillna("").astype(str).tolist()
df = pd.read_csv(args.input_csv, sep=",", dtype=str, on_bad_lines="skip")
rows = df["Original"].fillna("").astype(str).tolist()
corpus_path = os.path.join(args.out_dir, "corpus.jsonl")
corpus = []
@@ -79,7 +94,12 @@ def main():
r = simple_clean(r)
if len(r) < 30:
continue
chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap)
chunks = chunk_text(
r,
tokenizer,
chunk_tokens=args.chunk_tokens,
overlap_tokens=args.overlap_tokens,
)
for ch in chunks:
if len(ch) < 30:
continue
@@ -111,9 +131,7 @@ def main():
for i, c in enumerate(corpus):
f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n")
print(
f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}\nText column detected: {text_col}"
)
print(f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}")
if __name__ == "__main__":

View File

@@ -37,7 +37,7 @@ def load_docstore(path):
return docs
def retrieve(index, embedder, query, top_k=6):
def retrieve(index, embedder, query, top_k=12):
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
scores, ids = index.search(q, top_k)
return ids[0].tolist(), scores[0].tolist()
@@ -55,8 +55,9 @@ def main():
ap.add_argument(
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
)
ap.add_argument("--top_k", type=int, default=6)
ap.add_argument("--top_k", type=int, default=12)
ap.add_argument("--max_new_tokens", type=int, default=320)
ap.add_argument("--no_model", action=argparse.BooleanOptionalAction)
args = ap.parse_args()
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
@@ -70,12 +71,13 @@ def main():
if tok.pad_token is None:
tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
device_map="auto",
torch_dtype=torch.float16,
)
model.eval()
if not args.no_model:
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
device_map="auto",
torch_dtype=torch.float16,
)
model.eval()
print("Type your question (Ctrl+C to exit).")
while True:
@@ -83,20 +85,34 @@ def main():
if not q:
continue
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
ids, scores = retrieve(index, embedder, q, top_k=args.top_k)
# Drop irrelevant context
if scores[0] > 0:
filtered = [(i, s) for i, s in zip(ids, scores) if s / scores[0] >= 0.75]
if not filtered:
print("No relevant context found.")
continue
ids, scores = zip(*filtered)
else:
print("No relevant context found.")
continue
context_docs = [docstore[i]["text"] for i in ids]
context_blob = "\n\n".join(
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
)
context_blob = "\n\n".join([t for _, t in enumerate(context_docs)])
print("\nRetrieved Context:")
print(context_blob)
for i, (doc, score) in enumerate(zip(context_docs, scores)):
print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}")
messages = [
{"role": "system", "content": SYSTEM_PERSONA},
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
]
if args.no_model:
continue
enc = tok.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)