Skip to content

Conversation

gagika
Copy link
Collaborator

@gagika gagika commented Jul 26, 2025

Description

Fix Llama4 attention flops to only count for chunk attention window size for chunk attention layers.

Tests

Comparing total_tflops, learnable_weight_tflops, attention_tflops before and after the fix:

https://diff.googleplex.com/#key=wo0bVmy9wbUc

learnable_weight_tflops are matching
attention_tflops are matching for context length <= 8192
For context length 2* 8192, old attention flops was increasing 4 x over 8192, now
it's increasing 10/4 x (3 chunk attention increasing 2x and 1 global attention increasing 4x).
1187* 10/4 = 2968

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# FLOPs for a single global attention layer (full attention, non-causal)
global_attention_flops_per_layer = 4 * config.per_device_batch_size * seq_len**2 * config.num_query_heads * config.head_dim

# FLOPs for a single chunked attention layer (non-causal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we separate out chunked attention flops into its own method?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Gagik!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants