Coverage for src/flag_gems/ops/cumsum.py: 40%

328 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import functools 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8from torch._prims_common import is_boolean_dtype, is_integer_dtype 

9 

10from flag_gems.runtime import device, torch_device_fn 

11from flag_gems.utils import get_device_properties, libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14device = device.name 

15logger = logging.getLogger(__name__) 

16 

17 

18@functools.lru_cache 

19def get_num_sms(idx: int) -> int: 

20 return get_device_properties(idx).multi_processor_count 

21 

22 

23@tl.constexpr 

24def get_scan_accum_type(inp_dtype: tl.dtype) -> tl.dtype: 

25 if inp_dtype.is_bf16() or inp_dtype.is_fp16(): 

26 return tl.float32 

27 if inp_dtype.is_int(): # signed or not(including bool) 

28 return tl.int64 

29 else: 

30 return inp_dtype 

31 

32 

33@libentry() 

34@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

35def scan_part_sum_kernel( 

36 inp, 

37 out, 

38 partial_sum, 

39 n_elements, 

40 part_num, 

41 BLOCK_SIZE: tl.constexpr, 

42): 

43 pid = tle.program_id(0) 

44 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

45 mask = offset < n_elements 

46 

47 inp_ptrs = inp + offset 

48 inp_vals = tl.load(inp_ptrs, mask=mask) 

49 if ( 

50 tl.constexpr(inp_vals.dtype.is_int64()) 

51 or tl.constexpr(inp_vals.dtype.is_uint64()) 

52 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

53 inp_vals = inp_vals 

54 elif tl.constexpr(inp_vals.dtype.is_int()): 

55 inp_vals = inp_vals.to(tl.int32) 

56 else: 

57 inp_vals = inp_vals.to(tl.float32) 

58 result = tl.cumsum(inp_vals, axis=0) 

59 

60 part_sum_via_sum = tl.sum(inp_vals) 

61 

62 out_ptrs = out + offset 

63 tl.store(out_ptrs, result, mask=mask) 

64 

65 partial_sum_ptrs = partial_sum + pid 

66 tl.store(partial_sum_ptrs, part_sum_via_sum) 

67 

68 

69@libentry() 

70@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

71def add_base_sum_kernel( 

72 out, 

73 partial_sum, 

74 n_elements, 

75 part_num, 

76 BLOCK_SIZE: tl.constexpr, 

77): 

78 pid = tle.program_id(0) 

79 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

80 mask = offset < n_elements 

81 

82 out_ptrs = out + offset 

83 out_vals = tl.load(out_ptrs, mask=mask) 

84 

85 if pid > 0: 

86 partial_sum_ptrs = partial_sum + pid - 1 

87 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

88 

89 final_vals = out_vals + last_part_sum_via_sum 

90 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

91 

92 

93@libentry() 

94@triton.jit(do_not_specialize=["part_num"]) 

95def scan_part_sum_abc_kernel( 

96 inp, 

97 out, 

98 partial_sum, 

99 B, 

100 C, 

101 part_num, 

102 BLOCK_SIZE: tl.constexpr, 

103): 

104 pid_a = tle.program_id(0) 

105 pid_b = tle.program_id(1) 

106 pid_c = tle.program_id(2) 

107 

108 a_idx = pid_a 

109 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

110 c_idx = pid_c 

111 

112 offset = a_idx * B * C + b_idx * C + c_idx 

113 base_part_offset = a_idx * part_num * C + c_idx 

114 part_offset = base_part_offset + pid_b * C 

115 

116 mask = b_idx < B 

117 inp_ptrs = inp + offset 

118 inp_vals = tl.load(inp_ptrs, mask=mask) 

119 if ( 

120 tl.constexpr(inp_vals.dtype.is_int64()) 

121 or tl.constexpr(inp_vals.dtype.is_uint64()) 

122 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

123 inp_vals = inp_vals 

124 elif tl.constexpr(inp_vals.dtype.is_int()): 

125 inp_vals = inp_vals.to(tl.int32) 

126 else: 

127 inp_vals = inp_vals.to(tl.float32) 

128 result = tl.cumsum(inp_vals, axis=0) 

129 

130 part_sum_via_sum = tl.sum(inp_vals) 

131 

132 out_ptrs = out + offset 

133 tl.store(out_ptrs, result, mask=mask) 

134 

135 partial_sum_ptrs = partial_sum + part_offset 

136 tl.store(partial_sum_ptrs, part_sum_via_sum) 

137 

138 

139@libentry() 

140@triton.jit(do_not_specialize=["part_num"]) 

141def add_base_sum_abc_kernel( 

142 out, 

143 partial_sum, 

144 B, 

145 C, 

146 part_num, 

147 BLOCK_SIZE: tl.constexpr, 

148): 

149 pid_a = tle.program_id(0) 

150 pid_b = tle.program_id(1) 

151 pid_c = tle.program_id(2) 

152 

153 a_idx = pid_a 

154 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

155 c_idx = pid_c 

156 

157 base_offset = a_idx * B * C + c_idx 

158 offset = base_offset + b_idx * C 

159 base_part_offset = a_idx * part_num * C + c_idx 

160 last_part_offset = base_part_offset + (pid_b - 1) * C 

161 

162 mask = b_idx < B 

163 out_ptrs = out + offset 

164 out_vals = tl.load(out_ptrs, mask=mask) 

165 

166 if pid_b > 0: 

167 partial_sum_ptrs = partial_sum + last_part_offset 

168 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

169 

170 final_vals = out_vals + last_part_sum_via_sum 

171 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

172 

173 

174def scan_then_fan_col(inp, out, n_ele, dtype): 

175 # TODO(all): tune on target board 

176 BLOCK_SIZE = 1024 

177 if n_ele <= 1024 * 4: 

178 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

179 part_num = math.ceil(n_ele / BLOCK_SIZE) 

180 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) 

181 

182 grid = (part_num,) 

183 with torch_device_fn.device(inp.device): 

184 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

185 

186 if part_num >= 2: 

187 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) 

188 with torch_device_fn.device(inp.device): 

189 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

190 

191 

192def scan_then_fan(inp, out, A, B, C, dtype): 

193 # TODO(all): tune on target board 

194 BLOCK_SIZE = 1024 

195 if B <= 1024 * 4: 

196 BLOCK_SIZE = triton.next_power_of_2(B) 

197 part_num = math.ceil(B / BLOCK_SIZE) 

198 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

199 

200 grid = (A, part_num, C) 

201 with torch_device_fn.device(inp.device): 

202 scan_part_sum_abc_kernel[grid]( 

203 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE 

204 ) 

205 

206 if part_num >= 2: 

207 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) 

208 with torch_device_fn.device(inp.device): 

209 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) 

210 

211 

212def cumsum_wrapper(inp, dim=1, dtype=None, out=None): 

213 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

214 shape = inp.shape 

215 dim = dim % inp.ndim 

216 M = 1 

217 N = shape[dim] 

218 for i in range(dim): 

219 M *= shape[i] 

220 inp = inp.contiguous() 

221 K = inp.numel() // M // N 

222 

223 if dtype is None: 

224 dtype = inp.dtype 

225 if is_integer_dtype(dtype) or is_boolean_dtype(dtype): 

226 dtype = torch.int64 

227 if out is None: 

228 out = torch.empty_like(inp, dtype=dtype) 

229 

230 compute_dtype = out.dtype 

231 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: 

232 compute_dtype = torch.float32 

233 

234 if K == 1: # row scan 

235 reduce_then_scan_row(inp, out, M, N, compute_dtype) 

236 else: # col scan 

237 scan_then_fan(inp, out, M, N, K, compute_dtype) 

238 

239 return out 

240 

241 

242def reduce_then_scan_row(x, out, M, N, compute_dtype): 

243 if N <= 16384: # persistent 

244 TILE_SIZE = triton.next_power_of_2(N) 

245 num_warps = 8 if TILE_SIZE > 2048 else 4 

246 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)]( 

247 x, out, N, TILE_SIZE, num_warps=num_warps 

248 ) 

249 return out 

250 

251 TILE_SIZE = min(4096, triton.next_power_of_2(N)) 

252 num_warps = 8 if TILE_SIZE > 2048 else 4 

253 num_tiles = triton.cdiv(N, TILE_SIZE) 

254 max_ctas = get_num_sms(x.device.index) * 4 

255 num_ctas = min(num_tiles, max_ctas) 

256 ROOT_SCAN_TILE_SIZE = triton.next_power_of_2(num_ctas) 

257 tiles_per_cta = triton.cdiv(num_tiles, num_ctas) 

258 block_sums = torch.empty( 

259 ( 

260 M, 

261 num_ctas, 

262 ), 

263 dtype=compute_dtype, 

264 device=x.device, 

265 ) 

266 block_inclusive_prefix = torch.empty( 

267 ( 

268 M, 

269 num_ctas, 

270 ), 

271 dtype=compute_dtype, 

272 device=x.device, 

273 ) 

274 

275 # 3-kernel implementation 

276 reduce_then_scan_block_sum_kernel_row[(M, num_ctas, 1, 1)]( 

277 x, block_sums, N, tiles_per_cta, TILE_SIZE, num_warps=num_warps 

278 ) 

279 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)]( 

280 block_sums, 

281 block_inclusive_prefix, 

282 num_ctas, 

283 ROOT_SCAN_TILE_SIZE, 

284 num_warps=num_warps, 

285 ) 

286 reduce_then_scan_block_scan_kernel_row[(M, num_ctas, 1)]( 

287 x, 

288 block_inclusive_prefix, 

289 out, 

290 N, 

291 num_ctas, 

292 tiles_per_cta, 

293 TILE_SIZE, 

294 num_warps=num_warps, 

295 ) 

296 return out 

297 

298 

299@triton.jit 

300def reduce_then_scan_block_sum_kernel_row( 

301 in_ptr, 

302 block_sum_ptr, 

303 N, 

304 tiles_per_cta, 

305 TILE_SIZE: tl.constexpr, 

306): 

307 """The same kernel as the block sum in parallel reduce""" 

308 pid_n = tl.program_id(1).to(tl.int64) 

309 pid_m = tl.program_id(0).to(tl.int64) 

310 num_programs_n = tl.num_programs(1) 

311 block_offset = pid_n * (tiles_per_cta * TILE_SIZE) 

312 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N) 

313 

314 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty) 

315 acc = tl.zeros((TILE_SIZE,), dtype=acc_dtype) 

316 for start in range(block_offset, block_end, TILE_SIZE): 

317 offsets = start + tl.arange(0, TILE_SIZE) 

318 x = tl.load(in_ptr + pid_m * N + offsets, mask=offsets < N).to(acc_dtype) 

319 acc += x 

320 block_sum = tl.sum(acc, 0) 

321 tl.store( 

322 block_sum_ptr + pid_m * num_programs_n + pid_n, block_sum, cache_modifier=".cg" 

323 ) 

324 

325 

326@triton.jit 

327def reduce_then_scan_root_scan_kernel_row(in_ptr, out_ptr, N, TILE_SIZE: tl.constexpr): 

328 """Almost The same kernel as the persistent scan kernel""" 

329 pid = tl.program_id(0).to(tl.int64) 

330 offsets = tl.arange(0, TILE_SIZE) 

331 mask = offsets < N 

332 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty) 

333 x = tl.load(in_ptr + pid * N + offsets, mask=mask, other=0).to(acc_dtype) 

334 out = tl.cumsum(x, 0) 

335 tl.store(out_ptr + pid * N + offsets, out, mask=mask) 

336 

337 

338@triton.jit 

339def reduce_then_scan_block_scan_kernel_row( 

340 in_ptr, 

341 previous_sum_ptr, 

342 out_ptr, 

343 N, 

344 num_tiles_n, 

345 tiles_per_cta, 

346 TILE_SIZE: tl.constexpr, 

347): 

348 pid_m = tl.program_id(0).to(tl.int64) 

349 pid_n = tl.program_id(1).to(tl.int64) 

350 block_offset = pid_n * (tiles_per_cta * TILE_SIZE) 

351 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N) 

352 acc_dtype: tl.constexpr = get_scan_accum_type(in_ptr.type.element_ty) 

353 

354 prefix = tl.load( 

355 previous_sum_ptr + pid_m * num_tiles_n + pid_n - 1, mask=pid_n > 0, other=0 

356 ).to(acc_dtype) 

357 for start in range(block_offset, block_end, TILE_SIZE): 

358 offsets = start + tl.arange(0, TILE_SIZE) 

359 mask = offsets < N 

360 x = tl.load(in_ptr + pid_m * N + offsets, mask=mask).to(acc_dtype) 

361 tile_scan = prefix + tl.cumsum(x, 0) 

362 prefix += tl.sum(x, 0) 

363 tl.store( 

364 out_ptr + pid_m * N + offsets, tile_scan, mask=mask, cache_modifier=".cg" 

365 ) 

366 

367 

368def cumsum(inp, dim=1, *, dtype=None): 

369 logger.debug("GEMS CUMSUM") 

370 return cumsum_wrapper(inp, dim, dtype) 

371 

372 

373def cumsum_out(inp, dim=1, *, dtype=None, out): 

374 logger.debug("GEMS CUMSUM_OUT") 

375 return cumsum_wrapper(inp, dim, dtype, out) 

376 

377 

378@libentry() 

379@triton.jit(do_not_specialize=["K"]) 

380def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr): 

381 row_start = tle.program_id(0) * K 

382 row_off = tl.arange(0, BLOCK) 

383 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0) 

384 if x.dtype.is_fp16(): 

385 x = x.to(tl.float32) 

386 y_sum = tl.sum(x, 0) 

387 y = tl.cumsum(x, 0) 

388 y = y / y_sum 

389 tl.store(out + row_start + row_off, y, mask=row_off < K) 

390 

391 

392@libentry() 

393@triton.jit( 

394 do_not_specialize=[ 

395 "r", 

396 "t", 

397 "R", 

398 "K", 

399 "r_stride", 

400 "out_r_stride", 

401 ] 

402) 

403def block_cumsum_kernel( 

404 inp, 

405 out, 

406 sums, 

407 r, 

408 t, 

409 R, 

410 K, 

411 r_stride, 

412 k_stride, 

413 out_r_stride, 

414 out_k_stride, 

415 OUTPUT_SUMS: tl.constexpr, 

416 NORMALIZE: tl.constexpr, 

417 HAS_OUT_LAYOUT: tl.constexpr, 

418 TILE: tl.constexpr, 

419): 

420 # One CTA processes a (r, t*tile) chunk 

421 # rows = [ grid.y, grid.y + r ) 

422 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

423 gridx = tle.program_id(0).to(tl.int64) 

424 gridy = tle.program_id(1).to(tl.int64) 

425 n_chunks = tle.num_programs(0) 

426 

427 for row in range(gridy * r, min((gridy + 1) * r, R)): 

428 curr_cumsum = tl.zeros((1,), tl.float32) 

429 row_offset = row * r_stride 

430 cols = gridx * t * TILE + tl.arange(0, TILE) 

431 for ti in range(0, t): 

432 cols_offset = cols * k_stride 

433 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

434 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

435 x = x.to(tl.float32) 

436 tile_sum = tl.sum(x, 0)[None] 

437 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum 

438 curr_cumsum += tile_sum 

439 if HAS_OUT_LAYOUT: 

440 cols_offset = cols * out_k_stride 

441 row_offset = row * out_r_stride 

442 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K) 

443 if OUTPUT_SUMS: 

444 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum) 

445 cols += TILE 

446 if NORMALIZE: 

447 cols = gridx * t * TILE + tl.arange(0, TILE) 

448 for _ in range(0, t): 

449 cols_offset = cols * k_stride 

450 if HAS_OUT_LAYOUT: 

451 cols_offset = cols * out_k_stride 

452 row_offset = row * out_r_stride 

453 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0) 

454 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

455 x = x.to(tl.float32) 

456 x = x / curr_cumsum 

457 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

458 cols += TILE 

459 

460 

461@libentry() 

462@triton.jit( 

463 do_not_specialize=[ 

464 "r", 

465 "t", 

466 "R", 

467 "K", 

468 "r_stride", 

469 "out_r_stride", 

470 ] 

471) 

472def block_update_kernel( 

473 inp, 

474 base, 

475 rscale_ptr, 

476 out, 

477 r, 

478 t, 

479 R, 

480 K, 

481 r_stride, 

482 k_stride, 

483 out_r_stride, 

484 out_k_stride, 

485 rscale_stride, 

486 HAS_OUT_LAYOUT: tl.constexpr, 

487 TILE: tl.constexpr, 

488): 

489 # One CTA processes a (r, t*tile) chunk 

490 # rows = [ grid.y, grid.y + r ) 

491 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

492 gridx = tle.program_id(0).to(tl.int64) 

493 gridy = tle.program_id(1).to(tl.int64) 

494 n_gridx = tle.num_programs(1) 

495 

496 base += gridy * n_gridx + gridx 

497 rscale_ptr += gridy * rscale_stride 

498 

499 for row in range(gridy, min(gridy + r, R)): 

500 d = tl.load(base) 

501 rscale = tl.load(rscale_ptr) 

502 base += gridx 

503 rscale_ptr += rscale_stride 

504 row_offset = row * r_stride 

505 cols = gridx * t * TILE + tl.arange(0, TILE) 

506 for _ in range(0, t): 

507 cols_offset = cols * k_stride 

508 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

509 x += d 

510 x /= rscale 

511 if HAS_OUT_LAYOUT: 

512 cols_offset = cols * out_k_stride 

513 row_offset = row * out_r_stride 

514 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

515 cols += TILE 

516 

517 

518GRID_Y_LIMIT = 65535 

519 

520 

521def normed_cumsum(inp, dim=-1): 

522 logger.debug("GEMS NORMED_CUMSUM") 

523 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

524 dim = dim % inp.ndim 

525 N = inp.numel() 

526 K = inp.size(dim) 

527 # inp = inp.contiguous() 

528 # First and last dims are easier to handle, but transpose the middle dim to the last 

529 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True) 

530 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1]) 

531 if is_mid_dim: 

532 inp = inp.transpose(dim, -1).contiguous() 

533 dim = -1 

534 out = torch.empty_like(inp) 

535 with torch_device_fn.device(inp.device.index): 

536 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta 

537 num_sms = get_device_properties(device).multi_processor_count 

538 TILE = 2048 

539 # Each row is split into n_chunks of chunks where each chunk is compised of 

540 # n_tiles of tiles. Different chunks are assigned to different ctas. 

541 n_rows = N // K 

542 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE)) 

543 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks) 

544 k_stride = inp.stride(dim) 

545 r_stride = inp.size(dim) if k_stride == 1 else 1 

546 if n_rows > GRID_Y_LIMIT: 

547 batch = triton.cdiv(n_rows, GRID_Y_LIMIT) 

548 n_batch = triton.cdiv(n_rows, batch) 

549 else: 

550 batch = 1 

551 n_batch = n_rows 

552 

553 grid = (n_chunks, n_batch) 

554 if n_chunks == 1: 

555 block_cumsum_kernel[grid]( 

556 inp, 

557 out, 

558 0, 

559 batch, 

560 n_tiles, 

561 n_rows, 

562 K, 

563 r_stride, 

564 k_stride, 

565 r_stride, 

566 k_stride, 

567 OUTPUT_SUMS=False, 

568 NORMALIZE=True, 

569 HAS_OUT_LAYOUT=False, 

570 TILE=TILE, 

571 ) 

572 return out 

573 

574 if inp.dtype != torch.float64: 

575 acc_dtype = torch.float32 

576 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name) 

577 cumsums = torch.empty_like(sums) 

578 block_cumsum_kernel[grid]( 

579 inp, 

580 out, 

581 sums, 

582 batch, 

583 n_tiles, 

584 n_rows, 

585 K, 

586 r_stride, 

587 k_stride, 

588 r_stride, 

589 k_stride, 

590 OUTPUT_SUMS=True, 

591 NORMALIZE=False, 

592 HAS_OUT_LAYOUT=False, 

593 TILE=TILE, 

594 ) 

595 # Pass two, scan partial cumsums 

596 block_cumsum_kernel[(1, n_batch)]( 

597 sums, 

598 cumsums, 

599 0, 

600 batch, 

601 1, 

602 n_rows, 

603 n_chunks, 

604 n_chunks, 

605 1, 

606 n_chunks, 

607 1, 

608 OUTPUT_SUMS=False, 

609 NORMALIZE=False, 

610 HAS_OUT_LAYOUT=True, 

611 TILE=TILE, 

612 ) 

613 # print(sums) 

614 rscale = cumsums[..., -1] 

615 block_update_kernel[grid]( 

616 out, 

617 cumsums - sums, 

618 rscale, 

619 out, 

620 batch, 

621 n_tiles, 

622 n_rows, 

623 K, 

624 r_stride, 

625 k_stride, 

626 r_stride, 

627 k_stride, 

628 n_chunks, 

629 HAS_OUT_LAYOUT=False, 

630 TILE=TILE, 

631 ) 

632 return out