Coverage for src/flag_gems/experimental_ops/maximum.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5MAX_DIMS = 8 

6BLOCK_SIZE = 1024 

7 

8 

9@triton.jit 

10def maximum_kernel( 

11 a_ptr, 

12 b_ptr, 

13 out_ptr, 

14 n_elements, 

15 s0, 

16 s1, 

17 s2, 

18 s3, 

19 s4, 

20 s5, 

21 s6, 

22 s7, # shape dims 

23 sa0, 

24 sa1, 

25 sa2, 

26 sa3, 

27 sa4, 

28 sa5, 

29 sa6, 

30 sa7, # a strides 

31 sb0, 

32 sb1, 

33 sb2, 

34 sb3, 

35 sb4, 

36 sb5, 

37 sb6, 

38 sb7, # b strides 

39 so0, 

40 so1, 

41 so2, 

42 so3, 

43 so4, 

44 so5, 

45 so6, 

46 so7, # out strides 

47 BLOCK_SIZE: tl.constexpr, 

48): 

49 pid = tl.program_id(axis=0) 

50 block_start = pid * BLOCK_SIZE 

51 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

52 mask = offsets < n_elements 

53 

54 # Use int64 for address calculations 

55 li = offsets.to(tl.int64) 

56 

57 # Compute multi-dimensional indices from linear index (row-major: last dim fastest) 

58 i7 = li % s7 

59 li = li // s7 

60 i6 = li % s6 

61 li = li // s6 

62 i5 = li % s5 

63 li = li // s5 

64 i4 = li % s4 

65 li = li // s4 

66 i3 = li % s3 

67 li = li // s3 

68 i2 = li % s2 

69 li = li // s2 

70 i1 = li % s1 

71 li = li // s1 

72 i0 = li % s0 

73 li = li // s0 

74 

75 # Compute element offsets for each tensor using strides (in elements) 

76 off_a = ( 

77 i0 * sa0 

78 + i1 * sa1 

79 + i2 * sa2 

80 + i3 * sa3 

81 + i4 * sa4 

82 + i5 * sa5 

83 + i6 * sa6 

84 + i7 * sa7 

85 ) 

86 off_b = ( 

87 i0 * sb0 

88 + i1 * sb1 

89 + i2 * sb2 

90 + i3 * sb3 

91 + i4 * sb4 

92 + i5 * sb5 

93 + i6 * sb6 

94 + i7 * sb7 

95 ) 

96 off_o = ( 

97 i0 * so0 

98 + i1 * so1 

99 + i2 * so2 

100 + i3 * so3 

101 + i4 * so4 

102 + i5 * so5 

103 + i6 * so6 

104 + i7 * so7 

105 ) 

106 

107 a_vals = tl.load(a_ptr + off_a, mask=mask, other=0) 

108 b_vals = tl.load(b_ptr + off_b, mask=mask, other=0) 

109 out_vals = tl.maximum(a_vals, b_vals) 

110 tl.store(out_ptr + off_o, out_vals, mask=mask) 

111 

112 

113def _as_tensor_on_device(x, device, dtype=None): 

114 if torch.is_tensor(x): 

115 return ( 

116 x.to(device=device, dtype=dtype) 

117 if (dtype is not None and x.dtype != dtype) or (x.device != device) 

118 else x 

119 ) 

120 return torch.tensor(x, device=device, dtype=dtype) 

121 

122 

123def _broadcast_to_common(a, b): 

124 a_b, b_b = torch.broadcast_tensors(a, b) 

125 return a_b, b_b 

126 

127 

128def _pad_shape_strides(shape, strides): 

129 # Ensure shape dims are at least 1 to avoid div by zero 

130 shape_list = list(shape) 

131 strides_list = list(strides) 

132 nd = len(shape_list) 

133 assert nd <= MAX_DIMS 

134 shape_list = shape_list + [1] * (MAX_DIMS - nd) 

135 strides_list = strides_list + [0] * (MAX_DIMS - nd) 

136 # Triton expects integers 

137 shape_list = [int(s) for s in shape_list] 

138 strides_list = [int(s) for s in strides_list] 

139 return shape_list, strides_list 

140 

141 

142def _launch_maximum_kernel(a, b, out): 

143 # Assumes a and b are broadcastable and already cast to out.dtype and on same device 

144 a_b, b_b = _broadcast_to_common(a, b) 

145 # Make inputs contiguous to avoid negative/irregular strides complications 

146 # Broadcasting uses 0-stride for broadcasted dims; keeping 0-stride is fine 

147 # but handle potential negative/non-standard strides by materializing. 

148 if any(s < 0 for s in a_b.stride()): 

149 a_b = a_b.contiguous() 

150 if any(s < 0 for s in b_b.stride()): 

151 b_b = b_b.contiguous() 

152 

153 out_shape = a_b.shape # == b_b.shape 

154 n_elements = int(a_b.numel()) 

155 if n_elements == 0: 

156 return 

157 

158 # Prepare shape and strides for kernel 

159 shp, sa = _pad_shape_strides(out_shape, a_b.stride()) 

160 _, sb = _pad_shape_strides(out_shape, b_b.stride()) 

161 _, so = _pad_shape_strides(out_shape, out.stride()) 

162 

163 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

164 maximum_kernel[grid]( 

165 a_b, 

166 b_b, 

167 out, 

168 n_elements, 

169 shp[0], 

170 shp[1], 

171 shp[2], 

172 shp[3], 

173 shp[4], 

174 shp[5], 

175 shp[6], 

176 shp[7], 

177 sa[0], 

178 sa[1], 

179 sa[2], 

180 sa[3], 

181 sa[4], 

182 sa[5], 

183 sa[6], 

184 sa[7], 

185 sb[0], 

186 sb[1], 

187 sb[2], 

188 sb[3], 

189 sb[4], 

190 sb[5], 

191 sb[6], 

192 sb[7], 

193 so[0], 

194 so[1], 

195 so[2], 

196 so[3], 

197 so[4], 

198 so[5], 

199 so[6], 

200 so[7], 

201 BLOCK_SIZE=BLOCK_SIZE, 

202 ) 

203 

204 

205def maximum(a, b): 

206 # Determine device 

207 dev = None 

208 if torch.is_tensor(a): 

209 dev = a.device 

210 if torch.is_tensor(b): 

211 dev = b.device if dev is None else dev 

212 if dev is None or dev.type != "cuda": 

213 raise ValueError("maximum expects at least one CUDA tensor as input") 

214 

215 # Determine result dtype per PyTorch promotion rules 

216 res_dtype = torch.result_type(a, b) 

217 a_t = _as_tensor_on_device(a, dev, dtype=res_dtype) 

218 b_t = _as_tensor_on_device(b, dev, dtype=res_dtype) 

219 

220 # Broadcast to determine output shape 

221 a_b, b_b = _broadcast_to_common(a_t, b_t) 

222 out = torch.empty(a_b.shape, device=dev, dtype=res_dtype) 

223 

224 # If out has negative strides or is non-contiguous, compute into a contiguous buffer then copy 

225 if not out.is_contiguous() or any(s < 0 for s in out.stride()): 

226 out_buf = torch.empty_like(out, memory_format=torch.contiguous_format) 

227 _launch_maximum_kernel(a_t, b_t, out_buf) 

228 out.copy_(out_buf) 

229 else: 

230 _launch_maximum_kernel(a_t, b_t, out) 

231 

232 return out 

233 

234 

235def maximum_out(a, b, out): 

236 if not torch.is_tensor(out): 

237 raise TypeError("out must be a torch.Tensor") 

238 if out.device.type != "cuda": 

239 raise ValueError("out tensor must be on CUDA device") 

240 

241 dev = out.device 

242 

243 # Cast inputs to out dtype (following typical .out behavior) 

244 a_t = _as_tensor_on_device(a, dev, dtype=out.dtype) 

245 b_t = _as_tensor_on_device(b, dev, dtype=out.dtype) 

246 

247 # Validate/broadcast shape against out 

248 a_b, b_b = _broadcast_to_common(a_t, b_t) 

249 if tuple(a_b.shape) != tuple(out.shape): 

250 raise ValueError( 

251 f"out shape {tuple(out.shape)} is not broadcast-compatible with inputs shape {tuple(a_b.shape)}" 

252 ) 

253 

254 # If out has negative strides or is non-contiguous, compute into a contiguous buffer then copy 

255 if not out.is_contiguous() or any(s < 0 for s in out.stride()): 

256 out_buf = torch.empty_like(out, memory_format=torch.contiguous_format) 

257 _launch_maximum_kernel(a_t, b_t, out_buf) 

258 out.copy_(out_buf) 

259 else: 

260 _launch_maximum_kernel(a_t, b_t, out) 

261 

262 return out