Coverage for src/flag_gems/runtime/backend/_ascend/ops/triu.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"]) 

17@triton.jit(do_not_specialize=["diagonal"]) 

18def triu_kernel( 

19 X, 

20 Y, 

21 M, 

22 N, 

23 diagonal, 

24 M_BLOCK_SIZE: tl.constexpr, 

25 N_BLOCK_SIZE: tl.constexpr, 

26): 

27 pid = tle.program_id(0) 

28 row = pid * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE)[:, None] 

29 m_mask = row < M 

30 X += row * N 

31 Y += row * N 

32 

33 for n_offset in range(0, N, N_BLOCK_SIZE): 

34 cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :] 

35 n_mask = cols < N 

36 mask = m_mask and n_mask 

37 

38 x = tl.load(X + cols, mask, other=0.0) 

39 y = tl.where(row + diagonal <= cols, x, 0.0) 

40 tl.store(Y + cols, y, mask=mask) 

41 

42 

43@libentry() 

44@triton.autotune( 

45 configs=runtime.get_tuned_config("triu_batch"), 

46 key=["batch", "MN", "N", "diagonal"], 

47) 

48@triton.jit(do_not_specialize=["diagonal"]) 

49def triu_batch_kernel( 

50 X, 

51 Y, 

52 batch, 

53 MN, 

54 N, 

55 diagonal, 

56 BATCH_BLOCK_SIZE: tl.constexpr, 

57 MN_BLOCK_SIZE: tl.constexpr, 

58): 

59 batch_id = tle.program_id(0) 

60 mn_id = tle.program_id(1) 

61 batch_workers = tle.num_programs(0) 

62 

63 total_batch_workloads = tl.cdiv(batch, BATCH_BLOCK_SIZE) 

64 batch_workloads = 1 

65 while batch_workloads < tl.cdiv(batch, total_batch_workloads): 

66 batch_workloads *= 2 

67 

68 for w in range(batch_workloads): 

69 batch_work_id = batch_id + w * batch_workers 

70 row = batch_work_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None] 

71 batch_mask = row < batch 

72 NX = X + row * MN 

73 NY = Y + row * MN 

74 

75 cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :] 

76 mn_mask = cols < MN 

77 mask = batch_mask and mn_mask 

78 x = tl.load(NX + cols, mask, other=0.0) 

79 m = cols // N 

80 n = cols % N 

81 y = tl.where(m + diagonal <= n, x, 0.0) 

82 tl.store(NY + cols, y, mask=mask) 

83 

84 

85INT32_MAX = torch.iinfo(torch.int32).max 

86 

87 

88def triu(A, diagonal=0): 

89 logger.debug("GEMS_ASCEND TRIU") 

90 A = A.contiguous() 

91 out = torch.empty_like(A) 

92 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions" 

93 M, N = A.shape[-2:] 

94 with torch_device_fn.device(A.device): 

95 if len(A.shape) == 2: 

96 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),) 

97 triu_kernel[grid](A, out, M, N, diagonal) 

98 else: 

99 batch = int(torch.numel(A) / M / N) 

100 B = A.view(batch, -1) 

101 

102 def grid(meta): 

103 axis0 = triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]) 

104 axis1 = triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]) 

105 while axis0 * axis1 >= 65536: 

106 axis0 = axis0 // 2 

107 return ( 

108 axis0, 

109 axis1, 

110 ) 

111 

112 triu_batch_kernel[grid]( 

113 B, 

114 out, 

115 batch, 

116 M * N, 

117 N, 

118 diagonal, 

119 ) 

120 out = out.view(A.shape) 

121 return out