Coverage for src/flag_gems/experimental_ops/silu.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-19 02:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def silu_kernel( 

8 x_ptr, # *Pointer* to input tensor 

9 y_ptr, # *Pointer* to output tensor 

10 n_elements, # Number of elements 

11 BLOCK_SIZE: tl.constexpr, 

12 COMPUTE_IN_FP32: tl.constexpr, 

13): 

14 pid = tl.program_id(axis=0) 

15 block_start = pid * BLOCK_SIZE 

16 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

17 mask = offsets < n_elements 

18 

19 x = tl.load(x_ptr + offsets, mask=mask) 

20 if COMPUTE_IN_FP32: 

21 xf = x.to(tl.float32) 

22 yf = xf / (1.0 + tl.exp(-xf)) 

23 y = yf.to(x.dtype) 

24 else: 

25 y = x / (1.0 + tl.exp(-x)) 

26 tl.store(y_ptr + offsets, y, mask=mask) 

27 

28 

29def _silu_impl(x: torch.Tensor, out: torch.Tensor = None): 

30 if not x.is_cuda: 

31 raise ValueError("Input tensor must be on CUDA device.") 

32 if not torch.is_floating_point(x): 

33 raise TypeError("silu expects a floating point tensor.") 

34 if out is None: 

35 out = torch.empty_like(x) 

36 else: 

37 if not out.is_cuda: 

38 raise ValueError("Output tensor must be on CUDA device.") 

39 if out.shape != x.shape: 

40 raise ValueError( 

41 f"Output shape {out.shape} does not match input shape {x.shape}." 

42 ) 

43 if out.dtype != x.dtype: 

44 raise TypeError( 

45 f"Output dtype {out.dtype} does not match input dtype {x.dtype}." 

46 ) 

47 

48 x_contig = x.contiguous() 

49 out_contig = out if out.is_contiguous() else torch.empty_like(x_contig) 

50 

51 n_elements = x_contig.numel() 

52 if n_elements == 0: 

53 if out_contig is not out: 

54 out.copy_(out_contig) 

55 return out 

56 

57 compute_in_fp32 = x_contig.dtype in (torch.float16, torch.bfloat16) 

58 

59 BLOCK_SIZE = 1024 

60 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

61 silu_kernel[grid]( 

62 x_contig, 

63 out_contig, 

64 n_elements, 

65 BLOCK_SIZE=BLOCK_SIZE, 

66 COMPUTE_IN_FP32=compute_in_fp32, 

67 ) 

68 

69 if out_contig.data_ptr() != out.data_ptr(): 

70 out.copy_(out_contig) 

71 return out 

72 

73 

74def silu(*args, **kwargs): 

75 # Expecting signature similar to aten.silu(self) 

76 x = None 

77 if len(args) >= 1: 

78 x = args[0] 

79 else: 

80 x = kwargs.get("self", kwargs.get("input", None)) 

81 if x is None: 

82 raise TypeError("silu expects a tensor as the first argument.") 

83 return _silu_impl(x) 

84 

85 

86def silu_out(*args, **kwargs): 

87 # Expecting signature similar to aten.silu.out(self, out) 

88 x = None 

89 out = None 

90 

91 if len(args) >= 1: 

92 x = args[0] 

93 else: 

94 x = kwargs.get("self", kwargs.get("input", None)) 

95 

96 if len(args) >= 2: 

97 out = args[1] 

98 else: 

99 out = kwargs.get("out", None) 

100 

101 if x is None or out is None: 

102 raise TypeError("silu_out expects input and out tensors.") 

103 

104 _silu_impl(x, out) 

105 return out