Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/masked_select.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-09 01:57 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable, libentry 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@libentry() 

15@triton.autotune( 

16 configs=runtime.get_tuned_config("masked_select"), 

17 key=["n_elements"], 

18 warmup=5, 

19 rep=5, 

20) 

21@triton.jit 

22def masked_select_kernel( 

23 inp_ptr, 

24 select_mask_ptr, 

25 prefix_sum_ptr, 

26 out_ptr, 

27 n_elements, 

28 BLOCK_SIZE: tl.constexpr, 

29): 

30 pid = tl.program_id(axis=0) 

31 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

32 mask = offsets < n_elements 

33 

34 inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0) 

35 select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(tl.int1) 

36 out_offset = tl.load(prefix_sum_ptr + offsets, mask=mask, other=0.0) - 1 

37 

38 tl.store(out_ptr + out_offset, inp, mask=(select_mask & mask)) 

39 

40 

41def masked_select(inp, mask): 

42 logger.debug("GEMS_TSINGMICRO MASKED SELECT") 

43 

44 inp_shape = tuple(inp.shape) 

45 mask_shape = tuple(mask.shape) 

46 

47 assert broadcastable( 

48 inp_shape, mask_shape 

49 ), "The shapes of the `mask` and the `input` tensor must be broadcastable" 

50 inp, mask = torch.broadcast_tensors(inp, mask) 

51 

52 inp = inp.contiguous() 

53 mask = mask.contiguous() 

54 

55 mask_flattened = mask.ravel() 

56 

57 prefix_sum = mask_flattened.cumsum(axis=0) 

58 out = torch.empty(prefix_sum[-1].item(), dtype=inp.dtype, device=inp.device) 

59 

60 n_elements = inp.numel() 

61 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

62 with torch_device_fn.device(inp.device): 

63 masked_select_kernel[grid](inp, mask_flattened, prefix_sum, out, n_elements) 

64 return out