Coverage for src/flag_gems/experimental_ops/fix_.py: 0%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def fix_(x_ptr, n_elements, DO_UPCAST: tl.constexpr, BLOCK_SIZE: tl.constexpr):
8 pid = tl.program_id(axis=0)
9 block_start = pid * BLOCK_SIZE
10 offsets = block_start + tl.arange(0, BLOCK_SIZE)
11 mask = offsets < n_elements
13 x = tl.load(x_ptr + offsets, mask=mask)
15 if DO_UPCAST:
16 x_work = x.to(tl.float32)
17 else:
18 x_work = x
20 y_floor = tl.floor(x_work)
21 y_ceil = tl.ceil(x_work)
22 y_work = tl.where(x_work >= 0, y_floor, y_ceil)
24 if DO_UPCAST:
25 y = y_work.to(x.dtype)
26 else:
27 y = y_work
29 tl.store(x_ptr + offsets, y, mask=mask)
32# Keep reference to the Triton kernel before redefining the name for the Python wrapper
33_fix_kernel = fix_
36def fix_(*args, **kwargs):
37 x = args[0]
38 if not isinstance(x, torch.Tensor):
39 raise TypeError("fix_ expects a torch.Tensor as the first argument")
40 if not x.is_cuda:
41 raise ValueError("Input tensor must be on CUDA device for Triton kernel")
42 if not x.is_contiguous():
43 raise ValueError("Input tensor must be contiguous")
45 # In-place fix_ does nothing for non-floating tensors
46 if not x.is_floating_point():
47 return x
49 n_elements = x.numel()
50 if n_elements == 0:
51 return x
53 # Upcast low-precision types for stable math
54 do_upcast = x.dtype in (torch.float16, torch.bfloat16)
56 BLOCK_SIZE = 1024
57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
58 _fix_kernel[grid](x, n_elements, DO_UPCAST=do_upcast, BLOCK_SIZE=BLOCK_SIZE)
59 return x