Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/normal.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-23 02:03 +0800
1import logging
3import torch
4import triton
6from flag_gems.runtime import torch_device_fn
7from flag_gems.utils.random_utils import philox_backend_seed_offset
8from flag_gems.utils.shape_utils import broadcast_shapes, volume
10from ..utils.pointwise_dynamic import pointwise_dynamic
11from .randn import randn_kernel
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@pointwise_dynamic(
17 is_tensor=[True, True, True], promotion_methods=[(0, 1, 2, "DEFAULT")]
18)
19@triton.jit
20def transform_func_tensor_tensor(val, std, mean):
21 return val * std + mean
24@pointwise_dynamic(
25 is_tensor=[True, False, True], promotion_methods=[(0, 1, 2, "DEFAULT")]
26)
27@triton.jit
28def transform_func_tensor_float(val, std, mean):
29 return val * std + mean
32@pointwise_dynamic(
33 is_tensor=[True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
34)
35@triton.jit
36def transform_func_float_tensor(val, std, mean):
37 return val * std + mean
40@pointwise_dynamic(
41 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
42)
43@triton.jit
44def transform_func_float_float(val, std, mean):
45 return val * std + mean
48UNROLL = 4
51def normal_distribution(shape, device, *, generator=None):
52 out = torch.empty(shape, device=device, dtype=torch.float32)
53 N = volume(shape)
54 # grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
55 cluster_num = 12
56 BLOCK_SIZE = min(triton.next_power_of_2(triton.cdiv(N, cluster_num * UNROLL)), 1024)
57 # BLOCK_SIZE = min(triton.next_power_of_2(triton.cdiv(N, cluster_num * UNROLL)), triton.cdiv(32768, UNROLL))
58 grid_fn = triton.cdiv(N, BLOCK_SIZE * UNROLL)
60 increment = triton.cdiv(N, UNROLL)
61 philox_seed, philox_offset = philox_backend_seed_offset(
62 increment, generator=generator
63 )
64 with torch_device_fn.device(device):
65 randn_kernel[(grid_fn,)](out, N, philox_seed, philox_offset, BLOCK_SIZE)
66 return out
69def normal_tensor_tensor(mean, std, *, generator=None):
70 logger.debug("GEMS NORMAL_TENSOR_TENSOR")
71 shape = broadcast_shapes([mean.shape, std.shape])
72 device = mean.device
73 out = normal_distribution(shape, device)
74 return transform_func_tensor_tensor(out, std, mean)
77def normal_tensor_float(mean, std, *, generator=None):
78 logger.debug("GEMS NORMAL_TENSOR_FLOAT")
79 shape = mean.shape
80 device = mean.device
81 out = normal_distribution(shape, device)
82 return transform_func_tensor_float(out, std, mean)
85def normal_float_tensor(mean, std, *, generator=None):
86 logger.debug("GEMS NORMAL_FLOAT_TENSOR")
87 shape = std.shape
88 device = std.device
89 out = normal_distribution(shape, device)
90 return transform_func_float_tensor(out, std, mean)