Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/normal.py: 0%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-26 15:32 +0800
1import logging
3import torch
4import triton
6from flag_gems.runtime import torch_device_fn
7from flag_gems.utils.random_utils import philox_backend_seed_offset
8from flag_gems.utils.shape_utils import broadcast_shapes, volume
10from ..utils.pointwise_dynamic import pointwise_dynamic
11from .randn import randn_kernel
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@pointwise_dynamic(
17 is_tensor=[True, True, True], promotion_methods=[(0, 1, 2, "DEFAULT")]
18)
19@triton.jit
20def transform_func_tensor_tensor(val, std, mean):
21 return val * std + mean
24@pointwise_dynamic(
25 is_tensor=[True, False, True], promotion_methods=[(0, 1, 2, "DEFAULT")]
26)
27@triton.jit
28def transform_func_tensor_float(val, std, mean):
29 return val * std + mean
32@pointwise_dynamic(
33 is_tensor=[True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
34)
35@triton.jit
36def transform_func_float_tensor(val, std, mean):
37 return val * std + mean
40@pointwise_dynamic(
41 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
42)
43@triton.jit
44def transform_func_float_float(val, std, mean):
45 return val * std + mean
48UNROLL = 4
51def normal_distribution(shape, device, *, generator=None, out=None):
52 if out is None:
53 out = torch.empty(shape, device=device, dtype=torch.float32)
54 N = volume(shape)
55 # grid_fn = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),)
56 cluster_num = 12
57 BLOCK_SIZE = min(triton.next_power_of_2(triton.cdiv(N, cluster_num * UNROLL)), 1024)
58 # BLOCK_SIZE = min(triton.next_power_of_2(triton.cdiv(N, cluster_num * UNROLL)), triton.cdiv(32768, UNROLL))
59 grid_fn = triton.cdiv(N, BLOCK_SIZE * UNROLL)
61 increment = triton.cdiv(N, UNROLL)
62 philox_seed, philox_offset = philox_backend_seed_offset(
63 increment, generator=generator
64 )
65 with torch_device_fn.device(device):
66 randn_kernel[(grid_fn,)](out, N, philox_seed, philox_offset, BLOCK_SIZE)
67 return out
70def normal_tensor_tensor(mean, std, *, generator=None):
71 logger.debug("GEMS NORMAL_TENSOR_TENSOR")
72 shape = broadcast_shapes([mean.shape, std.shape])
73 device = mean.device
74 out = normal_distribution(shape, device)
75 return transform_func_tensor_tensor(out, std, mean)
78def normal_tensor_float(mean, std, *, generator=None):
79 logger.debug("GEMS NORMAL_TENSOR_FLOAT")
80 shape = mean.shape
81 device = mean.device
82 out = normal_distribution(shape, device)
83 return transform_func_tensor_float(out, std, mean)
86def normal_float_tensor(mean, std, *, generator=None):
87 logger.debug("GEMS NORMAL_FLOAT_TENSOR")
88 shape = std.shape
89 device = std.device
90 out = normal_distribution(shape, device)
91 return transform_func_float_tensor(out, std, mean)
94def normal_(self, mean=0, std=1, *, generator=None):
95 logger.debug("GEMS NORMAL_")
96 shape = self.shape
97 device = self.device
98 self = normal_distribution(shape, device, generator=None, out=self)
99 transform_func_float_float(self, std, mean, out0=self)
100 return self