Coverage for src/flag_gems/ops/nllloss.py: 39%
196 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13@libentry()
14@triton.jit(do_not_specialize=["ignore_index"])
15def nll_loss_forward_kernel(
16 inp_ptr,
17 tgt_ptr,
18 wgt_ptr,
19 out_ptr,
20 ignore_index,
21 N,
22 C,
23 reduction: tl.constexpr = 1,
24 BLOCK_N: tl.constexpr = 128,
25):
26 pid_n = tl.program_id(0)
27 offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
29 mask_n = offsets_n < N
31 tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0)
32 assert tgt >= 0 and tgt < C, "Invalid target value"
33 ignore_mask = not (tgt == ignore_index) and mask_n
35 if wgt_ptr is None:
36 wgt_tgt = ignore_mask.to(tl.float32)
37 else:
38 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
40 inp_tgt_ptrs = inp_ptr + offsets_n * C + tgt
41 inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32)
42 out = inp_tgt * wgt_tgt * -1
44 # none
45 if reduction == 0:
46 tl.store(out_ptr + offsets_n, out, mask=mask_n)
47 # mean
48 elif reduction == 1:
49 total_out = tl.sum(out)
50 total_wgt = tl.sum(wgt_tgt)
51 tl.atomic_add(out_ptr, total_out, sem="relaxed") # output
52 tl.atomic_add(out_ptr + 1, total_wgt, sem="relaxed") # weight
53 tl.atomic_add(out_ptr + 2, 1, sem="release") # counter
54 counter = tl.load(out_ptr + 2)
55 if counter == tl.num_programs(0):
56 total_out = tl.load(out_ptr)
57 total_wgt = tl.load(out_ptr + 1)
58 tl.store(out_ptr + 3, total_out / total_wgt)
59 # sum
60 else:
61 total_out = tl.sum(out)
62 tl.atomic_add(out_ptr, total_out, sem="relaxed")
65@libentry()
66@triton.jit(do_not_specialize=["ignore_index"])
67def nll_loss_backward_kernel(
68 out_grad_ptr,
69 tgt_ptr,
70 wgt_ptr,
71 inp_grad_ptr,
72 ignore_index,
73 total_weight,
74 N,
75 C,
76 reduction: tl.constexpr = 1,
77 BLOCK_N: tl.constexpr = 128,
78):
79 pid_n = tl.program_id(0)
80 offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
82 mask_n = offsets_n < N
84 tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0)
85 ignore_mask = not (tgt == ignore_index) and mask_n
87 if wgt_ptr is None:
88 wgt_tgt = ignore_mask.to(tl.float32)
89 else:
90 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
92 if reduction == 0:
93 out_grad_ptrs = out_grad_ptr + offsets_n
94 out_grad = tl.load(out_grad_ptrs, mask=mask_n, other=0).to(tl.float32)
95 else:
96 out_grad = tl.load(out_grad_ptr).to(tl.float32)
97 if reduction == 1:
98 total_w = tl.load(total_weight).to(tl.float32)
99 else:
100 total_w = 1
102 inp_grad = tl.where(ignore_mask, -1 * out_grad * wgt_tgt / total_w, 0)
103 inp_grad_ptrs = inp_grad_ptr + offsets_n * C + tgt
104 tl.store(inp_grad_ptrs, inp_grad, mask=mask_n)
107@libentry()
108@triton.jit(do_not_specialize=["ignore_index"])
109def nll_loss2d_forward_kernel(
110 inp_ptr,
111 tgt_ptr,
112 wgt_ptr,
113 out_ptr,
114 ignore_index,
115 N,
116 C,
117 D,
118 reduction: tl.constexpr = 1,
119 BLOCK_ND: tl.constexpr = 128,
120):
121 pid_nd = tl.program_id(0)
122 offset_nd = pid_nd * BLOCK_ND + tl.arange(0, BLOCK_ND)
123 offset_d = offset_nd % D
124 offset_n = offset_nd // D
126 mask_block = offset_nd < N * D
128 tgt_ptrs = tgt_ptr + offset_n * D + offset_d
129 tgt = tl.load(tgt_ptrs, mask=mask_block, other=0)
130 assert tgt >= 0 and tgt < C, "Invalid target value"
131 ignore_mask = not (tgt == ignore_index) and mask_block
133 if wgt_ptr is None:
134 wgt_tgt = ignore_mask.to(tl.float32)
135 else:
136 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
138 inp_tgt_ptrs = inp_ptr + offset_n * C * D + tgt * D + offset_d
139 inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32)
140 out = inp_tgt * wgt_tgt * -1
142 # none
143 if reduction == 0:
144 out_ptrs = out_ptr + offset_n * D + offset_d
145 tl.store(out_ptrs, out, mask=mask_block)
146 # mean
147 elif reduction == 1:
148 total_out = tl.sum(out)
149 total_wgt = tl.sum(wgt_tgt)
150 tl.atomic_add(out_ptr, total_out, sem="relaxed") # output
151 tl.atomic_add(out_ptr + 1, total_wgt, sem="relaxed") # weight
152 tl.atomic_add(out_ptr + 2, 1, sem="release") # counter
153 counter = tl.load(out_ptr + 2)
154 if counter == tl.num_programs(0):
155 total_out = tl.load(out_ptr)
156 total_wgt = tl.load(out_ptr + 1)
157 tl.store(out_ptr + 3, total_out / total_wgt)
158 # sum
159 else:
160 total_out = tl.sum(out)
161 tl.atomic_add(out_ptr, total_out, sem="relaxed")
164@libentry()
165@triton.jit(do_not_specialize=["ignore_index"])
166def nll_loss2d_backward_kernel(
167 out_grad_ptr,
168 tgt_ptr,
169 wgt_ptr,
170 inp_grad_ptr,
171 ignore_index,
172 total_weight,
173 N,
174 C,
175 D,
176 reduction: tl.constexpr = 1,
177 BLOCK_ND: tl.constexpr = 128,
178):
179 pid_nd = tl.program_id(0)
180 offset_nd = pid_nd * BLOCK_ND + tl.arange(0, BLOCK_ND)
181 offset_d = offset_nd % D
182 offset_n = offset_nd // D
184 mask_block = offset_nd < N * D
186 tgt_ptrs = tgt_ptr + offset_n * D + offset_d
187 tgt = tl.load(tgt_ptrs, mask=mask_block, other=0)
188 ignore_mask = not (tgt == ignore_index) and mask_block
190 if wgt_ptr is None:
191 wgt_tgt = ignore_mask.to(tl.float32)
192 else:
193 wgt_tgt = tl.load(wgt_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
195 if reduction == 0:
196 out_grad_ptrs = out_grad_ptr + offset_n * D + offset_d
197 out_grad = tl.load(out_grad_ptrs, mask=mask_block, other=0).to(tl.float32)
198 else:
199 out_grad = tl.load(out_grad_ptr).to(tl.float32)
201 if reduction == 1:
202 total_w = tl.load(total_weight).to(tl.float32)
203 else:
204 total_w = 1
205 inp_grad = tl.where(ignore_mask, -1 * out_grad * wgt_tgt / total_w, 0)
206 inp_grad_ptrs = inp_grad_ptr + offset_n * C * D + tgt * D + offset_d
207 tl.store(inp_grad_ptrs, inp_grad, mask=mask_block)
210# Negative Log Likelihood Loss (NLLLoss)
211#
212# This loss function is used for training classification problems with C classes.
213#
214# Parameters:
215# - input (Tensor):
216# - Expected to contain log-probabilities for each class.
217# - Shape can be either:
218# - (minibatch, C) for standard classification tasks.
219# - (minibatch, C, d1, d2, ..., dK) for K-dimensional inputs (e.g., per-pixel loss for 2D images).
220#
221# - target (Tensor):
222# - Should contain class indices in the range [0, C-1].
223# - If ignore_index is specified, this index can be outside the class range
224# and will be ignored in the loss computation.
225#
226# - weight (1D Tensor, optional):
227# - Assigns weight to each class, useful for unbalanced datasets.
228#
229# Reduction modes:
230# - 'none': returns per-sample loss (shape: (N,)).
231# - 'mean' (default): computes the mean of the weighted losses.
232# - 'sum': computes the sum of the weighted losses.
233#
234# Mathematical description:
235# - Unreduced loss:
236# l_n = -w_y_n * x_n, where w_c = weight[c] * 1{c != ignore_index}.
237# - Reduced loss (depending on the specified reduction mode):
238# - mean: ℓ(x, y) = (1/N) * Σ(w_y_n * l_n)
239# - sum: ℓ(x, y) = Σ(l_n)
242# 1d & 2d tensor
243def nll_loss_forward(self, target, weight=None, reduction=1, ignore_index=-100):
244 logger.debug("GEMS NLL Loss FWD")
245 assert self.ndim <= 2, "Invalid input ndim"
246 shape = list(target.shape)
247 N = 1 if self.ndim == 1 else self.shape[0]
248 C = self.shape[-1]
249 assert target.numel() == N, "Invalid target size"
251 self = self.contiguous()
252 target = target.contiguous()
253 weight = None if weight is None else weight.contiguous()
255 # redution: 0-None, 1-mean, 2-sum
256 if reduction == 0:
257 out = torch.empty(shape, dtype=self.dtype, device=self.device)
258 elif reduction == 1:
259 out = torch.zeros(
260 [
261 4,
262 ],
263 dtype=torch.float32,
264 device=self.device,
265 )
266 else:
267 out = torch.zeros([], dtype=torch.float32, device=self.device)
269 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
270 with torch_device_fn.device(self.device):
271 nll_loss_forward_kernel[grid](
272 self,
273 target,
274 weight,
275 out,
276 ignore_index,
277 N,
278 C,
279 reduction,
280 )
282 # redution: 0-None, 1-mean, 2-sum
283 if reduction == 0:
284 output = out
285 total_weight = torch.empty([], dtype=self.dtype, device=self.device)
286 elif reduction == 1:
287 out = out.to(self.dtype)
288 output = out[3]
289 total_weight = out[1]
290 else:
291 output = out.to(self.dtype)
292 total_weight = torch.empty([], dtype=self.dtype, device=self.device)
294 return output, total_weight
297def nll_loss_backward(
298 grad_output,
299 self,
300 target,
301 weight=None,
302 reduction=1,
303 ignore_index=-100,
304 total_weight=None,
305):
306 logger.debug("GEMS NLL Loss BWD")
307 N = 1 if self.ndim == 1 else self.shape[0]
308 C = self.shape[-1]
310 grad_output = grad_output.contiguous()
311 target = target.contiguous()
312 weight = None if weight is None else weight.contiguous()
314 grad_input = torch.zeros_like(self).contiguous()
316 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_N"]),)
317 with torch_device_fn.device(self.device):
318 nll_loss_backward_kernel[grid](
319 grad_output,
320 target,
321 weight,
322 grad_input,
323 ignore_index,
324 total_weight,
325 N,
326 C,
327 reduction,
328 )
330 return grad_input
333# 3d+ tensor
334def nll_loss2d_forward(self, target, weight=None, reduction=1, ignore_index=-100):
335 logger.debug("GEMS NLL Loss2d FWD")
336 assert self.ndim == 4, "Invalid input ndim"
338 shape = list(target.shape)
339 N, C, D1, D2 = self.shape
340 assert shape == [N, D1, D2], "Invalid target size"
341 D = D1 * D2
342 self = self.contiguous()
343 target = target.contiguous()
344 weight = None if weight is None else weight.contiguous()
346 if reduction == 0:
347 out = torch.empty(shape, dtype=self.dtype, device=self.device)
348 elif reduction == 1:
349 out = torch.zeros(
350 [
351 4,
352 ],
353 dtype=torch.float32,
354 device=self.device,
355 )
356 else:
357 out = torch.zeros([], dtype=torch.float32, device=self.device)
359 grid = lambda meta: (triton.cdiv(N * D, meta["BLOCK_ND"]),)
360 with torch_device_fn.device(self.device):
361 nll_loss2d_forward_kernel[grid](
362 self, target, weight, out, ignore_index, N, C, D, reduction
363 )
365 # redution: 0-None, 1-mean, 2-sum
366 if reduction == 0:
367 output = out
368 total_weight = torch.empty([], dtype=self.dtype, device=self.device)
369 elif reduction == 1:
370 out = out.to(self.dtype)
371 output = out[3]
372 total_weight = out[1]
373 else:
374 output = out.to(self.dtype)
375 total_weight = torch.empty([], dtype=self.dtype, device=self.device)
377 return output, total_weight
380def nll_loss2d_backward(
381 grad_output,
382 self,
383 target,
384 weight=None,
385 reduction=1,
386 ignore_index=-100,
387 total_weight=None,
388):
389 logger.debug("GEMS NLL Loss2d BWD")
390 N, C, D1, D2 = self.shape
391 D = D1 * D2
392 grad_output = grad_output.contiguous()
393 target = target.contiguous()
394 weight = None if weight is None else weight.contiguous()
396 grad_input = torch.zeros_like(self).contiguous()
398 grid = lambda meta: (triton.cdiv(N * D, meta["BLOCK_ND"]),)
399 with torch_device_fn.device(self.device):
400 nll_loss2d_backward_kernel[grid](
401 grad_output,
402 target,
403 weight,
404 grad_input,
405 ignore_index,
406 total_weight,
407 N,
408 C,
409 D,
410 reduction,
411 )
413 return grad_input