Coverage for src/flag_gems/runtime/backend/_ascend/ops/linspace.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8from flag_gems.utils import triton_lang_extension as tle 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11 

12 

13@libentry() 

14@triton.jit 

15def linspace_kernel( 

16 out_ptr, 

17 out_stride0, 

18 start, 

19 mid, 

20 end, 

21 step_size, 

22 steps, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tle.program_id(0) 

26 pnum = tle.num_programs(0) 

27 work_loads = tl.cdiv(steps, BLOCK_SIZE) 

28 loop_counts = tl.cdiv(work_loads, pnum) 

29 for loop in range(0, loop_counts): 

30 idx = (pid * loop_counts + loop) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

31 mask = idx < steps 

32 fw_mask = idx < mid 

33 fw_values = start + (step_size * idx) 

34 bd_values = end - step_size * (steps - idx - 1) 

35 

36 out_val = tl.where(fw_mask, fw_values, bd_values) 

37 tl.store(out_ptr + idx * out_stride0, out_val, mask=mask) 

38 

39 

40def linspace( 

41 start, end, steps, *, dtype=None, layout=None, device=None, pin_memory=None 

42) -> torch.Tensor: 

43 logger.debug("GEMS_ASCEND LINSPACE") 

44 assert steps >= 1, "steps must be >= 1" 

45 

46 out = torch.empty( 

47 steps, 

48 dtype=dtype, 

49 layout=layout, 

50 device=device, 

51 pin_memory=pin_memory, 

52 ) 

53 if steps == 1: 

54 return torch.fill(out, start) 

55 else: 

56 if isinstance(start, torch.Tensor): 

57 start = start.item() 

58 if isinstance(end, torch.Tensor): 

59 end = end.item() 

60 mid = steps // 2 

61 step_size = (float(end) - float(start)) / (steps - 1) 

62 BLOCK_SIZE = 128 

63 

64 def grid(meta): 

65 dim0 = triton.cdiv(steps, BLOCK_SIZE) 

66 while dim0 >= 65536: 

67 dim0 = triton.cdiv(dim0, 2) 

68 return (dim0,) 

69 

70 linspace_kernel[grid]( 

71 out, out.stride(0), start, mid, end, step_size, steps, BLOCK_SIZE=BLOCK_SIZE 

72 ) 

73 return out