【Pytorch】Fizz Buzz

news2025/1/23 16:52:47

在这里插入图片描述

文章目录

  • 1 数据编码
  • 2 网络搭建
  • 3 网络配置,训练
  • 4 结果预测
  • 5 翻车现场

学习参考来自:

  • Fizz Buzz in Tensorflow
  • https://github.com/wmn7/ML_Practice/tree/master/2019_06_10
  • Fizz Buzz in Pytorch

I need you to print the numbers from 1 to 100, except that if the number is divisible by 3 print “fizz”, if it’s divisible by 5 print “buzz”, and if it’s divisible by 15 print “fizzbuzz”.

编程题很简单,我们用 MLP 实现试试

思路,训练集数据101~1024,对其进行某种规则的编码,标签为经分类 one-hot 编码后的标签
测试集,1~100

don’t say so much, show me the code.

1 数据编码

import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data

def binary_encode(i, num_digits):
    """将每个input转换为binary digits(转换为二进制的表示, 最多可是表示2^num_digits)
    :param i:
    :param num_digits:
    :return:
    """
    return np.array([i >> d & 1 for d in range(num_digits)])

编码形式,依次除以 2 0 , 1 , 2 , 3 , . . . 2^{0,1,2,3,...} 20,1,2,3,...,结果按位与 1

m & 1,结果为 0 表示 m 为偶数, 结果为 1 表示 m 为奇数

> > m >> m >>m 右移表示除以 2 m 2^m 2m

第一位就能表示奇偶了,所有数字编码都不一样

eg,101 进行 num_digits=10 编码后结果为 1 0 1 0 0 1 1 0 0 0

步骤

101 / 1 = 101 奇数 1
101 / 2 = 50 偶数 0
101 / 4 = 25 奇数 1
101 / 8 = 12 偶数 0
101 / 16 = 6 偶数 0
101 / 32 = 3 奇数 1
101 / 64 = 1 奇数 1
101 / 128 = 0 偶数 0
101 / 256= 0 偶数 0
101 / 512= 0 偶数 0

标签,0,1,2,3 四个类别

def fizz_buzz_encode(i):
    """将output转换为lebel
    :param i:
    :return:
    """
    if i % 15 == 0:  # fizzbuzz
        return 3
    elif i % 5 == 0:  # buzz
        return 2
    elif i % 3 == 0:  # fizz
        return 1
    else:
        return 0

编码长度设定,数据集 101 ~ 1024

NUM_DIGITS = 10
trX = np.array([binary_encode(i, NUM_DIGITS) for i in range(101, 2**NUM_DIGITS)])  # 101~1024
trY = np.array([fizz_buzz_encode(i) for i in range(101, 2**NUM_DIGITS)])

# print(len(trX), len(trY))  # 923 923
# print(trX[:5])
"""
[[1 0 1 0 0 1 1 0 0 0]
 [0 1 1 0 0 1 1 0 0 0]
 [1 1 1 0 0 1 1 0 0 0]
 [0 0 0 1 0 1 1 0 0 0]
 [1 0 0 1 0 1 1 0 0 0]]
"""
# print(trY[:5])  # [0 1 0 0 3]

2 网络搭建

搭建简单的 MLP 网络

class FizzBuzzModel(nn.Module):
    def __init__(self, in_features, out_classes, hidden_size, n_hidden_layers):
        super(FizzBuzzModel,self).__init__()
        layers = []
        for i in range(n_hidden_layers):
            layers.append(nn.Linear(hidden_size,hidden_size))
            # layers.append(nn.Dropout(0.5))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.ReLU())
        self.inputLayer = nn.Linear(in_features, hidden_size)
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(*layers)  # 重复的搭建隐藏层
        self.outputLayer = nn.Linear(hidden_size, out_classes)

    def forward(self, x):
        x = self.inputLayer(x)
        x = self.relu(x)
        x = self.layers(x)
        out = self.outputLayer(x)
        return out

初始化网络,看看网络结构

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# define the model
simpleModel = FizzBuzzModel(NUM_DIGITS, 4, 150, 3).to(device)
print(simpleModel)
"""
FizzBuzzModel(
  (inputLayer): Linear(in_features=10, out_features=150, bias=True)
  (relu): ReLU()
  (layers): Sequential(
    (0): Linear(in_features=150, out_features=150, bias=True)
    (1): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=150, out_features=150, bias=True)
    (4): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=150, out_features=150, bias=True)
    (7): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (outputLayer): Linear(in_features=150, out_features=4, bias=True)
)
"""

输入 10, 输出4,隐藏层维度 150,隐藏层重复了 3 次

3 网络配置,训练

定义下超参数,损失函数,优化器,载入数据训练,输出训练精度与损失

# Loss and optimizer
learning_rate = 0.05
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(simpleModel.parameters(), lr=learning_rate)
optimizer = torch.optim.SGD(simpleModel.parameters(), lr=learning_rate)

# 使用batch进行训练
FizzBuzzDataset = Data.TensorDataset(torch.from_numpy(trX).float().to(device),
                                     torch.from_numpy(trY).long().to(device))

loader = Data.DataLoader(dataset=FizzBuzzDataset,
                         batch_size=128*5,
                         shuffle=True)

# 进行训练
simpleModel.train()
epochs = 3000

for epoch in range(1, epochs):
    for step, (batch_x, batch_y) in enumerate(loader):
        out = simpleModel(batch_x)  # 前向传播
        loss = criterion(out, batch_y)  # 计算损失
        optimizer.zero_grad()  # 梯度清零
        loss.backward()  # 反向传播
        optimizer.step()  # 随机梯度下降
    correct = 0
    total = 0
    _, predicted = torch.max(out.data, 1)
    total += batch_y.size(0)
    correct += (predicted == batch_y).sum().item()
    acc = 100*correct/total
    print('Epoch : {:0>4d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, loss, acc))

"""
Epoch : 0001 | Loss : 1.5343 | Train Accuracy : 14.63 %
Epoch : 0002 | Loss : 1.9779 | Train Accuracy : 42.58 %
Epoch : 0003 | Loss : 2.4198 | Train Accuracy : 53.41 %
Epoch : 0004 | Loss : 1.7360 | Train Accuracy : 53.41 %
Epoch : 0005 | Loss : 1.3161 | Train Accuracy : 49.73 %
Epoch : 0006 | Loss : 1.4866 | Train Accuracy : 22.75 %
Epoch : 0007 | Loss : 1.3993 | Train Accuracy : 25.57 %
Epoch : 0008 | Loss : 1.2428 | Train Accuracy : 28.49 %
Epoch : 0009 | Loss : 1.1906 | Train Accuracy : 44.31 %
Epoch : 0010 | Loss : 1.1929 | Train Accuracy : 52.44 %
...
Epoch : 2990 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2991 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2992 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2993 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2994 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2995 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2996 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2997 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2998 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2999 | Loss : 0.0000 | Train Accuracy : 100.00%
"""

训练集上精度是 OK 的,能到 100%,下面看看测试集上的精度

4 结果预测

把 one-hot 标签转化成 fizz buzz 的形式

def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

载入测试集,开始预测

simpleModel.eval()
# 进行预测
testX = np.array([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
predicts = simpleModel(torch.from_numpy(testX).float().to(device))
# 预测的结果
_, res = torch.max(predicts, 1)
print(res)
"""
tensor([0, 0, 0, 1, 0, 0, 0, 2, 1, 0, 1, 3, 3, 1, 1, 0, 0, 0, 0, 0, 0, 3, 1, 0,
        0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 1, 1, 1, 0,
        0, 0, 0, 0], device='cuda:0')
"""

# 格式的转换
predictions = [fizz_buzz_decode(i, prediction) for (i, prediction) in zip(range(1, 101), res)]
print(predictions)
"""
['1', '2', '3', 'fizz', '5', '6', '7', 'buzz', 'fizz', '10', 'fizz', 'fizzbuzz', 'fizzbuzz', 'fizz', 'fizz', '16', '17', '18', '19', '20', '21', 'fizzbuzz', 'fizz', '24', '25', '26', '27', '28', '29', '30', 'fizz', '32', '33', '34', '35', '36', '37', '38', '39', 'fizz', '41', 'fizz', '43', '44', '45', 'fizz', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', 'fizz', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', 'fizz', 'fizz', 'fizz', 'fizz', 'buzz', 'buzz', 'fizz', '80', '81', '82', '83', '84', '85', '86', '87', 'fizzbuzz', '89', '90', 'fizz', '92', 'fizz', 'fizz', 'fizz', '96', '97', '98', '99', '100']
"""

5 翻车现场

对比下标签

labels = []
for i in range(1, 101):
    if i % 15 == 0:  # fizzbuzz
        labels.append("fizzbuzz")
    elif i % 5 == 0:  # buzz
        labels.append("buzz")
    elif i % 3 == 0:  # fizz
        labels.append("fizz")
    else:
        labels.append(str(i))
print(labels)
print(labels == predictions)

"""
['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz']
False
"""

哈哈哈, False 翻车了,尝试了很多次,很难 True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1297824.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

快捷切换raw页面到repo页面-Raw2Repo插件

Raw2Repo By Rick &#x1f4d6;快捷切换代码托管平台raw页面到repo页面 &#x1f517;github链接 https://github.com/rickhqh/Raw2Repo ✨Features 功能&#xff1a; ✅单击 Raw2Repo 插件按钮&#xff0c;即可跳转到相应的代码仓库页面。✅支持 GitHub、Gitee、GitCode …

ChatGPT OpenAI API请求限制 尝试解决

1. OpenAI API请求限制 Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.._completion_with_retry in 4.0 seconds as it raised RateLimitError: Rate limit reached for gpt-3.5-turbo-16k in organization org-U7I2eKpAo6xA7RUa2Nq307ae on reques…

Ignis - Interactive Fire System

Ignis - 点火、蔓延、熄灭、定制! 全方位火焰系统。 这个插件在21年的项目中使用过很好用值使用概述 想玩火吗?如果想的话,那么Ignis就是你的最佳工具。有了Ignis,你可以把任何物体、植被或带皮带骨的网状物转换为可燃物体,它就会自动着火。然后,火焰可以蔓延,点燃其他物…

C++_函数重载

前言&#xff1a; 函数重载的意思就是可以有多个同名函数存在&#xff0c;但是这些同名函数的参数列表有着不同情形&#xff0c;以便区分。在C中&#xff0c;支持在同一作用域下可以声明、定义多个同名函数&#xff0c;但是这些函数的形参类型&#xff0c;类型顺序以及参数个数…

dcat admin多后台和自定义登录

多后台按照教程配置 https://learnku.com/docs/dcat-admin/2.x/multi-application-multi-background/8475 自定义登录 我的新后台的登录需要另外一个用户表&#xff0c;所以原来的逻辑要修改一下。 1、首先是模板修改 参考连接 https://learnku.com/docs/dcat-admin/2.x/ba…

UML案例分析

首先需要花大约20分钟来思考解决这个问题&#xff0c;如果对问题不是很熟悉&#xff0c;也可以在完成题目之后&#xff0c;找相关的资料翻阅&#xff08;例如看UML类图的基本情况&#xff0c;UML状态图的基本情况&#xff0c;然后结合这些信息 做一个自我评价&#xff0c;看这个…

Error: Cannot find module ‘E:\Workspace_zwf\mall\build\webpack.dev.conf.js‘

执行&#xff1a;npm run dev E:\Workspace_zwf\zengwenfeng-master>npm run dev> mall-app-web1.0.0 dev E:\Workspace_zwf\zengwenfeng-master > webpack-dev-server --inline --progress --config build/webpack.dev.conf.jsinternal/modules/cjs/loader.js:983thr…

多线程案例-单例模式

单例模式 设计模式的概念 设计模式好比象棋中的"棋谱".红方当头炮,黑方马来跳.针对红方的一些走法,黑方应招的时候有一些固定的套路.按照套路来走局势就不会吃亏. 软件开发中也有很多常见的"问题场景".针对这些问题的场景,大佬们总结出了一些固定的套路.按…

docker容器配置MySQL与远程连接设置(纯步骤)

以下为ubuntu20.04环境&#xff0c;默认已安装docker&#xff0c;没安装的网上随便找个教程就好了 拉去mysql镜像 docker pull mysql这样是默认拉取最新的版本latest 这样是指定版本拉取 docker pull mysql:5.7查看已安装的mysql镜像 docker images通过镜像生成容器 docke…

初识人工智能,一文读懂迁移学习的知识文集(4)

&#x1f3c6;作者简介&#xff0c;普修罗双战士&#xff0c;一直追求不断学习和成长&#xff0c;在技术的道路上持续探索和实践。 &#x1f3c6;多年互联网行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &#x1f389;欢迎 &#x1f44d;点赞✍评论…

低多边形建筑3D模型纹理贴图

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 当谈到游戏角色的3D模型风格时&#xff0c;有几种不同的风格&#xf…

ThinkPHP如何讲链接多个数据库

为什么要使用多个数据库 数据分片&#xff1a; 当数据量非常大时&#xff0c;可能需要将数据分布在不同的数据库中&#xff0c;以提高查询性能。这被称为数据分片&#xff0c;其中不同的数据库负责存储不同范围的数据。 业务分离&#xff1a; 有时&#xff0c;一个大型项目可…

使用Android Studio导入Android源码:基于全志H713 AOSP,方便解决编译、编码问题

文章目录 一、 篇头二、 操作步骤2.1 编译AOSP AS工程文件2.2 将AOSP导入Android Studio2.3 切到Project试图2.4 等待index结束2.5 下载缺失的JDK 1.82.6 导入完成 三、 导入AS的好处3.1 本文案例演示源码编译错误AS对比同文件其余地方的调用AS错误提示依赖AS做错误修正 一、 篇…

Docker网络原理及Cgroup硬件资源占用控制

docker的网络模式 获取容器的进程号 docker inspect -f {{.State.Pid}} 容器id/容器名 docker初始状态下有三种默认的网络模式 &#xff0c;bridg&#xff08;桥接&#xff09;&#xff0c;host&#xff08;主机&#xff09;&#xff0c;none&#xff08;无网络设置&#xff…

QT中时间时区处理总结

最近项目中要做跨国设备时间校正功能&#xff0c;用到了时区时间&#xff0c;在此做一下记录。 目录 1.常见时区名 2.测试代码 3.运行效果 1.常见时区名 "Pacific/Midway": "中途岛 (UTC-11:00)", …

2023中国(海南)国际高尔夫旅游文化博览会 暨国际商界峰层·全球华人高尔夫精英巡回赛 全国颍商自贸港行盛大启幕

2023中国&#xff08;海南&#xff09;国际高尔夫旅游文化博览会&#xff08;以下简称“海高博”&#xff09;暨全国颍商走进海南自贸港于12月7-9日在海口观澜湖盛大开幕。该活动由中国国际贸易促进委员会海南省委员会、海南省旅游和文化广电体育厅主办&#xff0c;中国国际商会…

用Java实现一对一聊天

目录 服务端 客户端 服务端 package 一对一用户; import java.awt.BorderLayout; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.net.ServerSocket; import java.net.Socket; imp…

数据结构中处理散列冲突的四种方法

1 开放定址法 1.1 定义 开放定址法就是一旦发生了冲突&#xff0c;就去寻找下一个空的散列地址 1.2 要求 只要散列表足够大 空的散列地址总能找到&#xff0c;并将记录存入 1.3 线性探测法 使用该公式用于解决冲突的开放定址法称为线性探测法 对于线性探测法&#xff0c…

rpc原理与应用

IPC和RPC&#xff1f; TCP是有三个特点&#xff0c;面向连接、可靠、基于字节流。粘包问题 RPC 而RPC&#xff08;Remote Procedure Call&#xff09;&#xff0c;又叫做远程过程调用。它本身并不是一个具体的协议&#xff0c;而是一种调用方式。 gRPC 是 Google 最近公布的…

GPT4停止订阅付费了怎么办? 怎么升级ChatGPT plus?提供解决方案

11月中旬日OpenAI 暂时关闭所有的升级入口之后&#xff0c;很多小伙伴就真的在排队等待哦。其实有方法可以绕开排队&#xff0c;直接付费订阅升级GPT的。赶紧用起来立马“插队”成功&#xff01;亲测~~~ 一、登录ChatGPT账号 1、没有账号可以直接注册一个&#xff0c;流程超级…