Coverage for src/flag_gems/experimental_ops/hardtanh.py: 0%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-17 02:35 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def hardtanh_kernel( 

8 x_ptr, out_ptr, n_elements, min_val, max_val, BLOCK_SIZE: tl.constexpr 

9): 

10 pid = tl.program_id(axis=0) 

11 block_start = pid * BLOCK_SIZE 

12 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

13 mask = offsets < n_elements 

14 

15 x = tl.load(x_ptr + offsets, mask=mask) 

16 x_fp32 = x.to(tl.float32) 

17 

18 min_v = min_val # expected to be float32 scalar 

19 max_v = max_val # expected to be float32 scalar 

20 x_clamped = tl.maximum(tl.minimum(x_fp32, max_v), min_v) 

21 

22 y = x_clamped.to(x.dtype) 

23 tl.store(out_ptr + offsets, y, mask=mask) 

24 

25 

26def _launch_hardtanh( 

27 input: torch.Tensor, output: torch.Tensor, min_val: float, max_val: float 

28): 

29 assert input.is_cuda and output.is_cuda, "Tensors must be on CUDA device" 

30 assert input.device == output.device, "Input and output must be on the same device" 

31 assert input.dtype == output.dtype, "Input and output must have the same dtype" 

32 n_elements = input.numel() 

33 if n_elements == 0: 

34 return output 

35 BLOCK_SIZE = 1024 

36 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

37 hardtanh_kernel[grid]( 

38 input, 

39 output, 

40 n_elements, 

41 float(min_val), 

42 float(max_val), 

43 BLOCK_SIZE=BLOCK_SIZE, 

44 ) 

45 return output 

46 

47 

48def hardtanh(self: torch.Tensor, min_val: float = -1.0, max_val: float = 1.0): 

49 x = self 

50 assert x.is_cuda, "Input tensor must be on CUDA device" 

51 x_contig = x.contiguous() 

52 out = torch.empty_like(x_contig) 

53 _launch_hardtanh(x_contig, out, min_val, max_val) 

54 # If original tensor wasn't contiguous, we still return a tensor matching input's shape and dtype 

55 return out.view_as(x) 

56 

57 

58def hardtanh_out( 

59 self: torch.Tensor, 

60 min_val: float = -1.0, 

61 max_val: float = 1.0, 

62 out: torch.Tensor = None, 

63): 

64 x = self 

65 assert x.is_cuda, "Input tensor must be on CUDA device" 

66 if out is None: 

67 out = torch.empty_like(x) 

68 assert out.is_cuda, "Output tensor must be on CUDA device" 

69 assert out.shape == x.shape, "Output tensor must have the same shape as input" 

70 assert out.dtype == x.dtype, "Output tensor must have the same dtype as input" 

71 if not out.is_contiguous(): 

72 # For non-contiguous out, compute into a contiguous buffer then copy back 

73 out_contig = torch.empty_like(out.contiguous()) 

74 _launch_hardtanh(x.contiguous(), out_contig, min_val, max_val) 

75 out.copy_(out_contig) 

76 return out 

77 _launch_hardtanh(x.contiguous(), out, min_val, max_val) 

78 return out