Coverage for src/flag_gems/ops/glu.py: 79%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9logger = logging.getLogger(__name__) 

10exp = tl_extra_shim.exp 

11 

12 

13@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

14@triton.jit 

15def glu_kernel(a, b): 

16 sigmoid_b = 1 / (1 + exp(-b.to(tl.float32))) 

17 result = a * sigmoid_b 

18 

19 return result 

20 

21 

22@pointwise_dynamic( 

23 promotion_methods=[ 

24 (0, 1, 2, "DEFAULT"), 

25 (0, 1, 2, "DEFAULT"), 

26 ] 

27) 

28@triton.jit 

29def glu_backward_kernel(grad_output, a, b): 

30 sigmoid_b = 1 / (1 + exp(-b.to(tl.float32))) 

31 da = grad_output * sigmoid_b 

32 db = grad_output.to(tl.float32) * a * sigmoid_b * (1.0 - sigmoid_b) 

33 

34 return da, db 

35 

36 

37def glu(self, dim=-1): 

38 assert self.shape[dim] % 2 == 0, "Split dimension must be even" 

39 logger.debug("GLU FORWARD") 

40 # Split into a and b 

41 a, b = torch.chunk(self, 2, dim=dim) 

42 out = glu_kernel(a, b) 

43 

44 return out 

45 

46 

47def glu_backward(grad_output, self, dim=-1): 

48 assert self.shape[dim] % 2 == 0, "Split dimension must be even" 

49 logger.debug("GEMS GLU BACKWARD") 

50 # Recreate a and b 

51 a, b = torch.chunk(self, 2, dim=dim) 

52 grad_input = torch.empty_like(self, memory_format=torch.contiguous_format) 

53 grad_a, grad_b = torch.chunk(grad_input, 2, dim=dim) 

54 glu_backward_kernel(grad_output, a, b, out0=grad_a, out1=grad_b) 

55 

56 return grad_input