Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/index_add.py: 0%
80 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import dim_compress, libentry
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11# def cfggen():
12# block_m = [1, 2, 4]
13# block_n = [128, 1024, 2048, 4096]
14# configs = [
15# triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)
16# for m in block_m
17# for n in block_n
18# ]
19# return configs
22@libentry()
23# @triton.autotune(configs=cfggen(), key=["M", "N"])
24@triton.heuristics(runtime.get_heuristic_config("index_add"))
25# @triton.autotune(
26# configs=[], generate_configs="index_add", op_affiliation="cluster", row_sign="M", col_sign="N",
27# key=["M", "N"],
28# )
29@triton.jit
30def index_add_kernel(
31 inp,
32 inp_cont,
33 index,
34 src,
35 M: tl.constexpr,
36 N: tl.constexpr,
37 alpha,
38 inp_len,
39 BLOCK_M: tl.constexpr,
40 BLOCK_N: tl.constexpr,
41):
42 pid_x = tl.program_id(axis=0) # block_x
43 pid_y = tl.program_id(axis=1) # block_y
44 rows_offsets = (
45 pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
46 ) # block_x * BLOCK_M + tl.arange(0, BLOCK_M)
47 cols_offsets = pid_y * BLOCK_N + tl.arange(
48 0, BLOCK_N
49 ) # block_y * BLOCK_N + tl.arange(0, BLOCK_N)
51 rows_mask = (
52 rows_offsets < M
53 ) # rows_mask = block_x * BLOCK_M + tl.arange(0, BLOCK_M) < M
54 index_mask = (
55 cols_offsets < N
56 ) # index_mask = block_y * BLOCK_N + tl.arange(0, BLOCK_N) < N
57 block_mask = rows_mask and index_mask # block_mask = rows_mask and index_mask
59 cur_indices = tl.load(
60 index + cols_offsets, mask=index_mask, other=0
61 ) # cur_indices = tl.load(index + cols_offsets, mask=index_mask, other=0)
62 inp_off = (
63 rows_offsets * inp_len + cur_indices[None, :]
64 ) # inp_off = (block_x * BLOCK_M + tl.arange(0, BLOCK_M)) * M + cur_indices
65 cur_inp = tl.load(
66 inp + inp_off, mask=block_mask, other=0.0
67 ) # cur_inp = tl.load(inp + inp_off, mask=block_mask, other=0.0)
68 src_off = (
69 rows_offsets * N + cols_offsets[None, :]
70 ) # src_off = (block_x * BLOCK_M + tl.arange(0, BLOCK_M)) * N + block_y * BLOCK_N + tl.arange(0, BLOCK_N)
71 cur_src = tl.load(
72 src + src_off, mask=block_mask, other=0.0
73 ) # cur_src = tl.load(src + src_off, mask=block_mask, other=0.0)
74 cur_inp += alpha * cur_src
76 tl.store(inp_cont + inp_off, cur_inp, mask=block_mask)
79def index_add(inp, dim, index, src, alpha=1):
80 logger.debug("GEMS INDEX ADD")
81 assert ((0 <= index) * (index < inp.size(dim))).equal(
82 torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda")
83 ), "0 <= index < self.size(dim)"
84 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
85 assert index.numel() == src.size(
86 dim
87 ), "The dimth dimension of source must have the same size as the length of index"
88 assert (
89 inp.ndim == src.ndim
90 ), "Self and source should have the same number of dimensions"
91 assert (
92 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
93 ), "src.size(d) == self.size(d) for all dimensions d != dim"
95 inp = inp.contiguous()
96 index = index.contiguous()
97 src = src.contiguous()
99 dim = dim % inp.ndim
100 inp_len = inp.size(dim)
101 N = index.numel()
102 M = src.numel() // N
103 fine_dim = inp.ndim - 1
104 if dim != fine_dim:
105 inp = dim_compress(inp, dim)
106 src = dim_compress(src, dim)
107 inp_cont = inp.clone()
109 grid = lambda meta: (
110 triton.cdiv(M, meta["BLOCK_M"]),
111 triton.cdiv(N, meta["BLOCK_N"]),
112 )
113 index_add_kernel[grid](inp, inp_cont, index, src, M, N, alpha, inp_len)
114 if dim != fine_dim:
115 order = [i for i in range(inp_cont.ndim - 1)]
116 order.insert(dim, fine_dim)
117 return inp_cont.permute(order).contiguous()
118 else:
119 return inp_cont
122def index_add_(inp, dim, index, src, alpha=1):
123 logger.debug("GEMS INDEX ADD_")
124 assert ((0 <= index) * (index < inp.size(dim))).equal(
125 torch.ones(tuple(index.shape), dtype=torch.bool, device="cuda")
126 ), "0 <= index < self.size(dim)"
127 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
128 assert index.numel() == src.size(
129 dim
130 ), "The dimth dimension of source must have the same size as the length of index"
131 assert (
132 inp.ndim == src.ndim
133 ), "Self and source should have the same number of dimensions"
134 assert (
135 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
136 ), "src.size(d) == self.size(d) for all dimensions d != dim"
138 inp_cont = inp.clone()
139 inp_cont = inp_cont.contiguous()
140 index = index.contiguous()
141 src = src.contiguous()
143 dim = dim % inp_cont.ndim
144 inp_len = inp_cont.size(dim)
145 N = index.numel()
146 M = src.numel() // N
147 fine_dim = inp_cont.ndim - 1
148 if dim != fine_dim:
149 inp_cont = dim_compress(inp_cont, dim)
150 src = dim_compress(src, dim)
152 grid = lambda meta: (
153 triton.cdiv(M, meta["BLOCK_M"]),
154 triton.cdiv(N, meta["BLOCK_N"]),
155 )
156 index_add_kernel[grid](inp_cont, inp_cont, index, src, M, N, alpha, inp_len)
157 if dim != fine_dim:
158 order = [i for i in range(inp_cont.ndim - 1)]
159 order.insert(dim, fine_dim)
160 inp_cont = inp_cont.permute(order).contiguous()
161 inp.copy_(inp_cont)
162 return inp
163 else:
164 inp.copy_(inp_cont)
165 return inp