Commit 2dbb3e55 authored by Marius Kriegerowski's avatar Marius Kriegerowski

refactor hyperparameter optimization

parent 7214aa65
......@@ -20,7 +20,7 @@ def delete_if_exists(dirname):
class Model(Object):
optimizer = Optimizer.T(optional=True)
hyperparameter_optimizer = Optimizer.T(optional=True)
data_generator = DataGeneratorBase.T()
dropout_rate = Float.T(optional=True)
batch_size = Int.T(default=10)
......@@ -180,10 +180,9 @@ class Model(Object):
return result
def optimize(self):
if self.optimizer is None:
print('No optimizer defined')
sys.exit()
self.optimizer.optimize(self)
if self.hyperparameter_optimizer is None:
sys.exit('No hyperparameter optimizer defined in config file')
self.hyperparameter_optimizer.optimize(self)
def main():
......@@ -246,18 +245,14 @@ def main():
print('file exists: %s' % fn_config)
sys.exit()
data_generator = GFData.get_example()
optimizer = Optimizer.get_example()
model = Model(
tf_config=tf_config,
data_generator=data_generator,
optimizer=optimizer)
data_generator=GFSwarmData.get_example(),
hyperparameter_optimizer=optimizer)
model = Model(tf_config=tf_config, data_generator=data_generator)
model.regularize()
# model.dump(filename=fn_config)
print(model)
# print('created a fresh "%s"' % fn_config)
if args.train and args.optimize:
print('Can only use --train or --optimize')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment