Coverage for src/flag_gems/ops/arange.py: 75%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit 

17def arange_func(y_ptr, start, end, step, size, BLOCK_SIZE: tl.constexpr): 

18 pid = tle.program_id(0) 

19 y_ptr += pid * BLOCK_SIZE 

20 step_offset = pid * BLOCK_SIZE * step 

21 

22 cols = tl.arange(0, BLOCK_SIZE) 

23 arange_val = cols * step + step_offset + start 

24 mask = cols + pid * BLOCK_SIZE 

25 tl.store(y_ptr + cols, arange_val, mask=mask < size) 

26 

27 

28def arange_start( 

29 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None 

30): 

31 logger.debug("GEMS ARANGE") 

32 if dtype is torch.int64: 

33 start = int(start) 

34 end = int(end) 

35 step = int(step) 

36 if step == 0: 

37 raise RuntimeError("step must be nonzero") 

38 sgn = (step > 0) - (step < 0) 

39 size = (end - start + step - sgn) // step 

40 else: 

41 if dtype is torch.int64 and ( 

42 isinstance(step, float) 

43 or isinstance(start, float) 

44 or isinstance(end, float) 

45 ): 

46 int_step = int(step) 

47 if int_step == 0: 

48 raise RuntimeError("step must be nonzero") 

49 size = math.ceil((end - start) / step) 

50 size = int(size) 

51 

52 BLOCK_SIZE = 128 

53 grid = triton.cdiv(size, BLOCK_SIZE) 

54 

55 if dtype is None: 

56 dtype = torch.int64 

57 

58 if pin_memory is None: 

59 pin_memory = False 

60 

61 if device is None: 

62 device = ( 

63 runtime.device.name 

64 ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU. 

65 

66 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory) 

67 arange_func[grid,](result, start, end, step, size, BLOCK_SIZE) 

68 return result 

69 

70 

71def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None): 

72 return arange_start( 

73 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 

74 )