Coverage for src/flag_gems/experimental_ops/soft_margin_loss.py: 0%
121 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _soft_margin_loss_elementwise_kernel(
8 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
9):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
15 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
16 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
18 xf = x.to(tl.float32)
19 yf = y.to(tl.float32)
20 z = -xf * yf
21 absz = tl.abs(z)
22 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz))
24 tl.store(out_ptr + offsets, vals, mask=mask)
27@triton.jit
28def _soft_margin_loss_sum_kernel(
29 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr
30):
31 pid = tl.program_id(axis=0)
32 block_start = pid * BLOCK_SIZE
33 offsets = block_start + tl.arange(0, BLOCK_SIZE)
34 mask = offsets < n_elements
36 x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
37 y = tl.load(y_ptr + offsets, mask=mask, other=0.0)
39 xf = x.to(tl.float32)
40 yf = y.to(tl.float32)
41 z = -xf * yf
42 absz = tl.abs(z)
43 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz))
44 vals = tl.where(mask, vals, 0.0)
46 acc = tl.sum(vals, axis=0)
47 tl.atomic_add(out_ptr, acc)
50def _normalize_reduction(reduction):
51 # Accept both string and enum/int forms: 0=none,1=mean,2=sum
52 if isinstance(reduction, str):
53 r = reduction.lower()
54 if r == "none":
55 return 0
56 if r == "mean":
57 return 1
58 if r == "sum":
59 return 2
60 raise ValueError(f"Invalid reduction: {reduction}")
61 if isinstance(reduction, int):
62 if reduction in (0, 1, 2):
63 return reduction
64 raise ValueError(f"Invalid reduction int: {reduction}")
65 raise ValueError(f"Unsupported reduction type: {type(reduction)}")
68def _check_tensors(input: torch.Tensor, target: torch.Tensor):
69 if not (input.is_cuda and target.is_cuda):
70 raise AssertionError(
71 "soft_margin_loss: input and target must be CUDA tensors for Triton kernel."
72 )
73 if input.device != target.device:
74 raise AssertionError(
75 "soft_margin_loss: input and target must be on the same device."
76 )
77 if input.numel() != target.numel():
78 raise AssertionError(
79 "soft_margin_loss: input and target must have the same number of elements."
80 )
81 if not input.is_contiguous():
82 input = input.contiguous()
83 if not target.is_contiguous():
84 target = target.contiguous()
85 return input, target
88def soft_margin_loss(input: torch.Tensor, target: torch.Tensor, reduction="mean"):
89 input, target = _check_tensors(input, target)
90 red = _normalize_reduction(reduction)
91 n_elements = input.numel()
93 if red == 0:
94 # reduction = 'none'
95 out = torch.empty_like(input)
96 if n_elements == 0:
97 return out
98 BLOCK_SIZE = 1024
99 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
100 _soft_margin_loss_elementwise_kernel[grid](
101 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE
102 )
103 return out
104 else:
105 # reduction = 'sum' or 'mean' (1=mean, 2=sum)
106 if n_elements == 0:
107 # Follow PyTorch behavior: sum -> 0, mean -> NaN
108 if red == 2:
109 return torch.zeros((), device=input.device, dtype=input.dtype)
110 else:
111 return torch.full(
112 (), float("nan"), device=input.device, dtype=input.dtype
113 )
114 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32)
115 BLOCK_SIZE = 1024
116 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
117 _soft_margin_loss_sum_kernel[grid](
118 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE
119 )
120 if red == 2:
121 # sum
122 return tmp_sum.to(dtype=input.dtype)
123 else:
124 # mean
125 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype)
126 return mean_val
129def soft_margin_loss_out(
130 input: torch.Tensor,
131 target: torch.Tensor,
132 reduction="mean",
133 out: torch.Tensor = None,
134):
135 input, target = _check_tensors(input, target)
136 red = _normalize_reduction(reduction)
137 n_elements = input.numel()
139 if out is None:
140 # Allocate output based on reduction
141 if red == 0:
142 out = torch.empty_like(input)
143 else:
144 out = torch.empty((), device=input.device, dtype=input.dtype)
145 else:
146 if not out.is_cuda:
147 raise AssertionError("soft_margin_loss_out: out must be a CUDA tensor.")
148 if red == 0:
149 if out.numel() != n_elements:
150 raise AssertionError(
151 "soft_margin_loss_out: for reduction='none', out must match input shape."
152 )
153 else:
154 if out.numel() != 1:
155 raise AssertionError(
156 "soft_margin_loss_out: for reduction='sum' or 'mean', out must be a scalar tensor."
157 )
158 if out.device != input.device:
159 raise AssertionError(
160 "soft_margin_loss_out: out must be on the same device as input."
161 )
163 if red == 0:
164 if n_elements > 0:
165 BLOCK_SIZE = 1024
166 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
167 _soft_margin_loss_elementwise_kernel[grid](
168 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE
169 )
170 return out
171 else:
172 if n_elements == 0:
173 if red == 2:
174 out.fill_(0)
175 else:
176 out.fill_(float("nan"))
177 return out
178 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32)
179 BLOCK_SIZE = 1024
180 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
181 _soft_margin_loss_sum_kernel[grid](
182 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE
183 )
184 if red == 2:
185 out.fill_(tmp_sum.to(dtype=input.dtype))
186 else:
187 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype)
188 out.fill_(mean_val)
189 return out