Coverage for src/flag_gems/ops/floor_.py: 54%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-28 12:23 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import torch
3import triton
4import triton.language as tl
6from flag_gems.runtime import torch_device_fn
9@triton.jit
10def floor_kernel_(
11 x_ptr,
12 n_elements,
13 BLOCK_SIZE: tl.constexpr,
14 IS_FP32: tl.constexpr,
15 IS_FP16: tl.constexpr,
16 IS_BF16: tl.constexpr,
17):
18 pid = tl.program_id(axis=0)
19 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < n_elements
23 x = tl.load(x_ptr + offsets, mask=mask)
25 # Apply floor only for floating-point dtypes; otherwise, no-op
26 out = x
27 if IS_FP32:
28 out = tl.floor(x)
29 elif IS_FP16:
30 x_fp32 = tl.cast(x, tl.float32)
31 out = tl.cast(tl.floor(x_fp32), tl.float16)
32 elif IS_BF16:
33 x_fp32 = tl.cast(x, tl.float32)
34 out = tl.cast(tl.floor(x_fp32), tl.bfloat16)
36 tl.store(x_ptr + offsets, out, mask=mask)
39def floor_(input):
40 x = input
41 if not isinstance(x, torch.Tensor):
42 raise TypeError("floor_ expects a torch.Tensor.")
43 if x.is_complex():
44 raise TypeError("floor_ is not supported for complex tensors.")
45 if not x.is_contiguous():
46 raise ValueError(
47 "floor_ Triton kernel currently supports only contiguous tensors."
48 )
50 n_elements = x.numel()
51 if n_elements == 0:
52 return x
54 dtype = x.dtype
55 IS_FP32 = dtype == torch.float32
56 IS_FP16 = dtype == torch.float16
57 IS_BF16 = dtype == torch.bfloat16
59 BLOCK_SIZE = 1024
60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
62 with torch_device_fn.device(x.device):
63 floor_kernel_[grid](
64 x,
65 n_elements,
66 BLOCK_SIZE=BLOCK_SIZE,
67 IS_FP32=IS_FP32,
68 IS_FP16=IS_FP16,
69 IS_BF16=IS_BF16,
70 )
71 return x