Skip to content

Conversation

2015aroras
Copy link
Contributor

@2015aroras 2015aroras commented Sep 9, 2025

Purpose

This PR adds the implementation for the upcoming Olmo 3 model. The HF implementation is being added concurrently, so the PR includes the config too.

Test Plan

The test plan is to see that basic generation (via examples/offline_inference/basic/generate.py) produces sensible output. I cannot run HF vs vLLM (in a shareable manner) because the HF implementation is being added concurrently. Nevertheless, I used a custom script to do HF vs vLLM and saw only minors errors (that would eventually propagate to be larger) with identical output.

Test Result

Result of running examples/offline_inference/basic/generate.py:

--------------------------------------------------
Prompt: 'Hello, my name is'
Generated text: ' Helen and I downloaded naproxen to treat my symptoms of acute pain. I'
--------------------------------------------------
Prompt: 'The president of the United States is'
Generated text: " said to be the 'first among equals' of these individuals, although most congress"
--------------------------------------------------
Prompt: 'The capital of France is'
Generated text: ' Paris. It is the capital city and the administrative capital of France.\nIt is'
--------------------------------------------------
Prompt: 'The future of AI is'
Generated text: ' an exciting one. One of the most significant opportunities in content creation is the use'

Excerpt of diff between HF and vLLM activations (using a custom script).

Input: San Francisco is 1 of 5
vLLM output:  cities in the U.S. with a population of over 1 million. It
HF output:  cities in the U.S. with a population of over 1 million. It
vLLM and HF output are the same!
Coord diff abs mean for key model.embed_tokens|input (HF/vllm 0/0) 0.0
Coord diff abs mean for key model.embed_tokens|output (HF/vllm 1/1) 0.0
Coord diff abs mean for key model.layers.0.self_attn.q_proj|input (HF/vllm 2/2) 0.0
Coord diff abs mean for key model.layers.0.self_attn.q_proj|output (HF/vllm 3/3) 0.0
Coord diff abs mean for key model.layers.0.self_attn.q_norm|input (HF/vllm 4/8) 0.0
Coord diff abs mean for key model.layers.0.self_attn.q_norm|output (HF/vllm 5/9) 0.00022692049969919026
Coord diff abs mean for key model.layers.0.self_attn.k_proj|input (HF/vllm 6/4) 0.0
Coord diff abs mean for key model.layers.0.self_attn.k_proj|output (HF/vllm 7/5) 0.0
Coord diff abs mean for key model.layers.0.self_attn.k_norm|input (HF/vllm 8/10) 0.0
Coord diff abs mean for key model.layers.0.self_attn.k_norm|output (HF/vllm 9/11) 0.00021613968419842422
Coord diff abs mean for key model.layers.0.self_attn.v_proj|input (HF/vllm 10/6) 0.0
Coord diff abs mean for key model.layers.0.self_attn.v_proj|output (HF/vllm 11/7) 0.0
Coord diff abs mean for key model.layers.0.self_attn.o_proj|input (HF/vllm 12/18) 7.689034100621939e-05
Coord diff abs mean for key model.layers.0.self_attn.o_proj|output (HF/vllm 13/19) 0.0003829110355582088
Coord diff abs mean for key model.layers.0.self_attn|input (HF/vllm 14/20) 0.0
Coord diff abs mean for key model.layers.0.self_attn|output (HF/vllm 15/21) 0.0003829110355582088
Coord diff abs mean for key model.layers.0.post_attention_layernorm|input (HF/vllm 16/22) 0.0003829110355582088
Coord diff abs mean for key model.layers.0.post_attention_layernorm|output (HF/vllm 17/23) 3.728982846951112e-05
Coord diff abs mean for key model.layers.0.mlp.gate_proj|input (HF/vllm 18/24) 3.6958059354219586e-05
Coord diff abs mean for key model.layers.0.mlp.gate_proj|output (HF/vllm 19/25) 0.00031039363238960505
Shape mismatch for key model.layers.0.mlp.act_fn|input, hf torch.Size([8, 11008]) vllm torch.Size([8, 22016])
Coord diff abs mean for key model.layers.0.mlp.act_fn|input (HF/vllm 20/28) 0.00031039363238960505
Coord diff abs mean for key model.layers.0.mlp.act_fn|output (HF/vllm 21/29) 0.04843476042151451
Coord diff abs mean for key model.layers.0.mlp.up_proj|input (HF/vllm 22/26) 3.6958059354219586e-05
Coord diff abs mean for key model.layers.0.mlp.up_proj|output (HF/vllm 23/27) 0.00015152627020142972
Coord diff abs mean for key model.layers.0.mlp.down_proj|input (HF/vllm 24/30) 1.5119816453079693e-05
Coord diff abs mean for key model.layers.0.mlp.down_proj|output (HF/vllm 25/31) 7.36563524696976e-05
Coord diff abs mean for key model.layers.0.mlp|input (HF/vllm 26/32) 3.6958059354219586e-05
Coord diff abs mean for key model.layers.0.mlp|output (HF/vllm 27/33) 7.36563524696976e-05
Coord diff abs mean for key model.layers.0.post_feedforward_layernorm|input (HF/vllm 28/34) 7.36563524696976e-05
Coord diff abs mean for key model.layers.0.post_feedforward_layernorm|output (HF/vllm 29/35) 5.175209662411362e-05
Coord diff abs mean for key model.layers.0|input (HF/vllm 30/36) 0.0
Coord diff abs mean for key model.layers.0|output (HF/vllm 31/37) 7.400968024739996e-05
Coord diff abs mean for key model.layers.1.self_attn.q_proj|input (HF/vllm 32/38) 7.400968024739996e-05
Coord diff abs mean for key model.layers.1.self_attn.q_proj|output (HF/vllm 33/39) 0.00011392683518351987
...

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models labels Sep 9, 2025
@2015aroras 2015aroras marked this pull request as ready for review September 9, 2025 21:11
Comment on lines 116 to 141
layer_idx = extract_layer_index(prefix)
sliding_window = (self.config.sliding_window
if self.config.layer_types[layer_idx]
== "sliding_attention" else None)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
)

# Rotary embeddings. Rope scaling is only applied on full attention
# layers.
self.rope_scaling = (self.config.rope_scaling
if sliding_window is None else None)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta, # type: ignore
rope_scaling=self.rope_scaling,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if I'm wrong. But seems the only difference between Olmo2 and Olmo3 is the introduction of sliding window? If so, I think we can simply modify Olmo2's attention implementation to make it fit both Olmo2 and Olmo3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep that's right. I'll try to merge the Olmo3 logic into the existing Olmo2 code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

61b1b26 I've updated the Olmo2 logic to support both Olmo2 and Olmo3. The sliding window settings are new and so I've kept the Olmo3Config class.

The transformers PR is out (huggingface/transformers#40778), but transformers folks haven't reviewed or expressed yet if they also want Olmo3 to be part of Olmo2. This vllm implementation should hopefully be compatible with the transformers implementation regardless of which option they choose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're content with Olmo3 being a separate model implementation to Olmo2.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think how Transformers organizes modeling file for olmo3 won't be an issue for us (they may request using modular modeling to inherit from olmo2 to create a separated class), as long as the config doesn't change very much.

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep the config fork temporarily before Transformers PR merged. Then we can clean this up after that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's what I had in mind! My local testing indicates that this PR works for transformers version with and without Olmo3.

@Isotr0py Isotr0py enabled auto-merge (squash) September 11, 2025 08:31
@Isotr0py Isotr0py disabled auto-merge September 11, 2025 08:31
@Isotr0py Isotr0py enabled auto-merge (squash) September 11, 2025 08:31
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 11, 2025
trust_remote_code=True),
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
"Olmo3ForCausalLM": _HfExamplesInfo("shanearora/2025-sep-a-base-model"),
Copy link
Member

@DarkLight1337 DarkLight1337 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add is_available_online=False (if the repo isn't available yet) and/or min_transformers_version (if the model isn't supported by the current version)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9542485 The problem was that I put olmo3 instead of olmo2 in the registry, which I have fixed in this commit. The above two solutions don't apply since the repo is public and this implementation is intended to work even before the transformers implementation is released.

auto-merge was automatically disabled September 12, 2025 17:23

Head branch was pushed to by a user without write access

@Isotr0py Isotr0py enabled auto-merge (squash) September 13, 2025 02:07
@Isotr0py Isotr0py merged commit 89e08d6 into vllm-project:main Sep 13, 2025
50 checks passed
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Sep 14, 2025
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
bbartels pushed a commit to bbartels/vllm that referenced this pull request Sep 15, 2025
Signed-off-by: Shane A <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: bbartels <[email protected]>
cboss6 pushed a commit to cboss6/vllm that referenced this pull request Sep 16, 2025
Signed-off-by: Shane A <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
cboss6 pushed a commit to cboss6/vllm that referenced this pull request Sep 16, 2025
Signed-off-by: Shane A <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants