Coverage for src/flag_gems/ops/to.py: 88%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-24 15:40 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

12 torch._C.DispatchKey.CompositeExplicitAutograd 

13) 

14 

15 

16@pointwise_dynamic( 

17 is_tensor=[ 

18 True, 

19 ], 

20 promotion_methods=[(0, "DEFAULT")], 

21) 

22@triton.jit 

23def _to_copy_func(x): 

24 return x 

25 

26 

27def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype: 

28 if dtype is None: 

29 return x.dtype 

30 if isinstance(dtype, torch.dtype): 

31 return dtype 

32 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}") 

33 

34 

35def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device: 

36 if device is None: 

37 return x.device 

38 return torch.device(device) 

39 

40 

41def _normalize_memory_format( 

42 memory_format: Optional[torch.memory_format], 

43) -> torch.memory_format: 

44 if memory_format is None: 

45 return torch.preserve_format 

46 return memory_format 

47 

48 

49def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor: 

50 """Recreate tensor storage while honoring preserve_format semantics.""" 

51 if torch.ops.aten.is_non_overlapping_and_dense(x): 

52 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs) 

53 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe. 

54 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs) 

55 

56 

57# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, 

58# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 

59def to_copy( 

60 x, 

61 *, 

62 dtype=None, 

63 layout=None, 

64 device=None, 

65 pin_memory=None, 

66 non_blocking=False, 

67 memory_format=None, 

68): 

69 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch. 

70 if (layout is not None and layout != torch.strided) or x.layout != torch.strided: 

71 raise NotImplementedError( 

72 "FlagGems to_copy currently supports strided tensors only." 

73 ) 

74 if pin_memory is not None: 

75 raise NotImplementedError( 

76 "FlagGems to_copy does not yet support pin_memory=True." 

77 ) 

78 if x.is_quantized: 

79 raise NotImplementedError( 

80 "Quantized tensors are not supported in FlagGems to_copy yet." 

81 ) 

82 

83 target_dtype = _resolve_dtype(x, dtype) 

84 target_device = _resolve_device(x, device) 

85 target_memory_format = _normalize_memory_format(memory_format) 

86 

87 # Triton does not support complex dtypes; fall back to PyTorch. 

88 if x.dtype.is_complex or target_dtype.is_complex: 

89 return torch.ops.aten._to_copy.default.redispatch( 

90 _FALLBACK_KEYSET, 

91 x, 

92 dtype=target_dtype, 

93 layout=layout, 

94 device=target_device, 

95 pin_memory=pin_memory, 

96 non_blocking=non_blocking, 

97 memory_format=target_memory_format, 

98 ) 

99 

100 if target_device != x.device or ( 

101 x.device.type == "cpu" and target_device.type == "cpu" 

102 ): 

103 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation. 

104 return torch.ops.aten._to_copy.default.redispatch( 

105 _FALLBACK_KEYSET, 

106 x, 

107 dtype=target_dtype, 

108 layout=layout, 

109 device=target_device, 

110 pin_memory=pin_memory, 

111 non_blocking=non_blocking, 

112 memory_format=target_memory_format, 

113 ) 

114 

115 logger.debug("GEMS _TO_COPY") 

116 empty_kwargs = {"dtype": target_dtype, "device": target_device} 

117 

118 if target_memory_format is torch.preserve_format: 

119 out = _allocate_preserve_format(x, empty_kwargs) 

120 else: 

121 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs) 

122 

123 return _to_copy_func(x, out0=out)