Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sum.py: 0%
142 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry, libtuner
11from ..utils import MAX_GRID_SIZE_X, TOTAL_CORE_NUM, cfggen_reduce_op
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@libtuner(
18 configs=cfggen_reduce_op(), key=["M"], strategy=["log"], reset_to_zero=["out"]
19)
20@triton.jit
21def sum_kernel_1(
22 inp,
23 out,
24 M,
25 BLOCK_SIZE: tl.constexpr,
26):
27 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
28 inp.dtype.element_ty == tl.bfloat16
29 ):
30 cdtype = tl.float32
31 else:
32 cdtype = inp.dtype.element_ty
34 pid = tl.program_id(0)
35 num_jobs = tl.num_programs(axis=0)
36 block_start = pid * BLOCK_SIZE
37 step = num_jobs * BLOCK_SIZE
38 _tmp = tl.zeros([BLOCK_SIZE], dtype=cdtype)
39 block_start = block_start.to(tl.int64)
40 for off in range(block_start, M, step):
41 offset = off + tl.arange(0, BLOCK_SIZE)
42 mask = offset < M
43 inp_val = tl.load(inp + offset, mask=mask, other=0.0)
44 _tmp = inp_val + _tmp
46 sum_val = tl.sum(_tmp)
47 tl.atomic_add(out, sum_val)
50@libentry()
51@libtuner(
52 configs=runtime.get_tuned_config("sum"),
53 key=["M", "N"],
54 strategy=["log", "log"],
55)
56@triton.jit
57def sum_kernel(
58 inp,
59 out,
60 M,
61 N,
62 BLOCK_M: tl.constexpr,
63 BLOCK_N: tl.constexpr,
64):
65 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr(
66 inp.dtype.element_ty == tl.bfloat16
67 ):
68 cdtype = tl.float32
69 elif tl.constexpr(inp.dtype.element_ty == tl.int1):
70 cdtype = tl.int32
71 else:
72 cdtype = inp.dtype.element_ty
73 prog_num = tl.num_programs(0).to(tl.uint64)
74 sub_pid = tl.program_id(0).to(tl.uint64)
75 task_num = tl.cdiv(M, BLOCK_M).to(tl.uint64)
76 while sub_pid < task_num:
77 # Map the program id to the row of inp it should compute.
78 pid = sub_pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
79 inp_ = inp + pid * N
80 out_ = out + pid
81 row_mask = pid < M
83 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype)
84 for off in range(0, N, BLOCK_N):
85 cols = off + tl.arange(0, BLOCK_N)[None, :]
86 col_mask = cols < N
87 mask = row_mask and col_mask
89 a = tl.load(inp_ + cols, mask, other=0).to(cdtype)
90 _sum += a
91 sum = tl.sum(_sum, axis=1)[:, None]
92 tl.store(out_, sum, row_mask)
93 sub_pid += prog_num
96def sum(inp, *, dtype=None):
97 logger.debug("GEMS_CAMBRICON SUM")
98 M = inp.numel()
99 if dtype is None:
100 dtype = inp.dtype
101 if dtype is torch.bool:
102 inp = inp.to(torch.int32)
103 dtype = torch.int32
105 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
106 out = torch.zeros([], dtype=dtype, device=inp.device)
108 with torch_device_fn.device(inp.device):
109 sum_kernel_1[grid](inp, out, M)
110 return out.to(dtype)
113def sum_out(inp, *, dtype=None, out):
114 logger.debug("GEMS_CAMBRICON SUM_OUT")
115 M = inp.numel()
116 if dtype is None:
117 dtype = inp.dtype
118 if dtype is torch.bool:
119 inp = inp.to(torch.int32)
120 dtype = torch.int32
122 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
124 with torch_device_fn.device(inp.device):
125 sum_kernel_1[grid](inp, out, M)
126 return out.to(dtype)
129def sum_dim(inp, dim=None, keepdim=False, *, dtype=None):
130 logger.debug("GEMS_CAMBRICON SUM DIM")
131 if dtype is None:
132 dtype = inp.dtype
133 if dtype is torch.bool:
134 dtype = torch.int64
136 if dim is None:
137 result = torch.sum(inp, dtype=dtype)
138 if keepdim:
139 result = result.reshape([1] * inp.ndim)
140 return result
142 if dim == []:
143 if not keepdim:
144 return sum(inp, dtype=dtype)
145 else:
146 dim_num = inp.ndim
147 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num)
149 shape = list(inp.shape)
150 dim = [d % inp.ndim for d in dim]
151 inp = dim_compress(inp, dim)
152 N = 1
153 for i in dim:
154 N *= shape[i]
155 shape[i] = 1
156 M = inp.numel() // N
158 out = torch.empty(shape, dtype=dtype, device=inp.device)
160 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_M"]), MAX_GRID_SIZE_X // 4),)
161 with torch_device_fn.device(inp.device):
162 sum_kernel[grid](inp, out, M, N)
163 if not keepdim:
164 out = out.squeeze(dim=dim)
165 return out
168def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out):
169 logger.debug("GEMS_CAMBRICON SUM_DIM_OUT")
170 if dtype is None:
171 dtype = inp.dtype
172 if dtype is torch.bool:
173 dtype = torch.int64
175 if dim is None:
176 result = torch.sum(inp, dtype=dtype)
177 if keepdim:
178 result = result.reshape([1] * inp.ndim)
179 return result
181 if dim == []:
182 if not keepdim:
183 return sum_out(inp, dtype=dtype, out=out)
184 else:
185 dim_num = inp.ndim
186 return torch.reshape(sum_out(inp, dtype=dtype, out=out), [1] * dim_num)
188 shape = list(inp.shape)
189 dim = [d % inp.ndim for d in dim]
190 inp = dim_compress(inp, dim)
191 N = 1
192 for i in dim:
193 N *= shape[i]
194 shape[i] = 1
195 M = inp.numel() // N
197 grid = lambda meta: (min(triton.cdiv(M, meta["BLOCK_M"]), MAX_GRID_SIZE_X // 4),)
198 with torch_device_fn.device(inp.device):
199 sum_kernel[grid](inp, out, M, N)
200 if not keepdim:
201 out.squeeze_(dim=dim)
202 return out