Coverage for src/flag_gems/runtime/backend/_cambricon/ops/linspace.py: 0%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import libentry
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.heuristics(runtime.get_heuristic_config("linspace"))
17@triton.jit
18def linspace_kernel(
19 out_ptr,
20 out_stride0,
21 start,
22 mid,
23 end,
24 step_size,
25 steps,
26 BLOCK_SIZE: tl.constexpr,
27 INNER_BLOCK_SIZE: tl.constexpr,
28):
29 pid = tl.program_id(0)
30 numcycle = (BLOCK_SIZE + INNER_BLOCK_SIZE - 1) // INNER_BLOCK_SIZE
31 for innerid in range(0, numcycle):
32 inneridx = innerid * INNER_BLOCK_SIZE + tl.arange(0, INNER_BLOCK_SIZE)
33 idx = pid * BLOCK_SIZE + inneridx
34 mask = (idx < steps) & (inneridx < BLOCK_SIZE)
35 fw_mask = idx < mid
36 fw_values = start + (step_size * idx)
37 bd_values = end - step_size * (steps - idx - 1)
39 out_val = tl.where(fw_mask, fw_values, bd_values)
40 tl.store(out_ptr + idx * out_stride0, out_val, mask=mask)
43def linspace(
44 start, end, steps, *, dtype=None, layout=None, device=None, pin_memory=None
45) -> torch.Tensor:
46 logger.debug("GEMS_CAMBRICON LINSPACE")
47 assert steps >= 1, "steps must be >= 1"
49 out = torch.empty(
50 steps,
51 dtype=dtype,
52 layout=layout,
53 device=device,
54 pin_memory=pin_memory,
55 )
56 if steps == 1:
57 return torch.fill(out, start)
58 else:
59 if isinstance(start, torch.Tensor):
60 start = start.item()
61 if isinstance(end, torch.Tensor):
62 end = end.item()
63 mid = steps // 2
64 step_size = (float(end) - float(start)) / (steps - 1)
65 block_size = triton.cdiv(steps, TOTAL_CORE_NUM)
66 grid = (TOTAL_CORE_NUM,)
67 linspace_kernel[grid](
68 out, out.stride(0), start, mid, end, step_size, steps, BLOCK_SIZE=block_size
69 )
70 return out