Files
masterthesis-playground/deepseek_label_distribution.py
2025-06-06 05:14:58 +02:00

102 lines
2.6 KiB
Python

import json
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
def load_labels(file_path):
"""Load labels from JSON file"""
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
def process_labels(data):
"""Extract valid categories and count their occurrences"""
categories = []
errors = 0
for entry in data:
if "deepseek" in entry:
categories.append(entry["deepseek"]["category"])
elif "error" in entry:
errors += 1
category_counts = Counter(categories)
return category_counts, errors
def visualize_distribution(category_counts, errors, output_file=None):
"""Create visualization of category distribution"""
# Prepare data
categories = list(category_counts.keys())
counts = list(category_counts.values())
total_valid = sum(counts)
total = total_valid + errors
# Set style
sns.set(style="whitegrid")
plt.figure(figsize=(10, 6))
# Create bar plot
ax = sns.barplot(x=categories, y=counts, palette="viridis")
# Customize plot
plt.title(
f"Review Category Distribution\n(Total: {total} reviews - {errors} errors)",
pad=20,
)
plt.xlabel("Category")
plt.ylabel("Count")
plt.xticks(rotation=45, ha="right")
# Add value labels
for i, count in enumerate(counts):
ax.text(i, count + 0.5, str(count), ha="center")
# Add error count annotation if there are errors
if errors > 0:
plt.annotate(
f"{errors} errors\n({errors/total:.1%})",
xy=(0.95, 0.95),
xycoords="axes fraction",
ha="right",
va="top",
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
# Adjust layout
plt.tight_layout()
# Save or show
if output_file:
plt.savefig(output_file, dpi=300)
print(f"Visualization saved to {output_file}")
else:
plt.show()
def main():
input_file = "deepseek_labels.json"
output_image = (
"./img/category_distribution.png" # Set to None to display instead of saving
)
# Load and process data
data = load_labels(input_file)
category_counts, errors = process_labels(data)
# Print basic stats
print("Category Distribution:")
for category, count in category_counts.most_common():
print(f"- {category}: {count} ({count/len(data):.1%})")
if errors > 0:
print(f"- Errors: {errors} ({errors/len(data):.1%})")
# Visualize
visualize_distribution(category_counts, errors, output_image)
if __name__ == "__main__":
main()