A Variational Autoencoder (VAE) is a type of artificial neural network used for generating new data samples similar to the ones it was trained on. Imagine teaching a computer to not only recognize faces but also create new, realistic faces that don't exist in reality. VAEs achieve this by learning to compress data into a simplified form and then reconstruct it from that form.
VAEs are powerful because they combine neural networks with probabilistic methods, allowing for both compression and generation of data. Unlike traditional autoencoders that map inputs to fixed points, VAEs map inputs to probability distributions, introducing a level of variability that enables the generation of new, diverse data samples.
Consider an image of a handwritten digit. A VAE learns to compress this high-dimensional data (the image) into a lower-dimensional "latent space." This compressed representation captures the essential features of the data. The VAE can then reconstruct the original image from this compressed form. By manipulating the latent space, the VAE can generate new images that are variations of the original data.
Unlike standard autoencoders that use fixed points in latent space, VAEs use probability distributions. This means that the encoder doesn't output a single point but parameters (mean and variance) defining a distribution. Sampling from this distribution allows the decoder to generate diverse outputs, enabling the creation of new, unseen data.
To train VAEs effectively, we need to backpropagate through the sampling process, which is not straightforward. The reparameterization trick solves this by expressing the sampling process in a differentiable manner. Instead of sampling directly from the distribution, we sample from a standard normal distribution and adjust using the learned mean and variance:
$$ z = \mu + \sigma \cdot \epsilon $$
where \( \epsilon \) is sampled from \( \mathcal{N}(0, 1) \).
The core idea of VAEs revolves around mapping input data to a latent space represented by probability distributions. The encoder maps the input \( x \) to a distribution \( q(z|x) \), typically a Gaussian with mean \( \mu \) and variance \( \sigma^2 \). The decoder then maps samples \( z \) from this distribution back to the data space, generating \( p(x|z) \).
Directly maximizing the likelihood \( p(x) \) is computationally intractable due to the integral over all possible latent variables \( z \). Instead, we maximize the Evidence Lower Bound (ELBO), which serves as a surrogate objective:
$$ \log p(x) \geq \mathbb{E}_{q(z|x)}[\log p(x|z)] - \text{KL}(q(z|x) || p(z)) $$
The ELBO consists of two terms:
The loss function for VAEs combines the reconstruction loss and the KL divergence:
$$ \mathcal{L} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \text{KL}(q(z|x) || p(z)) $$
In practice, this is implemented as:
def loss_function(recon_x, x, mu, log_var):
# Reconstruction loss
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
# KL Divergence
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
First, we need to import the necessary libraries:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
The encoder compresses the input data into the parameters of the latent distribution:
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
self.relu = nn.ReLU()
def forward(self, x):
h = self.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
This function allows us to sample from the latent distribution in a differentiable way:
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z
The decoder reconstructs the data from the latent vector:
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, z):
h = self.relu(self.fc1(z))
recon_x = self.sigmoid(self.fc2(h))
return recon_x
Now, we integrate the encoder and decoder into the VAE model:
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
def forward(self, x):
mu, logvar = self.encoder(x)
z = reparameterize(mu, logvar)
recon_x = self.decoder(z)
return recon_x, mu, logvar
We define the training loop, specifying hyperparameters and the optimizer:
# Hyperparameters
input_dim = 784 # MNIST images (28x28)
hidden_dim = 400
latent_dim = 20
learning_rate = 1e-3
epochs = 10
batch_size = 128
# Data Loading
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# Initialize VAE and Optimizer
vae = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
# Training Loop
for epoch in range(epochs):
vae.train()
train_loss = 0
for batch_idx, (x, _) in enumerate(train_loader):
x = x.view(-1, input_dim)
optimizer.zero_grad()
recon_x, mu, logvar = vae(x)
loss = loss_function(recon_x, x, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {train_loss / len(train_loader.dataset):.4f}')
Bringing everything together, here's the full code for a Variational Autoencoder:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Encoder Definition
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
self.relu = nn.ReLU()
def forward(self, x):
h = self.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
# Reparameterization Trick
def reparameterize(mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z
# Decoder Definition
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, z):
h = self.relu(self.fc1(z))
recon_x = self.sigmoid(self.fc2(h))
return recon_x
# VAE Definition
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
def forward(self, x):
mu, logvar = self.encoder(x)
z = reparameterize(mu, logvar)
recon_x = self.decoder(z)
return recon_x, mu, logvar
# Loss Function
def loss_function(recon_x, x, mu, logvar):
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
# Hyperparameters
input_dim = 784 # MNIST images (28x28)
hidden_dim = 400
latent_dim = 20
learning_rate = 1e-3
epochs = 10
batch_size = 128
# Data Loading
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
# Initialize VAE and Optimizer
vae = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
# Training Loop
for epoch in range(epochs):
vae.train()
train_loss = 0
for batch_idx, (x, _) in enumerate(train_loader):
x = x.view(-1, input_dim)
optimizer.zero_grad()
recon_x, mu, logvar = vae(x)
loss = loss_function(recon_x, x, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f'Epoch {epoch+1}, Loss: {train_loss / len(train_loader.dataset):.4f}')
Variational Autoencoders are a powerful tool in the realm of generative models, combining the strengths of neural networks and probabilistic modeling. By understanding the intuition behind compression and reconstruction, delving into the mathematical underpinnings, and implementing them through code, you can harness VAEs for a variety of applications in data generation and beyond.