Skip to content

Commit 20b370b

Browse files
committed
fix combine mlir for llama2
1 parent ad55cb6 commit 20b370b

File tree

1 file changed

+40
-34
lines changed

1 file changed

+40
-34
lines changed

apps/language_models/scripts/vicuna.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ def combine_mlir_scripts(
244244
print(f"[DEBUG] output_name = {output_name}")
245245
maps1 = []
246246
maps2 = []
247-
constants = set()
247+
constants_1 = set()
248+
constants_2 = set()
248249
f1 = []
249250
f2 = []
250251

@@ -255,7 +256,7 @@ def combine_mlir_scripts(
255256
if re.search("#map\d*\s*=", line):
256257
maps1.append(line)
257258
elif re.search("arith.constant", line):
258-
constants.add(line)
259+
constants_1.add(line)
259260
elif not re.search("module", line):
260261
line = re.sub("forward", "first_vicuna_forward", line)
261262
f1.append(line)
@@ -281,7 +282,7 @@ def combine_mlir_scripts(
281282
elif "global_seed" in line:
282283
continue
283284
elif re.search("arith.constant", line):
284-
constants.add(line)
285+
constants_2.add(line)
285286
elif not re.search("module", line):
286287
line = re.sub("forward", "second_vicuna_forward", line)
287288
f2.append(line)
@@ -304,15 +305,21 @@ def combine_mlir_scripts(
304305
module_end = "}"
305306

306307
global_vars = []
307-
vnames = []
308-
global_var_loading1 = []
309-
global_var_loading2 = []
308+
global_var_loading1 = dict()
309+
global_var_loading2 = dict()
310310

311311
print(f"[DEBUG] processing constants")
312-
counter = 0
313-
constants = list(constants)
312+
# in both 1 and 2
313+
constants = [(e , "") for e in list(constants_1 & constants_2)]
314+
# only in 1
315+
constants.extend([(e, "_1") for e in list(constants_1.difference(constants_2))])
316+
# only in 2
317+
constants.extend([(e, "_2") for e in list(constants_2.difference(constants_1))])
318+
del constants_1, constants_2
319+
gc.collect()
320+
314321
while constants:
315-
constant = constants.pop(0)
322+
constant, vname_suf = constants.pop(0)
316323
vname, vbody = constant.split("=")
317324
vname = re.sub("%", "", vname)
318325
vname = vname.strip()
@@ -322,43 +329,42 @@ def combine_mlir_scripts(
322329
print(constant)
323330
vdtype = vbody.split(":")[-1].strip()
324331
fixed_vdtype = vdtype
325-
noinline = "{noinline}" if "tensor" in fixed_vdtype else ""
326-
if "c1_i64" in vname:
327-
print(constant)
328-
counter += 1
329-
if counter == 2:
330-
counter = 0
331-
print("detected duplicate")
332-
continue
333-
vnames.append(vname)
334332
if "true" not in vname:
335333
global_vars.append(
336-
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
337-
)
338-
global_var_loading1.append(
339-
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
340-
)
341-
global_var_loading2.append(
342-
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
334+
f"ml_program.global private @{vname}{vname_suf}({vbody}) : {fixed_vdtype}"
343335
)
336+
if vname_suf != "_2":
337+
global_var_loading1[
338+
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
339+
] = ""
340+
if vname_suf != "_1":
341+
global_var_loading2[
342+
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
343+
] = ""
344344
else:
345345
global_vars.append(
346-
f"ml_program.global private @{vname}({vbody}) : i1"
347-
)
348-
global_var_loading1.append(
349-
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
350-
)
351-
global_var_loading2.append(
352-
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
346+
f"ml_program.global private @{vname}{vname_suf}({vbody}) : i1"
353347
)
348+
if vname_suf != "_2":
349+
global_var_loading1[
350+
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
351+
] = ""
352+
if vname_suf != "_1":
353+
global_var_loading2[
354+
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
355+
] = ""
356+
357+
del constants
358+
gc.collect()
359+
354360

355361
new_f1, new_f2 = [], []
356362

357363
print(f"[DEBUG] processing f1")
358364
for line in f1:
359365
if "func.func" in line:
360366
new_f1.append(line)
361-
for global_var in global_var_loading1:
367+
for global_var in global_var_loading1.keys():
362368
new_f1.append(global_var)
363369
else:
364370
new_f1.append(line)
@@ -367,7 +373,7 @@ def combine_mlir_scripts(
367373
for line in f2:
368374
if "func.func" in line:
369375
new_f2.append(line)
370-
for global_var in global_var_loading2:
376+
for global_var in global_var_loading2.keys():
371377
if (
372378
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
373379
in global_var

0 commit comments

Comments
 (0)