Skip to content

Commit e908554

Browse files
ChromeHeartsOrbax Authors
authored andcommitted
Check JAX version when using jax.layout.Format.
PiperOrigin-RevId: 778111564
1 parent c41e83e commit e908554

File tree

8 files changed

+40
-13
lines changed

8 files changed

+40
-13
lines changed

checkpoint/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.11.18] - 2025-07-01
11+
12+
### Changed
13+
14+
- For JAX>=0.6.2, JAX layout.Layout renamed to layout.Format
15+
1016
## [0.11.17] - 2025-06-30
1117

1218
### Added

checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
DLL = layout.Layout
4343
else:
4444
DLL = layout.DeviceLocalLayout # type: ignore
45-
Format = layout.Format
45+
if jax.__version_info__ >= (0, 6, 2):
46+
Format = layout.Format
47+
else:
48+
Format = layout.Layout
4649
PyTree = Any
4750
SaveArgs = type_handlers.SaveArgs
4851
StandardRestoreArgs = standard_checkpoint_handler.StandardRestoreArgs
@@ -175,7 +178,7 @@ def test_custom_layout(self):
175178
if jax.__version_info__ >= (0, 6, 3)
176179
else arr.format.device_local_layout # type: ignore
177180
)
178-
custom_layout = Format(
181+
custom_layout = Format( # pytype: disable=wrong-keyword-args
179182
DLL(
180183
major_to_minor=arr_layout.major_to_minor[::-1], # pytype: disable=attribute-error
181184
_tiling=arr_layout._tiling, # pytype: disable=attribute-error

checkpoint/orbax/checkpoint/_src/serialization/serialization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
4343

4444
Index = types.Index
45-
Format = layout.Format
45+
if jax.__version_info__ >= (0, 6, 2):
46+
Format = layout.Format
47+
else:
48+
Format = layout.Layout
4649
Shape = types.Shape
4750

4851

@@ -457,7 +460,7 @@ async def _read_array_index_and_device_put(
457460
sharding = jax.sharding.SingleDeviceSharding(
458461
device, memory_kind=memory_kind
459462
)
460-
result.append(jax.device_put(shard, Format(dll, sharding)))
463+
result.append(jax.device_put(shard, Format(dll, sharding))) # pytype: disable=wrong-arg-types
461464
return result
462465

463466

@@ -508,7 +511,7 @@ async def read_and_create_array(
508511

509512

510513
async def async_deserialize(
511-
user_sharding: jax.sharding.Sharding | Format,
514+
user_sharding: jax.sharding.Sharding | Format, # pytype: disable=wrong-arg-types # pytype: disable=unsupported-operands
512515
tensorstore_spec: Union[ts.Spec, Dict[str, Any]],
513516
global_shape: Optional[Shape] = None,
514517
dtype: Optional[jnp.dtype] = None,

checkpoint/orbax/checkpoint/_src/serialization/serialization_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@
4141
DLL = layout.Layout
4242
else:
4343
DLL = layout.DeviceLocalLayout # type: ignore
44-
Format = layout.Format
45-
44+
if jax.__version_info__ >= (0, 6, 2):
45+
Format = layout.Format
46+
else:
47+
Format = layout.Layout
4648
jax.config.update('jax_enable_x64', True)
4749

4850

checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262
PLACEHOLDER = ...
6363
PLACEHOLDER_TYPESTR = 'placeholder'
6464

65-
Format = layout.Format
65+
if jax.__version_info__ >= (0, 6, 2):
66+
Format = layout.Format
67+
else:
68+
Format = layout.Layout
6669
Shape = arrays_types.Shape
6770
Scalar = Union[int, float, np.number]
6871
NamedSharding = jax.sharding.NamedSharding
@@ -812,9 +815,11 @@ class ArrayRestoreArgs(RestoreArgs):
812815
restore_type: Optional[Any] = jax.Array
813816
mesh: Optional[jax.sharding.Mesh] = None
814817
mesh_axes: Optional[jax.sharding.PartitionSpec] = None
815-
sharding: Optional[Union[jax.sharding.Sharding, ShardingMetadata, Format]] = (
818+
# pyformat: disable
819+
sharding: Optional[Union[jax.sharding.Sharding, ShardingMetadata, Format]] = ( # type: ignore[invalid-annotation]
816820
None
817821
)
822+
# pyformat: enable
818823
global_shape: Optional[Tuple[int, ...]] = None
819824
shape: Optional[Tuple[int, ...]] = None
820825
strict: bool = True

checkpoint/orbax/checkpoint/checkpoint_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
PyTree = Any
3636
STANDARD_ARRAY_TYPES = (int, float, np.ndarray, jax.Array)
3737
_SNAPSHOTS = '_SNAPSHOTS'
38-
Format = layout.Format
38+
if jax.__version_info__ >= (0, 6, 2):
39+
Format = layout.Format
40+
else:
41+
Format = layout.Layout
3942
PLACEHOLDER = type_handlers.PLACEHOLDER
4043

4144

@@ -428,7 +431,7 @@ def construct_restore_args(
428431

429432
def _array_restore_args(
430433
value: Any,
431-
sharding: Optional[jax.sharding.Sharding | Format],
434+
sharding: Optional[jax.sharding.Sharding | Format], # pytype: disable=unsupported-operands
432435
dtype: Optional[np.dtype] = None,
433436
) -> type_handlers.ArrayRestoreArgs:
434437
return type_handlers.ArrayRestoreArgs(

checkpoint/orbax/checkpoint/experimental/v1/_src/serialization/array_leaf_handler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939
ArrayDeserializationParam = types.DeserializationParam["AbstractArray"]
4040
Shape = arrays_types_v0.Shape
4141

42+
if jax.__version_info__ >= (0, 6, 2):
43+
Format = jax_layout.Format
44+
else:
45+
Format = jax_layout.Layout
46+
4247

4348
class AbstractArray(Protocol):
4449
"""Abstract representation of an array.
@@ -63,7 +68,7 @@ class AbstractArray(Protocol):
6368

6469
shape: Shape | None
6570
dtype: jax.numpy.dtype | None
66-
sharding: jax.sharding.Sharding | jax_layout.Format | None
71+
sharding: jax.sharding.Sharding | Format | None # pytype: disable=unsupported-operands
6772

6873

6974
@dataclasses.dataclass

checkpoint/orbax/checkpoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# A new PyPI release will be pushed everytime `__version__` is increased.
1818
# Also modify version and date in CHANGELOG.
1919
# LINT.IfChange
20-
__version__ = '0.11.17'
20+
__version__ = '0.11.18'
2121
# LINT.ThenChange(//depot//orbax/checkpoint/CHANGELOG.md)
2222

2323

0 commit comments

Comments
 (0)