Skip to content

Commit 81821e0

Browse files
authored
Revert "Fixed issue with dot_product_attention when using TPU. (keras-team#21254)" (keras-team#21329)
This reverts commit d8f3f70.
1 parent 86638c2 commit 81821e0

File tree

1 file changed

+42
-186
lines changed
  • keras/src/backend/jax

1 file changed

+42
-186
lines changed

keras/src/backend/jax/nn.py

Lines changed: 42 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,17 +1126,16 @@ def wrap_flash_attention(
11261126
decoder_segment_ids,
11271127
custom_mask=None,
11281128
attn_logits_soft_cap=None,
1129-
head_shards=1,
1130-
q_seq_shards=1,
11311129
):
11321130
if decoder_segment_ids is not None:
11331131
assert query.shape[2] == decoder_segment_ids.q.shape[1], (
1134-
"Sharding along sequence dimension not allowed"
1135-
" in TPU kernel attention"
1132+
"Sharding along sequence dimension not allowed in tpu kernel "
1133+
"attention"
11361134
)
11371135

11381136
if custom_mask is not None:
11391137
mask = splash_attention_mask.NumpyMask(array=custom_mask)
1138+
11401139
else:
11411140
mask = splash_attention_mask.CausalMask(
11421141
shape=(query.shape[2], query.shape[2])
@@ -1148,8 +1147,8 @@ def wrap_flash_attention(
11481147
)
11491148
splash_kernel = splash_attention_kernel.make_splash_mha(
11501149
mask=multi_head_mask,
1151-
head_shards=head_shards,
1152-
q_seq_shards=q_seq_shards,
1150+
head_shards=1,
1151+
q_seq_shards=1,
11531152
attn_logits_soft_cap=attn_logits_soft_cap,
11541153
)
11551154

@@ -1169,38 +1168,6 @@ def dot_product_attention(
11691168
flash_attention=None,
11701169
attn_logits_soft_cap=None,
11711170
):
1172-
"""Computes dot-product attention given query, key, and value.
1173-
1174-
This is the core computation of attention that is used in transformers.
1175-
For TPU platforms, flash attention optimizations are automatically applied
1176-
when possible, and sharding parameters are inferred from the layout map
1177-
in the current distribution context.
1178-
1179-
Args:
1180-
query: Queries with shape `[batch, time, heads,
1181-
depth_k]`.
1182-
key: Keys with shape `[batch, time, heads,
1183-
depth_k]`.
1184-
value: Values with shape `[batch, time, heads,
1185-
depth_v]`.
1186-
bias: Optional bias with shape broadcastable to
1187-
`[batch, heads, dest_time, source_time]`.
1188-
mask: Optional mask with shape broadcastable to
1189-
`[batch, heads, dest_time, source_time]`.
1190-
scale: Float. Optional scale that is applied to the attention
1191-
computation.
1192-
is_causal: Boolean. Specifying whether causal masking is applied.
1193-
flash_attention: Boolean. Whether to use flash attention optimization
1194-
for increased performance. Default to None, which means it will
1195-
be auto-determined based on the platform, input shapes and
1196-
compatibility.
1197-
attn_logits_soft_cap: Float. Optional float to softly cap attention
1198-
logits to avoid numerical stability issues. Applied as:
1199-
`logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.
1200-
1201-
Returns:
1202-
JAX Array of shape `[batch, time, heads, depth_v]`.
1203-
"""
12041171
query = convert_to_tensor(query)
12051172
key = convert_to_tensor(key)
12061173
value = convert_to_tensor(value)
@@ -1210,155 +1177,47 @@ def dot_product_attention(
12101177
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
12111178
f"value.shape={value.shape}."
12121179
)
1213-
1214-
# Check platform
1215-
platform = jax.devices()[0].platform
1216-
is_tpu = platform == "tpu"
1217-
1218-
# Get sharding parameters from distribution context
1219-
head_shards = 1
1220-
q_seq_shards = 1
1221-
1222-
if is_tpu:
1223-
try:
1224-
from keras.src.distribution.distribution_lib import ModelParallel
1225-
from keras.src.distribution.distribution_lib import (
1226-
distribution as get_dist,
1227-
)
1228-
1229-
# Get current distribution if available
1230-
dist = get_dist()
1231-
if dist and isinstance(dist, ModelParallel):
1232-
mesh = dist.device_mesh
1233-
if "model" in mesh.axis_names:
1234-
model_dim_index = mesh.axis_names.index("model")
1235-
# Set head_shards based on the model dimension of the mesh
1236-
head_shards = mesh.shape[model_dim_index]
1237-
# Typically keep q_seq_shards=1 for best performance
1238-
q_seq_shards = 1
1239-
except (ImportError, ValueError, AttributeError):
1240-
# Use default values if detection fails
1241-
head_shards = 1
1242-
q_seq_shards = 1
1243-
1244-
# Check if inputs use partial sharding (not fully replicated)
1245-
# Flash attention works well with fully replicated tensors on all platforms
1246-
# but may have issues with certain partial sharding patterns on non-TPU
1247-
# platforms
1248-
partially_sharded_inputs = any(
1249-
hasattr(t, "sharding") and not t.sharding.is_fully_replicated
1250-
for t in (query, key, value)
1251-
)
1252-
1253-
# Determine flash attention compatibility
12541180
if flash_attention is None:
1255-
# Auto-detect flash attention availability
1256-
if is_tpu:
1257-
# TPUs have specialized hardware for attention that works with any
1258-
# sharding pattern
1259-
flash_attention = True
1260-
else:
1261-
# For GPU/CPU with partially sharded inputs, we need
1262-
# multiple devices to efficiently handle the sharding
1263-
if partially_sharded_inputs and len(jax.devices()) <= 1:
1264-
flash_attention = False
1265-
else:
1266-
flash_attention = _can_use_flash_attention(
1267-
query, key, value, bias
1268-
)
1269-
elif flash_attention is True and not is_tpu:
1270-
# If flash attention is explicitly requested, validate compatibility
1271-
# Skip validation for TPU as it has specialized hardware support
1272-
try:
1273-
_can_use_flash_attention(query, key, value, bias, raise_error=True)
1274-
except Exception:
1275-
# Only disable flash attention on non-TPU platforms
1276-
# if validation fails
1277-
flash_attention = False
1278-
1279-
# TPU-specific flash attention path
1280-
if is_tpu and flash_attention:
1281-
# Transpose to ('batch', 'heads', 'length', 'head_dim')
1282-
query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
1283-
key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
1284-
value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3))
1285-
1286-
bs, num_heads, q_len, head_dim = query_tpu_layout.shape
1287-
1288-
# Apply scale to query if provided
1289-
if scale is not None:
1290-
# TPU kernel applies 1/sqrt(head_dim) internally, to achieve
1291-
# overall QK^T * scale, scale query by (scale * sqrt(head_dim))
1292-
query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))
1293-
1294-
# Create segment IDs for Splash Attention (for packing/batching)
1295-
segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32)
1296-
decoder_segment_ids = splash_attention_kernel.SegmentIds(
1297-
q=segment_ids, kv=segment_ids
1298-
)
1181+
flash_attention = _can_use_flash_attention(query, key, value, bias)
1182+
elif flash_attention is True:
1183+
# Use `raise_error=True` to provide more details if the inputs failed to
1184+
# use flash attention
1185+
_can_use_flash_attention(query, key, value, bias, raise_error=True)
12991186

1300-
# Process mask for Splash Attention
1301-
custom_mask = None
1302-
if mask is not None:
1303-
mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask
1304-
1305-
if mask_bool.ndim == 3 and mask_bool.shape[0] == bs:
1306-
custom_mask = mask_bool[0]
1307-
elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs:
1308-
custom_mask = mask_bool[0, 0]
1309-
1310-
if is_causal and custom_mask is not None:
1311-
causal_mask = jnp.tril(
1312-
jnp.ones((q_len, q_len), dtype=jnp.bool_)
1313-
)
1314-
custom_mask = jnp.logical_and(custom_mask, causal_mask)
1315-
1316-
if custom_mask is None and is_causal:
1317-
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
1318-
1319-
try:
1320-
output = wrap_flash_attention(
1321-
query_tpu_layout,
1322-
key_tpu_layout,
1323-
value_tpu_layout,
1324-
decoder_segment_ids=decoder_segment_ids,
1325-
custom_mask=custom_mask,
1326-
attn_logits_soft_cap=attn_logits_soft_cap,
1327-
head_shards=head_shards,
1328-
q_seq_shards=q_seq_shards,
1329-
)
1330-
# Transpose output back to Keras layout
1331-
return jnp.transpose(output, axes=(0, 2, 1, 3))
1332-
except Exception:
1333-
flash_attention = False
1187+
if jax.devices()[0].platform == "tpu":
1188+
# Transpose to ('batch', 'heads', 'length', 'kv')
1189+
query = jnp.transpose(query, axes=(0, 2, 1, 3))
1190+
key = jnp.transpose(key, axes=(0, 2, 1, 3))
1191+
value = jnp.transpose(value, axes=(0, 2, 1, 3))
1192+
B, H, S, KV = query.shape
1193+
1194+
segment_ids = jnp.ones([B, S])
1195+
# {token_ids, padding_mask, segment_ids} enable packing
1196+
out = wrap_flash_attention(
1197+
query,
1198+
key,
1199+
value,
1200+
decoder_segment_ids=splash_attention_kernel.SegmentIds(
1201+
segment_ids, segment_ids
1202+
),
1203+
custom_mask=mask,
1204+
attn_logits_soft_cap=attn_logits_soft_cap,
1205+
)
1206+
out = jnp.transpose(out, axes=(0, 2, 1, 3))
1207+
return out
13341208

1335-
# JAX native dot_product_attention for GPU or fallback for TPU
1209+
# `dot_product_attention` is only available in jax>=0.4.31
13361210
if hasattr(jax.nn, "dot_product_attention"):
1337-
try:
1338-
return jax.nn.dot_product_attention(
1339-
query,
1340-
key,
1341-
value,
1342-
bias=bias,
1343-
mask=mask,
1344-
scale=scale,
1345-
is_causal=is_causal,
1346-
implementation="cudnn" if flash_attention else "xla",
1347-
)
1348-
except Exception:
1349-
# If flash attention fails, fall back to XLA implementation
1350-
if flash_attention:
1351-
return jax.nn.dot_product_attention(
1352-
query,
1353-
key,
1354-
value,
1355-
bias=bias,
1356-
mask=mask,
1357-
scale=scale,
1358-
is_causal=is_causal,
1359-
implementation="xla",
1360-
)
1361-
raise
1211+
return jax.nn.dot_product_attention(
1212+
query,
1213+
key,
1214+
value,
1215+
bias=bias,
1216+
mask=mask,
1217+
scale=scale,
1218+
is_causal=is_causal,
1219+
implementation="cudnn" if flash_attention else "xla",
1220+
)
13621221

13631222
if flash_attention:
13641223
raise RuntimeError(
@@ -1369,9 +1228,6 @@ def dot_product_attention(
13691228
# Ref: jax.nn.dot_product_attention
13701229
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
13711230
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
1372-
1373-
# Fallback to custom XLA implementation
1374-
# This is the reference implementation from jax.nn.dot_product_attention
13751231
output_shape = query.shape
13761232
_, _, K, H = key.shape
13771233
scale = (1.0 / jnp.sqrt(H)) if scale is None else scale

0 commit comments

Comments
 (0)