mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
Compare commits
2 Commits
1a99b53d44
...
49c622db08
| Author | SHA1 | Date | |
|---|---|---|---|
|
49c622db08
|
|||
|
04cc3f8e77
|
@@ -9,7 +9,7 @@
|
|||||||
## Vorbereiten des Retrieval-Corpus
|
## Vorbereiten des Retrieval-Corpus
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python prepare_corpus.py --input_tab ../data/intermediate/culture_reviews.csv --out_dir out
|
python prepare_corpus.py --input_csv ../data/intermediate/culture_reviews.csv --out_dir out
|
||||||
```
|
```
|
||||||
|
|
||||||
## Erstellen des RAFT-Datensatzes
|
## Erstellen des RAFT-Datensatzes
|
||||||
@@ -26,14 +26,16 @@ python train_mistral_raft.py --train_jsonl out/raft_train.jsonl --out_dir out/mi
|
|||||||
|
|
||||||
## Inferenz
|
## Inferenz
|
||||||
|
|
||||||
### Per Baseline Mistral 7B + PEFT-Adapter
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python rag_chat.py --lora_dir out/mistral_balitwin_lora
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pre-Merged Modell + Adapter
|
### Pre-Merged Modell + Adapter
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out
|
python rag_chat_merged.py --model_dir /path/to/model_folder --out_dir out
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Per Baseline Mistral 7B + PEFT-Adapter
|
||||||
|
|
||||||
|
Hinweis: das Skript wurde nach wenigen oberflächlichen Evaluationsrunden nicht weiter verwendet, da der beste Kandidat durch einen Merge des Basismodells und seiner PEFT-Adapter beschleunigt werden konnte und dieses Skript nicht länger relevant war.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python deprecated_rag_chat.py --lora_dir out/mistral_balitwin_lora
|
||||||
|
```
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def simple_clean(text: str) -> str:
|
def simple_clean(text: str) -> str:
|
||||||
@@ -18,21 +19,34 @@ def simple_clean(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(text: str, chunk_chars: int = 900, overlap: int = 150):
|
def chunk_text(text: str, tokenizer, chunk_tokens: int = 250, overlap_tokens: int = 50):
|
||||||
"""
|
"""
|
||||||
Simple char-based chunking (good enough for reviews).
|
Token-based chunking using a transformer tokenizer.
|
||||||
For better chunking, split by sentences and cap token length.
|
|
||||||
"""
|
"""
|
||||||
text = simple_clean(text)
|
text = simple_clean(text)
|
||||||
if len(text) <= chunk_chars:
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Tokenize the entire text
|
||||||
|
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
if len(tokens) <= chunk_tokens:
|
||||||
return [text] if text else []
|
return [text] if text else []
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(text):
|
while i < len(tokens):
|
||||||
chunk = text[i : i + chunk_chars]
|
# Get chunk of tokens
|
||||||
if chunk:
|
chunk_token_ids = tokens[i : i + chunk_tokens]
|
||||||
chunks.append(chunk)
|
|
||||||
i += max(1, chunk_chars - overlap)
|
if chunk_token_ids:
|
||||||
|
# Decode tokens back to text
|
||||||
|
chunk_text_str = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
|
||||||
|
if chunk_text_str:
|
||||||
|
chunks.append(chunk_text_str)
|
||||||
|
|
||||||
|
i += max(1, chunk_tokens - overlap_tokens)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
@@ -48,28 +62,29 @@ def detect_text_col(df: pd.DataFrame) -> str:
|
|||||||
best_score = avg_len
|
best_score = avg_len
|
||||||
best_col = col
|
best_col = col
|
||||||
if best_col is None:
|
if best_col is None:
|
||||||
raise ValueError("Could not detect a text column in the .tab file.")
|
raise ValueError("Could not detect a text column in the .csv file.")
|
||||||
return best_col
|
return best_col
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
ap.add_argument("--input_tab", required=True, help="Tripadvisor reviews .tab file")
|
ap.add_argument("--input_csv", required=True, help="Tripadvisor reviews .csv file")
|
||||||
ap.add_argument("--out_dir", default="out")
|
ap.add_argument("--out_dir", default="out")
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
)
|
)
|
||||||
ap.add_argument("--chunk_chars", type=int, default=900)
|
ap.add_argument("--chunk_tokens", type=int, default=250)
|
||||||
ap.add_argument("--overlap", type=int, default=150)
|
ap.add_argument("--overlap_tokens", type=int, default=50)
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
os.makedirs(args.out_dir, exist_ok=True)
|
os.makedirs(args.out_dir, exist_ok=True)
|
||||||
|
|
||||||
# Many .tab files are TSV
|
# Initialize tokenizer
|
||||||
df = pd.read_csv(args.input_tab, sep="\t", dtype=str, on_bad_lines="skip")
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
text_col = detect_text_col(df)
|
|
||||||
|
|
||||||
rows = df[text_col].fillna("").astype(str).tolist()
|
df = pd.read_csv(args.input_csv, sep=",", dtype=str, on_bad_lines="skip")
|
||||||
|
|
||||||
|
rows = df["Original"].fillna("").astype(str).tolist()
|
||||||
|
|
||||||
corpus_path = os.path.join(args.out_dir, "corpus.jsonl")
|
corpus_path = os.path.join(args.out_dir, "corpus.jsonl")
|
||||||
corpus = []
|
corpus = []
|
||||||
@@ -79,7 +94,12 @@ def main():
|
|||||||
r = simple_clean(r)
|
r = simple_clean(r)
|
||||||
if len(r) < 30:
|
if len(r) < 30:
|
||||||
continue
|
continue
|
||||||
chunks = chunk_text(r, chunk_chars=args.chunk_chars, overlap=args.overlap)
|
chunks = chunk_text(
|
||||||
|
r,
|
||||||
|
tokenizer,
|
||||||
|
chunk_tokens=args.chunk_tokens,
|
||||||
|
overlap_tokens=args.overlap_tokens,
|
||||||
|
)
|
||||||
for ch in chunks:
|
for ch in chunks:
|
||||||
if len(ch) < 30:
|
if len(ch) < 30:
|
||||||
continue
|
continue
|
||||||
@@ -111,9 +131,7 @@ def main():
|
|||||||
for i, c in enumerate(corpus):
|
for i, c in enumerate(corpus):
|
||||||
f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n")
|
f.write(json.dumps({"faiss_id": i, **c}, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
print(
|
print(f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}")
|
||||||
f"Saved:\n- {corpus_path}\n- {faiss_path}\n- {mapping_path}\nText column detected: {text_col}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def load_docstore(path):
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def retrieve(index, embedder, query, top_k=6):
|
def retrieve(index, embedder, query, top_k=12):
|
||||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||||||
scores, ids = index.search(q, top_k)
|
scores, ids = index.search(q, top_k)
|
||||||
return ids[0].tolist(), scores[0].tolist()
|
return ids[0].tolist(), scores[0].tolist()
|
||||||
@@ -55,8 +55,9 @@ def main():
|
|||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
)
|
)
|
||||||
ap.add_argument("--top_k", type=int, default=6)
|
ap.add_argument("--top_k", type=int, default=12)
|
||||||
ap.add_argument("--max_new_tokens", type=int, default=320)
|
ap.add_argument("--max_new_tokens", type=int, default=320)
|
||||||
|
ap.add_argument("--no_model", action=argparse.BooleanOptionalAction)
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
||||||
@@ -70,12 +71,13 @@ def main():
|
|||||||
if tok.pad_token is None:
|
if tok.pad_token is None:
|
||||||
tok.pad_token = tok.eos_token
|
tok.pad_token = tok.eos_token
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
if not args.no_model:
|
||||||
args.model_dir,
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
device_map="auto",
|
args.model_dir,
|
||||||
torch_dtype=torch.float16,
|
device_map="auto",
|
||||||
)
|
torch_dtype=torch.float16,
|
||||||
model.eval()
|
)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
print("Type your question (Ctrl+C to exit).")
|
print("Type your question (Ctrl+C to exit).")
|
||||||
while True:
|
while True:
|
||||||
@@ -83,20 +85,34 @@ def main():
|
|||||||
if not q:
|
if not q:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
|
ids, scores = retrieve(index, embedder, q, top_k=args.top_k)
|
||||||
|
|
||||||
|
# Drop irrelevant context
|
||||||
|
if scores[0] > 0:
|
||||||
|
filtered = [(i, s) for i, s in zip(ids, scores) if s / scores[0] >= 0.75]
|
||||||
|
if not filtered:
|
||||||
|
print("No relevant context found.")
|
||||||
|
continue
|
||||||
|
ids, scores = zip(*filtered)
|
||||||
|
else:
|
||||||
|
print("No relevant context found.")
|
||||||
|
continue
|
||||||
|
|
||||||
context_docs = [docstore[i]["text"] for i in ids]
|
context_docs = [docstore[i]["text"] for i in ids]
|
||||||
context_blob = "\n\n".join(
|
context_blob = "\n\n".join([t for _, t in enumerate(context_docs)])
|
||||||
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\nRetrieved Context:")
|
print("\nRetrieved Context:")
|
||||||
print(context_blob)
|
for i, (doc, score) in enumerate(zip(context_docs, scores)):
|
||||||
|
print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": SYSTEM_PERSONA},
|
{"role": "system", "content": SYSTEM_PERSONA},
|
||||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if args.no_model:
|
||||||
|
continue
|
||||||
|
|
||||||
enc = tok.apply_chat_template(
|
enc = tok.apply_chat_template(
|
||||||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user