diff --git a/raft/README.md b/raft/README.md index 45e8be3..56ffd87 100644 --- a/raft/README.md +++ b/raft/README.md @@ -26,14 +26,16 @@ python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mi ## Inferenz -### Per Baseline Mistral 7B + PEFT-Adapter - -```bash -python rag_chat.py --lora_dir out/mistral_balitwin_lora -``` - ### Pre-Merged Modell + Adapter ```bash python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out ``` + +### Per Baseline Mistral 7B + PEFT-Adapter + +Hinweis: das Skript wurde nach wenigen oberflächlichen Evaluationsrunden nicht weiter verwendet, da der beste Kandidat durch einen Merge des Basismodells und seiner PEFT-Adapter beschleunigt werden konnte und dieses Skript nicht länger relevant war. + +```bash +python deprecated_rag_chat.py --lora_dir out/mistral_balitwin_lora +``` diff --git a/raft/rag_chat.py b/raft/deprecated_rag_chat.py similarity index 100% rename from raft/rag_chat.py rename to raft/deprecated_rag_chat.py diff --git a/raft/prepare_corpus.py b/raft/prepare_corpus.py index 88c05e9..182d2ea 100644 --- a/raft/prepare_corpus.py +++ b/raft/prepare_corpus.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer from tqdm import tqdm +from transformers import AutoTokenizer def simple_clean(text: str) -> str: @@ -18,21 +19,34 @@ def simple_clean(text: str) -> str: return text -def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150): +def chunk_text(text: str, tokenizer, chunk_tokens: int = 250, overlap_tokens: int = 50): """ - Simple char-based chunking (good enough for reviews). - For better chunking, split by sentences and cap token length. + Token-based chunking using a transformer tokenizer. """ text = simple_clean(text) - if len(text) <= chunk_chars: + if not text: + return [] + + # Tokenize the entire text + tokens = tokenizer.encode(text, add_special_tokens=False) + + if len(tokens) <= chunk_tokens: return [text] if text else [] + chunks = [] i = 0 - while i < len(text): - chunk = text[i : i + chunk_chars] - if chunk: - chunks.append(chunk) - i += max(1, chunk_chars - overlap) + while i < len(tokens): + # Get chunk of tokens + chunk_token_ids = tokens[i : i + chunk_tokens] + + if chunk_token_ids: + # Decode tokens back to text + chunk_text_str = tokenizer.decode(chunk_token_ids, skip_special_tokens=True) + if chunk_text_str: + chunks.append(chunk_text_str) + + i += max(1, chunk_tokens - overlap_tokens) + return chunks @@ -59,12 +73,15 @@ def main(): ap.add_argument( "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" ) - ap.add_argument("--chunk_chars", type=int, default=900) - ap.add_argument("--overlap", type=int, default=150) + ap.add_argument("--chunk_tokens", type=int, default=250) + ap.add_argument("--overlap_tokens", type=int, default=50) args = ap.parse_args() os.makedirs(args.out_dir, exist_ok=True) + # Initialize tokenizer + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + df = pd.read_csv(args.input_csv, sep=",", dtype=str, on_bad_lines="skip") rows = df["Original"].fillna("").astype(str).tolist() @@ -77,7 +94,12 @@ def main(): r = simple_clean(r) if len(r) < 30: continue - chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap) + chunks = chunk_text( + r, + tokenizer, + chunk_tokens=args.chunk_tokens, + overlap_tokens=args.overlap_tokens, + ) for ch in chunks: if len(ch) < 30: continue diff --git a/raft/rag_chat_merged.py b/raft/rag_chat_merged.py index 3079a8b..d1cb445 100644 --- a/raft/rag_chat_merged.py +++ b/raft/rag_chat_merged.py @@ -37,7 +37,7 @@ def load_docstore(path): return docs -def retrieve(index, embedder, query, top_k=6): +def retrieve(index, embedder, query, top_k=12): q = embedder.encode([query], normalize_embeddings=True).astype(np.float32) scores, ids = index.search(q, top_k) return ids[0].tolist(), scores[0].tolist() @@ -55,8 +55,9 @@ def main(): ap.add_argument( "--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2" ) - ap.add_argument("--top_k", type=int, default=6) + ap.add_argument("--top_k", type=int, default=12) ap.add_argument("--max_new_tokens", type=int, default=320) + ap.add_argument("--no_model", action=argparse.BooleanOptionalAction) args = ap.parse_args() index = faiss.read_index(os.path.join(args.out_dir, "faiss.index")) @@ -70,12 +71,13 @@ def main(): if tok.pad_token is None: tok.pad_token = tok.eos_token - model = AutoModelForCausalLM.from_pretrained( - args.model_dir, - device_map="auto", - torch_dtype=torch.float16, - ) - model.eval() + if not args.no_model: + model = AutoModelForCausalLM.from_pretrained( + args.model_dir, + device_map="auto", + torch_dtype=torch.float16, + ) + model.eval() print("Type your question (Ctrl+C to exit).") while True: @@ -83,20 +85,34 @@ def main(): if not q: continue - ids, _ = retrieve(index, embedder, q, top_k=args.top_k) + ids, scores = retrieve(index, embedder, q, top_k=args.top_k) + + # Drop irrelevant context + if scores[0] > 0: + filtered = [(i, s) for i, s in zip(ids, scores) if s / scores[0] >= 0.75] + if not filtered: + print("No relevant context found.") + continue + ids, scores = zip(*filtered) + else: + print("No relevant context found.") + continue + context_docs = [docstore[i]["text"] for i in ids] - context_blob = "\n\n".join( - [f"[DOC {i}] {t}" for i, t in enumerate(context_docs)] - ) + context_blob = "\n\n".join([t for _, t in enumerate(context_docs)]) print("\nRetrieved Context:") - print(context_blob) + for i, (doc, score) in enumerate(zip(context_docs, scores)): + print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}") messages = [ {"role": "system", "content": SYSTEM_PERSONA}, {"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"}, ] + if args.no_model: + continue + enc = tok.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" )