文章目录
- 问题如图所示
- 运行结果如图
- 代码分析
- 完整代码
- 完结撒花
问题如图所示
运行结果如图
代码分析
% 定义样本数量
n = 500;
这行代码定义了一个变量 n
,它代表样本数量。这个变量在后面的代码中会被用到。
% 将 s 和 z 取值范围分成子区间的个数
num_intervals = 40;
这行代码定义了一个变量 num_intervals
,它代表将 s 和 z 取值范围分成的子区间个数。这个变量在后面的代码中也会被用到。
% 将取值范围 [0,1] 和 [-1,1] 等分为 num_intervals+1 个子区间,存储在一维数组 s_values 和 z_values 中
s_values = linspace(0, 1, num_intervals + 1);
这行代码利用 linspace
函数将取值范围 [0, 1]
等分为 num_intervals+1
个子区间,并将每个子区间的左端点作为一个 num_intervals+1
长度的一维数组 s_values
中的元素。
% 将取值范围 [0,1] 和 [-1,1] 等分为 num_intervals+1 个子区间,存储在一维数组 s_values 和 z_values 中z_values = linspace(-1, 1, num_intervals + 1);
这行代码利用 linspace
函数将取值范围 [-1, 1]
等分为 num_intervals+1
个子区间,并将每个子区间的左端点作为一个 num_intervals+1
长度的一维数组 z_values
中的元素。
% 输出 s 和 z 的取值范围
fprintf('s ranges from %.2f to %.2f\n', s_values(1), s_values(end));
fprintf('z ranges from %.2f to %.2f\n', z_values(1), z_values(end));
这行代码分别输出了 s
和 z
的取值范围,使用了 fprintf
函数对字符串进行格式化输出。
% 定义 r 和 g 函数
r = @(s) n / 4 * s;
g = @(z, s) 1.03 * (1 - exp(-1 * s)) * (1 + z.^2);
这两行代码定义了两个函数,分别为 r
和 g
。其中,r
是关于参数 s
的一次函数,g
是关于参数 z
和 s
的一次函数。这里用到了匿名函数的语法。
% 定义 r 和 g 的一阶、二阶导数
r_prime = n / 4;
g_second = @(a, s, z) (1.03 * exp(-s) * (z.^2 - 1)) / ((1.03 * (1 - exp(-s)) + a).^3);
这两行代码定义了 r
和 g
的一阶、二阶导数。r_prime
是一个常数,等于 n/4
;g_second
是一个与 s
和 z
相关的函数,用了一个 lambda 表达式进行定义。
% 初始化 a_s_z 矩阵
a_s_z = zeros(num_intervals + 1);
这行代码初始化了一个 num_intervals+1
行、num_intervals+1
列的零矩阵 a_s_z
,它将用来存储各个 s
和 z
取值下求得的最小值点处的 a
值。
% 对每个区间端点使用梯度下降法计算最小值点处 a 的值for i = 1 : num_intervals + 1
for j = 1 : num_intervals + 1
% 计算在 a=0 时的 g''(n*a) 值
g_second_0 = g_second(0, s_values(i), z_values(j));
a = 0;
% 进行梯度下降迭代,根据公式更新 a 直到收敛
while true
grad_s = r_prime * s_values(i) + g_second(a, s_values(i), z_values(j)) * (1.03 * exp(-s_values(i)) - a - 1.03);
grad_z = 2 * z_values(j) * g_second(a, s_values(i), z_values(j));
% 根据 i 和 j 的值判断更新 a 的方式
if i == 1 && j == 1
a = a - 0.0005 * grad_s;
elseif i == 1
a = a - 0.0005 * grad_z;
elseif j == 1
a = a - 0.0005 * grad_s;
else
a = a - 0.0005 * (grad_s + grad_z);
end
% 检查是否收敛
if abs(g_second(a, s_values(i), z_values(j))) < 1e-10
break
end
end
% 将求得的 a 值存储到 a_s_z 矩阵中
a_s_z(i, j) = a;
end
end
这部分代码是最主要的部分,它使用了梯度下降法来求解每个子区间端点处对应的最小值点 a
。具体来说,对于矩阵中的每个元素 a_s_z(i,j)
,首先计算在 a=0
时的 g''(n*a)
值,然后进行梯度下降迭代,根据公式更新 a
直到收敛(即 g''(n*a)
的绝对值小于一个很小的数)。在每次更新 a
时,需要分别计算在 s
方向和 z
方向上的梯度并进行更新,具体涉及到一些判断语句,因为对于矩阵中的每个边界点,梯度计算方式不同。最终,每个子区间端点处求得的 a
值都储存在 a_s_z
矩阵中。
% 使用 mesh 函数将 a_s_z 矩阵可视化为一个三维网格图
figure;
[X, Y] = meshgrid(s_values, z_values);
mesh(X, Y, a_s_z');
xlabel('s');
ylabel('z');
zlabel('a');
title('Mesh plot of a(s,z)');
这行代码使用 mesh
函数将 a_s_z
矩阵可视化为一个三维网格图。使用 meshgrid
函数生成一组坐标点 X
和 Y
,然后将 a_s_z
矩阵的转置作为纵坐标值,传入 mesh
函数中即可。最后,添加坐标轴标签和图标题,完成可视化。
完整代码
% 定义样本数量
n = 500;
% 将 s 和 z 取值范围分成子区间的个数
num_intervals = 40;
% 将取值范围 [0,1] 和 [-1,1] 等分为 num_intervals+1 个子区间,存储在一维数组 s_values 和 z_values 中
s_values = linspace(0, 1, num_intervals + 1);
z_values = linspace(-1, 1, num_intervals + 1);
% 输出 s 和 z 的取值范围
fprintf('s 范围从 %.2f 到 %.2f\n', s_values(1), s_values(end));
fprintf('z 范围从 %.2f 到 %.2f\n', z_values(1), z_values(end));
% 定义 r 和 g 函数
r = @(s) n / 4 * s;
g = @(z, s) 1.03 * (1 - exp(-1 * s)) * (1 + z.^2);
% 定义 r 和 g 的一阶、二阶导数
r_prime = n / 4;
g_second = @(a, s, z) (1.03 * exp(-s) * (z.^2 - 1)) / ((1.03 * (1 - exp(-s)) + a).^3);
% 初始化 a_s_z 矩阵
a_s_z = zeros(num_intervals + 1);
% 对每个区间端点使用梯度下降法计算最小值点处 a 的值
for i = 1 : num_intervals + 1
for j = 1 : num_intervals + 1
% 计算在 a=0 时的 g''(n*a) 值
g_second_0 = g_second(0, s_values(i), z_values(j));
a = 0;
% 进行梯度下降迭代,根据公式更新 a 直到收敛
while true
grad_s = r_prime * s_values(i) + g_second(a, s_values(i), z_values(j)) * (1.03 * exp(-s_values(i)) - a - 1.03);
grad_z = 2 * z_values(j) * g_second(a, s_values(i), z_values(j));
% 根据 i 和 j 的值判断更新 a 的方式
if i == 1 && j == 1
a = a - 0.0005 * grad_s;
elseif i == 1
a = a - 0.0005 * grad_z;
elseif j == 1
a = a - 0.0005 * grad_s;
else
a = a - 0.0005 * (grad_s + grad_z);
end
% 检查是否收敛
if abs(g_second(a, s_values(i), z_values(j))) < 1e-10
break
end
end
% 将求得的 a 值存储到 a_s_z 矩阵中
a_s_z(i, j) = a;
end
end
% 使用 mesh 函数将 a_s_z 矩阵可视化为一个三维网格图
figure;
[X, Y] = meshgrid(s_values, z_values);
mesh(X, Y, a_s_z');
xlabel('s');
ylabel('z');
zlabel('a');
title('a(s,z) 的网格图');