Commit 461c4953 authored by Daniel Scheffler's avatar Daniel Scheffler
Browse files

PEP-8 editing. Added style checkers.

Former-commit-id: b6a2f1b6
Former-commit-id: f867b21f
parent 31dc3c42
before_script:
- git lfs pull
stages:
- test
- deploy
test_gms_preprocessing:
stage: test
script:
- source /root/anaconda3/bin/activate
- export GDAL_DATA=/root/anaconda3/share/gdal
......@@ -17,7 +24,24 @@ test_gms_preprocessing:
- nosetests.xml
when: always
pages:
test_styles:
stage: test
script:
- source /root/anaconda3/bin/activate
- export GDAL_DATA=/root/anaconda3/share/gdal
- export PYTHONPATH=$PYTHONPATH:/root # /root <- directory needed later
- pip install flake8 pycodestyle pylint pydocstyle # TODO remove as soon as docker container is rebuilt
- make lint
artifacts:
paths:
- tests/linting/flake8.log
- tests/linting/pycodestyle.log
- tests/linting/pydocstyle.log
when: always
deploy_pages:
stage: deploy
dependencies:
- test_gms_preprocessing
......@@ -28,7 +52,6 @@ pages:
- cp nosetests.* public/nosetests_reports/
- mkdir -p public/doc
- cp -r docs/_build/html/* public/doc/
artifacts:
paths:
- public
......
......@@ -50,7 +50,9 @@ clean-test: ## remove test and coverage artifacts
rm -fr nosetests.xml
lint: ## check style with flake8
flake8 gms_preprocessing tests
flake8 --max-line-length=120 gms_preprocessing tests > ./tests/linting/flake8.log
pycodestyle gms_preprocessing --exclude="*.ipynb,*.ipynb*,envifilehandling.py" --max-line-length=120 > ./tests/linting/pycodestyle.log
-pydocstyle gms_preprocessing > ./tests/linting/pydocstyle.log
test: ## run tests quickly with the default Python
python setup.py test
......
......@@ -4,12 +4,12 @@ import os
if 'MPLBACKEND' not in os.environ:
os.environ['MPLBACKEND'] = 'Agg'
from . import algorithms
from . import io
from . import misc
from . import processing
from . import config
from .processing.process_controller import process_controller
from . import algorithms # noqa: E402
from . import io # noqa: E402
from . import misc # noqa: E402
from . import processing # noqa: E402
from . import config # noqa: E402
from .processing.process_controller import process_controller # noqa: E402
__author__ = """Daniel Scheffler"""
__email__ = 'daniel.scheffler@gfz-potsdam.de'
......
......@@ -416,12 +416,13 @@ class L1A_object(GMS_object):
if ds:
sds_md = ds.GetMetadata('SUBDATASETS')
data_arr = None
for bidx, b in enumerate(self.LayerBandsAssignment):
sds_name = [i for i in sds_md.values() if '%s_Band%s:ImageData' % (subsystem_identifier, b) in str(i) or
'%s_Swath:ImageData%s' % (subsystem_identifier, b) in str(i)][0]
data = gdalnumeric.LoadFile(sds_name)
data_arr = np.empty(data.shape + (len(self.LayerBandsAssignment),),
data.dtype) if bidx == 0 else data_arr
if bidx == 0:
data_arr = np.empty(data.shape + (len(self.LayerBandsAssignment),), data.dtype)
data_arr[:, :, bidx] = data
if CFG.job.exec_mode == 'Flink' and path_output is None: # numpy array output
......@@ -442,14 +443,16 @@ class L1A_object(GMS_object):
if subsystem_identifier in str(ds.dimensions()) and 'ImagePixel' in str(ds.dimensions()):
list_matching_dsIdx.append(i)
i += 1
except:
except Exception:
break
list_matching_dsIdx = list_matching_dsIdx[:3] if self.subsystem == 'VNIR1' else \
[list_matching_dsIdx[-1]] if self.subsystem == 'VNIR2' else list_matching_dsIdx
data_arr = None
for i, dsIdx in enumerate(list_matching_dsIdx):
data = hdfFile.select(dsIdx)[:]
data_arr = np.empty(data.shape + (len(self.LayerBandsAssignment),), data.dtype) if i == 0 else data_arr
if i == 0:
data_arr = np.empty(data.shape + (len(self.LayerBandsAssignment),), data.dtype)
data_arr[:, :, i] = data
if CFG.job.exec_mode == 'Flink' and path_output is None: # numpy array output
......@@ -462,7 +465,7 @@ class L1A_object(GMS_object):
self.logger.error('Missing HDF4 support. Reading HDF file failed.')
raise ImportError('No suitable library for reading HDF4 data available.')
ds = None
del ds
def import_metadata(self, v=False):
"""Reads metainformation of the given file from the given ASCII metafile.
......@@ -570,7 +573,7 @@ class L1A_object(GMS_object):
if conv == 'Rad':
"""http://s2tbx.telespazio-vega.de/sen2three/html/r2rusage.html?highlight=quantification182
rToa = (float)(DN_L1C_band / QUANTIFICATION_VALUE);
L = (rToa * e0__SOLAR_IRRADIANCE_For_band * cos(Z__Sun_Angles_Grid_Zenith_Values)) /
L = (rToa * e0__SOLAR_IRRADIANCE_For_band * cos(Z__Sun_Angles_Grid_Zenith_Values)) /
(PI * U__earth_sun_distance_correction_factor);
L = (U__earth_sun_distance_correction_factor * rToa * e0__SOLAR_IRRADIANCE_For_band * cos(
Z__Sun_Angles_Grid_Zenith_Values)) / PI;"""
......@@ -650,7 +653,7 @@ class L1A_object(GMS_object):
if False in [self.GeoAlign_ok, self.GeoTransProj_ok]:
previous_dataname = self.MetaObj.Dataname
if hasattr(self, 'arr') and isinstance(self.arr, (GeoArray, np.ndarray)) and \
self.MetaObj.Dataname.startswith('/vsi'):
self.MetaObj.Dataname.startswith('/vsi'):
outP = os.path.join(self.ExtractedFolder, self.baseN + '__' + self.arr_desc)
# FIXME ineffective but needed as long as georeference_by_TieP_or_inherent_GCPs does not support
# FIXME direct array inputs
......@@ -721,21 +724,26 @@ class L1A_object(GMS_object):
mask_clouds = None # FIXME
else:
# FIXME Landsat cloud mask pixel values are currently not compatible to definition_dicts.get_mask_classdefinition
# append /<GeoMultiSensRepo>/algorithms to PATH in order to properly import py_tools_ah when unpickling cloud classifiers
# FIXME Landsat cloud mask pixel values are currently not compatible to
# FIXME definition_dicts.get_mask_classdefinition
# append /<GeoMultiSensRepo>/algorithms to PATH in order to properly import py_tools_ah when unpickling
# cloud classifiers
sys.path.append(
os.path.join(os.path.dirname(__file__))) # FIXME handle py_tools_ah as normal external dependency
# in_mem = hasattr(self, 'arr') and isinstance(self.arr, np.ndarray)
# if in_mem:
# (rS, rE), (cS, cE) = self.arr_pos if self.arr_pos else ((0,self.shape_fullArr[0]),(0,self.shape_fullArr[1]))
# bands = self.arr.shape[2] if len(self.arr.shape) == 3 else 1
# (rS, rE), (cS, cE) = \
# self.arr_pos if self.arr_pos else ((0,self.shape_fullArr[0]),(0,self.shape_fullArr[1]))
# bands = self.arr.shape[2] if len(self.arr.shape) == 3 else 1
# else:
# subset = subset if subset else ['block', self.arr_pos] if self.arr_pos else ['cube', None]
# bands, rS, rE, cS, cE = list(GEOP.get_subsetProps_from_subsetArg(self.shape_fullArr, subset).values())[2:7]
# bands, rS, rE, cS, cE = \
# list(GEOP.get_subsetProps_from_subsetArg(self.shape_fullArr, subset).values())[2:7]
# arr_isPath = isinstance(self.arr, str) and os.path.isfile(self.arr) # FIXME
# inPath = self.arr if arr_isPath else self.MetaObj.Dataname if \
# (hasattr(self,'MetaObj') and self.MetaObj) else self.meta_odict['Dataname'] # FIXME ersetzen durch path generator?
# # FIXME ersetzen durch path generator?:
# inPath = self.arr if arr_isPath else self.MetaObj.Dataname if \
# (hasattr(self,'MetaObj') and self.MetaObj) else self.meta_odict['Dataname']
if not self.path_cloud_class_obj or self.satellite == 'Sentinel-2A': # FIXME dont exclude S2 here
self.log_for_fullArr_or_firstTile('Cloud masking is not yet implemented for %s %s...'
......@@ -754,9 +762,10 @@ class L1A_object(GMS_object):
# logger.info("Cloud mask missing -> derive own cloud mask.")
# CldMsk = CloudMask(logger=logger, persistence_file=options["cld_mask"]["persistence_file"],
# processing_tiles=options["cld_mask"]["processing_tiles"])
# s2img.mask_clouds = CldMsk(S2_img=s2img, target_resolution=options["cld_mask"]["target_resolution"],
# majority_filter_options=options["cld_mask"]["majority_mask_filter"],
# nodata_value=options["cld_mask"]['nodata_value_mask'])
# s2img.mask_clouds = \
# CldMsk(S2_img=s2img, target_resolution=options["cld_mask"]["target_resolution"],
# majority_filter_options=options["cld_mask"]["majority_mask_filter"],
# nodata_value=options["cld_mask"]['nodata_value_mask'])
# del CldMsk
self.GMS_identifier['logger'] = self.logger
......@@ -788,21 +797,21 @@ class L1A_object(GMS_object):
for i, class_path in zip(range(0, 2 * len(pathlist_cloud_class_obj), 2), pathlist_cloud_class_obj):
categories_timinggroup_timing[i:i + 1, 0] = os.path.splitext(os.path.basename(class_path))[0]
t1 = time.time()
CLD_obj = CLD_P.GmsCloudClassifier(classifier=class_path)
# CLD_obj = CLD_P.GmsCloudClassifier(classifier=class_path)
categories_timinggroup_timing[i, 1] = "import time"
categories_timinggroup_timing[i, 2] = time.time() - t1
t2 = time.time()
mask_clouds = CLD_obj(self)
# mask_clouds = CLD_obj(self)
categories_timinggroup_timing[i + 1, 1] = "processing time"
categories_timinggroup_timing[i + 1, 2] = time.time() - t2
classifiers = np.unique(categories_timinggroup_timing[:, 0])
categories = np.unique(categories_timinggroup_timing[:, 1])
# classifiers = np.unique(categories_timinggroup_timing[:, 0])
# categories = np.unique(categories_timinggroup_timing[:, 1])
plt.ioff()
fig = plt.figure()
ax = fig.add_subplot(111)
space = 0.3
n = len(classifiers)
width = (1 - space) / (len(classifiers))
# ax = fig.add_subplot(111)
# space = 0.3
# n = len(classifiers)
# width = (1 - space) / (len(classifiers))
# for i,classif in enumerate(classifiers): # FIXME
# vals = dpoints[dpoints[:,0] == cond][:,2].astype(np.float)
# pos = [j - (1 - space) / 2. + i * width for j in range(1,len(categories)+1)]
......@@ -951,7 +960,8 @@ class L1A_object(GMS_object):
self.arr.gt = mapinfo2geotransform(self.MetaObj.map_info)
self.arr.prj = self.MetaObj.projection
self.mask_nodata.gt = self.arr.gt # must be set here because nodata mask has been computed from self.arr without geoinfos
# must be set here because nodata mask has been computed from self.arr without geoinfos:
self.mask_nodata.gt = self.arr.gt
self.mask_nodata.prj = self.arr.prj
def update_spec_vals_according_to_dtype(self, dtype=None):
......
......@@ -208,7 +208,7 @@ class Scene_finder(object):
self.plusminus_years = plusminus_years
# get temporal constraints
add_years = lambda dt, years: dt.replace(dt.year + years) \
def add_years(dt, years): return dt.replace(dt.year + years) \
if not (dt.month == 2 and dt.day == 29) else dt.replace(dt.year + years, 3, 1)
self.timeStart = add_years(self.src_AcqDate, -plusminus_years)
timeEnd = add_years(self.src_AcqDate, +plusminus_years)
......@@ -242,7 +242,9 @@ class Scene_finder(object):
GDF['acquisitiondate'] = list(GDF['object'].map(lambda scene: scene.acquisitiondate))
GDF['cloudcover'] = list(GDF['object'].map(lambda scene: scene.cloudcover))
GDF['polyLonLat'] = list(GDF['object'].map(lambda scene: scene.polyLonLat))
LonLat2UTM = lambda polyLL: reproject_shapelyGeometry(polyLL, 4326, self.src_prj)
def LonLat2UTM(polyLL): return reproject_shapelyGeometry(polyLL, 4326, self.src_prj)
GDF['polyUTM'] = list(GDF['polyLonLat'].map(LonLat2UTM))
self.GDF_ref_scenes = GDF
......@@ -271,7 +273,7 @@ class Scene_finder(object):
GDF = self.GDF_ref_scenes
if not GDF.empty:
# get overlap parameter
get_OL_prms = lambda poly: get_overlap_polygon(poly, self.src_footprint_poly)
def get_OL_prms(poly): return get_overlap_polygon(poly, self.src_footprint_poly)
GDF['overlapParams'] = list(GDF['polyLonLat'].map(get_OL_prms))
GDF['overlap area'] = list(GDF['overlapParams'].map(lambda OL_prms: OL_prms['overlap area']))
GDF['overlap percentage'] = list(GDF['overlapParams'].map(lambda OL_prms: OL_prms['overlap percentage']))
......@@ -285,9 +287,9 @@ class Scene_finder(object):
GDF = self.GDF_ref_scenes
if not GDF.empty:
# get processing level of refernce scenes
query_procL = lambda sceneID: \
DB_T.get_info_from_postgreSQLdb(CFG.job.conn_database, 'scenes_proc', ['proc_level'],
{'sceneid': sceneID})
def query_procL(sceneID):
return DB_T.get_info_from_postgreSQLdb(CFG.job.conn_database, 'scenes_proc', ['proc_level'],
{'sceneid': sceneID})
GDF['temp_queryRes'] = list(GDF['sceneid'].map(query_procL))
GDF['proc_level'] = list(GDF['temp_queryRes'].map(lambda queryRes: queryRes[0][0] if queryRes else None))
GDF.drop('temp_queryRes', axis=1, inplace=True)
......@@ -300,40 +302,46 @@ class Scene_finder(object):
if not GDF.empty:
# get path of binary file and check if the corresponding dataset exists
GDF = self.GDF_ref_scenes
get_path_binary = lambda GDF_row: \
PG.path_generator(scene_ID=GDF_row['sceneid'], proc_level=GDF_row['proc_level']).get_path_imagedata()
check_exists = lambda path: os.path.exists(path)
def get_path_binary(GDF_row):
return PG.path_generator(scene_ID=GDF_row['sceneid'], proc_level=GDF_row['proc_level'])\
.get_path_imagedata()
def check_exists(path): return os.path.exists(path)
GDF['path_ref'] = GDF.apply(lambda GDF_row: get_path_binary(GDF_row), axis=1)
GDF['refDs_exists'] = list(GDF['path_ref'].map(check_exists))
# filter scenes out where the corresponding dataset does not exist on fileserver
self.GDF_ref_scenes = GDF[GDF['refDs_exists'] == True]
self.GDF_ref_scenes = GDF[GDF['refDs_exists']]
def _filter_by_entity_ID_availability(self):
GDF = self.GDF_ref_scenes
if not GDF.empty:
# check if a proper entity ID can be gathered from database
query_eID = lambda sceneID: DB_T.get_info_from_postgreSQLdb(CFG.job.conn_database, 'scenes', ['entityid'],
{'id': sceneID}, records2fetch=1)
def query_eID(sceneID):
return DB_T.get_info_from_postgreSQLdb(CFG.job.conn_database, 'scenes', ['entityid'],
{'id': sceneID}, records2fetch=1)
GDF['temp_queryRes'] = list(GDF['sceneid'].map(query_eID))
GDF['entityid'] = list(GDF['temp_queryRes'].map(lambda queryRes: queryRes[0][0] if queryRes else None))
GDF.drop('temp_queryRes', axis=1, inplace=True)
# filter scenes out that have no entity ID (database errors)
self.GDF_ref_scenes = GDF[GDF['refDs_exists'] == True]
self.GDF_ref_scenes = GDF[GDF['refDs_exists']]
def _filter_by_projection(self):
GDF = self.GDF_ref_scenes
if not GDF.empty:
# compare projections of target and reference image
from ..io.Input_reader import read_ENVIhdr_to_dict
get_prj = lambda path_binary: \
read_ENVIhdr_to_dict(os.path.splitext(path_binary)[0] + '.hdr')['coordinate system string']
is_prj_equal = lambda path_binary: prj_equal(self.src_prj, get_prj(path_binary))
def get_prj(path_binary):
return read_ENVIhdr_to_dict(os.path.splitext(path_binary)[0] + '.hdr')['coordinate system string']
def is_prj_equal(path_binary): return prj_equal(self.src_prj, get_prj(path_binary))
GDF['prj_equal'] = list(GDF['path_ref'].map(is_prj_equal))
# filter scenes out that have a different projection
self.GDF_ref_scenes = GDF[GDF['prj_equal'] == True]
self.GDF_ref_scenes = GDF[GDF['prj_equal']]
class ref_Scene:
......@@ -431,13 +439,14 @@ class L1B_object(L1A_object):
% (date_minmax[0].month, date_minmax[0].day, date_minmax[1].month, date_minmax[1].day)
# TODO weitere Kriterien einbauen!
query_scenes = lambda condlist: DB_T.get_overlapping_scenes_from_postgreSQLdb(
CFG.job.conn_database,
table='scenes',
tgt_corners_lonlat=self.trueDataCornerLonLat,
conditions=condlist,
add_cmds='ORDER BY scenes.cloudcover ASC',
timeout=30000)
def query_scenes(condlist):
return DB_T.get_overlapping_scenes_from_postgreSQLdb(
CFG.job.conn_database,
table='scenes',
tgt_corners_lonlat=self.trueDataCornerLonLat,
conditions=condlist,
add_cmds='ORDER BY scenes.cloudcover ASC',
timeout=30000)
conds_descImportance = [dataset_cond, cloudcov_cond, dayrange_cond]
self.logger.info('Querying database in order to find a suitable reference scene for co-registration.')
......@@ -491,7 +500,7 @@ class L1B_object(L1A_object):
break
# start download of scene data not available and start L1A processing
dl_cmd = lambda scene_ID: print('%s %s %s' % (
def dl_cmd(scene_ID): print('%s %s %s' % (
CFG.job.java_commands['keyword'].strip(), # FIXME CFG.job.java_commands is deprecated
CFG.job.java_commands["value_download"].strip(), scene_ID))
......@@ -605,7 +614,9 @@ class L1B_object(L1A_object):
for idx, cwl, fwhm in zip(range(len(shift_cwl)), shift_cwl, shift_fwhm):
if shift_bbl[idx]:
continue # skip cwl if it is declared as bad band above
is_inside = lambda r_cwl, s_cwl, s_fwhm: s_cwl - s_fwhm / 2 < r_cwl < s_cwl + s_fwhm / 2
def is_inside(r_cwl, s_cwl, s_fwhm): return s_cwl - s_fwhm / 2 < r_cwl < s_cwl + s_fwhm / 2
matching_r_cwls = [r_cwl for i, r_cwl in enumerate(ref_cwl) if
is_inside(r_cwl, cwl, fwhm) and not ref_bbl[i]]
if matching_r_cwls:
......
......@@ -6,15 +6,10 @@ import re
import logging
import dill
import traceback
from typing import List
from typing import List, TypeVar
import numpy as np
try:
from osgeo import osr
except ImportError:
import osr
from geoarray import GeoArray
from py_tools_ds.geo.map_info import mapinfo2geotransform
......@@ -165,7 +160,8 @@ class L1C_object(L1B_object):
:return:
"""
if self._SAA_arr is None:
_ = self.SZA_arr # getter also sets self._SAA_arr
# noinspection PyStatementEffect
self.SZA_arr # getter also sets self._SAA_arr
return self._SAA_arr
@SAA_arr.setter
......@@ -201,6 +197,9 @@ class L1C_object(L1B_object):
del self.dem
_T_list_L1Cobjs = TypeVar(List[L1C_object])
class AtmCorr(object):
def __init__(self, *L1C_objs, reporting=False):
"""Wrapper around atmospheric correction by Andre Hollstein, GFZ Potsdam
......@@ -227,7 +226,7 @@ class AtmCorr(object):
assert len(list(set(scene_IDs))) == 1, \
"Input GMS objects for 'AtmCorr' must all belong to the same scene ID!. Received %s." % scene_IDs
self.inObjs = L1C_objs # type: List[L1C_object]
self.inObjs = L1C_objs # type: _T_list_L1Cobjs
self.reporting = reporting
self.ac_input = {} # set by self.run_atmospheric_correction()
self.results = None # direct output of external atmCorr module (set by run_atmospheric_correction)
......@@ -850,7 +849,7 @@ class AtmCorr(object):
# FIXME really set AC nodata values to GMS outZero?
surf_refl[nodata] = oZ_refl # overwrite AC nodata values with GMS outZero
# apply the original nodata mask (indicating background values)
surf_refl[np.array(inObj.mask_nodata) == False] = oF_refl
surf_refl[np.array(inObj.mask_nodata).astype(np.int8) == 0] = oF_refl
if self.results.bad_data_value is np.nan:
surf_refl[np.isnan(surf_refl)] = oF_refl
......
......@@ -7,6 +7,7 @@ import numpy as np
from scipy.interpolate import interp1d
import scipy as sp
import matplotlib.pyplot as plt
from typing import TypeVar
from ..config import GMS_config as CFG
from ..io.Input_reader import SRF
......@@ -14,6 +15,8 @@ from .L2A_P import L2A_object
__author__ = 'Daniel Scheffler'
_T_SRF = TypeVar(SRF)
class L2B_object(L2A_object):
def __init__(self, L2A_obj=None):
......@@ -22,7 +25,7 @@ class L2B_object(L2A_object):
if L2A_obj:
# populate attributes
[setattr(self, key, value) for key,value in L2A_obj.__dict__.items()]
[setattr(self, key, value) for key, value in L2A_obj.__dict__.items()]
self.proc_level = 'L2B'
......@@ -38,7 +41,7 @@ class L2B_object(L2A_object):
# TODO better band names for homogenized product -> include in get_LayerBandsAssignment
self.LayerBandsAssignment = []
self.arr = self.interpolate_cube_linear(self.arr,src_cwls,tgt_cwls) if kind == 'linear' else self.arr
self.arr = self.interpolate_cube_linear(self.arr, src_cwls, tgt_cwls) if kind == 'linear' else self.arr
self.meta_odict['wavelength'] = list(tgt_cwls)
self.meta_odict['bands'] = len(tgt_cwls)
......@@ -64,7 +67,7 @@ class SpectralResampler(object):
"""Class for spectral resampling of a single spectral signature (1D-array) or an image (3D-array)."""
def __init__(self, wvl_src, srf_tgt, wvl_unit='nanometers'):
# type: (np.ndarray, SRF, str) -> None
# type: (np.ndarray, _T_SRF, str) -> None
"""Get an instance of the SpectralResampler1D class.
:param wvl_src: center wavelength positions of the source spectrum
......
......@@ -16,9 +16,8 @@ class GmsCloudClassifier(object):
"""
if type(classifier) is str:
with open(classifier, "rb") as fl:
_ = dill.load(fl)
self.classifier = dill.load(fl)
with open(classifier, "rb") as F:
self.classifier = dill.load(F)
else:
self.classifier = copy(classifier)
......@@ -35,7 +34,7 @@ class GmsCloudClassifier(object):
if __name__ == "__main__":
print("Start Test")
from glob import glob
from matplotlib.pyplot import *
from matplotlib.pyplot import imshow, colorbar, savefig, close, figure
import sys
from datetime import datetime
......
......@@ -2,13 +2,13 @@ from copy import copy
from random import sample
from operator import itemgetter
import random
from inspect import getargspec # FIXME
from inspect import getargspec # FIXME
import numpy as np
from scipy.ndimage.filters import gaussian_filter
from sklearn.ensemble import AdaBoostClassifier
from sklearn.feature_selection import chi2
from sklearn.cross_validation import train_test_split # FIXME
from sklearn.cross_validation import train_test_split # FIXME
__author__ = "Andre Hollstein"
......@@ -113,8 +113,9 @@ class ToClassifierRnd(_ToClassifierBase):
self.chls = np.zeros(self.n_channels, dtype=np.int)
if channel_selection == "equal":
for ii, (i1, i2) in enumerate(zip(*(lambda x: (x[:-1], x[1:]))(
np.linspace(0, self.n_channels_data, self.n_channels + 1, dtype=np.int)))):
for ii, (i1, i2) in \
enumerate(zip(*(lambda x: (x[:-1], x[1:]))(np.linspace(0, self.n_channels_data,
self.n_channels + 1, dtype=np.int)))):
self.chls[ii] = i1 + np.argmax(self.scores[i1:i2])
else:
raise ValueError("wrong value for channel_selection:%s" % channel_selection)
......@@ -127,7 +128,7 @@ class ToClassifierRnd(_ToClassifierBase):
self.classifiers_id = [np.array(sample(
list(self.chls), self.number_of_arguments(self.clf_functions[func_key])), dtype=np.int)
for func_key in self.classifiers_fk]
for func_key in self.classifiers_fk]
@staticmethod
def number_of_arguments(func):
......@@ -246,15 +247,15 @@ class ClassicalBayesian(object):
left = smoth["min"]
right = smoth["max"]
_ = self.__set__(xx_train, yy_train, smoth=left, sample_weight=sample_weight)
_ = self.__set__(xx_train, yy_train, smoth=left, sample_weight=sample_weight) # noqa F841 unused
tst_left = self.__test__(xx_test, yy_test, sample_weight=sample_weight)
_ = self.__set__(xx_train, yy_train, smoth=right, sample_weight=sample_weight)
_ = self.__set__(xx_train, yy_train, smoth=right, sample_weight=sample_weight) # noqa F841 unused
tst_right = self.__test__(xx_test, yy_test, sample_weight=sample_weight)
for i_steps in range(smoth["steps"]):
middle = 0.5 * (left + right)
_ = self.__set__(xx_train, yy_train, smoth=middle, sample_weight=sample_weight)
_ = self.__set__(xx_train, yy_train, smoth=middle, sample_weight=sample_weight) # noqa F841 unused
tst_middle = self.__test__(xx_test, yy_test, sample_weight=sample_weight)
if smoth["debug"] is True:
......@@ -268,10 +269,9 @@ class ClassicalBayesian(object):
left = copy(middle)
tst_left = copy(tst_middle)
if tst_left > tst_middle:
_ = self.__set__(xx_train, yy_train, smoth=left, sample_weight=sample_weight)
_ = self.__set__(xx_train, yy_train, smoth=left, sample_weight=sample_weight) # noqa F841 unused
if tst_right > tst_middle:
_ = self.__set__(xx_train, yy_train, smoth=right, sample_weight=sample_weight)
_ = self.__set__(xx_train, yy_train, smoth=right, sample_weight=sample_weight) # noqa F841 unused
else: # scalar or None
tst = self.__set__(xx, yy, smoth=smoth, sample_weight=sample_weight)
......@@ -373,12 +373,11 @@ if __name__ == "__main__":
chls = np.linspace(-3.0, 3.0, n_channels)
def rr(loc=1.0, s=1.0):
return np.random.normal(loc=loc, scale=s)
def rr(l=1.0, s=1.0):
return np.random.normal(loc=l, scale=s)
def norm(x): return x / np.max(np.abs(x))
norm = lambda x: x / np.max(np.abs(x))
funcs = {
10.0: lambda x: norm(rr(1.0, 1.0) + rr(1.0, 1.0) * chls),
# 39.0: lambda x:norm(rr(1.0,1.0)+rr(1.0,1.0)*chls),
......@@ -419,7 +418,6 @@ if __name__ == "__main__":
print("Train Data:")
_ = print_corr(clf, xx_train, yy_train)
print("Define AdaBoos Classical Bayesian")
adb = AdaBoostClassifier(base_estimator=clf, n_estimators=3)
adb.fit(X=xx_train, y=yy_train)
......
This diff is collapsed.
......@@ -9,8 +9,8 @@ import re
import tarfile
import warnings