Skip to content

Commit e2b43e2

Browse files
Replace dm-tree with optree (#19306)
* Refactor `keras.utils.tree` * Fix tests * Replace `dm-tree` with `optree` * Eliminate `tf.nest` * Resolve comments * Fix merge conflicts * Update exporting path
1 parent 3fcb38c commit e2b43e2

File tree

12 files changed

+868
-115
lines changed

12 files changed

+868
-115
lines changed

keras/backend/tensorflow/core.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.backend.common.name_scope import name_scope as base_name_scope
1010
from keras.backend.common.stateless_scope import StatelessScope
1111
from keras.backend.common.stateless_scope import in_stateless_scope
12+
from keras.utils import tree
1213
from keras.utils.naming import auto_name
1314

1415
SUPPORTS_SPARSE_TENSORS = True
@@ -189,7 +190,7 @@ def convert_keras_tensor_to_tf(x):
189190
)
190191
return x
191192

192-
args, kwargs = tf.nest.map_structure(
193+
args, kwargs = tree.map_structure(
193194
convert_keras_tensor_to_tf, (args, kwargs)
194195
)
195196
tf_out = fn(*args, **kwargs)
@@ -201,9 +202,7 @@ def convert_tf_to_keras_tensor(x):
201202
)
202203
return x
203204

204-
output_spec = tf.nest.map_structure(
205-
convert_tf_to_keras_tensor, tf_out
206-
)
205+
output_spec = tree.map_structure(convert_tf_to_keras_tensor, tf_out)
207206
return output_spec
208207

209208

keras/backend/tensorflow/layer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from keras.backend.tensorflow.trackable import KerasAutoTrackable
44
from keras.utils import tf_utils
55
from keras.utils import tracking
6+
from keras.utils import tree
67

78

89
class TFLayer(KerasAutoTrackable):
@@ -27,16 +28,16 @@ def _set_save_spec(self, inputs, args=None, kwargs=None):
2728
if self._saved_model_inputs_spec is not None:
2829
return # Already set.
2930

30-
inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
31-
args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or [])
31+
inputs_spec = tree.map_structure(tf_utils.get_tensor_spec, inputs)
32+
args_spec = tree.map_structure(tf_utils.get_tensor_spec, args or [])
3233
kwargs_spec = {}
3334
# Filter out non-tensor arguments from kwargs.
3435
for key, kwarg in kwargs.items():
35-
flat_kwarg = tf.nest.flatten(kwarg)
36+
flat_kwarg = tree.flatten(kwarg)
3637
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
3738
if any(s is None for s in flat_specs):
3839
continue
39-
kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs)
40+
kwargs_spec[key] = tree.pack_sequence_as(kwarg, flat_specs)
4041

4142
self._saved_model_inputs_spec = inputs_spec
4243
self._saved_model_arg_spec = (
@@ -94,7 +95,7 @@ def _default_save_signature(self):
9495

9596
if inputs is not None:
9697
input_signature = [
97-
tf.nest.map_structure(
98+
tree.map_structure(
9899
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
99100
inputs,
100101
)
@@ -108,7 +109,7 @@ def _default_save_signature(self):
108109
]
109110
else:
110111
input_signature = [
111-
tf.nest.map_structure(
112+
tree.map_structure(
112113
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
113114
shapes_dict,
114115
)

keras/backend/tensorflow/numpy.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras.backend.common.backend_utils import to_tuple_or_list
1818
from keras.backend.tensorflow import sparse
1919
from keras.backend.tensorflow.core import convert_to_tensor
20+
from keras.utils import tree
2021

2122

2223
@sparse.elementwise_binary_union(tf.sparse.add)
@@ -95,7 +96,7 @@ def _normalize_einsum_subscripts(subscripts):
9596

9697

9798
def einsum(subscripts, *operands, **kwargs):
98-
operands = tf.nest.map_structure(convert_to_tensor, operands)
99+
operands = tree.map_structure(convert_to_tensor, operands)
99100
subscripts = _normalize_einsum_subscripts(subscripts)
100101

101102
def is_valid_for_custom_ops(subscripts, *operands):
@@ -240,15 +241,15 @@ def use_custom_ops(subscripts, *operands, output_type):
240241
# output_type="int32"
241242
if "int" in compute_dtype and output_type is None:
242243
compute_dtype = config.floatx()
243-
operands = tf.nest.map_structure(
244+
operands = tree.map_structure(
244245
lambda x: tf.cast(x, compute_dtype), operands
245246
)
246247
result = use_custom_ops(subscripts, *operands, output_type=output_type)
247248
else:
248249
# TODO: tf.einsum doesn't support integer dtype with gpu
249250
if "int" in compute_dtype:
250251
compute_dtype = config.floatx()
251-
operands = tf.nest.map_structure(
252+
operands = tree.map_structure(
252253
lambda x: tf.cast(x, compute_dtype), operands
253254
)
254255
result = tf.einsum(subscripts, *operands, **kwargs)
@@ -763,11 +764,11 @@ def concatenate(xs, axis=0):
763764
)
764765
for x in xs
765766
]
766-
xs = tf.nest.map_structure(convert_to_tensor, xs)
767+
xs = tree.map_structure(convert_to_tensor, xs)
767768
dtype_set = set([x.dtype for x in xs])
768769
if len(dtype_set) > 1:
769770
dtype = dtypes.result_type(*dtype_set)
770-
xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs)
771+
xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs)
771772
return tf.concat(xs, axis=axis)
772773

773774

@@ -872,7 +873,7 @@ def digitize(x, bins):
872873
bins = list(bins)
873874

874875
# bins must be float type
875-
bins = tf.nest.map_structure(lambda x: float(x), bins)
876+
bins = tree.map_structure(lambda x: float(x), bins)
876877

877878
# TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8
878879
# int16, uint8, uint16, uint32
@@ -1023,7 +1024,7 @@ def hstack(xs):
10231024
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
10241025
if len(dtype_set) > 1:
10251026
dtype = dtypes.result_type(*dtype_set)
1026-
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
1027+
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
10271028
rank = tf.rank(xs[0])
10281029
return tf.cond(
10291030
tf.equal(rank, 1),
@@ -1328,9 +1329,7 @@ def ndim(x):
13281329
def nonzero(x):
13291330
x = convert_to_tensor(x)
13301331
result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1)
1331-
return tf.nest.map_structure(
1332-
lambda indices: tf.cast(indices, "int32"), result
1333-
)
1332+
return tree.map_structure(lambda indices: tf.cast(indices, "int32"), result)
13341333

13351334

13361335
def not_equal(x1, x2):
@@ -1620,7 +1619,7 @@ def stack(x, axis=0):
16201619
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
16211620
if len(dtype_set) > 1:
16221621
dtype = dtypes.result_type(*dtype_set)
1623-
x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x)
1622+
x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x)
16241623
return tf.stack(x, axis=axis)
16251624

16261625

@@ -1807,7 +1806,7 @@ def vstack(xs):
18071806
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
18081807
if len(dtype_set) > 1:
18091808
dtype = dtypes.result_type(*dtype_set)
1810-
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
1809+
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
18111810
return tf.concat(xs, axis=0)
18121811

18131812

keras/backend/tensorflow/trainer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def multi_step_on_data(data):
225225
outputs = one_step_on_data_distributed(data[:1])
226226
for single_step_data in data[1:]:
227227
step_outputs = one_step_on_data_distributed([single_step_data])
228-
outputs = tf.nest.map_structure(
228+
outputs = tree.map_structure(
229229
lambda t1, t2: concat([t1, t2]), outputs, step_outputs
230230
)
231231
return outputs
@@ -473,7 +473,7 @@ def predict(
473473

474474
def append_to_outputs(batch_outputs, outputs):
475475
if outputs is None:
476-
outputs = tf.nest.map_structure(
476+
outputs = tree.map_structure(
477477
lambda batch_output: [batch_output],
478478
batch_outputs,
479479
)
@@ -521,7 +521,7 @@ def get_data(iterator):
521521
outputs = tree.map_structure_up_to(
522522
batch_outputs, potentially_ragged_concat, outputs
523523
)
524-
return tf.nest.map_structure(convert_to_np_if_not_ragged, outputs)
524+
return tree.map_structure(convert_to_np_if_not_ragged, outputs)
525525

526526
def train_on_batch(
527527
self,
@@ -549,7 +549,7 @@ def data():
549549
yield (x, y, sample_weight)
550550

551551
logs = self.train_function(data())
552-
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
552+
logs = tree.map_structure(lambda x: np.array(x), logs)
553553
if return_dict:
554554
return logs
555555
return self._flatten_metrics_in_order(logs)
@@ -568,15 +568,15 @@ def data():
568568
yield (x, y, sample_weight)
569569

570570
logs = self.test_function(data())
571-
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
571+
logs = tree.map_structure(lambda x: np.array(x), logs)
572572
if return_dict:
573573
return logs
574574
return self._flatten_metrics_in_order(logs)
575575

576576
def predict_on_batch(self, x):
577577
self.make_predict_function()
578578
batch_outputs = self.predict_function([(x,)])
579-
batch_outputs = tf.nest.map_structure(
579+
batch_outputs = tree.map_structure(
580580
convert_to_np_if_not_ragged, batch_outputs
581581
)
582582
return batch_outputs
@@ -771,7 +771,7 @@ def _reduce(v):
771771
f"Received: reduction={reduction}."
772772
)
773773

774-
return tf.nest.map_structure(_reduce, values)
774+
return tree.map_structure(_reduce, values)
775775

776776

777777
def _multi_worker_concat(v, strategy):

keras/export/export_lib.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.models import Functional
99
from keras.models import Sequential
1010
from keras.utils import io_utils
11+
from keras.utils import tree
1112
from keras.utils.module_utils import tensorflow as tf
1213

1314

@@ -143,16 +144,16 @@ def track(self, resource):
143144
# Variables in the lists below are actually part of the trackables
144145
# that get saved, because the lists are created in __init__.
145146
if backend.backend() == "jax":
146-
self._tf_trackable.variables += tf.nest.flatten(
147-
tf.nest.map_structure(tf.Variable, resource.variables)
147+
self._tf_trackable.variables += tree.flatten(
148+
tree.map_structure(tf.Variable, resource.variables)
148149
)
149-
self._tf_trackable.trainable_variables += tf.nest.flatten(
150-
tf.nest.map_structure(
150+
self._tf_trackable.trainable_variables += tree.flatten(
151+
tree.map_structure(
151152
tf.Variable, resource.trainable_variables
152153
)
153154
)
154-
self._tf_trackable.non_trainable_variables += tf.nest.flatten(
155-
tf.nest.map_structure(
155+
self._tf_trackable.non_trainable_variables += tree.flatten(
156+
tree.map_structure(
156157
tf.Variable, resource.non_trainable_variables
157158
)
158159
)
@@ -362,9 +363,7 @@ def add_variable_collection(self, name, variables):
362363
f"{list(set(type(v) for v in variables))}"
363364
)
364365
if backend.backend() == "jax":
365-
variables = tf.nest.flatten(
366-
tf.nest.map_structure(tf.Variable, variables)
367-
)
366+
variables = tree.flatten(tree.map_structure(tf.Variable, variables))
368367
setattr(self._tf_trackable, name, list(variables))
369368

370369
def write_out(self, filepath, options=None):
@@ -470,7 +469,7 @@ def _convert_jax2tf_function(self, fn, input_signature):
470469

471470
def _spec_to_poly_shape(self, spec):
472471
if isinstance(spec, (dict, list)):
473-
return tf.nest.map_structure(self._spec_to_poly_shape, spec)
472+
return tree.map_structure(self._spec_to_poly_shape, spec)
474473
spec_shape = spec.shape
475474
spec_shape = str(spec_shape).replace("None", "b")
476475
return spec_shape
@@ -500,7 +499,7 @@ def export_model(model, filepath):
500499
export_archive = ExportArchive()
501500
export_archive.track(model)
502501
if isinstance(model, (Functional, Sequential)):
503-
input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs)
502+
input_signature = tree.map_structure(_make_tensor_spec, model.inputs)
504503
if isinstance(input_signature, list) and len(input_signature) > 1:
505504
input_signature = [input_signature]
506505
export_archive.add_endpoint("serve", model.__call__, input_signature)

keras/models/cloning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _clone_layer(layer):
261261
)
262262
try:
263263
tree.assert_same_structure(input_tensors, model.input)
264-
except TypeError as e:
264+
except (ValueError, TypeError) as e:
265265
raise ValueError(
266266
"`input_tensors` must have the same structure as model.input"
267267
f"\nReference structure: {model.input}"

keras/ops/core_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -789,9 +789,8 @@ def test_cond_check_output_spec_list_tuple(self):
789789

790790
def test_cond_check_output_spec_other_types(self):
791791
cond_op = core.Cond()
792-
# Create mock objects with dtype and shape attributes
793-
mock_spec1 = Mock(dtype="float32", shape=(2, 2))
794-
mock_spec2 = Mock(dtype="float32", shape=(2, 2))
792+
mock_spec1 = KerasTensor(shape=(2, 2), dtype="float32")
793+
mock_spec2 = KerasTensor(shape=(2, 2), dtype="float32")
795794
self.assertTrue(cond_op._check_output_spec(mock_spec1, mock_spec2))
796795

797796
def test_cond_check_output_spec_none(self):

keras/utils/tracking.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from functools import wraps
22

3+
import optree
4+
import optree.utils
5+
36
from keras.backend.common.global_state import get_global_attribute
47
from keras.backend.common.global_state import set_global_attribute
58
from keras.utils import python_utils
@@ -110,6 +113,7 @@ def add_to_store(self, store_name, value):
110113
self.stored_ids[store_name].add(id(value))
111114

112115

116+
@optree.register_pytree_node_class(namespace="keras")
113117
class TrackedList(list):
114118
def __init__(self, values=None, tracker=None):
115119
self.tracker = tracker
@@ -160,7 +164,17 @@ def __delitem__(self, index):
160164
if self.tracker:
161165
self.tracker.untrack(value)
162166

167+
def tree_flatten(self):
168+
# For optree
169+
return (self, None)
170+
171+
@classmethod
172+
def tree_unflatten(cls, metadata, children):
173+
# For optree
174+
return cls(children)
163175

176+
177+
@optree.register_pytree_node_class(namespace="keras")
164178
class TrackedDict(dict):
165179
def __init__(self, values=None, tracker=None):
166180
self.tracker = tracker
@@ -199,7 +213,20 @@ def clear(self):
199213
self.tracker.untrack(value)
200214
super().clear()
201215

216+
def tree_flatten(self):
217+
# For optree
218+
keys, values = optree.utils.unzip2(
219+
optree.utils.total_order_sorted(self.items(), key=lambda kv: kv[0])
220+
)
221+
return values, list(keys), keys
222+
223+
@classmethod
224+
def tree_unflatten(cls, keys, values):
225+
# For optree
226+
return cls(optree.utils.safe_zip(keys, values))
227+
202228

229+
@optree.register_pytree_node_class(namespace="keras")
203230
class TrackedSet(set):
204231
def __init__(self, values=None, tracker=None):
205232
self.tracker = tracker
@@ -233,3 +260,12 @@ def clear(self):
233260
for value in self:
234261
self.tracker.untrack(value)
235262
super().clear()
263+
264+
def tree_flatten(self):
265+
# For optree
266+
return (self, None)
267+
268+
@classmethod
269+
def tree_unflatten(cls, metadata, children):
270+
# For optree
271+
return cls(children)

0 commit comments

Comments
 (0)