Coverage for src/flag_gems/experimental_ops/frac.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-24 15:40 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def frac_kernel(
8 x_ptr,
9 out_ptr,
10 n_elements,
11 BLOCK_SIZE: tl.constexpr,
12 IS_FP16: tl.constexpr,
13 IS_BF16: tl.constexpr,
14 IS_FP64: tl.constexpr,
15):
16 pid = tl.program_id(axis=0)
17 block_start = pid * BLOCK_SIZE
18 offsets = block_start + tl.arange(0, BLOCK_SIZE)
19 mask = offsets < n_elements
21 x = tl.load(x_ptr + offsets, mask=mask, other=0)
23 # Choose compute dtype
24 if IS_FP64:
25 x_comp = x.to(tl.float64)
26 elif IS_FP16 or IS_BF16:
27 x_comp = x.to(tl.float32)
28 else:
29 x_comp = x # float32
31 trunc_val = tl.where(x_comp >= 0, tl.floor(x_comp), tl.ceil(x_comp))
32 y_comp = x_comp - trunc_val
34 # Cast back to output dtype
35 if IS_FP64:
36 y = y_comp.to(tl.float64)
37 elif IS_FP16:
38 y = y_comp.to(tl.float16)
39 elif IS_BF16:
40 y = y_comp.to(tl.bfloat16)
41 else:
42 y = y_comp.to(tl.float32)
44 tl.store(out_ptr + offsets, y, mask=mask)
47def _launch_frac(x: torch.Tensor, out: torch.Tensor):
48 assert x.is_cuda and out.is_cuda, "Inputs must be CUDA tensors"
49 assert (
50 x.numel() == out.numel()
51 ), "Input and output must have the same number of elements"
52 assert x.dtype == out.dtype, "Input and output must have the same dtype"
53 if not x.is_floating_point():
54 raise NotImplementedError("frac is only implemented for floating point dtypes")
55 if x.is_complex():
56 raise NotImplementedError(
57 "frac is not implemented for complex dtypes in this Triton kernel"
58 )
60 n_elements = x.numel()
61 if n_elements == 0:
62 return out
64 # Use contiguous buffers for kernel execution
65 x_contig = x.contiguous()
66 out_contig = out.contiguous()
68 is_fp16 = x_contig.dtype == torch.float16
69 is_bf16 = x_contig.dtype == torch.bfloat16
70 is_fp64 = x_contig.dtype == torch.float64
72 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
73 frac_kernel[grid](
74 x_contig,
75 out_contig,
76 n_elements,
77 BLOCK_SIZE=1024,
78 IS_FP16=is_fp16,
79 IS_BF16=is_bf16,
80 IS_FP64=is_fp64,
81 )
83 # If out was non-contiguous, copy results back
84 if out_contig.data_ptr() != out.data_ptr():
85 out.copy_(out_contig)
86 return out
89def frac(input: torch.Tensor):
90 out = torch.empty_like(input)
91 _launch_frac(input, out)
92 return out
95def frac_out(input: torch.Tensor, out: torch.Tensor):
96 # Ensure shape and dtype match per .out contract
97 assert out.shape == input.shape, "out must have the same shape as input"
98 assert out.dtype == input.dtype, "out must have the same dtype as input"
99 _launch_frac(input, out)
100 return out