Commit 60a2a353 authored by Daniel Scheffler's avatar Daniel Scheffler

Fixed behaviour of GeoArray.__getitem__() unequal to numpy behaviour (caused...

Fixed behaviour of GeoArray.__getitem__() unequal to numpy behaviour (caused issue #18). Added tests.
parent d33ac744
Pipeline #3197 passed with stages
in 1 minute and 26 seconds
......@@ -759,6 +759,7 @@ class GeoArray(object):
# populate rS, rE, cS, cE, bS, bE, bL
if getitem_params:
# populate rS, rE, cS, cE
if len(getitem_params) >= 2:
givenR, givenC = getitem_params[:2]
if isinstance(givenR, slice):
......@@ -773,6 +774,8 @@ class GeoArray(object):
elif isinstance(givenC, int):
cS = givenC
cE = givenC
# populate bS, bE, bL
if len(getitem_params) in [1, 3]:
givenB = getitem_params[2] if len(getitem_params) == 3 else getitem_params[0]
if isinstance(givenB, slice):
......@@ -823,9 +826,38 @@ class GeoArray(object):
# NOTE: # bandlist must be string because truth value of an array with more than one element is ambiguous
arr_pos = dict(rS=rS, rE=rE, cS=cS, cE=cE, bS=bS, bE=bE, bL=bL)
def _ensure_np_shape_3D_2D(arr):
"""Match numpy output shape according to the given indexing parameters.
This may require 3D to 2D conversion in case out_arr can be represented by a 2D array AND index has been
provided as integer (avoids shapes like (1,2,2). It also may require 2D to 3D conversion in case only one
band has been requested and the 3rd dimension has been provided as a slice.
NOTE: -> numpy also returns a 2D array in that case
NOTE: if array is indexed with a slice -> keep it a 3D array
"""
# 2D -> 3D
if arr.ndim == 2 and isinstance(getitem_params, (tuple, list)) and len(getitem_params) == 3 and \
isinstance(getitem_params[2], slice):
arr = arr[:, :, np.newaxis]
# 3D -> 2D
if 1 in arr.shape:
outshape = []
for i, sh in enumerate(arr.shape):
if sh == 1 and isinstance(getitem_params[i], (int, float)):
pass
else:
outshape.append(sh)
arr = arr.reshape(*outshape)
return arr
# check if the requested array position is already in cache -> if yes, return it from there
if self._arr_cache is not None and self._arr_cache['pos'] == arr_pos:
out_arr = self._arr_cache['arr_cached']
out_arr = _ensure_np_shape_3D_2D(out_arr)
else:
# TODO insert a multiprocessing.Lock here in order to prevent IO bottlenecks?
......@@ -854,10 +886,7 @@ class GeoArray(object):
del ds
# 3D to 2D conversion in case out_arr can be represented by a 2D array (avoids shapes like (1,2,2
# NOTE: -> numpy also returns a 2D array in that case
if 1 in out_arr.shape:
out_arr = out_arr.reshape(*[i for i in out_arr.shape if i != 1])
out_arr = _ensure_np_shape_3D_2D(out_arr)
# only set self.arr if the whole cube has been read (in order to avoid sudden shape changes)
if out_arr.shape == self.shape:
......
......@@ -309,21 +309,24 @@ class Test_GeoarrayAppliedOnPathArray(unittest.TestCase):
R, C, B = self.testtiff.shape # (10, 11, 2)
# test full array
validate(self.testtiff[:], (R, C, B))
# test row/col subset
validate(self.testtiff[:1, :3, :], (1, 3, B)) # only one row is requested, given as a slice
validate(self.testtiff[0, :3, :], (3, B)) # only one row is requested, given as an int
validate(self.testtiff[2:5, :3], (3, 3, B)) # third dimension is not given
validate(self.testtiff[2:5, :3, :], (3, 3, B))
# test band subset
validate(self.testtiff[:, :, 0:1], (R, C, 1)) # band slice # FIXME returns 3D array
validate(self.testtiff[1], (R, C)) # only band is given # FIXME returns 2D
validate(self.testtiff[:, :, 0:1], (R, C, 1)) # band slice # returns 3D array
validate(self.testtiff[:, :, 0], (R, C)) # band indexing # returns 2D array
validate(self.testtiff[1], (R, C)) # only band is given # returns 2D
validate(self.testtiff['B1'], (R, C)) # only bandname is given
# test wrong inputs
self.assertRaises(ValueError, self.testtiff.__getitem__, 'B01')
# test full array # NOTE: This sets self.testtiff.arr!
validate(self.testtiff[:], (R, C, B))
# TODO: add further tests
def test___getitem__consistency(self):
......@@ -484,6 +487,10 @@ class Test_GeoarrayFunctions(unittest.TestCase):
sub_gA = self.testtiff.get_subset(xslice=slice(2, 5), yslice=slice(None, 3))
self.assertIsInstance(sub_gA, GeoArray)
# test requesting only one column
sub_gA = self.testtiff.get_subset(xslice=slice(0, 1), yslice=slice(None, 3))
self.assertIsInstance(sub_gA, GeoArray)
# test with resetting band names
sub_gA = self.testtiff.get_subset(xslice=slice(2, 5), yslice=slice(None, 3), zslice=slice(1, 2),
reset_bandnames=True)
......@@ -537,6 +544,10 @@ class Test_GeoarrayFunctions(unittest.TestCase):
sub_gA = gA_2D.get_subset(xslice=slice(2, 5), yslice=slice(None, 3))
self.assertIsInstance(sub_gA, GeoArray)
# test requesting only one column
sub_gA = self.testtiff.get_subset(xslice=slice(0, 1), yslice=slice(None, 3))
self.assertIsInstance(sub_gA, GeoArray)
# test with resetting band names
sub_gA = gA_2D.get_subset(xslice=slice(2, 5), yslice=slice(None, 3), reset_bandnames=True)
self.assertTrue(list(sub_gA.bandnames), ['B1'])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment