Coverage for src/flag_gems/runtime/backend/_cambricon/ops/embedding.py: 0%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13@libentry() 

14@triton.jit 

15def indice_freq_kernel( 

16 indices_freq, 

17 indices, # pointer to the input 

18 elem_cnt: tl.constexpr, # number of columns in X 

19 INDICE_BLOCK_SIZE: tl.constexpr, 

20): 

21 pid = tl.program_id(0) 

22 block_start = pid * INDICE_BLOCK_SIZE 

23 

24 offsets = block_start + tl.arange(0, INDICE_BLOCK_SIZE) 

25 mask = offsets < elem_cnt 

26 

27 index_element = tl.load(indices + offsets, mask=mask) 

28 tl.atomic_add(indices_freq + index_element, 1, mask=mask) 

29 

30 

31@libentry() 

32@triton.jit(do_not_specialize=["padding_idx"]) 

33def embedding_backward_kernel( 

34 grad_in, # pointer to the gradient input 

35 grad_out, # pointer to the gradient output 

36 indices, # pointer to the input 

37 padding_idx, # padding_idx 

38 HAS_PADDING_IDX: tl.constexpr, 

39 N: tl.constexpr, # number of columns in X 

40 BLOCK_SIZE: tl.constexpr, 

41): 

42 pid = tl.program_id(0) 

43 grad_out += pid * N 

44 indices += pid 

45 

46 mask = tl.arange(0, BLOCK_SIZE) < N 

47 cols = tl.arange(0, BLOCK_SIZE) 

48 

49 row_idx = tl.load(indices).to(tl.int32) 

50 if not HAS_PADDING_IDX: 

51 grad_in += row_idx * N 

52 embedding_grad = tl.load(grad_out + cols, mask, other=0.0) 

53 if tl.constexpr(embedding_grad.dtype.is_bf16()): 

54 embedding_grad = embedding_grad.to(tl.float32) 

55 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask) 

56 else: 

57 if row_idx != padding_idx: 

58 grad_in += row_idx * N 

59 embedding_grad = tl.load(grad_out + cols, mask, other=0.0) 

60 if tl.constexpr(embedding_grad.dtype.is_bf16()): 

61 embedding_grad = embedding_grad.to(tl.float32) 

62 tl.atomic_add(grad_in + cols, embedding_grad, mask=mask) 

63 

64 

65@libentry() 

66@triton.jit(do_not_specialize=["n_rows"]) 

67def embedding_grad_scale_kernel( 

68 grad_out, 

69 indice_freq, 

70 n_rows, 

71 N, 

72 BLOCK_SIZE: tl.constexpr, 

73): 

74 row_start = tl.program_id(0) 

75 row_step = tl.num_programs(0) 

76 

77 for row_idx in range(row_start, n_rows, row_step): 

78 embedding_scale = 1.0 

79 indice_freq_val = tl.load(indice_freq + row_idx) 

80 if indice_freq_val > 1: 

81 embedding_scale = 1.0 / indice_freq_val 

82 

83 cols = tl.arange(0, BLOCK_SIZE) 

84 mask = tl.arange(0, BLOCK_SIZE) < N 

85 embedding_grad = tl.load(grad_out + row_idx * N + cols, mask=mask) 

86 scaled_embedding_grad = embedding_grad * embedding_scale 

87 tl.store(grad_out + row_idx * N + cols, scaled_embedding_grad, mask=mask) 

88 

89 

90def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): 

91 logger.debug("GEMS_CAMBRICON EMBEDDING FORWARD") 

92 assert not sparse, "Currently do not support sparse format" 

93 

94 indices = indices.contiguous() 

95 weight = weight.contiguous() 

96 

97 from .index_select import index_select 

98 

99 output = index_select(weight, 0, indices.flatten()) 

100 output = output.reshape(indices.shape + (-1,)) 

101 

102 if padding_idx is not None and padding_idx < 0: 

103 padding_idx = None 

104 

105 return output 

106 

107 

108def embedding_backward( 

109 grad_outputs, 

110 indices, 

111 num_weights, 

112 padding_idx=-1, 

113 scale_grad_by_freq=False, 

114 sparse=False, 

115): 

116 logger.debug("GEMS_CAMBRICON EMBEDDING BACKWARD") 

117 assert not sparse, "Currently do not support sparse format" 

118 

119 M = indices.numel() 

120 N = grad_outputs.shape[-1] 

121 

122 grad_inputs = torch.zeros( 

123 (num_weights, grad_outputs.shape[-1]), 

124 device=grad_outputs.device, 

125 dtype=torch.float32 

126 if grad_outputs.dtype is torch.bfloat16 

127 else grad_outputs.dtype, 

128 ) 

129 

130 if scale_grad_by_freq: 

131 indice_freq = torch.zeros( 

132 (num_weights,), 

133 requires_grad=False, 

134 device=grad_outputs.device, 

135 dtype=torch.int32, 

136 ) 

137 INDICE_BLOCK_SIZE = 256 

138 indice_grid = lambda meta: (triton.cdiv(M, INDICE_BLOCK_SIZE),) 

139 

140 with torch_device_fn.device(grad_outputs.device): 

141 indice_freq_kernel[indice_grid](indice_freq, indices, M, INDICE_BLOCK_SIZE) 

142 else: 

143 indice_freq = None 

144 

145 BLOCK_SIZE = triton.next_power_of_2(N) 

146 

147 HAS_PADDING_IDX = padding_idx is not None 

148 

149 with torch_device_fn.device(grad_outputs.device): 

150 embedding_backward_kernel[M,]( 

151 grad_inputs, 

152 grad_outputs, 

153 indices, 

154 padding_idx, 

155 HAS_PADDING_IDX, 

156 N, 

157 BLOCK_SIZE, 

158 ) 

159 

160 if scale_grad_by_freq: 

161 with torch_device_fn.device(grad_outputs.device): 

162 embedding_grad_scale_kernel[M,]( 

163 grad_inputs, indice_freq, num_weights, N, BLOCK_SIZE 

164 ) 

165 return ( 

166 grad_inputs.to(torch.bfloat16) 

167 if grad_outputs.dtype is torch.bfloat16 

168 else grad_inputs 

169 )