Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addr.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13@libentry()
14@triton.jit(do_not_specialize=["beta", "alpha"])
15def addr_kernel(
16 input_ptr,
17 vec1_ptr,
18 vec2_ptr,
19 output_ptr,
20 beta,
21 alpha,
22 M,
23 N,
24 stride_input_m,
25 stride_input_n,
26 stride_vec1,
27 stride_vec2,
28 stride_output_m,
29 stride_output_n,
30 BLOCK_SIZE_M: tl.constexpr,
31 BLOCK_SIZE_N: tl.constexpr,
32):
33 pid_m = tl.program_id(0)
34 pid_n = tl.program_id(1)
36 offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
37 offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
39 vec1_ptrs = vec1_ptr + offs_m * stride_vec1
40 vec2_ptrs = vec2_ptr + offs_n * stride_vec2
42 mask_m = offs_m < M
43 mask_n = offs_n < N
45 vec1 = tl.load(vec1_ptrs, mask=mask_m, other=0.0).to(tl.float32)
46 vec2 = tl.load(vec2_ptrs, mask=mask_n, other=0.0).to(tl.float32)
48 input_ptrs = (
49 input_ptr + offs_m[:, None] * stride_input_m + offs_n[None, :] * stride_input_n
50 )
52 mask_2d = mask_m[:, None] & mask_n[None, :]
53 input_val = tl.load(input_ptrs, mask=mask_2d, other=0.0).to(tl.float32)
55 result = beta * input_val + alpha * (vec1[:, None] * vec2[None, :])
57 output_ptrs = (
58 output_ptr
59 + offs_m[:, None] * stride_output_m
60 + offs_n[None, :] * stride_output_n
61 )
62 tl.store(output_ptrs, result, mask=mask_2d)
65def addr(input, vec1, vec2, *, beta=1, alpha=1):
66 logger.debug("GEMS ADDR")
67 if vec1.dim() != 1 or vec2.dim() != 1:
68 raise ValueError("addr: expected 1-D vectors")
70 M = vec1.shape[0]
71 N = vec2.shape[0]
72 output_shape = (M, N)
74 try:
75 input_broadcasted = torch.broadcast_to(input, output_shape)
76 except RuntimeError:
77 raise ValueError(
78 f"addr: input tensor of shape {input.shape} cannot be broadcast to output shape {output_shape}"
79 )
80 out = torch.empty(output_shape, device=input.device, dtype=input.dtype)
82 BLOCK_SIZE_M = 32
83 BLOCK_SIZE_N = 32
84 grid = lambda META: (
85 triton.cdiv(M, BLOCK_SIZE_M),
86 triton.cdiv(N, BLOCK_SIZE_N),
87 )
88 with torch_device_fn.device(input.device):
89 addr_kernel[grid](
90 input_broadcasted,
91 vec1,
92 vec2,
93 out,
94 beta,
95 alpha,
96 M,
97 N,
98 input_broadcasted.stride(0),
99 input_broadcasted.stride(1),
100 vec1.stride(0),
101 vec2.stride(0),
102 out.stride(0),
103 out.stride(1),
104 BLOCK_SIZE_M=BLOCK_SIZE_M,
105 BLOCK_SIZE_N=BLOCK_SIZE_N,
106 )
107 return out