Coverage for src/flag_gems/runtime/backend/_ascend/ops/embedding.py: 0%

110 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@triton.jit 

17def embedding_kernel( 

18 out_ptr, # pointer to the output 

19 in_ptr, # pointer to the input 

20 weight_ptr, # pointer to the weights 

21 N: tl.constexpr, # number of columns in X 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 pid = tle.program_id(0) 

25 out_ptr += pid * N 

26 in_ptr += pid 

27 

28 mask = tl.arange(0, BLOCK_SIZE) < N 

29 cols = tl.arange(0, BLOCK_SIZE) 

30 

31 row_idx = tl.load(in_ptr) 

32 weight_ptr += row_idx * N 

33 embedding_weight = tl.load(weight_ptr + cols, mask, other=0.0) 

34 tl.store(out_ptr + cols, embedding_weight, mask) 

35 

36 

37@libentry() 

38@triton.jit 

39def indice_freq_kernel( 

40 indices_freq, 

41 indices, # pointer to the input 

42 elem_cnt: tl.constexpr, # number of columns in X 

43 INDICE_BLOCK_SIZE: tl.constexpr, 

44): 

45 pid = tle.program_id(0) 

46 block_start = pid * INDICE_BLOCK_SIZE 

47 

48 for i in range(INDICE_BLOCK_SIZE): 

49 off = block_start + i 

50 if off < elem_cnt: 

51 idx = tl.load(indices + off) 

52 tl.atomic_add(indices_freq + idx, 1) 

53 

54 

55@libentry() 

56@triton.jit(do_not_specialize=["padding_idx"]) 

57def embedding_backward_kernel( 

58 grad_in, # pointer to the gradient input 

59 grad_out, # pointer to the gradient output 

60 indices, # pointer to the input 

61 padding_idx, # padding_idx 

62 HAS_PADDING_IDX: tl.constexpr, 

63 N: tl.constexpr, # number of columns in X 

64 BLOCK_SIZE: tl.constexpr, 

65): 

66 pid = tle.program_id(0) 

67 grad_out += pid * N 

68 indices += pid 

69 

70 mask = tl.arange(0, BLOCK_SIZE) < N 

71 cols = tl.arange(0, BLOCK_SIZE) 

72 

73 row_idx = tl.load(indices).to(tl.int32) 

74 if not HAS_PADDING_IDX: 

75 grad_in += row_idx * N 

76 embedding_grad = tl.load(grad_out + cols, mask, other=0.0) 

77 if tl.constexpr(embedding_grad.dtype.is_bf16()): 

78 embedding_grad = embedding_grad.to(tl.float32) 

79 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask) 

80 else: 

81 if row_idx != padding_idx: 

82 grad_in += row_idx * N 

83 embedding_grad = tl.load(grad_out + cols, mask, other=0.0) 

84 if tl.constexpr(embedding_grad.dtype.is_bf16()): 

85 embedding_grad = embedding_grad.to(tl.float32) 

86 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask) 

87 

88 

89@libentry() 

90@triton.jit(do_not_specialize=["n_rows"]) 

91def embedding_grad_scale_kernel( 

92 grad_out, 

93 indice_freq, 

94 n_rows, 

95 N, 

96 BLOCK_SIZE: tl.constexpr, 

97): 

98 row_start = tle.program_id(0) 

99 row_step = tle.num_programs(0) 

100 

101 for row_idx in range(row_start, n_rows, row_step): 

102 embedding_scale = 1.0 

103 indice_freq_val = tl.load(indice_freq + row_idx) 

104 if indice_freq_val > 1: 

105 embedding_scale = 1.0 / indice_freq_val 

106 

107 cols = tl.arange(0, BLOCK_SIZE) 

108 mask = tl.arange(0, BLOCK_SIZE) < N 

109 embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask) 

110 scaled_embedding_grad = embedding_grad * embedding_scale 

111 tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask) 

112 

113 

114class Embedding(torch.autograd.Function): 

115 @staticmethod 

116 def forward( 

117 ctx, weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False 

118 ): 

119 logger.debug("GEMS_ASCEND EMBEDDING FORWARD") 

120 assert not sparse, "Currently do not support sparse format" 

121 

122 M = math.prod(indices.shape) 

123 N = weight.shape[-1] 

124 

125 BLOCK_SIZE = triton.next_power_of_2(N) 

126 indices = indices.contiguous() 

127 weight = weight.contiguous() 

128 output = torch.empty( 

129 (*indices.shape, N), device=indices.device, dtype=weight.dtype 

130 ) 

131 

132 with torch_device_fn.device(weight.device): 

133 embedding_kernel[M,](output, indices, weight, N, BLOCK_SIZE) 

134 

135 ctx.M = M 

136 ctx.N = N 

137 ctx.num_weights = weight.shape[0] 

138 ctx.padding_idx = padding_idx 

139 ctx.scale_grad_by_freq = scale_grad_by_freq 

140 ctx.sparse = sparse 

141 ctx.indices = indices 

142 

143 return output 

144 

145 @staticmethod 

146 def backward(ctx, grad_outputs): 

147 logger.debug("GEMS_ASCEND EMBEDDING BACKWARD") 

148 assert not ctx.sparse, "Currently do not support sparse format" 

149 

150 grad_inputs = torch.zeros( 

151 (ctx.num_weights, grad_outputs.shape[-1]), 

152 device=grad_outputs.device, 

153 dtype=( 

154 torch.float32 

155 if grad_outputs.dtype is torch.bfloat16 

156 else grad_outputs.dtype 

157 ), 

158 ) 

159 

160 if ctx.scale_grad_by_freq: 

161 indice_freq = torch.zeros( 

162 (ctx.num_weights,), 

163 requires_grad=False, 

164 device=grad_outputs.device, 

165 dtype=torch.int32, 

166 ) 

167 INDICE_BLOCK_SIZE = 256 

168 indice_grid = lambda meta: (triton.cdiv(ctx.M, INDICE_BLOCK_SIZE),) 

169 

170 with torch_device_fn.device(grad_outputs.device): 

171 indice_freq_kernel[indice_grid]( 

172 indice_freq, ctx.indices, ctx.M, INDICE_BLOCK_SIZE 

173 ) 

174 else: 

175 indice_freq = None 

176 

177 BLOCK_SIZE = triton.next_power_of_2(ctx.N) 

178 

179 HAS_PADDING_IDX = ctx.padding_idx is not None 

180 

181 with torch_device_fn.device(grad_outputs.device): 

182 embedding_backward_kernel[ctx.M,]( 

183 grad_inputs, 

184 grad_outputs, 

185 ctx.indices, 

186 ctx.padding_idx, 

187 HAS_PADDING_IDX, 

188 ctx.N, 

189 BLOCK_SIZE, 

190 ) 

191 

192 if ctx.scale_grad_by_freq: 

193 with torch_device_fn.device(grad_outputs.device): 

194 embedding_grad_scale_kernel[ctx.M,]( 

195 grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE 

196 ) 

197 return ( 

198 ( 

199 grad_inputs.to(torch.bfloat16) 

200 if grad_outputs.dtype is torch.bfloat16 

201 else grad_inputs 

202 ), 

203 None, 

204 None, 

205 None, 

206 None, 

207 ) 

208 

209 

210def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): 

211 return Embedding.apply(weight, indices, padding_idx, scale_grad_by_freq, sparse)