|
17 | 17 | from keras.backend.common.backend_utils import to_tuple_or_list
|
18 | 18 | from keras.backend.tensorflow import sparse
|
19 | 19 | from keras.backend.tensorflow.core import convert_to_tensor
|
| 20 | +from keras.utils import tree |
20 | 21 |
|
21 | 22 |
|
22 | 23 | @sparse.elementwise_binary_union(tf.sparse.add)
|
@@ -95,7 +96,7 @@ def _normalize_einsum_subscripts(subscripts):
|
95 | 96 |
|
96 | 97 |
|
97 | 98 | def einsum(subscripts, *operands, **kwargs):
|
98 |
| - operands = tf.nest.map_structure(convert_to_tensor, operands) |
| 99 | + operands = tree.map_structure(convert_to_tensor, operands) |
99 | 100 | subscripts = _normalize_einsum_subscripts(subscripts)
|
100 | 101 |
|
101 | 102 | def is_valid_for_custom_ops(subscripts, *operands):
|
@@ -240,15 +241,15 @@ def use_custom_ops(subscripts, *operands, output_type):
|
240 | 241 | # output_type="int32"
|
241 | 242 | if "int" in compute_dtype and output_type is None:
|
242 | 243 | compute_dtype = config.floatx()
|
243 |
| - operands = tf.nest.map_structure( |
| 244 | + operands = tree.map_structure( |
244 | 245 | lambda x: tf.cast(x, compute_dtype), operands
|
245 | 246 | )
|
246 | 247 | result = use_custom_ops(subscripts, *operands, output_type=output_type)
|
247 | 248 | else:
|
248 | 249 | # TODO: tf.einsum doesn't support integer dtype with gpu
|
249 | 250 | if "int" in compute_dtype:
|
250 | 251 | compute_dtype = config.floatx()
|
251 |
| - operands = tf.nest.map_structure( |
| 252 | + operands = tree.map_structure( |
252 | 253 | lambda x: tf.cast(x, compute_dtype), operands
|
253 | 254 | )
|
254 | 255 | result = tf.einsum(subscripts, *operands, **kwargs)
|
@@ -763,11 +764,11 @@ def concatenate(xs, axis=0):
|
763 | 764 | )
|
764 | 765 | for x in xs
|
765 | 766 | ]
|
766 |
| - xs = tf.nest.map_structure(convert_to_tensor, xs) |
| 767 | + xs = tree.map_structure(convert_to_tensor, xs) |
767 | 768 | dtype_set = set([x.dtype for x in xs])
|
768 | 769 | if len(dtype_set) > 1:
|
769 | 770 | dtype = dtypes.result_type(*dtype_set)
|
770 |
| - xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs) |
| 771 | + xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs) |
771 | 772 | return tf.concat(xs, axis=axis)
|
772 | 773 |
|
773 | 774 |
|
@@ -872,7 +873,7 @@ def digitize(x, bins):
|
872 | 873 | bins = list(bins)
|
873 | 874 |
|
874 | 875 | # bins must be float type
|
875 |
| - bins = tf.nest.map_structure(lambda x: float(x), bins) |
| 876 | + bins = tree.map_structure(lambda x: float(x), bins) |
876 | 877 |
|
877 | 878 | # TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8
|
878 | 879 | # int16, uint8, uint16, uint32
|
@@ -1023,7 +1024,7 @@ def hstack(xs):
|
1023 | 1024 | dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
|
1024 | 1025 | if len(dtype_set) > 1:
|
1025 | 1026 | dtype = dtypes.result_type(*dtype_set)
|
1026 |
| - xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs) |
| 1027 | + xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs) |
1027 | 1028 | rank = tf.rank(xs[0])
|
1028 | 1029 | return tf.cond(
|
1029 | 1030 | tf.equal(rank, 1),
|
@@ -1328,9 +1329,7 @@ def ndim(x):
|
1328 | 1329 | def nonzero(x):
|
1329 | 1330 | x = convert_to_tensor(x)
|
1330 | 1331 | result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1)
|
1331 |
| - return tf.nest.map_structure( |
1332 |
| - lambda indices: tf.cast(indices, "int32"), result |
1333 |
| - ) |
| 1332 | + return tree.map_structure(lambda indices: tf.cast(indices, "int32"), result) |
1334 | 1333 |
|
1335 | 1334 |
|
1336 | 1335 | def not_equal(x1, x2):
|
@@ -1620,7 +1619,7 @@ def stack(x, axis=0):
|
1620 | 1619 | dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
|
1621 | 1620 | if len(dtype_set) > 1:
|
1622 | 1621 | dtype = dtypes.result_type(*dtype_set)
|
1623 |
| - x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x) |
| 1622 | + x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x) |
1624 | 1623 | return tf.stack(x, axis=axis)
|
1625 | 1624 |
|
1626 | 1625 |
|
@@ -1807,7 +1806,7 @@ def vstack(xs):
|
1807 | 1806 | dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
|
1808 | 1807 | if len(dtype_set) > 1:
|
1809 | 1808 | dtype = dtypes.result_type(*dtype_set)
|
1810 |
| - xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs) |
| 1809 | + xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs) |
1811 | 1810 | return tf.concat(xs, axis=0)
|
1812 | 1811 |
|
1813 | 1812 |
|
|
0 commit comments