Coverage for src/flag_gems/ops/tril.py: 51%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import torch
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
9@triton.jit
10def _tril_kernel(
11 in_ptr, out_ptr, M, N, B, diag, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
12):
13 pid_m = tl.program_id(0)
14 pid_n = tl.program_id(1)
15 pid_b = tl.program_id(2)
17 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
18 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
19 mask = (offs_m < M) & (offs_n < N)
21 base = pid_b * M * N
22 idxs = base + offs_m * N + offs_n
24 x = tl.load(in_ptr + idxs, mask=mask, other=0)
25 keep = offs_n <= (offs_m + diag)
26 y = tl.where(keep, x, 0)
27 tl.store(out_ptr + idxs, y, mask=mask)
30def tril(input: torch.Tensor, diagonal: int = 0):
31 assert input.dim() >= 2, "Input tensor must have at least 2 dimensions"
33 input = input.contiguous()
34 out = torch.empty_like(input)
36 M = input.size(-2)
37 N = input.size(-1)
38 B = input.numel() // (M * N)
40 if M == 0 or N == 0 or B == 0:
41 return out
43 BLOCK_M = 32
44 BLOCK_N = 32
45 grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), B)
47 with torch_device_fn.device(input.device):
48 _tril_kernel[grid](
49 input,
50 out,
51 M,
52 N,
53 B,
54 int(diagonal),
55 BLOCK_M=BLOCK_M,
56 BLOCK_N=BLOCK_N,
57 num_warps=4,
58 )
59 return out
62def tril_out(input: torch.Tensor, diagonal: int = 0, out: torch.Tensor = None):
63 if out is None:
64 out = torch.empty_like(input)
65 assert out.shape == input.shape, "Input and output must have the same shape"
66 assert out.dtype == input.dtype, "Input and output must have the same dtype"
67 result = tril(input, diagonal)
68 out.copy_(result)
69 return out