Coverage for src/flag_gems/runtime/backend/_mthreads/ops/arange.py: 0%
82 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
2import math
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.arange import arange_start as default_arange_start
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry
13from flag_gems.utils import triton_lang_extension as tle
15logger = logging.getLogger(
16 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}'
17)
19device_ = runtime.device
20_SUPPORTED_DTYPES = {
21 torch.float16,
22 torch.bfloat16,
23 torch.float32,
24 torch.int32,
25 torch.int64,
26}
27_AUTOTUNE_CONFIGS = [
28 triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=1),
29 triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=1),
30 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
31 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=1),
32]
35@libentry()
36@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["n_elements", "USE_INT64"])
37@triton.jit(do_not_specialize=["start", "step"])
38def arange_kernel(
39 out_ptr,
40 start,
41 step,
42 n_elements,
43 IS_FLOAT: tl.constexpr,
44 USE_INT64: tl.constexpr,
45 BLOCK_SIZE: tl.constexpr,
46):
47 pid = tle.program_id(0)
48 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
49 if USE_INT64:
50 offsets = offsets.to(tl.int64)
51 n_elements = tl.full((1,), n_elements, tl.int64)
52 else:
53 offsets = offsets.to(tl.int32)
54 n_elements = tl.full((1,), n_elements, tl.int32)
55 mask = offsets < n_elements
57 if IS_FLOAT:
58 idx = offsets.to(tl.float32)
59 step_val = tl.full((1,), step, tl.float32)
60 start_val = tl.full((1,), start, tl.float32)
61 values = tl.fma(idx, step_val, start_val)
62 else:
63 value_dtype = tl.int64 if USE_INT64 else tl.int32
64 idx = offsets.to(value_dtype)
65 step_val = tl.full((1,), step, value_dtype)
66 start_val = tl.full((1,), start, value_dtype)
67 values = start_val + idx * step_val
69 tl.store(out_ptr + offsets, values, mask=mask)
72def _normalize_scalar(value):
73 if isinstance(value, torch.Tensor):
74 return value.item()
75 return value
78def _compute_size(start, end, step, is_float_dtype: bool) -> int:
79 if step == 0:
80 raise ValueError("arange(): step must be non-zero.")
81 if is_float_dtype:
82 size = math.ceil((end - start) / step)
83 else:
84 sgn = (step > 0) - (step < 0)
85 size = (end - start + step - sgn) // step
86 return int(size) if size > 0 else 0
89def _use_triton(dtype: torch.dtype, device: torch.device, size: int) -> bool:
90 if device.type != "musa":
91 return False
92 if dtype not in _SUPPORTED_DTYPES:
93 return False
94 return size > 0
97def _launch_triton_kernel(
98 out: torch.Tensor,
99 start,
100 step,
101 size: int,
102 *,
103 is_float_dtype: bool,
104 use_int64: bool,
105):
106 grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),)
107 with torch_device_fn.device(out.device):
108 arange_kernel[grid](
109 out,
110 start,
111 step,
112 size,
113 IS_FLOAT=is_float_dtype,
114 USE_INT64=use_int64,
115 )
116 return out
119def arange_start(
120 start,
121 end,
122 step=1,
123 *,
124 dtype: Optional[torch.dtype] = None,
125 layout=None,
126 device=None,
127 pin_memory: Optional[bool] = None,
128):
129 logger.debug("GEMS_MTHREADS ARANGE")
130 start = _normalize_scalar(start)
131 end = _normalize_scalar(end)
132 step = _normalize_scalar(step)
134 if dtype is None:
135 dtype = torch.int64
136 if pin_memory is None:
137 pin_memory = False
138 if device is None:
139 device = torch.device(device_.name)
140 else:
141 device = torch.device(device)
143 is_float_dtype = torch.is_floating_point(torch.tensor(0, dtype=dtype))
144 use_int64 = dtype == torch.int64
145 size = _compute_size(start, end, step, is_float_dtype)
147 if not _use_triton(dtype, device, size):
148 return default_arange_start(
149 start,
150 end,
151 step,
152 dtype=dtype,
153 layout=layout,
154 device=device,
155 pin_memory=pin_memory,
156 )
158 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory)
159 return _launch_triton_kernel(
160 result,
161 start,
162 step,
163 size,
164 is_float_dtype=is_float_dtype,
165 use_int64=use_int64,
166 )
169def arange(
170 end,
171 *,
172 dtype: Optional[torch.dtype] = None,
173 layout=None,
174 device=None,
175 pin_memory: Optional[bool] = None,
176):
177 return arange_start(
178 0,
179 end,
180 1,
181 dtype=dtype,
182 layout=layout,
183 device=device,
184 pin_memory=pin_memory,
185 )