1.5 神经元实现
分拆数据集
def load_data(filename): """read data from data file.""" with open(filename, 'rb') as f: data = pickle.load(f, encoding='bytes') return data[b'data'], data[b'labels']# trensorflow.DataSetclass CifarData: def __init__(self, filenames, need_shuffle): all_data = [] all_labels = [] for filename in filenames: data,labels = load_data(filename) for item,label in zip(data,labels): if label in [0,1]: all_data.append(item) all_labels.append(label) self._data = np.vstack(all_data) # 归一化,将0-255的数归一成0-1直接的数 self._data = self._data / 127.5 - 1 self._labels = np.hstack(all_labels) self._num_examples = self._data.shape[0] self._need_shuffle = need_shuffle self._indicator = 0 if self._need_shuffle: self._shuffle_data() def _shuffle_data(self): # 混排 [0,1,2,3,4,5] -> [2,1,4,0,3,5] p = np.random.permutation(self._num_examples) self._data = self._data[p] self._labels = self._labels[p] def next_batch(self, batch_size): """return batch_size examples as a batch.""" end_indicator = self._indicator + batch_size if end_indicator > self._num_examples: if self._need_shuffle: self._shuffle_data() self._indicator = 0 end_indicator = batch_size else: raise Exception("have no more examples") if end_indicator > self._num_examples: raise Exception("batch size is lager then all examples") batch_data = self._data[self._indicator:end_indicator] batch_labels = self._labels[self._indicator:end_indicator] self._indicator = end_indicator return batch_data, batch_labels train_filename = [os.path.join(CIFAR_DIR,'data_batch_%d' % i) for i in range(1,6)] test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')] train_data = CifarData(train_filename, True) test_data = CifarData(test_filenames, False) batch_data,batch_labels = train_data.next_batch(10)
测试算法准确率
init = tf.global_variables_initializer() batch_size = 20train_steps = 100000test_steps = 100with tf.Session() as sess: sess.run(init) for i in range(train_steps): batch_data, batch_labels = train_data.next_batch(batch_size) loss_val, acc_val, _ = sess.run( [loss, accuracy, train_op], feed_dict={ x: batch_data, y: batch_labels}) if (i+1) % 500 == 0: print ('[Train] Step: %d, loss: %4.5f, acc: %4.5f' \ % (i+1, loss_val, acc_val)) if (i+1) % 5000 == 0: test_data = CifarData(test_filenames, False) all_test_acc_val = [] for j in range(test_steps): test_batch_data, test_batch_labels \ = test_data.next_batch(batch_size) test_acc_val = sess.run( [accuracy], feed_dict = { x: test_batch_data, y: test_batch_labels }) all_test_acc_val.append(test_acc_val) test_acc = np.mean(all_test_acc_val) print('[Test ] Step: %d, acc: %4.5f' % (i+1, test_acc))
作者:Meet相识_bfa5
链接:https://www.jianshu.com/p/fed88fcd3428