import json
import torch
import logging
import numpy as np
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from preprocess.utils import preprocess_text
[docs]
class PreprocessBERTFinetune:
"""
A class to preprocess BERT Finetuning Data
:param config_dict: Config Params Dictionary
:type config_dict: dict
:param wordpiece: WordPiece class
:type wordpiece: src.preprocess.WordPiece
:param word2id: Words to Ids mapping
:type word2id: dict
"""
def __init__(self, config_dict, wordpiece, word2id):
self.logger = logging.getLogger(__name__)
self.input_path = config_dict["finetune"]["paths"]["input_file"]
self.num_topics = config_dict["finetune"]["dataset"]["num_topics"]
self.seq_len = config_dict["dataset"]["seq_len"]
self.operations = config_dict["preprocess"]["operations"]
self.wordpiece = wordpiece
self.word2id = word2id
self.id2word = {v: k for k, v in self.word2id.items()}
[docs]
def get_data(self):
"""
Converts extracted data into tokens with start, end ids of answers along with the topics of each sample
:return: Finetuning Data (tokens, start ids, end ids, topics)
:rtype: tuple (numpy.ndarray [num_samples, seq_len], numpy.ndarray [num_samples,], numpy.ndarray [num_samples,] , numpy.ndarray [num_samples,])
"""
df = self.extract_data()
df["Context"] = df["Context"].map(lambda x: self.preprocess_text(x))
df["Question"] = df["Question"].map(lambda x: self.preprocess_text(x))
vocab = self.word2id.keys()
half_seq_len = self.seq_len // 2 - 1
ques_tokens, cxt_tokens = [], []
start_ids, end_ids = [], []
topics = []
cxt_lens = [len(text.split()) for text in df["Context"]]
ques_lens = [len(text.split()) for text in df["Question"]]
cxt_corpus = self.wordpiece.transform(df["Context"])
ques_corpus = self.wordpiece.transform(df["Question"])
cxt_count, ques_count = 0, 0
for id, (cxt_len, ques_len) in enumerate(zip(cxt_lens, ques_lens)):
cxt_token = cxt_corpus[cxt_count : cxt_count + cxt_len]
cxt_token = [i for ls in cxt_token for i in ls]
cxt_token = cxt_token[:half_seq_len]
if len(cxt_token) < half_seq_len:
cxt_token = cxt_token + ["<PAD>"] * (half_seq_len - len(cxt_token))
ques_token = ques_corpus[ques_count : ques_count + ques_len]
ques_token = [i for ls in ques_token for i in ls]
ques_token = ques_token[:half_seq_len]
if len(ques_token) < half_seq_len:
ques_token = ques_token + ["<PAD>"] * (half_seq_len - len(ques_token))
start_word_id = df.iloc[id]["Answer Start ID"]
num_word = df.iloc[id]["Num Words"]
start_id = len(
[
i
for ls in cxt_corpus[cxt_count : cxt_count + start_word_id]
for i in ls
]
)
end_id = len(
[
i
for ls in cxt_corpus[
cxt_count : cxt_count + start_word_id + num_word
]
for i in ls
]
)
if end_id <= half_seq_len:
cxt_tokens.append(
[
self.word2id[ch] if ch in vocab else self.word2id["<UNK>"]
for ch in ["<CLS>"] + cxt_token
]
)
ques_tokens.append(
[
self.word2id[ch] if ch in vocab else self.word2id["<UNK>"]
for ch in ["<SEP>"] + ques_token
]
)
start_ids.append(start_id)
end_ids.append(end_id)
topics.append(df.iloc[id]["Topic"])
cxt_count += cxt_len
ques_count += ques_len
tokens = np.concatenate([ques_tokens, cxt_tokens], axis=-1)
return tokens, np.array(start_ids), np.array(end_ids), np.array(topics)
[docs]
def preprocess_text(self, text):
"""
Preprocesses text
:param text: Raw Input string
:type text: str
:return: Preprocessed string
:rtype: str
"""
text = preprocess_text(text, self.operations)
return text
[docs]
def batched_ids2tokens(self, tokens):
"""
Converting sentence of ids to tokens
:param tokens: Tokens Array, 2D array (num_samples, seq_len)
:type tokens: numpy.ndarray
:return: List of decoded sentences
:rtype: list
"""
func = lambda x: self.id2word[x]
vect_func = np.vectorize(func)
tokens = vect_func(tokens)
sentences = []
for seq in tokens:
start_id = 0
words = []
for i, ch in enumerate(seq):
if "##" != ch[:2] and i != 0:
tokens = seq[start_id:i]
word = "".join(
[w if i == 0 else w[2:] for i, w in enumerate(tokens)]
)
words.append(word)
start_id = i
final_word = "".join(
[w if i == 0 else w[2:] for i, w in enumerate(seq[start_id:])]
)
words.append(final_word)
sentences.append(" ".join(words))
return sentences
[docs]
def create_data_loader_finetune(
tokens,
start_ids,
end_ids,
topics,
val_split=0.2,
test_split=0.2,
batch_size=32,
seed=2024,
):
"""
Creates PyTorch DataLoaders for Finetuning data
:param tokens: Input tokens
:type tokens: torch.Tensor
:param start_ids: Start ids of Prediction
:type start_ids: torch.Tensor
:param end_ids: End ids of Prediction
:type end_ids: torch.Tensor
:param topics: Topic type of data samples
:type topics: torch.Tensor
:param val_split: validation split, defaults to 0.2
:type val_split: float, optional
:param test_split: Test split, defaults to 0.2
:type test_split: float, optional
:param batch_size: Batch size, defaults to 32
:type batch_size: int, optional
:param seed: Seed, defaults to 2024
:type seed: int, optional
:return: Train, Val and Test dataloaders
:rtype: tuple (torch.utils.data.DataLoader, torch.utils.data.DataLoader,torch.utils.data.DataLoader)
"""
(
train_tokens,
val_tokens,
train_start_ids,
val_start_ids,
train_end_ids,
val_end_ids,
_,
val_topics,
) = train_test_split(
tokens,
start_ids,
end_ids,
topics,
test_size=val_split + test_split,
random_state=seed,
stratify=topics,
)
(
val_tokens,
test_tokens,
val_start_ids,
test_start_ids,
val_end_ids,
test_end_ids,
val_topics,
_,
) = train_test_split(
val_tokens,
val_start_ids,
val_end_ids,
val_topics,
test_size=test_split / (val_split + test_split),
random_state=seed,
stratify=val_topics,
)
train_ds = TensorDataset(
torch.Tensor(train_tokens),
torch.Tensor(train_start_ids),
torch.Tensor(train_end_ids),
)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
drop_last=True,
num_workers=1,
pin_memory=True,
)
val_ds = TensorDataset(
torch.Tensor(val_tokens), torch.Tensor(val_start_ids), torch.Tensor(val_end_ids)
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
shuffle=False,
drop_last=True,
num_workers=1,
pin_memory=True,
)
test_ds = TensorDataset(
torch.Tensor(test_tokens),
torch.Tensor(test_start_ids),
torch.Tensor(test_end_ids),
)
test_loader = DataLoader(
test_ds,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=1,
pin_memory=False,
)
return train_loader, val_loader, test_loader