Coverage for src/flag_gems/runtime/backend/_ascend/ops/diagonal.py: 0%
17 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.codegen_config_utils import CodeGenConfig
9logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
12config_ = CodeGenConfig(
13 2048,
14 tuple([48, 1, 1]),
15 32,
16 False,
17 prefer_1d_tile=int(triton.__version__[0]) < 3,
18)
21@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")], config=config_)
22@triton.jit
23def copy_func(x):
24 return x
27def diagonal_backward(grad_output, input_sizes, offset, dim1, dim2):
28 logger.debug("GEMS_ASCEND DIAGONAL BACKWARD")
29 grad_input = torch.zeros(
30 input_sizes, dtype=grad_output.dtype, device=grad_output.device
31 )
32 diag = torch.diagonal(grad_input, offset, dim1, dim2)
33 copy_func.instantiate(grad_output.ndim)(grad_output, out0=diag)
34 return grad_input