Skip to content

Commit 395505c

Browse files
committed
Fixed missing axis annotation in tf map_coordinates
1 parent dfaca0e commit 395505c

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

keras/src/backend/tensorflow/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def process_coordinates(coords, size):
707707
gathered = tf.transpose(tf.gather_nd(input_arr, indices))
708708

709709
if fill_mode == "constant":
710-
all_valid = tf.reduce_all(validities)
710+
all_valid = tf.reduce_all(validities, axis=0)
711711
gathered = tf.where(all_valid, gathered, fill_value)
712712

713713
contribution = gathered

keras/src/ops/image_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,28 @@ def test_elastic_transform(self):
18651865
)
18661866
self.assertAllClose(np.var(ref_out), np.var(out), atol=1e-2, rtol=1e-2)
18671867

1868+
def test_map_coordinates_constant_padding_tf(self):
1869+
input_img = tf.ones((2, 2), dtype=tf.uint8)
1870+
# one pixel outside of the input space around the edges
1871+
grid = tf.stack(
1872+
tf.meshgrid(
1873+
tf.range(-1, 3, dtype=tf.float32),
1874+
tf.range(-1, 3, dtype=tf.float32),
1875+
indexing="ij",
1876+
),
1877+
axis=0,
1878+
)
1879+
out = kimage.map_coordinates(
1880+
input_img, grid, order=0, fill_mode="constant", fill_value=0
1881+
).numpy()
1882+
1883+
# check for ones in the middle and zeros around the edges
1884+
assert np.all(out[:1] == 0)
1885+
assert np.all(out[-1:] == 0)
1886+
assert np.all(out[:, :1] == 0)
1887+
assert np.all(out[:, -1:] == 0)
1888+
assert np.all(out[1:3, 1:3] == 1)
1889+
18681890

18691891
class ImageOpsBehaviorTests(testing.TestCase):
18701892
def setUp(self):

0 commit comments

Comments
 (0)