Coverage for src/flag_gems/experimental_ops/addcmul_.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-17 02:35 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def addcmul_(self_ptr, t1_ptr, t2_ptr, n_elements, value, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(self_ptr + offsets, mask=mask)
14 a = tl.load(t1_ptr + offsets, mask=mask)
15 b = tl.load(t2_ptr + offsets, mask=mask)
17 xf = x.to(tl.float32)
18 af = a.to(tl.float32)
19 bf = b.to(tl.float32)
21 out_f = xf + af * bf * value
22 out = out_f.to(x.dtype)
24 tl.store(self_ptr + offsets, out, mask=mask)
27_addcmul_kernel = addcmul_
30def addcmul_(*args, **kwargs):
31 # Parse arguments: self, tensor1, tensor2, value (defaults to 1)
32 if len(args) == 0:
33 raise TypeError("addcmul_ expected at least 1 argument (self tensor)")
34 self = args[0]
36 # Extract tensor1 and tensor2
37 if len(args) >= 3:
38 tensor1 = args[1]
39 tensor2 = args[2]
40 if len(args) >= 4:
41 value = args[3]
42 else:
43 value = kwargs.get("value", kwargs.get("alpha", 1.0))
44 else:
45 tensor1 = kwargs.get("tensor1", None)
46 tensor2 = kwargs.get("tensor2", None)
47 value = kwargs.get("value", kwargs.get("alpha", 1.0))
49 if tensor1 is None or tensor2 is None:
50 raise TypeError("addcmul_ requires tensor1 and tensor2")
52 # Convert value to float
53 value = float(value)
55 # Broadcast tensor1 and tensor2 to match self's shape
56 try:
57 t1 = tensor1.expand_as(self)
58 t2 = tensor2.expand_as(self)
59 except Exception:
60 t1 = torch.broadcast_to(tensor1, self.shape)
61 t2 = torch.broadcast_to(tensor2, self.shape)
63 # Fallback conditions
64 # - non-CUDA tensors
65 # - non-contiguous self (in-place update with non-contiguous memory)
66 # - unsupported dtype
67 if not (self.is_cuda and t1.is_cuda and t2.is_cuda):
68 return torch.ops.aten.addcmul_(self, tensor1, tensor2, value=value)
70 if not self.is_contiguous():
71 return torch.ops.aten.addcmul_(self, tensor1, tensor2, value=value)
73 if self.dtype not in (torch.float16, torch.bfloat16, torch.float32):
74 return torch.ops.aten.addcmul_(self, tensor1, tensor2, value=value)
76 # Make inputs contiguous for efficient loads
77 t1 = t1.contiguous()
78 t2 = t2.contiguous()
80 # Cast inputs to self dtype if needed
81 if t1.dtype != self.dtype:
82 t1 = t1.to(self.dtype)
83 if t2.dtype != self.dtype:
84 t2 = t2.to(self.dtype)
86 n_elements = self.numel()
87 if n_elements == 0:
88 return self
90 BLOCK_SIZE = 1024
91 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
93 _addcmul_kernel[grid](
94 self,
95 t1,
96 t2,
97 n_elements,
98 value,
99 BLOCK_SIZE=BLOCK_SIZE,
100 )
101 return self