import openpyxl
from TextClustering.utils_wordlist import get_top_cluster_words_as_latex_table



path2table = "WordsPerCluster_HDBSCAN.xlsx"

green = 'FF00FF00'
blue = 'FF4A86E8'
orange = 'FFFF9900'
black = '1'
latex_weak_word = '\\weakcolor'
latex_strong_word = '\\strongcolor'

translate_to_eng = False


if translate_to_eng:
    from googletrans import Translator  # use pip install googletrans==3.1.0a0, 3.0 version is broken
    from utils_general import custom_translation


def color2latex_color(color):
    if color == green:
        return latex_strong_word
    if color == blue:
        return latex_weak_word
    if color == orange:
        return latex_weak_word
    # print(f"unknown color: {color}")
    return None


def get_annotated_exceltable(ws):
    words_list = []
    topics = []
    colors = []
    for idx, col in enumerate(ws.iter_rows(min_row=2, max_row=25, min_col=1, max_col=11)):
        if col[0].value is None:
            break
        words_list.append([])
        colors.append([])
        for i, cell in enumerate(col):
            if i == 0:
                topics.append((cell.value, color2latex_color(cell.font.color.rgb)))
            else:
                words_list[idx].append(cell.value)
                colors[idx].append(color2latex_color(cell.font.color.rgb))

    # return get_top_cluster_words_as_latex_table(words_list, colors, topics)
    return words_list, colors, topics


def main():
    wb = openpyxl.load_workbook(path2table)
    extraction_methods = ['tf-idf', 'SVM']
    cluster_method = 'HDBSCAN'
    anotate_svm_as_tfidf = True

    if translate_to_eng:
        translator = Translator()

    words_list_tfidf = []
    colorstfidf = []
    topicstfidf = []
    for i, extraction_method in enumerate(extraction_methods):
        ws = wb[['TFIDF-based', 'svm-based'][i]]
        words_list, colors, topics = get_annotated_exceltable(ws)
        if anotate_svm_as_tfidf:
            if extraction_method != 'tf-idf':
                topics = topicstfidf
                for j, words in enumerate(words_list):
                    for k, word in enumerate(words):
                        if word in words_list_tfidf[j]:
                            colors[j][k] = colorstfidf[j][words_list_tfidf[j].index(word)]
            else:
                words_list_tfidf, colorstfidf, topicstfidf = words_list, colors, topics

        # print german topic words:
        label = 'table_cluster_topics_' + cluster_method + '_' + extraction_method + '_ger'
        # print("\n processing " + label+"...\n")
        description = f'Annotated German topic words, extracted from the {cluster_method} cluster-set, ' \
                      f'using the {extraction_method} based extraction method.'
        latex = get_top_cluster_words_as_latex_table(words_list, colors, topics).replace('DESCRIPTON',
                                                                                         description).replace(
            'EXTRACTIONMETHOD', extraction_method).replace(
            'LABEL', label
        )
        print(latex)

        # print english topic words:
        if translate_to_eng:
            description = f'Annotated topic words (translated from German to English), ' \
                          f'extracted from the {cluster_method} cluster-set, ' \
                          f'using the {extraction_method} based extraction method.'
            label = 'table_cluster_topics_' + cluster_method + '_' + extraction_method + '_eng'

            word_list_eng = [[word if word.lower() not in custom_translation.keys() else custom_translation[word.lower()]
                              for word in words] for words in words_list]
            topics_eng = [(translator.translate(topic[0], src='de').text, topic[1]) for topic in topics]
            latex = get_top_cluster_words_as_latex_table(word_list_eng, colors, topics_eng).replace('DESCRIPTON',
                                                                                                    description).replace(
                'EXTRACTIONMETHOD', extraction_method).replace(
                'LABEL', label
            )
            print(latex)


if __name__ == '__main__':
    main()