Coverage for src/flag_gems/experimental_ops/eye.py: 0%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-10 02:30 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def eye_kernel( 

8 out_ptr, # *Pointer* to output 2D tensor 

9 n_rows, # number of rows (n) 

10 n_cols, # number of cols (m) 

11 stride_row, # stride for row dimension 

12 stride_col, # stride for col dimension 

13 BLOCK_M: tl.constexpr, 

14 BLOCK_N: tl.constexpr, 

15): 

16 pid_m = tl.program_id(axis=0) 

17 pid_n = tl.program_id(axis=1) 

18 

19 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

20 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

21 

22 row_idx = offs_m[:, None] 

23 col_idx = offs_n[None, :] 

24 

25 in_bounds = (row_idx < n_rows) & (col_idx < n_cols) 

26 is_diag = row_idx == col_idx 

27 

28 # Produce 1 on diagonal, 0 elsewhere; Triton will cast to the pointer dtype on store 

29 vals = tl.where(is_diag, 1, 0) 

30 ptrs = out_ptr + row_idx * stride_row + col_idx * stride_col 

31 tl.store(ptrs, vals, mask=in_bounds) 

32 

33 

34# Shared implementation 

35def _eye_impl(n, m=None, dtype=None, device=None, out: torch.Tensor = None): 

36 if m is None: 

37 m = n 

38 

39 if out is None: 

40 if dtype is None: 

41 dtype = torch.get_default_dtype() 

42 if device is None: 

43 device = ( 

44 torch.device("cuda") 

45 if torch.cuda.is_available() 

46 else torch.device("cpu") 

47 ) 

48 out = torch.empty((n, m), dtype=dtype, device=device) 

49 else: 

50 if out.dim() != 2: 

51 raise ValueError("out tensor must be 2D") 

52 # Resize to expected shape if necessary 

53 if out.shape[0] != n or out.shape[1] != m: 

54 out.resize_(n, m) 

55 

56 # Handle empty tensors 

57 if n == 0 or m == 0: 

58 out.zero_() 

59 return out 

60 

61 # CUDA path uses Triton 

62 if out.is_cuda: 

63 BLOCK_M = 64 

64 BLOCK_N = 64 

65 grid = (triton.cdiv(n, BLOCK_M), triton.cdiv(m, BLOCK_N)) 

66 eye_kernel[grid]( 

67 out, 

68 n, 

69 m, 

70 out.stride(0), 

71 out.stride(1), 

72 BLOCK_M=BLOCK_M, 

73 BLOCK_N=BLOCK_N, 

74 ) 

75 return out 

76 

77 # CPU fallback without calling torch.eye 

78 out.zero_() 

79 k = min(n, m) 

80 if k > 0: 

81 idx = torch.arange(k, device=out.device) 

82 one = torch.ones(k, dtype=out.dtype, device=out.device) 

83 out[idx, idx] = one 

84 return out 

85 

86 

87# Wrappers for ATen operator interfaces 

88 

89 

90def eye(n, m=None, dtype=None, device=None): 

91 return _eye_impl(n, m, dtype, device, out=None) 

92 

93 

94def eye_m(n, m, dtype=None, device=None): 

95 return _eye_impl(n, m, dtype, device, out=None) 

96 

97 

98def eye_out(n, out: torch.Tensor): 

99 # eye.out expects shape (n, n) 

100 return _eye_impl(n, n, dtype=out.dtype, device=out.device, out=out) 

101 

102 

103def eye_m_out(n, m, out: torch.Tensor): 

104 # eye.m_out expects shape (n, m) 

105 return _eye_impl(n, m, dtype=out.dtype, device=out.device, out=out)