Coverage for src/flag_gems/runtime/backend/_nvidia/ops/add.py: 30%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-27 02:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def add_kernel( 

8 x_ptr, # *Pointer* to first input vector. 

9 y_ptr, # *Pointer* to second input vector. 

10 output_ptr, # *Pointer* to output vector. 

11 n_elements, # Size of the vector. 

12 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. 

13 # NOTE: `constexpr` so it can be used as a shape value. 

14): 

15 # There are multiple 'programs' processing different data. We identify which program 

16 # we are here: 

17 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 

18 # This program will process inputs that are offset from the initial data. 

19 # For instance, if you had a vector of length 256 and block_size of 64, the programs 

20 # would each access the elements [0:64, 64:128, 128:192, 192:256]. 

21 # Note that offsets is a list of pointers: 

22 block_start = pid * BLOCK_SIZE 

23 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

24 # Create a mask to guard memory operations against out-of-bounds accesses. 

25 mask = offsets < n_elements 

26 # Load x and y from DRAM, masking out any extra elements in case the input is not a 

27 # multiple of the block size. 

28 x = tl.load(x_ptr + offsets, mask=mask) 

29 y = tl.load(y_ptr + offsets, mask=mask) 

30 output = x + y 

31 # Write x + y back to DRAM. 

32 tl.store(output_ptr + offsets, output, mask=mask) 

33 

34 

35def add(x: torch.Tensor, y: torch.Tensor): 

36 # We need to preallocate the output. 

37 print("\n.......test for mutibackend specific add........\n") 

38 output = torch.empty_like(x) 

39 n_elements = output.numel() 

40 # The SPMD launch grid denotes the number of kernel instances that run in parallel. 

41 # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. 

42 # In this case, we use a 1D grid where the size is the number of blocks: 

43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

44 # NOTE: 

45 # - Each torch.tensor object is implicitly converted into a pointer to its first element. 

46 # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. 

47 # - Don't forget to pass meta-parameters as keywords arguments. 

48 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) 

49 # We return a handle to z but, since `torch_device_fn.synchronize()` hasn't been called, the kernel is still 

50 # running asynchronously at this point. 

51 return output