Coverage for src/flag_gems/ops/soft_margin_loss.py: 38%
125 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def _soft_margin_loss_elementwise_kernel(
13 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
21 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
23 xf = x.to(tl.float32)
24 yf = y.to(tl.float32)
25 z = -xf * yf
26 absz = tl.abs(z)
27 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz))
29 tl.store(out_ptr + offsets, vals, mask=mask)
32@triton.jit
33def _soft_margin_loss_sum_kernel(
34 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
35):
36 pid = tl.program_id(axis=0)
37 block_start = pid * BLOCK_SIZE
38 offsets = block_start + tl.arange(0, BLOCK_SIZE)
39 mask = offsets < n_elements
41 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
42 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
44 xf = x.to(tl.float32)
45 yf = y.to(tl.float32)
46 z = -xf * yf
47 absz = tl.abs(z)
48 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz))
49 vals = tl.where(mask, vals, 0.0)
51 acc = tl.sum(vals, axis=0)
52 tl.atomic_add(out_ptr, acc)
55def _normalize_reduction(reduction):
56 # Accept both string and enum/int forms: 0=none,1=mean,2=sum
57 if isinstance(reduction, str):
58 r = reduction.lower()
59 if r == "none":
60 return 0
61 if r == "mean":
62 return 1
63 if r == "sum":
64 return 2
65 raise ValueError(f"Invalid reduction: {reduction}")
66 if isinstance(reduction, int):
67 if reduction in (0, 1, 2):
68 return reduction
69 raise ValueError(f"Invalid reduction int: {reduction}")
70 raise ValueError(f"Unsupported reduction type: {type(reduction)}")
73def _check_tensors(input: torch.Tensor, target: torch.Tensor):
74 if not (input.is_cuda and target.is_cuda):
75 raise AssertionError(
76 "soft_margin_loss: input and target must be CUDA tensors for Triton kernel."
77 )
78 if input.device != target.device:
79 raise AssertionError(
80 "soft_margin_loss: input and target must be on the same device."
81 )
82 if input.numel() != target.numel():
83 raise AssertionError(
84 "soft_margin_loss: input and target must have the same number of elements."
85 )
86 if not input.is_contiguous():
87 input = input.contiguous()
88 if not target.is_contiguous():
89 target = target.contiguous()
90 return input, target
93def soft_margin_loss(input: torch.Tensor, target: torch.Tensor, reduction="mean"):
94 logger.debug("GEMS SOFT_MARGIN_LOSS")
95 input, target = _check_tensors(input, target)
96 red = _normalize_reduction(reduction)
97 n_elements = input.numel()
99 if red == 0:
100 # reduction = 'none'
101 out = torch.empty_like(input)
102 if n_elements == 0:
103 return out
104 BLOCK_SIZE = 1024
105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
106 _soft_margin_loss_elementwise_kernel[grid](
107 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE
108 )
109 return out
110 else:
111 # reduction = 'sum' or 'mean' (1=mean, 2=sum)
112 if n_elements == 0:
113 # Follow PyTorch behavior: sum -> 0, mean -> NaN
114 if red == 2:
115 return torch.zeros((), device=input.device, dtype=input.dtype)
116 else:
117 return torch.full(
118 (), float("nan"), device=input.device, dtype=input.dtype
119 )
120 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32)
121 BLOCK_SIZE = 1024
122 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
123 _soft_margin_loss_sum_kernel[grid](
124 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE
125 )
126 if red == 2:
127 # sum
128 return tmp_sum.to(dtype=input.dtype)
129 else:
130 # mean
131 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype)
132 return mean_val
135def soft_margin_loss_out(
136 input: torch.Tensor,
137 target: torch.Tensor,
138 reduction="mean",
139 out: torch.Tensor = None,
140):
141 logger.debug("GEMS SOFT_MARGIN_LOSS_OUT")
142 input, target = _check_tensors(input, target)
143 red = _normalize_reduction(reduction)
144 n_elements = input.numel()
146 if out is None:
147 # Allocate output based on reduction
148 if red == 0:
149 out = torch.empty_like(input)
150 else:
151 out = torch.empty((), device=input.device, dtype=input.dtype)
152 else:
153 if not out.is_cuda:
154 raise AssertionError("soft_margin_loss_out: out must be a CUDA tensor.")
155 if red == 0:
156 if out.numel() != n_elements:
157 raise AssertionError(
158 "soft_margin_loss_out: for reduction='none', out must match input shape."
159 )
160 else:
161 if out.numel() != 1:
162 raise AssertionError(
163 "soft_margin_loss_out: for reduction='sum' or 'mean', out must be a scalar tensor."
164 )
165 if out.device != input.device:
166 raise AssertionError(
167 "soft_margin_loss_out: out must be on the same device as input."
168 )
170 if red == 0:
171 if n_elements > 0:
172 BLOCK_SIZE = 1024
173 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
174 _soft_margin_loss_elementwise_kernel[grid](
175 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE
176 )
177 return out
178 else:
179 if n_elements == 0:
180 if red == 2:
181 out.fill_(0)
182 else:
183 out.fill_(float("nan"))
184 return out
185 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32)
186 BLOCK_SIZE = 1024
187 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
188 _soft_margin_loss_sum_kernel[grid](
189 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE
190 )
191 if red == 2:
192 out.fill_(tmp_sum.to(dtype=input.dtype))
193 else:
194 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype)
195 out.fill_(mean_val)
196 return out