Coverage for src/flag_gems/experimental_ops/triu.py: 0%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def triu_kernel(
8 in_ptr,
9 out_ptr,
10 M,
11 N,
12 B, # matrix rows, cols, number of batches
13 stride_in_b,
14 stride_in_m,
15 stride_in_n,
16 stride_out_b,
17 stride_out_m,
18 stride_out_n,
19 diagonal: tl.constexpr,
20 BLOCK_M: tl.constexpr,
21 BLOCK_N: tl.constexpr,
22):
23 pid_m = tl.program_id(0)
24 pid_n = tl.program_id(1)
25 pid_b = tl.program_id(2)
27 row = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
28 col = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
30 mask = (row[:, None] < M) & (col[None, :] < N)
32 # keep if col - row >= diagonal
33 keep = (col[None, :] - row[:, None]) >= diagonal
35 row_i64 = row[:, None].to(tl.int64)
36 col_i64 = col[None, :].to(tl.int64)
38 base_in = in_ptr + pid_b.to(tl.int64) * stride_in_b
39 base_out = out_ptr + pid_b.to(tl.int64) * stride_out_b
41 in_offsets = row_i64 * stride_in_m + col_i64 * stride_in_n
42 out_offsets = row_i64 * stride_out_m + col_i64 * stride_out_n
44 vals = tl.load(base_in + in_offsets, mask=mask & keep, other=0)
45 tl.store(base_out + out_offsets, vals, mask=mask)
48def _check_supported_dtype(t: torch.Tensor):
49 if t.dtype in (
50 torch.complex64,
51 torch.complex128,
52 torch.complex32 if hasattr(torch, "complex32") else None,
53 ):
54 raise TypeError(
55 "Complex dtypes are not supported by this Triton triu implementation."
56 )
59def _launch_triu_kernel(inp: torch.Tensor, out: torch.Tensor, diagonal: int):
60 assert inp.is_cuda and out.is_cuda, "Input and output must be CUDA tensors"
61 assert inp.dtype == out.dtype, "Input and output dtypes must match"
62 assert inp.device == out.device, "Input and output must be on the same device"
63 _check_supported_dtype(inp)
65 ndim = inp.dim()
66 assert ndim >= 2, "triu expects input with at least 2 dimensions"
68 M = inp.shape[-2]
69 N = inp.shape[-1]
70 batch_shape = inp.shape[:-2]
71 B = 1
72 for s in batch_shape:
73 B *= s
75 # Ensure contiguous layout for simplicity
76 inp_c = inp.contiguous()
77 out_c = out.contiguous()
79 # Strides as int64
80 stride_in_n = inp_c.stride(-1)
81 stride_in_m = inp_c.stride(-2)
82 stride_out_n = out_c.stride(-1)
83 stride_out_m = out_c.stride(-2)
85 # Batch stride: distance between consecutive matrices in flattened batch
86 stride_in_b = (
87 M * stride_in_m
88 if len(batch_shape) == 0
89 else inp_c.stride(-3) * inp_c.size(-3)
90 if ndim > 2
91 else M * stride_in_m
92 )
93 stride_out_b = (
94 M * stride_out_m
95 if len(batch_shape) == 0
96 else out_c.stride(-3) * out_c.size(-3)
97 if ndim > 2
98 else M * stride_out_m
99 )
101 # For fully contiguous tensors, the above may not equal true batch stride for high dims.
102 # Since we used .contiguous(), we can simply set:
103 if inp_c.is_contiguous():
104 stride_in_n = 1
105 stride_in_m = N
106 stride_in_b = M * N
107 if out_c.is_contiguous():
108 stride_out_n = 1
109 stride_out_m = N
110 stride_out_b = M * N
112 BLOCK_M = 32
113 BLOCK_N = 32
115 grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), B)
117 triu_kernel[grid](
118 inp_c,
119 out_c,
120 M,
121 N,
122 B,
123 stride_in_b,
124 stride_in_m,
125 stride_in_n,
126 stride_out_b,
127 stride_out_m,
128 stride_out_n,
129 diagonal=diagonal,
130 BLOCK_M=BLOCK_M,
131 BLOCK_N=BLOCK_N,
132 )
134 if out.data_ptr() != out_c.data_ptr():
135 out.copy_(out_c)
138def triu(input: torch.Tensor, diagonal: int = 0):
139 """
140 Wrapper for ATen op: ('triu', <Autograd.disable: False>)
141 """
142 out = torch.empty_like(input)
143 _launch_triu_kernel(input, out, diagonal)
144 return out
147def triu_out(input: torch.Tensor, diagonal: int = 0, out: torch.Tensor = None):
148 """
149 Wrapper for ATen op: ('triu.out', <Autograd.disable: False>)
150 """
151 if out is None:
152 out = torch.empty_like(input)
153 else:
154 if out.shape != input.shape:
155 raise ValueError(
156 f"out tensor must have the same shape as input, got {out.shape} vs {input.shape}"
157 )
158 if out.dtype != input.dtype:
159 raise TypeError(
160 f"out dtype must match input dtype, got {out.dtype} vs {input.dtype}"
161 )
162 if not out.is_cuda or out.device != input.device:
163 raise ValueError("out must be a CUDA tensor on the same device as input")
164 _launch_triu_kernel(input, out, diagonal)
165 return out