@@ -175,31 +175,30 @@ def test_lambda_fn(self):
175
175
_ , new_obj , _ = self .roundtrip (obj , safe_mode = False )
176
176
self .assertEqual (obj ["activation" ](3 ), new_obj ["activation" ](3 ))
177
177
178
- # TODO
179
- # def test_lambda_layer(self):
180
- # lmbda = keras.layers.Lambda(lambda x: x**2)
181
- # with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
182
- # self.roundtrip(lmbda, safe_mode=True)
183
-
184
- # _, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
185
- # x = ops.random.normal((2, 2))
186
- # y1 = lmbda(x)
187
- # y2 = new_lmbda(x)
188
- # self.assertAllClose(y1, y2, atol=1e-5)
189
-
190
- # def test_safe_mode_scope(self):
191
- # lmbda = keras.layers.Lambda(lambda x: x**2)
192
- # with serialization_lib.SafeModeScope(safe_mode=True):
193
- # with self.assertRaisesRegex(
194
- # ValueError, "arbitrary code execution"
195
- # ):
196
- # self.roundtrip(lmbda)
197
- # with serialization_lib.SafeModeScope(safe_mode=False):
198
- # _, new_lmbda, _ = self.roundtrip(lmbda)
199
- # x = ops.random.normal((2, 2))
200
- # y1 = lmbda(x)
201
- # y2 = new_lmbda(x)
202
- # self.assertAllClose(y1, y2, atol=1e-5)
178
+ def test_lambda_layer (self ):
179
+ lmbda = keras .layers .Lambda (lambda x : x ** 2 )
180
+ with self .assertRaisesRegex (ValueError , "Deserializing it is unsafe" ):
181
+ self .roundtrip (lmbda , safe_mode = True )
182
+
183
+ _ , new_lmbda , _ = self .roundtrip (lmbda , safe_mode = False )
184
+ x = ops .random .normal ((2 , 2 ))
185
+ y1 = lmbda (x )
186
+ y2 = new_lmbda (x )
187
+ self .assertAllClose (y1 , y2 , atol = 1e-5 )
188
+
189
+ def test_safe_mode_scope (self ):
190
+ lmbda = keras .layers .Lambda (lambda x : x ** 2 )
191
+ with serialization_lib .SafeModeScope (safe_mode = True ):
192
+ with self .assertRaisesRegex (
193
+ ValueError , "Deserializing it is unsafe"
194
+ ):
195
+ self .roundtrip (lmbda )
196
+ with serialization_lib .SafeModeScope (safe_mode = False ):
197
+ _ , new_lmbda , _ = self .roundtrip (lmbda )
198
+ x = ops .random .normal ((2 , 2 ))
199
+ y1 = lmbda (x )
200
+ y2 = new_lmbda (x )
201
+ self .assertAllClose (y1 , y2 , atol = 1e-5 )
203
202
204
203
@pytest .mark .requires_trainable_backend
205
204
def test_dict_inputs_outputs (self ):
0 commit comments