Coverage for src/flag_gems/ops/logspace.py: 76%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from ..utils import libentry 

9from ..utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def logspace_kernel( 

17 out_ptr, 

18 out_stride0, 

19 start, 

20 step_size, 

21 steps, 

22 log2_base: tl.constexpr, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tle.program_id(0) 

26 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

27 mask = idx < steps 

28 

29 exponent = start + idx * step_size 

30 vals = tl.exp2(log2_base * exponent) 

31 

32 tl.store(out_ptr + idx * out_stride0, vals, mask=mask) 

33 

34 

35def logspace( 

36 start, 

37 end, 

38 steps, 

39 base=10.0, 

40 *, 

41 dtype=None, 

42 layout=None, 

43 device=None, 

44 pin_memory=None, 

45) -> torch.Tensor: 

46 logger.debug("GEMS LOGSPACE") 

47 assert steps >= 0, "number of steps must be non-negative" 

48 out_dtype = dtype if dtype is not None else torch.get_default_dtype() 

49 out = torch.empty( 

50 steps, 

51 dtype=out_dtype, 

52 layout=layout, 

53 device=device, 

54 pin_memory=pin_memory, 

55 ) 

56 if steps == 0: 

57 pass 

58 elif steps == 1: 

59 if isinstance(start, torch.Tensor): 

60 start = start.item() 

61 out = torch.fill(out, base**start) 

62 else: 

63 if isinstance(start, torch.Tensor): 

64 start = start.item() 

65 if isinstance(end, torch.Tensor): 

66 end = end.item() 

67 step_size = (float(end) - float(start)) / (steps - 1) 

68 BLOCK_SIZE = 256 # according to benchmark, 256 is the best block size 

69 grid = (triton.cdiv(steps, BLOCK_SIZE),) 

70 logspace_kernel[grid]( 

71 out, 

72 out.stride(0), 

73 start, 

74 step_size, 

75 steps, 

76 log2_base=math.log2(float(base)), # math.log2 require float input 

77 BLOCK_SIZE=BLOCK_SIZE, 

78 ) 

79 

80 return out