@@ -1126,17 +1126,16 @@ def wrap_flash_attention(
1126
1126
decoder_segment_ids ,
1127
1127
custom_mask = None ,
1128
1128
attn_logits_soft_cap = None ,
1129
- head_shards = 1 ,
1130
- q_seq_shards = 1 ,
1131
1129
):
1132
1130
if decoder_segment_ids is not None :
1133
1131
assert query .shape [2 ] == decoder_segment_ids .q .shape [1 ], (
1134
- "Sharding along sequence dimension not allowed"
1135
- " in TPU kernel attention"
1132
+ "Sharding along sequence dimension not allowed in tpu kernel "
1133
+ "attention"
1136
1134
)
1137
1135
1138
1136
if custom_mask is not None :
1139
1137
mask = splash_attention_mask .NumpyMask (array = custom_mask )
1138
+
1140
1139
else :
1141
1140
mask = splash_attention_mask .CausalMask (
1142
1141
shape = (query .shape [2 ], query .shape [2 ])
@@ -1148,8 +1147,8 @@ def wrap_flash_attention(
1148
1147
)
1149
1148
splash_kernel = splash_attention_kernel .make_splash_mha (
1150
1149
mask = multi_head_mask ,
1151
- head_shards = head_shards ,
1152
- q_seq_shards = q_seq_shards ,
1150
+ head_shards = 1 ,
1151
+ q_seq_shards = 1 ,
1153
1152
attn_logits_soft_cap = attn_logits_soft_cap ,
1154
1153
)
1155
1154
@@ -1169,38 +1168,6 @@ def dot_product_attention(
1169
1168
flash_attention = None ,
1170
1169
attn_logits_soft_cap = None ,
1171
1170
):
1172
- """Computes dot-product attention given query, key, and value.
1173
-
1174
- This is the core computation of attention that is used in transformers.
1175
- For TPU platforms, flash attention optimizations are automatically applied
1176
- when possible, and sharding parameters are inferred from the layout map
1177
- in the current distribution context.
1178
-
1179
- Args:
1180
- query: Queries with shape `[batch, time, heads,
1181
- depth_k]`.
1182
- key: Keys with shape `[batch, time, heads,
1183
- depth_k]`.
1184
- value: Values with shape `[batch, time, heads,
1185
- depth_v]`.
1186
- bias: Optional bias with shape broadcastable to
1187
- `[batch, heads, dest_time, source_time]`.
1188
- mask: Optional mask with shape broadcastable to
1189
- `[batch, heads, dest_time, source_time]`.
1190
- scale: Float. Optional scale that is applied to the attention
1191
- computation.
1192
- is_causal: Boolean. Specifying whether causal masking is applied.
1193
- flash_attention: Boolean. Whether to use flash attention optimization
1194
- for increased performance. Default to None, which means it will
1195
- be auto-determined based on the platform, input shapes and
1196
- compatibility.
1197
- attn_logits_soft_cap: Float. Optional float to softly cap attention
1198
- logits to avoid numerical stability issues. Applied as:
1199
- `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.
1200
-
1201
- Returns:
1202
- JAX Array of shape `[batch, time, heads, depth_v]`.
1203
- """
1204
1171
query = convert_to_tensor (query )
1205
1172
key = convert_to_tensor (key )
1206
1173
value = convert_to_tensor (value )
@@ -1210,155 +1177,47 @@ def dot_product_attention(
1210
1177
f"Received: query.shape={ query .shape } , key.shape={ key .shape } , "
1211
1178
f"value.shape={ value .shape } ."
1212
1179
)
1213
-
1214
- # Check platform
1215
- platform = jax .devices ()[0 ].platform
1216
- is_tpu = platform == "tpu"
1217
-
1218
- # Get sharding parameters from distribution context
1219
- head_shards = 1
1220
- q_seq_shards = 1
1221
-
1222
- if is_tpu :
1223
- try :
1224
- from keras .src .distribution .distribution_lib import ModelParallel
1225
- from keras .src .distribution .distribution_lib import (
1226
- distribution as get_dist ,
1227
- )
1228
-
1229
- # Get current distribution if available
1230
- dist = get_dist ()
1231
- if dist and isinstance (dist , ModelParallel ):
1232
- mesh = dist .device_mesh
1233
- if "model" in mesh .axis_names :
1234
- model_dim_index = mesh .axis_names .index ("model" )
1235
- # Set head_shards based on the model dimension of the mesh
1236
- head_shards = mesh .shape [model_dim_index ]
1237
- # Typically keep q_seq_shards=1 for best performance
1238
- q_seq_shards = 1
1239
- except (ImportError , ValueError , AttributeError ):
1240
- # Use default values if detection fails
1241
- head_shards = 1
1242
- q_seq_shards = 1
1243
-
1244
- # Check if inputs use partial sharding (not fully replicated)
1245
- # Flash attention works well with fully replicated tensors on all platforms
1246
- # but may have issues with certain partial sharding patterns on non-TPU
1247
- # platforms
1248
- partially_sharded_inputs = any (
1249
- hasattr (t , "sharding" ) and not t .sharding .is_fully_replicated
1250
- for t in (query , key , value )
1251
- )
1252
-
1253
- # Determine flash attention compatibility
1254
1180
if flash_attention is None :
1255
- # Auto-detect flash attention availability
1256
- if is_tpu :
1257
- # TPUs have specialized hardware for attention that works with any
1258
- # sharding pattern
1259
- flash_attention = True
1260
- else :
1261
- # For GPU/CPU with partially sharded inputs, we need
1262
- # multiple devices to efficiently handle the sharding
1263
- if partially_sharded_inputs and len (jax .devices ()) <= 1 :
1264
- flash_attention = False
1265
- else :
1266
- flash_attention = _can_use_flash_attention (
1267
- query , key , value , bias
1268
- )
1269
- elif flash_attention is True and not is_tpu :
1270
- # If flash attention is explicitly requested, validate compatibility
1271
- # Skip validation for TPU as it has specialized hardware support
1272
- try :
1273
- _can_use_flash_attention (query , key , value , bias , raise_error = True )
1274
- except Exception :
1275
- # Only disable flash attention on non-TPU platforms
1276
- # if validation fails
1277
- flash_attention = False
1278
-
1279
- # TPU-specific flash attention path
1280
- if is_tpu and flash_attention :
1281
- # Transpose to ('batch', 'heads', 'length', 'head_dim')
1282
- query_tpu_layout = jnp .transpose (query , axes = (0 , 2 , 1 , 3 ))
1283
- key_tpu_layout = jnp .transpose (key , axes = (0 , 2 , 1 , 3 ))
1284
- value_tpu_layout = jnp .transpose (value , axes = (0 , 2 , 1 , 3 ))
1285
-
1286
- bs , num_heads , q_len , head_dim = query_tpu_layout .shape
1287
-
1288
- # Apply scale to query if provided
1289
- if scale is not None :
1290
- # TPU kernel applies 1/sqrt(head_dim) internally, to achieve
1291
- # overall QK^T * scale, scale query by (scale * sqrt(head_dim))
1292
- query_tpu_layout = query_tpu_layout * (scale * math .sqrt (head_dim ))
1293
-
1294
- # Create segment IDs for Splash Attention (for packing/batching)
1295
- segment_ids = jnp .zeros ([bs , q_len ], dtype = jnp .int32 )
1296
- decoder_segment_ids = splash_attention_kernel .SegmentIds (
1297
- q = segment_ids , kv = segment_ids
1298
- )
1181
+ flash_attention = _can_use_flash_attention (query , key , value , bias )
1182
+ elif flash_attention is True :
1183
+ # Use `raise_error=True` to provide more details if the inputs failed to
1184
+ # use flash attention
1185
+ _can_use_flash_attention (query , key , value , bias , raise_error = True )
1299
1186
1300
- # Process mask for Splash Attention
1301
- custom_mask = None
1302
- if mask is not None :
1303
- mask_bool = mask .astype ("bool" ) if mask .dtype != jnp .bool_ else mask
1304
-
1305
- if mask_bool .ndim == 3 and mask_bool .shape [0 ] == bs :
1306
- custom_mask = mask_bool [0 ]
1307
- elif mask_bool .ndim == 4 and mask_bool .shape [0 ] == bs :
1308
- custom_mask = mask_bool [0 , 0 ]
1309
-
1310
- if is_causal and custom_mask is not None :
1311
- causal_mask = jnp .tril (
1312
- jnp .ones ((q_len , q_len ), dtype = jnp .bool_ )
1313
- )
1314
- custom_mask = jnp .logical_and (custom_mask , causal_mask )
1315
-
1316
- if custom_mask is None and is_causal :
1317
- custom_mask = jnp .tril (jnp .ones ((q_len , q_len ), dtype = jnp .bool_ ))
1318
-
1319
- try :
1320
- output = wrap_flash_attention (
1321
- query_tpu_layout ,
1322
- key_tpu_layout ,
1323
- value_tpu_layout ,
1324
- decoder_segment_ids = decoder_segment_ids ,
1325
- custom_mask = custom_mask ,
1326
- attn_logits_soft_cap = attn_logits_soft_cap ,
1327
- head_shards = head_shards ,
1328
- q_seq_shards = q_seq_shards ,
1329
- )
1330
- # Transpose output back to Keras layout
1331
- return jnp .transpose (output , axes = (0 , 2 , 1 , 3 ))
1332
- except Exception :
1333
- flash_attention = False
1187
+ if jax .devices ()[0 ].platform == "tpu" :
1188
+ # Transpose to ('batch', 'heads', 'length', 'kv')
1189
+ query = jnp .transpose (query , axes = (0 , 2 , 1 , 3 ))
1190
+ key = jnp .transpose (key , axes = (0 , 2 , 1 , 3 ))
1191
+ value = jnp .transpose (value , axes = (0 , 2 , 1 , 3 ))
1192
+ B , H , S , KV = query .shape
1193
+
1194
+ segment_ids = jnp .ones ([B , S ])
1195
+ # {token_ids, padding_mask, segment_ids} enable packing
1196
+ out = wrap_flash_attention (
1197
+ query ,
1198
+ key ,
1199
+ value ,
1200
+ decoder_segment_ids = splash_attention_kernel .SegmentIds (
1201
+ segment_ids , segment_ids
1202
+ ),
1203
+ custom_mask = mask ,
1204
+ attn_logits_soft_cap = attn_logits_soft_cap ,
1205
+ )
1206
+ out = jnp .transpose (out , axes = (0 , 2 , 1 , 3 ))
1207
+ return out
1334
1208
1335
- # JAX native dot_product_attention for GPU or fallback for TPU
1209
+ # ` dot_product_attention` is only available in jax>=0.4.31
1336
1210
if hasattr (jax .nn , "dot_product_attention" ):
1337
- try :
1338
- return jax .nn .dot_product_attention (
1339
- query ,
1340
- key ,
1341
- value ,
1342
- bias = bias ,
1343
- mask = mask ,
1344
- scale = scale ,
1345
- is_causal = is_causal ,
1346
- implementation = "cudnn" if flash_attention else "xla" ,
1347
- )
1348
- except Exception :
1349
- # If flash attention fails, fall back to XLA implementation
1350
- if flash_attention :
1351
- return jax .nn .dot_product_attention (
1352
- query ,
1353
- key ,
1354
- value ,
1355
- bias = bias ,
1356
- mask = mask ,
1357
- scale = scale ,
1358
- is_causal = is_causal ,
1359
- implementation = "xla" ,
1360
- )
1361
- raise
1211
+ return jax .nn .dot_product_attention (
1212
+ query ,
1213
+ key ,
1214
+ value ,
1215
+ bias = bias ,
1216
+ mask = mask ,
1217
+ scale = scale ,
1218
+ is_causal = is_causal ,
1219
+ implementation = "cudnn" if flash_attention else "xla" ,
1220
+ )
1362
1221
1363
1222
if flash_attention :
1364
1223
raise RuntimeError (
@@ -1369,9 +1228,6 @@ def dot_product_attention(
1369
1228
# Ref: jax.nn.dot_product_attention
1370
1229
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
1371
1230
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
1372
-
1373
- # Fallback to custom XLA implementation
1374
- # This is the reference implementation from jax.nn.dot_product_attention
1375
1231
output_shape = query .shape
1376
1232
_ , _ , K , H = key .shape
1377
1233
scale = (1.0 / jnp .sqrt (H )) if scale is None else scale
0 commit comments