Coverage for src/flag_gems/experimental_ops/tril.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _tril_kernel( 

8 in_ptr, out_ptr, M, N, B, diag, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr 

9): 

10 pid_m = tl.program_id(0) 

11 pid_n = tl.program_id(1) 

12 pid_b = tl.program_id(2) 

13 

14 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

15 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

16 mask = (offs_m < M) & (offs_n < N) 

17 

18 base = pid_b * M * N 

19 idxs = base + offs_m * N + offs_n 

20 

21 x = tl.load(in_ptr + idxs, mask=mask, other=0) 

22 keep = offs_n <= (offs_m + diag) 

23 y = tl.where(keep, x, 0) 

24 tl.store(out_ptr + idxs, y, mask=mask) 

25 

26 

27def _launch_tril_kernel(input: torch.Tensor, out: torch.Tensor, diagonal: int): 

28 assert input.is_cuda and out.is_cuda, "Input and output must be CUDA tensors" 

29 assert ( 

30 input.is_contiguous() and out.is_contiguous() 

31 ), "Only contiguous tensors are supported" 

32 assert input.shape == out.shape, "Input and output must have the same shape" 

33 assert input.dtype == out.dtype, "Input and output must have the same dtype" 

34 

35 if input.dim() < 2: 

36 out.copy_(input) 

37 return out 

38 

39 M = input.size(-2) 

40 N = input.size(-1) 

41 B = input.numel() // (M * N) 

42 

43 if M == 0 or N == 0 or B == 0: 

44 # Nothing to compute 

45 return out 

46 

47 BLOCK_M = 32 

48 BLOCK_N = 32 

49 grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N), B) 

50 

51 _tril_kernel[grid]( 

52 input, 

53 out, 

54 M, 

55 N, 

56 B, 

57 int(diagonal), 

58 BLOCK_M=BLOCK_M, 

59 BLOCK_N=BLOCK_N, 

60 num_warps=4, 

61 ) 

62 return out 

63 

64 

65def tril(input: torch.Tensor, diagonal: int = 0): 

66 out = torch.empty_like(input) 

67 return _launch_tril_kernel(input, out, diagonal) 

68 

69 

70def tril_out(input: torch.Tensor, diagonal: int = 0, out: torch.Tensor = None): 

71 if out is None: 

72 out = torch.empty_like(input) 

73 _launch_tril_kernel(input, out, diagonal) 

74 return out