Coverage for src/flag_gems/runtime/backend/_cambricon/ops/triu.py: 0%
140 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-20 02:31 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils.shape_utils import can_use_int32_index
12from ..utils import TOTAL_CORE_NUM
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@libentry()
18@triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"])
19@triton.jit(do_not_specialize=["diagonal"])
20def triu_kernel(
21 X,
22 Y,
23 M,
24 N,
25 diagonal,
26 M_BLOCK_SIZE: tl.constexpr,
27 N_BLOCK_SIZE: tl.constexpr,
28 NEED_LOOP: tl.constexpr,
29 INT64_INDEX: tl.constexpr = False,
30):
31 pid = tl.program_id(0)
32 if INT64_INDEX:
33 pid = pid.to(tl.int64)
34 num_jobs = tl.num_programs(0)
35 m_block_step = M_BLOCK_SIZE * num_jobs
37 for m_offset in range(pid * M_BLOCK_SIZE, M, m_block_step):
38 if NEED_LOOP:
39 row = m_offset + tl.arange(0, M_BLOCK_SIZE)[:, None]
40 m_mask = row < M
41 PX = X + row * N
42 PY = Y + row * N
43 for n_offset in range(0, N, N_BLOCK_SIZE):
44 cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :]
45 n_mask = cols < N
46 mask = m_mask and n_mask
48 x = tl.load(PX + cols, mask, other=0.0)
49 y = tl.where(row + diagonal <= cols, x, 0.0)
50 tl.store(PY + cols, y, mask=mask)
51 else:
52 write = tl.empty([M_BLOCK_SIZE, N_BLOCK_SIZE], X.dtype.element_ty)
53 cols = tl.arange(0, N_BLOCK_SIZE)
54 repeat_num = tl.minimum(M_BLOCK_SIZE, M - m_offset)
55 for i in tl.range(repeat_num, num_stages=0):
56 cur_row = m_offset + i
57 PX = X + cur_row * N
58 rmask = cols >= cur_row + diagonal
59 write[i, :] = tl.load(PX + cols, mask=rmask, other=0.0)
61 row = m_offset + tl.arange(0, M_BLOCK_SIZE)
62 offset = cols[None, :] + row[:, None] * N
63 n_mask = row[:, None] < M
64 tl.store(Y + offset, write, mask=n_mask)
67@libentry()
68@triton.autotune(
69 configs=runtime.get_tuned_config("triu_batch"),
70 key=["batch", "MN", "N", "diagonal"],
71)
72@triton.jit(do_not_specialize=["diagonal"])
73def triu_batch_kernel(
74 X,
75 Y,
76 batch,
77 MN,
78 N,
79 diagonal,
80 BATCH_BLOCK_SIZE: tl.constexpr,
81 MN_BLOCK_SIZE: tl.constexpr,
82 INT64_INDEX: tl.constexpr = False,
83):
84 batch_id = tl.program_id(0)
85 mn_id = tl.program_id(1)
86 if INT64_INDEX:
87 batch_id = batch_id.to(tl.int64)
88 mn_id = mn_id.to(tl.int64)
89 row = batch_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None]
90 batch_mask = row < batch
91 X += row * MN
92 Y += row * MN
94 cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :]
95 mn_mask = cols < MN
96 mask = batch_mask and mn_mask
97 x = tl.load(X + cols, mask, other=0.0)
98 m = cols // N
99 n = cols % N
100 y = tl.where(m + diagonal <= n, x, 0.0)
101 tl.store(Y + cols, y, mask=mask)
104def _check_batch_contiguous(tensor, allow_zero_stride=True):
105 if tensor.is_contiguous():
106 return True, tensor
108 dims = tensor.dim()
110 if dims >= 2:
111 n = tensor.size(-1)
112 stride_row, stride_col = tensor.stride(-2), tensor.stride(-1)
114 if not (stride_col == 1 and stride_row == n):
115 return False, tensor.contiguous()
117 if allow_zero_stride and dims <= 3:
118 return True, tensor
120 expected_stride = tensor.size(-1) * tensor.size(-2)
121 for i in range(dims - 3, -1, -1):
122 if (
123 allow_zero_stride
124 and i == 0
125 and (tensor.stride(i) == 0 or tensor.size(i) == 1)
126 ):
127 continue
129 if tensor.stride(i) != expected_stride:
130 return False, tensor.contiguous()
132 expected_stride *= tensor.size(i)
134 return True, tensor
137def triu(A, diagonal=0):
138 logger.debug("GEMS_CAMBRICON TRIU")
140 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions"
142 can_use_directly, A_input = _check_batch_contiguous(A, allow_zero_stride=False)
144 out = torch.empty(
145 A.shape, dtype=A.dtype, device=A.device, memory_format=torch.contiguous_format
146 )
148 M, N = A_input.shape[-2:]
149 use_int64_index = not can_use_int32_index(A_input)
150 with torch_device_fn.device(A_input.device):
151 if len(A_input.shape) == 2:
152 grid = lambda meta: (
153 min(triton.cdiv(M, meta["M_BLOCK_SIZE"]), TOTAL_CORE_NUM),
154 )
155 # A large value for n_block_size can lead to insufficient MLU resources,
156 # causing the compilation to fail. Therefore, a conservative upper limit of 8192
157 # is currently set, but the actual maximum achievable value should be confirmed
158 # based on real-world conditions.
159 elements_bytes = A_input.element_size()
160 n_block = min(256 * 1024 // elements_bytes, N)
161 need_loop = n_block < N
162 triu_kernel[grid](
163 A_input,
164 out,
165 M,
166 N,
167 diagonal,
168 N_BLOCK_SIZE=n_block,
169 NEED_LOOP=need_loop,
170 INT64_INDEX=use_int64_index,
171 )
172 else:
173 batch = int(torch.numel(A_input) / M / N)
174 B = A_input.view(batch, -1)
175 grid = lambda meta: (
176 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]),
177 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]),
178 )
179 triu_batch_kernel[grid](
180 B, out, batch, M * N, N, diagonal, INT64_INDEX=use_int64_index
181 )
182 out = out.view(A.shape)
183 return out
186def triu_(A, diagonal=0):
187 logger.debug("GEMS_CAMBRICON TRIU_(inplace)")
189 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions"
190 diagonal = int(diagonal)
191 M, N = A.shape[-2:]
193 can_use_directly, A_to_use = _check_batch_contiguous(A, allow_zero_stride=True)
195 if not can_use_directly:
196 logger.debug(
197 "Input tensor does not satisfy contiguity requirements, "
198 "using temporary tensor for computation"
199 )
201 result_temp = torch.empty_like(A_to_use, memory_format=torch.contiguous_format)
202 use_int64_index = not can_use_int32_index(A_to_use)
203 with torch_device_fn.device(A.device):
204 if len(A.shape) == 2:
205 grid = lambda meta: (
206 min(triton.cdiv(M, meta["M_BLOCK_SIZE"]), TOTAL_CORE_NUM),
207 )
208 # A large value for n_block_size can lead to insufficient MLU resources,
209 # causing the compilation to fail. Therefore, a conservative upper limit of 8192
210 # is currently set, but the actual maximum achievable value should be confirmed
211 # based on real-world conditions.
212 elements_bytes = A.element_size()
213 n_block = min(256 * 1024 // elements_bytes, N)
214 need_loop = n_block < N
215 triu_kernel[grid](
216 A_to_use,
217 result_temp,
218 M,
219 N,
220 diagonal,
221 N_BLOCK_SIZE=n_block,
222 NEED_LOOP=need_loop,
223 INT64_INDEX=use_int64_index,
224 )
225 else:
226 batch = int(torch.numel(A) / M / N)
227 B = A_to_use.view(batch, -1)
228 result_temp_flat = result_temp.view(batch, -1)
229 grid = lambda meta: (
230 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]),
231 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]),
232 )
233 triu_batch_kernel[grid](
234 B,
235 result_temp_flat,
236 batch,
237 M * N,
238 N,
239 diagonal,
240 INT64_INDEX=use_int64_index,
241 )
242 A.copy_(result_temp)
243 else:
244 use_int64_index = not can_use_int32_index(A)
245 with torch_device_fn.device(A.device):
246 if len(A.shape) == 2:
247 grid = lambda meta: (
248 min(triton.cdiv(M, meta["M_BLOCK_SIZE"]), TOTAL_CORE_NUM),
249 )
250 # A large value for n_block_size can lead to insufficient MLU resources,
251 # causing the compilation to fail. Therefore, a conservative upper limit of 8192
252 # is currently set, but the actual maximum achievable value should be confirmed
253 # based on real-world conditions.
254 elements_bytes = A.element_size()
255 n_block = min(256 * 1024 // elements_bytes, N)
256 need_loop = n_block < N
257 triu_kernel[grid](
258 A,
259 A,
260 M,
261 N,
262 diagonal,
263 N_BLOCK_SIZE=n_block,
264 NEED_LOOP=need_loop,
265 INT64_INDEX=use_int64_index,
266 )
267 else:
268 batch = int(torch.numel(A) / M / N)
269 B = A.view(batch, -1)
270 grid = lambda meta: (
271 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]),
272 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]),
273 )
274 triu_batch_kernel[grid](
275 B, B, batch, M * N, N, diagonal, INT64_INDEX=use_int64_index
276 )
277 return A