Skip to content

Commit

Permalink
Rewrote shownet so that it extends the model class now, making it eas…
Browse files Browse the repository at this point in the history
…y to run the model from that script. Added ability to show predictions made by the net.
  • Loading branch information
akrizhevsky committed Oct 3, 2011
1 parent 62149b5 commit abd3461
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 171 deletions.
7 changes: 4 additions & 3 deletions convdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
class CIFARDataProvider(LabeledMemoryDataProvider):
def __init__(self, data_dir, batch_range, init_epoch=1, init_batchnum=None, dp_params={}, test=False):
LabeledMemoryDataProvider.__init__(self, data_dir, batch_range, init_epoch, init_batchnum, dp_params, test)
data_mean = self.batch_meta['data_mean']

self.data_mean = self.batch_meta['data_mean']
self.num_colors = 3
# Subtract the mean from the data and make sure that both data and
# labels are in single-precision floating point.
for d in self.data_dic:
# This converts the data matrix to single precision and makes sure that it is C-ordered
d['data'] = n.require((d['data'] - data_mean), dtype=n.single, requirements='C')
d['data'] = n.require((d['data'] - self.data_mean), dtype=n.single, requirements='C')
d['labels'] = n.require(d['labels'].reshape((1, d['data'].shape[1])), dtype=n.single)

def get_next_batch(self):
Expand All @@ -55,6 +55,7 @@ def __init__(self, data_dir, batch_range=None, init_epoch=1, init_batchnum=None,
self.multiview = dp_params['multiview_test'] and test
self.num_views = 9
self.data_mult = self.num_views if self.multiview else 1
self.num_colors = 3

for d in self.data_dic:
d['data'] = n.require(d['data'], requirements='C')
Expand Down
16 changes: 8 additions & 8 deletions convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def init_model_state(self):
else:
ms['layers'] = LayerParser.parse_layers(self.layer_def, self.layer_params, self)

if self.op.get_value('multiview_test'):
logreg_name = self.op.get_value('logreg_name')
logreg_name = self.op.get_value('logreg_name')
if logreg_name:
try:
self.logreg_idx = [l['name'] for l in ms['layers']].index(logreg_name)
if ms['layers'][self.logreg_idx]['type'] != 'cost.logreg':
Expand Down Expand Up @@ -136,7 +136,7 @@ def get_options_parser(cls):
op.add_option("layer-def", "layer_def", StringOptionParser, "Layer definition file", set_once=True)
op.add_option("layer-params", "layer_params", StringOptionParser, "Layer parameter file")
op.add_option("check-grads", "check_grads", BooleanOptionParser, "Check gradients and quit?", default=0, excuses=['data_path','save_path','train_batch_range','test_batch_range'])
op.add_option("multiview-test", "multiview_test", BooleanOptionParser, "Cropped DP: test on multiple patches?", default=0)
op.add_option("multiview-test", "multiview_test", BooleanOptionParser, "Cropped DP: test on multiple patches?", default=0, requires=['logreg_name'])
op.add_option("crop-border", "crop_border", IntegerOptionParser, "Cropped DP: crop border size", default=4)
op.add_option("logreg-name", "logreg_name", StringOptionParser, "Cropped DP: logreg layer name", default="")

Expand All @@ -146,16 +146,16 @@ def get_options_parser(cls):
op.options["num_epochs"].default = 50000
op.options['dp_type'].default = None

DataProvider.register_data_provider('cifar', 'CIFAR', CIFARDataProvider)
DataProvider.register_data_provider('dummy-cn-n', 'Dummy ConvNet', DummyConvNetDataProvider)
DataProvider.register_data_provider('cifar-cropped', 'Cropped CIFAR', CroppedCIFARDataProvider)

return op

if __name__ == "__main__":
#nr.seed(5)
op = GPUModel.get_options_parser()

DataProvider.register_data_provider('cifar', 'CIFAR', CIFARDataProvider)
DataProvider.register_data_provider('dummy-cn-n', 'Dummy ConvNet', DummyConvNetDataProvider)
DataProvider.register_data_provider('cifar-cropped', 'Cropped CIFAR', CroppedCIFARDataProvider)


op, load_dic = IGPUModel.parse_options(op)
model = GPUModel("ConvNet", op, load_dic)
model.start()
1 change: 1 addition & 0 deletions include/common/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class Matrix {
* Use transpose() if you want to get the transpose of this matrix.
*/
inline void setTrans(bool trans) {
assert(!isView());
_trans = trans ? CblasTrans : CblasNoTrans;
}

Expand Down
5 changes: 1 addition & 4 deletions include/pyconvnet.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@
#define _QUOTEME(x) #x
#define QUOTEME(x) _QUOTEME(x)

#ifdef EXEC
int main(int argc, char** argv);
#else
extern "C" void INITNAME();

PyObject* initModel(PyObject *self, PyObject *args);
Expand All @@ -41,7 +38,7 @@ PyObject* finishBatch(PyObject *self, PyObject *args);
PyObject* checkGradients(PyObject *self, PyObject *args);
PyObject* syncWithHost(PyObject *self, PyObject *args);
PyObject* startMultiviewTest(PyObject *self, PyObject *args);
#endif
PyObject* startLabeler(PyObject *self, PyObject *args);

#endif /* PYCONVNET3_CUH */

10 changes: 10 additions & 0 deletions include/worker.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,15 @@ public:
void run();
};

class LabelWorker : public Worker {
protected:
CPUData* _data;
Matrix* _preds;
int _logregIdx;
public:
LabelWorker(ConvNet& convNet, CPUData& data, Matrix& preds, int logregIdx);
void run();
};

#endif /* WORKER_CUH */

2 changes: 1 addition & 1 deletion options.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def parse(self, eval_expr_defaults=False):
# check requirements
if o.prefixed_letter in dic:
for o2 in self.get_options_list(sort_order=self.SORT_LETTER):
if o2.name in o.requires and o2.prefixed_letter not in dic and o2.default is None:
if o2.name in o.requires and o2.prefixed_letter not in dic:
raise OptionMissingException("Option %s (%s) requires option %s (%s)" % (o.prefixed_letter, o.desc,
o2.prefixed_letter, o2.desc))
if eval_expr_defaults:
Expand Down
Loading

0 comments on commit abd3461

Please sign in to comment.