diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index e0b93886ee..77c4c8abdd 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -642,7 +642,6 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, } if (!e) { // hadamard reduction - auto &[A, B] = AB; TiledRange trange(range_map[i]); RangeProduct tiles; @@ -685,75 +684,147 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, : element_contract_op.value()(l, r); }; - auto pa = A.permutation; - auto pb = B.permutation; + auto const pa = A.permutation; + auto const pb = B.permutation; + auto const pc = C.permutation; + + // Each H-tile iteration produces an independent output tile, so the + // loop is parallel-safe. Dispatch per-H-tile work to the MADNESS task + // queue; pre-size a per-slot result vector so tasks write their own + // slot without synchronization, and gather before exiting scope so + // captured references stay alive for the task lifetime. + // + // Input tiles must be resolved BEFORE submitting tasks: calling + // .get() on an unready future from inside a task body is unsafe + // with the PaRSEC backend ("recursive call to wait"). We instead + // issue all find()s up-front (non-blocking, so requests overlap) + // and materialize them on the main thread; the submitted tasks + // then operate purely on local data. + using ATile = typename ArrayA::value_type; + using BTile = typename ArrayB::value_type; + + // Per-H-tile job metadata shared by both the in-flight (futures) + // representation and the materialized (resolved tiles) representation. + struct HJobMeta { + Index h; // H-tile coord in A/B's annotation + Index c_target; // C-tile coord (= apply(pc, h)) + size_t batch; // product of H.batch sizes for this h + }; + struct PendingHJob : HJobMeta { + std::vector< + std::tuple, madness::Future>> + inputs; + }; + struct HJob : HJobMeta { + // (i_index, ai, bi) for each non-zero input pair contributing to h + std::vector> inputs; + }; + + // Phase 1: issue all find() calls (non-blocking) so remote requests + // are in flight concurrently; collect futures + metadata. + std::vector pending_jobs; for (Index h : H.tiles) { - auto const pc = C.permutation; - auto const c = apply(pc, h); - if (!C.array.is_local(c)) continue; - size_t batch = 1; - for (size_t i = 0; i < h.size(); ++i) { - batch *= H.batch[i].at(h[i]); + auto const c_target = apply(pc, h); + if (!C.array.is_local(c_target)) continue; + PendingHJob pj; + pj.h = h; + pj.c_target = c_target; + pj.batch = 1; + for (size_t hi = 0; hi < h.size(); ++hi) { + pj.batch *= H.batch[hi].at(h[hi]); } - ResultTensor tile(TiledArray::Range{batch}, - typename ResultTensor::value_type{}); for (Index i : tiles) { - // skip this unless both input tiles exist const auto pahi_inv = apply_inverse(pa, h + i); const auto pbhi_inv = apply_inverse(pb, h + i); if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; + pj.inputs.emplace_back(i, A.array.find(pahi_inv), + B.array.find(pbhi_inv)); + } + pending_jobs.push_back(std::move(pj)); + } + + // Phase 2: materialize all input tiles on the main thread. + auto materialize = [](PendingHJob &&pj) -> HJob { + HJob job; + static_cast(job) = static_cast(pj); + job.inputs.reserve(pj.inputs.size()); + for (auto &[i, fa, fb] : pj.inputs) + job.inputs.emplace_back(std::move(i), fa.get(), fb.get()); + return job; + }; + std::vector jobs; + jobs.reserve(pending_jobs.size()); + for (auto &pj : pending_jobs) jobs.push_back(materialize(std::move(pj))); + pending_jobs.clear(); + + // Single inner-product element op subsuming the three template arms + // (ToT-ToT, mixed T/ToT, plain T). Writes one batch element of `out` + // from one batch slice each of A and B. + auto kelement_op = [&](auto &out, auto const &aik, auto const &bik) { + if constexpr (AreArrayToT) { + auto vol = aik.total_size(); + TA_ASSERT(vol == bik.total_size()); + for (auto ii = 0; ii < vol; ++ii) + out.add_to(element_product_op(aik.data()[ii], bik.data()[ii])); + } else if constexpr (!AreArraySame) { + auto vol = aik.total_size(); + TA_ASSERT(vol == bik.total_size()); + for (auto ii = 0; ii < vol; ++ii) { + if constexpr (IsArrayToT) { + out.add_to(aik.data()[ii].scale(bik.data()[ii])); + } else { + out.add_to(bik.data()[ii].scale(aik.data()[ii])); + } + } + } else { + out += aik.dot(bik); + } + }; + + std::vector> h_results(jobs.size()); - auto ai = A.array.find(pahi_inv).get(); - auto bi = B.array.find(pbhi_inv).get(); + // per_h_work: process jobs[slot] and write into h_results[slot]. + // Captures listed explicitly so the lifetime contract is checkable: + // every captured reference must outlive the task queue (gathered + // below before this scope exits). + auto per_h_work = [&jobs, &trange, &h_results, &kelement_op, pa, pb, pc, + &C](size_t slot) -> bool { + auto const &job = jobs[slot]; + size_t batch = job.batch; + ResultTensor tile(TiledArray::Range{batch}, + typename ResultTensor::value_type{}); + for (auto const &[i, ai_in, bi_in] : job.inputs) { + ATile ai = ai_in; + BTile bi = bi_in; if (pa) ai = ai.permute(pa); if (pb) bi = bi.permute(pb); auto shape = trange.tile(i); ai = ai.reshape(shape, batch); bi = bi.reshape(shape, batch); for (size_t k = 0; k < batch; ++k) { - using Ix = ::Einsum::Index; - if constexpr (AreArrayToT) { - auto aik = ai.batch(k); - auto bik = bi.batch(k); - auto vol = aik.total_size(); - TA_ASSERT(vol == bik.total_size()); - - auto &el = tile({k}); - using TensorT = std::remove_reference_t; - - for (auto i = 0; i < vol; ++i) - el.add_to(element_product_op(aik.data()[i], bik.data()[i])); - - } else if constexpr (!AreArraySame) { - auto aik = ai.batch(k); - auto bik = bi.batch(k); - auto vol = aik.total_size(); - TA_ASSERT(vol == bik.total_size()); - - auto &el = tile({k}); - - for (auto i = 0; i < vol; ++i) - if constexpr (IsArrayToT) { - el.add_to(aik.data()[i].scale(bik.data()[i])); - } else { - el.add_to(bik.data()[i].scale(aik.data()[i])); - } - - } else { - auto hk = ai.batch(k).dot(bi.batch(k)); - tile({k}) += hk; - } + auto &el = tile({k}); + kelement_op(el, ai.batch(k), bi.batch(k)); } } // data is stored as h1 h2 ... but all modes folded as 1 batch dim // first reshape to h = (h1 h2 ...) // n.b. can't just use shape = C.array.trange().tile(h) - auto shape = apply_inverse(pc, C.array.trange().tile(c)); + auto shape = apply_inverse(pc, C.array.trange().tile(job.c_target)); tile = tile.reshape(shape); // then permute to target C layout c = (c1 c2 ...) if (pc) tile = tile.permute(pc); - // and move to C_local_tiles - C_local_tiles.emplace_back(std::move(c), std::move(tile)); + h_results[slot] = {job.c_target, std::move(tile)}; + return true; + }; + + std::vector> h_futures; + h_futures.reserve(jobs.size()); + for (size_t slot = 0; slot < jobs.size(); ++slot) { + h_futures.push_back(world.taskq.add(per_h_work, slot)); + } + for (auto &fut : h_futures) fut.get(); + for (auto &r : h_results) { + C_local_tiles.emplace_back(std::move(r.first), std::move(r.second)); } build_C_array(); @@ -809,17 +880,36 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, term.local_tiles.clear(); const Permutation &P = term.permutation; + using TileType = + typename std::decay_t::value_type; + std::vector>> pending; + for (Index ei : term.tiles) { auto idx = apply_inverse(P, h + ei); if (!term.array.is_local(idx)) continue; if (term.array.is_zero(idx)) continue; - // TODO no need for immediate evaluation - auto tile = term.array.find_local(idx).get(); - if (P) tile = tile.permute(P); + + auto tile_future = term.array.find_local(idx); // non-blocking auto shape = term.ei_tiled_range.tile(ei); - tile = tile.reshape(shape, batch); - term.local_tiles.push_back({ei, tile}); + + // Submit per-tile permute-and-reshape as a MADNESS task. + // Capture P by value (it's small) and shape by value. + madness::Future permuted = owners->taskq.add( + [P, shape, batch](const TileType &tile) -> TileType { + TileType result = P ? tile.permute(P) : tile; + return result.reshape(shape, batch); + }, + tile_future); + + pending.emplace_back(ei, permuted); + } + + // Wait for all per-tile tasks to complete, then gather into + // local_tiles. + for (auto &[ei, fut] : pending) { + term.local_tiles.push_back({ei, fut.get()}); } + bool replicated = term.array.pmap()->is_replicated(); term.ei = TiledArray::make_array( *owners, term.ei_tiled_range, term.local_tiles.begin(),