Coverage for src/flag_gems/experimental_ops/sigmoid_.py: 0%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-20 02:31 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def sigmoid_(x_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

14 x_dtype = x.dtype 

15 x_fp32 = x.to(tl.float32) 

16 

17 exp_neg = tl.exp(-x_fp32) 

18 exp_pos = tl.exp(x_fp32) 

19 out_pos = 1.0 / (1.0 + exp_neg) 

20 out_neg = exp_pos / (1.0 + exp_pos) 

21 cond = x_fp32 >= 0 

22 y_fp32 = tl.where(cond, out_pos, out_neg) 

23 y = y_fp32.to(x_dtype) 

24 

25 tl.store(x_ptr + offsets, y, mask=mask) 

26 

27 

28# Keep a reference to the Triton kernel before defining the Python wrapper with the same name. 

29sigmoid___kernel = sigmoid_ 

30 

31 

32def sigmoid_(*args, **kwargs): 

33 # Extract the input tensor following aten.sigmoid_ schema (self is the tensor) 

34 x = None 

35 if args: 

36 x = args[0] 

37 else: 

38 x = kwargs.get("self", kwargs.get("input", None)) 

39 

40 if not isinstance(x, torch.Tensor): 

41 raise TypeError("sigmoid_ expects a torch.Tensor as the first argument") 

42 

43 # Fallback for unsupported cases 

44 if x.numel() == 0: 

45 return x 

46 if ( 

47 (not x.is_cuda) 

48 or (not x.is_contiguous()) 

49 or x.dtype not in (torch.float16, torch.bfloat16, torch.float32) 

50 ): 

51 return torch.ops.aten.sigmoid_(x) 

52 

53 n_elements = x.numel() 

54 BLOCK_SIZE = 1024 

55 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

56 sigmoid___kernel[grid](x, n_elements, BLOCK_SIZE=BLOCK_SIZE) 

57 return x