Coverage for src/flag_gems/ops/t_copy.py: 53%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@triton.jit 

12def t_copy_2d_kernel( 

13 in_ptr, 

14 out_ptr, 

15 in_stride_0, 

16 in_stride_1, 

17 out_stride_0, 

18 out_stride_1, 

19 M, # input dim0 

20 N, # input dim1 

21 BLOCK_M: tl.constexpr, 

22 BLOCK_N: tl.constexpr, 

23): 

24 pid_m = tl.program_id(0) 

25 pid_n = tl.program_id(1) 

26 

27 i = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # corresponds to out rows [0..N) 

28 j = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # corresponds to out cols [0..M) 

29 

30 i64 = i.to(tl.int64)[None, :] # shape [1, BM] 

31 j64 = j.to(tl.int64)[:, None] # shape [BN, 1] 

32 

33 # out shape = (N, M) 

34 mask = (i64 < N) & (j64 < M) 

35 

36 # in index = (j, i) -> in_offset = j*in_stride_0 + i*in_stride_1 

37 in_offsets = j64 * in_stride_0 + i64 * in_stride_1 

38 # out index = (i, j) -> out_offset = i*out_stride_0 + j*out_stride_1 

39 out_offsets = i64 * out_stride_0 + j64 * out_stride_1 

40 

41 x = tl.load(in_ptr + in_offsets, mask=mask) 

42 tl.store(out_ptr + out_offsets, x, mask=mask) 

43 

44 

45@triton.jit 

46def copy_1d_strided_kernel( 

47 in_ptr, 

48 out_ptr, 

49 in_stride, 

50 out_stride, 

51 N, 

52 BLOCK_SIZE: tl.constexpr, 

53): 

54 pid = tl.program_id(0) 

55 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

56 mask = offs < N 

57 offs64 = offs.to(tl.int64) 

58 in_idx = offs64 * in_stride 

59 out_idx = offs64 * out_stride 

60 x = tl.load(in_ptr + in_idx, mask=mask) 

61 tl.store(out_ptr + out_idx, x, mask=mask) 

62 

63 

64def _launch_t_copy_kernel(inp: torch.Tensor, out: torch.Tensor): 

65 assert inp.is_cuda and out.is_cuda, "t_copy kernels require CUDA tensors" 

66 assert inp.dtype == out.dtype, "dtype mismatch between input and output" 

67 

68 dim = inp.dim() 

69 if dim == 0: 

70 # Scalar copy 

71 n = 1 

72 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) 

73 copy_1d_strided_kernel[grid]( 

74 inp, 

75 out, 

76 0, 

77 0, 

78 n, 

79 BLOCK_SIZE=1, 

80 ) 

81 elif dim == 1: 

82 n = inp.numel() 

83 in_stride = inp.stride(0) 

84 out_stride = out.stride(0) 

85 assert out.numel() == n, "Output size mismatch for 1D t_copy" 

86 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) 

87 copy_1d_strided_kernel[grid]( 

88 inp, 

89 out, 

90 in_stride, 

91 out_stride, 

92 n, 

93 BLOCK_SIZE=1024, 

94 ) 

95 elif dim == 2: 

96 M, N = inp.shape # input dims 

97 # out should be (N, M) 

98 assert ( 

99 out.dim() == 2 and out.shape[0] == N and out.shape[1] == M 

100 ), "Output shape must be (input.size(1), input.size(0)) for t_copy" 

101 in_s0, in_s1 = inp.stride() 

102 out_s0, out_s1 = out.stride() 

103 grid = lambda meta: ( 

104 triton.cdiv(N, meta["BLOCK_M"]), 

105 triton.cdiv(M, meta["BLOCK_N"]), 

106 ) 

107 t_copy_2d_kernel[grid]( 

108 inp, 

109 out, 

110 in_s0, 

111 in_s1, 

112 out_s0, 

113 out_s1, 

114 M, 

115 N, 

116 BLOCK_M=32, 

117 BLOCK_N=32, 

118 ) 

119 else: 

120 raise RuntimeError("t_copy expects a tensor with <= 2 dims") 

121 

122 

123def t_copy_out( 

124 input: torch.Tensor, 

125 out: torch.Tensor, 

126 memory_format: torch.memory_format | None = None, 

127): 

128 logger.debug("GEMS T_COPY_OUT") 

129 _launch_t_copy_kernel(input, out) 

130 return out 

131 

132 

133def t_copy(input: torch.Tensor, memory_format: torch.memory_format | None = None): 

134 logger.debug("GEMS T_COPY") 

135 dim = input.dim() 

136 if dim == 0: 

137 out = torch.empty((), dtype=input.dtype, device=input.device) 

138 elif dim == 1: 

139 out = torch.empty_like(input, memory_format=torch.contiguous_format) 

140 elif dim == 2: 

141 M, N = input.shape 

142 out = torch.empty( 

143 (N, M), 

144 dtype=input.dtype, 

145 device=input.device, 

146 memory_format=torch.contiguous_format, 

147 ) 

148 else: 

149 raise RuntimeError("t_copy expects a tensor with <= 2 dims") 

150 _launch_t_copy_kernel(input, out) 

151 return out