Coverage for src/flag_gems/experimental_ops/multiply.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-12 02:21 +0800
1from numbers import Number
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def _multiply_tt_kernel(x_ptr, y_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
14 x = tl.load(x_ptr + offsets, mask=mask)
15 y = tl.load(y_ptr + offsets, mask=mask)
16 tl.store(out_ptr + offsets, x * y, mask=mask)
19@triton.jit
20def _multiply_ts_kernel(x_ptr, scalar, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
21 pid = tl.program_id(axis=0)
22 block_start = pid * BLOCK_SIZE
23 offsets = block_start + tl.arange(0, BLOCK_SIZE)
24 mask = offsets < n_elements
25 x = tl.load(x_ptr + offsets, mask=mask)
26 # scalar will be implicitly cast to x's dtype by Triton during multiplication
27 tl.store(out_ptr + offsets, x * scalar, mask=mask)
30def _broadcast_shape(a_shape, b_shape):
31 return torch.broadcast_shapes(a_shape, b_shape)
34def _result_dtype_for(a, b):
35 if isinstance(b, torch.Tensor):
36 return torch.result_type(a, b)
37 else:
38 # b is a Python scalar/Number
39 return torch.result_type(a, torch.tensor(b))
42def _ensure_cuda_device(t):
43 if not (isinstance(t, torch.Tensor) and t.is_cuda):
44 raise ValueError("Input tensors must be CUDA tensors for Triton kernels.")
47def _launch_tt(a_ctg, b_ctg, out_t):
48 n_elements = out_t.numel()
49 if n_elements == 0:
50 return
51 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
52 _multiply_tt_kernel[grid](a_ctg, b_ctg, out_t, n_elements, BLOCK_SIZE=1024)
55def _launch_ts(a_ctg, scalar, out_t):
56 n_elements = out_t.numel()
57 if n_elements == 0:
58 return
59 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
60 _multiply_ts_kernel[grid](a_ctg, scalar, out_t, n_elements, BLOCK_SIZE=1024)
63def _multiply_impl(a, b, out=None):
64 if not isinstance(a, torch.Tensor):
65 raise TypeError("First argument must be a torch.Tensor")
66 _ensure_cuda_device(a)
67 device = a.device
69 # Determine result dtype and broadcasted shape
70 res_dtype = _result_dtype_for(a, b)
72 if isinstance(b, torch.Tensor):
73 _ensure_cuda_device(b)
74 if b.device != device:
75 raise ValueError("Both tensors must be on the same CUDA device.")
76 out_shape = _broadcast_shape(a.shape, b.shape)
77 a_ctg = a.to(res_dtype).expand(out_shape).contiguous()
78 b_ctg = b.to(res_dtype).expand(out_shape).contiguous()
79 if out is None:
80 out_t = torch.empty(out_shape, device=device, dtype=res_dtype)
81 else:
82 if not isinstance(out, torch.Tensor) or not out.is_cuda:
83 raise TypeError("out must be a CUDA torch.Tensor")
84 if out.shape != out_shape:
85 raise ValueError(
86 f"out shape {out.shape} does not match broadcasted shape {out_shape}"
87 )
88 if out.dtype != res_dtype:
89 raise TypeError(
90 f"out dtype {out.dtype} does not match result dtype {res_dtype}"
91 )
92 if out.device != device:
93 raise ValueError("out must be on the same CUDA device as inputs")
94 out_t = out
95 _launch_tt(a_ctg, b_ctg, out_t)
96 return out_t
97 elif isinstance(b, Number):
98 # Scalar path
99 out_shape = a.shape
100 a_ctg = a.to(res_dtype).contiguous()
101 if out is None:
102 out_t = torch.empty(out_shape, device=device, dtype=res_dtype)
103 else:
104 if not isinstance(out, torch.Tensor) or not out.is_cuda:
105 raise TypeError("out must be a CUDA torch.Tensor")
106 if out.shape != out_shape:
107 raise ValueError(
108 f"out shape {out.shape} does not match input tensor shape {out_shape}"
109 )
110 if out.dtype != res_dtype:
111 raise TypeError(
112 f"out dtype {out.dtype} does not match result dtype {res_dtype}"
113 )
114 if out.device != device:
115 raise ValueError("out must be on the same CUDA device as inputs")
116 out_t = out
117 _launch_ts(a_ctg, b, out_t)
118 return out_t
119 else:
120 raise TypeError("Second argument must be a torch.Tensor or a Python scalar.")
123def multiply_Tensor(self, other):
124 return _multiply_impl(self, other, out=None)
127def multiply_Scalar(self, other):
128 return _multiply_impl(self, other, out=None)
131def multiply_out(self, other, out):
132 return _multiply_impl(self, other, out=out)