import os
import tqdm
import time
import math
import logging
import numpy as np
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.transformer.model import MultiHeadAttention, PositionalEncoding, FeedForward
from metrics import TextGenerationMetrics
[docs]
class DecoderLayer(nn.Module):
"""
GPT Decoder layer
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
super(DecoderLayer, self).__init__()
dropout = config_dict["model"]["dropout"]
d_model = config_dict["model"]["d_model"]
self.mh_masked_self_attn = MultiHeadAttention(config_dict)
self.layer_norm = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.feed_forward = FeedForward(config_dict)
[docs]
def forward(self, tokens):
"""
Forward propogation
:param tokens: Input tokens
:type tokens: torch.Tensor (num_samples, seq_len)
:return: Decoder output
:rtype: torch.Tensor (num_samples, seq_len, d_ff)
"""
tokens = self.layer_norm(tokens)
masked_attn_output = self.mh_masked_self_attn(tokens, tokens, tokens, True)
output = tokens + self.dropout1(masked_attn_output)
ffwd_output = self.feed_forward(output)
output = output + self.dropout2(ffwd_output)
return output
[docs]
class GPTModel(nn.Module):
"""
GPT Architecture
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
super(GPTModel, self).__init__()
embed_dim = config_dict["model"]["d_model"]
num_vocab = config_dict["dataset"]["num_vocab"]
num_layers = config_dict["model"]["num_layers"]
dropout = config_dict["model"]["dropout"]
d_model = config_dict["model"]["d_model"]
self.num_predict_tokens = config_dict["test"]["predict_tokens"]
self.seq_len = config_dict["dataset"]["seq_len"]
self.embed_layer = nn.Embedding(num_vocab, embed_dim)
self.positional_encoding = PositionalEncoding(config_dict)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model)
self.decoder_layers = [DecoderLayer(config_dict) for _ in range(num_layers)]
self.classifier_layer = nn.Linear(embed_dim, num_vocab)
[docs]
def forward(self, tokens):
"""
Forward propogation
:param tokens: Input tokens
:type tokens: torch.Tensor (num_samples, seq_len)
:return: probability of Generated Tokens
:rtype: torch.Tensor (num_samples, seq_len, num_vocab)
"""
embeds = self.dropout(self.positional_encoding(self.embed_layer(tokens)))
dec_output = embeds
for layer in self.decoder_layers:
dec_output = layer(dec_output)
output = self.layer_norm(dec_output)
output = self.classifier_layer(output)
return output
[docs]
def generate(self, tokens):
"""
Generate Tokens
:param tokens: Input tokens
:type tokens: torch.Tensor (num_samples, seq_len + num_pred_tokens)
:return: Generated tokens
:rtype: torch.Tensor (num_samples, seq_len + num_pred_tokens)
"""
tokens_pred = torch.zeros_like(tokens)
tokens_pred[:, : self.seq_len] = tokens[:, : self.seq_len]
for i in range(self.num_predict_tokens):
inputs = tokens_pred[:, i : i + self.seq_len]
outputs = self.forward(inputs)
new_token = torch.argmax(outputs[:, -1, :], dim=-1)
tokens_pred[:, self.seq_len + i] = new_token.squeeze()
return tokens_pred
[docs]
class GPTTrainer(nn.Module):
"""
GPT Trainer
:param model: GPT model
:type model: torch.nn.Module
:param optimizer: Optimizer
:type optimizer: torch.optim
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, model, optimizer, config_dict):
super(GPTTrainer, self).__init__()
self.logger = logging.getLogger(__name__)
self.model = model
self.optim = optimizer
self.config_dict = config_dict
self.metric_cls = TextGenerationMetrics(config_dict)
self.eval_metric = config_dict["train"]["eval_metric"]
[docs]
def train_one_epoch(self, data_loader, epoch):
"""
Train step
:param data_loader: Train Data Loader
:type data_loader: torch.utils.data.Dataloader
:param epoch: Epoch number
:type epoch: int
:return: Train Losse, Train Metrics
:rtype: tuple (torch.float32, dict)
"""
self.model.train()
total_loss, num_instances = 0, 0
y_true, y_pred = [], []
self.logger.info(
f"-----------Epoch {epoch}/{self.config_dict['train']['epochs']}-----------"
)
pbar = tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc="Training"
)
for batch_id, sent in pbar:
src, tgt = sent[0][:, :-1], sent[0][:, 1:]
src = src.to(torch.long)
tgt = tgt.to(torch.long)
tgt_hat = self.model(src)
loss = self.calc_loss(tgt_hat, tgt)
loss.backward()
self.optim.step()
self.optim.zero_grad()
total_loss += loss
num_instances += tgt.size(0)
y_true.append(tgt.cpu().detach().numpy())
y_pred.append(tgt_hat.cpu().detach().numpy())
train_loss = total_loss / num_instances
y_true = np.concatenate(y_true, axis=0)
y_pred = np.concatenate(y_pred, axis=0)
train_metrics = self.metric_cls.get_metrics(y_true, y_pred)
return train_loss, train_metrics
[docs]
@torch.no_grad()
def val_one_epoch(self, data_loader):
"""
Validation step
:param data_loader: Validation Data Loader
:type data_loader: torch.utils.data.Dataloader
:return: Validation Losse, Validation Metrics
:rtype: tuple (torch.float32, dict)
"""
self.model.eval()
total_loss, num_instances = 0, 0
y_true, y_pred = [], []
pbar = tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc="Validation"
)
for batch_id, sent in pbar:
src, tgt = sent[0][:, :-1], sent[0][:, 1:]
src = src.to(torch.long)
tgt = tgt.to(torch.long)
tgt_hat = self.model(src)
loss = self.calc_loss(tgt_hat, tgt)
total_loss += loss
num_instances += tgt.size(0)
y_true.append(tgt.cpu().detach().numpy())
y_pred.append(tgt_hat.cpu().detach().numpy())
val_loss = total_loss / num_instances
y_true = np.concatenate(y_true, axis=0)
y_pred = np.concatenate(y_pred, axis=0)
val_metrics = self.metric_cls.get_metrics(y_true, y_pred)
return val_loss, val_metrics
[docs]
@torch.no_grad()
def predict(self, data_loader):
"""
Runs inference to predict a shifted sentence
:param data_loader: Infer Data loader
:type data_loader: torch.utils.data.DataLoader
:return: True tokens, Predicted tokens
:rtype: tuple (numpy.ndarray [num_samples, seq_len], numpy.ndarray [num_samples, seq_len, num_vocab])
"""
self.model.eval()
y_pred, sents = [], []
pbar = tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc="Inference"
)
for batch_id, sent in pbar:
src, tgt = sent[0][:, :-1], sent[0][:, 1:]
src = src.to(torch.long)
tgt = tgt.to(torch.long)
tgt_hat = self.model(src)
y_pred.append(tgt_hat.cpu().detach().numpy())
sents.append(sent[0].cpu().detach().numpy())
y_pred = np.concatenate(y_pred, axis=0)
sents = np.concatenate(sents, axis=0)
return sents, y_pred
[docs]
@torch.no_grad()
def generate(self, data_loader):
"""
Runs inference to generate new text
:param data_loader: Infer Data loader
:type data_loader: torch.utils.data.DataLoader
:return: True tokens, Generated tokens
:rtype: tuple (numpy.ndarray [num_samples, seq_len + num_pred_tokens], numpy.ndarray [num_samples, seq_len + num_pred_tokens])
"""
self.model.eval()
y_true, y_pred = [], []
pbar = tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc="Generation"
)
for batch_id, tokens in pbar:
tokens = tokens[0].to(torch.long)
tokens_pred = self.model.generate(tokens)
y_pred.append(tokens_pred.cpu().detach().numpy())
y_true.append(tokens.cpu().detach().numpy())
y_pred = np.concatenate(y_pred, axis=0)
y_true = np.concatenate(y_true, axis=0)
return y_true, y_pred
[docs]
def fit(self, train_loader, val_loader):
"""
Fits the model on dataset. Runs training and Validation steps for given epochs and saves best model based on the evaluation metric
:param train_loader: Train Data loader
:type train_loader: torch.utils.data.DataLoader
:param val_loader: Validaion Data Loader
:type val_loader: torch.utils.data.DataLoader
:return: Training History
:rtype: dict
"""
num_epochs = self.config_dict["train"]["epochs"]
output_folder = self.config_dict["paths"]["output_folder"]
best_val_metric = np.inf
history = defaultdict(list)
start = time.time()
for epoch in range(1, num_epochs + 1):
train_loss, train_metrics = self.train_one_epoch(train_loader, epoch)
val_loss, val_metrics = self.val_one_epoch(val_loader)
history["train_loss"].append(float(train_loss.detach().numpy()))
history["val_loss"].append(float(val_loss.detach().numpy()))
for key in train_metrics.keys():
history[f"train_{key}"].append(train_metrics[key])
history[f"val_{key}"].append(val_metrics[key])
self.logger.info(f"Train Loss : {train_loss} - Val Loss : {val_loss}")
for key in train_metrics.keys():
self.logger.info(
f"Train {key} : {train_metrics[key]} - Val {key} : {val_metrics[key]}"
)
if val_metrics[self.eval_metric] <= best_val_metric:
self.logger.info(
f"Validation {self.eval_metric} score improved from {best_val_metric} to {val_metrics[self.eval_metric]}"
)
best_val_metric = val_metrics[self.eval_metric]
torch.save(
self.model.state_dict(),
os.path.join(output_folder, "best_model.pt"),
)
else:
self.logger.info(
f"Validation {self.eval_metric} score didn't improve from {best_val_metric}"
)
end = time.time()
time_taken = end - start
self.logger.info(
"Training completed in {:.0f}h {:.0f}m {:.0f}s".format(
time_taken // 3600, (time_taken % 3600) // 60, (time_taken % 3600) % 60
)
)
self.logger.info(f"Best Val {self.eval_metric} score: {best_val_metric}")
return history
[docs]
def calc_loss(self, y_pred, y_true):
"""
Crossentropy loss for predicted tokens
:param y_pred: Predicted tokens
:type y_pred: torch.Tensor (batch_size, seq_len, num_vocab)
:param y_true: True tokens
:type y_true: torch.Tensor (batch_size, seq_len)
:return: BCE Loss
:rtype: torch.float32
"""
y_pred = torch.flatten(y_pred, end_dim=1)
y_true = torch.flatten(y_true)
y_true = F.one_hot(
y_true, num_classes=self.config_dict["dataset"]["num_vocab"]
).to(torch.float)
loss_fn = nn.CrossEntropyLoss(reduce="sum")
return loss_fn(y_pred, y_true)