Coverage for src/flag_gems/ops/segment_reduce.py: 35%
591 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
14_BLOCK_SIZE = 1024
15_NPU_BLOCK_SIZE = 256
16_UNIFORM_FAST_PATH_MIN_NUMEL = 1 << 20
17_UNIFORM_KERNEL_MAX_SEGMENT_LENGTH = 1024
18_UNIFORM_LENGTHS_CACHE = {}
19_SUPPORTED_REDUCES = ("sum", "mean", "max", "min", "prod")
20_SUPPORTED_DATA_DTYPES = (
21 torch.float16,
22 torch.bfloat16,
23 torch.float32,
24 torch.float64,
25)
26_SUPPORTED_INDEX_DTYPES = (torch.int32, torch.int64)
29def _prod(shape):
30 return math.prod(shape) if shape else 1
33def _get_block_size(device):
34 return _NPU_BLOCK_SIZE if device.type == "npu" else _BLOCK_SIZE
37def _get_uniform_kernel_config(device, inner_size):
38 if device.type == "npu":
39 return 4, 16 if inner_size > 1 else 1
40 if inner_size == 1:
41 return 16, 1
42 return 4, 64
45def _get_uniform_backward_tile_config(device, inner_size, reduce, dtype):
46 if device.type == "npu":
47 return 1, 16 if inner_size > 1 else 1
48 if reduce == "prod" and inner_size > 1 and dtype in (torch.float16, torch.bfloat16):
49 return 4, 256
50 return 4, 64 if inner_size > 1 else 1
53@triton.jit
54def _mul_combine(a, b):
55 return a * b
58def _all_lengths_equal(lengths, value):
59 cache_key = (
60 lengths.device.type,
61 lengths.data_ptr(),
62 tuple(lengths.shape),
63 getattr(lengths, "_version", None),
64 value,
65 )
66 is_equal = _UNIFORM_LENGTHS_CACHE.get(cache_key)
67 if is_equal is None:
68 is_equal = torch.all(lengths.detach().cpu() == value).item()
69 if len(_UNIFORM_LENGTHS_CACHE) > 128:
70 _UNIFORM_LENGTHS_CACHE.clear()
71 _UNIFORM_LENGTHS_CACHE[cache_key] = is_equal
72 return is_equal
75def _wrap_axis(axis, ndim):
76 if ndim == 0:
77 raise IndexError(
78 "segment_reduce(): input tensor must have at least one dimension."
79 )
80 if axis < -ndim or axis >= ndim:
81 raise IndexError(
82 f"segment_reduce(): axis {axis} is out of bounds for tensor of dimension {ndim}."
83 )
84 return axis % ndim
87def _check_reduce_and_dtype(data, reduce):
88 if reduce not in _SUPPORTED_REDUCES:
89 raise RuntimeError(
90 "segment_reduce(): reduce must be one of 'sum', 'mean', 'max', 'min', or 'prod'."
91 )
92 if data.dtype not in _SUPPORTED_DATA_DTYPES:
93 raise NotImplementedError(f'"segment_reduce" not implemented for {data.dtype}.')
96def _check_index_tensor(data, index_tensor, name, axis):
97 if index_tensor.dtype not in _SUPPORTED_INDEX_DTYPES:
98 raise NotImplementedError(f"segment_reduce(): {name} must be int32 or int64.")
99 if index_tensor.device != data.device:
100 raise RuntimeError(
101 f"segment_reduce(): Expected data and {name} on the same device."
102 )
103 if data.dim() < index_tensor.dim():
104 raise RuntimeError(
105 f"segment_reduce(): Expected data.dim() >= {name}.dim(), got "
106 f"{data.dim()} and {index_tensor.dim()}."
107 )
108 if axis != index_tensor.dim() - 1:
109 raise RuntimeError(
110 f"segment_reduce(): Expected axis to be the last dimension of {name} "
111 f"but got {axis}."
112 )
115def _validate_lengths(data, lengths, axis, unsafe):
116 _check_index_tensor(data, lengths, "lengths", axis)
117 if unsafe:
118 return
119 lengths_cpu = lengths.detach().cpu()
120 if torch.any(lengths_cpu < 0).item():
121 raise RuntimeError("lengths contains negative value!")
122 valid_lengths = torch.all(lengths_cpu.sum(dim=-1) == data.size(axis)).item()
123 if not valid_lengths:
124 raise RuntimeError(
125 "segment_reduce(): Expected all rows of lengths along axis to sum to "
126 "data.size(lengths.dim()-1) when !unsafe."
127 )
130def _make_initial(reduce, initial):
131 if initial is not None:
132 return True, initial
133 if reduce == "max":
134 return False, float("-inf")
135 if reduce == "min":
136 return False, float("inf")
137 if reduce == "prod":
138 return False, 1.0
139 return False, 0.0
142def _get_uniform_segment_length(data, lengths, axis):
143 if data.numel() < _UNIFORM_FAST_PATH_MIN_NUMEL:
144 return None
145 if tuple(lengths.shape[:-1]) != tuple(data.shape[:axis]):
146 return None
147 segment_count = lengths.shape[-1]
148 if segment_count <= 0:
149 return None
150 data_size_axis = data.shape[axis]
151 if data_size_axis % segment_count != 0:
152 return None
153 segment_length = data_size_axis // segment_count
154 if segment_length <= 0:
155 return None
157 if _all_lengths_equal(lengths, segment_length):
158 return segment_length
159 return None
162def _is_unit_lengths(data, lengths, axis):
163 if tuple(lengths.shape[:-1]) != tuple(data.shape[:axis]):
164 return False
165 if lengths.shape[-1] != data.shape[axis]:
166 return False
167 return _all_lengths_equal(lengths, 1)
170@libentry()
171@triton.jit
172def _segment_reduce_uniform_other_backward_kernel(
173 grad,
174 output,
175 data,
176 grad_input,
177 total_rows,
178 segment_count,
179 segment_length,
180 inner_size,
181 data_size_axis,
182 IS_MAX: tl.constexpr,
183 IS_MIN: tl.constexpr,
184 IS_PROD: tl.constexpr,
185 INITIAL_PROD_VALUE: tl.constexpr,
186 BLOCK_M: tl.constexpr,
187 BLOCK_N: tl.constexpr,
188 BLOCK_K: tl.constexpr,
189):
190 pid_m = tle.program_id(0)
191 pid_k = tle.program_id(1)
192 data_dtype = data.dtype.element_ty
193 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32
195 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None]
196 seg_offsets = tl.arange(0, BLOCK_N)[None, :, None]
197 k_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)[None, None, :]
198 output_rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
199 output_k_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)[None, :]
200 row_mask = rows < total_rows
201 seg_mask = seg_offsets < segment_length
202 k_mask = k_offsets < inner_size
203 mask = row_mask & seg_mask & k_mask
205 outer_idx = rows // segment_count
206 dim_idx = rows - outer_idx * segment_count
207 data_offsets = (
208 outer_idx * data_size_axis * inner_size
209 + (dim_idx * segment_length + seg_offsets) * inner_size
210 + k_offsets
211 )
212 output_offsets = output_rows * inner_size + output_k_offsets
213 output_mask = (output_rows < total_rows) & (output_k_offsets < inner_size)
215 values = tl.load(data + data_offsets, mask=mask, other=0.0).to(compute_dtype)
216 grad_value = tl.load(grad + output_offsets, mask=output_mask, other=0.0).to(
217 compute_dtype
218 )
219 output_value = tl.load(output + output_offsets, mask=output_mask, other=0.0).to(
220 compute_dtype
221 )
223 if IS_MAX or IS_MIN:
224 match = ((values != values) | (values == output_value[:, None, :])) & mask
225 counter = tl.sum(match.to(tl.int64), axis=1)
226 store_value = tl.where(
227 (counter >= 2) & (grad_value > 0),
228 grad_value / counter,
229 grad_value,
230 )
231 tl.store(grad_input + data_offsets, store_value[:, None, :], mask=match)
232 elif IS_PROD:
233 nan_mask = (values != values) & mask
234 zero_mask = (values == 0) & mask & ~nan_mask
235 zero_count = tl.sum(zero_mask.to(tl.int64), axis=1)
236 nan_count = tl.sum(nan_mask.to(tl.int64), axis=1)
237 product_values = tl.where(nan_mask | zero_mask | ~mask, 1.0, values)
238 product = tl.reduce(product_values, axis=1, combine_fn=_mul_combine)
239 product *= INITIAL_PROD_VALUE
241 zero_scalar = tl.full((BLOCK_M, BLOCK_K), 0.0, dtype=compute_dtype)
242 nan_scalar = zero_scalar / zero_scalar
243 normal_prefix = grad_value * output_value
244 normal_grad = normal_prefix[:, None, :] / values
245 zero_exclusive = tl.where(
246 nan_count > 0,
247 nan_scalar,
248 tl.where(zero_count > 1, zero_scalar, product),
249 )
250 nan_exclusive = tl.where(
251 nan_count > 1,
252 nan_scalar,
253 tl.where(zero_count > 0, zero_scalar, product),
254 )
255 exclusive = tl.where(
256 nan_mask, nan_exclusive[:, None, :], zero_exclusive[:, None, :]
257 )
258 grad_result = tl.where(
259 nan_mask | zero_mask,
260 grad_value[:, None, :] * exclusive,
261 normal_grad,
262 )
263 tl.store(grad_input + data_offsets, grad_result, mask=mask)
266@libentry()
267@triton.jit
268def _segment_reduce_uniform_sum_mean_backward_kernel(
269 grad,
270 grad_input,
271 total_numel,
272 segment_count,
273 segment_length,
274 inner_size,
275 data_size_axis,
276 IS_MEAN: tl.constexpr,
277 BLOCK_SIZE: tl.constexpr,
278):
279 pid = tle.program_id(0)
280 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
281 mask = offsets < total_numel
283 inner_idx = offsets % inner_size
284 axis_idx = (offsets // inner_size) % data_size_axis
285 outer_idx = offsets // (data_size_axis * inner_size)
286 segment_idx = axis_idx // segment_length
287 grad_offsets = (outer_idx * segment_count + segment_idx) * inner_size + inner_idx
289 grad_value = tl.load(grad + grad_offsets, mask=mask, other=0.0)
290 if IS_MEAN:
291 grad_value = grad_value / segment_length
292 tl.store(grad_input + offsets, grad_value, mask=mask)
295@libentry()
296@triton.jit
297def _segment_reduce_uniform_inner1_forward_kernel(
298 data,
299 output,
300 total_rows,
301 segment_count,
302 segment_length,
303 data_size_axis,
304 IS_SUM: tl.constexpr,
305 IS_MEAN: tl.constexpr,
306 IS_MAX: tl.constexpr,
307 IS_MIN: tl.constexpr,
308 IS_PROD: tl.constexpr,
309 BLOCK_M: tl.constexpr,
310 BLOCK_N: tl.constexpr,
311):
312 pid = tle.program_id(0)
313 data_dtype = data.dtype.element_ty
314 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32
316 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
317 cols = tl.arange(0, BLOCK_N)[None, :]
318 row_mask = rows < total_rows
319 col_mask = cols < segment_length
320 mask = row_mask & col_mask
322 outer_idx = rows // segment_count
323 dim_idx = rows - outer_idx * segment_count
324 data_offsets = outer_idx * data_size_axis + dim_idx * segment_length + cols
326 if IS_SUM or IS_MEAN:
327 values = tl.load(data + data_offsets, mask=mask, other=0.0).to(compute_dtype)
328 result = tl.sum(values, axis=1)
329 if IS_MEAN:
330 result = result / segment_length
331 elif IS_PROD:
332 values = tl.load(data + data_offsets, mask=mask, other=1.0).to(compute_dtype)
333 result = tl.reduce(values, axis=1, combine_fn=_mul_combine)
334 elif IS_MAX:
335 values = tl.load(data + data_offsets, mask=mask, other=float("-inf")).to(
336 compute_dtype
337 )
338 nan_mask = (values != values) & mask
339 has_nan = tl.sum(nan_mask.to(tl.int32), axis=1) > 0
340 nan_value = tl.sum(tl.where(nan_mask, values, 0.0), axis=1)
341 result = tl.max(tl.where(mask & ~nan_mask, values, float("-inf")), axis=1)
342 result = tl.where(has_nan, nan_value, result)
343 elif IS_MIN:
344 values = tl.load(data + data_offsets, mask=mask, other=float("inf")).to(
345 compute_dtype
346 )
347 nan_mask = (values != values) & mask
348 has_nan = tl.sum(nan_mask.to(tl.int32), axis=1) > 0
349 nan_value = tl.sum(tl.where(nan_mask, values, 0.0), axis=1)
350 result = tl.min(tl.where(mask & ~nan_mask, values, float("inf")), axis=1)
351 result = tl.where(has_nan, nan_value, result)
353 output_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M)
354 tl.store(output + output_offsets, result, mask=output_offsets < total_rows)
357@libentry()
358@triton.jit
359def _segment_reduce_uniform_forward_kernel(
360 data,
361 output,
362 total_rows,
363 segment_count,
364 segment_length,
365 inner_size,
366 data_size_axis,
367 IS_SUM: tl.constexpr,
368 IS_MEAN: tl.constexpr,
369 IS_MAX: tl.constexpr,
370 IS_MIN: tl.constexpr,
371 IS_PROD: tl.constexpr,
372 BLOCK_M: tl.constexpr,
373 BLOCK_K: tl.constexpr,
374):
375 pid_m = tle.program_id(0)
376 pid_k = tle.program_id(1)
377 data_dtype = data.dtype.element_ty
378 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32
380 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
381 k_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)[None, :]
382 row_mask = rows < total_rows
383 k_mask = k_offsets < inner_size
384 mask = row_mask & k_mask
386 outer_idx = rows // segment_count
387 dim_idx = rows - outer_idx * segment_count
388 segment_start = dim_idx * segment_length
389 base_offsets = (
390 outer_idx * data_size_axis * inner_size + segment_start * inner_size + k_offsets
391 )
393 if IS_MAX:
394 acc = tl.full((BLOCK_M, BLOCK_K), float("-inf"), dtype=compute_dtype)
395 elif IS_MIN:
396 acc = tl.full((BLOCK_M, BLOCK_K), float("inf"), dtype=compute_dtype)
397 elif IS_PROD:
398 acc = tl.full((BLOCK_M, BLOCK_K), 1.0, dtype=compute_dtype)
399 else:
400 acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=compute_dtype)
402 has_nan = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.int1)
403 nan_value = tl.zeros((BLOCK_M, BLOCK_K), dtype=compute_dtype)
405 pos = 0
406 while pos < segment_length:
407 data_offsets = base_offsets + pos * inner_size
408 if IS_SUM or IS_MEAN:
409 values = tl.load(data + data_offsets, mask=mask, other=0.0).to(
410 compute_dtype
411 )
412 acc += values
413 elif IS_PROD:
414 values = tl.load(data + data_offsets, mask=mask, other=1.0).to(
415 compute_dtype
416 )
417 acc *= values
418 elif IS_MAX:
419 values = tl.load(data + data_offsets, mask=mask, other=float("-inf")).to(
420 compute_dtype
421 )
422 nan_mask = (values != values) & mask
423 has_nan |= nan_mask
424 nan_value = tl.where(nan_mask, values, nan_value)
425 acc = tl.maximum(acc, tl.where(mask & ~nan_mask, values, float("-inf")))
426 elif IS_MIN:
427 values = tl.load(data + data_offsets, mask=mask, other=float("inf")).to(
428 compute_dtype
429 )
430 nan_mask = (values != values) & mask
431 has_nan |= nan_mask
432 nan_value = tl.where(nan_mask, values, nan_value)
433 acc = tl.minimum(acc, tl.where(mask & ~nan_mask, values, float("inf")))
434 pos += 1
436 if IS_MEAN:
437 acc = acc / segment_length
438 if IS_MAX or IS_MIN:
439 acc = tl.where(has_nan, nan_value, acc)
441 output_offsets = rows * inner_size + k_offsets
442 tl.store(output + output_offsets, acc, mask=mask)
445def _segment_reduce_uniform_lengths(data, reduce, lengths, axis):
446 segment_count = lengths.shape[-1]
447 segment_length = _get_uniform_segment_length(data, lengths, axis)
448 if segment_length is None:
449 return None
451 output_shape = lengths.shape + data.shape[axis + 1 :]
452 inner_size = _prod(data.shape[axis + 1 :])
453 if segment_length <= _UNIFORM_KERNEL_MAX_SEGMENT_LENGTH:
454 output = torch.empty(output_shape, dtype=data.dtype, device=data.device)
455 if output.numel() == 0:
456 return output
457 total_rows = _prod(lengths.shape)
458 if inner_size == 1:
459 block_m = 4 if data.device.type == "npu" else 32
460 block_n = min(
461 _get_block_size(data.device),
462 triton.next_power_of_2(segment_length),
463 )
464 grid = (triton.cdiv(total_rows, block_m),)
465 with torch_device_fn.device(data.device):
466 _segment_reduce_uniform_inner1_forward_kernel[grid](
467 data,
468 output,
469 total_rows,
470 segment_count,
471 segment_length,
472 data.shape[axis],
473 reduce == "sum",
474 reduce == "mean",
475 reduce == "max",
476 reduce == "min",
477 reduce == "prod",
478 BLOCK_M=block_m,
479 BLOCK_N=block_n,
480 )
481 return output
483 block_m, block_k = _get_uniform_kernel_config(data.device, inner_size)
484 grid = (triton.cdiv(total_rows, block_m), triton.cdiv(inner_size, block_k))
485 with torch_device_fn.device(data.device):
486 _segment_reduce_uniform_forward_kernel[grid](
487 data,
488 output,
489 total_rows,
490 segment_count,
491 segment_length,
492 inner_size,
493 data.shape[axis],
494 reduce == "sum",
495 reduce == "mean",
496 reduce == "max",
497 reduce == "min",
498 reduce == "prod",
499 BLOCK_M=block_m,
500 BLOCK_K=block_k,
501 )
502 return output
504 if data.device.type == "npu":
505 return None
507 view_shape = (
508 data.shape[:axis] + (segment_count, segment_length) + data.shape[axis + 1 :]
509 )
510 reshaped = data.reshape(view_shape)
511 reduce_dim = axis + 1
513 if segment_length == 1:
514 return torch.squeeze(reshaped, dim=reduce_dim)
515 if reduce == "sum":
516 return torch.sum(reshaped, dim=reduce_dim)
517 if reduce == "mean":
518 return torch.mean(reshaped, dim=reduce_dim)
519 if reduce == "max":
520 return torch.amax(reshaped, dim=reduce_dim)
521 if reduce == "min":
522 return torch.amin(reshaped, dim=reduce_dim)
523 return torch.prod(reshaped, dim=reduce_dim)
526def _segment_reduce_uniform_sum_mean_backward(data, grad, reduce, lengths, axis):
527 segment_count = lengths.shape[-1]
528 segment_length = _get_uniform_segment_length(data, lengths, axis)
529 if segment_length is None:
530 return None
532 grad_input = torch.empty_like(data, dtype=grad.dtype)
533 if grad_input.numel() == 0:
534 return grad_input
536 inner_size = _prod(data.shape[axis + 1 :])
537 block_size = _get_block_size(data.device)
538 grid = (triton.cdiv(data.numel(), block_size),)
539 with torch_device_fn.device(data.device):
540 _segment_reduce_uniform_sum_mean_backward_kernel[grid](
541 grad,
542 grad_input,
543 data.numel(),
544 segment_count,
545 segment_length,
546 inner_size,
547 data.shape[axis],
548 reduce == "mean",
549 BLOCK_SIZE=block_size,
550 )
551 return grad_input
554def _segment_reduce_uniform_other_backward(
555 data, output, grad, reduce, lengths, axis, initial
556):
557 segment_count = lengths.shape[-1]
558 segment_length = _get_uniform_segment_length(data, lengths, axis)
559 if segment_length is None or segment_length > _UNIFORM_KERNEL_MAX_SEGMENT_LENGTH:
560 return None
562 if reduce in ("max", "min"):
563 grad_input = torch.zeros_like(data, dtype=grad.dtype)
564 else:
565 grad_input = torch.empty_like(data, dtype=grad.dtype)
566 if grad_input.numel() == 0:
567 return grad_input
569 inner_size = _prod(data.shape[axis + 1 :])
570 total_rows = _prod(lengths.shape)
571 block_m, block_k = _get_uniform_backward_tile_config(
572 data.device, inner_size, reduce, data.dtype
573 )
574 block_n = min(_get_block_size(data.device), triton.next_power_of_2(segment_length))
575 _, initial_prod_value = _make_initial("prod", initial)
576 grid = (triton.cdiv(total_rows, block_m), triton.cdiv(inner_size, block_k))
577 with torch_device_fn.device(data.device):
578 _segment_reduce_uniform_other_backward_kernel[grid](
579 grad,
580 output,
581 data,
582 grad_input,
583 total_rows,
584 segment_count,
585 segment_length,
586 inner_size,
587 data.shape[axis],
588 reduce == "max",
589 reduce == "min",
590 reduce == "prod",
591 initial_prod_value,
592 BLOCK_M=block_m,
593 BLOCK_N=block_n,
594 BLOCK_K=block_k,
595 )
596 return grad_input
599@libentry()
600@triton.jit
601def _lengths_to_offsets_kernel(
602 lengths,
603 offsets,
604 outer_count,
605 segment_count,
606):
607 pid = tle.program_id(0)
608 acc = tl.full((), 0, dtype=tl.int64)
609 base_lengths = pid * segment_count
610 base_offsets = pid * (segment_count + 1)
611 tl.store(offsets + base_offsets, acc)
613 idx = 0
614 while idx < segment_count:
615 length = tl.load(lengths + base_lengths + idx)
616 acc += length
617 tl.store(offsets + base_offsets + idx + 1, acc)
618 idx += 1
621@libentry()
622@triton.jit
623def _segment_reduce_forward_kernel(
624 data,
625 offsets,
626 output,
627 segment_count,
628 inner_size,
629 data_size_axis,
630 IS_SUM: tl.constexpr,
631 IS_MEAN: tl.constexpr,
632 IS_MAX: tl.constexpr,
633 IS_MIN: tl.constexpr,
634 IS_PROD: tl.constexpr,
635 HAS_INITIAL: tl.constexpr,
636 INITIAL_VALUE: tl.constexpr,
637 BLOCK_SIZE: tl.constexpr,
638):
639 pid = tle.program_id(0)
640 data_dtype = data.dtype.element_ty
641 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32
643 inner_idx = pid % inner_size
644 row_idx = pid // inner_size
645 dim_idx = row_idx % segment_count
646 outer_idx = row_idx // segment_count
648 offsets_base = outer_idx * (segment_count + 1) + dim_idx
649 segment_start = tl.load(offsets + offsets_base)
650 segment_end = tl.load(offsets + offsets_base + 1)
651 segment_length = segment_end - segment_start
653 acc = tl.full((), INITIAL_VALUE, dtype=compute_dtype)
654 if IS_PROD:
655 pos = segment_start
656 while pos < segment_end:
657 data_offset = (
658 outer_idx * data_size_axis * inner_size + pos * inner_size + inner_idx
659 )
660 value = tl.load(data + data_offset).to(compute_dtype)
661 acc *= value
662 pos += 1
663 else:
664 pos = segment_start
665 while pos < segment_end:
666 segment_offsets = pos + tl.arange(0, BLOCK_SIZE)
667 mask = segment_offsets < segment_end
668 data_offsets = (
669 outer_idx * data_size_axis * inner_size
670 + segment_offsets * inner_size
671 + inner_idx
672 )
674 if IS_SUM or IS_MEAN:
675 values = tl.load(data + data_offsets, mask=mask, other=0.0).to(
676 compute_dtype
677 )
678 acc += tl.sum(tl.where(mask, values, 0.0), axis=0)
679 elif IS_MAX:
680 values = tl.load(
681 data + data_offsets, mask=mask, other=float("-inf")
682 ).to(compute_dtype)
683 nan_mask = (values != values) & mask
684 has_nan = tl.sum(nan_mask.to(tl.int32), axis=0) > 0
685 nan_value = tl.sum(tl.where(nan_mask, values, 0.0), axis=0)
686 chunk = tl.max(
687 tl.where(mask & ~nan_mask, values, float("-inf")), axis=0
688 )
689 chunk = tl.where(has_nan, nan_value, chunk)
690 acc = tl.where(has_nan, chunk, tl.maximum(acc, chunk))
691 elif IS_MIN:
692 values = tl.load(data + data_offsets, mask=mask, other=float("inf")).to(
693 compute_dtype
694 )
695 nan_mask = (values != values) & mask
696 has_nan = tl.sum(nan_mask.to(tl.int32), axis=0) > 0
697 nan_value = tl.sum(tl.where(nan_mask, values, 0.0), axis=0)
698 chunk = tl.min(tl.where(mask & ~nan_mask, values, float("inf")), axis=0)
699 chunk = tl.where(has_nan, nan_value, chunk)
700 acc = tl.where(has_nan, chunk, tl.minimum(acc, chunk))
701 pos += BLOCK_SIZE
703 if IS_MEAN:
704 acc_is_nan = acc != acc
705 nan_value = acc / acc
706 if not HAS_INITIAL:
707 acc = tl.where(segment_length == 0, nan_value, acc)
708 acc = tl.where((segment_length > 0) & ~acc_is_nan, acc / segment_length, acc)
710 tl.store(output + pid, acc)
713@libentry()
714@triton.jit
715def _segment_reduce_backward_kernel(
716 grad,
717 output,
718 data,
719 offsets,
720 grad_input,
721 segment_count,
722 inner_size,
723 data_size_axis,
724 IS_SUM: tl.constexpr,
725 IS_MEAN: tl.constexpr,
726 IS_MAX: tl.constexpr,
727 IS_MIN: tl.constexpr,
728 IS_PROD: tl.constexpr,
729 INITIAL_PROD_VALUE: tl.constexpr,
730 BLOCK_SIZE: tl.constexpr,
731):
732 pid = tle.program_id(0)
733 data_dtype = data.dtype.element_ty
734 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32
736 inner_idx = pid % inner_size
737 row_idx = pid // inner_size
738 dim_idx = row_idx % segment_count
739 outer_idx = row_idx // segment_count
741 offsets_base = outer_idx * (segment_count + 1) + dim_idx
742 segment_start = tl.load(offsets + offsets_base)
743 segment_end = tl.load(offsets + offsets_base + 1)
744 segment_length = segment_end - segment_start
746 if segment_length > 0:
747 grad_value = tl.load(grad + pid).to(compute_dtype)
748 output_value = tl.load(output + pid).to(compute_dtype)
750 if IS_SUM or IS_MEAN:
751 if IS_MEAN:
752 grad_value = grad_value / segment_length
753 pos = segment_start
754 while pos < segment_end:
755 segment_offsets = pos + tl.arange(0, BLOCK_SIZE)
756 mask = segment_offsets < segment_end
757 data_offsets = (
758 outer_idx * data_size_axis * inner_size
759 + segment_offsets * inner_size
760 + inner_idx
761 )
762 tl.store(grad_input + data_offsets, grad_value, mask=mask)
763 pos += BLOCK_SIZE
764 elif IS_MAX or IS_MIN:
765 counter = tl.full((), 0, dtype=tl.int64)
766 pos = segment_start
767 while pos < segment_end:
768 segment_offsets = pos + tl.arange(0, BLOCK_SIZE)
769 mask = segment_offsets < segment_end
770 data_offsets = (
771 outer_idx * data_size_axis * inner_size
772 + segment_offsets * inner_size
773 + inner_idx
774 )
775 values = tl.load(data + data_offsets, mask=mask, other=0.0).to(
776 compute_dtype
777 )
778 match = ((values != values) | (values == output_value)) & mask
779 counter += tl.sum(match.to(tl.int64), axis=0)
780 pos += BLOCK_SIZE
782 store_value = tl.where(
783 (counter >= 2) & (grad_value > 0),
784 grad_value / counter,
785 grad_value,
786 )
787 pos = segment_start
788 while pos < segment_end:
789 segment_offsets = pos + tl.arange(0, BLOCK_SIZE)
790 mask = segment_offsets < segment_end
791 data_offsets = (
792 outer_idx * data_size_axis * inner_size
793 + segment_offsets * inner_size
794 + inner_idx
795 )
796 values = tl.load(data + data_offsets, mask=mask, other=0.0).to(
797 compute_dtype
798 )
799 match = ((values != values) | (values == output_value)) & mask
800 tl.store(grad_input + data_offsets, store_value, mask=match)
801 pos += BLOCK_SIZE
802 elif IS_PROD:
803 zero_count = tl.full((), 0, dtype=tl.int64)
804 nan_count = tl.full((), 0, dtype=tl.int64)
805 product = tl.full((), INITIAL_PROD_VALUE, dtype=compute_dtype)
806 pos = segment_start
807 while pos < segment_end:
808 data_offset = (
809 outer_idx * data_size_axis * inner_size
810 + pos * inner_size
811 + inner_idx
812 )
813 value = tl.load(data + data_offset).to(compute_dtype)
814 if value != value:
815 nan_count += 1
816 elif value == 0:
817 zero_count += 1
818 else:
819 product *= value
820 pos += 1
822 zero_scalar = tl.full((), 0.0, dtype=compute_dtype)
823 nan_scalar = zero_scalar / zero_scalar
824 normal_prefix = grad_value * output_value
825 pos = segment_start
826 while pos < segment_end:
827 segment_offsets = pos + tl.arange(0, BLOCK_SIZE)
828 mask = segment_offsets < segment_end
829 data_offsets = (
830 outer_idx * data_size_axis * inner_size
831 + segment_offsets * inner_size
832 + inner_idx
833 )
834 values = tl.load(data + data_offsets, mask=mask, other=1.0).to(
835 compute_dtype
836 )
837 nan_mask = (values != values) & mask
838 zero_mask = (values == 0) & mask & ~nan_mask
839 normal_grad = normal_prefix / values
840 zero_exclusive = tl.where(
841 nan_count > 0,
842 nan_scalar,
843 tl.where(zero_count > 1, zero_scalar, product),
844 )
845 nan_exclusive = tl.where(
846 nan_count > 1,
847 nan_scalar,
848 tl.where(zero_count > 0, zero_scalar, product),
849 )
850 exclusive = tl.where(nan_mask, nan_exclusive, zero_exclusive)
851 grad_result = tl.where(
852 nan_mask | zero_mask,
853 grad_value * exclusive,
854 normal_grad,
855 )
856 tl.store(grad_input + data_offsets, grad_result, mask=mask)
857 pos += BLOCK_SIZE
860def _lengths_to_offsets(lengths):
861 segment_count = lengths.shape[-1]
862 offsets_shape = lengths.shape[:-1] + (segment_count + 1,)
863 offsets = torch.empty(offsets_shape, dtype=lengths.dtype, device=lengths.device)
864 outer_count = _prod(lengths.shape[:-1])
865 if offsets.numel() > 0:
866 with torch_device_fn.device(lengths.device):
867 _lengths_to_offsets_kernel[(outer_count,)](
868 lengths,
869 offsets,
870 outer_count,
871 segment_count,
872 )
873 return offsets
876def _prepare_common(data, reduce, lengths, offsets, indices, axis, unsafe):
877 _check_reduce_and_dtype(data, reduce)
878 axis = _wrap_axis(axis, data.dim())
879 if indices is not None:
880 raise RuntimeError(
881 "segment_reduce(): indices based reduction is not supported yet."
882 )
884 if offsets is not None:
885 _check_index_tensor(data, offsets, "offsets", axis)
886 offsets_contig = offsets.contiguous()
887 segment_count = offsets_contig.shape[-1] - 1
888 output_shape = (
889 offsets_contig.shape[:-1] + (segment_count,) + data.shape[axis + 1 :]
890 )
891 return axis, offsets_contig, output_shape, True
893 if lengths is None:
894 raise RuntimeError(
895 "segment_reduce(): Either lengths or offsets must be defined."
896 )
898 _validate_lengths(data, lengths, axis, unsafe)
899 lengths_contig = lengths.contiguous()
900 offsets_contig = _lengths_to_offsets(lengths_contig)
901 output_shape = lengths_contig.shape + data.shape[axis + 1 :]
902 return axis, offsets_contig, output_shape, False
905def segment_reduce(
906 data,
907 reduce,
908 *,
909 lengths=None,
910 indices=None,
911 offsets=None,
912 axis=0,
913 unsafe=False,
914 initial=None,
915):
916 logger.debug("GEMS SEGMENT_REDUCE")
917 _check_reduce_and_dtype(data, reduce)
918 axis = _wrap_axis(axis, data.dim())
919 if indices is not None:
920 raise RuntimeError(
921 "segment_reduce(): indices based reduction is not supported yet."
922 )
924 if initial is None and lengths is not None and offsets is None:
925 _check_index_tensor(data, lengths, "lengths", axis)
926 if _is_unit_lengths(data, lengths, axis):
927 return data.contiguous()
929 data_contig = data.contiguous()
930 uniform_result = _segment_reduce_uniform_lengths(
931 data_contig, reduce, lengths, axis
932 )
933 if uniform_result is not None:
934 return uniform_result
936 axis, offsets_contig, output_shape, _ = _prepare_common(
937 data, reduce, lengths, offsets, indices, axis, unsafe
938 )
940 data_contig = data.contiguous()
941 output = torch.empty(output_shape, dtype=data.dtype, device=data.device)
942 if output.numel() == 0:
943 return output
945 segment_count = output_shape[axis]
946 inner_size = _prod(data_contig.shape[axis + 1 :])
947 data_size_axis = data_contig.shape[axis]
948 has_initial, initial_value = _make_initial(reduce, initial)
949 grid = (output.numel(),)
951 with torch_device_fn.device(data.device):
952 _segment_reduce_forward_kernel[grid](
953 data_contig,
954 offsets_contig,
955 output,
956 segment_count,
957 inner_size,
958 data_size_axis,
959 reduce == "sum",
960 reduce == "mean",
961 reduce == "max",
962 reduce == "min",
963 reduce == "prod",
964 has_initial,
965 initial_value,
966 BLOCK_SIZE=_get_block_size(data.device),
967 )
968 return output
971def segment_reduce_out(
972 data,
973 reduce,
974 *,
975 lengths=None,
976 indices=None,
977 offsets=None,
978 axis=0,
979 unsafe=False,
980 initial=None,
981 out,
982):
983 logger.debug("GEMS SEGMENT_REDUCE_OUT")
984 result = segment_reduce(
985 data,
986 reduce,
987 lengths=lengths,
988 indices=indices,
989 offsets=offsets,
990 axis=axis,
991 unsafe=unsafe,
992 initial=initial,
993 )
994 if out.shape != result.shape:
995 out.resize_(result.shape)
996 out.copy_(result)
997 return out
1000def _segment_reduce_backward(
1001 grad,
1002 output,
1003 data,
1004 reduce,
1005 *,
1006 lengths=None,
1007 offsets=None,
1008 axis=0,
1009 initial=None,
1010):
1011 logger.debug("GEMS _SEGMENT_REDUCE_BACKWARD")
1012 if (
1013 initial is None
1014 and lengths is not None
1015 and offsets is None
1016 and reduce in _SUPPORTED_REDUCES
1017 ):
1018 _check_reduce_and_dtype(data, reduce)
1019 axis = _wrap_axis(axis, data.dim())
1020 _check_index_tensor(data, lengths, "lengths", axis)
1021 if _is_unit_lengths(data, lengths, axis):
1022 return grad.contiguous()
1024 if lengths is not None and offsets is None and reduce in ("sum", "mean"):
1025 _check_reduce_and_dtype(data, reduce)
1026 axis = _wrap_axis(axis, data.dim())
1027 _check_index_tensor(data, lengths, "lengths", axis)
1028 data_contig = data.contiguous()
1029 grad_contig = grad.contiguous()
1030 uniform_result = _segment_reduce_uniform_sum_mean_backward(
1031 data_contig, grad_contig, reduce, lengths, axis
1032 )
1033 if uniform_result is not None:
1034 return uniform_result
1035 if lengths is not None and offsets is None and reduce in ("max", "min", "prod"):
1036 _check_reduce_and_dtype(data, reduce)
1037 axis = _wrap_axis(axis, data.dim())
1038 _check_index_tensor(data, lengths, "lengths", axis)
1039 data_contig = data.contiguous()
1040 grad_contig = grad.contiguous()
1041 output_contig = output.contiguous()
1042 uniform_result = _segment_reduce_uniform_other_backward(
1043 data_contig, output_contig, grad_contig, reduce, lengths, axis, initial
1044 )
1045 if uniform_result is not None:
1046 return uniform_result
1048 axis, offsets_contig, output_shape, _ = _prepare_common(
1049 data, reduce, lengths, offsets, None, axis, True
1050 )
1051 data_contig = data.contiguous()
1052 grad_contig = grad.contiguous()
1053 output_contig = output.contiguous()
1054 grad_input = torch.zeros(data_contig.shape, dtype=grad.dtype, device=grad.device)
1056 if output_contig.numel() == 0:
1057 return grad_input
1059 segment_count = output_shape[axis]
1060 inner_size = _prod(data_contig.shape[axis + 1 :])
1061 data_size_axis = data_contig.shape[axis]
1062 _, initial_prod_value = _make_initial("prod", initial)
1063 grid = (output_contig.numel(),)
1065 with torch_device_fn.device(data.device):
1066 _segment_reduce_backward_kernel[grid](
1067 grad_contig,
1068 output_contig,
1069 data_contig,
1070 offsets_contig,
1071 grad_input,
1072 segment_count,
1073 inner_size,
1074 data_size_axis,
1075 reduce == "sum",
1076 reduce == "mean",
1077 reduce == "max",
1078 reduce == "min",
1079 reduce == "prod",
1080 initial_prod_value,
1081 BLOCK_SIZE=_get_block_size(data.device),
1082 )
1083 return grad_input
1086def _segment_reduce_backward_out(
1087 grad,
1088 output,
1089 data,
1090 reduce,
1091 *,
1092 lengths=None,
1093 offsets=None,
1094 axis=0,
1095 initial=None,
1096 out,
1097):
1098 logger.debug("GEMS _SEGMENT_REDUCE_BACKWARD_OUT")
1099 result = _segment_reduce_backward(
1100 grad,
1101 output,
1102 data,
1103 reduce,
1104 lengths=lengths,
1105 offsets=offsets,
1106 axis=axis,
1107 initial=initial,
1108 )
1109 if out.shape != result.shape:
1110 out.resize_(result.shape)
1111 out.copy_(result)
1112 return out