Coverage for src/flag_gems/runtime/backend/_ascend/ops/full.py: 0%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger(__name__)
13ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64)
14ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64)
17def check_dtype(fill_value, dtype, device):
18 if isinstance(fill_value, bool):
19 if dtype != torch.bool:
20 fill_value = int(fill_value)
22 elif (
23 dtype in ALL_INT_DTYPES
24 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max)
25 ) or (
26 dtype in ALL_FLOAT_DTYPES
27 and not (math.isinf(fill_value) or math.isnan(fill_value))
28 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max)
29 ):
30 raise RuntimeError(
31 f"value cannot be converted to type {dtype} without overflow"
32 )
34 if dtype == torch.float64:
35 fill_value = torch.tensor(fill_value, dtype=dtype, device=device)
37 return fill_value
40@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")])
41@triton.jit
42def full_func(out, fill_value):
43 return fill_value
46@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
47@triton.jit
48def full_func_scalar(out, fill_value):
49 return tl.full(out.shape, fill_value, out.dtype)
52def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None):
53 logger.debug("GEMS_ASCEND FULL")
54 if device is None:
55 device = torch.device("cpu")
56 if dtype is None:
57 if isinstance(fill_value, bool):
58 dtype = torch.bool
59 elif isinstance(fill_value, int):
60 dtype = torch.int64
61 else:
62 dtype = torch.get_default_dtype()
63 else:
64 fill_value = check_dtype(fill_value, dtype, device)
66 out = torch.empty(size, device=device, dtype=dtype)
68 if isinstance(fill_value, torch.Tensor):
69 return full_func(out, fill_value, out0=out)
70 else:
71 return full_func_scalar(out, fill_value, out0=out)