Coverage for src/flag_gems/experimental_ops/trace.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-28 12:23 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def trace_kernel( 

8 x_ptr, 

9 stride0, 

10 stride1, 

11 diag_len, 

12 out_ptr, 

13 OUT_TYPE: tl.constexpr, 

14 BLOCK: tl.constexpr, 

15): 

16 # Accumulate in float32 for numerical stability across input dtypes 

17 acc = tl.zeros([BLOCK], dtype=tl.float32) 

18 sdiag = stride0 + stride1 

19 

20 i = 0 

21 while i < diag_len: 

22 idx = i + tl.arange(0, BLOCK) 

23 mask = idx < diag_len 

24 ptrs = x_ptr + idx * sdiag 

25 vals = tl.load(ptrs, mask=mask, other=0) 

26 acc += tl.cast(vals, tl.float32) 

27 i += BLOCK 

28 

29 total = tl.sum(acc, axis=0) 

30 

31 # Cast to desired output dtype based on OUT_TYPE code 

32 if OUT_TYPE == 0: 

33 val = tl.cast(total, tl.float16) 

34 elif OUT_TYPE == 1: 

35 val = tl.cast(total, tl.bfloat16) 

36 elif OUT_TYPE == 2: 

37 val = tl.cast(total, tl.float32) 

38 elif OUT_TYPE == 3: 

39 val = tl.cast(total, tl.float64) 

40 elif OUT_TYPE == 4: 

41 val = tl.cast(total, tl.int32) 

42 elif OUT_TYPE == 5: 

43 val = tl.cast(total, tl.int64) 

44 elif OUT_TYPE == 6: 

45 val = tl.cast(total, tl.int16) 

46 elif OUT_TYPE == 7: 

47 val = tl.cast(total, tl.int8) 

48 elif OUT_TYPE == 8: 

49 val = tl.cast(total, tl.uint8) 

50 else: 

51 val = tl.cast(total, tl.float32) 

52 

53 tl.store(out_ptr, val) 

54 

55 

56def _dtype_to_code(dtype: torch.dtype) -> int: 

57 mapping = { 

58 torch.float16: 0, 

59 torch.bfloat16: 1, 

60 torch.float32: 2, 

61 torch.float64: 3, 

62 torch.int32: 4, 

63 torch.int64: 5, 

64 torch.int16: 6, 

65 torch.int8: 7, 

66 torch.uint8: 8, 

67 } 

68 if dtype not in mapping: 

69 raise ValueError(f"Unsupported dtype for trace kernel: {dtype}") 

70 return mapping[dtype] 

71 

72 

73def _launch_trace_kernel(input: torch.Tensor, out: torch.Tensor): 

74 if not input.is_cuda or not out.is_cuda: 

75 raise ValueError("trace kernel requires CUDA tensors") 

76 if input.dim() != 2: 

77 raise ValueError(f"trace expects a 2D tensor, got {input.dim()}D") 

78 if out.numel() != 1: 

79 raise ValueError("out tensor must have a single element (0-dim/scalar)") 

80 

81 n0, n1 = input.shape 

82 diag_len = min(n0, n1) 

83 s0, s1 = input.stride() 

84 

85 out_code = _dtype_to_code(out.dtype) 

86 

87 grid = lambda meta: (1,) 

88 trace_kernel[grid]( 

89 input, 

90 s0, 

91 s1, 

92 diag_len, 

93 out, 

94 OUT_TYPE=out_code, 

95 BLOCK=1024, 

96 ) 

97 

98 

99def trace(input: torch.Tensor): 

100 out = torch.empty((), device=input.device, dtype=input.dtype) 

101 _launch_trace_kernel(input, out) 

102 return out 

103 

104 

105def trace_out(input: torch.Tensor, out: torch.Tensor): 

106 _launch_trace_kernel(input, out) 

107 return out