Coverage for src/flag_gems/ops/margin_ranking_loss.py: 63%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def _margin_ranking_loss_kernel(
13 x1_ptr, x2_ptr, target_ptr, out_ptr, n_elements, margin, BLOCK_SIZE: tl.constexpr
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0)
21 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0)
22 y = tl.load(target_ptr + offsets, mask=mask, other=0)
24 diff = x1 - x2
25 m = tl.full([BLOCK_SIZE], margin, x1.dtype)
26 val = -y * diff + m
27 zero = tl.zeros([BLOCK_SIZE], dtype=val.dtype)
28 loss = tl.maximum(val, zero)
30 tl.store(out_ptr + offsets, loss, mask=mask)
33def margin_ranking_loss(*args, **kwargs):
34 logger.debug("GEMS MARGIN_RANKING_LOSS")
35 # Parse inputs: (input1, input2, target, margin=0.0, reduction='mean')
36 if len(args) < 3 and not all(k in kwargs for k in ("self", "other", "target")):
37 raise TypeError(
38 "margin_ranking_loss requires at least three positional arguments: input1, input2, target"
39 )
41 # Positional extraction
42 if len(args) >= 3:
43 x1, x2, target = args[0], args[1], args[2]
44 else:
45 # Fallback to keyword names similar to ATen signature
46 x1 = kwargs["self"]
47 x2 = kwargs["other"]
48 target = kwargs["target"]
50 # margin and reduction extraction
51 margin = 0.0
52 reduction = "mean"
53 if len(args) >= 4:
54 margin = args[3]
55 if len(args) >= 5:
56 reduction = args[4]
57 if "margin" in kwargs:
58 margin = kwargs["margin"]
59 if "reduction" in kwargs:
60 reduction = kwargs["reduction"]
62 # Normalize reduction
63 if isinstance(reduction, int):
64 reduction = {0: "none", 1: "mean", 2: "sum"}.get(reduction, "mean")
65 if reduction not in ("none", "mean", "sum"):
66 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'")
68 # Device check and fallback
69 device = x1.device
70 if not (isinstance(device, torch.device) and device.type == "cuda"):
71 # Fallback to PyTorch implementation for non-CUDA tensors
72 return torch.ops.aten.margin_ranking_loss(
73 x1, x2, target, float(margin), {"none": 0, "mean": 1, "sum": 2}[reduction]
74 )
76 # Broadcast tensors
77 x1_b, x2_b, tgt_b = torch.broadcast_tensors(x1, x2, target)
79 # Choose dtype (prefer input dtype; fall back to float32 if non-floating)
80 common_dtype = x1_b.dtype if x1_b.is_floating_point() else torch.float32
81 x1_b = x1_b.to(dtype=common_dtype)
82 x2_b = x2_b.to(dtype=common_dtype)
83 tgt_b = tgt_b.to(dtype=common_dtype)
85 # Flatten contiguous buffers
86 x1_c = x1_b.contiguous().view(-1)
87 x2_c = x2_b.contiguous().view(-1)
88 tgt_c = tgt_b.contiguous().view(-1)
90 # Output buffer
91 out = torch.empty_like(x1_c)
93 n_elements = out.numel()
94 if n_elements == 0:
95 # Handle empty tensors
96 if reduction == "none":
97 return out.view(x1_b.shape)
98 elif reduction == "sum":
99 return out.sum()
100 else:
101 return out.mean()
103 # Launch Triton kernel
104 BLOCK_SIZE = 1024
105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
106 _margin_ranking_loss_kernel[grid](
107 x1_c, x2_c, tgt_c, out, n_elements, float(margin), BLOCK_SIZE=BLOCK_SIZE
108 )
110 # Apply reduction
111 if reduction == "none":
112 return out.view(x1_b.shape)
113 elif reduction == "sum":
114 return out.sum()
115 else:
116 return out.mean()