Coverage for src/flag_gems/ops/elu.py: 81%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@pointwise_dynamic( 

12 is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")] 

13) 

14@triton.jit 

15def elu_forward_kernel(x, alpha, scale, input_scale): 

16 return tl.where( 

17 x > 0, 

18 scale * input_scale * x, 

19 scale * alpha * (tl.exp(x.to(tl.float32) * input_scale) - 1), 

20 ) 

21 

22 

23@pointwise_dynamic( 

24 is_tensor=[True, False, False, False, True], promotion_methods=[(0, 4, "DEFAULT")] 

25) 

26@triton.jit 

27def elu_backward_kernel_with_self(grad_output, alpha, scale, input_scale, x): 

28 x_fp32 = x.to(tl.float32) 

29 grad_input = tl.where( 

30 x > 0, 

31 grad_output * scale * input_scale, 

32 grad_output * (scale * alpha * tl.exp(x_fp32 * input_scale) * input_scale), 

33 ) 

34 return grad_input 

35 

36 

37@pointwise_dynamic( 

38 is_tensor=[True, False, False, False, True], promotion_methods=[(0, 4, "DEFAULT")] 

39) 

40@triton.jit 

41def elu_backward_kernel_with_result(grad_output, alpha, scale, input_scale, y): 

42 grad_input = tl.where( 

43 y > 0, 

44 grad_output * scale * input_scale, 

45 grad_output * ((y + scale * alpha) * input_scale), 

46 ) 

47 return grad_input 

48 

49 

50def elu(A, alpha=1.0, scale=1.0, input_scale=1.0): 

51 logger.debug("GEMS ELU") 

52 return elu_forward_kernel(A, alpha, scale, input_scale) 

53 

54 

55def elu_(A, alpha=1.0, scale=1.0, input_scale=1.0): 

56 logger.debug("GEMS ELU_") 

57 return elu_forward_kernel(A, alpha, scale, input_scale, out0=A) 

58 

59 

60def elu_backward(grad_output, alpha, scale, input_scale, is_result, self_or_result): 

61 logger.debug("GEMS ELU BACKWARD") 

62 if is_result: 

63 return elu_backward_kernel_with_result( 

64 grad_output, alpha, scale, input_scale, self_or_result 

65 ) 

66 else: 

67 return elu_backward_kernel_with_self( 

68 grad_output, alpha, scale, input_scale, self_or_result 

69 )