Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras/models/cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _clone_layer(layer):
)
try:
tree.assert_same_structure(input_tensors, model.input)
except TypeError as e:
except (ValueError, TypeError) as e:
raise ValueError(
"`input_tensors` must have the same structure as model.input"
f"\nReference structure: {model.input}"
Expand Down
4 changes: 2 additions & 2 deletions keras/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,8 +790,8 @@ def test_cond_check_output_spec_list_tuple(self):
def test_cond_check_output_spec_other_types(self):
cond_op = core.Cond()
# Create mock objects with dtype and shape attributes
mock_spec1 = Mock(dtype="float32", shape=(2, 2))
mock_spec2 = Mock(dtype="float32", shape=(2, 2))
mock_spec1 = KerasTensor(shape=(2, 2), dtype="float32")
mock_spec2 = KerasTensor(shape=(2, 2), dtype="float32")
self.assertTrue(cond_op._check_output_spec(mock_spec1, mock_spec2))

def test_cond_check_output_spec_none(self):
Expand Down
36 changes: 36 additions & 0 deletions keras/utils/tracking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from functools import wraps

import optree
import optree.utils

from keras.backend.common.global_state import get_global_attribute
from keras.backend.common.global_state import set_global_attribute
from keras.utils import python_utils
Expand Down Expand Up @@ -110,6 +113,7 @@ def add_to_store(self, store_name, value):
self.stored_ids[store_name].add(id(value))


@optree.register_pytree_node_class(namespace="keras")
class TrackedList(list):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -160,7 +164,17 @@ def __delitem__(self, index):
if self.tracker:
self.tracker.untrack(value)

def tree_flatten(self):
# For optree
return (self, None)

@classmethod
def tree_unflatten(cls, metadata, children):
# For optree
return cls(children)


@optree.register_pytree_node_class(namespace="keras")
class TrackedDict(dict):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -199,7 +213,20 @@ def clear(self):
self.tracker.untrack(value)
super().clear()

def tree_flatten(self):
# For optree
keys, values = optree.utils.unzip2(
optree.utils.total_order_sorted(self.items(), key=lambda kv: kv[0])
)
return values, list(keys), keys

@classmethod
def tree_unflatten(cls, keys, values):
# For optree
return cls(optree.utils.safe_zip(keys, values))


@optree.register_pytree_node_class(namespace="keras")
class TrackedSet(set):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
Expand Down Expand Up @@ -233,3 +260,12 @@ def clear(self):
for value in self:
self.tracker.untrack(value)
super().clear()

def tree_flatten(self):
# For optree
return (self, None)

@classmethod
def tree_unflatten(cls, metadata, children):
# For optree
return cls(children)
Loading