创建的文件.npy应该包含200MB数据集,几乎为空

我正在学习本教程,其主要目标是平衡数据并将其保存到第二个培训数据表中(第一个数据表包含不平衡的数据)。这是代码:


import numpy as np

import pandas as pd

from collections import Counter

from random import shuffle


train_data = np.load('training_data.npy')


df = pd.DataFrame(train_data)

print(df.head())

print(Counter(df[1].apply(str)))


lefts = []

rights = []

forwards = []


shuffle(train_data)


for data in train_data:

    img = data[0]

    choice = data[1]


    if choice == [1,0,0]:

        lefts.append([img,choice])

    elif choice == [0,1,0]:

        forwards.append([img,choice])

    elif choice == [0,0,1]:

        rights.append([img,choice])

    else:

        print('no matches')



forwards = forwards[:len(lefts)][:len(rights)]

lefts = lefts[:len(forwards)]

rights = rights[:len(forwards)]


final_data = forwards + lefts + rights

shuffle(final_data)


np.save('training_data_v2.npy', final_data)

我真的不明白为什么它创建了120B文件,而数据集却重达200MB。


波斯汪
浏览 272回答 1
1回答

富国沪深

所以主要的问题在于这三行forwards = forwards[:len(lefts)][:len(rights)]lefts = lefts[:len(forwards)]rights = rights[:len(forwards)]您正在截断数组。因此,要确认阵列的最终形状,请执行以下操作:print(len(forwards),len(lefts),len(rights))// those 3 linesprint(len(forwards),len(lefts),len(rights))您会看到差异。另外,尝试在不使用这三行代码的情况下运行代码,数组将为200 MB :)附言:我建议您手动进行截断-forwards = forwards[:my_number]等等..
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python