一、介绍
Ceres solver 是谷歌开发的一款用于非线性优化的库,在已知一个函数表达式,以及一组观测到的值,利用最小二乘是可以解算得到相关参数。下面举例使用ceres solver解算直线函数以及曲线函数参数。
其过程包括三个步骤:
(1)步骤一:构建代价函数
(2)步骤二:通过代价函数构建待求解的优化问题
(3)步骤三:配置求解器参数并求解。这个步骤不需要知道怎么求解,只要调用solver方法即可完成。
二、直线函数的求解
假如直线方程:,一组观测值[-1,-1; 1,5; 2,8; 3,11]
代码如下:
#include "ceres/ceres.h"
#include "glog/logging.h"
using ceres::AutoDiffCostFunction;
using ceres::CostFunction;
using ceres::Problem;
using ceres::Solve;
using ceres::Solver;
//y=3x+2
const int kNumObservations = 4;
// clang-format off
const double data[] = {
-1.0, -1.0,
1, 5,
2, 8,
3, 11,
};
//(1)构造代价函数
// clang-format on
//y=ax+b=3x+2
struct CostFunc_Line {
CostFunc_Line(double x, double y) : x_(x), y_(y) {}
template <typename T>
bool operator()(const T* const a, const T* const b, T* residual) const {
residual[0] = y_ - a[0] * x_ - b[0];
return true;
}
private:
const double x_;
const double y_;
};
int main(int argc, char** argv) {
// google::InitGoogleLogging(argv[0]);
double a = -1.0;
double b = 0.0;
Problem problem;
//(2)利用代价 函数构建优化问题
for (int i = 0; i < kNumObservations; ++i) {
problem.AddResidualBlock(
new AutoDiffCostFunction<CostFunc_Line, 1, 1, 1>(
new CostFunc_Line(data[2 * i], data[2 * i + 1])),
nullptr,
&a,
&b);
}
Solver::Options options;
options.max_num_iterations = 20;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = true;
Solver::Summary summary;
//(3)利用solver解算参数
Solve(options, &problem, &summary);
std::cout << summary.BriefReport() << "\n";
std::cout << "Initial a: " << -1.0 << " b: " << 0.0 << "\n";
std::cout << "Final a: " << a << " b: " << b << "\n";
system("pause");
return 0;
}
解算结果正确:a=3 b=2
data:image/s3,"s3://crabby-images/17c7d/17c7d98b0f94b2087679759e25e234ef61c98a85" alt=""
三、抛物线解算参数
假如抛物线方程:,对应的一组值为[-1,2; 0.5,4.25; 1,6; 2,11; 3,18;]
代码如下:
#include "ceres/ceres.h"
#include "glog/logging.h"
using ceres::AutoDiffCostFunction;
using ceres::CostFunction;
using ceres::Problem;
using ceres::Solve;
using ceres::Solver;
//y=3x+2
const int kNumObservations = 4;
// clang-format off
const double data[] = {
-1.0, 2,
0.5, 4.25,
1, 6,
2, 11,
3,18
};
//(1)构建代价函数
// clang-format on
//y=ax^2+bx+c=x^2+2x+3
struct CostFunc_para_curve{
CostFunc_para_curve(double x, double y) : x_(x), y_(y) {}
template <typename T>
bool operator()(const T* const a, const T* const b, const T* const c, T* residual) const {
residual[0] = y_ - a[0] * x_ * x_ - b[0] * x_ - c[0];//函数
return true;
}
private:
const double x_;
const double y_;
};
int main(int argc, char** argv) {
// google::InitGoogleLogging(argv[0]);
double a = 0.0;
double b = 0.0;
double c = 0.0;
Problem problem;
//(2)利用代价函数构建优化问题
for (int i = 0; i < kNumObservations; ++i) {
problem.AddResidualBlock(
//4个参数
new AutoDiffCostFunction<CostFunc_para_curve, 1, 1, 1,1>(
new CostFunc_para_curve(data[2 * i], data[2 * i + 1])),
nullptr,
&a,
&b,
&c);//三个参数
}
Solver::Options options;
options.max_num_iterations = 20;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = true;
Solver::Summary summary;
//(3)利用solver解算参数
Solve(options, &problem, &summary);
std::cout << summary.BriefReport() << "\n";
std::cout << "Initial a: " << 0.0 << " b: " << 0.0 << "c: " << 0.0 << "\n";
std::cout << "Final a: " << a << " b: " << b << " c: " << c << "\n";
system("pause");
return 0;
}
解算结果为:a=1 b=2 c=3
data:image/s3,"s3://crabby-images/13af5/13af5bae2c0643663a070e98d2e66c39938b7840" alt=""
四、曲线
假设曲线方程,对应的一组值为[0,1.1051; 1,1.4918; 2,2.0137 3,2.71828;4,3.66926 ]
#include "ceres/ceres.h"
#include "glog/logging.h"
using ceres::AutoDiffCostFunction;
using ceres::CostFunction;
using ceres::Problem;
using ceres::Solve;
using ceres::Solver;
//y=e^(mx+c)=e^(0.3x+0.1)
const int kNumObservations = 5;
// clang-format off
const double data[] = {
0,1.1051,
1,1.4918,
2,2.0137,
3,2.71828,
4,3.66926
};
//(1)构造代价函数
// clang-format on
//y=e^(mx+c)=e^(0.3x+0.1)
struct CostFunc_curve {
CostFunc_curve(double x, double y) : x_(x), y_(y) {}
template <typename T>
bool operator()(const T* const m, const T* const c, T* residual) const {
residual[0] = y_ - exp(m[0] * x_ + c[0]);
return true;
}
private:
const double x_;
const double y_;
};
int main(int argc, char** argv) {
// google::InitGoogleLogging(argv[0]);
double m = -1.0;
double c = 0.0;
Problem problem;
//(2)利用代价 函数构建优化问题
for (int i = 0; i < kNumObservations; ++i) {
problem.AddResidualBlock(
new AutoDiffCostFunction<CostFunc_curve, 1, 1, 1>(
new CostFunc_curve(data[2 * i], data[2 * i + 1])),
nullptr,
&m,
&c);
}
Solver::Options options;
options.max_num_iterations = 20;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = true;
Solver::Summary summary;
//(3)利用solver解算参数
Solve(options, &problem, &summary);
std::cout << summary.BriefReport() << "\n";
std::cout << "Initial m: " << -1.0 << " c: " << 0.0 << "\n";
std::cout << "Final m: " << m << " c: " << c << "\n";
system("pause");
return 0;
}
m=0.3 c=0.0999
data:image/s3,"s3://crabby-images/1ad66/1ad666a6c5e0dd7099f5363e1f552a9987308dc0" alt=""
小结:
相对来说,ceres解算已知函数模型,直接将观测值输入,利用solover函数解算即可,比较方便。