Coverage for src/flag_gems/experimental_ops/trace.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-11 02:28 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def trace_kernel(
8 x_ptr,
9 stride0,
10 stride1,
11 diag_len,
12 out_ptr,
13 OUT_TYPE: tl.constexpr,
14 BLOCK: tl.constexpr,
15):
16 # Accumulate in float32 for numerical stability across input dtypes
17 acc = tl.zeros([BLOCK], dtype=tl.float32)
18 sdiag = stride0 + stride1
20 i = 0
21 while i < diag_len:
22 idx = i + tl.arange(0, BLOCK)
23 mask = idx < diag_len
24 ptrs = x_ptr + idx * sdiag
25 vals = tl.load(ptrs, mask=mask, other=0)
26 acc += tl.cast(vals, tl.float32)
27 i += BLOCK
29 total = tl.sum(acc, axis=0)
31 # Cast to desired output dtype based on OUT_TYPE code
32 if OUT_TYPE == 0:
33 val = tl.cast(total, tl.float16)
34 elif OUT_TYPE == 1:
35 val = tl.cast(total, tl.bfloat16)
36 elif OUT_TYPE == 2:
37 val = tl.cast(total, tl.float32)
38 elif OUT_TYPE == 3:
39 val = tl.cast(total, tl.float64)
40 elif OUT_TYPE == 4:
41 val = tl.cast(total, tl.int32)
42 elif OUT_TYPE == 5:
43 val = tl.cast(total, tl.int64)
44 elif OUT_TYPE == 6:
45 val = tl.cast(total, tl.int16)
46 elif OUT_TYPE == 7:
47 val = tl.cast(total, tl.int8)
48 elif OUT_TYPE == 8:
49 val = tl.cast(total, tl.uint8)
50 else:
51 val = tl.cast(total, tl.float32)
53 tl.store(out_ptr, val)
56def _dtype_to_code(dtype: torch.dtype) -> int:
57 mapping = {
58 torch.float16: 0,
59 torch.bfloat16: 1,
60 torch.float32: 2,
61 torch.float64: 3,
62 torch.int32: 4,
63 torch.int64: 5,
64 torch.int16: 6,
65 torch.int8: 7,
66 torch.uint8: 8,
67 }
68 if dtype not in mapping:
69 raise ValueError(f"Unsupported dtype for trace kernel: {dtype}")
70 return mapping[dtype]
73def _launch_trace_kernel(input: torch.Tensor, out: torch.Tensor):
74 if not input.is_cuda or not out.is_cuda:
75 raise ValueError("trace kernel requires CUDA tensors")
76 if input.dim() != 2:
77 raise ValueError(f"trace expects a 2D tensor, got {input.dim()}D")
78 if out.numel() != 1:
79 raise ValueError("out tensor must have a single element (0-dim/scalar)")
81 n0, n1 = input.shape
82 diag_len = min(n0, n1)
83 s0, s1 = input.stride()
85 out_code = _dtype_to_code(out.dtype)
87 grid = lambda meta: (1,)
88 trace_kernel[grid](
89 input,
90 s0,
91 s1,
92 diag_len,
93 out,
94 OUT_TYPE=out_code,
95 BLOCK=1024,
96 )
99def trace(input: torch.Tensor):
100 out = torch.empty((), device=input.device, dtype=input.dtype)
101 _launch_trace_kernel(input, out)
102 return out
105def trace_out(input: torch.Tensor, out: torch.Tensor):
106 _launch_trace_kernel(input, out)
107 return out