Coverage for src/flag_gems/experimental_ops/trunc.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-29 04:01 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def trunc_kernel( 

8 x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr, DTYPE_CODE: tl.constexpr 

9): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 

15 x = tl.load(x_ptr + offsets, mask=mask) 

16 

17 # DTYPE_CODE: 

18 # 0 -> integer types (copy) 

19 # 1 -> float16 

20 # 2 -> bfloat16 

21 # 3 -> float32 

22 # 4 -> float64 

23 if DTYPE_CODE == 0: 

24 y = x 

25 elif DTYPE_CODE == 1: 

26 xf = x.to(tl.float32) 

27 y = tl.where(xf >= 0, tl.floor(xf), tl.ceil(xf)).to(tl.float16) 

28 elif DTYPE_CODE == 2: 

29 xf = x.to(tl.float32) 

30 y = tl.where(xf >= 0, tl.floor(xf), tl.ceil(xf)).to(tl.bfloat16) 

31 elif DTYPE_CODE == 3: 

32 xf = x 

33 y = tl.where(xf >= 0, tl.floor(xf), tl.ceil(xf)) 

34 elif DTYPE_CODE == 4: 

35 xf = x 

36 y = tl.where(xf >= 0, tl.floor(xf), tl.ceil(xf)) 

37 else: 

38 # Fallback: copy 

39 y = x 

40 

41 tl.store(out_ptr + offsets, y, mask=mask) 

42 

43 

44def _dtype_code(t: torch.Tensor) -> int: 

45 if t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64): 

46 return 0 

47 if t.dtype == torch.float16: 

48 return 1 

49 if t.dtype == torch.bfloat16: 

50 return 2 

51 if t.dtype == torch.float32: 

52 return 3 

53 if t.dtype == torch.float64: 

54 return 4 

55 raise NotImplementedError(f"Unsupported dtype: {t.dtype}") 

56 

57 

58def _launch_trunc(inp: torch.Tensor, out: torch.Tensor): 

59 assert inp.numel() == out.numel() 

60 assert inp.device.type == "cuda" and out.device.type == "cuda" 

61 n_elements = inp.numel() 

62 if n_elements == 0: 

63 return 

64 

65 code = _dtype_code(inp) 

66 BLOCK_SIZE = 1024 

67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

68 

69 trunc_kernel[grid](inp, out, n_elements, BLOCK_SIZE=BLOCK_SIZE, DTYPE_CODE=code) 

70 

71 

72def trunc(input: torch.Tensor): 

73 # Allocate output 

74 out = torch.empty_like(input) 

75 

76 if input.is_complex(): 

77 # Work on real view 

78 in_r = torch.view_as_real(input) 

79 out_r = torch.view_as_real(out) 

80 if not in_r.is_contiguous() or not out_r.is_contiguous(): 

81 in_r_c = in_r.contiguous() 

82 out_r_c = out_r.contiguous() 

83 _launch_trunc(in_r_c.view(-1), out_r_c.view(-1)) 

84 out_r.copy_(out_r_c) 

85 else: 

86 _launch_trunc(in_r.view(-1), out_r.view(-1)) 

87 else: 

88 inp_c = input if input.is_contiguous() else input.contiguous() 

89 out_c = out if out.is_contiguous() else out.contiguous() 

90 _launch_trunc(inp_c.view(-1), out_c.view(-1)) 

91 if out_c.data_ptr() != out.data_ptr(): 

92 out.copy_(out_c) 

93 

94 return out 

95 

96 

97def trunc_out(input: torch.Tensor, out: torch.Tensor): 

98 assert input.shape == out.shape, "input and out must have the same shape" 

99 assert input.dtype == out.dtype, "input and out must have the same dtype" 

100 assert ( 

101 input.device.type == "cuda" and out.device.type == "cuda" 

102 ), "Tensors must be on CUDA device" 

103 

104 if input.is_complex(): 

105 in_r = torch.view_as_real(input) 

106 out_r = torch.view_as_real(out) 

107 if not in_r.is_contiguous() or not out_r.is_contiguous(): 

108 in_r_c = in_r.contiguous() 

109 out_r_c = out_r.contiguous() 

110 _launch_trunc(in_r_c.view(-1), out_r_c.view(-1)) 

111 out_r.copy_(out_r_c) 

112 else: 

113 _launch_trunc(in_r.view(-1), out_r.view(-1)) 

114 else: 

115 inp_c = input if input.is_contiguous() else input.contiguous() 

116 if out.is_contiguous(): 

117 _launch_trunc(inp_c.view(-1), out.view(-1)) 

118 else: 

119 out_c = out.contiguous() 

120 _launch_trunc(inp_c.view(-1), out_c.view(-1)) 

121 out.copy_(out_c) 

122 

123 return out