Coverage for src/flag_gems/runtime/backend/_ascend/ops/dot.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@triton.jit 

17def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr): 

18 pid = tle.program_id(0) 

19 block_start = pid * BLOCK_SIZE 

20 

21 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

22 

23 mask = offsets < N 

24 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

25 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

26 

27 sum = tl.sum(x * y) 

28 tl.store(out_ptr, sum) 

29 

30 

31@libentry() 

32@triton.jit 

33def dot_kernel_1(x_ptr, y_ptr, mid_ptr, N, BLOCK_SIZE: tl.constexpr): 

34 n_workers = tle.num_programs(0) 

35 pid = tle.program_id(0) 

36 

37 n_tasks = tl.cdiv(N, BLOCK_SIZE) 

38 tasks_per_worker = tl.cdiv(n_tasks, n_workers) 

39 for task_index in range(tasks_per_worker): 

40 task_id = pid + task_index * n_workers 

41 

42 block_start = task_id * BLOCK_SIZE 

43 

44 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

45 

46 mask = offsets < N 

47 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

48 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) 

49 

50 partial_sum = tl.sum(x * y) 

51 tl.store(mid_ptr + task_id, partial_sum) 

52 

53 

54@libentry() 

55@triton.jit 

56def dot_kernel_2(mid_ptr, out_ptr, M, BLOCK_MID: tl.constexpr): 

57 n_tasks = tl.cdiv(M, BLOCK_MID) 

58 for task_index in range(n_tasks): 

59 offset = task_index * BLOCK_MID + tl.arange(0, BLOCK_MID) 

60 mid = mid_ptr + offset 

61 mask = offset < M 

62 mid_val = tl.load(mid, mask=mask, other=0.0) 

63 out_val = tl.sum(mid_val) 

64 tl.store(out_ptr, out_val) 

65 

66 

67def dot(x, y): 

68 logger.debug("GEMS_ASCEND DOT") 

69 

70 assert x.shape == y.shape, "Input vectors must have the same shape" 

71 assert x.dim() == 1, "Input must be 1D tensors" 

72 

73 N = x.shape[0] 

74 

75 # Only when N is less than TRITON_MAX_TENSOR_NUMEL can it be processed with a single kernel, 

76 # and performance is better when N < 4096 

77 if N >= 4096: 

78 block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) 

79 block_size = 12032 if block_size > 12032 else block_size 

80 mid_size = triton.cdiv(N, block_size) 

81 block_mid = triton.next_power_of_2(mid_size) 

82 block_mid = 16384 if block_mid > 16384 else block_mid 

83 

84 grid_1 = (mid_size if mid_size < 240 else 240, 1, 1) 

85 grid_2 = (1, 1, 1) 

86 

87 mid = torch.empty((mid_size,), dtype=torch.float32, device=x.device) 

88 out = torch.empty([], dtype=x.dtype, device=x.device) 

89 

90 with torch_device_fn.device(x.device): 

91 dot_kernel_1[grid_1](x, y, mid, N, block_size) 

92 dot_kernel_2[grid_2](mid, out, mid_size, block_mid) 

93 

94 else: 

95 block_size = triton.next_power_of_2(N) 

96 

97 grid = (1, 1, 1) 

98 

99 out = torch.empty([], dtype=torch.float32, device=x.device) 

100 

101 with torch_device_fn.device(x.device): 

102 dot_kernel[grid](x, y, out, N, block_size) 

103 out = out.to(x.dtype) 

104 

105 return out