Coverage for src/flag_gems/experimental_ops/_unsafe_view.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _copy_1d_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 x = tl.load(x_ptr + offsets, mask=mask) 

13 tl.store(y_ptr + offsets, x, mask=mask) 

14 

15 

16def _infer_view_size(input_numel, size): 

17 if isinstance(size, torch.Size): 

18 size = list(size) 

19 elif isinstance(size, (list, tuple)): 

20 size = list(size) 

21 else: 

22 raise TypeError("size must be a list/tuple/torch.Size of ints") 

23 neg_one_count = sum(1 for s in size if s == -1) 

24 if neg_one_count > 1: 

25 raise ValueError("only one dimension can be inferred") 

26 known_prod = 1 

27 for s in size: 

28 if s != -1: 

29 if s < 0: 

30 raise ValueError( 

31 "invalid size, negative dimensions other than -1 not allowed" 

32 ) 

33 known_prod *= s if s != 0 else 1 

34 if neg_one_count == 0: 

35 prod = 1 

36 for s in size: 

37 prod *= s 

38 if prod != input_numel: 

39 raise ValueError( 

40 f"requested view size {tuple(size)} does not match input numel {input_numel}" 

41 ) 

42 return tuple(size) 

43 else: 

44 if known_prod == 0: 

45 if input_numel != 0: 

46 raise ValueError( 

47 f"cannot infer dimension with zero known product and non-zero numel {input_numel}" 

48 ) 

49 inferred = 0 

50 else: 

51 if input_numel % known_prod != 0: 

52 raise ValueError( 

53 "input numel not divisible by known product for inferred dimension" 

54 ) 

55 inferred = input_numel // known_prod 

56 out = [] 

57 inferred_used = False 

58 for s in size: 

59 if s == -1 and not inferred_used: 

60 out.append(int(inferred)) 

61 inferred_used = True 

62 else: 

63 out.append(int(s)) 

64 return tuple(out) 

65 

66 

67def _launch_copy_kernel(src_flat: torch.Tensor, dst_flat: torch.Tensor): 

68 assert src_flat.is_cuda and dst_flat.is_cuda, "tensors must be on CUDA device" 

69 assert src_flat.dtype == dst_flat.dtype, "dtypes must match" 

70 n_elements = src_flat.numel() 

71 if n_elements == 0: 

72 return 

73 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

74 _copy_1d_kernel[grid](src_flat, dst_flat, n_elements, BLOCK_SIZE=1024) 

75 

76 

77def _unsafe_view(self: torch.Tensor, size): 

78 new_size = _infer_view_size(self.numel(), size) 

79 out = torch.empty(new_size, device=self.device, dtype=self.dtype) 

80 src_flat = self.contiguous().view(-1) 

81 dst_flat = out.view(-1) 

82 _launch_copy_kernel(src_flat, dst_flat) 

83 return out 

84 

85 

86def _unsafe_view_out(self: torch.Tensor, size, out: torch.Tensor = None): 

87 if out is None: 

88 # create out if not provided 

89 out = torch.empty(0, device=self.device, dtype=self.dtype) 

90 if out.device != self.device: 

91 raise ValueError("out tensor must be on the same device as input") 

92 if out.dtype != self.dtype: 

93 raise ValueError("out tensor must have the same dtype as input") 

94 new_size = _infer_view_size(self.numel(), size) 

95 out.resize_(new_size) 

96 src_flat = self.contiguous().view(-1) 

97 dst_flat = out.view(-1) 

98 _launch_copy_kernel(src_flat, dst_flat) 

99 return out