Coverage for src/flag_gems/runtime/backend/_ascend/ops/select_scatter.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def select_scatter_kernel( 

14 out_ptr, 

15 inp_ptr, 

16 src_ptr, 

17 total_elements, 

18 dim_size, 

19 dim_prod_post, 

20 index, 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 pid = tl.program_id(0) 

24 block_start = pid * BLOCK_SIZE 

25 offsets = tl.arange(0, BLOCK_SIZE) 

26 mask = block_start + offsets < total_elements 

27 idx = block_start + offsets 

28 

29 pre_idx = idx // (dim_size * dim_prod_post) 

30 dim_idx = (idx // dim_prod_post) % dim_size 

31 post_idx = idx % dim_prod_post 

32 

33 select_mask = dim_idx == index 

34 

35 inp_data = tl.load(inp_ptr + idx, mask=mask) 

36 

37 src_idx = pre_idx * dim_prod_post + post_idx 

38 src_data = tl.load(src_ptr + src_idx, mask=mask & select_mask) 

39 result = tl.where(select_mask, src_data, inp_data) 

40 tl.store(out_ptr + idx, result, mask=mask) 

41 

42 

43def select_scatter(inp, src, dim, index): 

44 logger.debug("GEMS_ASCEND SELECT_SCATTER") 

45 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

46 assert index >= -inp.size(dim) and index < inp.size(dim), "Invalid index" 

47 dim = dim % inp.ndim 

48 index = index % inp.size(dim) 

49 

50 valid_shape = list(inp.shape) 

51 del valid_shape[dim] 

52 assert ( 

53 list(src.shape) == valid_shape 

54 ), "Expected src to have a size equal to the slice of self" 

55 

56 if has_internal_overlapping(inp) == MemOverlap.Yes: 

57 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device) 

58 else: 

59 out = torch.empty_strided( 

60 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device 

61 ) 

62 

63 inp = inp.contiguous() 

64 src = src.contiguous() 

65 

66 total_elements = inp.numel() 

67 dim_size = inp.size(dim) 

68 

69 dim_prod_post = 1 

70 for d in range(dim + 1, inp.ndim): 

71 dim_prod_post *= inp.size(d) 

72 

73 BLOCK_SIZE = 1024 

74 grid = (triton.cdiv(total_elements, BLOCK_SIZE),) 

75 

76 select_scatter_kernel[grid]( 

77 out, 

78 inp, 

79 src, 

80 total_elements, 

81 dim_size, 

82 dim_prod_post, 

83 index, 

84 BLOCK_SIZE=BLOCK_SIZE, 

85 ) 

86 

87 return out