import os
import tqdm
import time
import logging
import numpy as np
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from metrics import TextGenerationMetrics
[docs]
class EncoderLSTMCell(nn.Module):
"""
Encoder LSTM Cell
:param h_dim: Hidden state vector dimension
:type h_dim: int
:param inp_x_dim: Input vector dimension
:type inp_x_dim: int
"""
def __init__(self, h_dim, inp_x_dim):
super(EncoderLSTMCell, self).__init__()
self.wf_dense = nn.Linear(h_dim, h_dim)
self.uf_dense = nn.Linear(inp_x_dim, h_dim)
self.wi_dense = nn.Linear(h_dim, h_dim)
self.ui_dense = nn.Linear(inp_x_dim, h_dim)
self.wo_dense = nn.Linear(h_dim, h_dim)
self.uo_dense = nn.Linear(inp_x_dim, h_dim)
self.wc_dense = nn.Linear(h_dim, h_dim)
self.uc_dense = nn.Linear(inp_x_dim, h_dim)
[docs]
def forward(self, ht_1, ct_1, xt):
"""
Forward propogation
:param ht_1: Hidden state vector
:type ht_1: torch.Tensor (batch_size, h_dim)
:param ct_1: Cell stae vector
:type ct_1: torch.Tensor (batch_size, h_dim)
:param xt: Input vector
:type xt: torch.Tensor (batch_size, embed_dim)
:return: New hidden, cell states
:rtype: tuple (torch.Tensor [batch_size, h_dim], torch.Tensor [batch_size, h_dim])
"""
ft = nn.Sigmoid()(self.wf_dense(ht_1) + self.uf_dense(xt))
it = nn.Sigmoid()(self.wi_dense(ht_1) + self.ui_dense(xt))
ot = nn.Sigmoid()(self.wo_dense(ht_1) + self.uo_dense(xt))
ct_ = nn.Tanh()(self.wc_dense(ht_1) + self.uc_dense(xt))
ct = ft * ct_1 + it * ct_
ht = ot * nn.Tanh()(ct)
return ht, ct
### LSTM Cell
[docs]
class DecoderLSTMCell(nn.Module):
"""
Decode LSTM cell
:param h_dim: Hidden state vector dimension
:type h_dim: int
:param inp_x_dim: Input vector dimension
:type inp_x_dim: int
:param out_x_dim: Output vector dimension
:type out_x_dim: int
"""
def __init__(self, h_dim, inp_x_dim, out_x_dim):
super(DecoderLSTMCell, self).__init__()
self.wf_dense = nn.Linear(h_dim, h_dim)
self.uf_dense = nn.Linear(inp_x_dim, h_dim)
self.wi_dense = nn.Linear(h_dim, h_dim)
self.ui_dense = nn.Linear(inp_x_dim, h_dim)
self.wo_dense = nn.Linear(h_dim, h_dim)
self.uo_dense = nn.Linear(inp_x_dim, h_dim)
self.wc_dense = nn.Linear(h_dim, h_dim)
self.uc_dense = nn.Linear(inp_x_dim, h_dim)
self.xh_dense = nn.Linear(h_dim, out_x_dim)
[docs]
def forward(self, ht_1, ct_1, xt):
"""
Forward propogation
:param ht_1: Hidden state vector
:type ht_1: torch.Tensor (batch_size, h_dim)
:param ct_1: Cell stae vector
:type ct_1: torch.Tensor (batch_size, h_dim)
:param xt: Input vector
:type xt: torch.Tensor (batch_size, embed_dim)
:return: New hidden, cell states, output
:rtype: tuple (torch.Tensor [batch_size, h_dim], torch.Tensor [batch_size, h_dim], torch.Tensor [batch_size, out_dim])
"""
ft = nn.Sigmoid()(self.wf_dense(ht_1) + self.uf_dense(xt))
it = nn.Sigmoid()(self.wi_dense(ht_1) + self.ui_dense(xt))
ot = nn.Sigmoid()(self.wo_dense(ht_1) + self.uo_dense(xt))
ct_ = nn.Tanh()(self.wc_dense(ht_1) + self.uc_dense(xt))
ct = ft * ct_1 + it * ct_
ht = ot * nn.Tanh()(ct)
yt = self.xh_dense(ht)
return ht, ct, yt
[docs]
class Seq2SeqEncoder(nn.Module):
"""
Seq2Seq Encoder
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
super(Seq2SeqEncoder, self).__init__()
self.seq_len = config_dict["dataset"]["seq_len"]
self.num_layers = config_dict["model"]["num_layers"]
self.h_dims = config_dict["model"]["encoder_h_dim"]
self.x_dims = config_dict["model"]["encoder_x_dim"]
self.x_dims.append(2 * self.h_dims[-1])
num_src_vocab = config_dict["dataset"]["num_src_vocab"]
embed_dim = config_dict["model"]["embed_dim"]
self.src_embed_layer = nn.Embedding(num_src_vocab, embed_dim)
self.enc_fwd_lstm_cells, self.enc_bwd_lstm_cells = [], []
self.enc_y_dense_layers = []
for i in range(self.num_layers):
h_dim = self.h_dims[i]
inp_x_dim, out_x_dim = self.x_dims[i], self.x_dims[i + 1]
self.enc_fwd_lstm_cells.append(EncoderLSTMCell(h_dim, inp_x_dim))
self.enc_bwd_lstm_cells.append(EncoderLSTMCell(h_dim, inp_x_dim))
self.enc_y_dense_layers.append(nn.Linear(2 * h_dim, out_x_dim))
[docs]
def forward(self, src):
"""
Forward propogation
:param src: Source tokens
:type src: torch.Tensor (batch_size, seq_len)
:return: Predicted tokens, Hidden states
:rtype: tuple (torch.Tensor [batch_size, seq_len, out_dim], torch.Tensor [batch_size, 2*h_dim])
"""
self.num_samples = src.size(0)
hts_fwd, cts_fwd = self.init_hidden(), self.init_hidden()
hts_bwd, cts_bwd = self.init_hidden(), self.init_hidden()
src_embed = self.src_embed_layer(src.to(torch.long))
yts = list(torch.transpose(src_embed, 1, 0))
for i in range(self.num_layers):
hts_fwd_dict, hts_bwd_dict = defaultdict(), defaultdict()
ht_fwd, ct_fwd = hts_fwd[i], cts_fwd[i]
ht_bwd, ct_bwd = hts_bwd[i], cts_bwd[i]
for j in range(self.seq_len):
ht_fwd, ct_fwd = self.enc_fwd_lstm_cells[i](ht_fwd, ct_fwd, yts[j])
hts_fwd_dict[j] = ht_fwd
ht_bwd, ct_bwd = self.enc_bwd_lstm_cells[i](
ht_bwd, ct_bwd, yts[self.seq_len - j - 1]
)
hts_bwd_dict[self.seq_len - j - 1] = ht_bwd
for w in range(self.seq_len):
ht_fwd, ht_bwd = hts_fwd_dict[w], hts_bwd_dict[w]
ht = torch.cat([ht_fwd, ht_bwd], dim=-1)
yt = self.enc_y_dense_layers[i](ht)
yts[w] = yt
return torch.stack(yts, dim=1), ht
[docs]
def init_hidden(self):
"""
Initialized hidden states
:return: List of hidden states
:rtype: list
"""
hts = [
nn.init.kaiming_uniform_(torch.empty(self.num_samples, dim))
for dim in self.h_dims
]
return hts
[docs]
class Seq2SeqAttention(nn.Module):
"""
Seq2Seq Attention layer
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
super(Seq2SeqAttention, self).__init__()
decoder_s_dim = 2 * config_dict["model"]["encoder_h_dim"][-1]
encoder_y_dim = config_dict["model"]["encoder_x_dim"][-1]
self.si_dense = nn.Linear(decoder_s_dim, decoder_s_dim)
self.yi_dense = nn.Linear(encoder_y_dim, decoder_s_dim)
self.attn_weights = nn.Linear(decoder_s_dim, 1)
[docs]
def forward(self, si_1, yts):
"""
Forward Propogation
:param si_1: Hidden state vector of decoder layer
:type si_1: torch.Tensor (batch_size, 2*h_dim)
:param yts: Encoder output vectors
:type yts: torch.tensor (batch_size, seq_len, out_dim)
:return: Attention weights, New hidden state vector
:rtype: tuple (torch.Tensor [batch_size, seq_len], torch.Tensor [batch_size, 2*h_dim])
"""
eij = self.attn_weights(self.si_dense(si_1.unsqueeze(1)) + self.yi_dense(yts))
eij = eij.squeeze(2).unsqueeze(1)
weights = F.softmax(eij, dim=-1)
si = torch.bmm(weights, yts).squeeze()
return weights, si
[docs]
class Seq2SeqDecoder(nn.Module):
"""
Seq2Seq Decoder
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
super(Seq2SeqDecoder, self).__init__()
self.seq_len = config_dict["dataset"]["seq_len"]
s_dim = 2 * config_dict["model"]["encoder_h_dim"][-1]
decoder_x_dim = config_dict["model"]["encoder_x_dim"][-1]
decoder_y_dim = config_dict["model"]["decoder_y_dim"]
num_tgt_vocab = config_dict["dataset"]["num_tgt_vocab"]
embed_dim = config_dict["model"]["embed_dim"]
self.tgt_embed_layer = nn.Embedding(num_tgt_vocab, embed_dim)
self.attn_layer = Seq2SeqAttention(config_dict)
self.dec_lstm_cell = DecoderLSTMCell(s_dim, decoder_x_dim, decoder_y_dim)
self.tgt_word_classifier = nn.Linear(decoder_y_dim, num_tgt_vocab)
[docs]
def forward(self, encoder_yts, encoder_h, tgt=None):
"""
Forward propogation
:param encoder_yts: Encoder Output vectors
:type encoder_yts: torch.Tensor (batch_size, seq_len, out_dim)
:param encoder_h: Encoder final hidden vectors
:type encoder_h: torch.Tensor (batch_size, seq_len, 2*h_dim)
:param tgt: Target vectors, defaults to None
:type tgt: torch.Tensor (batch_size, seq_len), optional
:return: Final predictions, Attention weights
:rtype: tuple (torch.Tensor [batch_size, seq_len, num_tgt_vocab], list)
"""
batch_size = encoder_yts.size(0)
if tgt is None:
sos_token = 2 * torch.ones(batch_size, 1).to(torch.long)
yt = self.tgt_embed_layer(sos_token)[:, 0, :]
else:
tgt_embeds = self.tgt_embed_layer(tgt.to(torch.long))
yt = tgt_embeds[:, 0, :]
st = encoder_h
pts = []
attn_weights = []
for i in range(self.seq_len):
weights, ct = self.attn_layer(st, encoder_yts)
st, ct, yt = self.dec_lstm_cell(st, ct, yt)
yt = self.tgt_word_classifier(yt)
# pt = nn.Softmax()(yt)[:, None, :]
pt = yt[:, None, :]
pts.append(pt)
attn_weights.append(weights)
if i >= self.seq_len - 1:
break
if tgt is not None:
yt = tgt_embeds[:, i + 1, :]
else:
yt = yt.argmax(axis=1)
yt = self.tgt_embed_layer(yt.to(torch.long))
return torch.concat(pts, dim=1), attn_weights
[docs]
class Seq2SeqModel(nn.Module):
"""
Seq2Seq Model Architecture
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
super(Seq2SeqModel, self).__init__()
self.encoder = Seq2SeqEncoder(config_dict)
self.decoder = Seq2SeqDecoder(config_dict)
[docs]
def forward(self, src, tgt=None):
"""
Forward propogation
:param src: Source tokens
:type src: torch.Tensor (batch_size, seq_len)
:param tgt: _Target tokens, defaults to None
:type tgt: torch.Tensor (batch_size, seq_len), optional
:return: Final predictions, Attention weights
:rtype: tuple (torch.Tensor [batch_size, seq_len, num_tgt_vocab], list)
"""
encoder_yts, encoder_h = self.encoder(src)
tgt_probs, attn_weights = self.decoder(encoder_yts, encoder_h, tgt)
return tgt_probs, attn_weights
[docs]
class Seq2SeqTrainer(nn.Module):
"""
Seq2Seq Trainer
:param model: Seq2Seq model
:type model: torch.nn.Module
:param optimizer: Optimizer
:type optimizer: torch.optim
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, model, optimizer, config_dict):
super(Seq2SeqTrainer, self).__init__()
self.logger = logging.getLogger(__name__)
self.model = model
self.optim = optimizer
self.config_dict = config_dict
self.metric_cls = TextGenerationMetrics(config_dict)
self.eval_metric = config_dict["train"]["eval_metric"]
[docs]
def train_one_epoch(self, data_loader, epoch):
"""
Train step
:param data_loader: Train Data Loader
:type data_loader: torch.utils.data.Dataloader
:param epoch: Epoch number
:type epoch: int
:return: Train Loss, Train Metrics
:rtype: tuple (torch.float32, dict)
"""
self.model.train()
total_loss, num_instances = 0, 0
y_true, y_pred = [], []
self.logger.info(
f"-----------Epoch {epoch}/{self.config_dict['train']['epochs']}-----------"
)
pbar = tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc="Training"
)
for batch_id, (src, tgt) in pbar:
tgt = tgt.to(torch.long)
tgt_hat, attn_weights = self.model(src, tgt)
loss = self.calc_loss(tgt_hat, tgt)
loss.backward()
self.optim.step()
self.optim.zero_grad()
total_loss += loss
num_instances += tgt.size(0)
y_true.append(tgt.cpu().detach().numpy())
y_pred.append(tgt_hat.cpu().detach().numpy())
train_loss = total_loss / num_instances
y_true = np.concatenate(y_true, axis=0)
y_pred = np.concatenate(y_pred, axis=0)
train_metrics = self.metric_cls.get_metrics(y_true, y_pred)
return train_loss, train_metrics
[docs]
@torch.no_grad()
def val_one_epoch(self, data_loader):
"""
Validation step
:param data_loader: Validation Data Loader
:type data_loader: torch.utils.data.Dataloader
:return: Validation Loss, Validation Metrics
:rtype: tuple (torch.float32, dict)
"""
self.model.eval()
total_loss, num_instances = 0, 0
y_true, y_pred = [], []
pbar = tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc="Validation"
)
for batch_id, (src, tgt) in pbar:
tgt = tgt.to(torch.long)
tgt_hat, attn_weights = self.model(src, tgt)
loss = self.calc_loss(tgt_hat, tgt)
total_loss += loss
num_instances += tgt.size(0)
y_true.append(tgt.cpu().detach().numpy())
y_pred.append(tgt_hat.cpu().detach().numpy())
val_loss = total_loss / num_instances
y_true = np.concatenate(y_true, axis=0)
y_pred = np.concatenate(y_pred, axis=0)
val_metrics = self.metric_cls.get_metrics(y_true, y_pred)
return val_loss, val_metrics
[docs]
@torch.no_grad()
def predict(self, data_loader):
"""
Runs inference to predict a translation of soruce sentence
:param data_loader: Infer Data loader
:type data_loader: torch.utils.data.DataLoader
:return: Predicted tokens
:rtype: numpy.ndarray (num_samples, seq_len, num_tgt_vocab)
"""
self.model.eval()
y_pred = []
pbar = tqdm.tqdm(
enumerate(data_loader), total=len(data_loader), desc="Inference"
)
for batch_id, src in pbar:
tgt_hat, attn_weights = self.model(src[0])
y_pred.append(tgt_hat.cpu().detach().numpy())
y_pred = np.concatenate(y_pred, axis=0)
return y_pred
[docs]
def fit(self, train_loader, val_loader):
"""
Fits the model on dataset. Runs training and Validation steps for given epochs and saves best model based on the evaluation metric
:param train_loader: Train Data loader
:type train_loader: torch.utils.data.DataLoader
:param val_loader: Validaion Data Loader
:type val_loader: torch.utils.data.DataLoader
:return: Training History
:rtype: dict
"""
num_epochs = self.config_dict["train"]["epochs"]
output_folder = self.config_dict["paths"]["output_folder"]
best_val_metric = np.inf
history = defaultdict(list)
start = time.time()
for epoch in range(1, num_epochs + 1):
train_loss, train_metrics = self.train_one_epoch(train_loader, epoch)
val_loss, val_metrics = self.val_one_epoch(val_loader)
history["train_loss"].append(float(train_loss.detach().numpy()))
history["val_loss"].append(float(val_loss.detach().numpy()))
for key in train_metrics.keys():
history[f"train_{key}"].append(train_metrics[key])
history[f"val_{key}"].append(val_metrics[key])
self.logger.info(f"Train Loss : {train_loss} - Val Loss : {val_loss}")
for key in train_metrics.keys():
self.logger.info(
f"Train {key} : {train_metrics[key]} - Val {key} : {val_metrics[key]}"
)
if val_metrics[self.eval_metric] <= best_val_metric:
self.logger.info(
f"Validation {self.eval_metric} score improved from {best_val_metric} to {val_metrics[self.eval_metric]}"
)
best_val_metric = val_metrics[self.eval_metric]
torch.save(
self.model.state_dict(),
os.path.join(output_folder, "best_model.pt"),
)
else:
self.logger.info(
f"Validation {self.eval_metric} score didn't improve from {best_val_metric}"
)
end = time.time()
time_taken = end - start
self.logger.info(
"Training completed in {:.0f}h {:.0f}m {:.0f}s".format(
time_taken // 3600, (time_taken % 3600) // 60, (time_taken % 3600) % 60
)
)
self.logger.info(f"Best Val {self.eval_metric} score: {best_val_metric}")
return history
[docs]
def calc_loss(self, y_pred, y_true):
"""
Crossentropy loss for predicted tokens
:param y_pred: Predicted tokens
:type y_pred: torch.Tensor (batch_size, seq_len, num_vocab)
:param y_true: True tokens
:type y_true: torch.Tensor (batch_size, seq_len)
:return: BCE Loss
:rtype: torch.float32
"""
y_pred = torch.flatten(y_pred, end_dim=1)
y_true = torch.flatten(y_true)
y_true = F.one_hot(
y_true, num_classes=self.config_dict["dataset"]["num_tgt_vocab"]
).to(torch.float)
loss_fn = nn.CrossEntropyLoss(reduce="sum")
return loss_fn(y_pred, y_true)