Skip to content

Faster Deep Ensembles#929

Open
avullo wants to merge 29 commits into
developfrom
alessandro/faster-de
Open

Faster Deep Ensembles#929
avullo wants to merge 29 commits into
developfrom
alessandro/faster-de

Conversation

@avullo

@avullo avullo commented May 21, 2026

Copy link
Copy Markdown
Collaborator

Related issue(s)/PRs:

Summary

...

Fully backwards compatible: yes / no

PR checklist

  • The quality checks are all passing
  • The bug case / new feature is covered by tests
  • Any new features are well-documented (in docstrings or notebooks)

Alessandro Vullo and others added 19 commits May 21, 2026 08:16
…er layer

Replace E separate Dense branch paths in KerasEnsemble with a single model
that stacks all E inputs into [E, batch, D], runs one VectorizedEnsembleDenseLayer
(batched matmul [E, batch, D] @ [E, D, H]) per hidden layer, then splits
[E, batch, 2] params into E separate DistributionLambda outputs.

External interface unchanged: same input/output names, loss=[nll]*E, metrics,
callbacks, compile/fit, verbose=1 all preserved.

Co-authored-by: Cursor <cursoragent@cursor.com>
…eparate calls

Generate all E bootstrap index sets at once with tf.random.uniform([E, N]),
then gather both query_points and observations in a single batched tf.gather
per tensor, yielding [E, N, D] and [E, N, 1]. This replaces E×2 individual
gather calls with 2 calls, and separates the bootstrap and no-bootstrap paths
for clarity.

Co-authored-by: Cursor <cursoragent@cursor.com>
…r step

When N % batch_size == 0, reshape each input/output tensor from [N, ...]
to [n_batches, batch_size, ...] before from_tensor_slices. Each Dataset
element is then a pre-formed [batch_size, ...] slice rather than requiring
B individual sample copies to be stacked per step.

Co-authored-by: Cursor <cursoragent@cursor.com>
…] input and one [batch,E,1] output

_build_vectorized_ensemble now accepts a single 'ensemble_input' of shape [E*D],
reshapes to [E, batch, D] internally, transposes the [E, batch, 2] params to
[batch, E, 2], and wraps in one DistributionLambda giving Normal([batch,E,1]).

prepare_dataset and prepare_query_points branch on len(input_names)==1 to pack/
unpack the stacked representation. ensemble_distributions splits the single
batched Distribution into E Normal distributions for the predict API.

The compile() call is adapted to use one loss/metric for single-output models.
_warmup_jit gains a forward-pass fallback to infer output shape when DistributionLambda
does not propagate it statically.

Reduces per-step Dataset tensor count from 2E=20 to 2, and Keras metric scalars
from 2E to 2.

Co-authored-by: Cursor <cursoragent@cursor.com>
…,1] output reduces per-step tensor overhead 20→2
@avullo avullo marked this pull request as draft May 23, 2026 14:11
@avullo avullo marked this pull request as ready for review June 2, 2026 15:46
@avullo avullo requested review from hstojic and uri-granta June 2, 2026 15:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant