sklearn 分层 k 折 CV 与线性模型,如 ElasticNetCV

使用交叉验证 (CV)sklearn非常简单直接。但是cv=5在线性 CV 模型中设置时的默认实现,例如ElasticNetCV或LassoCV是KFoldCV。出于各种原因,我想使用StratifiedKFold. 从文档来看,似乎任何CV 方法都可以用cv=.


传递cv=KFold(5)按预期工作,但cv=StratifiedKFold(5)会引发错误:


ValueError: 支持的目标类型是: ('binary', 'multiclass')。取而代之的是“连续”。


我知道我可以cross_val_score在拟合后使用,但我想StratifiedKFold作为 CV 直接传递给线性模型。


我的最低工作示例是:


from sklearn.linear_model import ElasticNetCV

from sklearn.model_selection import KFold, StratifiedKFold

import numpy as np


x = np.arange(100, dtype=np.float64).reshape(-1, 1)

y = np.arange(100) + np.random.rand(100)


# KFold default implementation:

model_default = ElasticNetCV(cv=5)

model_default.fit(x, y)  # works fine

# KFold given as cv explicitly:

model_kfexp = ElasticNetCV(cv=KFold(5))

model_kfexp.fit(x, y)  # also works fine


# StratifiedKFold given as cv explicitly:

model_skf = ElasticNetCV(cv=StratifiedKFold(5))

model_skf.fit(x, y)  # THIS RAISES THE ERROR

知道如何StratifiedKFold直接设置为 CV 吗?


杨__羊羊
浏览 260回答 1
1回答

婷婷同学_

你的问题的根源是这一行:y = np.arange(100) + np.random.rand(100)StratifiedKFold无法从连续分布中采样,因此您的错误。尝试更改这一行,您的代码将愉快地执行:from sklearn.linear_model import ElasticNetCVfrom sklearn.model_selection import KFold, StratifiedKFoldimport numpy as npx = np.arange(100, dtype=np.float64).reshape(-1, 1)y = np.random.choice([0,1], size=100)# KFold default implementation:model_default = ElasticNetCV(cv=5)model_default.fit(x, y)  # works fine# KFold given as cv explicitly:model_kfexp = ElasticNetCV(cv=KFold(5))model_kfexp.fit(x, y)  # also works fine# StratifiedKFold given as cv explicitly:model_skf = ElasticNetCV(cv=StratifiedKFold(5))model_skf.fit(x, y)  # no ERROR笔记如果您对连续数据进行采样,请使用KFold. 如果您的目标是明确的,您可以使用两者KFold并 使用StratifiedKFold适合您需要的任何一种。笔记2如果您坚持在连续数据上模拟分层抽样,您可能希望应用pandas.cut到您的数据,然后对该数据进行分层抽样,最后将结果(train_id, test_id)生成器传递给cvparam:x = np.arange(100, dtype=np.float64).reshape(-1, 1)y = np.arange(100) + np.random.rand(100)y_cat = pd.cut(y, 10, labels=range(10))skf_gen = StratifiedKFold(5).split(x, y_cat)model_skf = ElasticNetCV(cv=skf_gen)model_skf.fit(x, y)  # no ERROR
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python