После того, как я обучил некоторые веса CNN, я решил использовать ту же сетевую архитектуру для прогнозирования. Я установил свои данные batch_size = 64
.
Я могу правильно запустить функцию pred_net.forward()
и получить предсказанные классы из blobs['prob']
.
В моем наборе данных 20000 образцов. Если я вызову функцию forward()
i
раз, я получу 64*i
выборок, отправленных в сеть. Таким образом, я не могу охватить 20000 образцов, не отправив некоторые образцы дважды.
Поэтому я попробовал функцию forward_all()
. Но я получил исключение без какой-либо полезной информации. Я не знаю, что случилось.
Я ожидал, что forward()
и forward_all()
похожи(но нет).
Вот часть моего кода и сообщение об ошибке:
pred_net = caffe.Net(pred_net_proto_file, 'kg_trained.caffemodel', caffe.TEST)
pred_net.forward_all()
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-6-cefd35621a35> in <module>()
----> 1 pred_net.forward_all()
/home/microos/Space/caffe-master/python/caffe/pycaffe.pyc in _Net_forward_all(self, blobs, **kwargs)
197 all_outs[out] = np.asarray(all_outs[out])
198 # Discard padding.
--> 199 pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
200 if pad:
201 for out in all_outs:
StopIteration:
Надеюсь, я ясно описал ситуацию.