Coverage for src/flag_gems/ops/atan.py: 94%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

7 

8_atan = tl_extra_shim.atan 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

13@triton.jit 

14def atan_kernel(x): 

15 return _atan(x.to(tl.float32)) 

16 

17 

18def atan(A): 

19 logger.debug("GEMS ATAN") 

20 out = atan_kernel(A) 

21 return out 

22 

23 

24def atan_(A): 

25 logger.debug("GEMS ATAN_") 

26 atan_kernel(A, out0=A) 

27 return A