基于 PyTorch 的 BigGAN 模型

这个项目是对 BigGAN(边界平衡生成对抗网络)的 PyTorch 实现,BigGAN 是一种强大的图像生成模型。它基于 Andrew Brock、Jeff Donahue 和 Karen Simonyan 的研究成果,并由 Andy Brock 和 Alex Andonian 编写代码。

使用方法

环境准备

  • PyTorch 1.0.1 或更高版本
  • tqdm、numpy、scipy 和 h5py
  • ImageNet 训练数据集

数据预处理

  1. (可选)将 ImageNet 训练数据集转换为 HDF5 格式,以加快 I/O 速度。
  2. 使用 sh scripts/utils/prepare_data.sh 计算 FID(Fréchet Inception Distance)所需的 Inception 时刻。脚本默认将 ImageNet 训练数据集的位置设置为 data 文件夹,并将其转换为 128x128 像素分辨率的缓存 HDF5 格式。

训练模型

scripts 文件夹中包含多个 bash 脚本,用于训练不同批处理大小的 BigGAN 模型。由于代码假设您无法访问完整的 TPU 吊舱,因此它使用梯度累积(对多个小批次的平均梯度进行更新)来模拟大批处理训练。

总结

该项目提供了 BigGAN 模型的 PyTorch 实现,并包含用于数据预处理和模型训练的脚本。它为图像生成任务提供了一个强大的工具,并允许用户探索不同的训练配置。