src.core.glove package#
Submodules#
src.core.glove.dataset module#
- class src.core.glove.dataset.GloVeDataset(config_dict)[source]#
Bases:
Word2VecDataset
GloVe Dataset
- Parameters:
config_dict (dict) – Config Params Dictionary
- src.core.glove.dataset.create_dataloader(X_ctr, X_cxt, X_count, val_split=0.2, batch_size=32, seed=2024)[source]#
Creates Train, Validation DataLoader
- Parameters:
X_ctr (numpy.ndarray (num_samples, )) – Center words
X_cxt (numpy.ndarray (num_samples, )) – Context words
X_count (numpy.ndarray (num_samples, )) – Co-occurence matrix elements
val_split (float, optional) – validation split, defaults to 0.2
batch_size (int, optional) – Batch size, defaults to 32
seed (int, optional) – Seed, defaults to 2024
- Returns:
Train, Val dataloaders
- Return type:
tuple (torch.utils.data.DataLoader, torch.utils.data.DataLoader)
src.core.glove.glove module#
- class src.core.glove.glove.GloVe(config_dict)[source]#
Bases:
object
A class to run GloVe data preprocessing, training and inference
- Parameters:
config_dict (dict) – Config Params Dictionary
src.core.glove.model module#
- class src.core.glove.model.GloVeModel(config_dict)[source]#
Bases:
Module
GloVe Model
- Parameters:
config_dict (dict) – Config Params Dictionary
- forward(ctr, cxt)[source]#
Forward propogation
- Parameters:
ctr (torch.Tensor (batch_size,)) – Center tokens
cxt (torch.Tensor (batch_size,)) – Context tokens
- Returns:
Center, Context Embeddings and Biases
- Return type:
tuple (torch.Tensor [batch_size, embed_dim], torch.Tensor [batch_size, embed_dim],torch.Tensor [batch_size, 1],torch.Tensor [batch_size, 1],)
- class src.core.glove.model.GloVeTrainer(model, optimizer, config_dict)[source]#
Bases:
Module
GloVe Trainer
- Parameters:
model (torch.nn.Module) – Seq2Seq model
optimizer (torch.optim) – Optimizer
config_dict (dict) – Config Params Dictionary
- fit(train_loader, val_loader)[source]#
Fits the model on dataset. Runs training and Validation steps for given epochs and saves best model based on the evaluation metric
- Parameters:
train_loader (torch.utils.data.DataLoader) – Train Data loader
val_loader (torch.utils.data.DataLoader) – Validaion Data Loader
- Returns:
Training History
- Return type:
dict
- loss_fn(ctr_embed, cxt_embed, ctr_bias, cxt_bias, count)[source]#
GloVe loss
- Parameters:
ctr_embed (torch.Tensor (batch_size, embed_dim)) – Center embedding
cxt_embed (torch.Tensor (batch_size, embed_dim)) – Context embedding
ctr_bias (torch.Tensor (batch_size, 1)) – Center Bias
cxt_bias (torch.Tensor (batch_size, 1)) – Context Bias
count (float) – Cooccurence matrix element for (center, context)
- Returns:
Loss
- Return type:
torch.float32