政安晨:【Keras机器学习示例演绎】(三十一)—— 梯度集中,提高训练效果

news2024/11/21 0:33:10

目录

简介

设置

准备数据

使用数据增强

定义模型

实现梯度集中化

训练工具

不使用 GC 训练模型

使用 GC 训练模型

性能比较


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:实施梯度集中化,提高 DNN 的训练性能。

简介


本示例实现了 Yong 等人提出的深度神经网络新优化技术 "梯度集中化"(Gradient Centralization),并在 Laurence Moroney 的 "马或人 "数据集(Horses or Humans Dataset)上进行了演示。

梯度集中化既能加快训练过程,又能提高深度神经网络的最终泛化性能。

它通过将梯度向量集中为零均值,直接对梯度进行操作。

梯度集中化还能改善损失函数及其梯度的 Lipschitzness,从而提高训练过程的效率和稳定性。

此示例需要使用 tensorflow_datasets,可通过此命令安装:

设置

from time import time

import keras
from keras import layers
from keras.optimizers import RMSprop
from keras import ops

from tensorflow import data as tf_data
import tensorflow_datasets as tfds

准备数据


在本例中,我们将使用 "马或人类 "数据集。

num_classes = 2
input_shape = (300, 300, 3)
dataset_name = "horses_or_humans"
batch_size = 128
AUTOTUNE = tf_data.AUTOTUNE

(train_ds, test_ds), metadata = tfds.load(
    name=dataset_name,
    split=[tfds.Split.TRAIN, tfds.Split.TEST],
    with_info=True,
    as_supervised=True,
)

print(f"Image shape: {metadata.features['image'].shape}")
print(f"Training images: {metadata.splits['train'].num_examples}")
print(f"Test images: {metadata.splits['test'].num_examples}")
Image shape: (300, 300, 3)
Training images: 1027
Test images: 256

使用数据增强


我们将把数据比例调整为 [0,1],并对数据进行简单的扩充。

rescale = layers.Rescaling(1.0 / 255)

data_augmentation = [
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomRotation(0.3),
    layers.RandomZoom(0.2),
]


# Helper to apply augmentation
def apply_aug(x):
    for aug in data_augmentation:
        x = aug(x)
    return x


def prepare(ds, shuffle=False, augment=False):
    # Rescale dataset
    ds = ds.map(lambda x, y: (rescale(x), y), num_parallel_calls=AUTOTUNE)

    if shuffle:
        ds = ds.shuffle(1024)

    # Batch dataset
    ds = ds.batch(batch_size)

    # Use data augmentation only on the training set
    if augment:
        ds = ds.map(
            lambda x, y: (apply_aug(x), y),
            num_parallel_calls=AUTOTUNE,
        )

    # Use buffered prefecting
    return ds.prefetch(buffer_size=AUTOTUNE)

重新缩放和扩充数据

train_ds = prepare(train_ds, shuffle=True, augment=True)
test_ds = prepare(test_ds)

定义模型


在本文中,我们将定义一个卷积神经网络。

model = keras.Sequential(
    [
        layers.Input(shape=input_shape),
        layers.Conv2D(16, (3, 3), activation="relu"),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(32, (3, 3), activation="relu"),
        layers.Dropout(0.5),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.Dropout(0.5),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation="relu"),
        layers.MaxPooling2D(2, 2),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(512, activation="relu"),
        layers.Dense(1, activation="sigmoid"),
    ]
)

实现梯度集中化


现在,我们将对 RMSProp 优化器类进行子类化,修改 keras.optimizers.Optimizer.get_gradients() 方法,从而实现梯度集中化。从高层次上讲,我们的想法是,假设我们通过密集层或卷积层的反向传播获得梯度,然后计算权重矩阵列向量的平均值,再从每一列向量中去除平均值。

本文在一般图像分类、细粒度图像分类、检测和分割以及人员 ReID 等各种应用中进行的实验表明,GC 可以持续提高 DNN 学习的性能。

此外,为了简单起见,我们目前没有实现渐变剪切功能,但这很容易实现。

目前,我们只是为 RMSProp 优化器创建了一个子类,但你可以很容易地在任何其他优化器或自定义优化器上以同样的方式重现这个子类。在后面的文章中,我们将使用梯度集中法训练模型时使用该类。

class GCRMSprop(RMSprop):
    def get_gradients(self, loss, params):
        # We here just provide a modified get_gradients() function since we are
        # trying to just compute the centralized gradients.

        grads = []
        gradients = super().get_gradients()
        for grad in gradients:
            grad_len = len(grad.shape)
            if grad_len > 1:
                axis = list(range(grad_len - 1))
                grad -= ops.mean(grad, axis=axis, keep_dims=True)
            grads.append(grad)

        return grads


optimizer = GCRMSprop(learning_rate=1e-4)

训练工具


我们还将创建一个回调函数,以便轻松测量总训练时间和每个历时所需的时间,因为我们有兴趣比较梯度集中化技术对上述模型的影响。

class TimeHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time() - self.epoch_time_start)

不使用 GC 训练模型


现在,我们在不使用梯度集中法的情况下训练之前建立的模型,并将其与使用梯度集中法训练的模型的训练性能进行比较。

time_callback_no_gc = TimeHistory()
model.compile(
    loss="binary_crossentropy",
    optimizer=RMSprop(learning_rate=1e-4),
    metrics=["accuracy"],
)

model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape              ┃    Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 298, 298, 16)      │        448 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 149, 149, 16)      │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_1 (Conv2D)               │ (None, 147, 147, 32)      │      4,640 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout (Dropout)               │ (None, 147, 147, 32)      │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 73, 73, 32)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_2 (Conv2D)               │ (None, 71, 71, 64)        │     18,496 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_1 (Dropout)             │ (None, 71, 71, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 35, 35, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_3 (Conv2D)               │ (None, 33, 33, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_3 (MaxPooling2D)  │ (None, 16, 16, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_4 (Conv2D)               │ (None, 14, 14, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_4 (MaxPooling2D)  │ (None, 7, 7, 64)          │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ flatten (Flatten)               │ (None, 3136)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_2 (Dropout)             │ (None, 3136)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense (Dense)                   │ (None, 512)               │  1,606,144 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_1 (Dense)                 │ (None, 1)                 │        513 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 1,704,097 (6.50 MB)
 Trainable params: 1,704,097 (6.50 MB)
 Non-trainable params: 0 (0.00 B)

我们还保存了历史记录,因为我们以后要比较使用梯度集中化训练和未使用梯度集中化训练的模型。

history_no_gc = model.fit(
    train_ds, epochs=10, verbose=1, callbacks=[time_callback_no_gc]
)
Epoch 1/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 24s 778ms/step - accuracy: 0.4772 - loss: 0.7405
Epoch 2/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 597ms/step - accuracy: 0.5434 - loss: 0.6861
Epoch 3/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 700ms/step - accuracy: 0.5402 - loss: 0.6911
Epoch 4/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 586ms/step - accuracy: 0.5884 - loss: 0.6788
Epoch 5/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 588ms/step - accuracy: 0.6570 - loss: 0.6564
Epoch 6/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 591ms/step - accuracy: 0.6671 - loss: 0.6395
Epoch 7/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 594ms/step - accuracy: 0.7010 - loss: 0.6161
Epoch 8/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 593ms/step - accuracy: 0.6946 - loss: 0.6129
Epoch 9/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 699ms/step - accuracy: 0.6972 - loss: 0.5987
Epoch 10/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 11s 623ms/step - accuracy: 0.6839 - loss: 0.6197

使用 GC 训练模型


现在,我们将使用梯度集中法训练同一个模型,注意这次使用梯度集中法的是我们的优化器。

time_callback_gc = TimeHistory()
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])

model.summary()

history_gc = model.fit(train_ds, epochs=10, verbose=1, callbacks=[time_callback_gc])

演绎展示:

Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type)                    ┃ Output Shape              ┃    Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 298, 298, 16)      │        448 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 149, 149, 16)      │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_1 (Conv2D)               │ (None, 147, 147, 32)      │      4,640 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout (Dropout)               │ (None, 147, 147, 32)      │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 73, 73, 32)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_2 (Conv2D)               │ (None, 71, 71, 64)        │     18,496 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_1 (Dropout)             │ (None, 71, 71, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 35, 35, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_3 (Conv2D)               │ (None, 33, 33, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_3 (MaxPooling2D)  │ (None, 16, 16, 64)        │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ conv2d_4 (Conv2D)               │ (None, 14, 14, 64)        │     36,928 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ max_pooling2d_4 (MaxPooling2D)  │ (None, 7, 7, 64)          │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ flatten (Flatten)               │ (None, 3136)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_2 (Dropout)             │ (None, 3136)              │          0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense (Dense)                   │ (None, 512)               │  1,606,144 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_1 (Dense)                 │ (None, 1)                 │        513 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
 Total params: 1,704,097 (6.50 MB)
 Trainable params: 1,704,097 (6.50 MB)
 Non-trainable params: 0 (0.00 B)
Epoch 1/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 12s 649ms/step - accuracy: 0.7118 - loss: 0.5594
Epoch 2/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 592ms/step - accuracy: 0.7249 - loss: 0.5817
Epoch 3/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 587ms/step - accuracy: 0.8060 - loss: 0.4448
Epoch 4/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 693ms/step - accuracy: 0.8472 - loss: 0.4051
Epoch 5/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 594ms/step - accuracy: 0.8386 - loss: 0.3978
Epoch 6/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 593ms/step - accuracy: 0.8442 - loss: 0.3976
Epoch 7/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 585ms/step - accuracy: 0.7409 - loss: 0.6626
Epoch 8/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 587ms/step - accuracy: 0.8191 - loss: 0.4357
Epoch 9/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 9s 587ms/step - accuracy: 0.8248 - loss: 0.3974
Epoch 10/10
 9/9 ━━━━━━━━━━━━━━━━━━━━ 10s 646ms/step - accuracy: 0.8022 - loss: 0.4589

性能比较

print("Not using Gradient Centralization")
print(f"Loss: {history_no_gc.history['loss'][-1]}")
print(f"Accuracy: {history_no_gc.history['accuracy'][-1]}")
print(f"Training Time: {sum(time_callback_no_gc.times)}")

print("Using Gradient Centralization")
print(f"Loss: {history_gc.history['loss'][-1]}")
print(f"Accuracy: {history_gc.history['accuracy'][-1]}")
print(f"Training Time: {sum(time_callback_gc.times)}")
Not using Gradient Centralization
Loss: 0.5345584154129028
Accuracy: 0.7604166865348816
Training Time: 112.48799777030945
Using Gradient Centralization
Loss: 0.4014038145542145
Accuracy: 0.8153935074806213
Training Time: 98.31573963165283

我们鼓励读者在不同领域的不同数据集上尝试梯度集中化,并实验其效果。
 


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1641715.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于php+mysql+html简单图书管理系统

博主介绍: 大家好,本人精通Java、Python、Php、C#、C、C编程语言,同时也熟练掌握微信小程序、Android等技术,能够为大家提供全方位的技术支持和交流。 我有丰富的成品Java、Python、C#毕设项目经验,能够为学生提供各类…

C#语言入门

一、基础知识 1. 程序语言是什么 用于人和计算机进行交流,通过程序语言让计算机能够响应我们发出的指令 2. 开发环境 IDE,集成开发环境。它就是一类用于程序开发的软件,这一类软件一般包括了代码编辑、编译器、调试器、图形用户界面等等工…

springboot 整合 knife4j-openapi3

适用于&#xff1a;项目已使用shiro安全认证框架&#xff0c;整合knife4j-openapi3 1.引入依赖 <!-- knife4j-openapi3 --> <dependency><groupId>com.github.xiaoymin</groupId><artifactId>knife4j-openapi3-spring-boot-starter</artifa…

【C语言】——结构体

【C语言】——结构体 一、结构体类型的声明1.1、结构体的声明1.2、结构体变量的创建和初始化1.3、结构体的特殊声明1.4、结构体的自引用1.5、结构体的重命名 二、 结构体的内存对齐2.1、对齐规则2.2、结构体对齐实践2.3、为什么存在内存对齐2.4、修改默认对齐数 三、结构体传参…

数据库(MySQL)—— 多表查询

数据库&#xff08;MySQL&#xff09;—— 多表查询 多表关系一对多多对多一对一多表查询概述数据准备查询形式笛卡尔积 分类连接查询内连接外连接左外连接右外连接 自连接联合查询 今天我们来进入MySQL中一个非常重要的部分&#xff1a;多表查询&#xff1a; 多表关系 多表关…

【HM】DevEco Studio如何使用代码编程AI助手

大家可能都有用过或了解过github copilot插件&#xff0c;确实为我们编码智能、提升开发效率有很大的帮助。推荐两款国产的ai编程插件&#xff0c;分别是华为的CodeArts Snap和阿里的通义灵码。 DevEco 中如何安装通义灵码&#xff1f; 一、下载通义灵码离线安装包 打开官网…

数组邻接表+堆优化版dijkstra+蓝桥杯2022年第十三届决赛真题-出差

文章目录 邻接表数组实现堆优化版dijkstra蓝桥杯2022年第十三届决赛真题-出差 邻接表数组实现 idx是每条边的地址e保存终点的节点值w保存每条边的权值ne[idx]保存边表&#xff0c;idx的下一个顶点的地址h[a]保存顶点表&#xff0c;a是起点&#xff0c;h[a]是终点的地址 int e…

docker-compose单机容器集群编排工具

前言&#xff1a; docker-compose用来单机上编排容器&#xff08;定义和运行多个容器&#xff0c;使容器能互通&#xff09; Eg&#xff1a;前端和后端部署在一台机器上&#xff0c;现在直接通过编写docker-compose文件对多个服务&#xff08;可定义依赖&#xff0c;按顺序启…

conda环境安装的pyproj包报错

conda环境安装的pyproj包报错 文章目录 conda环境安装的pyproj包报错问题解决参考 问题 在conda创建的Python3.9虚拟环境中安装pyproj包3.6在运行时出现以下报错 UserWarning: pyproj unable to set database path. _pyproj_global_context_initialize()解决 先激活并进入创…

古典密码学简介

目录 C. D. Shannon: 一、置换密码 二、单表代替密码 ① 加法密码 ② 乘法密码 ③密钥词组代替密码 三、多表代替密码 代数密码 四、古典密码的穷举分析 1、单表代替密码分析 五、古典密码的统计分析 1、密钥词组单表代替密码的统计分析 2、英语的统计规…

从零开始学AI绘画,万字Stable Diffusion终极教程(二)

【第2期】关键词 欢迎来到SD的终极教程&#xff0c;这是我们的第二节课 这套课程分为六节课&#xff0c;会系统性的介绍sd的全部功能&#xff0c;让你打下坚实牢靠的基础 1.SD入门 2.关键词 3.Lora模型 4.图生图 5.controlnet 6.知识补充 在第一节课里面&#xff0c;我们…

【数据库原理及应用】期末复习汇总高校期末真题试卷

试卷 一、填空题 1.________是位于用户与操作系统之间的一层数据管理软件。 2.数据库系统的三级模式结构是指________、________、________。 3.数据库系统的三种数据模型是________ 、________、________。 4.若关系中的某一属性组的值能唯一地标识一个元组&#xff0c;则…

【LinuxC语言】信号的基本概念与基本使用

文章目录 前言一、信号的概念二、信号的使用2.1 基本的信号类型2.2 signal函数 总结 前言 在Linux环境下&#xff0c;信号是一种用于通知进程发生了某种事件的机制。这些事件可能是由操作系统、其他进程或进程本身触发的。对于C语言编程者来说&#xff0c;理解信号的基本概念和…

使用 ORPO 微调 Llama 3

原文地址&#xff1a;https://towardsdatascience.com/fine-tune-llama-3-with-orpo-56cfab2f9ada 更便宜、更快的统一微调技术 2024 年 4 月 19 日 ORPO 是一种新的令人兴奋的微调技术&#xff0c;它将传统的监督微调和偏好校准阶段合并为一个过程。这减少了训练所需的计算…

8.MyBatis 操作数据库(进阶)

文章目录 1.动态SQL插入1.1使用注解方式插入数据1.2使用xml方式插入数据1.3何时用注解何时用xml&#xff1f;1.4使用SQL查询中有多个and时&#xff0c;如何自动去除多余and1.4.1方法一&#xff1a;删除and之后的代码如图所示&#xff0c;再次运行1.4.2方法二&#xff1a;加上tr…

C语言——文件相关操作

2.什么是文件 3.文件的打开和关闭 4.文件的顺序读写 5.文件的随机读写 6.文本文件和二进制文件 7.文件读取结束的判定 8.文件缓冲区 一、文件相关介绍 1、为什么使用文件 文件用于永久存储数据。通过使用文件&#xff0c;我们可以在程序关闭后保存数据&#xff0c;以便将来…

Springboot图片上传【本地+oss】

文章目录 1 前端组件页面2 本地上传3 上传到阿里云oss3.1申请开通账号&#xff0c;做好先导准备3.2 开始使用 1 前端组件页面 使用的VueElement组件 在线cdn引入&#xff1a; <script src"https://cdn.bootcdn.net/ajax/libs/vue/2.7.16/vue.js"></script&…

Simulink|【免费】虚拟同步发电机(VSG)惯量阻尼自适应控制仿真模型

目录 主要内容 仿真模型要点 2.1 整体仿真模型 2.2 电压电流双闭环模块 2.3 SVPWM调制策略 2.4 无功电压模块 2.5 自适应控制策略及算法 部分结果 下载链接 主要内容 该模型为simulink仿真模型&#xff0c;主要实现的内容如下&#xff1a; 随着风力发电、…

免费APP分发平台 - 一个指南和解析

数字化时代的APP分发平台 随着数字化进程的加速免费APP分发平台 - 一个指南和解析&#xff0c;移动应用&#xff08;APP&#xff09;市场正迅速扩大。在这个充满竞争的市场中免费APP分发平台 - 一个指南和解析&#xff0c;一个优秀的APP分发平台能够帮助开发者和商家更有效地触…

用keras识别狗狗

一、需求场景 从照片从识别出狗狗 from keras.applications.resnet50 import ResNet50 from keras.preprocessing import image from keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np# 加载预训练的ResNet50模型 model ResNet5…