Commit 9e31a1cc authored by Marius Kriegerowski's avatar Marius Kriegerowski

util routine

parent 782084ea
......@@ -257,7 +257,8 @@ class PileData(DataGenerator):
class OnTheFlyData(DataGenerator):
class GFData(DataGenerator):
swarm = swarm.Swarm.T()
gf_engine = Engine.T()
n_sources = Int.T(default=100)
onset_phase = String.T(default='p')
......@@ -294,8 +295,6 @@ class OnTheFlyData(DataGenerator):
return (n, e, source.depth)
def generate(self):
swarm = synthi.setup(self.gf_engine, self.n_sources)
sources = swarm.get_effective_sources()
self.tensor_shape = (len(self.targets), self.n_samples_max)
......
......@@ -14,6 +14,11 @@ import shutil
logger = logging.getLogger('pinky.model')
def delete_if_exists(dirname):
if os.path.exists(dirname):
logger.info('deleting directory: %s' % dirname)
shutil.rmtree(dirname)
class Model(Object):
data_generator = DataGeneratorBase.T()
......@@ -29,16 +34,19 @@ class Model(Object):
def __init__(self, tf_config=None, debug=False, **kwargs):
super().__init__(**kwargs)
if self.auto_clear and os.path.exists(self.summary_outdir):
logger.info('deleting directory: %s' % self.summary_outdir)
shutil.rmtree(self.summary_outdir)
logger.info('deleting directory: %s' % self.outdir)
shutil.rmtree(self.outdir)
if self.auto_clear:
delete_if_exists(self.summary_outdir)
delete_if_exists(self.outdir)
self.tf_config = tf_config
self.debug = debug
self.sess = tf.Session(config=tf_config)
# initializer = tf.truncated_normal_initializer(
self.initializer = tf.random_normal_initializer(
# mean=0.5, stddev=0.1)
mean=0.0, stddev=0.1)
def generate_input(self):
dataset = self.data_generator.get_dataset()
dataset = dataset.batch(self.batch_size)
......@@ -55,6 +63,7 @@ class Model(Object):
CNN along horizontal axis
:param cross_channel_kernel: convolution kernel size accross channels
:param n_filters:
(Needs some debugging and checking)
'''
......@@ -64,17 +73,13 @@ class Model(Object):
cross_channel_kernel = n_channels
with tf.variable_scope('conv_layer%s' %name):
# initializer = tf.truncated_normal_initializer(
initializer = tf.random_normal_initializer(
# mean=0.5, stddev=0.1)
mean=0.0, stddev=0.1)
input = tf.layers.conv2d(
inputs=input,
filters=n_filters,
kernel_size=(cross_channel_kernel, kernel_width), # use identity (1) along channels
activation=tf.nn.relu,
bias_initializer=initializer,
bias_initializer=self.initializer,
name=name)
input = tf.layers.batch_normalization(input, training=training)
......@@ -203,7 +208,7 @@ def main():
tf_config = None
if args.cpu:
tf_config = tf.ConfigProto(
device_count = {'GPU': 0}
device_count={'GPU': 0}
)
if args.show_data:
......@@ -237,7 +242,7 @@ def main():
store_superdirs=['/data/stores'],
default_store_id='vogtland_001')
data_generator = OnTheFlyData(fn_stations='stations.pf', gf_engine=gf_engine)
data_generator = GFData(fn_stations='stations.pf', gf_engine=gf_engine)
model = Model(tf_config=tf_config, data_generator=data_generator)
model.regularize()
model.dump(filename=fn_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment