Coverage for src/flag_gems/ops/embedding_dense_backward.py: 53%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.jit 

11def _embedding_dense_backward_kernel( 

12 grad_output_ptr, 

13 indices_ptr, 

14 grad_weight_ptr, 

15 num_weights, 

16 padding_idx, 

17 BLOCK_D: tl.constexpr, 

18 EMBED_DIM: tl.constexpr, 

19): 

20 pid_n = tl.program_id(0) 

21 pid_d = tl.program_id(1) 

22 

23 offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

24 mask_d = offs_d < EMBED_DIM 

25 

26 idx = tl.load(indices_ptr + pid_n) 

27 valid = (idx != padding_idx) & (idx >= 0) & (idx < num_weights) 

28 

29 go_ptrs = grad_output_ptr + pid_n * EMBED_DIM + offs_d 

30 go = tl.load(go_ptrs, mask=mask_d, other=0).to(tl.float32) 

31 

32 gw_ptrs = grad_weight_ptr + idx * EMBED_DIM + offs_d 

33 mask = mask_d & valid 

34 tl.atomic_add(gw_ptrs, go, mask=mask) 

35 

36 

37@triton.jit 

38def _embedding_dense_backward_count_kernel( 

39 indices_ptr, 

40 counts_ptr, 

41 N, 

42 num_weights, 

43 padding_idx, 

44 BLOCK_N: tl.constexpr, 

45): 

46 pid = tl.program_id(0) 

47 offs = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

48 mask = offs < N 

49 idx = tl.load(indices_ptr + offs, mask=mask, other=0).to(tl.int32) 

50 valid = mask & (idx != padding_idx) & (idx >= 0) & (idx < num_weights) 

51 tl.atomic_add(counts_ptr + idx, 1, mask=valid) 

52 

53 

54@triton.jit 

55def _embedding_dense_backward_kernel_scale_by_freq( 

56 grad_output_ptr, 

57 indices_ptr, 

58 counts_ptr, 

59 grad_weight_ptr, 

60 num_weights, 

61 padding_idx, 

62 BLOCK_D: tl.constexpr, 

63 EMBED_DIM: tl.constexpr, 

64): 

65 pid_n = tl.program_id(0) 

66 pid_d = tl.program_id(1) 

67 

68 offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) 

69 mask_d = offs_d < EMBED_DIM 

70 

71 idx = tl.load(indices_ptr + pid_n).to(tl.int32) 

72 valid = (idx != padding_idx) & (idx >= 0) & (idx < num_weights) 

73 

74 go_ptrs = grad_output_ptr + pid_n * EMBED_DIM + offs_d 

75 # go = tl.load(go_ptrs, mask=mask_d, other=0.0).to(tl.float32) 

76 go = tl.load(go_ptrs, mask=mask_d, other=0.0) 

77 

78 # cnt = tl.load(counts_ptr + idx, mask=valid, other=1).to(tl.float32) 

79 cnt = tl.load(counts_ptr + idx, mask=valid, other=1) 

80 go = go / cnt 

81 

82 gw_ptrs = grad_weight_ptr + idx * EMBED_DIM + offs_d 

83 mask = mask_d & valid 

84 tl.atomic_add(gw_ptrs, go, mask=mask) 

85 

86 

87def embedding_dense_backward( 

88 grad_output: torch.Tensor, 

89 indices: torch.Tensor, 

90 num_weights: int, 

91 padding_idx: int, 

92 scale_grad_by_freq: bool, 

93): 

94 logger.debug("GEMS: embedding_dense_backward") 

95 assert indices.dtype in ( 

96 torch.int32, 

97 torch.int64, 

98 ), "Indices must be int32 or int64." 

99 assert ( 

100 grad_output.is_cuda and indices.is_cuda and grad_output.device == indices.device 

101 ), "Inputs must be CUDA tensors on the same device." 

102 

103 device = grad_output.device 

104 assert ( 

105 grad_output.dim() >= 2 

106 ), "grad_output must have embedding dimension as the last dim." 

107 

108 D = grad_output.shape[-1] 

109 go = grad_output.contiguous().view(-1, D) # (N, D) 

110 idx = indices.contiguous().view(-1) 

111 N = idx.numel() 

112 

113 assert go.shape[0] == N, "indices number must match grad_output rows." 

114 grad_weight_fp32 = torch.zeros((num_weights, D), device=device, dtype=torch.float32) 

115 

116 BLOCK_D = 128 

117 grid = (N, triton.cdiv(D, BLOCK_D)) 

118 

119 if scale_grad_by_freq: 

120 counts = torch.zeros((num_weights,), device=device, dtype=torch.int32) 

121 BLOCK_N = 512 

122 _embedding_dense_backward_count_kernel[(triton.cdiv(N, BLOCK_N),)]( 

123 idx, 

124 counts, 

125 N, 

126 num_weights, 

127 padding_idx if padding_idx is not None else -1, 

128 BLOCK_N=BLOCK_N, 

129 ) 

130 

131 _embedding_dense_backward_kernel_scale_by_freq[grid]( 

132 go, 

133 idx, 

134 counts, 

135 grad_weight_fp32, 

136 num_weights, 

137 padding_idx if padding_idx is not None else -1, 

138 BLOCK_D=BLOCK_D, 

139 EMBED_DIM=D, 

140 ) 

141 else: 

142 _embedding_dense_backward_kernel[grid]( 

143 go, 

144 idx, 

145 grad_weight_fp32, 

146 num_weights, 

147 padding_idx if padding_idx is not None else -1, 

148 BLOCK_D=BLOCK_D, 

149 EMBED_DIM=D, 

150 ) 

151 

152 if grad_output.dtype != torch.float32: 

153 return grad_weight_fp32.to(grad_output.dtype) 

154 return grad_weight_fp32