-
Notifications
You must be signed in to change notification settings - Fork 23
Replace constant tensor initialisation with different function calls #561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Should element inits also adopt the new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes PyTorch tensor initialization by replacing verbose torch.tensor()
and torch.zeros()
calls with more efficient alternatives like tensor.new_zeros()
and torch.broadcast_tensors()
, which reduces overhead by reusing device and dtype from existing tensors.
- Replace
torch.tensor(0.0, device=..., dtype=...)
withtensor.new_zeros(())
- Replace
torch.zeros(..., device=..., dtype=...)
withtensor.new_zeros(shape)
- Replace
torch.ones_like()
withtorch.broadcast_tensors()
where appropriate
Reviewed Changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
cheetah/utils/bmadx.py | Replace tensor initialization with new_ones() |
cheetah/track_methods.py | Replace zero tensor creation with new_zeros() |
cheetah/accelerator/space_charge_kick.py | Replace multiple torch.zeros() calls with new_zeros() |
cheetah/accelerator/sextupole.py | Replace zero tensor constants with new_zeros() |
cheetah/accelerator/screen.py | Replace torch.zeros() with new_zeros() |
cheetah/accelerator/quadrupole.py | Replace zero constants and torch.ones_like() with optimized alternatives |
cheetah/accelerator/drift.py | Replace multiple zero tensor creations with single new_zeros() |
cheetah/accelerator/dipole.py | Replace zero tensor constants and fix tensor shape issue |
cheetah/accelerator/cavity.py | Replace torch.zeros_like() and torch.full_like() with new_zeros() |
cheetah/accelerator/bpm.py | Move buffer registration after misalignment initialization |
CONTRIBUTING.md | Add performance guidelines for tensor creation |
CHANGELOG.md | Document performance improvements |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
||
self.register_buffer( | ||
"reading", | ||
torch.tensor((torch.nan, torch.nan), **factory_kwargs), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buffer registration should use self.misalignment.new_full((2,), torch.nan)
instead of torch.tensor((torch.nan, torch.nan), **factory_kwargs)
to be consistent with the PR's optimization goals and ensure the tensor uses the same device/dtype as other parameters.
torch.tensor((torch.nan, torch.nan), **factory_kwargs), | |
self.misalignment.new_full((2,), torch.nan), |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This really depends on the answer to my earlier question. I'm thinking the answer is probably no, let's not do that.
Co-authored-by: Copilot <[email protected]>
Description
Sub- / replacement PR for #538.
Where possible, replace new initialisations of tensors with
another_tensor.new_xxx
calls. If the goal is to reproduce the shape on another tensor,torch.xxx_like
is used.Motivation and Context
Types of changes
Checklist
flake8
(required).pytest
tests pass (required).pytest
on a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line.