Coverage for src/flag_gems/runtime/backend/_mthreads/ops/index_add.py: 0%
79 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
7from flag_gems.utils import dim_compress, libentry
8from flag_gems.utils import triton_lang_extension as tle
10logger = logging.getLogger(
11 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
12)
15def cfggen():
16 """Generate autotune configurations for index_add kernel."""
17 block_m = [1, 2, 4, 8, 16]
18 block_n = [64, 128, 256, 512, 1024, 2048]
19 warps = [4, 8, 16]
20 configs = [
21 triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=w)
22 for m in block_m
23 for n in block_n
24 for w in warps
25 if m * n <= 16384 # Limit total block size
26 ]
27 return configs
30@libentry()
31@triton.autotune(configs=cfggen(), key=["M", "N"])
32@triton.jit
33def index_add_kernel(
34 inp_ptr,
35 out_ptr,
36 index_ptr,
37 src_ptr,
38 M,
39 N,
40 alpha,
41 inp_len,
42 BLOCK_M: tl.constexpr,
43 BLOCK_N: tl.constexpr,
44):
45 """
46 Kernel for index_add operation with autotune.
48 After dim_compress, tensors are reshaped so that:
49 - inp has shape (M, inp_len) where inp_len is the size of target dimension
50 - src has shape (M, N) where N is the size of index
52 For each row m and each index position n:
53 out[m, index[n]] += alpha * src[m, n]
54 """
55 pid_m = tle.program_id(axis=0)
56 pid_n = tle.program_id(axis=1)
58 # Calculate row and column offsets
59 rows_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
60 cols_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
62 # Create masks
63 rows_mask = rows_offset < M
64 cols_mask = cols_offset < N
65 block_mask = rows_mask & cols_mask
67 # Load indices for this block of columns
68 cur_indices = tl.load(index_ptr + cols_offset, mask=cols_mask, other=0)
70 # Calculate offsets into inp/out (which has shape M x inp_len)
71 inp_off = rows_offset * inp_len + cur_indices
73 # Load current values from input
74 cur_inp = tl.load(inp_ptr + inp_off, mask=block_mask, other=0.0)
76 # Calculate offsets into src (which has shape M x N)
77 src_off = rows_offset * N + cols_offset
79 # Load source values
80 cur_src = tl.load(src_ptr + src_off, mask=block_mask, other=0.0)
82 # Compute: out = inp + alpha * src
83 result = cur_inp + alpha * cur_src
85 # Store result
86 tl.store(out_ptr + inp_off, result, mask=block_mask)
89def index_add(inp, dim, index, src, alpha=1):
90 """
91 Optimized index_add for mthreads backend.
93 self.index_add_(dim, index, source, alpha=1) -> Tensor
95 For a 3-D tensor the output is:
96 self[index[i], :, :] += alpha * src[i, :, :] # if dim == 0
97 self[:, index[i], :] += alpha * src[:, i, :] # if dim == 1
98 self[:, :, index[i]] += alpha * src[:, :, i] # if dim == 2
99 """
100 logger.debug("GEMS_MTHREADS INDEX ADD")
102 # Make inputs contiguous
103 inp = inp.contiguous()
104 index = index.contiguous()
105 src = src.contiguous()
107 # Normalize dimension
108 dim = dim % inp.ndim
109 inp_len = inp.size(dim)
110 N = index.numel()
111 M = src.numel() // N
113 # Move target dim to last position for coalesced memory access
114 final_dim = inp.ndim - 1
115 if dim != final_dim:
116 inp = dim_compress(inp, dim)
117 src = dim_compress(src, dim)
119 # Clone input for output
120 out = inp.clone()
122 # Calculate grid with autotune
123 grid = lambda meta: (
124 triton.cdiv(M, meta["BLOCK_M"]),
125 triton.cdiv(N, meta["BLOCK_N"]),
126 )
128 with torch_device_fn.device(inp.device):
129 index_add_kernel[grid](inp, out, index, src, M, N, alpha, inp_len)
131 # Restore original dimension order if needed
132 if dim != final_dim:
133 order = list(range(out.ndim - 1))
134 order.insert(dim, final_dim)
135 return out.permute(order).contiguous()
136 else:
137 return out
140def index_add_(inp, dim, index, src, alpha=1):
141 """
142 In-place version of index_add.
143 """
144 logger.debug("GEMS_MTHREADS INDEX ADD_")
146 # Make index and src contiguous
147 index = index.contiguous()
148 src = src.contiguous()
150 # Normalize dimension
151 dim = dim % inp.ndim
152 inp_len = inp.size(dim)
153 N = index.numel()
154 M = src.numel() // N
156 # Move target dim to last position
157 final_dim = inp.ndim - 1
159 if dim != final_dim:
160 # Need to work on a permuted copy
161 inp_work = dim_compress(inp.clone().contiguous(), dim)
162 src_work = dim_compress(src, dim)
164 # Calculate grid with autotune
165 grid = lambda meta: (
166 triton.cdiv(M, meta["BLOCK_M"]),
167 triton.cdiv(N, meta["BLOCK_N"]),
168 )
170 with torch_device_fn.device(inp.device):
171 index_add_kernel[grid](
172 inp_work, inp_work, index, src_work, M, N, alpha, inp_len
173 )
175 # Restore original dimension order and copy back
176 order = list(range(inp_work.ndim - 1))
177 order.insert(dim, final_dim)
178 inp_work = inp_work.permute(order).contiguous()
179 inp.copy_(inp_work)
180 else:
181 # Can work directly on input if already contiguous
182 inp_contig = inp.contiguous()
184 # Calculate grid with autotune
185 grid = lambda meta: (
186 triton.cdiv(M, meta["BLOCK_M"]),
187 triton.cdiv(N, meta["BLOCK_N"]),
188 )
190 with torch_device_fn.device(inp.device):
191 index_add_kernel[grid](
192 inp_contig, inp_contig, index, src, M, N, alpha, inp_len
193 )
195 # Copy back if input wasn't contiguous
196 if not inp.is_contiguous():
197 inp.copy_(inp_contig)
199 return inp