numpy 数组输入到 tensorflow/keras 神经网络的 dtype 有什么关系吗?

如果我采用 tensorflow.keras 模型并调用model.fit(x, y)(numpy 数组在哪里x),numpy 数组是y什么重要吗?dtype我最好只是尽可能dtype小(例如int8二进制数据),还是这会给 tensorflow/keras 额外的工作以将其转换为浮点数?



慕沐林林
浏览 66回答 1
1回答

慕田峪4524236

您应该将输入转换为np.float32,这是 Keras 的默认数据类型。查一下:import tensorflow as tftf.keras.backend.floatx()'float32'如果你给 Keras 输入 in np.float64,它会抱怨:import tensorflow as tffrom tensorflow.keras.layers import Dense from tensorflow.keras import Modelfrom sklearn.datasets import load_irisiris, target = load_iris(return_X_y=True)X = iris[:, :3]y = iris[:, 3]ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(25).batch(8)class MyModel(Model):  def __init__(self):    super(MyModel, self).__init__()    self.d0 = Dense(16, activation='relu')    self.d1 = Dense(32, activation='relu')    self.d2 = Dense(1, activation='linear')  def call(self, x):    x = self.d0(x)    x = self.d1(x)    x = self.d2(x)    return xmodel = MyModel()_ = model(X)警告:tensorflow:Layer my_model 正在将输入张量从 dtype float64 转换为层的 dtype float32,这是 TensorFlow 2 中的新行为。该层具有 dtype float32,因为它的 dtype 默认为 floatx。如果你打算在 float32 中运行这个层,你可以安全地忽略这个警告。如果有疑问,如果您将 TensorFlow 1.X 模型移植到 TensorFlow 2,则此警告可能只是一个问题。要将所有层更改为默认 dtype float64,请调用tf.keras.backend.set_floatx('float64'). 要仅更改这一层,请将 dtype='float64' 传递给层构造函数。如果您是该层的作者,则可以通过将 autocast=False 传递给基础层构造函数来禁用自动转换。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python