Coverage for src/flag_gems/experimental_ops/multiply.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1from numbers import Number 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def _multiply_tt_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 x = tl.load(x_ptr + offsets, mask=mask) 

15 y = tl.load(y_ptr + offsets, mask=mask) 

16 tl.store(out_ptr + offsets, x * y, mask=mask) 

17 

18 

19@triton.jit 

20def _multiply_ts_kernel(x_ptr, scalar, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

21 pid = tl.program_id(axis=0) 

22 block_start = pid * BLOCK_SIZE 

23 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

24 mask = offsets < n_elements 

25 x = tl.load(x_ptr + offsets, mask=mask) 

26 # scalar will be implicitly cast to x's dtype by Triton during multiplication 

27 tl.store(out_ptr + offsets, x * scalar, mask=mask) 

28 

29 

30def _broadcast_shape(a_shape, b_shape): 

31 return torch.broadcast_shapes(a_shape, b_shape) 

32 

33 

34def _result_dtype_for(a, b): 

35 if isinstance(b, torch.Tensor): 

36 return torch.result_type(a, b) 

37 else: 

38 # b is a Python scalar/Number 

39 return torch.result_type(a, torch.tensor(b)) 

40 

41 

42def _ensure_cuda_device(t): 

43 if not (isinstance(t, torch.Tensor) and t.is_cuda): 

44 raise ValueError("Input tensors must be CUDA tensors for Triton kernels.") 

45 

46 

47def _launch_tt(a_ctg, b_ctg, out_t): 

48 n_elements = out_t.numel() 

49 if n_elements == 0: 

50 return 

51 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

52 _multiply_tt_kernel[grid](a_ctg, b_ctg, out_t, n_elements, BLOCK_SIZE=1024) 

53 

54 

55def _launch_ts(a_ctg, scalar, out_t): 

56 n_elements = out_t.numel() 

57 if n_elements == 0: 

58 return 

59 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

60 _multiply_ts_kernel[grid](a_ctg, scalar, out_t, n_elements, BLOCK_SIZE=1024) 

61 

62 

63def _multiply_impl(a, b, out=None): 

64 if not isinstance(a, torch.Tensor): 

65 raise TypeError("First argument must be a torch.Tensor") 

66 _ensure_cuda_device(a) 

67 device = a.device 

68 

69 # Determine result dtype and broadcasted shape 

70 res_dtype = _result_dtype_for(a, b) 

71 

72 if isinstance(b, torch.Tensor): 

73 _ensure_cuda_device(b) 

74 if b.device != device: 

75 raise ValueError("Both tensors must be on the same CUDA device.") 

76 out_shape = _broadcast_shape(a.shape, b.shape) 

77 a_ctg = a.to(res_dtype).expand(out_shape).contiguous() 

78 b_ctg = b.to(res_dtype).expand(out_shape).contiguous() 

79 if out is None: 

80 out_t = torch.empty(out_shape, device=device, dtype=res_dtype) 

81 else: 

82 if not isinstance(out, torch.Tensor) or not out.is_cuda: 

83 raise TypeError("out must be a CUDA torch.Tensor") 

84 if out.shape != out_shape: 

85 raise ValueError( 

86 f"out shape {out.shape} does not match broadcasted shape {out_shape}" 

87 ) 

88 if out.dtype != res_dtype: 

89 raise TypeError( 

90 f"out dtype {out.dtype} does not match result dtype {res_dtype}" 

91 ) 

92 if out.device != device: 

93 raise ValueError("out must be on the same CUDA device as inputs") 

94 out_t = out 

95 _launch_tt(a_ctg, b_ctg, out_t) 

96 return out_t 

97 elif isinstance(b, Number): 

98 # Scalar path 

99 out_shape = a.shape 

100 a_ctg = a.to(res_dtype).contiguous() 

101 if out is None: 

102 out_t = torch.empty(out_shape, device=device, dtype=res_dtype) 

103 else: 

104 if not isinstance(out, torch.Tensor) or not out.is_cuda: 

105 raise TypeError("out must be a CUDA torch.Tensor") 

106 if out.shape != out_shape: 

107 raise ValueError( 

108 f"out shape {out.shape} does not match input tensor shape {out_shape}" 

109 ) 

110 if out.dtype != res_dtype: 

111 raise TypeError( 

112 f"out dtype {out.dtype} does not match result dtype {res_dtype}" 

113 ) 

114 if out.device != device: 

115 raise ValueError("out must be on the same CUDA device as inputs") 

116 out_t = out 

117 _launch_ts(a_ctg, b, out_t) 

118 return out_t 

119 else: 

120 raise TypeError("Second argument must be a torch.Tensor or a Python scalar.") 

121 

122 

123def multiply_Tensor(self, other): 

124 return _multiply_impl(self, other, out=None) 

125 

126 

127def multiply_Scalar(self, other): 

128 return _multiply_impl(self, other, out=None) 

129 

130 

131def multiply_out(self, other, out): 

132 return _multiply_impl(self, other, out=out)