Coverage for src/flag_gems/experimental_ops/cos_.py: 0%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def cos_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

13 x_fp32 = x.to(tl.float32) 

14 y = tl.cos(x_fp32) 

15 y = y.to(x.dtype) 

16 tl.store(x_ptr + offsets, y, mask=mask) 

17 

18 

19# Preserve reference to the kernel before defining the wrapper with the same name. 

20cos__kernel = cos_ 

21 

22 

23def cos_(*args, **kwargs): 

24 # Expect a single tensor input, similar to torch.ops.aten.cos_ 

25 x = None 

26 if len(args) == 1 and isinstance(args[0], torch.Tensor): 

27 x = args[0] 

28 elif "input" in kwargs and isinstance(kwargs["input"], torch.Tensor): 

29 x = kwargs["input"] 

30 else: 

31 raise TypeError( 

32 "cos_ expects a single Tensor argument (positional or keyword 'input')." 

33 ) 

34 

35 # Fallback to PyTorch for unsupported cases 

36 if ( 

37 (not x.is_cuda) 

38 or (not x.is_contiguous()) 

39 or ( 

40 x.dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

41 ) 

42 ): 

43 return torch.cos_(x) 

44 

45 n_elements = x.numel() 

46 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

47 cos__kernel[grid](x, n_elements, BLOCK_SIZE=1024) 

48 return x