Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/logspace.py: 0%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@triton.jit 

15def exp2_tmp(x): 

16 LN2 = 0.69314718056 

17 return tl.exp(x.to(tl.float32) * LN2) 

18 

19 

20@libentry() 

21@triton.jit 

22def logspace_kernel( 

23 out_ptr, 

24 out_stride0, 

25 start, 

26 step_size, 

27 steps, 

28 log2_base: tl.constexpr, 

29 BLOCK_SIZE: tl.constexpr, 

30): 

31 pid = tle.program_id(0) 

32 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

33 mask = idx < steps 

34 

35 exponent = start + idx * step_size 

36 vals = exp2_tmp(log2_base * exponent) 

37 

38 tl.store(out_ptr + idx * out_stride0, vals, mask=mask) 

39 

40 

41def logspace( 

42 start, 

43 end, 

44 steps, 

45 base=10.0, 

46 *, 

47 dtype=None, 

48 layout=None, 

49 device=None, 

50 pin_memory=None, 

51) -> torch.Tensor: 

52 logger.debug("GEMS LOGSPACE") 

53 assert steps >= 0, "number of steps must be non-negative" 

54 out_dtype = dtype if dtype is not None else torch.get_default_dtype() 

55 out = torch.empty( 

56 steps, 

57 dtype=out_dtype, 

58 layout=layout, 

59 device=device, 

60 pin_memory=pin_memory, 

61 ) 

62 if steps == 0: 

63 pass 

64 elif steps == 1: 

65 if isinstance(start, torch.Tensor): 

66 start = start.item() 

67 out = torch.fill(out, base**start) 

68 else: 

69 if isinstance(start, torch.Tensor): 

70 start = start.item() 

71 if isinstance(end, torch.Tensor): 

72 end = end.item() 

73 step_size = (float(end) - float(start)) / (steps - 1) 

74 BLOCK_SIZE = 256 # according to benchmark, 256 is the best block size 

75 grid = (triton.cdiv(steps, BLOCK_SIZE),) 

76 logspace_kernel[grid]( 

77 out, 

78 out.stride(0), 

79 start, 

80 step_size, 

81 steps, 

82 log2_base=math.log2(float(base)), # math.log2 require float input 

83 BLOCK_SIZE=BLOCK_SIZE, 

84 ) 

85 

86 return out