Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/linspace.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8from flag_gems.utils import triton_lang_extension as tle 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13@libentry() 

14@triton.jit 

15def linspace_kernel( 

16 out_ptr, 

17 out_stride0, 

18 start, 

19 mid, 

20 end, 

21 step_size, 

22 steps, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tle.program_id(0) 

26 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

27 mask = idx < steps 

28 fw_mask = idx < mid 

29 fw_values = start + (step_size * idx) 

30 bd_values = end - step_size * (steps - idx - 1) 

31 

32 out_val = tl.where(fw_mask, fw_values, bd_values) 

33 tl.store(out_ptr + idx * out_stride0, out_val, mask=mask) 

34 

35 

36def linspace( 

37 start, end, steps, *, dtype=None, layout=None, device=None, pin_memory=None 

38) -> torch.Tensor: 

39 logger.debug("GEMS LINSPACE") 

40 assert steps >= 1, "steps must be >= 1" 

41 

42 out = torch.empty( 

43 steps, 

44 dtype=dtype, 

45 layout=layout, 

46 device=device, 

47 pin_memory=pin_memory, 

48 ) 

49 if steps == 1: 

50 return torch.fill(out, start) 

51 else: 

52 if isinstance(start, torch.Tensor): 

53 start = start.item() 

54 if isinstance(end, torch.Tensor): 

55 end = end.item() 

56 mid = steps // 2 

57 step_size = (float(end) - float(start)) / (steps - 1) 

58 BLOCK_SIZE = 128 

59 grid = (triton.cdiv(steps, BLOCK_SIZE),) 

60 linspace_kernel[grid]( 

61 out, out.stride(0), start, mid, end, step_size, steps, BLOCK_SIZE=BLOCK_SIZE 

62 ) 

63 return out