Coverage for src/flag_gems/runtime/backend/_ascend/ops/diag.py: 0%
62 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-09 01:57 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as tle
10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13@triton.jit
14def diag_1d_to_2d_kernel(
15 data_ptr, output_ptr, N, M, stride, diagonal: tl.constexpr, BLOCK_SIZE: tl.constexpr
16):
17 idx = tle.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
19 if diagonal >= 0:
20 row_idx = idx
21 col_idx = row_idx + diagonal
22 else:
23 col_idx = idx
24 row_idx = col_idx + tl.abs(diagonal)
26 mask = (row_idx < M) & (col_idx < M)
28 diag_value = tl.load(data_ptr + idx * stride, mask=idx < N, other=0)
30 out_offset = row_idx * M + col_idx
31 tl.store(output_ptr + out_offset, diag_value, mask=mask)
34@triton.jit
35def diag_2d_to_1d_kernel(
36 data_ptr,
37 output_ptr,
38 N,
39 M,
40 stride0,
41 stride1,
42 diagonal: tl.constexpr,
43 BLOCK_SIZE: tl.constexpr,
44):
45 idx = tle.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
47 if diagonal >= 0:
48 row_idx = idx
49 col_idx = row_idx + diagonal
50 else:
51 col_idx = idx
52 row_idx = col_idx + tl.abs(diagonal)
53 mask = (row_idx < N) & (col_idx < M)
55 diag_value = tl.load(
56 data_ptr + row_idx * stride0 + col_idx * stride1, mask=mask, other=0
57 )
58 tl.store(output_ptr + idx, diag_value, mask=mask)
61def diag_1d_to_2d(x, diagonal=0):
62 N = x.shape[0]
63 M = N + abs(diagonal)
64 output = torch.zeros((M, M), dtype=x.dtype, device=x.device)
66 stride = x.stride(0)
67 BLOCK_SIZE = 128
69 grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)
71 with torch_device_fn.device(x.device):
72 diag_1d_to_2d_kernel[grid](
73 x, output, N, M, stride, diagonal, BLOCK_SIZE=BLOCK_SIZE
74 )
75 return output
78def diag_2d_to_1d(x, diagonal=0):
79 N, M = x.shape
80 if diagonal >= 0:
81 diag_len = min(N, M - diagonal)
82 else:
83 diag_len = min(N + diagonal, M)
84 if diag_len <= 0:
85 return torch.empty(0, dtype=x.dtype, device=x.device)
86 output = torch.empty(diag_len, dtype=x.dtype, device=x.device)
87 stride0 = x.stride(0)
88 stride1 = x.stride(1)
89 BLOCK_SIZE = 128
91 grid = lambda meta: (triton.cdiv(diag_len, BLOCK_SIZE),)
93 with torch_device_fn.device(x.device):
94 diag_2d_to_1d_kernel[grid](
95 x, output, N, M, stride0, stride1, diagonal, BLOCK_SIZE=BLOCK_SIZE
96 )
97 return output
100def diag(x, diagonal=0):
101 logger.debug("GEMS_ASCEND DIAG")
102 if x.dim() == 1:
103 return diag_1d_to_2d(x, diagonal)
104 elif x.dim() == 2:
105 return diag_2d_to_1d(x, diagonal)
106 else:
107 raise ValueError("Input must be a 1D or 2D tensor.")