Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/triu.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import builtins 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8# from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

14 

15 

16def heur_m_block_size(args): 

17 return triton.next_power_of_2(triton.cdiv(args["M"], 12)) # cluster_num 

18 

19 

20def heur_n_block_size(args): 

21 return builtins.min(args["N"], 8192) 

22 

23 

24@libentry() 

25# @triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"]) 

26@triton.heuristics( 

27 values={ 

28 "M_BLOCK_SIZE": heur_m_block_size, 

29 "N_BLOCK_SIZE": heur_n_block_size, 

30 }, 

31) 

32@triton.jit(do_not_specialize=["diagonal"]) 

33def triu_kernel( 

34 X, 

35 Y, 

36 M, 

37 N, 

38 diagonal, 

39 M_BLOCK_SIZE: tl.constexpr, 

40 N_BLOCK_SIZE: tl.constexpr, 

41): 

42 pid = tle.program_id(0) 

43 row = pid * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE)[:, None] 

44 m_mask = row < M 

45 X += row * N 

46 Y += row * N 

47 

48 for n_offset in range(0, N, N_BLOCK_SIZE): 

49 cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :] 

50 n_mask = cols < N 

51 mask = m_mask and n_mask 

52 

53 x = tl.load(X + cols, mask, other=0.0) 

54 y = tl.where(row + diagonal <= cols, x, 0.0) 

55 tl.store(Y + cols, y, mask=mask) 

56 

57 

58def heur_batch_block_size(args): 

59 return triton.next_power_of_2(triton.cdiv(args["batch"], 12)) # cluster_num 

60 

61 

62def heur_mn_block_size(args): 

63 return builtins.min(args["MN"], 8192) 

64 

65 

66@libentry() 

67# @triton.autotune( 

68# configs=runtime.get_tuned_config("triu_batch"), 

69# key=["batch", "MN", "N", "diagonal"], 

70# ) 

71@triton.heuristics( 

72 { 

73 "BATCH_BLOCK_SIZE": heur_batch_block_size, 

74 "MN_BLOCK_SIZE": heur_mn_block_size, 

75 } 

76) 

77@triton.jit(do_not_specialize=["diagonal"]) 

78def triu_batch_kernel( 

79 X, 

80 Y, 

81 batch, 

82 MN, 

83 N, 

84 diagonal, 

85 BATCH_BLOCK_SIZE: tl.constexpr, 

86 MN_BLOCK_SIZE: tl.constexpr, 

87): 

88 batch_id = tle.program_id(0) 

89 mn_id = tle.program_id(1) 

90 row = batch_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None] 

91 batch_mask = row < batch 

92 X += row * MN 

93 Y += row * MN 

94 

95 cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :] 

96 mn_mask = cols < MN 

97 mask = batch_mask and mn_mask 

98 x = tl.load(X + cols, mask, other=0.0) 

99 m = cols // N 

100 n = cols % N 

101 y = tl.where(m + diagonal <= n, x, 0.0) 

102 tl.store(Y + cols, y, mask=mask) 

103 

104 

105INT32_MAX = torch.iinfo(torch.int32).max 

106 

107 

108def triu(A, diagonal=0): 

109 logger.debug("GEMS TRIU") 

110 A = A.contiguous() 

111 out = torch.empty_like(A) 

112 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions" 

113 M, N = A.shape[-2:] 

114 with torch_device_fn.device(A.device): 

115 if len(A.shape) == 2: 

116 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) 

117 triu_kernel[grid](A, out, M, N, diagonal) 

118 else: 

119 batch = int(torch.numel(A) / M / N) 

120 B = A.view(batch, -1) 

121 grid = lambda meta: ( 

122 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]), 

123 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]), 

124 ) 

125 triu_batch_kernel[grid]( 

126 B, 

127 out, 

128 batch, 

129 M * N, 

130 N, 

131 diagonal, 

132 ) 

133 out = out.view(A.shape) 

134 return out