Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/triu.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import builtins
2import logging
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def heur_m_block_size(args):
17 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
20def heur_n_block_size(args):
21 return builtins.min(args["N"], 8192)
24@libentry()
25# @triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"])
26@triton.heuristics(
27 values={
28 "M_BLOCK_SIZE": heur_m_block_size,
29 "N_BLOCK_SIZE": heur_n_block_size,
30 },
31)
32@triton.jit(do_not_specialize=["diagonal"])
33def triu_kernel(
34 X,
35 Y,
36 M,
37 N,
38 diagonal,
39 M_BLOCK_SIZE: tl.constexpr,
40 N_BLOCK_SIZE: tl.constexpr,
41):
42 pid = tle.program_id(0)
43 row = pid * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE)[:, None]
44 m_mask = row < M
45 X += row * N
46 Y += row * N
48 for n_offset in range(0, N, N_BLOCK_SIZE):
49 cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :]
50 n_mask = cols < N
51 mask = m_mask and n_mask
53 x = tl.load(X + cols, mask, other=0.0)
54 y = tl.where(row + diagonal <= cols, x, 0.0)
55 tl.store(Y + cols, y, mask=mask)
58def heur_batch_block_size(args):
59 return triton.next_power_of_2(triton.cdiv(args["batch"], 12)) # cluster_num
62def heur_mn_block_size(args):
63 return builtins.min(args["MN"], 8192)
66@libentry()
67# @triton.autotune(
68# configs=runtime.get_tuned_config("triu_batch"),
69# key=["batch", "MN", "N", "diagonal"],
70# )
71@triton.heuristics(
72 {
73 "BATCH_BLOCK_SIZE": heur_batch_block_size,
74 "MN_BLOCK_SIZE": heur_mn_block_size,
75 }
76)
77@triton.jit(do_not_specialize=["diagonal"])
78def triu_batch_kernel(
79 X,
80 Y,
81 batch,
82 MN,
83 N,
84 diagonal,
85 BATCH_BLOCK_SIZE: tl.constexpr,
86 MN_BLOCK_SIZE: tl.constexpr,
87):
88 batch_id = tle.program_id(0)
89 mn_id = tle.program_id(1)
90 row = batch_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None]
91 batch_mask = row < batch
92 X += row * MN
93 Y += row * MN
95 cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :]
96 mn_mask = cols < MN
97 mask = batch_mask and mn_mask
98 x = tl.load(X + cols, mask, other=0.0)
99 m = cols // N
100 n = cols % N
101 y = tl.where(m + diagonal <= n, x, 0.0)
102 tl.store(Y + cols, y, mask=mask)
105INT32_MAX = torch.iinfo(torch.int32).max
108def triu(A, diagonal=0):
109 logger.debug("GEMS TRIU")
110 A = A.contiguous()
111 out = torch.empty_like(A)
112 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions"
113 M, N = A.shape[-2:]
114 with torch_device_fn.device(A.device):
115 if len(A.shape) == 2:
116 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),)
117 triu_kernel[grid](A, out, M, N, diagonal)
118 else:
119 batch = int(torch.numel(A) / M / N)
120 B = A.view(batch, -1)
121 grid = lambda meta: (
122 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]),
123 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]),
124 )
125 triu_batch_kernel[grid](
126 B,
127 out,
128 batch,
129 M * N,
130 N,
131 diagonal,
132 )
133 out = out.view(A.shape)
134 return out