Source code for src.utils

import os
import sys
import yaml
import torch
import random
import datetime
import logging
import numpy as np

from configs import configDictDType, MainKeysDict

logger = logging.getLogger(__name__)


[docs] def load_config(config_path): """ Loading YAML Config file as a Dictionary :param config_path: Path to Config File :type config_path: str :return: Config Params Dictionary :rtype: dict """ with open(config_path, "r") as stream: try: config_dict = yaml.safe_load(stream) except yaml.YAMLError as exc: logging.error(exc) logging.info("Config File Loaded") return config_dict
[docs] def set_seed(seed): """ Setting seed across Libraries to reproduce results :param seed: Seed value :type seed: int """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs] class ValidateConfig: """ Validating Config File :param config_dict: Config Params Dictionary :type config_dict: dict :param algo: Name of the Algorithm :type algo: str """ def __init__(self, config_dict, algo): self.config_dict = config_dict self.algo = algo
[docs] def run_verify(self): """ Config Params Keys and Values Verification """ logger.info("Validating Config File") self.verify_main_keys(self.config_dict.keys()) self.verify_values()
[docs] def check_float(self, key, val): """ To check whether given key whose value is float has a valid value or not :param key: Param Key :type key: str :param val: Param value :type val: float """ pass
[docs] def check_int(self, key, val): """ To check whether given key whose value is int has a valid value or not :param key: Param Key :type key: str :param val: Param value :type val: int """ pass
[docs] def check_string(self, key, val): """ To check whether given key whose value is str has a valid value or not :param key: Param Key :type key: str :param val: Param value :type val: str """ pass
[docs] def check_paths(self, key, val): """ To check whether given key whose value is a filepath has a valid value or not :param key: Param Key :type key: str :param val: Param value :type val: str """ pass
[docs] def check_list(self, key, val): """ To check whether given key whose value is list has a valid value or not :param key: Param Key :type key: str :param val: Param value :type val: list """ pass
[docs] def compare_dtype(self, key, val): """ To check whether given key whose value has a valid dtype or not :param key: Param Key :type key: str :param val: Param value :type val: float/int/str/list """ type_abs = configDictDType[key] type_cfg = type(val) if type_abs != type_cfg: logging.error(f"Dtype of {key} should be {type_abs}")
[docs] def verify_values(self): """ Verifying the Datatypes of all the Parameters in Config """ for k, v in self.config_dict.items(): if isinstance(v, dict): for k_, v_ in v.items(): if type(v_) is dict: for k__, v__ in v_.items(): self.compare_dtype(k__, v__) else: self.compare_dtype(k_, v_) else: self.compare_dtype(k, v)
[docs] def verify_main_keys(self, keys): """ Verifying whether Config has all the required keys or not :param keys: Parent Config Parameters :type keys: list """ for key in keys: true_val = MainKeysDict[self.algo][key] if isinstance(true_val, list): cfg_val = list(self.config_dict[key].keys()) true_val.sort() cfg_val.sort() if true_val != cfg_val: logging.error(f"Config Keys for {key} doesn't match Default Config") elif isinstance(true_val, dict): for k in self.config_dict[key].keys(): true_val_k = MainKeysDict[self.algo][key][k] cfg_val_k = list(self.config_dict[key][k].keys()) true_val_k.sort() cfg_val_k.sort() if true_val_k != cfg_val_k: logging.error( f"Config Keys for {k} doesn't match Default Config" ) else: if not isinstance(self.config_dict[key], MainKeysDict[self.algo][key]): logging.error( f"Config Key {key} should be of type {MainKeysDict[self.algo][key]}" )
[docs] def get_logger(log_folder): """ Initializing Log File :param log_folder: Path to folder where Log file is added :type log_folder: str """ os.makedirs(log_folder, exist_ok=True) current_time = datetime.datetime.now() timestamp = current_time.strftime("%Y-%m-%d-%H.%M.%S") logging.basicConfig( level=logging.INFO, filename=os.path.join(log_folder, f"logs-{timestamp}.txt"), filemode="w", format="%(asctime)-8s [%(filename)s:%(lineno)d] %(levelname)s: %(message)s", ) console = logging.StreamHandler(stream=sys.stdout) console.setLevel(logging.INFO) formatter = logging.Formatter( "[%(filename)s:%(lineno)d] %(levelname)s: %(message)s" ) console.setFormatter(formatter) logging.getLogger().addHandler(console)