Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tracking issue with lazy filled attributes #20783

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,27 @@ def build_wrapper(*args, **kwargs):
with obj._open_name_scope():
obj._path = current_path()
original_build_method(*args, **kwargs)
# Check for any untracked attrs/vars
if obj._tracker._has_untracked_attrs:
if backend.backend() == "tensorflow":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it backend specific?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only TF has this attr _tracked which will have only tracked attrs to iterate. To put in numbers other backend will have atleast 50-55 additional attrs to iterate. Since this might be a rare case and if it is OK we can keep a single check ignoring TF specific.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite get it -- can you explain in more detail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attribute for layer, _tracked is only initiated in TFLayer and hence only available for TF backend. It will log only that are initialized in custom layer, which is what we interested for this case, along with few common attrs '_inbound_nodes', '_outbound_nodes', '_losses', '_loss_ids', '_losses_override' .

Consider below custom layer:

class NGLayer(keras.layers.Layer):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        self.l1 = []
        for i in range(2):
            l2 = []
            self.l1.append(l2)
            for j in range(2):
                l2.append(keras.layers.Dense(1, name=f'dense_{i}_{j}'))
            # print("before appending l2 to l1 in OK way")
            # self.l1.append(l2) #This works
    def call(self, x):
        for l in self.l1:
            for d in l:
                x = d(x)
        return x

The final list of obj._tracked will be something like the below:

['_inbound_nodes', '_outbound_nodes', '_losses', '_loss_ids', '_losses_override', 'l1'] where 'l1' being initiated in the custom layer by the user which was not tracked properly.

For other backends the attribute _tracked is not available and hence we need to iterate through complete attributes initialized by Keras on the particular layer object to get the required attribute that missed tracking.

For TF backend the untracked attribute can be quickly retrieved using obj._tracked list. However we can also get this attribute using __dict__.keys() but list to iterate is big.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok -- we need to solve this issue in a completely different way. None of this is appropriate -- the reliance on the _tracked attribute, the changes to the build() wrapper, etc.

for attr in obj._tracked:
if (
id(getattr(obj, attr))
in obj._tracker._untracked_attrs_ids
):
untracked_attr = getattr(obj, attr)
for var in untracked_attr:
obj._tracker.track(var)
else:
for attr in obj.__dict__.keys():
if (
id(getattr(obj, attr))
in obj._tracker._untracked_attrs_ids
):
untracked_attr = getattr(obj, attr)
for var in untracked_attr:
obj._tracker.track(var)

# Record build config.
signature = inspect.signature(original_build_method)
obj._build_shapes_dict = signature.bind(*args, **kwargs).arguments
Expand Down
22 changes: 22 additions & 0 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,24 @@ def __init__(self, inner_layer):
def call(self, x):
return self.inner_layer(x)

class LayerWithLazyTracker(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def build(self, input_shape):
self.l1 = []
for i in range(2):
l2 = []
self.l1.append(l2)
for j in range(2):
l2.append(layers.Dense(1, name=f"dense_{i}_{j}"))

def call(self, x):
for l in self.l1:
for d in l:
x = d(x)
return x

layer = LayerWithDenseLayers(3)
layer.build((1, 3))
self.assertLen(layer._layers, 4)
Expand All @@ -253,6 +271,10 @@ def call(self, x):
self.assertLen(layer.variables, 9)
self.assertLen(layer.weights, 8)

layer = LayerWithLazyTracker()
layer(np.ones((4, 1)))
self.assertLen(layer.weights, 8)

def test_metric_tracking(self):
class LayerWithMetric(layers.Layer):
def __init__(self, units):
Expand Down
8 changes: 8 additions & 0 deletions keras/src/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(self, config, exclusions=None):
self.locked = False
self._lock_violation_msg = None
self.exclusions = exclusions or {}
# log untracked attrs if any
self._has_untracked_attrs = False
self._untracked_attrs_ids = []

def track(self, attr):
if not is_tracking_enabled():
Expand Down Expand Up @@ -144,6 +147,11 @@ def __init__(self, values=None, tracker=None):
def append(self, value):
if self.tracker:
self.tracker.track(value)
# Check if an empty attr assigned with empty list and list them
if not self and isinstance(value, list) and not value:
if self.tracker:
self.tracker._has_untracked_attrs = True
self.tracker._untracked_attrs_ids.append(id(self))
super().append(value)

def insert(self, index, value):
Expand Down
Loading