Coverage for src/flag_gems/experimental_ops/addcdiv.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-21 14:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def addcdiv_kernel( 

8 self_ptr, t1_ptr, t2_ptr, out_ptr, n_elements, value, BLOCK_SIZE: tl.constexpr 

9): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 

15 a = tl.load(self_ptr + offsets, mask=mask) 

16 b = tl.load(t1_ptr + offsets, mask=mask) 

17 c = tl.load(t2_ptr + offsets, mask=mask) 

18 

19 val_vec = tl.full(offsets.shape, value, a.dtype) 

20 result = a + (b / c) * val_vec 

21 tl.store(out_ptr + offsets, result, mask=mask) 

22 

23 

24def _prepare_addcdiv_tensors(self, tensor1, tensor2): 

25 if not (self.is_cuda and tensor1.is_cuda and tensor2.is_cuda): 

26 raise NotImplementedError( 

27 "addcdiv Triton implementation requires CUDA tensors." 

28 ) 

29 if not (self.device == tensor1.device == tensor2.device): 

30 raise ValueError("All tensors must be on the same CUDA device.") 

31 a, b, c = torch.broadcast_tensors(self, tensor1, tensor2) 

32 # Determine common dtype for computation 

33 common_dtype = torch.promote_types(torch.promote_types(a.dtype, b.dtype), c.dtype) 

34 a = a.to(dtype=common_dtype).contiguous() 

35 b = b.to(dtype=common_dtype).contiguous() 

36 c = c.to(dtype=common_dtype).contiguous() 

37 return a, b, c, common_dtype 

38 

39 

40def _launch_addcdiv(a, b, c, out, value): 

41 n_elements = out.numel() 

42 BLOCK_SIZE = 1024 

43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

44 # value can be Python number or 0-d tensor; convert to float 

45 if torch.is_tensor(value): 

46 if value.numel() != 1: 

47 raise ValueError("value must be a scalar.") 

48 # move to same device if needed, then to host scalar 

49 if value.device.type == "cuda" and value.device != a.device: 

50 raise ValueError( 

51 "Scalar tensor 'value' must be on the same device as inputs." 

52 ) 

53 value = float(value.to(dtype=out.dtype).item()) 

54 else: 

55 value = float(value) 

56 addcdiv_kernel[grid](a, b, c, out, n_elements, value, BLOCK_SIZE=BLOCK_SIZE) 

57 

58 

59def addcdiv(self, tensor1, tensor2, *, value=1): 

60 """ 

61 Returns self + value * tensor1 / tensor2 (element-wise). 

62 """ 

63 a, b, c, common_dtype = _prepare_addcdiv_tensors(self, tensor1, tensor2) 

64 out = torch.empty_like(a, dtype=common_dtype, device=a.device) 

65 _launch_addcdiv(a, b, c, out, value) 

66 return out 

67 

68 

69def addcdiv_out(self, tensor1, tensor2, *, value=1, out=None): 

70 """ 

71 Writes self + value * tensor1 / tensor2 (element-wise) into out. 

72 """ 

73 if out is None: 

74 raise ValueError("out tensor must be provided for addcdiv_out.") 

75 a, b, c, common_dtype = _prepare_addcdiv_tensors(self, tensor1, tensor2) 

76 

77 # Ensure out has correct device, dtype, and shape 

78 if not out.is_cuda: 

79 raise NotImplementedError("out tensor must be a CUDA tensor.") 

80 if out.device != a.device: 

81 raise ValueError("out tensor must be on the same device as inputs.") 

82 if out.dtype != common_dtype: 

83 raise TypeError(f"out tensor has dtype {out.dtype}, expected {common_dtype}.") 

84 if out.shape != a.shape: 

85 out.resize_(a.shape) 

86 

87 if out.is_contiguous(): 

88 _launch_addcdiv(a, b, c, out, value) 

89 else: 

90 tmp = torch.empty_like(a, dtype=common_dtype, device=a.device) 

91 _launch_addcdiv(a, b, c, tmp, value) 

92 out.copy_(tmp) 

93 return out