1
+ package hat ;
2
+
3
+ import hat .buffer .Buffer ;
4
+ import hat .ifacemapper .MappableIface ;
5
+ import jdk .incubator .code .*;
6
+ import jdk .incubator .code .analysis .Inliner ;
7
+ import jdk .incubator .code .analysis .SSA ;
8
+ import jdk .incubator .code .dialect .core .CoreOp ;
9
+ import jdk .incubator .code .dialect .java .*;
10
+
11
+ import java .lang .invoke .MethodHandles ;
12
+ import java .lang .reflect .Method ;
13
+ import java .util .*;
14
+ import java .util .concurrent .atomic .AtomicBoolean ;
15
+
16
+ public class BufferTagger {
17
+ static HashMap <Value , AccessType > accessMap = new HashMap <>();
18
+ static HashMap <Value , Value > remappedVals = new HashMap <>(); // maps values to their "root" parameter/value
19
+ static HashMap <Block , List <Block .Parameter >> blockParams = new HashMap <>(); // holds block parameters for easy lookup
20
+
21
+ public enum AccessType {
22
+ NA (1 ),
23
+ RO (2 ),
24
+ WO (4 ),
25
+ RW (6 ),
26
+ NOT_BUFFER (0 );
27
+
28
+ public final int value ;
29
+ AccessType (int i ) {
30
+ value = i ;
31
+ }
32
+ }
33
+
34
+ // generates a list of AccessTypes matching the given FuncOp's parameter order
35
+ public static ArrayList <AccessType > getAccessList (MethodHandles .Lookup l , CoreOp .FuncOp f ) {
36
+ CoreOp .FuncOp inlinedFunc = inlineLoop (l , f );
37
+ buildAccessMap (l , inlinedFunc );
38
+ ArrayList <AccessType > accessList = new ArrayList <>();
39
+ for (Block .Parameter p : inlinedFunc .body ().entryBlock ().parameters ()) {
40
+ if (accessMap .containsKey (p )) {
41
+ accessList .add (accessMap .get (p )); // is an accessed buffer
42
+ } else if (getClass (l , p .type ()) instanceof Class <?> c && MappableIface .class .isAssignableFrom (c )) {
43
+ accessList .add (AccessType .NA ); // is a buffer but not accessed
44
+ } else {
45
+ accessList .add (AccessType .NOT_BUFFER ); // is not a buffer
46
+ }
47
+ }
48
+ return accessList ;
49
+ }
50
+
51
+ // inlines functions found in FuncOp f until no more inline-able functions are present
52
+ public static CoreOp .FuncOp inlineLoop (MethodHandles .Lookup l , CoreOp .FuncOp f ) {
53
+ CoreOp .FuncOp ssaFunc = SSA .transform (f .transform (OpTransformer .LOWERING_TRANSFORMER ));
54
+ AtomicBoolean changed = new AtomicBoolean (true );
55
+ while (changed .get ()) { // loop until no more inline-able functions
56
+ changed .set (false );
57
+ ssaFunc = ssaFunc .transform ((bb , op ) -> {
58
+ if (op instanceof JavaOp .InvokeOp iop ) {
59
+ MethodRef methodRef = iop .invokeDescriptor ();
60
+ Method invokeOpCalledMethod ;
61
+ try {
62
+ invokeOpCalledMethod = methodRef .resolveToMethod (l , iop .invokeKind ());
63
+ } catch (ReflectiveOperationException _) {
64
+ throw new IllegalStateException ("Could not resolve invokeOp to method" );
65
+ }
66
+ if (invokeOpCalledMethod instanceof Method method ) { // if method isn't a buffer access (is code reflected)
67
+ if (Op .ofMethod (method ).isPresent ()) {
68
+ CoreOp .FuncOp inline = Op .ofMethod (method ).get (); // method to be inlined
69
+ CoreOp .FuncOp ssaInline = SSA .transform (inline .transform (OpTransformer .LOWERING_TRANSFORMER ));
70
+
71
+ Block .Builder exit = Inliner .inline (bb , ssaInline , bb .context ().getValues (iop .operands ()), (_ , v ) -> {
72
+ if (v != null ) bb .context ().mapValue (iop .result (), v );
73
+ });
74
+
75
+ if (!exit .parameters ().isEmpty ()) {
76
+ bb .context ().mapValue (iop .result (), exit .parameters ().getFirst ());
77
+ }
78
+ changed .set (true );
79
+ return exit .rebind (bb .context (), bb .transformer ()); // return exit in same context as block
80
+ }
81
+ }
82
+ }
83
+ bb .op (op );
84
+ return bb ;
85
+ });
86
+ }
87
+ return ssaFunc ;
88
+ }
89
+
90
+ // creates the access map
91
+ public static void buildAccessMap (MethodHandles .Lookup l , CoreOp .FuncOp f ) {
92
+ // build blockParams so that we can map params to "root" params later
93
+ for (Body b : f .bodies ()) {
94
+ for (Block block : b .blocks ()) {
95
+ if (!block .parameters ().isEmpty ()) {
96
+ blockParams .put (block , block .parameters ());
97
+ }
98
+ }
99
+ }
100
+
101
+ f .traverse (null , (map , op ) -> {
102
+ if (op instanceof CoreOp .BranchOp b ) {
103
+ mapBranch (l , b .branch ());
104
+ } else if (op instanceof CoreOp .ConditionalBranchOp cb ) {
105
+ mapBranch (l , cb .trueBranch ()); // handle true branch
106
+ mapBranch (l , cb .falseBranch ()); // handle false branch
107
+ } else if (op instanceof JavaOp .InvokeOp iop ) { // (almost) all the buffer accesses happen here
108
+ if (isAssignable (l , iop .invokeDescriptor ().refType (), MappableIface .class )) {
109
+ updateAccessType (getRootValue (iop ), getAccessType (iop )); // update buffer access
110
+ if (isAssignable (l , iop .invokeDescriptor ().refType (), Buffer .class )
111
+ && iop .result () != null && !(iop .resultType () instanceof PrimitiveType )
112
+ && isAssignable (l , iop .resultType (), MappableIface .class )) {
113
+ // if we access a struct/union from a buffer, we map the struct/union to the buffer root
114
+ remappedVals .put (iop .result (), getRootValue (iop ));
115
+ }
116
+ }
117
+ } else if (op instanceof CoreOp .VarOp vop ) { // map the new VarOp to the "root" param
118
+ if (isAssignable (l , vop .resultType ().valueType (), Buffer .class )) {
119
+ remappedVals .put (vop .initOperand (), getRootValue (vop ));
120
+ }
121
+ } else if (op instanceof JavaOp .FieldAccessOp .FieldLoadOp flop ) {
122
+ if (isAssignable (l , flop .fieldDescriptor ().refType (), KernelContext .class )) {
123
+ updateAccessType (getRootValue (flop ), AccessType .RO ); // handle kc access
124
+ }
125
+ }
126
+ return map ;
127
+ });
128
+ }
129
+
130
+ // maps the parameters of a block to the values passed to a branch
131
+ public static void mapBranch (MethodHandles .Lookup l , Block .Reference b ) {
132
+ List <Value > args = b .arguments ();
133
+ for (int i = 0 ; i < args .size (); i ++) {
134
+ Value key = blockParams .get (b .targetBlock ()).get (i );
135
+ Value val = args .get (i );
136
+
137
+ if (val instanceof Op .Result ) {
138
+ // either find root param or it doesnt exist (is a constant for example)
139
+ if (isAssignable (l , val .type (), MappableIface .class )) {
140
+ val = getRootValue (((Op .Result ) val ).op ());
141
+ if (val instanceof Block .Parameter ) {
142
+ val = remappedVals .getOrDefault (val , val );
143
+ }
144
+ }
145
+ }
146
+ remappedVals .put (key , val );
147
+ }
148
+ }
149
+
150
+ // checks if a TypeElement is assignable to a certain class
151
+ public static boolean isAssignable (MethodHandles .Lookup l , TypeElement type , Class <?> clazz ) {
152
+ Class <?> fopClass = getClass (l , type );
153
+ return (fopClass != null && (clazz .isAssignableFrom (fopClass )));
154
+ }
155
+
156
+ // retrieves the class of a TypeElement
157
+ public static Class <?> getClass (MethodHandles .Lookup l , TypeElement type ) {
158
+ if (type instanceof ClassType classType ) {
159
+ try {
160
+ return (Class <?>) classType .resolve (l );
161
+ } catch (ReflectiveOperationException e ) {
162
+ throw new RuntimeException (e );
163
+ }
164
+ }
165
+ return null ;
166
+ }
167
+
168
+ // retrieves "root" value of an op, the origin of the parameter (or value) used by the op
169
+ public static Value getRootValue (Op op ) {
170
+ if (op .operands ().isEmpty ()) {
171
+ return op .result ();
172
+ }
173
+ if (op .operands ().getFirst () instanceof Block .Parameter param ) {
174
+ return param ;
175
+ }
176
+ Value val = op .operands ().getFirst ();
177
+ while (!(val instanceof Block .Parameter )) {
178
+ // or if the "root VarOp" is an invoke (not sure how to tell)
179
+ // if (tempOp instanceof JavaOp.InvokeOp iop
180
+ // && ((TypeElement) iop.resultType()) instanceof ClassType classType
181
+ // && !hasOperandType(iop, classType)) return ((CoreOp.VarOp) op);
182
+ val = ((Op .Result ) val ).op ().operands ().getFirst ();
183
+ }
184
+ return val ;
185
+ }
186
+
187
+ // retrieves accessType based on return value of InvokeOp
188
+ public static AccessType getAccessType (JavaOp .InvokeOp iop ) {
189
+ return iop .invokeDescriptor ().type ().returnType ().equals (JavaType .VOID ) ? AccessType .WO : AccessType .RO ;
190
+ }
191
+
192
+ // updates accessMap
193
+ public static void updateAccessType (Value val , AccessType curAccess ) {
194
+ Value remappedVal = remappedVals .getOrDefault (val , val );
195
+ AccessType storedAccess = accessMap .get (remappedVal );
196
+ if (storedAccess == null ) {
197
+ accessMap .put (remappedVal , curAccess );
198
+ } else if (curAccess != storedAccess && storedAccess != AccessType .RW ) {
199
+ accessMap .put (remappedVal , AccessType .RW );
200
+ }
201
+ }
202
+
203
+ public static void printAccessMap () {
204
+ System .out .println ("access map output:" );
205
+ for (Value val : accessMap .keySet ()) {
206
+ if (val instanceof Block .Parameter param ) {
207
+ System .out .println ("\t " + ((CoreOp .FuncOp ) param .declaringBlock ().parent ().parent ()).funcName ()
208
+ + " param w/ idx " + param .index () + ": " + accessMap .get (val ));
209
+ } else {
210
+ System .out .println ("\t " + val .toString () + ": " + accessMap .get (val ));
211
+ }
212
+ }
213
+ }
214
+ }
0 commit comments