Skip to content

Commit 4cd871e

Browse files
rbrchenGary Frost
authored andcommitted
Add buffer tagging and opt-out moduleOp backend
1 parent 4330ec2 commit 4cd871e

File tree

10 files changed

+257
-34
lines changed

10 files changed

+257
-34
lines changed

hat/backends/ffi/cuda/src/main/java/hat/backend/ffi/CudaBackend.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -412,21 +412,21 @@ String createPTX(KernelCallGraph kernelCallGraph, Object... args){
412412
builder.ptxHeader(major, minor, target, addressSize);
413413
out.append(builder.getTextAndReset());
414414

415-
if (CallGraph.usingModuleOp) {
416-
System.out.println("Using ModuleOp for CudaBackend");
417-
kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> {
418-
CoreOp.FuncOp loweredFunc = OpTk.lower(funcOp);
419-
loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns);
420-
invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false));
421-
});
422-
} else {
415+
if (CallGraph.noModuleOp) {
423416
System.out.println("NOT using ModuleOp for CudaBackend");
424417
for (KernelCallGraph.KernelReachableResolvedMethodCall k : kernelCallGraph.kernelReachableResolvedStream().toList()) {
425418
CoreOp.FuncOp calledFunc = k.funcOp();
426419
CoreOp.FuncOp loweredFunc = OpTk.lower(calledFunc);
427420
loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns);
428421
invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false));
429422
}
423+
} else {
424+
System.out.println("Using ModuleOp for CudaBackend");
425+
kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> {
426+
CoreOp.FuncOp loweredFunc = OpTk.lower(funcOp);
427+
loweredFunc = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,loweredFunc, argsMap, usedMathFns);
428+
invokedMethods.append(createFunction(new PTXHATKernelBuilder(addressSize).nl().nl(), loweredFunc, false));
429+
});
430430
}
431431

432432
lowered = transformPTXPtrs(kernelCallGraph.computeContext.accelerator.lookup,lowered, argsMap, usedMathFns);

hat/backends/ffi/mock/src/main/java/hat/backend/ffi/MockBackend.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,14 @@ public void dispatchKernel(KernelCallGraph kernelCallGraph, NDRange ndRange, Obj
5656
// Here we receive a callgraph from the kernel entrypoint
5757
// The first time we see this we need to convert the kernel entrypoint
5858
// and rechable methods to a form that our mock backend can execute.
59-
if (CallGraph.usingModuleOp) {
60-
System.out.println("Using ModuleOp for MockBackend");
61-
kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> {
62-
});
63-
} else {
59+
if (CallGraph.noModuleOp) {
6460
System.out.println("NOT using ModuleOp for MockBackend");
6561
kernelCallGraph.kernelReachableResolvedStream().forEach(kr -> {
6662

63+
});
64+
} else {
65+
System.out.println("Using ModuleOp for MockBackend");
66+
kernelCallGraph.moduleOp.functionTable().forEach((_, funcOp) -> {
6767
});
6868
}
6969
}

hat/backends/ffi/shared/src/main/java/hat/backend/ffi/C99FFIBackend.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,21 +233,21 @@ public <T extends C99HATKernelBuilder<T>> String createCode(KernelCallGraph kern
233233
kernelCallGraph.entrypoint.funcOp());
234234

235235
// Sorting by rank ensures we don't need forward declarations
236-
if (CallGraph.usingModuleOp) {
237-
System.out.println("Using ModuleOp for C99FFIBackend");
238-
kernelCallGraph.moduleOp.functionTable()
239-
.forEach((_, funcOp) -> builder
240-
.nl()
241-
.kernelMethod(buildContext,funcOp)
242-
.nl());
243-
} else {
236+
if (CallGraph.noModuleOp) {
244237
System.out.println("NOT using ModuleOp for C99FFIBackend");
245238
kernelCallGraph.kernelReachableResolvedStream().sorted((lhs, rhs) -> rhs.rank - lhs.rank)
246239
.forEach(kernelReachableResolvedMethod ->
247240
builder
248241
.nl()
249242
.kernelMethod(buildContext,kernelReachableResolvedMethod.funcOp())
250243
.nl());
244+
} else {
245+
System.out.println("Using ModuleOp for C99FFIBackend");
246+
kernelCallGraph.moduleOp.functionTable()
247+
.forEach((_, funcOp) -> builder
248+
.nl()
249+
.kernelMethod(buildContext,funcOp)
250+
.nl());
251251
}
252252

253253
builder.nl().kernelEntrypoint(buildContext, args).nl();
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
package hat;
2+
3+
import hat.buffer.Buffer;
4+
import hat.ifacemapper.MappableIface;
5+
import jdk.incubator.code.*;
6+
import jdk.incubator.code.analysis.Inliner;
7+
import jdk.incubator.code.analysis.SSA;
8+
import jdk.incubator.code.dialect.core.CoreOp;
9+
import jdk.incubator.code.dialect.java.*;
10+
11+
import java.lang.invoke.MethodHandles;
12+
import java.lang.reflect.Method;
13+
import java.util.*;
14+
import java.util.concurrent.atomic.AtomicBoolean;
15+
16+
public class BufferTagger {
17+
static HashMap<Value, AccessType> accessMap = new HashMap<>();
18+
static HashMap<Value, Value> remappedVals = new HashMap<>(); // maps values to their "root" parameter/value
19+
static HashMap<Block, List<Block.Parameter>> blockParams = new HashMap<>(); // holds block parameters for easy lookup
20+
21+
public enum AccessType {
22+
NA(1),
23+
RO(2),
24+
WO(4),
25+
RW(6),
26+
NOT_BUFFER(0);
27+
28+
public final int value;
29+
AccessType(int i) {
30+
value = i;
31+
}
32+
}
33+
34+
// generates a list of AccessTypes matching the given FuncOp's parameter order
35+
public static ArrayList<AccessType> getAccessList(MethodHandles.Lookup l, CoreOp.FuncOp f) {
36+
CoreOp.FuncOp inlinedFunc = inlineLoop(l, f);
37+
buildAccessMap(l, inlinedFunc);
38+
ArrayList<AccessType> accessList = new ArrayList<>();
39+
for (Block.Parameter p : inlinedFunc.body().entryBlock().parameters()) {
40+
if (accessMap.containsKey(p)) {
41+
accessList.add(accessMap.get(p)); // is an accessed buffer
42+
} else if (getClass(l, p.type()) instanceof Class<?> c && MappableIface.class.isAssignableFrom(c)) {
43+
accessList.add(AccessType.NA); // is a buffer but not accessed
44+
} else {
45+
accessList.add(AccessType.NOT_BUFFER); // is not a buffer
46+
}
47+
}
48+
return accessList;
49+
}
50+
51+
// inlines functions found in FuncOp f until no more inline-able functions are present
52+
public static CoreOp.FuncOp inlineLoop(MethodHandles.Lookup l, CoreOp.FuncOp f) {
53+
CoreOp.FuncOp ssaFunc = SSA.transform(f.transform(OpTransformer.LOWERING_TRANSFORMER));
54+
AtomicBoolean changed = new AtomicBoolean(true);
55+
while (changed.get()) { // loop until no more inline-able functions
56+
changed.set(false);
57+
ssaFunc = ssaFunc.transform((bb, op) -> {
58+
if (op instanceof JavaOp.InvokeOp iop) {
59+
MethodRef methodRef = iop.invokeDescriptor();
60+
Method invokeOpCalledMethod;
61+
try {
62+
invokeOpCalledMethod = methodRef.resolveToMethod(l, iop.invokeKind());
63+
} catch (ReflectiveOperationException _) {
64+
throw new IllegalStateException("Could not resolve invokeOp to method");
65+
}
66+
if (invokeOpCalledMethod instanceof Method method) { // if method isn't a buffer access (is code reflected)
67+
if (Op.ofMethod(method).isPresent()) {
68+
CoreOp.FuncOp inline = Op.ofMethod(method).get(); // method to be inlined
69+
CoreOp.FuncOp ssaInline = SSA.transform(inline.transform(OpTransformer.LOWERING_TRANSFORMER));
70+
71+
Block.Builder exit = Inliner.inline(bb, ssaInline, bb.context().getValues(iop.operands()), (_, v) -> {
72+
if (v != null) bb.context().mapValue(iop.result(), v);
73+
});
74+
75+
if (!exit.parameters().isEmpty()) {
76+
bb.context().mapValue(iop.result(), exit.parameters().getFirst());
77+
}
78+
changed.set(true);
79+
return exit.rebind(bb.context(), bb.transformer()); // return exit in same context as block
80+
}
81+
}
82+
}
83+
bb.op(op);
84+
return bb;
85+
});
86+
}
87+
return ssaFunc;
88+
}
89+
90+
// creates the access map
91+
public static void buildAccessMap(MethodHandles.Lookup l, CoreOp.FuncOp f) {
92+
// build blockParams so that we can map params to "root" params later
93+
for (Body b : f.bodies()) {
94+
for (Block block : b.blocks()) {
95+
if (!block.parameters().isEmpty()) {
96+
blockParams.put(block, block.parameters());
97+
}
98+
}
99+
}
100+
101+
f.traverse(null, (map, op) -> {
102+
if (op instanceof CoreOp.BranchOp b) {
103+
mapBranch(l, b.branch());
104+
} else if (op instanceof CoreOp.ConditionalBranchOp cb) {
105+
mapBranch(l, cb.trueBranch()); // handle true branch
106+
mapBranch(l, cb.falseBranch()); // handle false branch
107+
} else if (op instanceof JavaOp.InvokeOp iop) { // (almost) all the buffer accesses happen here
108+
if (isAssignable(l, iop.invokeDescriptor().refType(), MappableIface.class)) {
109+
updateAccessType(getRootValue(iop), getAccessType(iop)); // update buffer access
110+
if (isAssignable(l, iop.invokeDescriptor().refType(), Buffer.class)
111+
&& iop.result() != null && !(iop.resultType() instanceof PrimitiveType)
112+
&& isAssignable(l, iop.resultType(), MappableIface.class)) {
113+
// if we access a struct/union from a buffer, we map the struct/union to the buffer root
114+
remappedVals.put(iop.result(), getRootValue(iop));
115+
}
116+
}
117+
} else if (op instanceof CoreOp.VarOp vop) { // map the new VarOp to the "root" param
118+
if (isAssignable(l, vop.resultType().valueType(), Buffer.class)) {
119+
remappedVals.put(vop.initOperand(), getRootValue(vop));
120+
}
121+
} else if (op instanceof JavaOp.FieldAccessOp.FieldLoadOp flop) {
122+
if (isAssignable(l, flop.fieldDescriptor().refType(), KernelContext.class)) {
123+
updateAccessType(getRootValue(flop), AccessType.RO); // handle kc access
124+
}
125+
}
126+
return map;
127+
});
128+
}
129+
130+
// maps the parameters of a block to the values passed to a branch
131+
public static void mapBranch(MethodHandles.Lookup l, Block.Reference b) {
132+
List<Value> args = b.arguments();
133+
for (int i = 0; i < args.size(); i++) {
134+
Value key = blockParams.get(b.targetBlock()).get(i);
135+
Value val = args.get(i);
136+
137+
if (val instanceof Op.Result) {
138+
// either find root param or it doesnt exist (is a constant for example)
139+
if (isAssignable(l, val.type(), MappableIface.class)) {
140+
val = getRootValue(((Op.Result) val).op());
141+
if (val instanceof Block.Parameter) {
142+
val = remappedVals.getOrDefault(val, val);
143+
}
144+
}
145+
}
146+
remappedVals.put(key, val);
147+
}
148+
}
149+
150+
// checks if a TypeElement is assignable to a certain class
151+
public static boolean isAssignable(MethodHandles.Lookup l, TypeElement type, Class<?> clazz) {
152+
Class<?> fopClass = getClass(l, type);
153+
return (fopClass != null && (clazz.isAssignableFrom(fopClass)));
154+
}
155+
156+
// retrieves the class of a TypeElement
157+
public static Class<?> getClass(MethodHandles.Lookup l, TypeElement type) {
158+
if (type instanceof ClassType classType) {
159+
try {
160+
return (Class<?>) classType.resolve(l);
161+
} catch (ReflectiveOperationException e) {
162+
throw new RuntimeException(e);
163+
}
164+
}
165+
return null;
166+
}
167+
168+
// retrieves "root" value of an op, the origin of the parameter (or value) used by the op
169+
public static Value getRootValue(Op op) {
170+
if (op.operands().isEmpty()) {
171+
return op.result();
172+
}
173+
if (op.operands().getFirst() instanceof Block.Parameter param) {
174+
return param;
175+
}
176+
Value val = op.operands().getFirst();
177+
while (!(val instanceof Block.Parameter)) {
178+
// or if the "root VarOp" is an invoke (not sure how to tell)
179+
// if (tempOp instanceof JavaOp.InvokeOp iop
180+
// && ((TypeElement) iop.resultType()) instanceof ClassType classType
181+
// && !hasOperandType(iop, classType)) return ((CoreOp.VarOp) op);
182+
val = ((Op.Result) val).op().operands().getFirst();
183+
}
184+
return val;
185+
}
186+
187+
// retrieves accessType based on return value of InvokeOp
188+
public static AccessType getAccessType(JavaOp.InvokeOp iop) {
189+
return iop.invokeDescriptor().type().returnType().equals(JavaType.VOID) ? AccessType.WO : AccessType.RO;
190+
}
191+
192+
// updates accessMap
193+
public static void updateAccessType(Value val, AccessType curAccess) {
194+
Value remappedVal = remappedVals.getOrDefault(val, val);
195+
AccessType storedAccess = accessMap.get(remappedVal);
196+
if (storedAccess == null) {
197+
accessMap.put(remappedVal, curAccess);
198+
} else if (curAccess != storedAccess && storedAccess != AccessType.RW) {
199+
accessMap.put(remappedVal, AccessType.RW);
200+
}
201+
}
202+
203+
public static void printAccessMap() {
204+
System.out.println("access map output:");
205+
for (Value val : accessMap.keySet()) {
206+
if (val instanceof Block.Parameter param) {
207+
System.out.println("\t" + ((CoreOp.FuncOp) param.declaringBlock().parent().parent()).funcName()
208+
+ " param w/ idx " + param.index() + ": " + accessMap.get(val));
209+
} else {
210+
System.out.println("\t" + val.toString() + ": " + accessMap.get(val));
211+
}
212+
}
213+
}
214+
}

hat/core/src/main/java/hat/buffer/ArgArray.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
package hat.buffer;
2626

2727
import hat.Accelerator;
28+
import hat.BufferTagger;
2829
import hat.ComputeContext;
2930
import hat.callgraph.KernelCallGraph;
3031
import hat.ifacemapper.Schema;
@@ -34,6 +35,7 @@
3435
import java.lang.foreign.ValueLayout;
3536
import java.lang.invoke.MethodHandles;
3637
import java.nio.ByteOrder;
38+
import java.util.ArrayList;
3739

3840
import static hat.buffer.ArgArray.Arg.Value.Buf.UNKNOWN_BYTE;
3941
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
@@ -289,6 +291,8 @@ static ArgArray create(Accelerator accelerator, KernelCallGraph kernelCallGraph,
289291

290292
static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object... args) {
291293
Annotation[][] parameterAnnotations = kernelCallGraph.entrypoint.getMethod().getParameterAnnotations();
294+
ArrayList<BufferTagger.AccessType> bufferAccessList = kernelCallGraph.bufferAccessList;
295+
292296
for (int i = 0; i < args.length; i++) {
293297
Object argObject = args[i];
294298
Arg arg = argArray.arg(i); // this should be invariant, but if we are called from create it will be 0 for all
@@ -324,6 +328,8 @@ static void update(ArgArray argArray, KernelCallGraph kernelCallGraph, Object...
324328
buf.address(segment);
325329
buf.bytes(segment.byteSize());
326330
buf.access(accessByte);
331+
332+
assert bufferAccessList.get(i).value == accessByte;
327333
}
328334
default -> throw new IllegalStateException("Unexpected value: " + argObject);
329335
}

hat/core/src/main/java/hat/callgraph/CallGraph.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public abstract class CallGraph<E extends Entrypoint> {
4242
public final Set<MethodCall> calls = new HashSet<>();
4343
public final Map<MethodRef, MethodCall> methodRefToMethodCallMap = new LinkedHashMap<>();
4444
public CoreOp.ModuleOp moduleOp;
45-
public static boolean usingModuleOp = Boolean.getBoolean("moduleOp");
45+
public static boolean noModuleOp = Boolean.getBoolean("noModuleOp");
4646
public Stream<MethodCall> callStream() {
4747
return methodRefToMethodCallMap.values().stream();
4848
}

hat/core/src/main/java/hat/callgraph/ComputeCallGraph.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ public void updateDag(ComputeReachableResolvedMethodCall computeReachableResolve
216216
}
217217

218218
public void close() {
219-
if (CallGraph.usingModuleOp) {
220-
closeWithModuleOp(entrypoint);
221-
} else {
219+
if (CallGraph.noModuleOp) {
222220
updateDag(entrypoint);
221+
} else {
222+
closeWithModuleOp(entrypoint);
223223
}
224224
}
225225

hat/core/src/main/java/hat/callgraph/KernelCallGraph.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
*/
2525
package hat.callgraph;
2626

27+
import hat.BufferTagger;
2728
import hat.buffer.Buffer;
2829
import hat.optools.OpTk;
2930
import jdk.incubator.code.Op;
@@ -38,6 +39,7 @@
3839
public class KernelCallGraph extends CallGraph<KernelEntrypoint> {
3940
public final ComputeCallGraph computeCallGraph;
4041
public final Map<MethodRef, MethodCall> bufferAccessToMethodCallMap = new LinkedHashMap<>();
42+
public final ArrayList<BufferTagger.AccessType> bufferAccessList;
4143

4244
public interface KernelReachable {
4345
}
@@ -77,6 +79,7 @@ public Stream<KernelReachableResolvedMethodCall> kernelReachableResolvedStream()
7779
super(computeCallGraph.computeContext, new KernelEntrypoint(null, methodRef, method, funcOp));
7880
entrypoint.callGraph = this;
7981
this.computeCallGraph = computeCallGraph;
82+
bufferAccessList = BufferTagger.getAccessList(computeContext.accelerator.lookup, entrypoint.funcOp());
8083
}
8184

8285
void updateDag(KernelReachableResolvedMethodCall kernelReachableResolvedMethodCall) {

0 commit comments

Comments
 (0)