Coverage for src/flag_gems/runtime/backend/_nvidia/ops/add.py: 30%
20 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-22 16:54 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def add_kernel(
8 x_ptr, # *Pointer* to first input vector.
9 y_ptr, # *Pointer* to second input vector.
10 output_ptr, # *Pointer* to output vector.
11 n_elements, # Size of the vector.
12 BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
13 # NOTE: `constexpr` so it can be used as a shape value.
14):
15 # There are multiple 'programs' processing different data. We identify which program
16 # we are here:
17 pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
18 # This program will process inputs that are offset from the initial data.
19 # For instance, if you had a vector of length 256 and block_size of 64, the programs
20 # would each access the elements [0:64, 64:128, 128:192, 192:256].
21 # Note that offsets is a list of pointers:
22 block_start = pid * BLOCK_SIZE
23 offsets = block_start + tl.arange(0, BLOCK_SIZE)
24 # Create a mask to guard memory operations against out-of-bounds accesses.
25 mask = offsets < n_elements
26 # Load x and y from DRAM, masking out any extra elements in case the input is not a
27 # multiple of the block size.
28 x = tl.load(x_ptr + offsets, mask=mask)
29 y = tl.load(y_ptr + offsets, mask=mask)
30 output = x + y
31 # Write x + y back to DRAM.
32 tl.store(output_ptr + offsets, output, mask=mask)
35def add(x: torch.Tensor, y: torch.Tensor):
36 # We need to preallocate the output.
37 print("\n.......test for mutibackend specific add........\n")
38 output = torch.empty_like(x)
39 n_elements = output.numel()
40 # The SPMD launch grid denotes the number of kernel instances that run in parallel.
41 # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
42 # In this case, we use a 1D grid where the size is the number of blocks:
43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
44 # NOTE:
45 # - Each torch.tensor object is implicitly converted into a pointer to its first element.
46 # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
47 # - Don't forget to pass meta-parameters as keywords arguments.
48 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
49 # We return a handle to z but, since `torch_device_fn.synchronize()` hasn't been called, the kernel is still
50 # running asynchronously at this point.
51 return output