从 VAE 到 CVAE:解锁生成模型的 “可控” 魔法

本文介绍变分自编码器原理及实践。

从 VAE 到 CVAE:解锁生成模型的 “可控” 魔法

前言

在生成式人工智能的浪潮里,变分自编码器(Variational Autoencoder,VAE)绝对是绕不开的经典模型。它凭借优雅的概率建模思想,打破了传统自编码器 “确定性生成” 的局限;而它的改进版 —— 条件变分自编码器(Conditional Variational Autoencoder,CVAE),更是给生成任务加上了 “可控开关”,让机器能按我们的指令生成想要的内容。今天,我们就来聊聊 VAE 和 CVAE 的核心原理,以及它们在现实世界中的精彩应用。

本文将使用 PyTorch 实现变分自编码器(VAE)和 条件变分自编码器(CVAE),并在 MNIST 数据集上进行训练与评估。

参考文章:

Pytorch 实现 VAE 和 CVAE

一、先从自编码器说起:生成模型的 “雏形”

要理解 VAE,得先回顾它的 “前辈”——传统自编码器(AE)

自编码器是一种无监督学习模型,结构上分为两部分:编码器(Encoder) 和解码器(Decoder)。编码器负责将高维输入数据(比如一张 MNIST 手写数字图片)压缩成低维的隐空间向量;解码器则负责把这个隐向量还原成和输入数据相似的输出。训练的目标很简单:让输出尽可能接近输入,也就是最小化 “重构误差”。

但传统自编码器有个致命缺陷:隐空间是 “无序” 的。编码器生成的隐向量没有统一的概率分布规律,直接从隐空间随机采样一个向量,解码器大概率会生成一张毫无意义的 “噪声图”。它只能 “复刻” 见过的数据,却没法 “创造” 新数据。

二、变分自编码器(VAE):给隐空间加个 “概率约束”

VAE 的出现,正是为了解决传统自编码器的痛点。它的核心创新,是给隐空间赋予了概率意义,让隐向量不再是孤立的点,而是服从某种分布的随机变量。

1. VAE 的核心原理:概率建模 + 重参数化技巧

VAE 的目标不再是简单的 “重构数据”,而是学习数据的概率分布。它的核心思想可以拆解为三步:

  1. 编码器:学习隐空间的分布参数不同于传统自编码器直接输出一个确定的隐向量,VAE 的编码器会输出一组分布参数—— 通常是正态分布的均值$μ$和对数方差$logσ^2$。也就是说,对于输入数据$x$,编码器会推断出它对应的隐向量$z$服从分布$z∼N(μ(x),σ^2(x))$。

  2. 重参数化技巧:解决 “采样不可导” 难题既然z是从分布中采样得到的,而采样操作是随机的、不可导的,这就导致梯度无法反向传播,模型没法训练。

VAE 的 “神来之笔” 就是重参数化:不直接从$N(μ,σ^2)$采样$z$,而是先从标准正态分布$N(0,1)$采样一个辅助变量$ϵ$,再通过公式 $z=μ+ϵ⋅exp(0.5⋅logσ^2)$ 计算得到$z$。

这样一来,$ϵ$是固定的随机噪声,$μ$和$logσ^2$是可导的模型参数,梯度就能顺利传递了。

  1. 解码器:从隐向量还原数据 + 损失函数约束解码器的任务和传统自编码器类似:把隐向量$z$还原成生成数据$\hat x$。但 VAE 的损失函数有两个部分,缺一不可:
  • 重构损失:衡量$\hat x$和原始输入$x$的差异,确保生成的数据 “像真的”。
  • KL 散度损失:衡量编码器推断的隐分布$N(μ,σ^2)$和标准正态分布$N(0,1)$的差异,确保隐空间是 “规整” 的、可采样的。

两个损失结合,既保证了生成质量,又让隐空间具备了 “创造性”——我们可以直接从标准正态分布采样$z$,输入解码器就能生成全新的、从未见过的数据。

2. VAE 的特点:模糊但多样的生成

VAE 生成的图像,往往带有一点 “朦胧感”,不像 GAN(生成对抗网络)那样锐利清晰。但它的优势也很明显:

  • 基于概率模型,生成过程更稳定,不容易出现模式崩溃;
  • 隐空间具有连续性,在隐空间中插值,能得到平滑的图像过渡效果。

VAE 模型架构

三、条件变分自编码器(CVAE):给生成加个 “可控标签”

VAE 实现了 “无监督生成”,但它有个小遗憾:生成内容不可控。比如用 VAE 生成 MNIST 数字时,你没法指定 “我要生成数字 2”,只能随机生成一个未知的数字。

CVAE 正是为了解决 “可控生成” 而来,它的核心改进很简单:给模型加个 “条件”

1. CVAE 的核心原理:给编码器和解码器都加 “标签”

CVAE 的全称是 Conditional VAE,直译就是 “带条件的 VAE”。这个 “条件” 可以是任何我们想要的信息 —— 比如数字的类别标签、图像的风格描述、文本的关键词等等。 它的结构和 VAE 几乎一致,只做了两处关键修改:

  1. 编码器输入:数据 + 条件编码器不再只接收输入数据x,而是接收 x和条件y的组合(比如 MNIST 的图像 + 数字标签 “2”)。编码器学习的是在条件$y$下,$x$对应的隐分布,即$z∼N(μ(x,y),σ^2(x,y))$。
  2. 解码器输入:隐向量 + 条件解码器同样不再只接收隐向量$z$,而是接收$z$和条件$y$的组合。解码器学习的是在条件$y$下,如何从$z$还原出$x$

这样一来,模型就学会了 “条件与数据” 的关联。当我们想要生成指定内容时,只需要:

  1. 从标准正态分布采样隐向量$z$;
  2. 传入我们想要的条件$y$(比如 “生成数字 2”);
  3. 解码器就会输出符合条件$y$的生成数据。

2. CVAE 的关键:类别嵌入层

对于离散的条件(比如数字标签 0-9),CVAE 通常会用嵌入层(Embedding Layer) 将离散的标签转化为连续的向量。这样做的目的,是让模型更好地学习条件与数据之间的非线性关系,提升生成的准确性。

四、VAE 与 CVAE 的应用场景

从理论到实践,VAE 和 CVAE 凭借独特的优势,在多个领域发挥着重要作用。

1. VAE 的典型应用

  • 无监督图像生成:生成风格统一的图像,比如人脸、风景的初步生成;
  • 数据降噪与修复:利用重构能力,修复老照片、去除图像噪声;
  • 异常检测:通过计算重构误差,识别偏离正常分布的异常数据(比如工业质检中的缺陷产品);
  • 隐空间插值:在两个隐向量之间插值,生成平滑过渡的图像,用于动画制作、风格迁移的初步探索。

2. CVAE 的典型应用

  • 可控图像生成:指定类别生成图像,比如生成手写数字 “5”、生成特定品种的猫;
  • 文本到图像生成:以文本描述为条件,生成符合描述的图像(比如 “生成一只戴帽子的小狗”);
  • 图像到图像翻译:以源图像和目标风格为条件,将图像转换成指定风格(比如素描转油画);
  • 医疗影像生成:指定病灶类型,生成模拟的医疗影像,辅助医生培训和诊断。

五、总结:VAE 与 CVAE 的价值

VAE 和 CVAE 虽然生成的图像清晰度不如 GAN,但它们的概率建模思想稳定的训练特性,在生成模型领域占据着不可替代的地位。

VAE 让自编码器从 “复刻工具” 升级为 “创作工具”,而 CVAE 又给这个工具装上了 “可控开关”。从无监督生成到有条件生成,这两步小小的改进,却为生成式 AI 的落地应用打开了更广阔的空间。随着技术的发展,VAE 和 CVAE 还在不断进化 —— 比如结合卷积神经网络(CNN)的 ConvVAE、结合注意力机制的 VAE 变体,未来还会有更多精彩的突破。

附录

PyTorch 实现变分自编码器 (VAE) - MNIST 数据集

你需要一份基于 PyTorch、针对 MNIST 数据集的 VAE 实现,要求包含重参数化技巧、详细注释、训练误差展示和生成图像存储,以下是完整可运行代码(兼容 Python 3.8、PyTorch 1.7+):

完整代码实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np

# 1. 配置全局超参数
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 优先使用GPU加速
BATCH_SIZE = 128  # 批次大小
EPOCHS = 30  # 训练轮数
LEARNING_RATE = 1e-3  # 学习率
LATENT_DIM = 20  # 潜在变量z的维度(隐空间维度)
IMAGE_SIZE = 784  # MNIST图像展平后的尺寸(28*28)
SAVE_IMAGE_PATH = "./vae_generated_images/"  # 生成图像的存储路径
import os
os.makedirs(SAVE_IMAGE_PATH, exist_ok=True)  # 创建存储目录(不存在则创建)

# 2. 数据预处理与数据集加载
# 定义数据变换:转为Tensor + 归一化(将像素值从[0,255]映射到[0,1])
transform = transforms.Compose([
    transforms.ToTensor(),  # 转为torch.Tensor,形状为(C,H,W),值范围[0,1]
])

# 加载MNIST训练数据集(自动下载到./data目录)
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

# 构建数据加载器(批量加载数据,自动打乱,多进程加载)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0  # Python 3.8在部分环境下num_workers>0可能报错,设为0更稳定
)

# 3. 定义VAE模型(包含编码器、解码器、重参数化技巧)
class VAE(nn.Module):
    def __init__(self, latent_dim, image_size):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.image_size = image_size
        
        # -------------- 编码器(Encoder):将输入图像映射到隐空间的分布参数(μ, log_var)--------------
        self.encoder = nn.Sequential(
            nn.Linear(image_size, 512),  # 输入层:展平图像(784) -> 隐藏层(512)
            nn.ReLU(inplace=True),  # 激活函数:ReLU(加速训练,缓解梯度消失)
            nn.Linear(512, 256),  # 隐藏层:512 -> 256
            nn.ReLU(inplace=True),  # 激活函数:ReLU
            nn.Linear(256, 2 * latent_dim)  # 输出层:256 -> 2*latent_dim(分别对应μ和log_var)
        )
        
        # -------------- 解码器(Decoder):将隐空间变量z映射回生成图像--------------
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),  # 输入层:隐变量z(20) -> 隐藏层(256)
            nn.ReLU(inplace=True),  # 激活函数:ReLU
            nn.Linear(256, 512),  # 隐藏层:256 -> 512
            nn.ReLU(inplace=True),  # 激活函数:ReLU
            nn.Linear(512, image_size),  # 输出层:512 -> 784(与输入图像尺寸一致)
            nn.Sigmoid()  # 激活函数:Sigmoid(将输出映射到[0,1],匹配图像像素分布)
        )
    
    def encode(self, x):
        """
        编码器前向传播:输入图像 -> 隐空间分布参数(μ, log_var)
        """
        h = self.encoder(x)
        # 将输出切分为均值μ和对数方差log_var(各占latent_dim维度)
        mu, log_var = torch.chunk(h, 2, dim=1)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        """
        重参数化技巧(核心):解决隐变量z不可导的问题
        思路:不直接从N(μ, σ²)采样z,而是从N(0,1)采样ε,通过z=μ+ε*σ得到z
        其中σ=exp(0.5*log_var),这样只有μ和log_var参与梯度传播,ε为常数不参与求导
        """
        std = torch.exp(0.5 * log_var)  # 计算标准差σ:exp(0.5*log_var)
        eps = torch.randn_like(std)  # 从标准正态分布N(0,1)采样ε,形状与std一致
        z = mu + eps * std  # 计算最终的隐变量z
        return z
    
    def decode(self, z):
        """
        解码器前向传播:隐变量z -> 生成图像
        """
        return self.decoder(z)
    
    def forward(self, x):
        """
        VAE整体前向传播
        """
        # 1. 编码:得到隐空间分布参数
        mu, log_var = self.encode(x)
        # 2. 重参数化:得到可导的隐变量z
        z = self.reparameterize(mu, log_var)
        # 3. 解码:得到生成图像
        x_recon = self.decode(z)
        # 返回生成图像、μ、log_var(μ和log_var用于计算损失函数)
        return x_recon, mu, log_var

# 4. 定义VAE损失函数(包含重构损失和KL散度损失)
def vae_loss(x_recon, x, mu, log_var):
    """
    VAE损失函数 = 重构损失(Reconstruction Loss) + KL散度损失(KL Divergence Loss)
    1. 重构损失:衡量生成图像与原始图像的差异,使用二元交叉熵(BCE)(匹配Sigmoid输出)
    2. KL散度损失:衡量隐空间分布N(μ, σ²)与标准正态分布N(0,1)的差异,约束隐空间分布
    """
    # 重构损失:二元交叉熵损失(flatten后计算,确保维度匹配)
    # reduction='sum':按样本求和,后续可按需归一化
    recon_loss = nn.functional.binary_cross_entropy(
        x_recon.view(-1, IMAGE_SIZE),  # 生成图像展平
        x.view(-1, IMAGE_SIZE),  # 原始图像展平
        reduction='sum'
    )
    
    # KL散度损失:计算N(μ, σ²)与N(0,1)的KL散度(推导后的简化公式)
    # KL散度公式:0.5 * sum(1 + log_var - mu² - exp(log_var))
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    # 总损失 = 重构损失 + KL散度损失(可根据需求调整权重,此处权重均为1)
    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss

# 5. 初始化模型、优化器
model = VAE(latent_dim=LATENT_DIM, image_size=IMAGE_SIZE).to(DEVICE)  # 实例化模型并移至指定设备
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)  # 使用Adam优化器(收敛更快)

# 6. 训练模型并记录误差
def train_vae(model, train_loader, optimizer, epochs, device):
    """
    VAE训练函数
    """
    # 记录训练过程中的损失值
    train_total_losses = []
    train_recon_losses = []
    train_kl_losses = []
    
    model.train()  # 切换模型到训练模式(启用Dropout、BatchNorm等训练相关层,此处模型无此类层,仅作规范)
    for epoch in range(epochs):
        epoch_total_loss = 0.0
        epoch_recon_loss = 0.0
        epoch_kl_loss = 0.0
        batch_count = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            # 数据预处理:移至指定设备 + 展平图像(28*28 -> 784)
            data = data.to(device).view(-1, IMAGE_SIZE)
            
            # 梯度清零(避免上一批次梯度累积)
            optimizer.zero_grad()
            
            # 前向传播:得到生成图像、μ、log_var
            x_recon, mu, log_var = model(data)
            
            # 计算损失
            total_loss, recon_loss, kl_loss = vae_loss(x_recon, data, mu, log_var)
            
            # 反向传播:计算梯度
            total_loss.backward()
            
            # 梯度更新:更新模型参数
            optimizer.step()
            
            # 累积批次损失
            epoch_total_loss += total_loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_kl_loss += kl_loss.item()
            batch_count += 1
        
        # 计算本轮平均损失(按批次平均,也可按样本数平均)
        avg_total_loss = epoch_total_loss / batch_count
        avg_recon_loss = epoch_recon_loss / batch_count
        avg_kl_loss = epoch_kl_loss / batch_count
        
        # 记录本轮损失
        train_total_losses.append(avg_total_loss)
        train_recon_losses.append(avg_recon_loss)
        train_kl_losses.append(avg_kl_loss)
        
        # 打印本轮训练信息
        print(f"Epoch [{epoch+1}/{epochs}], "
              f"Avg Total Loss: {avg_total_loss:.4f}, "
              f"Avg Recon Loss: {avg_recon_loss:.4f}, "
              f"Avg KL Loss: {avg_kl_loss:.4f}")
        
        # 每轮训练结束后,生成并保存一组示例图像
        generate_and_save_images(model, epoch+1, device)
    
    # 返回训练损失记录
    return train_total_losses, train_recon_losses, train_kl_losses

def generate_and_save_images(model, epoch, device):
    """
    生成图像并保存到指定路径
    思路:从标准正态分布采样隐变量z,通过解码器生成图像并保存
    """
    model.eval()  # 切换模型到评估模式
    with torch.no_grad():  # 禁用梯度计算(节省内存,加快速度)
        # 采样64个隐变量z(对应8*8网格图像),从标准正态分布N(0,1)采样
        z = torch.randn(64, LATENT_DIM).to(device)
        # 解码生成图像
        generated_images = model.decode(z).view(-1, 1, 28, 28)
        # 保存生成的图像(网格形式,文件名包含轮数)
        save_image(
            generated_images,
            os.path.join(SAVE_IMAGE_PATH, f"vae_generated_epoch_{epoch}.png"),
            nrow=8,  # 每行显示8张图像
            normalize=False  # 无需额外归一化(输出已在[0,1]范围内)
        )
    model.train()  # 切换回训练模式

# 7. 启动训练
print("开始训练VAE模型...")
total_losses, recon_losses, kl_losses = train_vae(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    epochs=EPOCHS,
    device=DEVICE
)

# 8. 绘制并展示训练误差曲线
def plot_training_losses(total_losses, recon_losses, kl_losses):
    """
    绘制训练过程中的三种损失曲线
    """
    plt.figure(figsize=(12, 4))
    
    # 绘制总损失曲线
    plt.subplot(1, 3, 1)
    plt.plot(range(1, EPOCHS+1), total_losses, 'b-', label='Total Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('VAE Total Training Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 绘制重构损失曲线
    plt.subplot(1, 3, 2)
    plt.plot(range(1, EPOCHS+1), recon_losses, 'r-', label='Reconstruction Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('VAE Reconstruction Training Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 绘制KL散度损失曲线
    plt.subplot(1, 3, 3)
    plt.plot(range(1, EPOCHS+1), kl_losses, 'g-', label='KL Divergence Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('VAE KL Divergence Training Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 调整子图间距并显示图像
    plt.tight_layout()
    plt.show()

# 调用函数绘制损失曲线
plot_training_losses(total_losses, recon_losses, kl_losses)

# 9. 最终保存一次生成图像(额外保存一组最终结果)
generate_and_save_images(model, "final", DEVICE)
print(f"训练完成!生成的图像已保存至:{os.path.abspath(SAVE_IMAGE_PATH)}")

代码核心说明与运行要求

1. 核心模块亮点

  • 重参数化技巧:在 reparameterize 方法中实现,通过采样标准正态分布的 ε 来间接生成隐变量 z,解决了直接采样的不可导问题,是 VAE 的核心创新点。
  • 双损失函数:VAE 损失由「重构损失」(保证生成图像与原图相似)和「KL 散度损失」(约束隐空间分布接近标准正态分布,提升生成图像的多样性)组成。
  • 详细注释:每个模块、方法都有清晰注释,便于理解 VAE 的工作流程。

2. 运行前置条件

安装所需依赖包(兼容 Python 3.8):

1
pip install torch==1.13.1 torchvision==0.14.1 matplotlib numpy

3. 运行结果说明

  1. 控制台输出:每轮训练结束后,会打印该轮的平均总损失、平均重构损失、平均 KL 散度损失,可观察到损失逐渐下降并趋于稳定。
  2. 生成图像存储:在当前目录下创建 vae_generated_images 文件夹,每轮训练结束后会保存一张 8×8 的生成图像(共 30 张 + 1 张最终结果),图像文件名包含训练轮数。
  3. 损失曲线展示:训练完成后,会弹出一个包含 3 张子图的窗口,分别展示总损失、重构损失、KL 散度损失随训练轮数的变化趋势。

4. 结果预期

  • 训练初期:生成图像模糊,仅能看到大致轮廓。
  • 训练后期(20 轮以后):生成图像清晰,能分辨出数字 0-9 的形状,且具有一定的多样性(不同数字、不同书写风格)。
  • 损失曲线:整体呈下降趋势,最终趋于平缓,无明显震荡(说明模型收敛稳定)。

总结

  1. 本实现完整覆盖了 VAE 的核心组件:编码器、解码器、重参数化技巧,且针对 MNIST 数据集做了优化。
  2. 实现了训练误差的实时打印与可视化,同时自动存储每轮的生成图像,便于观察模型训练进度。
  3. 代码兼容 Python 3.8,无需额外修改即可直接运行,生成的图像存储路径清晰,结果可复现。

PyTorch 实现条件变分自编码器 (CVAE) - MNIST 数据集

你需要一份基于 PyTorch、针对 MNIST 数据集的 CVAE 实现,要求包含重参数化技巧、详细注释、训练误差展示、生成图像存储,且支持手动指定生成数字类型(如数字 2),以下是完整可运行代码(兼容 Python 3.8、PyTorch 1.7+):

完整代码实现

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import os

# 1. 配置全局超参数(兼容Python 3.8,参数清晰可调整)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 优先使用GPU加速
BATCH_SIZE = 128  # 批次大小
EPOCHS = 30  # 训练轮数
LEARNING_RATE = 1e-3  # 学习率
LATENT_DIM = 20  # 潜在变量z的维度(隐空间维度)
IMAGE_SIZE = 784  # MNIST图像展平后的尺寸(28*28)
NUM_CLASSES = 10  # MNIST数据集的类别数(0-9共10个数字)
CONDITION_DIM = 16  # 类别条件的嵌入维度(将类别标签转为固定维度向量)
SAVE_IMAGE_PATH = "./cvae_generated_images/"  # 生成图像的存储路径

# 创建生成图像存储目录(不存在则创建,避免文件保存报错)
os.makedirs(SAVE_IMAGE_PATH, exist_ok=True)

# 2. 数据预处理与数据集加载
# 定义数据变换:转为Tensor + 归一化(将像素值从[0,255]映射到[0,1],匹配后续Sigmoid输出)
transform = transforms.Compose([
    transforms.ToTensor(),  # 转为torch.Tensor,形状为(C,H,W),值范围[0,1]
])

# 加载MNIST训练数据集(自动下载到./data目录,包含图像和对应的类别标签)
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

# 构建数据加载器(批量加载数据,自动打乱,num_workers=0适配Python 3.8部分环境)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# 3. 定义CVAE模型(包含编码器、解码器、重参数化技巧,核心新增类别条件融入)
class CVAE(nn.Module):
    def __init__(self, latent_dim, image_size, num_classes, condition_dim):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.image_size = image_size
        self.num_classes = num_classes
        self.condition_dim = condition_dim
        
        # -------------- 类别标签嵌入层:将离散类别标签转为连续向量(便于与图像/隐变量融合)--------------
        # 输入:类别索引(0-9),输出:固定维度的嵌入向量(condition_dim)
        self.label_embedding = nn.Embedding(num_classes, condition_dim)
        
        # -------------- 编码器(Encoder):输入「图像+类别条件」-> 隐空间分布参数(μ, log_var)--------------
        # 编码器输入维度 = 图像展平维度 + 类别嵌入维度(融合图像信息和类别信息)
        encoder_input_dim = image_size + condition_dim
        self.encoder = nn.Sequential(
            nn.Linear(encoder_input_dim, 512),  # 输入层:融合特征 -> 隐藏层(512)
            nn.ReLU(inplace=True),  # 激活函数:ReLU(加速训练,缓解梯度消失)
            nn.Linear(512, 256),  # 隐藏层:512 -> 256
            nn.ReLU(inplace=True),  # 激活函数:ReLU
            nn.Linear(256, 2 * latent_dim)  # 输出层:256 -> 2*latent_dim(分别对应μ和log_var)
        )
        
        # -------------- 解码器(Decoder):输入「隐变量z+类别条件」-> 生成图像--------------
        # 解码器输入维度 = 隐变量维度 + 类别嵌入维度(融合隐空间信息和类别信息)
        decoder_input_dim = latent_dim + condition_dim
        self.decoder = nn.Sequential(
            nn.Linear(decoder_input_dim, 256),  # 输入层:融合特征 -> 隐藏层(256)
            nn.ReLU(inplace=True),  # 激活函数:ReLU
            nn.Linear(256, 512),  # 隐藏层:256 -> 512
            nn.ReLU(inplace=True),  # 激活函数:ReLU
            nn.Linear(512, image_size),  # 输出层:512 -> 784(与输入图像尺寸一致)
            nn.Sigmoid()  # 激活函数:Sigmoid(将输出映射到[0,1],匹配图像像素分布)
        )
    
    def encode(self, x, labels):
        """
        编码器前向传播:输入「图像+类别标签」-> 隐空间分布参数(μ, log_var)
        核心:先将类别标签嵌入,再与图像展平数据拼接,作为编码器输入
        """
        # 1. 类别标签嵌入:离散标签 -> 连续向量(形状:[batch_size, condition_dim])
        label_emb = self.label_embedding(labels)
        # 2. 图像展平:[batch_size, 1, 28, 28] -> [batch_size, 784]
        x_flat = x.view(-1, self.image_size)
        # 3. 拼接图像特征和类别嵌入特征(维度在第1维拼接,保持批次维度不变)
        encoder_input = torch.cat([x_flat, label_emb], dim=1)
        # 4. 编码器前向传播,得到输出
        h = self.encoder(encoder_input)
        # 5. 将输出切分为均值μ和对数方差log_var(各占latent_dim维度)
        mu, log_var = torch.chunk(h, 2, dim=1)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        """
        重参数化技巧(核心,与VAE一致):解决隐变量z不可导的问题
        思路:不直接从N(μ, σ²)采样z,而是从N(0,1)采样ε,通过z=μ+ε*σ得到z
        其中σ=exp(0.5*log_var),这样只有μ和log_var参与梯度传播,ε为常数不参与求导
        """
        std = torch.exp(0.5 * log_var)  # 计算标准差σ:exp(0.5*log_var)(避免方差为负)
        eps = torch.randn_like(std)  # 从标准正态分布N(0,1)采样ε,形状与std一致
        z = mu + eps * std  # 计算最终的可导隐变量z
        return z
    
    def decode(self, z, labels):
        """
        解码器前向传播:输入「隐变量z+类别标签」-> 生成图像
        核心:先将类别标签嵌入,再与隐变量z拼接,作为解码器输入
        """
        # 1. 类别标签嵌入:离散标签 -> 连续向量(形状:[batch_size, condition_dim])
        label_emb = self.label_embedding(labels)
        # 2. 拼接隐变量z和类别嵌入特征(维度在第1维拼接,保持批次维度不变)
        decoder_input = torch.cat([z, label_emb], dim=1)
        # 3. 解码器前向传播,得到生成图像
        return self.decoder(decoder_input)
    
    def forward(self, x, labels):
        """
        CVAE整体前向传播(端到端)
        """
        # 1. 编码:得到融入类别信息的隐空间分布参数
        mu, log_var = self.encode(x, labels)
        # 2. 重参数化:得到可导的隐变量z
        z = self.reparameterize(mu, log_var)
        # 3. 解码:得到融入类别信息的生成图像
        x_recon = self.decode(z, labels)
        # 返回生成图像、μ、log_var(μ和log_var用于计算损失函数)
        return x_recon, mu, log_var

# 4. 定义CVAE损失函数(与VAE一致,包含重构损失和KL散度损失)
def cvae_loss(x_recon, x, mu, log_var):
    """
    CVAE损失函数 = 重构损失(Reconstruction Loss) + KL散度损失(KL Divergence Loss)
    1. 重构损失:衡量生成图像与原始图像的差异,使用二元交叉熵(BCE)(匹配Sigmoid输出)
    2. KL散度损失:衡量隐空间分布N(μ, σ²)与标准正态分布N(0,1)的差异,约束隐空间分布
    """
    # 重构损失:二元交叉熵损失(展平后计算,确保维度匹配,reduction='sum'按样本求和)
    recon_loss = nn.functional.binary_cross_entropy(
        x_recon.view(-1, IMAGE_SIZE),  # 生成图像展平:[batch_size, 784]
        x.view(-1, IMAGE_SIZE),  # 原始图像展平:[batch_size, 784]
        reduction='sum'
    )
    
    # KL散度损失:计算N(μ, σ²)与N(0,1)的KL散度(推导后的简化公式,避免数值不稳定)
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    
    # 总损失 = 重构损失 + KL散度损失(权重均为1,可按需调整平衡生成质量与多样性)
    total_loss = recon_loss + kl_loss
    return total_loss, recon_loss, kl_loss

# 5. 初始化模型、优化器
model = CVAE(
    latent_dim=LATENT_DIM,
    image_size=IMAGE_SIZE,
    num_classes=NUM_CLASSES,
    condition_dim=CONDITION_DIM
).to(DEVICE)  # 实例化CVAE模型并移至指定设备(GPU/CPU)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)  # 使用Adam优化器(收敛快,稳定性好)

# 6. 定义训练函数(记录训练误差,每轮生成并保存示例图像)
def train_cvae(model, train_loader, optimizer, epochs, device):
    """
    CVAE训练函数:完成模型训练,记录每轮损失,每轮生成示例图像
    """
    # 初始化损失记录列表,用于后续可视化训练误差
    train_total_losses = []
    train_recon_losses = []
    train_kl_losses = []
    
    model.train()  # 切换模型到训练模式(启用BatchNorm/Dropout等训练相关层,此处无但保持规范)
    for epoch in range(epochs):
        # 初始化每轮的损失累积变量
        epoch_total_loss = 0.0
        epoch_recon_loss = 0.0
        epoch_kl_loss = 0.0
        batch_count = 0
        
        for batch_idx, (data, labels) in enumerate(train_loader):
            # 数据预处理:移至指定设备,保持标签为长整型(适配Embedding层输入)
            data = data.to(device)
            labels = labels.to(device, dtype=torch.long)
            
            # 梯度清零:避免上一批次梯度累积影响当前批次训练
            optimizer.zero_grad()
            
            # 前向传播:得到生成图像、μ、log_var
            x_recon, mu, log_var = model(data, labels)
            
            # 计算损失:总损失、重构损失、KL散度损失
            total_loss, recon_loss, kl_loss = cvae_loss(x_recon, data, mu, log_var)
            
            # 反向传播:计算模型参数梯度
            total_loss.backward()
            
            # 梯度更新:更新模型参数
            optimizer.step()
            
            # 累积批次损失(转换为numpy值,避免占用GPU内存)
            epoch_total_loss += total_loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_kl_loss += kl_loss.item()
            batch_count += 1
        
        # 计算本轮平均损失(按批次平均,便于跨轮次对比)
        avg_total_loss = epoch_total_loss / batch_count
        avg_recon_loss = epoch_recon_loss / batch_count
        avg_kl_loss = epoch_kl_loss / batch_count
        
        # 记录本轮平均损失,用于后续可视化
        train_total_losses.append(avg_total_loss)
        train_recon_losses.append(avg_recon_loss)
        train_kl_losses.append(avg_kl_loss)
        
        # 打印本轮训练信息(控制台输出,观察训练进度)
        print(f"Epoch [{epoch+1}/{epochs}], "
              f"Avg Total Loss: {avg_total_loss:.4f}, "
              f"Avg Recon Loss: {avg_recon_loss:.4f}, "
              f"Avg KL Loss: {avg_kl_loss:.4f}")
        
        # 每轮训练结束后,生成并保存一组示例图像(默认生成0-9每个数字各8张)
        generate_and_save_images(model, epoch+1, device, target_digits=None)
    
    # 返回训练损失记录,用于后续可视化
    return train_total_losses, train_recon_losses, train_kl_losses

# 7. 定义图像生成与保存函数(支持手动指定生成数字类型,如生成数字2)
def generate_and_save_images(model, epoch, device, target_digits=None, n_samples_per_digit=8):
    """
    生成图像并保存到指定路径,支持手动指定生成的数字类型
    参数说明:
        model: 训练好的CVAE模型
        epoch: 训练轮数(用于文件名命名)
        device: 计算设备(GPU/CPU)
        target_digits: 手动指定的生成数字列表,None则生成0-9所有数字
        n_samples_per_digit: 每个数字生成的样本数
    """
    model.eval()  # 切换模型到评估模式(禁用BatchNorm/Dropout,固定模型参数)
    with torch.no_grad():  # 禁用梯度计算(节省内存,加快生成速度,无需求导)
        # 1. 确定要生成的数字类别
        if target_digits is None:
            target_digits = list(range(NUM_CLASSES))  # 默认生成0-9所有数字
        else:
            # 验证输入的目标数字是否合法(0-9之间的整数)
            target_digits = [d for d in target_digits if 0 <= d < NUM_CLASSES]
            if not target_digits:
                raise ValueError("目标数字必须是0-9之间的整数")
        
        # 2. 构造类别标签(每个数字生成n_samples_per_digit个样本)
        labels_list = []
        for digit in target_digits:
            labels_list.extend([digit] * n_samples_per_digit)
        labels = torch.tensor(labels_list, dtype=torch.long).to(device)
        
        # 3. 从标准正态分布采样隐变量z(形状与标签匹配)
        z = torch.randn(len(labels), LATENT_DIM).to(device)
        
        # 4. 解码生成图像(融入类别标签信息,得到指定数字的生成图像)
        generated_images = model.decode(z, labels).view(-1, 1, 28, 28)
        
        # 5. 保存生成的图像(网格形式,文件名包含轮数/目标数字)
        if target_digits is not None and len(target_digits) == 1:
            # 单个数字生成:文件名标注具体数字
            filename = f"cvae_generated_epoch_{epoch}_digit_{target_digits[0]}.png"
            nrow = n_samples_per_digit  # 单行显示所有该数字的样本
        else:
            # 多个数字生成:文件名标注轮数
            filename = f"cvae_generated_epoch_{epoch}.png"
            nrow = n_samples_per_digit  # 每行显示单个数字的所有样本
        
        save_image(
            generated_images,
            os.path.join(SAVE_IMAGE_PATH, filename),
            nrow=nrow,  # 网格每行显示的图像数
            normalize=False  # 无需额外归一化(输出已在[0,1]范围内)
        )
    
    model.train()  # 切换回训练模式(不影响后续训练,保持规范)

# 8. 启动CVAE模型训练
print("开始训练CVAE模型...")
total_losses, recon_losses, kl_losses = train_cvae(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    epochs=EPOCHS,
    device=DEVICE
)

# 9. 绘制并展示训练误差曲线(直观观察训练过程中的损失变化)
def plot_training_losses(total_losses, recon_losses, kl_losses):
    """
    绘制训练过程中的三种损失曲线:总损失、重构损失、KL散度损失
    """
    plt.figure(figsize=(12, 4))
    
    # 绘制总损失曲线
    plt.subplot(1, 3, 1)
    plt.plot(range(1, EPOCHS+1), total_losses, 'b-', label='Total Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('CVAE Total Training Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)  # 添加网格,提升可读性
    
    # 绘制重构损失曲线
    plt.subplot(1, 3, 2)
    plt.plot(range(1, EPOCHS+1), recon_losses, 'r-', label='Reconstruction Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('CVAE Reconstruction Training Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 绘制KL散度损失曲线
    plt.subplot(1, 3, 3)
    plt.plot(range(1, EPOCHS+1), kl_losses, 'g-', label='KL Divergence Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('CVAE KL Divergence Training Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 调整子图间距,避免重叠
    plt.tight_layout()
    # 显示损失曲线窗口
    plt.show()

# 调用函数绘制训练误差曲线
plot_training_losses(total_losses, recon_losses, kl_losses)

# 10. 手动指定生成数字类型(示例:生成数字2,可修改为其他数字或多个数字)
print("开始生成指定数字(数字2)的图像...")
generate_and_save_images(
    model=model,
    epoch="final_target_2",
    device=DEVICE,
    target_digits=[2],  # 手动指定生成数字2,可修改为[0,5,9]等多个数字
    n_samples_per_digit=64  # 生成64张数字2的图像,按8×8网格保存
)

# 额外示例:生成数字5和7的图像(可注释掉,按需运行)
# generate_and_save_images(
#     model=model,
#     epoch="final_target_5_7",
#     device=DEVICE,
#     target_digits=[5,7],
#     n_samples_per_digit=32
# )

print(f"训练与生成完成!所有图像已保存至:{os.path.abspath(SAVE_IMAGE_PATH)}")

代码核心说明与运行要求

1. 核心模块亮点(CVAE 与 VAE 的关键区别)

  • 类别条件融入:新增 nn.Embedding 层将离散数字标签转为连续向量,分别在 ** 编码器输入(图像 + 类别)解码器输入(隐变量 z + 类别)** 中拼接,使模型学习到类别相关的生成能力,支持手动指定生成数字。
  • 重参数化技巧:与 VAE 保持一致,在 reparameterize 方法中实现,解决隐变量 z 不可导的问题,保证模型可通过梯度下降训练。
  • 指定数字生成generate_and_save_images 函数支持 target_digits 参数,传入 [2] 即可生成纯数字 2 的图像,满足手动指定需求。
  • 详细注释:每个模块、方法都有清晰注释,便于理解 CVAE 的工作流程和核心创新点。

2. 运行前置条件(兼容 Python 3.8)

安装所需依赖包,推荐指定版本保证兼容性:

bash

运行

1
pip install torch==1.13.1 torchvision==0.14.1 matplotlib numpy

3. 运行结果说明

  1. 控制台输出:每轮训练结束后,打印该轮的平均总损失、平均重构损失、平均 KL 散度损失,可观察到损失逐渐下降并趋于稳定,说明模型收敛。
  2. 默认生成图像:训练过程中,每轮会在 cvae_generated_images 文件夹中保存 0-9 所有数字的生成图像(每数字 8 张),文件名包含训练轮数。
  3. 指定数字生成图像:训练完成后,会自动生成 64 张数字 2 的图像(8×8 网格),文件名为 cvae_generated_epoch_final_target_2_digit_2.png,同时保存最终轮的全数字图像。
  4. 损失曲线展示:训练完成后,弹出包含 3 张子图的窗口,分别展示总损失、重构损失、KL 散度损失随训练轮数的变化趋势,直观观察模型训练效果。

4. 结果预期

  • 训练初期:生成图像模糊,数字轮廓不清晰,但已能区分不同类别(因融入了类别条件)。
  • 训练后期(20 轮以后):生成图像清晰,指定数字 2 的图像具有明显的 “2” 形状,且具有一定的多样性(不同书写风格、粗细变化)。
  • 损失曲线:整体呈下降趋势,最终趋于平缓,无明显震荡,说明模型训练稳定,重构能力和隐空间约束能力均已收敛。

5. 自定义修改指南

  • 修改生成数字:将 target_digits=[2] 改为 [0][9] 即可生成其他单个数字,改为 [1,3,6] 可生成多个数字。
  • 调整生成数量:修改 n_samples_per_digit 参数,如改为 32 则每个数字生成 32 张图像。
  • 调整超参数:修改 LATENT_DIM(隐空间维度)、EPOCHS(训练轮数)可优化生成图像质量,隐空间维度越大,生成多样性越强(但训练难度增加)。

总结

  1. 本实现完整覆盖了 CVAE 的核心组件:类别嵌入层、带条件的编码器 / 解码器、重参数化技巧,且针对 MNIST 数据集做了优化,兼容 Python 3.8。
  2. 实现了训练误差的实时打印与可视化,同时支持默认全数字生成和手动指定数字生成,生成图像存储路径清晰,结果可复现。
  3. 代码结构清晰,注释详细,便于后续扩展(如迁移到其他数据集、添加卷积层提升生成质量)。
comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计