Coverage for src/flag_gems/ops/trace.py: 58%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-22 16:54 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@libentry() 

14@triton.jit 

15def trace_kernel( 

16 inp_ptr, 

17 out_ptr, 

18 num_diag, 

19 stride0, 

20 stride1, 

21 BLOCK_SIZE: tl.constexpr, 

22): 

23 inp_dtype = inp_ptr.type.element_ty 

24 if inp_dtype.is_int(): 

25 acc_dtype = tl.int64 

26 other_val = 0 

27 elif inp_dtype == tl.float64: 

28 acc_dtype = tl.float64 

29 other_val = 0.0 

30 else: 

31 acc_dtype = tl.float32 

32 other_val = 0.0 

33 

34 acc = tl.zeros((BLOCK_SIZE,), dtype=acc_dtype) 

35 

36 diag_stride = stride0 + stride1 

37 

38 for i in range(0, tl.cdiv(num_diag, BLOCK_SIZE)): 

39 block_start = i * BLOCK_SIZE 

40 current_indices = block_start + tl.arange(0, BLOCK_SIZE) 

41 

42 mask = current_indices < num_diag 

43 

44 ptr_offsets = current_indices * diag_stride 

45 current_ptrs = inp_ptr + ptr_offsets 

46 

47 vals = tl.load(current_ptrs, mask=mask, other=other_val) 

48 

49 acc += vals.to(acc_dtype) 

50 

51 final_sum = tl.sum(acc, axis=0) 

52 tl.store(out_ptr, final_sum.to(out_ptr.type.element_ty)) 

53 

54 

55def trace(self): 

56 logger.debug("GEMS TRACE") 

57 

58 if self.ndim != 2: 

59 raise RuntimeError( 

60 f"trace: expected a 2D tensor, but got a {self.ndim}D tensor" 

61 ) 

62 

63 M, N = self.shape 

64 stride0, stride1 = self.stride() 

65 num_diag = min(M, N) 

66 if num_diag == 0: 

67 if self.dtype.is_floating_point: 

68 return torch.tensor(0.0, dtype=self.dtype, device=self.device) 

69 else: 

70 return torch.tensor(0, dtype=torch.int64, device=self.device) 

71 

72 if self.dtype.is_floating_point: 

73 output_dtype = self.dtype 

74 else: 

75 output_dtype = torch.int64 

76 out = torch.empty((), dtype=output_dtype, device=self.device) 

77 

78 grid = (1,) 

79 BLOCK_SIZE = 1024 

80 if num_diag < BLOCK_SIZE: 

81 BLOCK_SIZE = triton.next_power_of_2(num_diag) 

82 if BLOCK_SIZE == 0: 

83 BLOCK_SIZE = 1 

84 

85 with torch_device_fn.device(self.device): 

86 trace_kernel[grid]( 

87 self, 

88 out, 

89 num_diag, 

90 stride0, 

91 stride1, 

92 BLOCK_SIZE=BLOCK_SIZE, 

93 ) 

94 

95 return out