Coverage for src/flag_gems/ops/__init__.py: 100%

175 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-03-18 02:36 +0800

1from flag_gems.ops.abs import abs, abs_ 

2from flag_gems.ops.acos import acos 

3from flag_gems.ops.add import add, add_ 

4from flag_gems.ops.addcdiv import addcdiv 

5from flag_gems.ops.addcmul import addcmul 

6from flag_gems.ops.addmm import addmm, addmm_out 

7from flag_gems.ops.addmv import addmv, addmv_out 

8from flag_gems.ops.addr import addr 

9from flag_gems.ops.all import all, all_dim, all_dims 

10from flag_gems.ops.amax import amax 

11from flag_gems.ops.angle import angle 

12from flag_gems.ops.any import any, any_dim, any_dims 

13from flag_gems.ops.arange import arange, arange_start 

14from flag_gems.ops.argmax import argmax 

15from flag_gems.ops.argmin import argmin 

16from flag_gems.ops.atan import atan, atan_ 

17from flag_gems.ops.attention import ( 

18 ScaleDotProductAttention, 

19 flash_attention_forward, 

20 flash_attn_varlen_func, 

21 scaled_dot_product_attention, 

22 scaled_dot_product_attention_backward, 

23 scaled_dot_product_attention_forward, 

24) 

25from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward 

26from flag_gems.ops.baddbmm import baddbmm 

27from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward 

28from flag_gems.ops.bitwise_and import ( 

29 bitwise_and_scalar, 

30 bitwise_and_scalar_, 

31 bitwise_and_scalar_tensor, 

32 bitwise_and_tensor, 

33 bitwise_and_tensor_, 

34) 

35from flag_gems.ops.bitwise_left_shift import bitwise_left_shift 

36from flag_gems.ops.bitwise_not import bitwise_not, bitwise_not_ 

37from flag_gems.ops.bitwise_or import ( 

38 bitwise_or_scalar, 

39 bitwise_or_scalar_, 

40 bitwise_or_scalar_tensor, 

41 bitwise_or_tensor, 

42 bitwise_or_tensor_, 

43) 

44from flag_gems.ops.bitwise_right_shift import bitwise_right_shift 

45from flag_gems.ops.bmm import bmm, bmm_out 

46from flag_gems.ops.cat import cat 

47from flag_gems.ops.ceil import ceil, ceil_, ceil_out 

48from flag_gems.ops.celu import celu, celu_ 

49from flag_gems.ops.clamp import ( 

50 clamp, 

51 clamp_, 

52 clamp_min, 

53 clamp_min_, 

54 clamp_tensor, 

55 clamp_tensor_, 

56) 

57from flag_gems.ops.contiguous import contiguous 

58from flag_gems.ops.conv1d import conv1d 

59from flag_gems.ops.conv2d import conv2d 

60from flag_gems.ops.conv3d import conv3d 

61from flag_gems.ops.conv_depthwise2d import _conv_depthwise2d 

62from flag_gems.ops.copy import copy, copy_ 

63from flag_gems.ops.cos import cos, cos_ 

64from flag_gems.ops.count_nonzero import count_nonzero 

65from flag_gems.ops.cummax import cummax 

66from flag_gems.ops.cummin import cummin 

67from flag_gems.ops.cumsum import cumsum, cumsum_out, normed_cumsum 

68from flag_gems.ops.diag import diag 

69from flag_gems.ops.diag_embed import diag_embed 

70from flag_gems.ops.diagonal import diagonal_backward 

71from flag_gems.ops.div import ( 

72 div_mode, 

73 div_mode_, 

74 floor_divide, 

75 floor_divide_, 

76 remainder, 

77 remainder_, 

78 true_divide, 

79 true_divide_, 

80 true_divide_out, 

81) 

82from flag_gems.ops.dot import dot 

83from flag_gems.ops.dropout import dropout, dropout_backward 

84from flag_gems.ops.elu import elu, elu_, elu_backward 

85from flag_gems.ops.embedding import embedding, embedding_backward 

86from flag_gems.ops.eq import eq, eq_scalar, equal 

87from flag_gems.ops.erf import erf, erf_ 

88from flag_gems.ops.exp import exp, exp_, exp_out 

89from flag_gems.ops.exp2 import exp2, exp2_ 

90from flag_gems.ops.exponential_ import exponential_ 

91from flag_gems.ops.eye import eye 

92from flag_gems.ops.eye_m import eye_m 

93from flag_gems.ops.fill import ( 

94 fill_scalar, 

95 fill_scalar_, 

96 fill_scalar_out, 

97 fill_tensor, 

98 fill_tensor_, 

99 fill_tensor_out, 

100) 

101from flag_gems.ops.flip import flip 

102from flag_gems.ops.full import full 

103from flag_gems.ops.full_like import full_like 

104from flag_gems.ops.gather import gather, gather_backward 

105from flag_gems.ops.ge import ge, ge_scalar 

106from flag_gems.ops.gelu import gelu, gelu_, gelu_backward 

107from flag_gems.ops.get_scheduler_metadata import get_scheduler_metadata 

108from flag_gems.ops.glu import glu, glu_backward 

109from flag_gems.ops.groupnorm import group_norm, group_norm_backward 

110from flag_gems.ops.gt import gt, gt_scalar 

111from flag_gems.ops.hstack import hstack 

112from flag_gems.ops.index import index 

113from flag_gems.ops.index_add import index_add, index_add_ 

114from flag_gems.ops.index_put import index_put, index_put_ 

115from flag_gems.ops.index_select import index_select 

116from flag_gems.ops.isclose import allclose, isclose 

117from flag_gems.ops.isfinite import isfinite 

118from flag_gems.ops.isin import isin 

119from flag_gems.ops.isinf import isinf 

120from flag_gems.ops.isnan import isnan 

121from flag_gems.ops.kron import kron 

122from flag_gems.ops.layernorm import layer_norm, layer_norm_backward 

123from flag_gems.ops.le import le, le_scalar 

124from flag_gems.ops.lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_ 

125from flag_gems.ops.linspace import linspace 

126from flag_gems.ops.log import log 

127from flag_gems.ops.log_sigmoid import log_sigmoid 

128from flag_gems.ops.log_softmax import log_softmax, log_softmax_backward 

129from flag_gems.ops.logical_and import logical_and, logical_and_ 

130from flag_gems.ops.logical_not import logical_not 

131from flag_gems.ops.logical_or import logical_or, logical_or_ 

132from flag_gems.ops.logical_xor import logical_xor 

133from flag_gems.ops.logspace import logspace 

134from flag_gems.ops.lt import lt, lt_scalar 

135from flag_gems.ops.masked_fill import masked_fill, masked_fill_ 

136from flag_gems.ops.masked_scatter import masked_scatter, masked_scatter_ 

137from flag_gems.ops.masked_select import masked_select 

138from flag_gems.ops.max import max, max_dim 

139from flag_gems.ops.max_pool2d_with_indices import ( 

140 max_pool2d_backward, 

141 max_pool2d_with_indices, 

142) 

143from flag_gems.ops.maximum import maximum 

144from flag_gems.ops.mean import mean, mean_dim 

145from flag_gems.ops.min import min, min_dim 

146from flag_gems.ops.minimum import minimum 

147from flag_gems.ops.mm import mm, mm_out 

148from flag_gems.ops.mse_loss import mse_loss 

149from flag_gems.ops.mul import mul, mul_ 

150from flag_gems.ops.multinomial import multinomial 

151from flag_gems.ops.mv import mv 

152from flag_gems.ops.nan_to_num import nan_to_num 

153from flag_gems.ops.ne import ne, ne_scalar 

154from flag_gems.ops.neg import neg, neg_ 

155from flag_gems.ops.nllloss import ( 

156 nll_loss2d_backward, 

157 nll_loss2d_forward, 

158 nll_loss_backward, 

159 nll_loss_forward, 

160) 

161from flag_gems.ops.nonzero import nonzero 

162from flag_gems.ops.normal import ( 

163 normal_, 

164 normal_float_tensor, 

165 normal_tensor_float, 

166 normal_tensor_tensor, 

167) 

168from flag_gems.ops.one_hot import one_hot 

169from flag_gems.ops.ones import ones 

170from flag_gems.ops.ones_like import ones_like 

171from flag_gems.ops.pad import constant_pad_nd, pad 

172from flag_gems.ops.per_token_group_quant_fp8 import ( 

173 SUPPORTED_FP8_DTYPE, 

174 per_token_group_quant_fp8, 

175) 

176from flag_gems.ops.polar import polar 

177from flag_gems.ops.pow import ( 

178 pow_scalar, 

179 pow_tensor_scalar, 

180 pow_tensor_scalar_, 

181 pow_tensor_tensor, 

182 pow_tensor_tensor_, 

183) 

184from flag_gems.ops.prod import prod, prod_dim 

185from flag_gems.ops.quantile import quantile 

186from flag_gems.ops.rand import rand 

187from flag_gems.ops.rand_like import rand_like 

188from flag_gems.ops.randn import randn 

189from flag_gems.ops.randn_like import randn_like 

190from flag_gems.ops.randperm import randperm 

191from flag_gems.ops.reciprocal import reciprocal, reciprocal_ 

192from flag_gems.ops.relu import relu, relu_ 

193from flag_gems.ops.repeat import repeat 

194from flag_gems.ops.repeat_interleave import ( 

195 repeat_interleave_self_int, 

196 repeat_interleave_self_tensor, 

197 repeat_interleave_tensor, 

198) 

199from flag_gems.ops.replication_pad3d import replication_pad3d 

200from flag_gems.ops.resolve_conj import resolve_conj 

201from flag_gems.ops.resolve_neg import resolve_neg 

202from flag_gems.ops.rms_norm import rms_norm, rms_norm_backward, rms_norm_forward 

203from flag_gems.ops.rsqrt import rsqrt, rsqrt_ 

204from flag_gems.ops.scaled_softmax import scaled_softmax_backward, scaled_softmax_forward 

205from flag_gems.ops.scatter import scatter, scatter_ 

206from flag_gems.ops.scatter_add_ import scatter_add_ 

207from flag_gems.ops.select_scatter import select_scatter 

208from flag_gems.ops.sigmoid import sigmoid, sigmoid_, sigmoid_backward 

209from flag_gems.ops.silu import silu, silu_, silu_backward 

210from flag_gems.ops.sin import sin, sin_ 

211from flag_gems.ops.slice_scatter import slice_scatter 

212from flag_gems.ops.softmax import softmax, softmax_backward 

213from flag_gems.ops.softplus import softplus 

214from flag_gems.ops.sort import sort, sort_stable 

215from flag_gems.ops.sqrt import sqrt, sqrt_ 

216from flag_gems.ops.stack import stack 

217from flag_gems.ops.std import std 

218from flag_gems.ops.sub import sub, sub_ 

219from flag_gems.ops.sum import sum, sum_dim, sum_dim_out, sum_out 

220from flag_gems.ops.tan import tan, tan_ 

221from flag_gems.ops.tanh import tanh, tanh_, tanh_backward 

222from flag_gems.ops.threshold import threshold, threshold_backward 

223from flag_gems.ops.tile import tile 

224from flag_gems.ops.to import to_copy 

225from flag_gems.ops.topk import topk 

226from flag_gems.ops.trace import trace 

227from flag_gems.ops.triu import triu, triu_ 

228from flag_gems.ops.unfold_backward import unfold_backward 

229from flag_gems.ops.uniform import uniform_ 

230from flag_gems.ops.unique import _unique2 

231from flag_gems.ops.upsample_bicubic2d_aa import _upsample_bicubic2d_aa 

232from flag_gems.ops.upsample_linear1d import upsample_linear1d 

233from flag_gems.ops.upsample_nearest1d import upsample_nearest1d 

234from flag_gems.ops.upsample_nearest2d import upsample_nearest2d 

235from flag_gems.ops.upsample_nearest3d import upsample_nearest3d 

236from flag_gems.ops.var_mean import var_mean 

237from flag_gems.ops.vdot import vdot 

238from flag_gems.ops.vector_norm import vector_norm 

239from flag_gems.ops.vstack import vstack 

240from flag_gems.ops.weightnorm import ( 

241 weight_norm_interface, 

242 weight_norm_interface_backward, 

243) 

244from flag_gems.ops.where import ( 

245 where_scalar_other, 

246 where_scalar_self, 

247 where_self, 

248 where_self_out, 

249) 

250from flag_gems.ops.zeros import zero_, zeros 

251from flag_gems.ops.zeros_like import zeros_like 

252 

253__all__ = [ 

254 "_conv_depthwise2d", 

255 "_unique2", 

256 "_upsample_bicubic2d_aa", 

257 "abs", 

258 "abs_", 

259 "acos", 

260 "add", 

261 "add_", 

262 "addcdiv", 

263 "addcmul", 

264 "addmm", 

265 "addmm_out", 

266 "addmv", 

267 "addmv_out", 

268 "addr", 

269 "all", 

270 "all_dim", 

271 "all_dims", 

272 "allclose", 

273 "amax", 

274 "angle", 

275 "any", 

276 "any_dim", 

277 "any_dims", 

278 "arange", 

279 "arange_start", 

280 "argmax", 

281 "argmin", 

282 "atan", 

283 "atan_", 

284 "avg_pool2d", 

285 "avg_pool2d_backward", 

286 "baddbmm", 

287 "batch_norm", 

288 "batch_norm_backward", 

289 "bitwise_and_scalar", 

290 "bitwise_and_scalar_", 

291 "bitwise_and_scalar_tensor", 

292 "bitwise_and_tensor", 

293 "bitwise_and_tensor_", 

294 "bitwise_left_shift", 

295 "bitwise_not", 

296 "bitwise_not_", 

297 "bitwise_or_scalar", 

298 "bitwise_or_scalar_", 

299 "bitwise_or_scalar_tensor", 

300 "bitwise_or_tensor", 

301 "bitwise_or_tensor_", 

302 "bitwise_right_shift", 

303 "bmm", 

304 "bmm_out", 

305 "cat", 

306 "ceil", 

307 "ceil_", 

308 "ceil_out", 

309 "celu", 

310 "celu_", 

311 "clamp", 

312 "clamp_", 

313 "clamp_min", 

314 "clamp_min_", 

315 "clamp_tensor", 

316 "clamp_tensor_", 

317 "constant_pad_nd", 

318 "contiguous", 

319 "conv1d", 

320 "conv2d", 

321 "conv3d", 

322 "copy", 

323 "copy_", 

324 "cos", 

325 "cos_", 

326 "count_nonzero", 

327 "cummax", 

328 "cummin", 

329 "cumsum", 

330 "cumsum_out", 

331 "diag", 

332 "diag_embed", 

333 "diagonal_backward", 

334 "div_mode", 

335 "div_mode_", 

336 "dot", 

337 "dropout", 

338 "dropout_backward", 

339 "elu", 

340 "elu_", 

341 "elu_backward", 

342 "embedding", 

343 "embedding_backward", 

344 "eq", 

345 "eq_scalar", 

346 "equal", 

347 "erf", 

348 "erf_", 

349 "exp", 

350 "exp_", 

351 "exp_out", 

352 "exp2", 

353 "exp2_", 

354 "exponential_", 

355 "eye", 

356 "eye_m", 

357 "fill_scalar", 

358 "fill_scalar_", 

359 "fill_scalar_out", 

360 "fill_tensor", 

361 "fill_tensor_", 

362 "fill_tensor_out", 

363 "flash_attention_forward", 

364 "flash_attn_varlen_func", 

365 "flip", 

366 "floor_divide", 

367 "floor_divide_", 

368 "full", 

369 "full_like", 

370 "gather", 

371 "gather_backward", 

372 "ge", 

373 "ge_scalar", 

374 "gelu", 

375 "gelu_", 

376 "gelu_backward", 

377 "get_scheduler_metadata", 

378 "glu", 

379 "glu_backward", 

380 "group_norm", 

381 "group_norm_backward", 

382 "gt", 

383 "gt_scalar", 

384 "hstack", 

385 "index", 

386 "index_add", 

387 "index_add_", 

388 "index_put", 

389 "index_put_", 

390 "index_select", 

391 "isclose", 

392 "isfinite", 

393 "isin", 

394 "isinf", 

395 "isnan", 

396 "kron", 

397 "layer_norm", 

398 "layer_norm_backward", 

399 "le", 

400 "le_scalar", 

401 "lerp_scalar", 

402 "lerp_scalar_", 

403 "lerp_tensor", 

404 "lerp_tensor_", 

405 "linspace", 

406 "log", 

407 "log_sigmoid", 

408 "log_softmax", 

409 "log_softmax_backward", 

410 "logical_and", 

411 "logical_and_", 

412 "logical_not", 

413 "logical_or", 

414 "logical_or_", 

415 "logical_xor", 

416 "logspace", 

417 "lt", 

418 "lt_scalar", 

419 "masked_fill", 

420 "masked_fill_", 

421 "masked_scatter", 

422 "masked_scatter_", 

423 "masked_select", 

424 "max", 

425 "max_dim", 

426 "max_pool2d_with_indices", 

427 "max_pool2d_backward", 

428 "maximum", 

429 "mean", 

430 "mean_dim", 

431 "min", 

432 "min_dim", 

433 "minimum", 

434 "mm", 

435 "mm_out", 

436 "mse_loss", 

437 "mul", 

438 "mul_", 

439 "multinomial", 

440 "mv", 

441 "nan_to_num", 

442 "ne", 

443 "ne_scalar", 

444 "neg", 

445 "neg_", 

446 "nll_loss_backward", 

447 "nll_loss_forward", 

448 "nll_loss2d_backward", 

449 "nll_loss2d_forward", 

450 "nonzero", 

451 "normal_float_tensor", 

452 "normal_tensor_float", 

453 "normal_tensor_tensor", 

454 "normal_", 

455 "normed_cumsum", 

456 "ones", 

457 "ones_like", 

458 "one_hot", 

459 "pad", 

460 "per_token_group_quant_fp8", 

461 "polar", 

462 "pow_scalar", 

463 "pow_tensor_scalar", 

464 "pow_tensor_scalar_", 

465 "pow_tensor_tensor", 

466 "pow_tensor_tensor_", 

467 "prod", 

468 "prod_dim", 

469 "quantile", 

470 "rand", 

471 "rand_like", 

472 "randn", 

473 "randn_like", 

474 "randperm", 

475 "reciprocal", 

476 "reciprocal_", 

477 "relu", 

478 "relu_", 

479 "remainder", 

480 "remainder_", 

481 "repeat", 

482 "repeat_interleave_self_int", 

483 "repeat_interleave_self_tensor", 

484 "repeat_interleave_tensor", 

485 "replication_pad3d", 

486 "resolve_conj", 

487 "resolve_neg", 

488 "rms_norm", 

489 "rms_norm_backward", 

490 "rms_norm_forward", 

491 "rsqrt", 

492 "rsqrt_", 

493 "scaled_dot_product_attention", 

494 "scaled_dot_product_attention_backward", 

495 "scaled_dot_product_attention_forward", 

496 "scaled_softmax_backward", 

497 "scaled_softmax_forward", 

498 "scatter", 

499 "scatter_", 

500 "scatter_add_", 

501 "select_scatter", 

502 "sigmoid", 

503 "sigmoid_", 

504 "sigmoid_backward", 

505 "silu", 

506 "silu_", 

507 "silu_backward", 

508 "sin", 

509 "sin_", 

510 "slice_scatter", 

511 "softmax", 

512 "softmax_backward", 

513 "softplus", 

514 "sort", 

515 "sort_stable", 

516 "sqrt", 

517 "sqrt_", 

518 "stack", 

519 "std", 

520 "sub", 

521 "sub_", 

522 "sum", 

523 "sum_dim", 

524 "sum_dim_out", 

525 "sum_out", 

526 "ScaleDotProductAttention", 

527 "SUPPORTED_FP8_DTYPE", 

528 "tan", 

529 "tan_", 

530 "tanh", 

531 "tanh_", 

532 "tanh_backward", 

533 "threshold", 

534 "threshold_backward", 

535 "tile", 

536 "to_copy", 

537 "topk", 

538 "trace", 

539 "triu", 

540 "triu_", 

541 "true_divide", 

542 "true_divide_", 

543 "true_divide_out", 

544 "unfold_backward", 

545 "uniform_", 

546 "upsample_linear1d", 

547 "upsample_nearest1d", 

548 "upsample_nearest2d", 

549 "upsample_nearest3d", 

550 "var_mean", 

551 "vdot", 

552 "vector_norm", 

553 "vstack", 

554 "weight_norm_interface", 

555 "weight_norm_interface_backward", 

556 "where_scalar_other", 

557 "where_scalar_self", 

558 "where_self", 

559 "where_self_out", 

560 "zeros", 

561 "zero_", 

562 "zeros_like", 

563]