任务要求:
自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:
import torch
from torch.autograd import Function
from torch.autograd import Variable
定义二值化函数
class BinarizedF(Function):
def forward(self, input):
self.save_for_backward(input)
a = torch.ones_like(input)
b = -torch.ones_like(input)
output = torch.where(input>=0,a,b)
return output
def backward(self, output_grad):
input, = self.saved_tensors
input_abs = torch.abs(input)
ones = torch.ones_like(input)
zeros = torch.zeros_like(input)
input_grad = torch.where(input_abs<=1,ones, zeros)
return input_grad
定义一个module
class BinarizedModule(nn.Module):
def __init__(self):
super(BinarizedModule, self).__init__()
self.BF = BinarizedF()
def forward(self,input):
print(input.shape)
output =self.BF(input)
return output
进行测试
a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)
其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s
class BinarizedF(Function):
def forward(self, input):
self.save_for_backward(input)
output = torch.ones_like(input)
output[input<0] = -1
return output
def backward(self, output_grad):
input, = self.saved_tensors
input_grad = output_grad.clone()
input_abs = torch.abs(input)
input_grad[input_abs>1] = 0
return input_grad
以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持亿速云。
亿速云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。