import os
import json
import scipy.special
import torch
import scipy
import logging
import numpy as np
import pandas as pd
import torch.nn as nn
from .dataset_pretrain import create_dataloader_pretrain, PreprocessBERTPretrain
from .dataset_finetune import create_data_loader_finetune, PreprocessBERTFinetune
from .model import BERTPretrainModel, BERTFinetuneModel
from .pretrain import BERTPretrainTrainer
from .finetune import BERTFinetuneTrainer
from plot_utils import plot_embed, plot_history
[docs]
class BERT:
"""
A class to run BERT data preprocessing, training and inference
:param config_dict: Config Params Dictionary
:type config_dict: dict
"""
def __init__(self, config_dict):
self.logger = logging.getLogger(__name__)
self.config_dict = config_dict
[docs]
def run(self):
"""
Runs BERT pretrain and finetune stages and saves output
"""
self.trainer_pretrain, self.pretrain_history = self.run_pretrain()
self.model_pretrain = self.trainer_pretrain.model
self.trainer_finetune, self.finetune_history = self.run_finetune()
self.model_finetune = self.trainer_finetune.model
self.save_output()
[docs]
def run_pretrain(self):
"""
Pretraining stage of BERT
:return: BERT Pretrain Trainer and Training History
:rtype: tuple (torch.nn.Module, dict)
"""
val_split = self.config_dict["dataset"]["val_split"]
test_split = self.config_dict["dataset"]["test_split"]
batch_size = self.config_dict["dataset"]["batch_size"]
seed = self.config_dict["dataset"]["seed"]
self.bert_pretrain_ds = PreprocessBERTPretrain(self.config_dict)
text_tokens, nsp_labels = self.bert_pretrain_ds.get_data()
self.word2id = self.bert_pretrain_ds.word2id
self.wordpiece = self.bert_pretrain_ds.wordpiece
train_loader, val_loader, self.test_loader_pretrain = (
create_dataloader_pretrain(
text_tokens,
nsp_labels,
self.config_dict,
self.word2id,
val_split,
test_split,
batch_size,
seed,
)
)
model = BERTPretrainModel(self.config_dict)
lr = self.config_dict["train"]["lr"]
optim = torch.optim.Adam(model.parameters(), lr=lr)
trainer = BERTPretrainTrainer(model, optim, self.config_dict)
self.logger.info(f"-----------BERT Pretraining-----------")
history = trainer.fit(train_loader, val_loader)
return trainer, history
[docs]
def run_finetune(self):
"""
Finetuning stage of BERT
:return: BERT Fientune Trainer and Training History
:rtype: tuple (torch.nn.Module, dict)
"""
val_split = self.config_dict["finetune"]["dataset"]["val_split"]
test_split = self.config_dict["finetune"]["dataset"]["test_split"]
batch_size = self.config_dict["finetune"]["dataset"]["batch_size"]
seed = self.config_dict["finetune"]["dataset"]["seed"]
self.bert_finetune_ds = PreprocessBERTFinetune(
self.config_dict, self.wordpiece, self.word2id
)
tokens, start_ids, end_ids, topics = self.bert_finetune_ds.get_data()
train_loader, val_loader, self.test_loader = create_data_loader_finetune(
tokens,
start_ids,
end_ids,
topics,
val_split=val_split,
test_split=test_split,
batch_size=batch_size,
seed=seed,
)
model = BERTFinetuneModel(self.config_dict)
model = self.load_pretrain_weights(self.model_pretrain, model)
lr = self.config_dict["finetune"]["train"]["lr"]
optim = torch.optim.Adam(model.parameters(), lr=lr)
trainer = BERTFinetuneTrainer(model, optim, self.config_dict)
self.logger.info(f"-----------BERT Finetuning-----------")
history = trainer.fit(train_loader, val_loader)
return trainer, history
[docs]
def run_infer_finetune(self):
"""
Runs inference using Finetuned BERT
:return: True and Predicted start, end ids
:rtype: tuple (numpy.ndarray [num_samples,], numpy.ndarray [num_samples,], numpy.ndarray [num_samples,], numpy.ndarray [num_samples,])
"""
start_ids, end_ids, enc_outputs = self.trainer_finetune.predict(
self.test_loader
)
num_samples = len(start_ids)
seq_len = self.config_dict["dataset"]["seq_len"]
start = self.model_finetune.start.cpu().detach().numpy()
end = self.model_finetune.end.cpu().detach().numpy()
cxt_enc_output = enc_outputs[:, seq_len // 2 :, :]
start_muls = np.matmul(cxt_enc_output, start).squeeze()
start_probs = scipy.special.softmax(start_muls, axis=1)
end_muls = np.matmul(cxt_enc_output, end).squeeze()
start_ids_pred = np.argmax(start_probs, axis=1)
start_max_muls = start_muls[np.arange(num_samples), start_ids_pred]
sum_muls = start_max_muls.reshape(-1, 1) + end_muls
def get_max_end_index(arr, start_indices):
num_rows, num_cols = arr.shape
mask = np.stack([np.arange(num_cols)] * num_rows) <= np.expand_dims(
start_indices, 1
)
masked_arr = arr
masked_arr[mask] = -1e9
max_indices = np.argmax(masked_arr, axis=1)
return max_indices
end_ids_pred = get_max_end_index(sum_muls, start_ids_pred)
return start_ids, end_ids, start_ids_pred, end_ids_pred
[docs]
def load_pretrain_weights(self, pretrain_model, finetune_model):
"""
Copies pretrain weights to finetune BERT model object
:param pretrain_model: Pretrain BERT model
:type pretrain_model: torch.nn.Module
:param finetune_model: Finetune BERT model
:type finetune_model: torch.nn.Module
:return: Finetune BERT model with Pretrained weights
:rtype: torch.nn.Module
"""
for i, layer in enumerate(pretrain_model.encoder_layers):
finetune_model.encoder_layers[i].load_state_dict(layer.state_dict())
finetune_model.embed_layer.load_state_dict(
pretrain_model.embed_layer.state_dict()
)
return finetune_model
[docs]
def save_output(self):
"""
Saves Training and Inference results
"""
output_folder = self.config_dict["paths"]["output_folder"]
self.logger.info(f"Saving Outputs {output_folder}")
with open(os.path.join(output_folder, "pretraining_history.json"), "w") as fp:
json.dump(self.pretrain_history, fp)
plot_history(self.pretrain_history, output_folder, "Pretrain History")
with open(os.path.join(output_folder, "finetuning_history.json"), "w") as fp:
json.dump(self.finetune_history, fp)
plot_history(self.finetune_history, output_folder, "Finetune History")
embeds = self.model_pretrain.embed_layer.weight.detach().numpy()
vocab = list(self.bert_pretrain_ds.word2id.keys())
plot_embed(embeds, vocab, output_folder, fname="Tokens Embeddings TSNE")
start_ids, end_ids, start_ids_pred, end_ids_pred = self.run_infer_finetune()
df = pd.DataFrame.from_dict(
{
"Start ID": start_ids,
"End ID": end_ids,
"Start ID Pred": start_ids_pred,
"End ID Pred": end_ids_pred,
}
)
df.to_csv(os.path.join(output_folder, "Finetune Predictions.csv"), index=False)