Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/dot.py: 0%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.libentry import libentry 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit 

17def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr): 

18 pid = tle.program_id(0) 

19 block_start = pid * BLOCK_SIZE 

20 

21 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

22 

23 mask = offsets < N 

24 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

25 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

26 

27 sum = tl.sum(x * y) 

28 tl.store(out_ptr, sum) 

29 

30 

31@libentry() 

32@triton.jit 

33def dot_kernel_1(x_ptr, y_ptr, mid_ptr, N, BLOCK_SIZE: tl.constexpr): 

34 pid = tle.program_id(0) 

35 block_start = pid * BLOCK_SIZE 

36 

37 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

38 

39 mask = offsets < N 

40 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

41 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

42 

43 partial_sum = tl.sum(x * y) 

44 tl.store(mid_ptr + pid, partial_sum) 

45 

46 

47@libentry() 

48@triton.jit 

49def dot_kernel_2(mid_ptr, out_ptr, M, BLOCK_MID: tl.constexpr): 

50 offset = tl.arange(0, BLOCK_MID) 

51 mid = mid_ptr + offset 

52 mask = offset < M 

53 mid_val = tl.load(mid, mask=mask, other=0.0) 

54 out_val = tl.sum(mid_val) 

55 tl.store(out_ptr, out_val) 

56 

57 

58def dot(x, y): 

59 logger.debug("Triton Dot Product") 

60 

61 assert x.shape == y.shape, "Input vectors must have the same shape" 

62 assert x.dim() == 1, "Input must be 1D tensors" 

63 

64 N = x.shape[0] 

65 

66 # Only when N is less than TRITON_MAX_TENSOR_NUMEL can it be processed with a single kernel, 

67 # and performance is better when N < 4096 

68 if N >= 4096: 

69 block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) 

70 

71 mid_size = triton.cdiv(N, block_size) 

72 block_mid = triton.next_power_of_2(mid_size) 

73 

74 grid_1 = (mid_size, 1, 1) 

75 grid_2 = (1, 1, 1) 

76 

77 mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device) 

78 out = torch.empty([], dtype=x.dtype, device=x.device) 

79 

80 with torch_device_fn.device(x.device): 

81 dot_kernel_1[grid_1](x, y, mid, N, block_size) 

82 dot_kernel_2[grid_2](mid, out, mid_size, block_mid) 

83 

84 else: 

85 block_size = triton.next_power_of_2(N) 

86 

87 grid = (1, 1, 1) 

88 

89 out = torch.empty([], dtype=torch.float32, device=x.device) 

90 

91 with torch_device_fn.device(x.device): 

92 dot_kernel[grid](x, y, out, N, block_size) 

93 out = out.to(x.dtype) 

94 

95 return out