Coverage for src/flag_gems/runtime/backend/_ascend/ops/triu.py: 0%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
15@libentry()
16@triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"])
17@triton.jit(do_not_specialize=["diagonal"])
18def triu_kernel(
19 X,
20 Y,
21 M,
22 N,
23 diagonal,
24 M_BLOCK_SIZE: tl.constexpr,
25 N_BLOCK_SIZE: tl.constexpr,
26):
27 pid = tle.program_id(0)
28 row = pid * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE)[:, None]
29 m_mask = row < M
30 X += row * N
31 Y += row * N
33 for n_offset in range(0, N, N_BLOCK_SIZE):
34 cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :]
35 n_mask = cols < N
36 mask = m_mask and n_mask
38 x = tl.load(X + cols, mask, other=0.0)
39 y = tl.where(row + diagonal <= cols, x, 0.0)
40 tl.store(Y + cols, y, mask=mask)
43@libentry()
44@triton.autotune(
45 configs=runtime.get_tuned_config("triu_batch"),
46 key=["batch", "MN", "N", "diagonal"],
47)
48@triton.jit(do_not_specialize=["diagonal"])
49def triu_batch_kernel(
50 X,
51 Y,
52 batch,
53 MN,
54 N,
55 diagonal,
56 BATCH_BLOCK_SIZE: tl.constexpr,
57 MN_BLOCK_SIZE: tl.constexpr,
58):
59 batch_id = tle.program_id(0)
60 mn_id = tle.program_id(1)
61 batch_workers = tle.num_programs(0)
63 total_batch_workloads = tl.cdiv(batch, BATCH_BLOCK_SIZE)
64 batch_workloads = 1
65 while batch_workloads < tl.cdiv(batch, total_batch_workloads):
66 batch_workloads *= 2
68 for w in range(batch_workloads):
69 batch_work_id = batch_id + w * batch_workers
70 row = batch_work_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None]
71 batch_mask = row < batch
72 NX = X + row * MN
73 NY = Y + row * MN
75 cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :]
76 mn_mask = cols < MN
77 mask = batch_mask and mn_mask
78 x = tl.load(NX + cols, mask, other=0.0)
79 m = cols // N
80 n = cols % N
81 y = tl.where(m + diagonal <= n, x, 0.0)
82 tl.store(NY + cols, y, mask=mask)
85INT32_MAX = torch.iinfo(torch.int32).max
88def triu(A, diagonal=0):
89 logger.debug("GEMS_ASCEND TRIU")
90 A = A.contiguous()
91 out = torch.empty_like(A)
92 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions"
93 M, N = A.shape[-2:]
94 with torch_device_fn.device(A.device):
95 if len(A.shape) == 2:
96 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),)
97 triu_kernel[grid](A, out, M, N, diagonal)
98 else:
99 batch = int(torch.numel(A) / M / N)
100 B = A.view(batch, -1)
102 def grid(meta):
103 axis0 = triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"])
104 axis1 = triton.cdiv(M * N, meta["MN_BLOCK_SIZE"])
105 while axis0 * axis1 >= 65536:
106 axis0 = axis0 // 2
107 return (
108 axis0,
109 axis1,
110 )
112 triu_batch_kernel[grid](
113 B,
114 out,
115 batch,
116 M * N,
117 N,
118 diagonal,
119 )
120 out = out.view(A.shape)
121 return out