帝王谷资源网 Design By www.wdxyy.com
pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。
那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`
import torch import numpy as np from PIL import Image from torch.autograd import gradcheck class Bicubic(torch.autograd.Function): def basis_function(self, x, a=-1): x_abs = np.abs(x) if x_abs < 1 and x_abs >= 0: y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1 elif x_abs > 1 and x_abs < 2: y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a else: y = 0 return y def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'): # data_in = data_in.detach().numpy() self.grad = np.zeros(data_in.shape,dtype=np.float32) obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2]) data_tmp = data_in.copy() data_obj = np.zeros(shape=obj_shape, dtype=np.float32) data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode) print(data_tmp.shape) for axis0 in range(obj_shape[0]): f_0 = float(axis0) / scale - np.floor(axis0 / scale) int_0 = int(axis0 / scale) + 2 axis0_weight = np.array( [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]]) for axis1 in range(obj_shape[1]): f_1 = float(axis1) / scale - np.floor(axis1 / scale) int_1 = int(axis1 / scale) + 2 axis1_weight = np.array( [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]]) nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32) grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight) for i in range(4): for j in range(4): nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :] for ii in range(data_in.shape[2]): self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j] tmp = np.matmul(axis0_weight, nbr_pixel) data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0] # img = np.transpose(img[0, :, :, :], [1, 2, 0]) return data_obj def forward(self,input): print(type(input)) input_ = input.detach().numpy() output = self.bicubic_interpolate(input_) # return input.new(output) return torch.Tensor(output) def backward(self,grad_output): print(self.grad.shape,grad_output.shape) grad_output.detach().numpy() grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32) for i in range(self.grad.shape[0]): for j in range(self.grad.shape[1]): grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:] grad_input = grad_output_tmp*self.grad print(type(grad_input)) # return grad_output.new(grad_input) return torch.Tensor(grad_input) def bicubic(input): return Bicubic()(input) def main(): hr = Image.open('./baboon/baboon_hr.png').convert('L') hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2)) hr.requires_grad = True lr = bicubic(hr) print(lr.is_leaf) loss=torch.mean(lr) loss.backward() if __name__ =='__main__': main()
要想实现自动求导,必须同时实现forward(),backward()两个函数。
1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。
2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数
以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
帝王谷资源网 Design By www.wdxyy.com
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
帝王谷资源网 Design By www.wdxyy.com
暂无评论...
更新日志
2025年01月08日
2025年01月08日
- 小骆驼-《草原狼2(蓝光CD)》[原抓WAV+CUE]
- 群星《欢迎来到我身边 电影原声专辑》[320K/MP3][105.02MB]
- 群星《欢迎来到我身边 电影原声专辑》[FLAC/分轨][480.9MB]
- 雷婷《梦里蓝天HQⅡ》 2023头版限量编号低速原抓[WAV+CUE][463M]
- 群星《2024好听新歌42》AI调整音效【WAV分轨】
- 王思雨-《思念陪着鸿雁飞》WAV
- 王思雨《喜马拉雅HQ》头版限量编号[WAV+CUE]
- 李健《无时无刻》[WAV+CUE][590M]
- 陈奕迅《酝酿》[WAV分轨][502M]
- 卓依婷《化蝶》2CD[WAV+CUE][1.1G]
- 群星《吉他王(黑胶CD)》[WAV+CUE]
- 齐秦《穿乐(穿越)》[WAV+CUE]
- 发烧珍品《数位CD音响测试-动向效果(九)》【WAV+CUE】
- 邝美云《邝美云精装歌集》[DSF][1.6G]
- 吕方《爱一回伤一回》[WAV+CUE][454M]