import torch
import logging
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from core.word2vec.dataset import Word2VecDataset
[docs]
class GloVeDataset(Word2VecDataset):
"""
GloVe Dataset
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
self.logger = logging.getLogger(__name__)
self.config_dict = config_dict
self.num_vocab = config_dict["dataset"]["num_vocab"]
self.context = config_dict["dataset"]["context"]
self.id2word = {}
self.word2id = {}
self.preprocess()
self.get_vocab()
[docs]
def get_data(self):
"""
Generates Coccurence Matrix
:return: Center, Context words and Co-occurence matrix
:rtype: tuple (numpy.ndarray [num_samples, ], numpy.ndarray [num_samples, ], numpy.ndarray [num_samples, ])
"""
self.cooccur_mat = np.zeros((1 + self.num_vocab, 1 + self.num_vocab))
for text in self.text_ls:
words = text.split()
if len(words) < 1 + self.context:
continue
for i in range(self.context, len(words) - self.context):
id_i = self.word2id[words[i]]
for j in range(i - self.context, i + self.context):
id_j = self.word2id[words[j]]
dist = np.abs(j - i)
if dist != 0:
self.cooccur_mat[id_i][id_j] += 1 / dist
X_ctr, X_cxt = np.indices((1 + self.num_vocab, 1 + self.num_vocab))
X_ctr, X_cxt = X_ctr.flatten(), X_cxt.flatten()
X_cnt = self.cooccur_mat.flatten()
return X_ctr, X_cxt, X_cnt
[docs]
def create_dataloader(X_ctr, X_cxt, X_count, val_split=0.2, batch_size=32, seed=2024):
"""
Creates Train, Validation DataLoader
:param X_ctr: Center words
:type X_ctr: numpy.ndarray (num_samples, )
:param X_cxt: Context words
:type X_cxt: numpy.ndarray (num_samples, )
:param X_count: Co-occurence matrix elements
:type X_count: numpy.ndarray (num_samples, )
:param val_split: validation split, defaults to 0.2
:type val_split: float, optional
:param batch_size: Batch size, defaults to 32
:type batch_size: int, optional
:param seed: Seed, defaults to 2024
:type seed: int, optional
:return: Train, Val dataloaders
:rtype: tuple (torch.utils.data.DataLoader, torch.utils.data.DataLoader)
"""
train_ctr, val_ctr, train_cxt, val_cxt, train_count, val_count = train_test_split(
X_ctr, X_cxt, X_count, test_size=val_split, random_state=seed
)
train_ds = TensorDataset(
torch.Tensor(train_ctr), torch.Tensor(train_cxt), torch.Tensor(train_count)
)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=1,
pin_memory=True,
)
val_ds = TensorDataset(
torch.Tensor(val_ctr), torch.Tensor(val_cxt), torch.Tensor(val_count)
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
drop_last=True,
num_workers=1,
pin_memory=True,
)
return train_loader, val_loader