猿问

图层权重形状 (1, 1) 与为 keras 模型提供的权重形状 (1,) 不兼容

我使用 Keras 训练了一个模型,但忘记保存模型。该模型是开发了许多其他模型的项目的一部分,但现在我无法继续该项目。幸运的是,我保存了初始和最终训练重量。现在,我正在尝试创建一个具有相同最终权重的模型来获得预测。我正在编译 keras 模型并使用函数 model.set_weights 将丢失模型的最终训练权重设置为新模型。这是代码。


model = Sequential()

model.add(Dense(1,input_dim = 1, activation = 'relu'))

model.add(Dense(1, activation = 'relu'))

model.compile(loss = 'mean_squared_error', optimizer = 'Adam', metrics = ['mse'])

listOfNumpyArrays = [np.array([0.2]),np.array([0.2])]

listOfNumpyArrays1 = listOfNumpyArrays

model.layers[0].set_weights(listOfNumpyArrays)

model.layers[1].set_weights(listOfNumpyArrays1)

追溯


ValueError                                Traceback (most recent call last)

<ipython-input-31-e63437554e30> in <module>()

----> 1 model.layers[0].set_weights(listOfNumpyArrays)

      2 model.layers[1].set_weights(listOfNumpyArrays1)

1 frames

/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py in set_weights(self, weights)

   1124                                  str(pv.shape) +

   1125                                  ' not compatible with '

-> 1126                                  'provided weight shape ' + str(w.shape))

   1127             weight_value_tuples.append((p, w))

   1128         K.batch_set_value(weight_value_tuples)

ValueError: Layer weight shape (1, 1) not compatible with provided weight shape (1,)


呼如林
浏览 126回答 1
1回答

慕的地6264312

您使用创建的 numpy 数组np.array([0.2])有一个形状(1,),而您的权重数组有一个形状(1,1)。虽然它们存储相同数量的数据,但 numpy 将它们视为不同的形状。您可以通过执行以下操作来解决此问题:代替:listOfNumpyArrays = [np.array([0.2]),np.array([0.2])]使用:listOfNumpyArrays = [np.empty(shape = (1,1), dtype = np.float32), np.empty(shape = (1,1), dtype = np.float32)]listOfNumpyArrays[0][0] = 0.2listOfNumpyArrays[1][0] = 0.2无关的说明:在这一行中:listOfNumpyArrays1 = listOfNumpyArrays看起来您想创建两个不同的 numpy 数组列表,它们被初始化为相同的值。listOfNumpyArrays1但是,实际上将引用与 相同的列表listOfNumpyArrays。因此,当您执行set_weightson时listOfNumpyArrays1,它也会修改listOfNumpyArrays。要在创建两个不同的列表时将它们初始化为相同的值,可以使用以下代码:listOfNumpyArrays1 = [np.copy(listOfNumpyArrays[0]), np.copy(listOfNumpyArrays[1])]np.copy创建一个新数组,它是您传递的数组的副本。这可以使用列表理解以更 pythonic 的方式编写,如下所示:listOfNumpyArrays1 = [np.copy(x) for x in listOfNumpyArrays]
随时随地看视频慕课网APP

相关分类

Python
我要回答