Coverage for src/flag_gems/ops/diag.py: 68%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as tle 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def diag_1d_to_2d_kernel( 

15 data_ptr, output_ptr, N, M, stride, diagonal: tl.constexpr, BLOCK_SIZE: tl.constexpr 

16): 

17 idx = tle.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

18 

19 if diagonal >= 0: 

20 row_idx = idx 

21 col_idx = row_idx + diagonal 

22 else: 

23 col_idx = idx 

24 row_idx = col_idx - diagonal 

25 

26 mask = (row_idx < M) & (col_idx < M) 

27 

28 diag_value = tl.load(data_ptr + idx * stride, mask=idx < N, other=0) 

29 

30 out_offset = row_idx * M + col_idx 

31 tl.store(output_ptr + out_offset, diag_value, mask=mask) 

32 

33 

34@triton.jit 

35def diag_2d_to_1d_kernel( 

36 data_ptr, 

37 output_ptr, 

38 N, 

39 M, 

40 stride0, 

41 stride1, 

42 diagonal: tl.constexpr, 

43 BLOCK_SIZE: tl.constexpr, 

44): 

45 idx = tle.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

46 

47 if diagonal >= 0: 

48 row_idx = idx 

49 col_idx = row_idx + diagonal 

50 else: 

51 col_idx = idx 

52 row_idx = col_idx - diagonal 

53 mask = (row_idx < N) & (col_idx < M) 

54 

55 diag_value = tl.load( 

56 data_ptr + row_idx * stride0 + col_idx * stride1, mask=mask, other=0 

57 ) 

58 tl.store(output_ptr + idx, diag_value, mask=mask) 

59 

60 

61def diag_1d_to_2d(x, diagonal=0): 

62 N = x.shape[0] 

63 M = N + abs(diagonal) 

64 output = torch.zeros((M, M), dtype=x.dtype, device=x.device) 

65 

66 stride = x.stride(0) 

67 BLOCK_SIZE = 128 

68 

69 grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),) 

70 

71 with torch_device_fn.device(x.device): 

72 diag_1d_to_2d_kernel[grid]( 

73 x, output, N, M, stride, diagonal, BLOCK_SIZE=BLOCK_SIZE 

74 ) 

75 return output 

76 

77 

78def diag_2d_to_1d(x, diagonal=0): 

79 N, M = x.shape 

80 if diagonal >= 0: 

81 diag_len = min(N, M - diagonal) 

82 else: 

83 diag_len = min(N + diagonal, M) 

84 if diag_len <= 0: 

85 return torch.empty(0, dtype=x.dtype, device=x.device) 

86 output = torch.empty(diag_len, dtype=x.dtype, device=x.device) 

87 stride0 = x.stride(0) 

88 stride1 = x.stride(1) 

89 BLOCK_SIZE = 128 

90 

91 grid = lambda meta: (triton.cdiv(diag_len, BLOCK_SIZE),) 

92 

93 with torch_device_fn.device(x.device): 

94 diag_2d_to_1d_kernel[grid]( 

95 x, output, N, M, stride0, stride1, diagonal, BLOCK_SIZE=BLOCK_SIZE 

96 ) 

97 return output 

98 

99 

100def diag(x, diagonal=0): 

101 logger.debug("GEMS DIAG") 

102 if x.dim() == 1: 

103 return diag_1d_to_2d(x, diagonal) 

104 elif x.dim() == 2: 

105 return diag_2d_to_1d(x, diagonal) 

106 else: 

107 raise ValueError("Input must be a 1D or 2D tensor.")