从0到1,AI我来了- (1)从AI手写数字识别开始

news2024/11/16 4:25:41

前两篇我们我们把控制台、Python环境Anaconda 搞定了,接下来,我们快速进入主题,把AI 界的“Hello World” 实现一下,有个感觉,再逐步了解一些AI的概念。

1、Pytorch 安装

1) 什么是Pytorch?

        一个深度学习框架,封装了很多深度学习相关函数,目前是pytorch已成为最受欢迎的深度学习框架之一,除了它,目前还有一些在用的TensorFlow、Keras、MXNet、Caffe 等。

2) 为什么用Pytorch?

        我们公司在用它,行业主流也在用它,我没必要一开始学冷门,出了问题,找不到“知音”

3)如何用上Pytorch?

        官网介绍在这里:Start Locally | PyTorch

        macOS  通过Anaconda 安装pytorch

conda install pytorch torchvision -c pytorch

       正常情况下,你会遇到我一样的错误

❯ conda install pytorch torchvision -c pytorch

Channels:
 - pytorch
 - https://mirrors.ustc.edu.cn/anaconda/pkgs/free
 - https://mirrors.ustc.edu.cn/anaconda/pkgs/main
 - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
 - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
 - defaults
Platform: osx-arm64
Collecting package metadata (repodata.json): failed

CondaHTTPError: HTTP 429 TOO MANY REQUESTS for url <https://mirrors.ustc.edu.cn/anaconda/pkgs/free/osx-arm64/repodata.json>
Elapsed: 00:26.256248

An HTTP error occurred when trying to retrieve this URL.
HTTP errors are often intermittent, and a simple retry will get you on your way.
'https//mirrors.ustc.edu.cn/anaconda/pkgs/free/osx-arm64'

       解决办法我放到下方的附录部分了:

重新执行安装

❯ conda install pytorch torchvision -c pytorch

Channels:
 - pytorch
 - defaults
Platform: osx-arm64
Collecting package metadata (repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /opt/anaconda3

  added / updated specs:
    - pytorch
    - torchvision


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    libjpeg-turbo-2.0.0        |       h1a28f6b_0         386 KB  defaults
    pytorch-2.4.0              |         py3.12_0        57.6 MB  pytorch
    torchvision-0.19.0         |        py312_cpu         6.8 MB  pytorch
    ------------------------------------------------------------
                                           Total:        64.8 MB

The following NEW packages will be INSTALLED:

  libjpeg-turbo      anaconda/pkgs/main/osx-arm64::libjpeg-turbo-2.0.0-h1a28f6b_0
  pytorch            pytorch/osx-arm64::pytorch-2.4.0-py3.12_0
  torchvision        pytorch/osx-arm64::torchvision-0.19.0-py312_cpu


Proceed ([y]/n)? y


Downloading and Extracting Packages:

Preparing transaction: done
Verifying transaction: done
Executing transaction: done

验证一下,没有报错,则安装成功了

❯ python
Python 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 10:07:17) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>>

2、开始写程序

先有个体感,从MNIST 手写的一批数字中,训练出一个模型,然后再拿一部分MNIST数据区验证模型准确率。下面的程序打印了一些识别错误的反例。

先跑跑,下篇我们逐行分析一下。

import torch  
import torch.nn as nn  
import torch.optim as optim  
from torch.utils.data import DataLoader  
from torchvision import datasets, transforms  
from torch.utils.tensorboard import SummaryWriter  
import matplotlib.pyplot as plt  

# 数据加载与预处理  
transform = transforms.Compose([  
    transforms.ToTensor(),  
    transforms.Normalize((0.5,), (0.5,))  
])  

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)  
print("load success")
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)  

# 模型定义  
class SimpleCNN(nn.Module):  
    def __init__(self):  
        super(SimpleCNN, self).__init__()  
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  
        self.pool = nn.MaxPool2d(2, 2)  
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  
        self.fc2 = nn.Linear(128, 10)  

    def forward(self, x):  
        x = self.pool(nn.ReLU()(self.conv1(x)))  
        x = self.pool(nn.ReLU()(self.conv2(x)))  
        x = x.view(-1, 64 * 7 * 7)  
        x = nn.ReLU()(self.fc1(x))  
        x = self.fc2(x)  
        return x  

# 训练设置  
model = SimpleCNN()  
criterion = nn.CrossEntropyLoss()  
optimizer = optim.Adam(model.parameters(), lr=0.001)  
num_epochs = 10  

# TensorBoard 初始化  
writer = SummaryWriter()  

# 记录损失和准确率  
loss_values = []  
accuracy_values = []  

# 训练循环  
for epoch in range(num_epochs):  
    running_loss = 0.0  
    correct = 0  
    total = 0  
    
    for images, labels in train_loader:  
        optimizer.zero_grad()  
        outputs = model(images)  
        loss = criterion(outputs, labels)  
        loss.backward()  
        optimizer.step()  

        running_loss += loss.item()  

        _, predicted = torch.max(outputs.data, 1)  
        total += labels.size(0)  
        correct += (predicted == labels).sum().item()  
    
    avg_loss = running_loss / len(train_loader)  
    accuracy = 100 * correct / total  

    # 记录到 TensorBoard  
    writer.add_scalar('Loss/train', avg_loss, epoch)  
    writer.add_scalar('Accuracy/train', accuracy, epoch)  

    # 记录到列表  
    loss_values.append(avg_loss)  
    accuracy_values.append(accuracy)  

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')  

writer.close()  

# 测试模型  
model.eval()  
test_loss = 0.0  
correct = 0  
total = 0  
i=0

with torch.no_grad():  
    for images, labels in test_loader:  
        outputs = model(images)  
        loss = criterion(outputs, labels)  
        test_loss += loss.item()  

        _, predicted = torch.max(outputs.data, 1)  
        total += labels.size(0)  
        correct += (predicted == labels).sum().item() 
        
        if (predicted[i] != labels[i] ):
            print("我猜是",predicted[i],"实际是",labels[i])
            plt.imshow(images[i].reshape(28,28),cmap='Greys', interpolation='nearest')
            plt.show()
        
        

test_loss /= len(test_loader)  
test_accuracy = 100 * correct / total  
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')  

# 使用 Matplotlib 绘制损失和准确率  
plt.figure(figsize=(12, 4))  

# 绘制损失图  
plt.subplot(1, 2, 1)  
plt.plot(range(1, num_epochs + 1), loss_values, label='Loss')  
plt.title('Training Loss')  
plt.xlabel('Epochs')  
plt.ylabel('Loss')  
plt.legend()  

# 绘制准确率图  
plt.subplot(1, 2, 2)  
plt.plot(range(1, num_epochs + 1), accuracy_values, label='Accuracy', color='orange')  
plt.title('Training Accuracy')  
plt.xlabel('Epochs')  
plt.ylabel('Accuracy (%)')  
plt.legend()  

plt.tight_layout()  
plt.show()
load success
Epoch [1/10], Loss: 0.1447, Accuracy: 95.53%
Epoch [2/10], Loss: 0.0447, Accuracy: 98.61%
Epoch [3/10], Loss: 0.0303, Accuracy: 99.06%
Epoch [4/10], Loss: 0.0218, Accuracy: 99.28%
Epoch [5/10], Loss: 0.0164, Accuracy: 99.47%
Epoch [6/10], Loss: 0.0136, Accuracy: 99.52%
Epoch [7/10], Loss: 0.0105, Accuracy: 99.65%
Epoch [8/10], Loss: 0.0079, Accuracy: 99.75%
Epoch [9/10], Loss: 0.0076, Accuracy: 99.75%
Epoch [10/10], Loss: 0.0078, Accuracy: 99.72%

 

附录

1、Http 429 限制问题解决,找到 用户目录的 .condarc 文件

vim ~/.condarc

vim 编辑后,替换文件内容一下为下方的内容

channels:
  - defaults
show_channel_urls: true
default_channels:
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
custom_channels:
  conda-forge: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
  msys2: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
  bioconda: https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud
  menpo: https://mirrors.tuna.tsinghua.edu

再执行安装,即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1949987.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C语言常见字符函数和字符串函数精讲

目录 引言 一、字符函数 1.字符分类函数 2.字符转换函数 二、字符串函数 1.gets、puts 2.strlen 3.strcpy 4.strncpy 5.strcat 6.strncat 7.strcmp 8.strncmp 9.strstr 10.strchr 11.strtok 12.strlwr 13.strupr 引言 在C语言编程中&#xff0c;字符函数…

第二证券:股票交易费用有哪些?

出资者生意股票是需求付出生意费用的&#xff0c;一般来说股票的生意费用主要有以下几种&#xff1a; 1、证券公司佣金。这是证券公司收取的一种服务费&#xff0c;用于供给股票生意的途径和服务。证券公司佣金的份额由证券公司自行拟定&#xff0c;但最高不得超越成交金额的0…

elasticsearch8.14.1集群安装部署

elasticsearch安装部署&#xff0c;首先需要准备至少三台服务器&#xff0c;本例再windows11下安装三台vmware虚拟机&#xff0c;利用centOS7系统模拟服务器环境。 本例假设你已经安装了三台vmware和centOS7&#xff0c;且centOS7运行正常。接下来我们直接讲解elasticsearch下载…

通过IEC104转MQTT网关轻松接入阿里云平台

随着智能电网和物联网技术的飞速发展&#xff0c;电力系统中的传统IEC 104协议设备正面临向现代化、智能化转型的迫切需求。阿里云作为全球领先的云计算服务提供商&#xff0c;其强大的物联网平台为IEC 104设备的接入与数据处理提供了强大的支持。本文将深入探讨钡铼网关在MQTT…

Python seaborn超级细节篇-使用配色palette

本文分享Python seaborn中通过配色palette美化图形。配色(palette),用于设置color palette,例如,Set1、#a1c9f4、red等。 内容很多,快速浏览一下,节选自👉Python可视化-seaborn篇 这里展示部分, 5.3 palette设置图形配色 设置图形配色palette目的在于有效地展示数…

【ffmpeg命令入门】视频剪切,倍速与倒放

文章目录 前言1. 视频剪切2. 视频倍速公式说明例子 3. 视频倒放总结 前言 在视频编辑中&#xff0c;剪切、倍速和倒放是常见的操作&#xff0c;能够帮助我们调整视频的长度、播放速度以及播放顺序。掌握 FFmpeg 命令中的相关参数和用法将使视频处理变得更加高效。在这篇文章中…

vLLM——使用PagedAttention加速推理

参考自https://blog.vllm.ai/2023/06/20/vllm.html 介绍 vLLM是一个用于快速LLM推理和服务的开源库。vLLM 利用PagedAttention&#xff0c;可以有效地管理注意力键和值。PagedAttention 的 vLLM 重新定义了 LLM 服务的最新水平&#xff1a;它提供了比 HuggingFace Transforme…

jdk的major version和minor version是啥意思?

写在前面 1&#xff1a;正文 major version是大版本号&#xff0c;minor version是小版本号&#xff0c;但目前minor version都是0&#xff08;也可能是我没有发现&#x1f605;&#xff09;&#xff0c;如jdk8就是52&#xff0c;如下表&#xff1a; 可以看到jdk版本号和ma…

【Java】随机值设置

&#x1f389;欢迎大家收看&#xff0c;请多多支持&#x1f339; &#x1f970;关注小哇&#xff0c;和我一起成长&#x1f680;个人主页&#x1f680; 在Java中设置随机值通常涉及到java.util.Random类或Math.random()方法。 使用Math.random()方法 Math.random()生成的随机…

AI在Facebook的应用:预见智能化社交的新前景

在数字化时代&#xff0c;社交媒体平台已成为我们生活的重要组成部分&#xff0c;而人工智能&#xff08;AI&#xff09;的快速发展正推动着这些平台向更智能、更个性化的方向发展。Facebook&#xff0c;作为全球最大的社交网络平台之一&#xff0c;正不断探索和应用AI技术&…

leetcode日记(54)加一

很简单 class Solution { public:vector<int> plusOne(vector<int>& digits) {int ndigits.size();for(int in-1;i>0;i--){if(digits[i]<9){digits[i];break;}else if(i0){digits[i]0;digits.insert(digits.begin(),1);}else digits[i]0;}return digits…

【React】JSX 实现列表渲染

文章目录 一、基础语法1. 使用 map() 方法2. key 属性的使用 二、常见错误和注意事项1. 忘记使用 key 属性2. key 属性的选择 三、列表渲染的高级用法1. 渲染嵌套列表2. 条件渲染列表项3. 动态生成组件 四、最佳实践 在 React 开发中&#xff0c;列表渲染是一个非常常见的需求。…

家政项目小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;家政人员管理&#xff0c;家政服务管理&#xff0c;咨询信息管理&#xff0c;咨询服务管理&#xff0c;家政预约管理&#xff0c;留言板管理&#xff0c;系统管理 微信端账号功能…

LeetCode 637, 67, 399

文章目录 637. 二叉树的层平均值题目链接标签思路代码 67. 二进制求和题目链接标签思路代码 399. 除法求值题目链接标签思路导入value 属性find() 方法union() 方法query() 方法 代码 637. 二叉树的层平均值 题目链接 637. 二叉树的层平均值 标签 树 深度优先搜索 广度优先…

nginx 启动 ssl 模块

文章目录 前言nginx 启动 ssl 模块1. 下载2. 启动 ssl 模块 步骤3. 验证前言 如果您觉得有用的话,记得给博主点个赞,评论,收藏一键三连啊,写作不易啊^ _ ^。   而且听说点赞的人每天的运气都不会太差,实在白嫖的话,那欢迎常来啊!!! nginx 启动 ssl 模块 1. 下载 下载…

STM32智能家居控制系统教程

目录 引言环境准备智能家居控制系统基础代码实现&#xff1a;实现智能家居控制系统 4.1 数据采集模块 4.2 数据处理与分析模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;家居监测与优化问题解决方案与优化收尾与总结 1. 引言 智能家居控制系统通…

PHP多场地预定小程序系统源码

一键畅游多地&#xff01;多场地预定小程序的超实用指南 段落一&#xff1a;【开篇&#xff1a;告别繁琐&#xff0c;预订新体验】 &#x1f389;&#x1f680; 还在为多个活动或会议的场地预订而头疼不已吗&#xff1f;多场地预定小程序来拯救你啦&#xff01;它像是一位贴心…

GPU虚拟化和池化技术解读

GPU虚拟化到池化技术深度分析 在大型模型的推动下&#xff0c;GPU算力的需求日益增长。然而&#xff0c;企业常常受限于有限的GPU卡资源&#xff0c;即使采用虚拟化技术&#xff0c;也难以充分利用或持续使用这些资源。为解决GPU算力资源的不均衡问题&#xff0c;同时推动国产…

【Qt 】JSON 数据格式详解

文章目录 1. JSON 有什么作用?2. JSON 的特点3. JSON 的两种数据格式3.1 JSON 数组3.2 JSON 对象 4. Qt 中如何使用 JSON 呢&#xff1f;4.1 QJsonObject4.2 QJsonArray4.3 QJsonValue4.4 QJsonDocument 5. 构建 JSON 字符串6. 解析 JSON 字符串 1. JSON 有什么作用? &#x…

C++中的继承与多态1

目录 C中的继承与多态1 1.继承的概念及定义 1.1继承的概念 1.2 继承定义 1.2.1定义格式 1.2.2继承关系和访问限定符 1.2.3继承基类成员访问方式的变化 2.基类和派生类对象赋值转换 3.继承中的作用域 4.派生类的默认成员函数 5.继承与友元 6.继承与静态成员 7.复杂…