Coverage for src/flag_gems/experimental_ops/margin_ranking_loss.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def margin_ranking_loss( 

8 x1_ptr, x2_ptr, target_ptr, out_ptr, n_elements, margin, BLOCK_SIZE: tl.constexpr 

9): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 

15 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0) 

16 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0) 

17 y = tl.load(target_ptr + offsets, mask=mask, other=0) 

18 

19 diff = x1 - x2 

20 m = tl.full([BLOCK_SIZE], margin, x1.dtype) 

21 val = -y * diff + m 

22 zero = tl.zeros([BLOCK_SIZE], dtype=val.dtype) 

23 loss = tl.maximum(val, zero) 

24 

25 tl.store(out_ptr + offsets, loss, mask=mask) 

26 

27 

28# Preserve a handle to the Triton kernel before defining the Python wrapper with the same name. 

29_margin_ranking_loss_kernel = margin_ranking_loss 

30 

31 

32def margin_ranking_loss(*args, **kwargs): 

33 # Parse inputs: (input1, input2, target, margin=0.0, reduction='mean') 

34 if len(args) < 3 and not all(k in kwargs for k in ("self", "other", "target")): 

35 raise TypeError( 

36 "margin_ranking_loss requires at least three positional arguments: input1, input2, target" 

37 ) 

38 

39 # Positional extraction 

40 if len(args) >= 3: 

41 x1, x2, target = args[0], args[1], args[2] 

42 else: 

43 # Fallback to keyword names similar to ATen signature 

44 x1 = kwargs["self"] 

45 x2 = kwargs["other"] 

46 target = kwargs["target"] 

47 

48 # margin and reduction extraction 

49 margin = 0.0 

50 reduction = "mean" 

51 if len(args) >= 4: 

52 margin = args[3] 

53 if len(args) >= 5: 

54 reduction = args[4] 

55 if "margin" in kwargs: 

56 margin = kwargs["margin"] 

57 if "reduction" in kwargs: 

58 reduction = kwargs["reduction"] 

59 

60 # Normalize reduction 

61 if isinstance(reduction, int): 

62 reduction = {0: "none", 1: "mean", 2: "sum"}.get(reduction, "mean") 

63 if reduction not in ("none", "mean", "sum"): 

64 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'") 

65 

66 # Device check and fallback 

67 device = x1.device 

68 if not (isinstance(device, torch.device) and device.type == "cuda"): 

69 # Fallback to PyTorch implementation for non-CUDA tensors 

70 return torch.ops.aten.margin_ranking_loss( 

71 x1, x2, target, float(margin), {"none": 0, "mean": 1, "sum": 2}[reduction] 

72 ) 

73 

74 # Broadcast tensors 

75 x1_b, x2_b, tgt_b = torch.broadcast_tensors(x1, x2, target) 

76 

77 # Choose dtype (prefer input dtype; fall back to float32 if non-floating) 

78 common_dtype = x1_b.dtype if x1_b.is_floating_point() else torch.float32 

79 x1_b = x1_b.to(dtype=common_dtype) 

80 x2_b = x2_b.to(dtype=common_dtype) 

81 tgt_b = tgt_b.to(dtype=common_dtype) 

82 

83 # Flatten contiguous buffers 

84 x1_c = x1_b.contiguous().view(-1) 

85 x2_c = x2_b.contiguous().view(-1) 

86 tgt_c = tgt_b.contiguous().view(-1) 

87 

88 # Output buffer 

89 out = torch.empty_like(x1_c) 

90 

91 n_elements = out.numel() 

92 if n_elements == 0: 

93 # Handle empty tensors 

94 if reduction == "none": 

95 return out.view(x1_b.shape) 

96 elif reduction == "sum": 

97 return out.sum() 

98 else: 

99 return out.mean() 

100 

101 # Launch Triton kernel 

102 BLOCK_SIZE = 1024 

103 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

104 _margin_ranking_loss_kernel[grid]( 

105 x1_c, x2_c, tgt_c, out, n_elements, float(margin), BLOCK_SIZE=BLOCK_SIZE 

106 ) 

107 

108 # Apply reduction 

109 if reduction == "none": 

110 return out.view(x1_b.shape) 

111 elif reduction == "sum": 

112 return out.sum() 

113 else: 

114 return out.mean()