pytoch M2芯片测试

news2025/1/11 15:54:59

今天才发现我的新片是M2芯片,而不是M1芯片,有点尴尬
在这里插入图片描述
参考网址
https://www.oldcai.com/ai/pytorch-train-MNIST-with-gpu-on-mac/

测试结果如下

M2_cpu.py

# https://www.oldcai.com/ai/pytorch-train-MNIST-with-gpu-on-mac/
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device
device = "cpu"
device = torch.device(device)
print(f"Using device: {device}")


# Define the CNN model
class HandwritingRecognitionModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Define the convolutional layers
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        # Define the pooling and dropout layers
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

        # Define the fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Pass the input through the convolutional layers
        x = self.conv1(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.dropout2(x)

        # Reshape the output for the fully connected layers
        x = x.view(-1, 32 * 7 * 7)

        # Pass the output through the fully connected layers
        x = self.fc1(x)
        x = self.fc2(x)

        # Return the final output
        return x


# Load the MNIST dataset
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
test_dataset = MNIST("./data", train=False, download=True, transform=ToTensor())

# Define the data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the model
model = HandwritingRecognitionModel().to(device)

# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

from time import time 
t0 = time()
# Train the model for 10 epochs
for epoch in range(10):
    # Set the model to training mode
    model.train()

    # Iterate over the training data
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        # Pass the input through the model
        outputs = model(images)

        # Compute the loss
        loss = loss_fn(outputs, labels)

        # Backpropagate the error
        loss.backward()

        # Update the model parameters
        optimizer.step()

    # Set the model to evaluation mode
    model.eval()

    # Evaluate the model on the validation set
    with torch.no_grad():
        correct = 0
        total = 0

        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Pass the input through the model
            outputs = model(images)

            # Get the predicted labels
            _, predicted = torch.max(outputs.data, 1)

            # Update the total and correct counts
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # Print the accuracy
        print(f"Epoch {epoch + 1}: Accuracy = {100 * correct / total:.2f}%")


t1 =time()
print("10 epoch cost {}s".format(t1-t0))

结果如下
在这里插入图片描述

M2_MPS.py

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = torch.device(device)
print(f"Using device: {device}")


# Define the CNN model
class HandwritingRecognitionModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Define the convolutional layers
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        # Define the pooling and dropout layers
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

        # Define the fully connected layers
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Pass the input through the convolutional layers
        x = self.conv1(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.dropout2(x)

        # Reshape the output for the fully connected layers
        x = x.view(-1, 32 * 7 * 7)

        # Pass the output through the fully connected layers
        x = self.fc1(x)
        x = self.fc2(x)

        # Return the final output
        return x


# Load the MNIST dataset
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
test_dataset = MNIST("./data", train=False, download=True, transform=ToTensor())

# Define the data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the model
model = HandwritingRecognitionModel().to(device)

# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

from time import time 
t0 = time()
# Train the model for 10 epochs
for epoch in range(10):
    # Set the model to training mode
    model.train()

    # Iterate over the training data
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        # Pass the input through the model
        outputs = model(images)

        # Compute the loss
        loss = loss_fn(outputs, labels)

        # Backpropagate the error
        loss.backward()

        # Update the model parameters
        optimizer.step()

    # Set the model to evaluation mode
    model.eval()

    # Evaluate the model on the validation set
    with torch.no_grad():
        correct = 0
        total = 0

        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            # Pass the input through the model
            outputs = model(images)

            # Get the predicted labels
            _, predicted = torch.max(outputs.data, 1)

            # Update the total and correct counts
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # Print the accuracy
        print(f"Epoch {epoch + 1}: Accuracy = {100 * correct / total:.2f}%")


t1 =time()
print("10 epoch cost {}s".format(t1-t0))

结果如下
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1082376.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WebRTC 系列(四、多人通话,H5、Android、iOS)

WebRTC 系列(三、点对点通话,H5、Android、iOS) 上一篇博客中,我们已经实现了点对点通话,即一对一通话,这一次就接着实现多人通话。多人通话的实现方式呢也有好几种方案,这里我简单介绍两种方案…

Linux开启SSH

Linux开启SSH 1.虚拟机确定连通性 如果是虚拟机的话则需要进行确定和宿主主机之间能正常联通(不能联通还远程个啥) 获取到虚拟机的IP 参考文章:Linux获取本机IP地址使用宿主机ping一下虚拟机的IP查看是否联通 2.安装SSH服务端 安装工具来使得能够通过SSH进行连接 命令 sudo a…

【推荐系统】推荐系统(RS)与大模型(LLM)的结合

【推荐系统】推荐系统(RS)与大模型(LLM)的结合 文章目录 【推荐系统】推荐系统(RS)与大模型(LLM)的结合1. 主流的推荐方法2. 大模型(LLM)可能作用的地方 1. 主…

Spring源码解析——ApplicationContext容器refresh过程

正文 在之前的博文中我们一直以BeanFactory接口以及它的默认实现类XmlBeanFactory为例进行分析,但是Spring中还提供了另一个接口ApplicationContext,用于扩展BeanFactory中现有的功能。 ApplicationContext和BeanFactory两者都是用于加载Bean的&#x…

graphviz 绘制单链表

dot 代码 digraph LinkedList {rankdirLR; // 设置布局方向为从左到右(左侧到右侧)node [fontname"Arial", shaperecord, stylefilled, color"#ffffff", fillcolor"#0077be", fontsize12, width1.5, height0.5];edge [fo…

汉诺塔问题:递归

经典汉诺塔问题 汉诺塔问题是经典的可以用递归解决的问题。 汉诺塔(Hanoi)游戏规则如下:在一块铜板装置上,有三根杆(编号A、B、C),在A杆自下而上、由大到小按顺序放置64个金盘(如下图)。游戏的目标:把A杆上的金盘全部移到C杆上&a…

系统架构师备考倒计时25天(每日知识点)

面向对象设计原则 单一职责原则:设计目的单一的类开放-封闭原则:对扩展开放,对修改封闭李氏(Liskov)替换原则:子类可以替换父类依赖倒置原则:要依赖于抽象,而不是具体实现;针对接口编程&#x…

【Redis】Redis持久化深度解析

原创不易,注重版权。转载请注明原作者和原文链接 文章目录 Redis持久化介绍RDB原理Fork函数与写时复制关于写时复制的思考 RDB相关配置 AOF原理AOF持久化配置AOF文件解读AOF文件修复AOF重写AOF缓冲区与AOF重写缓存区AOF缓冲区可以替代AOF重写缓冲区吗AOF相关配置写后…

机器学习之自训练协同训练

前言 监督学习往往需要大量的标注数据, 而标注数据的成本比较高 . 因此 , 利用大量的无标注数据来提高监督学习的效果有着十分重要的意义. 这种利用少量标注数据和大量无标注数据进行学习的方式称为 半监督学习 ( Semi…

Java 序列化和反序列化为什么要实现 Serializable 接口?

序列化和反序列化 序列化:把对象转换为字节序列的过程称为对象的序列化. 反序列化:把字节序列恢复为对象的过程称为对象的反序列化. 什么时候需要用到序列化和反序列化呢或者对象序列化的两种用途… : (1) 对象持久化:把对象的…

NSSCTF做题(8)

[SWPUCTF 2022 新生赛]js_sign 看到了js代码 有一个base64编码,解密 最后发现这是一个加密方式 去掉空格之后得到了flag NSSCTF{youfindflagbytapcode} [MoeCTF 2022]baby_file 提示说有一个秘密看看你能不能找到 输入?filesecret 出现报错 输入php伪协议读取i…

Simulink仿真之离散系统

最近,为了完成课程作业,需要用到Simulink验证数字控制器的合理性。题目如下所示。 其实这道题在胡寿松老师的《自动控制原理(第七版)》的364页有答案。 这里给出数字控制器的脉冲传递函数为 ​​​​​​​ ​​​​​​​…

【22】c++设计模式——>外观模式

外观模式定义 为复杂系统提供一个简化接口&#xff0c;它通过创建一个高层接口(外观)&#xff0c;将多个子系统的复杂操作封装起来&#xff0c;以便客户端更容易使用。 简单实现 #include<iostream>// 子系统类 class SubsystemA { public:void operationA() {std::co…

经典循环命题:百钱百鸡

翁五钱一只&#xff0c;母三钱&#xff0c;小鸡三只一钱&#xff1b;百钱百鸡百鸡花百钱。 (本笔记适合能熟练应用for循环、会使if条件分支语句、能格式化字符输出的 coder 翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xff1a;https://www.python.org/ Free&#xff1a…

自定义表单工具好用的优点是什么?

如果想提升办公效率&#xff0c;那么就离不开低代码技术平台了。它的轻量级、易掌握、易操作、简洁简便等优势特点深得很多领域客户朋友的喜爱。目前&#xff0c;IBPS开发平台在通信业、制造业、医疗、高校等很多行业中得到了客户的肯定和喜爱&#xff0c;推广价值高&#xff0…

懒人福利:不用动脑就能制作电子画册

对于很多企业来说&#xff0c;想要快速把自己的活动大面积的宣传出去&#xff0c;就要快人一步提前制作电子版画册&#xff0c;通过网络推送出去&#xff0c;让大众及时了解。如何制作电子版画册呢&#xff1f; 我发现了一个懒人福利&#xff01;就是FLBOOK &#xff0c;它简单…

学习函数式编程、可变参数及 defer - GO语言从入门到实战

函数是⼀等公⺠、学习函数式编程、可变参数及 defer - GO语言从入门到实战 函数是⼀等公⺠ 在Go语言中&#xff0c;函数可以分配给一个变量&#xff0c;可以作为函数的参数&#xff0c;也可以作为函数的返回值。这样的行为就可以理解为函数属于一等公民。 与其他主要编程语⾔…

Nodejs内置模块process

文章目录 内置模块process写在前面1. arch()2. cwd()3. argv4. memoryUsage()5. exit()6. kill()7. env【最常用】 内置模块process 写在前面 process是Nodejs操作当前进程和控制当前进程的API&#xff0c;并且是挂载到globalThis下面的全局API。 下面是process的一些常用AP…

linux开发板中的数据存储和读取操作

问题&#xff1a;MQTT远程下发的参数存储在本地linux开发板&#xff0c;开发板依据该参数执行相应功能&#xff0c;当开发板重新上电时依然能继续执行该功能 解决方式&#xff1a; 在linux板中写一个标识并存储为文件&#xff0c;依据读取的文件标识执行相应的功能 1)打开文…

【2023】M1/M2 Mac 导入Flac音频到Pr的终极解决方案

介绍 原作者链接&#xff1a;https://github.com/fnordware/AdobeOgg 很早之前就发现了这个插件&#xff0c;超级好用&#xff0c;在windows上完全没有问题&#xff0c;可惜移植到mac就不行了&#xff08;然后我给作者发了一个Issue&#xff0c;后来就有大佬把m1的编译出来了&…