Coverage for src/flag_gems/experimental_ops/sinc.py: 0%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def sinc_kernel_fp32(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
10 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask) # fp32
13 y = x * 3.141592653589793
14 siny = tl.sin(y)
15 val = siny / y
16 out = tl.where(x == 0.0, 1.0, val)
18 tl.store(out_ptr + offsets, out, mask=mask)
21def sinc(input: torch.Tensor):
22 x_fp32 = input.contiguous().to(torch.float32)
23 out_fp32 = torch.empty_like(x_fp32)
24 n_elements = x_fp32.numel()
25 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
26 sinc_kernel_fp32[grid](x_fp32, out_fp32, n_elements, BLOCK_SIZE=1024)
28 if input.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
29 return out_fp32.to(input.dtype)
30 else:
31 return out_fp32
34def sinc_out(input: torch.Tensor, out: torch.Tensor):
35 x_fp32 = input.contiguous().to(torch.float32)
36 out_fp32 = torch.empty_like(x_fp32)
37 n_elements = x_fp32.numel()
38 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
39 sinc_kernel_fp32[grid](x_fp32, out_fp32, n_elements, BLOCK_SIZE=1024)
41 out.copy_(out_fp32.to(out.dtype))
42 return out