Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/masked_select.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable, libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15def heur_block_size(args): 

16 return triton.next_power_of_2(triton.cdiv(args["n_elements"], 12)) # cluster_num 

17 

18 

19@libentry() 

20@triton.heuristics( 

21 values={ 

22 "BLOCK_SIZE": heur_block_size, 

23 }, 

24) 

25@triton.jit 

26def masked_select_kernel( 

27 inp_ptr, 

28 select_mask_ptr, 

29 prefix_sum_ptr, 

30 out_ptr, 

31 n_elements, 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = tle.program_id(axis=0) 

35 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

36 mask = offsets < n_elements 

37 

38 inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0) 

39 select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(tl.int1) 

40 out_offset = ( 

41 tl.load(prefix_sum_ptr + offsets, mask=(select_mask & mask), other=0.0) - 1 

42 ) 

43 

44 tl.store(out_ptr + out_offset, inp, mask=(select_mask & mask)) 

45 

46 

47def masked_select(inp, mask): 

48 logger.debug("GEMS MASKED SELECT") 

49 

50 inp_shape = tuple(inp.shape) 

51 mask_shape = tuple(mask.shape) 

52 

53 assert broadcastable( 

54 inp_shape, mask_shape 

55 ), "The shapes of the `mask` and the `input` tensor must be broadcastable" 

56 inp, mask = torch.broadcast_tensors(inp, mask) 

57 

58 inp = inp.contiguous() 

59 mask = mask.contiguous() 

60 

61 mask_flattened = mask.ravel() 

62 

63 prefix_sum = mask_flattened.cumsum(axis=0) 

64 out = torch.empty(prefix_sum[-1].item(), dtype=inp.dtype, device=inp.device) 

65 

66 n_elements = inp.numel() 

67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

68 

69 import os 

70 

71 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

72 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

73 with torch_device_fn.device(inp.device): 

74 masked_select_kernel[grid](inp, mask_flattened, prefix_sum, out, n_elements) 

75 

76 if "TRITONXPU_OTHER_SIM" in os.environ: 

77 del os.environ["TRITONXPU_OTHER_SIM"] 

78 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

79 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

80 

81 return out