# FGSM attack codedeffgsm_attack(image,epsilon,data_grad):# Collect the element-wise sign of the data gradientsign_data_grad=data_grad.sign()# Create the perturbed image by adjusting each pixel of # the input imageperturbed_image=image+epsilon*sign_data_grad# Adding clipping to maintain [0,1] rangeperturbed_image=torch.clamp(perturbed_image,0,1)# Return the perturbed imagereturnperturbed_image
fordata,targetintest_loader:data,target=data.to(device),target.to(device)data.requires_grad=True# Forward pass the data through the modeloutput=model(data)# get the index of the max log-probabilityinit_pred=output.max(1,keepdim=True)[1]# If the initial prediction is wrong, dont bother attacking,# just move onifinit_pred.item()!=target.item():continue# Calculate the lossloss=F.nll_loss(output,target)# Zero all existing gradientsmodel.zero_grad()# Calculate gradients of model in backward passloss.backward()# Collect datagraddata_grad=data.grad.data# Call FGSM Attackperturbed_data=fgsm_attack(data,epsilon,data_grad)# Re-classify the perturbed imageoutput=model(perturbed_data)# Check for success attack# get the index of the max log-probabilityfinal_pred=output.max(1,keepdim=True)[1]iffinal_pred.item()==target.item():correct+=1