Skip to content

Commit 6a266b8

Browse files
Fix keras.ops.softmax for the tensorflow backend (#19300)
1 parent df705d4 commit 6a266b8

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

keras/ops/nn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,13 @@ def softmax(x, axis=-1):
538538
array([0.09003057, 0.24472847, 0.66524096], shape=(3,), dtype=float64)
539539
540540
"""
541-
if isinstance(axis, int) and backend.shape(x)[axis] == 1:
541+
# Don't use `backend.shape` since TensorFlow returns
542+
# symbolic tensors for unknown shape which can trigger
543+
# an error in TensorFlow graph execution.
544+
if isinstance(axis, int) and x.shape[axis] == 1:
542545
warnings.warn(
543546
f"You are using a softmax over axis {axis} "
544-
f"of a tensor of shape {backend.shape(x)}. This axis "
547+
f"of a tensor of shape {x.shape}. This axis "
545548
"has size 1. The softmax operation will always return "
546549
"the value 1, which is likely not what you intended. "
547550
"Did you mean to use a sigmoid instead?"

keras/ops/nn_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import pytest
33
from absl.testing import parameterized
44

5+
import keras
56
from keras import backend
67
from keras import layers
78
from keras import losses
89
from keras import models
10+
from keras import ops
911
from keras import testing
1012
from keras.backend.common import standardize_dtype
1113
from keras.backend.common.keras_tensor import KerasTensor
@@ -84,6 +86,22 @@ def test_softmax(self):
8486
self.assertEqual(knn.softmax(x, axis=1).shape, (None, 2, 3))
8587
self.assertEqual(knn.softmax(x, axis=-1).shape, (None, 2, 3))
8688

89+
def test_softmax_in_graph(self):
90+
class SoftmaxLayer(keras.Layer):
91+
def call(self, x):
92+
return ops.softmax(x, axis=-1)
93+
94+
class Model(keras.Model):
95+
def __init__(self):
96+
x = keras.Input(shape=(None,))
97+
y = SoftmaxLayer()(x)
98+
super().__init__(inputs=x, outputs=y)
99+
100+
# Make sure Keras is able to compile the model graph
101+
model = Model()
102+
x = ops.array([[1.0, 2.0, 3.0, 4.0]])
103+
model.predict(x)
104+
87105
def test_log_softmax(self):
88106
x = KerasTensor([None, 2, 3])
89107
self.assertEqual(knn.log_softmax(x).shape, (None, 2, 3))

0 commit comments

Comments
 (0)