Appearance
卷积层与池化层:从 shape 推导理解 fc1 输入维度
针对 MNIST 训练过程中对卷积层与池化层理解不够透彻的问题,整理如下。
核心问题:
为什么全连接层
fc1的输入维度是256,而不是原始图像的784?
答案要从卷积层和池化层对特征图尺寸的改变开始推导。
1. 手算 shape:28×28 如何变成 256
假设输入是一张 28×28 的灰度图,网络结构为:
Conv(5×5)MaxPool(2×2)Conv(5×5)MaxPool(2×2)Flatten
在没有 padding、stride 为 1 的情况下,卷积输出尺寸公式为:
其中:
表示输入的高或宽 表示卷积核大小 表示输出的高或宽
所以:
text
28×28
→ Conv(5×5)
→ 24×24因为:
接着经过 MaxPool(2×2),长宽各缩小一半:
text
24×24
→ MaxPool(2×2)
→ 12×12也就是:
再经过第二个 Conv(5×5):
text
12×12
→ Conv(5×5)
→ 8×8因为:
再经过第二个 MaxPool(2×2):
text
8×8
→ MaxPool(2×2)
→ 4×4也就是:
如果第二个卷积层输出 16 张特征图,那么最终不是只有一个 4×4,而是:
text
16 张 4×4 特征图Flatten 之后的向量长度为:
所以:
完整流程:
text
28×28
→ Conv(5×5)
→ 24×24
→ MaxPool(2×2)
→ 12×12
→ Conv(5×5)
→ 8×8
→ MaxPool(2×2)
→ 4×4
→ Flatten
→ 16 × 4 × 4 = 256因此 fc1 应该写成接收 256 维输入,而不是 784。
这个维度只适用于直接把原始图像展平后送入全连接层的情况;这里图像已经先经过卷积和池化,尺寸发生了变化。
2. Conv:卷积层是特征提取器
可以把卷积核想象成一个小放大镜,在图像上从左到右、从上到下滑动。
每次只看图像中的一小块区域,判断这块区域里有没有它要找的特征。
示例:
text
原图 5×5 卷积核 3×3 输出 3×3
1 0 1 0 1 1 0 1 ? ? ?
0 1 0 1 0 × 0 1 0 = ? ? ?
1 0 1 0 1 1 0 1 ? ? ?
0 1 0 1 0
1 0 1 0 1每次滑动时,把卷积核覆盖到的区域和卷积核对应位置相乘,再求和,得到输出特征图上的一个数。
左上角位置:
text
图像局部区域: 卷积核:
1 0 1 1 0 1
0 1 0 × 0 1 0
1 0 1 1 0 1计算过程可以写成:
这个 5 就是输出特征图左上角的值。
关键理解:
- 一个卷积核就是一种“特征检测器”,可以检测边缘、纹理、形状等局部模式。
Conv2d(1, 6, 5)中的6表示有6个不同卷积核,也就是同时学习6种特征。- 卷积核的权重是训练出来的,网络会自己学会应该检测什么特征。
- 如果不加 padding,卷积后尺寸会变小,因为边缘位置放不下完整的卷积核。
例如:
text
28×28 经过 5×5 卷积后变成 24×24对应的计算是:
3. MaxPool:最大池化是信息压缩器
最大池化没有可训练参数。
它把特征图分成若干小块,每块只保留最大值,丢掉其他值。
示例:
text
输入 4×4 MaxPool(2×2) 输出 2×2
1 3 2 4 6 4
5 6 1 2 → 每个 2×2 取最大值 → 7 8
7 2 8 3
4 1 5 6四个区域分别是:
text
左上:[1, 3, 5, 6],最大值是 6
右上:[2, 4, 1, 2],最大值是 4
左下:[7, 2, 4, 1],最大值是 7
右下:[8, 3, 5, 6],最大值是 8所以输出为:
text
6 4
7 8关键理解:
- MaxPool 没有可训练参数,只负责压缩信息。
- 它能减小特征图尺寸,降低后续计算量。
- 它会保留局部区域中最显著的特征信号。
- 它能带来轻微的平移不变性。
MaxPool(2×2)通常让长宽各缩小一半,例如24×24 → 12×12。
4. MNIST 网络中的完整数据流
以经典 MNIST 卷积网络为例:
text
[输入图像 28×28]
│
▼
Conv1(5×5)
学习识别边缘、笔画方向等底层特征
输出:6 张 24×24 特征图
│
▼
MaxPool(2×2)
压缩并保留最强特征信号
输出:6 张 12×12 特征图
│
▼
Conv2(5×5)
学习识别更复杂的组合特征,例如弯曲、交叉
输出:16 张 8×8 特征图
│
▼
MaxPool(2×2)
再次压缩
输出:16 张 4×4 特征图
│
▼
Flatten
把 16×4×4 = 256 个数拉成一列
│
▼
全连接层
根据这 256 个特征值判断数字属于 0-9 中的哪一类其中 Flatten 的维度计算为:
5. 总结
| 层 | 作用 | 是否有参数 | 尺寸变化 |
|---|---|---|---|
| Conv | 提取特征,检测局部模式 | 有,卷积核权重可训练 | 通常变小,取决于卷积核、padding、stride |
| MaxPool | 压缩信息,保留局部最强信号 | 无 | 通常变小,例如长宽各减半 |
最终结论:
text
28×28
→ Conv(5×5)
→ 24×24
→ MaxPool(2×2)
→ 12×12
→ Conv(5×5)
→ 8×8
→ MaxPool(2×2)
→ 4×4
→ Flatten
→ 16 × 4 × 4 = 256用数学形式表示为:
最终 Flatten 维度为:
所以 fc1 的输入必须是 256,不能是 784。
代码实例:
Python
import torch
import torchvision
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# -------- 数据准备 --------transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=200, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=200, shuffle=False)
# -------- 模型定义 --------class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 6, 5)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.fc1 = torch.nn.Linear(256, 128)
self.fc2 = torch.nn.Linear(128, 64)
self.fc3 = torch.nn.Linear(64, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001) #使用Adam优化器
# -------- 记录数据用的列表 -------- ← 新增
train_losses = [] # 每个 epoch 的平均训练 losstest_accuracies = [] # 每个 epoch 结束后的测试准确率
test_losses = [] # 每个 epoch 结束后的测试 loss
# -------- 训练循环 --------epochs = 10
print("Start Training...")
for epoch in range(epochs):
net.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(dataloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# ← 每个 epoch 结束,记录训练 loss epoch_train_loss = running_loss / len(dataloader)
train_losses.append(epoch_train_loss)
# ← 每个 epoch 结束,跑一次测试
net.eval()
correct, total, t_loss = 0, 0, 0.0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
t_loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
epoch_test_loss = t_loss / len(
testloader)
epoch_test_acc = 100 * correct / total
test_losses.append(epoch_test_loss)
test_accuracies.append(epoch_test_acc)
print(f"Epoch [{epoch+1}/{epochs}] "
f"Train Loss: {epoch_train_loss:.4f} "
f"Test Loss: {epoch_test_loss:.4f} "
f"Test Acc: {epoch_test_acc:.2f}%")
print("Finished Training")
# -------- 可视化 -------- ← 新增
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# 左图:Loss 曲线
ax1.plot(range(1, epochs+1), train_losses, 'b-o', label='Train Loss')
ax1.plot(range(1, epochs+1), test_losses, 'r-o', label='Test Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Test Loss')
ax1.legend()
ax1.grid(True)
# 右图:准确率曲线
ax2.plot(range(1, epochs+
1), test_accuracies, 'g-o', label='Test Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Test Accuracy')
ax2.legend()
ax2.grid(True)
ax2.set_ylim([0, 100])
plt.tight_layout()
plt.savefig('mnist_training_curves.png', dpi=150) # 保存图片
plt.show()
print("曲线已保存为 mnist_training_curves.png")