Coverage for src/flag_gems/ops/margin_ranking_loss.py: 63%
71 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def _margin_ranking_loss_kernel(
15 x1_ptr, x2_ptr, target_ptr, out_ptr, n_elements, margin, BLOCK_SIZE: tl.constexpr
16):
17 pid = tl.program_id(axis=0)
18 block_start = pid * BLOCK_SIZE
19 offsets = block_start + tl.arange(0, BLOCK_SIZE)
20 mask = offsets < n_elements
22 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0)
23 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0)
24 y = tl.load(target_ptr + offsets, mask=mask, other=0)
26 diff = x1 - x2
27 m = tl.full([BLOCK_SIZE], margin, x1.dtype)
28 val = -y * diff + m
29 zero = tl.zeros([BLOCK_SIZE], dtype=val.dtype)
30 loss = tl.maximum(val, zero)
32 tl.store(out_ptr + offsets, loss, mask=mask)
35def margin_ranking_loss(*args, **kwargs):
36 logger.debug("GEMS MARGIN_RANKING_LOSS")
37 # Parse inputs: (input1, input2, target, margin=0.0, reduction='mean')
38 if len(args) < 3 and not all(k in kwargs for k in ("self", "other", "target")):
39 raise TypeError(
40 "margin_ranking_loss requires at least three positional arguments: input1, input2, target"
41 )
43 # Positional extraction
44 if len(args) >= 3:
45 x1, x2, target = args[0], args[1], args[2]
46 else:
47 # Fallback to keyword names similar to ATen signature
48 x1 = kwargs["self"]
49 x2 = kwargs["other"]
50 target = kwargs["target"]
52 # margin and reduction extraction
53 margin = 0.0
54 reduction = "mean"
55 if len(args) >= 4:
56 margin = args[3]
57 if len(args) >= 5:
58 reduction = args[4]
59 if "margin" in kwargs:
60 margin = kwargs["margin"]
61 if "reduction" in kwargs:
62 reduction = kwargs["reduction"]
64 # Normalize reduction
65 if isinstance(reduction, int):
66 reduction = {0: "none", 1: "mean", 2: "sum"}.get(reduction, "mean")
67 if reduction not in ("none", "mean", "sum"):
68 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'")
70 # Device check and fallback
71 device = x1.device
72 if not (isinstance(device, torch.device) and device.type == flag_gems.device):
73 # Fallback to PyTorch implementation for non-CUDA tensors
74 return torch.ops.aten.margin_ranking_loss(
75 x1, x2, target, float(margin), {"none": 0, "mean": 1, "sum": 2}[reduction]
76 )
78 # Broadcast tensors
79 x1_b, x2_b, tgt_b = torch.broadcast_tensors(x1, x2, target)
81 # Choose dtype (prefer input dtype; fall back to float32 if non-floating)
82 common_dtype = x1_b.dtype if x1_b.is_floating_point() else torch.float32
83 x1_b = x1_b.to(dtype=common_dtype)
84 x2_b = x2_b.to(dtype=common_dtype)
85 tgt_b = tgt_b.to(dtype=common_dtype)
87 # Flatten contiguous buffers
88 x1_c = x1_b.contiguous().view(-1)
89 x2_c = x2_b.contiguous().view(-1)
90 tgt_c = tgt_b.contiguous().view(-1)
92 # Output buffer
93 out = torch.empty_like(x1_c)
95 n_elements = out.numel()
96 if n_elements == 0:
97 # Handle empty tensors
98 if reduction == "none":
99 return out.view(x1_b.shape)
100 elif reduction == "sum":
101 return out.sum()
102 else:
103 return out.mean()
105 # Launch Triton kernel
106 BLOCK_SIZE = 1024
107 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
108 _margin_ranking_loss_kernel[grid](
109 x1_c, x2_c, tgt_c, out, n_elements, float(margin), BLOCK_SIZE=BLOCK_SIZE
110 )
112 # Apply reduction
113 if reduction == "none":
114 return out.view(x1_b.shape)
115 elif reduction == "sum":
116 return out.sum()
117 else:
118 return out.mean()