Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/linspace.py: 0%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
8from flag_gems.utils import triton_lang_extension as tle
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13@libentry()
14@triton.jit
15def linspace_kernel(
16 out_ptr,
17 out_stride0,
18 start,
19 mid,
20 end,
21 step_size,
22 steps,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
26 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
27 mask = idx < steps
28 fw_mask = idx < mid
29 fw_values = start + (step_size * idx)
30 bd_values = end - step_size * (steps - idx - 1)
32 out_val = tl.where(fw_mask, fw_values, bd_values)
33 tl.store(out_ptr + idx * out_stride0, out_val, mask=mask)
36def linspace(
37 start, end, steps, *, dtype=None, layout=None, device=None, pin_memory=None
38) -> torch.Tensor:
39 logger.debug("GEMS LINSPACE")
40 assert steps >= 1, "steps must be >= 1"
42 out = torch.empty(
43 steps,
44 dtype=dtype,
45 layout=layout,
46 device=device,
47 pin_memory=pin_memory,
48 )
49 if steps == 1:
50 return torch.fill(out, start)
51 else:
52 if isinstance(start, torch.Tensor):
53 start = start.item()
54 if isinstance(end, torch.Tensor):
55 end = end.item()
56 mid = steps // 2
57 step_size = (float(end) - float(start)) / (steps - 1)
58 BLOCK_SIZE = 128
59 grid = (triton.cdiv(steps, BLOCK_SIZE),)
60 linspace_kernel[grid](
61 out, out.stride(0), start, mid, end, step_size, steps, BLOCK_SIZE=BLOCK_SIZE
62 )
63 return out