|
6 | 6 |
|
7 | 7 | from keras.src import backend
|
8 | 8 | from keras.src.api_export import keras_export
|
9 |
| -from keras.src.callbacks.callback import Callback |
| 9 | +from keras.src.callbacks.monitor_callback import MonitorCallback |
10 | 10 | from keras.src.utils import file_utils
|
11 | 11 | from keras.src.utils import io_utils
|
12 | 12 |
|
13 | 13 |
|
14 | 14 | @keras_export("keras.callbacks.ModelCheckpoint")
|
15 |
| -class ModelCheckpoint(Callback): |
| 15 | +class ModelCheckpoint(MonitorCallback): |
16 | 16 | """Callback to save the Keras model or model weights at some frequency.
|
17 | 17 |
|
18 | 18 | `ModelCheckpoint` callback is used in conjunction with training using
|
@@ -105,9 +105,8 @@ class ModelCheckpoint(Callback):
|
105 | 105 | decision to overwrite the current save file is made based on either
|
106 | 106 | the maximization or the minimization of the monitored quantity.
|
107 | 107 | For `val_acc`, this should be `"max"`, for `val_loss` this should be
|
108 |
| - `"min"`, etc. In `"auto"` mode, the mode is set to `"max"` if the |
109 |
| - quantities monitored are `"acc"` or start with `"fmeasure"` and are |
110 |
| - set to `"min"` for the rest of the quantities. |
| 108 | + `"min"`, etc. In `"auto"` mode, the direction is automatically |
| 109 | + inferred from the name of the monitored quantity. |
111 | 110 | save_weights_only: if `True`, then only the model's weights will be
|
112 | 111 | saved (`model.save_weights(filepath)`), else the full model is
|
113 | 112 | saved (`model.save(filepath)`).
|
@@ -136,42 +135,14 @@ def __init__(
|
136 | 135 | save_freq="epoch",
|
137 | 136 | initial_value_threshold=None,
|
138 | 137 | ):
|
139 |
| - super().__init__() |
140 |
| - self.monitor = monitor |
| 138 | + super().__init__(monitor, mode, initial_value_threshold) |
141 | 139 | self.verbose = verbose
|
142 | 140 | self.filepath = file_utils.path_to_string(filepath)
|
143 | 141 | self.save_best_only = save_best_only
|
144 | 142 | self.save_weights_only = save_weights_only
|
145 | 143 | self.save_freq = save_freq
|
146 | 144 | self._batches_seen_since_last_saving = 0
|
147 | 145 | self._last_batch_seen = 0
|
148 |
| - self.best = initial_value_threshold |
149 |
| - |
150 |
| - if mode not in ["auto", "min", "max"]: |
151 |
| - warnings.warn( |
152 |
| - f"ModelCheckpoint mode '{mode}' is unknown, " |
153 |
| - "fallback to auto mode.", |
154 |
| - stacklevel=2, |
155 |
| - ) |
156 |
| - mode = "auto" |
157 |
| - |
158 |
| - if mode == "min": |
159 |
| - self.monitor_op = np.less |
160 |
| - if self.best is None: |
161 |
| - self.best = np.inf |
162 |
| - elif mode == "max": |
163 |
| - self.monitor_op = np.greater |
164 |
| - if self.best is None: |
165 |
| - self.best = -np.inf |
166 |
| - else: |
167 |
| - if "acc" in self.monitor or self.monitor.startswith("fmeasure"): |
168 |
| - self.monitor_op = np.greater |
169 |
| - if self.best is None: |
170 |
| - self.best = -np.inf |
171 |
| - else: |
172 |
| - self.monitor_op = np.less |
173 |
| - if self.best is None: |
174 |
| - self.best = np.inf |
175 | 146 |
|
176 | 147 | if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
177 | 148 | raise ValueError(
|
@@ -205,6 +176,10 @@ def on_epoch_begin(self, epoch, logs=None):
|
205 | 176 | self._current_epoch = epoch
|
206 | 177 |
|
207 | 178 | def on_epoch_end(self, epoch, logs=None):
|
| 179 | + if self.monitor_op is None: |
| 180 | + # Delay setup until the model's metrics are all built |
| 181 | + self._set_monitor_op() |
| 182 | + |
208 | 183 | if self.save_freq == "epoch":
|
209 | 184 | self._save_model(epoch=epoch, batch=None, logs=logs)
|
210 | 185 |
|
@@ -262,7 +237,7 @@ def _should_save_model(self, epoch, batch, logs, filepath):
|
262 | 237 | )
|
263 | 238 | return True
|
264 | 239 | else:
|
265 |
| - if self.monitor_op(current, self.best): |
| 240 | + if self._is_improvement(current, self.best): |
266 | 241 | if self.verbose > 0:
|
267 | 242 | io_utils.print_msg(
|
268 | 243 | f"\nEpoch {epoch + 1}: {self.monitor} "
|
|
0 commit comments