Coverage for src/flag_gems/runtime/backend/_cambricon/ops/linspace.py: 0%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import libentry 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.heuristics(runtime.get_heuristic_config("linspace")) 

17@triton.jit 

18def linspace_kernel( 

19 out_ptr, 

20 out_stride0, 

21 start, 

22 mid, 

23 end, 

24 step_size, 

25 steps, 

26 BLOCK_SIZE: tl.constexpr, 

27 INNER_BLOCK_SIZE: tl.constexpr, 

28): 

29 pid = tl.program_id(0) 

30 numcycle = (BLOCK_SIZE + INNER_BLOCK_SIZE - 1) // INNER_BLOCK_SIZE 

31 for innerid in range(0, numcycle): 

32 inneridx = innerid * INNER_BLOCK_SIZE + tl.arange(0, INNER_BLOCK_SIZE) 

33 idx = pid * BLOCK_SIZE + inneridx 

34 mask = (idx < steps) & (inneridx < BLOCK_SIZE) 

35 fw_mask = idx < mid 

36 fw_values = start + (step_size * idx) 

37 bd_values = end - step_size * (steps - idx - 1) 

38 

39 out_val = tl.where(fw_mask, fw_values, bd_values) 

40 tl.store(out_ptr + idx * out_stride0, out_val, mask=mask) 

41 

42 

43def linspace( 

44 start, end, steps, *, dtype=None, layout=None, device=None, pin_memory=None 

45) -> torch.Tensor: 

46 logger.debug("GEMS_CAMBRICON LINSPACE") 

47 assert steps >= 1, "steps must be >= 1" 

48 

49 out = torch.empty( 

50 steps, 

51 dtype=dtype, 

52 layout=layout, 

53 device=device, 

54 pin_memory=pin_memory, 

55 ) 

56 if steps == 1: 

57 return torch.fill(out, start) 

58 else: 

59 if isinstance(start, torch.Tensor): 

60 start = start.item() 

61 if isinstance(end, torch.Tensor): 

62 end = end.item() 

63 mid = steps // 2 

64 step_size = (float(end) - float(start)) / (steps - 1) 

65 block_size = triton.cdiv(steps, TOTAL_CORE_NUM) 

66 grid = (TOTAL_CORE_NUM,) 

67 linspace_kernel[grid]( 

68 out, out.stride(0), start, mid, end, step_size, steps, BLOCK_SIZE=block_size 

69 ) 

70 return out