Coverage for src/flag_gems/experimental_ops/diag.py: 0%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _diag_extract_kernel( 

8 a_ptr, out_ptr, i0, j0, L, stride_row, stride_col, BLOCK_SIZE: tl.constexpr 

9): 

10 pid = tl.program_id(0) 

11 block_start = pid * BLOCK_SIZE 

12 offs = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offs < L 

14 a_idx = (i0 + offs) * stride_row + (j0 + offs) * stride_col 

15 vals = tl.load(a_ptr + a_idx, mask=mask) 

16 tl.store(out_ptr + offs, vals, mask=mask) 

17 

18 

19@triton.jit 

20def _diag_write_kernel( 

21 v_ptr, out_ptr, i0, j0, N, stride_row, stride_col, BLOCK_SIZE: tl.constexpr 

22): 

23 pid = tl.program_id(0) 

24 block_start = pid * BLOCK_SIZE 

25 offs = block_start + tl.arange(0, BLOCK_SIZE) 

26 mask = offs < N 

27 out_idx = (i0 + offs) * stride_row + (j0 + offs) * stride_col 

28 vals = tl.load(v_ptr + offs, mask=mask) 

29 tl.store(out_ptr + out_idx, vals, mask=mask) 

30 

31 

32def diag(*args, **kwargs): 

33 # Parse input tensor and diagonal 

34 if len(args) < 1 or not isinstance(args[0], torch.Tensor): 

35 raise TypeError("diag expects a torch.Tensor as the first positional argument") 

36 input = args[0] 

37 diagonal = 0 

38 if len(args) >= 2 and isinstance(args[1], int): 

39 diagonal = args[1] 

40 elif "diagonal" in kwargs and isinstance(kwargs["diagonal"], int): 

41 diagonal = kwargs["diagonal"] 

42 

43 if input.dim() == 1: 

44 N = input.numel() 

45 k = int(diagonal) 

46 i0 = max(0, -k) 

47 j0 = max(0, k) 

48 size = N + abs(k) 

49 out = torch.zeros((size, size), dtype=input.dtype, device=input.device) 

50 if N > 0: 

51 stride_row, stride_col = out.stride() 

52 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

53 _diag_write_kernel[grid]( 

54 input, out, i0, j0, N, stride_row, stride_col, BLOCK_SIZE=1024 

55 ) 

56 return out 

57 elif input.dim() == 2: 

58 M, Nv = input.shape 

59 k = int(diagonal) 

60 i0 = max(0, -k) 

61 j0 = max(0, k) 

62 L = min(M - i0, Nv - j0) 

63 L = max(L, 0) 

64 out = torch.empty((L,), dtype=input.dtype, device=input.device) 

65 if L > 0: 

66 stride_row, stride_col = input.stride() 

67 grid = lambda meta: (triton.cdiv(L, meta["BLOCK_SIZE"]),) 

68 _diag_extract_kernel[grid]( 

69 input, out, i0, j0, L, stride_row, stride_col, BLOCK_SIZE=1024 

70 ) 

71 return out 

72 else: 

73 raise RuntimeError("diag expects a 1D or 2D tensor") 

74 

75 

76def diag_out(*args, **kwargs): 

77 # Supports signatures: 

78 # - diag_out(input, diagonal, out) 

79 # - diag_out(out, input, diagonal) 

80 # - diag_out(input, diagonal, out=...) 

81 # - diag_out(input, out=..., diagonal=...) 

82 input = None 

83 out = None 

84 diagonal = 0 

85 

86 # Extract out from kwargs if provided 

87 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor): 

88 out = kwargs["out"] 

89 

90 # Try positional interpretations 

91 if input is None and len(args) >= 1 and isinstance(args[0], torch.Tensor): 

92 # Could be (input, diagonal, out) or (out, input, diagonal) 

93 if ( 

94 out is None 

95 and len(args) >= 3 

96 and isinstance(args[2], torch.Tensor) 

97 and isinstance(args[1], int) 

98 ): 

99 input = args[0] 

100 diagonal = int(args[1]) 

101 out = args[2] 

102 elif ( 

103 out is None 

104 and len(args) >= 3 

105 and isinstance(args[0], torch.Tensor) 

106 and isinstance(args[1], torch.Tensor) 

107 and isinstance(args[2], int) 

108 ): 

109 out = args[0] 

110 input = args[1] 

111 diagonal = int(args[2]) 

112 else: 

113 # Fallback: treat first tensor as input 

114 input = args[0] 

115 if len(args) >= 2 and isinstance(args[1], int): 

116 diagonal = int(args[1]) 

117 

118 # Override diagonal from kwargs if provided 

119 if "diagonal" in kwargs and isinstance(kwargs["diagonal"], int): 

120 diagonal = int(kwargs["diagonal"]) 

121 

122 if input is None or out is None: 

123 raise TypeError("diag_out expects input tensor, diagonal, and out tensor") 

124 

125 if input.dim() == 1: 

126 N = input.numel() 

127 k = int(diagonal) 

128 i0 = max(0, -k) 

129 j0 = max(0, k) 

130 size = N + abs(k) 

131 

132 if out.dim() != 2 or out.shape[0] != size or out.shape[1] != size: 

133 raise RuntimeError( 

134 f"diag_out: expected out shape ({size}, {size}), got {tuple(out.shape)}" 

135 ) 

136 if out.dtype != input.dtype or out.device != input.device: 

137 raise RuntimeError("diag_out: out dtype/device must match input") 

138 

139 # Zero-fill out and write diagonal 

140 if out.numel() > 0: 

141 out.zero_() 

142 if N > 0: 

143 stride_row, stride_col = out.stride() 

144 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) 

145 _diag_write_kernel[grid]( 

146 input, out, i0, j0, N, stride_row, stride_col, BLOCK_SIZE=1024 

147 ) 

148 return out 

149 elif input.dim() == 2: 

150 M, Nv = input.shape 

151 k = int(diagonal) 

152 i0 = max(0, -k) 

153 j0 = max(0, k) 

154 L = min(M - i0, Nv - j0) 

155 L = max(L, 0) 

156 

157 if out.dim() != 1 or out.numel() != L: 

158 raise RuntimeError( 

159 f"diag_out: expected out shape ({L},), got {tuple(out.shape)}" 

160 ) 

161 if out.dtype != input.dtype or out.device != input.device: 

162 raise RuntimeError("diag_out: out dtype/device must match input") 

163 

164 if L > 0: 

165 stride_row, stride_col = input.stride() 

166 grid = lambda meta: (triton.cdiv(L, meta["BLOCK_SIZE"]),) 

167 _diag_extract_kernel[grid]( 

168 input, out, i0, j0, L, stride_row, stride_col, BLOCK_SIZE=1024 

169 ) 

170 return out 

171 else: 

172 raise RuntimeError("diag_out expects a 1D or 2D input tensor")