Coverage for src/flag_gems/runtime/backend/_cambricon/ops/glu.py: 0%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import tl_extra_shim
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger(__name__)
12exp = tl_extra_shim.exp
15@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
16@triton.jit
17def glu_kernel(a, b):
18 sigmoid_b = 1 / (1 + exp(-b.to(tl.float32)))
19 result = a * sigmoid_b
21 return result
24@pointwise_dynamic(
25 promotion_methods=[
26 (0, 1, 2, "DEFAULT"),
27 (0, 1, 2, "DEFAULT"),
28 ]
29)
30@triton.jit
31def glu_backward_kernel(grad_output, a, b):
32 sigmoid_b = 1 / (1 + exp(-b.to(tl.float32)))
33 da = grad_output * sigmoid_b
34 db = grad_output.to(tl.float32) * a * sigmoid_b * (1.0 - sigmoid_b)
36 return da, db
39def glu(self, dim=-1):
40 assert self.shape[dim] % 2 == 0, "Split dimension must be even"
41 logger.debug("GEMS_CAMBRICON GLU FORWARD")
42 # Split into a and b
43 a, b = torch.chunk(self, 2, dim=dim)
44 out = glu_kernel(a, b)
46 return out
49def glu_backward(grad_output, self, dim=-1):
50 assert self.shape[dim] % 2 == 0, "Split dimension must be even"
51 logger.debug("GEMS_CAMBRICON GLU BACKWARD")
52 # Recreate a and b
53 a, b = torch.chunk(self, 2, dim=dim)
54 grad_input = torch.empty_like(self, memory_format=torch.contiguous_format)
55 grad_a, grad_b = torch.chunk(grad_input, 2, dim=dim)
56 glu_backward_kernel(grad_output, a, b, out0=grad_a, out1=grad_b)
58 return grad_input