mirror of
https://github.com/marvinscham/masterthesis-playground.git
synced 2025-12-06 10:10:50 +01:00
Init
This commit is contained in:
101
deepseek_label_distribution.py
Normal file
101
deepseek_label_distribution.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
def load_labels(file_path):
|
||||
"""Load labels from JSON file"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def process_labels(data):
|
||||
"""Extract valid categories and count their occurrences"""
|
||||
categories = []
|
||||
errors = 0
|
||||
|
||||
for entry in data:
|
||||
if "deepseek" in entry:
|
||||
categories.append(entry["deepseek"]["category"])
|
||||
elif "error" in entry:
|
||||
errors += 1
|
||||
|
||||
category_counts = Counter(categories)
|
||||
return category_counts, errors
|
||||
|
||||
|
||||
def visualize_distribution(category_counts, errors, output_file=None):
|
||||
"""Create visualization of category distribution"""
|
||||
# Prepare data
|
||||
categories = list(category_counts.keys())
|
||||
counts = list(category_counts.values())
|
||||
total_valid = sum(counts)
|
||||
total = total_valid + errors
|
||||
|
||||
# Set style
|
||||
sns.set(style="whitegrid")
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
||||
# Create bar plot
|
||||
ax = sns.barplot(x=categories, y=counts, palette="viridis")
|
||||
|
||||
# Customize plot
|
||||
plt.title(
|
||||
f"Review Category Distribution\n(Total: {total} reviews - {errors} errors)",
|
||||
pad=20,
|
||||
)
|
||||
plt.xlabel("Category")
|
||||
plt.ylabel("Count")
|
||||
plt.xticks(rotation=45, ha="right")
|
||||
|
||||
# Add value labels
|
||||
for i, count in enumerate(counts):
|
||||
ax.text(i, count + 0.5, str(count), ha="center")
|
||||
|
||||
# Add error count annotation if there are errors
|
||||
if errors > 0:
|
||||
plt.annotate(
|
||||
f"{errors} errors\n({errors/total:.1%})",
|
||||
xy=(0.95, 0.95),
|
||||
xycoords="axes fraction",
|
||||
ha="right",
|
||||
va="top",
|
||||
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
|
||||
)
|
||||
|
||||
# Adjust layout
|
||||
plt.tight_layout()
|
||||
|
||||
# Save or show
|
||||
if output_file:
|
||||
plt.savefig(output_file, dpi=300)
|
||||
print(f"Visualization saved to {output_file}")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
input_file = "deepseek_labels.json"
|
||||
output_image = (
|
||||
"./img/category_distribution.png" # Set to None to display instead of saving
|
||||
)
|
||||
|
||||
# Load and process data
|
||||
data = load_labels(input_file)
|
||||
category_counts, errors = process_labels(data)
|
||||
|
||||
# Print basic stats
|
||||
print("Category Distribution:")
|
||||
for category, count in category_counts.most_common():
|
||||
print(f"- {category}: {count} ({count/len(data):.1%})")
|
||||
if errors > 0:
|
||||
print(f"- Errors: {errors} ({errors/len(data):.1%})")
|
||||
|
||||
# Visualize
|
||||
visualize_distribution(category_counts, errors, output_image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user