Commit 30dbee81 authored by Daniel Scheffler's avatar Daniel Scheffler

Merge branch 'enhancement/share_mp_variables' into 'master'

Enhancement/share mp variables

See merge request !8
parents 24ab86c9 195f9d4a
Pipeline #3810 failed with stages
in 11 minutes and 30 seconds
...@@ -799,6 +799,24 @@ class SensorMapGeometryTransformer(object): ...@@ -799,6 +799,24 @@ class SensorMapGeometryTransformer(object):
return data_sensorgeo return data_sensorgeo
_global_shared_lats = None
_global_shared_lons = None
_global_shared_data = None
def _initializer(lats, lons, data):
"""Declare global variables needed for SensorMapGeometryTransformer3D.to_map_geometry and to_sensor_geometry.
:param lats:
:param lons:
:param data:
"""
global _global_shared_lats, _global_shared_lons, _global_shared_data
_global_shared_lats = lats
_global_shared_lons = lons
_global_shared_data = data
class SensorMapGeometryTransformer3D(object): class SensorMapGeometryTransformer3D(object):
def __init__(self, lons, lats, resamp_alg='nearest', radius_of_influence=30, **opts): def __init__(self, lons, lats, resamp_alg='nearest', radius_of_influence=30, **opts):
# type: (np.ndarray, np.ndarray, str, int, Any) -> None # type: (np.ndarray, np.ndarray, str, int, Any) -> None
...@@ -854,12 +872,14 @@ class SensorMapGeometryTransformer3D(object): ...@@ -854,12 +872,14 @@ class SensorMapGeometryTransformer3D(object):
@staticmethod @staticmethod
def _to_map_geometry_2D(kwargs_dict): def _to_map_geometry_2D(kwargs_dict):
# type: (dict) -> Tuple[np.ndarray, tuple, str, int] # type: (dict) -> Tuple[np.ndarray, tuple, str, int]
SMGT2D = SensorMapGeometryTransformer(lons=kwargs_dict['lons_2D'], assert [var is not None for var in (_global_shared_lons, _global_shared_lats, _global_shared_data)]
lats=kwargs_dict['lats_2D'],
SMGT2D = SensorMapGeometryTransformer(lons=_global_shared_lons[:, :, kwargs_dict['band_idx']],
lats=_global_shared_lats[:, :, kwargs_dict['band_idx']],
resamp_alg=kwargs_dict['resamp_alg'], resamp_alg=kwargs_dict['resamp_alg'],
radius_of_influence=kwargs_dict['radius_of_influence'], radius_of_influence=kwargs_dict['radius_of_influence'],
**kwargs_dict['init_opts']) **kwargs_dict['init_opts'])
data_mapgeo, out_gt, out_prj = SMGT2D.to_map_geometry(data=kwargs_dict['data_sensor_geo_2D'], data_mapgeo, out_gt, out_prj = SMGT2D.to_map_geometry(data=_global_shared_data[:, :, kwargs_dict['band_idx']],
tgt_prj=kwargs_dict['tgt_prj'], tgt_prj=kwargs_dict['tgt_prj'],
tgt_extent=kwargs_dict['tgt_extent'], tgt_extent=kwargs_dict['tgt_extent'],
tgt_res=kwargs_dict['tgt_res']) tgt_res=kwargs_dict['tgt_res'])
...@@ -905,12 +925,9 @@ class SensorMapGeometryTransformer3D(object): ...@@ -905,12 +925,9 @@ class SensorMapGeometryTransformer3D(object):
del init_opts['nprocs'] # avoid sub-multiprocessing del init_opts['nprocs'] # avoid sub-multiprocessing
args = [dict( args = [dict(
lons_2D=self.lons[:, :, band],
lats_2D=self.lats[:, :, band],
resamp_alg=self.resamp_alg, resamp_alg=self.resamp_alg,
radius_of_influence=self.radius_of_influence, radius_of_influence=self.radius_of_influence,
init_opts=init_opts, init_opts=init_opts,
data_sensor_geo_2D=data[:, :, band],
tgt_prj=tgt_prj, tgt_prj=tgt_prj,
tgt_extent=tgt_extent, tgt_extent=tgt_extent,
tgt_res=tgt_res, tgt_res=tgt_res,
...@@ -918,9 +935,12 @@ class SensorMapGeometryTransformer3D(object): ...@@ -918,9 +935,12 @@ class SensorMapGeometryTransformer3D(object):
) for band in range(data.shape[2])] ) for band in range(data.shape[2])]
if self.opts['nprocs'] > 1 and self.mp_alg == 'bands': if self.opts['nprocs'] > 1 and self.mp_alg == 'bands':
with multiprocessing.Pool(self.opts['nprocs']) as pool: with multiprocessing.Pool(self.opts['nprocs'],
initializer=_initializer,
initargs=(self.lats, self.lons, data)) as pool:
result = pool.map(self._to_map_geometry_2D, args) result = pool.map(self._to_map_geometry_2D, args)
else: else:
_initializer(self.lats, self.lons, data)
result = [self._to_map_geometry_2D(argsdict) for argsdict in args] result = [self._to_map_geometry_2D(argsdict) for argsdict in args]
band_inds = list(np.array(result)[:, -1]) band_inds = list(np.array(result)[:, -1])
...@@ -933,12 +953,14 @@ class SensorMapGeometryTransformer3D(object): ...@@ -933,12 +953,14 @@ class SensorMapGeometryTransformer3D(object):
@staticmethod @staticmethod
def _to_sensor_geometry_2D(kwargs_dict): def _to_sensor_geometry_2D(kwargs_dict):
# type: (dict) -> (np.ndarray, int) # type: (dict) -> (np.ndarray, int)
SMGT2D = SensorMapGeometryTransformer(lons=kwargs_dict['lons_2D'], assert [var is not None for var in (_global_shared_lons, _global_shared_lats, _global_shared_data)]
lats=kwargs_dict['lats_2D'],
SMGT2D = SensorMapGeometryTransformer(lons=_global_shared_lons[:, :, kwargs_dict['band_idx']],
lats=_global_shared_lats[:, :, kwargs_dict['band_idx']],
resamp_alg=kwargs_dict['resamp_alg'], resamp_alg=kwargs_dict['resamp_alg'],
radius_of_influence=kwargs_dict['radius_of_influence'], radius_of_influence=kwargs_dict['radius_of_influence'],
**kwargs_dict['init_opts']) **kwargs_dict['init_opts'])
data_sensorgeo = SMGT2D.to_sensor_geometry(data=kwargs_dict['data_map_geo_2D'], data_sensorgeo = SMGT2D.to_sensor_geometry(data=_global_shared_data[:, :, kwargs_dict['band_idx']],
src_prj=kwargs_dict['src_prj'], src_prj=kwargs_dict['src_prj'],
src_extent=kwargs_dict['src_extent']) src_extent=kwargs_dict['src_extent'])
...@@ -960,21 +982,21 @@ class SensorMapGeometryTransformer3D(object): ...@@ -960,21 +982,21 @@ class SensorMapGeometryTransformer3D(object):
del init_opts['nprocs'] # avoid sub-multiprocessing del init_opts['nprocs'] # avoid sub-multiprocessing
args = [dict( args = [dict(
lons_2D=self.lons[:, :, band],
lats_2D=self.lats[:, :, band],
resamp_alg=self.resamp_alg, resamp_alg=self.resamp_alg,
radius_of_influence=self.radius_of_influence, radius_of_influence=self.radius_of_influence,
init_opts=init_opts, init_opts=init_opts,
data_map_geo_2D=data[:, :, band],
src_prj=src_prj, src_prj=src_prj,
src_extent=src_extent, src_extent=src_extent,
band_idx=band band_idx=band
) for band in range(data.shape[2])] ) for band in range(data.shape[2])]
if self.opts['nprocs'] > 1 and self.mp_alg == 'bands': if self.opts['nprocs'] > 1 and self.mp_alg == 'bands':
with multiprocessing.Pool(self.opts['nprocs']) as pool: with multiprocessing.Pool(self.opts['nprocs'],
initializer=_initializer,
initargs=(self.lats, self.lons, data)) as pool:
result = pool.map(self._to_sensor_geometry_2D, args) result = pool.map(self._to_sensor_geometry_2D, args)
else: else:
_initializer(self.lats, self.lons, data)
result = [self._to_sensor_geometry_2D(argsdict) for argsdict in args] result = [self._to_sensor_geometry_2D(argsdict) for argsdict in args]
band_inds = list(np.array(result)[:, -1]) band_inds = list(np.array(result)[:, -1])
......
__version__ = '0.14.15' __version__ = '0.14.16'
__versionalias__ = '20190322_01' __versionalias__ = '20190322_02'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment