Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions keras/src/layers/core/lambda_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,15 @@ def _serialize_function_to_config(self, fn):
)

@staticmethod
def _raise_for_lambda_deserialization(arg_name, safe_mode):
def _raise_for_lambda_deserialization(safe_mode):
if safe_mode:
raise ValueError(
f"The `{arg_name}` of this `Lambda` layer is a Python lambda. "
"Deserializing it is unsafe. If you trust the source of the "
"config artifact, you can override this error "
"by passing `safe_mode=False` "
"to `from_config()`, or calling "
"Requested the deserialization of a `Lambda` layer whose "
"`function` is a Python lambda. This carries a potential risk "
"of arbitrary code execution and thus it is disallowed by "
"default. If you trust the source of the artifact, you can "
"override this error by passing `safe_mode=False` to the "
"loading function, or calling "
"`keras.config.enable_unsafe_deserialization()."
)

Expand All @@ -187,7 +188,7 @@ def from_config(cls, config, custom_objects=None, safe_mode=None):
and "class_name" in fn_config
and fn_config["class_name"] == "__lambda__"
):
cls._raise_for_lambda_deserialization("function", safe_mode)
cls._raise_for_lambda_deserialization(safe_mode)
inner_config = fn_config["config"]
fn = python_utils.func_load(
inner_config["code"],
Expand All @@ -206,7 +207,7 @@ def from_config(cls, config, custom_objects=None, safe_mode=None):
and "class_name" in fn_config
and fn_config["class_name"] == "__lambda__"
):
cls._raise_for_lambda_deserialization("function", safe_mode)
cls._raise_for_lambda_deserialization(safe_mode)
inner_config = fn_config["config"]
fn = python_utils.func_load(
inner_config["code"],
Expand Down
9 changes: 7 additions & 2 deletions keras/src/legacy/saving/legacy_h5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.src.legacy.saving import saving_options
from keras.src.legacy.saving import saving_utils
from keras.src.saving import object_registration
from keras.src.saving import serialization_lib
from keras.src.utils import io_utils

try:
Expand Down Expand Up @@ -72,7 +73,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
f.close()


def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
def load_model_from_hdf5(
filepath, custom_objects=None, compile=True, safe_mode=True
):
"""Loads a model saved via `save_model_to_hdf5`.
Args:
Expand Down Expand Up @@ -128,7 +131,9 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
model_config = model_config.decode("utf-8")
model_config = json_utils.decode(model_config)

with saving_options.keras_option_scope(use_legacy_config=True):
legacy_scope = saving_options.keras_option_scope(use_legacy_config=True)
safe_mode_scope = serialization_lib.SafeModeScope(safe_mode)
with legacy_scope, safe_mode_scope:
model = saving_utils.model_from_config(
model_config, custom_objects=custom_objects
)
Expand Down
14 changes: 12 additions & 2 deletions keras/src/legacy/saving/legacy_h5_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,13 @@ def test_saving_lambda(self):

temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5")
legacy_h5_format.save_model_to_hdf5(model, temp_filepath)
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)

with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
legacy_h5_format.load_model_from_hdf5(temp_filepath)

loaded = legacy_h5_format.load_model_from_hdf5(
temp_filepath, safe_mode=False
)
self.assertAllClose(mean, loaded.layers[1].arguments["mu"])
self.assertAllClose(std, loaded.layers[1].arguments["std"])

Expand Down Expand Up @@ -353,8 +358,13 @@ def test_saving_lambda(self):

temp_filepath = os.path.join(self.get_temp_dir(), "lambda_model.h5")
tf_keras_model.save(temp_filepath)
loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)

with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
legacy_h5_format.load_model_from_hdf5(temp_filepath)

loaded = legacy_h5_format.load_model_from_hdf5(
temp_filepath, safe_mode=False
)
self.assertAllClose(mean, loaded.layers[1].arguments["mu"])
self.assertAllClose(std, loaded.layers[1].arguments["std"])

Expand Down
12 changes: 0 additions & 12 deletions keras/src/legacy/saving/saving_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import threading

from absl import logging
Expand Down Expand Up @@ -81,10 +80,6 @@ def model_from_config(config, custom_objects=None):
function_dict["config"]["closure"] = function_config[2]
config["config"]["function"] = function_dict

# TODO(nkovela): Swap find and replace args during Keras 3.0 release
# Replace keras refs with keras
config = _find_replace_nested_dict(config, "keras.", "keras.")

return serialization.deserialize_keras_object(
config,
module_objects=MODULE_OBJECTS.ALL_OBJECTS,
Expand Down Expand Up @@ -231,13 +226,6 @@ def _deserialize_metric(metric_config):
return metrics_module.deserialize(metric_config)


def _find_replace_nested_dict(config, find, replace):
dict_str = json.dumps(config)
dict_str = dict_str.replace(find, replace)
config = json.loads(dict_str)
return config


def _resolve_compile_arguments_compat(obj, obj_config, module):
"""Resolves backwards compatibility issues with training config arguments.

Expand Down
14 changes: 0 additions & 14 deletions keras/src/legacy/saving/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import contextlib
import inspect
import json
import threading
import weakref

Expand Down Expand Up @@ -485,12 +484,6 @@ def deserialize(config, custom_objects=None):
arg_spec = inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}

# TODO(nkovela): Swap find and replace args during Keras 3.0 release
# Replace keras refs with keras
cls_config = _find_replace_nested_dict(
cls_config, "keras.", "keras."
)

if "custom_objects" in arg_spec.args:
deserialized_obj = cls.from_config(
cls_config,
Expand Down Expand Up @@ -565,10 +558,3 @@ def validate_config(config):
def is_default(method):
"""Check if a method is decorated with the `default` wrapper."""
return getattr(method, "_is_default", False)


def _find_replace_nested_dict(config, find, replace):
dict_str = json.dumps(config)
dict_str = dict_str.replace(find, replace)
config = json.loads(dict_str)
return config
5 changes: 4 additions & 1 deletion keras/src/saving/saving_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
)
if str(filepath).endswith((".h5", ".hdf5")):
return legacy_h5_format.load_model_from_hdf5(
filepath, custom_objects=custom_objects, compile=compile
filepath,
custom_objects=custom_objects,
compile=compile,
safe_mode=safe_mode,
)
elif str(filepath).endswith(".keras"):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion keras/src/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ def test_safe_mode(self):
]
)
model.save(temp_filepath)
with self.assertRaisesRegex(ValueError, "Deserializing it is unsafe"):
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
model = saving_lib.load_model(temp_filepath)
model = saving_lib.load_model(temp_filepath, safe_mode=False)

Expand Down
12 changes: 6 additions & 6 deletions keras/src/saving/serialization_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,12 +656,12 @@ class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
if config["class_name"] == "__lambda__":
if safe_mode:
raise ValueError(
"Requested the deserialization of a `lambda` object. "
"This carries a potential risk of arbitrary code execution "
"and thus it is disallowed by default. If you trust the "
"source of the saved model, you can pass `safe_mode=False` to "
"the loading function in order to allow `lambda` loading, "
"or call `keras.config.enable_unsafe_deserialization()`."
"Requested the deserialization of a Python lambda. This "
"carries a potential risk of arbitrary code execution and thus "
"it is disallowed by default. If you trust the source of the "
"artifact, you can override this error by passing "
"`safe_mode=False` to the loading function, or calling "
"`keras.config.enable_unsafe_deserialization()."
)
return python_utils.func_load(inner_config["value"])
if tf is not None and config["class_name"] == "__typespec__":
Expand Down
47 changes: 22 additions & 25 deletions keras/src/saving/serialization_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,31 +175,28 @@ def test_lambda_fn(self):
_, new_obj, _ = self.roundtrip(obj, safe_mode=False)
self.assertEqual(obj["activation"](3), new_obj["activation"](3))

# TODO
# def test_lambda_layer(self):
# lmbda = keras.layers.Lambda(lambda x: x**2)
# with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
# self.roundtrip(lmbda, safe_mode=True)

# _, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
# x = ops.random.normal((2, 2))
# y1 = lmbda(x)
# y2 = new_lmbda(x)
# self.assertAllClose(y1, y2, atol=1e-5)

# def test_safe_mode_scope(self):
# lmbda = keras.layers.Lambda(lambda x: x**2)
# with serialization_lib.SafeModeScope(safe_mode=True):
# with self.assertRaisesRegex(
# ValueError, "arbitrary code execution"
# ):
# self.roundtrip(lmbda)
# with serialization_lib.SafeModeScope(safe_mode=False):
# _, new_lmbda, _ = self.roundtrip(lmbda)
# x = ops.random.normal((2, 2))
# y1 = lmbda(x)
# y2 = new_lmbda(x)
# self.assertAllClose(y1, y2, atol=1e-5)
def test_lambda_layer(self):
lmbda = keras.layers.Lambda(lambda x: x**2)
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
self.roundtrip(lmbda, safe_mode=True)

_, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)
x = ops.random.normal((2, 2))
y1 = lmbda(x)
y2 = new_lmbda(x)
self.assertAllClose(y1, y2, atol=1e-5)

def test_safe_mode_scope(self):
lmbda = keras.layers.Lambda(lambda x: x**2)
with serialization_lib.SafeModeScope(safe_mode=True):
with self.assertRaisesRegex(ValueError, "arbitrary code execution"):
self.roundtrip(lmbda)
with serialization_lib.SafeModeScope(safe_mode=False):
_, new_lmbda, _ = self.roundtrip(lmbda)
x = ops.random.normal((2, 2))
y1 = lmbda(x)
y2 = new_lmbda(x)
self.assertAllClose(y1, y2, atol=1e-5)

@pytest.mark.requires_trainable_backend
def test_dict_inputs_outputs(self):
Expand Down
8 changes: 4 additions & 4 deletions keras/src/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ def from_config(cls, config):
"Requested the deserialization of a `torch.nn.Module` "
"object via `torch.load()`. This carries a potential risk "
"of arbitrary code execution and thus it is disallowed by "
"default. If you trust the source of the saved model, you "
"can pass `safe_mode=False` to the loading function in "
"order to allow `torch.nn.Module` loading, or call "
"`keras.config.enable_unsafe_deserialization()`."
"default. If you trust the source of the artifact, you can "
"override this error by passing `safe_mode=False` to the "
"loading function, or calling "
"`keras.config.enable_unsafe_deserialization()."
)

# Decode the base64 string back to bytes
Expand Down