Coverage for src/flag_gems/experimental_ops/addcdiv.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def addcdiv_kernel(
8 self_ptr, t1_ptr, t2_ptr, out_ptr, n_elements, value, BLOCK_SIZE: tl.constexpr
9):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
15 a = tl.load(self_ptr + offsets, mask=mask)
16 b = tl.load(t1_ptr + offsets, mask=mask)
17 c = tl.load(t2_ptr + offsets, mask=mask)
19 val_vec = tl.full(offsets.shape, value, a.dtype)
20 result = a + (b / c) * val_vec
21 tl.store(out_ptr + offsets, result, mask=mask)
24def _prepare_addcdiv_tensors(self, tensor1, tensor2):
25 if not (self.is_cuda and tensor1.is_cuda and tensor2.is_cuda):
26 raise NotImplementedError(
27 "addcdiv Triton implementation requires CUDA tensors."
28 )
29 if not (self.device == tensor1.device == tensor2.device):
30 raise ValueError("All tensors must be on the same CUDA device.")
31 a, b, c = torch.broadcast_tensors(self, tensor1, tensor2)
32 # Determine common dtype for computation
33 common_dtype = torch.promote_types(torch.promote_types(a.dtype, b.dtype), c.dtype)
34 a = a.to(dtype=common_dtype).contiguous()
35 b = b.to(dtype=common_dtype).contiguous()
36 c = c.to(dtype=common_dtype).contiguous()
37 return a, b, c, common_dtype
40def _launch_addcdiv(a, b, c, out, value):
41 n_elements = out.numel()
42 BLOCK_SIZE = 1024
43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
44 # value can be Python number or 0-d tensor; convert to float
45 if torch.is_tensor(value):
46 if value.numel() != 1:
47 raise ValueError("value must be a scalar.")
48 # move to same device if needed, then to host scalar
49 if value.device.type == "cuda" and value.device != a.device:
50 raise ValueError(
51 "Scalar tensor 'value' must be on the same device as inputs."
52 )
53 value = float(value.to(dtype=out.dtype).item())
54 else:
55 value = float(value)
56 addcdiv_kernel[grid](a, b, c, out, n_elements, value, BLOCK_SIZE=BLOCK_SIZE)
59def addcdiv(self, tensor1, tensor2, *, value=1):
60 """
61 Returns self + value * tensor1 / tensor2 (element-wise).
62 """
63 a, b, c, common_dtype = _prepare_addcdiv_tensors(self, tensor1, tensor2)
64 out = torch.empty_like(a, dtype=common_dtype, device=a.device)
65 _launch_addcdiv(a, b, c, out, value)
66 return out
69def addcdiv_out(self, tensor1, tensor2, *, value=1, out=None):
70 """
71 Writes self + value * tensor1 / tensor2 (element-wise) into out.
72 """
73 if out is None:
74 raise ValueError("out tensor must be provided for addcdiv_out.")
75 a, b, c, common_dtype = _prepare_addcdiv_tensors(self, tensor1, tensor2)
77 # Ensure out has correct device, dtype, and shape
78 if not out.is_cuda:
79 raise NotImplementedError("out tensor must be a CUDA tensor.")
80 if out.device != a.device:
81 raise ValueError("out tensor must be on the same device as inputs.")
82 if out.dtype != common_dtype:
83 raise TypeError(f"out tensor has dtype {out.dtype}, expected {common_dtype}.")
84 if out.shape != a.shape:
85 out.resize_(a.shape)
87 if out.is_contiguous():
88 _launch_addcdiv(a, b, c, out, value)
89 else:
90 tmp = torch.empty_like(a, dtype=common_dtype, device=a.device)
91 _launch_addcdiv(a, b, c, tmp, value)
92 out.copy_(tmp)
93 return out