Coverage for src/flag_gems/experimental_ops/t_copy.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def t_copy_2d_kernel( 

8 in_ptr, 

9 out_ptr, 

10 in_stride_0, 

11 in_stride_1, 

12 out_stride_0, 

13 out_stride_1, 

14 M, # input dim0 

15 N, # input dim1 

16 BLOCK_M: tl.constexpr, 

17 BLOCK_N: tl.constexpr, 

18): 

19 pid_m = tl.program_id(0) 

20 pid_n = tl.program_id(1) 

21 

22 i = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # corresponds to out rows [0..N) 

23 j = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # corresponds to out cols [0..M) 

24 

25 i64 = i.to(tl.int64)[None, :] # shape [1, BM] 

26 j64 = j.to(tl.int64)[:, None] # shape [BN, 1] 

27 

28 # out shape = (N, M) 

29 mask = (i64 < N) & (j64 < M) 

30 

31 # in index = (j, i) -> in_offset = j*in_stride_0 + i*in_stride_1 

32 in_offsets = j64 * in_stride_0 + i64 * in_stride_1 

33 # out index = (i, j) -> out_offset = i*out_stride_0 + j*out_stride_1 

34 out_offsets = i64 * out_stride_0 + j64 * out_stride_1 

35 

36 x = tl.load(in_ptr + in_offsets, mask=mask) 

37 tl.store(out_ptr + out_offsets, x, mask=mask) 

38 

39 

40@triton.jit 

41def copy_1d_strided_kernel( 

42 in_ptr, 

43 out_ptr, 

44 in_stride, 

45 out_stride, 

46 N, 

47 BLOCK_SIZE: tl.constexpr, 

48): 

49 pid = tl.program_id(0) 

50 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

51 mask = offs < N 

52 offs64 = offs.to(tl.int64) 

53 in_idx = offs64 * in_stride 

54 out_idx = offs64 * out_stride 

55 x = tl.load(in_ptr + in_idx, mask=mask) 

56 tl.store(out_ptr + out_idx, x, mask=mask) 

57 

58 

59def _launch_t_copy_kernel(inp: torch.Tensor, out: torch.Tensor): 

60 assert inp.is_cuda and out.is_cuda, "t_copy kernels require CUDA tensors" 

61 assert inp.dtype == out.dtype, "dtype mismatch between input and output" 

62 

63 dim = inp.dim() 

64 if dim == 0: 

65 # Scalar copy 

66 n = 1 

67 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) 

68 copy_1d_strided_kernel[grid]( 

69 inp, 

70 out, 

71 0, 

72 0, 

73 n, 

74 BLOCK_SIZE=1, 

75 ) 

76 elif dim == 1: 

77 n = inp.numel() 

78 in_stride = inp.stride(0) 

79 out_stride = out.stride(0) 

80 assert out.numel() == n, "Output size mismatch for 1D t_copy" 

81 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) 

82 copy_1d_strided_kernel[grid]( 

83 inp, 

84 out, 

85 in_stride, 

86 out_stride, 

87 n, 

88 BLOCK_SIZE=1024, 

89 ) 

90 elif dim == 2: 

91 M, N = inp.shape # input dims 

92 # out should be (N, M) 

93 assert ( 

94 out.dim() == 2 and out.shape[0] == N and out.shape[1] == M 

95 ), "Output shape must be (input.size(1), input.size(0)) for t_copy" 

96 in_s0, in_s1 = inp.stride() 

97 out_s0, out_s1 = out.stride() 

98 grid = lambda meta: ( 

99 triton.cdiv(N, meta["BLOCK_M"]), 

100 triton.cdiv(M, meta["BLOCK_N"]), 

101 ) 

102 t_copy_2d_kernel[grid]( 

103 inp, 

104 out, 

105 in_s0, 

106 in_s1, 

107 out_s0, 

108 out_s1, 

109 M, 

110 N, 

111 BLOCK_M=32, 

112 BLOCK_N=32, 

113 ) 

114 else: 

115 raise RuntimeError("t_copy expects a tensor with <= 2 dims") 

116 

117 

118def t_copy_out( 

119 input: torch.Tensor, 

120 out: torch.Tensor, 

121 memory_format: torch.memory_format | None = None, 

122): 

123 _launch_t_copy_kernel(input, out) 

124 return out 

125 

126 

127def t_copy(input: torch.Tensor, memory_format: torch.memory_format | None = None): 

128 dim = input.dim() 

129 if dim == 0: 

130 out = torch.empty((), dtype=input.dtype, device=input.device) 

131 elif dim == 1: 

132 out = torch.empty_like(input, memory_format=torch.contiguous_format) 

133 elif dim == 2: 

134 M, N = input.shape 

135 out = torch.empty( 

136 (N, M), 

137 dtype=input.dtype, 

138 device=input.device, 

139 memory_format=torch.contiguous_format, 

140 ) 

141 else: 

142 raise RuntimeError("t_copy expects a tensor with <= 2 dims") 

143 _launch_t_copy_kernel(input, out) 

144 return out