Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/diag.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import torch
2import triton
3import triton.language as tl
5from flag_gems.runtime import torch_device_fn
6from flag_gems.utils import triton_lang_extension as tle
9@triton.jit
10def diag_1d_to_2d_kernel(
11 data_ptr, output_ptr, N, M, stride, diagonal: tl.constexpr, BLOCK_SIZE: tl.constexpr
12):
13 idx = tle.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
15 if diagonal >= 0:
16 row_idx = idx
17 col_idx = row_idx + diagonal
18 else:
19 col_idx = idx
20 row_idx = col_idx - diagonal
22 mask = (row_idx < M) & (col_idx < M)
24 diag_value = tl.load(data_ptr + idx * stride, mask=idx < N, other=0)
26 out_offset = row_idx * M + col_idx
27 tl.store(output_ptr + out_offset, diag_value, mask=mask)
30@triton.jit
31def diag_2d_to_1d_kernel(
32 data_ptr,
33 output_ptr,
34 N,
35 M,
36 stride0,
37 stride1,
38 diagonal: tl.constexpr,
39 BLOCK_SIZE: tl.constexpr,
40):
41 idx = tle.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
43 if diagonal >= 0:
44 row_idx = idx
45 col_idx = row_idx + diagonal
46 else:
47 col_idx = idx
48 row_idx = col_idx - diagonal
49 mask = (row_idx < N) & (col_idx < M)
51 diag_value = tl.load(
52 data_ptr + row_idx * stride0 + col_idx * stride1, mask=mask, other=0
53 )
54 tl.store(output_ptr + idx, diag_value, mask=mask)
57def diag_1d_to_2d(x, diagonal=0):
58 N = x.shape[0]
59 M = N + abs(diagonal)
60 output = torch.zeros((M, M), dtype=x.dtype, device=x.device)
62 stride = x.stride(0)
63 # BLOCK_SIZE = 128
64 BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(N, 12))
66 grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)
68 with torch_device_fn.device(x.device):
69 diag_1d_to_2d_kernel[grid](
70 x, output, N, M, stride, diagonal, BLOCK_SIZE=BLOCK_SIZE
71 )
72 return output
75def diag_2d_to_1d(x, diagonal=0):
76 N, M = x.shape
77 if diagonal >= 0:
78 diag_len = min(N, M - diagonal)
79 else:
80 diag_len = min(N + diagonal, M)
81 if diag_len <= 0:
82 return torch.empty(0, dtype=x.dtype, device=x.device)
83 output = torch.empty(diag_len, dtype=x.dtype, device=x.device)
84 stride0 = x.stride(0)
85 stride1 = x.stride(1)
86 BLOCK_SIZE = 128
88 grid = lambda meta: (triton.cdiv(diag_len, BLOCK_SIZE),)
90 with torch_device_fn.device(x.device):
91 diag_2d_to_1d_kernel[grid](
92 x, output, N, M, stride0, stride1, diagonal, BLOCK_SIZE=BLOCK_SIZE
93 )
94 return output
97def diag(x, diagonal=0):
98 if x.dim() == 1:
99 return diag_1d_to_2d(x, diagonal)
100 elif x.dim() == 2:
101 return diag_2d_to_1d(x, diagonal)
102 else:
103 raise ValueError("Input must be a 1D or 2D tensor.")