Coverage for src/flag_gems/runtime/backend/_ascend/ops/diag_embed.py: 0%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.codegen_config_utils import CodeGenConfig 

8 

9logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

10 

11 

12config_ = CodeGenConfig( 

13 512, 

14 tuple([48, 1, 1]), 

15 32, 

16 False, 

17 prefer_1d_tile=int(triton.__version__[0]) < 3, 

18) 

19 

20 

21@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_) 

22@triton.jit 

23def copy_func(x): 

24 return x 

25 

26 

27def diag_embed(x, offset=0, dim1=-2, dim2=-1): 

28 logger.debug("GEMS_ASCEND DIAG_EMBED") 

29 

30 rank = x.ndim + 1 

31 

32 assert dim1 >= -rank and dim1 < rank, f"Invalid dim1: {dim1}" 

33 assert dim2 >= -rank and dim2 < rank, f"Invalid dim2: {dim2}" 

34 # convert from negative dims 

35 dim1 = dim1 % rank 

36 dim2 = dim2 % rank 

37 

38 assert dim1 != dim2, "diagonal dimensions cannot be identical" 

39 

40 # as per the docs, exchanging dims is equivalent to changing the sign of 

41 # offset 

42 if dim1 > dim2: 

43 offset = -offset 

44 dim1, dim2 = dim2, dim1 

45 

46 # as per the docs, the size of last dim is placed at dim1 and dim2 

47 last_dim = x.size(-1) + abs(offset) 

48 

49 y_shape = list(x.shape) 

50 y_shape.pop() 

51 y_shape.insert(dim1, last_dim) 

52 y_shape.insert(dim2, last_dim) 

53 

54 y = torch.zeros(y_shape, dtype=x.dtype, device=x.device) 

55 y_diagonal_view = torch.diagonal(y, offset, dim1, dim2) 

56 copy_func.instantiate(x.ndim)(x, out0=y_diagonal_view) 

57 

58 return y