Commit 30dbee81 authored by Daniel Scheffler's avatar Daniel Scheffler

Merge branch 'enhancement/share_mp_variables' into 'master'

Enhancement/share mp variables

See merge request !8
parents 24ab86c9 195f9d4a
Pipeline #3810 failed with stages
in 11 minutes and 30 seconds
......@@ -799,6 +799,24 @@ class SensorMapGeometryTransformer(object):
return data_sensorgeo
_global_shared_lats = None
_global_shared_lons = None
_global_shared_data = None
def _initializer(lats, lons, data):
"""Declare global variables needed for SensorMapGeometryTransformer3D.to_map_geometry and to_sensor_geometry.
:param lats:
:param lons:
:param data:
"""
global _global_shared_lats, _global_shared_lons, _global_shared_data
_global_shared_lats = lats
_global_shared_lons = lons
_global_shared_data = data
class SensorMapGeometryTransformer3D(object):
def __init__(self, lons, lats, resamp_alg='nearest', radius_of_influence=30, **opts):
# type: (np.ndarray, np.ndarray, str, int, Any) -> None
......@@ -854,12 +872,14 @@ class SensorMapGeometryTransformer3D(object):
@staticmethod
def _to_map_geometry_2D(kwargs_dict):
# type: (dict) -> Tuple[np.ndarray, tuple, str, int]
SMGT2D = SensorMapGeometryTransformer(lons=kwargs_dict['lons_2D'],
lats=kwargs_dict['lats_2D'],
assert [var is not None for var in (_global_shared_lons, _global_shared_lats, _global_shared_data)]
SMGT2D = SensorMapGeometryTransformer(lons=_global_shared_lons[:, :, kwargs_dict['band_idx']],
lats=_global_shared_lats[:, :, kwargs_dict['band_idx']],
resamp_alg=kwargs_dict['resamp_alg'],
radius_of_influence=kwargs_dict['radius_of_influence'],
**kwargs_dict['init_opts'])
data_mapgeo, out_gt, out_prj = SMGT2D.to_map_geometry(data=kwargs_dict['data_sensor_geo_2D'],
data_mapgeo, out_gt, out_prj = SMGT2D.to_map_geometry(data=_global_shared_data[:, :, kwargs_dict['band_idx']],
tgt_prj=kwargs_dict['tgt_prj'],
tgt_extent=kwargs_dict['tgt_extent'],
tgt_res=kwargs_dict['tgt_res'])
......@@ -905,12 +925,9 @@ class SensorMapGeometryTransformer3D(object):
del init_opts['nprocs'] # avoid sub-multiprocessing
args = [dict(
lons_2D=self.lons[:, :, band],
lats_2D=self.lats[:, :, band],
resamp_alg=self.resamp_alg,
radius_of_influence=self.radius_of_influence,
init_opts=init_opts,
data_sensor_geo_2D=data[:, :, band],
tgt_prj=tgt_prj,
tgt_extent=tgt_extent,
tgt_res=tgt_res,
......@@ -918,9 +935,12 @@ class SensorMapGeometryTransformer3D(object):
) for band in range(data.shape[2])]
if self.opts['nprocs'] > 1 and self.mp_alg == 'bands':
with multiprocessing.Pool(self.opts['nprocs']) as pool:
with multiprocessing.Pool(self.opts['nprocs'],
initializer=_initializer,
initargs=(self.lats, self.lons, data)) as pool:
result = pool.map(self._to_map_geometry_2D, args)
else:
_initializer(self.lats, self.lons, data)
result = [self._to_map_geometry_2D(argsdict) for argsdict in args]
band_inds = list(np.array(result)[:, -1])
......@@ -933,12 +953,14 @@ class SensorMapGeometryTransformer3D(object):
@staticmethod
def _to_sensor_geometry_2D(kwargs_dict):
# type: (dict) -> (np.ndarray, int)
SMGT2D = SensorMapGeometryTransformer(lons=kwargs_dict['lons_2D'],
lats=kwargs_dict['lats_2D'],
assert [var is not None for var in (_global_shared_lons, _global_shared_lats, _global_shared_data)]
SMGT2D = SensorMapGeometryTransformer(lons=_global_shared_lons[:, :, kwargs_dict['band_idx']],
lats=_global_shared_lats[:, :, kwargs_dict['band_idx']],
resamp_alg=kwargs_dict['resamp_alg'],
radius_of_influence=kwargs_dict['radius_of_influence'],
**kwargs_dict['init_opts'])
data_sensorgeo = SMGT2D.to_sensor_geometry(data=kwargs_dict['data_map_geo_2D'],
data_sensorgeo = SMGT2D.to_sensor_geometry(data=_global_shared_data[:, :, kwargs_dict['band_idx']],
src_prj=kwargs_dict['src_prj'],
src_extent=kwargs_dict['src_extent'])
......@@ -960,21 +982,21 @@ class SensorMapGeometryTransformer3D(object):
del init_opts['nprocs'] # avoid sub-multiprocessing
args = [dict(
lons_2D=self.lons[:, :, band],
lats_2D=self.lats[:, :, band],
resamp_alg=self.resamp_alg,
radius_of_influence=self.radius_of_influence,
init_opts=init_opts,
data_map_geo_2D=data[:, :, band],
src_prj=src_prj,
src_extent=src_extent,
band_idx=band
) for band in range(data.shape[2])]
if self.opts['nprocs'] > 1 and self.mp_alg == 'bands':
with multiprocessing.Pool(self.opts['nprocs']) as pool:
with multiprocessing.Pool(self.opts['nprocs'],
initializer=_initializer,
initargs=(self.lats, self.lons, data)) as pool:
result = pool.map(self._to_sensor_geometry_2D, args)
else:
_initializer(self.lats, self.lons, data)
result = [self._to_sensor_geometry_2D(argsdict) for argsdict in args]
band_inds = list(np.array(result)[:, -1])
......
__version__ = '0.14.15'
__versionalias__ = '20190322_01'
__version__ = '0.14.16'
__versionalias__ = '20190322_02'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment