Coverage for src/flag_gems/ops/acos.py: 93%

14 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

7 

8_acos = tl_extra_shim.acos 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

13@triton.jit() 

14def acos_kernel(x): 

15 # TODO: use flag_gems.utils.tl_extra_shim help apis 

16 return _acos(x.to(tl.float32)) 

17 

18 

19def acos(x): 

20 logger.debug("GEMS ACOS FORWARD") 

21 y = acos_kernel(x) 

22 return y