Coverage for src/flag_gems/experimental_ops/fix.py: 0%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _fix_trunc_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
15 # Upcast to fp32 for stable math, then cast back to input dtype
16 x_fp32 = x.to(tl.float32)
17 res_pos = tl.floor(x_fp32)
18 res_neg = tl.ceil(x_fp32)
19 res_fp32 = tl.where(x_fp32 >= 0, res_pos, res_neg)
20 res = res_fp32.to(x.dtype)
22 tl.store(out_ptr + offsets, res, mask=mask)
25@triton.jit
26def _copy_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
27 pid = tl.program_id(axis=0)
28 block_start = pid * BLOCK_SIZE
29 offsets = block_start + tl.arange(0, BLOCK_SIZE)
30 mask = offsets < n_elements
31 x = tl.load(x_ptr + offsets, mask=mask)
32 tl.store(out_ptr + offsets, x, mask=mask)
35def _launch_fix_kernel(x: torch.Tensor, out: torch.Tensor, block_size: int = 1024):
36 assert x.is_cuda and out.is_cuda, "Input and output must be on CUDA device"
37 assert (
38 x.numel() == out.numel()
39 ), "Input and output must have the same number of elements"
40 assert x.device == out.device, "Input and output must be on the same device"
41 assert (
42 x.is_contiguous() and out.is_contiguous()
43 ), "Only contiguous tensors are supported"
45 n_elements = x.numel()
46 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
48 # For floating point types, perform truncation toward zero.
49 # For non-floating types, it's effectively a no-op (copy).
50 if x.is_floating_point():
51 _fix_trunc_kernel[grid](x, out, n_elements, BLOCK_SIZE=block_size)
52 else:
53 _copy_kernel[grid](x, out, n_elements, BLOCK_SIZE=block_size)
56def fix(self: torch.Tensor):
57 """
58 Wrapper for ATen operator: ('fix', <Autograd.disable: False>)
59 Rounds elements toward zero (like trunc) for floating tensors.
60 Leaves integer tensors unchanged.
61 """
62 if self.is_complex():
63 # Fallback for complex dtypes not supported by Triton: use PyTorch
64 return torch.trunc(self)
66 out = torch.empty_like(self)
67 _launch_fix_kernel(self, out)
68 return out
71def fix_out(self: torch.Tensor, out: torch.Tensor):
72 """
73 Wrapper for ATen operator: ('fix.out', <Autograd.disable: False>)
74 Writes the result into 'out'.
75 """
76 if self.is_complex():
77 # Fallback for complex dtypes not supported by Triton: use PyTorch
78 out.copy_(torch.trunc(self))
79 return out
81 _launch_fix_kernel(self, out)
82 return out