Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addr.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13@libentry() 

14@triton.jit(do_not_specialize=["beta", "alpha"]) 

15def addr_kernel( 

16 input_ptr, 

17 vec1_ptr, 

18 vec2_ptr, 

19 output_ptr, 

20 beta, 

21 alpha, 

22 M, 

23 N, 

24 stride_input_m, 

25 stride_input_n, 

26 stride_vec1, 

27 stride_vec2, 

28 stride_output_m, 

29 stride_output_n, 

30 BLOCK_SIZE_M: tl.constexpr, 

31 BLOCK_SIZE_N: tl.constexpr, 

32): 

33 pid_m = tl.program_id(0) 

34 pid_n = tl.program_id(1) 

35 

36 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 

37 offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 

38 

39 vec1_ptrs = vec1_ptr + offs_m * stride_vec1 

40 vec2_ptrs = vec2_ptr + offs_n * stride_vec2 

41 

42 mask_m = offs_m < M 

43 mask_n = offs_n < N 

44 

45 vec1 = tl.load(vec1_ptrs, mask=mask_m, other=0.0).to(tl.float32) 

46 vec2 = tl.load(vec2_ptrs, mask=mask_n, other=0.0).to(tl.float32) 

47 

48 input_ptrs = ( 

49 input_ptr + offs_m[:, None] * stride_input_m + offs_n[None, :] * stride_input_n 

50 ) 

51 

52 mask_2d = mask_m[:, None] & mask_n[None, :] 

53 input_val = tl.load(input_ptrs, mask=mask_2d, other=0.0).to(tl.float32) 

54 

55 result = beta * input_val + alpha * (vec1[:, None] * vec2[None, :]) 

56 

57 output_ptrs = ( 

58 output_ptr 

59 + offs_m[:, None] * stride_output_m 

60 + offs_n[None, :] * stride_output_n 

61 ) 

62 tl.store(output_ptrs, result, mask=mask_2d) 

63 

64 

65def addr(input, vec1, vec2, *, beta=1, alpha=1): 

66 logger.debug("GEMS ADDR") 

67 if vec1.dim() != 1 or vec2.dim() != 1: 

68 raise ValueError("addr: expected 1-D vectors") 

69 

70 M = vec1.shape[0] 

71 N = vec2.shape[0] 

72 output_shape = (M, N) 

73 

74 try: 

75 input_broadcasted = torch.broadcast_to(input, output_shape) 

76 except RuntimeError: 

77 raise ValueError( 

78 f"addr: input tensor of shape {input.shape} cannot be broadcast to output shape {output_shape}" 

79 ) 

80 out = torch.empty(output_shape, device=input.device, dtype=input.dtype) 

81 

82 BLOCK_SIZE_M = 32 

83 BLOCK_SIZE_N = 32 

84 grid = lambda META: ( 

85 triton.cdiv(M, BLOCK_SIZE_M), 

86 triton.cdiv(N, BLOCK_SIZE_N), 

87 ) 

88 with torch_device_fn.device(input.device): 

89 addr_kernel[grid]( 

90 input_broadcasted, 

91 vec1, 

92 vec2, 

93 out, 

94 beta, 

95 alpha, 

96 M, 

97 N, 

98 input_broadcasted.stride(0), 

99 input_broadcasted.stride(1), 

100 vec1.stride(0), 

101 vec2.stride(0), 

102 out.stride(0), 

103 out.stride(1), 

104 BLOCK_SIZE_M=BLOCK_SIZE_M, 

105 BLOCK_SIZE_N=BLOCK_SIZE_N, 

106 ) 

107 return out