Source code for src.core.bert.pretrain

import os
import tqdm
import time
import math
import logging
import numpy as np
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class BERTPretrainTrainer(nn.Module): """ BERT Pretrain Model trainer :param model: BERT Pretrain model :type model: torch.nn.Module :param optimizer: Optimizer :type optimizer: torch.optim :param config_dict: Config Params Dictionary :type config_dict: dict """ def __init__(self, model, optimizer, config_dict): super(BERTPretrainTrainer, self).__init__() self.logger = logging.getLogger(__name__) self.model = model self.optim = optimizer self.config_dict = config_dict
[docs] def train_one_epoch(self, data_loader, epoch): """ Train step :param data_loader: Train Data Loader :type data_loader: torch.utils.data.Dataloader :param epoch: Epoch number :type epoch: int :return: Train Losses (Train Loss, Train masked tokens loss, Train NSP loss) :rtype: tuple (torch.float32, torch.flooat32, torch.float32) """ self.model.train() total_loss, total_clf_loss, total_nsp_loss, num_instances = 0, 0, 0, 0 self.logger.info( f"-----------Epoch {epoch}/{self.config_dict['train']['epochs']}-----------" ) pbar = tqdm.tqdm( enumerate(data_loader), total=len(data_loader), desc="Training" ) for batch_id, (tokens, tokens_mask, nsp_labels) in pbar: tokens_pred, nsp_pred = self.model(tokens) clf_loss, nsp_loss = self.calc_loss( tokens_pred, nsp_pred, tokens, nsp_labels, tokens_mask ) loss = clf_loss + nsp_loss loss.backward() self.optim.step() self.optim.zero_grad() total_loss += loss total_clf_loss += clf_loss total_nsp_loss += nsp_loss num_instances += tokens.size(0) train_loss = total_loss / num_instances train_clf_loss = total_clf_loss / num_instances train_nsp_loss = total_nsp_loss / num_instances return train_loss, train_clf_loss, train_nsp_loss
[docs] @torch.no_grad() def val_one_epoch(self, data_loader): """ Validation step :param data_loader: Validation Data Loader :type data_loader: torch.utils.data.Dataloader :return: Validation Losses (Validation Loss, Validation masked tokens loss, Validation NSP loss) :rtype: tuple (torch.float32, torch.flooat32, torch.float32) """ self.model.eval() total_loss, total_clf_loss, total_nsp_loss, num_instances = 0, 0, 0, 0 pbar = tqdm.tqdm( enumerate(data_loader), total=len(data_loader), desc="Validation" ) for batch_id, (tokens, tokens_mask, nsp_labels) in pbar: tokens_pred, nsp_pred = self.model(tokens) clf_loss, nsp_loss = self.calc_loss( tokens_pred, nsp_pred, tokens, nsp_labels, tokens_mask ) loss = clf_loss + nsp_loss total_loss += loss total_clf_loss += clf_loss total_nsp_loss += nsp_loss num_instances += tokens.size(0) val_loss = total_loss / num_instances val_clf_loss = total_clf_loss / num_instances val_nsp_loss = total_nsp_loss / num_instances return val_loss, val_clf_loss, val_nsp_loss
[docs] @torch.no_grad() def predict(self, data_loader): """ Runs inference on input data :param data_loader: Infer Data loader :type data_loader: torch.utils.data.DataLoader :return: Labels, Predictions (True tokens, NSP Labels, Predicton tokens, NSP Predictions) :rtype: tuple (numpy.ndarray [num_samples, seq_len], numpy.ndarray [num_samples,], numpy.ndarray [num_samples, seq_len], numpy.ndarray [num_samples,]) """ self.model.eval() y_tokens_pred, y_tokens_true = [], [] y_nsp_pred, y_nsp_true = [], [] pbar = tqdm.tqdm( enumerate(data_loader), total=len(data_loader), desc="Inference" ) for batch_id, (tokens, _, nsp_labels) in pbar: tokens_pred, nsp_pred = self.model(tokens) y_tokens_pred.append(tokens_pred.cpu().detach().numpy()) y_tokens_true.append(tokens.cpu().detach().numpy()) y_nsp_true.append(nsp_labels.cpu().detach().numpy()) y_nsp_pred.append(nsp_pred.cpu().detach().numpy()) y_tokens_pred = np.concatenate(y_tokens_pred, axis=0) y_tokens_true = np.concatenate(y_tokens_true, axis=0) y_nsp_pred = np.concatenate(y_nsp_pred, axis=0) y_nsp_true = np.concatenate(y_nsp_true, axis=0) return y_tokens_true, y_nsp_true, y_tokens_pred, y_nsp_pred
[docs] def fit(self, train_loader, val_loader): """ Fits the model on dataset. Runs training and Validation steps for given epochs and saves best model based on the evaluation metric :param train_loader: Train Data loader :type train_loader: torch.utils.data.DataLoader :param val_loader: Validaion Data Loader :type val_loader: torch.utils.data.DataLoader :return: Training History :rtype: dict """ num_epochs = self.config_dict["train"]["epochs"] output_folder = self.config_dict["paths"]["output_folder"] best_val_loss = np.inf history = defaultdict(list) start = time.time() for epoch in range(1, num_epochs + 1): train_loss, train_clf_loss, train_nsp_loss = self.train_one_epoch( train_loader, epoch ) val_loss, val_clf_loss, val_nsp_loss = self.val_one_epoch(val_loader) history["train_loss"].append(float(train_loss.detach().numpy())) history["val_loss"].append(float(val_loss.detach().numpy())) history["train_clf_loss"].append(float(train_clf_loss.detach().numpy())) history["val_clf_loss"].append(float(val_clf_loss.detach().numpy())) history["train_nsp_loss"].append(float(train_nsp_loss.detach().numpy())) history["val_nsp_loss"].append(float(val_nsp_loss.detach().numpy())) self.logger.info(f"Train Loss : {train_loss} - Val Loss : {val_loss}") if val_loss <= best_val_loss: self.logger.info( f"Validation Loss score improved from {best_val_loss} to {val_loss}" ) best_val_loss = val_loss torch.save( self.model.state_dict(), os.path.join(output_folder, "best_model_pretrain.pt"), ) else: self.logger.info( f"Validation Loss score didn't improve from {best_val_loss}" ) end = time.time() time_taken = end - start self.logger.info( "Training completed in {:.0f}h {:.0f}m {:.0f}s".format( time_taken // 3600, (time_taken % 3600) // 60, (time_taken % 3600) % 60 ) ) self.logger.info(f"Best Val Loss score: {best_val_loss}") return history
[docs] def calc_loss(self, tokens_pred, nsp_pred, tokens_true, nsp_labels, tokens_mask): """ Calculates Training Loss components :param tokens_pred: Predicted Tokens :type tokens_pred: torch.Tensor (batch_size, seq_len, num_vocab) :param nsp_pred: Predicted NSP Label :type nsp_pred: torch.Tensor (batch_size,) :param tokens_true: True tokens :type tokens_true: torch.Tensor (batch_size, seq_len) :param nsp_labels: NSP labels :type nsp_labels: torch.Tensor (batch_size,) :param tokens_mask: Tokens mask :type tokens_mask: torch.Tensor (batch_size, seq_len) :return: Masked word prediction Cross Entropy loss, NSP classification loss :rtype: tuple (torch.float32, torch.float32) """ num_vocab = self.config_dict["dataset"]["num_vocab"] nsp_loss_fn = nn.BCELoss() nsp_loss = nsp_loss_fn(nsp_pred.squeeze(), nsp_labels) mask_flatten = torch.flatten(tokens_mask).bool() mask_flatten = torch.stack([mask_flatten] * num_vocab, dim=1) y_pred = torch.flatten(tokens_pred, end_dim=1) y_true = torch.flatten(tokens_true) y_true = F.one_hot(y_true, num_classes=num_vocab).to(torch.float) y_pred_mask = torch.masked_select(y_pred, mask_flatten).reshape( -1, y_pred.shape[1] ) y_true_mask = torch.masked_select(y_true, mask_flatten).reshape( -1, y_true.shape[1] ) clf_loss_fn = nn.CrossEntropyLoss(reduce="sum") clf_loss = clf_loss_fn(y_pred_mask, y_true_mask) return clf_loss, nsp_loss