Coverage for src/flag_gems/runtime/backend/_cambricon/fused/weight_norm.py: 0%

124 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12from ..ops import weight_norm_interface, weight_norm_interface_backward 

13 

14logger = logging.getLogger(__name__) 

15 

16MAX_N = 31744 

17 

18 

19@libentry() 

20@triton.autotune( 

21 configs=runtime.get_tuned_config("weight_norm_kernel"), 

22 key=["v_shape0", "v_shape1", "v_shape2"], 

23) 

24@triton.jit(do_not_specialize=["eps"]) 

25def weight_norm_except_dim_kernel( 

26 output, 

27 norm, 

28 v, 

29 g, 

30 v_shape0, 

31 v_shape1, 

32 v_shape2, 

33 eps, 

34 BLOCK_ROW_SIZE: tl.constexpr, 

35 BLOCK_COL_SIZE: tl.constexpr, 

36): 

37 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

38 pid = tl.program_id(axis=0) * BLOCK_ROW_SIZE 

39 row_offset = pid + tid_m 

40 row_mask = row_offset < v_shape1 

41 

42 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

43 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

44 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

45 col_offset = base + tid_n 

46 m_idx = col_offset // v_shape2 

47 n_idx = row_offset 

48 k_idx = col_offset % v_shape2 

49 

50 mask = m_idx < v_shape0 and row_mask 

51 

52 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

53 v_value = tl.load(v + v_offsets, mask=mask) 

54 v_block += v_value * v_value 

55 v_sum = tl.sum(v_block, axis=1) + eps 

56 v_norm = tl.sqrt(v_sum[:, None]) 

57 tl.store(norm + row_offset, v_norm, mask=row_mask) 

58 

59 g_value = tl.load(g + row_offset, mask=row_mask) 

60 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

61 col_offset = base + tid_n 

62 m_idx = col_offset // v_shape2 

63 n_idx = row_offset 

64 k_idx = col_offset % v_shape2 

65 

66 mask = m_idx < v_shape0 and row_mask 

67 

68 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

69 v_value = tl.load(v + v_offsets, mask=mask) 

70 out = v_value * g_value / v_norm 

71 tl.store(output + v_offsets, out, mask=mask) 

72 

73 

74@libentry() 

75@triton.autotune( 

76 configs=runtime.get_tuned_config("weight_norm_kernel"), 

77 key=["v_shape0", "v_shape1", "v_shape2"], 

78) 

79@triton.jit(do_not_specialize=["eps"]) 

80def weight_norm_except_dim_bwd_kernel( 

81 v_grad, 

82 g_grad, 

83 grad, 

84 v, 

85 g, 

86 norm, 

87 v_shape0, 

88 v_shape1, 

89 v_shape2, 

90 eps, 

91 BLOCK_ROW_SIZE: tl.constexpr, 

92 BLOCK_COL_SIZE: tl.constexpr, 

93): 

94 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

95 pid = tl.program_id(axis=0) * BLOCK_ROW_SIZE 

96 row_offset = pid + tid_m 

97 row_mask = row_offset < v_shape1 

98 

99 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32) 

100 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32) 

101 

102 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

103 

104 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

105 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

106 col_offset = base + tid_n 

107 m_idx = col_offset // v_shape2 

108 n_idx = row_offset 

109 k_idx = col_offset % v_shape2 

110 

111 mask = m_idx < v_shape0 and row_mask 

112 

113 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

114 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32) 

115 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32) 

116 v_block += v_value * grad_value 

117 vw_sum = tl.sum(v_block, axis=1)[:, None] 

118 

119 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

120 col_offset = base + tid_n 

121 m_idx = col_offset // v_shape2 

122 n_idx = row_offset 

123 k_idx = col_offset % v_shape2 

124 

125 mask = m_idx < v_shape0 and row_mask 

126 

127 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

128 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32) 

129 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32) 

130 v_grad_value = g_value * ( 

131 grad_value / (norm_value + eps) 

132 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum 

133 ) 

134 tl.store(v_grad + v_offsets, v_grad_value, mask=mask) 

135 

136 g_grad_value = vw_sum / (norm_value + eps) 

137 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask) 

138 

139 

140def weight_norm_except_dim(v, g, dim): 

141 logger.debug("GEMS_CAMBRICON WEIGHT NORM EXCEPT DIM FORWARD") 

142 v = v.contiguous() 

143 output = torch.empty_like(v) 

144 norm = torch.empty_like(g, dtype=torch.float32) 

145 v_shape = [ 

146 math.prod(v.shape[:dim]), 

147 v.shape[dim], 

148 math.prod(v.shape[dim + 1 :]), 

149 ] 

150 

151 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),) 

152 

153 with torch_device_fn.device(v.device): 

154 weight_norm_except_dim_kernel[grid]( 

155 output, 

156 norm, 

157 v, 

158 g, 

159 v_shape[0], 

160 v_shape[1], 

161 v_shape[2], 

162 eps=torch.finfo(torch.float32).tiny, 

163 ) 

164 return output, norm 

165 

166 

167def weight_norm_except_dim_backward(grad, v, g, norm, dim): 

168 logger.debug("GEMS_CAMBRICON NORM BACKWARD") 

169 grad = grad.contiguous() 

170 v_grad = torch.empty_like(v) 

171 g_grad = torch.empty_like(g) 

172 v_shape = [ 

173 math.prod(v.shape[:dim]), 

174 v.shape[dim], 

175 math.prod(v.shape[dim + 1 :]), 

176 ] 

177 

178 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),) 

179 with torch_device_fn.device(v.device): 

180 weight_norm_except_dim_bwd_kernel[grid]( 

181 v_grad, 

182 g_grad, 

183 grad, 

184 v, 

185 g, 

186 norm, 

187 *v_shape, 

188 eps=torch.finfo(torch.float32).tiny, 

189 ) 

190 return v_grad, g_grad 

191 

192 

193class WeightNorm(torch.autograd.Function): 

194 @staticmethod 

195 def forward(ctx, v, g, dim=0): 

196 logger.debug("GEMS_CAMBRICON WEIGHT NORM") 

197 dim = dim % v.ndim 

198 can_use_fused = dim == 0 or dim == v.ndim - 1 

199 if can_use_fused: 

200 output, norm = weight_norm_interface(v, g, dim) 

201 else: 

202 output, norm = weight_norm_except_dim(v, g, dim) 

203 ctx.save_for_backward(v, g, norm) 

204 ctx.dim = dim 

205 ctx.can_use_fused = can_use_fused 

206 return output 

207 

208 @staticmethod 

209 def backward(ctx, grad): 

210 logger.debug("GEMS_CAMBRICON WEIGHT NORM BACKWARD") 

211 v, g, norm = ctx.saved_tensors 

212 dim = ctx.dim 

213 if ctx.can_use_fused: 

214 v_grad, g_grad = weight_norm_interface_backward(grad, v, g, norm, dim) 

215 else: 

216 v_grad, g_grad = weight_norm_except_dim_backward(grad, v, g, norm, dim) 

217 return v_grad, g_grad, None 

218 

219 

220def weight_norm(v, g, dim=0): 

221 return WeightNorm.apply(v, g, dim)