Coverage for src/flag_gems/experimental_ops/_unsafe_view.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _copy_1d_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
12 x = tl.load(x_ptr + offsets, mask=mask)
13 tl.store(y_ptr + offsets, x, mask=mask)
16def _infer_view_size(input_numel, size):
17 if isinstance(size, torch.Size):
18 size = list(size)
19 elif isinstance(size, (list, tuple)):
20 size = list(size)
21 else:
22 raise TypeError("size must be a list/tuple/torch.Size of ints")
23 neg_one_count = sum(1 for s in size if s == -1)
24 if neg_one_count > 1:
25 raise ValueError("only one dimension can be inferred")
26 known_prod = 1
27 for s in size:
28 if s != -1:
29 if s < 0:
30 raise ValueError(
31 "invalid size, negative dimensions other than -1 not allowed"
32 )
33 known_prod *= s if s != 0 else 1
34 if neg_one_count == 0:
35 prod = 1
36 for s in size:
37 prod *= s
38 if prod != input_numel:
39 raise ValueError(
40 f"requested view size {tuple(size)} does not match input numel {input_numel}"
41 )
42 return tuple(size)
43 else:
44 if known_prod == 0:
45 if input_numel != 0:
46 raise ValueError(
47 f"cannot infer dimension with zero known product and non-zero numel {input_numel}"
48 )
49 inferred = 0
50 else:
51 if input_numel % known_prod != 0:
52 raise ValueError(
53 "input numel not divisible by known product for inferred dimension"
54 )
55 inferred = input_numel // known_prod
56 out = []
57 inferred_used = False
58 for s in size:
59 if s == -1 and not inferred_used:
60 out.append(int(inferred))
61 inferred_used = True
62 else:
63 out.append(int(s))
64 return tuple(out)
67def _launch_copy_kernel(src_flat: torch.Tensor, dst_flat: torch.Tensor):
68 assert src_flat.is_cuda and dst_flat.is_cuda, "tensors must be on CUDA device"
69 assert src_flat.dtype == dst_flat.dtype, "dtypes must match"
70 n_elements = src_flat.numel()
71 if n_elements == 0:
72 return
73 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
74 _copy_1d_kernel[grid](src_flat, dst_flat, n_elements, BLOCK_SIZE=1024)
77def _unsafe_view(self: torch.Tensor, size):
78 new_size = _infer_view_size(self.numel(), size)
79 out = torch.empty(new_size, device=self.device, dtype=self.dtype)
80 src_flat = self.contiguous().view(-1)
81 dst_flat = out.view(-1)
82 _launch_copy_kernel(src_flat, dst_flat)
83 return out
86def _unsafe_view_out(self: torch.Tensor, size, out: torch.Tensor = None):
87 if out is None:
88 # create out if not provided
89 out = torch.empty(0, device=self.device, dtype=self.dtype)
90 if out.device != self.device:
91 raise ValueError("out tensor must be on the same device as input")
92 if out.dtype != self.dtype:
93 raise ValueError("out tensor must have the same dtype as input")
94 new_size = _infer_view_size(self.numel(), size)
95 out.resize_(new_size)
96 src_flat = self.contiguous().view(-1)
97 dst_flat = out.view(-1)
98 _launch_copy_kernel(src_flat, dst_flat)
99 return out