Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/trace.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-16 02:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13@libentry()
14@triton.jit
15def trace_kernel(
16 inp_ptr,
17 out_ptr,
18 num_diag,
19 stride0,
20 stride1,
21 BLOCK_SIZE: tl.constexpr,
22):
23 inp_dtype = inp_ptr.type.element_ty
24 if inp_dtype.is_int():
25 acc_dtype = tl.int64
26 other_val = 0
27 elif inp_dtype == tl.float64:
28 acc_dtype = tl.float64
29 other_val = 0.0
30 else:
31 acc_dtype = tl.float32
32 other_val = 0.0
34 acc = tl.zeros((BLOCK_SIZE,), dtype=acc_dtype)
36 diag_stride = stride0 + stride1
38 for i in range(0, tl.cdiv(num_diag, BLOCK_SIZE)):
39 block_start = i * BLOCK_SIZE
40 current_indices = block_start + tl.arange(0, BLOCK_SIZE)
42 mask = current_indices < num_diag
44 ptr_offsets = current_indices * diag_stride
45 current_ptrs = inp_ptr + ptr_offsets
47 vals = tl.load(current_ptrs, mask=mask, other=other_val)
49 acc += vals.to(acc_dtype)
51 final_sum = tl.sum(acc, axis=0)
52 tl.store(out_ptr, final_sum.to(out_ptr.type.element_ty))
55def trace(self):
56 logger.debug("GEMS TRACE")
58 if self.ndim != 2:
59 raise RuntimeError(
60 f"trace: expected a 2D tensor, but got a {self.ndim}D tensor"
61 )
63 M, N = self.shape
64 stride0, stride1 = self.stride()
65 num_diag = min(M, N)
66 if num_diag == 0:
67 if self.dtype.is_floating_point:
68 return torch.tensor(0.0, dtype=self.dtype, device=self.device)
69 else:
70 return torch.tensor(0, dtype=torch.int64, device=self.device)
72 if self.dtype.is_floating_point:
73 output_dtype = self.dtype
74 else:
75 output_dtype = torch.int64
76 out = torch.empty((), dtype=output_dtype, device=self.device)
78 grid = (1,)
79 BLOCK_SIZE = 1024
80 if num_diag < BLOCK_SIZE:
81 BLOCK_SIZE = triton.next_power_of_2(num_diag)
82 if BLOCK_SIZE == 0:
83 BLOCK_SIZE = 1
85 with torch_device_fn.device(self.device):
86 trace_kernel[grid](
87 self,
88 out,
89 num_diag,
90 stride0,
91 stride1,
92 BLOCK_SIZE=BLOCK_SIZE,
93 )
95 return out