Skip to content

Commit eb1fb3a

Browse files
jeffdailymsaroufim
andauthored
[ROCm] use dataclass for fnuz type setting (#1142)
* [ROCm] use dataclass for fnuz type setting * update float8 tests to use type alias * fix lint --------- Co-authored-by: Mark Saroufim <[email protected]>
1 parent c8f1174 commit eb1fb3a

File tree

4 files changed

+55
-38
lines changed

4 files changed

+55
-38
lines changed

test/float8/test_base.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525

2626
from torchao.float8.config import (
27-
CastConfig,
28-
Float8LinearConfig,
27+
CastConfig,
28+
Float8LinearConfig,
2929
ScalingGranularity,
3030
ScalingType,
3131
Float8LinearRecipeName,
@@ -109,15 +109,15 @@ def test_split_cat(self):
109109

110110
def test_index_put(self):
111111
a = torch.rand(16, dtype=torch.bfloat16)
112-
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
113-
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
112+
scale_a = tensor_to_scale(a, e4m3_dtype)
113+
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
114114

115115
index = torch.randint(0, 15, (16,), dtype=torch.long)
116116

117117
b = torch.rand(16, 16, dtype=torch.bfloat16)
118-
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
119-
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
120-
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)
118+
scale_b = tensor_to_scale(b, e4m3_dtype)
119+
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype)
120+
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype)
121121

122122
with pytest.raises(AssertionError):
123123
b[index] = fp8_a
@@ -127,8 +127,8 @@ def test_index_put(self):
127127

128128
def test_copy_(self):
129129
a = torch.rand(16, dtype=torch.bfloat16)
130-
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
131-
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
130+
scale_a = tensor_to_scale(a, e4m3_dtype)
131+
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
132132

133133
b = torch.empty(16, dtype=torch.bfloat16)
134134
b.copy_(fp8_a) # Should work
@@ -137,7 +137,7 @@ def test_copy_(self):
137137
fp8_a.copy_(b) # Should fail
138138

139139
fp8_b = Float8Tensor(
140-
torch.empty(16, dtype=torch.float8_e4m3fn),
140+
torch.empty(16, dtype=e4m3_dtype),
141141
scale_a,
142142
torch.bfloat16,
143143
fp8_a._linear_mm_config,
@@ -332,11 +332,11 @@ def _test_linear_impl(
332332
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
333333
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
334334
@pytest.mark.parametrize(
335-
"scaling_type_input",
335+
"scaling_type_input",
336336
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
337337
)
338338
@pytest.mark.parametrize(
339-
"scaling_type_weight",
339+
"scaling_type_weight",
340340
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
341341
)
342342
@pytest.mark.parametrize(
@@ -377,7 +377,7 @@ def test_linear_from_config_params(
377377
# to combine with the main testing function.
378378
# TODO(future PR): make this cleaner.
379379
@pytest.mark.parametrize(
380-
"recipe_name",
380+
"recipe_name",
381381
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
382382
)
383383
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@@ -610,7 +610,7 @@ def test_different_configs_error(self):
610610
@pytest.mark.parametrize("use_fast_accum", [True, False])
611611
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
612612
torch.manual_seed(42)
613-
input_dtype = torch.float8_e4m3fn
613+
input_dtype = e4m3_dtype
614614
compare_type = torch.float32
615615

616616
a = torch.randn(16, 41, device="cuda", dtype=base_dtype)

test/float8/test_compile.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import torch
2121
import torch.nn as nn
2222
from torchao.float8.config import (
23-
CastConfig,
24-
Float8LinearConfig,
25-
ScalingType,
23+
CastConfig,
24+
Float8LinearConfig,
25+
ScalingType,
2626
Float8LinearRecipeName,
2727
recipe_name_to_linear_config,
2828
)
@@ -77,7 +77,7 @@ def _test_compile_base(
7777
y_fp8.sum().backward()
7878
y_ref = m_ref(x_ref)
7979
y_ref.sum().backward()
80-
# TODO(future PR): can also test fp8 eager vs compile here with a tigher
80+
# TODO(future PR): can also test fp8 eager vs compile here with a tigher
8181
# tolerance
8282
torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2)
8383
torch.testing.assert_close(
@@ -199,7 +199,7 @@ def test_inductor_from_config_params(
199199
# to combine with the main testing function.
200200
# TODO(future PR): make this cleaner.
201201
@pytest.mark.parametrize(
202-
"recipe_name",
202+
"recipe_name",
203203
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
204204
)
205205
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
@@ -412,14 +412,14 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
412412
)
413413
float8_eager = hp_tensor_to_float8_dynamic(
414414
hp_tensor1,
415-
torch.float8_e4m3fn,
415+
e4m3_dtype,
416416
linear_mm_config,
417417
gemm_input_role=GemmInputRole.WEIGHT,
418418
)
419419
torch._dynamo.reset()
420420
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
421421
hp_tensor2,
422-
torch.float8_e4m3fn,
422+
e4m3_dtype,
423423
linear_mm_config,
424424
gemm_input_role=GemmInputRole.WEIGHT,
425425
)

torchao/float8/config.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,29 @@ def __post_init__(self):
9696
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
9797

9898

99+
@dataclass
100+
class Float8TypeConfig:
101+
"""
102+
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
103+
104+
Currently, ROCm only supports fnuz variants.
105+
"""
106+
107+
# The preferred e4m3 type.
108+
e4m3_dtype = torch.float8_e4m3fn
109+
110+
# The preferred e5m2 type.
111+
e5m2_dtype = torch.float8_e5m2
112+
113+
def __post_init__(self):
114+
if torch.version.hip:
115+
prop = torch.cuda.get_device_properties(0)
116+
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
117+
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
118+
self.e4m3_dtype = torch.float8_e4m3fnuz
119+
self.e5m2_dtype = torch.float8_e5m2fnuz
120+
121+
99122
@dataclass(frozen=True)
100123
class Float8GemmConfig:
101124
"""
@@ -118,11 +141,11 @@ class Float8LinearConfig:
118141
# Per-tensor configuration for casting of `input`, `weight`, `grad_output`
119142
# for the operands of gemms calculating `output`, `grad_weight`, and `grad_input`.
120143
#
121-
# Note:
122-
# 1. if `cast_config_input_for_grad_weight` is None, then
144+
# Note:
145+
# 1. if `cast_config_input_for_grad_weight` is None, then
123146
# `cast_config_input` is used for scaling `input` for both gemms that
124-
# use `input.
125-
# 2. if `cast_config_input_for_grad_weight` is specified, then
147+
# use `input.
148+
# 2. if `cast_config_input_for_grad_weight` is specified, then
126149
# a. `cast_config_input` is used for scaling `input` for the gemm that calculates
127150
# `output`
128151
# b. `cast_config_input_for_grad_weight` is used for scaling `input` for
@@ -240,12 +263,6 @@ def __post_init__(self):
240263
f"incompatible operand precision for {gemm_name}"
241264

242265

243-
# If True, use 'fnuz' float8 types for calculations.
244-
# Currently, ROCm only supports fnuz variants.
245-
# TODO(future PR): move this to Float8LinearConfig
246-
use_fnuz_dtype = False
247-
248-
249266
# Pre-made recipes for common configurations
250267
# TODO(future PR): go through a round of design on this, and eventually expose
251268
# as a top level public API.
@@ -272,7 +289,7 @@ def recipe_name_to_linear_config(
272289
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
273290
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
274291
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
275-
292+
276293
# The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
277294
# fast with `use_fast_accum=True`. Note that rowwise scaling is more
278295
# accurate than tensorwise scaling, so the overall impact on accuracy
@@ -300,8 +317,8 @@ def recipe_name_to_linear_config(
300317
#
301318
# key characteristics:
302319
# * increased accuracy for grad_weight
303-
# * `input`, `weight` and `grad_output` now only need to be scaled
304-
# axiswise across a single dim compared to vanilla all-axiswise,
320+
# * `input`, `weight` and `grad_output` now only need to be scaled
321+
# axiswise across a single dim compared to vanilla all-axiswise,
305322
# which is more amenable to fast kernels
306323

307324
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1

torchao/float8/float8_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import torch
1010
import torch.distributed as dist
1111

12-
import torchao.float8.config as config
13-
from torchao.float8.config import ScalingGranularity
12+
from torchao.float8.config import Float8TypeConfig, ScalingGranularity
1413

1514
# Helpful visualizer for debugging (only supports fp32):
1615
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
@@ -29,8 +28,9 @@
2928

3029

3130
# User defined type for using the individual F8 type based on config
32-
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
33-
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
31+
type_config = Float8TypeConfig()
32+
e4m3_dtype = type_config.e4m3_dtype
33+
e5m2_dtype = type_config.e5m2_dtype
3434

3535

3636
@torch.no_grad()

0 commit comments

Comments
 (0)