Coverage for src/flag_gems/experimental_ops/floor_.py: 0%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def floor_( 

8 x_ptr, # pointer to input/output tensor (in-place) 

9 n_elements, # total number of elements 

10 BLOCK_SIZE: tl.constexpr, 

11 IS_FP32: tl.constexpr, 

12 IS_FP16: tl.constexpr, 

13 IS_BF16: tl.constexpr, 

14): 

15 pid = tl.program_id(axis=0) 

16 block_start = pid * BLOCK_SIZE 

17 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

18 mask = offsets < n_elements 

19 

20 x = tl.load(x_ptr + offsets, mask=mask) 

21 

22 # Apply floor only for floating-point dtypes; otherwise, no-op 

23 out = x 

24 if IS_FP32: 

25 out = tl.floor(x) 

26 elif IS_FP16: 

27 x_fp32 = tl.cast(x, tl.float32) 

28 out = tl.cast(tl.floor(x_fp32), tl.float16) 

29 elif IS_BF16: 

30 x_fp32 = tl.cast(x, tl.float32) 

31 out = tl.cast(tl.floor(x_fp32), tl.bfloat16) 

32 

33 tl.store(x_ptr + offsets, out, mask=mask) 

34 

35 

36# Keep a reference to the kernel before defining the wrapper with the same name 

37floor__kernel = floor_ 

38 

39 

40def floor_(*args, **kwargs): 

41 x = args[0] if len(args) > 0 else kwargs.get("input", None) 

42 if x is None: 

43 raise ValueError( 

44 "floor_ expects a Tensor as the first positional argument or 'input' keyword." 

45 ) 

46 if not isinstance(x, torch.Tensor): 

47 raise TypeError("floor_ expects a torch.Tensor.") 

48 if not x.is_cuda: 

49 raise ValueError("floor_ Triton kernel requires a CUDA tensor.") 

50 if x.is_complex(): 

51 raise TypeError("floor_ is not supported for complex tensors.") 

52 if not x.is_contiguous(): 

53 raise ValueError( 

54 "floor_ Triton kernel currently supports only contiguous tensors." 

55 ) 

56 

57 n_elements = x.numel() 

58 if n_elements == 0: 

59 return x 

60 

61 dtype = x.dtype 

62 IS_FP32 = dtype == torch.float32 

63 IS_FP16 = dtype == torch.float16 

64 IS_BF16 = dtype == torch.bfloat16 

65 

66 BLOCK_SIZE = 1024 

67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

68 

69 floor__kernel[grid]( 

70 x, # in-place: pass the same tensor pointer for load/store 

71 n_elements, 

72 BLOCK_SIZE=BLOCK_SIZE, 

73 IS_FP32=IS_FP32, 

74 IS_FP16=IS_FP16, 

75 IS_BF16=IS_BF16, 

76 ) 

77 return x