Coverage for src/flag_gems/ops/margin_ranking_loss.py: 63%

71 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def _margin_ranking_loss_kernel( 

15 x1_ptr, x2_ptr, target_ptr, out_ptr, n_elements, margin, BLOCK_SIZE: tl.constexpr 

16): 

17 pid = tl.program_id(axis=0) 

18 block_start = pid * BLOCK_SIZE 

19 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

20 mask = offsets < n_elements 

21 

22 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0) 

23 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0) 

24 y = tl.load(target_ptr + offsets, mask=mask, other=0) 

25 

26 diff = x1 - x2 

27 m = tl.full([BLOCK_SIZE], margin, x1.dtype) 

28 val = -y * diff + m 

29 zero = tl.zeros([BLOCK_SIZE], dtype=val.dtype) 

30 loss = tl.maximum(val, zero) 

31 

32 tl.store(out_ptr + offsets, loss, mask=mask) 

33 

34 

35def margin_ranking_loss(*args, **kwargs): 

36 logger.debug("GEMS MARGIN_RANKING_LOSS") 

37 # Parse inputs: (input1, input2, target, margin=0.0, reduction='mean') 

38 if len(args) < 3 and not all(k in kwargs for k in ("self", "other", "target")): 

39 raise TypeError( 

40 "margin_ranking_loss requires at least three positional arguments: input1, input2, target" 

41 ) 

42 

43 # Positional extraction 

44 if len(args) >= 3: 

45 x1, x2, target = args[0], args[1], args[2] 

46 else: 

47 # Fallback to keyword names similar to ATen signature 

48 x1 = kwargs["self"] 

49 x2 = kwargs["other"] 

50 target = kwargs["target"] 

51 

52 # margin and reduction extraction 

53 margin = 0.0 

54 reduction = "mean" 

55 if len(args) >= 4: 

56 margin = args[3] 

57 if len(args) >= 5: 

58 reduction = args[4] 

59 if "margin" in kwargs: 

60 margin = kwargs["margin"] 

61 if "reduction" in kwargs: 

62 reduction = kwargs["reduction"] 

63 

64 # Normalize reduction 

65 if isinstance(reduction, int): 

66 reduction = {0: "none", 1: "mean", 2: "sum"}.get(reduction, "mean") 

67 if reduction not in ("none", "mean", "sum"): 

68 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'") 

69 

70 # Device check and fallback 

71 device = x1.device 

72 if not (isinstance(device, torch.device) and device.type == flag_gems.device): 

73 # Fallback to PyTorch implementation for non-CUDA tensors 

74 return torch.ops.aten.margin_ranking_loss( 

75 x1, x2, target, float(margin), {"none": 0, "mean": 1, "sum": 2}[reduction] 

76 ) 

77 

78 # Broadcast tensors 

79 x1_b, x2_b, tgt_b = torch.broadcast_tensors(x1, x2, target) 

80 

81 # Choose dtype (prefer input dtype; fall back to float32 if non-floating) 

82 common_dtype = x1_b.dtype if x1_b.is_floating_point() else torch.float32 

83 x1_b = x1_b.to(dtype=common_dtype) 

84 x2_b = x2_b.to(dtype=common_dtype) 

85 tgt_b = tgt_b.to(dtype=common_dtype) 

86 

87 # Flatten contiguous buffers 

88 x1_c = x1_b.contiguous().view(-1) 

89 x2_c = x2_b.contiguous().view(-1) 

90 tgt_c = tgt_b.contiguous().view(-1) 

91 

92 # Output buffer 

93 out = torch.empty_like(x1_c) 

94 

95 n_elements = out.numel() 

96 if n_elements == 0: 

97 # Handle empty tensors 

98 if reduction == "none": 

99 return out.view(x1_b.shape) 

100 elif reduction == "sum": 

101 return out.sum() 

102 else: 

103 return out.mean() 

104 

105 # Launch Triton kernel 

106 BLOCK_SIZE = 1024 

107 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

108 _margin_ranking_loss_kernel[grid]( 

109 x1_c, x2_c, tgt_c, out, n_elements, float(margin), BLOCK_SIZE=BLOCK_SIZE 

110 ) 

111 

112 # Apply reduction 

113 if reduction == "none": 

114 return out.view(x1_b.shape) 

115 elif reduction == "sum": 

116 return out.sum() 

117 else: 

118 return out.mean()