reproject.py 9.46 KB
Newer Older
Daniel Scheffler's avatar
Daniel Scheffler committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# -*- coding: utf-8 -*-
__author__ = "Daniel Scheffler"


import numpy as np
import warnings

# custom
import rasterio
from rasterio.warp import reproject as rio_reproject
from rasterio.warp import calculate_default_transform as rio_calc_transform
from rasterio.warp import Resampling

from ..projection  import WKT2EPSG, isProjectedOrGeographic
from ..coord_trafo import pixelToLatLon


def warp_ndarray(ndarray, in_gt, in_prj, out_prj, out_gt=None, outRowsCols=None, outUL=None, out_res=None,
                 out_extent=None, out_dtype=None, rsp_alg=0, in_nodata=None, out_nodata=None, outExtent_within=True):

    """Reproject / warp a numpy array with given geo information to target coordinate system.

    :param ndarray:             numpy.ndarray [rows,cols,bands]
    :param in_gt:               input gdal GeoTransform
    :param in_prj:              input projection as WKT string
    :param out_prj:             output projection as WKT string
    :param out_gt:              output gdal GeoTransform as float tuple in the source coordinate system (optional)
    :param outUL:               [X,Y] output upper left coordinates as floats in the source coordinate system
                                (requires outRowsCols)
    :param outRowsCols:         [rows, cols] (optional)
    :param out_res:             output resolution as tuple of floats (x,y) in the TARGET coordinate system
    :param out_extent:          [left, bottom, right, top] as floats in the source coordinate system
    :param out_dtype:           output data type as numpy data type
    :param rsp_alg:             Resampling method to use. One of the following (int, default is 0):
                                0 = nearest neighbour, 1 = bilinear, 2 = cubic, 3 = cubic spline, 4 = lanczos,
                                5 = average, 6 = mode
    :param in_nodata:           no data value of the input image
    :param out_nodata:          no data value of the output image
    :param outExtent_within:    Ensures that the output extent is within the input extent.
                                Otherwise an exception is raised.
    :return out_arr:            warped numpy array
    :return out_gt:             warped gdal GeoTransform
    :return out_prj:            warped projection as WKT string
    """
    if not ndarray.flags['OWNDATA']:
        temp    = np.empty_like(ndarray)
        temp[:] = ndarray
        ndarray = temp  # deep copy: converts view to its own array in order to avoid wrong output

    with rasterio.env.Env():
        if outUL is not None:
            assert outRowsCols is not None, 'outRowsCols must be given if outUL is given.'
        outUL = [in_gt[0], in_gt[3]] if outUL is None else outUL

        inEPSG, outEPSG = [WKT2EPSG(prj) for prj in [in_prj, out_prj]]
        assert inEPSG,  'Could not derive input EPSG code.'
        assert outEPSG, 'Could not derive output EPSG code.'
        assert in_nodata  is None or type(in_nodata)  in [int, float]
        assert out_nodata is None or type(out_nodata) in [int, float]

        src_crs = {'init': 'EPSG:%s' % inEPSG}
        dst_crs = {'init': 'EPSG:%s' % outEPSG}

        if len(ndarray.shape) == 3:
            # convert input array axis order to rasterio axis order
            ndarray = np.swapaxes(np.swapaxes(ndarray, 0, 2), 1, 2)
            bands, rows, cols = ndarray.shape
            rows, cols = outRowsCols if outRowsCols else (rows, cols)
        else:
            rows, cols = ndarray.shape if outRowsCols is None else outRowsCols

        # set dtypes ensuring at least int16 (int8 is not supported by rasterio)
        in_dtype  = ndarray.dtype
        out_dtype = ndarray.dtype if out_dtype is None else out_dtype
        out_dtype = np.int16 if str(out_dtype) == 'int8' else out_dtype
        ndarray   = ndarray.astype(np.int16) if str(in_dtype) == 'int8' else ndarray

        gt2bounds = lambda gt, r, c: [gt[0], gt[3] + r * gt[5], gt[0] + c * gt[1], gt[3]]  # left, bottom, right, top

        # get dst_transform
        if out_gt is None and out_extent is None:
            if outRowsCols:
                outUL       = [in_gt[0], in_gt[3]] if outUL is None else outUL
                ulRC2bounds = lambda ul, r, c: [ul[0], ul[1] + r * in_gt[5], ul[0] + c * in_gt[1], ul[1]]  # left, bottom, right, top
                left, bottom, right, top = ulRC2bounds(outUL, rows, cols)
            else:  # outRowsCols is None and outUL is None: use in_gt
                left, bottom, right, top = gt2bounds(in_gt, rows, cols)
                # ,im_xmax,im_ymin,im_ymax = corner_coord_to_minmax(get_corner_coordinates(self.ds_im2shift))
        elif out_extent:
            left, bottom, right, top = out_extent
        else:  # out_gt is given
            left, bottom, right, top = gt2bounds(in_gt, rows, cols)

        if outExtent_within:
            # input array is only a window of the actual input array
            assert left >= in_gt[0] and right <= (in_gt[0] + (cols + 1) * in_gt[1]) and \
                  bottom >= in_gt[3] + (rows + 1) * in_gt[5] and top <= in_gt[3], \
               "out_extent has to be completely within the input image bounds."

        if out_res is None:
            # get pixel resolution in target coord system
            prj_in_out = (isProjectedOrGeographic(in_prj), isProjectedOrGeographic(out_prj))
            assert None not in prj_in_out, 'prj_in_out contains None.'
            if prj_in_out[0] == prj_in_out[1]:
                out_res = (in_gt[1], abs(in_gt[5]))
            elif prj_in_out == ('geographic', 'projected'):
                raise NotImplementedError('Different projections are currently not supported.')
            else:  # ('projected','geographic')
                px_size_LatLon = np.array(pixelToLatLon([1, 1], geotransform=in_gt, projection=in_prj)) - \
                                 np.array(pixelToLatLon([0, 0], geotransform=in_gt, projection=in_prj))
                out_res = tuple(reversed(abs(px_size_LatLon)))
                print('OUT_RES NOCHMAL CHECKEN: ', out_res)

        dst_transform, out_cols, out_rows = rio_calc_transform(
            src_crs, dst_crs, cols, rows, left, bottom, right, top, resolution=out_res)  # TODO keyword densify_pts=21 does not exist anymore (moved to transform_bounds()) -> could that be a problem? check warp outputs

        # check if calculated output dimensions correspond to expected ones and correct them if neccessary
        rows_expected = int(round(abs(top - bottom) / out_res[1], 0))
        cols_expected = int(round(abs(right - left) / out_res[0], 0))
        diff_rows_exp_real, diff_cols_exp_real = abs(out_rows - rows_expected), abs(out_cols - cols_expected)
        if diff_rows_exp_real > 0.1 or diff_cols_exp_real > 0.1:
            assert diff_rows_exp_real < 1.1 and diff_cols_exp_real < 1.1, 'warp_ndarray: The output image size ' \
                                                                          'calculated by rasterio is too far away from the expected output image size.'
            out_rows, out_cols = rows_expected, cols_expected
            # fixes an issue where rio_calc_transform() does not return quadratic output image although input parameters
            # correspond to a quadratic image and inEPSG equals outEPSG

        aff = list(dst_transform)
        out_gt = out_gt if out_gt else (aff[2], aff[0], aff[1], aff[5], aff[3], aff[4])

        src_transform = rasterio.transform.from_origin(in_gt[0], in_gt[3], in_gt[1], in_gt[5])

        dict_rspInt_rspAlg = \
            {0: Resampling.nearest,      1: Resampling.bilinear, 2: Resampling.cubic,
             3: Resampling.cubic_spline, 4: Resampling.lanczos,  5: Resampling.average, 6: Resampling.mode}

        out_arr = np.zeros((bands, out_rows, out_cols), out_dtype) \
            if len(ndarray.shape) == 3 else np.zeros((out_rows, out_cols), out_dtype)

        # FIXME direct passing of src_transform and dst_transform results in a wrong output image. Maybe a rasterio-bug?
        # rio_reproject(ndarray, out_arr, src_transform=src_transform, src_crs=src_crs, dst_transform=dst_transform,
        #    dst_crs=dst_crs, resampling=dict_rspInt_rspAlg[rsp_alg])
        # FIXME indirect passing causes Future warning
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')  # FIXME supresses: FutureWarning: GDAL-style transforms are deprecated and will not be supported in Rasterio 1.0.
            try:
                #print('INPUTS')
                #print(ndarray.shape, ndarray.dtype, out_arr.shape, out_arr.dtype)
                #print(in_gt)
                #print(src_crs)
                #print(out_gt)
                #print(dst_crs)
                #print(dict_rspInt_rspAlg[rsp_alg])
                #print(in_nodata)
                #print(out_nodata)
                rio_reproject(ndarray, out_arr,
                              src_transform=in_gt, src_crs=src_crs, dst_transform=out_gt, dst_crs=dst_crs,
                              resampling=dict_rspInt_rspAlg[rsp_alg], src_nodata=in_nodata, dst_nodata=out_nodata)
                from matplotlib import pyplot as plt
                #print(out_arr.shape)
                #plt.figure()
                #plt.imshow(out_arr[:,:,1])
            except KeyError:
                print(in_dtype, str(in_dtype))
                print(ndarray.dtype)

        # convert output array axis order to GMS axis order [rows,cols,bands]
        out_arr = out_arr if len(ndarray.shape) == 2 else np.swapaxes(np.swapaxes(out_arr, 0, 1), 1, 2)

        if outRowsCols:
            out_arr = out_arr[:outRowsCols[0], :outRowsCols[1]]

    return out_arr, out_gt, out_prj