Coverage for src/flag_gems/experimental_ops/t_copy.py: 0%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def t_copy_2d_kernel(
8 in_ptr,
9 out_ptr,
10 in_stride_0,
11 in_stride_1,
12 out_stride_0,
13 out_stride_1,
14 M, # input dim0
15 N, # input dim1
16 BLOCK_M: tl.constexpr,
17 BLOCK_N: tl.constexpr,
18):
19 pid_m = tl.program_id(0)
20 pid_n = tl.program_id(1)
22 i = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # corresponds to out rows [0..N)
23 j = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # corresponds to out cols [0..M)
25 i64 = i.to(tl.int64)[None, :] # shape [1, BM]
26 j64 = j.to(tl.int64)[:, None] # shape [BN, 1]
28 # out shape = (N, M)
29 mask = (i64 < N) & (j64 < M)
31 # in index = (j, i) -> in_offset = j*in_stride_0 + i*in_stride_1
32 in_offsets = j64 * in_stride_0 + i64 * in_stride_1
33 # out index = (i, j) -> out_offset = i*out_stride_0 + j*out_stride_1
34 out_offsets = i64 * out_stride_0 + j64 * out_stride_1
36 x = tl.load(in_ptr + in_offsets, mask=mask)
37 tl.store(out_ptr + out_offsets, x, mask=mask)
40@triton.jit
41def copy_1d_strided_kernel(
42 in_ptr,
43 out_ptr,
44 in_stride,
45 out_stride,
46 N,
47 BLOCK_SIZE: tl.constexpr,
48):
49 pid = tl.program_id(0)
50 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
51 mask = offs < N
52 offs64 = offs.to(tl.int64)
53 in_idx = offs64 * in_stride
54 out_idx = offs64 * out_stride
55 x = tl.load(in_ptr + in_idx, mask=mask)
56 tl.store(out_ptr + out_idx, x, mask=mask)
59def _launch_t_copy_kernel(inp: torch.Tensor, out: torch.Tensor):
60 assert inp.is_cuda and out.is_cuda, "t_copy kernels require CUDA tensors"
61 assert inp.dtype == out.dtype, "dtype mismatch between input and output"
63 dim = inp.dim()
64 if dim == 0:
65 # Scalar copy
66 n = 1
67 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
68 copy_1d_strided_kernel[grid](
69 inp,
70 out,
71 0,
72 0,
73 n,
74 BLOCK_SIZE=1,
75 )
76 elif dim == 1:
77 n = inp.numel()
78 in_stride = inp.stride(0)
79 out_stride = out.stride(0)
80 assert out.numel() == n, "Output size mismatch for 1D t_copy"
81 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
82 copy_1d_strided_kernel[grid](
83 inp,
84 out,
85 in_stride,
86 out_stride,
87 n,
88 BLOCK_SIZE=1024,
89 )
90 elif dim == 2:
91 M, N = inp.shape # input dims
92 # out should be (N, M)
93 assert (
94 out.dim() == 2 and out.shape[0] == N and out.shape[1] == M
95 ), "Output shape must be (input.size(1), input.size(0)) for t_copy"
96 in_s0, in_s1 = inp.stride()
97 out_s0, out_s1 = out.stride()
98 grid = lambda meta: (
99 triton.cdiv(N, meta["BLOCK_M"]),
100 triton.cdiv(M, meta["BLOCK_N"]),
101 )
102 t_copy_2d_kernel[grid](
103 inp,
104 out,
105 in_s0,
106 in_s1,
107 out_s0,
108 out_s1,
109 M,
110 N,
111 BLOCK_M=32,
112 BLOCK_N=32,
113 )
114 else:
115 raise RuntimeError("t_copy expects a tensor with <= 2 dims")
118def t_copy_out(
119 input: torch.Tensor,
120 out: torch.Tensor,
121 memory_format: torch.memory_format | None = None,
122):
123 _launch_t_copy_kernel(input, out)
124 return out
127def t_copy(input: torch.Tensor, memory_format: torch.memory_format | None = None):
128 dim = input.dim()
129 if dim == 0:
130 out = torch.empty((), dtype=input.dtype, device=input.device)
131 elif dim == 1:
132 out = torch.empty_like(input, memory_format=torch.contiguous_format)
133 elif dim == 2:
134 M, N = input.shape
135 out = torch.empty(
136 (N, M),
137 dtype=input.dtype,
138 device=input.device,
139 memory_format=torch.contiguous_format,
140 )
141 else:
142 raise RuntimeError("t_copy expects a tensor with <= 2 dims")
143 _launch_t_copy_kernel(input, out)
144 return out