Coverage for src/flag_gems/runtime/backend/_cambricon/ops/logspace.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import libentry 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit 

17def logspace_kernel( 

18 out_ptr, 

19 out_stride0, 

20 start, 

21 step_size, 

22 steps, 

23 log2_base: tl.constexpr, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tl.program_id(0) 

27 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 mask = idx < steps 

29 

30 exponent = start + idx * step_size 

31 vals = tl.exp2(log2_base * exponent) 

32 

33 tl.store(out_ptr + idx * out_stride0, vals, mask=mask) 

34 

35 

36def logspace( 

37 start, 

38 end, 

39 steps, 

40 base=10.0, 

41 *, 

42 dtype=None, 

43 layout=None, 

44 device=None, 

45 pin_memory=None, 

46) -> torch.Tensor: 

47 logger.debug("GEMS_CAMBRICON LOGSPACE") 

48 assert steps >= 0, "number of steps must be non-negative" 

49 out_dtype = dtype if dtype is not None else torch.get_default_dtype() 

50 out = torch.empty( 

51 steps, 

52 dtype=out_dtype, 

53 layout=layout, 

54 device=device, 

55 pin_memory=pin_memory, 

56 ) 

57 if steps == 0: 

58 pass 

59 elif steps == 1: 

60 if isinstance(start, torch.Tensor): 

61 start = start.item() 

62 out = torch.fill(out, base**start) 

63 else: 

64 if isinstance(start, torch.Tensor): 

65 start = start.item() 

66 if isinstance(end, torch.Tensor): 

67 end = end.item() 

68 step_size = (float(end) - float(start)) / (steps - 1) 

69 BLOCK_SIZE = 256 # according to benchmark, 256 is the best block size 

70 grid = lambda meta: (min(triton.cdiv(steps, BLOCK_SIZE), TOTAL_CORE_NUM),) 

71 logspace_kernel[grid]( 

72 out, 

73 out.stride(0), 

74 start, 

75 step_size, 

76 steps, 

77 log2_base=math.log2(float(base)), # math.log2 require float input 

78 BLOCK_SIZE=BLOCK_SIZE, 

79 ) 

80 

81 return out