|
|
51CTO旗下网站
|
|
移动端

一文读懂自动编码器的前世今生

变分自动编码器(VAE)可以说是最实用的自动编码器,但是在讨论VAE之前,还必须了解一下用于数据压缩或去噪的传统自动编码器。

作者:读芯术来源:头条科技|2019-05-22 17:34

代码详解:一文读懂自动编码器的前世今生

变分自动编码器(VAE)可以说是最实用的自动编码器,但是在讨论VAE之前,还必须了解一下用于数据压缩或去噪的传统自动编码器。

变分自动编码器的厉害之处

假设你正在开发一款开放性世界端游,且游戏里的景观设定相当复杂。

你聘用了一个图形设计团队来制作一些植物和树木以装饰游戏世界,但是将这些装饰植物放进游戏中之后,你发现它们看起来很不自然,因为同种植物的外观看起来一模一样,这时你该怎么办呢?

首先,你可能会建议使用一些参数化来尝试随机地改变图像,但是多少改变才足够呢?又需要多大的改变呢?还有一个重要的问题:实现这种改变的计算强度如何?

这是使用变分自动编码器的理想情况。我们可以训练一个神经网络,使其学习植物的潜在特征,每当我们将一个植物放入游戏世界中,就可以从“已学习”的特征中随机抽取一个样本,生成独特的植物。事实上,很多开放性世界游戏正在通过这种方法构建他们的游戏世界设定。

再看一个更图形化的例子。假设我们是一个建筑师,想要为任意形状的建筑生成平面图。可以让一个自动编码器网络基于任意建筑形状来学习数据生成分布,它将从数据生成分布中提取样本来生成一个平面图。详见下方的动画。

代码详解:一文读懂自动编码器的前世今生

对于设计师来说,这些技术的潜力无疑是最突出的。

再假设我们为一个时装公司工作,需要设计一种新的服装风格,可以基于“时尚”的服装来训练自动编码器,使其学习时装的数据生成分布。随后,从这个低维潜在分布中提取样本,并以此来创造新的风格。

在该节中我们将研究fashion MNIST数据集。

自动编码器

传统自动编码器

自动编码器其实就是非常简单的神经结构。它们大体上是一种压缩形式,类似于使用MP3压缩音频文件或使用jpeg压缩图像文件。

代码详解:一文读懂自动编码器的前世今生

自动编码器与主成分分析(PCA)密切相关。事实上,如果自动编码器使用的激活函数在每一层中都是线性的,那么瓶颈处存在的潜在变量(网络中最小的层,即代码)将直接对应(PCA/主成分分析)的主要组件。通常,自动编码器中使用的激活函数是非线性的,典型的激活函数是ReLU(整流线性函数)和sigmoid/S函数。

网络背后的数学原理理解起来相对容易。从本质上看,可以把网络分成两个部分:编码器和解码器。

代码详解:一文读懂自动编码器的前世今生

编码器函数用ϕ表示,该函数将原始数据X映射到潜在空间F中(潜在空间F位于瓶颈处)。解码器函数用ψ表示,该函数将瓶颈处的潜在空间F映射到输出函数。此处的输出函数与输入函数相同。因此,我们基本上是在一些概括的非线性压缩之后重建原始图像。

编码网络可以用激活函数传递的标准神经网络函数表示,其中z是潜在维度。

代码详解:一文读懂自动编码器的前世今生

相似地,解码网络可以用相同的方式表示,但需要使用不同的权重、偏差和潜在的激活函数。

代码详解:一文读懂自动编码器的前世今生

随后就可以利用这些网络函数来编写损失函数,我们会利用这个损失函数通过标准的反向传播程序来训练神经网络。

代码详解:一文读懂自动编码器的前世今生

由于输入和输出的是相同的图像,神经网络的训练过程并不是监督学习或无监督学习,我们通常将这个过程称为自我监督学习。自动编码器的目的是选择编码器和解码器函数,这样就可以用最少的信息来编码图像,使其可以在另一侧重新生成。

如果在瓶颈层中使用的节点太少,重新创建图像的能力将受到限制,导致重新生成的图像模糊或者和原图像差别很大。如果使用的节点太多,那么就没必要压缩了。

压缩背后的理论其实很简单,例如,每当你在Netflix下载某些内容时,发送给你的数据都会被压缩。一旦这个内容传输到电脑上就会通解压算法在电脑屏幕显示出来。这类似于zip文件的运行方式,只是这里说的压缩是在后台通过流处理算法完成的。

去噪自动编码器

有几种其它类型的自动编码器。其中最常用的是去噪自动编码器,本教程稍后会和Keras一起进行分析。这些自动编码器在训练前给数据添加一些白噪声,但在训练时会将误差与原始图像进行比较。这就使得网络不会过度拟合图像中出现的任意噪声。稍后,将使用它来清除文档扫描图像中的折痕和暗黑区域。

稀疏自动编码器

与其字义相反的是,稀疏自动编码器具有比输入或输出维度更大的潜在维度。然而,每次网络运行时,只有很小一部分神经元会触发,这意味着网络本质上是“稀疏”的。稀疏自动编码器也是通过一种规则化的形式来减少网络过度拟合的倾向,这一点与去噪自动编码器相似。

收缩自动编码器

收缩编码器与前两个自动编码器的运行过程基本相同,但是在收缩自动编码器中,我们不改变结构,只是在丢失函数中添加一个正则化器。这可以被看作是岭回归的一种神经形式。

现在了解了自动编码器是如何运行的,接下来看看自动编码器的弱项。一些最显著的挑战包括:

· 潜在空间中的间隙

· 潜在空间中的可分性

· 离散潜在空间

这些问题都在以下图中体现。

代码详解:一文读懂自动编码器的前世今生

MNIST数据集的潜在空间表示

这张图显示了潜在空间中不同标记数字的位置。可以看到潜在空间中存在间隙,我们不知道字符在这些空间中是长什么样的。这相当于在监督学习中缺乏数据,因为网络并没有针对这些潜在空间的情况进行过训练。另一个问题就是空间的可分性,上图中有几个数字被很好地分离,但也有一些区域被标签字符是随机分布的,这让我们很难区分字符的独特特征(在这个图中就是数字0-9)。还有一个问题是无法研究连续的潜在空间。例如,我们没有针对任意输入而训练的统计模型(即使我们填补了潜在空间中的所有间隙也无法做到)。

这些传统自动编码器的问题意味着我们还要做出更多努力来学习数据生成分布并生成新的数据与图像。

现在已经了解了传统自动编码器是如何运行的,接下来讨论变分自动编码器。变分自动编码器采用了一种从贝叶斯统计中提取的变分推理形式,因此会比前几种自动编码器稍微复杂一些。我们会在下一节中更深入地讨论变分自动编码器。

变分自动编码器

变分自动编码器延续了传统自动编码器的结构,并利用这一结构来学习数据生成分布,这让我们可以从潜在空间中随机抽取样本。然后,可以使用解码器网络对这些随机样本进行解码,以生成独特的图像,这些图像与网络所训练的图像具有相似的特征。

代码详解:一文读懂自动编码器的前世今生

对于熟悉贝叶斯统计的人来说,编码器正在学习后验分布的近似值。这种分布通常很难分析,因为它没有封闭式的解。这意味着我们要么执行计算上复杂的采样程序,如马尔可夫链蒙特卡罗(MCMC)算法,要么采用变分方法。正如你可能猜测的那样,变分自动编码器使用变分推理来生成其后验分布的近似值。

我们将会用适量的细节来讨论这一过程,但是如果你想了解更深入的分析,建议你阅览一下Jaan Altosaar撰写的博客。变分推理是研究生机器学习课程或统计学课程的一个主题,但是了解其基本概念并不需要拥有一个统计学学位。

若对背后的数学理论不感兴趣,也可以选择跳过这篇变分自动编码器(VAE)编码教程。

首先需要理解的是后验分布以及它无法被计算的原因。先看看下面的方程式:贝叶斯定理。这里的前提是要知道如何从潜变量“z”生成数据“x”。这意味着要搞清p(z|x)。然而,该分布值是未知的,不过这并不重要,因为贝叶斯定理可以重新表达这个概率。但是这还没有解决所有的问题,因为分母(证据)通常很难解。但也不是就此束手无辞了,还有一个挺有意思的办法可以近似这个后验分布值。那就是将这个推理问题转化为一个优化问题。

代码详解:一文读懂自动编码器的前世今生

要近似后验分布值,就必须找出一个办法来评估提议分布与真实后验分布相比是否更好。而要这么做,就需要贝叶斯统计员的最佳伙伴:KL散度。KL散度是两个概率分布相似度的度量。如果它们相等,那散度为零;而如果散度是正值,就代表这两个分布不相等。KL散度的值为非负数,但实际上它不是一个距离,因为该函数不具有对称性。可以采用下面的方式使用KL散度:

代码详解:一文读懂自动编码器的前世今生

这个方程式看起来可能有点复杂,但是概念相对简单。那就是先猜测可能生成数据的方式,并提出一系列潜在分布Q,然后再找出最佳分布q*,从将提议分布和真实分布的距离最小化,然后因其难解性将其近似。但这个公式还是有一个问题,那就是p(z|x)的未知值,所以也无法计算KL散度。那么,应该怎么解决这个问题呢?

这里就需要一些内行知识了。可以先进行一些计算上的修改并针对证据下界(ELBO)和p(x)重写KL散度:

代码详解:一文读懂自动编码器的前世今生

有趣的是ELBO是这个方程中唯一取决于所选分布的变量。而后者由于不取决于q,则不受所选分布的影响。因此,可以在上述方程中通过将ELBO(负值)最大化来使KL散度最小化。这里的重点是ELBO可以被计算,也就是说现在可以进行一个优化流程。

所以现在要做的就是给Q做一个好的选择,再微分ELBO,将其设为零,然后就大功告成了。可是开始的时候就会面临一些障碍,即必须选择最好的分布系列。

一般来说,为了简化定义q的过程,会进行平均场变分推理。每个变分参数实质上是相互独立的。因此,每个数据点都有一个单独的q,可被相称以得到一个联合概率,从而获得一个“平均场”q。

代码详解:一文读懂自动编码器的前世今生

实际上,可以选用任意多的场或者集群。比如在MINIST数据集中,可以选择10个集群,因为可能有10个数字存在。

要做的第二件事通常被称为再参数化技巧,通过把随机变量带离导数完成,因为从随机变量求导数的话会由于它的内在随机性而产生较大的误差。

代码详解:一文读懂自动编码器的前世今生

再参数化技巧较为深奥,但简单来说就是可以将一个正态分布写成均值加标准差,再乘以误差。这样在微分时,我们不是从随机变量本身求导数,而是从它的参数求得。

这个程序没有一个通用的闭型解,所以近似后验分布的能力仍然受到一定限制。然而,指数分布族确实有一个闭型解。这意味着标准分布,如正态分布、二项分布、泊松分布、贝塔分布等。所以,就算真正的后验分布值无法被查出,依然可以利用指数分布族得出最接近的近似值。

变分推理的奥秘在于选择分布区Q,使其足够大以求得后验分布的近似值,但又不需要很长时间来计算。

既然已经大致了解如何训练网络学习数据的潜在分布,那么现在可以探讨如何使用这个分布生成数据。

数据生成过程

观察下图,可以看出对数据生成过程的近似认为应生成数字‘2’,所以它从潜在变量质心生成数值2。但是也许不希望每次都生成一摸一样的数字‘2’,就好像上述端游例子所提的植物,所以我们根据一个随机数和“已学”的数值‘2’分布范围,在潜在空间给这一过程添加了一些随机噪声。该过程通过解码器网络后,我们得到了一个和原型看起来不一样的‘2’。

代码详解:一文读懂自动编码器的前世今生

这是一个非常简化的例子,抽象描述了实际自动编码器网络的体系结构。下图表示了一个真实变分自动编码器在其编码器和解码器网络使用卷积层的结构体系。从这里可以观察到,我们正在分别学习潜在空间中生成数据分布的中心和范围,然后从这些分布“抽样”生成本质上“虚假”的数据。

代码详解:一文读懂自动编码器的前世今生

该学习过程的固有性代表所有看起来很相似的参数(刺激相同的网络神经元放电)都聚集到潜在空间中,而不是随意的分散。如下图所示,可以看到数值2都聚集在一起,而数值3都逐渐地被推开。这一过程很有帮助,因为这代表网络并不会在潜在空间随意摆放字符,从而使数值之间的转换更有真实性。

代码详解:一文读懂自动编码器的前世今生

整个网络体系结构的概述如下图所示。希望读者看到这里,可以比较清晰地理解整个过程。我们使用一组图像训练自动编码器,让它学习潜在空间里均值和标准值的差,从而形成我们的数据生成分布。接下来,当我们要生成一个类似的图像,就从潜在空间的一个质心取样,利用标准差和一些随机误差对它进行轻微的改变,然后使其通过解码器网络。从这个例子可以明显看出,最终的输出看起来与输入图像相似,但却是不一样的。

代码详解:一文读懂自动编码器的前世今生

变分自动编码器编码指南

本节将讨论一个简单的去噪自动编码器,用于去除文档扫描图像上的折痕和污痕,以及去除Fashion MNIST数据集中的噪声。然后,在MNIST数据集训练网络后,就使用变分自动编码器生成新的服装。

去噪自编码器

Fashion MNIST

在第一个练习中,在Fashion MNIST数据集添加一些随机噪声(椒盐噪声),然后使用去噪自编码器尝试移除噪声。首先进行预处理:下载数据,调整数据大小,然后添加噪声。

  1. ## Download the data 
  2. (x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data() 
  3. ## normalize and reshape 
  4. x_train = x_train/255
  5. x_test = x_test/255
  6. x_train = x_train.reshape(-128281
  7. x_test = x_test.reshape(-128281
  8. # Lets add sample noise - Salt and Pepper 
  9. noise = augmenters.SaltAndPepper(0.1
  10. seq_object = augmenters.Sequential([noise]) 
  11. train_x_n = seq_object.augment_images(x_train * 255) / 255 
  12. val_x_n = seq_object.augment_images(x_test * 255) / 255 

接着,给自编码器网络创建结构。这包括多层卷积神经网络、编码器网络的最大池化层和解码器网络上的升级层。

  1. # input layer 
  2. input_layer =Input(shape=(28281)) 
  3.   
  4. # encodingarchitecture 
  5. encoded_layer1= Conv2D(64, (33), activation='relu', padding='same')(input_layer) 
  6. encoded_layer1= MaxPool2D( (22), padding='same')(encoded_layer1) 
  7. encoded_layer2= Conv2D(32, (33), activation='relu', padding='same')(encoded_layer1) 
  8. encoded_layer2= MaxPool2D( (22), padding='same')(encoded_layer2) 
  9. encoded_layer3= Conv2D(16, (33), activation='relu', padding='same')(encoded_layer2) 
  10. latent_view = MaxPool2D( (22),padding='same')(encoded_layer3) 
  11.   
  12. # decodingarchitecture 
  13. decoded_layer1= Conv2D(16, (33), activation='relu', padding='same')(latent_view) 
  14. decoded_layer1= UpSampling2D((22))(decoded_layer1) 
  15. decoded_layer2= Conv2D(32, (33), activation='relu', padding='same')(decoded_layer1) 
  16. decoded_layer2= UpSampling2D((22))(decoded_layer2) 
  17. decoded_layer3= Conv2D(64, (33), activation='relu')(decoded_layer2) 
  18. decoded_layer3= UpSampling2D((22))(decoded_layer3) 
  19. output_layer = Conv2D(1, (33), padding='same',activation='sigmoid')(decoded_layer3) 
  20.   
  21. # compile themodel 
  22. model =Model(input_layer, output_layer) 
  23. model.compile(optimizer='adam',loss='mse'
  24.   
  25. # run themodel 
  26. early_stopping= EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=5,mode='auto'
  27. history =model.fit(train_x_n, x_train, epochs=20, batch_size=2048,validation_data=(val_x_n, x_test), callbacks=[early_stopping]) 

所输入的图像,添加噪声的图像,和输出图像。

代码详解:一文读懂自动编码器的前世今生

从时尚MNIST输入的图像。

代码详解:一文读懂自动编码器的前世今生

添加椒盐噪声的输入图像。

代码详解:一文读懂自动编码器的前世今生

从去噪网络输出的图像。

从这里可以看到,我们成功从噪声图像去除相当的噪声,但同时也失去了一定量的服装细节的分辨率。这是使用稳健网络所需付出的代价之一。可以对该网络进行调优,使最终的输出更能代表所输入的图像。

文本清理

去噪自编码器的第二个例子包括清理扫描图像的折痕和暗黑区域。这是最终获得的输入和输出图像。

代码详解:一文读懂自动编码器的前世今生

输入的有噪声文本数据图像。

代码详解:一文读懂自动编码器的前世今生

经清理的文本图像。

为此进行的数据预处理稍微复杂一些,因此就不在这里进行介绍,预处理过程和相关数据可在GitHub库里获取。网络结构如下:

  1. input_layer= Input(shape=(2585401)) 
  2.   
  3. #encoder 
  4. encoder= Conv2D(64, (33), activation='relu', padding='same')(input_layer) 
  5. encoder= MaxPooling2D((22), padding='same')(encoder) 
  6.   
  7. #decoder 
  8. decoder= Conv2D(64, (33), activation='relu', padding='same')(encoder) 
  9. decoder= UpSampling2D((22))(decoder) 
  10. output_layer= Conv2D(1, (33), activation='sigmoid', padding='same')(decoder) 
  11.   
  12. ae =Model(input_layer, output_layer) 
  13.   
  14. ae.compile(loss='mse',optimizer=Adam(lr=0.001)) 
  15.   
  16. batch_size= 16 
  17. epochs= 200 
  18.   
  19. early_stopping= EarlyStopping(monitor='val_loss',min_delta=0,patience=5,verbose=1,mode='auto'
  20. history= ae.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,validation_data=(x_val, y_val), callbacks=[early_stopping]) 

变分自编码器

最后的压轴戏,是尝试从FashionMNIST数据集现有的服装中生成新图像。

其中的神经结构较为复杂,包含了一个称‘Lambda’层的采样层。

  1. batch_size = 16 
  2. latent_dim = 2 # Number of latent dimension parameters 
  3. # ENCODER ARCHITECTURE: Input -> Conv2D*4 -> Flatten -> Dense 
  4. input_img = Input(shape=(28281)) 
  5. x = Conv2D(323
  6.  padding='same',  
  7.  activation='relu')(input_img) 
  8. x = Conv2D(643
  9.  padding='same',  
  10.  activation='relu'
  11.  strides=(22))(x) 
  12. x = Conv2D(643
  13.  padding='same',  
  14.  activation='relu')(x) 
  15. x = Conv2D(643
  16.  padding='same',  
  17.  activation='relu')(x) 
  18. # need to know the shape of the network here for the decoder 
  19. shape_before_flattening = K.int_shape(x) 
  20. x = Flatten()(x) 
  21. x = Dense(32, activation='relu')(x) 
  22. # Two outputs, latent mean and (log)variance 
  23. z_mu = Dense(latent_dim)(x) 
  24. z_log_sigma = Dense(latent_dim)(x) 
  25. ## SAMPLING FUNCTION 
  26. def sampling(args): 
  27.  z_mu, z_log_sigma = args epsilon = K.random_normal(shape=(K.shape(z_mu)[0], latent_dim), 
  28.  mean=0., stddev=1.) 
  29.  return z_mu + K.exp(z_log_sigma) * epsilon 
  30. # sample vector from the latent distribution 
  31. z = Lambda(sampling)([z_mu, z_log_sigma]) 
  32. ## DECODER ARCHITECTURE 
  33. # decoder takes the latent distribution sample as input 
  34. decoder_input = Input(K.int_shape(z)[1:]) 
  35. # Expand to 784 total pixels 
  36. x = Dense(np.prod(shape_before_flattening[1:]), 
  37.  activation='relu')(decoder_input) 
  38. # reshape 
  39. x = Reshape(shape_before_flattening[1:])(x) 
  40. # use Conv2DTranspose to reverse the conv layers from the encoder 
  41. x = Conv2DTranspose(323
  42.  padding='same',  
  43.  activation='relu'
  44.  strides=(22))(x) 
  45. x = Conv2D(13
  46.  padding='same',  
  47.  activation='sigmoid')(x) 
  48. # decoder model statement 
  49. decoder = Model(decoder_input, x) 
  50. # apply the decoder to the sample from the latent distribution 
  51. z_decoded = decoder(z) 

这就是体系结构,但还是需要插入损失函数再合并KL散度。

  1. # construct a custom layer to calculate the loss 
  2. class CustomVariationalLayer(Layer): 
  3.  def vae_loss(self, x, z_decoded): 
  4.  x = K.flatten(x) 
  5.  z_decoded = K.flatten(z_decoded) 
  6.  # Reconstruction loss 
  7.  xent_loss = binary_crossentropy(x, z_decoded) 
  8.  # KL divergence 
  9.  kl_loss = -5e-4 * K.mean(1 + z_log_sigma - K.square(z_mu) - K.exp(z_log_sigma), axis=-1
  10.  return K.mean(xent_loss + kl_loss) 
  11.  # adds the custom loss to the class 
  12.  def call(self, inputs): 
  13.  x = inputs[0
  14.  z_decoded = inputs[1
  15.  loss = self.vae_loss(x, z_decoded) 
  16.  self.add_loss(loss, inputs=inputs) 
  17.  return x 
  18. # apply the custom loss to the input images and the decoded latent distribution sample 
  19. y = CustomVariationalLayer()([input_img, z_decoded]) 
  20. # VAE model statement 
  21. vae = Model(input_img, y) 
  22. vae.compile(optimizer='rmsprop', loss=None) 
  23. vae.fit(x=train_x, y=None, 
  24.  shuffle=True, 
  25.  epochs=20
  26.  batch_size=batch_size, 
  27.  validation_data=(val_x, None)) 

现在,可以查看重构的样本,看看网络能够学习到什么。

代码详解:一文读懂自动编码器的前世今生

从这里可以清楚看到鞋子、手袋和服装之间的过渡。在此并没有标出所有使画面更清晰的潜在空间。也可以观察到Fashion MNIST数据集现有的10件服装的潜在空间和颜色代码。

代码详解:一文读懂自动编码器的前世今生

可看出这些服饰分成了不同的集群。

【编辑推荐】

  1. 还为模拟流量测试发愁吗?!滴滴开源RDebug流量回放工具!
  2. 开发 7 年,我学到了什么?
  3. 12位中年程序员:代码一敲十年,收入虽高前途摇摆
  4. 26个适用于VMware管理员的强大工具
  5. GitHub上的开源代码到底受不受美国出口管制?
【责任编辑:张燕妮 TEL:(010)68476606】

点赞 0
分享:
大家都在看
猜你喜欢

订阅专栏+更多

Spring Boot 爬虫搜索轻松游

Spring Boot 爬虫搜索轻松游

全栈式开发之旅
共4章 | 美码师

28人订阅学习

Linux性能调优攻略

Linux性能调优攻略

性能调优规范
共15章 | 南非蚂蚁

134人订阅学习

VMware vSphere虚拟化常见故障

VMware vSphere虚拟化常见故障

搞定vSphere虚拟化
共18章 | 王春海

55人订阅学习

读 书 +更多

SOA 原理•方法•实践

本书并不是关于Web服务的又一本开发手册,抑或是开发技术的宝典之类的读物。本书的作者来自于IBM软件开发中心的SOA技术中心,作为最早的一...

订阅51CTO邮刊

点击这里查看样刊

订阅51CTO邮刊

51CTO服务号

51CTO播客