Skip to content

Counter intuitive behavior of nn.CrossEntropy/nn.NLLLoss with weights and issue with gradient accumulation #72047

@idc9

Description

@idc9

🚀 The feature, motivation and pitch

The behavior of mean reduction in nn.CrossEntropy and nn.NLLLoss is counter intuitive when there are class weights as discussed in #9882. The current behavior performs a weighted average instead of an unweighted average, which is probably what people expect.

This counter intuitive behavior also causes an issue when doing gradient accumulation. In particular, when you adjust the loss function to account for gradient accumulation (i.e. to make the divisor batch_size x n_grad_accum_steps instead of just batch_size) you no longer have the exact gradients (i.e. the gradients you would have had if your batch size was batch_size x n_grad_accum_steps).

# Example gradient accumulation code
# loader, model, optimizer, weights set above 

# unweighted case -- no issue for gradient accumulation
loss_func = nn.CrossEntropyLoss(reduction='mean')

# weighted case with an issue
# loss_func = nn.CrossEntropyLoss(weights=weights, reduction='mean')

# this loop assumes everything is nicely divisible -- it can be modified to handle when
# num batches is not divisible by grad_accum and num_samples is not divisible by batch_size
for batch_idx, (x, y_true) in enumerate(loader):
    y_pred = model(x)
    loss = loss_func(y_true, y_pred) 
    
    # adjust for gradient accumulation
    # this gives you exact gradients in the unweighted case, but not in the weighted case!
    loss = loss / n_grad_accum_batches
    
    loss.backward()
    if (batch_idx + 1) % n_grad_accum_batches == 0:
        optimizer.step()
        optimizer.zero_grad()

You can of course address this issue if you use reduction=sum and manually averaging the loss, but this is clunky and probably frequently overlooked.

Possible solution

A straightforward solution and the most intuitive -- at least to me -- would be

  1. make reduction='mean' perform an unweighted average
  2. introduce reduction='weighted_mean' for weighted averages (current behavior of mean)
  3. default to to the unweighted mean case, which is probably what users expect to happen

Alternatives

No response

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: lossProblem is related to loss functionmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions