Source code for src.core.rnn.rnn

import os
import json
import torch
import logging
from torch.utils.data import TensorDataset, DataLoader

from preprocess.utils import preprocess_text
from .model import RNNModel, RNNTrainer
from .dataset import create_dataloader, RNNDataset
from plot_utils import plot_embed, plot_history, plot_conf_matrix


[docs] class RNN: """ A class to run RNN data preprocessing, training and inference :param config_dict: Config Params Dictionary :type config_dict: dict """ def __init__(self, config_dict): self.logger = logging.getLogger(__name__) self.config_dict = config_dict
[docs] def run(self): """ Runs Seq2Seq Training and saves output """ self.rnn_ds = RNNDataset(self.config_dict) X, y = self.rnn_ds.get_data() self.config_dict["dataset"]["labels"] = list( self.rnn_ds.label_encoder.categories_[0] ) val_split = self.config_dict["dataset"]["val_split"] batch_size = self.config_dict["dataset"]["batch_size"] seed = self.config_dict["dataset"]["seed"] train_loader, val_loader = create_dataloader(X, y, val_split, batch_size, seed) self.model = RNNModel(self.config_dict) lr = self.config_dict["train"]["lr"] optim = torch.optim.Adam(self.model.parameters(), lr=lr) self.trainer = RNNTrainer(self.model, optim, self.config_dict) self.history = self.trainer.fit(train_loader, val_loader) self.save_output()
[docs] def run_infer(self): """ Runs inference :return: True and predicted captions :rtype: tuple (torch.Tensor [batch_size, seq_len, num_classes], torch.Tensor [batch_size, seq_len, num_classes]) """ test_x, y_true = self.rnn_ds.get_test_data() test_ds = TensorDataset(torch.Tensor(test_x)) test_loader = DataLoader( test_ds, batch_size=self.config_dict["dataset"]["batch_size"], shuffle=True, drop_last=False, num_workers=1, pin_memory=True, ) y_pred = self.trainer.predict(test_loader) return y_true, y_pred
[docs] def save_output(self): """ Saves Training and Inference results """ output_folder = self.config_dict["paths"]["output_folder"] self.logger.info(f"Saving Outputs {output_folder}") with open(os.path.join(output_folder, "training_history.json"), "w") as fp: json.dump(self.history, fp) embeds = self.model.embed_layer.weight.detach().numpy() vocab = list(self.rnn_ds.word2id.keys()) plot_embed(embeds, vocab, output_folder) plot_history(self.history, output_folder) y_true, y_pred = self.run_infer() y_true = self.rnn_ds.label_encoder.inverse_transform(y_true).squeeze() y_pred = self.rnn_ds.label_encoder.inverse_transform(y_pred).squeeze() classes = self.rnn_ds.label_encoder.categories_[0] plot_conf_matrix(y_true, y_pred, classes, output_folder)