博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
『PyTorch』第十二弹_nn.Module和nn.functional
阅读量:6640 次
发布时间:2019-06-25

本文共 2010 字,大约阅读时间需要 6 分钟。

大部分nn中的层class都有nn.function对应,其区别是:

  • nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Parameter
  • nn.functional中的函数更像是纯函数,由def function(input)定义。

由于两者性能差异不大,所以具体使用取决于个人喜好。对于激活函数和池化层,由于没有可学习参数,一般使用nn.functional完成,其他的有学习参数的部分则使用类。但是Droupout由于在训练和测试时操作不同,所以建议使用nn.Module实现,它能够通过model.eval加以区分。

一、nn.functional函数基本使用

import torch as timport torch.nn as nnfrom torch.autograd import Variable as Vinput_ = V(t.randn(2, 3))model = nn.Linear(3, 4)output1 = model(input_)output2 = nn.functional.linear(input_, model.weight, model.bias)print(output1 == output2)b1 = nn.functional.relu(input_)b2 = nn.ReLU()(input_)print(b1 == b2)

 

二、搭配使用nn.Module和nn.functional

并不是什么难事,之前有接触过,nn.functional不需要放入__init__进行构造,所以不具有可学习参数的部分可以使用nn.functional进行代替。

 

# Author : Hellcat# Time   : 2018/2/11 import torch as timport torch.nn as nnimport torch.nn.functional as F class LeNet(nn.Module):    def __init__(self):        super(LeNet,self).__init__()        self.conv1 = nn.Conv2d(3, 6, 5)        self.conv2 = nn.Conv2d(6,16,5)        self.fc1 = nn.Linear(16*5*5,120)        self.fc2 = nn.Linear(120,84)        self.fc3 = nn.Linear(84,10)     def forward(self,x):        x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))        x = F.max_pool2d(F.relu(self.conv2(x)),2)        x = x.view(x.size()[0], -1)        x = F.relu(self.fc1(x))        x = F.relu(self.fc2(x))        x = self.fc3(x)        return x

 

三、nn.functional函数构造nn.Module类

两者主要的区别就是对于可学习参数nn.Parameter的识别能力,所以构造时添加了识别能力即可。

class Linear(nn.Module):    def __init__(self, in_features, out_features):        # nn.Module.__init__(self)        super(Linear, self).__init__()        self.w = nn.Parameter(t.randn(out_features, in_features))  # nn.Parameter是特殊Variable        self.b = nn.Parameter(t.randn(out_features))    def forward(self, x):        # wx+b        return F.linear(x, self.w, self.b)layer = Linear(4, 3)input = V(t.randn(2, 4))output = layer(input)print(output)

Variable containing:

 1.7498 -0.8839  0.5314
-2.4863 -0.6442  1.1036
[torch.FloatTensor of size 2x3]

 

转载地址:http://gcovo.baihongyu.com/

你可能感兴趣的文章
Linux设备驱动开发详解-Note(11)--- Linux 文件系统与设备文件系统(3)
查看>>
实习第一天之数据绑定:<%#Eval("PartyName")%>'
查看>>
POJ 2318 TOYS (计算几何,叉积判断)
查看>>
第四章 Spring与JDBC的整合
查看>>
开源的Android视频播放器
查看>>
Java多线程-概念与原理
查看>>
“无法在web服务器上启动调试,不是Debugger User组成员..."
查看>>
POJ1258Agri-Net
查看>>
使用Frame控件设计Silverlight的导航
查看>>
数据分析师们的不可不读的信息图与数据可视化图书
查看>>
嵌入式开发常用的一些命令
查看>>
产品设计的关键
查看>>
Virtual Treeview 安装以及入门
查看>>
多线程的那点儿事(之多线程调试)
查看>>
数据库记录锁表锁实际研究笔记 --- MSSQLSERVER
查看>>
GPIO实验(一)
查看>>
安装Exchange2010
查看>>
java Socket 获取本地主机ip
查看>>
【经验分享】URL链接地址最长是多少?
查看>>
进度条脚本
查看>>