Coverage for src/flag_gems/ops/diag_embed.py: 97%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-10 02:30 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
12@triton.jit
13def copy_func(x):
14 return x
17def diag_embed(x, offset=0, dim1=-2, dim2=-1):
18 logger.debug("GEMS DIAG_EMBED")
20 rank = x.ndim + 1
22 assert dim1 >= -rank and dim1 < rank, f"Invalid dim1: {dim1}"
23 assert dim2 >= -rank and dim2 < rank, f"Invalid dim2: {dim2}"
24 # convert from negative dims
25 dim1 = dim1 % rank
26 dim2 = dim2 % rank
28 assert dim1 != dim2, "diagonal dimensions cannot be identical"
30 # as per the docs, exchanging dims is equivalent to changing the sign of
31 # offset
32 if dim1 > dim2:
33 offset = -offset
34 dim1, dim2 = dim2, dim1
36 # as per the docs, the size of last dim is placed at dim1 and dim2
37 last_dim = x.size(-1) + abs(offset)
39 y_shape = list(x.shape)
40 y_shape.pop()
41 y_shape.insert(dim1, last_dim)
42 y_shape.insert(dim2, last_dim)
44 y = torch.zeros(y_shape, dtype=x.dtype, device=x.device)
45 y_diagonal_view = torch.diagonal(y, offset, dim1, dim2)
46 copy_func.instantiate(x.ndim)(x, out0=y_diagonal_view)
48 return y