from transformers import AutoTokenizer, AutoModelForSequenceClassification
import datasets
import torch
import pandas as pd
import numpy as np
from transformers import Trainer
from transformers import TrainingArguments
import os
import sys
import pyarrow as pa

sys.path.append(os.getcwd())
from TextClassification.argsparse_classification_preamble import argsparse_preamble
import TextClassification.classification_metrics as cls_metrics
from database_preparation.utils_labeled_datasets import get_splits_for_cross_val
from database_preparation.utils_labeled_datasets import text_label_files_to_labeled_dataset
from database_preparation.preprocess import print_meta_data
from database_preparation.utils_labeled_datasets import is_text_lst_tokenized, is_text_lst_tfidf_vectorized

args = argsparse_preamble()

models_save_path = "./TextClassification/models/bert_models_new"

if not os.path.isdir(models_save_path):
    os.makedirs(models_save_path)


########## functions ##########

def train(train_set, test_set, classifier_save_path, base_bert_model,
          overwrite=False, track_loss_curves=True, epochs=3,
          learning_rate=5e-5, save_model=True, cuda_batch_size=8):
    '''
    trains and saves the model + train/test-data at classifier_save_path
    '''
    if save_model:
        if os.path.isdir(classifier_save_path):
            if overwrite:
                print(classifier_save_path + " already exists! (overwriting old model!)")
            else:
                print(classifier_save_path + " already exists! (skipping training)")
                return

    # This will issue a warning about some of the pretrained weights not being used and some weights being randomly initialized.
    # That’s because we are throwing away the pretraining head of the BERT model to replace it with a classification head which is randomly initialized.
    # We will fine-tune this model on our task, transferring the knowledge of the pretrained model to it (which is why doing this is called transfer learning).
    if test_set == None:
        num_labels = len(np.unique(train_set["label"]))
    else:
        num_labels = len(np.unique(train_set["label"] + test_set["label"]))
    model = AutoModelForSequenceClassification.from_pretrained(base_bert_model, num_labels=num_labels)

    if torch.cuda.is_available():
        batch_size = cuda_batch_size
    else:
        batch_size = 8

    if track_loss_curves:
        training_args = TrainingArguments(classifier_save_path + "/trainer",
                                          overwrite_output_dir=True,
                                          save_steps=2000,
                                          do_train=True,
                                          do_eval=True,
                                          num_train_epochs=epochs,
                                          evaluation_strategy='steps',
                                          logging_steps=2000,
                                          per_device_train_batch_size=batch_size,
                                          learning_rate=learning_rate
                                          )
    else:
        training_args = TrainingArguments(classifier_save_path + "/trainer",
                                          overwrite_output_dir=True,
                                          save_steps=2000,
                                          num_train_epochs=epochs,
                                          logging_steps=2000,
                                          per_device_train_batch_size=batch_size,
                                          learning_rate=learning_rate
                                          )

    print("training args: " + str(training_args.to_dict()))
    print("device:" + str(training_args.device))
    print("gpus: " + str(training_args.n_gpu))

    trainer = Trainer(
        model=model, args=training_args, train_dataset=train_set, eval_dataset=test_set
    )

    # training
    train_result = trainer.train()

    if track_loss_curves:
        # compute train results
        metrics = train_result.metrics
        metrics["train_samples"] = len(train_set)

        # save train results
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

        # compute evaluation results
        metrics = trainer.evaluate()
        metrics["eval_samples"] = len(test_set)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # save model
    if save_model and test_set != None:
        model.save_pretrained(classifier_save_path)

        hf_data_dict = datasets.DatasetDict({"train": train_set, "test": test_set})
        hf_data_dict.save_to_disk(classifier_save_path + "/tokenized_train_test_dataset")
    return model


def evaluate_saved_model(classifier_path, metrics_obj):
    # load model:
    model = AutoModelForSequenceClassification.from_pretrained(classifier_path, from_tf=False)

    # load tokenized datasets:
    train_test_set = datasets.DatasetDict.load_from_disk(classifier_path + "/tokenized_train_test_dataset")

    # train_set = train_test_set["train"]
    test_set = train_test_set["test"]

    evaluate(model, test_set, metrics_obj)


def evaluate(model, test_set, metrics_obj):
    # just use default parameters
    training_args = TrainingArguments("TextClassification/models/temp_trainer", evaluation_strategy="epoch",
                                      overwrite_output_dir=True, )

    trainer = Trainer(
        model=model,
        args=training_args,
        # train_dataset=train_set,
        eval_dataset=test_set
    )
    # print(trainer.evaluate())

    predictions = trainer.predict(test_set)

    preds = np.argmax(predictions.predictions, axis=-1)

    metrics_obj.update_metrics(predictions.label_ids, preds, True)


def main():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("running with cuda")

    label_sets = [args.clustered_data]

    if is_text_lst_tokenized(args.path2corpus):
        print(f"Error: {args.path2corpus} is a tokenized corpus. Please pass a not tokenized corpus!")
        exit(1)

    base_bert_models = ["./LanguageModelling/ger-patho-bert-2", "bert-base-german-cased"]
    # base_bert_models = ["bert-base-german-cased"]

    evaluate_test_set = True
    do_train = True
    test_run = False  # runns k-fold cross validation with only one test run for each model
    folds = 10
    track_loss_curves = False
    epochs = 4
    save_model = False
    cuda_batch_size = 2

    for label_set in label_sets:

        # train_test_dataset = dt.load_labeled_dataset(label_set)
        train_test_dataset = text_label_files_to_labeled_dataset(args.clustered_data, path2corpus=args.path2corpus)
        if train_test_dataset == None:
            print("cant do bert training without data!")
            sys.exit()

        # pre-savetrain/test  data for cross validation.
        # to train and test each model with the same data.
        k_train_test_sets = []
        for (train_dataset, test_dataset) in get_splits_for_cross_val(train_test_dataset, folds):
            k_train_test_sets.append(tuple((train_dataset, test_dataset)))

        for base_bert_model in base_bert_models:
            print(base_bert_model + " Evaluation with corpus " + args.path2corpus + " and cluster set " + label_set)
            print("infos about corpus:")
            print_meta_data(args.path2corpus)

            # compose names, depending on label_set and base bert model:
            if "./LanguageModelling/" in base_bert_model:
                # is it a custom LM from our languagemodeling-folder?
                name = base_bert_model.replace("./LanguageModelling/", "")
                classifier_path = models_save_path + "/" + name + "_" + label_set + "_ClassificatonModel"
                metrics = cls_metrics.ClassificationMetrics(name)
            elif '/' in base_bert_model and not './' in base_bert_model:
                name = base_bert_model.replace("/", "_")
                classifier_path = models_save_path + "/" + name + "_" + label_set + "_ClassificatonModel"
                metrics = cls_metrics.ClassificationMetrics(name)
            elif "gottbert-base" in base_bert_model:
                classifier_path = models_save_path + "/gottbert-base_" + label_set + "_ClassificatonModel"
                metrics = cls_metrics.ClassificationMetrics("gottbert-base")
            else:  # germanbert
                classifier_path = models_save_path + "/" + base_bert_model + "_" + label_set + "_ClassificatonModel"
                metrics = cls_metrics.ClassificationMetrics(base_bert_model)
            if save_model:
                print("saving model at: ")
                print(classifier_path)

            # cross validation:
            for i, (train_dataset, test_dataset) in enumerate(k_train_test_sets):
                # convert to dataframe:
                train_dataset_ds = datasets.Dataset(pa.Table.from_pandas(pd.DataFrame(train_dataset)))
                test_dataset_ds = datasets.Dataset(pa.Table.from_pandas(pd.DataFrame(test_dataset)))

                # tokenize
                tokenizer = AutoTokenizer.from_pretrained(base_bert_model)

                def tokenize_function(examples):
                    return tokenizer(examples["text"], padding="max_length", truncation=True)

                train_set = train_dataset_ds.map(tokenize_function, batched=True)
                test_set = test_dataset_ds.map(tokenize_function, batched=True)

                # train
                if do_train:
                    print("==> training " + classifier_path + "_" + str(i))
                    model = train(train_set, test_set, classifier_path + "_" + str(i), base_bert_model,
                                  track_loss_curves=track_loss_curves, epochs=epochs,
                                  save_model=save_model, cuda_batch_size=cuda_batch_size)

                # evaluate
                if evaluate_test_set:
                    if save_model:
                        print("==> predicting test set with " + classifier_path + "_" + str(i))
                        evaluate_saved_model(classifier_path + "_" + str(i), metrics)
                    else:
                        print("==> predicting test set with " + classifier_path + "_" + str(i))
                        evaluate(model, test_set, metrics)

                if test_run:
                    break

            metrics.save_scores_to_disk(label_set)
            metrics.pickle_object(label_set)


if __name__ == '__main__':
    main()