机器学习 - 选择模型

news2025/1/16 11:31:21

接着这一篇博客做进一步说明:
机器学习 - 准备数据

PyTorch moduleExplain
torch.nnContains all of the building blocks for computational graphs (essentially a series of computations executed in a particular way). nn 模块为用户提供了丰富的神经网络组件,包括各种层,激活函数,损失函数以及其他辅助功能。
torch.nn.ParameterStores tensors that can be used with nn.Module. If requires_grad=True gradients (used for updating model parameters via gradient descent) are calculated automatically, this is often referred to as “autograd”. 通常在定义神经网络模型时用于表示权重 (weights) 和 偏置 (biases) 等参数
torch.nn.ModuleThe base class for all neural network modules, all the building blocks for neural networks are subclasses. If you’re building a neural network in PyTorch, your models should subclass nn.Module. Requires a forward() method be implemented
torch.optimContains various optimization algorithms (these tell the model parameters stored in nn.Parameter how to best change to improve gradient descent and in turn reduce the loss).
def forward()All nn.Module subclasses require a forward() method, this defines the computation that will take place on the data passed to the particular nn.Module (e.g. the linear regression formula above).

可以这么理解,almost everything in a PyTorch neural network comes from torch.nn .

  • nn.Module contains the larger building blocks (layers)
  • nn.Parameter contains the smaller parameters like weights and biases (put these together to make nn.Module )
  • forward() tells the larger blocks how to make calculations on inputs (tensors full of data) within nn.Module(s)
  • torch.optim contains optimization methods on how to improve the parameters within nn.Parameter to better represent input data.

大概可以这么理解:module 里包含各种参数 (parameter),在 module 里做计算 (forward) 甚至可以通过修改参数来优化 (torch.optim)。

这里稍微介绍 Neural Network Block。
Neural Network Block 通常指的是神经网络中的一个模块化组件,它可以包含一个或多个层 (layers) 以及一些额外的操作,被设计用来完成特定的功能或实现特定的神经网络结构。
Neural Network Block的设计旨在简化神经网络模型的构建和管理,提高代码的可读性和可维护性。通过将神经网络模型划分为多个块,可以将模型的不同部分进行分离,使得每个部分都可以独立地设计,调整和复用。这种模块化的设计使得构建复杂的神经网络变得更加灵活和高效。
比如:卷积神经网络中的卷积块。

代码如下所示

import torch 

class LinearRegressionModel(nn.Module):  # child class nn.Module
  def __init__(self):
    super().__init__()

    # Initialize model parameters
    self.weights = nn.Parameter(torch.randn(1,
                                            dtype=torch.float),
                                requires_grad = True)
    self.bias = nn.Parameter(torch.randn(1,
                                         dtype=torch.float),
                             requires_grad = True)  # requires_grad=True means PyTorch will track the gradients of this specific parameter for use with torch.autograd and gradient descent (for many torch.nn modules, requires_grad=True is set by default)

  # Any child class of nn.Module needs to override forward()
  # This defines the forward computation of the model
  def forward(self, x: torch.Tensor) -> torch.tensor:
    return self.weights * x + self.bias

# Set manual seed since nn.Parameter are randomly initizalized
torch.manual_seed(42)

# Create an instance of the model (this is a subclass of nn.Module that contains nn.Parameter(s))
model_0 = LinearRegressionModel()

# Check the nn.Parameter(s) within the nn.Module subclass
print(f"Check the nn.Parameter(s): {list(model_0.parameters())}")

# List named parameters
print(f"List named parameters: {model_0.state_dict()}")

# 输出结果如下
Check the nn.Parameter(s): [Parameter containing:
tensor([0.3367], requires_grad=True), Parameter containing:
tensor([0.1288], requires_grad=True)]
List named parameters: OrderedDict([('weights', tensor([0.3367])), ('bias', tensor([0.1288]))])


使用 torch.inference_mode() 来做预测。
The data is passed to our model. It will go through the model’s forward() method and produce a result using the computation.

# Make predictions with model
with torch.inference_mode():
  y_test_preds = model_0(X_test)

As the name suggests, torch.inference_mode() is used when using a model for inference (making predictions). torch.inference_mode() turns off a bunch of things (like gradient tracking, which is necessary for training but not for inference) to make forward-passes (data going through the forward() method) faster.

# Check the predictions
print(f"Number of testing samples: {len(X_test)}")
print(f"Number of predictions made: {len(y_test_preds)}")
print(f"Predicted values (X_test):\n {y_test_preds}")

def plot_predictions(train_data = X_train,
                     train_labels = y_train,
                     test_data = X_test,
                     test_labels = y_test,
                     predictions = None):
  """
  Plots training data, test data and compares predictions
  """
  plt.figure(figsize=(10, 7))

  # Plot training data in blue
  plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")

  # Plot test data in green
  plt.scatter(test_data, test_labels, c="g", s=4, label="Test data")

  if predictions is not None:
    plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")

  plt.legend(prop={"size": 14})

plot_predictions(predictions=y_test_preds)

print(f"check the difference:\n {y_test - y_test_preds}")  # 可以发现两者之间的差距是很大的

# 结果如下
Number of testing samples: 10
Number of predictions made: 10
Predicted values (X_test):
 tensor([[0.3982],
        [0.4049],
        [0.4116],
        [0.4184],
        [0.4251],
        [0.4318],
        [0.4386],
        [0.4453],
        [0.4520],
        [0.4588]])
check the difference:
 tensor([[0.4618],
        [0.4691],
        [0.4764],
        [0.4836],
        [0.4909],
        [0.4982],
        [0.5054],
        [0.5127],
        [0.5200],
        [0.5272]])

将数据显示到图里
效果图

看到这了,给个赞呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1530409.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【软件】项目管理工具focalboard使用docker部署

github官方网址 使用宝塔进入docker从官方进行镜像仓库拉去mattermost/focalboard 容器》添加容器》容器名》镜像》暴露端口》加号》添加 注意的是原始容器端口号为8000和9092

Vue使用qrcodejs2实现生成二维码

Vue使用qrcodejs2实现生成二维码示例 业务需求 比如说我们需要对下方的列表数据访问地址列进行生成二维码,扫描后跳转对应的地址。 安装qrcodejs2依赖 npm i qrcodejs2引用 在我们需要使用的页面进行引用qrcodejs2 import QRCode from qrcodejs2定义我们的二维…

业务服务:redisson

文章目录 前言一、配置1. 添加依赖2. 配置文件/类3. 注入redission3. 封装工具类 二、应用1. RedisUtils工具类的基本使用 三、队列1. 工具类2. 普通队列2. 有界队列(限制数据量) 前言 redission是一个开源的java redis的客户端,在其基础上进…

备战蓝桥杯---牛客寒假训练营2VP

题挺好的,收获了许多 1.暴力枚举(许多巧妙地处理细节方法) n是1--9,于是我们可以直接暴力,对于1注意特判开头0但N!1,对于情报4,我们可以把a,b,c,d的所有取值枚举一遍,那么如何判断有…

机器学习——编程实现从零构造训练集的决策树

自己搭建一棵决策树【长文预警】 忙了一个周末就写到了“构建决策树”这一步,还没有考虑划分测试集、验证集、“缺失值、连续值”,预剪枝、后剪枝的部分,后面再补吧(挖坑) 第二节内容:验证集划分\k折交叉…

Docker-安装

Docker ⛅Docker-安装🌠各平台支持情况🌠Server 版本安装☃️Ubuntu☃️Centos 🌠Docker 镜像源修改🌠Docker 目录修改 ⛅Docker-安装 🌠各平台支持情况 🌠Server 版本安装 ☃️Ubuntu 🍂安装…

地脚螺栓的介绍

地脚螺栓简单来说,它是一种机械构件。通常用于铁路、公路、电力、桥梁、锅炉钢结构、塔吊、大型建筑等。一头预埋在地底下,另一头穿过设备用螺母拧紧,用来固定设备,钢结构设施,铁塔设施等,所以地脚螺栓拥有…

上海亚商投顾:沪指震荡调整 北向资金全天净卖出超70亿

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 一.市场情绪 沪指昨日震荡调整,创业板指尾盘跌超1%。猪肉股集体反弹,播恩集团、湘佳股份、傲农生物…

vue项目:使用xlsx导出Excel数据

文章目录 一、安装xlsx二、报错及解决三、编写公共方法四、方法使用 一、安装xlsx 执行命令:npm i xlsx file-saver --save 二、报错及解决 使用时:import XLSX from "xlsx"; 发现如下报错信息 报错原因:xlsx版本不兼容。 解…

幼犬狗粮和成年犬狗粮该怎么挑选?

亲爱的狗友们,我们都知道,给狗狗选择适合的狗粮是非常重要的。那么,面对市面上琳琅满目的幼犬狗粮和成年犬狗粮,我们该如何挑选呢?别担心,接下来就让我来给大家支支招。 🐶 幼犬狗粮挑选篇 &…

Linux 网络接口管理

为了更深入的了解linux系统,为此做出网络接口管理的知识总结。看起来麻烦,其实一点都不难,相信多看多了解总会是没错的!❤️❤️ 一起加油吧!✨✨🎉🎉 文章目录 前言一、网络配置的文件介绍二、…

路由器怎么做端口映射

路由器在网络中起到了连接不同设备和提供网络服务的重要作用。端口映射是一项常见的操作,它允许外部网络中的设备通过路由器访问内部网络中的设备。我们将介绍如何在路由器上进行端口映射的设置。 理解端口映射 在开始操作之前,我们需要了解一些基本概念…

JJJ:改善ubuntu网速慢的方法

Ubuntu 系统默认的软件下载源由于服务器的原因, 在国内的下载速度往往比较慢,这时我 们可以将 Ubuntu 系统的软件下载源更改为国内软件源,譬如阿里源、中科大源、清华源等等, 下载速度相比 Ubuntu 官方软件源会快很多!…

机器学习 - 训练模型

接着这一篇博客做进一步说明: 机器学习 - 选择模型 为了解决测试和预测之间的差距,可以通过更新 internal parameters, the weights set randomly use nn.Parameter() and bias set randomly use torch.randn(). Much of the time you won’t know what…

Python内置对象

Python是一种强大的、动态类型的高级编程语言,其内置对象是构成程序的基础元素。Python的内置对象包括数字、字符串、列表、元组、字典、集合、布尔值和None等,每种对象都有特定的类型和用途。 01 什么是内置对象 这些对象是编程语言的基础构建块&…

C语言 指针练习

一、 a、b是两个浮点型变量&#xff0c;给a、b赋值&#xff0c;建立两个指针分别指向a的地址和b的地址&#xff0c;输出两个指针的值。 #include<stdio.h> int main() {float a,b,*p1,*p2;a10.2;b2.3;p1&a;p2&b;printf("a%f,b%f\n",a,b);printf("…

软考高级:类的分类(边界类、控制类、实体类)概念和例题

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

【Java初阶(二)】分支与循环

❣博主主页: 33的博客❣ ▶文章专栏分类: Java从入门到精通◀ &#x1f69a;我的代码仓库: 33的代码仓库&#x1f69a; 目录 1.前言2.顺序结构3.分支循环3.1if语句3.2switch语句 4.循环结构4.1while循环4.2 break和continue4.3 for循环4.4 do while循环 5.输入输出5.1输出5.2输…

记录C++中,子类同名属性并不能完全覆盖父类属性的问题

问题代码&#xff1a; 首先看一段代码&#xff1a;很简单&#xff0c;就是BBB继承自AAA&#xff0c;然后BBB重写定义了同名属性&#xff0c;然后调用父类AAA的打印函数&#xff1a; #include <iostream> using namespace std;class AAA { public:AAA() {}~AAA() {}void …

Django单表数据库操作

单表操作 测试脚本 当你只想测试django某一个py文件的内容,可以不用书写前后端的交互,直接写一个测试脚本即可 单表删除 数据库操作方法: 1.all():查询所有的数据 2.filter():带有过滤条件的查询 3.get():直接拿数据对象,不存在则报错 4.first():拿queryset里面的第一个元素…