使用PyTorch實作Gradient Reversal Layer

Yanwei Liu
Dec 11, 2021

--

在採用對抗學習方法的Domain Adaptation程式碼當中,大多數都會使用Gradient Reversal的方式來進行反向傳播。

只不過,舊版PyTorch(如:0.3或0.4)寫法與現在新版(1.3之後)無法相容,會出現RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method.的錯誤,因此需要一些調整。詳細可參考下方的程式碼。

舊版PyTorch實作方式

class GradReverse(torch.autograd.Function):
def __init__(self, lambd):
self.lambd = lambd
def forward(self, x):
return x.view_as(x)
def backward(self, grad_output):
return (grad_output * -self.lambd)
def grad_reverse(x, lambd=1.0):
return GradReverse(lambd)(x)

新版本PyTorch實作方式

class GradReverse(torch.autograd.Function):
def __init__(self):
super(GradReverse, self).__init__()
@ staticmethod
def forward(ctx, x, lambda_):
ctx.save_for_backward(lambda_)
return x.view_as(x)
@ staticmethod
def backward(ctx, grad_output):
lambda_, = ctx.saved_variables
grad_input = grad_output.clone()
return - lambda_ * grad_input, None
def grad_reverse(x, lambd=1.0):
lam = torch.tensor(lambd)
return GradReverse.apply(x,lam)

參考資料

Gradient Reversal Layer指什么? — 知乎 (zhihu.com)

Open Set Domain Adaptation by Backpropagation(OSBP)论文数字数据集复现_且慢-CSDN博客

pytorch 实现Gradient Flipping 各种坑 — 知乎 (zhihu.com)

--

--

No responses yet