Coverage for src/flag_gems/fused/silu_and_mul_with_clamp.py: 53%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")]) 

13@triton.jit 

14def silu_and_mul_with_clamp_kernel(x, y, limit): 

15 x_fp32 = x.to(tl.float32) 

16 y_fp32 = y.to(tl.float32) 

17 limit_fp32 = limit.to(tl.float32) 

18 

19 gate = tl.minimum(x_fp32, limit_fp32) 

20 up = tl.minimum(tl.maximum(y_fp32, -limit_fp32), limit_fp32) 

21 gate_silu = tl.fdiv(gate, (1.0 + tl.exp(-gate))) 

22 

23 return gate_silu * up 

24 

25 

26@pointwise_dynamic( 

27 promotion_methods=[ 

28 (0, 1, 2, 3, "DEFAULT"), 

29 (0, 1, 2, 3, "DEFAULT"), 

30 ], 

31 num_outputs=2, 

32) 

33@triton.jit 

34def silu_and_mul_with_clamp_grad_kernel(x, y, dgrad, limit): 

35 x_fp32 = x.to(tl.float32) 

36 y_fp32 = y.to(tl.float32) 

37 dgrad_fp32 = dgrad.to(tl.float32) 

38 limit_fp32 = limit.to(tl.float32) 

39 

40 gate = tl.minimum(x_fp32, limit_fp32) 

41 up = tl.minimum(tl.maximum(y_fp32, -limit_fp32), limit_fp32) 

42 

43 sig = 1 / (1 + tl.exp(-gate)) 

44 gate_silu = gate * sig 

45 d_gate_silu = sig * (1 + gate * (1 - sig)) 

46 

47 gate_mask = x_fp32 <= limit_fp32 

48 up_mask = (y_fp32 >= -limit_fp32) & (y_fp32 <= limit_fp32) 

49 

50 dx = dgrad_fp32 * up * d_gate_silu * gate_mask.to(tl.float32) 

51 dy = dgrad_fp32 * gate_silu * up_mask.to(tl.float32) 

52 

53 return dx, dy 

54 

55 

56class SiluAndMulWithClamp(torch.autograd.Function): 

57 @staticmethod 

58 def forward(ctx, x, y, limit): 

59 limit_tensor = torch.tensor(limit, device=x.device, dtype=x.dtype) 

60 ctx.save_for_backward(x, y, limit_tensor) 

61 logger.debug("GEMS SILU_AND_MUL_WITH_CLAMP_FORWARD") 

62 return silu_and_mul_with_clamp_kernel(x, y, limit_tensor) 

63 

64 @staticmethod 

65 def backward(ctx, dgrad): 

66 x, y, limit_tensor = ctx.saved_tensors 

67 logger.debug("GEMS SILU_AND_MUL_WITH_CLAMP_BACKWARD") 

68 dx, dy = silu_and_mul_with_clamp_grad_kernel(x, y, dgrad, limit_tensor) 

69 return dx, dy, None 

70 

71 

72def silu_and_mul_with_clamp(x, y, limit): 

73 return SiluAndMulWithClamp.apply(x, y, limit) 

74 

75 

76def silu_and_mul_with_clamp_out(x, y, out, limit): 

77 logger.debug("GEMS SILU_AND_MUL_WITH_CLAMP_OUT") 

78 limit_tensor = torch.tensor(limit, device=x.device, dtype=x.dtype) 

79 silu_and_mul_with_clamp_kernel(x, y, limit_tensor, out0=out) 

80 return out