慕仙森
imshow可用于绘制二维函数。x 和 y 方向首先使用 eg 在 1D 中创建np.linspace,然后通过 合并到 2D np.meshgrid。Numpy 的魔法允许编写简单的表达式,这些表达式在幕后立即对整个网格进行操作。from matplotlib import pyplot as pltimport matplotlibimport numpy as npfig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 4), gridspec_kw={'hspace': 0.05})distance, velocity = np.meshgrid(np.linspace(0, 1, 50), np.linspace(0, 1, 50))reward1 = 1 - distance ** 0.4reward1[distance < 0.1] = -1reward1[(distance < 0.1) & (velocity < 0.1)] = 1im1 = ax1.imshow(reward1, origin='bottom', extent=[0, 1, 0, 1], vmin=-1, vmax=1, cmap='bwr', interpolation='nearest')ax1.set_xlabel('distance')ax1.set_ylabel('velocity')plt.colorbar(im1, ax=ax1, shrink=0.9)dist_reward = 1 - distance ** 0.4vel_discount = (1 - np.maximum(velocity, 0.1)) ** (1 / np.maximum(distance, 0.1))reward2 = vel_discount * dist_rewardreward2[distance < 0.1] = -1reward2[(distance < 0.1) & (velocity < 0.1)] = 1im2 = ax2.imshow(reward2, origin='bottom', extent=[0, 1, 0, 1], vmin=-1, vmax=1, cmap='bwr', interpolation='nearest')ax2.set_xlabel('distance')ax2.set_ylabel('velocity')plt.colorbar(im2, ax=ax2, shrink=0.9)plt.tight_layout()plt.show()