import os
import time
import torch
import tqdm
import numpy as np
import logging
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
[docs]
class Word2VecModel(nn.Module):
"""
Word2Vec Model
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
super(Word2VecModel, self).__init__()
self.embed_dim = config_dict["model"]["embed_dim"]
self.num_vocab = 1 + config_dict["dataset"]["num_vocab"]
self.cxt_embedding = nn.Embedding(self.num_vocab, self.embed_dim)
self.lbl_embedding = nn.Embedding(self.num_vocab - 1, self.embed_dim)
[docs]
def forward(self, l_cxt, r_cxt, l_lbl, r_lbl):
"""
Forward propogation
:param l_cxt: Left context
:type l_cxt: torch.Tensor (batch_size,)
:param r_cxt: Right context
:type r_cxt: torch.Tensor (batch_size,)
:param l_lbl: Left label
:type l_lbl: torch.Tensor (batch_size,)
:param r_lbl: Right label
:type r_lbl: torch.Tensor (batch_size,)
:return: Loss
:rtype: torch.float32
"""
l_cxt_emb = self.compute_cxt_embed(l_cxt)
r_cxt_emb = self.compute_cxt_embed(r_cxt)
l_lbl_emb = self.lbl_embedding(torch.LongTensor(l_lbl) - self.num_vocab)
r_lbl_emb = self.lbl_embedding(torch.LongTensor(r_lbl) - self.num_vocab)
l_loss = torch.mul(l_cxt_emb, l_lbl_emb).squeeze()
l_loss = torch.sum(l_loss, dim=1)
l_loss = F.logsigmoid(-1 * l_loss)
r_loss = torch.mul(r_cxt_emb, r_lbl_emb).squeeze()
r_loss = torch.sum(r_loss, dim=1)
r_loss = F.logsigmoid(r_loss)
loss = torch.sum(l_loss) + torch.sum(r_loss)
return -1 * loss
[docs]
def compute_cxt_embed(self, cxt):
"""
Computes context embedding vector
:param cxt: Context vector
:type cxt: torch.Tensor (batch_size, context_len)
:return: Label embedding
:rtype: torch.Tensor (batch_size, embed_dim)
"""
lbl_emb = self.cxt_embedding(torch.LongTensor(cxt))
return torch.mean(lbl_emb, dim=1)
[docs]
class Word2VecTrainer(nn.Module):
"""
Word2Vec Trainer
:param model: Word2Vec model
:type model: torch.nn.Module
:param optimizer: Optimizer
:type optimizer: torch.optim
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, model, optimizer, config_dict):
super(Word2VecTrainer, self).__init__()
self.logger = logging.getLogger(__name__)
self.model = model
self.optim = optimizer
self.config_dict = config_dict
[docs]
def train_one_epoch(self, data_loader, epoch):
"""
Train step
:param data_loader: Train Data Loader
:type data_loader: torch.utils.data.Dataloader
:param epoch: Epoch number
:type epoch: int
:return: Train Losse
:rtype: torch.float32
"""
self.model.train()
total_loss, num_instances = 0, 0
left_loader, right_loader = data_loader
self.logger.info(
f"-----------Epoch {epoch}/{self.config_dict['train']['epochs']}-----------"
)
pbar = tqdm.tqdm(
enumerate(left_loader), total=len(left_loader), desc="Training"
)
right_iter = iter(right_loader)
for batch_id, (l_cxt, l_lbl) in pbar:
try:
r_cxt, r_lbl = next(right_iter)
except:
right_iter = iter(right_loader)
r_cxt, r_lbl = next(right_iter)
l_cxt = l_cxt.to(dtype=torch.long)
r_cxt = r_cxt.to(dtype=torch.long)
l_lbl = l_lbl.to(dtype=torch.long)
r_lbl = r_lbl.to(dtype=torch.long)
loss = self.model(l_cxt, r_cxt, l_lbl, r_lbl)
loss.backward()
self.optim.step()
self.optim.zero_grad()
total_loss += loss
num_instances += r_lbl.size(0)
train_loss = total_loss / num_instances
return train_loss
[docs]
@torch.no_grad()
def val_one_epoch(self, data_loader):
"""
Validation step
:param data_loader: Validation Data Loader
:type data_loader: torch.utils.data.Dataloader
:return: Validation Loss
:rtype: torch.float32
"""
self.model.eval()
total_loss, num_instances = 0, 0
left_loader, right_loader = data_loader
pbar = tqdm.tqdm(
enumerate(left_loader), total=len(left_loader), desc="Validation"
)
right_iter = iter(right_loader)
for batch_id, (l_cxt, l_lbl) in pbar:
try:
r_cxt, r_lbl = next(right_iter)
except:
right_iter = iter(right_loader)
r_cxt, r_lbl = next(right_iter)
l_cxt = l_cxt.to(dtype=torch.long)
r_cxt = r_cxt.to(dtype=torch.long)
l_lbl = l_lbl.to(dtype=torch.long)
r_lbl = r_lbl.to(dtype=torch.long)
loss = self.model(l_cxt, r_cxt, l_lbl, r_lbl)
total_loss += loss
num_instances += r_lbl.size(0)
val_loss = total_loss / num_instances
return val_loss
[docs]
def fit(self, train_loader, val_loader):
"""
Fits the model on dataset. Runs training and Validation steps for given epochs and saves best model based on the evaluation metric
:param train_loader: Train Data loader
:type train_loader: torch.utils.data.DataLoader
:param val_loader: Validaion Data Loader
:type val_loader: torch.utils.data.DataLoader
:return: Training History
:rtype: dict
"""
logger = logging.getLogger(__name__)
num_epochs = self.config_dict["train"]["epochs"]
output_folder = self.config_dict["paths"]["output_folder"]
best_val_loss = np.inf
history = defaultdict(list)
start = time.time()
for epoch in range(1, num_epochs + 1):
train_loss = self.train_one_epoch(train_loader, epoch)
val_loss = self.val_one_epoch(val_loader)
logger.info(f"Train Loss : {train_loss} - Val Loss : {val_loss}")
history["train_loss"].append(float(train_loss.detach().numpy()))
history["val_loss"].append(float(val_loss.detach().numpy()))
if val_loss <= best_val_loss:
logger.info(
f"Validation Loss improved from {best_val_loss} to {val_loss}"
)
best_val_loss = val_loss
torch.save(
self.model.state_dict(),
os.path.join(output_folder, "best_model.pt"),
)
else:
logger.info(f"Validation loss didn't improve from {best_val_loss}")
end = time.time()
time_taken = end - start
logger.info(
"Training completed in {:.0f}h {:.0f}m {:.0f}s".format(
time_taken // 3600, (time_taken % 3600) // 60, (time_taken % 3600) % 60
)
)
logger.info(f"Best Val RMSE: {best_val_loss}")
return history