Coverage for src/flag_gems/runtime/backend/_cambricon/ops/to.py: 0%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6 

7from ..utils.pointwise_dynamic import pointwise_dynamic 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

12 torch._C.DispatchKey.CompositeExplicitAutograd 

13) 

14 

15 

16@pointwise_dynamic( 

17 is_tensor=[ 

18 True, 

19 ], 

20 promotion_methods=[(0, "DEFAULT")], 

21) 

22@triton.jit 

23def _to_copy_func(x): 

24 return x 

25 

26 

27def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype: 

28 if dtype is None: 

29 return x.dtype 

30 if isinstance(dtype, torch.dtype): 

31 return dtype 

32 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}") 

33 

34 

35def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device: 

36 if device is None: 

37 return x.device 

38 return torch.device(device) 

39 

40 

41def _normalize_memory_format( 

42 memory_format: Optional[torch.memory_format], 

43) -> torch.memory_format: 

44 if memory_format is None: 

45 return torch.preserve_format 

46 return memory_format 

47 

48 

49def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor: 

50 """Recreate tensor storage while honoring preserve_format semantics.""" 

51 if torch.ops.aten.is_non_overlapping_and_dense(x): 

52 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs) 

53 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe. 

54 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs) 

55 

56 

57# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, 

58# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 

59def to_copy( 

60 x, 

61 *, 

62 dtype=None, 

63 layout=None, 

64 device=None, 

65 pin_memory=None, 

66 non_blocking=False, 

67 memory_format=None, 

68): 

69 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch. 

70 if (layout is not None and layout != torch.strided) or x.layout != torch.strided: 

71 raise NotImplementedError( 

72 "FlagGems to_copy currently supports strided tensors only." 

73 ) 

74 if pin_memory is not None: 

75 raise NotImplementedError( 

76 "FlagGems to_copy does not yet support pin_memory=True." 

77 ) 

78 if x.is_quantized: 

79 raise NotImplementedError( 

80 "Quantized tensors are not supported in FlagGems to_copy yet." 

81 ) 

82 

83 target_dtype = _resolve_dtype(x, dtype) 

84 target_device = _resolve_device(x, device) 

85 target_memory_format = _normalize_memory_format(memory_format) 

86 

87 if target_device != x.device or ( 

88 x.device.type == "cpu" and target_device.type == "cpu" 

89 ): 

90 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation. 

91 return torch.ops.aten._to_copy.default.redispatch( 

92 _FALLBACK_KEYSET, 

93 x, 

94 dtype=target_dtype, 

95 layout=layout, 

96 device=target_device, 

97 pin_memory=pin_memory, 

98 non_blocking=non_blocking, 

99 memory_format=target_memory_format, 

100 ) 

101 

102 logger.debug("GEMS_CAMBRICON _TO_COPY") 

103 empty_kwargs = {"dtype": target_dtype, "device": target_device} 

104 

105 if target_memory_format is torch.preserve_format: 

106 out = _allocate_preserve_format(x, empty_kwargs) 

107 else: 

108 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs) 

109 

110 return _to_copy_func(x, out0=out)