mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
RAG cleanup
This commit is contained in:
@@ -37,7 +37,7 @@ def load_docstore(path):
|
||||
return docs
|
||||
|
||||
|
||||
def retrieve(index, embedder, query, top_k=6):
|
||||
def retrieve(index, embedder, query, top_k=12):
|
||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
||||
scores, ids = index.search(q, top_k)
|
||||
return ids[0].tolist(), scores[0].tolist()
|
||||
@@ -55,8 +55,9 @@ def main():
|
||||
ap.add_argument(
|
||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
ap.add_argument("--top_k", type=int, default=6)
|
||||
ap.add_argument("--top_k", type=int, default=12)
|
||||
ap.add_argument("--max_new_tokens", type=int, default=320)
|
||||
ap.add_argument("--no_model", action=argparse.BooleanOptionalAction)
|
||||
args = ap.parse_args()
|
||||
|
||||
index = faiss.read_index(os.path.join(args.out_dir, "faiss.index"))
|
||||
@@ -70,12 +71,13 @@ def main():
|
||||
if tok.pad_token is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_dir,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
model.eval()
|
||||
if not args.no_model:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_dir,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
print("Type your question (Ctrl+C to exit).")
|
||||
while True:
|
||||
@@ -83,20 +85,34 @@ def main():
|
||||
if not q:
|
||||
continue
|
||||
|
||||
ids, _ = retrieve(index, embedder, q, top_k=args.top_k)
|
||||
ids, scores = retrieve(index, embedder, q, top_k=args.top_k)
|
||||
|
||||
# Drop irrelevant context
|
||||
if scores[0] > 0:
|
||||
filtered = [(i, s) for i, s in zip(ids, scores) if s / scores[0] >= 0.75]
|
||||
if not filtered:
|
||||
print("No relevant context found.")
|
||||
continue
|
||||
ids, scores = zip(*filtered)
|
||||
else:
|
||||
print("No relevant context found.")
|
||||
continue
|
||||
|
||||
context_docs = [docstore[i]["text"] for i in ids]
|
||||
context_blob = "\n\n".join(
|
||||
[f"[DOC {i}] {t}" for i, t in enumerate(context_docs)]
|
||||
)
|
||||
context_blob = "\n\n".join([t for _, t in enumerate(context_docs)])
|
||||
|
||||
print("\nRetrieved Context:")
|
||||
print(context_blob)
|
||||
for i, (doc, score) in enumerate(zip(context_docs, scores)):
|
||||
print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PERSONA},
|
||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
||||
]
|
||||
|
||||
if args.no_model:
|
||||
continue
|
||||
|
||||
enc = tok.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user