使用Pytorch实现linear_regression

news2025/1/11 15:56:10

使用Pytorch实现线性回归

# import necessary packages
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Set necessary Hyper-parameters.
input_size = 1
output_size = 1
num_epochs = 60
learning_rate = 0.001
# Define a Toy dataset.
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], 
                    [9.779], [6.182], [7.59], [2.167], [7.042], 
                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)

y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], 
                    [3.366], [2.596], [2.53], [1.221], [2.827], 
                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)

# Confirm the data shape.
print(x_train.shape, y_train.shape)
(15, 1) (15, 1)
# Linear regression model
model = nn.Linear(input_size, output_size)
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, )
# Train the model
for epoch in range(num_epochs):
    # Convert numpy arrays to torch tensors
    inputs = torch.from_numpy(x_train)
    targets = torch.from_numpy(y_train)

    # Forward pass
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Set an output counter
    if (epoch+1) % 5 == 0:
        print('Epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# Plot the graph
predicted = model(torch.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()
Epoch [5/60], loss: 7.1598
Epoch [10/60], loss: 3.0717
Epoch [15/60], loss: 1.4154
Epoch [20/60], loss: 0.7443
Epoch [25/60], loss: 0.4722
Epoch [30/60], loss: 0.3618
Epoch [35/60], loss: 0.3169
Epoch [40/60], loss: 0.2985
Epoch [45/60], loss: 0.2909
Epoch [50/60], loss: 0.2876
Epoch [55/60], loss: 0.2861
Epoch [60/60], loss: 0.2853

在这里插入图片描述

# Save the model checkpoint
torch.save(model.state_dict(), 'model_param.ckpt')
torch.save(model, 'model.ckpt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1236280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity团结引擎使用总结

团结引擎创世版以 Unity 2022 LTS 为研发基础,与 Unity 2022 LTS 兼容、UI 也基本保持一致,使 Unity 开发者可以无缝转换到团结引擎。融入了团结引擎独有功能和优化,未来会加入更多为中国开发者量身定制的功能和优化。 目前正在内测&#xf…

函数模板(成长版)

与普通函数区别&#xff1a;1.多了个template<class T>;2.某些确定类型变不确定类型T 一&#xff1a;引子&#xff1a; #include<iostream> using namespace std; template<typename T> T Max(T a, T b) {return a > b ? a : b; } int main() {int x, …

深信服AC跨三层取mac,绑定ip/mac

拓扑图 目录 拓扑图 1.交换机配置团体名和版本号 2.配置跨三层取mac 3.配置策略 验证&#xff1a; “您的每一次阅读、点赞和分享&#xff0c;都是对我最大的鼓舞和动力。” 如果对亲爱您有所帮助&#xff0c;可以尝试支持一下博主&#xff0c;让博主更有动力 1.交换机配置…

nodejs express vue uniapp新闻发布系统源码

开发技术&#xff1a; node.js&#xff0c;mysql5.7&#xff0c;vscode&#xff0c;HBuilder nodejs express vue uniapp 功能介绍&#xff1a; 用户端&#xff1a; 登录注册 首页显示搜索新闻&#xff0c;新闻分类&#xff0c;新闻列表 点击新闻进入新闻详情&#xff0…

Ubuntu18 Opencv3.4.12 viz 3D显示安装、编译、移植

Opencv3.*主模块默认包括两个3D库 calib3d用于相机校准和三维重建 &#xff0c;viz用于三维图像显示&#xff0c;其中viz是cmake选配。 参考&#xff1a; https://docs.opencv.org/3.4.12/index.html 下载linux版本的源码 sources。 查看cmake apt list --installed | grep…

如何在IAR软件中使用STLINK V2编译下载和调试stm8单片机

安装使用IAR后&#xff0c;如使用系统默认设置&#xff0c;往往很难正常实现用stlink v2来下载和调试stm8芯片&#xff0c;我的解决方法如下&#xff1a; 1、打开项目的options菜单&#xff1a; 2、在项目的选项菜单中选择ST-LINK作为调试工具&#xff1a; 3、选择额外的输出…

【C+进阶之路】第六篇:C++11

文章目录 一、【C】C11&#xff08;1&#xff09;二、【C】C11&#xff08;2&#xff09; 一、【C】C11&#xff08;1&#xff09; 【C】C11&#xff08;1&#xff09; 二、【C】C11&#xff08;2&#xff09; 【C】C11&#xff08;2&#xff09; &#x1f339;&#x1f33…

AT89S52单片机

目录 一.AT89S52单片机的硬件组成 1.CPU(微处理器) (1)运算器 (2)控制器 2.数据存储器 (RAM) (1)片内数据存储器 (2)片外数据存储器 3.程序存储器(Flash ROM) 4.定时器/计数器 5.中断系统 6.串行口 7.P0口、P1口、P2口和P3口 8.特殊功能寄存器 (SFR) 常用的特殊功…

机器视觉系统选型-环形光源分类及应用场景

环形光源主要分为&#xff1f; 1.环形光源(高角度) 照射光线与水平方向成高角度夹角 外观缺陷检测字符识别PCB基板检测二维码读取- 2.环形光源(低角度) 照射光线与水平方向成低角度夹角各种边缘提取字符识别玻璃断面的损伤检测金属表面刻印、损伤 3.环形光源(高亮) 高亮度远…

STM32出现 Invalid Rom Table 芯片锁死解决方案

出现该现象的原因为板子外部晶振为25M&#xff0c;而程序软件上以8M为输入晶振频率&#xff0c;导致芯片超频锁死&#xff0c;无法连接、下载。 解决方案 断电&#xff0c;将芯片原来通过10k电阻接地的BOOT0引脚直接接3.3V&#xff0c;硬件上置1上电&#xff0c;连接目标板&am…

Polygon Miden VM架构总览

1. 计算类型 Programs程序有2种类型&#xff1a; 1&#xff09;Circuit电路&#xff1a;即&#xff0c;程序即电路。将程序转换为电路。2&#xff09;Virtual machine虚拟机&#xff1a;即&#xff0c;程序为电路的输入。【Miden VM属于此类型】 2. 何为ZK virtual machine…

TensorFlow实战教程(十八)-Keras搭建卷积神经网络及CNN原理详解

从本专栏开始,作者正式研究Python深度学习、神经网络及人工智能相关知识。前一篇文章详细讲解了Keras实现分类学习,以MNIST数字图片为例进行讲解。本篇文章详细讲解了卷积神经网络CNN原理,并通过Keras编写CNN实现了MNIST分类学习案例。基础性文章,希望对您有所帮助! 一…

机器视觉技术在现代汽车制造中的应用

原创 | 文 BFT机器人 机器视觉技术&#xff0c;利用计算机模拟人眼视觉功能&#xff0c;从图像中提取信息以用于检测、测量和控制&#xff0c;已广泛应用于现代工业&#xff0c;特别是汽车制造业。其主要应用包括视觉测量、视觉引导和视觉检测。 01 视觉测量 视觉测量技术用于…

2021秋招-面经

面经总结 微软STCA面试-面经 字节AI lab实习面试记录 腾讯PCG-腾讯新闻面试 百度(AIDU)-内容策略部门面试 百度(AIDU)-搜索策略-机器学习算法工程师 百度(AIDU)-知识图谱部门算法工程师(2020-07-08) 百度(AIDU)-NLP部门算法工程师(2020-07-10) 微软STCA面试-面经 2020-…

瑞格心理咨询系统设置多个管理员的操作方法

使用瑞格心理咨询系统&#xff0c;需要设置多个admin权限的管理员账号来管理&#xff0c;咨询厂家答复只能有1个管理员&#xff0c;个人觉得不可能&#xff0c;于是开始折腾。 解决办法&#xff1a; 在没有数据字典的情况下&#xff0c; 通过遍历数据库&#xff0c;发现用户信…

【19年扬大真题】已知a数组int a[ ]={1,2,3,4,5,6,7,8,9,10},编写程序,求a数组中偶数的个数和偶数的平均值

【18年扬大真题】 已知a数组int a[ ]{1,2,3,4,5,6,7,8,9,10}&#xff0c;编写程序&#xff0c;求a数组中偶数的个数和偶数的平均值 int main() {int arr[10] { 1,2,3,4,5,6,7,8,9,10 };int os 0;//偶数个数int sum 0;//偶数和float ave 0;//偶数平均值for (int i 0;i <…

基于Bagging集成学习方法的情绪分类预测模型研究(文末送书)

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

智能座舱架构与芯片- (12) 软件篇 中

三、智能座舱操作系统 3.1 概述 车载智能计算平台自下而上可大致划分为硬件平台、系统软件&#xff08;硬件抽象层OS内核中间件&#xff09;、功能软件&#xff08;库组件中间件&#xff09;和应用算法软件等四个部分。狭义上的OS特指可直接搭载在硬件上的OS内核&#xff1b;…

OpenAI宫斗大戏,奥特曼面临的选择

首先不得不说&#xff0c;这件事情进展真快&#xff0c;三四天时间之内&#xff0c;大量的消息&#xff0c;各种不同渠道的&#xff0c;各种不同角度的&#xff0c;其中也包括各种决策&#xff0c;速度之快真的是应接不暇&#xff0c;仿佛在看真人秀一般 这里简单帮大家梳理一…

单链表——OJ题(一)

目录 ​一.前言 二.移除链表元素 三.返回链表中间节点 四.链表中倒数第K个节点 五.合并两个有序链表 六.反转链表 七.链表分割 八.链表的回文结构 九.相交链表 十.环形链表 十一.环形链表&#xff08;二&#xff09; ​六.结语 一.前言 本文主要对平时的链表OJ进行…