Coverage for src/flag_gems/runtime/backend/_metax/ops/arange.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger("flag_gems." + __name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def arange_func(y_ptr, start, end, step, size, BLOCK_SIZE: tl.constexpr): 

17 pid = tle.program_id(0) 

18 y_ptr += pid * BLOCK_SIZE 

19 step_offset = pid * BLOCK_SIZE * step 

20 

21 cols = tl.arange(0, BLOCK_SIZE) 

22 arange_val = cols * step + step_offset + start 

23 mask = cols + pid * BLOCK_SIZE 

24 tl.store(y_ptr + cols, arange_val, mask=mask < size) 

25 

26 

27def arange_start( 

28 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None 

29): 

30 logger.debug("METAX GEMS ARANGE") 

31 if dtype is torch.int64: 

32 sgn = (step > 0) - (step < 0) 

33 size = (end - start + step - sgn) // step 

34 else: 

35 size = math.ceil((end - start) / step) 

36 

37 BLOCK_SIZE = 1024 

38 grid = triton.cdiv(size, BLOCK_SIZE) 

39 

40 if dtype is None: 

41 dtype = torch.int64 

42 

43 if pin_memory is None: 

44 pin_memory = False 

45 

46 if device is None: 

47 device = ( 

48 device.name 

49 ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU. 

50 

51 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory) 

52 arange_func[grid,](result, start, end, step, size, BLOCK_SIZE) 

53 return result 

54 

55 

56def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None): 

57 return arange_start( 

58 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory 

59 )