Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/weight_norm.py: 0%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12from ..ops import weight_norm_interface, weight_norm_interface_backward 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17def heur_row_weight_norm_except_dim_kernel(args): 

18 return triton.next_power_of_2(triton.cdiv(args["v_shape1"], 12)) 

19 

20 

21def heur_col_weight_norm_except_dim_kernel(args): 

22 return 1 

23 

24 

25@libentry() 

26# @triton.autotune( 

27# configs=runtime.get_tuned_config("weight_norm_kernel"), 

28# key=["v_shape0", "v_shape1", "v_shape2"], 

29# ) 

30@triton.heuristics( 

31 values={ 

32 "BLOCK_ROW_SIZE": heur_row_weight_norm_except_dim_kernel, 

33 "BLOCK_COL_SIZE": heur_col_weight_norm_except_dim_kernel, 

34 }, 

35) 

36@triton.jit(do_not_specialize=["eps"]) 

37def weight_norm_except_dim_kernel( 

38 output, 

39 norm, 

40 v, 

41 g, 

42 v_shape0, 

43 v_shape1, 

44 v_shape2, 

45 eps, 

46 BLOCK_ROW_SIZE: tl.constexpr, 

47 BLOCK_COL_SIZE: tl.constexpr, 

48): 

49 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

50 pid = tle.program_id(axis=0) * BLOCK_ROW_SIZE 

51 row_offset = pid + tid_m 

52 row_mask = row_offset < v_shape1 

53 

54 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

55 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

56 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

57 col_offset = base + tid_n 

58 m_idx = col_offset // v_shape2 

59 n_idx = row_offset 

60 k_idx = col_offset % v_shape2 

61 

62 mask = m_idx < v_shape0 and row_mask 

63 

64 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

65 v_value = tl.load(v + v_offsets, mask=mask) 

66 v_block += v_value * v_value 

67 v_sum = tl.sum(v_block, axis=1) + eps 

68 v_norm = tl.sqrt(v_sum[:, None]) 

69 tl.store(norm + row_offset, v_norm, mask=row_mask) 

70 

71 g_value = tl.load(g + row_offset, mask=row_mask) 

72 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

73 col_offset = base + tid_n 

74 m_idx = col_offset // v_shape2 

75 n_idx = row_offset 

76 k_idx = col_offset % v_shape2 

77 

78 mask = m_idx < v_shape0 and row_mask 

79 

80 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

81 v_value = tl.load(v + v_offsets, mask=mask) 

82 out = v_value * g_value / v_norm 

83 tl.store(output + v_offsets, out, mask=mask) 

84 

85 

86@libentry() 

87# @triton.autotune( 

88# configs=runtime.get_tuned_config("weight_norm_kernel"), 

89# key=["v_shape0", "v_shape1", "v_shape2"], 

90# ) 

91@triton.heuristics( 

92 values={ 

93 "BLOCK_ROW_SIZE": heur_row_weight_norm_except_dim_kernel, 

94 "BLOCK_COL_SIZE": heur_col_weight_norm_except_dim_kernel, 

95 }, 

96) 

97@triton.jit(do_not_specialize=["eps"]) 

98def weight_norm_except_dim_bwd_kernel( 

99 v_grad, 

100 g_grad, 

101 grad, 

102 v, 

103 g, 

104 norm, 

105 v_shape0, 

106 v_shape1, 

107 v_shape2, 

108 eps, 

109 BLOCK_ROW_SIZE: tl.constexpr, 

110 BLOCK_COL_SIZE: tl.constexpr, 

111): 

112 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

113 pid = tle.program_id(axis=0) * BLOCK_ROW_SIZE 

114 row_offset = pid + tid_m 

115 row_mask = row_offset < v_shape1 

116 

117 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32) 

118 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32) 

119 

120 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

121 

122 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

123 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

124 col_offset = base + tid_n 

125 m_idx = col_offset // v_shape2 

126 n_idx = row_offset 

127 k_idx = col_offset % v_shape2 

128 

129 mask = m_idx < v_shape0 and row_mask 

130 

131 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

132 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32) 

133 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32) 

134 v_block += v_value * grad_value 

135 vw_sum = tl.sum(v_block, axis=1)[:, None] 

136 

137 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

138 col_offset = base + tid_n 

139 m_idx = col_offset // v_shape2 

140 n_idx = row_offset 

141 k_idx = col_offset % v_shape2 

142 

143 mask = m_idx < v_shape0 and row_mask 

144 

145 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

146 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32) 

147 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32) 

148 v_grad_value = g_value * ( 

149 grad_value / (norm_value + eps) 

150 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum 

151 ) 

152 tl.store(v_grad + v_offsets, v_grad_value, mask=mask) 

153 

154 g_grad_value = vw_sum / (norm_value + eps) 

155 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask) 

156 

157 

158def weight_norm_except_dim(v, g, dim): 

159 logger.debug("GEMS WEIGHT NORM EXCEPT DIM FORWARD") 

160 v = v.contiguous() 

161 output = torch.empty_like(v) 

162 norm = torch.empty_like(g, dtype=torch.float32) 

163 v_shape = [ 

164 math.prod(v.shape[:dim]), 

165 v.shape[dim], 

166 math.prod(v.shape[dim + 1 :]), 

167 ] 

168 

169 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),) 

170 

171 with torch_device_fn.device(v.device): 

172 weight_norm_except_dim_kernel[grid]( 

173 output, 

174 norm, 

175 v, 

176 g, 

177 v_shape[0], 

178 v_shape[1], 

179 v_shape[2], 

180 eps=torch.finfo(torch.float32).tiny, 

181 ) 

182 return output, norm 

183 

184 

185def weight_norm_except_dim_backward(grad, v, g, norm, dim): 

186 logger.debug("GEMS WEIGHT NORM EXCEPT DIM BACKWARD") 

187 grad = grad.contiguous() 

188 v_grad = torch.empty_like(v) 

189 g_grad = torch.empty_like(g) 

190 v_shape = [ 

191 math.prod(v.shape[:dim]), 

192 v.shape[dim], 

193 math.prod(v.shape[dim + 1 :]), 

194 ] 

195 

196 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),) 

197 with torch_device_fn.device(v.device): 

198 weight_norm_except_dim_bwd_kernel[grid]( 

199 v_grad, 

200 g_grad, 

201 grad, 

202 v, 

203 g, 

204 norm, 

205 *v_shape, 

206 eps=torch.finfo(torch.float32).tiny, 

207 ) 

208 return v_grad, g_grad 

209 

210 

211class WeightNorm(torch.autograd.Function): 

212 @staticmethod 

213 def forward(ctx, v, g, dim=0): 

214 logger.debug("GEMS WEIGHT NORM") 

215 dim = dim % v.ndim 

216 can_use_fused = dim == 0 or dim == v.ndim - 1 

217 if can_use_fused: 

218 output, norm = weight_norm_interface(v, g, dim) 

219 else: 

220 output, norm = weight_norm_except_dim(v, g, dim) 

221 ctx.save_for_backward(v, g, norm) 

222 ctx.dim = dim 

223 ctx.can_use_fused = can_use_fused 

224 return output 

225 

226 @staticmethod 

227 def backward(ctx, grad): 

228 logger.debug("GEMS WEIGHT NORM BACKWARD") 

229 v, g, norm = ctx.saved_tensors 

230 dim = ctx.dim 

231 if ctx.can_use_fused: 

232 v_grad, g_grad = weight_norm_interface_backward(grad, v, g, norm, dim) 

233 else: 

234 v_grad, g_grad = weight_norm_except_dim_backward(grad, v, g, norm, dim) 

235 return v_grad, g_grad, None 

236 

237 

238def weight_norm(v, g, dim=0): 

239 return WeightNorm.apply(v, g, dim)