Coverage for src/flag_gems/ops/floor_.py: 54%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import torch 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.runtime import torch_device_fn 

7 

8 

9@triton.jit 

10def floor_kernel_( 

11 x_ptr, 

12 n_elements, 

13 BLOCK_SIZE: tl.constexpr, 

14 IS_FP32: tl.constexpr, 

15 IS_FP16: tl.constexpr, 

16 IS_BF16: tl.constexpr, 

17): 

18 pid = tl.program_id(axis=0) 

19 block_start = pid * BLOCK_SIZE 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < n_elements 

22 

23 x = tl.load(x_ptr + offsets, mask=mask) 

24 

25 # Apply floor only for floating-point dtypes; otherwise, no-op 

26 out = x 

27 if IS_FP32: 

28 out = tl.floor(x) 

29 elif IS_FP16: 

30 x_fp32 = tl.cast(x, tl.float32) 

31 out = tl.cast(tl.floor(x_fp32), tl.float16) 

32 elif IS_BF16: 

33 x_fp32 = tl.cast(x, tl.float32) 

34 out = tl.cast(tl.floor(x_fp32), tl.bfloat16) 

35 

36 tl.store(x_ptr + offsets, out, mask=mask) 

37 

38 

39def floor_(input): 

40 x = input 

41 if not isinstance(x, torch.Tensor): 

42 raise TypeError("floor_ expects a torch.Tensor.") 

43 if x.is_complex(): 

44 raise TypeError("floor_ is not supported for complex tensors.") 

45 if not x.is_contiguous(): 

46 raise ValueError( 

47 "floor_ Triton kernel currently supports only contiguous tensors." 

48 ) 

49 

50 n_elements = x.numel() 

51 if n_elements == 0: 

52 return x 

53 

54 dtype = x.dtype 

55 IS_FP32 = dtype == torch.float32 

56 IS_FP16 = dtype == torch.float16 

57 IS_BF16 = dtype == torch.bfloat16 

58 

59 BLOCK_SIZE = 1024 

60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

61 

62 with torch_device_fn.device(x.device): 

63 floor_kernel_[grid]( 

64 x, 

65 n_elements, 

66 BLOCK_SIZE=BLOCK_SIZE, 

67 IS_FP32=IS_FP32, 

68 IS_FP16=IS_FP16, 

69 IS_BF16=IS_BF16, 

70 ) 

71 return x