import re
import copy as cp
import yaml
import io
from ast import literal_eval
[docs]class CfgNode(dict):
""" Config Node """
__FROZEN_KEY__ = '__frozen__'
[docs] def __init__(self, init_dict:dict=None, copy=True):
"""
Args:
init_dict (dict): a possibly-nested dictionary to initialize the CfgNode.
copy (bool): if this option is set to False, the CfgNode instance will share the value
with the `init_dict`, otherwise the contents of `init_dict` will be deepcopied.
"""
if copy:
init_dict = cp.deepcopy(init_dict)
if init_dict is None:
init_dict = {}
init_dict = CfgNode._create_config_from_dict(init_dict)
super(CfgNode, self).__init__(init_dict)
self.__dict__[CfgNode.__FROZEN_KEY__] = False
@classmethod
def _create_config_from_dict(cls, in_dict):
"""
Create a config tree from a possible nested dictionary.
Note:
this method will share variables with `in_dict` and modify `in_dict`.
Args:
in_dict (dict):
Returns:
dict: a dict whose nested dictionaries are all replaced with CfgNode.
"""
for key, value in in_dict.items():
# check key
if not isinstance(key, str):
raise KeyError("each key of the input dictionary must be `str` type.")
if not cls._check_name_valid(key):
raise KeyError("key in the input dictionary must be a vaild python variable name.")
# process value
in_dict[key] = cls._decode_value(value)
return in_dict
@classmethod
def _check_name_valid(cls, name):
"""
Check whether the name satisfies the named rule of python variable.
Args:
name (str):
Returns:
bool: whether the name is a valid python variable name.
"""
return re.match("[a-zA-Z_][a-zA-Z0-9_]*", name) is not None
@classmethod
def _decode_value(cls, value):
"""
Decode the value into CfgNode, str or other possible type.
Args:
value (Any):
Returns:
Any
"""
if isinstance(value, dict):
value = cls(init_dict=value, copy=False)
elif isinstance(value, str):
try:
value = literal_eval(value)
except(ValueError, SyntaxError):
pass
return value
[docs] def freeze(self, frozen:bool=True):
"""
freeze or unfreeze the CfgNode and all of its children
Args:
frozen (bool): freeze or unfreeze the config
"""
self.__dict__[CfgNode.__FROZEN_KEY__] = frozen
for value in self.values():
if isinstance(value, CfgNode):
value.freeze(frozen)
[docs] def is_frozen(self):
"""
get the state of the config.
Returns:
bool: whether the config tree is frozen.
"""
return self.__dict__[CfgNode.__FROZEN_KEY__]
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name, value):
if self.is_frozen():
raise AttributeError("Attempted to set {} to {}, but the CfgNode is frozen.".format(name, value))
self[name] = value
def __delattr__(self, name):
if name in self:
del self[name]
else:
raise AttributeError(name)
def __repr__(self):
return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
[docs] def copy(self):
"""
deepcopy this CfgNode
Returns:
CfgNode:
"""
return cp.deepcopy(self)
[docs] def merge(self, cfg):
"""
merge another CfgNode into this CfgNode, the another CfgNode will override this CfgNode.
Args:
cfg (CfgNode):
"""
if self.is_frozen():
raise Exception("Attempted to merge another CfgNode to this CfgNode, but this CfgNode is frozen.")
cfg = cfg.copy()
for key, value in cfg.items():
if key not in self:
self[key] = value
elif self[key] != value:
if isinstance(self[key], CfgNode) and isinstance(value, CfgNode):
self[key].merge(value)
else:
self[key] = value
[docs] def save(self, save_path, encoding='utf-8'):
"""
save the CfgNode into a yaml file
Args:
save_path:
"""
with open(save_path, 'w', encoding=encoding) as f:
CfgNode.dump(self, stream=f, encoding=encoding)
@classmethod
def _load_cfg_from_yaml_str(cls, yaml_str):
"""
load a CfgNode from a string of yaml format
Args:
yaml_str (str): a string of yaml format
Returns:
CfgNode:
"""
cfg_dict = yaml.load(yaml_str, Loader=yaml.UnsafeLoader)
return cls(cfg_dict)
@classmethod
def _load_cfg_from_yaml_file(cls, yaml_file):
"""
load a CfgNode from a yaml file object
Args:
yaml_file (io.IOBase):
Returns:
CfgNode:
"""
return cls._load_cfg_from_yaml_str(yaml_file.read())
[docs] @classmethod
def open(cls, file, encoding='utf-8'):
"""
load a CfgNode from file.
Args:
file (io.IOBase or str): file object or path to the yaml file.
encoding (str):
Returns:
CfgNode:
"""
if isinstance(file, str):
with open(file, 'r', encoding=encoding) as f:
return cls._load_cfg_from_yaml_file(f)
elif isinstance(file, io.IOBase):
return cls._load_cfg_from_yaml_file(file)
else:
raise TypeError("Expect a file object or str object, but got a {}".format(type(file)))
[docs] @classmethod
def load(cls, yaml_str:str):
"""
load a CfgNode from a string of yaml format
Args:
yaml_str (str):
Returns:
CfgNode:
"""
return cls._load_cfg_from_yaml_str(yaml_str)
@classmethod
def _convert_cfg_to_dict(cls, cfg):
"""
convert a CfgNode to pure dict
Args:
cfg (CfgNode):
Returns:
dict:
"""
out_dict = {}
for key, value in cfg.items():
if isinstance(value, cls):
out_dict[key] = cls._convert_cfg_to_dict(value)
else:
out_dict[key] = value
return out_dict
[docs] @classmethod
def dump(cls, cfg, stream=None, encoding=None, **kwargs):
"""
dump CfgNode into yaml str or yaml file
Note:
if `stream` option is set to non-None object, the CfgNode will be dumpped into stream and return None,
if `stream` option is not given or set to None, return a string instead.
Args:
cfg (CfgNode):
stream (io.IOBase or None): if set to a file object, the CfgNode will be dumpped into stream
and return None, if set to None, return a string instead.
encoding (str or None):
**kwargs: options of the yaml dumper. \n
Some useful options: ["allow_unicode", "line_break", "explicit_start",
"explicit_end", "version", "tags"]. \n
See more details at https://github.com/yaml/pyyaml/blob/2f463cf5b0e98a52bc20e348d1e69761bf263b86/lib3/yaml/__init__.py#L252
Returns:
None or str:
"""
cfg_dict = CfgNode._convert_cfg_to_dict(cfg)
return yaml.dump(cfg_dict, stream=stream, encoding=encoding, **kwargs)
[docs] def dict(self):
"""
convert to a dict
Returns:
dict:
"""
return CfgNode._convert_cfg_to_dict(self)
[docs] def __str__(self):
"""
Returns:
str: a str of dict format
"""
return str(self.dict())