Coverage for src/flag_gems/fused/weight_norm.py: 22%

124 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.ops import weight_norm_interface, weight_norm_interface_backward 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.autotune( 

19 configs=runtime.get_tuned_config("weight_norm_kernel"), 

20 key=["v_shape0", "v_shape1", "v_shape2"], 

21) 

22@triton.jit(do_not_specialize=["eps"]) 

23def weight_norm_except_dim_kernel( 

24 output, 

25 norm, 

26 v, 

27 g, 

28 v_shape0, 

29 v_shape1, 

30 v_shape2, 

31 eps, 

32 BLOCK_ROW_SIZE: tl.constexpr, 

33 BLOCK_COL_SIZE: tl.constexpr, 

34): 

35 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

36 pid = tle.program_id(axis=0) * BLOCK_ROW_SIZE 

37 row_offset = pid + tid_m 

38 row_mask = row_offset < v_shape1 

39 

40 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

41 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

42 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

43 col_offset = base + tid_n 

44 m_idx = col_offset // v_shape2 

45 n_idx = row_offset 

46 k_idx = col_offset % v_shape2 

47 

48 mask = m_idx < v_shape0 and row_mask 

49 

50 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

51 v_value = tl.load(v + v_offsets, mask=mask) 

52 v_block += v_value * v_value 

53 v_sum = tl.sum(v_block, axis=1) + eps 

54 v_norm = tl.sqrt(v_sum[:, None]) 

55 tl.store(norm + row_offset, v_norm, mask=row_mask) 

56 

57 g_value = tl.load(g + row_offset, mask=row_mask) 

58 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

59 col_offset = base + tid_n 

60 m_idx = col_offset // v_shape2 

61 n_idx = row_offset 

62 k_idx = col_offset % v_shape2 

63 

64 mask = m_idx < v_shape0 and row_mask 

65 

66 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

67 v_value = tl.load(v + v_offsets, mask=mask) 

68 out = v_value * g_value / v_norm 

69 tl.store(output + v_offsets, out, mask=mask) 

70 

71 

72@libentry() 

73@triton.autotune( 

74 configs=runtime.get_tuned_config("weight_norm_kernel"), 

75 key=["v_shape0", "v_shape1", "v_shape2"], 

76) 

77@triton.jit(do_not_specialize=["eps"]) 

78def weight_norm_except_dim_bwd_kernel( 

79 v_grad, 

80 g_grad, 

81 grad, 

82 v, 

83 g, 

84 norm, 

85 v_shape0, 

86 v_shape1, 

87 v_shape2, 

88 eps, 

89 BLOCK_ROW_SIZE: tl.constexpr, 

90 BLOCK_COL_SIZE: tl.constexpr, 

91): 

92 tid_m = tl.arange(0, BLOCK_ROW_SIZE)[:, None] 

93 pid = tle.program_id(axis=0) * BLOCK_ROW_SIZE 

94 row_offset = pid + tid_m 

95 row_mask = row_offset < v_shape1 

96 

97 g_value = tl.load(g + row_offset, mask=row_mask).to(tl.float32) 

98 norm_value = tl.load(norm + row_offset, mask=row_mask).to(tl.float32) 

99 

100 tid_n = tl.arange(0, BLOCK_COL_SIZE)[None, :] 

101 

102 v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32) 

103 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

104 col_offset = base + tid_n 

105 m_idx = col_offset // v_shape2 

106 n_idx = row_offset 

107 k_idx = col_offset % v_shape2 

108 

109 mask = m_idx < v_shape0 and row_mask 

110 

111 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

112 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32) 

113 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32) 

114 v_block += v_value * grad_value 

115 vw_sum = tl.sum(v_block, axis=1)[:, None] 

116 

117 for base in range(0, v_shape0 * v_shape2, BLOCK_COL_SIZE): 

118 col_offset = base + tid_n 

119 m_idx = col_offset // v_shape2 

120 n_idx = row_offset 

121 k_idx = col_offset % v_shape2 

122 

123 mask = m_idx < v_shape0 and row_mask 

124 

125 v_offsets = m_idx * v_shape1 * v_shape2 + n_idx * v_shape2 + k_idx 

126 v_value = tl.load(v + v_offsets, mask=mask).to(tl.float32) 

127 grad_value = tl.load(grad + v_offsets, mask=mask).to(tl.float32) 

128 v_grad_value = g_value * ( 

129 grad_value / (norm_value + eps) 

130 - v_value / (norm_value * norm_value * norm_value + eps) * vw_sum 

131 ) 

132 tl.store(v_grad + v_offsets, v_grad_value, mask=mask) 

133 

134 g_grad_value = vw_sum / (norm_value + eps) 

135 tl.store(g_grad + row_offset, g_grad_value, mask=row_mask) 

136 

137 

138def weight_norm_except_dim(v, g, dim): 

139 logger.debug("GEMS WEIGHT NORM EXCEPT DIM FORWARD") 

140 v = v.contiguous() 

141 output = torch.empty_like(v) 

142 norm = torch.empty_like(g, dtype=torch.float32) 

143 v_shape = [ 

144 math.prod(v.shape[:dim]), 

145 v.shape[dim], 

146 math.prod(v.shape[dim + 1 :]), 

147 ] 

148 

149 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),) 

150 

151 with torch_device_fn.device(v.device): 

152 weight_norm_except_dim_kernel[grid]( 

153 output, 

154 norm, 

155 v, 

156 g, 

157 v_shape[0], 

158 v_shape[1], 

159 v_shape[2], 

160 eps=torch.finfo(torch.float32).tiny, 

161 ) 

162 return output, norm 

163 

164 

165def weight_norm_except_dim_backward(grad, v, g, norm, dim): 

166 logger.debug("GEMS WEIGHT NORM EXCEPT DIM BACKWARD") 

167 grad = grad.contiguous() 

168 v_grad = torch.empty_like(v) 

169 g_grad = torch.empty_like(g) 

170 v_shape = [ 

171 math.prod(v.shape[:dim]), 

172 v.shape[dim], 

173 math.prod(v.shape[dim + 1 :]), 

174 ] 

175 

176 grid = lambda META: (triton.cdiv(v_shape[1], META["BLOCK_ROW_SIZE"]),) 

177 with torch_device_fn.device(v.device): 

178 weight_norm_except_dim_bwd_kernel[grid]( 

179 v_grad, 

180 g_grad, 

181 grad, 

182 v, 

183 g, 

184 norm, 

185 *v_shape, 

186 eps=torch.finfo(torch.float32).tiny, 

187 ) 

188 return v_grad, g_grad 

189 

190 

191class WeightNorm(torch.autograd.Function): 

192 @staticmethod 

193 def forward(ctx, v, g, dim=0): 

194 logger.debug("GEMS WEIGHT NORM") 

195 dim = dim % v.ndim 

196 can_use_fused = dim == 0 or dim == v.ndim - 1 

197 if can_use_fused: 

198 output, norm = weight_norm_interface(v, g, dim) 

199 else: 

200 output, norm = weight_norm_except_dim(v, g, dim) 

201 ctx.save_for_backward(v, g, norm) 

202 ctx.dim = dim 

203 ctx.can_use_fused = can_use_fused 

204 return output 

205 

206 @staticmethod 

207 def backward(ctx, grad): 

208 logger.debug("GEMS WEIGHT NORM BACKWARD") 

209 v, g, norm = ctx.saved_tensors 

210 dim = ctx.dim 

211 if ctx.can_use_fused: 

212 v_grad, g_grad = weight_norm_interface_backward(grad, v, g, norm, dim) 

213 else: 

214 v_grad, g_grad = weight_norm_except_dim_backward(grad, v, g, norm, dim) 

215 return v_grad, g_grad, None 

216 

217 

218def weight_norm(v, g, dim=0): 

219 return WeightNorm.apply(v, g, dim)