Coverage for src/flag_gems/experimental_ops/sigmoid.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-26 15:32 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def sigmoid_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

8 pid = tl.program_id(axis=0) 

9 block_start = pid * BLOCK_SIZE 

10 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

11 mask = offsets < n_elements 

12 

13 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

14 x_f32 = x.to(tl.float32) 

15 y = 1.0 / (1.0 + tl.exp(-x_f32)) 

16 y = y.to(x.dtype) 

17 tl.store(out_ptr + offsets, y, mask=mask) 

18 

19 

20def _sigmoid_common(x: torch.Tensor, out: torch.Tensor = None): 

21 if not isinstance(x, torch.Tensor): 

22 raise TypeError("sigmoid: expected a torch.Tensor as input") 

23 if not x.is_cuda: 

24 raise ValueError("sigmoid: input tensor must be on CUDA device") 

25 if x.dtype not in (torch.float16, torch.bfloat16, torch.float32): 

26 raise NotImplementedError( 

27 f"sigmoid: dtype {x.dtype} is not supported (supported: float16, bfloat16, float32)" 

28 ) 

29 

30 n_elements = x.numel() 

31 if out is None: 

32 out = torch.empty_like(x) 

33 else: 

34 if not isinstance(out, torch.Tensor): 

35 raise TypeError("sigmoid.out: 'out' must be a torch.Tensor") 

36 if not out.is_cuda: 

37 raise ValueError("sigmoid.out: 'out' tensor must be on CUDA device") 

38 if out.shape != x.shape: 

39 raise ValueError( 

40 f"sigmoid.out: 'out' shape {out.shape} does not match input shape {x.shape}" 

41 ) 

42 if out.dtype != x.dtype: 

43 raise ValueError( 

44 f"sigmoid.out: 'out' dtype {out.dtype} must match input dtype {x.dtype}" 

45 ) 

46 

47 if n_elements == 0: 

48 return out 

49 

50 x_contig = x.contiguous() 

51 out_contig = ( 

52 out 

53 if out.is_contiguous() 

54 else torch.empty_like(out, memory_format=torch.contiguous_format) 

55 ) 

56 

57 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

58 sigmoid_kernel[grid](x_contig, out_contig, n_elements, BLOCK_SIZE=1024) 

59 

60 if out_contig.data_ptr() != out.data_ptr(): 

61 out.copy_(out_contig) 

62 return out 

63 

64 

65def sigmoid(self: torch.Tensor): 

66 return _sigmoid_common(self, out=None) 

67 

68 

69def sigmoid_out(self: torch.Tensor, out: torch.Tensor): 

70 return _sigmoid_common(self, out=out)