Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/arange.py: 0%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit 

17def arange_func( 

18 y_ptr, 

19 start, 

20 end, 

21 step, 

22 size, 

23 BLOCK_SIZE: tl.constexpr, 

24 buffer_size_limit: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 y_ptr += pid * BLOCK_SIZE 

28 step_offset = pid * BLOCK_SIZE * step 

29 

30 cols = tl.arange(0, BLOCK_SIZE) 

31 arange_val = cols * step + step_offset + start 

32 mask = cols + pid * BLOCK_SIZE 

33 tl.store(y_ptr + cols, arange_val, mask=mask < size) 

34 

35 

36def arange_start( 

37 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None 

38): 

39 logger.debug("GEMS ARANGE") 

40 if dtype is torch.int64: 

41 start = int(start) 

42 end = int(end) 

43 step = int(step) 

44 if step == 0: 

45 raise RuntimeError("step must be nonzero") 

46 sgn = (step > 0) - (step < 0) 

47 size = (end - start + step - sgn) // step 

48 else: 

49 if dtype is torch.int64 and ( 

50 isinstance(step, float) 

51 or isinstance(start, float) 

52 or isinstance(end, float) 

53 ): 

54 int_step = int(step) 

55 if int_step == 0: 

56 raise RuntimeError("step must be nonzero") 

57 size = math.ceil((end - start) / step) 

58 size = int(size) 

59 

60 cluster_num = 12 

61 tmp = torch.tensor([], dtype=dtype) 

62 BLOCK_SIZE = min( 

63 triton.next_power_of_2(triton.cdiv(size, cluster_num)), 

64 int(2048 * 64 / tmp.element_size()), 

65 ) 

66 grid = triton.cdiv(size, BLOCK_SIZE) 

67 

68 if dtype is None: 

69 dtype = torch.int64 

70 

71 if pin_memory is None: 

72 pin_memory = False 

73 

74 if device is None: 

75 device = ( 

76 runtime.device.name 

77 ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU. 

78 

79 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory) 

80 arange_func[grid,]( 

81 result, start, end, step, size, BLOCK_SIZE, buffer_size_limit=2048 

82 ) 

83 return result 

84 

85 

86def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None): 

87 return arange_start( 

88 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 

89 )