import json from collections import Counter import matplotlib.pyplot as plt import seaborn as sns def load_labels(file_path): """Load labels from JSON file""" with open(file_path, "r", encoding="utf-8") as f: return json.load(f) def process_labels(data): """Extract valid categories and count their occurrences""" categories = [] errors = 0 for entry in data: if "deepseek" in entry: categories.append(entry["deepseek"]["category"]) elif "error" in entry: errors += 1 category_counts = Counter(categories) return category_counts, errors def visualize_distribution(category_counts, errors, output_file=None): """Create visualization of category distribution""" # Prepare data categories = list(category_counts.keys()) counts = list(category_counts.values()) total_valid = sum(counts) total = total_valid + errors # Set style sns.set(style="whitegrid") plt.figure(figsize=(10, 6)) # Create bar plot ax = sns.barplot(x=categories, y=counts, palette="viridis") # Customize plot plt.title( f"Review Category Distribution\n(Total: {total} reviews - {errors} errors)", pad=20, ) plt.xlabel("Category") plt.ylabel("Count") plt.xticks(rotation=45, ha="right") # Add value labels for i, count in enumerate(counts): ax.text(i, count + 0.5, str(count), ha="center") # Add error count annotation if there are errors if errors > 0: plt.annotate( f"{errors} errors\n({errors/total:.1%})", xy=(0.95, 0.95), xycoords="axes fraction", ha="right", va="top", bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), ) # Adjust layout plt.tight_layout() # Save or show if output_file: plt.savefig(output_file, dpi=300) print(f"Visualization saved to {output_file}") else: plt.show() def main(): input_file = "deepseek_labels.json" output_image = ( "./img/category_distribution.png" # Set to None to display instead of saving ) # Load and process data data = load_labels(input_file) category_counts, errors = process_labels(data) # Print basic stats print("Category Distribution:") for category, count in category_counts.most_common(): print(f"- {category}: {count} ({count/len(data):.1%})") if errors > 0: print(f"- Errors: {errors} ({errors/len(data):.1%})") # Visualize visualize_distribution(category_counts, errors, output_image) if __name__ == "__main__": main()