Commit 3dda57ff authored by Sebastian Heimann's avatar Sebastian Heimann

restructure problem.evaluate

parent ccbb734c
......@@ -76,7 +76,7 @@ class Analyser(object):
isok_mask = num.logical_not(isbad_mask)
else:
isok_mask = None
ms = wproblem.evaluate(x, mask=isok_mask)[:, 1]
ms = wproblem.misfits(x, mask=isok_mask)[:, 1]
mss[iiter, :] = ms
isbad_mask = num.isnan(ms)
......
......@@ -250,7 +250,7 @@ def forward(rundir_or_config_path, event_names):
events = []
for (problem, x) in payload:
ds.empty_cache()
_, results = problem.evaluate(x, result_mode='full')
results = problem.evaluate(x)
event = problem.get_source(x).pyrocko_event()
events.append(event)
......@@ -376,7 +376,7 @@ def check(
if n_random_synthetics == 0:
x = problem.pack(problem.base_source)
sources.append(problem.base_source)
_, results = problem.evaluate(x, result_mode='full')
results = problem.evaluate(x)
results_list.append(results)
else:
......@@ -391,7 +391,7 @@ def check(
pass
sources.append(problem.get_source(x))
_, results = problem.evaluate(x, result_mode='full')
results = problem.evaluate(x)
results_list.append(results)
if show_waveforms:
......
......@@ -435,7 +435,7 @@ class HighScoreOptimizer(Optimizer):
else:
isok_mask = None
misfits = problem.evaluate(x, mask=isok_mask)
misfits = problem.misfits(x, mask=isok_mask)
isbad_mask_new = num.isnan(misfits[:, 0])
if isbad_mask is not None and num.any(
......
......@@ -903,7 +903,7 @@ def draw_fits_figures_statics(ds, history, optimizer, plt):
source = problem.get_source(xbest)
_, results = problem.evaluate(xbest, result_mode='full')
results = problem.evaluate(xbest)
figures = []
......@@ -1059,7 +1059,7 @@ def draw_fits_ensemble_figures(
model = models[imodel, :]
source = problem.get_source(model)
_, results = problem.evaluate(model, result_mode='full')
results = problem.evaluate(model)
dtraces.append([])
......@@ -1420,7 +1420,7 @@ def draw_fits_figures(ds, history, optimizer, plt):
target_to_result = {}
all_syn_trs = []
all_syn_specs = []
_, results = problem.evaluate(xbest, result_mode='full')
results = problem.evaluate(xbest)
dtraces = []
for target, result in zip(problem.waveform_targets, results):
......
......@@ -11,8 +11,8 @@ from pyrocko import gf, util, guts
from pyrocko.guts import Object, String, Bool, List, Dict, Int
from ..meta import ADict, Parameter, GrondError, xjoin
from ..targets import MisfitTarget, TargetGroup, WaveformMisfitTarget, \
SatelliteMisfitTarget
from ..targets import MisfitResult, MisfitTarget, TargetGroup, \
WaveformMisfitTarget, SatelliteMisfitTarget
guts_prefix = 'grond'
......@@ -353,7 +353,7 @@ class Problem(Object):
return self._family_mask
def evaluate(self, x, mask=None, result_mode='sparse'):
def evaluate(self, x, mask=None, result_mode='full'):
source = self.get_source(x)
engine = self.get_engine()
......@@ -371,16 +371,12 @@ class Problem(Object):
modelling_results = list(resp.results_list[0])
imt = 0
imisfit = 0
misfits = num.zeros((self.nmisfits, 2))
misfits.fill(None)
results = []
for itarget, target in enumerate(self.targets):
nmt_this = len(t2m_map[target])
if mask is None or mask[itarget]:
misfits[imisfit:imisfit+target.nmisfits, :], result = \
target.finalize_modelling(
modelling_results[imt:imt+nmt_this])
result = target.finalize_modelling(
modelling_results[imt:imt+nmt_this])
imt += nmt_this
else:
......@@ -388,12 +384,21 @@ class Problem(Object):
'target was excluded from modelling')
results.append(result)
return results
def misfits(self, x, mask=None):
results = self.evaluate(x, mask=mask, result_mode='sparse')
imisfit = 0
misfits = num.zeros((self.nmisfits, 2))
misfits.fill(None)
for target, result in zip(self.targets, results):
if isinstance(result, MisfitResult):
misfits[imisfit:imisfit+target.nmisfits, :] = result.misfits
imisfit += target.nmisfits
if result_mode == 'full':
return misfits, results
else:
return misfits
return misfits
class InvalidRundir(Exception):
......
......@@ -3,6 +3,7 @@ import copy
import numpy as num
from pyrocko import gf
from pyrocko.guts_array import Array
from pyrocko.guts import Object, Float
......@@ -29,7 +30,9 @@ class TargetAnalysisResult(Object):
class MisfitResult(Object):
pass
misfits = Array.T(
shape=(None, 2),
dtype=num.float)
class MisfitTarget(Object):
......@@ -100,7 +103,7 @@ class MisfitTarget(Object):
def init_modelling(self):
return []
def finalize_modelling(self, results):
def finalize_modelling(self, modelling_results):
raise NotImplemented('must be overloaded in subclass')
......
......@@ -2,7 +2,7 @@ import logging
import numpy as num
from pyrocko import gf
from pyrocko.guts import String, Bool, Dict, List, Object, Float
from pyrocko.guts import String, Bool, Dict, List, Object
from grond.meta import Parameter
......@@ -67,8 +67,6 @@ class SatelliteTargetGroup(TargetGroup):
class SatelliteMisfitResult(gf.Result, MisfitResult):
misfit_value = Float.T()
misfit_norm = Float.T()
statics_syn = Dict.T(optional=True)
statics_obs = Dict.T(optional=True)
......@@ -101,8 +99,7 @@ class SatelliteMisfitTarget(gf.SatelliteTarget, MisfitTarget):
self._target_ranges.pop(k)
return self._target_ranges
@property
def id(self):
def string_id(self):
return self.scene_id
def set_dataset(self, ds):
......@@ -135,8 +132,7 @@ class SatelliteMisfitTarget(gf.SatelliteTarget, MisfitTarget):
num.sum((stat_obs * scene.covariance.weight_vector)**2))
result = SatelliteMisfitResult(
misfit_value=misfit_value,
misfit_norm=misfit_norm)
misfits=num.array([[misfit_value, misfit_norm]], dtype=num.float))
if self._result_mode == 'full':
result.statics_syn = statics
......
......@@ -149,29 +149,10 @@ class WaveformTargetGroup(TargetGroup):
targets.append(target)
if self.limit:
return self.weed(origin, targets, self.limit)[0]
return weed(origin, targets, self.limit)[0]
else:
return targets
@staticmethod
def weed(origin, targets, limit, neighborhood=3):
azimuths = num.zeros(len(targets))
dists = num.zeros(len(targets))
for i, target in enumerate(targets):
_, azimuths[i] = target.azibazi_to(origin)
dists[i] = target.distance_to(origin)
badnesses = num.ones(len(targets), dtype=float)
deleted, meandists_kept = weeding.weed(
azimuths, dists, badnesses,
nwanted=limit,
neighborhood=neighborhood)
targets_weeded = [
target for (delete, target) in zip(deleted, targets) if not delete]
return targets_weeded, meandists_kept, deleted
class TraceSpectrum(Object):
network = String.T()
......@@ -190,8 +171,6 @@ class TraceSpectrum(Object):
class WaveformMisfitResult(gf.Result, MisfitResult):
misfit_value = Float.T()
misfit_norm = Float.T()
processed_obs = Trace.T(optional=True)
processed_syn = Trace.T(optional=True)
filtered_obs = Trace.T(optional=True)
......@@ -217,15 +196,6 @@ class WaveformMisfitTarget(gf.Target, MisfitTarget):
def string_id(self):
return '.'.join(x for x in (self.path,) + self.codes if x)
@property
def id(self):
return '.'.join(self.codes)
def get_plain_modelling_targets(self):
d = dict(
(k, getattr(self, k)) for k in gf.Target.T.propnames)
return [gf.Target(**d)]
def get_combined_weight(self, apply_balancing_weights):
w = self.manual_weight
if apply_balancing_weights:
......@@ -369,16 +339,13 @@ class WaveformMisfitTarget(gf.Target, MisfitTarget):
def prepare_modelling(self):
return [self]
def finalize_modelling(self, results):
result = results[0]
if isinstance(result, gf.SeismosizerError):
misfits = num.array(
[[None, None]], dtype=num.float)
def finalize_modelling(self, modelling_results):
return modelling_results[0]
else:
misfits = num.array(
[[result.misfit_value, result.misfit_norm]], dtype=num.float)
return targets
return misfits, result
def misfit(
......@@ -472,8 +439,7 @@ tautoshift**2 / tautoshift_max**2``
if result_mode == 'full':
result = WaveformMisfitResult(
misfit_value=float(m),
misfit_norm=float(n),
misfits=num.array([[m, n]], dtype=num.float),
processed_obs=tr_proc_obs,
processed_syn=tr_proc_syn,
filtered_obs=tr_obs.copy(),
......@@ -486,8 +452,7 @@ tautoshift**2 / tautoshift_max**2``
elif result_mode == 'sparse':
result = WaveformMisfitResult(
misfit_value=m,
misfit_norm=n)
misfits=num.array([[m, n]], dtype=num.float))
else:
assert False
......@@ -568,6 +533,25 @@ def float_or_none(x):
return float(x)
def weed(origin, targets, limit, neighborhood=3):
azimuths = num.zeros(len(targets))
dists = num.zeros(len(targets))
for i, target in enumerate(targets):
_, azimuths[i] = target.azibazi_to(origin)
dists[i] = target.distance_to(origin)
badnesses = num.ones(len(targets), dtype=float)
deleted, meandists_kept = weeding.weed(
azimuths, dists, badnesses,
nwanted=limit,
neighborhood=neighborhood)
targets_weeded = [
target for (delete, target) in zip(deleted, targets) if not delete]
return targets_weeded, meandists_kept, deleted
__all__ = '''
WaveformTargetGroup
WaveformMisfitConfig
......
......@@ -93,7 +93,7 @@ class ToyProblem(Problem):
[t.obs_distance for t in self.targets],
dtype=num.float)
def evaluate(self, x, mask=None):
def misfits(self, x, mask=None):
self._setup_modelling()
distances = num.sqrt(
num.sum((x[num.newaxis, :]-self._xtargets)**2, axis=1))
......@@ -104,7 +104,7 @@ class ToyProblem(Problem):
* num.mean(num.abs(self._obs_distances))
return misfits
def evaluate_many(self, xs):
def misfits_many(self, xs):
self._setup_modelling()
distances = num.sqrt(
num.sum(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment