|
2 | 2 | import pytest
|
3 | 3 | from absl.testing import parameterized
|
4 | 4 |
|
| 5 | +import keras |
5 | 6 | from keras import backend
|
6 | 7 | from keras import layers
|
7 | 8 | from keras import losses
|
8 | 9 | from keras import models
|
| 10 | +from keras import ops |
9 | 11 | from keras import testing
|
10 | 12 | from keras.backend.common import standardize_dtype
|
11 | 13 | from keras.backend.common.keras_tensor import KerasTensor
|
@@ -84,6 +86,22 @@ def test_softmax(self):
|
84 | 86 | self.assertEqual(knn.softmax(x, axis=1).shape, (None, 2, 3))
|
85 | 87 | self.assertEqual(knn.softmax(x, axis=-1).shape, (None, 2, 3))
|
86 | 88 |
|
| 89 | + def test_softmax_in_graph(self): |
| 90 | + class SoftmaxLayer(keras.Layer): |
| 91 | + def call(self, x): |
| 92 | + return ops.softmax(x, axis=-1) |
| 93 | + |
| 94 | + class Model(keras.Model): |
| 95 | + def __init__(self): |
| 96 | + x = keras.Input(shape=(None,)) |
| 97 | + y = SoftmaxLayer()(x) |
| 98 | + super().__init__(inputs=x, outputs=y) |
| 99 | + |
| 100 | + # Make sure Keras is able to compile the model graph |
| 101 | + model = Model() |
| 102 | + x = ops.array([[1.0, 2.0, 3.0, 4.0]]) |
| 103 | + model.predict(x) |
| 104 | + |
87 | 105 | def test_log_softmax(self):
|
88 | 106 | x = KerasTensor([None, 2, 3])
|
89 | 107 | self.assertEqual(knn.log_softmax(x).shape, (None, 2, 3))
|
|
0 commit comments