“无效的分类数据:期望标签值”

我正在尝试在 java 中使用深度学习来训练模型,当我开始训练训练数据时它会出错


Invalid classification data: expect label value (at label index column = 0) to be in range 0 to 1 inclusive (0 to numClasses-1, with numClasses=2); got label value of 2

我不明白这个错误,因为我是深度学习 4j 的初学者。我正在使用一个查看两个人之间关系的数据集(如果两个人之间存在关系,那么类标签将为 1,否则为 0)。


Java 代码


public class SNA {

private static Logger log = LoggerFactory.getLogger(SNA.class);


public static void main(String[] args) throws Exception {

    int seed = 123;

    double learningRate = 0.01;

    int batchSize = 50;

    int nEpochs = 30;

    int numInputs = 2;

    int numOutputs = 2;

    int numHiddenNodes = 20;


    //load the training data

    RecordReader rr = new CSVRecordReader(0,",");

    rr.initialize(new FileSplit(new File("C:\\Users\\GTS\\Desktop\\SNA project\\experiments\\First experiment\\train\\slashdotTrain.csv")));

    DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize,0, 2);


    // load test data

    RecordReader rrTest = new CSVRecordReader();

    rr.initialize(new FileSplit(new File("C:\\Users\\GTS\\Desktop\\SNA project\\experiments\\First experiment\\test\\slashdotTest.csv")));

    DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize,0, 2);


    log.info("**** Building Model ****");

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

            .seed(seed)

            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)

            .iterations(1)

            .learningRate(learningRate)

            .updater(Updater.NESTEROVS).momentum(0.9)

            .list()

            .layer(0, new DenseLayer.Builder()

                    .nIn(numInputs)

                    .nOut(numHiddenNodes)

                    .activation("relu")

                    .weightInit(WeightInit.XAVIER)

                    .build())

    }

}

}


有什么帮助吗?多谢


一只名叫tom的猫
浏览 149回答 1
1回答

繁花如伊

解决问题:将RecordReaderDataSetIteratorin的第三个参数DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize,0, 2);由0改为2;因为数据集有三列,类标签的索引是 2,因为它是第三列。解决方案:DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize,2, 2);
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Java