Coverage for src/flag_gems/ops/nllloss.py: 39%

196 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@libentry() 

14@triton.jit(do_not_specialize=["ignore_index"]) 

15def nll_loss_forward_kernel( 

16 inp_ptr, 

17 tgt_ptr, 

18 wgt_ptr, 

19 out_ptr, 

20 ignore_index, 

21 N, 

22 C, 

23 reduction: tl.constexpr = 1, 

24 BLOCK_N: tl.constexpr = 128, 

25): 

26 pid_n = tl.program_id(0) 

27 offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

28 

29 mask_n = offsets_n < N 

30 

31 tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0) 

32 assert tgt >= 0 and tgt < C, "Invalid target value" 

33 ignore_mask = not (tgt == ignore_index) and mask_n 

34 

35 if wgt_ptr is None: 

36 wgt_tgt = ignore_mask.to(tl.float32) 

37 else: 

38 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

39 

40 inp_tgt_ptrs = inp_ptr + offsets_n * C + tgt 

41 inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32) 

42 out = inp_tgt * wgt_tgt * -1 

43 

44 # none 

45 if reduction == 0: 

46 tl.store(out_ptr + offsets_n, out, mask=mask_n) 

47 # mean 

48 elif reduction == 1: 

49 total_out = tl.sum(out) 

50 total_wgt = tl.sum(wgt_tgt) 

51 tl.atomic_add(out_ptr, total_out, sem="relaxed") # output 

52 tl.atomic_add(out_ptr + 1, total_wgt, sem="relaxed") # weight 

53 tl.atomic_add(out_ptr + 2, 1, sem="release") # counter 

54 counter = tl.load(out_ptr + 2) 

55 if counter == tl.num_programs(0): 

56 total_out = tl.load(out_ptr) 

57 total_wgt = tl.load(out_ptr + 1) 

58 tl.store(out_ptr + 3, total_out / total_wgt) 

59 # sum 

60 else: 

61 total_out = tl.sum(out) 

62 tl.atomic_add(out_ptr, total_out, sem="relaxed") 

63 

64 

65@libentry() 

66@triton.jit(do_not_specialize=["ignore_index"]) 

67def nll_loss_backward_kernel( 

68 out_grad_ptr, 

69 tgt_ptr, 

70 wgt_ptr, 

71 inp_grad_ptr, 

72 ignore_index, 

73 total_weight, 

74 N, 

75 C, 

76 reduction: tl.constexpr = 1, 

77 BLOCK_N: tl.constexpr = 128, 

78): 

79 pid_n = tl.program_id(0) 

80 offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

81 

82 mask_n = offsets_n < N 

83 

84 tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0) 

85 ignore_mask = not (tgt == ignore_index) and mask_n 

86 

87 if wgt_ptr is None: 

88 wgt_tgt = ignore_mask.to(tl.float32) 

89 else: 

90 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

91 

92 if reduction == 0: 

93 out_grad_ptrs = out_grad_ptr + offsets_n 

94 out_grad = tl.load(out_grad_ptrs, mask=mask_n, other=0).to(tl.float32) 

95 else: 

96 out_grad = tl.load(out_grad_ptr).to(tl.float32) 

97 if reduction == 1: 

98 total_w = tl.load(total_weight).to(tl.float32) 

99 else: 

100 total_w = 1 

101 

102 inp_grad = tl.where(ignore_mask, -1 * out_grad * wgt_tgt / total_w, 0) 

103 inp_grad_ptrs = inp_grad_ptr + offsets_n * C + tgt 

104 tl.store(inp_grad_ptrs, inp_grad, mask=mask_n) 

105 

106 

107@libentry() 

108@triton.jit(do_not_specialize=["ignore_index"]) 

109def nll_loss2d_forward_kernel( 

110 inp_ptr, 

111 tgt_ptr, 

112 wgt_ptr, 

113 out_ptr, 

114 ignore_index, 

115 N, 

116 C, 

117 D, 

118 reduction: tl.constexpr = 1, 

119 BLOCK_ND: tl.constexpr = 128, 

120): 

121 pid_nd = tl.program_id(0) 

122 offset_nd = pid_nd * BLOCK_ND + tl.arange(0, BLOCK_ND) 

123 offset_d = offset_nd % D 

124 offset_n = offset_nd // D 

125 

126 mask_block = offset_nd < N * D 

127 

128 tgt_ptrs = tgt_ptr + offset_n * D + offset_d 

129 tgt = tl.load(tgt_ptrs, mask=mask_block, other=0) 

130 assert tgt >= 0 and tgt < C, "Invalid target value" 

131 ignore_mask = not (tgt == ignore_index) and mask_block 

132 

133 if wgt_ptr is None: 

134 wgt_tgt = ignore_mask.to(tl.float32) 

135 else: 

136 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

137 

138 inp_tgt_ptrs = inp_ptr + offset_n * C * D + tgt * D + offset_d 

139 inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32) 

140 out = inp_tgt * wgt_tgt * -1 

141 

142 # none 

143 if reduction == 0: 

144 out_ptrs = out_ptr + offset_n * D + offset_d 

145 tl.store(out_ptrs, out, mask=mask_block) 

146 # mean 

147 elif reduction == 1: 

148 total_out = tl.sum(out) 

149 total_wgt = tl.sum(wgt_tgt) 

150 tl.atomic_add(out_ptr, total_out, sem="relaxed") # output 

151 tl.atomic_add(out_ptr + 1, total_wgt, sem="relaxed") # weight 

152 tl.atomic_add(out_ptr + 2, 1, sem="release") # counter 

153 counter = tl.load(out_ptr + 2) 

154 if counter == tl.num_programs(0): 

155 total_out = tl.load(out_ptr) 

156 total_wgt = tl.load(out_ptr + 1) 

157 tl.store(out_ptr + 3, total_out / total_wgt) 

158 # sum 

159 else: 

160 total_out = tl.sum(out) 

161 tl.atomic_add(out_ptr, total_out, sem="relaxed") 

162 

163 

164@libentry() 

165@triton.jit(do_not_specialize=["ignore_index"]) 

166def nll_loss2d_backward_kernel( 

167 out_grad_ptr, 

168 tgt_ptr, 

169 wgt_ptr, 

170 inp_grad_ptr, 

171 ignore_index, 

172 total_weight, 

173 N, 

174 C, 

175 D, 

176 reduction: tl.constexpr = 1, 

177 BLOCK_ND: tl.constexpr = 128, 

178): 

179 pid_nd = tl.program_id(0) 

180 offset_nd = pid_nd * BLOCK_ND + tl.arange(0, BLOCK_ND) 

181 offset_d = offset_nd % D 

182 offset_n = offset_nd // D 

183 

184 mask_block = offset_nd < N * D 

185 

186 tgt_ptrs = tgt_ptr + offset_n * D + offset_d 

187 tgt = tl.load(tgt_ptrs, mask=mask_block, other=0) 

188 ignore_mask = not (tgt == ignore_index) and mask_block 

189 

190 if wgt_ptr is None: 

191 wgt_tgt = ignore_mask.to(tl.float32) 

192 else: 

193 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

194 

195 if reduction == 0: 

196 out_grad_ptrs = out_grad_ptr + offset_n * D + offset_d 

197 out_grad = tl.load(out_grad_ptrs, mask=mask_block, other=0).to(tl.float32) 

198 else: 

199 out_grad = tl.load(out_grad_ptr).to(tl.float32) 

200 

201 if reduction == 1: 

202 total_w = tl.load(total_weight).to(tl.float32) 

203 else: 

204 total_w = 1 

205 inp_grad = tl.where(ignore_mask, -1 * out_grad * wgt_tgt / total_w, 0) 

206 inp_grad_ptrs = inp_grad_ptr + offset_n * C * D + tgt * D + offset_d 

207 tl.store(inp_grad_ptrs, inp_grad, mask=mask_block) 

208 

209 

210# Negative Log Likelihood Loss (NLLLoss) 

211# 

212# This loss function is used for training classification problems with C classes. 

213# 

214# Parameters: 

215# - input (Tensor): 

216# - Expected to contain log-probabilities for each class. 

217# - Shape can be either: 

218# - (minibatch, C) for standard classification tasks. 

219# - (minibatch, C, d1, d2, ..., dK) for K-dimensional inputs (e.g., per-pixel loss for 2D images). 

220# 

221# - target (Tensor): 

222# - Should contain class indices in the range [0, C-1]. 

223# - If ignore_index is specified, this index can be outside the class range 

224# and will be ignored in the loss computation. 

225# 

226# - weight (1D Tensor, optional): 

227# - Assigns weight to each class, useful for unbalanced datasets. 

228# 

229# Reduction modes: 

230# - 'none': returns per-sample loss (shape: (N,)). 

231# - 'mean' (default): computes the mean of the weighted losses. 

232# - 'sum': computes the sum of the weighted losses. 

233# 

234# Mathematical description: 

235# - Unreduced loss: 

236# l_n = -w_y_n * x_n, where w_c = weight[c] * 1{c != ignore_index}. 

237# - Reduced loss (depending on the specified reduction mode): 

238# - mean: ℓ(x, y) = (1/N) * Σ(w_y_n * l_n) 

239# - sum: ℓ(x, y) = Σ(l_n) 

240 

241 

242# 1d & 2d tensor 

243def nll_loss_forward(self, target, weight=None, reduction=1, ignore_index=-100): 

244 logger.debug("GEMS NLL Loss FWD") 

245 assert self.ndim <= 2, "Invalid input ndim" 

246 shape = list(target.shape) 

247 N = 1 if self.ndim == 1 else self.shape[0] 

248 C = self.shape[-1] 

249 assert target.numel() == N, "Invalid target size" 

250 

251 self = self.contiguous() 

252 target = target.contiguous() 

253 weight = None if weight is None else weight.contiguous() 

254 

255 # redution: 0-None, 1-mean, 2-sum 

256 if reduction == 0: 

257 out = torch.empty(shape, dtype=self.dtype, device=self.device) 

258 elif reduction == 1: 

259 out = torch.zeros( 

260 [ 

261 4, 

262 ], 

263 dtype=torch.float32, 

264 device=self.device, 

265 ) 

266 else: 

267 out = torch.zeros([], dtype=torch.float32, device=self.device) 

268 

269 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) 

270 with torch_device_fn.device(self.device): 

271 nll_loss_forward_kernel[grid]( 

272 self, 

273 target, 

274 weight, 

275 out, 

276 ignore_index, 

277 N, 

278 C, 

279 reduction, 

280 ) 

281 

282 # redution: 0-None, 1-mean, 2-sum 

283 if reduction == 0: 

284 output = out 

285 total_weight = torch.empty([], dtype=self.dtype, device=self.device) 

286 elif reduction == 1: 

287 out = out.to(self.dtype) 

288 output = out[3] 

289 total_weight = out[1] 

290 else: 

291 output = out.to(self.dtype) 

292 total_weight = torch.empty([], dtype=self.dtype, device=self.device) 

293 

294 return output, total_weight 

295 

296 

297def nll_loss_backward( 

298 grad_output, 

299 self, 

300 target, 

301 weight=None, 

302 reduction=1, 

303 ignore_index=-100, 

304 total_weight=None, 

305): 

306 logger.debug("GEMS NLL Loss BWD") 

307 N = 1 if self.ndim == 1 else self.shape[0] 

308 C = self.shape[-1] 

309 

310 grad_output = grad_output.contiguous() 

311 target = target.contiguous() 

312 weight = None if weight is None else weight.contiguous() 

313 

314 grad_input = torch.zeros_like(self).contiguous() 

315 

316 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) 

317 with torch_device_fn.device(self.device): 

318 nll_loss_backward_kernel[grid]( 

319 grad_output, 

320 target, 

321 weight, 

322 grad_input, 

323 ignore_index, 

324 total_weight, 

325 N, 

326 C, 

327 reduction, 

328 ) 

329 

330 return grad_input 

331 

332 

333# 3d+ tensor 

334def nll_loss2d_forward(self, target, weight=None, reduction=1, ignore_index=-100): 

335 logger.debug("GEMS NLL Loss2d FWD") 

336 assert self.ndim == 4, "Invalid input ndim" 

337 

338 shape = list(target.shape) 

339 N, C, D1, D2 = self.shape 

340 assert shape == [N, D1, D2], "Invalid target size" 

341 D = D1 * D2 

342 self = self.contiguous() 

343 target = target.contiguous() 

344 weight = None if weight is None else weight.contiguous() 

345 

346 if reduction == 0: 

347 out = torch.empty(shape, dtype=self.dtype, device=self.device) 

348 elif reduction == 1: 

349 out = torch.zeros( 

350 [ 

351 4, 

352 ], 

353 dtype=torch.float32, 

354 device=self.device, 

355 ) 

356 else: 

357 out = torch.zeros([], dtype=torch.float32, device=self.device) 

358 

359 grid = lambda meta: (triton.cdiv(N * D, meta["BLOCK_ND"]),) 

360 with torch_device_fn.device(self.device): 

361 nll_loss2d_forward_kernel[grid]( 

362 self, target, weight, out, ignore_index, N, C, D, reduction 

363 ) 

364 

365 # redution: 0-None, 1-mean, 2-sum 

366 if reduction == 0: 

367 output = out 

368 total_weight = torch.empty([], dtype=self.dtype, device=self.device) 

369 elif reduction == 1: 

370 out = out.to(self.dtype) 

371 output = out[3] 

372 total_weight = out[1] 

373 else: 

374 output = out.to(self.dtype) 

375 total_weight = torch.empty([], dtype=self.dtype, device=self.device) 

376 

377 return output, total_weight 

378 

379 

380def nll_loss2d_backward( 

381 grad_output, 

382 self, 

383 target, 

384 weight=None, 

385 reduction=1, 

386 ignore_index=-100, 

387 total_weight=None, 

388): 

389 logger.debug("GEMS NLL Loss2d BWD") 

390 N, C, D1, D2 = self.shape 

391 D = D1 * D2 

392 grad_output = grad_output.contiguous() 

393 target = target.contiguous() 

394 weight = None if weight is None else weight.contiguous() 

395 

396 grad_input = torch.zeros_like(self).contiguous() 

397 

398 grid = lambda meta: (triton.cdiv(N * D, meta["BLOCK_ND"]),) 

399 with torch_device_fn.device(self.device): 

400 nll_loss2d_backward_kernel[grid]( 

401 grad_output, 

402 target, 

403 weight, 

404 grad_input, 

405 ignore_index, 

406 total_weight, 

407 N, 

408 C, 

409 D, 

410 reduction, 

411 ) 

412 

413 return grad_input