Coverage for src/flag_gems/experimental_ops/sinc.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def sinc_kernel_fp32(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

10 mask = offsets < n_elements 

11 

12 x = tl.load(x_ptr + offsets, mask=mask) # fp32 

13 y = x * 3.141592653589793 

14 siny = tl.sin(y) 

15 val = siny / y 

16 out = tl.where(x == 0.0, 1.0, val) 

17 

18 tl.store(out_ptr + offsets, out, mask=mask) 

19 

20 

21def sinc(input: torch.Tensor): 

22 x_fp32 = input.contiguous().to(torch.float32) 

23 out_fp32 = torch.empty_like(x_fp32) 

24 n_elements = x_fp32.numel() 

25 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

26 sinc_kernel_fp32[grid](x_fp32, out_fp32, n_elements, BLOCK_SIZE=1024) 

27 

28 if input.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): 

29 return out_fp32.to(input.dtype) 

30 else: 

31 return out_fp32 

32 

33 

34def sinc_out(input: torch.Tensor, out: torch.Tensor): 

35 x_fp32 = input.contiguous().to(torch.float32) 

36 out_fp32 = torch.empty_like(x_fp32) 

37 n_elements = x_fp32.numel() 

38 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

39 sinc_kernel_fp32[grid](x_fp32, out_fp32, n_elements, BLOCK_SIZE=1024) 

40 

41 out.copy_(out_fp32.to(out.dtype)) 

42 return out