Coverage for src/flag_gems/ops/t_copy.py: 53%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-27 02:51 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def t_copy_2d_kernel(
13 in_ptr,
14 out_ptr,
15 in_stride_0,
16 in_stride_1,
17 out_stride_0,
18 out_stride_1,
19 M, # input dim0
20 N, # input dim1
21 BLOCK_M: tl.constexpr,
22 BLOCK_N: tl.constexpr,
23):
24 pid_m = tl.program_id(0)
25 pid_n = tl.program_id(1)
27 i = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # corresponds to out rows [0..N)
28 j = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # corresponds to out cols [0..M)
30 i64 = i.to(tl.int64)[None, :] # shape [1, BM]
31 j64 = j.to(tl.int64)[:, None] # shape [BN, 1]
33 # out shape = (N, M)
34 mask = (i64 < N) & (j64 < M)
36 # in index = (j, i) -> in_offset = j*in_stride_0 + i*in_stride_1
37 in_offsets = j64 * in_stride_0 + i64 * in_stride_1
38 # out index = (i, j) -> out_offset = i*out_stride_0 + j*out_stride_1
39 out_offsets = i64 * out_stride_0 + j64 * out_stride_1
41 x = tl.load(in_ptr + in_offsets, mask=mask)
42 tl.store(out_ptr + out_offsets, x, mask=mask)
45@triton.jit
46def copy_1d_strided_kernel(
47 in_ptr,
48 out_ptr,
49 in_stride,
50 out_stride,
51 N,
52 BLOCK_SIZE: tl.constexpr,
53):
54 pid = tl.program_id(0)
55 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
56 mask = offs < N
57 offs64 = offs.to(tl.int64)
58 in_idx = offs64 * in_stride
59 out_idx = offs64 * out_stride
60 x = tl.load(in_ptr + in_idx, mask=mask)
61 tl.store(out_ptr + out_idx, x, mask=mask)
64def _launch_t_copy_kernel(inp: torch.Tensor, out: torch.Tensor):
65 assert inp.is_cuda and out.is_cuda, "t_copy kernels require CUDA tensors"
66 assert inp.dtype == out.dtype, "dtype mismatch between input and output"
68 dim = inp.dim()
69 if dim == 0:
70 # Scalar copy
71 n = 1
72 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
73 copy_1d_strided_kernel[grid](
74 inp,
75 out,
76 0,
77 0,
78 n,
79 BLOCK_SIZE=1,
80 )
81 elif dim == 1:
82 n = inp.numel()
83 in_stride = inp.stride(0)
84 out_stride = out.stride(0)
85 assert out.numel() == n, "Output size mismatch for 1D t_copy"
86 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
87 copy_1d_strided_kernel[grid](
88 inp,
89 out,
90 in_stride,
91 out_stride,
92 n,
93 BLOCK_SIZE=1024,
94 )
95 elif dim == 2:
96 M, N = inp.shape # input dims
97 # out should be (N, M)
98 assert (
99 out.dim() == 2 and out.shape[0] == N and out.shape[1] == M
100 ), "Output shape must be (input.size(1), input.size(0)) for t_copy"
101 in_s0, in_s1 = inp.stride()
102 out_s0, out_s1 = out.stride()
103 grid = lambda meta: (
104 triton.cdiv(N, meta["BLOCK_M"]),
105 triton.cdiv(M, meta["BLOCK_N"]),
106 )
107 t_copy_2d_kernel[grid](
108 inp,
109 out,
110 in_s0,
111 in_s1,
112 out_s0,
113 out_s1,
114 M,
115 N,
116 BLOCK_M=32,
117 BLOCK_N=32,
118 )
119 else:
120 raise RuntimeError("t_copy expects a tensor with <= 2 dims")
123def t_copy_out(
124 input: torch.Tensor,
125 out: torch.Tensor,
126 memory_format: torch.memory_format | None = None,
127):
128 logger.debug("GEMS T_COPY_OUT")
129 _launch_t_copy_kernel(input, out)
130 return out
133def t_copy(input: torch.Tensor, memory_format: torch.memory_format | None = None):
134 logger.debug("GEMS T_COPY")
135 dim = input.dim()
136 if dim == 0:
137 out = torch.empty((), dtype=input.dtype, device=input.device)
138 elif dim == 1:
139 out = torch.empty_like(input, memory_format=torch.contiguous_format)
140 elif dim == 2:
141 M, N = input.shape
142 out = torch.empty(
143 (N, M),
144 dtype=input.dtype,
145 device=input.device,
146 memory_format=torch.contiguous_format,
147 )
148 else:
149 raise RuntimeError("t_copy expects a tensor with <= 2 dims")
150 _launch_t_copy_kernel(input, out)
151 return out