Coverage for src/flag_gems/experimental_ops/unsqueeze.py: 0%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-07 22:33 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def unsqueeze_kernel( 

8 src_ptr, # *Pointer* to input tensor data. 

9 dst_ptr, # *Pointer* to output tensor data. 

10 out_numel, # Total number of elements in output (same as input). 

11 in_strides_ptr, # *Pointer* to input strides (length = RANK-1). 

12 out_shape_ptr, # *Pointer* to output shape (length = RANK). 

13 BLOCK_SIZE: tl.constexpr, # Number of elements processed by each program. 

14 RANK: tl.constexpr, # Rank of the output tensor. 

15 UNSQ_DIM: tl.constexpr, # The dimension at which to unsqueeze (compile-time constant). 

16): 

17 pid = tl.program_id(axis=0) 

18 block_start = pid * BLOCK_SIZE 

19 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

20 offsets = offsets.to(tl.int64) 

21 mask = offsets < out_numel 

22 

23 tmp = offsets 

24 src_offset = tl.zeros([BLOCK_SIZE], dtype=tl.int64) 

25 # Decompose linear index into multi-dimensional coordinates and map to input offset 

26 for k in range(RANK - 1, -1, -1): 

27 s_k = tl.load(out_shape_ptr + k) 

28 c_k = tmp % s_k 

29 tmp = tmp // s_k 

30 if k != UNSQ_DIM: 

31 in_k = k if k < UNSQ_DIM else k - 1 

32 stride_in_k = tl.load(in_strides_ptr + in_k) 

33 src_offset += c_k * stride_in_k 

34 

35 vals = tl.load(src_ptr + src_offset, mask=mask) 

36 tl.store(dst_ptr + offsets, vals, mask=mask) 

37 

38 

39def unsqueeze(*args, **kwargs): 

40 # Expect signature: unsqueeze(x, dim) 

41 if len(args) >= 2: 

42 x, dim = args[0], args[1] 

43 else: 

44 x = kwargs.get("self", kwargs.get("input", None)) 

45 dim = kwargs.get("dim", None) 

46 assert isinstance( 

47 x, torch.Tensor 

48 ), "unsqueeze expects a torch.Tensor as the first argument." 

49 assert isinstance(dim, int), "unsqueeze expects an integer 'dim' argument." 

50 assert x.is_cuda, "Input tensor must be on CUDA device." 

51 

52 ndims = x.dim() 

53 if dim < 0: 

54 dim += ndims + 1 

55 assert 0 <= dim <= ndims, f"dim must be in range [0, {ndims}] after normalization." 

56 

57 out_shape = list(x.shape) 

58 out_shape.insert(dim, 1) 

59 out = torch.empty(out_shape, dtype=x.dtype, device=x.device) 

60 

61 numel = out.numel() 

62 # Prepare strides and shape tensors on device 

63 in_strides = torch.tensor(x.stride(), dtype=torch.int64, device=x.device) 

64 out_shape_t = torch.tensor(out_shape, dtype=torch.int64, device=x.device) 

65 

66 RANK = len(out_shape) 

67 BLOCK_SIZE = 1024 

68 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

69 

70 unsqueeze_kernel[grid]( 

71 x, 

72 out, 

73 numel, 

74 in_strides, 

75 out_shape_t, 

76 BLOCK_SIZE=BLOCK_SIZE, 

77 RANK=RANK, 

78 UNSQ_DIM=dim, 

79 ) 

80 return out