图像生成器缺少 unet keras 的位置参数

我不断收到以下错误下面的代码,当我尝试训练模型:TypeError: fit_generator() missing 1 required positional argument: 'generator'。对于我的生活,我无法弄清楚是什么导致了这个错误。x_train 是一个形状为 (400, 256, 256, 3) 的 rgb 图像,对于 y_train,我有 10 个输出类使其具有形状 (400, 256, 256, 10)。这里出了什么问题?


如有必要,可通过以下链接下载数据:https : //www49.zippyshare.com/v/5pR3GPv3/file.html


import skimage

from skimage.io import imread, imshow, imread_collection, concatenate_images

from skimage.transform import resize

from skimage.morphology import label

import numpy as np

import matplotlib.pyplot as plt

from keras.models import Model

from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout

from keras.optimizers import Adam

from keras.callbacks import ModelCheckpoint, LearningRateScheduler

from keras import backend as K

from sklearn.metrics import jaccard_similarity_score

from shapely.geometry import MultiPolygon, Polygon

import shapely.wkt

import shapely.affinity

from collections import defaultdict

from keras.preprocessing.image import ImageDataGenerator

from keras.utils.np_utils import to_categorical

from keras import utils as np_utils

import os

from keras.preprocessing.image import ImageDataGenerator

gen = ImageDataGenerator()

#Importing image and labels

labels = skimage.io.imread("ede_subset_293_wegen.tif")

images = skimage.io.imread("ede_subset_293_20180502_planetscope.tif")[...,:-1]



#scaling image

img_scaled = images / images.max()


#Make non-roads 0

labels[labels == 15] = 0


#Resizing image and mask and labels

img_scaled_resized = img_scaled[:6400, :6400 ]

print(img_scaled_resized.shape)

labels_resized = labels[:6400, :6400]

print(labels_resized.shape)


#splitting images

split_img = [

    np.split(array, 25, axis=0) 

    for array in np.split(img_scaled_resized, 25, axis=1)

]


split_img[-1][-1].shape


#splitting labels

split_labels = [

    np.split(array, 25, axis=0) 

    for array in np.split(labels_resized, 25, axis=1)

]



当年话下
浏览 173回答 1
1回答

宝慕林4294392

您的代码中有一些错误,但考虑到您的错误:类型错误:fit_generator() 缺少 1 个必需的位置参数:'generator'这是因为 fit_generator 调用 XYaugmentGenerator 但内部没有调用增强生成器。gen.flow(...不会工作,因为未声明 gen。您应该将 image_datagen 重命名为 gen 为:gen = ImageDataGenerator(**data_gen_args)或者,用 image_datagen 替换 gengenX1 = image_datagen.flow(X1, y, batch_size=batch_size, seed=seed) genX2 = image_datagen.flow(y, X1, batch_size=batch_size, seed=seed)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python