Coverage for src/flag_gems/utils/type_utils.py: 100%
8 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import torch
2from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, elementwise_dtypes
5def type_promotion(*args, type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND):
6 computation_dtype, result_dtype = elementwise_dtypes(
7 *args,
8 type_promotion_kind=type_promotion,
9 )
10 return computation_dtype, result_dtype
13_accumulator_dtype_map = {
14 torch.bfloat16: torch.float32,
15 torch.float16: torch.float32,
16 torch.complex32: torch.complex64,
17}
20def get_accumulator_dtype(dtype: torch.dtype) -> torch.dtype:
21 return _accumulator_dtype_map.get(dtype, dtype)