Coverage for src/flag_gems/runtime/backend/_ascend/ops/multinomial.py: 0%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-25 02:48 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import libentry 

8from flag_gems.utils.random_utils import philox_backend_seed_offset, uniform 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11 

12 

13@libentry() 

14@triton.jit(do_not_specialize=["K", "N", "philox_seed", "philox_offset"]) 

15def multinomial_with_replacement( 

16 cdf_ptr, out_ptr, K, N, philox_seed, philox_offset, NBLOCK: tl.constexpr = 128 

17): 

18 # The computation is arranged in a 2d grid of blocks, each producing 

19 # a batch of samples for a particular distribution. 

20 # <------------------- grid.x ---------------------> 

21 # | dist0.batch0 | dist0.batch1 | dist0.batch2 ... 

22 # grid.y | dist1.batch0 | dist1.batch1 | dist1.batch2 ... 

23 # | dist2.batch0 | dist2.batch1 | dist2.batch2 ... 

24 y_off = tl.program_id(1) * N 

25 n = tl.program_id(0) * NBLOCK + tl.arange(0, NBLOCK) 

26 rv, _, _, _ = uniform(philox_seed, philox_offset, y_off + n) 

27 

28 # Do a binary search for each random number on the cumulative probabilities. 

29 # Each random number always selects the leftmost index of the data greater 

30 # than or equal to itself. However, this is likely to give a wrong result 

31 # in case the first probability is zero which is not expected to selected. 

32 # This error happens when the tossed random number is also zero. To avoid 

33 # this mistake, we simply perturb random variable with a small number. 

34 rv += 0.0001 

35 rv = tl.where(rv > 0.9999, 0.9999, rv) 

36 

37 cdf_ptr += tl.program_id(1) * K 

38 start = tl.zeros((NBLOCK,), dtype=tl.int32) 

39 end = tl.zeros((NBLOCK,), dtype=tl.int32) + K - 1 

40 steps = tl.math.log2(K.to(tl.float32)).to(tl.int32) + 1 

41 for _ in range(steps): 

42 mid = start + (end - start) // 2 

43 x = tl.load(cdf_ptr + mid, mask=n < N) 

44 start = tl.where(x < rv, mid + 1, start) 

45 end = tl.where(x < rv, end, mid) 

46 

47 # Returns the last index in case of an overflow 

48 start = tl.where(start >= K, K - 1, start) 

49 

50 tl.store(out_ptr + y_off + n, start, mask=n < N) 

51 

52 

53def multinomial(prob, n_samples, with_replacement=False, *, gen=None): 

54 logger.debug("GEMS_ASCEND MULTINOMIAL") 

55 assert prob.dtype in (torch.float16, torch.float32, torch.bfloat16, torch.float64) 

56 assert 0 < prob.dim() <= 2, "prob_dist must be 1 or 2 dim" 

57 n_categories = prob.size(-1) 

58 assert n_categories <= (1 << 24), "number of categories cannot exceed 2^24" 

59 assert ( 

60 with_replacement or n_samples <= n_categories 

61 ), "cannot sample n_samples > prob.size(-1) samples without replacement." 

62 

63 # Sampling without replacement 

64 if (not with_replacement) or n_samples == 1: 

65 # In case of with_replacement, sampling is approximated by selecing 

66 # the top k indices over sorted probabilities with an exponential pertubation 

67 # s = argmax( p / q ) where q ~ Exp(1) 

68 q = torch.empty_like(prob).exponential_(1.0) 

69 s = torch.div(prob, q, out=q) 

70 if n_samples == 1: 

71 return torch.argmax(s, dim=-1, keepdim=True).to(torch.int64) 

72 else: 

73 vals, indices = torch.topk(s, n_samples, dim=-1) 

74 return indices.to(torch.int64) 

75 

76 from flag_gems.runtime.backend._ascend import ops as asd_ops 

77 

78 cum_prob = asd_ops.normed_cumsum(prob, dim=-1) 

79 

80 if cum_prob.dim() == 1: 

81 n_dist = 1 

82 out = torch.empty((n_samples,), device=prob.device, dtype=torch.int64) 

83 else: 

84 n_dist = cum_prob.size(0) 

85 out = torch.empty((n_dist, n_samples), device=prob.device, dtype=torch.int64) 

86 # The CTA level parallelism is framed in a 2d grid of blocks with grid.y 

87 # indexing into distributions and grid.x output sample batches 

88 increment = n_dist * n_samples 

89 philox_seed, philox_offset = philox_backend_seed_offset(increment) 

90 grid = lambda META: (triton.cdiv(n_samples, META["NBLOCK"]), n_dist) 

91 multinomial_with_replacement[grid]( 

92 cum_prob, out, n_categories, n_samples, philox_seed, philox_offset 

93 ) 

94 return out