Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/full.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.shape_utils import volume
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@triton.jit(do_not_specialize=["fill_value_or_ptr"])
16def full_kernel(
17 output_ptr,
18 n_elements,
19 fill_value_or_ptr,
20 FILL_VALUE_IS_PTR: tl.constexpr,
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid = tle.program_id(axis=0)
24 block_start = pid * BLOCK_SIZE
25 offsets = block_start + tl.arange(0, BLOCK_SIZE)
26 mask = offsets < n_elements
27 if FILL_VALUE_IS_PTR:
28 fill_value = tl.load(fill_value_or_ptr)
29 else:
30 fill_value = fill_value_or_ptr
31 tl.store(output_ptr + offsets, fill_value, mask=mask)
34ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64)
35ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64)
38def check_dtype(fill_value, dtype, device):
39 if isinstance(fill_value, bool):
40 if dtype != torch.bool:
41 fill_value = int(fill_value)
43 elif (
44 dtype in ALL_INT_DTYPES
45 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max)
46 ) or (
47 dtype in ALL_FLOAT_DTYPES
48 and not (math.isinf(fill_value) or math.isnan(fill_value))
49 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max)
50 ):
51 raise RuntimeError(
52 f"value cannot be converted to type {dtype} without overflow"
53 )
55 if dtype == torch.float64:
56 fill_value = torch.tensor(fill_value, dtype=dtype, device=device)
58 return fill_value
61def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None):
62 logger.debug("GEMS FULL")
63 if size == [0]:
64 out = torch.empty(size, device=device, dtype=dtype)
65 return out
67 if device is None:
68 device = torch.device("cpu")
69 if dtype is None:
70 if isinstance(fill_value, bool):
71 dtype = torch.bool
72 elif isinstance(fill_value, int):
73 dtype = torch.int64
74 else:
75 dtype = torch.get_default_dtype()
76 else:
77 fill_value = check_dtype(fill_value, dtype, device)
79 out = torch.empty(size, device=device, dtype=dtype)
80 N = volume(size)
81 grid_fn = (12, 1, 1)
82 block_size = triton.next_power_of_2(triton.cdiv(N, 12))
83 with torch_device_fn.device(device):
84 full_kernel[grid_fn](
85 out,
86 N,
87 fill_value,
88 FILL_VALUE_IS_PTR=isinstance(fill_value, torch.Tensor),
89 BLOCK_SIZE=block_size,
90 buffer_size_limit=2048,
91 isCloseDtypeConvert=True,
92 )
93 return out