Coverage for src/flag_gems/runtime/backend/_ascend/ops/where.py: 0%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.codegen_config_utils import CodeGenConfig
10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13config_ = CodeGenConfig(
14 760,
15 (48, 1, 1),
16 32,
17 False,
18 prefer_1d_tile=int(triton.__version__[0]) < 3,
19)
22@pointwise_dynamic(
23 is_tensor=[True, True, True],
24 promotion_methods=[(1, 2, "NO_OPMATH")],
25 config=config_,
26)
27@triton.jit
28def where_inner(condition, self, other):
29 return tl.where(condition, self, other)
32def where_self_out(condition, self, other, out=None):
33 logger.debug("GEMS_ASCEND WHERE_SELF_OUT")
34 result_type = torch.result_type(self, other)
35 if out is not None:
36 assert (
37 out.dtype == result_type
38 ), f"Expected out type to be {result_type}, but got {out.dtype}."
40 c, a, b = list(
41 map(
42 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x),
43 (condition, self, other),
44 )
45 )
47 if a.dtype != result_type:
48 a = a.to(result_type)
49 if b.dtype != result_type:
50 b = b.to(result_type)
52 devices = map(lambda x: x.device, (c, a, b))
53 devices = list(filter(lambda k: k.type != "cpu", devices))
55 assert len(devices), "CPU only. There seems a mistake to dispatch to here."
57 device = devices[0]
58 if c.device != device and c.ndim == 0:
59 c = c.to(device)
60 if a.device != device and a.ndim == 0:
61 a = a.to(device)
62 if b.device != device and b.ndim == 0:
63 b = b.to(device)
65 assert (
66 len(set(devices)) == 1
67 ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}"
68 assert (
69 c.dtype == torch.bool
70 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}"
72 if out is None:
73 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape)
74 out = torch.empty(out_shape, dtype=result_type, device=device)
76 ndim = max(c.ndim, a.ndim, b.ndim)
77 where_inner.instantiate(ndim)
78 where_inner(c, a, b, out0=out)
79 return out
82def where_self(condition, self, other):
83 logger.debug("GEMS_ASCEND WHERE_SELF")
84 return where_self_out(condition, self, other)
87def where_scalar_self(condition, self, other):
88 logger.debug("GEMS_ASCEND WHERE_SCALAR_SELF")
89 return where_self_out(condition, self, other)
92def where_scalar_other(condition, self, other):
93 logger.debug("GEMS_ASCEND WHERE_SCALAR_OTHER")
94 return where_self_out(condition, self, other)