Coverage for src/flag_gems/modules/activation.py: 73%
15 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import torch.nn as nn
6import flag_gems
8logger = logging.getLogger(__name__)
10has_c_extension = False # Disable C extension for now, as we have not implemented c++ wrapper for silu_and_mul yet.
12__all__ = [
13 "gems_silu_and_mul",
14 "GemsSiluAndMul",
15]
18def gems_silu_and_mul(
19 x: torch.Tensor,
20 y: torch.Tensor,
21) -> torch.Tensor:
22 logger.debug("GEMS CUSTOM SILU_AND_MUL FORWARD")
23 # TODO: Implement C++ wrapper for silu_and_mul
24 return flag_gems.silu_and_mul(x, y)
27class GemsSiluAndMul(nn.Module):
28 """
29 Fused Silu and Mul activation function.
30 The function computes torch.mul(torch.nn.functional.silu(x), y) in a fused way.
31 """
33 def __init__(self):
34 super().__init__()
36 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
37 return gems_silu_and_mul(x, y)