mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2025-12-06 02:00:50 +01:00
102 lines
2.6 KiB
Python
102 lines
2.6 KiB
Python
import json
|
|
from collections import Counter
|
|
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
|
|
def load_labels(file_path):
|
|
"""Load labels from JSON file"""
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def process_labels(data):
|
|
"""Extract valid categories and count their occurrences"""
|
|
categories = []
|
|
errors = 0
|
|
|
|
for entry in data:
|
|
if "deepseek" in entry:
|
|
categories.append(entry["deepseek"]["category"])
|
|
elif "error" in entry:
|
|
errors += 1
|
|
|
|
category_counts = Counter(categories)
|
|
return category_counts, errors
|
|
|
|
|
|
def visualize_distribution(category_counts, errors, output_file=None):
|
|
"""Create visualization of category distribution"""
|
|
# Prepare data
|
|
categories = list(category_counts.keys())
|
|
counts = list(category_counts.values())
|
|
total_valid = sum(counts)
|
|
total = total_valid + errors
|
|
|
|
# Set style
|
|
sns.set(style="whitegrid")
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
# Create bar plot
|
|
ax = sns.barplot(x=categories, y=counts, palette="viridis")
|
|
|
|
# Customize plot
|
|
plt.title(
|
|
f"Review Category Distribution\n(Total: {total} reviews - {errors} errors)",
|
|
pad=20,
|
|
)
|
|
plt.xlabel("Category")
|
|
plt.ylabel("Count")
|
|
plt.xticks(rotation=45, ha="right")
|
|
|
|
# Add value labels
|
|
for i, count in enumerate(counts):
|
|
ax.text(i, count + 0.5, str(count), ha="center")
|
|
|
|
# Add error count annotation if there are errors
|
|
if errors > 0:
|
|
plt.annotate(
|
|
f"{errors} errors\n({errors/total:.1%})",
|
|
xy=(0.95, 0.95),
|
|
xycoords="axes fraction",
|
|
ha="right",
|
|
va="top",
|
|
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
|
|
)
|
|
|
|
# Adjust layout
|
|
plt.tight_layout()
|
|
|
|
# Save or show
|
|
if output_file:
|
|
plt.savefig(output_file, dpi=300)
|
|
print(f"Visualization saved to {output_file}")
|
|
else:
|
|
plt.show()
|
|
|
|
|
|
def main():
|
|
input_file = "deepseek_labels.json"
|
|
output_image = (
|
|
"./img/category_distribution.png" # Set to None to display instead of saving
|
|
)
|
|
|
|
# Load and process data
|
|
data = load_labels(input_file)
|
|
category_counts, errors = process_labels(data)
|
|
|
|
# Print basic stats
|
|
print("Category Distribution:")
|
|
for category, count in category_counts.most_common():
|
|
print(f"- {category}: {count} ({count/len(data):.1%})")
|
|
if errors > 0:
|
|
print(f"- Errors: {errors} ({errors/len(data):.1%})")
|
|
|
|
# Visualize
|
|
visualize_distribution(category_counts, errors, output_image)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|