Coverage for src/flag_gems/experimental_ops/diag.py: 0%
112 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _diag_extract_kernel(
8 a_ptr, out_ptr, i0, j0, L, stride_row, stride_col, BLOCK_SIZE: tl.constexpr
9):
10 pid = tl.program_id(0)
11 block_start = pid * BLOCK_SIZE
12 offs = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offs < L
14 a_idx = (i0 + offs) * stride_row + (j0 + offs) * stride_col
15 vals = tl.load(a_ptr + a_idx, mask=mask)
16 tl.store(out_ptr + offs, vals, mask=mask)
19@triton.jit
20def _diag_write_kernel(
21 v_ptr, out_ptr, i0, j0, N, stride_row, stride_col, BLOCK_SIZE: tl.constexpr
22):
23 pid = tl.program_id(0)
24 block_start = pid * BLOCK_SIZE
25 offs = block_start + tl.arange(0, BLOCK_SIZE)
26 mask = offs < N
27 out_idx = (i0 + offs) * stride_row + (j0 + offs) * stride_col
28 vals = tl.load(v_ptr + offs, mask=mask)
29 tl.store(out_ptr + out_idx, vals, mask=mask)
32def diag(*args, **kwargs):
33 # Parse input tensor and diagonal
34 if len(args) < 1 or not isinstance(args[0], torch.Tensor):
35 raise TypeError("diag expects a torch.Tensor as the first positional argument")
36 input = args[0]
37 diagonal = 0
38 if len(args) >= 2 and isinstance(args[1], int):
39 diagonal = args[1]
40 elif "diagonal" in kwargs and isinstance(kwargs["diagonal"], int):
41 diagonal = kwargs["diagonal"]
43 if input.dim() == 1:
44 N = input.numel()
45 k = int(diagonal)
46 i0 = max(0, -k)
47 j0 = max(0, k)
48 size = N + abs(k)
49 out = torch.zeros((size, size), dtype=input.dtype, device=input.device)
50 if N > 0:
51 stride_row, stride_col = out.stride()
52 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
53 _diag_write_kernel[grid](
54 input, out, i0, j0, N, stride_row, stride_col, BLOCK_SIZE=1024
55 )
56 return out
57 elif input.dim() == 2:
58 M, Nv = input.shape
59 k = int(diagonal)
60 i0 = max(0, -k)
61 j0 = max(0, k)
62 L = min(M - i0, Nv - j0)
63 L = max(L, 0)
64 out = torch.empty((L,), dtype=input.dtype, device=input.device)
65 if L > 0:
66 stride_row, stride_col = input.stride()
67 grid = lambda meta: (triton.cdiv(L, meta["BLOCK_SIZE"]),)
68 _diag_extract_kernel[grid](
69 input, out, i0, j0, L, stride_row, stride_col, BLOCK_SIZE=1024
70 )
71 return out
72 else:
73 raise RuntimeError("diag expects a 1D or 2D tensor")
76def diag_out(*args, **kwargs):
77 # Supports signatures:
78 # - diag_out(input, diagonal, out)
79 # - diag_out(out, input, diagonal)
80 # - diag_out(input, diagonal, out=...)
81 # - diag_out(input, out=..., diagonal=...)
82 input = None
83 out = None
84 diagonal = 0
86 # Extract out from kwargs if provided
87 if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
88 out = kwargs["out"]
90 # Try positional interpretations
91 if input is None and len(args) >= 1 and isinstance(args[0], torch.Tensor):
92 # Could be (input, diagonal, out) or (out, input, diagonal)
93 if (
94 out is None
95 and len(args) >= 3
96 and isinstance(args[2], torch.Tensor)
97 and isinstance(args[1], int)
98 ):
99 input = args[0]
100 diagonal = int(args[1])
101 out = args[2]
102 elif (
103 out is None
104 and len(args) >= 3
105 and isinstance(args[0], torch.Tensor)
106 and isinstance(args[1], torch.Tensor)
107 and isinstance(args[2], int)
108 ):
109 out = args[0]
110 input = args[1]
111 diagonal = int(args[2])
112 else:
113 # Fallback: treat first tensor as input
114 input = args[0]
115 if len(args) >= 2 and isinstance(args[1], int):
116 diagonal = int(args[1])
118 # Override diagonal from kwargs if provided
119 if "diagonal" in kwargs and isinstance(kwargs["diagonal"], int):
120 diagonal = int(kwargs["diagonal"])
122 if input is None or out is None:
123 raise TypeError("diag_out expects input tensor, diagonal, and out tensor")
125 if input.dim() == 1:
126 N = input.numel()
127 k = int(diagonal)
128 i0 = max(0, -k)
129 j0 = max(0, k)
130 size = N + abs(k)
132 if out.dim() != 2 or out.shape[0] != size or out.shape[1] != size:
133 raise RuntimeError(
134 f"diag_out: expected out shape ({size}, {size}), got {tuple(out.shape)}"
135 )
136 if out.dtype != input.dtype or out.device != input.device:
137 raise RuntimeError("diag_out: out dtype/device must match input")
139 # Zero-fill out and write diagonal
140 if out.numel() > 0:
141 out.zero_()
142 if N > 0:
143 stride_row, stride_col = out.stride()
144 grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
145 _diag_write_kernel[grid](
146 input, out, i0, j0, N, stride_row, stride_col, BLOCK_SIZE=1024
147 )
148 return out
149 elif input.dim() == 2:
150 M, Nv = input.shape
151 k = int(diagonal)
152 i0 = max(0, -k)
153 j0 = max(0, k)
154 L = min(M - i0, Nv - j0)
155 L = max(L, 0)
157 if out.dim() != 1 or out.numel() != L:
158 raise RuntimeError(
159 f"diag_out: expected out shape ({L},), got {tuple(out.shape)}"
160 )
161 if out.dtype != input.dtype or out.device != input.device:
162 raise RuntimeError("diag_out: out dtype/device must match input")
164 if L > 0:
165 stride_row, stride_col = input.stride()
166 grid = lambda meta: (triton.cdiv(L, meta["BLOCK_SIZE"]),)
167 _diag_extract_kernel[grid](
168 input, out, i0, j0, L, stride_row, stride_col, BLOCK_SIZE=1024
169 )
170 return out
171 else:
172 raise RuntimeError("diag_out expects a 1D or 2D input tensor")