Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/nllloss.py: 0%
171 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-18 02:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13@libentry()
14@triton.jit(do_not_specialize=["ignore_index"])
15def nll_loss_forward_kernel(
16 inp_ptr,
17 tgt_ptr,
18 wgt_ptr,
19 out_ptr,
20 ignore_wgt_tgt_ptr,
21 ignore_index,
22 N,
23 C,
24 reduction: tl.constexpr = 1,
25 BLOCK_N: tl.constexpr = 128,
26):
27 pid_n = tl.program_id(0)
28 offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
30 mask_n = offsets_n < N
32 tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0)
33 assert tgt >= 0 and tgt < C, "Invalid target value"
34 ignore_mask = not (tgt == ignore_index) and mask_n
36 if wgt_ptr is None:
37 wgt_tgt = ignore_mask.to(tl.float32)
38 else:
39 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
41 inp_tgt_ptrs = inp_ptr + offsets_n * C + tgt
42 inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32)
43 out = inp_tgt * wgt_tgt * -1
45 tl.store(out_ptr + offsets_n, out, mask=mask_n)
46 if reduction == 1:
47 tl.store(ignore_wgt_tgt_ptr + offsets_n, wgt_tgt, mask=mask_n)
50@libentry()
51@triton.jit(do_not_specialize=["ignore_index"])
52def nll_loss_backward_kernel(
53 out_grad_ptr,
54 tgt_ptr,
55 wgt_ptr,
56 inp_grad_ptr,
57 ignore_index,
58 total_weight,
59 N,
60 C,
61 reduction: tl.constexpr = 1,
62 BLOCK_N: tl.constexpr = 128,
63):
64 pid_n = tl.program_id(0)
65 offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
67 mask_n = offsets_n < N
69 tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0)
70 ignore_mask = not (tgt == ignore_index) and mask_n
72 if wgt_ptr is None:
73 wgt_tgt = ignore_mask.to(tl.float32)
74 else:
75 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
77 if reduction == 0:
78 out_grad_ptrs = out_grad_ptr + offsets_n
79 out_grad = tl.load(out_grad_ptrs, mask=mask_n, other=0).to(tl.float32)
80 else:
81 out_grad = tl.load(out_grad_ptr).to(tl.float32)
82 if reduction == 1:
83 total_w = tl.load(total_weight).to(tl.float32)
84 else:
85 total_w = 1
87 inp_grad = tl.where(ignore_mask, -1 * out_grad * wgt_tgt / total_w, 0)
88 inp_grad_ptrs = inp_grad_ptr + offsets_n * C + tgt
89 tl.store(inp_grad_ptrs, inp_grad, mask=mask_n)
92@libentry()
93@triton.jit(do_not_specialize=["ignore_index"])
94def nll_loss2d_forward_kernel(
95 inp_ptr,
96 tgt_ptr,
97 wgt_ptr,
98 out_ptr,
99 ignore_wgt_tgt_ptr,
100 ignore_index,
101 N,
102 C,
103 D,
104 reduction: tl.constexpr = 1,
105 BLOCK_ND: tl.constexpr = 128,
106):
107 pid_nd = tl.program_id(0)
108 offset_nd = pid_nd * BLOCK_ND + tl.arange(0, BLOCK_ND)
109 offset_d = offset_nd % D
110 offset_n = offset_nd // D
112 mask_block = offset_nd < N * D
114 tgt_ptrs = tgt_ptr + offset_n * D + offset_d
115 tgt = tl.load(tgt_ptrs, mask=mask_block, other=0)
116 assert tgt >= 0 and tgt < C, "Invalid target value"
117 ignore_mask = not (tgt == ignore_index) and mask_block
119 if wgt_ptr is None:
120 wgt_tgt = ignore_mask.to(tl.float32)
121 else:
122 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
124 inp_tgt_ptrs = inp_ptr + offset_n * C * D + tgt * D + offset_d
125 inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32)
126 out = inp_tgt * wgt_tgt * -1
128 out_ptrs = out_ptr + offset_n * D + offset_d
129 tl.store(out_ptrs, out, mask=mask_block)
131 if reduction == 1:
132 ignore_wgt_tgt_ptrs = ignore_wgt_tgt_ptr + offset_n * D + offset_d
133 tl.store(ignore_wgt_tgt_ptrs, wgt_tgt, mask=mask_block)
136@libentry()
137@triton.jit(do_not_specialize=["ignore_index"])
138def nll_loss2d_backward_kernel(
139 out_grad_ptr,
140 tgt_ptr,
141 wgt_ptr,
142 inp_grad_ptr,
143 ignore_index,
144 total_weight,
145 N,
146 C,
147 D,
148 reduction: tl.constexpr = 1,
149 BLOCK_ND: tl.constexpr = 128,
150):
151 pid_nd = tl.program_id(0)
152 offset_nd = pid_nd * BLOCK_ND + tl.arange(0, BLOCK_ND)
153 offset_d = offset_nd % D
154 offset_n = offset_nd // D
156 mask_block = offset_nd < N * D
158 tgt_ptrs = tgt_ptr + offset_n * D + offset_d
159 tgt = tl.load(tgt_ptrs, mask=mask_block, other=0)
160 ignore_mask = not (tgt == ignore_index) and mask_block
162 if wgt_ptr is None:
163 wgt_tgt = ignore_mask.to(tl.float32)
164 else:
165 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
167 if reduction == 0:
168 out_grad_ptrs = out_grad_ptr + offset_n * D + offset_d
169 out_grad = tl.load(out_grad_ptrs, mask=mask_block, other=0).to(tl.float32)
170 else:
171 out_grad = tl.load(out_grad_ptr).to(tl.float32)
173 if reduction == 1:
174 total_w = tl.load(total_weight).to(tl.float32)
175 else:
176 total_w = 1
177 inp_grad = tl.where(ignore_mask, -1 * out_grad * wgt_tgt / total_w, 0)
178 inp_grad_ptrs = inp_grad_ptr + offset_n * C * D + tgt * D + offset_d
179 tl.store(inp_grad_ptrs, inp_grad, mask=mask_block)
182# Negative Log Likelihood Loss (NLLLoss)
183#
184# This loss function is used for training classification problems with C classes.
185#
186# Parameters:
187# - input (Tensor):
188# - Expected to contain log-probabilities for each class.
189# - Shape can be either:
190# - (minibatch, C) for standard classification tasks.
191# - (minibatch, C, d1, d2, ..., dK) for K-dimensional inputs (e.g., per-pixel loss for 2D images).
192#
193# - target (Tensor):
194# - Should contain class indices in the range [0, C-1].
195# - If ignore_index is specified, this index can be outside the class range
196# and will be ignored in the loss computation.
197#
198# - weight (1D Tensor, optional):
199# - Assigns weight to each class, useful for unbalanced datasets.
200#
201# Reduction modes:
202# - 'none': returns per-sample loss (shape: (N,)).
203# - 'mean' (default): computes the mean of the weighted losses.
204# - 'sum': computes the sum of the weighted losses.
205#
206# Mathematical description:
207# - Unreduced loss:
208# l_n = -w_y_n * x_n, where w_c = weight[c] * 1{c != ignore_index}.
209# - Reduced loss (depending on the specified reduction mode):
210# - mean: ℓ(x, y) = (1/N) * Σ(w_y_n * l_n)
211# - sum: ℓ(x, y) = Σ(l_n)
214# 1d & 2d tensor
215def nll_loss_forward(self, target, weight=None, reduction=1, ignore_index=-100):
216 logger.debug("GEMS NLL Loss FWD")
217 assert self.ndim <= 2, "Invalid input ndim"
218 shape = list(target.shape)
219 N = 1 if self.ndim == 1 else self.shape[0]
220 C = self.shape[-1]
221 assert target.numel() == N, "Invalid target size"
223 self = self.contiguous()
224 target = target.contiguous()
225 weight = None if weight is None else weight.contiguous()
227 out = torch.empty(shape, dtype=self.dtype, device=self.device)
228 ignore_weight_tgt = None
229 if reduction == 1:
230 ignore_weight_tgt = torch.zeros(
231 target.shape, dtype=self.dtype, device=self.device
232 )
234 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
235 with torch_device_fn.device(self.device):
236 nll_loss_forward_kernel[grid](
237 self, # torch.Size([4096, 256])
238 target, # torch.Size([4096]), tensor([174, 125, 174, ..., 216, 171, 120])
239 weight, # torch.Size([256])
240 out, # torch.Size([4096])
241 ignore_weight_tgt, # torch.Size([4096])
242 ignore_index, # 1
243 N, # 4096
244 C, # 256
245 reduction, # 0
246 is_use_mask_zero=True,
247 )
249 # redution: 0-None, 1-mean, 2-sum
250 if reduction == 0:
251 output = out
252 total_weight = torch.empty([], dtype=self.dtype, device=self.device)
253 elif reduction == 1:
254 total_out = torch.sum(out)
255 total_weight = torch.sum(ignore_weight_tgt).to(self.dtype)
256 output = (total_out / total_weight).to(self.dtype)
257 else:
258 total_out = torch.sum(out)
259 output = total_out.to(self.dtype)
260 total_weight = torch.empty([], dtype=self.dtype, device=self.device)
262 return output, total_weight
265def nll_loss_backward(
266 grad_output,
267 self,
268 target,
269 weight=None,
270 reduction=1,
271 ignore_index=-100,
272 total_weight=None,
273):
274 logger.debug("GEMS NLL Loss BWD")
275 N = 1 if self.ndim == 1 else self.shape[0]
276 C = self.shape[-1]
278 grad_output = grad_output.contiguous()
279 target = target.contiguous()
280 weight = None if weight is None else weight.contiguous()
282 grad_input = torch.zeros_like(self).contiguous()
284 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
285 with torch_device_fn.device(self.device):
286 nll_loss_backward_kernel[grid](
287 grad_output,
288 target,
289 weight,
290 grad_input,
291 ignore_index,
292 total_weight,
293 N,
294 C,
295 reduction,
296 )
298 return grad_input
301# 3d+ tensor
302def nll_loss2d_forward(self, target, weight=None, reduction=1, ignore_index=-100):
303 logger.debug("GEMS NLL Loss2d FWD")
304 assert self.ndim == 4, "Invalid input ndim"
306 shape = list(target.shape)
307 N, C, _, D = self.shape
308 assert shape == [N, 1, D], "Invalid target size"
310 self = self.contiguous()
311 target = target.contiguous()
312 weight = None if weight is None else weight.contiguous()
314 out = torch.empty(shape, dtype=self.dtype, device=self.device)
315 ignore_weight_tgt = None
316 if reduction == 1:
317 ignore_weight_tgt = torch.zeros(
318 target.shape, dtype=self.dtype, device=self.device
319 )
321 grid = lambda meta: (triton.cdiv(N * D, meta["BLOCK_ND"]),)
322 with torch_device_fn.device(self.device):
323 nll_loss2d_forward_kernel[grid](
324 self,
325 target,
326 weight,
327 out,
328 ignore_weight_tgt,
329 ignore_index,
330 N,
331 C,
332 D,
333 reduction,
334 is_use_mask_zero=True,
335 )
337 # redution: 0-None, 1-mean, 2-sum
338 if reduction == 0:
339 output = out
340 total_weight = torch.empty([], dtype=self.dtype, device=self.device)
341 elif reduction == 1:
342 total_out = torch.sum(out)
343 total_weight = torch.sum(ignore_weight_tgt).to(self.dtype)
344 output = (total_out / total_weight).to(self.dtype)
345 else:
346 total_out = torch.sum(out)
347 output = total_out.to(self.dtype)
348 total_weight = torch.empty([], dtype=self.dtype, device=self.device)
350 return output, total_weight
353def nll_loss2d_backward(
354 grad_output,
355 self,
356 target,
357 weight=None,
358 reduction=1,
359 ignore_index=-100,
360 total_weight=None,
361):
362 logger.debug("GEMS NLL Loss2d BWD")
363 N, C, _, D = self.shape
365 grad_output = grad_output.contiguous()
366 target = target.contiguous()
367 weight = None if weight is None else weight.contiguous()
369 grad_input = torch.zeros_like(self).contiguous()
371 grid = lambda meta: (triton.cdiv(N * D, meta["BLOCK_ND"]),)
372 with torch_device_fn.device(self.device):
373 nll_loss2d_backward_kernel[grid](
374 grad_output,
375 target,
376 weight,
377 grad_input,
378 ignore_index,
379 total_weight,
380 N,
381 C,
382 D,
383 reduction,
384 )
386 return grad_input