Coverage for src/flag_gems/ops/diag_embed.py: 97%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")]) 

12@triton.jit 

13def copy_func(x): 

14 return x 

15 

16 

17def diag_embed(x, offset=0, dim1=-2, dim2=-1): 

18 logger.debug("GEMS DIAG_EMBED") 

19 

20 rank = x.ndim + 1 

21 

22 assert dim1 >= -rank and dim1 < rank, f"Invalid dim1: {dim1}" 

23 assert dim2 >= -rank and dim2 < rank, f"Invalid dim2: {dim2}" 

24 # convert from negative dims 

25 dim1 = dim1 % rank 

26 dim2 = dim2 % rank 

27 

28 assert dim1 != dim2, "diagonal dimensions cannot be identical" 

29 

30 # as per the docs, exchanging dims is equivalent to changing the sign of 

31 # offset 

32 if dim1 > dim2: 

33 offset = -offset 

34 dim1, dim2 = dim2, dim1 

35 

36 # as per the docs, the size of last dim is placed at dim1 and dim2 

37 last_dim = x.size(-1) + abs(offset) 

38 

39 y_shape = list(x.shape) 

40 y_shape.pop() 

41 y_shape.insert(dim1, last_dim) 

42 y_shape.insert(dim2, last_dim) 

43 

44 y = torch.zeros(y_shape, dtype=x.dtype, device=x.device) 

45 y_diagonal_view = torch.diagonal(y, offset, dim1, dim2) 

46 copy_func.instantiate(x.ndim)(x, out0=y_diagonal_view) 

47 

48 return y