Skip to content

ArrayDataAdapter will needlessly convert backend native tensors #18408

@mattdangerw

Description

@mattdangerw

Given the following snippet on the tf backend....

x = tf.ones((1000, 2))
y = tf.ones((1000,))
model.fit(x, y, batch_size=8)

The ArrayDataAdapter will convert all tf.Tensor inputs to np.array, before converting them back to the tf.Tensor format. The same roundtrip through np.array is true for jax and torch tensors on the respective backends.

This seem inefficient. We should benchmark if this is causing a slowdown on any backend, and fix if necessary.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions