Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/glu.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-15 02:11 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import tl_extra_shim 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12exp = tl_extra_shim.exp 

13 

14 

15@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

16@triton.jit 

17def glu_kernel(a, b): 

18 sigmoid_b = 1 / (1 + tl.exp(-b.to(tl.float32))) 

19 result = a * sigmoid_b 

20 

21 return result 

22 

23 

24@pointwise_dynamic( 

25 promotion_methods=[ 

26 (0, 1, 2, "DEFAULT"), 

27 (0, 1, 2, "DEFAULT"), 

28 ] 

29) 

30@triton.jit 

31def glu_backward_kernel(grad_output, a, b): 

32 sigmoid_b = 1 / (1 + tl.exp(-b.to(tl.float32))) 

33 da = grad_output * sigmoid_b 

34 db = grad_output.to(tl.float32) * a * sigmoid_b * (1.0 - sigmoid_b) 

35 

36 return da, db 

37 

38 

39def glu(self, dim=-1): 

40 assert self.shape[dim] % 2 == 0, "Split dimension must be even" 

41 logger.debug("GLU FORWARD") 

42 # Split into a and b 

43 a, b = torch.chunk(self, 2, dim=dim) 

44 out = glu_kernel(a, b) 

45 

46 return out 

47 

48 

49def glu_backward(grad_output, self, dim=-1): 

50 assert self.shape[dim] % 2 == 0, "Split dimension must be even" 

51 logger.debug("GEMS GLU BACKWARD") 

52 # Recreate a and b 

53 a, b = torch.chunk(self, 2, dim=dim) 

54 grad_input = torch.empty_like(self, memory_format=torch.contiguous_format) 

55 grad_a, grad_b = torch.chunk(grad_input, 2, dim=dim) 

56 glu_backward_kernel(grad_output, a, b, out0=grad_a, out1=grad_b) 

57 

58 return grad_input