# argsparse section
import sys, os
sys.path.append(os.getcwd())

from TextClustering.argsparse_clustering_preamble import argsparse_preamble
args = argsparse_preamble()


from database_preparation.utils_labeled_datasets import is_text_lst_tokenized
if not is_text_lst_tokenized(args.path2corpus):
    print("Error: "+args.path2corpus + '.pkl is not tokenized! '
            'Please pass texts list where each text is tokenized (a list of words).')
    exit(1)

# import section
import pickle
import gensim
import gensim.corpora as corpora
from gensim.models import CoherenceModel
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from gensim.models import CoherenceModel
from tqdm import tqdm
from database_preparation.preprocess import print_meta_data
from TextClassification.classification_for_cluster_evaluation import cross_validate_label_corpus_with_simple_SVM

# load the diag and main_diag list
with open(args.path2corpus, 'rb') as f:
    diag_lst = pickle.load(f)

print_meta_data(args.path2corpus)

# prepare database_preparation for LDA-model-trainng
# Creates, which is a mapping of word IDs to words.
words = corpora.Dictionary(diag_lst)

# Turns each document into a bag of words.
corpus = [words.doc2bow(doc) for doc in diag_lst] #is that allready a model

# train LDA-model with different number of clusters
if args.find_k_value:
    limit=21; start=5; step=1
    coherence_values = []
    model_list, n_cluster, svm_scores = [], [], []

    for num_topics in tqdm(range(start, limit, step)):

        lda_model = gensim.models.ldamodel.LdaModel(corpus=corpus,
                                                    id2word=words,
                                                    num_topics=num_topics,
                                                    random_state=5,
                                                    update_every=1,
                                                    passes=10,
                                                    alpha='auto',
                                                    per_word_topics=True)

        coherencemodel = CoherenceModel(model=lda_model, texts=diag_lst, dictionary=words,
                                        coherence='c_v', processes= 1)
        coherence_values.append(coherencemodel.get_coherence())

        topic_weights = []
        for i, row_list in enumerate(lda_model[corpus]):
            topic_weights.append([w for i, w in row_list[0]])
        predictedCluster = np.argmax(pd.DataFrame(topic_weights).fillna(0).values, axis=1)
        svm_scores.append(
            cross_validate_label_corpus_with_simple_SVM(predictedCluster, args.path2corpus,
                                                        False))

        #n_cluster.append(len(lda_model.print_topics(num_words=3)))
        n_cluster.append(len(np.unique(np.asarray(predictedCluster))))
        print("coherence: " + str(coherencemodel.get_coherence()))

    # visualize the results
    x = range(start, limit, step)
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax3 = ax1.twinx()
    ax1.plot(x, coherence_values,'bx-')
    ax2.plot(x, n_cluster, 'rx-')
    ax3.plot(x, svm_scores, 'gx-')
    ax1.set_xlabel('Minimal cluster size')
    ax1.yaxis.label.set_color('blue')
    ax1.set_ylabel('Coherence score')
    ax2.yaxis.label.set_color('red')
    ax2.set_ylabel('Number of clusters')
    ax3.yaxis.label.set_color('green')
    ax3.set_ylabel('svm accuracy')
    plt.title('Ellbow-method-like plot')
    plt.savefig("TextClustering/plots/elbow_method/LDA_elbow_plot.png", dpi=300)
    plt.show()
    exit()

# train LDA-model
lda_model = gensim.models.ldamodel.LdaModel(corpus=corpus,
                                           id2word=words,
                                           num_topics=args.k_value,
                                           random_state=5,
                                           update_every=1,
                                           passes=10,
                                           alpha='auto',
                                           per_word_topics=True)


# get topic weights / features
topic_weights = []
for i, row_list in enumerate(lda_model[corpus]):
    topic_weights.append([w for i, w in row_list[0]])

# Array of topic weights
text_features = pd.DataFrame(topic_weights).fillna(0).values

# get prediction
predictedCluster= np.argmax(text_features, axis=1)

# and add it to the dataframe
df = pd.read_pickle(args.df_cases_file)
df['label_LDA'] = predictedCluster


from sklearn.decomposition import PCA
pca = PCA(n_components=2)
reduced_features = pca.fit_transform(text_features)
df['pcaX_LDA'] = reduced_features[:, 0]
df['pcaY_LDA'] = reduced_features[:, 1]


# and with umap
import umap
umap_text_features2D = umap.UMAP(n_neighbors=15,
                                     n_components=2,
                                     min_dist=0.0, metric='cosine').fit_transform(text_features)

df['umapX_LDA'] = umap_text_features2D[:, 0]
df['umapY_LDA'] = umap_text_features2D[:, 1]
df.to_pickle(args.df_cases_file)

# evalute the model
from TextClustering.utils_metrics import ClusterMetrics
evaluation = ClusterMetrics(text_features, predictedCluster,
                            file_name= "TextClustering/cluster_metrics/LDA_metrics.pkl")

evaluation.write_to_file()