Coverage for src/flag_gems/experimental_ops/sin_.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def sin_( 

8 x_ptr, # Pointer to input/output tensor (in-place). 

9 n_elements, # Number of elements. 

10 BLOCK_SIZE: tl.constexpr, # Elements processed per program. 

11): 

12 pid = tl.program_id(axis=0) 

13 block_start = pid * BLOCK_SIZE 

14 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

15 mask = offsets < n_elements 

16 

17 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

18 x_fp32 = x.to(tl.float32) 

19 y_fp32 = tl.sin(x_fp32) 

20 y = y_fp32.to(x.dtype) 

21 tl.store(x_ptr + offsets, y, mask=mask) 

22 

23 

24# Keep a reference to the Triton kernel before defining the Python wrapper with the same name. 

25sin__kernel = sin_ 

26 

27 

28def sin_(*args, **kwargs): 

29 # Extract the tensor argument similar to aten.sin_ 

30 x = None 

31 if len(args) > 0: 

32 x = args[0] 

33 else: 

34 x = kwargs.get("input", kwargs.get("self", None)) 

35 if x is None: 

36 raise ValueError("sin_ expects a tensor as the first argument") 

37 

38 if not x.is_cuda: 

39 raise ValueError("Input tensor must be on CUDA device") 

40 if not x.is_contiguous(): 

41 raise ValueError( 

42 "Input tensor must be contiguous for this Triton implementation" 

43 ) 

44 

45 # Fallback for unsupported dtypes 

46 if not x.is_floating_point() or x.dtype not in ( 

47 torch.float16, 

48 torch.bfloat16, 

49 torch.float32, 

50 ): 

51 # Use PyTorch fallback for unsupported dtypes (e.g., float64, complex) 

52 torch.sin_(x) 

53 return x 

54 

55 n_elements = x.numel() 

56 if n_elements == 0: 

57 return x 

58 

59 BLOCK_SIZE = 1024 

60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

61 sin__kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

62 return x