00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 '''
00031 メモリアクセス認識、最適化
00032
00033 CSEとloop invariant code motionを行う
00034
00035 <h3>CSE</h3>
00036 <code>
00037 for (i=0; i<n; i++) {
00038 x[i] = a[i]
00039 y[i] = a[i+1]
00040 }
00041 </code>
00042 のように、ひとつの配列に対する複数のアクセスがあった場合に、それをひとつにまとめる。
00043
00044
00045 <h3>loop invariant code motion</h3>
00046 <code>
00047 for (i=0; i<n; i++) {
00048 for (j=0; j<n; j++) {
00049 x[i][j] = y[i];
00050 }
00051 }
00052 </code>
00053 のように、ループ内で不変なアクセスがあった場合、それをループ外へ移動する
00054 '''
00055
00056
00057 import ctrump
00058 import copy
00059
00060 __all__ = ["MemoryOp", "MemoryOpTree", "loop_tree_to_memop_tree", "PtrIncVarEnv"]
00061
00062 class PtrIncVarEnv:
00063 def __init__(self):
00064 self.incr_table = {}
00065 self.array_table = {}
00066
00067 def append_ptrinc(self, var, incr):
00068 if not var in self.array_table:
00069 self.array_table[var] = True
00070
00071 if not incr in self.incr_table:
00072 self.incr_table[incr] = True
00073
00074 class MemAccessKey:
00075 '''メモリアクセスを辞書引きするときに使うキー
00076 添字とインデクスからハッシュ値、同値性を求める'''
00077 def __init__(self, array, subscripts, load_store):
00078 name_hash = hash(array.var.name)
00079
00080 subscript_hash = 0
00081
00082 for i in subscripts:
00083 subscript_hash += ctrump.loop_subscript_hash(i)
00084
00085 self.array_name = array.var.name
00086 self.subscripts = subscripts
00087 self.load_store = load_store
00088
00089 self.hash = int((name_hash * subscript_hash + load_store)%0x7fffffff)
00090
00091 def __hash__(self):
00092 return self.hash
00093
00094 def __eq__(self, rhs):
00095 if rhs.array_name != self.array_name:
00096 return False
00097 if rhs.load_store != self.load_store:
00098 return False
00099
00100 rs = rhs.subscripts
00101 ls = self.subscripts
00102
00103 l = len(ls)
00104 if len(rs) != l:
00105 return False
00106
00107 for i in range(0,l):
00108 if not ctrump.loop_subscript_equal(rs[i], ls[i]):
00109 return False
00110
00111 return True
00112
00113 class CSETableNode:
00114 def __init__(self, node_list):
00115 self.node_list = node_list
00116
00117 class MemoryOpOptimizer:
00118 def __init__(self):
00119 self.changed = False
00120 self.cse_table = {}
00121 self.level_table = {}
00122 self.invariants = []
00123 self.neighbor_table = {}
00124
00125 def change(self):
00126 self.changed = True
00127
00128 def loop_invariant_motion_0(self, tree):
00129 tree_level = tree.loop_node.nest_level
00130 self.level_table[tree_level] = tree
00131
00132 motion_ops = []
00133 for i in tree.ops:
00134 level = i.max_iv_level()
00135
00136 if level != tree_level:
00137 motion_ops.append((i, level))
00138 i.removed = True
00139 self.changed = True
00140
00141 tree.ops = filter(lambda x:not x.removed, tree.ops)
00142
00143 for (i,level) in motion_ops:
00144 i.removed = False
00145 if (level == None) or (level < self.min_level):
00146 self.invariants.append(i)
00147 i.loop_node = None
00148 else:
00149 self.level_table[level].ops.append(i)
00150 i.loop_node = self.level_table[level].loop_node
00151
00152 for i in tree.children:
00153 self.loop_invariant_motion_0(i)
00154
00155 def loop_invariant_motion(self, tree):
00156 self.min_level = tree.loop_node.nest_level
00157 self.loop_invariant_motion_0(tree)
00158
00159 def do_cse(self, prev_op, op):
00160 changed = True
00161 op.removed = True
00162
00163 for i in op.children:
00164 i.chain = prev_op
00165 prev_op.children.append(i)
00166
00167 new_offset_list = []
00168
00169 l = len(op.offset_list)
00170
00171 for i in range(0, l):
00172 (prev_off, prev_min, prev_max) = prev_op.offset_list[i]
00173 (op_off, op_min, op_max) = op.offset_list[i]
00174
00175 if op_min < prev_min:
00176 prev_min = op_min
00177 if op_max > prev_max:
00178 prev_max = op_max
00179
00180 new_offset_list.append((prev_off, prev_min, prev_max))
00181
00182 prev_op.offset_list = new_offset_list
00183
00184 prev_op.update_expr_list = prev_op.update_expr_list + op.update_expr_list
00185
00186 def merge_neighbor_0(self, cur_op):
00187 table = self.neighbor_table
00188 array = cur_op.array
00189 num_sub = len(cur_op.subscripts)
00190 sub = cur_op.subscripts
00191
00192 load_data_sub = cur_op.load_data_sub_offset()
00193 sub_sub = sub[0:load_data_sub]
00194
00195 key = MemAccessKey(array, sub_sub, cur_op.load_store)
00196
00197 if key in table:
00198 cse_node = table[key]
00199
00200 found = False
00201
00202 for prev_op in cse_node.node_list:
00203 if id(prev_op.chain) != id(cur_op.chain):
00204 continue
00205
00206 prev_last_sub = prev_op.subscripts[load_data_sub]
00207 cur_last_sub = cur_op.subscripts[load_data_sub]
00208
00209 prev_indices = prev_last_sub.indices
00210 cur_indices = cur_last_sub.indices
00211
00212 if len(prev_indices) != len(cur_indices):
00213 continue
00214
00215
00216 if len(prev_indices) == 0:
00217 self.changed = True
00218 self.do_cse(prev_op, cur_op)
00219
00220 found = True
00221 elif len(prev_indices) == 1:
00222 if ctrump.loop_index_equal(prev_indices[0], cur_indices[0]):
00223 self.changed = True
00224 self.do_cse(prev_op, cur_op)
00225 found = True
00226
00227 if not found:
00228 cse_node.node_list.append(cur_op)
00229
00230 else:
00231 table[key] = CSETableNode([cur_op])
00232
00233
00234
00235
00236 def merge_neighbor(self, tree):
00237 for i in tree.ops:
00238 self.merge_neighbor_0(i)
00239
00240 tree.ops = filter(lambda x:not x.removed, tree.ops)
00241
00242 for i in tree.children:
00243 self.merge_neighbor(i)
00244
00245 class MemoryOp:
00246 ACCESS_SEQUENTIAL = 0
00247 ACCESS_RANDOM = 1
00248
00249 ACCESS_NULL = 3
00250 ACCESS_CALC_ADDR = 4
00251
00252 STORE = 0
00253 LOAD = 1
00254
00255 COMBINE_OFFSET = 0
00256 COMBINE_TAG = 1
00257
00258 def __init__(self, chain, at_expr, array, subscripts, load_store, loop_node):
00259
00260 self.at_expr = at_expr
00261
00262 off = []
00263 for i in subscripts:
00264 if i.code == ctrump.LOOP_SUBSCRIPT_RECORD_MEMBER_TERMINAL:
00265 cur_off = i.member_name
00266 cur_min = 0
00267 cur_max = 0
00268 else:
00269 cur_off = i.offset
00270 cur_min = cur_off
00271 cur_max = cur_off
00272
00273 off.append((cur_off,cur_min,cur_max))
00274
00275 self.offset_list = off
00276 self.update_expr_list = [(off,at_expr)]
00277 self.array = array
00278 self.subscripts = subscripts
00279 self.chain = chain
00280 if chain:
00281 chain.children.append(self)
00282 self.load_store = load_store
00283 self.children = []
00284 self.removed = False
00285 self.loop_node = loop_node
00286
00287 self.key = MemAccessKey(array, subscripts, load_store)
00288
00289 def __hash__(self):
00290 if self.chain:
00291 return int((hash(self.key) + hash(self.chain))%0x7fffffff)
00292 else:
00293 return hash(self.key)
00294 def __eq__(self, rhs):
00295 return (self.key == rhs.key) and (self.chain == rhs.chain)
00296
00297
00298 def __repr__(self):
00299 ret = str(self.array.var.name)
00300 ret += ctrump.format_loop_subscripts(self.subscripts)
00301 for (off,min,max) in self.offset_list:
00302 ret += "(%d,%d)"%(min, max)
00303 if self.chain:
00304 ret += ' - (chain = %s)'%self.chain
00305 return ret
00306
00307 def access_data_type(self):
00308 sub = self.subscripts
00309 l = len(sub)
00310
00311 if l == 0:
00312 raise Exception('subscript should terminal')
00313
00314 if sub[-1].code == ctrump.LOOP_SUBSCRIPT_COEF_TERMINAL:
00315 return sub[-1].load_type
00316 elif sub[-1].code == ctrump.LOOP_SUBSCRIPT_RECORD_MEMBER_TERMINAL:
00317 return sub[-1].load_type
00318 raise Exception("subscript should terminal '%s'"%(ctrump.loop_subscript_code_string_table[sub[-1].code]))
00319
00320 def load_data_type(self):
00321 sub = self.subscripts[self.load_data_sub_offset()]
00322
00323 if sub.code == ctrump.LOOP_SUBSCRIPT_COEF_TERMINAL:
00324 return sub.load_type
00325 elif sub.code == ctrump.LOOP_SUBSCRIPT_RECORD_MEMBER_TERMINAL:
00326 return sub.record_type
00327 elif sub.code == ctrump.LOOP_SUBSCRIPT_LOAD_RECORD:
00328 return sub.load_type
00329 raise Exception("subscript should terminal '%s'"%(ctrump.loop_subscript_code_string_table[sub[-1].code]))
00330
00331 def load_data_range(self):
00332 off = self.load_data_sub_offset()
00333 (o, min, max) = self.offset_list[off]
00334 return (min,max)
00335
00336 def load_data_sub_offset(self):
00337 sub = self.subscripts
00338 l = len(sub)
00339
00340 if l == 0:
00341 raise Exception('subscript should terminal')
00342
00343 if l >= 2:
00344 if sub[-2].code == ctrump.LOOP_SUBSCRIPT_LOAD_RECORD:
00345 return l-2
00346
00347 return l-1
00348
00349
00350 def max_iv_level(self):
00351 max_level = None
00352 for i in self.subscripts:
00353 for j in i.indices:
00354 if (j.code == ctrump.LOOP_INDEX_INDUCTIVE or
00355 j.code == ctrump.LOOP_INDEX_POINTER_INC):
00356 if j.iv_level > max_level:
00357 max_level = j.iv_level
00358 if self.chain:
00359 chain_level = self.chain.max_iv_level()
00360 if chain_level > max_level:
00361 max_level = chain_level
00362
00363 return max_level
00364
00365 def extract_iv(self, ptrinc_env):
00366 iv_list = []
00367
00368 sub = self.subscripts
00369 for i in sub:
00370 for j in i.indices:
00371 if j.code == ctrump.LOOP_INDEX_INDUCTIVE:
00372 iv_list.append(j)
00373 elif j.code == ctrump.LOOP_INDEX_POINTER_INC:
00374 ptrinc_env.append_ptrinc(self.array, j.incr)
00375
00376 return iv_list
00377
00378 def gen_name(self):
00379 name = str(self.array.var.name)
00380 for i in self.subscripts:
00381 if i.code == ctrump.LOOP_SUBSCRIPT_RECORD_MEMBER:
00382 name += '_%s'%i.member_name
00383 elif i.code == ctrump.LOOP_SUBSCRIPT_RECORD_MEMBER_TERMINAL:
00384 name += '_%s'%i.member_name
00385
00386 for j in i.indices:
00387 if j.code == ctrump.LOOP_INDEX_INDUCTIVE:
00388 name += "_%s"%str(j.var.name)
00389
00390 return name
00391
00392 def is_partial_store(self):
00393 if self.load_store != self.STORE:
00394 return False
00395
00396 sub_off = self.load_data_sub_offset()
00397 if sub_off != len(self.subscripts)-1:
00398 return True
00399
00400 (range_min, range_max) = self.load_data_range()
00401
00402 if range_min != 0 and range_max != 0:
00403 return True
00404
00405 if len(self.subscripts) == 0:
00406 return False
00407
00408 sub = self.subscripts[-1]
00409 if sub.code != ctrump.LOOP_SUBSCRIPT_COEF_TERMINAL:
00410 return True
00411
00412 if len(sub.indices) != 1:
00413 return True
00414
00415 if (sub.indices[0].code == ctrump.LOOP_INDEX_INDUCTIVE or
00416 sub.indices[0].code == ctrump.LOOP_INDEX_POINTER_INC):
00417 if sub.indices[0].incr == 1:
00418 return False
00419
00420 return True
00421
00422
00423 class MemoryOpTree:
00424 def __init__(self, loop_node, ops, children):
00425 self.loop_node = loop_node
00426 self.ops = ops
00427 self.children = children
00428
00429 def __repr__(self):
00430 ret = "ops = \n" + str(self.ops) + "\n"
00431 ret += "children = \n" + str(self.children) + "\n"
00432 return ret
00433
00434 def get_depth(self):
00435 max_depth = 0
00436 for i in self.children:
00437 d = i.get_depth()
00438 if d > max_depth:
00439 max_depth = d
00440 return max_depth + 1
00441
00442 def optimize(self):
00443 invariants = []
00444 while True:
00445 optimizer = MemoryOpOptimizer()
00446
00447 for i in invariants:
00448 optimizer.merge_neighbor_0(i)
00449 invariants = filter(lambda x:not x.removed, invariants)
00450
00451 optimizer.merge_neighbor(self)
00452 optimizer.loop_invariant_motion(self)
00453
00454 invariants += optimizer.invariants
00455
00456 if not optimizer.changed:
00457 break
00458
00459 return invariants
00460
00461 class MemoryOperation:
00462 def __init__(self, tree, invariants):
00463 self.tree = tree
00464 self.invariants = invariants
00465
00466 def optimize(self):
00467 invariants = self.tree.optimize()
00468 self.invariants += invariants
00469
00470 def loop_tree_to_memop_tree_0(loop_tree):
00471 table = {}
00472
00473 cur_level = loop_tree.nest_level
00474
00475 loads = copy.copy(loop_tree.parallel_loads)
00476 stores = copy.copy(loop_tree.parallel_stores)
00477
00478 ops = []
00479 children = []
00480
00481 for i in loop_tree.children:
00482 n = loop_tree_to_memop_tree_0(i)
00483 children.append(n)
00484
00485 load_store = loads
00486
00487 def filter_single_op(i, load_store):
00488 num_subscript = i.num_subscript
00489 prev_sub_n = 0
00490 chain = None
00491
00492 for j in range(0, num_subscript):
00493 sub = i.subscripts[j]
00494 if (sub.code == ctrump.LOOP_SUBSCRIPT_RECORD_MEMBER_TERMINAL or
00495 sub.code == ctrump.LOOP_SUBSCRIPT_COEF_TERMINAL):
00496 cur_subscripts = i.subscripts[prev_sub_n:j+1]
00497 if j != (num_subscript-1):
00498 memop = MemoryOp.LOAD
00499 else:
00500 memop = load_store
00501
00502 replace_expr = None
00503 if j == num_subscript-1:
00504 replace_expr = i.at_expr
00505
00506 cur_op = MemoryOp(chain, replace_expr, i.array, cur_subscripts, memop, loop_tree)
00507 ops.append(cur_op)
00508
00509 chain = cur_op
00510 prev_sub_n = j+1
00511
00512 if prev_sub_n != num_subscript:
00513 raise NotImplementedError('non terminal memory operation')
00514
00515 for i in loads:
00516 filter_single_op(i, MemoryOp.LOAD)
00517 for i in stores:
00518 filter_single_op(i, MemoryOp.STORE)
00519
00520 return MemoryOpTree(loop_tree, ops, children)
00521
00522 def loop_tree_to_memop_tree(loop_tree):
00523 return MemoryOperation(loop_tree_to_memop_tree_0(loop_tree), [])