mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
Updated BERTopic quality gates
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.18.0
|
||||
# kernelspec:
|
||||
# display_name: .venv
|
||||
# display_name: .venv (3.12.3)
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
@@ -56,7 +56,7 @@ nltk.download("wordnet")
|
||||
# %%
|
||||
RECREATE_MODEL = True
|
||||
RECREATE_REDUCED_MODEL = True
|
||||
PROCESS_DATA = False
|
||||
PROCESS_DATA = True
|
||||
REDUCE_OUTLIERS = False
|
||||
|
||||
# Data Sample Size, -1 for all data
|
||||
@@ -107,6 +107,12 @@ else:
|
||||
.review.to_list()
|
||||
)
|
||||
|
||||
# Remove all duplicate reviews
|
||||
reviews = list(set(reviews))
|
||||
|
||||
# Remove reviews that contain less than x words
|
||||
reviews = [review for review in reviews if len(review.split()) >= 9]
|
||||
|
||||
print("Loaded {} reviews".format(len(reviews)))
|
||||
|
||||
# %%
|
||||
@@ -115,9 +121,6 @@ rep = {
|
||||
r"\n": " ",
|
||||
r'\\"': "",
|
||||
r'"': "",
|
||||
"mongkey": "monkey",
|
||||
"monky": "monkey",
|
||||
"verry": "very",
|
||||
"bali": "",
|
||||
r"\s+": " ",
|
||||
}
|
||||
@@ -302,44 +305,40 @@ if REDUCE_OUTLIERS:
|
||||
#
|
||||
|
||||
# %%
|
||||
import random
|
||||
from pathlib import Path
|
||||
CLASSIFICATION = False
|
||||
if CLASSIFICATION:
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
# --- config ---
|
||||
topics_to_keep = {2, 4, 5, 9, 22, 26}
|
||||
INPUT_PATH = "../data/original/reviews.tab" # TSV with a 'review' column
|
||||
OUTPUT_CSV = "../data/intermediate/selected_topics_documents.csv"
|
||||
OUTPUT_DIR = Path("../raft/corpus")
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
# --- config ---
|
||||
topics_to_keep = {2, 4, 5, 9, 22, 26}
|
||||
INPUT_PATH = "../data/original/reviews.tab" # TSV with a 'review' column
|
||||
OUTPUT_CSV = "../data/intermediate/selected_topics_documents.csv"
|
||||
OUTPUT_DIR = Path("../raft/corpus")
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
BATCH_SIZE = 60
|
||||
MIN_CHARS = 40
|
||||
SEED = 42
|
||||
BATCH_SIZE = 60
|
||||
MIN_CHARS = 40
|
||||
SEED = 42
|
||||
|
||||
# --- load data ---
|
||||
data = pd.read_csv(INPUT_PATH, sep="\t")
|
||||
# Topic model document info
|
||||
df = topic_model.get_document_info(reviews) # assumes your model is already fitted
|
||||
df["Original"] = reviews.values
|
||||
|
||||
# If you already have `reviews` elsewhere, replace the next line with that variable
|
||||
reviews = data["review"].astype(str).fillna("")
|
||||
# --- filter by topics and length ---
|
||||
filtered = df[df["Topic"].isin(topics_to_keep)].copy()
|
||||
filtered["Original"] = filtered["Original"].str.strip()
|
||||
filtered = filtered[filtered["Original"].str.len() >= MIN_CHARS]
|
||||
|
||||
# Topic model document info
|
||||
df = topic_model.get_document_info(reviews) # assumes your model is already fitted
|
||||
df["Original"] = reviews.values
|
||||
# Save an audit CSV
|
||||
filtered[["Original", "Topic"]].to_csv(OUTPUT_CSV, index=False)
|
||||
|
||||
# --- filter by topics and length ---
|
||||
filtered = df[df["Topic"].isin(topics_to_keep)].copy()
|
||||
filtered["Original"] = filtered["Original"].str.strip()
|
||||
filtered = filtered[filtered["Original"].str.len() >= MIN_CHARS]
|
||||
# --- deterministic shuffle + write batched corpus files ---
|
||||
total_files = 0
|
||||
total_reviews = 0
|
||||
rng = random.Random(SEED)
|
||||
|
||||
# Save an audit CSV
|
||||
filtered[["Original", "Topic"]].to_csv(OUTPUT_CSV, index=False)
|
||||
|
||||
# --- deterministic shuffle + write batched corpus files ---
|
||||
total_files = 0
|
||||
total_reviews = 0
|
||||
rng = random.Random(SEED)
|
||||
|
||||
for topic_val, g in filtered.groupby("Topic", sort=True):
|
||||
for topic_val, g in filtered.groupby("Topic", sort=True):
|
||||
reviews_list = g["Original"].tolist()
|
||||
|
||||
# deterministic shuffle within topic
|
||||
@@ -353,7 +352,8 @@ for topic_val, g in filtered.groupby("Topic", sort=True):
|
||||
|
||||
# simple header for traceability
|
||||
header = (
|
||||
f"[TOPIC] {topic_val}\n" + f"[Stats] N={len(chunk)} | Source={INPUT_PATH}\n"
|
||||
f"[TOPIC] {topic_val}\n"
|
||||
+ f"[Stats] N={len(chunk)} | Source={INPUT_PATH}\n"
|
||||
)
|
||||
|
||||
lines = [header, ""]
|
||||
@@ -367,10 +367,10 @@ for topic_val, g in filtered.groupby("Topic", sort=True):
|
||||
total_files += 1
|
||||
total_reviews += len(chunk)
|
||||
|
||||
print(
|
||||
print(
|
||||
f"[green]Wrote {total_files} docs with {total_reviews} reviews to {OUTPUT_DIR}[/green]"
|
||||
)
|
||||
print(f"[green]Filtered CSV saved to {OUTPUT_CSV}[/green]")
|
||||
)
|
||||
print(f"[green]Filtered CSV saved to {OUTPUT_CSV}[/green]")
|
||||
|
||||
# %%
|
||||
doc_topic_matrix = probs
|
||||
|
||||
Reference in New Issue
Block a user