Coverage for src/flag_gems/runtime/backend/_metax/ops/arange.py: 0%
37 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger("flag_gems." + __name__)
14@libentry()
15@triton.jit
16def arange_func(y_ptr, start, end, step, size, BLOCK_SIZE: tl.constexpr):
17 pid = tle.program_id(0)
18 y_ptr += pid * BLOCK_SIZE
19 step_offset = pid * BLOCK_SIZE * step
21 cols = tl.arange(0, BLOCK_SIZE)
22 arange_val = cols * step + step_offset + start
23 mask = cols + pid * BLOCK_SIZE
24 tl.store(y_ptr + cols, arange_val, mask=mask < size)
27def arange_start(
28 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None
29):
30 logger.debug("METAX GEMS ARANGE")
31 if dtype is torch.int64:
32 sgn = (step > 0) - (step < 0)
33 size = (end - start + step - sgn) // step
34 else:
35 size = math.ceil((end - start) / step)
37 BLOCK_SIZE = 1024
38 grid = triton.cdiv(size, BLOCK_SIZE)
40 if dtype is None:
41 dtype = torch.int64
43 if pin_memory is None:
44 pin_memory = False
46 if device is None:
47 device = (
48 device.name
49 ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU.
51 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory)
52 arange_func[grid,](result, start, end, step, size, BLOCK_SIZE)
53 return result
56def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None):
57 return arange_start(
58 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
59 )