Coverage for src/flag_gems/experimental_ops/hardtanh_.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def hardtanh_( 

8 x_ptr, # *Pointer* to input/output tensor (in-place). 

9 n_elements, # Number of elements. 

10 min_val, # Minimum clamp value (scalar). 

11 max_val, # Maximum clamp value (scalar). 

12 BLOCK_SIZE: tl.constexpr, 

13): 

14 pid = tl.program_id(axis=0) 

15 block_start = pid * BLOCK_SIZE 

16 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

17 mask = offsets < n_elements 

18 

19 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

20 

21 # Cast min/max to tensor dtype 

22 min_v = tl.full([1], min_val, dtype=x.dtype) 

23 max_v = tl.full([1], max_val, dtype=x.dtype) 

24 

25 x = tl.minimum(x, max_v) 

26 x = tl.maximum(x, min_v) 

27 

28 tl.store(x_ptr + offsets, x, mask=mask) 

29 

30 

31# Keep a reference to the Triton kernel before defining the Python wrapper of the same name 

32hardtanh___kernel = hardtanh_ 

33 

34 

35def hardtanh_(*args, **kwargs): 

36 # Parse arguments: expected signature hardtanh_(x, min_val=-1.0, max_val=1.0) 

37 if len(args) == 0 and "input" not in kwargs and "self" not in kwargs: 

38 raise TypeError("hardtanh_ expected at least 1 argument: a CUDA tensor") 

39 

40 # Accept common naming: positional 0, or keyword 'input'/'self' 

41 x = None 

42 if len(args) >= 1: 

43 x = args[0] 

44 elif "input" in kwargs: 

45 x = kwargs["input"] 

46 elif "self" in kwargs: 

47 x = kwargs["self"] 

48 

49 # Defaults 

50 min_val = -1.0 

51 max_val = 1.0 

52 

53 # Override from positional args if provided 

54 if len(args) >= 2: 

55 min_val = args[1] 

56 if len(args) >= 3: 

57 max_val = args[2] 

58 

59 # Override from kwargs if provided 

60 if "min_val" in kwargs and kwargs["min_val"] is not None: 

61 min_val = kwargs["min_val"] 

62 if "max_val" in kwargs and kwargs["max_val"] is not None: 

63 max_val = kwargs["max_val"] 

64 

65 if not isinstance(x, torch.Tensor): 

66 raise TypeError("hardtanh_ expects a torch.Tensor as the first argument") 

67 

68 # Fallback for unsupported device/dtypes 

69 if not x.is_cuda: 

70 # CPU fallback using PyTorch 

71 return torch.clamp_(x, min=min_val, max=max_val) 

72 

73 if not x.is_floating_point(): 

74 # For non-floating types, use PyTorch fallback to preserve semantics 

75 return torch.clamp_(x, min=min_val, max=max_val) 

76 

77 # Require contiguous memory for in-place update 

78 if not x.is_contiguous(): 

79 # To preserve in-place semantics on non-contiguous tensors, use PyTorch 

80 return torch.clamp_(x, min=min_val, max=max_val) 

81 

82 n_elements = x.numel() 

83 if n_elements == 0: 

84 return x 

85 

86 BLOCK_SIZE = 1024 

87 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

88 

89 # Launch Triton kernel 

90 hardtanh___kernel[grid]( 

91 x, n_elements, float(min_val), float(max_val), BLOCK_SIZE=BLOCK_SIZE # in-place 

92 ) 

93 return x