Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/full_like.py: 0%
20 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-29 04:01 +0800
1import logging
3import torch
4import triton
6from flag_gems.runtime import torch_device_fn
8from .full import check_dtype, full_kernel
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13def full_like(
14 x,
15 fill_value,
16 *,
17 dtype=None,
18 layout=None,
19 device=None,
20 pin_memory=None,
21 memory_format=None,
22):
23 logger.debug("GEMS FULL_LIKE")
24 if device is None:
25 device = x.device
26 if dtype is None:
27 dtype = x.dtype
28 fill_value = check_dtype(fill_value, dtype, device)
29 out = torch.empty_like(x, device=device, dtype=dtype)
30 N = x.numel()
31 grid_fn = (12, 1, 1)
32 block_size = triton.next_power_of_2(triton.cdiv(N, 12))
33 with torch_device_fn.device(x.device):
34 full_kernel[grid_fn](
35 out,
36 N,
37 fill_value,
38 FILL_VALUE_IS_PTR=isinstance(fill_value, torch.Tensor),
39 BLOCK_SIZE=block_size,
40 buffer_size_limit=2048,
41 isCloseDtypeConvert=True,
42 )
43 return out