Coverage for src/flag_gems/ops/elu.py: 81%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(
12 is_tensor=[True, False, False, False], promotion_methods=[(0, "DEFAULT")]
13)
14@triton.jit
15def elu_forward_kernel(x, alpha, scale, input_scale):
16 return tl.where(
17 x > 0,
18 scale * input_scale * x,
19 scale * alpha * (tl.exp(x.to(tl.float32) * input_scale) - 1),
20 )
23@pointwise_dynamic(
24 is_tensor=[True, False, False, False, True], promotion_methods=[(0, 4, "DEFAULT")]
25)
26@triton.jit
27def elu_backward_kernel_with_self(grad_output, alpha, scale, input_scale, x):
28 x_fp32 = x.to(tl.float32)
29 grad_input = tl.where(
30 x > 0,
31 grad_output * scale * input_scale,
32 grad_output * (scale * alpha * tl.exp(x_fp32 * input_scale) * input_scale),
33 )
34 return grad_input
37@pointwise_dynamic(
38 is_tensor=[True, False, False, False, True], promotion_methods=[(0, 4, "DEFAULT")]
39)
40@triton.jit
41def elu_backward_kernel_with_result(grad_output, alpha, scale, input_scale, y):
42 grad_input = tl.where(
43 y > 0,
44 grad_output * scale * input_scale,
45 grad_output * ((y + scale * alpha) * input_scale),
46 )
47 return grad_input
50def elu(A, alpha=1.0, scale=1.0, input_scale=1.0):
51 logger.debug("GEMS ELU")
52 return elu_forward_kernel(A, alpha, scale, input_scale)
55def elu_(A, alpha=1.0, scale=1.0, input_scale=1.0):
56 logger.debug("GEMS ELU_")
57 return elu_forward_kernel(A, alpha, scale, input_scale, out0=A)
60def elu_backward(grad_output, alpha, scale, input_scale, is_result, self_or_result):
61 logger.debug("GEMS ELU BACKWARD")
62 if is_result:
63 return elu_backward_kernel_with_result(
64 grad_output, alpha, scale, input_scale, self_or_result
65 )
66 else:
67 return elu_backward_kernel_with_self(
68 grad_output, alpha, scale, input_scale, self_or_result
69 )