Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/to.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import logging 

2import os 

3from typing import Optional 

4 

5import torch 

6import triton 

7from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

14 torch._C.DispatchKey.CompositeExplicitAutograd 

15) 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[ 

20 True, 

21 ], 

22 promotion_methods=[(0, "DEFAULT")], 

23) 

24@triton.jit 

25def _to_copy_func(x): 

26 return x 

27 

28 

29close_interleave_config = CodeGenConfig( 

30 512, 

31 (65536, 65536, 65536), 

32 32, 

33 True, 

34 prefer_1d_tile=True, 

35 isCloseInterleave=True, 

36) 

37 

38 

39@pointwise_dynamic( 

40 is_tensor=[ 

41 True, 

42 ], 

43 promotion_methods=[(0, "DEFAULT")], 

44 config=close_interleave_config, 

45) 

46@triton.jit 

47def _to_copy_func_close_interleave(x): 

48 return x 

49 

50 

51def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype: 

52 if dtype is None: 

53 return x.dtype 

54 if isinstance(dtype, torch.dtype): 

55 return dtype 

56 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}") 

57 

58 

59def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device: 

60 if device is None: 

61 return x.device 

62 return torch.device(device) 

63 

64 

65def _normalize_memory_format( 

66 memory_format: Optional[torch.memory_format], 

67) -> torch.memory_format: 

68 if memory_format is None: 

69 return torch.preserve_format 

70 return memory_format 

71 

72 

73def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor: 

74 """Recreate tensor storage while honoring preserve_format semantics.""" 

75 if torch.ops.aten.is_non_overlapping_and_dense(x): 

76 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs) 

77 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe. 

78 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs) 

79 

80 

81# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, 

82# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 

83def to_copy( 

84 x, 

85 *, 

86 dtype=None, 

87 layout=None, 

88 device=None, 

89 pin_memory=None, 

90 non_blocking=False, 

91 memory_format=None, 

92): 

93 if x.dtype == torch.bfloat16: 

94 to_dtype_fn = _to_copy_func_close_interleave 

95 else: 

96 to_dtype_fn = _to_copy_func 

97 

98 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch. 

99 if (layout is not None and layout != torch.strided) or x.layout != torch.strided: 

100 raise NotImplementedError( 

101 "FlagGems to_copy currently supports strided tensors only." 

102 ) 

103 if pin_memory is not None: 

104 raise NotImplementedError( 

105 "FlagGems to_copy does not yet support pin_memory=True." 

106 ) 

107 if x.is_quantized: 

108 raise NotImplementedError( 

109 "Quantized tensors are not supported in FlagGems to_copy yet." 

110 ) 

111 

112 target_dtype = _resolve_dtype(x, dtype) 

113 target_device = _resolve_device(x, device) 

114 target_memory_format = _normalize_memory_format(memory_format) 

115 

116 if target_device != x.device or ( 

117 x.device.type == "cpu" and target_device.type == "cpu" 

118 ): 

119 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation. 

120 return torch.ops.aten._to_copy.default.redispatch( 

121 _FALLBACK_KEYSET, 

122 x, 

123 dtype=target_dtype, 

124 layout=layout, 

125 device=target_device, 

126 pin_memory=pin_memory, 

127 non_blocking=non_blocking, 

128 memory_format=target_memory_format, 

129 ) 

130 

131 logger.debug("GEMS _TO_COPY") 

132 empty_kwargs = {"dtype": target_dtype, "device": target_device} 

133 

134 if target_memory_format is torch.preserve_format: 

135 out = _allocate_preserve_format(x, empty_kwargs) 

136 else: 

137 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs) 

138 

139 out = torch.empty_like(x, dtype=dtype, memory_format=memory_format) 

140 if out.element_size() == 8: 

141 os.environ["TRITONXPU_ELEMBYTES"] = "8" 

142 os.environ["TRITONXPU_BF16_FAST"] = "1" 

143 res = to_dtype_fn(x, out0=out) 

144 del os.environ["TRITONXPU_ELEMBYTES"] 

145 del os.environ["TRITONXPU_BF16_FAST"] 

146 else: 

147 os.environ["TRITONXPU_BF16_FAST"] = "1" 

148 res = to_dtype_fn(x, out0=out) 

149 del os.environ["TRITONXPU_BF16_FAST"] 

150 return res