Coverage for src/flag_gems/runtime/backend/_cambricon/ops/full_like.py: 0%
22 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
3import torch
4import triton
6from flag_gems.runtime import torch_device_fn
8from ..utils import TOTAL_CORE_NUM
9from .full import check_dtype, full_scalar_kernel, full_tensor_kernel
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def full_like(
15 x,
16 fill_value,
17 *,
18 dtype=None,
19 layout=None,
20 device=None,
21 pin_memory=None,
22 memory_format=None,
23):
24 logger.debug("GEMS_CAMBRICON FULL_LIKE")
25 if device is None:
26 device = x.device
27 if dtype is None:
28 dtype = x.dtype
29 fill_value = check_dtype(fill_value, dtype, device)
30 out = torch.empty_like(x, device=device, dtype=dtype)
31 N = x.numel()
32 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
33 with torch_device_fn.device(x.device):
34 if isinstance(fill_value, torch.Tensor):
35 full_tensor_kernel[grid_fn](
36 out,
37 N,
38 fill_value,
39 )
40 else:
41 full_scalar_kernel[grid_fn](
42 out,
43 N,
44 fill_value,
45 )
46 return out