Coverage for src/flag_gems/runtime/backend/_mthreads/ops/gather.py: 0%

71 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.ops.gather import gather as default_gather 

10from flag_gems.ops.gather import gather_backward as default_gather_backward 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import libentry 

13 

14logger = logging.getLogger( 

15 f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' 

16) 

17 

18_SUPPORTED_DTYPES = {torch.float16, torch.bfloat16, torch.float32} 

19 

20 

21@libentry() 

22@triton.heuristics(runtime.get_heuristic_config("gather")) 

23@triton.jit 

24def _gather_lastdim_kernel( 

25 inp_ptr, 

26 index_ptr, 

27 out_ptr, 

28 stride_inp_row, 

29 stride_index_row, 

30 stride_out_row, 

31 dim_stride, 

32 M, 

33 N, 

34 BLOCK_M: tl.constexpr, 

35 BLOCK_N: tl.constexpr, 

36): 

37 pid_m = tl.program_id(0) 

38 pid_n = tl.program_id(1) 

39 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

40 cols = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

41 rows = rows.to(tl.int64) 

42 cols = cols.to(tl.int64) 

43 mask = (rows < M) & (cols < N) 

44 

45 row_inp = rows * stride_inp_row 

46 row_idx = rows * stride_index_row 

47 row_out = rows * stride_out_row 

48 

49 idx = tl.load(index_ptr + row_idx + cols, mask=mask, other=0).to(tl.int64) 

50 gather_ptr = inp_ptr + row_inp + idx * dim_stride 

51 values = tl.load(gather_ptr, mask=mask, other=0) 

52 tl.store(out_ptr + row_out + cols, values, mask=mask) 

53 

54 

55def _normalize_dim(dim: int, ndim: int) -> int: 

56 return dim if dim >= 0 else dim + ndim 

57 

58 

59def _use_triton_kernel( 

60 inp: torch.Tensor, 

61 dim: int, 

62 index: torch.Tensor, 

63 out: Optional[torch.Tensor], 

64) -> bool: 

65 if inp.device.type != "musa" or index.device != inp.device: 

66 return False 

67 if inp.dtype not in _SUPPORTED_DTYPES or index.dtype != torch.long: 

68 return False 

69 

70 dim = _normalize_dim(dim, inp.ndim) 

71 if dim != inp.ndim - 1: 

72 return False 

73 

74 if not inp.is_contiguous() or not index.is_contiguous(): 

75 return False 

76 if out is not None: 

77 if ( 

78 out.device != inp.device 

79 or out.dtype != inp.dtype 

80 or not out.is_contiguous() 

81 ): 

82 return False 

83 

84 if index.shape[:-1] != inp.shape[:-1]: 

85 return False 

86 

87 return True 

88 

89 

90def _launch_triton( 

91 inp: torch.Tensor, 

92 index: torch.Tensor, 

93 out: torch.Tensor, 

94 dim_stride: int, 

95) -> torch.Tensor: 

96 inp_2d = inp.view(-1, inp.shape[-1]) 

97 index_2d = index.view(-1, index.shape[-1]) 

98 out_2d = out.view(-1, index.shape[-1]) 

99 

100 M, N = index_2d.shape 

101 stride_inp_row = inp_2d.stride(0) 

102 stride_index_row = index_2d.stride(0) 

103 stride_out_row = out_2d.stride(0) 

104 

105 grid = lambda meta: ( 

106 triton.cdiv(M, meta["BLOCK_M"]), 

107 triton.cdiv(N, meta["BLOCK_N"]), 

108 ) 

109 with torch_device_fn.device(out.device): 

110 _gather_lastdim_kernel[grid]( 

111 inp_2d, 

112 index_2d, 

113 out_2d, 

114 stride_inp_row, 

115 stride_index_row, 

116 stride_out_row, 

117 dim_stride, 

118 M, 

119 N, 

120 ) 

121 return out 

122 

123 

124def gather(inp, dim, index, out=None, sparse_grad=False): 

125 logger.debug("GEMS_MTHREADS GATHER") 

126 if not _use_triton_kernel(inp, dim, index, out): 

127 return default_gather(inp, dim, index, out, sparse_grad) 

128 

129 if out is None: 

130 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

131 

132 dim_stride = inp.stride(_normalize_dim(dim, inp.ndim)) 

133 return _launch_triton(inp, index, out, dim_stride) 

134 

135 

136def gather_backward(grad, self, dim, index, sparse_grad): 

137 logger.debug("GEMS_MTHREADS GATHER BACKWARD") 

138 return default_gather_backward(grad, self, dim, index, sparse_grad)