Skip to content

Classification on Two Moons Toy Data#

This toy example shows how to train a variational neural network for binary classification on the two moons dataset using implicit regularization via SGD initialized to the prior, as described in this paper. This approach avoids the computational cost of explicit regularization for non-trivial variational families and preserves beneficial inductive biases.

You can run this example yourself via the corresponding standalone script.

Data#

We begin by generating synthetic training and test data based on the two moons classification technique.

Data

Model#

Next, we define a fully-connected stochastic neural network using a pre-defined models.MLP.

Model
from torch import nn

from inferno import bnn, loss_fns

model = bnn.Sequential(
    inferno.models.MLP(
        in_size=2,
        hidden_sizes=[hidden_width] * num_hidden_layers,
        out_size=1,
        activation_layer=nn.SiLU,
        cov=[bnn.params.FactorizedCovariance()]
        + [None] * (num_hidden_layers - 1)
        + [bnn.params.FactorizedCovariance()],
        bias=True,
    ),
    nn.Flatten(-2, -1),
    parametrization=bnn.params.MUP(),
)
  1. PyTorch nn.Modules can be used as part of inferno models.

Training#

We train the model via the expected loss, \( \bar{\ell}(\theta) = \mathbb{E}_{q_\theta(w)}[\ell(y, f_w(X))] \) i.e. the average loss of the model when drawing weights from the variational distribution \(q_\theta(w)\). In practice, for efficiency we only use a single sample per batch during training.

Training
# Loss function
loss_fn = loss_fns.BCEWithLogitsLoss()

# Optimizer
optimizer = torch.optim.SGD(
    params=model.parameters_and_lrs(
        lr=lr, optimizer="SGD"
    ),  # Sets module-specific learning rates
    lr=lr,
    momentum=0.9,
)

# Training loop
for epoch in tqdm.trange(num_epochs):
    model.train()

    for X_batch, y_batch in iter(train_dataloader):
        optimizer.zero_grad()

        X_batch = X_batch.to(device=device)
        y_batch = y_batch.to(device=device)

        logits = model(X_batch)

        loss = loss_fn(logits, y_batch)

        loss.backward()
        optimizer.step()

Results#

Learning Curves#

Learning Curves

Decision Boundary#

Decision Boundary