Coverage for src/flag_gems/runtime/backend/_ascend/ops/threshold.py: 0%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

9 

10 

11@pointwise_dynamic(is_tensor=[True, False, False], promotion_methods=[(0, "DEFAULT")]) 

12@triton.jit 

13def threshold_kernel(self, threshold, value): 

14 return tl.where(self > threshold, self, value) 

15 

16 

17@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

18@triton.jit 

19def threshold_backward_kernel(grad_output, self, threshold): 

20 return tl.where(self > threshold, grad_output, 0) 

21 

22 

23def threshold(self, threshold, value): 

24 logger.debug("GEMS_ASCEND THRESHOLD FORWARD") 

25 output = threshold_kernel(self, threshold, value) 

26 return output 

27 

28 

29def threshold_backward(grad_output, self, threshold): 

30 logger.debug("GEMS_ASCEND THRESHOLD BACKWARD") 

31 grad_input = threshold_backward_kernel(grad_output, self, threshold) 

32 return grad_input