Skip to content

Commit 432f705

Browse files
committed
Propagate safe_mode flag to legacy h5 loading code.
Also: - removed no-op renaming code in legacy saving - uncommented unit tests in `serialization_lib_test.py`
1 parent 7da416d commit 432f705

File tree

6 files changed

+47
-56
lines changed

6 files changed

+47
-56
lines changed

keras/src/legacy/saving/legacy_h5_format.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from keras.src.legacy.saving import saving_options
1212
from keras.src.legacy.saving import saving_utils
1313
from keras.src.saving import object_registration
14+
from keras.src.saving import serialization_lib
1415
from keras.src.utils import io_utils
1516

1617
try:
@@ -72,7 +73,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
7273
f.close()
7374

7475

75-
def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
76+
def load_model_from_hdf5(
77+
filepath, custom_objects=None, compile=True, safe_mode=True
78+
):
7679
"""Loads a model saved via `save_model_to_hdf5`.
7780
7881
Args:
@@ -128,7 +131,9 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
128131
model_config = model_config.decode("utf-8")
129132
model_config = json_utils.decode(model_config)
130133

131-
with saving_options.keras_option_scope(use_legacy_config=True):
134+
legacy_scope = saving_options.keras_option_scope(use_legacy_config=True)
135+
safe_mode_scope = serialization_lib.SafeModeScope(safe_mode)
136+
with legacy_scope, safe_mode_scope:
132137
model = saving_utils.model_from_config(
133138
model_config, custom_objects=custom_objects
134139
)

keras/src/legacy/saving/legacy_h5_format_test.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,13 @@ def test_saving_lambda(self):
158158

159159
temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5")
160160
legacy_h5_format.save_model_to_hdf5(model, temp_filepath)
161-
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)
162161

162+
with self.assertRaisesRegex(ValueError, "Deserializing it is unsafe"):
163+
legacy_h5_format.load_model_from_hdf5(temp_filepath)
164+
165+
loaded = legacy_h5_format.load_model_from_hdf5(
166+
temp_filepath, safe_mode=False
167+
)
163168
self.assertAllClose(mean, loaded.layers[1].arguments["mu"])
164169
self.assertAllClose(std, loaded.layers[1].arguments["std"])
165170

@@ -353,8 +358,13 @@ def test_saving_lambda(self):
353358

354359
temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5")
355360
tf_keras_model.save(temp_filepath)
356-
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)
357361

362+
with self.assertRaisesRegex(ValueError, "Deserializing it is unsafe"):
363+
legacy_h5_format.load_model_from_hdf5(temp_filepath)
364+
365+
loaded = legacy_h5_format.load_model_from_hdf5(
366+
temp_filepath, safe_mode=False
367+
)
358368
self.assertAllClose(mean, loaded.layers[1].arguments["mu"])
359369
self.assertAllClose(std, loaded.layers[1].arguments["std"])
360370

keras/src/legacy/saving/saving_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import threading
32

43
from absl import logging
@@ -81,10 +80,6 @@ def model_from_config(config, custom_objects=None):
8180
function_dict["config"]["closure"] = function_config[2]
8281
config["config"]["function"] = function_dict
8382

84-
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
85-
# Replace keras refs with keras
86-
config = _find_replace_nested_dict(config, "keras.", "keras.")
87-
8883
return serialization.deserialize_keras_object(
8984
config,
9085
module_objects=MODULE_OBJECTS.ALL_OBJECTS,
@@ -231,13 +226,6 @@ def _deserialize_metric(metric_config):
231226
return metrics_module.deserialize(metric_config)
232227

233228

234-
def _find_replace_nested_dict(config, find, replace):
235-
dict_str = json.dumps(config)
236-
dict_str = dict_str.replace(find, replace)
237-
config = json.loads(dict_str)
238-
return config
239-
240-
241229
def _resolve_compile_arguments_compat(obj, obj_config, module):
242230
"""Resolves backwards compatibility issues with training config arguments.
243231

keras/src/legacy/saving/serialization.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import contextlib
44
import inspect
5-
import json
65
import threading
76
import weakref
87

@@ -485,12 +484,6 @@ def deserialize(config, custom_objects=None):
485484
arg_spec = inspect.getfullargspec(cls.from_config)
486485
custom_objects = custom_objects or {}
487486

488-
# TODO(nkovela): Swap find and replace args during Keras 3.0 release
489-
# Replace keras refs with keras
490-
cls_config = _find_replace_nested_dict(
491-
cls_config, "keras.", "keras."
492-
)
493-
494487
if "custom_objects" in arg_spec.args:
495488
deserialized_obj = cls.from_config(
496489
cls_config,
@@ -565,10 +558,3 @@ def validate_config(config):
565558
def is_default(method):
566559
"""Check if a method is decorated with the `default` wrapper."""
567560
return getattr(method, "_is_default", False)
568-
569-
570-
def _find_replace_nested_dict(config, find, replace):
571-
dict_str = json.dumps(config)
572-
dict_str = dict_str.replace(find, replace)
573-
config = json.loads(dict_str)
574-
return config

keras/src/saving/saving_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
194194
)
195195
if str(filepath).endswith((".h5", ".hdf5")):
196196
return legacy_h5_format.load_model_from_hdf5(
197-
filepath, custom_objects=custom_objects, compile=compile
197+
filepath,
198+
custom_objects=custom_objects,
199+
compile=compile,
200+
safe_mode=safe_mode,
198201
)
199202
elif str(filepath).endswith(".keras"):
200203
raise ValueError(

keras/src/saving/serialization_lib_test.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -175,31 +175,30 @@ def test_lambda_fn(self):
175175
_, new_obj, _ = self.roundtrip(obj, safe_mode=False)
176176
self.assertEqual(obj["activation"](3), new_obj["activation"](3))
177177

178-
# TODO
179-
# def test_lambda_layer(self):
180-
# lmbda = keras.layers.Lambda(lambda x: x**2)
181-
# with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
182-
# self.roundtrip(lmbda, safe_mode=True)
183-
184-
# _, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
185-
# x = ops.random.normal((2, 2))
186-
# y1 = lmbda(x)
187-
# y2 = new_lmbda(x)
188-
# self.assertAllClose(y1, y2, atol=1e-5)
189-
190-
# def test_safe_mode_scope(self):
191-
# lmbda = keras.layers.Lambda(lambda x: x**2)
192-
# with serialization_lib.SafeModeScope(safe_mode=True):
193-
# with self.assertRaisesRegex(
194-
# ValueError, "arbitrary code execution"
195-
# ):
196-
# self.roundtrip(lmbda)
197-
# with serialization_lib.SafeModeScope(safe_mode=False):
198-
# _, new_lmbda, _ = self.roundtrip(lmbda)
199-
# x = ops.random.normal((2, 2))
200-
# y1 = lmbda(x)
201-
# y2 = new_lmbda(x)
202-
# self.assertAllClose(y1, y2, atol=1e-5)
178+
def test_lambda_layer(self):
179+
lmbda = keras.layers.Lambda(lambda x: x**2)
180+
with self.assertRaisesRegex(ValueError, "Deserializing it is unsafe"):
181+
self.roundtrip(lmbda, safe_mode=True)
182+
183+
_, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
184+
x = ops.random.normal((2, 2))
185+
y1 = lmbda(x)
186+
y2 = new_lmbda(x)
187+
self.assertAllClose(y1, y2, atol=1e-5)
188+
189+
def test_safe_mode_scope(self):
190+
lmbda = keras.layers.Lambda(lambda x: x**2)
191+
with serialization_lib.SafeModeScope(safe_mode=True):
192+
with self.assertRaisesRegex(
193+
ValueError, "Deserializing it is unsafe"
194+
):
195+
self.roundtrip(lmbda)
196+
with serialization_lib.SafeModeScope(safe_mode=False):
197+
_, new_lmbda, _ = self.roundtrip(lmbda)
198+
x = ops.random.normal((2, 2))
199+
y1 = lmbda(x)
200+
y2 = new_lmbda(x)
201+
self.assertAllClose(y1, y2, atol=1e-5)
203202

204203
@pytest.mark.requires_trainable_backend
205204
def test_dict_inputs_outputs(self):

0 commit comments

Comments
 (0)