Coverage for src/flag_gems/experimental_ops/expand.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def expand( 

8 x_ptr, 

9 out_ptr, 

10 n_elements, 

11 ndims, 

12 out_shape_ptr, 

13 out_cumprod_ptr, 

14 in_stride_ptr, 

15 BLOCK_SIZE: tl.constexpr, 

16 MAX_DIMS: tl.constexpr, 

17): 

18 pid = tl.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < n_elements 

22 

23 # Compute input offsets corresponding to each output linear index 

24 in_offsets = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

25 

26 # Accumulate contributions per dimension 

27 for d in range(MAX_DIMS): 

28 # Load scalars defining the output decomposition and input strides 

29 s = tl.load(out_shape_ptr + d) 

30 stride_right = tl.load(out_cumprod_ptr + d) 

31 in_stride = tl.load(in_stride_ptr + d) 

32 # idx along dimension d for each linear offset 

33 idx_d = (offsets // stride_right) % s 

34 # contribution to input linear offset 

35 in_offsets += idx_d * in_stride 

36 

37 # Load from input using computed offsets and store to output 

38 x = tl.load(x_ptr + in_offsets, mask=mask) 

39 tl.store(out_ptr + offsets, x, mask=mask) 

40 

41 

42_expand_kernel = expand 

43 

44 

45def expand(*args, **kwargs): 

46 x = args[0] 

47 size = args[1] 

48 implicit = kwargs.get( # noqa: F841 

49 "implicit", False 

50 ) # not used but accepted for signature compatibility 

51 

52 if not isinstance(size, (list, tuple, torch.Size)): 

53 raise TypeError("expand size must be a list/tuple/torch.Size of ints") 

54 

55 size = list(size) 

56 in_shape = list(x.shape) 

57 in_strides = list(x.stride()) 

58 

59 out_ndim = len(size) 

60 in_ndim = len(in_shape) 

61 

62 if in_ndim > out_ndim: 

63 raise RuntimeError( 

64 f"expand: requested size has fewer dimensions ({out_ndim}) than input ({in_ndim})" 

65 ) 

66 

67 # Pad input shape/strides on the left to match output ndim 

68 if in_ndim < out_ndim: 

69 pad = out_ndim - in_ndim 

70 in_shape = [1] * pad + in_shape 

71 # For padded (new) leading dims, stride effectively is 0 since they will be broadcast 

72 in_strides = [0] * pad + in_strides 

73 

74 # Resolve -1 and validate broadcastability 

75 out_shape = [] 

76 for d in range(out_ndim): 

77 req = size[d] 

78 src = in_shape[d] 

79 if req == -1: 

80 target = src 

81 else: 

82 target = req 

83 if src != target and src != 1: 

84 raise RuntimeError( 

85 f"The expanded size of the tensor ({target}) must match the existing size ({src}) at non-singleton " 

86 f"dimension {d}. Target sizes must be the same, or -1, or the size of dimension in the original tensor must be 1." # noqa: E501 

87 ) 

88 out_shape.append(int(target)) 

89 

90 # Effective input strides: 0 for broadcasted dims, original stride otherwise 

91 in_stride_eff = [ 

92 int(in_strides[d]) if in_shape[d] != 1 else 0 for d in range(out_ndim) 

93 ] 

94 

95 # Prepare decomposition multipliers: product of sizes to the right for each dim 

96 out_cumprod_right = [0] * out_ndim 

97 prod = 1 

98 for d in range(out_ndim - 1, -1, -1): 

99 out_cumprod_right[d] = prod 

100 prod *= out_shape[d] 

101 

102 # Allocate output 

103 out = torch.empty(out_shape, dtype=x.dtype, device=x.device) 

104 

105 n_elements = out.numel() 

106 if n_elements == 0: 

107 return out 

108 

109 # Triton kernel parameters 

110 BLOCK_SIZE = 1024 

111 MAX_DIMS = max(out_ndim, 1) # at least 1 

112 # Round up MAX_DIMS to a reasonable static upper bound for compilation (e.g., 16) 

113 # but ensure arrays we pass match MAX_DIMS in kernel 

114 STATIC_MAX = 16 

115 if MAX_DIMS > STATIC_MAX: 

116 STATIC_MAX = MAX_DIMS 

117 

118 # Create device arrays for shapes/strides with padding for MAX_DIMS 

119 pad_len = STATIC_MAX - out_ndim 

120 out_shape_arr = torch.tensor( 

121 out_shape + [1] * pad_len, dtype=torch.int64, device=x.device 

122 ) 

123 out_cumprod_arr = torch.tensor( 

124 out_cumprod_right + [1] * pad_len, dtype=torch.int64, device=x.device 

125 ) 

126 in_stride_arr = torch.tensor( 

127 in_stride_eff + [0] * pad_len, dtype=torch.int64, device=x.device 

128 ) 

129 

130 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

131 

132 _expand_kernel[grid]( 

133 x, 

134 out, 

135 n_elements, 

136 out_ndim, 

137 out_shape_arr, 

138 out_cumprod_arr, 

139 in_stride_arr, 

140 BLOCK_SIZE=BLOCK_SIZE, 

141 MAX_DIMS=STATIC_MAX, 

142 ) 

143 return out