Coverage for src/flag_gems/runtime/backend/_cambricon/ops/diag.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import torch
2import triton
3import triton.language as tl
5from flag_gems.runtime import torch_device_fn
8@triton.jit
9def diag_1d_to_2d_kernel(
10 data_ptr, output_ptr, N, M, stride, diagonal: tl.constexpr, BLOCK_SIZE: tl.constexpr
11):
12 idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
14 if diagonal >= 0:
15 row_idx = idx
16 col_idx = row_idx + diagonal
17 else:
18 col_idx = idx
19 row_idx = col_idx - diagonal
21 mask = (row_idx < M) & (col_idx < M)
23 diag_value = tl.load(data_ptr + idx * stride, mask=idx < N, other=0)
25 out_offset = row_idx * M + col_idx
26 tl.store(output_ptr + out_offset, diag_value, mask=mask)
29@triton.jit
30def diag_2d_to_1d_kernel(
31 data_ptr,
32 output_ptr,
33 N,
34 M,
35 stride0,
36 stride1,
37 diagonal: tl.constexpr,
38 BLOCK_SIZE: tl.constexpr,
39):
40 idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 if diagonal >= 0:
43 row_idx = idx
44 col_idx = row_idx + diagonal
45 else:
46 col_idx = idx
47 row_idx = col_idx - diagonal
48 mask = (row_idx < N) & (col_idx < M)
50 diag_value = tl.load(
51 data_ptr + row_idx * stride0 + col_idx * stride1, mask=mask, other=0
52 )
53 tl.store(output_ptr + idx, diag_value, mask=mask)
56def diag_1d_to_2d(x, diagonal=0):
57 N = x.shape[0]
58 M = N + abs(diagonal)
59 output = torch.zeros((M, M), dtype=x.dtype, device=x.device)
61 stride = x.stride(0)
62 BLOCK_SIZE = 128
64 grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)
66 with torch_device_fn.device(x.device):
67 diag_1d_to_2d_kernel[grid](
68 x, output, N, M, stride, diagonal, BLOCK_SIZE=BLOCK_SIZE
69 )
70 return output
73def diag_2d_to_1d(x, diagonal=0):
74 N, M = x.shape
75 if diagonal >= 0:
76 diag_len = min(N, M - diagonal)
77 else:
78 diag_len = min(N + diagonal, M)
79 if diag_len <= 0:
80 return torch.empty(0, dtype=x.dtype, device=x.device)
81 output = torch.empty(diag_len, dtype=x.dtype, device=x.device)
82 stride0 = x.stride(0)
83 stride1 = x.stride(1)
84 BLOCK_SIZE = 128
86 grid = lambda meta: (triton.cdiv(diag_len, BLOCK_SIZE),)
88 with torch_device_fn.device(x.device):
89 diag_2d_to_1d_kernel[grid](
90 x, output, N, M, stride0, stride1, diagonal, BLOCK_SIZE=BLOCK_SIZE
91 )
92 return output
95def diag(x, diagonal=0):
96 if x.dim() == 1:
97 return diag_1d_to_2d(x, diagonal)
98 elif x.dim() == 2:
99 return diag_2d_to_1d(x, diagonal)
100 else:
101 raise ValueError("Input must be a 1D or 2D tensor.")