Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/copy.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6 

7from ..utils.codegen_config_utils import CodeGenConfig 

8from ..utils.pointwise_dynamic import pointwise_dynamic 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

13 torch._C.DispatchKey.CompositeExplicitAutograd 

14) 

15 

16config_ = CodeGenConfig( 

17 512, 

18 (65536, 65536, 65536), 

19 32, 

20 True, 

21 prefer_1d_tile=True, 

22 is_scatter_slice=True, 

23) 

24 

25 

26# @pointwise_dynamic(is_tensor=(True,), promotion_methods=[(0, "DEFAULT")]) 

27# @triton.jit 

28# def copy(src): 

29# return src 

30 

31 

32@pointwise_dynamic( 

33 is_tensor=(True,), promotion_methods=[(0, "DEFAULT")], config=config_ 

34) 

35@triton.jit 

36def copy_slice(src): 

37 return src 

38 

39 

40@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")]) 

41@triton.jit 

42def _copy_kernel(src): 

43 return src 

44 

45 

46def _can_use_triton(dst: torch.Tensor, src: torch.Tensor) -> bool: 

47 if dst.layout != torch.strided or src.layout != torch.strided: 

48 return False 

49 if dst.device != src.device: 

50 return False 

51 if dst.is_quantized or src.is_quantized: 

52 return False 

53 if src.is_complex() and not dst.is_complex(): 

54 # Preserve PyTorch's behaviour of warning when casting complex to real 

55 # by forcing the redispatch path, which issues the warning internally. 

56 return False 

57 if not src.is_contiguous(): 

58 return False 

59 return True 

60 

61 

62def _expand_like(src: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: 

63 if src.shape == target_shape: 

64 return src 

65 return src.expand(target_shape) 

66 

67 

68def copy( 

69 template: torch.Tensor, src: torch.Tensor, *, non_blocking: Optional[bool] = False 

70): 

71 logger.debug("GEMS COPY (functional)") 

72 out = torch.empty_strided( 

73 template.size(), template.stride(), dtype=template.dtype, device=template.device 

74 ) 

75 copy_(out, src, non_blocking=bool(non_blocking)) 

76 return out 

77 

78 

79def copy_(dst: torch.Tensor, src: torch.Tensor, non_blocking: bool = False): 

80 if not isinstance(src, torch.Tensor): 

81 raise TypeError("src must be a Tensor") 

82 

83 # this is the same as PyTorch's check 

84 if dst._is_zerotensor(): 

85 raise RuntimeError("ZeroTensors are immutable. Call clone() before copy_.") 

86 if src._is_zerotensor(): 

87 return dst.zero_() 

88 

89 if torch._C._is_alias_of(dst, src): 

90 # Align with PyTorch: if metadata fully matches, this is a no-op. 

91 if ( 

92 dst.storage_offset() == src.storage_offset() 

93 and dst.stride() == src.stride() 

94 and dst.size() == src.size() 

95 and dst.dtype == src.dtype 

96 and dst.device == src.device 

97 and dst.is_conj() == src.is_conj() 

98 and dst.is_neg() == src.is_neg() 

99 ): 

100 return dst 

101 # Otherwise defer to PyTorch for well-defined semantics on overlapping writes. 

102 return torch.ops.aten.copy_.default.redispatch( 

103 _FALLBACK_KEYSET, dst, src, non_blocking 

104 ) 

105 

106 if not _can_use_triton(dst, src): 

107 return torch.ops.aten.copy_.default.redispatch( 

108 _FALLBACK_KEYSET, dst, src, non_blocking 

109 ) 

110 

111 if dst.numel() == 0: 

112 # Respect PyTorch behaviour: empty tensors should still validate broadcast. 

113 return torch.ops.aten.copy_.default.redispatch( 

114 _FALLBACK_KEYSET, dst, src, non_blocking 

115 ) 

116 

117 logger.debug("GEMS COPY_") 

118 

119 try: 

120 broadcast_shape = torch.broadcast_shapes(dst.shape, src.shape) 

121 except RuntimeError as exc: 

122 raise RuntimeError(str(exc)) from exc 

123 

124 if torch.Size(broadcast_shape) != dst.shape: 

125 raise RuntimeError( 

126 f"The broadcast shape {broadcast_shape} does not match destination shape {tuple(dst.shape)}" 

127 ) 

128 

129 expanded_src = _expand_like(src, dst.shape) 

130 

131 overload = _copy_kernel.instantiate(expanded_src.ndim) 

132 overload(expanded_src, out0=dst) 

133 return dst