Coverage for src/flag_gems/runtime/backend/_metax/ops/full.py: 0%
66 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import triton_lang_extension as tle
10from flag_gems.utils.shape_utils import volume
12logger = logging.getLogger("flag_gems." + __name__)
15@triton.jit(do_not_specialize=["fill_value_or_ptr"])
16def full_kernel(
17 output_ptr,
18 n_elements,
19 fill_value_or_ptr,
20 FILL_VALUE_IS_PTR: tl.constexpr,
21 BLOCK_SIZE: tl.constexpr,
22):
23 pid = tle.program_id(axis=0)
24 block_start = pid * BLOCK_SIZE
25 offsets = block_start + tl.arange(0, BLOCK_SIZE)
26 mask = offsets < n_elements
27 if FILL_VALUE_IS_PTR:
28 fill_value = tl.load(fill_value_or_ptr)
29 else:
30 fill_value = fill_value_or_ptr
31 tl.store(output_ptr + offsets, fill_value, mask=mask)
34@triton.autotune(
35 configs=[
36 triton.Config({"BLOCK_SIZE": k}, num_warps=w, num_stages=4)
37 for w in [2, 4, 8, 16] # maca support up to 16
38 for k in [1024, 2048, 4096, 8192]
39 ],
40 key=[
41 "N",
42 ],
43)
44@triton.jit()
45def full_kernel_scale(
46 output_ptr,
47 N,
48 fill_value,
49 BLOCK_SIZE: tl.constexpr,
50):
51 pid = tle.program_id(axis=0)
52 block_start = pid * BLOCK_SIZE
53 offsets = block_start + tl.arange(0, BLOCK_SIZE)
54 mask = offsets < N
55 tl.store(output_ptr + offsets, fill_value, mask=mask)
58ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64)
59ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64)
62def check_dtype(fill_value, dtype, device):
63 if isinstance(fill_value, bool):
64 if dtype != torch.bool:
65 fill_value = int(fill_value)
66 elif (
67 dtype in ALL_INT_DTYPES
68 and (fill_value < torch.iinfo(dtype).min or fill_value > torch.iinfo(dtype).max)
69 ) or (
70 dtype in ALL_FLOAT_DTYPES
71 and not (math.isinf(fill_value) or math.isnan(fill_value))
72 and (fill_value < torch.finfo(dtype).min or fill_value > torch.finfo(dtype).max)
73 ):
74 raise RuntimeError(
75 f"value cannot be converted to type {dtype} without overflow"
76 )
77 if dtype in [torch.double]:
78 fill_value = torch.tensor(fill_value, dtype=dtype, device=device)
80 return fill_value
83def full_(out, N, dtype, device, fill_value):
84 FILL_VALUE_IS_PTR = isinstance(fill_value, torch.Tensor)
85 is_scale = True
86 if FILL_VALUE_IS_PTR:
87 is_scale = False
89 if FILL_VALUE_IS_PTR and fill_value.numel() == 1 and dtype not in [torch.double]:
90 fill_value = fill_value.item()
91 is_scale = True
93 grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
94 with torch_device_fn.device(device):
95 if is_scale:
96 full_kernel_scale[grid_fn](
97 out,
98 N,
99 fill_value,
100 )
101 else:
102 full_kernel[grid_fn](
103 out,
104 N,
105 fill_value,
106 FILL_VALUE_IS_PTR=FILL_VALUE_IS_PTR,
107 BLOCK_SIZE=1024,
108 )
109 return out
112def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None):
113 logger.debug("METAX GEMS FULL")
114 if device is None:
115 device = torch.device("cpu")
116 if dtype is None:
117 if isinstance(fill_value, bool):
118 dtype = torch.bool
119 elif isinstance(fill_value, int):
120 dtype = torch.int64
121 else:
122 dtype = torch.get_default_dtype()
123 else:
124 fill_value = check_dtype(fill_value, dtype, device)
126 out = torch.empty(size, device=device, dtype=dtype)
127 N = volume(size)
129 return full_(out, N, dtype, device, fill_value)