Coverage for src/flag_gems/experimental_ops/margin_ranking_loss.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def margin_ranking_loss(
8 x1_ptr, x2_ptr, target_ptr, out_ptr, n_elements, margin, BLOCK_SIZE: tl.constexpr
9):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
15 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0)
16 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0)
17 y = tl.load(target_ptr + offsets, mask=mask, other=0)
19 diff = x1 - x2
20 m = tl.full([BLOCK_SIZE], margin, x1.dtype)
21 val = -y * diff + m
22 zero = tl.zeros([BLOCK_SIZE], dtype=val.dtype)
23 loss = tl.maximum(val, zero)
25 tl.store(out_ptr + offsets, loss, mask=mask)
28# Preserve a handle to the Triton kernel before defining the Python wrapper with the same name.
29_margin_ranking_loss_kernel = margin_ranking_loss
32def margin_ranking_loss(*args, **kwargs):
33 # Parse inputs: (input1, input2, target, margin=0.0, reduction='mean')
34 if len(args) < 3 and not all(k in kwargs for k in ("self", "other", "target")):
35 raise TypeError(
36 "margin_ranking_loss requires at least three positional arguments: input1, input2, target"
37 )
39 # Positional extraction
40 if len(args) >= 3:
41 x1, x2, target = args[0], args[1], args[2]
42 else:
43 # Fallback to keyword names similar to ATen signature
44 x1 = kwargs["self"]
45 x2 = kwargs["other"]
46 target = kwargs["target"]
48 # margin and reduction extraction
49 margin = 0.0
50 reduction = "mean"
51 if len(args) >= 4:
52 margin = args[3]
53 if len(args) >= 5:
54 reduction = args[4]
55 if "margin" in kwargs:
56 margin = kwargs["margin"]
57 if "reduction" in kwargs:
58 reduction = kwargs["reduction"]
60 # Normalize reduction
61 if isinstance(reduction, int):
62 reduction = {0: "none", 1: "mean", 2: "sum"}.get(reduction, "mean")
63 if reduction not in ("none", "mean", "sum"):
64 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'")
66 # Device check and fallback
67 device = x1.device
68 if not (isinstance(device, torch.device) and device.type == "cuda"):
69 # Fallback to PyTorch implementation for non-CUDA tensors
70 return torch.ops.aten.margin_ranking_loss(
71 x1, x2, target, float(margin), {"none": 0, "mean": 1, "sum": 2}[reduction]
72 )
74 # Broadcast tensors
75 x1_b, x2_b, tgt_b = torch.broadcast_tensors(x1, x2, target)
77 # Choose dtype (prefer input dtype; fall back to float32 if non-floating)
78 common_dtype = x1_b.dtype if x1_b.is_floating_point() else torch.float32
79 x1_b = x1_b.to(dtype=common_dtype)
80 x2_b = x2_b.to(dtype=common_dtype)
81 tgt_b = tgt_b.to(dtype=common_dtype)
83 # Flatten contiguous buffers
84 x1_c = x1_b.contiguous().view(-1)
85 x2_c = x2_b.contiguous().view(-1)
86 tgt_c = tgt_b.contiguous().view(-1)
88 # Output buffer
89 out = torch.empty_like(x1_c)
91 n_elements = out.numel()
92 if n_elements == 0:
93 # Handle empty tensors
94 if reduction == "none":
95 return out.view(x1_b.shape)
96 elif reduction == "sum":
97 return out.sum()
98 else:
99 return out.mean()
101 # Launch Triton kernel
102 BLOCK_SIZE = 1024
103 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
104 _margin_ranking_loss_kernel[grid](
105 x1_c, x2_c, tgt_c, out, n_elements, float(margin), BLOCK_SIZE=BLOCK_SIZE
106 )
108 # Apply reduction
109 if reduction == "none":
110 return out.view(x1_b.shape)
111 elif reduction == "sum":
112 return out.sum()
113 else:
114 return out.mean()