import matplotlib.pyplot as plt
import math
import json
import argparse

def plot_loss_curve(path2json, title='loss'):
    with open(path2json) as f:
        log_history = json.load(f)["log_history"]

    # Keep track of train and evaluate loss.
    loss_history = {'train_loss': [], 'eval_loss': [],
                    'train_steps': [], 'train_epochs': [],
                    'eval_steps': [], 'eval_epochs': []}

    # Keep track of train and evaluate perplexity.
    # This is a metric useful to track for language models.
    perplexity_history = {'train_perplexity': [], 'eval_perplexity': []}

    for log in log_history:
        if 'loss' in log.keys():
            # Deal with trianing loss.
            loss_history['train_loss'].append(log['loss'])
            perplexity_history['train_perplexity'].append(math.exp(log['loss']))
            loss_history['train_epochs'].append(log["epoch"])
            loss_history['train_steps'].append(log["step"])

        elif 'eval_loss' in log.keys():
            # Deal with eval loss.
            loss_history['eval_loss'].append(log['eval_loss'])
            perplexity_history['eval_perplexity'].append(math.exp(log['eval_loss']))
            loss_history['eval_epochs'].append(log["epoch"])
            loss_history['eval_steps'].append(log["step"])

    # Plot Losses.
    plt.figure()
    plt.plot(loss_history['eval_epochs'], loss_history["eval_loss"],
             label="eval loss")
    plt.plot(loss_history['train_epochs'], loss_history["train_loss"],
             label="train loss")
    plt.xlabel("epoch", fontsize=14)
    plt.ylabel("loss", fontsize=14)
    plt.title(title, fontsize=16)
    plt.grid(True)
    plt.legend()
    plt.show()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--path_to_trainer_state_file",
                        default='./LanguageModelling/ger-patho-bert-w3/trainer_state.json')
    args = parser.parse_args()

    # example how to plot loss curve:
    plot_loss_curve(args.path_to_trainer_state_file,
                    args.path_to_trainer_state_file.replace('/trainer_state.json',''))

if __name__ == '__main__':
    main()