Coverage for src/flag_gems/modules/normalization.py: 43%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1# LayerNorm-related implementation.
2# References:
3# - PyTorch: https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/modules/normalization.py#L321
4# - vLLM: https://github.com/vllm-project/vllm/blob/v0.8.5/vllm/model_executor/layers/layernorm.py#L82
5# - TransformerEngine:
6# https://github.com/NVIDIA/TransformerEngine/blob/v2.2.1/transformer_engine/pytorch/module/rmsnorm.py#L16
7#
8# Design notes:
9# - Aligns with PyTorch’s RMSNorm interface for compatibility.
10# - Works with or without flag_gems C++ wrappers.
11# - Avoids relying on torch.nn.RMSNorm (introduced in v2.4.0) for broader compatibility.
13import logging
14import numbers
15from typing import List, Optional, Tuple, Union
17import torch
18import torch.nn as nn
19from torch import Size
20from torch.nn import Parameter, init
22import flag_gems
23from flag_gems.config import use_c_extension
25logger = logging.getLogger(__name__)
27__all__ = [
28 "gems_rms_forward",
29 "GemsRMSNorm",
30]
33def gems_rms_forward(
34 x: torch.Tensor, residual: Optional[torch.Tensor], weight: torch.Tensor, eps: float
35) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
36 add_residual = residual is not None
37 if add_residual:
38 if use_c_extension:
39 logger.debug("GEMS CUSTOM FUSED_ADD_RMS_NORM(C EXTENSION)")
40 torch.ops.flag_gems.fused_add_rms_norm(x, residual, weight, eps)
41 return x, residual
42 else:
43 logger.debug("GEMS CUSTOM FUSED_ADD_RMS_NORM")
44 return flag_gems.fused_add_rms_norm(
45 x, residual, list(weight.size()), weight, eps
46 )
47 else:
48 if use_c_extension:
49 logger.debug("GEMS CUSTOM RMS_NORM(C EXTENSION)")
50 return torch.ops.flag_gems.rms_norm(x, weight, eps)
51 else:
52 logger.debug("GEMS CUSTOM RMS_NORM")
53 return flag_gems.rms_norm(x, list(weight.size()), weight, eps)
56class GemsRMSNorm(nn.Module):
57 """
58 GemsRMSNorm implementation compatible with both PyTorch and vLLM behavior.
60 This module directly inherits from `nn.Module` instead of `torch.nn.RMSNorm`
61 (introduced in PyTorch 2.4.0) to avoid version compatibility issues.
63 It also supports fused residual addition (`fused_add_rms_norm` behavior),
64 which PyTorch's RMSNorm does not provide.
65 """
67 __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
68 normalized_shape: Union[int, List[int], Size]
69 eps: Optional[float]
70 elementwise_affine: bool
72 def __init__(
73 self,
74 normalized_shape: List[int],
75 eps: float = 1e-6,
76 elementwise_affine: bool = True,
77 device=None,
78 dtype=None,
79 ) -> None:
80 factory_kwargs = {"device": device, "dtype": dtype}
81 super().__init__()
82 if isinstance(normalized_shape, numbers.Integral):
83 # mypy error: incompatible types in assignment
84 normalized_shape = (normalized_shape,) # type: ignore[assignment]
85 self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
86 self.eps = eps
87 self.elementwise_affine = elementwise_affine
88 if self.elementwise_affine:
89 self.weight = Parameter(
90 torch.empty(self.normalized_shape, **factory_kwargs)
91 )
92 else:
93 self.register_parameter("weight", None)
94 self.reset_parameters()
96 def reset_parameters(self) -> None:
97 """
98 Resets parameters based on their initialization used in __init__.
99 """
100 if self.elementwise_affine:
101 init.ones_(self.weight)
103 def forward(
104 self,
105 x: torch.Tensor,
106 residual: Optional[torch.Tensor] = None,
107 ) -> torch.Tensor:
108 """
109 Applies RMSNorm to input. If residual is provided, applies
110 fused residual addition and normalization.
111 """
112 return gems_rms_forward(x, residual, self.weight, self.eps)
114 def extra_repr(self) -> str:
115 """
116 Extra information about the module.
117 """
118 return (
119 "{normalized_shape}, eps={eps}, "
120 "elementwise_affine={elementwise_affine}".format(**self.__dict__)
121 )