@@ -244,7 +244,8 @@ def combine_mlir_scripts(
244
244
print (f"[DEBUG] output_name = { output_name } " )
245
245
maps1 = []
246
246
maps2 = []
247
- constants = set ()
247
+ constants_1 = set ()
248
+ constants_2 = set ()
248
249
f1 = []
249
250
f2 = []
250
251
@@ -255,7 +256,7 @@ def combine_mlir_scripts(
255
256
if re .search ("#map\d*\s*=" , line ):
256
257
maps1 .append (line )
257
258
elif re .search ("arith.constant" , line ):
258
- constants .add (line )
259
+ constants_1 .add (line )
259
260
elif not re .search ("module" , line ):
260
261
line = re .sub ("forward" , "first_vicuna_forward" , line )
261
262
f1 .append (line )
@@ -281,7 +282,7 @@ def combine_mlir_scripts(
281
282
elif "global_seed" in line :
282
283
continue
283
284
elif re .search ("arith.constant" , line ):
284
- constants .add (line )
285
+ constants_2 .add (line )
285
286
elif not re .search ("module" , line ):
286
287
line = re .sub ("forward" , "second_vicuna_forward" , line )
287
288
f2 .append (line )
@@ -304,15 +305,21 @@ def combine_mlir_scripts(
304
305
module_end = "}"
305
306
306
307
global_vars = []
307
- vnames = []
308
- global_var_loading1 = []
309
- global_var_loading2 = []
308
+ global_var_loading1 = dict ()
309
+ global_var_loading2 = dict ()
310
310
311
311
print (f"[DEBUG] processing constants" )
312
- counter = 0
313
- constants = list (constants )
312
+ # in both 1 and 2
313
+ constants = [(e , "" ) for e in list (constants_1 & constants_2 )]
314
+ # only in 1
315
+ constants .extend ([(e , "_1" ) for e in list (constants_1 .difference (constants_2 ))])
316
+ # only in 2
317
+ constants .extend ([(e , "_2" ) for e in list (constants_2 .difference (constants_1 ))])
318
+ del constants_1 , constants_2
319
+ gc .collect ()
320
+
314
321
while constants :
315
- constant = constants .pop (0 )
322
+ constant , vname_suf = constants .pop (0 )
316
323
vname , vbody = constant .split ("=" )
317
324
vname = re .sub ("%" , "" , vname )
318
325
vname = vname .strip ()
@@ -322,43 +329,42 @@ def combine_mlir_scripts(
322
329
print (constant )
323
330
vdtype = vbody .split (":" )[- 1 ].strip ()
324
331
fixed_vdtype = vdtype
325
- noinline = "{noinline}" if "tensor" in fixed_vdtype else ""
326
- if "c1_i64" in vname :
327
- print (constant )
328
- counter += 1
329
- if counter == 2 :
330
- counter = 0
331
- print ("detected duplicate" )
332
- continue
333
- vnames .append (vname )
334
332
if "true" not in vname :
335
333
global_vars .append (
336
- f"ml_program.global private @{ vname } ({ vbody } ) : { fixed_vdtype } "
337
- )
338
- global_var_loading1 .append (
339
- f"\t \t %{ vname } = ml_program.global_load_const @{ vname } : { fixed_vdtype } "
340
- )
341
- global_var_loading2 .append (
342
- f"\t \t %{ vname } = ml_program.global_load_const @{ vname } : { fixed_vdtype } "
334
+ f"ml_program.global private @{ vname } { vname_suf } ({ vbody } ) : { fixed_vdtype } "
343
335
)
336
+ if vname_suf != "_2" :
337
+ global_var_loading1 [
338
+ f"\t \t %{ vname } = ml_program.global_load_const @{ vname } { vname_suf } : { fixed_vdtype } "
339
+ ] = ""
340
+ if vname_suf != "_1" :
341
+ global_var_loading2 [
342
+ f"\t \t %{ vname } = ml_program.global_load_const @{ vname } { vname_suf } : { fixed_vdtype } "
343
+ ] = ""
344
344
else :
345
345
global_vars .append (
346
- f"ml_program.global private @{ vname } ({ vbody } ) : i1"
347
- )
348
- global_var_loading1 .append (
349
- f"\t \t %{ vname } = ml_program.global_load_const @{ vname } : i1"
350
- )
351
- global_var_loading2 .append (
352
- f"\t \t %{ vname } = ml_program.global_load_const @{ vname } : i1"
346
+ f"ml_program.global private @{ vname } { vname_suf } ({ vbody } ) : i1"
353
347
)
348
+ if vname_suf != "_2" :
349
+ global_var_loading1 [
350
+ f"\t \t %{ vname } = ml_program.global_load_const @{ vname } { vname_suf } : i1"
351
+ ] = ""
352
+ if vname_suf != "_1" :
353
+ global_var_loading2 [
354
+ f"\t \t %{ vname } = ml_program.global_load_const @{ vname } { vname_suf } : i1"
355
+ ] = ""
356
+
357
+ del constants
358
+ gc .collect ()
359
+
354
360
355
361
new_f1 , new_f2 = [], []
356
362
357
363
print (f"[DEBUG] processing f1" )
358
364
for line in f1 :
359
365
if "func.func" in line :
360
366
new_f1 .append (line )
361
- for global_var in global_var_loading1 :
367
+ for global_var in global_var_loading1 . keys () :
362
368
new_f1 .append (global_var )
363
369
else :
364
370
new_f1 .append (line )
@@ -367,7 +373,7 @@ def combine_mlir_scripts(
367
373
for line in f2 :
368
374
if "func.func" in line :
369
375
new_f2 .append (line )
370
- for global_var in global_var_loading2 :
376
+ for global_var in global_var_loading2 . keys () :
371
377
if (
372
378
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
373
379
in global_var
0 commit comments