Skip to content

Commit 7da5b95

Browse files
committed
Allow tensors in tf.Datasets to have different dimensions.
The shape for the `tf.TensorSpec` for the `tf.Dataset` is determined by inspecting several batches and keeping dimensions that are common. Fixes #19124
1 parent 818c9fa commit 7da5b95

File tree

4 files changed

+86
-87
lines changed

4 files changed

+86
-87
lines changed

keras/trainers/data_adapters/data_adapter_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from keras.api_export import keras_export
55
from keras.utils import tree
66

7+
NUM_SAMPLES_FOR_TENSOR_SPEC = 4
8+
79

810
@keras_export("keras.utils.unpack_x_y_sample_weight")
911
def unpack_x_y_sample_weight(data):
@@ -125,6 +127,54 @@ def class_weight_to_sample_weights(y, class_weight):
125127
return sample_weight
126128

127129

130+
def get_tensor_spec(batches):
131+
"""Return the common tensor spec for a list of batches.
132+
133+
Args:
134+
batches: list of structures of tensors. The structures must be
135+
identical, but the shape at each leaf may be different.
136+
Returns: the common tensor spec for all the batches.
137+
"""
138+
from keras.utils.module_utils import tensorflow as tf
139+
140+
def get_single_tensor_spec(*tensors):
141+
x = tensors[0]
142+
rank = len(x.shape)
143+
if rank < 1:
144+
raise ValueError(
145+
"When passing a dataset to a Keras model, the arrays must "
146+
f"be at least rank 1. Received: {x} of rank {len(x.shape)}."
147+
)
148+
for t in tensors:
149+
if len(t.shape) != rank:
150+
raise ValueError(
151+
"When passing a dataset to a Keras model, the "
152+
"corresponding arrays in each batch must have the same "
153+
f"rank. Received: {x} and {t}"
154+
)
155+
shape = []
156+
# Merge shapes: go through each dimension one by one and keep the
157+
# common values
158+
for dims in zip(*[list(x.shape) for x in tensors]):
159+
dims_set = set(dims)
160+
shape.append(dims_set.pop() if len(dims_set) == 1 else None)
161+
shape[0] = None # batch size may not be static
162+
163+
dtype = backend.standardize_dtype(x.dtype)
164+
if isinstance(x, tf.RaggedTensor):
165+
return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
166+
if (
167+
isinstance(x, tf.SparseTensor)
168+
or is_scipy_sparse(x)
169+
or is_jax_sparse(x)
170+
):
171+
return tf.SparseTensorSpec(shape=shape, dtype=dtype)
172+
else:
173+
return tf.TensorSpec(shape=shape, dtype=dtype)
174+
175+
return tree.map_structure(get_single_tensor_spec, *batches)
176+
177+
128178
def get_jax_iterator(iterable):
129179
from keras.backend.jax.core import convert_to_tensor
130180

keras/trainers/data_adapters/generator_data_adapter.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import itertools
22

3-
from keras import backend
43
from keras.trainers.data_adapters import data_adapter_utils
54
from keras.trainers.data_adapters.data_adapter import DataAdapter
65
from keras.utils import tree
@@ -10,49 +9,19 @@ class GeneratorDataAdapter(DataAdapter):
109
"""Adapter for Python generators."""
1110

1211
def __init__(self, generator):
13-
first_batch, generator = peek_and_restore(generator)
12+
first_batches, generator = peek_and_restore(generator)
1413
self.generator = generator
15-
self._first_batch = first_batch
14+
self._first_batches = first_batches
1615
self._output_signature = None
17-
if not isinstance(first_batch, tuple):
16+
if not isinstance(first_batches[0], tuple):
1817
raise ValueError(
1918
"When passing a Python generator to a Keras model, "
2019
"the generator must return a tuple, either "
2120
"(input,) or (inputs, targets) or "
2221
"(inputs, targets, sample_weights). "
23-
f"Received: {first_batch}"
22+
f"Received: {first_batches[0]}"
2423
)
2524

26-
def _set_tf_output_signature(self):
27-
from keras.utils.module_utils import tensorflow as tf
28-
29-
def get_tensor_spec(x):
30-
shape = x.shape
31-
if len(shape) < 1:
32-
raise ValueError(
33-
"When passing a Python generator to a Keras model, "
34-
"the arrays returned by the generator "
35-
"must be at least rank 1. Received: "
36-
f"{x} of rank {len(x.shape)}"
37-
)
38-
shape = list(shape)
39-
shape[0] = None # The batch size is not guaranteed to be static.
40-
dtype = backend.standardize_dtype(x.dtype)
41-
if isinstance(x, tf.RaggedTensor):
42-
return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
43-
if (
44-
isinstance(x, tf.SparseTensor)
45-
or data_adapter_utils.is_scipy_sparse(x)
46-
or data_adapter_utils.is_jax_sparse(x)
47-
):
48-
return tf.SparseTensorSpec(shape=shape, dtype=dtype)
49-
else:
50-
return tf.TensorSpec(shape=shape, dtype=dtype)
51-
52-
self._output_signature = tree.map_structure(
53-
get_tensor_spec, self._first_batch
54-
)
55-
5625
def get_numpy_iterator(self):
5726
return data_adapter_utils.get_numpy_iterator(self.generator)
5827

@@ -85,7 +54,9 @@ def get_tf_iterator():
8554
yield batch
8655

8756
if self._output_signature is None:
88-
self._set_tf_output_signature()
57+
self._output_signature = data_adapter_utils.get_tensor_spec(
58+
self._first_batches
59+
)
8960
ds = tf.data.Dataset.from_generator(
9061
get_tf_iterator,
9162
output_signature=self._output_signature,
@@ -106,5 +77,9 @@ def batch_size(self):
10677

10778

10879
def peek_and_restore(generator):
109-
element = next(generator)
110-
return element, itertools.chain([element], generator)
80+
batches = list(
81+
itertools.islice(
82+
generator, data_adapter_utils.NUM_SAMPLES_FOR_TENSOR_SPEC
83+
)
84+
)
85+
return batches, itertools.chain(batches, generator)

keras/trainers/data_adapters/py_dataset_adapter.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99

1010
import numpy as np
1111

12-
from keras import backend
1312
from keras.api_export import keras_export
1413
from keras.trainers.data_adapters import data_adapter_utils
1514
from keras.trainers.data_adapters.data_adapter import DataAdapter
16-
from keras.utils import tree
1715

1816

1917
@keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"])
@@ -188,28 +186,6 @@ def __init__(
188186
self.shuffle = shuffle
189187
self._output_signature = None
190188

191-
def _set_tf_output_signature(self):
192-
from keras.utils.module_utils import tensorflow as tf
193-
194-
def get_tensor_spec(x):
195-
shape = x.shape
196-
if len(shape) < 1:
197-
raise ValueError(
198-
"The arrays returned by PyDataset.__getitem__() "
199-
"must be at least rank 1. Received: "
200-
f"{x} of rank {len(x.shape)}"
201-
)
202-
shape = list(shape)
203-
shape[0] = None # The batch size is not guaranteed to be static.
204-
dtype = backend.standardize_dtype(x.dtype)
205-
return tf.TensorSpec(shape=shape, dtype=dtype)
206-
207-
# Grab the first example
208-
batch = self.py_dataset[0]
209-
# Run checks on it and format it
210-
batch = self._standardize_batch(batch)
211-
self._output_signature = tree.map_structure(get_tensor_spec, batch)
212-
213189
def _standardize_batch(self, batch):
214190
if isinstance(batch, dict):
215191
return batch
@@ -287,7 +263,15 @@ def get_tf_dataset(self):
287263
from keras.utils.module_utils import tensorflow as tf
288264

289265
if self._output_signature is None:
290-
self._set_tf_output_signature()
266+
num_samples = min(
267+
data_adapter_utils.NUM_SAMPLES_FOR_TENSOR_SPEC,
268+
len(self.py_dataset),
269+
)
270+
batches = [
271+
self._standardize_batch(self.py_dataset[i])
272+
for i in range(num_samples)
273+
]
274+
self._output_signature = data_adapter_utils.get_tensor_spec(batches)
291275

292276
ds = tf.data.Dataset.from_generator(
293277
self._get_iterator,

keras/trainers/data_adapters/torch_data_loader_adapter.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import itertools
2+
13
import numpy as np
24

3-
from keras import backend
45
from keras.trainers.data_adapters import data_adapter_utils
56
from keras.trainers.data_adapters.data_adapter import DataAdapter
67
from keras.utils import tree
@@ -19,6 +20,7 @@ def __init__(self, dataloader):
1920
)
2021

2122
self._dataloader = dataloader
23+
self._output_signature = None
2224
self._batch_size = dataloader.batch_size
2325
self._num_batches = None
2426
self._partial_batch_size = None
@@ -44,36 +46,24 @@ def get_jax_iterator(self):
4446
def get_tf_dataset(self):
4547
from keras.utils.module_utils import tensorflow as tf
4648

47-
output_signature = self.peek_and_get_tensor_spec()
49+
if self._output_signature is None:
50+
batches = list(
51+
itertools.islice(
52+
self._dataloader,
53+
data_adapter_utils.NUM_SAMPLES_FOR_TENSOR_SPEC,
54+
)
55+
)
56+
self._output_signature = tuple(
57+
data_adapter_utils.get_tensor_spec(batches)
58+
)
4859
return tf.data.Dataset.from_generator(
4960
self.get_numpy_iterator,
50-
output_signature=output_signature,
61+
output_signature=self._output_signature,
5162
)
5263

5364
def get_torch_dataloader(self):
5465
return self._dataloader
5566

56-
def peek_and_get_tensor_spec(self):
57-
from keras.utils.module_utils import tensorflow as tf
58-
59-
batch_data = next(iter(self._dataloader))
60-
61-
def get_tensor_spec(x):
62-
shape = x.shape
63-
if len(shape) < 1:
64-
raise ValueError(
65-
"When passing a Pytorch DataLoader to a Keras model, "
66-
"the arrays returned by the generator "
67-
"must be at least rank 1. Received: "
68-
f"{x} of rank {len(x.shape)}"
69-
)
70-
shape = list(shape)
71-
shape[0] = None # The batch size is not guaranteed to be static.
72-
dtype = backend.standardize_dtype(x.dtype)
73-
return tf.TensorSpec(shape=shape, dtype=dtype)
74-
75-
return tuple(tree.map_structure(get_tensor_spec, batch_data))
76-
7767
@property
7868
def num_batches(self):
7969
return self._num_batches

0 commit comments

Comments
 (0)