Coverage for src/flag_gems/experimental_ops/fix.py: 0%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-12 02:21 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _fix_trunc_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 

15 # Upcast to fp32 for stable math, then cast back to input dtype 

16 x_fp32 = x.to(tl.float32) 

17 res_pos = tl.floor(x_fp32) 

18 res_neg = tl.ceil(x_fp32) 

19 res_fp32 = tl.where(x_fp32 >= 0, res_pos, res_neg) 

20 res = res_fp32.to(x.dtype) 

21 

22 tl.store(out_ptr + offsets, res, mask=mask) 

23 

24 

25@triton.jit 

26def _copy_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

27 pid = tl.program_id(axis=0) 

28 block_start = pid * BLOCK_SIZE 

29 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

30 mask = offsets < n_elements 

31 x = tl.load(x_ptr + offsets, mask=mask) 

32 tl.store(out_ptr + offsets, x, mask=mask) 

33 

34 

35def _launch_fix_kernel(x: torch.Tensor, out: torch.Tensor, block_size: int = 1024): 

36 assert x.is_cuda and out.is_cuda, "Input and output must be on CUDA device" 

37 assert ( 

38 x.numel() == out.numel() 

39 ), "Input and output must have the same number of elements" 

40 assert x.device == out.device, "Input and output must be on the same device" 

41 assert ( 

42 x.is_contiguous() and out.is_contiguous() 

43 ), "Only contiguous tensors are supported" 

44 

45 n_elements = x.numel() 

46 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

47 

48 # For floating point types, perform truncation toward zero. 

49 # For non-floating types, it's effectively a no-op (copy). 

50 if x.is_floating_point(): 

51 _fix_trunc_kernel[grid](x, out, n_elements, BLOCK_SIZE=block_size) 

52 else: 

53 _copy_kernel[grid](x, out, n_elements, BLOCK_SIZE=block_size) 

54 

55 

56def fix(self: torch.Tensor): 

57 """ 

58 Wrapper for ATen operator: ('fix', <Autograd.disable: False>) 

59 Rounds elements toward zero (like trunc) for floating tensors. 

60 Leaves integer tensors unchanged. 

61 """ 

62 if self.is_complex(): 

63 # Fallback for complex dtypes not supported by Triton: use PyTorch 

64 return torch.trunc(self) 

65 

66 out = torch.empty_like(self) 

67 _launch_fix_kernel(self, out) 

68 return out 

69 

70 

71def fix_out(self: torch.Tensor, out: torch.Tensor): 

72 """ 

73 Wrapper for ATen operator: ('fix.out', <Autograd.disable: False>) 

74 Writes the result into 'out'. 

75 """ 

76 if self.is_complex(): 

77 # Fallback for complex dtypes not supported by Triton: use PyTorch 

78 out.copy_(torch.trunc(self)) 

79 return out 

80 

81 _launch_fix_kernel(self, out) 

82 return out