Coverage for src/flag_gems/utils/type_utils.py: 100%

8 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import torch 

2from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, elementwise_dtypes 

3 

4 

5def type_promotion(*args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND): 

6 computation_dtype, result_dtype = elementwise_dtypes( 

7 *args, 

8 type_promotion_kind=type_promotion, 

9 ) 

10 return computation_dtype, result_dtype 

11 

12 

13_accumulator_dtype_map = { 

14 torch.bfloat16: torch.float32, 

15 torch.float16: torch.float32, 

16 torch.complex32: torch.complex64, 

17} 

18 

19 

20def get_accumulator_dtype(dtype: torch.dtype) -> torch.dtype: 

21 return _accumulator_dtype_map.get(dtype, dtype)