Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/arange.py: 0%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit
17def arange_func(
18 y_ptr,
19 start,
20 end,
21 step,
22 size,
23 BLOCK_SIZE: tl.constexpr,
24 buffer_size_limit: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 y_ptr += pid * BLOCK_SIZE
28 step_offset = pid * BLOCK_SIZE * step
30 cols = tl.arange(0, BLOCK_SIZE)
31 arange_val = cols * step + step_offset + start
32 mask = cols + pid * BLOCK_SIZE
33 tl.store(y_ptr + cols, arange_val, mask=mask < size)
36def arange_start(
37 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None
38):
39 logger.debug("GEMS ARANGE")
40 if dtype is torch.int64:
41 start = int(start)
42 end = int(end)
43 step = int(step)
44 if step == 0:
45 raise RuntimeError("step must be nonzero")
46 sgn = (step > 0) - (step < 0)
47 size = (end - start + step - sgn) // step
48 else:
49 if dtype is torch.int64 and (
50 isinstance(step, float)
51 or isinstance(start, float)
52 or isinstance(end, float)
53 ):
54 int_step = int(step)
55 if int_step == 0:
56 raise RuntimeError("step must be nonzero")
57 size = math.ceil((end - start) / step)
58 size = int(size)
60 cluster_num = 12
61 tmp = torch.tensor([], dtype=dtype)
62 BLOCK_SIZE = min(
63 triton.next_power_of_2(triton.cdiv(size, cluster_num)),
64 int(2048 * 64 / tmp.element_size()),
65 )
66 grid = triton.cdiv(size, BLOCK_SIZE)
68 if dtype is None:
69 dtype = torch.int64
71 if pin_memory is None:
72 pin_memory = False
74 if device is None:
75 device = (
76 runtime.device.name
77 ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU.
79 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory)
80 arange_func[grid,](
81 result, start, end, step, size, BLOCK_SIZE, buffer_size_limit=2048
82 )
83 return result
86def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None):
87 return arange_start(
88 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
89 )