Coverage for src/flag_gems/runtime/backend/_cambricon/ops/log_softmax.py: 0%
556 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import copy
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
13from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16MAX_N = 16384
19def align(max_block):
20 a = triton.next_power_of_2(max_block)
21 return max_block if max_block == a else a // 2
24def config_prune1(configs, named_args, **kwargs):
25 M = named_args["M"]
26 N = named_args["N"]
27 K = named_args["K"]
28 input = named_args["input_ptr"]
29 configs_map = {}
30 for config in configs:
31 kw = config.kwargs
32 TILE_K, TILE_N, num_warps, num_stages = (
33 kw["TILE_K"],
34 kw["TILE_N"],
35 config.num_warps,
36 config.num_stages,
37 )
38 if N < MAX_N:
39 config = copy.deepcopy(config)
40 TILE_N = config.kwargs["TILE_N"] = N
41 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1))
42 nram_usage = (2 * TILE_N + 1) * k_per_core * 4
43 if nram_usage < MAX_NRAM_SIZE:
44 TILE_K = config.kwargs["TILE_K"] = k_per_core
45 num_stages = config.num_stages = 1
46 key = (TILE_K, TILE_N, num_warps, num_stages)
47 configs_map.setdefault(key, config)
48 else:
49 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (2 * TILE_N + 1)
50 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe)
51 num_stages = config.num_stages = 1
52 key = (TILE_K, TILE_N, num_warps, num_stages)
53 configs_map.setdefault(key, config)
55 config = copy.deepcopy(config)
56 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (3 * TILE_N + 1)
57 if input.dtype == torch.float32:
58 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (4 * TILE_N + 1)
59 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe)
60 num_stages = config.num_stages = 3
61 key = (TILE_K, TILE_N, num_warps, num_stages)
62 configs_map.setdefault(key, config)
63 else:
64 key = (TILE_K, TILE_N, num_warps, num_stages)
65 configs_map.setdefault(key, config)
66 pruned_configs = []
67 for k, v in configs_map.items():
68 pruned_configs.append(v)
69 extra_config = copy.deepcopy(pruned_configs[0])
70 extra_config.kwargs["TILE_K"] = 1
71 extra_config.kwargs["TILE_N"] = N
72 extra_config.num_warps = 1
73 extra_config.num_stages = 3
74 pruned_configs.append(extra_config)
75 extra_config2 = copy.deepcopy(extra_config)
76 extra_config2.num_stages = 1
77 pruned_configs.append(extra_config2)
78 return pruned_configs
81def log_softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K):
82 one_tile_k = TILE_K * max(TOTAL_CORE_NUM // M, 1) >= K
83 one_tile_n = TILE_N >= N
84 if one_tile_n and one_tile_k:
85 return 0
86 elif one_tile_n and not one_tile_k:
87 return 1
88 else:
89 return 2
92@libentry()
93@libtuner(
94 configs=[
95 triton.Config({"TILE_K": k, "TILE_N": 2**n}, num_stages=s, num_warps=1)
96 for k in [1, 2, 4, 8]
97 for n in range(10, 15, 1)
98 for s in [1, 3]
99 ],
100 key=[
101 "N",
102 "K",
103 ],
104 prune_configs_by={"early_config_prune": config_prune1},
105)
106@triton.heuristics(
107 values={
108 "TILE_MODE": lambda args: log_softmax_tile_mode_for_non_inner(
109 args["M"], args["N"], args["K"], args["TILE_N"], args["TILE_K"]
110 ),
111 },
112)
113@triton.jit
114def log_softmax_kernel_non_inner(
115 output_ptr,
116 input_ptr,
117 M,
118 N,
119 K,
120 TILE_N: tl.constexpr,
121 TILE_K: tl.constexpr,
122 TILE_MODE: tl.constexpr,
123):
124 pid_m = tl.program_id(0)
125 pid_k = tl.program_id(1)
127 p_k_num = tl.num_programs(axis=1)
128 split_k = tl.cdiv(K, p_k_num)
129 k_start = pid_k * split_k
131 log2e = 1.442695
133 if TILE_MODE == 0:
134 n_offset = tl.arange(0, TILE_N)
135 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)
136 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
137 mask = k_offset[None, :] < K
138 input_ptrs = input_ptr + offset
139 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
140 m = inp - tl.max(inp, axis=0)[None, :]
141 e = tl.exp(m)
142 s = tl.sum(e, axis=0)[None, :]
143 output = m - tl.log2(s) / log2e
144 output_ptrs = output_ptr + offset
145 tl.store(output_ptrs, output, mask=mask)
146 elif TILE_MODE == 1:
147 for k_idx in range(0, split_k, TILE_K):
148 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
149 n_offset = tl.arange(0, TILE_N)
150 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
151 mask = k_offset[None, :] < K
152 input_ptrs = input_ptr + offset
153 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
154 m = inp - tl.max(inp, axis=0)[None, :]
155 e = tl.exp(m)
156 s = tl.sum(e, axis=0)[None, :]
157 output = m - tl.log2(s) / log2e
158 output_ptrs = output_ptr + offset
159 tl.store(output_ptrs, output, mask=mask)
160 else:
161 for k_idx in range(0, split_k, TILE_K):
162 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
163 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
164 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
166 # specialization does not improve performance inn this example, as tested
167 for start_n in range(0, N, TILE_N):
168 n_offset = start_n + tl.arange(0, TILE_N)
169 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
170 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
171 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
172 tl.float32
173 )
174 m_new = tl.maximum(m, inp)
175 all_neg_inf = m_new == float("-inf")
176 z = tl.where(
177 all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)
178 )
179 m = m_new
181 m_reduced = tl.max(m, 0) # (TILE_K,)
182 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
183 recip_z = 1.0 / z
184 m = m_reduced
186 # specialization does not improve performance inn this example, as tested
187 for start_n in range(0, N, TILE_N):
188 n_offset = start_n + tl.arange(0, TILE_N)
189 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
190 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
191 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
192 tl.float32
193 )
194 o = tl.exp(inp - m[None, :]) * recip_z[None, :]
195 output = tl.log2(o) / log2e
196 tl.store(output_ptr + offset, output, mask=mask)
199def config_prune2(configs, named_args, **kwargs):
200 M = named_args["M"]
201 N = named_args["N"]
202 input = named_args["input_ptr"]
203 configs_map = {}
204 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops
205 for config in configs:
206 kw = config.kwargs
207 BLOCK_M, BLOCK_N, num_warps, num_stages = (
208 kw["BLOCK_M"],
209 kw["BLOCK_N"],
210 config.num_warps,
211 config.num_stages,
212 )
213 if N < MAX_N:
214 config = copy.deepcopy(config)
215 BLOCK_N = config.kwargs["BLOCK_N"] = N
216 m_per_core = math.ceil(M / TOTAL_CORE_NUM)
217 nram_usage = (2 * BLOCK_N + 1) * m_per_core * 4
218 if nram_usage < MAX_NRAM_SIZE:
219 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core
220 num_stages = config.num_stages = 1
221 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
222 configs_map.setdefault(key, config)
223 else:
224 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (2 * BLOCK_N + 1)
225 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
226 num_stages = config.num_stages = 1
227 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
228 configs_map.setdefault(key, config)
230 config = copy.deepcopy(config)
231 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (4 * BLOCK_N + 1)
232 if input.dtype == torch.float32:
233 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1)
234 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
235 num_stages = config.num_stages = 3
236 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
237 configs_map.setdefault(key, config)
238 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
239 # Only keep one config for the same key
240 configs_map.setdefault(key, config)
241 pruned_configs = []
242 for k, v in configs_map.items():
243 pruned_configs.append(v)
244 # Add a heuristic config.
245 extra_config = copy.deepcopy(pruned_configs[0])
246 extra_config.kwargs["BLOCK_M"] = 1
247 extra_config.kwargs["BLOCK_N"] = N
248 extra_config.num_warps = 1
249 extra_config.num_stages = 3
250 pruned_configs.append(extra_config)
251 extra_config2 = copy.deepcopy(extra_config)
252 extra_config2.num_stages = 1
253 pruned_configs.append(extra_config2)
254 return pruned_configs
257def log_softmax_tile_mode_for_inner(M, N, BLOCK_M, BLOCK_N):
258 one_tile_m = BLOCK_M * TOTAL_CORE_NUM >= M
259 one_tile_n = BLOCK_N >= N
260 if one_tile_n and one_tile_m:
261 return 0
262 elif one_tile_n and not one_tile_m:
263 return 1
264 else:
265 return 2
268@libentry()
269@libtuner(
270 configs=runtime.get_tuned_config("log_softmax"),
271 key=[
272 "M",
273 "N",
274 ],
275 prune_configs_by={"early_config_prune": config_prune2},
276)
277@triton.heuristics(
278 values={
279 "TILE_MODE": lambda args: log_softmax_tile_mode_for_inner(
280 args["M"], args["N"], args["BLOCK_M"], args["BLOCK_N"]
281 ),
282 },
283)
284@triton.jit
285def log_softmax_kernel_inner(
286 output_ptr,
287 input_ptr,
288 M,
289 N,
290 BLOCK_M: tl.constexpr,
291 BLOCK_N: tl.constexpr,
292 TILE_MODE: tl.constexpr,
293):
294 pid_m = tl.program_id(0)
295 pnum = tl.num_programs(axis=0)
296 split_m = tl.cdiv(M, pnum)
297 m_start = pid_m * split_m
299 log2e = 1.442695
301 if TILE_MODE == 0:
302 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
303 n_offset = tl.arange(0, BLOCK_N)
304 offset = m_offset[:, None] * N + n_offset[None, :]
305 mask = m_offset[:, None] < M
306 input_ptrs = input_ptr + offset
307 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
308 row_minus_max = inp - tl.max(inp, axis=1)[:, None]
309 numerator = tl.exp(row_minus_max)
310 denominator = tl.sum(numerator, axis=1)[:, None]
311 recip = 1.0 / denominator
312 softmax_output = numerator * recip
313 output = tl.log2(softmax_output) / log2e
314 output_ptrs = output_ptr + offset
315 tl.store(output_ptrs, output, mask=mask)
316 elif TILE_MODE == 1:
317 for m_idx in range(0, split_m, BLOCK_M):
318 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
319 n_offset = tl.arange(0, BLOCK_N)
320 offset = m_offset[:, None] * N + n_offset[None, :]
321 mask = m_offset[:, None] < M
322 input_ptrs = input_ptr + offset
323 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
324 trans_inp = tl.trans(inp)
325 row_minus_max = trans_inp - tl.max(trans_inp, axis=0)[None, :]
326 numerator = tl.exp(row_minus_max)
327 denominator = tl.sum(numerator, axis=0)[None, :]
328 recip = 1.0 / denominator
329 softmax_output = tl.trans(numerator * recip)
330 output = tl.log2(softmax_output) / log2e
331 output_ptrs = output_ptr + offset
332 tl.store(output_ptrs, output, mask=mask)
333 else:
334 for m_idx in range(0, split_m, BLOCK_M):
335 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
336 block_max = tl.full(
337 [BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32
338 )
339 block_sum = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32)
340 # specialization does not improve performance inn this example, as tested
341 for start_n in range(0, N, BLOCK_N):
342 n_offset = start_n + tl.arange(0, BLOCK_N)
343 offset = m_offset[:, None] * N + n_offset[None, :]
344 mask = m_offset[:, None] < M and n_offset[None, :] < N
345 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
346 tl.float32
347 )
348 cur_max = tl.maximum(block_max, inp)
349 all_neg_inf = cur_max == float("-inf")
350 block_sum = tl.where(
351 all_neg_inf,
352 block_sum,
353 block_sum * tl.exp(block_max - cur_max) + tl.exp(inp - cur_max),
354 )
355 block_max = cur_max
357 trans_block_max = tl.trans(block_max)
358 trans_block_sum = tl.trans(block_sum)
359 max_reduced = tl.max(trans_block_max, 0)
360 total_sum = tl.sum(
361 trans_block_sum * tl.exp(trans_block_max - max_reduced[None, :]), 0
362 )
363 recip_total_sum = 1.0 / total_sum
364 total_max = max_reduced
366 for start_n in range(0, N, BLOCK_N):
367 n_offset = start_n + tl.arange(0, BLOCK_N)
368 offset = m_offset[:, None] * N + n_offset[None, :]
369 mask = m_offset[:, None] < M and n_offset[None, :] < N
370 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
371 tl.float32
372 )
373 o = tl.exp(inp - total_max[:, None]) * recip_total_sum[:, None]
374 output = tl.log2(o) / log2e
375 tl.store(output_ptr + offset, output, mask=mask)
378@triton.jit
379def log_softmax_kernel_inner_k_partial_stats(
380 x_ptr,
381 max_buf_ptr,
382 sum_buf_ptr,
383 M,
384 N,
385 T,
386 BLOCK_M: tl.constexpr,
387 BLOCK_N: tl.constexpr,
388):
389 pnum = tl.num_programs(axis=0)
390 pid = tl.program_id(0)
391 total_blocks = (M // BLOCK_M) * T
392 work_per_core = (total_blocks + pnum - 1) // pnum
393 start = pid * work_per_core
394 end = tl.minimum(start + work_per_core, total_blocks)
396 for task in range(start, end):
397 row_id = task // T
398 tile_id = task % T
400 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
401 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
402 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
404 tile = tl.load(
405 x_ptr + offs_m[:, None] * N + offs_n[None, :],
406 mask=mask,
407 other=-float("inf"),
408 ).to(tl.float32)
410 tile_max = tl.max(tile, axis=1)
411 all_neg_inf = tile_max == -float("inf")
413 tile_sum = tl.where(
414 all_neg_inf,
415 0.0,
416 tl.sum(tl.exp(tile - tile_max[:, None]), axis=1),
417 )
419 tl.store(max_buf_ptr + offs_m * T + tile_id, tile_max, mask=(offs_m < M))
420 tl.store(sum_buf_ptr + offs_m * T + tile_id, tile_sum, mask=(offs_m < M))
423@triton.jit
424def log_softmax_kernel_inner_k_merge_stats(
425 max_buf_ptr,
426 sum_buf_ptr,
427 gmax_ptr,
428 gsum_ptr,
429 M: tl.constexpr,
430 T: tl.constexpr,
431 BLOCK_M: tl.constexpr,
432):
433 pid_m = tl.program_id(axis=0)
434 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
435 mask_m = offs_m < M
437 tile_max = tl.load(
438 max_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :],
439 mask=(offs_m[:, None] < M),
440 other=-float("inf"),
441 )
442 tile_sum = tl.load(
443 sum_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :],
444 mask=(offs_m[:, None] < M),
445 other=0.0,
446 ).to(tl.float32)
448 gmax = tl.max(tile_max, axis=1)
449 scale = tl.exp(tile_max - gmax[:, None])
450 scale = tl.where(gmax[:, None] == -float("inf"), 0.0, scale)
451 gsum = tl.sum(tile_sum * scale, axis=1)
453 tl.store(gmax_ptr + offs_m, gmax, mask=mask_m)
454 tl.store(gsum_ptr + offs_m, gsum, mask=mask_m)
457@triton.jit
458def log_softmax_kernel_inner_k_write_logsoftmax(
459 x_ptr,
460 y_ptr,
461 gmax_ptr,
462 gsum_ptr,
463 M,
464 N,
465 T,
466 BLOCK_M: tl.constexpr,
467 BLOCK_N: tl.constexpr,
468):
469 log2e = 1.442695
470 pnum = tl.num_programs(axis=0)
471 pid = tl.program_id(0)
472 total_blocks = (M // BLOCK_M) * T
473 work_per_core = (total_blocks + pnum - 1) // pnum
474 start = pid * work_per_core
475 end = tl.minimum(start + work_per_core, total_blocks)
477 for task in range(start, end):
478 row_id = task // T
479 tile_id = task % T
481 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
482 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
483 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
485 gmax = tl.load(gmax_ptr + offs_m, mask=(offs_m < M), other=-float("inf")).to(
486 tl.float32
487 )
488 gsum = tl.load(gsum_ptr + offs_m, mask=(offs_m < M), other=0.0).to(tl.float32)
490 tile = tl.load(
491 x_ptr + offs_m[:, None] * N + offs_n[None, :],
492 mask=mask,
493 other=-float("inf"),
494 ).to(tl.float32)
496 valid = gsum[:, None] > 0
498 o = tl.where(
499 valid,
500 tl.exp(tile - gmax[:, None]) / gsum[:, None],
501 0.0,
502 )
503 out = tl.log2(o) / log2e
505 tl.store(y_ptr + offs_m[:, None] * N + offs_n[None, :], out, mask=mask)
508# ------------------------ backward -------------------------------
511def config_prune3(configs, named_args, **kwargs):
512 M = named_args["M"]
513 N = named_args["N"]
514 K = named_args["K"]
515 output = named_args["output_ptr"]
516 configs_map = {}
517 for config in configs:
518 kw = config.kwargs
519 TILE_K, TILE_N, num_warps, num_stages = (
520 kw["TILE_K"],
521 kw["TILE_N"],
522 config.num_warps,
523 config.num_stages,
524 )
525 if N < MAX_N:
526 config = copy.deepcopy(config)
527 TILE_N = config.kwargs["TILE_N"] = N
528 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1))
529 nram_usage = (3 * TILE_N + 1) * k_per_core * 4
530 if nram_usage < MAX_NRAM_SIZE:
531 TILE_K = config.kwargs["TILE_K"] = k_per_core
532 num_stages = config.num_stages = 1
533 key = (TILE_K, TILE_N, num_warps, num_stages)
534 configs_map.setdefault(key, config)
535 else:
536 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (3 * TILE_N + 1)
537 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe)
538 num_stages = config.num_stages = 1
539 key = (TILE_K, TILE_N, num_warps, num_stages)
540 configs_map.setdefault(key, config)
542 config = copy.deepcopy(config)
543 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (6 * TILE_N + 1)
544 if output.dtype == torch.float32:
545 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (7 * TILE_N + 1)
546 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe)
547 num_stages = config.num_stages = 3
548 key = (TILE_K, TILE_N, num_warps, num_stages)
549 configs_map.setdefault(key, config)
550 else:
551 key = (TILE_K, TILE_N, num_warps, num_stages)
552 configs_map.setdefault(key, config)
553 pruned_configs = []
554 for k, v in configs_map.items():
555 pruned_configs.append(v)
556 extra_config = copy.deepcopy(pruned_configs[0])
557 extra_config.kwargs["TILE_K"] = 1
558 extra_config.kwargs["TILE_N"] = N
559 extra_config.num_warps = 1
560 extra_config.num_stages = 3
561 pruned_configs.append(extra_config)
562 extra_config2 = copy.deepcopy(extra_config)
563 extra_config2.num_stages = 1
564 pruned_configs.append(extra_config2)
565 return pruned_configs
568@libentry()
569@libtuner(
570 configs=[
571 triton.Config({"TILE_K": k, "TILE_N": 2**n}, num_stages=s, num_warps=1)
572 for k in [1, 2, 4, 8]
573 for n in range(10, 15, 1)
574 for s in [1, 3]
575 ],
576 key=[
577 "N",
578 "K",
579 ],
580 prune_configs_by={"early_config_prune": config_prune3},
581)
582@triton.heuristics(
583 values={
584 "TILE_MODE": lambda args: log_softmax_tile_mode_for_non_inner(
585 args["M"], args["N"], args["K"], args["TILE_N"], args["TILE_K"]
586 ),
587 },
588)
589@triton.jit
590def log_softmax_backward_kernel_non_inner(
591 output_ptr,
592 out_grad_ptr,
593 in_grad_ptr,
594 M,
595 N,
596 K,
597 TILE_N: tl.constexpr,
598 TILE_K: tl.constexpr,
599 TILE_MODE: tl.constexpr,
600):
601 pid_m = tl.program_id(0)
602 pid_k = tl.program_id(1)
604 p_k_num = tl.num_programs(axis=1)
605 split_k = tl.cdiv(K, p_k_num)
606 k_start = pid_k * split_k
608 if TILE_MODE == 0:
609 n_offset = tl.arange(0, TILE_N)
610 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)
611 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
612 mask = k_offset[None, :] < K
613 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
614 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
615 scale = tl.sum(out_grad_tile, axis=0)
616 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :]
617 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
618 elif TILE_MODE == 1:
619 for k_idx in range(0, split_k, TILE_K):
620 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
621 n_offset = tl.arange(0, TILE_N)
622 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
623 mask = k_offset[None, :] < K
624 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
625 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
626 scale = tl.sum(out_grad_tile, axis=0)
627 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :]
628 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
629 else:
630 for k_idx in range(0, split_k, TILE_K):
631 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
632 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
633 # specialization does not improve performance inn this example, as tested
634 for start_n in range(0, N, TILE_N):
635 n_offset = start_n + tl.arange(0, TILE_N)
636 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
637 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
638 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
639 scale += out_grad_tile
640 scale = tl.sum(scale, axis=0)
641 for start_n in range(0, N, TILE_N):
642 n_offset = start_n + tl.arange(0, TILE_N)
643 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
644 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
645 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
646 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
647 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[None, :]
648 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
651def nram_usage_for_backward_inner(bm, bn, tile_mode, num_stages, dtype):
652 coef = 1
653 if tile_mode == 0:
654 if dtype == torch.float32:
655 return 5 * bn * bm * 4
656 else:
657 return 4 * bn * bm * 4
658 elif tile_mode == 1:
659 if num_stages == 1:
660 coef = 3
661 else:
662 if dtype == torch.float32:
663 coef = 8
664 else:
665 coef = 6
666 else:
667 if num_stages == 1:
668 coef = 4
669 else:
670 if dtype == torch.float32:
671 coef = 11
672 else:
673 coef = 8
674 return (coef * bn + 1) * bm * 4
677def config_prune4(configs, named_args, **kwargs):
678 M = named_args["M"]
679 N = named_args["N"]
680 output = named_args["output_ptr"]
681 dtype = output.dtype
682 m_per_core = math.ceil(M / TOTAL_CORE_NUM)
683 # No need for any loop.
684 if nram_usage_for_backward_inner(m_per_core, N, 0, 1, dtype) < MAX_NRAM_SIZE:
685 config = copy.deepcopy(configs[0])
686 config.kwargs["BLOCK_M"] = m_per_core
687 config.kwargs["BLOCK_N"] = N
688 config.num_stages = 1
689 return [config]
690 align_num = 256 // 4 if dtype == torch.float32 else 256 // 2
691 pruned_configs = []
692 for config in configs:
693 kw = config.kwargs
694 BLOCK_M, BLOCK_N, num_stages = (
695 kw["BLOCK_M"],
696 kw["BLOCK_N"],
697 config.num_stages,
698 )
699 # Align the lowest dimension to 256B while loading/storing data.
700 if BLOCK_N % align_num != 0:
701 continue
702 # nram usage shoule be smaller than MAX_NRAM_SIZE
703 mode = log_softmax_tile_mode_for_inner(M, N, BLOCK_M, BLOCK_N)
704 nram = nram_usage_for_backward_inner(BLOCK_M, BLOCK_N, mode, num_stages, dtype)
705 if nram > MAX_NRAM_SIZE or nram < MAX_NRAM_SIZE // 2:
706 continue
707 pruned_configs.append(config)
708 return pruned_configs
711@libentry()
712@libtuner(
713 configs=[
714 triton.Config({"BLOCK_N": 64 * k, "BLOCK_M": 2**n}, num_stages=s, num_warps=1)
715 for k in range(1, 17)
716 for n in range(3, 10, 1)
717 for s in [1, 3]
718 ],
719 key=[
720 "N",
721 "M",
722 ],
723 prune_configs_by={"early_config_prune": config_prune4},
724)
725@triton.heuristics(
726 values={
727 "TILE_MODE": lambda args: log_softmax_tile_mode_for_inner(
728 args["M"], args["N"], args["BLOCK_M"], args["BLOCK_N"]
729 ),
730 },
731)
732@triton.jit
733def log_softmax_backward_kernel_inner(
734 output_ptr,
735 out_grad_ptr,
736 in_grad_ptr,
737 M,
738 N,
739 BLOCK_M: tl.constexpr,
740 BLOCK_N: tl.constexpr,
741 TILE_MODE: tl.constexpr,
742):
743 pid_m = tl.program_id(0)
744 pnum = tl.num_programs(axis=0)
745 split_m = tl.cdiv(M, pnum)
746 m_start = pid_m * split_m
748 if TILE_MODE == 0:
749 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
750 n_offset = tl.arange(0, BLOCK_N)
751 offset = m_offset[:, None] * N + n_offset[None, :]
752 mask = m_offset[:, None] < M
753 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
754 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
755 scale = tl.sum(out_grad_tile, 1)
756 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None]
757 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
758 elif TILE_MODE == 1:
759 for m_idx in range(0, split_m, BLOCK_M):
760 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
761 n_offset = tl.arange(0, BLOCK_N)
762 offset = m_offset[:, None] * N + n_offset[None, :]
763 mask = m_offset[:, None] < M and n_offset[None, :] < N
764 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
765 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
766 scale = tl.sum(out_grad_tile, 1)
767 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None]
768 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
769 else:
770 for m_idx in range(0, split_m, BLOCK_M):
771 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
772 scale = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
773 for start_n in range(0, N, BLOCK_N):
774 n_offset = start_n + tl.arange(0, BLOCK_N)
775 offset = m_offset[:, None] * N + n_offset[None, :]
776 mask = m_offset[:, None] < M and n_offset[None, :] < N
777 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
778 scale += out_grad_tile
779 scale = tl.sum(scale, 1)
780 for start_n in range(0, N, BLOCK_N):
781 n_offset = start_n + tl.arange(0, BLOCK_N)
782 offset = m_offset[:, None] * N + n_offset[None, :]
783 mask = m_offset[:, None] < M and n_offset[None, :] < N
784 out_tile = tl.load(
785 output_ptr + offset, mask=mask, eviction_policy="evict_first"
786 ).to(tl.float32)
787 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
788 in_grad_tile = out_grad_tile - tl.exp(out_tile) * scale[:, None]
789 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
792def log_softmax(self, dim, half_to_float=False):
793 logger.debug("GEMS_CAMBRICON LOG_SOFTMAX")
795 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
796 dim = dim % self.ndim
797 M = 1
798 N = self.shape[dim]
799 for i in range(dim):
800 M *= self.shape[i]
801 inp = self.contiguous()
802 if half_to_float:
803 dtype = torch.float32
804 else:
805 dtype = self.dtype
806 out = torch.empty_like(inp, dtype=dtype)
807 K = inp.numel() // M // N
809 with torch_device_fn.device(inp.device):
810 if K > 1:
811 logger.debug("GEMS_CAMBRICON LOGSOFTMAX USE NON INNER")
812 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1)
813 log_softmax_kernel_non_inner[grid](
814 out,
815 inp,
816 M,
817 N,
818 K,
819 )
820 else:
821 logger.debug("GEMS_CAMBRICON LOGSOFTMAX USE INNER")
822 if M > TOTAL_CORE_NUM or N < 1024 * 8 * 8:
823 log_softmax_kernel_inner[TOTAL_CORE_NUM, 1, 1](
824 out,
825 inp,
826 M,
827 N,
828 )
829 else:
830 block_m = 1
831 block_n = 8192 * 4
832 if dtype is torch.float32:
833 block_n = 8192 * 2
834 # workspace
835 T = (N + block_n - 1) // block_n
836 max_buf = torch.empty((M, T), device=self.device, dtype=torch.float32)
837 sum_buf = torch.empty((M, T), device=self.device, dtype=torch.float32)
838 gmax = torch.empty((M,), device=self.device, dtype=torch.float32)
839 gsum = torch.empty((M,), device=self.device, dtype=torch.float32)
840 # kernel 1: per-tile stats
841 log_softmax_kernel_inner_k_partial_stats[(TOTAL_CORE_NUM,)](
842 self,
843 max_buf,
844 sum_buf,
845 M,
846 N,
847 T,
848 BLOCK_M=block_m,
849 BLOCK_N=block_n,
850 bottleneck="simd",
851 num_stages=3,
852 )
853 # kernel 2: merge stats along N-tiles
854 grid_merge = (triton.cdiv(M, block_m),)
855 log_softmax_kernel_inner_k_merge_stats[grid_merge](
856 max_buf, sum_buf, gmax, gsum, M, T, BLOCK_M=block_m
857 )
858 block_n = block_n // 2
859 T = (N + block_n - 1) // block_n
860 # kernel 3: write normalized outputs
861 log_softmax_kernel_inner_k_write_logsoftmax[(TOTAL_CORE_NUM,)](
862 self,
863 out,
864 gmax,
865 gsum,
866 M,
867 N,
868 T,
869 BLOCK_M=block_m,
870 BLOCK_N=block_n,
871 bottleneck="simd",
872 num_stages=3,
873 )
874 return out
877def log_softmax_backward(grad_output, output, dim, input_dtype):
878 logger.debug("GEMS_CAMBRICON LOG_SOFTMAX VJP")
880 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
881 dim = dim % output.ndim
882 M = 1
883 N = output.shape[dim]
884 for i in range(dim):
885 M *= output.shape[i]
887 grad_output = grad_output.contiguous()
888 in_grad = torch.empty_like(output)
889 K = output.numel() // M // N
891 with torch_device_fn.device(in_grad.device):
892 if K > 1:
893 logger.debug("GEMS_CAMBRICON LOG SOFTMAX VJP USE NON INNER")
894 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1)
895 log_softmax_backward_kernel_non_inner[grid](
896 output,
897 grad_output,
898 in_grad,
899 M,
900 N,
901 K,
902 )
903 else:
904 logger.debug("GEMS_CAMBRICON LOG SOFTMAX VJP USE INNER")
905 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1)
906 log_softmax_backward_kernel_inner[TOTAL_CORE_NUM, 1, 1](
907 output,
908 grad_output,
909 in_grad,
910 M,
911 N,
912 )
913 return in_grad