Coverage for src/flag_gems/experimental_ops/sin_.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def sin_(
8 x_ptr, # Pointer to input/output tensor (in-place).
9 n_elements, # Number of elements.
10 BLOCK_SIZE: tl.constexpr, # Elements processed per program.
11):
12 pid = tl.program_id(axis=0)
13 block_start = pid * BLOCK_SIZE
14 offsets = block_start + tl.arange(0, BLOCK_SIZE)
15 mask = offsets < n_elements
17 x = tl.load(x_ptr + offsets, mask=mask, other=0)
18 x_fp32 = x.to(tl.float32)
19 y_fp32 = tl.sin(x_fp32)
20 y = y_fp32.to(x.dtype)
21 tl.store(x_ptr + offsets, y, mask=mask)
24# Keep a reference to the Triton kernel before defining the Python wrapper with the same name.
25sin__kernel = sin_
28def sin_(*args, **kwargs):
29 # Extract the tensor argument similar to aten.sin_
30 x = None
31 if len(args) > 0:
32 x = args[0]
33 else:
34 x = kwargs.get("input", kwargs.get("self", None))
35 if x is None:
36 raise ValueError("sin_ expects a tensor as the first argument")
38 if not x.is_cuda:
39 raise ValueError("Input tensor must be on CUDA device")
40 if not x.is_contiguous():
41 raise ValueError(
42 "Input tensor must be contiguous for this Triton implementation"
43 )
45 # Fallback for unsupported dtypes
46 if not x.is_floating_point() or x.dtype not in (
47 torch.float16,
48 torch.bfloat16,
49 torch.float32,
50 ):
51 # Use PyTorch fallback for unsupported dtypes (e.g., float64, complex)
52 torch.sin_(x)
53 return x
55 n_elements = x.numel()
56 if n_elements == 0:
57 return x
59 BLOCK_SIZE = 1024
60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
61 sin__kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE)
62 return x