Coverage for src/flag_gems/modules/normalization.py: 43%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1# LayerNorm-related implementation. 

2# References: 

3# - PyTorch: https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/modules/normalization.py#L321 

4# - vLLM: https://github.com/vllm-project/vllm/blob/v0.8.5/vllm/model_executor/layers/layernorm.py#L82 

5# - TransformerEngine: 

6# https://github.com/NVIDIA/TransformerEngine/blob/v2.2.1/transformer_engine/pytorch/module/rmsnorm.py#L16 

7# 

8# Design notes: 

9# - Aligns with PyTorch’s RMSNorm interface for compatibility. 

10# - Works with or without flag_gems C++ wrappers. 

11# - Avoids relying on torch.nn.RMSNorm (introduced in v2.4.0) for broader compatibility. 

12 

13import logging 

14import numbers 

15from typing import List, Optional, Tuple, Union 

16 

17import torch 

18import torch.nn as nn 

19from torch import Size 

20from torch.nn import Parameter, init 

21 

22import flag_gems 

23from flag_gems.config import use_c_extension 

24 

25logger = logging.getLogger(__name__) 

26 

27__all__ = [ 

28 "gems_rms_forward", 

29 "GemsRMSNorm", 

30] 

31 

32 

33def gems_rms_forward( 

34 x: torch.Tensor, residual: Optional[torch.Tensor], weight: torch.Tensor, eps: float 

35) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 

36 add_residual = residual is not None 

37 if add_residual: 

38 if use_c_extension: 

39 logger.debug("GEMS CUSTOM FUSED_ADD_RMS_NORM(C EXTENSION)") 

40 torch.ops.flag_gems.fused_add_rms_norm(x, residual, weight, eps) 

41 return x, residual 

42 else: 

43 logger.debug("GEMS CUSTOM FUSED_ADD_RMS_NORM") 

44 return flag_gems.fused_add_rms_norm( 

45 x, residual, list(weight.size()), weight, eps 

46 ) 

47 else: 

48 if use_c_extension: 

49 logger.debug("GEMS CUSTOM RMS_NORM(C EXTENSION)") 

50 return torch.ops.flag_gems.rms_norm(x, weight, eps) 

51 else: 

52 logger.debug("GEMS CUSTOM RMS_NORM") 

53 return flag_gems.rms_norm(x, list(weight.size()), weight, eps) 

54 

55 

56class GemsRMSNorm(nn.Module): 

57 """ 

58 GemsRMSNorm implementation compatible with both PyTorch and vLLM behavior. 

59 

60 This module directly inherits from `nn.Module` instead of `torch.nn.RMSNorm` 

61 (introduced in PyTorch 2.4.0) to avoid version compatibility issues. 

62 

63 It also supports fused residual addition (`fused_add_rms_norm` behavior), 

64 which PyTorch's RMSNorm does not provide. 

65 """ 

66 

67 __constants__ = ["normalized_shape", "eps", "elementwise_affine"] 

68 normalized_shape: Union[int, List[int], Size] 

69 eps: Optional[float] 

70 elementwise_affine: bool 

71 

72 def __init__( 

73 self, 

74 normalized_shape: List[int], 

75 eps: float = 1e-6, 

76 elementwise_affine: bool = True, 

77 device=None, 

78 dtype=None, 

79 ) -> None: 

80 factory_kwargs = {"device": device, "dtype": dtype} 

81 super().__init__() 

82 if isinstance(normalized_shape, numbers.Integral): 

83 # mypy error: incompatible types in assignment 

84 normalized_shape = (normalized_shape,) # type: ignore[assignment] 

85 self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 

86 self.eps = eps 

87 self.elementwise_affine = elementwise_affine 

88 if self.elementwise_affine: 

89 self.weight = Parameter( 

90 torch.empty(self.normalized_shape, **factory_kwargs) 

91 ) 

92 else: 

93 self.register_parameter("weight", None) 

94 self.reset_parameters() 

95 

96 def reset_parameters(self) -> None: 

97 """ 

98 Resets parameters based on their initialization used in __init__. 

99 """ 

100 if self.elementwise_affine: 

101 init.ones_(self.weight) 

102 

103 def forward( 

104 self, 

105 x: torch.Tensor, 

106 residual: Optional[torch.Tensor] = None, 

107 ) -> torch.Tensor: 

108 """ 

109 Applies RMSNorm to input. If residual is provided, applies 

110 fused residual addition and normalization. 

111 """ 

112 return gems_rms_forward(x, residual, self.weight, self.eps) 

113 

114 def extra_repr(self) -> str: 

115 """ 

116 Extra information about the module. 

117 """ 

118 return ( 

119 "{normalized_shape}, eps={eps}, " 

120 "elementwise_affine={elementwise_affine}".format(**self.__dict__) 

121 )