Coverage for src/flag_gems/runtime/backend/_ascend/ops/diagonal.py: 0%

17 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.codegen_config_utils import CodeGenConfig 

8 

9logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

10 

11 

12config_ = CodeGenConfig( 

13 2048, 

14 tuple([48, 1, 1]), 

15 32, 

16 False, 

17 prefer_1d_tile=int(triton.__version__[0]) < 3, 

18) 

19 

20 

21@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_) 

22@triton.jit 

23def copy_func(x): 

24 return x 

25 

26 

27def diagonal_backward(grad_output, input_sizes, offset, dim1, dim2): 

28 logger.debug("GEMS_ASCEND DIAGONAL BACKWARD") 

29 grad_input = torch.zeros( 

30 input_sizes, dtype=grad_output.dtype, device=grad_output.device 

31 ) 

32 diag = torch.diagonal(grad_input, offset, dim1, dim2) 

33 copy_func.instantiate(grad_output.ndim)(grad_output, out0=diag) 

34 return grad_input