Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/nllloss.py: 0%

171 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13@libentry() 

14@triton.jit(do_not_specialize=["ignore_index"]) 

15def nll_loss_forward_kernel( 

16 inp_ptr, 

17 tgt_ptr, 

18 wgt_ptr, 

19 out_ptr, 

20 ignore_wgt_tgt_ptr, 

21 ignore_index, 

22 N, 

23 C, 

24 reduction: tl.constexpr = 1, 

25 BLOCK_N: tl.constexpr = 128, 

26): 

27 pid_n = tl.program_id(0) 

28 offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

29 

30 mask_n = offsets_n < N 

31 

32 tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0) 

33 assert tgt >= 0 and tgt < C, "Invalid target value" 

34 ignore_mask = not (tgt == ignore_index) and mask_n 

35 

36 if wgt_ptr is None: 

37 wgt_tgt = ignore_mask.to(tl.float32) 

38 else: 

39 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

40 

41 inp_tgt_ptrs = inp_ptr + offsets_n * C + tgt 

42 inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32) 

43 out = inp_tgt * wgt_tgt * -1 

44 

45 tl.store(out_ptr + offsets_n, out, mask=mask_n) 

46 if reduction == 1: 

47 tl.store(ignore_wgt_tgt_ptr + offsets_n, wgt_tgt, mask=mask_n) 

48 

49 

50@libentry() 

51@triton.jit(do_not_specialize=["ignore_index"]) 

52def nll_loss_backward_kernel( 

53 out_grad_ptr, 

54 tgt_ptr, 

55 wgt_ptr, 

56 inp_grad_ptr, 

57 ignore_index, 

58 total_weight, 

59 N, 

60 C, 

61 reduction: tl.constexpr = 1, 

62 BLOCK_N: tl.constexpr = 128, 

63): 

64 pid_n = tl.program_id(0) 

65 offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

66 

67 mask_n = offsets_n < N 

68 

69 tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0) 

70 ignore_mask = not (tgt == ignore_index) and mask_n 

71 

72 if wgt_ptr is None: 

73 wgt_tgt = ignore_mask.to(tl.float32) 

74 else: 

75 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

76 

77 if reduction == 0: 

78 out_grad_ptrs = out_grad_ptr + offsets_n 

79 out_grad = tl.load(out_grad_ptrs, mask=mask_n, other=0).to(tl.float32) 

80 else: 

81 out_grad = tl.load(out_grad_ptr).to(tl.float32) 

82 if reduction == 1: 

83 total_w = tl.load(total_weight).to(tl.float32) 

84 else: 

85 total_w = 1 

86 

87 inp_grad = tl.where(ignore_mask, -1 * out_grad * wgt_tgt / total_w, 0) 

88 inp_grad_ptrs = inp_grad_ptr + offsets_n * C + tgt 

89 tl.store(inp_grad_ptrs, inp_grad, mask=mask_n) 

90 

91 

92@libentry() 

93@triton.jit(do_not_specialize=["ignore_index"]) 

94def nll_loss2d_forward_kernel( 

95 inp_ptr, 

96 tgt_ptr, 

97 wgt_ptr, 

98 out_ptr, 

99 ignore_wgt_tgt_ptr, 

100 ignore_index, 

101 N, 

102 C, 

103 D, 

104 reduction: tl.constexpr = 1, 

105 BLOCK_ND: tl.constexpr = 128, 

106): 

107 pid_nd = tl.program_id(0) 

108 offset_nd = pid_nd * BLOCK_ND + tl.arange(0, BLOCK_ND) 

109 offset_d = offset_nd % D 

110 offset_n = offset_nd // D 

111 

112 mask_block = offset_nd < N * D 

113 

114 tgt_ptrs = tgt_ptr + offset_n * D + offset_d 

115 tgt = tl.load(tgt_ptrs, mask=mask_block, other=0) 

116 assert tgt >= 0 and tgt < C, "Invalid target value" 

117 ignore_mask = not (tgt == ignore_index) and mask_block 

118 

119 if wgt_ptr is None: 

120 wgt_tgt = ignore_mask.to(tl.float32) 

121 else: 

122 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

123 

124 inp_tgt_ptrs = inp_ptr + offset_n * C * D + tgt * D + offset_d 

125 inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32) 

126 out = inp_tgt * wgt_tgt * -1 

127 

128 out_ptrs = out_ptr + offset_n * D + offset_d 

129 tl.store(out_ptrs, out, mask=mask_block) 

130 

131 if reduction == 1: 

132 ignore_wgt_tgt_ptrs = ignore_wgt_tgt_ptr + offset_n * D + offset_d 

133 tl.store(ignore_wgt_tgt_ptrs, wgt_tgt, mask=mask_block) 

134 

135 

136@libentry() 

137@triton.jit(do_not_specialize=["ignore_index"]) 

138def nll_loss2d_backward_kernel( 

139 out_grad_ptr, 

140 tgt_ptr, 

141 wgt_ptr, 

142 inp_grad_ptr, 

143 ignore_index, 

144 total_weight, 

145 N, 

146 C, 

147 D, 

148 reduction: tl.constexpr = 1, 

149 BLOCK_ND: tl.constexpr = 128, 

150): 

151 pid_nd = tl.program_id(0) 

152 offset_nd = pid_nd * BLOCK_ND + tl.arange(0, BLOCK_ND) 

153 offset_d = offset_nd % D 

154 offset_n = offset_nd // D 

155 

156 mask_block = offset_nd < N * D 

157 

158 tgt_ptrs = tgt_ptr + offset_n * D + offset_d 

159 tgt = tl.load(tgt_ptrs, mask=mask_block, other=0) 

160 ignore_mask = not (tgt == ignore_index) and mask_block 

161 

162 if wgt_ptr is None: 

163 wgt_tgt = ignore_mask.to(tl.float32) 

164 else: 

165 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32) 

166 

167 if reduction == 0: 

168 out_grad_ptrs = out_grad_ptr + offset_n * D + offset_d 

169 out_grad = tl.load(out_grad_ptrs, mask=mask_block, other=0).to(tl.float32) 

170 else: 

171 out_grad = tl.load(out_grad_ptr).to(tl.float32) 

172 

173 if reduction == 1: 

174 total_w = tl.load(total_weight).to(tl.float32) 

175 else: 

176 total_w = 1 

177 inp_grad = tl.where(ignore_mask, -1 * out_grad * wgt_tgt / total_w, 0) 

178 inp_grad_ptrs = inp_grad_ptr + offset_n * C * D + tgt * D + offset_d 

179 tl.store(inp_grad_ptrs, inp_grad, mask=mask_block) 

180 

181 

182# Negative Log Likelihood Loss (NLLLoss) 

183# 

184# This loss function is used for training classification problems with C classes. 

185# 

186# Parameters: 

187# - input (Tensor): 

188# - Expected to contain log-probabilities for each class. 

189# - Shape can be either: 

190# - (minibatch, C) for standard classification tasks. 

191# - (minibatch, C, d1, d2, ..., dK) for K-dimensional inputs (e.g., per-pixel loss for 2D images). 

192# 

193# - target (Tensor): 

194# - Should contain class indices in the range [0, C-1]. 

195# - If ignore_index is specified, this index can be outside the class range 

196# and will be ignored in the loss computation. 

197# 

198# - weight (1D Tensor, optional): 

199# - Assigns weight to each class, useful for unbalanced datasets. 

200# 

201# Reduction modes: 

202# - 'none': returns per-sample loss (shape: (N,)). 

203# - 'mean' (default): computes the mean of the weighted losses. 

204# - 'sum': computes the sum of the weighted losses. 

205# 

206# Mathematical description: 

207# - Unreduced loss: 

208# l_n = -w_y_n * x_n, where w_c = weight[c] * 1{c != ignore_index}. 

209# - Reduced loss (depending on the specified reduction mode): 

210# - mean: ℓ(x, y) = (1/N) * Σ(w_y_n * l_n) 

211# - sum: ℓ(x, y) = Σ(l_n) 

212 

213 

214# 1d & 2d tensor 

215def nll_loss_forward(self, target, weight=None, reduction=1, ignore_index=-100): 

216 logger.debug("GEMS NLL Loss FWD") 

217 assert self.ndim <= 2, "Invalid input ndim" 

218 shape = list(target.shape) 

219 N = 1 if self.ndim == 1 else self.shape[0] 

220 C = self.shape[-1] 

221 assert target.numel() == N, "Invalid target size" 

222 

223 self = self.contiguous() 

224 target = target.contiguous() 

225 weight = None if weight is None else weight.contiguous() 

226 

227 out = torch.empty(shape, dtype=self.dtype, device=self.device) 

228 ignore_weight_tgt = None 

229 if reduction == 1: 

230 ignore_weight_tgt = torch.zeros( 

231 target.shape, dtype=self.dtype, device=self.device 

232 ) 

233 

234 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) 

235 with torch_device_fn.device(self.device): 

236 nll_loss_forward_kernel[grid]( 

237 self, # torch.Size([4096, 256]) 

238 target, # torch.Size([4096]), tensor([174, 125, 174, ..., 216, 171, 120]) 

239 weight, # torch.Size([256]) 

240 out, # torch.Size([4096]) 

241 ignore_weight_tgt, # torch.Size([4096]) 

242 ignore_index, # 1 

243 N, # 4096 

244 C, # 256 

245 reduction, # 0 

246 is_use_mask_zero=True, 

247 ) 

248 

249 # redution: 0-None, 1-mean, 2-sum 

250 if reduction == 0: 

251 output = out 

252 total_weight = torch.empty([], dtype=self.dtype, device=self.device) 

253 elif reduction == 1: 

254 total_out = torch.sum(out) 

255 total_weight = torch.sum(ignore_weight_tgt).to(self.dtype) 

256 output = (total_out / total_weight).to(self.dtype) 

257 else: 

258 total_out = torch.sum(out) 

259 output = total_out.to(self.dtype) 

260 total_weight = torch.empty([], dtype=self.dtype, device=self.device) 

261 

262 return output, total_weight 

263 

264 

265def nll_loss_backward( 

266 grad_output, 

267 self, 

268 target, 

269 weight=None, 

270 reduction=1, 

271 ignore_index=-100, 

272 total_weight=None, 

273): 

274 logger.debug("GEMS NLL Loss BWD") 

275 N = 1 if self.ndim == 1 else self.shape[0] 

276 C = self.shape[-1] 

277 

278 grad_output = grad_output.contiguous() 

279 target = target.contiguous() 

280 weight = None if weight is None else weight.contiguous() 

281 

282 grad_input = torch.zeros_like(self).contiguous() 

283 

284 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),) 

285 with torch_device_fn.device(self.device): 

286 nll_loss_backward_kernel[grid]( 

287 grad_output, 

288 target, 

289 weight, 

290 grad_input, 

291 ignore_index, 

292 total_weight, 

293 N, 

294 C, 

295 reduction, 

296 ) 

297 

298 return grad_input 

299 

300 

301# 3d+ tensor 

302def nll_loss2d_forward(self, target, weight=None, reduction=1, ignore_index=-100): 

303 logger.debug("GEMS NLL Loss2d FWD") 

304 assert self.ndim == 4, "Invalid input ndim" 

305 

306 shape = list(target.shape) 

307 N, C, _, D = self.shape 

308 assert shape == [N, 1, D], "Invalid target size" 

309 

310 self = self.contiguous() 

311 target = target.contiguous() 

312 weight = None if weight is None else weight.contiguous() 

313 

314 out = torch.empty(shape, dtype=self.dtype, device=self.device) 

315 ignore_weight_tgt = None 

316 if reduction == 1: 

317 ignore_weight_tgt = torch.zeros( 

318 target.shape, dtype=self.dtype, device=self.device 

319 ) 

320 

321 grid = lambda meta: (triton.cdiv(N * D, meta["BLOCK_ND"]),) 

322 with torch_device_fn.device(self.device): 

323 nll_loss2d_forward_kernel[grid]( 

324 self, 

325 target, 

326 weight, 

327 out, 

328 ignore_weight_tgt, 

329 ignore_index, 

330 N, 

331 C, 

332 D, 

333 reduction, 

334 is_use_mask_zero=True, 

335 ) 

336 

337 # redution: 0-None, 1-mean, 2-sum 

338 if reduction == 0: 

339 output = out 

340 total_weight = torch.empty([], dtype=self.dtype, device=self.device) 

341 elif reduction == 1: 

342 total_out = torch.sum(out) 

343 total_weight = torch.sum(ignore_weight_tgt).to(self.dtype) 

344 output = (total_out / total_weight).to(self.dtype) 

345 else: 

346 total_out = torch.sum(out) 

347 output = total_out.to(self.dtype) 

348 total_weight = torch.empty([], dtype=self.dtype, device=self.device) 

349 

350 return output, total_weight 

351 

352 

353def nll_loss2d_backward( 

354 grad_output, 

355 self, 

356 target, 

357 weight=None, 

358 reduction=1, 

359 ignore_index=-100, 

360 total_weight=None, 

361): 

362 logger.debug("GEMS NLL Loss2d BWD") 

363 N, C, _, D = self.shape 

364 

365 grad_output = grad_output.contiguous() 

366 target = target.contiguous() 

367 weight = None if weight is None else weight.contiguous() 

368 

369 grad_input = torch.zeros_like(self).contiguous() 

370 

371 grid = lambda meta: (triton.cdiv(N * D, meta["BLOCK_ND"]),) 

372 with torch_device_fn.device(self.device): 

373 nll_loss2d_backward_kernel[grid]( 

374 grad_output, 

375 target, 

376 weight, 

377 grad_input, 

378 ignore_index, 

379 total_weight, 

380 N, 

381 C, 

382 D, 

383 reduction, 

384 ) 

385 

386 return grad_input