慕姐4208626
对我来说,你的代码看起来完全正确!至少算法是正确的。我已经更改了您的代码以用于numpy快速计算而不是纯Python。另外,我还配置了一些参数,例如改变了动量和学习率,也实现了MSE。然后我用来matplotlib画情节动画。最后,在动画上,看起来您的回归实际上试图将曲线拟合到数据。尽管在最后一次拟合迭代中它sin(x)看起来像线性近似,但仍然尽可能接近二次曲线的数据点。但对于for in来说,它看起来像是理想的近似(它从迭代周围开始拟合)。x[0; 2 * pi]sin(x)x[0; pi]12-thi-th动画帧只是用 进行回归dErr = 0.7 ** (i + 15)。我的动画运行脚本有点慢,但是如果您save像这样添加参数python script.py save,它将渲染/保存以line.gif绘制绘图动画。如果您在没有参数的情况下运行脚本,它将在您的 PC 屏幕上实时绘制/拟合动画。完整的代码在图形之后,代码需要通过运行一次安装一些Python模块python -m pip install numpy matplotlib。接下来是sin(x)在x:(0, pi)接下来是sin(x)在x:(0, 2 * pi)接下来是abs(x)在x:(-1, 1)# Needs: python -m pip install numpy matplotlibimport math, sysimport numpy as np, matplotlib.pyplot as plt, matplotlib.animation as animationfrom matplotlib.animation import FuncAnimationx_range = (0., math.pi, 0.1) # (xmin, xmax, xstep)y_range = (-0.2, 1.2) # (ymin, ymax)num_iterations = 50def f(x): return np.sin(x)def derr(iteration): return 0.7 ** (iteration + 15) def MSE(a, b): return (np.abs(np.array(a) - np.array(b)) ** 2).mean()def quadraticRegression(*, x, data, dErr): x, data = np.array(x), np.array(data) assert x.size == data.size, (x.size, data.size) a = 1 #Starting values b = 1 c = 1 a_momentum = 0.1 #Momentum to counter steady state error b_momentum = 0.1 c_momentum = 0.1 estimate = a*x**2 + b*x + c #Estimate curve error = MSE(data, estimate) #Get errors 'n stuff errorOld = 0. lr = 10. ** -4 #learning rate while abs(error - errorOld) > dErr: #Fit a (dE/da) deda = np.sum(2*x**2 * (a*x**2 + b*x + c - data))/len(data) correction = deda*lr a_momentum = (a_momentum)*0.99 + correction*0.1 #0.99 is to slow down momentum when correction speed changes a = a - correction - a_momentum #fit b (dE/db) dedb = np.sum(2*x*(a*x**2 + b*x + c - data))/len(data) correction = dedb*lr b_momentum = (b_momentum)*0.99 + correction*0.1 b = b - correction - b_momentum #fit c (dE/dc) dedc = np.sum(2*(a*x**2 + b*x + c - data))/len(data) correction = dedc*lr c_momentum = (c_momentum)*0.99 + correction*0.1 c = c - correction - c_momentum #Update model and find errors estimate = a*x**2 +b*x + c errorOld = error #print(error) error = MSE(data, estimate) return a, b, c, error fig, ax = plt.subplots()fig.set_tight_layout(True)x = np.arange(x_range[0], x_range[1], x_range[2])#ax.scatter(x, x + np.random.normal(0, 3.0, len(x)))line0, line1 = None, Nonedo_save = len(sys.argv) > 1 and sys.argv[1] == 'save'def g(x, derr): a, b, c, error = quadraticRegression(x = x, data = f(x), dErr = derr) return a * x ** 2 + b * x + c def dummy(x): return np.ones_like(x, dtype = np.float64) * 100.def update(i): global line0, line1 de = derr(i) if line0 is None: assert line1 is None line0, = ax.plot(x, f(x), 'r-', linewidth=2) line1, = ax.plot(x, g(x, de), 'r-', linewidth=2, color = 'blue') ax.set_ylim(y_range[0], y_range[1]) if do_save: sys.stdout.write(str(i) + ' ') sys.stdout.flush() label = 'iter {0} derr {1}'.format(i, round(de, math.ceil(-math.log(de) / math.log(10)) + 2)) line1.set_ydata(g(x, de)) ax.set_xlabel(label) return line1, axif __name__ == '__main__': anim = FuncAnimation(fig, update, frames = np.arange(0, num_iterations), interval = 200) if do_save: anim.save('line.gif', dpi = 200, writer = 'imagemagick') else: plt.show()