Coverage for src/flag_gems/ops/nll_loss_nd.py: 12%
138 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.ops.nllloss import nll_loss_backward as nll_loss_2d_backward
8from flag_gems.ops.nllloss import nll_loss_forward as nll_loss_2d
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit
17def nll_loss_nd_forward_kernel(
18 input_ptr,
19 target_ptr,
20 weight_ptr,
21 out_ptr,
22 scratch_ptr,
23 C,
24 S,
25 stride_in_n,
26 stride_in_c,
27 stride_in_s,
28 stride_tgt_n,
29 stride_tgt_s,
30 ignore_index,
31 HAS_WEIGHT: tl.constexpr,
32 REDUCTION: tl.constexpr,
33 BLOCK_S: tl.constexpr = 1024,
34):
35 pid_s = tl.program_id(0)
36 pid_n = tl.program_id(1)
38 s_offsets = pid_s * BLOCK_S + tl.arange(0, BLOCK_S)
39 mask_s = s_offsets < S
41 tgt_offsets = pid_n * stride_tgt_n + s_offsets * stride_tgt_s
42 t = tl.load(target_ptr + tgt_offsets, mask=mask_s, other=ignore_index).to(tl.int32)
44 valid = mask_s & (t != ignore_index) & (t >= 0) & (t < C)
46 in_offsets = pid_n * stride_in_n + t * stride_in_c + s_offsets * stride_in_s
47 val = tl.load(input_ptr + in_offsets, mask=valid, other=0.0).to(tl.float32)
49 if HAS_WEIGHT:
50 w = tl.load(weight_ptr + t, mask=valid, other=0.0).to(tl.float32)
51 loss_val = tl.where(valid, -val * w, 0.0)
52 else:
53 w = tl.where(valid, 1.0, 0.0).to(tl.float32)
54 loss_val = tl.where(valid, -val, 0.0)
56 # none
57 if REDUCTION == 0:
58 out_offset = pid_n * S + s_offsets
59 tl.store(
60 out_ptr + out_offset, loss_val.to(out_ptr.dtype.element_ty), mask=mask_s
61 )
62 else:
63 block_loss_sum = tl.sum(loss_val, axis=0)
64 # mean
65 if REDUCTION == 1:
66 block_weight_sum = tl.sum(w, axis=0)
68 tl.atomic_add(scratch_ptr, block_loss_sum, sem="relaxed")
69 tl.atomic_add(scratch_ptr + 1, block_weight_sum, sem="relaxed")
71 old_cnt = tl.atomic_add(scratch_ptr + 2, 1.0, sem="release")
73 total_programs = tl.num_programs(0) * tl.num_programs(1)
75 if old_cnt == total_programs - 1.0:
76 total_loss = tl.load(scratch_ptr)
77 total_weight = tl.load(scratch_ptr + 1)
78 final_val = tl.where(
79 total_weight == 0.0, 0.0, total_loss / total_weight
80 )
81 tl.store(out_ptr, final_val.to(out_ptr.dtype.element_ty))
82 # Sum
83 else:
84 tl.atomic_add(scratch_ptr, block_loss_sum, sem="relaxed")
86 old_cnt = tl.atomic_add(scratch_ptr + 2, 1.0, sem="release")
87 total_programs = tl.num_programs(0) * tl.num_programs(1)
89 if old_cnt == total_programs - 1.0:
90 total_loss = tl.load(scratch_ptr)
91 tl.store(out_ptr, total_loss.to(out_ptr.dtype.element_ty))
94@libentry()
95@triton.jit
96def nll_loss_nd_backward_kernel(
97 grad_out_ptr,
98 target_ptr,
99 weight_ptr,
100 grad_in_ptr,
101 total_weight_ptr,
102 C,
103 S,
104 stride_in_n,
105 stride_in_c,
106 stride_in_s,
107 stride_tgt_n,
108 stride_tgt_s,
109 stride_go_n,
110 stride_go_s,
111 ignore_index,
112 HAS_WEIGHT: tl.constexpr,
113 REDUCTION: tl.constexpr,
114 BLOCK_S: tl.constexpr = 1024,
115):
116 pid_s = tl.program_id(0)
117 pid_n = tl.program_id(1)
119 s_offsets = pid_s * BLOCK_S + tl.arange(0, BLOCK_S)
120 mask_s = s_offsets < S
122 tgt_offsets = pid_n * stride_tgt_n + s_offsets * stride_tgt_s
123 t = tl.load(target_ptr + tgt_offsets, mask=mask_s, other=ignore_index).to(tl.int32)
125 valid = mask_s & (t != ignore_index) & (t >= 0) & (t < C)
127 if REDUCTION == 0: # none
128 out_grad_offsets = pid_n * stride_go_n + s_offsets * stride_go_s
129 out_grad = tl.load(grad_out_ptr + out_grad_offsets, mask=valid, other=0.0).to(
130 tl.float32
131 )
132 else: # mean or sum
133 out_grad = tl.load(grad_out_ptr).to(tl.float32)
135 if HAS_WEIGHT:
136 w = tl.load(weight_ptr + t, mask=valid, other=0.0).to(tl.float32)
137 else:
138 w = tl.where(valid, 1.0, 0.0).to(tl.float32)
140 if REDUCTION == 1: # mean
141 total_weight = tl.load(total_weight_ptr).to(tl.float32)
142 grad_in_val = tl.where(total_weight != 0.0, -w * out_grad / total_weight, 0.0)
143 else: # sum or none
144 grad_in_val = -w * out_grad
146 in_offsets = pid_n * stride_in_n + t * stride_in_c + s_offsets * stride_in_s
147 tl.store(
148 grad_in_ptr + in_offsets,
149 grad_in_val.to(grad_in_ptr.dtype.element_ty),
150 mask=valid,
151 )
154def nll_loss_nd_forward(
155 input: torch.Tensor,
156 target: torch.Tensor,
157 weight: torch.Tensor = None,
158 reduction: int = 1,
159 ignore_index: int = -100,
160):
161 logger.debug("GEMS NLL LOSS ND FWD")
162 if input.dim() < 3:
163 out, total_weight = nll_loss_2d(
164 input, target, weight=weight, reduction=reduction, ignore_index=ignore_index
165 )
166 return out, total_weight
167 else:
168 N = input.shape[0]
169 C = input.shape[1]
170 S = input.numel() // (N * C)
172 inp = input.reshape(N, C, S)
174 if target.numel() != N * S:
175 raise ValueError(
176 f"Target size {target.shape} doesn't match input size (N={N}, S={S})"
177 )
178 else:
179 tgt = target.reshape(N, S)
181 stride_in_n, stride_in_c, stride_in_s = inp.stride()
182 stride_tgt_n, stride_tgt_s = tgt.stride()
184 if weight is None:
185 has_weight = False
186 w = input
187 else:
188 has_weight = True
189 if weight.numel() != C:
190 raise ValueError(f"Weight shape {weight.shape} must be ({C},)")
191 w = weight.contiguous()
193 if reduction not in [0, 1, 2]:
194 raise ValueError("reduction must be 0 ('none'), 1 ('mean'), or 2 ('sum')")
196 grid = lambda meta: (triton.cdiv(S, meta["BLOCK_S"]), N)
197 with torch_device_fn.device(input.device):
198 if reduction == 0:
199 out = torch.empty((N, S), device=input.device, dtype=input.dtype)
200 scratch = torch.empty(1, device=input.device)
202 nll_loss_nd_forward_kernel[grid](
203 inp,
204 tgt,
205 w,
206 out,
207 scratch,
208 C,
209 S,
210 stride_in_n,
211 stride_in_c,
212 stride_in_s,
213 stride_tgt_n,
214 stride_tgt_s,
215 ignore_index,
216 HAS_WEIGHT=has_weight,
217 REDUCTION=reduction,
218 )
220 if target.dim() == input.dim() - 1:
221 res = out.view_as(target)
222 else:
223 res = out.reshape(target.shape)
225 total_weight = torch.empty([], device=input.device, dtype=input.dtype)
226 return res, total_weight
228 else:
229 out = torch.empty(1, device=input.device, dtype=input.dtype)
230 scratch = torch.zeros(3, device=input.device, dtype=torch.float32)
232 nll_loss_nd_forward_kernel[grid](
233 inp,
234 tgt,
235 w,
236 out,
237 scratch,
238 C,
239 S,
240 stride_in_n,
241 stride_in_c,
242 stride_in_s,
243 stride_tgt_n,
244 stride_tgt_s,
245 ignore_index,
246 HAS_WEIGHT=has_weight,
247 REDUCTION=reduction,
248 )
249 out = out[0]
251 if reduction == 1:
252 total_weight = scratch[1]
253 else:
254 total_weight = torch.empty(
255 [], device=input.device, dtype=input.dtype
256 )
258 return out, total_weight
261def nll_loss_nd_backward(
262 grad_output: torch.Tensor,
263 input: torch.Tensor,
264 target: torch.Tensor,
265 weight: torch.Tensor = None,
266 reduction: int = 1,
267 ignore_index: int = -100,
268 total_weight: torch.Tensor = None,
269):
270 logger.debug("GEMS NLL LOSS ND BWD")
272 if input.dim() < 3:
273 return nll_loss_2d_backward(
274 grad_output,
275 input,
276 target,
277 weight=weight,
278 reduction=reduction,
279 ignore_index=ignore_index,
280 total_weight=total_weight,
281 )
282 else:
283 grad_input = torch.zeros_like(input)
285 N = input.shape[0]
286 C = input.shape[1]
287 S = input.numel() // (N * C)
289 grad_inp = grad_input.reshape(N, C, S)
290 tgt = target.reshape(N, S)
292 stride_in_n, stride_in_c, stride_in_s = grad_inp.stride()
293 stride_tgt_n, stride_tgt_s = tgt.stride()
295 if weight is None:
296 has_weight = False
297 w = input
298 else:
299 has_weight = True
300 w = weight.contiguous()
302 if reduction == 0:
303 grad_out = grad_output.reshape(N, S)
304 stride_go_n, stride_go_s = grad_out.stride()
305 else:
306 grad_out = grad_output
307 stride_go_n, stride_go_s = 0, 0
309 grid = lambda meta: (triton.cdiv(S, meta["BLOCK_S"]), N)
311 with torch_device_fn.device(input.device):
312 nll_loss_nd_backward_kernel[grid](
313 grad_out,
314 tgt,
315 w,
316 grad_input,
317 total_weight,
318 C,
319 S,
320 stride_in_n,
321 stride_in_c,
322 stride_in_s,
323 stride_tgt_n,
324 stride_tgt_s,
325 stride_go_n,
326 stride_go_s,
327 ignore_index,
328 HAS_WEIGHT=has_weight,
329 REDUCTION=reduction,
330 )
332 return grad_input