文章插图
作者 | 李秋键
责编 | Carol
封图 | CSDN 下载自视觉中国
近几年来GAN图像生成应用越来越广泛,其中主要得益于GAN 在博弈下不断提高建模能力,最终实现以假乱真的图像生成 。GAN 由两个神经网络组成,一个生成器和一个判别器组成,其中生成器试图产生欺骗判别器的真实样本,而判别器试图区分真实样本和生成样本 。这种对抗博弈下使得生成器和判别器不断提高性能,在达到纳什平衡后生成器可以实现以假乱真的输出 。
其中GAN 在图像生成应用最为突出,当然在计算机视觉中还有许多其他应用,如图像绘画,图像标注,物体检测和语义分割 。在自然语言处理中应用 GAN 的研究也是一种增长趋势,如文本建模,对话生成,问答和机器翻译 。然而,在 NLP 任务中训练 GAN 更加困难并且需要更多技术,这也使其成为具有挑战性但有趣的研究领域 。
而今天我们就将利用CC-GAN训练将侧脸生成正脸的模型,其中迭代20次结果如下:
文章插图
文章插图
实验前的准备首先我们使用的Python版本是3.6.5所用到的模块如下:tensorflow用来模型训练和网络层建立;numpy模块用来处理矩阵运算;OpenCV用来读取图片和图像处理;os模块用来读取数据集等本地文件操作 。
文章插图
素材准备其中准备训练的不同角度人脸图片放入以下文件夹作为训练集,如下图可见:
文章插图
测试集图片如下可见:
文章插图
文章插图
模型搭建原始GAN(GAN 简介与代码实战)在理论上可以完全逼近真实数据,但它的可控性不强(生成小图片还行,生成的大图片可能是不合逻辑的),因此需要对gan加一些约束,能生成我们想要的图片,这个时候,CGAN就横空出世了 。其中CCGAN整体模型结构如下:
文章插图
1、网络结构参数的搭建:【用 Python 可以实现侧脸转正脸?我也要试一下】首先是定义标准化、激活函数和池化层等函数:Batch_Norm是对其进行规整,是为了防止同一个batch间的梯度相互抵消 。其将不同batch规整到同一个均值0和方差1 。InstanceNorm是将输入在深度方向上减去均值除以标准差,可以加快网络的训练速度 。
def instance_norm(x, scope='instance_norm'):return tf_contrib.layers.instance_norm(x, epsilon=1e-05, center=True, scale=True, scope=scope)def batch_norm(x, scope='batch_norm'):return tf_contrib.layers.batch_norm(x, decay=0.9, epsilon=1e-05, center=True, scale=True, scope=scope)def flatten(x) :return tf.layers.flatten(x)def lrelu(x, alpha=0.2):return tf.nn.leaky_relu(x, alpha)def relu(x):return tf.nn.relu(x)def global_avg_pooling(x):gap = tf.reduce_mean(x, axis=[1, 2], keepdims=True)return gapdef resblock(x_init, c, scope='resblock'):with tf.variable_scope(scope):with tf.variable_scope('res1'):x = slim.conv2d(x_init, c, kernel_size=[3,3], stride=1, activation_fn = None)x = batch_norm(x)x = relu(x)with tf.variable_scope('res2'):x = slim.conv2d(x, c, kernel_size=[3,3], stride=1, activation_fn = None)x = batch_norm(x)return x + x_init
然后是卷积层的定义:def conv(x, c):x1 = slim.conv2d(x, c, kernel_size=[5,5], stride=2, padding = 'SAME', activation_fn=relu)# print(x1.shape)x2 = slim.conv2d(x, c, kernel_size=[3,3], stride=2, padding = 'SAME', activation_fn=relu)# print(x2.shape)x3 = slim.conv2d(x, c, kernel_size=[1,1], stride=2, padding = 'SAME', activation_fn=relu)# print(x3.shape)out = tf.concat([x1, x2, x3],axis = 3)out = slim.conv2d(out, c, kernel_size=[1,1], stride=1, padding = 'SAME', activation_fn=None)# print(out.shape)return out
生成器函数定义:def mixgenerator(x_init, c, org_pose, trg_pose): reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0with tf.variable_scope('generator', reuse = reuse):org_pose = tf.cast(tf.reshape(org_pose, shape=[-1, 1, 1, org_pose.shape[-1]]), tf.float32)print(org_pose.shape)org_pose = tf.tile(org_pose, [1, x_init.shape[1], x_init.shape[2], 1])print(org_pose.shape)x = tf.concat([x_init, org_pose], axis=-1)print(x.shape)x = conv(x, c)x = batch_norm(x, scope='bat_norm_1')x = relu(x)#64print('----------------')print(x.shape)x = conv(x, c*2)x = batch_norm(x, scope='bat_norm_2')x = relu(x)#32print(x.shape)x = conv(x, c*4)x = batch_norm(x, scope='bat_norm_3')x = relu(x)#16print(x.shape)f_org = xx = conv(x, c*8)x = batch_norm(x, scope='bat_norm_4')x = relu(x)#8print(x.shape)x = conv(x, c*8)x = batch_norm(x, scope='bat_norm_5')x = relu(x)#4print(x.shape)for i in range(6):x = resblock(x, c*8, scope = str(i)+"_resblock")trg_pose = tf.cast(tf.reshape(trg_pose, shape=[-1, 1, 1, trg_pose.shape[-1]]), tf.float32)print(trg_pose.shape)trg_pose = tf.tile(trg_pose, [1, x.shape[1], x.shape[2], 1])print(trg_pose.shape)x = tf.concat([x, trg_pose], axis=-1)print(x.shape)x = slim.conv2d_transpose(x, c*8, kernel_size=[3, 3], stride=2, activation_fn=None)x = batch_norm(x, scope='bat_norm_8')x = relu(x)#8print(x.shape)x = slim.conv2d_transpose(x, c*4, kernel_size=[3, 3], stride=2, activation_fn=None)x = batch_norm(x, scope='bat_norm_9')x = relu(x)#16print(x.shape)f_trg =xx = slim.conv2d_transpose(x, c*2, kernel_size=[3, 3], stride=2, activation_fn=None)x = batch_norm(x, scope='bat_norm_10')x = relu(x)#32print(x.shape)x = slim.conv2d_transpose(x, c, kernel_size=[3, 3], stride=2, activation_fn=None)x = batch_norm(x, scope='bat_norm_11')x = relu(x)#64print(x.shape)z = slim.conv2d_transpose(x, 3 , kernel_size=[3,3], stride=2, activation_fn = tf.nn.tanh)f = tf.concat([f_org, f_trg], axis=-1)print(f.shape)return z, f
推荐阅读
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- SpringBoot常用属性配置
- 2020年适用于任何团队的5大数据库文档工具
- C,Java和Python之间的性能比较
- Linux操作系统中常用调度算法
- Python通过MySQLdb访问操作MySQL数据库
- Bash技巧:介绍一个可以增删改查键值对格式配置文件的Shell脚本
- 恢复AD用户误删,给你3种方案
- 使用 Mailmerge 发送定制邮件
- 在Python中使用Torchmoji将文本转换为表情符号
- 适合数据库初级人员 常用的sql语句集合