Source code for src.core.lstm.lstm

import os
import json
import torch
import logging
import pandas as pd

from preprocess.flickr import PreprocessFlickr
from .model import LSTMModel, LSTMTrainer
from .dataset import create_dataloader
from plot_utils import plot_history, plot_embed


[docs] class LSTM: """ A class to run LSTM data preprocessing, training and inference :param config_dict: Config Params Dictionary :type config_dict: dict """ def __init__(self, config_dict): self.logger = logging.getLogger(__name__) self.config_dict = config_dict self.val_split = self.config_dict["dataset"]["val_split"] self.batch_size = self.config_dict["dataset"]["batch_size"] self.seed = self.config_dict["dataset"]["seed"]
[docs] def run(self): """ Runs LSTM Training and saves output """ self.lstm_ds = PreprocessFlickr(self.config_dict) train_paths, train_tokens, transforms = self.lstm_ds.get_data() train_loader, val_loader = create_dataloader( train_paths, train_tokens, transforms, self.val_split, self.batch_size, self.seed, "train", ) self.model = LSTMModel(self.config_dict) lr = self.config_dict["train"]["lr"] optim = torch.optim.Adam(self.model.parameters(), lr=lr) self.trainer = LSTMTrainer(self.model, optim, self.config_dict) self.history = self.trainer.fit(train_loader, val_loader) self.save_output()
[docs] def run_infer(self): """ Runs inference :return: Test image paths, True captions, Predicted captions :rtype: tuple (list, list, list) """ test_paths, test_tokens, transforms = self.lstm_ds.get_test_data() test_loader = create_dataloader( test_paths, test_tokens, transforms, self.val_split, self.batch_size, self.seed, "test", ) test_pred_tokens = self.trainer.predict(test_loader) test_pred_tokens = test_pred_tokens.argmax(axis=-1).astype("int") test_captions = self.lstm_ds.batched_ids2captions(test_tokens) test_pred_captions = self.lstm_ds.batched_ids2captions(test_pred_tokens) return test_paths, test_captions, test_pred_captions
[docs] def save_output(self): """ Saves Training and Inference results """ output_folder = self.config_dict["paths"]["output_folder"] self.logger.info(f"Saving Outputs {output_folder}") with open(os.path.join(output_folder, "training_history.json"), "w") as fp: json.dump(self.history, fp) embeds = self.model.embed_layer.weight.detach().numpy() vocab = list(self.lstm_ds.word2id.keys()) plot_embed(embeds, vocab, output_folder) plot_history(self.history, output_folder) test_paths, test_captions, test_pred_captions = self.run_infer() test_df = pd.DataFrame.from_dict( { "Image": [os.path.basename(i) for i in test_paths], "True Caption": test_captions, "Generated Caption": test_pred_captions, } ) test_df.to_csv(os.path.join(output_folder, "Test Predictions.csv"), index=False)