Coverage for src/flag_gems/runtime/backend/_cambricon/ops/diag.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5from flag_gems.runtime import torch_device_fn 

6 

7 

8@triton.jit 

9def diag_1d_to_2d_kernel( 

10 data_ptr, output_ptr, N, M, stride, diagonal: tl.constexpr, BLOCK_SIZE: tl.constexpr 

11): 

12 idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

13 

14 if diagonal >= 0: 

15 row_idx = idx 

16 col_idx = row_idx + diagonal 

17 else: 

18 col_idx = idx 

19 row_idx = col_idx - diagonal 

20 

21 mask = (row_idx < M) & (col_idx < M) 

22 

23 diag_value = tl.load(data_ptr + idx * stride, mask=idx < N, other=0) 

24 

25 out_offset = row_idx * M + col_idx 

26 tl.store(output_ptr + out_offset, diag_value, mask=mask) 

27 

28 

29@triton.jit 

30def diag_2d_to_1d_kernel( 

31 data_ptr, 

32 output_ptr, 

33 N, 

34 M, 

35 stride0, 

36 stride1, 

37 diagonal: tl.constexpr, 

38 BLOCK_SIZE: tl.constexpr, 

39): 

40 idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

41 

42 if diagonal >= 0: 

43 row_idx = idx 

44 col_idx = row_idx + diagonal 

45 else: 

46 col_idx = idx 

47 row_idx = col_idx - diagonal 

48 mask = (row_idx < N) & (col_idx < M) 

49 

50 diag_value = tl.load( 

51 data_ptr + row_idx * stride0 + col_idx * stride1, mask=mask, other=0 

52 ) 

53 tl.store(output_ptr + idx, diag_value, mask=mask) 

54 

55 

56def diag_1d_to_2d(x, diagonal=0): 

57 N = x.shape[0] 

58 M = N + abs(diagonal) 

59 output = torch.zeros((M, M), dtype=x.dtype, device=x.device) 

60 

61 stride = x.stride(0) 

62 BLOCK_SIZE = 128 

63 

64 grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),) 

65 

66 with torch_device_fn.device(x.device): 

67 diag_1d_to_2d_kernel[grid]( 

68 x, output, N, M, stride, diagonal, BLOCK_SIZE=BLOCK_SIZE 

69 ) 

70 return output 

71 

72 

73def diag_2d_to_1d(x, diagonal=0): 

74 N, M = x.shape 

75 if diagonal >= 0: 

76 diag_len = min(N, M - diagonal) 

77 else: 

78 diag_len = min(N + diagonal, M) 

79 if diag_len <= 0: 

80 return torch.empty(0, dtype=x.dtype, device=x.device) 

81 output = torch.empty(diag_len, dtype=x.dtype, device=x.device) 

82 stride0 = x.stride(0) 

83 stride1 = x.stride(1) 

84 BLOCK_SIZE = 128 

85 

86 grid = lambda meta: (triton.cdiv(diag_len, BLOCK_SIZE),) 

87 

88 with torch_device_fn.device(x.device): 

89 diag_2d_to_1d_kernel[grid]( 

90 x, output, N, M, stride0, stride1, diagonal, BLOCK_SIZE=BLOCK_SIZE 

91 ) 

92 return output 

93 

94 

95def diag(x, diagonal=0): 

96 if x.dim() == 1: 

97 return diag_1d_to_2d(x, diagonal) 

98 elif x.dim() == 2: 

99 return diag_2d_to_1d(x, diagonal) 

100 else: 

101 raise ValueError("Input must be a 1D or 2D tensor.")