Coverage for src/flag_gems/ops/soft_margin_loss.py: 38%

125 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def _soft_margin_loss_elementwise_kernel( 

13 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

21 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

22 

23 xf = x.to(tl.float32) 

24 yf = y.to(tl.float32) 

25 z = -xf * yf 

26 absz = tl.abs(z) 

27 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz)) 

28 

29 tl.store(out_ptr + offsets, vals, mask=mask) 

30 

31 

32@triton.jit 

33def _soft_margin_loss_sum_kernel( 

34 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

35): 

36 pid = tl.program_id(axis=0) 

37 block_start = pid * BLOCK_SIZE 

38 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

39 mask = offsets < n_elements 

40 

41 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

42 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

43 

44 xf = x.to(tl.float32) 

45 yf = y.to(tl.float32) 

46 z = -xf * yf 

47 absz = tl.abs(z) 

48 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz)) 

49 vals = tl.where(mask, vals, 0.0) 

50 

51 acc = tl.sum(vals, axis=0) 

52 tl.atomic_add(out_ptr, acc) 

53 

54 

55def _normalize_reduction(reduction): 

56 # Accept both string and enum/int forms: 0=none,1=mean,2=sum 

57 if isinstance(reduction, str): 

58 r = reduction.lower() 

59 if r == "none": 

60 return 0 

61 if r == "mean": 

62 return 1 

63 if r == "sum": 

64 return 2 

65 raise ValueError(f"Invalid reduction: {reduction}") 

66 if isinstance(reduction, int): 

67 if reduction in (0, 1, 2): 

68 return reduction 

69 raise ValueError(f"Invalid reduction int: {reduction}") 

70 raise ValueError(f"Unsupported reduction type: {type(reduction)}") 

71 

72 

73def _check_tensors(input: torch.Tensor, target: torch.Tensor): 

74 if not (input.is_cuda and target.is_cuda): 

75 raise AssertionError( 

76 "soft_margin_loss: input and target must be CUDA tensors for Triton kernel." 

77 ) 

78 if input.device != target.device: 

79 raise AssertionError( 

80 "soft_margin_loss: input and target must be on the same device." 

81 ) 

82 if input.numel() != target.numel(): 

83 raise AssertionError( 

84 "soft_margin_loss: input and target must have the same number of elements." 

85 ) 

86 if not input.is_contiguous(): 

87 input = input.contiguous() 

88 if not target.is_contiguous(): 

89 target = target.contiguous() 

90 return input, target 

91 

92 

93def soft_margin_loss(input: torch.Tensor, target: torch.Tensor, reduction="mean"): 

94 logger.debug("GEMS SOFT_MARGIN_LOSS") 

95 input, target = _check_tensors(input, target) 

96 red = _normalize_reduction(reduction) 

97 n_elements = input.numel() 

98 

99 if red == 0: 

100 # reduction = 'none' 

101 out = torch.empty_like(input) 

102 if n_elements == 0: 

103 return out 

104 BLOCK_SIZE = 1024 

105 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

106 _soft_margin_loss_elementwise_kernel[grid]( 

107 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE 

108 ) 

109 return out 

110 else: 

111 # reduction = 'sum' or 'mean' (1=mean, 2=sum) 

112 if n_elements == 0: 

113 # Follow PyTorch behavior: sum -> 0, mean -> NaN 

114 if red == 2: 

115 return torch.zeros((), device=input.device, dtype=input.dtype) 

116 else: 

117 return torch.full( 

118 (), float("nan"), device=input.device, dtype=input.dtype 

119 ) 

120 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32) 

121 BLOCK_SIZE = 1024 

122 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

123 _soft_margin_loss_sum_kernel[grid]( 

124 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE 

125 ) 

126 if red == 2: 

127 # sum 

128 return tmp_sum.to(dtype=input.dtype) 

129 else: 

130 # mean 

131 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype) 

132 return mean_val 

133 

134 

135def soft_margin_loss_out( 

136 input: torch.Tensor, 

137 target: torch.Tensor, 

138 reduction="mean", 

139 out: torch.Tensor = None, 

140): 

141 logger.debug("GEMS SOFT_MARGIN_LOSS_OUT") 

142 input, target = _check_tensors(input, target) 

143 red = _normalize_reduction(reduction) 

144 n_elements = input.numel() 

145 

146 if out is None: 

147 # Allocate output based on reduction 

148 if red == 0: 

149 out = torch.empty_like(input) 

150 else: 

151 out = torch.empty((), device=input.device, dtype=input.dtype) 

152 else: 

153 if not out.is_cuda: 

154 raise AssertionError("soft_margin_loss_out: out must be a CUDA tensor.") 

155 if red == 0: 

156 if out.numel() != n_elements: 

157 raise AssertionError( 

158 "soft_margin_loss_out: for reduction='none', out must match input shape." 

159 ) 

160 else: 

161 if out.numel() != 1: 

162 raise AssertionError( 

163 "soft_margin_loss_out: for reduction='sum' or 'mean', out must be a scalar tensor." 

164 ) 

165 if out.device != input.device: 

166 raise AssertionError( 

167 "soft_margin_loss_out: out must be on the same device as input." 

168 ) 

169 

170 if red == 0: 

171 if n_elements > 0: 

172 BLOCK_SIZE = 1024 

173 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

174 _soft_margin_loss_elementwise_kernel[grid]( 

175 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE 

176 ) 

177 return out 

178 else: 

179 if n_elements == 0: 

180 if red == 2: 

181 out.fill_(0) 

182 else: 

183 out.fill_(float("nan")) 

184 return out 

185 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32) 

186 BLOCK_SIZE = 1024 

187 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

188 _soft_margin_loss_sum_kernel[grid]( 

189 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE 

190 ) 

191 if red == 2: 

192 out.fill_(tmp_sum.to(dtype=input.dtype)) 

193 else: 

194 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype) 

195 out.fill_(mean_val) 

196 return out