Updated BERTopic quality gates

This commit is contained in:
2026-02-20 01:58:00 +01:00
parent 101bd81ca1
commit 99ba5031ca

View File

@@ -7,7 +7,7 @@
# format_version: '1.3' # format_version: '1.3'
# jupytext_version: 1.18.0 # jupytext_version: 1.18.0
# kernelspec: # kernelspec:
# display_name: .venv # display_name: .venv (3.12.3)
# language: python # language: python
# name: python3 # name: python3
# --- # ---
@@ -56,7 +56,7 @@ nltk.download("wordnet")
# %% # %%
RECREATE_MODEL = True RECREATE_MODEL = True
RECREATE_REDUCED_MODEL = True RECREATE_REDUCED_MODEL = True
PROCESS_DATA = False PROCESS_DATA = True
REDUCE_OUTLIERS = False REDUCE_OUTLIERS = False
# Data Sample Size, -1 for all data # Data Sample Size, -1 for all data
@@ -107,6 +107,12 @@ else:
.review.to_list() .review.to_list()
) )
# Remove all duplicate reviews
reviews = list(set(reviews))
# Remove reviews that contain less than x words
reviews = [review for review in reviews if len(review.split()) >= 9]
print("Loaded {} reviews".format(len(reviews))) print("Loaded {} reviews".format(len(reviews)))
# %% # %%
@@ -115,9 +121,6 @@ rep = {
r"\n": " ", r"\n": " ",
r'\\"': "", r'\\"': "",
r'"': "", r'"': "",
"mongkey": "monkey",
"monky": "monkey",
"verry": "very",
"bali": "", "bali": "",
r"\s+": " ", r"\s+": " ",
} }
@@ -302,6 +305,8 @@ if REDUCE_OUTLIERS:
# #
# %% # %%
CLASSIFICATION = False
if CLASSIFICATION:
import random import random
from pathlib import Path from pathlib import Path
@@ -316,12 +321,6 @@ BATCH_SIZE = 60
MIN_CHARS = 40 MIN_CHARS = 40
SEED = 42 SEED = 42
# --- load data ---
data = pd.read_csv(INPUT_PATH, sep="\t")
# If you already have `reviews` elsewhere, replace the next line with that variable
reviews = data["review"].astype(str).fillna("")
# Topic model document info # Topic model document info
df = topic_model.get_document_info(reviews) # assumes your model is already fitted df = topic_model.get_document_info(reviews) # assumes your model is already fitted
df["Original"] = reviews.values df["Original"] = reviews.values
@@ -353,7 +352,8 @@ for topic_val, g in filtered.groupby("Topic", sort=True):
# simple header for traceability # simple header for traceability
header = ( header = (
f"[TOPIC] {topic_val}\n" + f"[Stats] N={len(chunk)} | Source={INPUT_PATH}\n" f"[TOPIC] {topic_val}\n"
+ f"[Stats] N={len(chunk)} | Source={INPUT_PATH}\n"
) )
lines = [header, ""] lines = [header, ""]