Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/embedding.py: 0%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@libentry() 

15@triton.jit 

16def embedding_kernel( 

17 out_ptr, # pointer to the output 

18 in_ptr, # pointer to the input 

19 weight_ptr, # pointer to the weights 

20 N: tl.constexpr, # number of columns in X 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid = tle.program_id(0) 

24 out_ptr += pid * N 

25 in_ptr += pid 

26 

27 mask = tl.arange(0, BLOCK_SIZE) < N 

28 cols = tl.arange(0, BLOCK_SIZE) 

29 

30 row_idx = tl.load(in_ptr) 

31 weight_ptr += row_idx * N 

32 embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0) 

33 tl.store(out_ptr + cols, embedding_weight, mask) 

34 

35 

36@libentry() 

37@triton.jit 

38def indice_freq_kernel( 

39 indices_freq, 

40 indices, # pointer to the input 

41 elem_cnt: tl.constexpr, # number of columns in X 

42 INDICE_BLOCK_SIZE: tl.constexpr, 

43): 

44 pid = tle.program_id(0) 

45 block_start = pid * INDICE_BLOCK_SIZE 

46 

47 offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE) 

48 mask = offsets < elem_cnt 

49 

50 index_element = tl.load(indices + offsets, mask=mask) 

51 tl.atomic_add(indices_freq + index_element, 1, mask=mask) 

52 

53 

54@libentry() 

55@triton.jit(do_not_specialize=["padding_idx"]) 

56def embedding_backward_kernel( 

57 grad_in, # pointer to the gradient input 

58 grad_out, # pointer to the gradient output 

59 indices, # pointer to the input 

60 padding_idx, # padding_idx 

61 HAS_PADDING_IDX: tl.constexpr, 

62 N: tl.constexpr, # number of columns in X 

63 BLOCK_SIZE: tl.constexpr, 

64): 

65 pid = tle.program_id(0) 

66 grad_out += pid * N 

67 indices += pid 

68 

69 mask = tl.arange(0, BLOCK_SIZE) < N 

70 cols = tl.arange(0, BLOCK_SIZE) 

71 

72 row_idx = tl.load(indices).to(tl.int32) 

73 if not HAS_PADDING_IDX: 

74 grad_in += row_idx * N 

75 embedding_grad = tl.load(grad_out + cols, mask, other=0.0) 

76 if tl.constexpr(embedding_grad.dtype.is_bf16()): 

77 embedding_grad = embedding_grad.to(tl.float32) 

78 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask) 

79 else: 

80 if row_idx != padding_idx: 

81 grad_in += row_idx * N 

82 embedding_grad = tl.load(grad_out + cols, mask, other=0.0) 

83 if tl.constexpr(embedding_grad.dtype.is_bf16()): 

84 embedding_grad = embedding_grad.to(tl.float32) 

85 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask) 

86 

87 

88@libentry() 

89@triton.jit(do_not_specialize=["n_rows"]) 

90def embedding_grad_scale_kernel( 

91 grad_out, 

92 indice_freq, 

93 n_rows, 

94 N, 

95 BLOCK_SIZE: tl.constexpr, 

96): 

97 row_start = tle.program_id(0) 

98 row_step = tle.num_programs(0) 

99 

100 for row_idx in range(row_start, n_rows, row_step): 

101 embedding_scale = 1.0 

102 indice_freq_val = tl.load(indice_freq + row_idx) 

103 if indice_freq_val > 1: 

104 embedding_scale = 1.0 / indice_freq_val 

105 

106 cols = tl.arange(0, BLOCK_SIZE) 

107 mask = tl.arange(0, BLOCK_SIZE) < N 

108 embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask) 

109 scaled_embedding_grad = embedding_grad * embedding_scale 

110 tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask) 

111 

112 

113def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): 

114 logger.debug("GEMS EMBEDDING FORWARD") 

115 assert not sparse, "Currently do not support sparse format" 

116 

117 M = indices.numel() 

118 N = weight.shape[-1] 

119 

120 BLOCK_SIZE = triton.next_power_of_2(N) 

121 # TODO: remove contiguous enforcement 

122 indices = indices.contiguous() 

123 weight = weight.contiguous() 

124 output = torch.empty((*indices.shape, N), device=indices.device, dtype=weight.dtype) 

125 

126 with torch_device_fn.device(weight.device): 

127 embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE) 

128 

129 return output 

130 

131 

132def embedding_backward( 

133 grad_outputs, 

134 indices, 

135 num_weights, 

136 padding_idx=-1, 

137 scale_grad_by_freq=False, 

138 sparse=False, 

139): 

140 logger.debug("GEMS EMBEDDING BACKWARD") 

141 assert not sparse, "Currently do not support sparse format" 

142 

143 M = indices.numel() 

144 N = grad_outputs.shape[-1] 

145 

146 grad_inputs = torch.zeros( 

147 (num_weights, grad_outputs.shape[-1]), 

148 device=grad_outputs.device, 

149 dtype=( 

150 torch.float32 

151 if grad_outputs.dtype is torch.bfloat16 

152 else grad_outputs.dtype 

153 ), 

154 ) 

155 

156 if scale_grad_by_freq: 

157 indice_freq = torch.zeros( 

158 (num_weights,), 

159 requires_grad=False, 

160 device=grad_outputs.device, 

161 dtype=torch.int32, 

162 ) 

163 INDICE_BLOCK_SIZE = 256 

164 indice_grid = (triton.cdiv(M, INDICE_BLOCK_SIZE),) 

165 

166 with torch_device_fn.device(grad_outputs.device): 

167 indice_freq_kernel[indice_grid]( 

168 indice_freq, 

169 indices, 

170 M, 

171 INDICE_BLOCK_SIZE, 

172 isCLOSE_TTXPU_O_ATOMIC_SIM=True, 

173 ) 

174 else: 

175 indice_freq = None 

176 

177 BLOCK_SIZE = triton.next_power_of_2(N) 

178 

179 HAS_PADDING_IDX = padding_idx is not None 

180 

181 with torch_device_fn.device(grad_outputs.device): 

182 embedding_backward_kernel[M,]( 

183 grad_inputs, 

184 grad_outputs, 

185 indices, 

186 padding_idx, 

187 HAS_PADDING_IDX, 

188 N, 

189 BLOCK_SIZE, 

190 ) 

191 

192 if scale_grad_by_freq: 

193 with torch_device_fn.device(grad_outputs.device): 

194 embedding_grad_scale_kernel[M,]( 

195 grad_inputs, indice_freq, num_weights, N, BLOCK_SIZE 

196 ) 

197 return ( 

198 grad_inputs.to(torch.bfloat16) 

199 if grad_outputs.dtype is torch.bfloat16 

200 else grad_inputs 

201 )