Coverage for src/flag_gems/ops/glu.py: 79%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic, tl_extra_shim
9logger = logging.getLogger(__name__)
10exp = tl_extra_shim.exp
13@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")])
14@triton.jit
15def glu_kernel(a, b):
16 sigmoid_b = 1 / (1 + exp(-b.to(tl.float32)))
17 result = a * sigmoid_b
19 return result
22@pointwise_dynamic(
23 promotion_methods=[
24 (0, 1, 2, "DEFAULT"),
25 (0, 1, 2, "DEFAULT"),
26 ]
27)
28@triton.jit
29def glu_backward_kernel(grad_output, a, b):
30 sigmoid_b = 1 / (1 + exp(-b.to(tl.float32)))
31 da = grad_output * sigmoid_b
32 db = grad_output.to(tl.float32) * a * sigmoid_b * (1.0 - sigmoid_b)
34 return da, db
37def glu(self, dim=-1):
38 assert self.shape[dim] % 2 == 0, "Split dimension must be even"
39 logger.debug("GLU FORWARD")
40 # Split into a and b
41 a, b = torch.chunk(self, 2, dim=dim)
42 out = glu_kernel(a, b)
44 return out
47def glu_backward(grad_output, self, dim=-1):
48 assert self.shape[dim] % 2 == 0, "Split dimension must be even"
49 logger.debug("GEMS GLU BACKWARD")
50 # Recreate a and b
51 a, b = torch.chunk(self, 2, dim=dim)
52 grad_input = torch.empty_like(self, memory_format=torch.contiguous_format)
53 grad_a, grad_b = torch.chunk(grad_input, 2, dim=dim)
54 glu_backward_kernel(grad_output, a, b, out0=grad_a, out1=grad_b)
56 return grad_input