Coverage for src/flag_gems/experimental_ops/fix_.py: 0%

39 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-13 10:08 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def fix_(x_ptr, n_elements, DO_UPCAST: tl.constexpr, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask) 

14 

15 if DO_UPCAST: 

16 x_work = x.to(tl.float32) 

17 else: 

18 x_work = x 

19 

20 y_floor = tl.floor(x_work) 

21 y_ceil = tl.ceil(x_work) 

22 y_work = tl.where(x_work >= 0, y_floor, y_ceil) 

23 

24 if DO_UPCAST: 

25 y = y_work.to(x.dtype) 

26 else: 

27 y = y_work 

28 

29 tl.store(x_ptr + offsets, y, mask=mask) 

30 

31 

32# Keep reference to the Triton kernel before redefining the name for the Python wrapper 

33_fix_kernel = fix_ 

34 

35 

36def fix_(*args, **kwargs): 

37 x = args[0] 

38 if not isinstance(x, torch.Tensor): 

39 raise TypeError("fix_ expects a torch.Tensor as the first argument") 

40 if not x.is_cuda: 

41 raise ValueError("Input tensor must be on CUDA device for Triton kernel") 

42 if not x.is_contiguous(): 

43 raise ValueError("Input tensor must be contiguous") 

44 

45 # In-place fix_ does nothing for non-floating tensors 

46 if not x.is_floating_point(): 

47 return x 

48 

49 n_elements = x.numel() 

50 if n_elements == 0: 

51 return x 

52 

53 # Upcast low-precision types for stable math 

54 do_upcast = x.dtype in (torch.float16, torch.bfloat16) 

55 

56 BLOCK_SIZE = 1024 

57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

58 _fix_kernel[grid](x, n_elements, DO_UPCAST=do_upcast, BLOCK_SIZE=BLOCK_SIZE) 

59 return x