Coverage for src/flag_gems/ops/margin_ranking_loss.py: 63%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def _margin_ranking_loss_kernel( 

13 x1_ptr, x2_ptr, target_ptr, out_ptr, n_elements, margin, BLOCK_SIZE: tl.constexpr 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x1 = tl.load(x1_ptr + offsets, mask=mask, other=0) 

21 x2 = tl.load(x2_ptr + offsets, mask=mask, other=0) 

22 y = tl.load(target_ptr + offsets, mask=mask, other=0) 

23 

24 diff = x1 - x2 

25 m = tl.full([BLOCK_SIZE], margin, x1.dtype) 

26 val = -y * diff + m 

27 zero = tl.zeros([BLOCK_SIZE], dtype=val.dtype) 

28 loss = tl.maximum(val, zero) 

29 

30 tl.store(out_ptr + offsets, loss, mask=mask) 

31 

32 

33def margin_ranking_loss(*args, **kwargs): 

34 logger.debug("GEMS MARGIN_RANKING_LOSS") 

35 # Parse inputs: (input1, input2, target, margin=0.0, reduction='mean') 

36 if len(args) < 3 and not all(k in kwargs for k in ("self", "other", "target")): 

37 raise TypeError( 

38 "margin_ranking_loss requires at least three positional arguments: input1, input2, target" 

39 ) 

40 

41 # Positional extraction 

42 if len(args) >= 3: 

43 x1, x2, target = args[0], args[1], args[2] 

44 else: 

45 # Fallback to keyword names similar to ATen signature 

46 x1 = kwargs["self"] 

47 x2 = kwargs["other"] 

48 target = kwargs["target"] 

49 

50 # margin and reduction extraction 

51 margin = 0.0 

52 reduction = "mean" 

53 if len(args) >= 4: 

54 margin = args[3] 

55 if len(args) >= 5: 

56 reduction = args[4] 

57 if "margin" in kwargs: 

58 margin = kwargs["margin"] 

59 if "reduction" in kwargs: 

60 reduction = kwargs["reduction"] 

61 

62 # Normalize reduction 

63 if isinstance(reduction, int): 

64 reduction = {0: "none", 1: "mean", 2: "sum"}.get(reduction, "mean") 

65 if reduction not in ("none", "mean", "sum"): 

66 raise ValueError("reduction must be one of 'none', 'mean', or 'sum'") 

67 

68 # Device check and fallback 

69 device = x1.device 

70 if not (isinstance(device, torch.device) and device.type == "cuda"): 

71 # Fallback to PyTorch implementation for non-CUDA tensors 

72 return torch.ops.aten.margin_ranking_loss( 

73 x1, x2, target, float(margin), {"none": 0, "mean": 1, "sum": 2}[reduction] 

74 ) 

75 

76 # Broadcast tensors 

77 x1_b, x2_b, tgt_b = torch.broadcast_tensors(x1, x2, target) 

78 

79 # Choose dtype (prefer input dtype; fall back to float32 if non-floating) 

80 common_dtype = x1_b.dtype if x1_b.is_floating_point() else torch.float32 

81 x1_b = x1_b.to(dtype=common_dtype) 

82 x2_b = x2_b.to(dtype=common_dtype) 

83 tgt_b = tgt_b.to(dtype=common_dtype) 

84 

85 # Flatten contiguous buffers 

86 x1_c = x1_b.contiguous().view(-1) 

87 x2_c = x2_b.contiguous().view(-1) 

88 tgt_c = tgt_b.contiguous().view(-1) 

89 

90 # Output buffer 

91 out = torch.empty_like(x1_c) 

92 

93 n_elements = out.numel() 

94 if n_elements == 0: 

95 # Handle empty tensors 

96 if reduction == "none": 

97 return out.view(x1_b.shape) 

98 elif reduction == "sum": 

99 return out.sum() 

100 else: 

101 return out.mean() 

102 

103 # Launch Triton kernel 

104 BLOCK_SIZE = 1024 

105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

106 _margin_ranking_loss_kernel[grid]( 

107 x1_c, x2_c, tgt_c, out, n_elements, float(margin), BLOCK_SIZE=BLOCK_SIZE 

108 ) 

109 

110 # Apply reduction 

111 if reduction == "none": 

112 return out.view(x1_b.shape) 

113 elif reduction == "sum": 

114 return out.sum() 

115 else: 

116 return out.mean()