diff --git a/geoarray/baseclasses.py b/geoarray/baseclasses.py index c71b2041133b81e6674afe7029cc66873602b316..4091e3c51b9a099d1d81643da5f519d9059798f8 100644 --- a/geoarray/baseclasses.py +++ b/geoarray/baseclasses.py @@ -681,25 +681,41 @@ class GeoArray(object): :param overwrite: whether to overwrite existing nodata mask that has already been calculated :return: """ - if self._mask_nodata is None or overwrite: assert self.ndim in [2, 3], "Only 2D or 3D arrays are supported. Got a %sD array." % self.ndim arr = self[:, :, fromBand] if self.ndim == 3 and fromBand is not None else self[:] - min_v, max_v = np.min(arr), np.max(arr) - if (min_v == max_v == self.nodata) or (np.isnan(min_v) and np.isnan(max_v) and np.isnan(self.nodata)): - self.mask_nodata = np.full(arr.shape[:2], False) + if self.nodata is None: + mask = np.ones((self.rows, self.cols), np.bool) + + elif np.isnan(self.nodata): + nanmask = np.isnan(arr) + nanbands = np.all(np.all(nanmask, axis=0), axis=0) + + if np.all(nanbands): + mask = np.full(arr.shape[:2], False) + elif arr.ndim == 2: + mask = ~np.isnan(arr) + else: + idx_1st_databand = np.argwhere(~nanbands)[0][0] + mask = ~np.isnan(arr[:, :, idx_1st_databand]) + mask[~mask] = np.any(~np.isnan(arr[~mask]), axis=1) + else: - if self.nodata is None: - self.mask_nodata = np.ones((self.rows, self.cols), np.bool) - elif np.isnan(self.nodata): - self.mask_nodata = \ - np.invert(np.isnan(arr)) if arr.ndim == 2 else \ - np.any(np.invert(np.isnan(arr)), axis=2) + bandmeans = np.mean(np.mean(arr, axis=0), axis=0) + + if np.nanmean(bandmeans) == self.nodata: + mask = np.full(arr.shape[:2], False) + elif arr.ndim == 2: + mask = arr != self.nodata else: - self.mask_nodata = \ - np.ma.masked_not_equal(arr, self.nodata).mask if arr.ndim == 2 else \ - np.any(np.ma.masked_not_equal(arr, self.nodata).mask, axis=2) + idx_1st_databand = np.argwhere(bandmeans != self.nodata)[0][0] + mask = np.array(arr[:, :, idx_1st_databand] != self.nodata) + mask[~mask] = np.any(arr[~mask] != self.nodata, axis=1) + + self.mask_nodata = mask + + return mask def find_noDataVal(self, bandIdx=0, sz=3): """Tries to derive no data value from homogenious corner pixels within 3x3 windows (by default).