【深度学习】- 作业1: Softmax实现手写数字识别

news2025/1/18 6:57:12

课程链接: 清华大学驭风计划

代码仓库:Victor94-king/MachineLearning: MachineLearning basic introduction (github.com)


驭风计划是由清华大学老师教授的,其分为四门课,包括: 机器学习(张敏教授) , 深度学习(胡晓林教授), 计算机语言(刘知远教授) 以及数据结构与算法(邓俊辉教授)。本人是综合成绩第一名,除了数据结构与算法其他单科均为第一名。代码和报告均为本人自己实现,由于篇幅限制,只展示任务布置以及关键代码,如果需要报告或者代码可以私聊博主



机器学习部分授课老师为胡晓林教授,主要主要通过介绍回归模型,多层感知机,CNN,优化器,图像分割,RNN & LSTM 以及生成式模型入门深度学习


有任何疑问或者问题,也欢迎私信博主,大家可以相互讨论交流哟~~



任务介绍

本次案例中,你需要用python实现Softmax回归方法,用于MNIST手写数字数据集分类任务。你需要完成前向计算loss和参数更新。

你需要首先实现Softmax函数和交叉熵损失函数的计算。

image.png

在更新参数的过程中,你需要实现参数梯度的计算,并按照随机梯度下降法来更新参数。

image.png

具体计算方法可自行推导,或参照第三章课件。

MNIST数据集

MNIST手写数字数据集是机器学习领域中广泛使用的图像分类数据集。它包含60,000个训练样本和10,000个测试样本。这些数字已进行尺寸规格化,并在固定尺寸的图像中居中。每个样本都是一个784×1的矩阵,是从原始的28×28灰度图像转换而来的。MNIST中的数字范围是0到9。下面显示了一些示例。 注意:在训练期间,切勿以任何形式使用有关测试样本的信息。

image.png

任务要求

  1. 代码清单

a) data/ 文件夹:存放MNIST数据集。你需要下载数据,解压后存放于该文件夹下。下载链接见文末,解压后的数据为 *ubyte 形式;

b) solver.py 这个文件中实现了训练和测试的流程。建议从这个文件开始阅读代码;

c) dataloader.py 实现了数据加载器,可用于准备数据以进行训练和测试;

d) visualize.py 实现了plot_loss_and_acc函数,该函数可用于绘制损失和准确率曲线;

e) optimizer.py 你需要实现带momentum的SGD优化器,可用于执行参数更新;

f) loss.py 你需要实现softmax_cross_entropy_loss,包含loss的计算和梯度计算;

g) runner.ipynb 完成所有代码后的执行文件,执行训练和测试过程。

  1. 要求

我们提供了完整的代码框架,你只需要完成optimizer.py,loss.py 中的 #TODO部分。你需要提交整个代码文件和带有结果的 runner.ipynb ( 不要提交数据集 ) 并且附一个pdf格式报告,内容包括:

a) 记录训练和测试的准确率。画出训练损失和准确率曲线;

b) 比较使用和不使用momentum结果的不同,可以从训练时间,收敛性和准确率等方面讨论差异;

c) 调整其他超参数,如学习率,Batchsize等,观察这些超参数如何影响分类性能。写下观察结果并将这些新结果记录在报告中。

3.其他

  1. 注意代码的执行效率,尽量不要使用for循环;
  2. 不要在pdf报告中粘贴很多代码(只能包含关键代码),对添加的代码作出解释;
  3. 不要使用任何深度学习框架,如TensorFlow,Pytorch等;
  4. 禁止抄袭。

4.参考

  1. 数据集下载:http://yann.lecun.com/exdb/mnist/



核心代码

optimizer.py

    def step(self):
        """One updating step, update weights"""

        layer = self.model
        if layer.trainable:

            ############################################################################
            # TODO: Put your code here
            # Calculate diff_W and diff_b using layer.grad_W and layer.grad_b.
            # You need to add momentum to this.

            # Weight update with momentum
            self.diff_W = self.momentum * self.diff_W - self.learning_rate * layer.grad_W
            self.diff_b = self.momentum * self.diff_b - self.learning_rate * layer.grad_b

            # # Weight update without momentum
            layer.W += self.diff_W
            layer.b += self.diff_b

            ############################################################################

loss.py

class SoftmaxCrossEntropyLoss(object):

    def __init__(self, num_input, num_output, trainable=True):
        """
        Apply a linear transformation to the incoming data: y = Wx + b
        Args:
            num_input: size of each input sample
            num_output: size of each output sample
            trainable: whether if this layer is trainable
        """

        self.num_input = num_input
        self.num_output = num_output
        self.trainable = trainable
        self.XavierInit()

    def onehot_labels(self , labels):
        '''one-hot lable'''
        lable_one = np.zeros((labels.shape[0],10))
        for i in range(len(labels)):
            lable_one[i,labels[i]] = 1 
        return lable_one 

    def forward(self, Input, labels):
        """
          Inputs: (minibatch)
          - Input: (batch_size, 784)
          - labels: the ground truth label, shape (batch_size, )
        """

        ############################################################################
        # TODO: Put your code here
        # Apply linear transformation (WX+b) to Input, and then
        # calculate the average accuracy and loss over the minibatch
        # Return the loss and acc, which will be used in solver.py
        # Hint: Maybe you need to save some arrays for gradient computing.
        self.z = np.exp(np.dot(Input, self.W) + self.b) # z = wx + b 
        self.z = (self.z + EPS) / ((np.sum(self.z, axis = 1).reshape(-1,1)) + EPS) # z = exp(Zi) / sum(exp(Zi))
      
        self.lables = self.onehot_labels(labels)# onehot_labels
        loss = np.mean(-np.log(np.sum( self.z *  self.lables, axis=1)))  # loss = - y lg(z)
        logit = np.nanargmax(self.z, axis=1) # output
        acc = sum([ int(i == j) for (i ,j) in zip(logit, labels)]) / len(labels) # compute acc
        self.Input , self.label = Input , labels #save arrays
        return loss, acc
        ############################################################################

  
    def gradient_computing(self):
      
        ############################################################################
        # TODO: Put your code here
        # Calculate the gradient of W and b.
        grad = self.z - self.lables
        self.grad_b = np.nanmean( grad ,axis = 0) 
        self.grad_W = np.dot(self.Input.T,grad) / self.Input.shape[0]
        ############################################################################

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/556859.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【深度学习】 - 作业7: 图像超分辨率重建

课程链接: 清华大学驭风计划 代码仓库:Victor94-king/MachineLearning: MachineLearning basic introduction (github.com) 驭风计划是由清华大学老师教授的,其分为四门课,包括: 机器学习(张敏教授) , 深度学习(胡晓林教授), 计算…

Linux网络编程—Day10

Linux服务器程序规范 Linux服务器程序一般以后台进程形式运行。后台进程又称守护进程。它没有控制终端,因而也不会意外接收到用户输入。 守护进程的父进程通常是init进程(PID为1的进程);Linux服务器程序通常有一套日志系统&#…

Android 11.0 系统设置显示主菜单添加屏幕旋转菜单实现旋转屏幕功能

1.前言 在android11.0的系统rom定制化开发中,在对系统设置进行定制开发中,有产品需求要求增加 旋转屏幕功能的菜单,就是在点击旋转屏幕菜单后弹窗显示旋转0度,旋转 90度,旋转180度, 旋转270度针对不同分辨率的无重力感应的大屏设备的屏幕旋转功能的实现, 接下来就来分析…

chatgpt赋能Python-python_dill

Python dill - 更高效,更强大的对象序列化工具 在Python编程中,对象序列化是一个非常常见的操作,它将Python对象转换为可以存储或传输的格式。Python默认提供了pickle模块来实现序列化,但是pickle存在一些限制,比如无…

usb摄像头驱动-core层usb设备的注册

usb摄像头驱动-core层driver.c 文章目录 usb摄像头驱动-core层driver.cusb_bus_typeusb_device_matchusb_uevent usb_register_driver 在ubuntu中接入罗技c920摄像头打印的信息如下: 在内核中,/driver/usb/core/driver.c 文件扮演了 USB 核心驱动程序管…

复现ms17-010漏洞

复现ms17-010 复现条件环境操作效果 让我们来看看“hkl”在干嘛 复现 条件 1.靶机必须为win8以下。 2.靶机必须同攻击机处于相同网段(即倒数第二位相同)。 3.靶机的445端口必须要开启。 环境 虚拟机kali为攻击机,win7虚拟机为靶机。 kali…

SELECT LAST_INSERT_ID()自增主键冲突或者为0问题

问题 数据库为mysql; mapper.xml文件为mybatis-generator自动生成的; 连接池使用DruidDataSource; 最终生成的insertSelective如下: 出现问题: 主键冲突:[WMyBatisTraceInterceptor:54][com.mysql.jdbc.…

如何定位OOM

造成OOM的原因 定位OOM 针对第一和第二种情况需要定位OOM 系统已经挂了: 通过堆dump文件定位。 当JVM发生OOM时,自动生成DUMP文件: -XX:HeapDumpOnOutOfMemoryError -XX:HeapDumpPath${目录} java -Xms10M -Xmx10M -XX:HeapDumpOnOutOfMemo…

c#中的json数据

JSON数据 数据传输的语言,用于前后端数据交互的语言,注意xml的区别。 json和xml的区别 xml:可扩展标记语言,是一种用于标记电子文件使其具有结构性的标记语言。 json:(JavaScript Object Notation, JS 对象简谱) 是…

昆仲被投企业「鲲游光电」近期完成新一轮数亿元融资,继续致力于探索前沿光电子领域产品与创新应用|昆仲·链

鲲游光电(North Ocean Photonics)是专注于晶圆级光学、光集成领域的高科技企业,总部位于中国上海。依托自身深厚的技术积累、高精度的制造工艺,鲲游光电通过融合集成光学与集成电路的创新思路,致力于探索前沿光子、光电子领域的实现与创新应用…

MyBatis环境搭建配置、增删改查操作、分页、事务操作、动态SQL、缓存机制、注解开发

MyBatis 文章目录 MyBatisXML语言简介用途各部分注解声明元素属性注释CDATA转义字符 搭建环境读取实体类创建实体与映射关系的文件 配置MyBatis创建工具类接口实现 Mybatis工作流程增删改查指定映射规则指定构造方法字段名称带下划线处理条件查询插入数据复杂查询和事务一对多查…

Codeforces Round 860 (Div. 2)

A Showstopper 题意:给你两个长度为n的数组a和b,每次操作你可以互换a[i]与b[i],问最终能否满足 思路:若a[i]>b[i],我们就进行操作。这样数组b元素都是较大的, 一定比不操作更优。最后判断是否满足条件…

Python中的异常处理机制

什么是异常与异常处理 异常就是错误 异常会导致程序崩溃并停止运行 能监控并捕获到异常&#xff0c;将异常部位的程序进行修理使得程序继续正常运行 异常的语法结构 try:<代码块1> 被try关键字检查并保护的业务代码except <异常的类型>:<代码块2> # 代码…

Mybatis源码细节探究:sqlSessionFactory.openSession()这个方法到底发生了什么?

给自己的每日一句 不从恶人的计谋&#xff0c;不站罪人的道路&#xff0c;不坐亵慢人的座位&#xff0c;惟喜爱耶和华的律法&#xff0c;昼夜思想&#xff0c;这人便为有福&#xff01;他要像一棵树栽在溪水旁&#xff0c;按时候结果子&#xff0c;叶子也不枯干。凡他所做的尽…

【笔记】【Javascript】浅面了解原型和原型链

前言 原型和原型链是学习前端必备知识笔记中有些个人理解后整理的笔记&#xff0c;可能有所偏差&#xff0c;也恳请读者帮忙指出&#xff0c;谢谢。 免责声明 为了方便&#xff0c;本文中使用的部分图片来自于网络&#xff0c;如有侵权,请联系博主进行删除&#xff0c;感谢其…

C++ 二分查找法 LeetCode:704. 二分查找

class Solution { public:int search(vector<int>& nums, int target) {int length nums.size();//计算容器长度int left 0;//0int right length-1;//5int middle 0;/*while(left<right){middle (leftright)/2;//middle (leftright)>>1;if(nums[middl…

mysql子查询嵌套

目录 前言 一、实际需求解决 1.方式1&#xff1a;自连接 2.方式2&#xff1a;子查询 二、单行子查询 1.操作符子查询 三、相关子查询 四、自定义语句 五、子查询的问题 1.空值问题 2.非法使用子查询 六、多行子查询 七、聚合函数的嵌套使用 八、多行子查询空值问题…

Python爬虫实战——获取指定博主所有专栏链接及博文链接

Python爬虫实战——获取指定博主所有专栏链接及博文链接 0. 前言1. 第三方库的安装2. 代码3. 演示效果 0. 前言 本节学习使用爬虫来爬取指定csdn用户的所有专栏下的文章 操作系统&#xff1a;Windows10 专业版 开发环境&#xff1a;Pycahrm Comunity 2022.3 Python解释器版…

带你学C带你飞-P16拾遗

自增运算符 #include <stdio.h> int main() {int i5,j;j i;printf("i%d,j%d",i,j);i5;ji;printf("i%d,j%d",i,j); }i:先使用i的值&#xff0c;再对i自身进行加一 i&#xff1a;先对i自身加一&#xff0c;再赋值给j 逗号运算符 条件运算符 三目运…

【Linux】冯诺依曼体系结构、操作系统概念、进程概念

文章目录 前言一、冯诺依曼体系结构1.简介冯诺依曼体系2.CPU3.存储器3.IO&#xff08;输入输出&#xff09;4.总结 二、操作系统&#xff08;OS&#xff09;1.操作系统是什么&#xff1f;2.为什么有操作系统&#xff1f;&#xff08;功能&#xff09;3.操作系统如何实现功能&am…