test_train_split 将字符串类型标签转换为 np.array。

我有一个带有字符串类型标签名称的图像数据集。当我使用 sklearn 库的 test_train_split 拆分数据时,它将标签转换为 np.array 类型。有没有办法找回原来的字符串类型标签名称?


以下代码拆分数据以进行训练和测试:


imgs, y = load_images()

train_img,ytrain_img,test_img,ytest_img = train_test_split(imgs,y, test_size=0.2, random_state=1)

如果我打印 y,它会给我标签名称,但如果我打印拆分的标签值,它会给出一个数组:


for k in y:

    print(k)

    break

for k in ytrain_img:

    print(k)

    break

输出:


 001.Affenpinscher

 [[[ 97 180 165]

  [ 93 174 159]

  [ 91 169 152]

  ...

 [[ 88 171 156]

 [ 88 170 152]

 [ 84 162 145]

 ...

 [130 209 222]

 [142 220 233]

 [152 230 243]]


 [[ 99 181 163]

 [ 98 178 161]

 [ 92 167 151]

 ...

 [130 212 224]

 [137 216 229]

 [143 222 235]]

 ...

 [[ 85 147 158]

 [ 85 147 158]

 [111 173 184]

 ...

 [227 237 244]

 [236 248 250]

 [234 248 247]]


 [[ 94 154 166]

 [ 96 156 168]

 [133 194 204]

 ...

[226 238 244]

[237 249 253]

[237 252 254]]

...

[228 240 246]

[238 252 255]

[241 255 255]]]

有没有办法将数组转换回原始标签名称?


皈依舞
浏览 231回答 1
1回答

一只斗牛犬

不,您是在推断train_test_split错误的输出。train_test_split 以这种方式工作:A_train, A_test, B_train, B_test, C_train, C_test ...                              = train_test_split(A, B, C ..., test_size=0.2)您可以提供尽可能多的数组进行拆分。对于每个给定的数组,它将首先提供训练和测试拆分,然后对下一个数组执行相同操作,然后是第三个数组,依此类推。所以在你的情况下实际上是:train_img, test_img, ytrain_img, ytest_img = train_test_split(imgs, y,                                                               test_size=0.2,                                                               random_state=1)但是您随后混淆了输出的名称并错误地使用了它们。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python