Coverage for src/flag_gems/experimental_ops/frac.py: 0%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def frac_kernel( 

8 x_ptr, 

9 out_ptr, 

10 n_elements, 

11 BLOCK_SIZE: tl.constexpr, 

12 IS_FP16: tl.constexpr, 

13 IS_BF16: tl.constexpr, 

14 IS_FP64: tl.constexpr, 

15): 

16 pid = tl.program_id(axis=0) 

17 block_start = pid * BLOCK_SIZE 

18 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

19 mask = offsets < n_elements 

20 

21 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

22 

23 # Choose compute dtype 

24 if IS_FP64: 

25 x_comp = x.to(tl.float64) 

26 elif IS_FP16 or IS_BF16: 

27 x_comp = x.to(tl.float32) 

28 else: 

29 x_comp = x # float32 

30 

31 trunc_val = tl.where(x_comp >= 0, tl.floor(x_comp), tl.ceil(x_comp)) 

32 y_comp = x_comp - trunc_val 

33 

34 # Cast back to output dtype 

35 if IS_FP64: 

36 y = y_comp.to(tl.float64) 

37 elif IS_FP16: 

38 y = y_comp.to(tl.float16) 

39 elif IS_BF16: 

40 y = y_comp.to(tl.bfloat16) 

41 else: 

42 y = y_comp.to(tl.float32) 

43 

44 tl.store(out_ptr + offsets, y, mask=mask) 

45 

46 

47def _launch_frac(x: torch.Tensor, out: torch.Tensor): 

48 assert x.is_cuda and out.is_cuda, "Inputs must be CUDA tensors" 

49 assert ( 

50 x.numel() == out.numel() 

51 ), "Input and output must have the same number of elements" 

52 assert x.dtype == out.dtype, "Input and output must have the same dtype" 

53 if not x.is_floating_point(): 

54 raise NotImplementedError("frac is only implemented for floating point dtypes") 

55 if x.is_complex(): 

56 raise NotImplementedError( 

57 "frac is not implemented for complex dtypes in this Triton kernel" 

58 ) 

59 

60 n_elements = x.numel() 

61 if n_elements == 0: 

62 return out 

63 

64 # Use contiguous buffers for kernel execution 

65 x_contig = x.contiguous() 

66 out_contig = out.contiguous() 

67 

68 is_fp16 = x_contig.dtype == torch.float16 

69 is_bf16 = x_contig.dtype == torch.bfloat16 

70 is_fp64 = x_contig.dtype == torch.float64 

71 

72 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

73 frac_kernel[grid]( 

74 x_contig, 

75 out_contig, 

76 n_elements, 

77 BLOCK_SIZE=1024, 

78 IS_FP16=is_fp16, 

79 IS_BF16=is_bf16, 

80 IS_FP64=is_fp64, 

81 ) 

82 

83 # If out was non-contiguous, copy results back 

84 if out_contig.data_ptr() != out.data_ptr(): 

85 out.copy_(out_contig) 

86 return out 

87 

88 

89def frac(input: torch.Tensor): 

90 out = torch.empty_like(input) 

91 _launch_frac(input, out) 

92 return out 

93 

94 

95def frac_out(input: torch.Tensor, out: torch.Tensor): 

96 # Ensure shape and dtype match per .out contract 

97 assert out.shape == input.shape, "out must have the same shape as input" 

98 assert out.dtype == input.dtype, "out must have the same dtype as input" 

99 _launch_frac(input, out) 

100 return out