Skip to content

Commit 813fbc5

Browse files
authored
Clarification: new mean option is synonym with sum_over_batch_size in loss function base class (#20352)
* mean as a synonym of sum_over_batch_size * mean synonym with sum_over_batch_size * typo fix * Update loss.py Undid ruff format while keeping typo fixes * reduction in ("mean", "sum_over_batch_size") loss.py
1 parent 3de677e commit 813fbc5

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

keras/src/losses/loss.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class Loss(KerasSaveable):
1717
Args:
1818
reduction: Type of reduction to apply to the loss. In almost all cases
1919
this should be `"sum_over_batch_size"`.
20-
Supported options are `"sum"`, `"sum_over_batch_size"` or `None`.
20+
Supported options are `"sum"`, `"sum_over_batch_size"`, `"mean"`
21+
or `None`.
2122
name: Optional name for the loss instance.
2223
dtype: The dtype of the loss's computations. Defaults to `None`, which
2324
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
@@ -95,7 +96,7 @@ def _obj_type(self):
9596

9697

9798
def standardize_reduction(reduction):
98-
allowed = {"sum_over_batch_size", "sum", None, "none"}
99+
allowed = {"sum_over_batch_size", "sum", None, "none", "mean"}
99100
if reduction not in allowed:
100101
raise ValueError(
101102
"Invalid value for argument `reduction`. "
@@ -135,7 +136,7 @@ def reduce_values(values, reduction="sum_over_batch_size"):
135136
):
136137
return values
137138
loss = ops.sum(values)
138-
if reduction == "sum_over_batch_size":
139+
if reduction in ("mean", "sum_over_batch_size"):
139140
loss /= ops.cast(
140141
ops.prod(ops.convert_to_tensor(ops.shape(values), dtype="int32")),
141142
loss.dtype,
@@ -180,7 +181,7 @@ def apply_mask(sample_weight, mask, dtype, reduction):
180181
"""Applies any mask on predictions to sample weights."""
181182
if mask is not None:
182183
mask = ops.cast(mask, dtype=dtype)
183-
if reduction == "sum_over_batch_size":
184+
if reduction in ("mean", "sum_over_batch_size"):
184185
# Valid entries have weight `total/valid`, while invalid ones
185186
# have 0. When summed over batch, they will be reduced to:
186187
#

keras/src/losses/loss_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_reduction(self):
6969
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")
7070
self.assertAllClose(np.sum((y_true - y_pred) ** 2), loss)
7171

72-
# sum_over_batch_size
72+
# sum_over_batch_size or mean
7373
loss_fn = ExampleLoss(reduction="sum_over_batch_size")
7474
loss = loss_fn(y_true, y_pred)
7575
self.assertEqual(backend.standardize_dtype(loss.dtype), "float32")

0 commit comments

Comments
 (0)