Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions keras/src/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, inputs, outputs, name=None, **kwargs):
if not all(is_input_keras_tensor(t) for t in flat_inputs):
inputs, outputs = clone_graph_nodes(inputs, outputs)

Function.__init__(self, inputs, outputs, name=name, **kwargs)
Function.__init__(self, inputs, outputs, name=name)

if trainable is not None:
self.trainable = trainable
Expand Down Expand Up @@ -494,16 +494,28 @@ def process_layer(layer_data):
# (e.g. a model such as A(B(A(B(x)))))
add_unprocessed_node(layer, node_data)

# Extract config used to instantiate Functional model from the config. The
# remaining config will be passed as keyword arguments to the Model
# constructor.
functional_config = {}
for key in ["layers", "input_layers", "output_layers"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simply this block:

  • single iteration over set of keys
  • use pop(key, None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have applied your suggestion, I like the idea to simplify it. However, it does not keep the previous behavior. Now instead of throwing a KeyError if layers, input_layers or output_layers is not present, it will silently set its value to None. It doesn't seem to be an issue, but I just wanted to mention it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it would lead to a strange error message instead of a clear one if the config is malformed. Let's avoid that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted the commit then. I am still waiting for the CLA to be handled. I'll let you know when that's done.

functional_config[key] = config.pop(key)
for key in ["name", "trainable"]:
if key in config:
functional_config[key] = config.pop(key)
else:
functional_config[key] = None

# First, we create all layers and enqueue nodes to be processed
for layer_data in config["layers"]:
for layer_data in functional_config["layers"]:
process_layer(layer_data)

# Then we process nodes in order of layer depth.
# Nodes that cannot yet be processed (if the inbound node
# does not yet exist) are re-enqueued, and the process
# is repeated until all nodes are processed.
while unprocessed_nodes:
for layer_data in config["layers"]:
for layer_data in functional_config["layers"]:
layer = created_layers[layer_data["name"]]

# Process all nodes in layer, if not yet processed
Expand Down Expand Up @@ -532,8 +544,8 @@ def process_layer(layer_data):
del unprocessed_nodes[layer]

# Create list of input and output tensors and return new class
name = config.get("name")
trainable = config.get("trainable")
name = functional_config["name"]
trainable = functional_config["trainable"]

def get_tensor(layer_name, node_index, tensor_index):
assert layer_name in created_layers
Expand All @@ -558,8 +570,8 @@ def map_tensors(tensors):
return tuple([map_tensors(v) for v in tensors])
return [map_tensors(v) for v in tensors]

input_tensors = map_tensors(config["input_layers"])
output_tensors = map_tensors(config["output_layers"])
input_tensors = map_tensors(functional_config["input_layers"])
output_tensors = map_tensors(functional_config["output_layers"])
if isinstance(input_tensors, list) and len(input_tensors) == 1:
input_tensors = input_tensors[0]
if isinstance(output_tensors, list) and len(output_tensors) == 1:
Expand All @@ -570,6 +582,7 @@ def map_tensors(tensors):
outputs=output_tensors,
name=name,
trainable=trainable,
**config,
)


Expand Down
18 changes: 18 additions & 0 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ def call(self, x):
)
self.assertIsInstance(new_model, Functional)

def test_reviving_functional_from_config_custom_model(self):
class CustomModel(Model):
def __init__(self, *args, param=1, **kwargs):
super().__init__(*args, **kwargs)
self.param = param

def get_config(self):
base_config = super().get_config()
config = {"param": self.param}
return base_config | config

inputs = layers.Input((3,))
outputs = layers.Dense(5)(inputs)
model = CustomModel(inputs=inputs, outputs=outputs, param=3)

new_model = CustomModel.from_config(model.get_config())
self.assertEqual(new_model.param, 3)

@parameterized.named_parameters(
("single_output_1", _get_model_single_output),
("single_output_2", _get_model_single_output),
Expand Down