Coverage for src/flag_gems/runtime/backend/_ascend/ops/diag_embed.py: 0%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-25 02:48 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.codegen_config_utils import CodeGenConfig
9logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
12config_ = CodeGenConfig(
13 512,
14 tuple([48, 1, 1]),
15 32,
16 False,
17 prefer_1d_tile=int(triton.__version__[0]) < 3,
18)
21@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_)
22@triton.jit
23def copy_func(x):
24 return x
27def diag_embed(x, offset=0, dim1=-2, dim2=-1):
28 logger.debug("GEMS_ASCEND DIAG_EMBED")
30 rank = x.ndim + 1
32 assert dim1 >= -rank and dim1 < rank, f"Invalid dim1: {dim1}"
33 assert dim2 >= -rank and dim2 < rank, f"Invalid dim2: {dim2}"
34 # convert from negative dims
35 dim1 = dim1 % rank
36 dim2 = dim2 % rank
38 assert dim1 != dim2, "diagonal dimensions cannot be identical"
40 # as per the docs, exchanging dims is equivalent to changing the sign of
41 # offset
42 if dim1 > dim2:
43 offset = -offset
44 dim1, dim2 = dim2, dim1
46 # as per the docs, the size of last dim is placed at dim1 and dim2
47 last_dim = x.size(-1) + abs(offset)
49 y_shape = list(x.shape)
50 y_shape.pop()
51 y_shape.insert(dim1, last_dim)
52 y_shape.insert(dim2, last_dim)
54 y = torch.zeros(y_shape, dtype=x.dtype, device=x.device)
55 y_diagonal_view = torch.diagonal(y, offset, dim1, dim2)
56 copy_func.instantiate(x.ndim)(x, out0=y_diagonal_view)
58 return y