Coverage for src/flag_gems/modules/activation.py: 73%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import torch.nn as nn 

5 

6import flag_gems 

7 

8logger = logging.getLogger(__name__) 

9 

10has_c_extension = False # Disable C extension for now, as we have not implemented c++ wrapper for silu_and_mul yet. 

11 

12__all__ = [ 

13 "gems_silu_and_mul", 

14 "GemsSiluAndMul", 

15] 

16 

17 

18def gems_silu_and_mul( 

19 x: torch.Tensor, 

20 y: torch.Tensor, 

21) -> torch.Tensor: 

22 logger.debug("GEMS CUSTOM SILU_AND_MUL FORWARD") 

23 # TODO: Implement C++ wrapper for silu_and_mul 

24 return flag_gems.silu_and_mul(x, y) 

25 

26 

27class GemsSiluAndMul(nn.Module): 

28 """ 

29 Fused Silu and Mul activation function. 

30 The function computes torch.mul(torch.nn.functional.silu(x), y) in a fused way. 

31 """ 

32 

33 def __init__(self): 

34 super().__init__() 

35 

36 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 

37 return gems_silu_and_mul(x, y)