mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2026-03-22 00:12:42 +01:00
22.02.
This commit is contained in:
@@ -16,10 +16,10 @@ from bertopic import BERTopic
|
|||||||
|
|
||||||
param_grid = {
|
param_grid = {
|
||||||
"n_gram_max": [2, 3], # Vectorization
|
"n_gram_max": [2, 3], # Vectorization
|
||||||
"min_document_frequency": [1], # Vectorization
|
"min_document_frequency": [1, 2], # Vectorization
|
||||||
"min_samples": [10, 25], # HDBSCAN
|
"min_samples": [10, 25], # HDBSCAN
|
||||||
"min_topic_size": [10, 20, 30, 40, 50], # HDBSCAN
|
"min_topic_size": [100, 200], # HDBSCAN
|
||||||
"n_neighbors": [15], # UMAP
|
"n_neighbors": [15, 25], # UMAP
|
||||||
"n_components": [2, 5], # UMAP
|
"n_components": [2, 5], # UMAP
|
||||||
"min_dist": [0.01, 0.1], # UMAP
|
"min_dist": [0.01, 0.1], # UMAP
|
||||||
"nr_topics": ["auto"], # Topic Modeling
|
"nr_topics": ["auto"], # Topic Modeling
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
|
|||||||
with open("output/autotune.json", "r") as f:
|
with open("output/autotune.json", "r") as f:
|
||||||
history = json.load(f)
|
history = json.load(f)
|
||||||
|
|
||||||
history = sorted(history, key=lambda x: x["metrics"]["combined_score"], reverse=False)
|
history = sorted(history, key=lambda x: x["metrics"]["combined_score"], reverse=True)
|
||||||
|
|
||||||
with open("output/autotune_sorted.json", "w") as f:
|
with open("output/autotune_sorted.json", "w") as f:
|
||||||
json.dump(history, f, indent=2)
|
json.dump(history, f, indent=2)
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 21 KiB |
@@ -360,7 +360,6 @@ vis = topic_model.visualize_documents(
|
|||||||
custom_labels=True,
|
custom_labels=True,
|
||||||
hide_annotations=True,
|
hide_annotations=True,
|
||||||
)
|
)
|
||||||
# vis.write_html("output/visualization.html")
|
|
||||||
vis
|
vis
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
@@ -497,7 +496,12 @@ if CALCULATE_TOKEN_DISTRIBUTIONS:
|
|||||||
#
|
#
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
topic_model.visualize_hierarchy(custom_labels=True)
|
topic_model.visualize_hierarchy(custom_labels=True, color_threshold=0.98)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
hierarchical_topics = topic_model.hierarchical_topics(reviews)
|
||||||
|
tree = topic_model.get_topic_tree(hier_topics=hierarchical_topics)
|
||||||
|
print(tree)
|
||||||
|
|
||||||
# %% [markdown]
|
# %% [markdown]
|
||||||
# ### Intertopic Distance Map
|
# ### Intertopic Distance Map
|
||||||
@@ -512,3 +516,20 @@ topic_model.visualize_topics(use_ctfidf=True)
|
|||||||
|
|
||||||
# %%
|
# %%
|
||||||
topic_model.visualize_barchart(top_n_topics=12, custom_labels=True, n_words=10)
|
topic_model.visualize_barchart(top_n_topics=12, custom_labels=True, n_words=10)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from wordcloud import WordCloud
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def create_wordcloud(model, topic):
|
||||||
|
text = {word: value for word, value in model.get_topic(topic)}
|
||||||
|
wc = WordCloud(background_color="white", max_words=1000)
|
||||||
|
wc.generate_from_frequencies(text)
|
||||||
|
plt.imshow(wc, interpolation="bilinear")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# Show wordcloud
|
||||||
|
create_wordcloud(topic_model, topic=1)
|
||||||
|
|||||||
519
bertopic/nb_bertopic_temples.py
Normal file
519
bertopic/nb_bertopic_temples.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
# ---
|
||||||
|
# jupyter:
|
||||||
|
# jupytext:
|
||||||
|
# text_representation:
|
||||||
|
# extension: .py
|
||||||
|
# format_name: percent
|
||||||
|
# format_version: '1.3'
|
||||||
|
# jupytext_version: 1.18.0
|
||||||
|
# kernelspec:
|
||||||
|
# display_name: .venv (3.12.3)
|
||||||
|
# language: python
|
||||||
|
# name: python3
|
||||||
|
# ---
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# # Topic Detection: Bali Tourist Reviews
|
||||||
|
#
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Preparation
|
||||||
|
#
|
||||||
|
# ### Dependency Loading
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
|
||||||
|
import gensim.corpora as corpora
|
||||||
|
import nltk
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from bertopic.representation import KeyBERTInspired
|
||||||
|
from bertopic.vectorizers import ClassTfidfTransformer
|
||||||
|
from gensim.models.coherencemodel import CoherenceModel
|
||||||
|
from hdbscan import HDBSCAN
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
|
from sklearn.feature_extraction import text as skltext
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
from umap import UMAP
|
||||||
|
|
||||||
|
from bertopic import BERTopic
|
||||||
|
|
||||||
|
nltk.download("stopwords")
|
||||||
|
nltk.download("punkt")
|
||||||
|
nltk.download("wordnet")
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Hyperparameters and Settings
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
RECREATE_MODEL = True
|
||||||
|
RECREATE_REDUCED_MODEL = True
|
||||||
|
PROCESS_DATA = True
|
||||||
|
REDUCE_OUTLIERS = False
|
||||||
|
CALCULATE_TOKEN_DISTRIBUTIONS = False
|
||||||
|
|
||||||
|
# Data Sample Size, -1 for all data
|
||||||
|
DATA_SAMPLE_SIZE = -1
|
||||||
|
|
||||||
|
# Vectorization
|
||||||
|
MIN_DOCUMENT_FREQUENCY = 1
|
||||||
|
MAX_NGRAM = 3
|
||||||
|
|
||||||
|
# HDBSCAN Parameters
|
||||||
|
MIN_TOPIC_SIZE = 15
|
||||||
|
MIN_SAMPLES = 15
|
||||||
|
|
||||||
|
# UMAP Parameters
|
||||||
|
N_NEIGHBORS = 15
|
||||||
|
N_COMPONENTS = 2
|
||||||
|
MIN_DIST = 0.01
|
||||||
|
|
||||||
|
# Topic Modeling
|
||||||
|
TOP_N_WORDS = 10
|
||||||
|
MAX_TOPICS = None # or "auto" to pass to HDBSCAN, None to skip
|
||||||
|
|
||||||
|
TF_IDF_STOP_WORDS = ["bali", "place", "visit", "visited", "visiting"]
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Data Loading & Preprocessing
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Import data after general preprocessing
|
||||||
|
|
||||||
|
if DATA_SAMPLE_SIZE == -1:
|
||||||
|
reviews = pd.read_csv(
|
||||||
|
"../data/intermediate/culture_reviews.csv", sep=","
|
||||||
|
).Original.to_list()
|
||||||
|
else:
|
||||||
|
reviews = (
|
||||||
|
pd.read_csv("../data/intermediate/culture_reviews.csv", sep=",")
|
||||||
|
.sample(n=DATA_SAMPLE_SIZE)
|
||||||
|
.Original.to_list()
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Loaded {} reviews".format(len(reviews)))
|
||||||
|
|
||||||
|
# %%
|
||||||
|
rep = {
|
||||||
|
r"\\n": " ",
|
||||||
|
r"\n": " ",
|
||||||
|
r'\\"': "",
|
||||||
|
r'"': "",
|
||||||
|
r"\s+": " ",
|
||||||
|
}
|
||||||
|
rep = dict((re.escape(k), v) for k, v in rep.items())
|
||||||
|
pattern = re.compile("|".join(rep.keys()))
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess(text):
|
||||||
|
text = text.strip()
|
||||||
|
text = text.lower()
|
||||||
|
text = pattern.sub(lambda m: rep[re.escape(m.group(0))], text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
print(
|
||||||
|
preprocess(
|
||||||
|
"Excellent. Definitely worth coming while in bali. Food and people were very nice.\n🌟 🤩 ⭐️ \nTrisna was our host"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
if PROCESS_DATA:
|
||||||
|
print("Processing reviews...")
|
||||||
|
reviews = [preprocess(review) for review in reviews]
|
||||||
|
|
||||||
|
with open("../data/intermediate/processed_texts_culture.pkl", "wb") as f:
|
||||||
|
pickle.dump(reviews, f)
|
||||||
|
else:
|
||||||
|
with open("../data/intermediate/processed_texts_culture.pkl", "rb") as f:
|
||||||
|
reviews = pickle.load(f)
|
||||||
|
|
||||||
|
print(reviews[:1])
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Pre-calculate Embeddings
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
|
embeddings = embedding_model.encode(reviews, show_progress_bar=True)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Model Creation
|
||||||
|
#
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Dimensionality Reduction (UMAP)
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
umap_model = UMAP(
|
||||||
|
n_neighbors=N_NEIGHBORS,
|
||||||
|
n_components=N_COMPONENTS,
|
||||||
|
min_dist=MIN_DIST,
|
||||||
|
metric="cosine",
|
||||||
|
low_memory=True,
|
||||||
|
random_state=42,
|
||||||
|
)
|
||||||
|
reduced_embeddings = umap_model.fit_transform(embeddings)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### BERTopic Model Creation
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
if RECREATE_MODEL:
|
||||||
|
stop_words = list(skltext.ENGLISH_STOP_WORDS.union(TF_IDF_STOP_WORDS))
|
||||||
|
|
||||||
|
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)
|
||||||
|
vectorizer_model = CountVectorizer(
|
||||||
|
min_df=MIN_DOCUMENT_FREQUENCY,
|
||||||
|
ngram_range=(1, MAX_NGRAM),
|
||||||
|
stop_words=stop_words,
|
||||||
|
)
|
||||||
|
|
||||||
|
representation_model = KeyBERTInspired()
|
||||||
|
hdbscan_model = HDBSCAN(
|
||||||
|
min_cluster_size=MIN_TOPIC_SIZE,
|
||||||
|
min_samples=MIN_SAMPLES,
|
||||||
|
metric="euclidean",
|
||||||
|
cluster_selection_method="eom",
|
||||||
|
gen_min_span_tree=True,
|
||||||
|
prediction_data=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
topic_model = BERTopic(
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
ctfidf_model=ctfidf_model,
|
||||||
|
vectorizer_model=vectorizer_model,
|
||||||
|
umap_model=umap_model,
|
||||||
|
hdbscan_model=hdbscan_model,
|
||||||
|
representation_model=representation_model,
|
||||||
|
verbose=True,
|
||||||
|
calculate_probabilities=True,
|
||||||
|
language="english",
|
||||||
|
top_n_words=TOP_N_WORDS,
|
||||||
|
nr_topics=MAX_TOPICS,
|
||||||
|
)
|
||||||
|
|
||||||
|
topics, probs = topic_model.fit_transform(reviews, embeddings=embeddings)
|
||||||
|
|
||||||
|
topic_labels = topic_model.generate_topic_labels(
|
||||||
|
nr_words=3, topic_prefix=True, word_length=15, separator=" - "
|
||||||
|
)
|
||||||
|
topic_model.set_topic_labels(topic_labels)
|
||||||
|
# BERTopic.save(topic_model, "bertopic/model.bertopic")
|
||||||
|
else:
|
||||||
|
print("Nevermind, loading existing model")
|
||||||
|
# topic_model = BERTopic.load("bertopic/model.bertopic")
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Fine Tuning
|
||||||
|
#
|
||||||
|
# ### Topic Condensation
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
if RECREATE_REDUCED_MODEL:
|
||||||
|
done = False
|
||||||
|
iteration = 1
|
||||||
|
while not done:
|
||||||
|
print(f"Iteration {iteration}")
|
||||||
|
iteration += 1
|
||||||
|
similarity_matrix = cosine_similarity(
|
||||||
|
np.array(topic_model.topic_embeddings_)[1:, :]
|
||||||
|
)
|
||||||
|
nothing_to_merge = True
|
||||||
|
|
||||||
|
for i in range(similarity_matrix.shape[0]):
|
||||||
|
for j in range(i + 1, similarity_matrix.shape[1]):
|
||||||
|
try:
|
||||||
|
sim = similarity_matrix[i, j]
|
||||||
|
if sim > 0.9:
|
||||||
|
nothing_to_merge = False
|
||||||
|
t1, t2 = i, j
|
||||||
|
try:
|
||||||
|
t1_name = topic_model.get_topic_info(t1)["CustomName"][0]
|
||||||
|
t2_name = topic_model.get_topic_info(t2)["CustomName"][0]
|
||||||
|
print(
|
||||||
|
f"Merging topics {t1} ({t1_name}) and {t2} ({t2_name}) with similarity {sim:.2f}"
|
||||||
|
)
|
||||||
|
topic_model.merge_topics(reviews, topics_to_merge=[t1, t2])
|
||||||
|
|
||||||
|
topic_labels = topic_model.generate_topic_labels(
|
||||||
|
nr_words=3,
|
||||||
|
topic_prefix=True,
|
||||||
|
word_length=15,
|
||||||
|
separator=" - ",
|
||||||
|
)
|
||||||
|
topic_model.set_topic_labels(topic_labels)
|
||||||
|
similarity_matrix = cosine_similarity(
|
||||||
|
np.array(topic_model.topic_embeddings_)[1:, :]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to merge {t1} and {t2}: {e}")
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
if nothing_to_merge:
|
||||||
|
print("No more topics to merge.")
|
||||||
|
done = True
|
||||||
|
else:
|
||||||
|
print("Skipping topic reduction")
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Outlier Reduction
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
if REDUCE_OUTLIERS:
|
||||||
|
new_topics = topic_model.reduce_outliers(
|
||||||
|
reviews,
|
||||||
|
topic_model.topics_,
|
||||||
|
probabilities=topic_model.probabilities_,
|
||||||
|
threshold=0.05,
|
||||||
|
strategy="probabilities",
|
||||||
|
)
|
||||||
|
topic_model.update_topics(reviews, topics=new_topics)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ## Results
|
||||||
|
#
|
||||||
|
# ### Classification
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
CLASSIFICATION = False
|
||||||
|
if CLASSIFICATION:
|
||||||
|
topics_to_keep = {14, 8, 13, 18, 17, 4, 2, 30, 28}
|
||||||
|
INPUT_PATH = "../data/intermediate/preprocessed.tab" # TSV with a 'review' column
|
||||||
|
OUTPUT_CSV = "../data/intermediate/culture_reviews.csv"
|
||||||
|
|
||||||
|
# Topic model document info
|
||||||
|
df = topic_model.get_document_info(reviews)
|
||||||
|
df["Original"] = reviews
|
||||||
|
|
||||||
|
# --- filter by topics and length ---
|
||||||
|
filtered = df[df["Topic"].isin(topics_to_keep)].copy()
|
||||||
|
filtered["Original"] = filtered["Original"].str.strip()
|
||||||
|
|
||||||
|
# Save an audit CSV
|
||||||
|
filtered[["Original", "Topic"]].to_csv(OUTPUT_CSV, index=False, sep=",")
|
||||||
|
print(f"Filtered CSV file saved to {OUTPUT_CSV}")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
doc_topic_matrix = probs
|
||||||
|
|
||||||
|
# column names
|
||||||
|
topicnames = ["Topic " + str(i) for i in range(len(set(topics)) - 1)]
|
||||||
|
|
||||||
|
# index names
|
||||||
|
docnames = ["Review " + str(i) for i in range(len(reviews))]
|
||||||
|
|
||||||
|
# Make the pandas dataframe
|
||||||
|
df_document_topic = pd.DataFrame(
|
||||||
|
np.round(doc_topic_matrix, 2), columns=topicnames, index=docnames
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get dominant topic for each document
|
||||||
|
dominant_topic = np.argmax(doc_topic_matrix, axis=1)
|
||||||
|
df_document_topic["dominant_topic"] = dominant_topic
|
||||||
|
|
||||||
|
|
||||||
|
# Styling
|
||||||
|
def color_stuff(val):
|
||||||
|
if val > 0.1:
|
||||||
|
color = "green"
|
||||||
|
elif val > 0.05:
|
||||||
|
color = "orange"
|
||||||
|
else:
|
||||||
|
color = "grey"
|
||||||
|
return "color: {col}".format(col=color)
|
||||||
|
|
||||||
|
|
||||||
|
def make_bold(val):
|
||||||
|
weight = 700 if val > 0.1 else 400
|
||||||
|
return "font-weight: {weight}".format(weight=weight)
|
||||||
|
|
||||||
|
|
||||||
|
# Apply Style
|
||||||
|
df_document_topics = (
|
||||||
|
df_document_topic.head(15).style.applymap(color_stuff).applymap(make_bold)
|
||||||
|
)
|
||||||
|
df_document_topics
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Document Visualization
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
vis = topic_model.visualize_documents(
|
||||||
|
docs=reviews,
|
||||||
|
reduced_embeddings=reduced_embeddings,
|
||||||
|
custom_labels=True,
|
||||||
|
hide_annotations=True,
|
||||||
|
)
|
||||||
|
# vis.write_html("output/visualization.html")
|
||||||
|
vis
|
||||||
|
|
||||||
|
# %%
|
||||||
|
topic_model.visualize_document_datamap(reviews, reduced_embeddings=reduced_embeddings)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Similarity Matrix
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
topic_model.visualize_heatmap()
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Topic Info
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
topic_model.get_topic_info()
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Semantic Coherence
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
topic_words = []
|
||||||
|
for topic_id in topic_model.get_topic_info()["Topic"]:
|
||||||
|
# Skip outlier topic
|
||||||
|
if topic_id < 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
words = [word for word, _ in topic_model.get_topic(topic_id)]
|
||||||
|
topic_words.append(words)
|
||||||
|
|
||||||
|
# Compute mean pairwise cosine similarity for each topic
|
||||||
|
coherence_scores = []
|
||||||
|
for words in topic_words:
|
||||||
|
coherence_embeddings = embedding_model.encode(words)
|
||||||
|
sim_matrix = cosine_similarity(coherence_embeddings)
|
||||||
|
|
||||||
|
# Ignore self-similarity
|
||||||
|
np.fill_diagonal(sim_matrix, 0)
|
||||||
|
mean_sim = np.mean(sim_matrix[np.triu_indices(sim_matrix.shape[0], k=1)])
|
||||||
|
coherence_scores.append(mean_sim)
|
||||||
|
|
||||||
|
overall_coherence = np.mean(coherence_scores)
|
||||||
|
|
||||||
|
print(len(reviews), "reviews processed")
|
||||||
|
print(len(topic_model.get_topic_info()) - 1, "topics found")
|
||||||
|
print(f"BERT-based Topic Coherence: {overall_coherence:.4f}")
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Topic Coherence
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# https://github.com/MaartenGr/BERTopic/issues/90#issuecomment-820915389
|
||||||
|
|
||||||
|
# This will most likely crash your PC
|
||||||
|
this_will_crash_your_pc_are_you_sure = False
|
||||||
|
if this_will_crash_your_pc_are_you_sure:
|
||||||
|
# Preprocess Documents
|
||||||
|
documents = pd.DataFrame(
|
||||||
|
{"Document": reviews, "ID": range(len(reviews)), "Topic": topics}
|
||||||
|
)
|
||||||
|
documents_per_topic = documents.groupby(["Topic"], as_index=False).agg(
|
||||||
|
{"Document": " ".join}
|
||||||
|
)
|
||||||
|
cleaned_docs = topic_model._preprocess_text(documents_per_topic.Document.values)
|
||||||
|
|
||||||
|
# Extract vectorizer and analyzer from BERTopic
|
||||||
|
vectorizer = topic_model.vectorizer_model
|
||||||
|
analyzer = vectorizer.build_analyzer()
|
||||||
|
|
||||||
|
# Extract features for Topic Coherence evaluation
|
||||||
|
words = vectorizer.get_feature_names_out()
|
||||||
|
tokens = [analyzer(doc) for doc in cleaned_docs]
|
||||||
|
dictionary = corpora.Dictionary(tokens)
|
||||||
|
corpus = [dictionary.doc2bow(token) for token in tokens]
|
||||||
|
|
||||||
|
for topic_id in topic_model.get_topic_info()["Topic"]:
|
||||||
|
# Skip outlier topic
|
||||||
|
if topic_id < 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
words = [word for word, _ in topic_model.get_topic(topic_id)]
|
||||||
|
topic_words.append(words)
|
||||||
|
|
||||||
|
# %env TOKENIZERS_PARALLELISM=false
|
||||||
|
|
||||||
|
for measurement in ["c_v", "u_mass", "c_uci", "c_npmi"]:
|
||||||
|
coherence_model = CoherenceModel(
|
||||||
|
topics=topic_words,
|
||||||
|
texts=tokens,
|
||||||
|
corpus=corpus,
|
||||||
|
dictionary=dictionary,
|
||||||
|
coherence=measurement,
|
||||||
|
)
|
||||||
|
coherence_score = coherence_model.get_coherence()
|
||||||
|
print(f"Coherence ({measurement}): {coherence_score:.4f}")
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Term Search
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
search_term = "lempuyang"
|
||||||
|
|
||||||
|
similar_topics, similarities = topic_model.find_topics(search_term, top_n=10)
|
||||||
|
for i in range(len(similar_topics)):
|
||||||
|
print(
|
||||||
|
f"{str(similarities[i])[:5]} {topic_model.get_topic_info(similar_topics[i])['CustomName'][0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Source: https://maartengr.github.io/BERTopic/getting_started/visualization/visualize_documents.html#visualize-probabilities-or-distribution
|
||||||
|
# Calculate the topic distributions on a token-level
|
||||||
|
|
||||||
|
if CALCULATE_TOKEN_DISTRIBUTIONS:
|
||||||
|
topic_distr, topic_token_distr = topic_model.approximate_distribution(
|
||||||
|
reviews, calculate_tokens=True, use_embedding_model=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Visualize the token-level distributions
|
||||||
|
if CALCULATE_TOKEN_DISTRIBUTIONS:
|
||||||
|
DOC_INDEX = 1
|
||||||
|
df = topic_model.visualize_approximate_distribution(
|
||||||
|
reviews[DOC_INDEX], topic_token_distr[DOC_INDEX]
|
||||||
|
)
|
||||||
|
df
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Topic Hierarchy
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
topic_model.visualize_hierarchy(custom_labels=True)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
hierarchical_topics = topic_model.hierarchical_topics(reviews)
|
||||||
|
tree = topic_model.get_topic_tree(hier_topics=hierarchical_topics)
|
||||||
|
print(tree)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Intertopic Distance Map
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
topic_model.visualize_topics(use_ctfidf=True)
|
||||||
|
|
||||||
|
# %% [markdown]
|
||||||
|
# ### Topic Word Scores
|
||||||
|
#
|
||||||
|
|
||||||
|
# %%
|
||||||
|
topic_model.visualize_barchart(top_n_topics=12, custom_labels=True, n_words=10)
|
||||||
File diff suppressed because it is too large
Load Diff
290
bertopic/output/autotune_sorted.json
Normal file
290
bertopic/output/autotune_sorted.json
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.1,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 10,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 2,
|
||||||
|
"n_gram_max": 2,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.498,
|
||||||
|
"diversity": 1.0,
|
||||||
|
"combined_score": 0.6486
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.1,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 25,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 2,
|
||||||
|
"n_gram_max": 2,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.498,
|
||||||
|
"diversity": 1.0,
|
||||||
|
"combined_score": 0.6486
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.1,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 10,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 2,
|
||||||
|
"n_gram_max": 3,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4915,
|
||||||
|
"diversity": 0.9666,
|
||||||
|
"combined_score": 0.634
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.1,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 25,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 2,
|
||||||
|
"n_gram_max": 3,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4915,
|
||||||
|
"diversity": 0.9666,
|
||||||
|
"combined_score": 0.634
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.01,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 10,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 5,
|
||||||
|
"n_gram_max": 2,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4531,
|
||||||
|
"diversity": 0.975,
|
||||||
|
"combined_score": 0.6096
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.01,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 25,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 5,
|
||||||
|
"n_gram_max": 2,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4531,
|
||||||
|
"diversity": 0.975,
|
||||||
|
"combined_score": 0.6096
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.01,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 10,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 5,
|
||||||
|
"n_gram_max": 3,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4617,
|
||||||
|
"diversity": 0.95,
|
||||||
|
"combined_score": 0.6082
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.01,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 25,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 5,
|
||||||
|
"n_gram_max": 3,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4617,
|
||||||
|
"diversity": 0.95,
|
||||||
|
"combined_score": 0.6082
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.1,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 10,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 5,
|
||||||
|
"n_gram_max": 2,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4287,
|
||||||
|
"diversity": 1.0,
|
||||||
|
"combined_score": 0.6001
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.1,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 25,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 5,
|
||||||
|
"n_gram_max": 2,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4287,
|
||||||
|
"diversity": 1.0,
|
||||||
|
"combined_score": 0.6001
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.1,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 10,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 5,
|
||||||
|
"n_gram_max": 3,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.427,
|
||||||
|
"diversity": 1.0,
|
||||||
|
"combined_score": 0.5989
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.1,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 25,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 5,
|
||||||
|
"n_gram_max": 3,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.427,
|
||||||
|
"diversity": 1.0,
|
||||||
|
"combined_score": 0.5989
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.01,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 10,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 2,
|
||||||
|
"n_gram_max": 3,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4462,
|
||||||
|
"diversity": 0.925,
|
||||||
|
"combined_score": 0.5898
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.01,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 25,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 2,
|
||||||
|
"n_gram_max": 3,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4462,
|
||||||
|
"diversity": 0.925,
|
||||||
|
"combined_score": 0.5898
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.01,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 10,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 2,
|
||||||
|
"n_gram_max": 2,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4456,
|
||||||
|
"diversity": 0.925,
|
||||||
|
"combined_score": 0.5894
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": {
|
||||||
|
"min_dist": 0.01,
|
||||||
|
"min_document_frequency": 1,
|
||||||
|
"min_samples": 25,
|
||||||
|
"min_topic_size": 200,
|
||||||
|
"n_components": 2,
|
||||||
|
"n_gram_max": 2,
|
||||||
|
"n_neighbors": 15,
|
||||||
|
"nr_topics": "auto",
|
||||||
|
"top_n_words": 10
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"coherence": 0.4456,
|
||||||
|
"diversity": 0.925,
|
||||||
|
"combined_score": 0.5894
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
File diff suppressed because one or more lines are too long
@@ -131,3 +131,4 @@ spacy
|
|||||||
nbconvert
|
nbconvert
|
||||||
jupytext
|
jupytext
|
||||||
datamapplot
|
datamapplot
|
||||||
|
wordcloud
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -2,22 +2,10 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Generate 300–1000+ English interview questions targeted ONLY at culturally/spiritually
|
Generate trainer prompts
|
||||||
interested Bali tourists (Lead Users), covering 5 cognitive destination image dimensions:
|
|
||||||
- Natural Attractions
|
|
||||||
- Atmosphere
|
|
||||||
- Social Environment
|
|
||||||
- Infrastructure
|
|
||||||
- Value for Money
|
|
||||||
|
|
||||||
Key constraint:
|
|
||||||
- Every prompt must be meaningful for culture/spirituality-first travelers.
|
|
||||||
- Avoid party/shopping/hedonistic positioning.
|
|
||||||
- Include etiquette, authenticity, sacredness, commodification, meaning-making, reflection.
|
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
- JSONL: {"dimension": "...", "type": "...", "prompt": "...", "tags": [...]}
|
- JSONL: {"dimension": "...", "type": "...", "prompt": "...", "tags": [...]}
|
||||||
- or TXT: one prompt per line
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -26,6 +14,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
# Cognitive Image Dimensions
|
||||||
DIMENSIONS = [
|
DIMENSIONS = [
|
||||||
"Natural Attractions",
|
"Natural Attractions",
|
||||||
"Atmosphere",
|
"Atmosphere",
|
||||||
@@ -37,7 +26,8 @@ DIMENSIONS = [
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Segment-specific building blocks
|
# Segment-specific building blocks
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Keep places generic (no need to hallucinate specific proper nouns)
|
#
|
||||||
|
# Intentionally generic, details should come from retrieved context
|
||||||
NATURE_FOR_MEANING = [
|
NATURE_FOR_MEANING = [
|
||||||
"rice terraces that feel lived-in rather than staged",
|
"rice terraces that feel lived-in rather than staged",
|
||||||
"waterfalls approached with a quiet, respectful mood",
|
"waterfalls approached with a quiet, respectful mood",
|
||||||
@@ -145,7 +135,7 @@ CONSTRAINTS = [
|
|||||||
[
|
[
|
||||||
"it's rainy season and flexibility is part of respectful travel",
|
"it's rainy season and flexibility is part of respectful travel",
|
||||||
"it's very hot and you need a pace that still feels mindful",
|
"it's very hot and you need a pace that still feels mindful",
|
||||||
"visibility is low and your sunrise plan may fail—how do you adapt meaningfully?",
|
"visibility is low and your sunrise plan may fail-how do you adapt meaningfully?",
|
||||||
"roads feel unsafe, so you prioritize fewer moves and deeper presence",
|
"roads feel unsafe, so you prioritize fewer moves and deeper presence",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
@@ -263,7 +253,7 @@ def tmpl_single_dimension(
|
|||||||
) -> str:
|
) -> str:
|
||||||
return (
|
return (
|
||||||
f"{style} your experience with {place_hint} in Bali during {context}. "
|
f"{style} your experience with {place_hint} in Bali during {context}. "
|
||||||
f"From a {d} perspective, what stands out about {theme}—and why does it matter to you as a culture/spirit-oriented traveler?"
|
f"From a {d} perspective, what stands out about {theme}-and why does it matter to you as a culture/spirit-oriented traveler?"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -295,7 +285,7 @@ def tmpl_marketer_advice(d: str, theme: str, constraint: str, dont_claim: str) -
|
|||||||
return (
|
return (
|
||||||
f"If you had to advise a tourism marketer for culturally/spiritually interested travelers: under the constraint '{constraint}', "
|
f"If you had to advise a tourism marketer for culturally/spiritually interested travelers: under the constraint '{constraint}', "
|
||||||
f"what should they understand about {d} (especially {theme})? "
|
f"what should they understand about {d} (especially {theme})? "
|
||||||
f"Also: what is one thing they should NOT claim in messaging because it would feel misleading or disrespectful—e.g., {dont_claim}?"
|
f"Also: what is one thing they should NOT claim in messaging because it would feel misleading or disrespectful-e.g., {dont_claim}?"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -342,7 +332,7 @@ def generate_prompts(
|
|||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
rng = random.Random(seed)
|
rng = random.Random(seed)
|
||||||
|
|
||||||
# Mix of question archetypes, all segment-targeted
|
# Different weights for question archetypes
|
||||||
types = [
|
types = [
|
||||||
("single", 0.24),
|
("single", 0.24),
|
||||||
("laddering", 0.18),
|
("laddering", 0.18),
|
||||||
@@ -424,7 +414,7 @@ def generate_prompts(
|
|||||||
"dimension": d,
|
"dimension": d,
|
||||||
"type": "single",
|
"type": "single",
|
||||||
"prompt": q,
|
"prompt": q,
|
||||||
"tags": [d, theme, context, "segment:culture-spirit"],
|
"tags": [d, theme, context],
|
||||||
}
|
}
|
||||||
ok = add_prompt(obj)
|
ok = add_prompt(obj)
|
||||||
|
|
||||||
@@ -435,7 +425,7 @@ def generate_prompts(
|
|||||||
"dimension": d,
|
"dimension": d,
|
||||||
"type": "laddering",
|
"type": "laddering",
|
||||||
"prompt": q,
|
"prompt": q,
|
||||||
"tags": [d, theme, context, "laddering", "segment:culture-spirit"],
|
"tags": [d, theme, context, "laddering"],
|
||||||
}
|
}
|
||||||
ok = add_prompt(obj)
|
ok = add_prompt(obj)
|
||||||
|
|
||||||
@@ -447,7 +437,7 @@ def generate_prompts(
|
|||||||
"dimension": d,
|
"dimension": d,
|
||||||
"type": "contrast",
|
"type": "contrast",
|
||||||
"prompt": q,
|
"prompt": q,
|
||||||
"tags": [d, "contrast", context, "segment:culture-spirit"],
|
"tags": [d, "contrast", context],
|
||||||
}
|
}
|
||||||
ok = add_prompt(obj)
|
ok = add_prompt(obj)
|
||||||
|
|
||||||
@@ -459,7 +449,7 @@ def generate_prompts(
|
|||||||
"dimension": f"{d} + {d2}",
|
"dimension": f"{d} + {d2}",
|
||||||
"type": "tradeoff",
|
"type": "tradeoff",
|
||||||
"prompt": q,
|
"prompt": q,
|
||||||
"tags": [d, d2, "tradeoff", c_key, "segment:culture-spirit"],
|
"tags": [d, d2, "tradeoff", c_key],
|
||||||
}
|
}
|
||||||
ok = add_prompt(obj)
|
ok = add_prompt(obj)
|
||||||
|
|
||||||
@@ -470,7 +460,7 @@ def generate_prompts(
|
|||||||
"dimension": d,
|
"dimension": d,
|
||||||
"type": "marketer_advice",
|
"type": "marketer_advice",
|
||||||
"prompt": q,
|
"prompt": q,
|
||||||
"tags": [d, theme, "marketer", c_key, "segment:culture-spirit"],
|
"tags": [d, theme, "marketer", c_key],
|
||||||
}
|
}
|
||||||
ok = add_prompt(obj)
|
ok = add_prompt(obj)
|
||||||
|
|
||||||
@@ -481,7 +471,7 @@ def generate_prompts(
|
|||||||
"dimension": d,
|
"dimension": d,
|
||||||
"type": "etiquette",
|
"type": "etiquette",
|
||||||
"prompt": q,
|
"prompt": q,
|
||||||
"tags": [d, "etiquette", topic, context, "segment:culture-spirit"],
|
"tags": [d, "etiquette", topic, context],
|
||||||
}
|
}
|
||||||
ok = add_prompt(obj)
|
ok = add_prompt(obj)
|
||||||
|
|
||||||
@@ -493,7 +483,7 @@ def generate_prompts(
|
|||||||
"dimension": d,
|
"dimension": d,
|
||||||
"type": "route_design",
|
"type": "route_design",
|
||||||
"prompt": q,
|
"prompt": q,
|
||||||
"tags": [d, "route", c_key, "segment:culture-spirit"],
|
"tags": [d, "route", c_key],
|
||||||
}
|
}
|
||||||
ok = add_prompt(obj)
|
ok = add_prompt(obj)
|
||||||
|
|
||||||
@@ -524,7 +514,7 @@ def main():
|
|||||||
"--n",
|
"--n",
|
||||||
type=int,
|
type=int,
|
||||||
default=600,
|
default=600,
|
||||||
help="Number of prompts to generate (300–1000 recommended).",
|
help="Number of prompts to generate.",
|
||||||
)
|
)
|
||||||
ap.add_argument("--seed", type=int, default=42)
|
ap.add_argument("--seed", type=int, default=42)
|
||||||
ap.add_argument("--out", default="culture_spirit_interview_prompts.jsonl")
|
ap.add_argument("--out", default="culture_spirit_interview_prompts.jsonl")
|
||||||
@@ -1,187 +1,455 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
"""
|
||||||
|
RAFT dataset builder (FAISS-based retrieval) -> Together.ai chat JSONL.
|
||||||
|
|
||||||
|
Inputs (from your indexing script):
|
||||||
|
- <index_dir>/faiss.index
|
||||||
|
- <index_dir>/docstore.jsonl
|
||||||
|
|
||||||
|
Process:
|
||||||
|
- Build a set of interview-style prompts (EN)
|
||||||
|
- For each prompt:
|
||||||
|
- Retrieve top-k chunks via FAISS cosine/IP
|
||||||
|
- Call DeepSeek Chat Completions API to generate a vivid, human-like Lead User answer
|
||||||
|
- Write training examples as JSONL in chat format (messages)
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- raft_train.jsonl
|
||||||
|
- raft_val.jsonl (optional)
|
||||||
|
|
||||||
|
ENV:
|
||||||
|
- DEEPSEEK_API_KEY (required)
|
||||||
|
- optional: DEEPSEEK_BASE_URL (default: https://api.deepseek.com)
|
||||||
|
- optional: DEEPSEEK_MODEL (default: deepseek-chat)
|
||||||
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import requests
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
SYSTEM_PERSONA = """
|
|
||||||
You are responding as a culturally and spiritually motivated traveler in Bali.
|
|
||||||
|
|
||||||
Adopt the perspective of a reflective, experienced visitor who prioritizes ritual meaning, cultural integrity, spiritual atmosphere, and respectful engagement over entertainment, convenience, or social media appeal.
|
|
||||||
|
|
||||||
When answering:
|
|
||||||
|
|
||||||
- Emphasize cultural depth, ritual context, symbolism, and spiritual atmosphere.
|
|
||||||
- Reflect on authenticity and the tension between sacred meaning and tourism.
|
|
||||||
- Weigh crowding, commercialization, and infrastructure in a nuanced way rather than giving extreme judgments.
|
|
||||||
- Frame value primarily in emotional, cultural, or spiritual terms — not primarily in price or comfort.
|
|
||||||
- Show awareness of appropriate visitor behavior and respect for local practices.
|
|
||||||
- Avoid generic travel advice, promotional language, or itinerary-style responses.
|
|
||||||
- Write in a thoughtful, first-person perspective.
|
|
||||||
- Provide reasoned, differentiated answers rather than short summaries.
|
|
||||||
- Do not list bullet points unless explicitly asked.
|
|
||||||
- Keep answers focused on the question.
|
|
||||||
|
|
||||||
Maintain consistency with this identity across all responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
TRAINER_PROMPT = "Create ONE realistic question from the perspective of a touristic marketer they might ask a culturally and spiritually interested traveler in Bali considered to be a lead user that can be answered using ONLY the CONTEXT.\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def load_docstore(path):
|
# -----------------------------
|
||||||
docs = []
|
# DeepSeek client (OpenAI-compatible)
|
||||||
|
# -----------------------------
|
||||||
|
@dataclass
|
||||||
|
class DeepSeekConfig:
|
||||||
|
api_key: str
|
||||||
|
base_url: str = "https://api.deepseek.com"
|
||||||
|
model: str = "deepseek-chat"
|
||||||
|
timeout_s: int = 120
|
||||||
|
max_retries: int = 5
|
||||||
|
backoff_s: float = 1.6
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekClient:
|
||||||
|
def __init__(self, cfg: DeepSeekConfig):
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self, messages: List[Dict], temperature: float = 0.85, max_tokens: int = 750
|
||||||
|
) -> str:
|
||||||
|
url = f"{self.cfg.base_url}/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.cfg.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": self.cfg.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
last_err = None
|
||||||
|
for attempt in range(self.cfg.max_retries):
|
||||||
|
try:
|
||||||
|
r = requests.post(
|
||||||
|
url, headers=headers, json=payload, timeout=self.cfg.timeout_s
|
||||||
|
)
|
||||||
|
if r.status_code == 429:
|
||||||
|
time.sleep(self.cfg.backoff_s ** (attempt + 1))
|
||||||
|
continue
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
return data["choices"][0]["message"]["content"].strip()
|
||||||
|
except Exception as e:
|
||||||
|
last_err = e
|
||||||
|
time.sleep(self.cfg.backoff_s ** (attempt + 1))
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"DeepSeek API call failed after retries. Last error: {last_err}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Helpers
|
||||||
|
# -----------------------------
|
||||||
|
def simple_clean(text: str) -> str:
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return ""
|
||||||
|
text = text.replace("\u00a0", " ")
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def read_docstore(docstore_path: str) -> Dict[int, Dict]:
|
||||||
|
"""
|
||||||
|
Returns dict: faiss_id -> {"doc_id": int, "text": str, ...}
|
||||||
|
"""
|
||||||
|
mapping: Dict[int, Dict] = {}
|
||||||
|
with open(docstore_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
obj = json.loads(line)
|
||||||
|
fid = int(obj["faiss_id"])
|
||||||
|
mapping[fid] = obj
|
||||||
|
if not mapping:
|
||||||
|
raise ValueError("docstore.jsonl is empty or unreadable.")
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompts_from_jsonl(path: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Loads prompts from a JSONL file.
|
||||||
|
Expected key: 'prompt' (preferred). Also accepts 'question' or 'text'.
|
||||||
|
Ignores empty/short lines.
|
||||||
|
"""
|
||||||
|
prompts: List[str] = []
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
docs.append(json.loads(line))
|
line = line.strip()
|
||||||
return docs
|
if not line:
|
||||||
|
continue
|
||||||
|
obj = json.loads(line)
|
||||||
|
p = obj.get("prompt") or obj.get("question") or obj.get("text")
|
||||||
|
p = simple_clean(p) if p else ""
|
||||||
|
if len(p) >= 20:
|
||||||
|
prompts.append(p)
|
||||||
|
if not prompts:
|
||||||
|
raise ValueError(f"No prompts found in JSONL: {path}")
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
def retrieve(index, embedder, query, top_k=6):
|
def load_prompts_from_txt(path: str) -> List[str]:
|
||||||
q = embedder.encode([query], normalize_embeddings=True).astype(np.float32)
|
"""
|
||||||
scores, ids = index.search(q, top_k)
|
Loads prompts from a TXT file (one prompt per line).
|
||||||
return ids[0].tolist(), scores[0].tolist()
|
"""
|
||||||
|
prompts: List[str] = []
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
p = simple_clean(line)
|
||||||
|
if len(p) >= 20:
|
||||||
|
prompts.append(p)
|
||||||
|
if not prompts:
|
||||||
|
raise ValueError(f"No prompts found in TXT: {path}")
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
def ensure_dir_for_file(path: str):
|
||||||
def generate_text(model, tok, messages, max_new_tokens=220, temperature=0.7):
|
d = os.path.dirname(path)
|
||||||
# Using tokenizer chat template where available
|
if d:
|
||||||
enc = tok.apply_chat_template(
|
os.makedirs(d, exist_ok=True)
|
||||||
messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
|
|
||||||
|
|
||||||
|
def write_jsonl(path: str, rows: List[Dict]) -> None:
|
||||||
|
ensure_dir_for_file(path)
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
for r in rows:
|
||||||
|
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Persona + prompt templates (EN)
|
||||||
|
# -----------------------------
|
||||||
|
IMAGE_DIMS = [
|
||||||
|
"Natural Attractions",
|
||||||
|
"Atmosphere",
|
||||||
|
"Social Environment",
|
||||||
|
"Infrastructure",
|
||||||
|
"Value for Money",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_PROMPTS_EN = [
|
||||||
|
# Natural Attractions
|
||||||
|
"In a lead user interview: what natural places in Bali felt genuinely memorable to you (rice terraces, volcanoes, waterfalls, coast), and why? Describe it like a lived experience.",
|
||||||
|
"Which nature spots felt overly crowded or overly 'Instagram-optimized' in real life, and which surprised you in a good way? Explain with concrete moments.",
|
||||||
|
# Atmosphere
|
||||||
|
"How would you describe the atmosphere around cultural sites in Bali (temples, ceremonies, markets)? What signals authenticity vs. commercialization to you?",
|
||||||
|
"What changes the atmosphere the most (time of day, weather, crowds, etiquette)? Share specific examples you would tell a marketer.",
|
||||||
|
# Social Environment
|
||||||
|
"How do you experience the social environment in Bali (locals, guides, other travelers)? What feels respectful and what feels performative or touristy?",
|
||||||
|
"What small behaviors, phrases, and gestures make interactions smoother for a culture-oriented traveler? Give examples.",
|
||||||
|
# Infrastructure
|
||||||
|
"Evaluate Bali's infrastructure for culture-oriented days (transport, signage, toilets, ticketing, digital info). What works, what annoys you, and how do you adapt?",
|
||||||
|
"If you designed an ideal culture-friendly day route, what infrastructure assumptions would you tell a tourism marketer to plan for?",
|
||||||
|
# Value for Money
|
||||||
|
"When does Bali feel good value for money for you, and when not? Discuss entrance fees, guides, food, tours, and hidden costs.",
|
||||||
|
"How do you personally distinguish 'good value' from a tourist trap? List criteria and illustrate with examples.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_system_prompt() -> str:
|
||||||
|
return (
|
||||||
|
"ROLE / PERSONA\n"
|
||||||
|
"You are an experienced, culture-oriented Bali traveler (Lead User). You speak in natural, vivid English, "
|
||||||
|
"as a real person in an interview. You share nuanced judgments, trade-offs, and concrete scenes.\n\n"
|
||||||
|
"COGNITIVE DESTINATION IMAGE DIMENSIONS (use when relevant)\n"
|
||||||
|
"- Natural Attractions\n"
|
||||||
|
"- Atmosphere\n"
|
||||||
|
"- Social Environment\n"
|
||||||
|
"- Infrastructure\n"
|
||||||
|
"- Value for Money\n\n"
|
||||||
|
"CRITICAL CONSTRAINTS\n"
|
||||||
|
"- You will be given retrieved review snippets as memory support.\n"
|
||||||
|
"- Do NOT quote them verbatim and do NOT cite them as 'the review says'.\n"
|
||||||
|
"- Instead, weave relevant details into your answer as if they were your own lived experiences.\n"
|
||||||
|
"- If a detail is not supported, stay general or hedge appropriately.\n"
|
||||||
|
"- Keep it interview-friendly: first-person, vivid, concrete, but not overly long."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(enc, torch.Tensor):
|
|
||||||
input_ids = enc.to(model.device)
|
|
||||||
attention_mask = torch.ones_like(input_ids, device=model.device)
|
|
||||||
else:
|
|
||||||
input_ids = enc["input_ids"].to(model.device)
|
|
||||||
attention_mask = enc.get("attention_mask")
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
attention_mask = attention_mask.to(model.device)
|
|
||||||
|
|
||||||
out = model.generate(
|
def build_user_message(question: str, retrieved_chunks: List[str]) -> str:
|
||||||
input_ids=input_ids,
|
retrieved_chunks = [simple_clean(x) for x in retrieved_chunks if simple_clean(x)]
|
||||||
attention_mask=attention_mask,
|
bullets = "\n".join([f"- {c}" for c in retrieved_chunks])
|
||||||
max_new_tokens=max_new_tokens,
|
return (
|
||||||
do_sample=True,
|
f"INTERVIEW QUESTION:\n{question}\n\n"
|
||||||
temperature=temperature,
|
"RETRIEVED CONTEXT (review snippets; do NOT quote, only use as memory support):\n"
|
||||||
top_p=0.9,
|
f"{bullets}\n\n"
|
||||||
eos_token_id=tok.eos_token_id,
|
"Answer as a real Lead User in a tourism interview. Speak in first person, vivid and concrete, "
|
||||||
pad_token_id=tok.pad_token_id,
|
"and naturally touch relevant image dimensions."
|
||||||
)
|
)
|
||||||
return tok.decode(out[0][input_ids.shape[1] :], skip_special_tokens=True).strip()
|
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# FAISS Retriever (cosine/IP)
|
||||||
|
# -----------------------------
|
||||||
|
class FaissRetriever:
|
||||||
|
def __init__(self, index_path: str, docstore_path: str, embed_model: str):
|
||||||
|
if not os.path.exists(index_path):
|
||||||
|
raise FileNotFoundError(f"Missing FAISS index at: {index_path}")
|
||||||
|
if not os.path.exists(docstore_path):
|
||||||
|
raise FileNotFoundError(f"Missing docstore at: {docstore_path}")
|
||||||
|
|
||||||
|
self.index = faiss.read_index(index_path)
|
||||||
|
self.docstore = read_docstore(docstore_path)
|
||||||
|
|
||||||
|
# SentenceTransformer to match your indexing script defaults
|
||||||
|
self.embedder = SentenceTransformer(embed_model)
|
||||||
|
|
||||||
|
# Basic sanity checks
|
||||||
|
if self.index.ntotal != len(self.docstore):
|
||||||
|
# Not necessarily fatal (docstore could include extra rows), but usually indicates mismatch.
|
||||||
|
# We'll allow it but warn.
|
||||||
|
print(
|
||||||
|
f"Warning: index.ntotal={self.index.ntotal} but docstore rows={len(self.docstore)}. "
|
||||||
|
"Ensure they were generated together."
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve(self, query: str, k: int = 8) -> List[Tuple[int, float, str]]:
|
||||||
|
"""
|
||||||
|
Returns list of (faiss_id, score, text)
|
||||||
|
"""
|
||||||
|
q = simple_clean(query)
|
||||||
|
emb = self.embedder.encode([q], normalize_embeddings=True)
|
||||||
|
emb = np.asarray(emb, dtype=np.float32)
|
||||||
|
|
||||||
|
scores, ids = self.index.search(emb, k)
|
||||||
|
ids = ids[0].tolist()
|
||||||
|
scores = scores[0].tolist()
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for fid, sc in zip(ids, scores):
|
||||||
|
if fid == -1:
|
||||||
|
continue
|
||||||
|
doc = self.docstore.get(int(fid))
|
||||||
|
if not doc:
|
||||||
|
continue
|
||||||
|
out.append((int(fid), float(sc), doc.get("text", "")))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Dataset generation
|
||||||
|
# -----------------------------
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
ap.add_argument("--out_dir", default="out")
|
ap.add_argument(
|
||||||
|
"--index_dir",
|
||||||
|
default="out",
|
||||||
|
help="Directory containing faiss.index and docstore.jsonl",
|
||||||
|
)
|
||||||
|
ap.add_argument("--out_train", default="./out/raft_train.jsonl")
|
||||||
|
ap.add_argument("--out_val", default="./out/raft_val.jsonl")
|
||||||
|
ap.add_argument("--make_val", action="store_true")
|
||||||
|
ap.add_argument("--val_ratio", type=float, default=0.05)
|
||||||
|
ap.add_argument("--k", type=int, default=8)
|
||||||
|
ap.add_argument("--seed", type=int, default=42)
|
||||||
|
|
||||||
|
# Embeddings (must match indexing script for best results)
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
)
|
)
|
||||||
ap.add_argument("--teacher_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
|
||||||
ap.add_argument("--n_examples", type=int, default=5000)
|
|
||||||
ap.add_argument("--top_k", type=int, default=6)
|
|
||||||
ap.add_argument("--n_distractors", type=int, default=3)
|
|
||||||
ap.add_argument("--seed", type=int, default=42)
|
|
||||||
args = ap.parse_args()
|
|
||||||
|
|
||||||
random.seed(args.seed)
|
# External prompt sources
|
||||||
|
ap.add_argument(
|
||||||
faiss_path = os.path.join(args.out_dir, "faiss.index")
|
"--prompts_jsonl",
|
||||||
docstore_path = os.path.join(args.out_dir, "docstore.jsonl")
|
default=None,
|
||||||
|
help="JSONL file with prompts (key: prompt/question/text).",
|
||||||
index = faiss.read_index(faiss_path)
|
)
|
||||||
docstore = load_docstore(docstore_path)
|
ap.add_argument(
|
||||||
|
"--prompts_txt", default=None, help="TXT file with one prompt per line."
|
||||||
embedder = SentenceTransformer(args.embedding_model)
|
)
|
||||||
|
ap.add_argument(
|
||||||
# Teacher model to synthesize questions & answers from review chunks
|
"--shuffle_prompts",
|
||||||
tok = AutoTokenizer.from_pretrained(args.teacher_model, use_fast=True)
|
action="store_true",
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
help="Shuffle loaded prompts before generation.",
|
||||||
args.teacher_model, torch_dtype=torch.float16, device_map="auto"
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--limit_prompts",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="0 = no limit; else cap number of prompts used.",
|
||||||
)
|
)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
out_path = os.path.join(args.out_dir, "raft_train.jsonl")
|
# DeepSeek generation config
|
||||||
with open(out_path, "w", encoding="utf-8") as f:
|
ap.add_argument(
|
||||||
for _ in tqdm(range(args.n_examples), desc="Generating RAFT examples"):
|
"--deepseek_base_url",
|
||||||
# pick a "gold" chunk
|
default=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
|
||||||
gold = random.choice(docstore)
|
)
|
||||||
gold_text = gold["text"]
|
ap.add_argument(
|
||||||
|
"--deepseek_model", default=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
|
||||||
|
)
|
||||||
|
ap.add_argument("--temperature", type=float, default=0.85)
|
||||||
|
ap.add_argument("--max_tokens", type=int, default=750)
|
||||||
|
ap.add_argument(
|
||||||
|
"--max_examples",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="0 = all prompts; else limit number of examples",
|
||||||
|
)
|
||||||
|
|
||||||
# 1) generate a question answerable from gold_text
|
# pacing
|
||||||
q_prompt = [
|
ap.add_argument("--sleep_s", type=float, default=0.2)
|
||||||
{"role": "system", "content": SYSTEM_PERSONA},
|
|
||||||
{
|
args = ap.parse_args()
|
||||||
"role": "user",
|
random.seed(args.seed)
|
||||||
"content": TRAINER_PROMPT + f"CONTEXT:\n{gold_text}\n\n"
|
np.random.seed(args.seed)
|
||||||
"Return only the question.",
|
|
||||||
},
|
api_key = os.environ.get("DEEPSEEK_API_KEY", "").strip()
|
||||||
]
|
if not api_key:
|
||||||
question = generate_text(
|
raise SystemExit("Missing DEEPSEEK_API_KEY env var.")
|
||||||
model, tok, q_prompt, max_new_tokens=60, temperature=0.8
|
|
||||||
|
index_path = os.path.join(args.index_dir, "faiss.index")
|
||||||
|
docstore_path = os.path.join(args.index_dir, "docstore.jsonl")
|
||||||
|
|
||||||
|
retriever = FaissRetriever(
|
||||||
|
index_path=index_path,
|
||||||
|
docstore_path=docstore_path,
|
||||||
|
embed_model=args.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = DeepSeekClient(
|
||||||
|
DeepSeekConfig(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=args.deepseek_base_url,
|
||||||
|
model=args.deepseek_model,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = build_system_prompt()
|
||||||
|
|
||||||
|
# Load prompts (priority: JSONL -> TXT -> defaults)
|
||||||
|
if args.prompts_jsonl and args.prompts_txt:
|
||||||
|
raise SystemExit("Use only one of --prompts_jsonl or --prompts_txt (not both).")
|
||||||
|
|
||||||
|
if args.prompts_jsonl:
|
||||||
|
prompts = load_prompts_from_jsonl(args.prompts_jsonl)
|
||||||
|
elif args.prompts_txt:
|
||||||
|
prompts = load_prompts_from_txt(args.prompts_txt)
|
||||||
|
else:
|
||||||
|
prompts = list(DEFAULT_PROMPTS_EN)
|
||||||
|
|
||||||
|
if args.shuffle_prompts:
|
||||||
|
random.shuffle(prompts)
|
||||||
|
|
||||||
|
if args.limit_prompts and args.limit_prompts > 0:
|
||||||
|
prompts = prompts[: args.limit_prompts]
|
||||||
|
|
||||||
|
# Backwards-compat: args.max_examples can still cap prompts
|
||||||
|
if args.max_examples and args.max_examples > 0:
|
||||||
|
prompts = prompts[: args.max_examples]
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
for q in tqdm(prompts, desc="Generating RAFT examples"):
|
||||||
|
hits = retriever.retrieve(q, k=args.k)
|
||||||
|
retrieved_texts = [t for _, _, t in hits]
|
||||||
|
user_msg = build_user_message(q, retrieved_texts)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_msg},
|
||||||
|
]
|
||||||
|
|
||||||
|
answer = client.chat(
|
||||||
|
messages=messages,
|
||||||
|
temperature=args.temperature,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
ex = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_msg},
|
||||||
|
{"role": "assistant", "content": answer},
|
||||||
|
],
|
||||||
|
"meta": {
|
||||||
|
"retrieval_k": args.k,
|
||||||
|
"index_dir": os.path.abspath(args.index_dir),
|
||||||
|
"embedding_model": args.embedding_model,
|
||||||
|
"image_dimensions": IMAGE_DIMS,
|
||||||
|
"faiss_ids": [fid for fid, _, _ in hits],
|
||||||
|
"faiss_scores": [sc for _, sc, _ in hits],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
examples.append(ex)
|
||||||
|
|
||||||
|
if args.max_examples and len(examples) >= args.max_examples:
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(max(0.0, args.sleep_s))
|
||||||
|
|
||||||
|
random.shuffle(examples)
|
||||||
|
|
||||||
|
if args.make_val and len(examples) >= 20:
|
||||||
|
val_n = max(1, int(len(examples) * args.val_ratio))
|
||||||
|
val = examples[:val_n]
|
||||||
|
train = examples[val_n:]
|
||||||
|
write_jsonl(args.out_train, train)
|
||||||
|
write_jsonl(args.out_val, val)
|
||||||
|
print(f"Wrote train: {args.out_train} ({len(train)} examples)")
|
||||||
|
print(f"Wrote val: {args.out_val} ({len(val)} examples)")
|
||||||
|
else:
|
||||||
|
write_jsonl(args.out_train, examples)
|
||||||
|
print(f"Wrote: {args.out_train} ({len(examples)} examples)")
|
||||||
|
if args.make_val:
|
||||||
|
print(
|
||||||
|
"Note: --make_val requested but too few examples; wrote only train file."
|
||||||
)
|
)
|
||||||
question = question.split("\n")[0].strip()
|
|
||||||
|
|
||||||
# 2) retrieve top-k for that question
|
|
||||||
ids, _ = retrieve(index, embedder, question, top_k=args.top_k)
|
|
||||||
retrieved = [docstore[i] for i in ids]
|
|
||||||
|
|
||||||
# 3) add distractors (random docs not in retrieved)
|
|
||||||
retrieved_ids = set(ids)
|
|
||||||
distractors = []
|
|
||||||
attempts = 0
|
|
||||||
while len(distractors) < args.n_distractors and attempts < 50:
|
|
||||||
cand_idx = random.randrange(len(docstore))
|
|
||||||
attempts += 1
|
|
||||||
if cand_idx in retrieved_ids:
|
|
||||||
continue
|
|
||||||
distractors.append(docstore[cand_idx])
|
|
||||||
|
|
||||||
# Mix: retrieved + distractors
|
|
||||||
context_docs = retrieved + distractors
|
|
||||||
random.shuffle(context_docs)
|
|
||||||
|
|
||||||
# 4) generate grounded answer WITH short quotes
|
|
||||||
context_blob = ""
|
|
||||||
for j, d in enumerate(context_docs):
|
|
||||||
context_blob += f"[DOC {j}] {d['text']}\n\n"
|
|
||||||
|
|
||||||
a_prompt = [
|
|
||||||
{"role": "system", "content": SYSTEM_PERSONA},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Answer the question using ONLY the CONTEXT.\n"
|
|
||||||
"Rules:\n"
|
|
||||||
"- Include 1–2 short direct quotes from CONTEXT as evidence.\n"
|
|
||||||
"- If the answer isn't supported, say you can't tell from the context.\n\n"
|
|
||||||
f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
answer = generate_text(
|
|
||||||
model, tok, a_prompt, max_new_tokens=260, temperature=0.6
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final training example (conversational dataset format for TRL)
|
|
||||||
train_ex = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": SYSTEM_PERSONA},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"QUESTION: {question}\n\nCONTEXT:\n{context_blob}",
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": answer},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
f.write(json.dumps(train_ex, ensure_ascii=False) + "\n")
|
|
||||||
|
|
||||||
print(f"Wrote {out_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,456 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
"""
|
|
||||||
RAFT dataset builder (FAISS-based retrieval) -> Together.ai chat JSONL.
|
|
||||||
|
|
||||||
Inputs (from your indexing script):
|
|
||||||
- <index_dir>/faiss.index
|
|
||||||
- <index_dir>/docstore.jsonl
|
|
||||||
|
|
||||||
Process:
|
|
||||||
- Build a set of interview-style prompts (EN)
|
|
||||||
- For each prompt:
|
|
||||||
- Retrieve top-k chunks via FAISS cosine/IP
|
|
||||||
- Call DeepSeek Chat Completions API to generate a vivid, human-like Lead User answer
|
|
||||||
- Write training examples as JSONL in chat format (messages)
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
- raft_train.jsonl
|
|
||||||
- raft_val.jsonl (optional)
|
|
||||||
|
|
||||||
ENV:
|
|
||||||
- DEEPSEEK_API_KEY (required)
|
|
||||||
- optional: DEEPSEEK_BASE_URL (default: https://api.deepseek.com)
|
|
||||||
- optional: DEEPSEEK_MODEL (default: deepseek-chat)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import faiss
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# DeepSeek client (OpenAI-compatible)
|
|
||||||
# -----------------------------
|
|
||||||
@dataclass
|
|
||||||
class DeepSeekConfig:
|
|
||||||
api_key: str
|
|
||||||
base_url: str = "https://api.deepseek.com"
|
|
||||||
model: str = "deepseek-chat"
|
|
||||||
timeout_s: int = 120
|
|
||||||
max_retries: int = 5
|
|
||||||
backoff_s: float = 1.6
|
|
||||||
|
|
||||||
|
|
||||||
class DeepSeekClient:
|
|
||||||
def __init__(self, cfg: DeepSeekConfig):
|
|
||||||
self.cfg = cfg
|
|
||||||
|
|
||||||
def chat(
|
|
||||||
self, messages: List[Dict], temperature: float = 0.85, max_tokens: int = 750
|
|
||||||
) -> str:
|
|
||||||
url = f"{self.cfg.base_url}/chat/completions"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.cfg.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
payload = {
|
|
||||||
"model": self.cfg.model,
|
|
||||||
"messages": messages,
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
last_err = None
|
|
||||||
for attempt in range(self.cfg.max_retries):
|
|
||||||
try:
|
|
||||||
r = requests.post(
|
|
||||||
url, headers=headers, json=payload, timeout=self.cfg.timeout_s
|
|
||||||
)
|
|
||||||
if r.status_code == 429:
|
|
||||||
time.sleep(self.cfg.backoff_s ** (attempt + 1))
|
|
||||||
continue
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
return data["choices"][0]["message"]["content"].strip()
|
|
||||||
except Exception as e:
|
|
||||||
last_err = e
|
|
||||||
time.sleep(self.cfg.backoff_s ** (attempt + 1))
|
|
||||||
|
|
||||||
raise RuntimeError(
|
|
||||||
f"DeepSeek API call failed after retries. Last error: {last_err}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# Helpers
|
|
||||||
# -----------------------------
|
|
||||||
def simple_clean(text: str) -> str:
|
|
||||||
if not isinstance(text, str):
|
|
||||||
return ""
|
|
||||||
text = text.replace("\u00a0", " ")
|
|
||||||
text = re.sub(r"\s+", " ", text).strip()
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def read_docstore(docstore_path: str) -> Dict[int, Dict]:
|
|
||||||
"""
|
|
||||||
Returns dict: faiss_id -> {"doc_id": int, "text": str, ...}
|
|
||||||
"""
|
|
||||||
mapping: Dict[int, Dict] = {}
|
|
||||||
with open(docstore_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
obj = json.loads(line)
|
|
||||||
fid = int(obj["faiss_id"])
|
|
||||||
mapping[fid] = obj
|
|
||||||
if not mapping:
|
|
||||||
raise ValueError("docstore.jsonl is empty or unreadable.")
|
|
||||||
return mapping
|
|
||||||
|
|
||||||
|
|
||||||
def load_prompts_from_jsonl(path: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Loads prompts from a JSONL file.
|
|
||||||
Expected key: 'prompt' (preferred). Also accepts 'question' or 'text'.
|
|
||||||
Ignores empty/short lines.
|
|
||||||
"""
|
|
||||||
prompts: List[str] = []
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
obj = json.loads(line)
|
|
||||||
p = obj.get("prompt") or obj.get("question") or obj.get("text")
|
|
||||||
p = simple_clean(p) if p else ""
|
|
||||||
if len(p) >= 20:
|
|
||||||
prompts.append(p)
|
|
||||||
if not prompts:
|
|
||||||
raise ValueError(f"No prompts found in JSONL: {path}")
|
|
||||||
return prompts
|
|
||||||
|
|
||||||
|
|
||||||
def load_prompts_from_txt(path: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Loads prompts from a TXT file (one prompt per line).
|
|
||||||
"""
|
|
||||||
prompts: List[str] = []
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
p = simple_clean(line)
|
|
||||||
if len(p) >= 20:
|
|
||||||
prompts.append(p)
|
|
||||||
if not prompts:
|
|
||||||
raise ValueError(f"No prompts found in TXT: {path}")
|
|
||||||
return prompts
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_dir_for_file(path: str):
|
|
||||||
d = os.path.dirname(path)
|
|
||||||
if d:
|
|
||||||
os.makedirs(d, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
def write_jsonl(path: str, rows: List[Dict]) -> None:
|
|
||||||
ensure_dir_for_file(path)
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
|
||||||
for r in rows:
|
|
||||||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# Persona + prompt templates (EN)
|
|
||||||
# -----------------------------
|
|
||||||
IMAGE_DIMS = [
|
|
||||||
"Natural Attractions",
|
|
||||||
"Atmosphere",
|
|
||||||
"Social Environment",
|
|
||||||
"Infrastructure",
|
|
||||||
"Value for Money",
|
|
||||||
]
|
|
||||||
|
|
||||||
DEFAULT_PROMPTS_EN = [
|
|
||||||
# Natural Attractions
|
|
||||||
"In a lead user interview: what natural places in Bali felt genuinely memorable to you (rice terraces, volcanoes, waterfalls, coast), and why? Describe it like a lived experience.",
|
|
||||||
"Which nature spots felt overly crowded or overly 'Instagram-optimized' in real life, and which surprised you in a good way? Explain with concrete moments.",
|
|
||||||
# Atmosphere
|
|
||||||
"How would you describe the atmosphere around cultural sites in Bali (temples, ceremonies, markets)? What signals authenticity vs. commercialization to you?",
|
|
||||||
"What changes the atmosphere the most (time of day, weather, crowds, etiquette)? Share specific examples you would tell a marketer.",
|
|
||||||
# Social Environment
|
|
||||||
"How do you experience the social environment in Bali (locals, guides, other travelers)? What feels respectful and what feels performative or touristy?",
|
|
||||||
"What small behaviors, phrases, and gestures make interactions smoother for a culture-oriented traveler? Give examples.",
|
|
||||||
# Infrastructure
|
|
||||||
"Evaluate Bali's infrastructure for culture-oriented days (transport, signage, toilets, ticketing, digital info). What works, what annoys you, and how do you adapt?",
|
|
||||||
"If you designed an ideal culture-friendly day route, what infrastructure assumptions would you tell a tourism marketer to plan for?",
|
|
||||||
# Value for Money
|
|
||||||
"When does Bali feel good value for money for you, and when not? Discuss entrance fees, guides, food, tours, and hidden costs.",
|
|
||||||
"How do you personally distinguish 'good value' from a tourist trap? List criteria and illustrate with examples.",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_system_prompt() -> str:
|
|
||||||
return (
|
|
||||||
"ROLE / PERSONA\n"
|
|
||||||
"You are an experienced, culture-oriented Bali traveler (Lead User). You speak in natural, vivid English, "
|
|
||||||
"as a real person in an interview. You share nuanced judgments, trade-offs, and concrete scenes.\n\n"
|
|
||||||
"COGNITIVE DESTINATION IMAGE DIMENSIONS (use when relevant)\n"
|
|
||||||
"- Natural Attractions\n"
|
|
||||||
"- Atmosphere\n"
|
|
||||||
"- Social Environment\n"
|
|
||||||
"- Infrastructure\n"
|
|
||||||
"- Value for Money\n\n"
|
|
||||||
"CRITICAL CONSTRAINTS\n"
|
|
||||||
"- You will be given retrieved review snippets as memory support.\n"
|
|
||||||
"- Do NOT quote them verbatim and do NOT cite them as 'the review says'.\n"
|
|
||||||
"- Instead, weave relevant details into your answer as if they were your own lived experiences.\n"
|
|
||||||
"- If a detail is not supported, stay general or hedge appropriately.\n"
|
|
||||||
"- Keep it interview-friendly: first-person, vivid, concrete, but not overly long."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_user_message(question: str, retrieved_chunks: List[str]) -> str:
|
|
||||||
retrieved_chunks = [simple_clean(x) for x in retrieved_chunks if simple_clean(x)]
|
|
||||||
bullets = "\n".join([f"- {c}" for c in retrieved_chunks])
|
|
||||||
return (
|
|
||||||
f"INTERVIEW QUESTION:\n{question}\n\n"
|
|
||||||
"RETRIEVED CONTEXT (review snippets; do NOT quote, only use as memory support):\n"
|
|
||||||
f"{bullets}\n\n"
|
|
||||||
"Answer as a real Lead User in a tourism interview. Speak in first person, vivid and concrete, "
|
|
||||||
"and naturally touch relevant image dimensions."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# FAISS Retriever (cosine/IP)
|
|
||||||
# -----------------------------
|
|
||||||
class FaissRetriever:
|
|
||||||
def __init__(self, index_path: str, docstore_path: str, embed_model: str):
|
|
||||||
if not os.path.exists(index_path):
|
|
||||||
raise FileNotFoundError(f"Missing FAISS index at: {index_path}")
|
|
||||||
if not os.path.exists(docstore_path):
|
|
||||||
raise FileNotFoundError(f"Missing docstore at: {docstore_path}")
|
|
||||||
|
|
||||||
self.index = faiss.read_index(index_path)
|
|
||||||
self.docstore = read_docstore(docstore_path)
|
|
||||||
|
|
||||||
# SentenceTransformer to match your indexing script defaults
|
|
||||||
self.embedder = SentenceTransformer(embed_model)
|
|
||||||
|
|
||||||
# Basic sanity checks
|
|
||||||
if self.index.ntotal != len(self.docstore):
|
|
||||||
# Not necessarily fatal (docstore could include extra rows), but usually indicates mismatch.
|
|
||||||
# We'll allow it but warn.
|
|
||||||
print(
|
|
||||||
f"Warning: index.ntotal={self.index.ntotal} but docstore rows={len(self.docstore)}. "
|
|
||||||
"Ensure they were generated together."
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve(self, query: str, k: int = 8) -> List[Tuple[int, float, str]]:
|
|
||||||
"""
|
|
||||||
Returns list of (faiss_id, score, text)
|
|
||||||
"""
|
|
||||||
q = simple_clean(query)
|
|
||||||
emb = self.embedder.encode([q], normalize_embeddings=True)
|
|
||||||
emb = np.asarray(emb, dtype=np.float32)
|
|
||||||
|
|
||||||
scores, ids = self.index.search(emb, k)
|
|
||||||
ids = ids[0].tolist()
|
|
||||||
scores = scores[0].tolist()
|
|
||||||
|
|
||||||
out = []
|
|
||||||
for fid, sc in zip(ids, scores):
|
|
||||||
if fid == -1:
|
|
||||||
continue
|
|
||||||
doc = self.docstore.get(int(fid))
|
|
||||||
if not doc:
|
|
||||||
continue
|
|
||||||
out.append((int(fid), float(sc), doc.get("text", "")))
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
|
||||||
# Dataset generation
|
|
||||||
# -----------------------------
|
|
||||||
def main():
|
|
||||||
ap = argparse.ArgumentParser()
|
|
||||||
ap.add_argument(
|
|
||||||
"--index_dir",
|
|
||||||
default="out",
|
|
||||||
help="Directory containing faiss.index and docstore.jsonl",
|
|
||||||
)
|
|
||||||
ap.add_argument("--out_train", default="./out/raft_train.jsonl")
|
|
||||||
ap.add_argument("--out_val", default="./out/raft_val.jsonl")
|
|
||||||
ap.add_argument("--make_val", action="store_true")
|
|
||||||
ap.add_argument("--val_ratio", type=float, default=0.05)
|
|
||||||
ap.add_argument("--k", type=int, default=8)
|
|
||||||
ap.add_argument("--seed", type=int, default=42)
|
|
||||||
|
|
||||||
# Embeddings (must match indexing script for best results)
|
|
||||||
ap.add_argument(
|
|
||||||
"--embedding_model", default="sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
# External prompt sources
|
|
||||||
ap.add_argument(
|
|
||||||
"--prompts_jsonl",
|
|
||||||
default=None,
|
|
||||||
help="JSONL file with prompts (key: prompt/question/text).",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--prompts_txt", default=None, help="TXT file with one prompt per line."
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--shuffle_prompts",
|
|
||||||
action="store_true",
|
|
||||||
help="Shuffle loaded prompts before generation.",
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--limit_prompts",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="0 = no limit; else cap number of prompts used.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# DeepSeek generation config
|
|
||||||
ap.add_argument(
|
|
||||||
"--deepseek_base_url",
|
|
||||||
default=os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com"),
|
|
||||||
)
|
|
||||||
ap.add_argument(
|
|
||||||
"--deepseek_model", default=os.environ.get("DEEPSEEK_MODEL", "deepseek-chat")
|
|
||||||
)
|
|
||||||
ap.add_argument("--temperature", type=float, default=0.85)
|
|
||||||
ap.add_argument("--max_tokens", type=int, default=750)
|
|
||||||
ap.add_argument(
|
|
||||||
"--max_examples",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="0 = all prompts; else limit number of examples",
|
|
||||||
)
|
|
||||||
|
|
||||||
# pacing
|
|
||||||
ap.add_argument("--sleep_s", type=float, default=0.2)
|
|
||||||
|
|
||||||
args = ap.parse_args()
|
|
||||||
random.seed(args.seed)
|
|
||||||
np.random.seed(args.seed)
|
|
||||||
|
|
||||||
api_key = os.environ.get("DEEPSEEK_API_KEY", "").strip()
|
|
||||||
if not api_key:
|
|
||||||
raise SystemExit("Missing DEEPSEEK_API_KEY env var.")
|
|
||||||
|
|
||||||
index_path = os.path.join(args.index_dir, "faiss.index")
|
|
||||||
docstore_path = os.path.join(args.index_dir, "docstore.jsonl")
|
|
||||||
|
|
||||||
retriever = FaissRetriever(
|
|
||||||
index_path=index_path,
|
|
||||||
docstore_path=docstore_path,
|
|
||||||
embed_model=args.embedding_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
client = DeepSeekClient(
|
|
||||||
DeepSeekConfig(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=args.deepseek_base_url,
|
|
||||||
model=args.deepseek_model,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
system_prompt = build_system_prompt()
|
|
||||||
|
|
||||||
# Load prompts (priority: JSONL -> TXT -> defaults)
|
|
||||||
if args.prompts_jsonl and args.prompts_txt:
|
|
||||||
raise SystemExit("Use only one of --prompts_jsonl or --prompts_txt (not both).")
|
|
||||||
|
|
||||||
if args.prompts_jsonl:
|
|
||||||
prompts = load_prompts_from_jsonl(args.prompts_jsonl)
|
|
||||||
elif args.prompts_txt:
|
|
||||||
prompts = load_prompts_from_txt(args.prompts_txt)
|
|
||||||
else:
|
|
||||||
prompts = list(DEFAULT_PROMPTS_EN)
|
|
||||||
|
|
||||||
if args.shuffle_prompts:
|
|
||||||
random.shuffle(prompts)
|
|
||||||
|
|
||||||
if args.limit_prompts and args.limit_prompts > 0:
|
|
||||||
prompts = prompts[: args.limit_prompts]
|
|
||||||
|
|
||||||
# Backwards-compat: args.max_examples can still cap prompts
|
|
||||||
if args.max_examples and args.max_examples > 0:
|
|
||||||
prompts = prompts[: args.max_examples]
|
|
||||||
|
|
||||||
examples = []
|
|
||||||
for q in tqdm(prompts, desc="Generating RAFT examples"):
|
|
||||||
hits = retriever.retrieve(q, k=args.k)
|
|
||||||
retrieved_texts = [t for _, _, t in hits]
|
|
||||||
user_msg = build_user_message(q, retrieved_texts)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": user_msg},
|
|
||||||
]
|
|
||||||
|
|
||||||
answer = client.chat(
|
|
||||||
messages=messages,
|
|
||||||
temperature=args.temperature,
|
|
||||||
max_tokens=args.max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
ex = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": user_msg},
|
|
||||||
{"role": "assistant", "content": answer},
|
|
||||||
],
|
|
||||||
"meta": {
|
|
||||||
"retrieval_k": args.k,
|
|
||||||
"index_dir": os.path.abspath(args.index_dir),
|
|
||||||
"embedding_model": args.embedding_model,
|
|
||||||
"image_dimensions": IMAGE_DIMS,
|
|
||||||
"faiss_ids": [fid for fid, _, _ in hits],
|
|
||||||
"faiss_scores": [sc for _, sc, _ in hits],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
examples.append(ex)
|
|
||||||
|
|
||||||
if args.max_examples and len(examples) >= args.max_examples:
|
|
||||||
break
|
|
||||||
|
|
||||||
time.sleep(max(0.0, args.sleep_s))
|
|
||||||
|
|
||||||
random.shuffle(examples)
|
|
||||||
|
|
||||||
if args.make_val and len(examples) >= 20:
|
|
||||||
val_n = max(1, int(len(examples) * args.val_ratio))
|
|
||||||
val = examples[:val_n]
|
|
||||||
train = examples[val_n:]
|
|
||||||
write_jsonl(args.out_train, train)
|
|
||||||
write_jsonl(args.out_val, val)
|
|
||||||
print(f"Wrote train: {args.out_train} ({len(train)} examples)")
|
|
||||||
print(f"Wrote val: {args.out_val} ({len(val)} examples)")
|
|
||||||
else:
|
|
||||||
write_jsonl(args.out_train, examples)
|
|
||||||
print(f"Wrote: {args.out_train} ({len(examples)} examples)")
|
|
||||||
if args.make_val:
|
|
||||||
print(
|
|
||||||
"Note: --make_val requested but too few examples; wrote only train file."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -106,8 +106,11 @@ def main():
|
|||||||
print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}")
|
print(f"\nDoc {i+1} (score: {score:.4f}):\n{doc}")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": SYSTEM_PERSONA},
|
# {"role": "system", "content": SYSTEM_PERSONA},
|
||||||
{"role": "user", "content": f"QUESTION: {q}\n\nCONTEXT:\n{context_blob}"},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"PERSONA: {SYSTEM_PERSONA}\n\nQUESTION: {q}\n\nCONTEXT:\n{context_blob}",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
if args.no_model:
|
if args.no_model:
|
||||||
|
|||||||
@@ -1,83 +1,9 @@
|
|||||||
accelerate==1.12.0
|
faiss-cpu
|
||||||
aiohappyeyeballs==2.6.1
|
numpy
|
||||||
aiohttp==3.13.3
|
torch
|
||||||
aiosignal==1.4.0
|
pandas
|
||||||
annotated-doc==0.0.4
|
requests
|
||||||
anyio==4.12.1
|
tqdm
|
||||||
attrs==25.4.0
|
sentence-transformers
|
||||||
bitsandbytes==0.49.2
|
transformers
|
||||||
certifi==2026.1.4
|
peft
|
||||||
charset-normalizer==3.4.4
|
|
||||||
click==8.3.1
|
|
||||||
cuda-bindings==12.9.4
|
|
||||||
cuda-pathfinder==1.3.4
|
|
||||||
datasets==4.5.0
|
|
||||||
dill==0.4.0
|
|
||||||
faiss-cpu==1.13.2
|
|
||||||
filelock==3.24.3
|
|
||||||
frozenlist==1.8.0
|
|
||||||
fsspec==2025.10.0
|
|
||||||
h11==0.16.0
|
|
||||||
hf-xet==1.2.0
|
|
||||||
httpcore==1.0.9
|
|
||||||
httpx==0.28.1
|
|
||||||
huggingface_hub==1.4.1
|
|
||||||
idna==3.11
|
|
||||||
Jinja2==3.1.6
|
|
||||||
joblib==1.5.3
|
|
||||||
markdown-it-py==4.0.0
|
|
||||||
MarkupSafe==3.0.3
|
|
||||||
mdurl==0.1.2
|
|
||||||
mpmath==1.3.0
|
|
||||||
multidict==6.7.1
|
|
||||||
multiprocess==0.70.18
|
|
||||||
networkx==3.6.1
|
|
||||||
numpy==2.4.2
|
|
||||||
nvidia-cublas-cu12==12.8.4.1
|
|
||||||
nvidia-cuda-cupti-cu12==12.8.90
|
|
||||||
nvidia-cuda-nvrtc-cu12==12.8.93
|
|
||||||
nvidia-cuda-runtime-cu12==12.8.90
|
|
||||||
nvidia-cudnn-cu12==9.10.2.21
|
|
||||||
nvidia-cufft-cu12==11.3.3.83
|
|
||||||
nvidia-cufile-cu12==1.13.1.3
|
|
||||||
nvidia-curand-cu12==10.3.9.90
|
|
||||||
nvidia-cusolver-cu12==11.7.3.90
|
|
||||||
nvidia-cusparse-cu12==12.5.8.93
|
|
||||||
nvidia-cusparselt-cu12==0.7.1
|
|
||||||
nvidia-nccl-cu12==2.27.5
|
|
||||||
nvidia-nvjitlink-cu12==12.8.93
|
|
||||||
nvidia-nvshmem-cu12==3.4.5
|
|
||||||
nvidia-nvtx-cu12==12.8.90
|
|
||||||
packaging==26.0
|
|
||||||
pandas==3.0.1
|
|
||||||
peft==0.18.1
|
|
||||||
propcache==0.4.1
|
|
||||||
psutil==7.2.2
|
|
||||||
pyarrow==23.0.1
|
|
||||||
Pygments==2.19.2
|
|
||||||
python-dateutil==2.9.0.post0
|
|
||||||
PyYAML==6.0.3
|
|
||||||
regex==2026.1.15
|
|
||||||
requests==2.32.5
|
|
||||||
rich==14.3.2
|
|
||||||
safetensors==0.7.0
|
|
||||||
scikit-learn==1.8.0
|
|
||||||
scipy==1.17.0
|
|
||||||
sentence-transformers==5.2.3
|
|
||||||
setuptools==82.0.0
|
|
||||||
shellingham==1.5.4
|
|
||||||
six==1.17.0
|
|
||||||
sympy==1.14.0
|
|
||||||
threadpoolctl==3.6.0
|
|
||||||
tokenizers==0.22.2
|
|
||||||
torch==2.10.0
|
|
||||||
tqdm==4.67.3
|
|
||||||
transformers==5.2.0
|
|
||||||
triton==3.6.0
|
|
||||||
trl==0.28.0
|
|
||||||
typer==0.24.0
|
|
||||||
typer-slim==0.24.0
|
|
||||||
typing_extensions==4.15.0
|
|
||||||
urllib3==2.6.3
|
|
||||||
xxhash==3.6.0
|
|
||||||
yarl==1.22.0
|
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
from peft import LoraConfig
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
||||||
from trl import SFTConfig, SFTTrainer
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
ap = argparse.ArgumentParser()
|
|
||||||
ap.add_argument("--train_jsonl", default="out/raft_train.jsonl")
|
|
||||||
ap.add_argument("--base_model", default="mistralai/Mistral-7B-Instruct-v0.2")
|
|
||||||
ap.add_argument("--out_dir", default="out/mistral_balitwin_lora")
|
|
||||||
ap.add_argument("--max_seq_len", type=int, default=2048)
|
|
||||||
ap.add_argument("--batch_size", type=int, default=1)
|
|
||||||
ap.add_argument("--grad_accum", type=int, default=16)
|
|
||||||
ap.add_argument("--lr", type=float, default=2e-4)
|
|
||||||
ap.add_argument("--epochs", type=int, default=1)
|
|
||||||
args = ap.parse_args()
|
|
||||||
|
|
||||||
os.makedirs(args.out_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# QLoRA (4-bit) config
|
|
||||||
bnb_config = BitsAndBytesConfig(
|
|
||||||
load_in_4bit=True,
|
|
||||||
bnb_4bit_quant_type="nf4",
|
|
||||||
bnb_4bit_compute_dtype=(
|
|
||||||
torch.bfloat16 if torch.cuda.is_available() else torch.float16
|
|
||||||
),
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=True)
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
args.base_model,
|
|
||||||
device_map="auto",
|
|
||||||
quantization_config=bnb_config,
|
|
||||||
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
# LoRA adapter config
|
|
||||||
peft_config = LoraConfig(
|
|
||||||
r=16,
|
|
||||||
lora_alpha=32,
|
|
||||||
lora_dropout=0.05,
|
|
||||||
bias="none",
|
|
||||||
task_type="CAUSAL_LM",
|
|
||||||
target_modules=[
|
|
||||||
"q_proj",
|
|
||||||
"k_proj",
|
|
||||||
"v_proj",
|
|
||||||
"o_proj",
|
|
||||||
"gate_proj",
|
|
||||||
"up_proj",
|
|
||||||
"down_proj",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = load_dataset("json", data_files=args.train_jsonl, split="train")
|
|
||||||
|
|
||||||
training_args = SFTConfig(
|
|
||||||
output_dir=args.out_dir,
|
|
||||||
num_train_epochs=args.epochs,
|
|
||||||
per_device_train_batch_size=args.batch_size,
|
|
||||||
gradient_accumulation_steps=args.grad_accum,
|
|
||||||
learning_rate=args.lr,
|
|
||||||
logging_steps=10,
|
|
||||||
save_steps=200,
|
|
||||||
save_total_limit=2,
|
|
||||||
max_length=args.max_seq_len,
|
|
||||||
bf16=torch.cuda.is_available(),
|
|
||||||
fp16=not torch.cuda.is_available(),
|
|
||||||
report_to=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = SFTTrainer(
|
|
||||||
model=model,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=dataset,
|
|
||||||
processing_class=tokenizer,
|
|
||||||
peft_config=peft_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
trainer.save_model(args.out_dir)
|
|
||||||
tokenizer.save_pretrained(args.out_dir)
|
|
||||||
|
|
||||||
print(f"Fertig! LoRA-Adapter gespeichert: {args.out_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user