Coverage for src/flag_gems/experimental_ops/triu.py: 0%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def triu_kernel( 

8 in_ptr, 

9 out_ptr, 

10 M, 

11 N, 

12 B, # matrix rows, cols, number of batches 

13 stride_in_b, 

14 stride_in_m, 

15 stride_in_n, 

16 stride_out_b, 

17 stride_out_m, 

18 stride_out_n, 

19 diagonal: tl.constexpr, 

20 BLOCK_M: tl.constexpr, 

21 BLOCK_N: tl.constexpr, 

22): 

23 pid_m = tl.program_id(0) 

24 pid_n = tl.program_id(1) 

25 pid_b = tl.program_id(2) 

26 

27 row = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

28 col = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

29 

30 mask = (row[:, None] < M) & (col[None, :] < N) 

31 

32 # keep if col - row >= diagonal 

33 keep = (col[None, :] - row[:, None]) >= diagonal 

34 

35 row_i64 = row[:, None].to(tl.int64) 

36 col_i64 = col[None, :].to(tl.int64) 

37 

38 base_in = in_ptr + pid_b.to(tl.int64) * stride_in_b 

39 base_out = out_ptr + pid_b.to(tl.int64) * stride_out_b 

40 

41 in_offsets = row_i64 * stride_in_m + col_i64 * stride_in_n 

42 out_offsets = row_i64 * stride_out_m + col_i64 * stride_out_n 

43 

44 vals = tl.load(base_in + in_offsets, mask=mask & keep, other=0) 

45 tl.store(base_out + out_offsets, vals, mask=mask) 

46 

47 

48def _check_supported_dtype(t: torch.Tensor): 

49 if t.dtype in ( 

50 torch.complex64, 

51 torch.complex128, 

52 torch.complex32 if hasattr(torch, "complex32") else None, 

53 ): 

54 raise TypeError( 

55 "Complex dtypes are not supported by this Triton triu implementation." 

56 ) 

57 

58 

59def _launch_triu_kernel(inp: torch.Tensor, out: torch.Tensor, diagonal: int): 

60 assert inp.is_cuda and out.is_cuda, "Input and output must be CUDA tensors" 

61 assert inp.dtype == out.dtype, "Input and output dtypes must match" 

62 assert inp.device == out.device, "Input and output must be on the same device" 

63 _check_supported_dtype(inp) 

64 

65 ndim = inp.dim() 

66 assert ndim >= 2, "triu expects input with at least 2 dimensions" 

67 

68 M = inp.shape[-2] 

69 N = inp.shape[-1] 

70 batch_shape = inp.shape[:-2] 

71 B = 1 

72 for s in batch_shape: 

73 B *= s 

74 

75 # Ensure contiguous layout for simplicity 

76 inp_c = inp.contiguous() 

77 out_c = out.contiguous() 

78 

79 # Strides as int64 

80 stride_in_n = inp_c.stride(-1) 

81 stride_in_m = inp_c.stride(-2) 

82 stride_out_n = out_c.stride(-1) 

83 stride_out_m = out_c.stride(-2) 

84 

85 # Batch stride: distance between consecutive matrices in flattened batch 

86 stride_in_b = ( 

87 M * stride_in_m 

88 if len(batch_shape) == 0 

89 else inp_c.stride(-3) * inp_c.size(-3) 

90 if ndim > 2 

91 else M * stride_in_m 

92 ) 

93 stride_out_b = ( 

94 M * stride_out_m 

95 if len(batch_shape) == 0 

96 else out_c.stride(-3) * out_c.size(-3) 

97 if ndim > 2 

98 else M * stride_out_m 

99 ) 

100 

101 # For fully contiguous tensors, the above may not equal true batch stride for high dims. 

102 # Since we used .contiguous(), we can simply set: 

103 if inp_c.is_contiguous(): 

104 stride_in_n = 1 

105 stride_in_m = N 

106 stride_in_b = M * N 

107 if out_c.is_contiguous(): 

108 stride_out_n = 1 

109 stride_out_m = N 

110 stride_out_b = M * N 

111 

112 BLOCK_M = 32 

113 BLOCK_N = 32 

114 

115 grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), B) 

116 

117 triu_kernel[grid]( 

118 inp_c, 

119 out_c, 

120 M, 

121 N, 

122 B, 

123 stride_in_b, 

124 stride_in_m, 

125 stride_in_n, 

126 stride_out_b, 

127 stride_out_m, 

128 stride_out_n, 

129 diagonal=diagonal, 

130 BLOCK_M=BLOCK_M, 

131 BLOCK_N=BLOCK_N, 

132 ) 

133 

134 if out.data_ptr() != out_c.data_ptr(): 

135 out.copy_(out_c) 

136 

137 

138def triu(input: torch.Tensor, diagonal: int = 0): 

139 """ 

140 Wrapper for ATen op: ('triu', <Autograd.disable: False>) 

141 """ 

142 out = torch.empty_like(input) 

143 _launch_triu_kernel(input, out, diagonal) 

144 return out 

145 

146 

147def triu_out(input: torch.Tensor, diagonal: int = 0, out: torch.Tensor = None): 

148 """ 

149 Wrapper for ATen op: ('triu.out', <Autograd.disable: False>) 

150 """ 

151 if out is None: 

152 out = torch.empty_like(input) 

153 else: 

154 if out.shape != input.shape: 

155 raise ValueError( 

156 f"out tensor must have the same shape as input, got {out.shape} vs {input.shape}" 

157 ) 

158 if out.dtype != input.dtype: 

159 raise TypeError( 

160 f"out dtype must match input dtype, got {out.dtype} vs {input.dtype}" 

161 ) 

162 if not out.is_cuda or out.device != input.device: 

163 raise ValueError("out must be a CUDA tensor on the same device as input") 

164 _launch_triu_kernel(input, out, diagonal) 

165 return out