Coverage for src/flag_gems/ops/tile.py: 99%
255 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import importlib
2import logging
3import os
4from typing import Callable, List, Mapping
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11logger = logging.getLogger(__name__)
14# --------------------------- tile wrapper genration -----------------------------------
15def parameter_for_wrapper() -> str:
16 """Generate parameter declaration with type annotation for wrapper function.
17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
18 """
19 parameters: List[str] = []
21 parameters.append("in0")
22 parameters.append("dims")
23 return ", ".join(parameters)
26def parameter_for_wrapper_out() -> str:
27 """Generate parameter declaration with type annotation for wrapper function.
28 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor
29 """
30 parameters: List[str] = []
32 parameters.append("in0")
33 parameters.append("out0")
35 return ", ".join(parameters)
38def parameter_ref_for_wrapper() -> str:
39 """Generate parameter reference for wrapper function.
40 Example: in0, val0, out0, out0_offset
41 """
42 parameters: List[str] = []
44 parameters.append("in0")
45 parameters.append("out0")
47 return ", ".join(parameters)
50def output_ref_for_wrapper() -> str:
51 return "out0"
54def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
55 code.writeline("import math")
56 code.writeline("import torch")
57 code.writeline("import triton")
58 code.writeline("from triton import language as tl")
59 code.newline()
60 code.writeline("from flag_gems.runtime import torch_device_fn")
61 code.writeline("from flag_gems.utils.shape_utils import volume")
62 code.writeline("from flag_gems.utils.libentry import libentry")
63 code.writeline("from flag_gems.utils.type_utils import type_promotion")
64 code.writeline("from flag_gems.utils import triton_lang_extension as tle")
65 code.newline()
66 code.newline()
67 return code
70def generate_functional_tile_wrapper(
71 wrapper_name: str,
72 destination_passing_func_name: str,
73 code: IndentedBuffer,
74) -> IndentedBuffer:
75 # wrapper signature
76 parameters: str = parameter_for_wrapper()
77 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
78 code.writeline(wrapper_signature)
80 with code.indent():
81 code.writeline("in0_rank = in0.dim()")
82 code.writeline("dims_rank = len(dims)")
83 code.writeline("in0_shape = list(in0.shape)")
84 code.writeline("dims_shape = list(dims)")
85 code.newline()
86 code.writeline("if (dims_rank < in0_rank): ")
87 with code.indent():
88 code.writeline("diff = in0_rank - dims_rank")
89 code.writeline("ones = [1 for _ in range(diff)]")
90 code.writeline("dims_shape = ones + dims_shape")
91 code.writeline("elif (dims_rank > in0_rank): ")
92 with code.indent():
93 code.writeline("diff = dims_rank - in0_rank")
94 code.writeline("ones = [1 for _ in range(diff)]")
95 code.writeline("in0_shape = ones + in0_shape")
96 code.newline()
97 code.writeline("is_empty = False")
98 code.writeline("out_shape = []")
99 code.writeline("for i in range(len(in0_shape)): ")
100 with code.indent():
101 code.writeline(
102 "assert(dims_shape[i] >= 0), 'the number of repetitions per dimension out of range (expected to >= 0) \
103 but got {}'.format(dims_shape[i])"
104 )
105 code.writeline("if dims_shape[i] == 0: ")
106 with code.indent():
107 code.writeline("is_empty = True")
108 code.writeline("out_shape.append(in0_shape[i] * dims_shape[i])")
109 code.newline()
110 code.writeline(
111 "out0 = torch.empty(out_shape, device=in0.device, dtype=in0.dtype)"
112 )
114 code.writeline("in0 = in0.reshape(in0_shape)")
115 code.writeline("if not is_empty: ")
116 with code.indent():
117 # call destination_passing_func
118 output_names: str = output_ref_for_wrapper()
119 call_str = (
120 f"{output_names} = {destination_passing_func_name}"
121 f"({parameter_ref_for_wrapper()})"
122 )
123 code.writeline(call_str)
125 return_str = "return out0"
126 code.writeline(return_str)
127 code.newline()
128 code.newline()
130 return code
133def generate_destination_passing_tile_wrapper(
134 rank: int,
135 wrapper_name: str,
136 kernel_name: str,
137 code: IndentedBuffer,
138) -> IndentedBuffer:
139 # wrapper signature
140 parameters: str = parameter_for_wrapper_out()
142 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
143 code.writeline(wrapper_signature)
145 with code.indent():
146 # docstring
147 if rank > 0:
148 code.writeline("shape = out0.shape")
149 code.writeline("num_tasks = volume(shape)")
151 if rank > 0:
152 code.writeline("tile_size = min(512, triton.next_power_of_2(num_tasks))")
153 code.writeline("num_warps = 4")
154 code.writeline("num_ctas = min(65535, triton.cdiv(num_tasks, tile_size))")
155 code.writeline(
156 "tiles_per_cta = triton.cdiv(num_tasks, tile_size * num_ctas)"
157 )
158 else:
159 code.writeline("num_warps = 1")
160 code.writeline("num_ctas = 1")
161 code.writeline("grid = (num_ctas, 1, 1)")
162 code.newline()
164 # input strides for each input tensor w.r.t. the task index space
165 if rank > 0:
166 code.writeline("# strides of each tensor argument w.r.t the task space")
167 code.writeline("in0_strides = in0.stride()")
168 code.writeline("in0_shape = in0.shape")
169 code.writeline("out0_strides = out0.stride()")
170 code.newline()
172 # grid
173 code.writeline("# kernel launch")
175 # launch kernel
176 code.writeline("with torch_device_fn.device(in0.device.index):")
177 with code.indent():
178 kernel_launch: str = f"{kernel_name}[grid]("
179 code.writeline(kernel_launch)
181 with code.indent():
182 code.writeline("in0, out0, ")
184 if rank > 0:
185 s = ", ".join(f"in0_strides[{j}]" for j in range(rank))
186 code.writeline(f"{s}, # stride for in0")
188 s = ", ".join(f"out0_strides[{j}]" for j in range(rank))
189 code.writeline(f"{s}, # stride for out0")
191 shape_args: str = ", ".join(f"shape[{i}]" for i in range(rank))
192 code.writeline(f"{shape_args}, # task indexing space")
193 in_shape_args: str = ", ".join(f"in0_shape[{i}]" for i in range(rank))
194 code.writeline(
195 f"{in_shape_args}, # task indexing space used when input and ouput tensor has different shape"
196 )
197 code.writeline("num_tasks, # num tasks")
198 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta")
199 code.writeline("tile_size=tile_size,")
200 code.writeline("one_tile_per_cta=tiles_per_cta==1,")
201 code.writeline("num_warps=num_warps,")
202 code.writeline(")")
204 # return
205 code.writeline("return out0")
206 code.newline()
207 code.newline()
208 return code
211def generate_tile_kernel(
212 rank: int,
213 kernel_name: str,
214 code: IndentedBuffer,
215) -> IndentedBuffer:
216 # make the inlined function visible in the context
217 code.newline()
219 # the decorators
220 code.writeline("@libentry()")
221 code.writeline("@triton.jit")
223 # signature
224 code.writeline(f"def {kernel_name}(")
225 with code.indent():
226 # signature: inputs ptrs & non tensor inputs
227 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type")
229 # signature: output ptrs
230 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type")
232 # signature: strides, for each tensor arguments
233 # only add this arguments when rank > 0
234 if rank > 0:
235 # strides for inputs
236 stride_args = ", ".join(f"in0_stride{j}: int" for j in range(rank))
237 code.writeline(f"{stride_args}, # strides for in0")
239 # strides for outputs
240 stride_args = ", ".join(f"out0_stride{j}: int" for j in range(rank))
241 code.writeline(f"{stride_args}, # strides for out0")
243 # task space, used to reconstruct multi index
244 task_space_args = ", ".join(f"s{i}: int" for i in range(rank))
245 code.writeline(f"{task_space_args}, # task_space")
247 task_space_args2 = ", ".join(f"in_s{i}: int" for i in range(rank))
248 code.writeline(
249 f"{task_space_args2}, # task_space2 used when input and output tensor has different shape"
250 )
252 # number of tasks, used to compute mask
253 code.writeline("num_tasks: int,")
255 # tile size & tiles_per_cta, gsl style
256 if rank > 0:
257 code.writeline("tiles_per_cta,")
259 code.writeline("tile_size: tl.constexpr,")
261 code.writeline("one_tile_per_cta: tl.constexpr,")
262 code.writeline("):")
264 with code.indent():
265 # get pid
266 code.writeline("# task id & masking")
267 pid_stmt = "pid = tle.program_id(0)"
268 code.writeline(pid_stmt)
270 code.writeline("num_ctas = tle.num_programs(0)")
272 # get tid (a.k.a task id)
273 tid_stmt = "init_tid = pid * tile_size + tl.arange(0, tile_size)"
274 code.writeline(tid_stmt)
276 # one-tile-per-cta, monolithic kernel style
277 code.writeline("if one_tile_per_cta: # monolitic kernel style")
278 with code.indent():
279 tid_stmt = "tid = init_tid"
280 code.writeline(tid_stmt)
282 # only apply masking when rank > 0
283 # since we only load a value instead of a block of values when the rank is 0
284 mask_stmt: str = "mask = tid < num_tasks"
285 code.writeline(mask_stmt)
286 code.newline()
288 # reconstruct multi index
289 code.writeline("# multi index recontruction")
290 for i in reversed(range(rank)):
291 if i > 0:
292 code.writeline(f"i{i} = tid % s{i}")
293 code.writeline(f"tid //= s{i}")
294 else:
295 code.writeline(f"i{i} = tid")
296 code.newline()
298 # loads
299 code.writeline("# loads")
300 ptrs_expr: str = " + ".join(
301 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank)
302 )
303 ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
304 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
305 code.writeline(load_stmt)
306 code.newline()
308 # compute
309 code.writeline("# compute")
310 code.writeline("out0 = in0")
311 code.newline()
313 # stores
314 code.writeline("# stores")
315 ptrs_expr: str = " + ".join(f"i{j} * out0_stride{j}" for j in range(rank))
316 ptrs_expr: str = f"out0_ptr + {ptrs_expr}"
317 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)"
318 code.writeline(store_stmt)
320 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
321 code.writeline("else: # grid-stride-loop style kernel")
322 with code.indent():
323 code.writeline("for j in range(0, tiles_per_cta):")
324 with code.indent():
325 tid_stmt = "tid = init_tid + j * tile_size * num_ctas"
326 code.writeline(tid_stmt)
328 # only apply masking when rank > 0
329 # since we only load a value instead of a block of values when the rank is 0
330 mask_stmt: str = "mask = tid < num_tasks"
331 code.writeline(mask_stmt)
332 code.newline()
334 # reconstruct multi index
335 code.writeline("# multi index recontruction")
336 for i in reversed(range(rank)):
337 if i > 0:
338 code.writeline(f"i{i} = tid % s{i}")
339 code.writeline(f"tid //= s{i}")
340 else:
341 code.writeline(f"i{i} = tid")
342 code.newline()
344 # loads
345 code.writeline("# loads")
346 ptrs_expr: str = " + ".join(
347 f"(i{j} % in_s{j}) * in{i}_stride{j}" for j in range(rank)
348 )
349 ptrs_expr: str = f"in0_ptr + {ptrs_expr}"
350 load_stmt: str = f"in0 = tl.load({ptrs_expr}, mask=mask)"
351 code.writeline(load_stmt)
352 code.newline()
354 # compute
355 code.writeline("# compute")
356 code.writeline("out0 = in0")
357 code.newline()
359 # stores
360 code.writeline("# stores")
361 ptrs_expr: str = " + ".join(
362 f"i{j} * out0_stride{j}" for j in range(rank)
363 )
364 ptrs_expr: str = f"out0_ptr + {ptrs_expr}"
365 store_stmt: str = f"tl.store({ptrs_expr}, out0, mask=mask)"
366 code.writeline(store_stmt)
367 code.newline()
368 return code
371def generate_code(
372 rank: int,
373 wrapper_name: str,
374 destination_passing_func_name: str,
375 kernel_name: str,
376 code: IndentedBuffer,
377) -> IndentedBuffer:
378 # the only runtime determined factor is the rank of the task space
379 code = generate_imports(code)
380 code = generate_functional_tile_wrapper(
381 wrapper_name, destination_passing_func_name, code
382 )
383 code = generate_destination_passing_tile_wrapper(
384 rank, destination_passing_func_name, kernel_name, code
385 )
386 code = generate_tile_kernel(rank, kernel_name, code)
387 return code
390class TileFunction:
391 def __init__(self):
392 self.pid = os.getpid()
393 # instantiated & cached overloads
394 self.overloads: Mapping[str, Callable] = {}
396 def __call__(self, x, dims):
397 # note: kwargs should not be used in JITFunction directly
398 ndim = self.arg_key(x, dims)
399 key = str(ndim)
400 if key in self.overloads:
401 overload = self.overloads[key]
402 else:
403 # generate file & import it
404 code = IndentedBuffer()
405 code = generate_code(
406 ndim,
407 "_wrapper",
408 "_wrapper_out",
409 "_tile_flaggems_jit_function",
410 code,
411 )
413 file_name = f"tile_rank_{key}.py"
414 file_path = code_cache_dir() / file_name
415 write_atomic(file_path, code.getvalue())
417 # load
418 spec = importlib.util.spec_from_file_location(
419 f"_gen_module_rank_{key}",
420 file_path,
421 )
423 m = importlib.util.module_from_spec(spec)
424 # do not expose it to sys.modules
425 # sys.modules["_add_module"] = m
426 spec.loader.exec_module(m)
427 overload = getattr(m, "_wrapper")
428 self.overloads[key] = overload
429 return overload(x, dims)
431 def arg_key(self, x, dims):
432 max_rank = max(x.ndim, len(dims))
433 return max_rank
436_tile_func = TileFunction()
439def tile(inp: torch.Tensor, dims) -> torch.Tensor:
440 logger.debug("GEMS TILE")
442 out = _tile_func(inp, dims)
443 return out