Coverage for src/flag_gems/experimental_ops/tril.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _tril_kernel(
8 in_ptr, out_ptr, M, N, B, diag, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
9):
10 pid_m = tl.program_id(0)
11 pid_n = tl.program_id(1)
12 pid_b = tl.program_id(2)
14 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
15 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
16 mask = (offs_m < M) & (offs_n < N)
18 base = pid_b * M * N
19 idxs = base + offs_m * N + offs_n
21 x = tl.load(in_ptr + idxs, mask=mask, other=0)
22 keep = offs_n <= (offs_m + diag)
23 y = tl.where(keep, x, 0)
24 tl.store(out_ptr + idxs, y, mask=mask)
27def _launch_tril_kernel(input: torch.Tensor, out: torch.Tensor, diagonal: int):
28 assert input.is_cuda and out.is_cuda, "Input and output must be CUDA tensors"
29 assert (
30 input.is_contiguous() and out.is_contiguous()
31 ), "Only contiguous tensors are supported"
32 assert input.shape == out.shape, "Input and output must have the same shape"
33 assert input.dtype == out.dtype, "Input and output must have the same dtype"
35 if input.dim() < 2:
36 out.copy_(input)
37 return out
39 M = input.size(-2)
40 N = input.size(-1)
41 B = input.numel() // (M * N)
43 if M == 0 or N == 0 or B == 0:
44 # Nothing to compute
45 return out
47 BLOCK_M = 32
48 BLOCK_N = 32
49 grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), B)
51 _tril_kernel[grid](
52 input,
53 out,
54 M,
55 N,
56 B,
57 int(diagonal),
58 BLOCK_M=BLOCK_M,
59 BLOCK_N=BLOCK_N,
60 num_warps=4,
61 )
62 return out
65def tril(input: torch.Tensor, diagonal: int = 0):
66 out = torch.empty_like(input)
67 return _launch_tril_kernel(input, out, diagonal)
70def tril_out(input: torch.Tensor, diagonal: int = 0, out: torch.Tensor = None):
71 if out is None:
72 out = torch.empty_like(input)
73 _launch_tril_kernel(input, out, diagonal)
74 return out