Coverage for src/flag_gems/runtime/backend/_ascend/ops/slice_scatter.py: 0%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils.codegen_config_utils import CodeGenConfig 

7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

8from flag_gems.utils.shape_utils import has_internal_overlapping 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11 

12 

13config_ = CodeGenConfig( 

14 1536, 

15 (40, 1, 1), 

16 32, 

17 False, 

18 prefer_1d_tile=int(triton.__version__[0]) < 3, 

19) 

20 

21 

22@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_) 

23@triton.jit 

24def copy(src): 

25 return src 

26 

27 

28def slice_scatter(inp, src, dim=0, start=None, end=None, step=1): 

29 logger.debug("GEMS_ASCEND SLICE_SCATTER") 

30 assert src.device == inp.device, "inp and src reside on different devices." 

31 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

32 assert step > 0, "slice step must be positive" 

33 dim = dim % inp.ndim 

34 

35 start = start or 0 

36 end = end or inp.size(dim) 

37 if end < 0: 

38 end = end % inp.size(dim) 

39 

40 valid_shape = list(inp.shape) 

41 valid_shape[dim] = triton.cdiv(end - start, step) 

42 assert ( 

43 list(src.shape) == valid_shape 

44 ), "Expected src to have a size equal to the slice of self" 

45 

46 if has_internal_overlapping(inp): 

47 out = torch.empty(inp.size(), dtype=inp.dtype, device=inp.device) 

48 else: 

49 out = torch.empty_strided( 

50 inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device 

51 ) 

52 

53 ndim = inp.ndim 

54 copy(inp, out0=out) 

55 

56 indices = [slice(None)] * ndim 

57 indices[dim] = slice(start, end, step) 

58 out_ = out[indices] 

59 copy(src, out0=out_) 

60 

61 return out