Scikit-learn Pipeline:测试集的预测大小等于训练集的大小

我正在尝试对测试数据集进行预测。我正在使用带有 MLPRegressor 的 Sklearn 管道。但是,即使我使用“test.csv”,我也只是从训练集中获得预测的大小。


我在哪里可以修改以获得长度与测试数据相同的预测?


train_pipeline.py

# Read training data

data = pd.read_csv(data_path, sep=';', low_memory=False, parse_dates=parse_dates)


# Fill all None records

data[config.TARGET] = data[config.TARGET].fillna(0)

data[config.TARGET] = data[config.TARGET].apply(lambda x: split_join_string(x) if (type(x) == str and len(x.split('.')) > 0) else x)


# Divide train and test

X_train, X_test, y_train, y_test = train_test_split(

    data[config.FEATURES],

    data[config.TARGET],

    test_size=0.1,

    random_state=0)  # we are setting the seed here


# Transform the target

y_train = y_train.apply(lambda x: np.log(float(x)) if x != 0 else 0)

y_test = y_test.apply(lambda x: np.log(float(x)) if x != 0 else 0)


data_test = pd.concat([X_test, y_test], axis=1)

# Save the dataset to a '.csv' file without index

data_test.to_csv(data_path_test, sep=';', index=False)


pipeline.order_pipe.fit(X_train[config.FEATURES],

                        y_train)


save_pipeline(pipeline_to_persist=pipeline.order_pipe)

预测.py

def make_prediction(*, input_data) -> dict:

    """Make a prediction using the saved model pipeline."""


    data = pd.DataFrame(input_data)

    validated_data = validate_inputs(input_data=data)


    prediction = _order_pipe.predict(validated_data[config.FEATURES])

    output = np.exp(prediction)


    #score = _order_pipe.score(validated_data[config.FEATURES], validated_data[config.TARGET])

    results = {'predictions': output, 'version': _version}


    _logger.info(f'Making predictions with model version:  {_version}'

            f'\nInputs:  {validated_data}'

            f'\nPredictions: {results}')


    return results

我希望预测的大小为“test.csv”,但实际预测的大小为“train.csv”。我是否需要将测试数据集拟合或转换为“order_pipe”以做出正确大小的预测?


慕森卡
浏览 108回答 1
1回答

婷婷同学_

我解决了这个问题,删除了一个导致 X_test 大小崩溃的预处理器。因此,X_test 被 X_train 取代,我无法做出正确的预测。此外,还有另一个预处理器(使用 pd.get_dummies() 创建虚拟对象)插入新列并在 X_test 预测期间带来更多问题。groupby()我还替换了那个预处理器,使用and对分类特征进行编码map()。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python