Coverage for src/flag_gems/runtime/backend/_ascend/ops/where.py: 0%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.codegen_config_utils import CodeGenConfig 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11 

12 

13config_ = CodeGenConfig( 

14 760, 

15 (48, 1, 1), 

16 32, 

17 False, 

18 prefer_1d_tile=int(triton.__version__[0]) < 3, 

19) 

20 

21 

22@pointwise_dynamic( 

23 is_tensor=[True, True, True], 

24 promotion_methods=[(1, 2, "NO_OPMATH")], 

25 config=config_, 

26) 

27@triton.jit 

28def where_inner(condition, self, other): 

29 return tl.where(condition, self, other) 

30 

31 

32def where_self_out(condition, self, other, out=None): 

33 logger.debug("GEMS_ASCEND WHERE_SELF_OUT") 

34 result_type = torch.result_type(self, other) 

35 if out is not None: 

36 assert ( 

37 out.dtype == result_type 

38 ), f"Expected out type to be {result_type}, but got {out.dtype}." 

39 

40 c, a, b = list( 

41 map( 

42 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x), 

43 (condition, self, other), 

44 ) 

45 ) 

46 

47 if a.dtype != result_type: 

48 a = a.to(result_type) 

49 if b.dtype != result_type: 

50 b = b.to(result_type) 

51 

52 devices = map(lambda x: x.device, (c, a, b)) 

53 devices = list(filter(lambda k: k.type != "cpu", devices)) 

54 

55 assert len(devices), "CPU only. There seems a mistake to dispatch to here." 

56 

57 device = devices[0] 

58 if c.device != device and c.ndim == 0: 

59 c = c.to(device) 

60 if a.device != device and a.ndim == 0: 

61 a = a.to(device) 

62 if b.device != device and b.ndim == 0: 

63 b = b.to(device) 

64 

65 assert ( 

66 len(set(devices)) == 1 

67 ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}" 

68 assert ( 

69 c.dtype == torch.bool 

70 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}" 

71 

72 if out is None: 

73 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape) 

74 out = torch.empty(out_shape, dtype=result_type, device=device) 

75 

76 ndim = max(c.ndim, a.ndim, b.ndim) 

77 where_inner.instantiate(ndim) 

78 where_inner(c, a, b, out0=out) 

79 return out 

80 

81 

82def where_self(condition, self, other): 

83 logger.debug("GEMS_ASCEND WHERE_SELF") 

84 return where_self_out(condition, self, other) 

85 

86 

87def where_scalar_self(condition, self, other): 

88 logger.debug("GEMS_ASCEND WHERE_SCALAR_SELF") 

89 return where_self_out(condition, self, other) 

90 

91 

92def where_scalar_other(condition, self, other): 

93 logger.debug("GEMS_ASCEND WHERE_SCALAR_OTHER") 

94 return where_self_out(condition, self, other)