Coverage for src/flag_gems/experimental_ops/floor_.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-21 14:31 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def floor_(
8 x_ptr, # pointer to input/output tensor (in-place)
9 n_elements, # total number of elements
10 BLOCK_SIZE: tl.constexpr,
11 IS_FP32: tl.constexpr,
12 IS_FP16: tl.constexpr,
13 IS_BF16: tl.constexpr,
14):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
20 x = tl.load(x_ptr + offsets, mask=mask)
22 # Apply floor only for floating-point dtypes; otherwise, no-op
23 out = x
24 if IS_FP32:
25 out = tl.floor(x)
26 elif IS_FP16:
27 x_fp32 = tl.cast(x, tl.float32)
28 out = tl.cast(tl.floor(x_fp32), tl.float16)
29 elif IS_BF16:
30 x_fp32 = tl.cast(x, tl.float32)
31 out = tl.cast(tl.floor(x_fp32), tl.bfloat16)
33 tl.store(x_ptr + offsets, out, mask=mask)
36# Keep a reference to the kernel before defining the wrapper with the same name
37floor__kernel = floor_
40def floor_(*args, **kwargs):
41 x = args[0] if len(args) > 0 else kwargs.get("input", None)
42 if x is None:
43 raise ValueError(
44 "floor_ expects a Tensor as the first positional argument or 'input' keyword."
45 )
46 if not isinstance(x, torch.Tensor):
47 raise TypeError("floor_ expects a torch.Tensor.")
48 if not x.is_cuda:
49 raise ValueError("floor_ Triton kernel requires a CUDA tensor.")
50 if x.is_complex():
51 raise TypeError("floor_ is not supported for complex tensors.")
52 if not x.is_contiguous():
53 raise ValueError(
54 "floor_ Triton kernel currently supports only contiguous tensors."
55 )
57 n_elements = x.numel()
58 if n_elements == 0:
59 return x
61 dtype = x.dtype
62 IS_FP32 = dtype == torch.float32
63 IS_FP16 = dtype == torch.float16
64 IS_BF16 = dtype == torch.bfloat16
66 BLOCK_SIZE = 1024
67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
69 floor__kernel[grid](
70 x, # in-place: pass the same tensor pointer for load/store
71 n_elements,
72 BLOCK_SIZE=BLOCK_SIZE,
73 IS_FP32=IS_FP32,
74 IS_FP16=IS_FP16,
75 IS_BF16=IS_BF16,
76 )
77 return x