import matplotlib.pyplot as plt import math import json import argparse def plot_loss_curve(path2json, title='loss'): with open(path2json) as f: log_history = json.load(f)["log_history"] # Keep track of train and evaluate loss. loss_history = {'train_loss': [], 'eval_loss': [], 'train_steps': [], 'train_epochs': [], 'eval_steps': [], 'eval_epochs': []} # Keep track of train and evaluate perplexity. # This is a metric useful to track for language models. perplexity_history = {'train_perplexity': [], 'eval_perplexity': []} for log in log_history: if 'loss' in log.keys(): # Deal with trianing loss. loss_history['train_loss'].append(log['loss']) perplexity_history['train_perplexity'].append(math.exp(log['loss'])) loss_history['train_epochs'].append(log["epoch"]) loss_history['train_steps'].append(log["step"]) elif 'eval_loss' in log.keys(): # Deal with eval loss. loss_history['eval_loss'].append(log['eval_loss']) perplexity_history['eval_perplexity'].append(math.exp(log['eval_loss'])) loss_history['eval_epochs'].append(log["epoch"]) loss_history['eval_steps'].append(log["step"]) # Plot Losses. plt.figure() plt.plot(loss_history['eval_epochs'], loss_history["eval_loss"], label="eval loss") plt.plot(loss_history['train_epochs'], loss_history["train_loss"], label="train loss") plt.xlabel("epoch", fontsize=14) plt.ylabel("loss", fontsize=14) plt.title(title, fontsize=16) plt.grid(True) plt.legend() plt.show() def main(): parser = argparse.ArgumentParser() parser.add_argument("--path_to_trainer_state_file", default='./LanguageModelling/ger-patho-bert-w3/trainer_state.json') args = parser.parse_args() # example how to plot loss curve: plot_loss_curve(args.path_to_trainer_state_file, args.path_to_trainer_state_file.replace('/trainer_state.json','')) if __name__ == '__main__': main()