Coverage for src/flag_gems/runtime/backend/_cambricon/ops/full.py: 0%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-13 10:08 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
10from flag_gems.utils.shape_utils import volume
12from ..utils import TOTAL_CORE_NUM
14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17@libentry()
18@libtuner(
19 configs=[
20 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
22 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
23 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
24 ],
25 key=["n_elements"],
26)
27@triton.jit(do_not_specialize=["fill_value_or_ptr"])
28def full_tensor_kernel(
29 output_ptr,
30 n_elements,
31 fill_value_ptr,
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = tl.program_id(axis=0)
35 num_jobs = tl.num_programs(axis=0)
36 block_start = pid * BLOCK_SIZE
37 step = num_jobs * BLOCK_SIZE
38 block_start = block_start.to(tl.int64)
39 for block_start_offset in range(block_start, n_elements, step):
40 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
41 mask = offsets < n_elements
42 fill_value = tl.load(fill_value_ptr)
43 tl.store(output_ptr + offsets, fill_value, mask=mask)
46@libentry()
47@libtuner(
48 configs=[
49 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=1, num_warps=1),
50 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=1, num_warps=1),
51 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=1, num_warps=1),
52 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=1, num_warps=1),
53 ],
54 key=["n_elements"],
55)
56@triton.jit(do_not_specialize=["fill_value_or_ptr"])
57def full_scalar_kernel(
58 output_ptr,
59 n_elements,
60 fill_value,
61 BLOCK_SIZE: tl.constexpr,
62):
63 pid = tl.program_id(axis=0)
64 num_jobs = tl.num_programs(axis=0)
65 block_start = pid * BLOCK_SIZE
66 step = num_jobs * BLOCK_SIZE
67 block_start = block_start.to(tl.int64)
68 for block_start_offset in range(block_start, n_elements, step):
69 offsets = block_start_offset + tl.arange(0, BLOCK_SIZE)
70 mask = offsets < n_elements
71 tl.store(output_ptr + offsets, fill_value, mask=mask)
74ALL_INT_DTYPES = (torch.int8, torch.int16, torch.int32, torch.int64)
75ALL_FLOAT_DTYPES = (torch.bfloat16, torch.float16, torch.float32, torch.float64)
78def check_dtype(fill_value, dtype, device):
79 if isinstance(fill_value, bool):
80 if dtype != torch.bool:
81 fill_value = int(fill_value)
82 else:
83 if isinstance(fill_value, float) and math.isinf(fill_value):
84 if dtype not in ALL_FLOAT_DTYPES:
85 raise RuntimeError(
86 f"value {fill_value!r} cannot be converted to type {dtype} without overflow"
87 )
88 elif (
89 dtype in ALL_INT_DTYPES
90 and (
91 fill_value < torch.iinfo(dtype).min
92 or fill_value > torch.iinfo(dtype).max
93 )
94 ) or (
95 dtype in ALL_FLOAT_DTYPES
96 and (
97 fill_value < torch.finfo(dtype).min
98 or fill_value > torch.finfo(dtype).max
99 )
100 ):
101 raise RuntimeError(
102 f"value cannot be converted to type {dtype} without overflow"
103 )
104 if dtype in ALL_FLOAT_DTYPES:
105 fill_value = torch.tensor(fill_value, dtype=dtype, device=device)
106 return fill_value
109def full(size, fill_value, *, dtype=None, layout=None, device=None, pin_memory=None):
110 logger.debug("GEMS_CAMBRICON FULL")
111 if device is None:
112 device = torch.device("cpu")
113 if dtype is None:
114 if isinstance(fill_value, bool):
115 dtype = torch.bool
116 elif isinstance(fill_value, int):
117 dtype = torch.int64
118 else:
119 dtype = torch.get_default_dtype()
120 else:
121 fill_value = check_dtype(fill_value, dtype, device)
123 out = torch.empty(size, device=device, dtype=dtype)
124 N = volume(size)
125 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
126 with torch_device_fn.device(device):
127 if isinstance(fill_value, torch.Tensor):
128 full_tensor_kernel[grid_fn](
129 out,
130 N,
131 fill_value,
132 )
133 else:
134 full_scalar_kernel[grid_fn](
135 out,
136 N,
137 fill_value,
138 )
139 return out