TensorFlow 示例摄氏度到华氏度的转换(一)

news2025/2/3 4:12:03

TensorFlow 实现神经网络模型来进行摄氏度到华氏度的转换,可以将其作为一个回归问题来处理。我们可以通过神经网络来拟合这个简单的转换公式。

1. 数据准备与预处理

2. 构建模型

3. 编译模型

4. 训练模型

5. 评估模型

6. 模型应用与预测

7. 保存与加载模型

8. 完整代码


1. 数据准备与预处理

你提供了摄氏度和华氏度的数据,并进行了标准化。标准化是为了使数据适应神经网络的训练,因为标准化可以加快训练过程并提高模型性能。

import numpy as np
import tensorflow as tf

# 温度数据:摄氏度到华氏度的转换
celsius = np.array([-50,-40, -10, 0, 8, 22, 35, 45, 55, 65, 75, 95], dtype=float)
fahrenheit = np.array([-58.0,-40.0,14.0,32.0,46.4,71.6,95.0,113.0,131.0,149.0,167.0,203.0], dtype=float)

# 数据标准化:计算均值和标准差
celsius_mean = np.mean(celsius)
celsius_std = np.std(celsius)

fahrenheit_mean = np.mean(fahrenheit)
fahrenheit_std = np.std(fahrenheit)

# 标准化输入和输出数据
celsius_normalized = (celsius - celsius_mean) / celsius_std
fahrenheit_normalized = (fahrenheit - fahrenheit_mean) / fahrenheit_std

2. 构建模型

在构建模型时,使用了一个简单的神经网络结构。神经网络包含了一个隐藏层和一个输出层。隐藏层使用了ReLU激活函数,输出层使用了线性激活函数,适合回归任务。

# 创建模型
model = tf.keras.Sequential([
    # 隐藏层,增加神经元数量,激活函数使用 ReLU
    tf.keras.layers.Dense(16, input_dim=1, activation='relu'),
    # 输出层,线性激活函数用于回归任务
    tf.keras.layers.Dense(1, activation='linear')
])

3. 编译模型

选择了Adam优化器,它在处理回归任务时表现较好,损失函数使用均方误差(MSE),这是回归问题中常用的损失函数。

# 编译模型,使用 Adam 优化器和均方误差损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mean_squared_error')

4. 训练模型

模型通过 fit() 方法进行训练,设置训练轮数(epochs)为5000轮。根据数据的复杂性和模型的表现,增加训练轮数可以帮助模型更好地收敛。

# 训练模型,设置训练轮数(epochs)增加到5000
model.fit(celsius_normalized, fahrenheit_normalized, epochs=5000)

5. 评估模型

训练完成后,你可以对模型进行评估。这里使用了一个测试集(test_celsius),并通过预测得到标准化的结果,然后将其恢复为原始的华氏度值。

# 测试模型
test_celsius = np.array([0, 20, 100], dtype=float)
test_celsius_normalized = (test_celsius - celsius_mean) / celsius_std
predictions_normalized = model.predict(test_celsius_normalized)

# 将预测结果从标准化值恢复到原始华氏度范围
predictions = predictions_normalized * fahrenheit_std + fahrenheit_mean

6. 模型应用与预测

最后,你可以输出预测的华氏度值。模型会对每个输入的摄氏度值返回预测的华氏度

# 输出预测结果
print("预测华氏度:")
for c, f in zip(test_celsius, predictions):
    print(f"{c} 摄氏度 => {f[0]} 华氏度")

7. 保存与加载模型

保存模型可以让你在之后加载并进行预测而不需要重新训练。在TensorFlow中,你可以使用 model.save() 来保存模型,使用 tf.keras.models.load_model() 来加载模型。

# 保存模型
model.save('temperature_conversion_model.h5')

# 加载模型
loaded_model = tf.keras.models.load_model('temperature_conversion_model.h5')

8. 完整代码

最终的完整代码如下:

import numpy as np
import tensorflow as tf

# 温度数据:摄氏度到华氏度的转换
celsius = np.array([-50,-40, -10, 0, 8, 22, 35, 45, 55, 65, 75, 95], dtype=float)
fahrenheit = np.array([-58.0,-40.0,14.0,32.0,46.4,71.6,95.0,113.0,131.0,149.0,167.0,203.0], dtype=float)

# 数据标准化:计算均值和标准差
celsius_mean = np.mean(celsius)
celsius_std = np.std(celsius)

fahrenheit_mean = np.mean(fahrenheit)
fahrenheit_std = np.std(fahrenheit)

# 标准化输入和输出数据
celsius_normalized = (celsius - celsius_mean) / celsius_std
fahrenheit_normalized = (fahrenheit - fahrenheit_mean) / fahrenheit_std

# 创建模型
model = tf.keras.Sequential([
    # 隐藏层,增加神经元数量,激活函数使用 ReLU
    tf.keras.layers.Dense(16, input_dim=1, activation='relu'),
    # 输出层,线性激活函数用于回归任务
    tf.keras.layers.Dense(1, activation='linear')
])

# 编译模型,使用 Adam 优化器和均方误差损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mean_squared_error')

# 训练模型,设置训练轮数(epochs)增加到5000
model.fit(celsius_normalized, fahrenheit_normalized, epochs=5000)

# 测试模型
test_celsius = np.array([0, 20, 100], dtype=float)
test_celsius_normalized = (test_celsius - celsius_mean) / celsius_std
predictions_normalized = model.predict(test_celsius_normalized)

# 将预测结果从标准化值恢复到原始华氏度范围
predictions = predictions_normalized * fahrenheit_std + fahrenheit_mean

# 输出预测结果
print("预测华氏度:")
for c, f in zip(test_celsius, predictions):
    print(f"{c} 摄氏度 => {f[0]} 华氏度")

# 保存模型
model.save('temperature_conversion_model.h5')

# 加载模型
loaded_model = tf.keras.models.load_model('temperature_conversion_model.h5')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2291086.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

win10部署本地deepseek-r1,chatbox,deepseek联网(谷歌网页插件Page Assist)

win10部署本地deepseek-r1,chatbox,deepseek联网(谷歌网页插件Page Assist) 前言一、本地部署DeepSeek-r1step1 安装ollamastep2 下载deepseek-r1step2.1 找到模型deepseek-r1step2.2 cmd里粘贴 后按回车,进行下载 ste…

【memgpt】letta 课程6: 多agent编排

Lab 6: Multi-Agent Orchestration 多代理协作 letta 是作为一个服务存在的,app通过restful api 通信 多智能体之间如何协调与沟通? 相互发送消息共享内存块,让代理同步到不同的服务的内存块

Java 大视界 -- Java 大数据在自动驾驶中的数据处理与决策支持(68)

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也…

【Linux】opencv在arm64上提示找不到libjasper-dev

解决opencv在arm64上提示找不到libjasper-dev的问题。 本文首发于❄慕雪的寒舍 问题说明 最近我在尝试编译opencv,安装依赖项libjasper1和libjasper-dev的时候就遇到了这个问题。在amd64平台上,我们可以通过下面的命令安装(ubuntu18.04&…

LabVIEW在电机自动化生产线中的实时数据采集与生产过程监控

在电机自动化生产线中,实时数据采集与生产过程监控是确保生产效率和产品质量的重要环节。LabVIEW作为一种强大的图形化编程平台,可以有效实现数据采集、实时监控和自动化控制。详细探讨如何利用LabVIEW实现这一目标,包括硬件选择、软件架构设…

Baklib揭示内容中台与人工智能技术的创新协同效应

内容概要 在当今信息爆炸的时代,内容的高效生产与分发已成为各行业竞争的关键。内容中台与人工智能技术的结合,为企业提供了一种新颖的解决方案,使得内容创造的流程更加智能化和高效化。 内容中台作为信息流动的核心,能够集中管…

18.Word:数据库培训课程❗【34】

目录 题目 NO1.2.3.4 NO5设置文档内容的格式与样式 NO6 NO7 NO8.9 NO10.11标签邮件合并 题目 NO1.2.3.4 FnF12:打开"Word素材.docx”文件,将其另存为"Word.docx”在考生文件夹下之后到任务9的所有操作均基于此文件:"Word.docx”…

git多人协作

目录 一、项目克隆 二、 1、进入克隆仓库设置 2、协作处理 3、冲突处理 4、多人协作分支的推送拉取删除 1、分支推送(2种) 2、远程分支拉取(2种) 3、远程分支删除 一、项目克隆 git clone 画船听雨眠/test1 (自定义的名…

什么是线性化PDF?

线性化PDF是一种特殊的PDF文件组织方式。 总体而言,PDF是一种极为优雅且设计精良的格式。PDF由大量PDF对象构成,这些对象用于创建页面。相关信息存储在一棵二叉树中,该二叉树同时记录文件中每个对象的位置。因此,打开文件时只需加…

SpringMVC的参数处理

一、参数接收 1.使用servlet API接收参数 在方法参数中添加HttpServletRequest类型的参数,然后就可以像servlet的方法一样来接收参数 2.在方法中定义同名参数 如果url地址中的参数名与方法的参数名不一致时,可以使用RequestParam注解进行重新关联 url地…

一觉醒来全球编码能力下降100000倍,新手小白的我决定科普C语言——函数

1. 函数的概念 数学中我们其实就⻅过函数的概念,⽐如:⼀次函数 y kx b ,k和b都是常数,给⼀个任意的 x,就得到⼀个y值。其实在C语⾔也引⼊函数(function)的概念,有些翻译为&#xf…

台账思维和GIS思维在资产管理中的不同模式

最近一些习惯用台账统计资产的网友聊天引发一些感想和大家分享一下:传统台账思维注重统计资产的数量及信息完整性,而GIS除了关心前两个指标外,更注重数据与现实世界是否能一一对应,即数据的现实准确性! 例如&#xff1…

AI-ISP论文Learning to See in the Dark解读

论文地址:Learning to See in the Dark 图1. 利用卷积网络进行极微光成像。黑暗的室内环境。相机处的照度小于0.1勒克斯。索尼α7S II传感器曝光时间为1/30秒。(a) 相机在ISO 8000下拍摄的图像。(b) 相机在ISO 409600下拍摄的图像。该图像存在噪点和色彩偏差。©…

Unbutu虚拟机+eclipse+CDT编译调试环境搭建

问题1: 安装CDT,直接Help->eclipse Market space-> 搜cdt , install,等待重启即可. 问题2:C变量不识别vector ’could not be resolved 这是库的头文件没加好,右键Properties->C Build->Enviroment,增加…

利用metaGPT多智能体框架实现智能体-1

1.metaGPT简介 MetaGPT 是一个基于大语言模型(如 GPT-4)的多智能体协作框架,旨在通过模拟人类团队的工作模式,让多个 AI 智能体分工合作,共同完成复杂的任务。它通过赋予不同智能体特定的角色(如产品经理、…

[CVPR 2024] AnyDoor: Zero-shot Object-level Image Customization

github.com/ali-vilab/AnyDoor.写在前面: 【论文速读】按照#论文十问#提炼出论文核心知识点,方便相关科研工作者快速掌握论文内容。过程中并不对论文相关内容进行翻译。博主认为翻译难免会损坏论文的原本含义,也鼓励诸位入门级科研人员阅读文…

Microsoft Power BI:融合 AI 的文本分析

Microsoft Power BI 是微软推出的一款功能强大的商业智能工具,旨在帮助用户从各种数据源中提取、分析和可视化数据,以支持业务决策和洞察。以下是关于 Power BI 的深度介绍: 1. 核心功能与特点 Power BI 提供了全面的数据分析和可视化功能&…

如何实现滑动列表功能

文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了沉浸式状态栏相关的内容,本章回中将介绍SliverList组件.闲话休提,让我们一起Talk Flutter吧。 1 概念介绍 我们在这里介绍的SliverList组件是一种列表类组件,类似我们之前介…

Linux——网络(tcp)

文章目录 目录 文章目录 前言 一、TCP逻辑 1. 面向连接 三次握手(建立连接) 四次挥手(关闭连接) 2. 可靠性 3. 流量控制 4. 拥塞控制 5. 基于字节流 6. 全双工通信 7. 状态机 8. TCP头部结构 9. TCP的应用场景 二、编写tcp代码函数…

算法题(54):插入区间

审题: 需要我们把newinterval的区间与interval的区间合并起来,并返回合并后的二维数组地址 思路: 方法一:排序合并区间 我们可以先把newinterval插入到interval中,进行排序然后复用合并区间的代码 方法二:模…