Coverage for src/flag_gems/experimental_ops/copy_.py: 0%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6def _tl_dtype_from_torch(dtype: torch.dtype): 

7 # Map common torch dtypes to Triton dtypes 

8 if dtype == torch.float16: 

9 return tl.float16 

10 if dtype == torch.bfloat16: 

11 return tl.bfloat16 

12 if dtype == torch.float32: 

13 return tl.float32 

14 if dtype == torch.float64: 

15 return tl.float64 

16 if dtype == torch.int8: 

17 return tl.int8 

18 if dtype == torch.int16: 

19 return tl.int16 

20 if dtype == torch.int32: 

21 return tl.int32 

22 if dtype == torch.int64: 

23 return tl.int64 

24 if dtype == torch.uint8: 

25 return tl.uint8 

26 raise NotImplementedError(f"Unsupported dtype for Triton copy_: {dtype}") 

27 

28 

29@triton.jit 

30def _copy_kernel( 

31 dst_ptr, src_ptr, n_elements, BLOCK_SIZE: tl.constexpr, DST_DTYPE: tl.constexpr 

32): 

33 pid = tl.program_id(axis=0) 

34 block_start = pid * BLOCK_SIZE 

35 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

36 mask = offsets < n_elements 

37 vals = tl.load(src_ptr + offsets, mask=mask) 

38 vals = tl.cast(vals, DST_DTYPE) 

39 tl.store(dst_ptr + offsets, vals, mask=mask) 

40 

41 

42@triton.jit 

43def _fill_kernel( 

44 dst_ptr, scalar_value, n_elements, BLOCK_SIZE: tl.constexpr, DST_DTYPE: tl.constexpr 

45): 

46 pid = tl.program_id(axis=0) 

47 block_start = pid * BLOCK_SIZE 

48 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

49 mask = offsets < n_elements 

50 vals = tl.full((BLOCK_SIZE,), tl.cast(scalar_value, DST_DTYPE), DST_DTYPE) 

51 tl.store(dst_ptr + offsets, vals, mask=mask) 

52 

53 

54def _launch_copy_tensor(dst: torch.Tensor, src: torch.Tensor): 

55 assert dst.is_cuda and src.is_cuda, "Triton copy_ supports CUDA tensors only." 

56 assert ( 

57 dst.is_contiguous() and src.is_contiguous() 

58 ), "Only contiguous tensors are supported." 

59 n_elements = dst.numel() 

60 assert ( 

61 src.numel() == n_elements 

62 ), "Source and destination must have the same number of elements." 

63 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

64 DST_DTYPE = _tl_dtype_from_torch(dst.dtype) 

65 _copy_kernel[grid]( 

66 dst, 

67 src, 

68 n_elements, 

69 BLOCK_SIZE=1024, 

70 DST_DTYPE=DST_DTYPE, 

71 ) 

72 return dst 

73 

74 

75def _launch_fill_scalar(dst: torch.Tensor, scalar): 

76 assert dst.is_cuda, "Triton copy_ (scalar) supports CUDA tensors only." 

77 assert dst.is_contiguous(), "Only contiguous tensors are supported." 

78 n_elements = dst.numel() 

79 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

80 DST_DTYPE = _tl_dtype_from_torch(dst.dtype) 

81 # Convert scalar to a Python number for kernel argument 

82 if dst.dtype.is_floating_point: 

83 scalar_val = float(scalar) 

84 else: 

85 scalar_val = int(scalar) 

86 _fill_kernel[grid]( 

87 dst, 

88 scalar_val, 

89 n_elements, 

90 BLOCK_SIZE=1024, 

91 DST_DTYPE=DST_DTYPE, 

92 ) 

93 return dst 

94 

95 

96def copy_(self: torch.Tensor, src, non_blocking: bool = False): 

97 if isinstance(src, torch.Tensor): 

98 return _launch_copy_tensor(self, src) 

99 elif isinstance(src, (int, bool)): 

100 return _launch_fill_scalar(self, int(src)) 

101 elif isinstance(src, float): 

102 return _launch_fill_scalar(self, float(src)) 

103 else: 

104 raise TypeError(f"Unsupported src type for copy_: {type(src)}") 

105 

106 

107def copy__Tensor(self: torch.Tensor, src: torch.Tensor, non_blocking: bool = False): 

108 return _launch_copy_tensor(self, src) 

109 

110 

111def copy__int(self: torch.Tensor, src: int, non_blocking: bool = False): 

112 return _launch_fill_scalar(self, int(src)) 

113 

114 

115def copy__float(self: torch.Tensor, src: float, non_blocking: bool = False): 

116 return _launch_fill_scalar(self, float(src))