Commit bc685a75 authored by Marius Kriegerowski's avatar Marius Kriegerowski

hyper parameter optimization plots (needs testing)

parent 1ddedccd
import os
from .util import delete_if_exists
from skopt import gp_minimize
from skopt import dump as dump_result
from skopt import load as load_result
from skopt.space import Real, Categorical, Integer
from skopt.plots import plot_convergence, plot_objective_2D
from skopt.plots import plot_objective, plot_evaluations
try:
from skopt.plots import plot_histogram
_plot_histogram_error = False
except ImportError as e:
_plot_histogram_error = e
logger.debug(e)
from pyrocko.guts import Object, Int, Float, List, Tuple, String
import logging
logger = logging.getLogger()
def to_skopt_real(x, name, prior):
return Real(low=x[0], high=x[1], prior=prior, name=name)
......@@ -15,12 +27,13 @@ class Optimizer(Object):
learning_rate = Tuple.T(3, Float.T(), default=(1e-3, 1e-5, 1e-4)) # low, high, default
n_calls = Int.T(default=50, help='number of test sets')
log_path = String.T(default='./logs/')
path_best = String.T(default='winner')
path_out = String.T(default='optimizer-results', help='base path where to store results, plots and logs')
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = None
self.result = None
self.fn_result = self.extend_path('result.optmz')
# self.dimensions = [
# to_skopt_real(self.learning_rate, 'learning_rate', 'log-uniform')]
self.optimizer_defaults = [
......@@ -52,7 +65,14 @@ class Optimizer(Object):
def optimizer_values(self):
return [default for (k, default) in self.optimizer_defaults]
@property
def non_categorical_dimensions(self):
'''Returns a list of non-categorical dimension names.'''
return [dim.name for dim in self.dimensions if not \
isinstance(dim, Categorical)]
def announce_test(self, params):
'''Log a parameter test set. '''
logger.info('+' * 20)
logger.info('evaluating next set of parameters:')
base =' {}: {}\n'
......@@ -72,18 +92,64 @@ class Optimizer(Object):
self.model = model
if self.model.auto_clear:
delete_if_exists(self.log_path)
delete_if_exists(self.path_out)
gp_minimize(
self.result = gp_minimize(
func=self.evaluate,
dimensions=self.dimensions,
acq_func='EI', # Expected Improvement
n_calls=self.n_calls,
x0=self.optimizer_values,
)
x0=self.optimizer_values)
dump_result(self.result, self.fn_result)
self.evaluate_result()
self.plot_results()
def ensure_result(self):
''' Load and set minimizer result.'''
if self.result is None:
self.result = load_result(self.fn_result)
else:
logger.warn(
'Cannot load results from filename: %s' % self.fn_result)
def extend_path(self, *path):
'''Prepend `self.path_out` to `path`.'''
return os.path.join(self.path_out, *path)
def evaluate_result(self):
self.ensure_result()
best = self.result.space.point_to_dict(self.result.x)
logger.info('Best parameter set:')
logger.info(best)
logger.info('Best parameter loss:')
logger.info(self.result.fun)
def plot_results(self):
'''Produce and save result plots. '''
self.ensure_result()
if _plot_histogram_error:
logger.warn(_plot_histogram_error)
else:
for dim_name in self.optimizer_keys:
fig, ax = plot_histogram(result=self.result, dimension_name=dim_name)
fig.savefig(extend_path('plots/histogram_%s.pdf' % dim_name))
fig, ax = plot_objective(
result=self.result,
dimension_names=self.non_categorical_dimensions)
fig.savefig(extend_path('plots/objectives.pdf'))
fig, ax = plot_evaluations(
result=self.result,
dimension_names=self.non_categorical_dimensions)
fig.savefig(extend_path('plots/evaluations.pdf'))
def log_dir_name(self, params):
'''Helper functions to transform `params` into a logging directory
'''Helper function to transform `params` into a logging directory
name.'''
placeholders = '{}_{}_' * len(params)
......@@ -93,12 +159,14 @@ class Optimizer(Object):
identifiers.append(v)
placeholders = placeholders.format(*identifiers)
log_dir = self.log_path + placeholders
log_dir = self.extend_path('tf_logs/' + placeholders)
logger.info('Created new logging directory: %s' % log_dir)
return log_dir
@classmethod
def get_example(cls):
'''Get an example instance of this class.'''
return cls()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment