Coverage for src/flag_gems/experimental_ops/trunc.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-03-15 02:11 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def trunc_kernel(
8 x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr, DTYPE_CODE: tl.constexpr
9):
10 pid = tl.program_id(axis=0)
11 block_start = pid * BLOCK_SIZE
12 offsets = block_start + tl.arange(0, BLOCK_SIZE)
13 mask = offsets < n_elements
15 x = tl.load(x_ptr + offsets, mask=mask)
17 # DTYPE_CODE:
18 # 0 -> integer types (copy)
19 # 1 -> float16
20 # 2 -> bfloat16
21 # 3 -> float32
22 # 4 -> float64
23 if DTYPE_CODE == 0:
24 y = x
25 elif DTYPE_CODE == 1:
26 xf = x.to(tl.float32)
27 y = tl.where(xf >= 0, tl.floor(xf), tl.ceil(xf)).to(tl.float16)
28 elif DTYPE_CODE == 2:
29 xf = x.to(tl.float32)
30 y = tl.where(xf >= 0, tl.floor(xf), tl.ceil(xf)).to(tl.bfloat16)
31 elif DTYPE_CODE == 3:
32 xf = x
33 y = tl.where(xf >= 0, tl.floor(xf), tl.ceil(xf))
34 elif DTYPE_CODE == 4:
35 xf = x
36 y = tl.where(xf >= 0, tl.floor(xf), tl.ceil(xf))
37 else:
38 # Fallback: copy
39 y = x
41 tl.store(out_ptr + offsets, y, mask=mask)
44def _dtype_code(t: torch.Tensor) -> int:
45 if t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64):
46 return 0
47 if t.dtype == torch.float16:
48 return 1
49 if t.dtype == torch.bfloat16:
50 return 2
51 if t.dtype == torch.float32:
52 return 3
53 if t.dtype == torch.float64:
54 return 4
55 raise NotImplementedError(f"Unsupported dtype: {t.dtype}")
58def _launch_trunc(inp: torch.Tensor, out: torch.Tensor):
59 assert inp.numel() == out.numel()
60 assert inp.device.type == "cuda" and out.device.type == "cuda"
61 n_elements = inp.numel()
62 if n_elements == 0:
63 return
65 code = _dtype_code(inp)
66 BLOCK_SIZE = 1024
67 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
69 trunc_kernel[grid](inp, out, n_elements, BLOCK_SIZE=BLOCK_SIZE, DTYPE_CODE=code)
72def trunc(input: torch.Tensor):
73 # Allocate output
74 out = torch.empty_like(input)
76 if input.is_complex():
77 # Work on real view
78 in_r = torch.view_as_real(input)
79 out_r = torch.view_as_real(out)
80 if not in_r.is_contiguous() or not out_r.is_contiguous():
81 in_r_c = in_r.contiguous()
82 out_r_c = out_r.contiguous()
83 _launch_trunc(in_r_c.view(-1), out_r_c.view(-1))
84 out_r.copy_(out_r_c)
85 else:
86 _launch_trunc(in_r.view(-1), out_r.view(-1))
87 else:
88 inp_c = input if input.is_contiguous() else input.contiguous()
89 out_c = out if out.is_contiguous() else out.contiguous()
90 _launch_trunc(inp_c.view(-1), out_c.view(-1))
91 if out_c.data_ptr() != out.data_ptr():
92 out.copy_(out_c)
94 return out
97def trunc_out(input: torch.Tensor, out: torch.Tensor):
98 assert input.shape == out.shape, "input and out must have the same shape"
99 assert input.dtype == out.dtype, "input and out must have the same dtype"
100 assert (
101 input.device.type == "cuda" and out.device.type == "cuda"
102 ), "Tensors must be on CUDA device"
104 if input.is_complex():
105 in_r = torch.view_as_real(input)
106 out_r = torch.view_as_real(out)
107 if not in_r.is_contiguous() or not out_r.is_contiguous():
108 in_r_c = in_r.contiguous()
109 out_r_c = out_r.contiguous()
110 _launch_trunc(in_r_c.view(-1), out_r_c.view(-1))
111 out_r.copy_(out_r_c)
112 else:
113 _launch_trunc(in_r.view(-1), out_r.view(-1))
114 else:
115 inp_c = input if input.is_contiguous() else input.contiguous()
116 if out.is_contiguous():
117 _launch_trunc(inp_c.view(-1), out.view(-1))
118 else:
119 out_c = out.contiguous()
120 _launch_trunc(inp_c.view(-1), out_c.view(-1))
121 out.copy_(out_c)
123 return out