系列文章目录
MNN createFromBuffer(一)
MNN createRuntime(二)
MNN createSession 之 Schedule(三)
MNN createSession 之创建流水线后端(四)
MNN Session::resize 之流水线编码(五)
MNN Session 创建执行器(六)
文章目录
- 系列文章目录
- 一、MNN 资料
- 二、使用示例
- 三、源码分析
- 1、createFromFile、createFromBuffer
- 1.1 Content
- 1.2 createFromBufferInternal
- 1.3 Net
- 1.4 Interpreter
- 1.5 Interpreter::Interpreter
一、MNN 资料
MNN GitHub
中文文档
二、使用示例
// 创建解释器 Interpreter
auto net_ = Interpreter* createFromFile(const char* file);
// 创建运行时 Runtime
ScheduleConfig config;
config.numberThread = 4;
auto runtimeInfo = Interpreter::createRuntime({config});
// 创建会话 Session
auto session = net_->createSession(config, runtimeInfo);
// 执行推理
net_->runSession(session1);
三、源码分析
1、createFromFile、createFromBuffer
createFromFile、createFromBuffer 把模型读入,并放置在结构体 Content 的 buffer 中。
// source/core/Interpreter.cpp
Interpreter* Interpreter::createFromFile(const char* file) {
Content* net = loadModelFile(file);
if (nullptr == net) {
return nullptr;
}
return createFromBufferInternal(net, true);
}
Interpreter* Interpreter::createFromBuffer(const void* buffer, size_t size) {
if (nullptr == buffer || 0 == size) {
MNN_PRINT("Buffer is null for create interpreter\n");
return nullptr;
}
auto net = new Content;
net->buffer.reset((int)size);
if (nullptr == net->buffer.get()) {
MNN_ERROR("Memory not enought!\n");
return nullptr;
}
::memcpy(net->buffer.get(), buffer, size);
return createFromBufferInternal(net, true);
}
1.1 Content
// source/core/Interpreter.cpp
struct Content {
AutoStorage<uint8_t> buffer;
const Net* net = nullptr;
std::vector<std::unique_ptr<Session>> sessions;
std::map<Tensor*, const Session*> tensorMap;
Session::ModeGroup modes;
AutoStorage<uint8_t> cacheBuffer;
std::string cacheFile;
std::mutex lock;
size_t lastCacheSize = 0;
std::string bizCode;
std::string uuid;
std::string externalFile;
#ifdef MNN_INTERNAL_ENABLED
std::map<std::string, std::string> basicLogginData;
std::map<const Session*, std::tuple<int, int>> sessionInfo;
#endif
};
1.2 createFromBufferInternal
// source/core/Interpreter.cpp
Interpreter* Interpreter::createFromBufferInternal(Content* net, bool enforceAuth) {
if (nullptr == net) {
MNN_PRINT("Buffer is null for create interpreter\n");
return nullptr;
}
#ifndef MNN_BUILD_MINI
// 验证模型
flatbuffers::Verifier verify((const uint8_t*)(net->buffer.get()), net->buffer.size());
if (false == VerifyNetBuffer(verify)) {
MNN_PRINT("Invalidate buffer to create interpreter\n");
delete net;
return nullptr;
}
#endif
// 获取网络
net->net = GetNet(net->buffer.get());
if (nullptr == net->net->oplists()) {
MNN_ERROR("Model has no oplist\n");
delete net;
return nullptr;
}
// 验证模型算子
int opSize = net->net->oplists()->size();
for (int i = 0; i < opSize; ++i) {
auto op = net->net->oplists()->GetAs<Op>(i);
if (nullptr == op || nullptr == op->outputIndexes()) {
MNN_ERROR("Invalid Model, the %d op is empty\n", i);
delete net;
return nullptr;
}
}
// 新建解释器
return new Interpreter(net);
}
1.3 Net
// schema/current/MNN_generated.h
struct Net FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef NetT NativeTableType;
static const flatbuffers::TypeTable *MiniReflectTypeTable() {
return NetTypeTable();
}
const flatbuffers::String *bizCode() const {
return GetPointer<const flatbuffers::String *>(4);
}
const flatbuffers::Vector<flatbuffers::Offset<TensorDescribe>> *extraTensorDescribe() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<TensorDescribe>> *>(6);
}
const ExtraInfo *extraInfo() const {
return GetPointer<const ExtraInfo *>(8);
}
const flatbuffers::Vector<flatbuffers::Offset<Op>> *oplists() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<Op>> *>(10);
}
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *outputName() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(12);
}
ForwardType preferForwardType() const {
return static_cast<ForwardType>(GetField<int8_t>(14, 0));
}
NetSource sourceType() const {
return static_cast<NetSource>(GetField<int8_t>(16, 0));
}
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *tensorName() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(18);
}
int32_t tensorNumber() const {
return GetField<int32_t>(20, 0);
}
Usage usage() const {
return static_cast<Usage>(GetField<int8_t>(22, 0));
}
const flatbuffers::Vector<flatbuffers::Offset<SubGraphProto>> *subgraphs() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<SubGraphProto>> *>(24);
}
const flatbuffers::String *mnn_uuid() const {
return GetPointer<const flatbuffers::String *>(26);
}
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, 4) &&
verifier.VerifyString(bizCode()) &&
VerifyOffset(verifier, 6) &&
verifier.VerifyVector(extraTensorDescribe()) &&
verifier.VerifyVectorOfTables(extraTensorDescribe()) &&
VerifyOffset(verifier, 8) &&
verifier.VerifyTable(extraInfo()) &&
VerifyOffset(verifier, 10) &&
verifier.VerifyVector(oplists()) &&
verifier.VerifyVectorOfTables(oplists()) &&
VerifyOffset(verifier, 12) &&
verifier.VerifyVector(outputName()) &&
verifier.VerifyVectorOfStrings(outputName()) &&
VerifyField<int8_t>(verifier, 14) &&
VerifyField<int8_t>(verifier, 16) &&
VerifyOffset(verifier, 18) &&
verifier.VerifyVector(tensorName()) &&
verifier.VerifyVectorOfStrings(tensorName()) &&
VerifyField<int32_t>(verifier, 20) &&
VerifyField<int8_t>(verifier, 22) &&
VerifyOffset(verifier, 24) &&
verifier.VerifyVector(subgraphs()) &&
verifier.VerifyVectorOfTables(subgraphs()) &&
VerifyOffset(verifier, 26) &&
verifier.VerifyString(mnn_uuid()) &&
verifier.EndTable();
}
NetT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
void UnPackTo(NetT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
static flatbuffers::Offset<Net> Pack(flatbuffers::FlatBufferBuilder &_fbb, const NetT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
1.4 Interpreter
/** net data holder. multiple sessions could share same net. */
class MNN_PUBLIC Interpreter {
public:
/**
* @brief create net from file.
* @param file given file.
* @return created net if success, NULL otherwise.
*/
static Interpreter* createFromFile(const char* file);
/**
* @brief create net from buffer.
* @param buffer given data buffer.
* @param size size of data buffer.
* @return created net if success, NULL otherwise.
*/
static Interpreter* createFromBuffer(const void* buffer, size_t size);
~Interpreter();
public:
/**
* @brief create session with schedule config. created session will be managed in net.
* @param config session schedule config.
* @return created session if success, NULL otherwise.
*/
Session* createSession(const ScheduleConfig& config);
/**
* @brief create multi-path session with schedule configs. created session will be managed in net.
* @param configs session schedule configs.
* @return created session if success, NULL otherwise.
*/
Session* createMultiPathSession(const std::vector<ScheduleConfig>& configs);
/**
* @brief release session.
* @param session given session.
* @return true if given session is held by net and is freed.
*/
bool releaseSession(Session* session);
/**
* @brief call this function to get tensors ready. output tensor buffer (host or deviceId) should be retrieved
* after resize of any input tensor.
* @param session given session.
*/
void resizeSession(Session* session);
/**
* @brief call this function if don't need resize or create session any more, it will save a few memory that equal
* to the size of model buffer
*/
void releaseModel();
/**
* @brief Get the model buffer for user to save
* @return std::make_pair(modelBuffer, modelSize).
* @example:
* std::ofstream output("trainResult.alinn")
* auto buffer = net->getModelBuffer();
* output.write((const char*)buffer.first, buffer.second);
*/
std::pair<const void*, size_t> getModelBuffer() const;
/**
* @brief update Session's Tensor to model's Const Op
* @param session given session.
* @return result of running.
*/
ErrorCode updateSessionToModel(Session* session);
/**
* @brief run session.
* @param session given session.
* @return result of running.
*/
ErrorCode runSession(Session* session) const;
/*
* @brief run session.
* @param session given session.
* @param before callback before each op. return true to run the op; return false to skip the op.
* @param after callback after each op. return true to continue running; return false to interrupt the session.
* @param sync synchronously wait for finish of execution or not.
* @return result of running.
*/
ErrorCode runSessionWithCallBack(const Session* session, const TensorCallBack& before, const TensorCallBack& end,
bool sync = false) const;
/*
* @brief run session.
* @param session given session.
* @param before callback before each op. return true to run the op; return false to skip the op.
* @param after callback after each op. return true to continue running; return false to interrupt the session.
* @param sync synchronously wait for finish of execution or not.
* @return result of running.
*/
ErrorCode runSessionWithCallBackInfo(const Session* session, const TensorCallBackWithInfo& before,
const TensorCallBackWithInfo& end, bool sync = false) const;
/**
* @brief get input tensor for given name.
* @param session given session.
* @param name given name. if NULL, return first input.
* @return tensor if found, NULL otherwise.
*/
Tensor* getSessionInput(const Session* session, const char* name);
/**
* @brief get output tensor for given name.
* @param session given session.
* @param name given name. if NULL, return first output.
* @return tensor if found, NULL otherwise.
*/
Tensor* getSessionOutput(const Session* session, const char* name);
/**
* @brief get all input tensors.
* @param session given session.
* @return all input tensors mapped with name.
*/
const std::map<std::string, Tensor*>& getSessionOutputAll(const Session* session) const;
/**
* @brief get all output tensors.
* @param session given session.
* @return all output tensors mapped with name.
*/
const std::map<std::string, Tensor*>& getSessionInputAll(const Session* session) const;
public:
/**
* @brief resize given tensor.
* @param tensor given tensor.
* @param dims new dims. at most 6 dims.
*/
void resizeTensor(Tensor* tensor, const std::vector<int>& dims);
/**
* @brief resize given tensor by nchw.
* @param batch / N.
* @param channel / C.
* @param height / H.
* @param width / W
*/
void resizeTensor(Tensor* tensor, int batch, int channel, int height, int width);
/**
* @brief get backend used to create given tensor.
* @param session given session.
* @param tensor given tensor.
* @return backend used to create given tensor, may be NULL.
*/
const Backend* getBackend(const Session* session, const Tensor* tensor) const;
/**
* @brief get business code (model identifier).
* @return business code.
*/
const char* bizCode() const;
private:
static Interpreter* createFromBufferInternal(Content* net);
Content* mNet = nullptr;
Interpreter(Content* net);
Interpreter(const Interpreter&) = delete;
Interpreter(const Interpreter&&) = delete;
Interpreter& operator=(const Interpreter&) = delete;
Interpreter& operator=(const Interpreter&&) = delete;
};
} // namespace MNN
1.5 Interpreter::Interpreter
把 Content 放入到 Interpreter 中
Interpreter::Interpreter(Content* net) {
MNN_ASSERT(nullptr != net);
mNet = net;
// Store bizcode and uuid because we need them even after `releaseModel` is called.
mNet->bizCode = std::string(mNet->net->bizCode() ? mNet->net->bizCode()->c_str() : "");
mNet->uuid = std::string(mNet->net->mnn_uuid() ? mNet->net->mnn_uuid()->c_str() : "");
#ifdef MNN_INTERNAL_ENABLED
mNet->basicLogginData = getBasicLoggingData();
mNet->basicLogginData.emplace("ModelVersion", getModelVersion());
#endif
}
☆