Coverage for src/flag_gems/experimental_ops/soft_margin_loss.py: 0%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _soft_margin_loss_elementwise_kernel( 

8 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

9): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 

15 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

16 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

17 

18 xf = x.to(tl.float32) 

19 yf = y.to(tl.float32) 

20 z = -xf * yf 

21 absz = tl.abs(z) 

22 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz)) 

23 

24 tl.store(out_ptr + offsets, vals, mask=mask) 

25 

26 

27@triton.jit 

28def _soft_margin_loss_sum_kernel( 

29 x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

30): 

31 pid = tl.program_id(axis=0) 

32 block_start = pid * BLOCK_SIZE 

33 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

34 mask = offsets < n_elements 

35 

36 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

37 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

38 

39 xf = x.to(tl.float32) 

40 yf = y.to(tl.float32) 

41 z = -xf * yf 

42 absz = tl.abs(z) 

43 vals = tl.maximum(z, 0.0) + tl.log(1.0 + tl.exp(-absz)) 

44 vals = tl.where(mask, vals, 0.0) 

45 

46 acc = tl.sum(vals, axis=0) 

47 tl.atomic_add(out_ptr, acc) 

48 

49 

50def _normalize_reduction(reduction): 

51 # Accept both string and enum/int forms: 0=none,1=mean,2=sum 

52 if isinstance(reduction, str): 

53 r = reduction.lower() 

54 if r == "none": 

55 return 0 

56 if r == "mean": 

57 return 1 

58 if r == "sum": 

59 return 2 

60 raise ValueError(f"Invalid reduction: {reduction}") 

61 if isinstance(reduction, int): 

62 if reduction in (0, 1, 2): 

63 return reduction 

64 raise ValueError(f"Invalid reduction int: {reduction}") 

65 raise ValueError(f"Unsupported reduction type: {type(reduction)}") 

66 

67 

68def _check_tensors(input: torch.Tensor, target: torch.Tensor): 

69 if not (input.is_cuda and target.is_cuda): 

70 raise AssertionError( 

71 "soft_margin_loss: input and target must be CUDA tensors for Triton kernel." 

72 ) 

73 if input.device != target.device: 

74 raise AssertionError( 

75 "soft_margin_loss: input and target must be on the same device." 

76 ) 

77 if input.numel() != target.numel(): 

78 raise AssertionError( 

79 "soft_margin_loss: input and target must have the same number of elements." 

80 ) 

81 if not input.is_contiguous(): 

82 input = input.contiguous() 

83 if not target.is_contiguous(): 

84 target = target.contiguous() 

85 return input, target 

86 

87 

88def soft_margin_loss(input: torch.Tensor, target: torch.Tensor, reduction="mean"): 

89 input, target = _check_tensors(input, target) 

90 red = _normalize_reduction(reduction) 

91 n_elements = input.numel() 

92 

93 if red == 0: 

94 # reduction = 'none' 

95 out = torch.empty_like(input) 

96 if n_elements == 0: 

97 return out 

98 BLOCK_SIZE = 1024 

99 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

100 _soft_margin_loss_elementwise_kernel[grid]( 

101 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE 

102 ) 

103 return out 

104 else: 

105 # reduction = 'sum' or 'mean' (1=mean, 2=sum) 

106 if n_elements == 0: 

107 # Follow PyTorch behavior: sum -> 0, mean -> NaN 

108 if red == 2: 

109 return torch.zeros((), device=input.device, dtype=input.dtype) 

110 else: 

111 return torch.full( 

112 (), float("nan"), device=input.device, dtype=input.dtype 

113 ) 

114 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32) 

115 BLOCK_SIZE = 1024 

116 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

117 _soft_margin_loss_sum_kernel[grid]( 

118 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE 

119 ) 

120 if red == 2: 

121 # sum 

122 return tmp_sum.to(dtype=input.dtype) 

123 else: 

124 # mean 

125 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype) 

126 return mean_val 

127 

128 

129def soft_margin_loss_out( 

130 input: torch.Tensor, 

131 target: torch.Tensor, 

132 reduction="mean", 

133 out: torch.Tensor = None, 

134): 

135 input, target = _check_tensors(input, target) 

136 red = _normalize_reduction(reduction) 

137 n_elements = input.numel() 

138 

139 if out is None: 

140 # Allocate output based on reduction 

141 if red == 0: 

142 out = torch.empty_like(input) 

143 else: 

144 out = torch.empty((), device=input.device, dtype=input.dtype) 

145 else: 

146 if not out.is_cuda: 

147 raise AssertionError("soft_margin_loss_out: out must be a CUDA tensor.") 

148 if red == 0: 

149 if out.numel() != n_elements: 

150 raise AssertionError( 

151 "soft_margin_loss_out: for reduction='none', out must match input shape." 

152 ) 

153 else: 

154 if out.numel() != 1: 

155 raise AssertionError( 

156 "soft_margin_loss_out: for reduction='sum' or 'mean', out must be a scalar tensor." 

157 ) 

158 if out.device != input.device: 

159 raise AssertionError( 

160 "soft_margin_loss_out: out must be on the same device as input." 

161 ) 

162 

163 if red == 0: 

164 if n_elements > 0: 

165 BLOCK_SIZE = 1024 

166 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

167 _soft_margin_loss_elementwise_kernel[grid]( 

168 input, target, out, n_elements, BLOCK_SIZE=BLOCK_SIZE 

169 ) 

170 return out 

171 else: 

172 if n_elements == 0: 

173 if red == 2: 

174 out.fill_(0) 

175 else: 

176 out.fill_(float("nan")) 

177 return out 

178 tmp_sum = torch.zeros((), device=input.device, dtype=torch.float32) 

179 BLOCK_SIZE = 1024 

180 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

181 _soft_margin_loss_sum_kernel[grid]( 

182 input, target, tmp_sum, n_elements, BLOCK_SIZE=BLOCK_SIZE 

183 ) 

184 if red == 2: 

185 out.fill_(tmp_sum.to(dtype=input.dtype)) 

186 else: 

187 mean_val = (tmp_sum / float(n_elements)).to(dtype=input.dtype) 

188 out.fill_(mean_val) 

189 return out