Coverage for src/flag_gems/runtime/backend/_cambricon/ops/softmax.py: 0%
546 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import copy
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
13from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16MAX_N = 16384
19def align(max_block):
20 a = triton.next_power_of_2(max_block)
21 return max_block if max_block == a else a // 2
24def config_prune1(configs, named_args, **kwargs):
25 M = named_args["M"]
26 N = named_args["N"]
27 K = named_args["K"]
28 input = named_args["input_ptr"]
29 configs_map = {}
30 for config in configs:
31 kw = config.kwargs
32 TILE_K, TILE_N, num_warps, num_stages = (
33 kw["TILE_K"],
34 kw["TILE_N"],
35 config.num_warps,
36 config.num_stages,
37 )
38 if N < MAX_N:
39 config = copy.deepcopy(config)
40 TILE_N = config.kwargs["TILE_N"] = N
41 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1))
42 TILE_K = config.kwargs["TILE_K"] = k_per_core
43 num_stages = config.num_stages = 1
44 key = (TILE_K, TILE_N, num_warps, num_stages)
45 configs_map.setdefault(key, config)
47 config = copy.deepcopy(config)
48 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (2 * TILE_N + 1)
49 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe)
50 num_stages = config.num_stages = 1
51 key = (TILE_K, TILE_N, num_warps, num_stages)
52 configs_map.setdefault(key, config)
54 config = copy.deepcopy(config)
55 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (3 * TILE_N + 1)
56 if input.dtype == torch.float32:
57 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (4 * TILE_N + 1)
58 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe)
59 num_stages = config.num_stages = 3
60 key = (TILE_K, TILE_N, num_warps, num_stages)
61 configs_map.setdefault(key, config)
62 else:
63 key = (TILE_K, TILE_N, num_warps, num_stages)
64 configs_map.setdefault(key, config)
65 pruned_configs = []
66 for k, v in configs_map.items():
67 pruned_configs.append(v)
68 extra_config = copy.deepcopy(pruned_configs[0])
69 extra_config.kwargs["TILE_K"] = 1
70 extra_config.kwargs["TILE_N"] = N
71 extra_config.num_warps = 1
72 extra_config.num_stages = 3
73 pruned_configs.append(extra_config)
74 extra_config2 = copy.deepcopy(extra_config)
75 extra_config2.num_stages = 1
76 pruned_configs.append(extra_config2)
77 return pruned_configs
80def softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K):
81 one_tile_k = TILE_K * max(TOTAL_CORE_NUM // M, 1) >= K
82 one_tile_n = TILE_N >= N
83 if one_tile_n and one_tile_k:
84 return 0
85 elif one_tile_n and not one_tile_k:
86 return 1
87 else:
88 return 2
91@libentry()
92@libtuner(
93 configs=runtime.get_tuned_config("softmax_non_inner"),
94 key=[
95 "N",
96 "K",
97 ],
98 prune_configs_by={"early_config_prune": config_prune1},
99)
100@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
101@triton.jit
102def softmax_kernel_non_inner(
103 output_ptr,
104 input_ptr,
105 M,
106 N,
107 K,
108 TILE_N: tl.constexpr,
109 TILE_K: tl.constexpr,
110 TILE_MODE: tl.constexpr,
111):
112 pid_m = tl.program_id(0)
113 pid_k = tl.program_id(1)
115 p_k_num = tl.num_programs(axis=1)
116 split_k = tl.cdiv(K, p_k_num)
117 k_start = pid_k * split_k
119 if TILE_MODE == 0:
120 n_offset = tl.arange(0, TILE_N)
121 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)
122 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
123 mask = k_offset[None, :] < K
124 input_ptrs = input_ptr + offset
125 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
126 row_minus_max = inp - tl.max(inp, axis=0)[None, :]
127 numerator = tl.exp(row_minus_max)
128 denominator = tl.sum(numerator, axis=0)[None, :]
129 recip = 1.0 / denominator
130 softmax_output = numerator * recip
131 output_ptrs = output_ptr + offset
132 tl.store(output_ptrs, softmax_output, mask=mask)
133 elif TILE_MODE == 1:
134 for k_idx in range(0, split_k, TILE_K):
135 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
136 n_offset = tl.arange(0, TILE_N)
137 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
138 mask = k_offset[None, :] < K
139 input_ptrs = input_ptr + offset
140 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
141 row_minus_max = inp - tl.max(inp, axis=0)[None, :]
142 numerator = tl.exp(row_minus_max)
143 denominator = tl.sum(numerator, axis=0)[None, :]
144 recip = 1.0 / denominator
145 softmax_output = numerator * recip
146 output_ptrs = output_ptr + offset
147 tl.store(output_ptrs, softmax_output, mask=mask)
148 else:
149 for k_idx in range(0, split_k, TILE_K):
150 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
151 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
152 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
153 # specialization does not improve performance inn this example, as tested
154 for start_n in range(0, N, TILE_N):
155 n_offset = start_n + tl.arange(0, TILE_N)
156 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
157 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
158 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
159 tl.float32
160 )
161 m_new = tl.maximum(m, inp)
162 all_neg_inf = m_new == float("-inf")
163 z = tl.where(
164 all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)
165 )
166 m = m_new
167 m_reduced = tl.max(m, 0) # (TILE_K,)
168 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
169 recip_z = 1.0 / z
170 m = m_reduced
171 # specialization does not improve performance inn this example, as tested
172 for start_n in range(0, N, TILE_N):
173 n_offset = start_n + tl.arange(0, TILE_N)
174 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
175 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
176 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
177 tl.float32
178 )
179 o = tl.exp(inp - m[None, :]) * recip_z[None, :]
180 tl.store(output_ptr + offset, o, mask=mask)
183def config_prune2(configs, named_args, **kwargs):
184 M = named_args["M"]
185 N = named_args["N"]
186 input = named_args["input_ptr"]
187 configs_map = {}
188 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops
189 for config in configs:
190 kw = config.kwargs
191 BLOCK_M, BLOCK_N, num_warps, num_stages = (
192 kw["BLOCK_M"],
193 kw["BLOCK_N"],
194 config.num_warps,
195 config.num_stages,
196 )
197 if N < MAX_N:
198 config = copy.deepcopy(config)
199 BLOCK_N = config.kwargs["BLOCK_N"] = N
200 m_per_core = math.ceil(M / TOTAL_CORE_NUM)
201 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core
202 num_stages = config.num_stages = 1
203 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
204 configs_map.setdefault(key, config)
206 config = copy.deepcopy(config)
207 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (2 * BLOCK_N + 1)
208 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
209 num_stages = config.num_stages = 1
210 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
211 configs_map.setdefault(key, config)
213 config = copy.deepcopy(config)
214 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (4 * BLOCK_N + 1)
215 if input.dtype == torch.float32:
216 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1)
217 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
218 num_stages = config.num_stages = 3
219 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
220 configs_map.setdefault(key, config)
221 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
222 # Only keep one config for the same key
223 configs_map.setdefault(key, config)
224 pruned_configs = []
225 for k, v in configs_map.items():
226 pruned_configs.append(v)
227 # Add a heuristic config.
228 extra_config = copy.deepcopy(pruned_configs[0])
229 extra_config.kwargs["BLOCK_M"] = 1
230 extra_config.kwargs["BLOCK_N"] = N
231 extra_config.num_warps = 1
232 extra_config.num_stages = 3
233 pruned_configs.append(extra_config)
234 extra_config2 = copy.deepcopy(extra_config)
235 extra_config2.num_stages = 1
236 pruned_configs.append(extra_config2)
237 return pruned_configs
240def softmax_tile_mode_for_inner(args):
241 one_tile_m = args["BLOCK_M"] * TOTAL_CORE_NUM >= args["M"]
242 one_tile_n = args["BLOCK_N"] >= args["N"]
243 if one_tile_n and one_tile_m:
244 return 0
245 elif one_tile_n and not one_tile_m:
246 return 1
247 else:
248 return 2
251@libentry()
252@libtuner(
253 configs=runtime.get_tuned_config("softmax_inner"),
254 key=[
255 "M",
256 "N",
257 ],
258 prune_configs_by={"early_config_prune": config_prune2},
259)
260@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
261@triton.jit
262def softmax_kernel_inner(
263 output_ptr,
264 input_ptr,
265 M,
266 N,
267 BLOCK_M: tl.constexpr,
268 BLOCK_N: tl.constexpr,
269 TILE_MODE: tl.constexpr,
270):
271 pid_m = tl.program_id(0)
272 pnum = tl.num_programs(axis=0)
273 split_m = tl.cdiv(M, pnum)
274 m_start = pid_m * split_m
276 if TILE_MODE == 0:
277 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
278 n_offset = tl.arange(0, BLOCK_N)
279 offset = m_offset[:, None] * N + n_offset[None, :]
280 mask = m_offset[:, None] < M
281 input_ptrs = input_ptr + offset
282 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
283 row_minus_max = inp - tl.max(inp, axis=1)[:, None]
284 numerator = tl.exp(row_minus_max)
285 denominator = tl.sum(numerator, axis=1)[:, None]
286 recip = 1.0 / denominator
287 softmax_output = numerator * recip
288 output_ptrs = output_ptr + offset
289 tl.store(output_ptrs, softmax_output, mask=mask)
290 elif TILE_MODE == 1:
291 for m_idx in range(0, split_m, BLOCK_M):
292 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
293 n_offset = tl.arange(0, BLOCK_N)
294 offset = m_offset[:, None] * N + n_offset[None, :]
295 mask = m_offset[:, None] < M
296 input_ptrs = input_ptr + offset
297 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
298 trans_inp = tl.trans(inp)
299 row_minus_max = trans_inp - tl.max(trans_inp, axis=0)[None, :]
300 numerator = tl.exp(row_minus_max)
301 denominator = tl.sum(numerator, axis=0)[None, :]
302 recip = 1.0 / denominator
303 softmax_output = tl.trans(numerator * recip)
304 output_ptrs = output_ptr + offset
305 tl.store(output_ptrs, softmax_output, mask=mask)
306 else:
307 for m_idx in range(0, split_m, BLOCK_M):
308 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
309 block_max = tl.full(
310 [BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32
311 )
312 block_sum = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32)
313 # specialization does not improve performance inn this example, as tested
314 for start_n in range(0, N, BLOCK_N):
315 n_offset = start_n + tl.arange(0, BLOCK_N)
316 offset = m_offset[:, None] * N + n_offset[None, :]
317 mask = m_offset[:, None] < M and n_offset[None, :] < N
318 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
319 tl.float32
320 )
321 cur_max = tl.maximum(block_max, inp)
322 all_neg_inf = cur_max == float("-inf")
323 block_sum = tl.where(
324 all_neg_inf,
325 block_sum,
326 block_sum * tl.exp(block_max - cur_max) + tl.exp(inp - cur_max),
327 )
328 block_max = cur_max
330 trans_block_max = tl.trans(block_max)
331 trans_block_sum = tl.trans(block_sum)
332 max_reduced = tl.max(trans_block_max, 0)
333 total_sum = tl.sum(
334 trans_block_sum * tl.exp(trans_block_max - max_reduced[None, :]), 0
335 )
336 recip_total_sum = 1.0 / total_sum
337 total_max = max_reduced
339 for start_n in range(0, N, BLOCK_N):
340 n_offset = start_n + tl.arange(0, BLOCK_N)
341 offset = m_offset[:, None] * N + n_offset[None, :]
342 mask = m_offset[:, None] < M and n_offset[None, :] < N
343 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
344 tl.float32
345 )
346 o = tl.exp(inp - total_max[:, None]) * recip_total_sum[:, None]
347 tl.store(output_ptr + offset, o, mask=mask)
350@triton.jit
351def softmax_kernel_inner_k_partial_stats(
352 x_ptr,
353 max_buf_ptr,
354 sum_buf_ptr,
355 M,
356 N,
357 T,
358 BLOCK_M: tl.constexpr,
359 BLOCK_N: tl.constexpr,
360):
361 pnum = tl.num_programs(axis=0)
362 pid = tl.program_id(0)
363 total_blocks = (M // BLOCK_M) * T
364 work_per_core = (total_blocks + pnum - 1) // pnum
365 start = pid * work_per_core
366 end = tl.minimum(start + work_per_core, total_blocks)
368 for task in range(start, end):
369 row_id = task // T
370 tile_id = task % T
372 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
373 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
374 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
376 tile = tl.load(
377 x_ptr + offs_m[:, None] * N + offs_n[None, :],
378 mask=mask,
379 other=-float("inf"),
380 ).to(tl.float32)
382 tile_max = tl.max(tile, axis=1)
383 all_neg_inf = tile_max == -float("inf")
385 tile_sum = tl.where(
386 all_neg_inf,
387 0.0,
388 tl.sum(tl.exp(tile - tile_max[:, None]), axis=1),
389 )
391 tl.store(max_buf_ptr + offs_m * T + tile_id, tile_max, mask=(offs_m < M))
392 tl.store(sum_buf_ptr + offs_m * T + tile_id, tile_sum, mask=(offs_m < M))
395@triton.jit
396def softmax_kernel_inner_k_merge_stats(
397 max_buf_ptr,
398 sum_buf_ptr,
399 gmax_ptr,
400 gsum_ptr,
401 M: tl.constexpr,
402 T: tl.constexpr,
403 BLOCK_M: tl.constexpr,
404):
405 pid_m = tl.program_id(axis=0)
406 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # [BM]
407 mask_m = offs_m < M
408 tile_max = tl.load(
409 max_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :],
410 mask=(offs_m[:, None] < M),
411 other=-float("inf"),
412 )
413 tile_sum = tl.load(
414 sum_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :],
415 mask=(offs_m[:, None] < M),
416 other=0.0,
417 ).to(tl.float32)
419 gmax = tl.max(tile_max, axis=1)
420 scale = tl.exp(tile_max - gmax[:, None])
421 scale = tl.where(gmax[:, None] == -float("inf"), 0.0, scale)
422 gsum = tl.sum(tile_sum * scale, axis=1)
424 tl.store(gmax_ptr + offs_m, gmax, mask=mask_m)
425 tl.store(gsum_ptr + offs_m, gsum, mask=mask_m)
428@triton.jit
429def softmax_kernel_inner_k_write_softmax(
430 x_ptr,
431 y_ptr,
432 gmax_ptr,
433 gsum_ptr,
434 M,
435 N,
436 T,
437 BLOCK_M: tl.constexpr,
438 BLOCK_N: tl.constexpr,
439):
440 pnum = tl.num_programs(axis=0)
441 pid = tl.program_id(0)
442 total_blocks = (M // BLOCK_M) * T
443 work_per_core = (total_blocks + pnum - 1) // pnum
444 start = pid * work_per_core
445 end = tl.minimum(start + work_per_core, total_blocks)
447 for task in range(start, end):
448 row_id = task // T
449 tile_id = task % T
451 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
452 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
453 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
455 # load global stats
456 gmax = tl.load(gmax_ptr + offs_m, mask=(offs_m < M), other=-float("inf")).to(
457 tl.float32
458 )
459 gsum = tl.load(gsum_ptr + offs_m, mask=(offs_m < M), other=0.0).to(tl.float32)
461 # load tile
462 tile = tl.load(
463 x_ptr + offs_m[:, None] * N + offs_n[None, :],
464 mask=mask,
465 other=-float("inf"),
466 ).to(tl.float32)
468 valid = gsum[:, None] > 0
470 out = tl.where(
471 valid,
472 tl.exp(tile - gmax[:, None]) / gsum[:, None],
473 0.0,
474 )
476 tl.store(y_ptr + offs_m[:, None] * N + offs_n[None, :], out, mask=mask)
479# ------------------------ backward -------------------------------
482def nram_usage_for_backward_non_inner(bn, bk, tile_mode, num_stages, dtype):
483 coef = 1
484 if tile_mode == 0:
485 coef = 3
486 elif tile_mode == 1:
487 if num_stages == 1:
488 coef = 3
489 else:
490 if dtype == torch.float32:
491 coef = 7
492 else:
493 coef = 6
494 else:
495 if num_stages == 1:
496 coef = 5
497 else:
498 if dtype == torch.float32:
499 coef = 13
500 else:
501 coef = 10
502 return (coef * bn + 1) * bk * 4
505def config_prune3(configs, named_args, **kwargs):
506 M = named_args["M"]
507 N = named_args["N"]
508 K = named_args["K"]
509 output = named_args["output_ptr"]
510 dtype = output.dtype
511 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1))
512 # No need for any loop.
513 if nram_usage_for_backward_non_inner(N, k_per_core, 0, 1, dtype) < MAX_NRAM_SIZE:
514 config = copy.deepcopy(configs[0])
515 config.kwargs["TILE_K"] = k_per_core
516 config.kwargs["TILE_N"] = N
517 config.num_stages = 1
518 return [config]
519 align_num = 256 // 4 if dtype == torch.float32 else 256 // 2
520 pruned_configs = []
521 for config in configs:
522 kw = config.kwargs
523 TILE_K, TILE_N, num_stages = (
524 kw["TILE_K"],
525 kw["TILE_N"],
526 config.num_stages,
527 )
528 # Align the lowest dimension to 256B while loading/storing data.
529 if TILE_K % align_num != 0:
530 continue
531 # nram usage shoule be smaller than MAX_NRAM_SIZE
532 mode = softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K)
533 nram = nram_usage_for_backward_non_inner(
534 TILE_N, TILE_K, mode, num_stages, dtype
535 )
536 if nram > MAX_NRAM_SIZE or nram < MAX_NRAM_SIZE // 2:
537 continue
538 pruned_configs.append(config)
539 return pruned_configs
542@libentry()
543@libtuner(
544 configs=runtime.get_tuned_config("softmax_non_inner_bw"),
545 key=[
546 "N",
547 "K",
548 ],
549 prune_configs_by={"early_config_prune": config_prune3},
550)
551@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner"))
552@triton.jit
553def softmax_backward_kernel_non_inner(
554 output_ptr,
555 out_grad_ptr,
556 in_grad_ptr,
557 M,
558 N,
559 K,
560 TILE_N: tl.constexpr,
561 TILE_K: tl.constexpr,
562 TILE_MODE: tl.constexpr,
563):
564 pid_m = tl.program_id(0)
565 pid_k = tl.program_id(1)
567 p_k_num = tl.num_programs(axis=1)
568 split_k = tl.cdiv(K, p_k_num)
569 k_start = pid_k * split_k
571 if TILE_MODE == 0:
572 n_offset = tl.arange(0, TILE_N)
573 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)
574 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
575 mask = k_offset[None, :] < K
576 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
577 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
578 scale = tl.sum(out_tile * out_grad_tile, axis=0)
579 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
580 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
581 elif TILE_MODE == 1:
582 for k_idx in range(0, split_k, TILE_K):
583 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
584 n_offset = tl.arange(0, TILE_N)
585 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
586 mask = k_offset[None, :] < K and n_offset[:, None] < N
587 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
588 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
589 scale = tl.sum(out_tile * out_grad_tile, axis=0)
590 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
591 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
592 else:
593 for k_idx in range(0, split_k, TILE_K):
594 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
595 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
596 # specialization does not improve performance inn this example, as tested
597 for start_n in range(0, N, TILE_N):
598 n_offset = start_n + tl.arange(0, TILE_N)
599 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
600 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
601 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
602 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
603 scale += out_tile * out_grad_tile
604 scale = tl.sum(scale, axis=0)
605 for start_n in range(0, N, TILE_N):
606 n_offset = start_n + tl.arange(0, TILE_N)
607 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
608 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
609 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
610 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
611 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
612 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
615def config_prune4(configs, named_args, **kwargs):
616 M = named_args["M"]
617 N = named_args["N"]
618 output = named_args["output_ptr"]
619 configs_map = {}
620 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops
621 for config in configs:
622 kw = config.kwargs
623 BLOCK_M, BLOCK_N, num_warps, num_stages = (
624 kw["BLOCK_M"],
625 kw["BLOCK_N"],
626 config.num_warps,
627 config.num_stages,
628 )
629 if N < MAX_N:
630 config = copy.deepcopy(config)
631 BLOCK_N = config.kwargs["BLOCK_N"] = N
632 m_per_core = math.ceil(M / TOTAL_CORE_NUM)
633 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core
634 num_stages = config.num_stages = 1
635 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
636 configs_map.setdefault(key, config)
638 config = copy.deepcopy(config)
639 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (3 * BLOCK_N + 1)
640 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
641 num_stages = config.num_stages = 1
642 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
643 configs_map.setdefault(key, config)
645 config = copy.deepcopy(config)
646 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1)
647 if output.dtype == torch.float32:
648 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (7 * BLOCK_N + 1)
649 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
650 num_stages = config.num_stages = 3
651 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
652 configs_map.setdefault(key, config)
653 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
654 # Only keep one config for the same key
655 configs_map.setdefault(key, config)
656 pruned_configs = []
657 for k, v in configs_map.items():
658 pruned_configs.append(v)
659 # Add a heuristic config.
660 extra_config = copy.deepcopy(pruned_configs[0])
661 extra_config.kwargs["BLOCK_M"] = 1
662 extra_config.kwargs["BLOCK_N"] = N
663 extra_config.num_warps = 1
664 extra_config.num_stages = 3
665 pruned_configs.append(extra_config)
666 extra_config2 = copy.deepcopy(extra_config)
667 extra_config2.num_stages = 1
668 pruned_configs.append(extra_config2)
669 return pruned_configs
672@libentry()
673@libtuner(
674 configs=runtime.get_tuned_config("softmax_inner_bw"),
675 key=[
676 "M",
677 "N",
678 ],
679 prune_configs_by={"early_config_prune": config_prune4},
680)
681@triton.heuristics(
682 values=runtime.get_heuristic_config("softmax_backward_inner"),
683)
684@triton.jit
685def softmax_backward_kernel_inner(
686 output_ptr,
687 out_grad_ptr,
688 in_grad_ptr,
689 M,
690 N,
691 BLOCK_M: tl.constexpr,
692 BLOCK_N: tl.constexpr,
693 TILE_MODE: tl.constexpr,
694):
695 pid_m = tl.program_id(0)
696 pnum = tl.num_programs(axis=0)
697 split_m = tl.cdiv(M, pnum)
698 m_start = pid_m * split_m
700 if TILE_MODE == 0:
701 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
702 n_offset = tl.arange(0, BLOCK_N)
703 offset = m_offset[:, None] * N + n_offset[None, :]
704 mask = m_offset[:, None] < M
705 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
706 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
707 scale = tl.sum(out_tile * out_grad_tile, 1)
708 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
709 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
710 elif TILE_MODE == 1:
711 for m_idx in range(0, split_m, BLOCK_M):
712 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
713 n_offset = tl.arange(0, BLOCK_N)
714 offset = m_offset[:, None] * N + n_offset[None, :]
715 mask = m_offset[:, None] < M
716 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
717 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
718 scale = tl.sum(out_tile * out_grad_tile, 1)
719 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
720 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
721 else:
722 for m_idx in range(0, split_m, BLOCK_M):
723 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
724 scale = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
725 for start_n in range(0, N, BLOCK_N):
726 n_offset = start_n + tl.arange(0, BLOCK_N)
727 offset = m_offset[:, None] * N + n_offset[None, :]
728 mask = m_offset[:, None] < M and n_offset[None, :] < N
729 out_tile = tl.load(
730 output_ptr + offset, mask=mask, eviction_policy="evict_last"
731 ).to(tl.float32)
732 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
733 scale += out_tile * out_grad_tile
734 scale = tl.sum(scale, 1)
735 for start_n in range(0, N, BLOCK_N):
736 n_offset = start_n + tl.arange(0, BLOCK_N)
737 offset = m_offset[:, None] * N + n_offset[None, :]
738 mask = m_offset[:, None] < M and n_offset[None, :] < N
739 out_tile = tl.load(
740 output_ptr + offset, mask=mask, eviction_policy="evict_first"
741 ).to(tl.float32)
742 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
743 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
744 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
747def softmax(self, dim, half_to_float=False):
748 logger.debug("GEMS_CAMBRICON SOFTMAX")
750 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
751 dim = dim % self.ndim
752 M = 1
753 N = self.shape[dim]
754 for i in range(dim):
755 M *= self.shape[i] # pre_dim
756 self = self.contiguous()
757 if half_to_float:
758 dtype = torch.float32
759 else:
760 dtype = self.dtype
761 out = torch.empty_like(self, dtype=dtype)
762 K = self.numel() // M // N # post_dim
764 with torch_device_fn.device(self.device):
765 if K > 1:
766 logger.debug("GEMS_CAMBRICON SOFTMAX USE NON INNER")
767 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1)
768 softmax_kernel_non_inner[grid](
769 out,
770 self,
771 M,
772 N,
773 K,
774 )
775 else:
776 logger.debug("GEMS_CAMBRICON SOFTMAX USE INNER")
777 if M > TOTAL_CORE_NUM or N < 1024 * 8 * 8:
778 softmax_kernel_inner[TOTAL_CORE_NUM, 1, 1](
779 out,
780 self,
781 M,
782 N,
783 )
784 else:
785 block_m = 1
786 block_n = 8192 * 4
787 if dtype is torch.float32:
788 block_n = 8192 * 2
789 # workspace
790 T = (N + block_n - 1) // block_n
791 max_buf = torch.empty((M, T), device=self.device, dtype=torch.float32)
792 sum_buf = torch.empty((M, T), device=self.device, dtype=torch.float32)
793 gmax = torch.empty((M,), device=self.device, dtype=torch.float32)
794 gsum = torch.empty((M,), device=self.device, dtype=torch.float32)
795 # kernel 1: per-tile stats
796 softmax_kernel_inner_k_partial_stats[(TOTAL_CORE_NUM,)](
797 self,
798 max_buf,
799 sum_buf,
800 M,
801 N,
802 T,
803 BLOCK_M=block_m,
804 BLOCK_N=block_n,
805 bottleneck="simd",
806 num_stages=3,
807 )
808 # kernel 2: merge stats along N-tiles
809 grid_merge = (triton.cdiv(M, block_m),)
810 softmax_kernel_inner_k_merge_stats[grid_merge](
811 max_buf, sum_buf, gmax, gsum, M, T, BLOCK_M=block_m
812 )
813 block_n = block_n // 2
814 T = (N + block_n - 1) // block_n
815 # kernel 3: write normalized outputs
816 softmax_kernel_inner_k_write_softmax[(TOTAL_CORE_NUM,)](
817 self,
818 out,
819 gmax,
820 gsum,
821 M,
822 N,
823 T,
824 BLOCK_M=block_m,
825 BLOCK_N=block_n,
826 bottleneck="simd",
827 num_stages=3,
828 )
829 return out
832def softmax_backward(grad_output, output, dim, input_dtype):
833 logger.debug("GEMS_CAMBRICON SOFTMAX VJP")
835 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
836 dim = dim % output.ndim
837 M = 1
838 N = output.shape[dim]
839 for i in range(dim):
840 M *= output.shape[i]
842 grad_output = grad_output.contiguous()
843 in_grad = torch.empty_like(output)
844 K = output.numel() // M // N
846 with torch_device_fn.device(in_grad.device):
847 if K > 1:
848 logger.debug("GEMS_CAMBRICON SOFTMAX VJP USE NON INNER")
849 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1)
850 softmax_backward_kernel_non_inner[grid](
851 output,
852 grad_output,
853 in_grad,
854 M,
855 N,
856 K,
857 )
858 else:
859 logger.debug("GEMS_CAMBRICON SOFTMAX VJP USE INNER")
860 softmax_backward_kernel_inner[TOTAL_CORE_NUM, 1, 1](
861 output,
862 grad_output,
863 in_grad,
864 M,
865 N,
866 )
867 return in_grad