从零构建深度学习推理框架-4 框架中的算子注册机制

news2025/1/16 7:45:21

今天要讲的这一注册机制用到了设计模式中的工厂模式和单例模式,所以这节课也是对两大设计模式的一个合理应用和实践。KuiperInfer的注册表是一个map数据结构,维护了一组键值对,key是对应的OpType,用来查找对应的value,value是用于创建该层的对应方法(Creator)。我们可以看一下KuiperInfer中的Layer注册表实现:

 typedef std::map<OpType, Creator> CreateRegistry;

OpType就是头文件中对应的索引:

创建该层的对应方法相当于一个工厂(Creator),Creator如下的代码所示,是一个函数指针类型,我们将存放参数的Oprator类传入到该方法中,然后该方法根据Operator内的参数返回具体的Layer.

typedef std::shared_ptr<Layer> (*Creator)(const std::shared_ptr<Operator> &op);

代表的是返回值为std::shared_ptr<Layer> 然后参数为const std::shared_ptr<Operator> &op的一类函数

只要返回值和参数的类型\个数都满足 creator就可以指向对应的函数

对应函数就是创建层layer的一个具体方法

所以目前就是这个样子:

ReluLayer定义完成--->LayerRegistererWrapper ---> RegisterCreator

接下来我们再看这个registor注册方法:

void LayerRegisterer::RegisterCreator(OpType op_type, const Creator &creator) {
  CHECK(creator != nullptr) << "Layer creator is empty";
  CreateRegistry &registry = Registry(); //实现单例的关键
  // 根据operator type
  CHECK_EQ(registry.count(op_type), 0) << "Layer type: " << int(op_type) << " has already registered!";
 //  ReluLayer::CreateInstance 没有被注册过,就塞入到注册表当中
  registry.insert({op_type, creator});
}

先来补充一下单例模式:

单例设计编程模式

全局当中有且只有一个变量

任意次和任意一方去调用都会得到这个唯一的变量

这里的唯一变量是全局的注册表 存的时候是这个,取得时候也需要是这个

这里的Registry也写上了,就是实现单例的关键:

LayerRegisterer::CreateRegistry &LayerRegisterer::Registry() {

  // C++程序员高频面试点
  static  CreateRegistry *kRegistry = new CreateRegistry();
  // 没有static 那就是调用一次初始化一次
  // 不构成单例
  CHECK(kRegistry != nullptr) << "Global layer register init failed!";
  return *kRegistry; // 返回了这个注册表
}

static CreateRegistry *kRegistry = new CreateRegistry();

这个其实只会被初始化一次

简单来说第一次,调用的时候 new CreateRegistry 存放到一个kRegistry (static)

后续调用的时候,只会返回kRegistry (static)

这是一种C++的特性。

如果没有static,那就是调用一次就初始化一次,就自然构不成单例模式。

layer_factory.cpp:


namespace kuiper_infer {
// OpType::kOperatorRelu 就是刚才还说的OpType
// ReluLayer::CreateInstance就是一个函数指针,用来初始化层的方法

// 单例设计编程模式
// 全局当中有且只有一个变量
// 任意次和任意一方去调用都会得到这个唯一的变量
// 这里的唯一变量是全局的注册表 存的时候是这个,取得时候也需要是这个
// typedef std::map<OpType, Creator> CreateRegistry
// 全局当中有且只有一个 CreateRegistry  的实例
// 什么方法来控制这个变量唯一呢
void LayerRegisterer::RegisterCreator(OpType op_type, const Creator &creator) {
  CHECK(creator != nullptr) << "Layer creator is empty";
  CreateRegistry &registry = Registry(); //实现单例的关键
  // 根据operator type
  CHECK_EQ(registry.count(op_type), 0) << "Layer type: " << int(op_type) << " has already registered!";
 //  ReluLayer::CreateInstance 没有被注册过,就塞入到注册表当中
  registry.insert({op_type, creator});
}
  CreateRegistry &registry = Registry(); //实现单例的关键

这个就是上面的Registry,之后作检查,如果已经有过的话,那count之后就该报错啦。

 CHECK_EQ(registry.count(op_type), 0) << "Layer type: " << int(op_type) << " has already registered!";

没有注册机制的create

然后我们来看一下没有注册机制前的createrelu是怎么做到的:

TEST(test_layer, forward_relu1) {
  using namespace kuiper_infer;
  float thresh = 0.f;
  // 初始化一个relu operator 并设置属性
  std::shared_ptr<Operator> relu_op = std::make_shared<ReluOperator>(thresh);

  // 有三个值的一个tensor<float>
  std::shared_ptr<Tensor<float>> input = std::make_shared<Tensor<float>>(1, 1, 3);
  input->index(0) = -1.f; //output对应的应该是0
  input->index(1) = -2.f; //output对应的应该是0
  input->index(2) = 3.f; //output对应的应该是3
  // 主要第一个算子,经典又简单,我们这里开始!

  std::vector<std::shared_ptr<Tensor<float>>> inputs; //作为一个批次去处理

  std::vector<std::shared_ptr<Tensor<float>>> outputs; //放结果
  inputs.push_back(input);
  ReluLayer layer(relu_op);
  // 因为是4.1 所以没有作业 4.2才有
// 一个批次是1
  layer.Forwards(inputs, outputs);
  ASSERT_EQ(outputs.size(), 1);
//记得切换分支!!!!!
  for (int i = 0; i < outputs.size(); ++i) {
    ASSERT_EQ(outputs.at(i)->index(0), 0.f);
    ASSERT_EQ(outputs.at(i)->index(1), 0.f);
    ASSERT_EQ(outputs.at(i)->index(2), 3.f);
  }
}

有了注册机制后的createrelu:

// 有了注册机制后的框架是如何init layer
TEST(test_layer, forward_relu2) {
  using namespace kuiper_infer;
  float thresh = 0.f;
  std::shared_ptr<Operator> relu_op = std::make_shared<ReluOperator>(thresh);
  std::shared_ptr<Layer> relu_layer = LayerRegisterer::CreateLayer(relu_op);

  std::shared_ptr<Tensor<float>> input = std::make_shared<Tensor<float>>(1, 1, 3);
  input->index(0) = -1.f;
  input->index(1) = -2.f;
  input->index(2) = 3.f;
  std::vector<std::shared_ptr<Tensor<float>>> inputs;
  std::vector<std::shared_ptr<Tensor<float>>> outputs;
  inputs.push_back(input);
  relu_layer->Forwards(inputs, outputs);
  ASSERT_EQ(outputs.size(), 1);
  for (int i = 0; i < outputs.size(); ++i) {
    ASSERT_EQ(outputs.at(i)->index(0), 0.f);
    ASSERT_EQ(outputs.at(i)->index(1), 0.f);
    ASSERT_EQ(outputs.at(i)->index(2), 3.f);
  }
}

我们可以看到std::shared_ptr<Operator> relu_op = std::make_shared<ReluOperator>(thresh), 初始化了一个ReluOperator, 其中的参数为thresh=0.f.

因为我们已经在ReluLayer的实现中完成了注册,{kOperatorRelu:ReluLayer::CreateInstance} , 所以现在可以使用 LayerRegisterer::CreateLayer(relu_op) 得到我们ReluLayer中的实例化工厂方法,我们再来看看CreateLayer的实现:

std::shared_ptr<Layer> LayerRegisterer::CreateLayer(const std::shared_ptr<Operator> &op) {
  CreateRegistry &registry = Registry();
  const OpType op_type = op->op_type_;

  LOG_IF(FATAL, registry.count(op_type) <= 0) << "Can not find the layer type: " << int(op_type);
  const auto &creator = registry.find(op_type)->second;

  LOG_IF(FATAL, !creator) << "Layer creator is empty!";
  std::shared_ptr<Layer> layer = creator(op);
  LOG_IF(FATAL, !layer) << "Layer init failed!";
  return layer;
}

可以看到传入的参数为op, 我们首先取得op中的op_type, 此处的op_type为kOperatorRelu, 根据registry.find(op_type), 就得到了层的初始化方法creator, 随后使用传入的op去初始化layer并返回实例。值得注意的是此处也调用了CreateRegistry &registry =Registry() 返回了我们所说的全局有且唯一的Layer注册表。

此处的creator(op)就相当于调用了ReluLayer::CreateInstance.(因为:

LayerRegistererWrapper kReluLayer(OpType::kOperatorRelu, ReluLayer::CreateInstance);
class LayerRegistererWrapper {
 public:
  LayerRegistererWrapper(OpType op_type, const LayerRegisterer::Creator &creator) {
    // 定义之后调用的
    // RegisterCreator
    LayerRegisterer::RegisterCreator(op_type, creator);
  }
};

这样子可能看上去和上面差别不大,但是在实际应用上会便捷很多:

ops:[] = {conv 1 , conv 2 ,conv 3 ,relu ,sigmoid ,linear , conv 3};
layers = [];
for op in ops : 
    layers.append(LayerRegisterer::CreateLayer(op))
    //初始化完毕

因为如果没有这个机制的话,那么语言多少层,他就要写多少层。

创建sigmoid算子:

void SigmoidLayer::Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs,
                            std::vector<std::shared_ptr<Tensor<float>>> &outputs) {
  CHECK(this->op_ != nullptr);
  CHECK(this->op_->op_type_ == OpType::kOperatorSigmoid);
  CHECK(!inputs.empty());

  const uint32_t batch_size = inputs.size();
  for (uint32_t i = 0; i < batch_size; ++i) {
    const std::shared_ptr<Tensor<float>> &input_data = inputs.at(i);
    std::shared_ptr<Tensor<float>> output_data = input_data->Clone();
    //补充,y=1/(1+e^{-x})
    output_data->data().transform([&](float value) {
    return(1/(1+exp(-1*value))); 
    });

    outputs.push_back(output_data);
  }
}

cube的transform函数就是对于这个cube中的每一个元素进行lambda表达式中的运算

在这里预先将input_data进行了复制i,所以可以对于output中的数值进行直接的运算。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/843864.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

断路器分合闸线圈电流试验

试验目的 仅通过断路器低电压值来分析判断断路器的状态, 不能有效地反映断路器内部潜 在缺陷, 同时无法对故障进行定位, 分、 合闸线圈电流蕴含断路器操作回路的极大信 息, 典型的分、 合闸线圈动作电流暂态波形, 通常有两个波峰和一个波谷, 根据波峰、 波谷出现的时间位置, …

关于外贸跟进客户过程中需要注意的地方

如果你感觉业务进展困难&#xff0c;多去看一些书&#xff0c;多去链接一些人&#xff0c;特别是优秀的人&#xff0c;多交流会让你思维更加开阔&#xff0c;笔记做好实践起来&#xff0c;就会有收获&#xff01; 我记得汪老师说过&#xff1a;跟进客户&#xff0c;当你准备好…

【Maven】依赖范围、依赖传递、依赖排除、依赖原则、依赖继承

【Maven】依赖范围、依赖传递、依赖排除、依赖原则、依赖继承 依赖范围 依赖传递 依赖排除 依赖原则 依赖继承 依赖范围 在Maven中&#xff0c;依赖范围&#xff08;Dependency Scope&#xff09;用于控制依赖项在编译、测试和运行时的可见性和可用性。通过指定适当的依赖…

AWK实战案例——筛选给定时间范围内的日志

时间戳与当地时间 概念&#xff1a; 1.时间戳&#xff1a; 时间戳是指格林威治时间自1970年1月1日&#xff08;00:00:00 GMT&#xff09;至当前时间的总秒数。它也被称为Unix时间戳&#xff08;Unix Timestamp&#xff09;。通俗的讲&#xff0c;时间戳是一份能够表示一份数据…

关于Python 的 Web 自动化测试的实践

Web 测试是软件测试中比较重要的一个分支&#xff0c;而要实现 Web 自动化测试则要求测试人员能熟练掌握自动化测试工具和编程语言。介绍免费开源的 Web 测试工具 Selenium&#xff0c;以及流行的编程语言 Python。根据自动化测试的原理&#xff0c;对网页元素的常用定位方式&a…

教育行业的文件管理方法和实践

信息化浪潮的冲击下&#xff0c;教育行业也正在向建设数据化平台发展。在文件管理方面&#xff0c;教育行业依旧存在文件交互与协作方式传统陈旧的问题。Zoho Workdrive为教育行业提供安全的文件集中存储管理空间&#xff0c;用户可以快速使用、共享文件&#xff0c;帮助教育行…

内容动态展示抽屉组件

知识点 mousemove与mouseenter的区别在于mousemove会触发事件冒泡&#xff0c;mouseenter不会&#xff0c;mouseleave同理。 mousemove会触发事件冒泡&#xff0c;因此鼠标在范围区域内移动时会一直触发。 mouseenter只触发一次&#xff0c;鼠标移入后触发&#xff0c;鼠标移…

绘制曲线python

文章目录 import matplotlib.pyplot as plt# 提供的数据 x= [1,1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,2,2.1,2.2,2.3,2.4,2.5,2.6,2.7,2.8,2.9,3,3.1,3.2,3.3,3.4,3.5,3.6,3.7,3.8,3.9,4,4.1,4.2,4.3,4.4,4.5,4.6,4.7,4.8,4.9,5,5.1,5.2,5.3,5.4,5.5,5.6,5.7,5.8,5.9,6,6.1,6.2…

AI时代,这些绘画软件让你简单上手绘画

即时灵感 即时设计是一款基于云端的矢量编辑工具&#xff0c;可以帮助用户创建网页和移动应用程序原型、界面和可视化设计等。除此之外&#xff0c;即时设计还具备协同。实时预览、团队管理等多种功能&#xff0c;可以实现多个用户同时设计一个项目。 AI 绘画工具和即时设计虽…

日撸java三百行day81-83

文章目录 说明CNN卷积神经网络1. 什么是CNN&#xff08;CNN基础知识&#xff09;1. 基本概念2.输入层3.卷积层3.1 图像3.2 卷积核3.3 偏置数3.4 滑动窗口步长3.5 特征图个数&#xff08;特征图通道数或深度&#xff09;3.6 边缘填充3.7 卷积过程例子 4. 激活函数5. 池化层6.全连…

打破界限,图文档与物料清单完美互联

在现代企业的产品开发过程中&#xff0c;图文档和物料清单是不可或缺的重要信息。然而&#xff0c;由于数据来源多样、格式繁杂&#xff0c;图文档与物料清单之间的信息传递往往存在障碍。而PDM系统&#xff08;Product Data Management&#xff0c;产品数据管理&#xff09;的…

MySQL插入数据的方法

插入数据方法&#xff1a; 1.insert into 表 values(value1, value2, value3....) 2.insert into 表 (字段1&#xff0c; 字段3&#xff0c; 字段5) values(value1, value2, value3) 3.insert into 表 [(字段1&#xff0c; 字段2&#xff0c; 字段3....)] values(value1, val…

41.利用matlab 平衡方程用于图像(matlab程序)

1.简述 白平衡 白平衡的英文为White Balance&#xff0c;其基本概念是“不管在任何光源下&#xff0c;都能将白色物体还原为白色”&#xff0c;对在特定光源下拍摄时出现的偏色现象&#xff0c;通过加强对应的补色来进行补偿。 所谓的白平衡是通过对白色被摄物的颜色还原&…

C++经典排序算法详解

目录 一、选择排序 二、冒泡排序 三、插入排序 一、选择排序 选择排序 选择排序&#xff08;Selection sort&#xff09;是一种简单直观的排序算法。它的工作原理是&#xff1a;第一次从待排序的数据元素中选出最小&#xff08;或最大&#xff09;的一个元素&#xff0c;存…

Linux命令200例:join将两个文件按照指定的键连接起来分析

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;全栈领域新星创作者✌。CSDN专家博主&#xff0c;阿里云社区专家博主&#xff0c;2023年6月csdn上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &…

十二、ESP32控制步进电机

1. 运行效果 2. 步进电机 最大特点:能够控制旋转一定的角度 3. 步进电机原理

P1433 吃奶酪(状态压缩dp)(内附封面)

吃奶酪 题目描述 房间里放着 n n n 块奶酪。一只小老鼠要把它们都吃掉&#xff0c;问至少要跑多少距离&#xff1f;老鼠一开始在 ( 0 , 0 ) (0,0) (0,0) 点处。 输入格式 第一行有一个整数&#xff0c;表示奶酪的数量 n n n。 第 2 2 2 到第 ( n 1 ) (n 1) (n1) 行…

研发工程师玩转Kubernetes——emptyDir

kubernets可以通过emptyDir实现在同一Pod的不同容器间共享文件系统。 正如它的名字&#xff0c;当Pod被创建时&#xff0c;emptyDir卷会被创建&#xff0c;这个时候它是一个空的文件夹&#xff1b;当Pod被删除时&#xff0c;emptyDir卷也会被永久删除。 同一Pod上不同容器之间…

报错Error: error:0308010C:digital envelope routines::unsupported

Error: error:0308010C:digital envelope routines::unsupported 解决&#xff1a; 在原有脚本前加&#xff1a; SET NODE_OPTIONS--openssl-legacy-provider && 原来&#xff1a; 加后&#xff1a;

常规VUE项目优化实践,跟着做就对了!

总结&#xff1a; 主要优化方式&#xff1a; imagemin优化打包大小&#xff08;96M->50M&#xff09;&#xff0c;但是以打包速度为代价&#xff0c;通过在构建过程中压缩图片来实现&#xff0c;可根据需求开启。字体压缩&#xff1a;目前项目内引用为思源字体&#xff0c…