Coverage for src/flag_gems/runtime/backend/_ascend/ops/linspace.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
8from flag_gems.utils import triton_lang_extension as tle
10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13@libentry()
14@triton.jit
15def linspace_kernel(
16 out_ptr,
17 out_stride0,
18 start,
19 mid,
20 end,
21 step_size,
22 steps,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
26 pnum = tle.num_programs(0)
27 work_loads = tl.cdiv(steps, BLOCK_SIZE)
28 loop_counts = tl.cdiv(work_loads, pnum)
29 for loop in range(0, loop_counts):
30 idx = (pid * loop_counts + loop) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31 mask = idx < steps
32 fw_mask = idx < mid
33 fw_values = start + (step_size * idx)
34 bd_values = end - step_size * (steps - idx - 1)
36 out_val = tl.where(fw_mask, fw_values, bd_values)
37 tl.store(out_ptr + idx * out_stride0, out_val, mask=mask)
40def linspace(
41 start, end, steps, *, dtype=None, layout=None, device=None, pin_memory=None
42) -> torch.Tensor:
43 logger.debug("GEMS_ASCEND LINSPACE")
44 assert steps >= 1, "steps must be >= 1"
46 out = torch.empty(
47 steps,
48 dtype=dtype,
49 layout=layout,
50 device=device,
51 pin_memory=pin_memory,
52 )
53 if steps == 1:
54 return torch.fill(out, start)
55 else:
56 if isinstance(start, torch.Tensor):
57 start = start.item()
58 if isinstance(end, torch.Tensor):
59 end = end.item()
60 mid = steps // 2
61 step_size = (float(end) - float(start)) / (steps - 1)
62 BLOCK_SIZE = 128
64 def grid(meta):
65 dim0 = triton.cdiv(steps, BLOCK_SIZE)
66 while dim0 >= 65536:
67 dim0 = triton.cdiv(dim0, 2)
68 return (dim0,)
70 linspace_kernel[grid](
71 out, out.stride(0), start, mid, end, step_size, steps, BLOCK_SIZE=BLOCK_SIZE
72 )
73 return out