GAN
WGAN
WGAN-GP
Lipschitz continuous function
1-Lipschitz continuity
mean value theorem
if is differentiable function then
gradient penalty
by let close to 1 the function can satisfy 1-Lipschitz requires
# https://github.com/Lornatang/WassersteinGAN_GP-PyTorch/blob/f2e2659089a4fe4cb7e1c4edeb5c5b9912e9c348/wgangp_pytorch/utils.py#L39
def calculate_gradient_penalty(model, real_images, fake_images, device,use_refiner):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake data
alpha = torch.randn((real_images.size(0), 1, 1 , 1), device=device)
# Get random interpolation between real and fake data
interpolates = (alpha * real_images + ((1 - alpha) * fake_images)).requires_grad_(True)
_, interpolates_real = model(interpolates,return_feature=False,use_refiner=use_refiner)
grad_outputs = torch.ones_like(interpolates_real, device=device)
# Get gradient w.r.t. interpolates
gradients = torch.autograd.grad(
outputs=interpolates_real,
inputs=interpolates,
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.reshape(gradients.size(0), -1)
gradient_penalty = torch.mean((gradients.norm(2, dim=1) - 1) ** 2)
return gradient_penalty