Coverage for src/flag_gems/ops/arange.py: 75%
48 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit
17def arange_func(y_ptr, start, end, step, size, BLOCK_SIZE: tl.constexpr):
18 pid = tle.program_id(0)
19 y_ptr += pid * BLOCK_SIZE
20 step_offset = pid * BLOCK_SIZE * step
22 cols = tl.arange(0, BLOCK_SIZE)
23 arange_val = cols * step + step_offset + start
24 mask = cols + pid * BLOCK_SIZE
25 tl.store(y_ptr + cols, arange_val, mask=mask < size)
28def arange_start(
29 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None
30):
31 logger.debug("GEMS ARANGE")
32 if dtype is torch.int64:
33 start = int(start)
34 end = int(end)
35 step = int(step)
36 if step == 0:
37 raise RuntimeError("step must be nonzero")
38 sgn = (step > 0) - (step < 0)
39 size = (end - start + step - sgn) // step
40 else:
41 if dtype is torch.int64 and (
42 isinstance(step, float)
43 or isinstance(start, float)
44 or isinstance(end, float)
45 ):
46 int_step = int(step)
47 if int_step == 0:
48 raise RuntimeError("step must be nonzero")
49 size = math.ceil((end - start) / step)
50 size = int(size)
52 BLOCK_SIZE = 128
53 grid = triton.cdiv(size, BLOCK_SIZE)
55 if dtype is None:
56 dtype = torch.int64
58 if pin_memory is None:
59 pin_memory = False
61 if device is None:
62 device = (
63 runtime.device.name
64 ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU.
66 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory)
67 arange_func[grid,](result, start, end, step, size, BLOCK_SIZE)
68 return result
71def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None):
72 return arange_start(
73 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
74 )