Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/diag.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5from flag_gems.runtime import torch_device_fn 

6from flag_gems.utils import triton_lang_extension as tle 

7 

8 

9@triton.jit 

10def diag_1d_to_2d_kernel( 

11 data_ptr, output_ptr, N, M, stride, diagonal: tl.constexpr, BLOCK_SIZE: tl.constexpr 

12): 

13 idx = tle.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

14 

15 if diagonal >= 0: 

16 row_idx = idx 

17 col_idx = row_idx + diagonal 

18 else: 

19 col_idx = idx 

20 row_idx = col_idx - diagonal 

21 

22 mask = (row_idx < M) & (col_idx < M) 

23 

24 diag_value = tl.load(data_ptr + idx * stride, mask=idx < N, other=0) 

25 

26 out_offset = row_idx * M + col_idx 

27 tl.store(output_ptr + out_offset, diag_value, mask=mask) 

28 

29 

30@triton.jit 

31def diag_2d_to_1d_kernel( 

32 data_ptr, 

33 output_ptr, 

34 N, 

35 M, 

36 stride0, 

37 stride1, 

38 diagonal: tl.constexpr, 

39 BLOCK_SIZE: tl.constexpr, 

40): 

41 idx = tle.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

42 

43 if diagonal >= 0: 

44 row_idx = idx 

45 col_idx = row_idx + diagonal 

46 else: 

47 col_idx = idx 

48 row_idx = col_idx - diagonal 

49 mask = (row_idx < N) & (col_idx < M) 

50 

51 diag_value = tl.load( 

52 data_ptr + row_idx * stride0 + col_idx * stride1, mask=mask, other=0 

53 ) 

54 tl.store(output_ptr + idx, diag_value, mask=mask) 

55 

56 

57def diag_1d_to_2d(x, diagonal=0): 

58 N = x.shape[0] 

59 M = N + abs(diagonal) 

60 output = torch.zeros((M, M), dtype=x.dtype, device=x.device) 

61 

62 stride = x.stride(0) 

63 # BLOCK_SIZE = 128 

64 BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(N, 12)) 

65 

66 grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),) 

67 

68 with torch_device_fn.device(x.device): 

69 diag_1d_to_2d_kernel[grid]( 

70 x, output, N, M, stride, diagonal, BLOCK_SIZE=BLOCK_SIZE 

71 ) 

72 return output 

73 

74 

75def diag_2d_to_1d(x, diagonal=0): 

76 N, M = x.shape 

77 if diagonal >= 0: 

78 diag_len = min(N, M - diagonal) 

79 else: 

80 diag_len = min(N + diagonal, M) 

81 if diag_len <= 0: 

82 return torch.empty(0, dtype=x.dtype, device=x.device) 

83 output = torch.empty(diag_len, dtype=x.dtype, device=x.device) 

84 stride0 = x.stride(0) 

85 stride1 = x.stride(1) 

86 BLOCK_SIZE = 128 

87 

88 grid = lambda meta: (triton.cdiv(diag_len, BLOCK_SIZE),) 

89 

90 with torch_device_fn.device(x.device): 

91 diag_2d_to_1d_kernel[grid]( 

92 x, output, N, M, stride0, stride1, diagonal, BLOCK_SIZE=BLOCK_SIZE 

93 ) 

94 return output 

95 

96 

97def diag(x, diagonal=0): 

98 if x.dim() == 1: 

99 return diag_1d_to_2d(x, diagonal) 

100 elif x.dim() == 2: 

101 return diag_2d_to_1d(x, diagonal) 

102 else: 

103 raise ValueError("Input must be a 1D or 2D tensor.")