Coverage for src/flag_gems/experimental_ops/addcmul_.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-16 02:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def addcmul_(self_ptr, t1_ptr, t2_ptr, n_elements, value, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(self_ptr + offsets, mask=mask) 

14 a = tl.load(t1_ptr + offsets, mask=mask) 

15 b = tl.load(t2_ptr + offsets, mask=mask) 

16 

17 xf = x.to(tl.float32) 

18 af = a.to(tl.float32) 

19 bf = b.to(tl.float32) 

20 

21 out_f = xf + af * bf * value 

22 out = out_f.to(x.dtype) 

23 

24 tl.store(self_ptr + offsets, out, mask=mask) 

25 

26 

27_addcmul_kernel = addcmul_ 

28 

29 

30def addcmul_(*args, **kwargs): 

31 # Parse arguments: self, tensor1, tensor2, value (defaults to 1) 

32 if len(args) == 0: 

33 raise TypeError("addcmul_ expected at least 1 argument (self tensor)") 

34 self = args[0] 

35 

36 # Extract tensor1 and tensor2 

37 if len(args) >= 3: 

38 tensor1 = args[1] 

39 tensor2 = args[2] 

40 if len(args) >= 4: 

41 value = args[3] 

42 else: 

43 value = kwargs.get("value", kwargs.get("alpha", 1.0)) 

44 else: 

45 tensor1 = kwargs.get("tensor1", None) 

46 tensor2 = kwargs.get("tensor2", None) 

47 value = kwargs.get("value", kwargs.get("alpha", 1.0)) 

48 

49 if tensor1 is None or tensor2 is None: 

50 raise TypeError("addcmul_ requires tensor1 and tensor2") 

51 

52 # Convert value to float 

53 value = float(value) 

54 

55 # Broadcast tensor1 and tensor2 to match self's shape 

56 try: 

57 t1 = tensor1.expand_as(self) 

58 t2 = tensor2.expand_as(self) 

59 except Exception: 

60 t1 = torch.broadcast_to(tensor1, self.shape) 

61 t2 = torch.broadcast_to(tensor2, self.shape) 

62 

63 # Fallback conditions 

64 # - non-CUDA tensors 

65 # - non-contiguous self (in-place update with non-contiguous memory) 

66 # - unsupported dtype 

67 if not (self.is_cuda and t1.is_cuda and t2.is_cuda): 

68 return torch.ops.aten.addcmul_(self, tensor1, tensor2, value=value) 

69 

70 if not self.is_contiguous(): 

71 return torch.ops.aten.addcmul_(self, tensor1, tensor2, value=value) 

72 

73 if self.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

74 return torch.ops.aten.addcmul_(self, tensor1, tensor2, value=value) 

75 

76 # Make inputs contiguous for efficient loads 

77 t1 = t1.contiguous() 

78 t2 = t2.contiguous() 

79 

80 # Cast inputs to self dtype if needed 

81 if t1.dtype != self.dtype: 

82 t1 = t1.to(self.dtype) 

83 if t2.dtype != self.dtype: 

84 t2 = t2.to(self.dtype) 

85 

86 n_elements = self.numel() 

87 if n_elements == 0: 

88 return self 

89 

90 BLOCK_SIZE = 1024 

91 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

92 

93 _addcmul_kernel[grid]( 

94 self, 

95 t1, 

96 t2, 

97 n_elements, 

98 value, 

99 BLOCK_SIZE=BLOCK_SIZE, 

100 ) 

101 return self