Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/sum.py: 0%
169 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-30 03:43 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-30 03:43 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7# from flag_gems import runtime
8from flag_gems.ops.zeros import zero_
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry
11from flag_gems.utils import triton_lang_extension as tle
13from ..utils.block_size_utils import get_block_size_1d
15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
18@libentry()
19@triton.jit
20def sum_kernel_1(
21 inp,
22 mid,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
27 inp.dtype.element_ty == tl.bfloat16
28 ):
29 cdtype = tl.float32
30 else:
31 cdtype = inp.dtype.element_ty
33 pid = tle.program_id(0)
34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
35 inp_ptrs = inp + offset
36 mask = offset < M
38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype)
39 sum_val = tl.sum(inp_val)
40 mid_ptr = mid + pid
41 tl.store(mid_ptr, sum_val)
44@libentry()
45@triton.jit
46def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr(
48 mid.dtype.element_ty == tl.bfloat16
49 ):
50 cdtype = tl.float32
51 else:
52 cdtype = mid.dtype.element_ty
54 offset = tl.arange(0, BLOCK_MID)
55 mid_ptrs = mid + offset
56 mask = offset < mid_size
57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype)
58 sum_val = tl.sum(mid_val)
59 tl.store(out, sum_val)
62def heur_m_block_size(args):
63 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num
66def heur_n_block_size(args):
67 import builtins
69 return builtins.min(triton.next_power_of_2(args["N"]), 8192)
72@libentry()
73@triton.heuristics(
74 values={
75 "BLOCK_M": heur_m_block_size,
76 "BLOCK_N": heur_n_block_size,
77 },
78)
79@triton.jit
80def sum_kernel(
81 inp,
82 out,
83 M,
84 N,
85 BLOCK_M: tl.constexpr,
86 BLOCK_N: tl.constexpr,
87):
88 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
89 inp.dtype.element_ty == tl.bfloat16
90 ):
91 cdtype = tl.float32
92 else:
93 cdtype = inp.dtype.element_ty
95 # Map the program id to the row of inp it should compute.
96 pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
97 inp = inp + pid * N
98 out = out + pid
99 row_mask = pid < M
101 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
102 for off in range(0, N, BLOCK_N):
103 cols = off + tl.arange(0, BLOCK_N)[None, :]
104 col_mask = cols < N
105 mask = row_mask and col_mask
107 a = tl.load(inp + cols, mask, other=0).to(cdtype)
108 _sum += a
109 sum = tl.sum(_sum, axis=1)[:, None]
110 tl.store(out, sum, row_mask)
113def sum(inp, *, dtype=None):
114 logger.debug("GEMS SUM")
115 M = inp.numel()
116 if dtype is None:
117 dtype = inp.dtype
118 if dtype is torch.bool:
119 inp = inp.to(torch.int64)
120 dtype = torch.int64
121 block_size = get_block_size_1d(M, inp.element_size())
122 mid_size = triton.cdiv(M, block_size)
123 block_mid = triton.next_power_of_2(mid_size)
125 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
126 out = torch.empty([], dtype=dtype, device=inp.device)
128 with torch_device_fn.device(inp.device):
129 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048)
130 if mid_size == 1:
131 return mid.reshape([])
132 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
133 return out
136def sum_out(inp, *, dtype=None, out):
137 logger.debug("GEMS SUM_OUT")
138 M = inp.numel()
139 if dtype is None:
140 dtype = inp.dtype
141 if dtype is torch.bool:
142 inp = inp.to(torch.int64)
143 dtype = torch.int64
144 block_size = get_block_size_1d(M, inp.element_size())
145 mid_size = triton.cdiv(M, block_size)
146 block_mid = triton.next_power_of_2(mid_size)
148 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
149 with torch_device_fn.device(inp.device):
150 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size, buffer_size_limit=2048)
151 if mid_size == 1:
152 return mid.reshape([])
153 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid, buffer_size_limit=2048)
154 return out
157def sum_dim(inp, dim=None, keepdim=False, *, dtype=None):
158 logger.debug("GEMS SUM DIM")
159 if dtype is None:
160 dtype = inp.dtype
161 if dtype is torch.bool:
162 dtype = torch.int64
164 if inp.numel() == 0:
165 out_shape = list(inp.shape)
166 if dim is None:
167 out_shape = [1] * len(out_shape) if keepdim else []
168 else:
169 dims = dim if isinstance(dim, (list, tuple)) else [dim]
170 if keepdim:
171 for d in dims:
172 out_shape[d % inp.ndim] = 1
173 else:
174 for d in sorted(dims, key=lambda x: x % inp.ndim, reverse=True):
175 out_shape.pop(d % inp.ndim)
176 out = torch.empty(out_shape, dtype=dtype, device=inp.device)
177 zero_(out)
178 return out
180 if dim == []:
181 if not keepdim:
182 return sum(inp, dtype=dtype)
183 else:
184 dim_num = inp.ndim
185 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num)
187 shape = list(inp.shape)
188 dim = [d % inp.ndim for d in dim]
189 inp = dim_compress(inp, dim)
190 N = 1
191 for i in dim:
192 N *= shape[i]
193 shape[i] = 1
194 M = inp.numel() // N
196 out = torch.empty(shape, dtype=dtype, device=inp.device)
198 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
199 with torch_device_fn.device(inp.device):
200 sum_kernel[grid](inp, out, M, N, buffer_size_limit=2048)
201 if not keepdim:
202 out = out.squeeze(dim=dim)
203 return out
206def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out):
207 logger.debug("GEMS SUM_DIM_OUT")
208 if dtype is None:
209 dtype = inp.dtype
210 if dtype is torch.bool:
211 dtype = torch.int64
213 if inp.numel() == 0:
214 dims = (
215 dim
216 if isinstance(dim, (list, tuple))
217 else ([dim] if dim is not None else [])
218 )
219 if keepdim:
220 for d in dims:
221 pass # out shape already correct from caller
222 zero_(out)
223 return out
225 if dim == []:
226 if not keepdim:
227 return sum_out(inp, dtype=dtype, out=out)
228 else:
229 dim_num = inp.ndim
230 return torch.reshape(sum_out(inp, dtype=dtype, out=out), [1] * dim_num)
232 shape = list(inp.shape)
233 dim = [d % inp.ndim for d in dim]
234 inp = dim_compress(inp, dim)
235 N = 1
236 for i in dim:
237 N *= shape[i]
238 shape[i] = 1
239 M = inp.numel() // N
241 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
242 with torch_device_fn.device(inp.device):
243 sum_kernel[grid](inp, out, M, N, buffer_size_limit=2048)
244 if not keepdim:
245 out.squeeze_(dim=dim)
246 return out