【tensorflow】连续输入+离散输入的神经网络模型训练代码

news2024/12/25 0:26:50

【tensorflow】连续输入+离散输入的神经网络模型训练代码

  • 离散输入的转化问题
  • 构造词典
  • 创建离散数据、转化字典索引、创建连续数据
  • 创建离散输入+连续输入模型
  • 训练输出
  • 全部代码 - 复制即用

  查看本系列三种模型写法:
  【tensorflow】连续输入的线性回归模型训练代码
  【tensorflow】连续输入的神经网络模型训练代码
  【tensorflow】连续输入+离散输入的神经网络模型训练代码

离散输入的转化问题

  离散输入一般有几种处理方式:

  1、如果是数字的话,可以直接输入到模型,或者正则化到[0-1]之间再输入。

但是离散的数字往往代表一个实体,比如它可能是id,去当数字输入到模型是不合适的。而且,离散的数据也不一定是数据,更多的是字符串。

  2、如果是字符串,可以转化为one-hot编码,但是这样的话,为0的数据占90%以上。

  3、所以此时就需要用到Embedding。在使用Embedding前,需要构造字典。

Embedding层形状为(input_length, dim, vocab_size)。input_length是输入的维度,dim表示一个词我要表征为几维度的向量,vocab_size表示词汇表的大小。输入需要是[0-vocab_size-1]之间的数字,所以我们需要把离散输入转化为数字,此时就要构造字典。

构造词典

  第一步,创建离散的数据集:

import numpy as np

random_numbers = np.random.randint(low=1, high=1000000, size=10000)

// array([781702, 805689, 194619, ..., 268855, 114390, 963977])

  第二步,提取离散的数据中的字典:

np.savetxt('voc.txt', [_ for _ in random_numbers], delimiter='\n', fmt='%d')

创建离散数据、转化字典索引、创建连续数据

# 加载词典
def get_vocab(path):
    vocab_dict = {}

    with open(path, 'r', encoding='utf-8') as file:
        for index, line in enumerate(file):
            word = line.strip()
            vocab_dict[word] = index

    print(f"\n===词典长度==={len(vocab_dict)}===\n")
    
    return vocab_dict

def get_data():
    # 设置随机种子,以确保结果可复现(可选)
    np.random.seed(0)

    # 生成随机数据
    data = np.random.rand(10000, 10)
    
    # 正则化数据
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    random_numbers = np.random.randint(low=1, high=1000000, size=10000)
    np.savetxt('voc.txt', [_ for _ in random_numbers], delimiter='\n', fmt='%d')
    vocab_dict = get_vocab('voc.txt')
    
    discrete = [vocab_dict[str(i)] for i in random_numbers]

    # 生成随机数据
    target = np.random.rand(10000, 1)
    
    

    return train_test_split(data, target, discrete, test_size=0.1, random_state=42)


data_train, data_val, target_train, target_val, discrete_train, discrete_val = get_data()

  get_vocab函数:
  这个函数用于从指定路径的文件中加载词典。它会逐行读取文件内容,并将每一行的单词作为词典的键,行号作为对应的值。最终返回一个包含词典内容的字典对象。

  path参数表示词典文件的路径。
  函数内部使用open函数打开文件,按行读取文件内容。
  对于每一行,使用strip方法去除行末尾的换行符等不需要的字符,并将其作为词典的键。
  行号(即索引值)作为对应的值,并将键值对添加到词典中。
  最后,返回包含词典内容的字典对象。

  get_data函数:
  这个函数用于生成随机数据,并结合词典将随机生成的整数映射为离散值。函数的执行过程如下:

  首先,使用np.random.rand函数生成一个形状为(10000, 10)的随机数据矩阵data。
  接下来,使用StandardScaler对数据进行正则化处理,将其转换为均值为0、标准差为1的数据。
  然后,使用np.random.randint生成一个长度为10000、范围在1到1000000之间的随机整数数组random_numbers。
  使用np.savetxt函数将random_numbers保存为文本文件voc.txt,每个整数占一行。
  调用get_vocab函数,加载词典文件voc.txt,并将其存储在vocab_dict字典中。
  根据词典,将random_numbers中的整数映射为对应的离散值,存储在discrete列表中。
  最后,使用np.random.rand函数生成一个形状为(10000, 1)的随机目标值数组target。

  函数返回了划分好的训练集和验证集数据,包括data_train、data_val、target_train、target_val、discrete_train和discrete_val。这些数据将在后续的模型训练和验证中使用。

创建离散输入+连续输入模型

def create_mlp(dim, regress=False):
    model = Sequential()
    model.add(Dense(64, input_dim=dim, activation="relu"))
    model.add(Dense(64, activation="relu"))
    # check to see if the regression node should be added
    if regress:
        model.add(Dense(1, activation="linear"))
    # return our model
    return model


def create_emb(dim, regress=False):
    model = Sequential()
    model.add(Embedding(input_length= dim, output_dim=8, input_dim=vocabulary_size))
    model.add(LSTM(128))
    model.add(Dense(64, activation="relu"))
    # check to see if the regression node should be added
    if regress:
        model.add(Dense(1, activation="linear"))
    # return our model
    return model

mlp = create_mlp(10, regress=False)
emb = create_emb(1, regress=False)

combined = concatenate([mlp.output, emb.output])

z = Dense(2, activation="relu")(combined)
z = Dense(1, activation="linear")(z)

model = Model(inputs=[mlp.input, emb.input], outputs=z)

model.summary()

  这段代码定义了两个函数create_mlp和create_emb,用于创建MLP(多层感知机)和Embedding-LSTM模型,并将它们结合起来构建一个联合模型。

  create_mlp函数:
  这个函数用于创建一个MLP模型。MLP是一种前馈神经网络,由多个全连接层组成。函数的输入参数dim表示输入维度,regress表示是否是回归任务。

  创建一个Sequential模型对象。
  添加一个具有64个神经元的全连接层,输入维度为dim,激活函数为ReLU。
  添加第二个具有64个神经元的全连接层,激活函数为ReLU。
  如果regress为True,则添加一个具有1个神经元的输出层,激活函数为线性激活函数(用于回归任务)。
  返回构建好的MLP模型对象。

  create_emb函数:
  这个函数用于创建一个包含Embedding和LSTM的模型。Embedding是一种用于将离散的整数序列映射到低维连续向量的技术,而LSTM是一种长短期记忆网络。

  创建一个Sequential模型对象。
  添加一个Embedding层,指定输入长度为dim,输出维度为8,输入维度为vocabulary_size(词汇表大小)。
  添加一个LSTM层,具有128个神经元。
  添加一个具有64个神经元的全连接层,激活函数为ReLU。
  如果regress为True,则添加一个具有1个神经元的输出层,激活函数为线性激活函数(用于回归任务)。
  返回构建好的Embedding-LSTM模型对象。
  接下来的代码将两个模型的输出通过concatenate函数进行合并。然后,构建一个新的模型model,输入为MLP模型的输入和Embedding-LSTM模型的输入,输出为合并后的结果。

  使用Model函数定义一个新的模型对象,指定输入为MLP模型的输入和Embedding-LSTM模型的输入,输出为合并后的结果。
  添加一个具有2个神经元的全连接层,激活函数为ReLU。
  添加一个具有1个神经元的输出层,激活函数为线性激活函数。

  打印模型的摘要信息,包括每层的名称、输出形状和参数数量。
  通过以上步骤,你可以创建一个包含MLP和Embedding-LSTM的联合模型,并输出该模型的摘要信息,包括每层的配置和参数数量。

训练输出

  模型结构如下:

在这里插入图片描述

  模型训练过程中的输出如下:

在这里插入图片描述

全部代码 - 复制即用

from sklearn.model_selection import train_test_split
import tensorflow as tf
import numpy as np
from keras import Input, Model, Sequential
from keras.layers import Dense, concatenate, Embedding, LSTM
from sklearn.preprocessing import StandardScaler
from tensorflow import keras
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

# 加载词典
def get_vocab(path):
    vocab_dict = {}

    with open(path, 'r', encoding='utf-8') as file:
        for index, line in enumerate(file):
            word = line.strip()
            vocab_dict[word] = index

    print(f"\n===词典长度==={len(vocab_dict)}===\n")
    
    return vocab_dict

def get_data():
    # 设置随机种子,以确保结果可复现(可选)
    np.random.seed(0)

    # 生成随机数据
    data = np.random.rand(10000, 10)
    
    # 正则化数据
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    
    random_numbers = np.random.randint(low=1, high=1000000, size=10000)
    np.savetxt('voc.txt', [_ for _ in random_numbers], delimiter='\n', fmt='%d')
    vocab_dict = get_vocab('voc.txt')
    
    discrete = [vocab_dict[str(i)] for i in random_numbers]
    discrete = np.array(discrete).reshape(-1, 1)

    # 生成随机数据
    target = np.random.rand(10000, 1)

    return train_test_split(data, target, discrete, test_size=0.1, random_state=42)


data_train, data_val, target_train, target_val, discrete_train, discrete_val = get_data()

# 迭代轮次
train_epochs = 10
# 学习率
learning_rate = 0.0001
# 批大小
batch_size = 200


def create_mlp(dim, regress=False):
    model = Sequential()
    model.add(Dense(64, input_dim=dim, activation="relu"))
    model.add(Dense(64, activation="relu"))
    # check to see if the regression node should be added
    if regress:
        model.add(Dense(1, activation="linear"))
    # return our model
    return model


def create_emb(dim, regress=False):
    model = Sequential()
    model.add(Embedding(input_length= dim, output_dim=8, input_dim=100000))
    model.add(LSTM(128))
    model.add(Dense(64, activation="relu"))
    # check to see if the regression node should be added
    if regress:
        model.add(Dense(1, activation="linear"))
    # return our model
    return model

mlp = create_mlp(10, regress=False)
emb = create_emb(1, regress=False)

combined = concatenate([mlp.output, emb.output])

z = Dense(2, activation="relu")(combined)
z = Dense(1, activation="linear")(z)

model = Model(inputs=[mlp.input, emb.input], outputs=z)

model.summary()

model.compile(loss="mse", optimizer=tf.train.GradientDescentOptimizer(learning_rate=learning_rate))

history = model.fit([data_train, discrete_train], target_train, epochs=train_epochs, batch_size=batch_size,
                    validation_data=([data_val, discrete_val], target_val))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/656615.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于知识图谱的电影推荐系统——Neo4jPython

文章目录 1. 数据解下载与配置2. 将处理好的数据导入数据库中3. 执行项目 1. 数据解下载与配置 选择TMDB电影数据集,Netflix Prize 数据集下载。 也可直接从这里下载:链接: https://pan.baidu.com/s/1l6wjwcUzy5G_dIlVDbCkpw 提取码: pkq6 。 执行prep…

【AI】金融FinGPT模型

金融FinGPT模型开源,对标BloombergGPT,训练参数可从61.7亿减少为367万,可预测股价 继Bloomberg提出了500亿参数的BloombergGPT,GPT在金融领域的应用受到了广泛关注,但BloombergGPT是一个非开源的模型,而且…

【ESP32之旅】U8g2 在线仿真和UI调试

前言 几乎每个玩屏幕的电子DIYer都知道万能的屏幕驱动中间件u8g2库,这个库提供了强大的驱动适配和ui设计能力。但是官方没有一个好用的ui设计和仿真软件,在设计UI布局的时候对单片机频繁的烧录调试浪费了大量的时间。最近在论坛看到有一个第三方维护的在…

nginx映射后,公网通过域名无法访问到静态资源

今天发生一件奇怪的事情,首先是阿里云的数字DV证书中pgj.bw580.com和acc.bw580.com无缘无故的消失了, 接着查看https://pgj.bw580.com/css/chunk-ceb11154.aefc15d8.css,在跳板机中可以访问到该资源,但是通过外网能够访问。 通过防…

MySQL 中各种锁的详细介绍

❤ 作者主页:欢迎来到我的技术博客😎 ❀ 个人介绍:大家好,本人热衷于Java后端开发,欢迎来交流学习哦!( ̄▽ ̄)~* 🍊 如果文章对您有帮助,记得关注、点赞、收藏、…

P109认识和改造世界

认识世界的根本目的在于改造世界 认识和改造世界之间的辩证关系 感觉只喜欢考 必然和自由的辩证关系 人类创造历史的两个基本活动 : 认识和改造世界所以认识和改造世界的基础是实践 认识改造和三大界之间的联系 改造客观世界和改造主观世界之间的关系 认识世界…

台电x80HD 安装linux系统,可调电压电源供电,外网访问、3D打印klipper固件

一、系统安装 参照https://blog.csdn.net/gangtieren/article/details/102975027安装 安装过程遇到的问题: 1、试了 linux mint 21 、ubuntu20.04 、ubuntu22.04 都没有直接安装成功,u盘选择安装进入系统后一直黑屏,只有ubuntu18.04 选择后稍…

基于Eclipse+Java+Swing+Mysql实现学生成绩管理系统

基于EclipseJavaSwingMysql实现学生成绩管理系统 一、系统介绍二、功能展示1.登陆2.成绩浏览3.班级添加4.班级维护5.学生添加6、学生维护 三、数据库四、其它1.其他系统实现五.获取源码 一、系统介绍 学生:登陆、成绩浏览 管理员:登陆、班级添加、班级维…

多分支merge忽略文件合并

该文章已同步收录到我的博客网站,欢迎浏览我的博客网站,xhang’s blog 1. .gitattributes 文件的作用 .gitattributes 文件是 Git 版本控制系统中的一个配置文件,它用于指定 Git 如何处理文件的二进制数据,以及如何标识文件的类…

字节月薪23k软件测试工程师:必备的6大技能(建议收藏)

软件测试 随着软件开发行业的日益发展,岗位需求量和行业薪资都不断增长,想要入行的人也是越来越多,但不知道从哪里下手,今天,就给大家分享一下,软件测试行业都有哪些必会的方法和技术知识点,作…

夏天到了,给数据中心泼点“冷水”

气温上升,还有什么能比“工作没了”,更能让人一瞬间心里拔凉拔凉的呢? 这个“薪尽自然凉”的故事,就发生在数据中心。 前不久,某电商平台正在购物高峰期,结果IDC冷冻系统故障,机房设备温度快速升…

智能电动汽车充电桩系统及硬件电路研究 安科瑞 许敏

摘要:随着充电桩技术的发展,以及人们对电动汽车快速充电的需求,很多厂商开始对智能充电桩进行研究。以电动 汽车智能充电桩的发展现状为背景,进行了智能电动汽车充电桩系统硬件电路的研究。 关键词:充电桩&#xff1b…

文件转换工具类—基于jodconverter和pdfbox实现的可以自定义各类文件转换和水印

源码获取&#xff1a;原文地址 概览 需要依赖 <dependency><groupId>org.jodconverter</groupId><artifactId>jodconverter-local</artifactId><version>4.4.6</version> </dependency> <dependency><groupId>or…

【MyBatis学习】占位符,sql注入问题,like模糊匹配等可能出现一定的问题,赶快与我一同去了解,避免入坑吧 ! ! !

前言: 大家好,我是良辰丫,今天还是我们的mybatis的学习,主要内容有两个占位符,sql注入问题,like模糊匹配,以及多表查询等,不断提升我们的编程能力,加油哈! ! !&#x1f48c;&#x1f48c;&#x1f48c; &#x1f9d1;个人主页&#xff1a;良辰针不戳 &#x1f4d6;所属专栏&…

MP地面站下载和回放日志

参考 https://ardupilot.org/dev/docs/common-downloading-and-analyzing-data-logs-in-mission-planner.html#common-downloading-and-analyzing-data-logs-in-mission-planner 下载日志 首先连接上飞控 然后在下图页面下载日志&#xff1a; 点击下图下载日志 下载的日志会…

在CentOS 7上安装Python 3.9

前言 这是我在这个网站整理的笔记&#xff0c;关注我&#xff0c;接下来还会持续更新。 作者&#xff1a;RodmaChen 在CentOS 7上安装Python 3.9 一. 更新系统软件包二. 安装必要的软件包和依赖项三. 下载Python 3.9四. 解压和编译源代码五. 安装Python 3.9六. 验证安装 一. 更…

SpringCloud Alibaba-Seata分布式事务

SpringCloud Alibaba-Seata 1 常用事务解决方案模型1.1 DTP模型1.2 2PC1.3 3PC1.4 TCC 2 Seata2.1 Seata术语2.1 Seata AT模式2.1.1 AT模式及工作流程2.1.2 Seata-Server安装2.1.3 集成springcloud-alibaba 4.2 Seata TCC模式 3 Seata注册中心3.1 服务端注册中心配置3.2 客户端…

全国主要城市建筑轮廓(含层高)矢量数据分享及最新AI提取建筑分布方法介绍

今天要给大家带来的数据就是全国主要大中型城市的城市建筑轮廓矢量数据&#xff01;&#xff01;同时给大家一个傻瓜式的建筑物提取软件&#xff0c;以及其使用方法&#xff01;&#xff01; 第一部分&#xff1a;数据 一、数据基本情况 建筑轮廓数据实际上就是建筑的边界矢量…

easyX绘图设备相关函数(注释版)

0.前言 这里是limou3434的easyX博文系列&#xff0c;感兴趣可以看看我的其他内容。 本次我给您带来的是easyX的绘图设备相关函数&#xff0c;和上一篇一样&#xff0c;对于官方文档我给了一些自认为重要的注释和测试例子&#xff0c;来辅助您理解这些函数。 1.easyX库函数分…

【汤4操作系统】深入掌握操作系统-文件管理篇

第六章 文件管理 文件 数据项&记录&文件 数据项分为&#xff1a; 基本数据项&#xff1a;描述对象的某些属性&#xff0c;例如学生的年龄&#xff0c;姓名学号等组合数据项&#xff1a;由若干个基本数据项组合而成 记录&#xff1a;一组相关数据项的集合&#xff0…