Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 144 additions & 54 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,6 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
}

if (!e) { // hadamard reduction

auto &[A, B] = AB;
TiledRange trange(range_map[i]);
RangeProduct tiles;
Expand Down Expand Up @@ -685,75 +684,147 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
: element_contract_op.value()(l, r);
};

auto pa = A.permutation;
auto pb = B.permutation;
auto const pa = A.permutation;
auto const pb = B.permutation;
auto const pc = C.permutation;

// Each H-tile iteration produces an independent output tile, so the
// loop is parallel-safe. Dispatch per-H-tile work to the MADNESS task
// queue; pre-size a per-slot result vector so tasks write their own
// slot without synchronization, and gather before exiting scope so
// captured references stay alive for the task lifetime.
//
// Input tiles must be resolved BEFORE submitting tasks: calling
// .get() on an unready future from inside a task body is unsafe
// with the PaRSEC backend ("recursive call to wait"). We instead
// issue all find()s up-front (non-blocking, so requests overlap)
// and materialize them on the main thread; the submitted tasks
// then operate purely on local data.
using ATile = typename ArrayA::value_type;
using BTile = typename ArrayB::value_type;

// Per-H-tile job metadata shared by both the in-flight (futures)
// representation and the materialized (resolved tiles) representation.
struct HJobMeta {
Index h; // H-tile coord in A/B's annotation
Index c_target; // C-tile coord (= apply(pc, h))
size_t batch; // product of H.batch sizes for this h
};
struct PendingHJob : HJobMeta {
std::vector<
std::tuple<Index, madness::Future<ATile>, madness::Future<BTile>>>
inputs;
};
struct HJob : HJobMeta {
// (i_index, ai, bi) for each non-zero input pair contributing to h
std::vector<std::tuple<Index, ATile, BTile>> inputs;
};

// Phase 1: issue all find() calls (non-blocking) so remote requests
// are in flight concurrently; collect futures + metadata.
std::vector<PendingHJob> pending_jobs;
for (Index h : H.tiles) {
auto const pc = C.permutation;
auto const c = apply(pc, h);
if (!C.array.is_local(c)) continue;
size_t batch = 1;
for (size_t i = 0; i < h.size(); ++i) {
batch *= H.batch[i].at(h[i]);
auto const c_target = apply(pc, h);
if (!C.array.is_local(c_target)) continue;
PendingHJob pj;
pj.h = h;
pj.c_target = c_target;
pj.batch = 1;
for (size_t hi = 0; hi < h.size(); ++hi) {
pj.batch *= H.batch[hi].at(h[hi]);
}
ResultTensor tile(TiledArray::Range{batch},
typename ResultTensor::value_type{});
for (Index i : tiles) {
// skip this unless both input tiles exist
const auto pahi_inv = apply_inverse(pa, h + i);
const auto pbhi_inv = apply_inverse(pb, h + i);
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;
pj.inputs.emplace_back(i, A.array.find(pahi_inv),
B.array.find(pbhi_inv));
}
pending_jobs.push_back(std::move(pj));
}

// Phase 2: materialize all input tiles on the main thread.
auto materialize = [](PendingHJob &&pj) -> HJob {
HJob job;
static_cast<HJobMeta &>(job) = static_cast<HJobMeta const &>(pj);
job.inputs.reserve(pj.inputs.size());
for (auto &[i, fa, fb] : pj.inputs)
job.inputs.emplace_back(std::move(i), fa.get(), fb.get());
return job;
};
std::vector<HJob> jobs;
jobs.reserve(pending_jobs.size());
for (auto &pj : pending_jobs) jobs.push_back(materialize(std::move(pj)));
pending_jobs.clear();

// Single inner-product element op subsuming the three template arms
// (ToT-ToT, mixed T/ToT, plain T). Writes one batch element of `out`
// from one batch slice each of A and B.
auto kelement_op = [&](auto &out, auto const &aik, auto const &bik) {
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());
for (auto ii = 0; ii < vol; ++ii)
out.add_to(element_product_op(aik.data()[ii], bik.data()[ii]));
} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());
for (auto ii = 0; ii < vol; ++ii) {
if constexpr (IsArrayToT<ArrayA>) {
out.add_to(aik.data()[ii].scale(bik.data()[ii]));
} else {
out.add_to(bik.data()[ii].scale(aik.data()[ii]));
}
}
} else {
out += aik.dot(bik);
}
};

std::vector<std::pair<Index, ResultTensor>> h_results(jobs.size());

auto ai = A.array.find(pahi_inv).get();
auto bi = B.array.find(pbhi_inv).get();
// per_h_work: process jobs[slot] and write into h_results[slot].
// Captures listed explicitly so the lifetime contract is checkable:
// every captured reference must outlive the task queue (gathered
// below before this scope exits).
auto per_h_work = [&jobs, &trange, &h_results, &kelement_op, pa, pb, pc,
&C](size_t slot) -> bool {
auto const &job = jobs[slot];
size_t batch = job.batch;
ResultTensor tile(TiledArray::Range{batch},
typename ResultTensor::value_type{});
for (auto const &[i, ai_in, bi_in] : job.inputs) {
ATile ai = ai_in;
BTile bi = bi_in;
if (pa) ai = ai.permute(pa);
if (pb) bi = bi.permute(pb);
auto shape = trange.tile(i);
ai = ai.reshape(shape, batch);
bi = bi.reshape(shape, batch);
for (size_t k = 0; k < batch; ++k) {
using Ix = ::Einsum::Index<std::string>;
if constexpr (AreArrayToT<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});
using TensorT = std::remove_reference_t<decltype(el)>;

for (auto i = 0; i < vol; ++i)
el.add_to(element_product_op(aik.data()[i], bik.data()[i]));

} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
auto bik = bi.batch(k);
auto vol = aik.total_size();
TA_ASSERT(vol == bik.total_size());

auto &el = tile({k});

for (auto i = 0; i < vol; ++i)
if constexpr (IsArrayToT<ArrayA>) {
el.add_to(aik.data()[i].scale(bik.data()[i]));
} else {
el.add_to(bik.data()[i].scale(aik.data()[i]));
}

} else {
auto hk = ai.batch(k).dot(bi.batch(k));
tile({k}) += hk;
}
auto &el = tile({k});
kelement_op(el, ai.batch(k), bi.batch(k));
}
}
// data is stored as h1 h2 ... but all modes folded as 1 batch dim
// first reshape to h = (h1 h2 ...)
// n.b. can't just use shape = C.array.trange().tile(h)
auto shape = apply_inverse(pc, C.array.trange().tile(c));
auto shape = apply_inverse(pc, C.array.trange().tile(job.c_target));
tile = tile.reshape(shape);
// then permute to target C layout c = (c1 c2 ...)
if (pc) tile = tile.permute(pc);
// and move to C_local_tiles
C_local_tiles.emplace_back(std::move(c), std::move(tile));
h_results[slot] = {job.c_target, std::move(tile)};
return true;
};

std::vector<madness::Future<bool>> h_futures;
h_futures.reserve(jobs.size());
for (size_t slot = 0; slot < jobs.size(); ++slot) {
h_futures.push_back(world.taskq.add(per_h_work, slot));
}
for (auto &fut : h_futures) fut.get();
for (auto &r : h_results) {
C_local_tiles.emplace_back(std::move(r.first), std::move(r.second));
}

build_C_array();
Expand Down Expand Up @@ -809,17 +880,36 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
term.local_tiles.clear();
const Permutation &P = term.permutation;

using TileType =
typename std::decay_t<decltype(term.array)>::value_type;
std::vector<std::pair<Index, madness::Future<TileType>>> pending;

for (Index ei : term.tiles) {
auto idx = apply_inverse(P, h + ei);
if (!term.array.is_local(idx)) continue;
if (term.array.is_zero(idx)) continue;
// TODO no need for immediate evaluation
auto tile = term.array.find_local(idx).get();
if (P) tile = tile.permute(P);

auto tile_future = term.array.find_local(idx); // non-blocking
auto shape = term.ei_tiled_range.tile(ei);
tile = tile.reshape(shape, batch);
term.local_tiles.push_back({ei, tile});

// Submit per-tile permute-and-reshape as a MADNESS task.
// Capture P by value (it's small) and shape by value.
madness::Future<TileType> permuted = owners->taskq.add(
[P, shape, batch](const TileType &tile) -> TileType {
TileType result = P ? tile.permute(P) : tile;
return result.reshape(shape, batch);
},
tile_future);

pending.emplace_back(ei, permuted);
}

// Wait for all per-tile tasks to complete, then gather into
// local_tiles.
for (auto &[ei, fut] : pending) {
term.local_tiles.push_back({ei, fut.get()});
}

bool replicated = term.array.pmap()->is_replicated();
term.ei = TiledArray::make_array<decltype(term.array)>(
*owners, term.ei_tiled_range, term.local_tiles.begin(),
Expand Down
Loading