对抗生成网络GAN系列——EGBAD原理及缺陷检测实战
本文为稀土掘金技术社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究!
??作者简介:秃头小苏,致力于用最通俗的语言描述问题
??往期回顾:对抗生成网络GAN系列——GAN原理及手写数字生成小案例??对抗生成网络GAN系列——DCGAN简介及人脸图像生成案例??对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战
??近期目标:写好专栏的每一篇文章
??支持小苏:点赞????、收藏?、留言??
对抗生成网络GAN系列——EGBAD原理及缺陷检测实战写在前面??在上一篇,我为大家介绍了首次应用在缺陷检测中的GAN网络——ANoGAN。在文末总结了AnoGAN一个显而易见的劣势,即在测试阶段需要花费大量时间来搜索潜在变量z,这在很多应用场景中是难以接受的。本文针对上述所说缺点,介绍一种新的GAN网络——EGBAD,其在训练过程中通过一个巧妙的编码器实现对z的搜索,这样在测试过程中就可以节约大量时间。??????
??阅读本文之前,建议先对AnoGAN有一定了解,可参考下文:
[1]对抗生成网络GAN系列——AnoGAN原理及缺陷检测实战????????如果你准备就绪的话,就让我们一起来学学AnoGAN的改进版EGBAD吧!!!??????
EGBAD原理详解?????一直在说EGBAD,大家肯定一脸懵,到底什么才是EGBAD了?我们先来看看它的英文全称,即EFFICIENT GAN-BASED ANOMALY DETECTION,中文译为基于GAN的高效异常检测。通过说明EGBAD的字面含义,相信大家知道了EGBAD是用来干什么的了。没错,它也是用于缺陷检测的网络,是对AnoGAN的优化。至于具体是怎么优化的,且听下文分解。??????
??我们先来回顾一下AnoGAN是怎么设计的?AnoGAN分为训练和测试两个阶段进行,训练阶段使用正常数据训练一个DCGAN网络,在测试阶段,固定训练阶段的网络权重,不断更新潜在变量z,使得由z生成的假图像尽可能接近真实图片。【如果你对这个过程不熟悉的话,建议看看[1]中内容喔】在介绍EGBAD是怎么设计的前,我们先来看看EGBAD主要解决了AnoGAN什么问题?其实这点我在写在前面已经提及,AnoGAN在测试阶段要不断搜索潜在变量z,这消耗了大量时间,EGBAD的提出就是为了解决AnoGAN时间消耗大的问题。接着我们来就来看看EGBAD具体是怎么做的呢?EGBAD也分为训练和测试两个阶段进行。在训练阶段,不仅要训练生成器和判别器,还会定义一个编码器(encoder)结构并对其训练,encoder主要用于将输入图像通过网络转变成一个潜在变量。在测试阶段,冻结训练阶段的所以权重,之后通过encoder将输入图像变为潜在变量,最后在将潜在变量送入生成器,生成假图像。可以发现EGBAD没有在测试阶段搜索潜在变量,而是直接通过一个encoder结构将输入图像转变成潜在变量,这大大节省了时间成本。
??关于EGBAD训练过程模型示意图如下:【测试过程很简单啦,就不介绍了】
??可以看出判别器的输入有两个,一个是生成器生成的假图像x′,另一个是编码器生成的z′。具体生成器、编码器和判别器的结构如何,将在下章代码实战中介绍。??????
EGBAD代码实战代码下载地址?????同样,我将此部分的源码上传到Github上了,大家可以阅读README文件了解代码的使用,Github地址如下:
EGBAD-pytorch实现
??我认为你阅读README文件后已经对这个项目的结构有所了解,我在下文也会帮大家分析分析源码,但更多的时间大家应该自己动手去亲自调试,这样你会有不一样的收获。??????
数据读取??这部分和AnoGAN中完全一致,就不带大家一行行看调试结果了,不明白的可以阅读AnoGAN教程,这里直接上代码:
#导入相关包
importnumpyasnp
importpandasaspd
"""
mnist数据集读取
"""
##读取训练集数据(60000,785)
train=pd.read_csv(".\data\mnist_train.csv",dtype=np.float32)
##读取测试集数据(10000,785)
test=pd.read_csv(".\data\mnist_test.csv",dtype=np.float32)
#查询训练数据中标签为7、8的数据,并取前400个
train=train.query("labelin[7.0,8.0]").head(400)
#查询训练数据中标签为7、8的数据,并取前400个
test=test.query("labelin[2.0,7.0,8.0]").head(600)
#取除标签后的784列数据
train=train.iloc[:,1:].values.astype('float32')
test=test.iloc[:,1:].values.astype('float32')
#train:(400,784)--(400,28,28)
#test:(600,784)--(600,28,28)
train=train.reshape(train.shape[0],28,28)
test=test.reshape(test.shape[0],28,28)
模型搭建??这部分大家就潜心修行,慢慢调试代码吧,我也会给出每个模型的结构图辅助大家,就让我们一起来看看吧???
生成模型搭建"""定义生成器网络结构"""
classGenerator(nn.Module):
def__init__(self):
super(Generator,self).__init__()
defCBA(in_channel,out_channel,kernel_size=4,stride=2,padding=1,activation=nn.ReLU(inplace=True),bn=True):
seq=[]
seq+=[nn.ConvTranspose2d(in_channel,out_channel,kernel_size=kernel_size,stride=stride,padding=padding)]
ifbnisTrue:
seq+=[nn.BatchNorm2d(out_channel)]
seq+=[activation]
returnnn.Sequential(*seq)
seq=[]
seq+=[CBA(20,64*8,stride=1,padding=0)]
seq+=[CBA(64*8,64*4)]
seq+=[CBA(64*4,64*2)]
seq+=[CBA(64*2,64)]
seq+=[CBA(64,1,activation=nn.Tanh(),bn=False)]
self.generator_network=nn.Sequential(*seq)
defforward(self,z):
out=self.generator_network(z)
returnout
??生成模型的搭建其实很AnoGAN是完全一样的,我也给出生成网络的结构图,如下:
编码器模型搭建"""定义编码器结构"""
classencoder(nn.Module):
def__init__(self):
super(encoder,self).__init__()
defCBA(in_channel,out_channel,kernel_size=4,stride=2,padding=1,activation=nn.LeakyReLU(0.1,inplace=True)):
seq=[]
seq+=[nn.Conv2d(in_channel,out_channel,kernel_size=kernel_size,stride=stride,padding=padding)]
seq+=[nn.BatchNorm2d(out_channel)]
seq+=[activation]
returnnn.Sequential(*seq)
seq=[]
seq+=[CBA(1,64)]
seq+=[CBA(64,64*2)]
seq+=[CBA(64*2,64*4)]
seq+=[CBA(64*4,64*8)]
seq+=[nn.Conv2d(64*8,512,kernel_size=4,stride=1)]
self.feature_network=nn.Sequential(*seq)
self.embedding_network=nn.Linear(512,20)
defforward(self,x):
feature=self.feature_network(x).view(-1,512)
z=self.embedding_network(feature)
returnz
??这部分其实也很简单,就是一系列卷积的堆积,编码器的结构图如下:
判别模型搭建"""定义判别器网络结构"""
classDiscriminator(nn.Module):
def__init__(self):
super(Discriminator,self).__init__()
defCBA(in_channel,out_channel,kernel_size=4,stride=2,padding=1,activation=nn.LeakyReLU(0.1,inplace=True)):
seq=[]
seq+=[nn.Conv2d(in_channel,out_channel,kernel_size=kernel_size,stride=stride,padding=padding)]
seq+=[nn.BatchNorm2d(out_channel)]
seq+=[activation]
returnnn.Sequential(*seq)
seq=[]
seq+=[CBA(1,64)]
seq+=[CBA(64,64*2)]
seq+=[CBA(64*2,64*4)]
seq+=[CBA(64*4,64*8)]
seq+=[nn.Conv2d(64*8,512,kernel_size=4,stride=1)]
self.feature_network=nn.Sequential(*seq)
seq=[]
seq+=[nn.Linear(20,512)]
seq+=[nn.BatchNorm1d(512)]
seq+=[nn.LeakyReLU(0.1,inplace=True)]
self.latent_network=nn.Sequential(*seq)
self.critic_network=nn.Linear(1024,1)
defforward(self,x,z):
feature=self.feature_network(x)
feature=feature.view(feature.size(0),-1)
latent=self.latent_network(z)
out=self.critic_network(torch.cat([feature,latent],dim=1))
returnout,feature
??虽然判别器有两个输入,两个输出,但是结构也非常清晰,如下图所示:
??在模型搭建部分我还想提一点我们需要注意的地方,一般我们设计好一个网络结构后,我们往往会先设计一个tensor来作为网络的输入,看看网络输出是否是是我们预期的,如果是,我们再进行下一步,否则我们需要调整我们的结构以适应我们的输入。通常情况下,tensor的batch维度设为1就行,但是这里设置成1就会报错,提示我们需要设置一个batch大于1的整数,当将batch设置为2时,程序正常,至于产生这种现象的原因我目前也不是很清楚,大家注意一下,知道的也烦请告知一下。关于调试网络结构是否正常的代码如下,仅供参考:
if__name__=='__main__':
x=torch.ones((2,1,64,64))
z=torch.ones((2,20,1,1))
Generator=Generator()
Discriminator=Discriminator()
encoder=encoder()
output_G=Generator(z)
output_D1,output_D2=Discriminator(x,z.view(2,-1))
output_E=encoder(x)
print(output_G.shape)
print(output_D1.shape)
print(output_D2.shape)
print(output_E.shape)
模型训练数据集加载??这部分和AnoGAN一致,注意最终输入网络的图片尺寸都上采样成了64×64.
classimage_data_set(Dataset):
def__init__(self,data):
self.images=data[:,:,:,None]
self.transform=transforms.Compose([
transforms.ToTensor(),
transforms.Resize(64,interpolation=InterpolationMode.BICUBIC),
transforms.Normalize((0.1307,),(0.3081,))
])
def__len__(self):
returnlen(self.images)
def__getitem__(self,idx):
returnself.transform(self.images[idx])
#加载训练数据
train_set=image_data_set(train)
train_loader=DataLoader(train_set,batch_size=batch_size,shuffle=True)
加载模型、定义优化器、损失函数等参数??这部分也基本和AnoGAN类似,只不过添加了encoder网络的定义和优化器定义部分,如下:
#指定设备
device=torch.device(args.deviceiftorch.cuda.is_available()else"cpu")
#batch_size默认128
batch_size=args.batch_size
#加载模型
G=Generator().to(device)
D=Discriminator().to(device)
E=Encoder().to(device)
#训练模式
G.train()
D.train()
E.train()
#设置优化器
optimizerG=torch.optim.Adam(G.parameters(),lr=0.0001,betas=(0.0,0.9))
optimizerD=torch.optim.Adam(D.parameters(),lr=0.0001,betas=(0.0,0.9))
optimizerE=torch.optim.Adam(E.parameters(),lr=0.0004,betas=(0.0,0.9))
#定义损失函数
criterion=nn.BCEWithLogitsLoss(reduction='mean')
训练GAN网络"""
训练
"""
#开始训练
forepochinrange(args.epochs):
#定义初始损失
log_g_loss,log_d_loss,log_e_loss=0.0,0.0,0.0
forimagesintrain_loader:
images=images.to(device)
##训练判别器Discriminator
#定义真标签(全1)和假标签(全0)维度:(batch_size)
label_real=torch.full((images.size(0),),1.0).to(device)
label_fake=torch.full((images.size(0),),0.0).to(device)
#定义潜在变量z维度:(batch_size,20,1,1)
z=torch.randn(images.size(0),20).to(device).view(images.size(0),20,1,1).to(device)
#潜在变量喂入生成网络---fake_images:(batch_size,1,64,64)
fake_images=G(z)
#使用编码器将真实图像变成潜在变量image:(batch_size,1,64,64)--z_real:(batch_size,20)
z_real=E(images)
#真图像和假图像送入判别网络,得到d_out_real、d_out_fake维度:都为(batch_size,1)
d_out_real,_=D(images,z_real)
d_out_fake,_=D(fake_images,z.view(images.size(0),20))
#损失计算
d_loss_real=criterion(d_out_real.view(-1),label_real)
d_loss_fake=criterion(d_out_fake.view(-1),label_fake)
d_loss=d_loss_real+d_loss_fake
#误差反向传播,更新损失
optimizerD.zero_grad()
d_loss.backward()
optimizerD.step()
##训练生成器Generator
#定义潜在变量z维度:(batch_size,20,1,1)
z=torch.randn(images.size(0),20).to(device).view(images.size(0),20,1,1).to(device)
fake_images=G(z)
#假图像喂入判别器,得到d_out_fake维度:(batch_size,1)
d_out_fake,_=D(fake_images,z.view(images.size(0),20))
#损失计算
g_loss=criterion(d_out_fake.view(-1),label_real)
#误差反向传播,更新损失
optimizerG.zero_grad()
g_loss.backward()
optimizerG.step()
##训练编码器Encode
#使用编码器将真实图像变成潜在变量image:(batch_size,1,64,64)--z_real:(batch_size,20)
z_real=E(images)
#真图像送入判别器,记录结果d_out_real:(128,1)
d_out_real,_=D(images,z_real)
#损失计算
e_loss=criterion(d_out_real.view(-1),label_fake)
#误差反向传播,更新损失
optimizerE.zero_grad()
e_loss.backward()
optimizerE.step()
##累计一个epoch的损失,判别器损失、生成器损失、编码器损失分别存放到log_d_loss、log_g_loss、log_e_loss中
log_d_loss+=d_loss.item()
log_g_loss+=g_loss.item()
log_e_loss+=e_loss.item()
##打印损失
print(f'epoch{epoch},D_Loss:{log_d_loss/128:.4f},G_Loss:{log_g_loss/128:.4f},E_Loss:{log_e_loss/128:.4f}')
??这里总结一下上述训练的步骤,不断循环下列过程:
使用生成器从潜在变量z中创建假图像使用编码器从真实图像中创建潜在变量生成器和编码器结果送入判别器,进行训练使用生成器从潜在变量z中创建假图像训练生成器使用编码器从真实图像中创建潜在变量训练编码器关于第3步,我也简单画了个图帮大家理解下,如下:
??最后我们来展示一下生成图片的效果,如下图所示:
image-20220922173918214
缺陷检测??EGBAD缺陷检测非常简单,首先定义一个就算损失的函数,如下:
##定义缺陷计算的得分
defanomaly_score(input_image,fake_image,z_real,D):
#Residualloss计算
residual_loss=torch.sum(torch.abs(input_image-fake_image),(1,2,3))
#Discriminationloss计算
_,real_feature=D(input_image,z_real)
_,fake_feature=D(fake_image,z_real)
discrimination_loss=torch.sum(torch.abs(real_feature-fake_feature),(1))
#结合Residualloss和Discriminationloss计算每张图像的损失
total_loss_by_image=0.9*residual_loss+0.1*discrimination_loss
returntotal_loss_by_image
??接着我们只需要用Encoder网络生成潜在变量,在再用生成器即可得到假图像,最后计算假图像和真图像的损失即可,如下:
#加载测试数据
test_set=image_data_set(test)
test_loader=DataLoader(test_set,batch_size=5,shuffle=False)
input_images=next(iter(test_loader)).to(device)
#通过编码器获取潜在变量,并用生成器生成假图像
z_real=E(input_images)
fake_images=G(z_real.view(input_images.size(0),20,1,1))
#异常计算
anomality=anomaly_score(input_images,fake_images,z_real,D)
print(anomality.cpu().detach().numpy())
??最后可以保存一下真实图像和假图像的结果,如下:
torchvision.utils.save_image(input_images,f"result/Nomal.jpg")
torchvision.utils.save_image(fake_images,f"result/ANomal.jpg")
??我们来看一下结果:
??通过上图你发现了什么呢?是不是发现输入图像为7的图片的生成图像不是7而变成了8呢,究其原因,应该是生成器学到了更多关于数据8的特征,也就是说这个网络的生成效果并没有很好。
??我做了很多实验,发现EGBAD虽然测试时间上比AnoGAN快很多,但是它的稳定性似乎并没有很理想,很容易出现模式崩溃的问题。其实啊,GAN网络普遍存在着训练不稳定的现象,这也是一些大牛不断探索的方向,后面的文章我也会给大家介绍一些增加GAN训练稳定性的文章,敬请期待吧!??????
AnoGAN和EGBAD测试时间对比?????我们一直说EGBAD的测试时间相较AnoGAN短,从原理上来说确实是这样,但是具体是不是这样我们还要以实验为准。测试代码也很简单,只需要在测试过程中使用time.time()函数即可,具体可以参考我上传github中的源码,这里给出我测试两种网络在测试阶段所用时间(以秒为单位),如下图所示:
??通过上图数据可以看出,EGBAD比AnoGAN快的不是一点点,EGBAD的速度将近是AnoGAN的10000倍,这个数字还是很恐怖的。??????
总结??到此,EGBAD的全部内容就为大家介绍完了,如果你明白了AnoGAN的话,这篇文章对你来说应该是小菜一碟了。EGBAD大大的减少了测试所有时间,但是GAN网络普遍存在易模式崩溃、训练不稳定的现象,下一篇博文我将为大家介绍一些让GAN训练更稳定的技巧,敬请期待吧。??????
参考链接EFFICIENT GAN-BASED ANOMALY DETECTION??????
GAN 使用 Pytorch 进行异常检测的方法??????
如若文章对你有所帮助,那就??????
咻咻咻咻~~duang~~点个赞呗
阅读原文
网站开发网络凭借多年的网站建设经验,坚持以“帮助中小企业实现网络营销化”为宗旨,累计为4000多家客户提供品质建站服务,得到了客户的一致好评。如果您有网站建设、网站改版、域名注册、主机空间、手机网站建设、网站备案等方面的需求...
请立即点击咨询我们或拨打咨询热线:13245491521 13245491521 ,我们会详细为你一一解答你心中的疑难。 项目经理在线