Skip to content

Files

This branch is 10 commits ahead of, 837 commits behind release/2.4.

cycle_gan

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020
Feb 28, 2020

运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照安装文档中的说明更新PaddlePaddle安装版本。

代码结构

├── data_reader.py  # 读取、处理数据。
├── layers.py   # 封装定义基础的layers。
├── model.py   # 定义基础生成网络和判别网络。
├── trainer.py   # 构造loss和训练网络。
├── train.py     # 训练脚本。
└── infer.py    # 预测脚本。

简介

TODO

数据准备

本教程使用 horse2zebra 数据集 来进行模型的训练测试工作,该数据集是用关键字'wild horse'和'zebra'过滤ImageNet数据集并下载得到的。

horse2zebra训练集包含1069张野马图片,1336张斑马图片。测试集包含121张野马图片和141张斑马图片。

数据下载处理完毕后,并组织为以下路径结构:

data
|-- horse2zebra
|   |-- testA
|   |-- testA.txt
|   |-- testB
|   |-- testB.txt
|   |-- trainA
|   |-- trainA.txt
|   |-- trainB
|   `-- trainB.txt

以上数据文件中,data文件夹需要放在训练脚本train.py同级目录下。testA为存放野马测试图片的文件夹,testB为存放斑马测试图片的文件夹,testA.txttestB.txt分别为野马和斑马测试图片路径列表文件,格式如下:

testA/n02381460_9243.jpg
testA/n02381460_9244.jpg
testA/n02381460_9245.jpg

训练数据组织方式与测试数据相同。

模型训练与预测

训练

在GPU单卡上训练:

env CUDA_VISIBLE_DEVICES=0 python train.py

执行python train.py --help可查看更多使用方式和参数详细说明。

图1为训练152轮的训练损失示意图,其中横坐标轴为训练轮数,纵轴为在训练集上的损失。其中,'g_A_loss','g_B_loss','d_A_loss'和'd_B_loss'分别为生成器A、生成器B、判别器A和判别器B的训练损失。


图 1

预测

执行以下命令读取多张图片进行预测:

env CUDA_VISIBLE_DEVICES=0 python infer.py \
    --init_model="output/checkpoints/1" --input="./data/horse2zebra/trainA/*" \
    --input_style A --output="./output"

训练150轮的模型预测效果如图2和图3所示:


图 2


图 3

在本文示例中,均可通过修改CUDA_VISIBLE_DEVICES改变使用的显卡号。