Coverage for src/flag_gems/ops/nll_loss_nd.py: 12%

138 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.nllloss import nll_loss_backward as nll_loss_2d_backward 

8from flag_gems.ops.nllloss import nll_loss_forward as nll_loss_2d 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit 

17def nll_loss_nd_forward_kernel( 

18 input_ptr, 

19 target_ptr, 

20 weight_ptr, 

21 out_ptr, 

22 scratch_ptr, 

23 C, 

24 S, 

25 stride_in_n, 

26 stride_in_c, 

27 stride_in_s, 

28 stride_tgt_n, 

29 stride_tgt_s, 

30 ignore_index, 

31 HAS_WEIGHT: tl.constexpr, 

32 REDUCTION: tl.constexpr, 

33 BLOCK_S: tl.constexpr = 1024, 

34): 

35 pid_s = tl.program_id(0) 

36 pid_n = tl.program_id(1) 

37 

38 s_offsets = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) 

39 mask_s = s_offsets < S 

40 

41 tgt_offsets = pid_n * stride_tgt_n + s_offsets * stride_tgt_s 

42 t = tl.load(target_ptr + tgt_offsets, mask=mask_s, other=ignore_index).to(tl.int32) 

43 

44 valid = mask_s & (t != ignore_index) & (t >= 0) & (t < C) 

45 

46 in_offsets = pid_n * stride_in_n + t * stride_in_c + s_offsets * stride_in_s 

47 val = tl.load(input_ptr + in_offsets, mask=valid, other=0.0).to(tl.float32) 

48 

49 if HAS_WEIGHT: 

50 w = tl.load(weight_ptr + t, mask=valid, other=0.0).to(tl.float32) 

51 loss_val = tl.where(valid, -val * w, 0.0) 

52 else: 

53 w = tl.where(valid, 1.0, 0.0).to(tl.float32) 

54 loss_val = tl.where(valid, -val, 0.0) 

55 

56 # none 

57 if REDUCTION == 0: 

58 out_offset = pid_n * S + s_offsets 

59 tl.store( 

60 out_ptr + out_offset, loss_val.to(out_ptr.dtype.element_ty), mask=mask_s 

61 ) 

62 else: 

63 block_loss_sum = tl.sum(loss_val, axis=0) 

64 # mean 

65 if REDUCTION == 1: 

66 block_weight_sum = tl.sum(w, axis=0) 

67 

68 tl.atomic_add(scratch_ptr, block_loss_sum, sem="relaxed") 

69 tl.atomic_add(scratch_ptr + 1, block_weight_sum, sem="relaxed") 

70 

71 old_cnt = tl.atomic_add(scratch_ptr + 2, 1.0, sem="release") 

72 

73 total_programs = tl.num_programs(0) * tl.num_programs(1) 

74 

75 if old_cnt == total_programs - 1.0: 

76 total_loss = tl.load(scratch_ptr) 

77 total_weight = tl.load(scratch_ptr + 1) 

78 final_val = tl.where( 

79 total_weight == 0.0, 0.0, total_loss / total_weight 

80 ) 

81 tl.store(out_ptr, final_val.to(out_ptr.dtype.element_ty)) 

82 # Sum 

83 else: 

84 tl.atomic_add(scratch_ptr, block_loss_sum, sem="relaxed") 

85 

86 old_cnt = tl.atomic_add(scratch_ptr + 2, 1.0, sem="release") 

87 total_programs = tl.num_programs(0) * tl.num_programs(1) 

88 

89 if old_cnt == total_programs - 1.0: 

90 total_loss = tl.load(scratch_ptr) 

91 tl.store(out_ptr, total_loss.to(out_ptr.dtype.element_ty)) 

92 

93 

94@libentry() 

95@triton.jit 

96def nll_loss_nd_backward_kernel( 

97 grad_out_ptr, 

98 target_ptr, 

99 weight_ptr, 

100 grad_in_ptr, 

101 total_weight_ptr, 

102 C, 

103 S, 

104 stride_in_n, 

105 stride_in_c, 

106 stride_in_s, 

107 stride_tgt_n, 

108 stride_tgt_s, 

109 stride_go_n, 

110 stride_go_s, 

111 ignore_index, 

112 HAS_WEIGHT: tl.constexpr, 

113 REDUCTION: tl.constexpr, 

114 BLOCK_S: tl.constexpr = 1024, 

115): 

116 pid_s = tl.program_id(0) 

117 pid_n = tl.program_id(1) 

118 

119 s_offsets = pid_s * BLOCK_S + tl.arange(0, BLOCK_S) 

120 mask_s = s_offsets < S 

121 

122 tgt_offsets = pid_n * stride_tgt_n + s_offsets * stride_tgt_s 

123 t = tl.load(target_ptr + tgt_offsets, mask=mask_s, other=ignore_index).to(tl.int32) 

124 

125 valid = mask_s & (t != ignore_index) & (t >= 0) & (t < C) 

126 

127 if REDUCTION == 0: # none 

128 out_grad_offsets = pid_n * stride_go_n + s_offsets * stride_go_s 

129 out_grad = tl.load(grad_out_ptr + out_grad_offsets, mask=valid, other=0.0).to( 

130 tl.float32 

131 ) 

132 else: # mean or sum 

133 out_grad = tl.load(grad_out_ptr).to(tl.float32) 

134 

135 if HAS_WEIGHT: 

136 w = tl.load(weight_ptr + t, mask=valid, other=0.0).to(tl.float32) 

137 else: 

138 w = tl.where(valid, 1.0, 0.0).to(tl.float32) 

139 

140 if REDUCTION == 1: # mean 

141 total_weight = tl.load(total_weight_ptr).to(tl.float32) 

142 grad_in_val = tl.where(total_weight != 0.0, -w * out_grad / total_weight, 0.0) 

143 else: # sum or none 

144 grad_in_val = -w * out_grad 

145 

146 in_offsets = pid_n * stride_in_n + t * stride_in_c + s_offsets * stride_in_s 

147 tl.store( 

148 grad_in_ptr + in_offsets, 

149 grad_in_val.to(grad_in_ptr.dtype.element_ty), 

150 mask=valid, 

151 ) 

152 

153 

154def nll_loss_nd_forward( 

155 input: torch.Tensor, 

156 target: torch.Tensor, 

157 weight: torch.Tensor = None, 

158 reduction: int = 1, 

159 ignore_index: int = -100, 

160): 

161 logger.debug("GEMS NLL LOSS ND FWD") 

162 if input.dim() < 3: 

163 out, total_weight = nll_loss_2d( 

164 input, target, weight=weight, reduction=reduction, ignore_index=ignore_index 

165 ) 

166 return out, total_weight 

167 else: 

168 N = input.shape[0] 

169 C = input.shape[1] 

170 S = input.numel() // (N * C) 

171 

172 inp = input.reshape(N, C, S) 

173 

174 if target.numel() != N * S: 

175 raise ValueError( 

176 f"Target size {target.shape} doesn't match input size (N={N}, S={S})" 

177 ) 

178 else: 

179 tgt = target.reshape(N, S) 

180 

181 stride_in_n, stride_in_c, stride_in_s = inp.stride() 

182 stride_tgt_n, stride_tgt_s = tgt.stride() 

183 

184 if weight is None: 

185 has_weight = False 

186 w = input 

187 else: 

188 has_weight = True 

189 if weight.numel() != C: 

190 raise ValueError(f"Weight shape {weight.shape} must be ({C},)") 

191 w = weight.contiguous() 

192 

193 if reduction not in [0, 1, 2]: 

194 raise ValueError("reduction must be 0 ('none'), 1 ('mean'), or 2 ('sum')") 

195 

196 grid = lambda meta: (triton.cdiv(S, meta["BLOCK_S"]), N) 

197 with torch_device_fn.device(input.device): 

198 if reduction == 0: 

199 out = torch.empty((N, S), device=input.device, dtype=input.dtype) 

200 scratch = torch.empty(1, device=input.device) 

201 

202 nll_loss_nd_forward_kernel[grid]( 

203 inp, 

204 tgt, 

205 w, 

206 out, 

207 scratch, 

208 C, 

209 S, 

210 stride_in_n, 

211 stride_in_c, 

212 stride_in_s, 

213 stride_tgt_n, 

214 stride_tgt_s, 

215 ignore_index, 

216 HAS_WEIGHT=has_weight, 

217 REDUCTION=reduction, 

218 ) 

219 

220 if target.dim() == input.dim() - 1: 

221 res = out.view_as(target) 

222 else: 

223 res = out.reshape(target.shape) 

224 

225 total_weight = torch.empty([], device=input.device, dtype=input.dtype) 

226 return res, total_weight 

227 

228 else: 

229 out = torch.empty(1, device=input.device, dtype=input.dtype) 

230 scratch = torch.zeros(3, device=input.device, dtype=torch.float32) 

231 

232 nll_loss_nd_forward_kernel[grid]( 

233 inp, 

234 tgt, 

235 w, 

236 out, 

237 scratch, 

238 C, 

239 S, 

240 stride_in_n, 

241 stride_in_c, 

242 stride_in_s, 

243 stride_tgt_n, 

244 stride_tgt_s, 

245 ignore_index, 

246 HAS_WEIGHT=has_weight, 

247 REDUCTION=reduction, 

248 ) 

249 out = out[0] 

250 

251 if reduction == 1: 

252 total_weight = scratch[1] 

253 else: 

254 total_weight = torch.empty( 

255 [], device=input.device, dtype=input.dtype 

256 ) 

257 

258 return out, total_weight 

259 

260 

261def nll_loss_nd_backward( 

262 grad_output: torch.Tensor, 

263 input: torch.Tensor, 

264 target: torch.Tensor, 

265 weight: torch.Tensor = None, 

266 reduction: int = 1, 

267 ignore_index: int = -100, 

268 total_weight: torch.Tensor = None, 

269): 

270 logger.debug("GEMS NLL LOSS ND BWD") 

271 

272 if input.dim() < 3: 

273 return nll_loss_2d_backward( 

274 grad_output, 

275 input, 

276 target, 

277 weight=weight, 

278 reduction=reduction, 

279 ignore_index=ignore_index, 

280 total_weight=total_weight, 

281 ) 

282 else: 

283 grad_input = torch.zeros_like(input) 

284 

285 N = input.shape[0] 

286 C = input.shape[1] 

287 S = input.numel() // (N * C) 

288 

289 grad_inp = grad_input.reshape(N, C, S) 

290 tgt = target.reshape(N, S) 

291 

292 stride_in_n, stride_in_c, stride_in_s = grad_inp.stride() 

293 stride_tgt_n, stride_tgt_s = tgt.stride() 

294 

295 if weight is None: 

296 has_weight = False 

297 w = input 

298 else: 

299 has_weight = True 

300 w = weight.contiguous() 

301 

302 if reduction == 0: 

303 grad_out = grad_output.reshape(N, S) 

304 stride_go_n, stride_go_s = grad_out.stride() 

305 else: 

306 grad_out = grad_output 

307 stride_go_n, stride_go_s = 0, 0 

308 

309 grid = lambda meta: (triton.cdiv(S, meta["BLOCK_S"]), N) 

310 

311 with torch_device_fn.device(input.device): 

312 nll_loss_nd_backward_kernel[grid]( 

313 grad_out, 

314 tgt, 

315 w, 

316 grad_input, 

317 total_weight, 

318 C, 

319 S, 

320 stride_in_n, 

321 stride_in_c, 

322 stride_in_s, 

323 stride_tgt_n, 

324 stride_tgt_s, 

325 stride_go_n, 

326 stride_go_s, 

327 ignore_index, 

328 HAS_WEIGHT=has_weight, 

329 REDUCTION=reduction, 

330 ) 

331 

332 return grad_input