Coverage for src/flag_gems/runtime/backend/_cambricon/ops/where.py: 0%

48 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-11 02:28 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from ..utils.pointwise_dynamic import pointwise_dynamic 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12@pointwise_dynamic( 

13 is_tensor=[True, True, True], 

14 promotion_methods=[(1, 2, "NO_OPMATH")], 

15) 

16@triton.jit 

17def where_inner(condition, self, other): 

18 return tl.where(condition, self, other) 

19 

20 

21def where_self_out(condition, self, other, out=None): 

22 logger.debug("GEMS_CAMBRICON WHERE_SELF_OUT") 

23 result_type = torch.result_type(self, other) 

24 if out is not None: 

25 assert ( 

26 out.dtype == result_type 

27 ), f"Expected out type to be {result_type}, but got {out.dtype}." 

28 

29 c, a, b = list( 

30 map( 

31 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x), 

32 (condition, self, other), 

33 ) 

34 ) 

35 

36 if a.dtype != result_type: 

37 a = a.to(result_type) 

38 if b.dtype != result_type: 

39 b = b.to(result_type) 

40 

41 devices = map(lambda x: x.device, (c, a, b)) 

42 devices = list(filter(lambda k: k.type != "cpu", devices)) 

43 

44 assert len(devices), "CPU only. There seems a mistake to dispatch to here." 

45 

46 device = devices[0] 

47 if c.device != device and c.ndim == 0: 

48 c = c.to(device) 

49 if a.device != device and a.ndim == 0: 

50 a = a.to(device) 

51 if b.device != device and b.ndim == 0: 

52 b = b.to(device) 

53 

54 assert ( 

55 len(set(devices)) == 1 

56 ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}" 

57 assert ( 

58 c.dtype == torch.bool 

59 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}" 

60 

61 if out is None: 

62 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape) 

63 out = torch.empty(out_shape, dtype=result_type, device=device) 

64 

65 ndim = max(c.ndim, a.ndim, b.ndim) 

66 where_inner.instantiate(ndim) 

67 where_inner(c, a, b, out0=out) 

68 return out 

69 

70 

71def where_self(condition, self, other): 

72 logger.debug("GEMS_CAMBRICON WHERE_SELF") 

73 return where_self_out(condition, self, other) 

74 

75 

76def where_scalar_self(condition, self, other): 

77 logger.debug("GEMS_CAMBRICON WHERE_SCALAR_SELF") 

78 return where_self_out(condition, self, other) 

79 

80 

81def where_scalar_other(condition, self, other): 

82 logger.debug("GEMS_CAMBRICON WHERE_SCALAR_OTHER") 

83 return where_self_out(condition, self, other)