You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
x = tf.ones((1000, 2))
y = tf.ones((1000,))
model.fit(x, y, batch_size=8)
The ArrayDataAdapter will convert all tf.Tensor inputs to np.array, before converting them back to the tf.Tensor format. The same roundtrip through np.array is true for jax and torch tensors on the respective backends.
This seem inefficient. We should benchmark if this is causing a slowdown on any backend, and fix if necessary.