Coverage for src/flag_gems/runtime/backend/_metax/ops/sigmoid.py: 0%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-23 02:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

8 

9logger = logging.getLogger("flag_gems." + __name__) 

10exp2 = tl_extra_shim.exp2 

11 

12 

13@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

14@triton.jit 

15def sigmoid_forward(x): 

16 # log2e: tl.constexpr = math.log2(math.e) 

17 # triton 3.0.0 disallow calling non-jitted function inside jitted function, even if it is in 

18 # the rhs of an assignment to a constexpr, so we use numeric literal instead to work around this. 

19 log2e: tl.constexpr = 1.4426950408889634 

20 return 1 / (1 + exp2(-x.to(tl.float32) * log2e)) 

21 

22 

23@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

24@triton.jit 

25def sigmoid_backward(y, dy): 

26 y_f32 = y.to(tl.float32) 

27 dy_f32 = dy.to(tl.float32) 

28 return dy_f32 * (1.0 - y_f32) * y_f32 

29 

30 

31@triton.jit 

32def sigmoid_backward_custom_kernel( 

33 x_ptr: tl.tensor, # *Pointer* to first input vector. 

34 y_ptr: tl.tensor, # *Pointer* to second input vector. 

35 output_ptr: tl.tensor, # *Pointer* to output vector. 

36 n_elements: int, # Size of the vector. 

37 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. 

38 # NOTE: `constexpr`` so it can be used as a shape value. 

39): 

40 # There are multiple 'programs' processing different data. We identify which program 

41 # we are here: 

42 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 

43 block_start = pid * BLOCK_SIZE 

44 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

45 

46 # Create a mask to guard memory operations against out-of-bounds accesses. 

47 mask = offsets < n_elements 

48 # Load x and y from DRAM, masking out any extra elements in case the input is not a 

49 # multiple of the block size. 

50 x = tl.load(x_ptr + offsets, mask=mask) 

51 

52 # No need to add offset and mask, as its stride is 0 

53 y = tl.load(y_ptr) 

54 

55 output = y * (1 - x) * x 

56 # Write output back to DRAM. 

57 tl.store(output_ptr + offsets, output, mask=mask) 

58 

59 

60def sigmoid_backward_custom(x: torch.Tensor, y: torch.Tensor): 

61 # We need to preallocate the output. 

62 output = torch.empty_like(x) 

63 assert x.is_cuda and y.is_cuda and output.is_cuda 

64 

65 n_elements = output.numel() 

66 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

67 sigmoid_backward_custom_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) 

68 return output 

69 

70 

71class Sigmoid(torch.autograd.Function): 

72 @staticmethod 

73 def forward(ctx, A): 

74 logger.debug("METAX GEMS SIGMOID FORWARD") 

75 if A.requires_grad is True: 

76 out = sigmoid_forward(A.to(torch.float32)) 

77 ctx.save_for_backward(out) 

78 return out.to(A.dtype) 

79 else: 

80 out = sigmoid_forward(A) 

81 return out 

82 

83 @staticmethod 

84 def backward(ctx, out_grad): 

85 logger.debug("METAX GEMS SIGMOID BACKWARD") 

86 (out,) = ctx.saved_tensors 

87 

88 is_grad_stride_0 = True 

89 for i in range(len(out_grad.stride())): 

90 if out_grad.stride()[i] != 0: 

91 is_grad_stride_0 = False 

92 break 

93 

94 # temporay plan 

95 if (is_grad_stride_0) and (out_grad.numel() % 1024 == 0): 

96 in_grad = sigmoid_backward_custom(out, out_grad) 

97 return in_grad 

98 in_grad = sigmoid_backward(out, out_grad) 

99 return in_grad 

100 

101 

102def sigmoid(A): 

103 return Sigmoid.apply(A)