import database_preparation.utils_labeled_datasets as dt import numpy as np import matplotlib.pyplot as plt import pandas as pd import sys, os import argparse sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() parser.add_argument("--df_cases_file", default="database/df_cases.pkl") args = parser.parse_args() plot_author_histos = False cluster = 2 clustersets = ["HDBSCAN", "KMeans", "LDA", "GSDPMM", "top2vec", "Patho_BERT", "German_BERT", "golden"] df = pd.read_pickle(args.df_cases_file) # plot histograms: how much docs do have the same label=cluster-index? for i,label_set in enumerate(clustersets): try: cluster_labels = dt.label_list_as_int_list(df['label_' + label_set]) except: print(f"skipping {label_set}. it is not in the df_cases_file.") continue if plot_author_histos: if 'label_author' in df: authors_labels = df["label_author"] authors_of_cluster = [authors_labels[i] for i, label in enumerate(cluster_labels) if label == cluster] authors = np.asarray(authors_of_cluster) x = [-1,0,1,2,3] h = [] for l in x: h.append(sum([1 for a in authors if a == l])) plt.bar(x, height=h) plt.title(label_set + " authors in cluster " + str(cluster)) file_path = 'TextClustering/plots/histograms/histogram_' + label_set + "_cluster" + str(cluster) + "_authors.png" else: print(f'Cant plot author histos, there is not "label_author" in df_cases.') else: labels = np.asarray([l for l in cluster_labels if l != -1]) label_num = dt.get_amount_unique_labels(label_set) x = np.arange(label_num) h = [] for l in x: h.append(sum([1 for label in labels if label == l])) plt.bar(x, height=h) plt.xticks(x, x) plt.title(label_set) plt.title(label_set) file_path = 'TextClustering/plots/histograms/histogram_' + label_set + ".png" plt.xticks(x, x) plt.savefig(file_path, dpi=600) plt.close() plt.clf() print(f"generated {file_path}")