Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/multinomial.py: 0%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-07 22:33 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import libentry
8from flag_gems.utils.random_utils import philox_backend_seed_offset, uniform
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13@libentry()
14@triton.jit(do_not_specialize=["K", "N", "philox_seed", "philox_offset"])
15def multinomial_with_replacement(
16 cdf_ptr, out_ptr, K, N, philox_seed, philox_offset, NBLOCK: tl.constexpr = 128
17):
18 # The computation is arranged in a 2d grid of blocks, each producing
19 # a batch of samples for a particular distribution.
20 # <------------------- grid.x --------------------->
21 # | dist0.batch0 | dist0.batch1 | dist0.batch2 ...
22 # grid.y | dist1.batch0 | dist1.batch1 | dist1.batch2 ...
23 # | dist2.batch0 | dist2.batch1 | dist2.batch2 ...
24 y_off = tl.program_id(1) * N
25 n = tl.program_id(0) * NBLOCK + tl.arange(0, NBLOCK)
26 rv, _, _, _ = uniform(philox_seed, philox_offset, y_off + n)
28 # Do a binary search for each random number on the cumulative probabilities.
29 # Each random number always selects the leftmost index of the data greater
30 # than or equal to itself. However, this is likely to give a wrong result
31 # in case the first probability is zero which is not expected to selected.
32 # This error happens when the tossed random number is also zero. To avoid
33 # this mistake, we simply perturb random variable with a small number.
34 rv += 0.0001
35 rv = tl.where(rv > 0.9999, 0.9999, rv)
37 cdf_ptr += tl.program_id(1) * K
38 start = tl.zeros((NBLOCK,), dtype=tl.int32)
39 end = tl.zeros((NBLOCK,), dtype=tl.int32) + K - 1
40 steps = tl.extra.xpu.libdevice.log2(K.to(tl.float32)).to(tl.int32) + 1
41 for _ in range(steps):
42 mid = start + (end - start) // 2
43 x = tl.load(cdf_ptr + mid, mask=n < N)
44 start = tl.where(x < rv, mid + 1, start)
45 end = tl.where(x < rv, end, mid)
47 # Returns the last index in case of an overflow
48 start = tl.where(start >= K, K - 1, start)
50 tl.store(out_ptr + y_off + n, start, mask=n < N)
53def multinomial(prob, n_samples, with_replacement=False, *, gen=None):
54 logger.debug("GEMS MULTINOMIAL")
55 assert prob.dtype in (torch.float16, torch.float32, torch.bfloat16, torch.float64)
56 assert 0 < prob.dim() <= 2, "prob_dist must be 1 or 2 dim"
57 n_categories = prob.size(-1)
58 assert n_categories <= (1 << 24), "number of categories cannot exceed 2^24"
59 assert (
60 with_replacement or n_samples <= n_categories
61 ), "cannot sample n_samples > prob.size(-1) samples without replacement."
63 # Sampling without replacement
64 if (not with_replacement) or n_samples == 1:
65 # In case of with_replacement, sampling is approximated by selecing
66 # the top k indices over sorted probabilities with an exponential pertubation
67 # s = argmax( p / q ) where q ~ Exp(1)
68 q = torch.empty_like(prob).exponential_(1.0)
69 s = torch.div(prob, q, out=q)
70 if n_samples == 1:
71 return torch.argmax(s, dim=-1, keepdim=True).to(torch.int64)
72 else:
73 vals, indices = torch.topk(s, n_samples, dim=-1)
74 return indices.to(torch.int64)
76 from _kunlunxin.ops import normed_cumsum
78 if len(prob.shape) == 2 and prob.shape[1] > 8192:
79 cum_prob_mid = torch.cumsum(prob, dim=-1)
80 row_sums = prob.sum(dim=-1, keepdim=True)
81 cum_prob = cum_prob_mid / row_sums
82 else:
83 cum_prob = normed_cumsum(prob, dim=-1)
85 if cum_prob.dim() == 1:
86 n_dist = 1
87 out = torch.empty((n_samples,), device=prob.device, dtype=torch.int64)
88 else:
89 n_dist = cum_prob.size(0)
90 out = torch.empty((n_dist, n_samples), device=prob.device, dtype=torch.int64)
91 # The CTA level parallelism is framed in a 2d grid of blocks with grid.y
92 # indexing into distributions and grid.x output sample batches
93 increment = n_dist * n_samples
94 philox_seed, philox_offset = philox_backend_seed_offset(increment, generator=gen)
95 grid = lambda META: (triton.cdiv(n_samples, META["NBLOCK"]), n_dist)
96 multinomial_with_replacement[grid](
97 cum_prob, out, n_categories, n_samples, philox_seed, philox_offset
98 )
99 return out