Skip to content

Commit

Permalink
Fix defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
EvgeniyS99 committed Dec 18, 2023
1 parent 1877af2 commit 03b36cc
Showing 1 changed file with 80 additions and 75 deletions.
155 changes: 80 additions & 75 deletions batchflow/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
""" Config class"""
from pprint import pformat

class Config(dict):
""" Class for configs that can be represented as nested dicts with easy indexing by slashes. """

# Should be defined temporarily for the already pickled configs
class IAddDict(dict):
Expand All @@ -27,6 +29,7 @@ def __init__(self, config=None, **kwargs):
kwargs :
Parameters from kwargs also are parsed and saved to self.config.
"""
# pylint: disable=super-init-not-called
self.config = {}

if config is None:
Expand All @@ -40,7 +43,7 @@ def __init__(self, config=None, **kwargs):

for key, value in kwargs.items():
self.put(key, value)

def parse(self, config):
""" Parses flatten config with slashes.
Expand All @@ -54,14 +57,14 @@ def parse(self, config):
"""
if isinstance(config, Config):
items = config.items(flatten=True) # suppose we have config = {'a': {'b': {'c': 1}}},
# and we try to update config with other = {'a': {'b': {'d': 3}}},
# and expect to see config = {'a': {'b': {'c': 1, 'd': 3}}}
# suppose we have config = {'a': {'b': {'c': 1}}},
# and we try to update config with other = {'a': {'b': {'d': 3}}},
# and expect to see config = {'a': {'b': {'c': 1, 'd': 3}}}
items = config.items(flatten=True)
elif isinstance(config, dict):
items = config.items()
else:
items = dict(config).items()
# items = config.items() if isinstance(config, dict) else dict(config).items()

for key, value in items:
if isinstance(key, str): # if key contains multiple consecutive '/'
Expand Down Expand Up @@ -98,11 +101,12 @@ def put(self, key, value):
if isinstance(value, dict) and last_level in config and isinstance(config[last_level], dict):
config[last_level].update(value)
else:
# for example, we try to set config['a/b/c'] = 3, where config = Config({'a/b': 1}) and don't want error here
if isinstance(config, dict):
config[last_level] = value
# for example, we try to set my_config['a/b/c'] = 3,
# where my_config = Config({'a/b': 1}) and don't want error here
else:
prev_config[level] = {last_level: value}
prev_config[level] = {last_level: value} # pylint: disable=undefined-loop-variable
else:
self.config[key] = value

Expand All @@ -118,15 +122,19 @@ def _get(self, key, default=None, has_default=False, pop=False):
key = [key]
unpack = True

# Provide `default` for each variable in key
if default is not None and len(key) != 1 and len(default) != len(key):
raise ValueError('You should provide `default` for each variable in `key`') # edit
default = [default] if not isinstance(default, list) else default
n = len(key)
if n > 1:
default = [default] * n if not isinstance(default, list) else default
if len(default) != n:
raise ValueError('The length of `default` must be equal to the length of `key`')
else:
default = [default]

ret_vars = []
for ix, variable in enumerate(key):

if isinstance(variable, str) and '/' in variable:

value = self.config
levels = variable.split('/')
values = []
Expand All @@ -137,29 +145,29 @@ def _get(self, key, default=None, has_default=False, pop=False):
if not has_default:
raise KeyError(level)
value = default[ix]
ret_vars.append(value)
values.append(value)
break

elif level not in value:
if level not in value:
if not has_default:
raise KeyError(level)
value = default[ix]
ret_vars.append(value)
values.append(value)
break

else:
value = value[level]
values.append(value)
value = value[level]
values.append(value)

if pop:
del values[-2][level] # delete the last level from the parent dict
# delete the last level from the parent dict
values[-2].pop(level, default[ix]) # pylint: disable=undefined-loop-variable

else:

if variable not in self.config:
if not has_default:
raise KeyError(variable)
value = default[ix]
ret_vars.append(value)

else:
value = method(variable)
Expand All @@ -182,7 +190,8 @@ def get(self, key, default=None):
A key in the dictionary. '/' is used to get value from nested dict.
default : misc
Default value if key doesn't exist in config.
Defaults to None, so that this method never raises a KeyError.
By default None, so this method never raises a KeyError.
If key has several variables, `default` can be a list with defaults for each variable.
Returns
-------
Expand All @@ -192,7 +201,7 @@ def get(self, key, default=None):
value = self._get(key, default=default, has_default=True)

return value

def pop(self, key, **kwargs):
""" Returns the value or tuple of values for key in the config.
If not found, returns a default value.
Expand All @@ -203,7 +212,6 @@ def pop(self, key, **kwargs):
A key in the dictionary. '/' is used to get value from nested dict.
default : misc
Default value if key doesn't exist in config.
Defaults to None, so that this method never raises a KeyError.
Returns
-------
Expand All @@ -216,13 +224,6 @@ def pop(self, key, **kwargs):

return value

def __repr__(self):
return repr(self.config)

def __getitem__(self, key):
value = self._get(key)
return value

def update(self, other=None, **kwargs):
other = other or {}
if not isinstance(other, (dict, tuple, list)):
Expand All @@ -233,6 +234,34 @@ def update(self, other=None, **kwargs):
for key, value in kwargs.items():
self.put(key, value)

def flatten(self, config=None):
""" Transforms nested dict into flatten dict.
Parameters
----------
config : dict, Config or None
If None `self.config` will be parsed else config.
Returns
-------
new_config : dict
"""
config = self.config if config is None else config
new_config = {}
for key, value in config.items():
if isinstance(value, dict) and len(value) > 0:
value = self.flatten(value)
for _key, _value in value.items():
if isinstance(_key, str):
new_config[key + '/' + _key] = _value
else:
new_config[key] = {_key: _value}
else:
new_config[key] = value

return new_config

def keys(self, flatten=False):
""" Returns config keys
Expand Down Expand Up @@ -290,33 +319,13 @@ def items(self, flatten=False):
items = self.config.items()
return items

def flatten(self, config=None):
""" Transforms nested dict into flatten dict.
Parameters
----------
config : dict, Config or None
If None `self.config` will be parsed else config.
Returns
-------
new_config : dict
"""
config = self.config if config is None else config
new_config = {}
for key, value in config.items():
if isinstance(value, dict) and len(value) > 0:
value = self.flatten(value)
for _key, _value in value.items():
if isinstance(_key, str):
new_config[key + '/' + _key] = _value
else:
new_config[key] = {_key: _value}
else:
new_config[key] = value
def copy(self):
""" Create a shallow copy of the instance. """
return Config(self.config.copy())

return new_config
def __getitem__(self, key):
value = self._get(key)
return value

def __setitem__(self, key, value):
if key in self.config:
Expand All @@ -326,10 +335,6 @@ def __setitem__(self, key, value):
def __delitem__(self, key):
self.pop(key)

def copy(self):
""" Create a shallow copy of the instance. """
return Config(self.config.copy())

def __getattr__(self, key):
if key in self.config:
value = self.config.get(key)
Expand All @@ -344,13 +349,6 @@ def __add__(self, other):
return Config([*self.flatten().items(), *other.flatten().items()])
return NotImplemented

def __iter__(self):
return iter(self.config)

def __repr__(self):
lines = ['\n' + 4 * ' ' + line for line in pformat(self.config).split('\n')]
return f"Config({''.join(lines)})"

def __iadd__(self, other):
if isinstance(other, dict):
self.update(other)
Expand All @@ -363,21 +361,20 @@ def __radd__(self, other):
other = Config(other)
return other.__add__(self)

def __len__(self):
return len(self.config)

def __eq__(self, other):
self_ = self.flatten()
other_ = Config(other).flatten() if isinstance(other, dict) else other
return self_.__eq__(other_)

def __getstate__(self):
""" Must be explicitly defined for pickling to work. """
return vars(self)
def __len__(self):
return len(self.config)

def __setstate__(self, state):
""" Must be explicitly defined for pickling to work. """
vars(self).update(state)
def __iter__(self):
return iter(self.config)

def __repr__(self):
lines = ['\n' + 4 * ' ' + line for line in pformat(self.config).split('\n')]
return f"Config({''.join(lines)})"

def __rshift__(self, other):
""" Parameters
Expand All @@ -390,3 +387,11 @@ def __rshift__(self, other):
Pipeline object with an updated config.
"""
return other << self

def __getstate__(self):
""" Must be explicitly defined for pickling to work. """
return vars(self)

def __setstate__(self, state):
""" Must be explicitly defined for pickling to work. """
vars(self).update(state)

0 comments on commit 03b36cc

Please sign in to comment.