Commit 4a3d5022 authored by Niklas Bohn's avatar Niklas Bohn
Browse files

Updated download link and file size of S2 novelty detector and unpinned scikit-learn version.


Signed-off-by: Niklas Bohn's avatarnbohn <nbohn@gfz-potsdam.de>
parent 8a5be05b
......@@ -8,6 +8,7 @@ History
New features:
* 'make lint' now directly prints errors instead of only logging them to logfiles.
* Automatic retraining of S2 novelty detector in case pretrained scikit-learn random forest model is out of date.
Bugfixes:
* Pinned gdal to version<=3.1.2 to avoid import error.
......
......@@ -7,6 +7,8 @@ import logging
import numpy as np
import json
import pickle
import gdown
import h5py
class CloudMask(S2cB):
......@@ -53,9 +55,41 @@ class CloudMask(S2cB):
nvc_data["clf"]
))(json.load(fl))
elif file_ext == "pkl":
try:
with open(novelty_detector, "rb") as fl:
nvc = pickle.load(fl)
ncv_clf = pickle.load(fl)
fl.close()
except ModuleNotFoundError:
# download training data base for novelty detector from google drive
logger.info("Novelty detector has to be updated with a newer version of scikit-learn.")
logger.info("Download training data base from google drive.")
url = "https://drive.google.com/uc?export=download&id=1PlJ84GGbQXM5NNSmkOn2Mg3WMy1UbPgI"
db_nv_fn = "noclear_novelty_detector_channel2_difference9_0_index10_1_channel12_index1_8.h5"
db_nv_path = novelty_detector.split("noclear")[0] + db_nv_fn
gdown.download(url, db_nv_path, quiet=False)
# retrain novelty detector
logger.info("Retrain novelty detector with updated version of scikit-learn.")
from sklearn.ensemble import RandomForestClassifier
with h5py.File(db_nv_path, "r") as h5f:
nv = RandomForestClassifier(**dict(h5f["xx"].attrs.items()))
nv.fit(X=h5f["xx"], y=h5f["yy"])
bf = json.loads(h5f.attrs["clf"])
with open(novelty_detector, "wb") as fl:
pickle.dump(nv, fl)
pickle.dump({"id": [tuple(ids) for ids in bf["id"]], "fk": [str(ids) for ids in bf["fk"]]},
fl)
fl.close()
h5f.close()
# reload novelty detector
with open(novelty_detector, "rb") as fl:
nvc = pickle.load(fl)
ncv_clf = pickle.load(fl)
fl.close()
else:
raise ValueError("Novelty detector file type not implemented")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment