使用PyTorch實作Gradient Reversal Layer

class GradReverse(torch.autograd.Function):
def __init__(self, lambd):
self.lambd = lambd
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output * -self.lambd)
def grad_reverse(x, lambd=1.0):
return GradReverse(lambd)(x)
class GradReverse(torch.autograd.Function):
def __init__(self):
super(GradReverse, self).__init__()
@ staticmethod
def forward(ctx, x, lambda_):
ctx.save_for_backward(lambda_)
return x.view_as(x)
@ staticmethod
def backward(ctx, grad_output):
lambda_, = ctx.saved_variables
grad_input = grad_output.clone()
return - lambda_ * grad_input, None
def grad_reverse(x, lambd=1.0):
lam = torch.tensor(lambd)
return GradReverse.apply(x,lam)

--

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning | https://linktr.ee/yanwei

More from Medium

Salus: Fine-Grained GPU Sharing Primitives for Deep Learning Applications

Predicting Sine Wave Output and Visualizing the Deep Learning Network

Python Classes and Their Use in Keras

The computer vision bias trilogy: Data representativity