Commit 3307de00 authored by Sebastian Heimann's avatar Sebastian Heimann

selective return of traces and spectra in misfit results

parent 0ffb76c0
......@@ -164,10 +164,13 @@ class CMTProblem(core.Problem):
return out
def evaluate(self, x, return_traces=False):
def evaluate(self, x, result_mode='sparse'):
source = self.unpack(x)
engine = self.get_engine()
for target in self.targets:
target.set_result_mode(result_mode)
resp = engine.process(source, self.targets)
data = []
results = []
......@@ -177,15 +180,15 @@ class CMTProblem(core.Problem):
'%s.%s.%s.%s: %s' % (target.codes + (str(result),)))
data.append((None, None))
if return_traces:
if result_mode == 'sparse':
results.append(None)
else:
data.append((result.misfit_value, result.misfit_norm))
if return_traces:
if result_mode == 'full':
results.append(result)
ms, ns = num.array(data, dtype=num.float).T
if return_traces:
if result_mode == 'full':
return ms, ns, results
else:
return ms, ns
......
......@@ -278,6 +278,7 @@ class MisfitTarget(gf.Target):
def __init__(self, **kwargs):
gf.Target.__init__(self, **kwargs)
self._ds = None
self._result_mode = 'sparse'
def string_id(self):
return '.'.join(x for x in (
......@@ -294,6 +295,9 @@ class MisfitTarget(gf.Target):
def set_dataset(self, ds):
self._ds = ds
def set_result_mode(self, result_mode):
self._result_mode = result_mode
def get_combined_weight(self, apply_balancing_weights):
w = self.manual_weight
if apply_balancing_weights:
......@@ -396,7 +400,8 @@ class MisfitTarget(gf.Target):
tmax_fit + tfade),
domain=config.domain,
exponent=2,
flip=self.flip_norm)
flip=self.flip_norm,
result_mode=self._result_mode)
mr.tobs_shift = float(tobs_shift)
mr.tsyn_pick = float_or_none(tsyn)
......@@ -408,7 +413,8 @@ class MisfitTarget(gf.Target):
raise gf.SeismosizerError('no waveform data, %s' % str(e))
def misfit(tr_obs, tr_syn, taper, domain, exponent, flip):
def misfit(
tr_obs, tr_syn, taper, domain, exponent, flip, result_mode='sparse'):
'''
Calculate misfit between observed and synthetic trace.
......@@ -421,6 +427,8 @@ def misfit(tr_obs, tr_syn, taper, domain, exponent, flip):
:param exponent: exponent of Lx type norms
:param flip: ``bool``, if set to ``True``, normalization factor is
computed against *tr_syn* rather than *tr_obs*
:param result_mode: ``'full'``, include traces and spectra or ``'sparse'``,
include only misfit and normalization factor in result
:returns: object of type :py:class:`MisfitResult`
'''
......@@ -459,18 +467,25 @@ def misfit(tr_obs, tr_syn, taper, domain, exponent, flip):
m, n = trace.Lx_norm(num.abs(a), num.abs(b), norm=exponent)
result = MisfitResult(
misfit_value=m,
misfit_norm=n,
processed_obs=tr_proc_obs,
processed_syn=tr_proc_syn,
filtered_obs=tr_obs,
filtered_syn=tr_syn,
spectrum_obs=trspec_proc_obs,
spectrum_syn=trspec_proc_syn,
taper=taper,
cc_shift=cc_shift,
cc=ctr)
if result_mode == 'full':
result = MisfitResult(
misfit_value=m,
misfit_norm=n,
processed_obs=tr_proc_obs,
processed_syn=tr_proc_syn,
filtered_obs=tr_obs.copy(),
filtered_syn=tr_syn,
spectrum_obs=trspec_proc_obs,
spectrum_syn=trspec_proc_syn,
taper=taper,
cc_shift=cc_shift,
cc=ctr)
elif result_mode == 'sparse':
result = MisfitResult(
misfit_value=m,
misfit_norm=n)
else:
assert False
return result
......@@ -1423,7 +1438,7 @@ def forward(rundir_or_config_path, event_names=None):
events = []
for (problem, x) in payload:
ds.empty_cache()
ms, ns, results = problem.evaluate(x, return_traces=True)
ms, ns, results = problem.evaluate(x, result_mode='full')
event = problem.unpack(x).pyrocko_event()
events.append(event)
......@@ -1526,7 +1541,7 @@ def check(config, event_names=None):
for i in xrange(10):
x = problem.random_uniform(xbounds)
print x
ms, ns, results = problem.evaluate(x, return_traces=True)
ms, ns, results = problem.evaluate(x, result_mode='full')
results_list.append(results)
for itarget, target in enumerate(problem.targets):
......
......@@ -858,13 +858,7 @@ def draw_fits_figures(ds, model, plt):
target_to_result = {}
all_syn_trs = []
ms, ns, results = problem.evaluate(xbest, return_traces=True)
for result in results:
if result is not None:
result.filtered_obs = result.filtered_obs.copy()
result.filtered_syn = result.filtered_syn.copy()
result.processed_obs = result.processed_obs.copy()
result.processed_syn = result.processed_syn.copy()
ms, ns, results = problem.evaluate(xbest, result_mode='full')
dtraces = []
for target, result in zip(problem.targets, results):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment