Coverage for src/flag_gems/ops/tril.py: 51%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7 

8 

9@triton.jit 

10def _tril_kernel( 

11 in_ptr, out_ptr, M, N, B, diag, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr 

12): 

13 pid_m = tl.program_id(0) 

14 pid_n = tl.program_id(1) 

15 pid_b = tl.program_id(2) 

16 

17 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

18 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

19 mask = (offs_m < M) & (offs_n < N) 

20 

21 base = pid_b * M * N 

22 idxs = base + offs_m * N + offs_n 

23 

24 x = tl.load(in_ptr + idxs, mask=mask, other=0) 

25 keep = offs_n <= (offs_m + diag) 

26 y = tl.where(keep, x, 0) 

27 tl.store(out_ptr + idxs, y, mask=mask) 

28 

29 

30def tril(input: torch.Tensor, diagonal: int = 0): 

31 assert input.dim() >= 2, "Input tensor must have at least 2 dimensions" 

32 

33 input = input.contiguous() 

34 out = torch.empty_like(input) 

35 

36 M = input.size(-2) 

37 N = input.size(-1) 

38 B = input.numel() // (M * N) 

39 

40 if M == 0 or N == 0 or B == 0: 

41 return out 

42 

43 BLOCK_M = 32 

44 BLOCK_N = 32 

45 grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), B) 

46 

47 with torch_device_fn.device(input.device): 

48 _tril_kernel[grid]( 

49 input, 

50 out, 

51 M, 

52 N, 

53 B, 

54 int(diagonal), 

55 BLOCK_M=BLOCK_M, 

56 BLOCK_N=BLOCK_N, 

57 num_warps=4, 

58 ) 

59 return out 

60 

61 

62def tril_out(input: torch.Tensor, diagonal: int = 0, out: torch.Tensor = None): 

63 if out is None: 

64 out = torch.empty_like(input) 

65 assert out.shape == input.shape, "Input and output must have the same shape" 

66 assert out.dtype == input.dtype, "Input and output must have the same dtype" 

67 result = tril(input, diagonal) 

68 out.copy_(result) 

69 return out