locks.py 12.3 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
__author__ = 'Daniel Scheffler'

import time
5
from redis import StrictRedis
6
from redis_semaphore import Semaphore
Daniel Scheffler's avatar
Daniel Scheffler committed
7
from redis.exceptions import ConnectionError as RedisConnectionError
8
import functools
9
from psutil import virtual_memory
10
11
12
13

from ..misc.logging import GMS_logger
from ..options.config import GMS_config as CFG

14
try:
15
    redis_conn = StrictRedis(host='localhost', db=0)
16
    redis_conn.keys()  # may raise ConnectionError, e.g., if redis server is not installed or not running
Daniel Scheffler's avatar
Daniel Scheffler committed
17
except RedisConnectionError:
18
19
20
    redis_conn = None


21
22
class MultiSlotLock(Semaphore):
    def __init__(self, name='MultiSlotLock', allowed_slots=1, logger=None, **kwargs):
23
        self.disabled = redis_conn is None or allowed_slots in [None, False]
24
        self.namespace = name
25
26
27
28
        self.allowed_slots = allowed_slots
        self.logger = logger or GMS_logger("RedisLock: '%s'" % name)

        if not self.disabled:
29
            super(MultiSlotLock, self).__init__(client=redis_conn, count=allowed_slots, namespace=name, **kwargs)
30
31
32
33

    def acquire(self, timeout=0, target=None):
        if not self.disabled:
            if self.available_count == 0:
34
                self.logger.info("Waiting for free lock '%s'." % self.namespace)
35

36
            token = super(MultiSlotLock, self).acquire(timeout=timeout, target=target)
37

38
            self.logger.info("Acquired lock '%s'" % self.namespace +
39
40
41
42
43
44
                             ('.' if self.allowed_slots == 1 else ', slot #%s.' % (int(token) + 1)))

            return token

    def release(self):
        if not self.disabled:
45
            token = super(MultiSlotLock, self).release()
46
            if token:
47
                self.logger.info("Released lock '%s'" % self.namespace +
48
49
50
51
                                 ('.' if self.allowed_slots == 1 else ', slot #%s.' % (int(token) + 1)))

    def delete(self):
        if not self.disabled:
52
53
54
            self.client.delete(self.check_exists_key)
            self.client.delete(self.available_key)
            self.client.delete(self.grabbed_key)
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class SharedResourceLock(MultiSlotLock):
    def acquire(self, timeout=0, target=None):
        if not self.disabled:
            token = super(SharedResourceLock, self).acquire(timeout=timeout, target=target)
            self.client.hset(self.grabbed_key_jobID, token, self.current_time)

    def release_all_jobID_tokens(self):
        if not self.disabled:
            for token in self.client.hkeys(self.grabbed_key_jobID):
                self.signal(token)

            self.client.delete(self.grabbed_key_jobID)

    @property
    def grabbed_key_jobID(self):
        return self._get_and_set_key('_grabbed_key_jobID', 'GRABBED_BY_GMSJOB_%s' % CFG.ID)

    def signal(self, token):
        if token is None:
            return None
        with self.client.pipeline() as pipe:
            pipe.multi()
            pipe.hdel(self.grabbed_key, token)
            pipe.hdel(self.grabbed_key_jobID, token)  # only difference to Semaphore.signal()
            pipe.lpush(self.available_key, token)
            pipe.execute()
            return token

    def delete(self):
        if not self.disabled:
            super(SharedResourceLock, self).delete()
            self.client.delete(self.grabbed_key_jobID)


class IOLock(SharedResourceLock):
92
    def __init__(self, allowed_slots=1, logger=None, **kwargs):
93
94
95
96
        self.disabled = CFG.disable_IO_locks

        if not self.disabled:
            super(IOLock, self).__init__(name='IOLock', allowed_slots=allowed_slots, logger=logger, **kwargs)
97
98


99
class ProcessLock(SharedResourceLock):
100
    def __init__(self, allowed_slots=1, logger=None, **kwargs):
101
102
103
104
        self.disabled = CFG.disable_CPU_locks

        if not self.disabled:
            super(ProcessLock, self).__init__(name='ProcessLock', allowed_slots=allowed_slots, logger=logger, **kwargs)
105
106


107
108
class MemoryReserver(Semaphore):
    def __init__(self, mem2lock_gb, max_usage=90, logger=None, **kwargs):
109
110
111
112
        """

        :param reserved_mem:    Amount of memory to be reserved during the lock is acquired (gigabytes).
        """
113
        self.disabled = redis_conn is None or CFG.disable_memory_locks or mem2lock_gb in [None, False]
114
115
116
117
118
119
120
121
122
123
124
125
        self.mem2lock_gb = mem2lock_gb
        self.max_usage = max_usage
        self._waiting = False

        if not self.disabled:
            mem_limit = int(virtual_memory().total * max_usage / 100 / 1024**3)
            super(MemoryReserver, self).__init__(client=redis_conn, count=mem_limit, namespace='MemoryReserver',
                                                 **kwargs)
            self.logger = logger or GMS_logger("RedisLock: 'MemoryReserver'")

    @property
    def mem_reserved_gb(self):
126
        return int(redis_conn.get('GMS_mem_reserved') or 0)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    @property
    def usable_memory_gb(self):
        return int((virtual_memory().total * self.max_usage / 100 - virtual_memory().used) / 1024**3) \
               - int(self.mem_reserved_gb)

    @property
    def grabbed_key_jobID(self):
        return self._get_and_set_key('_grabbed_key_jobID', 'GRABBED_BY_GMSJOB_%s' % CFG.ID)

    @property
    def reserved_key(self):
        return self._get_and_set_key('_reserved_key', 'MEM_RESERVED')

    @property
    def reserved_key_jobID(self):
        return self._get_and_set_key('_reserved_key_jobID', 'MEM_RESERVED_BY_GMSJOB_%s' % CFG.ID)

145
146
147
148
149
    @property
    def acquisition_key(self):
        return self._get_and_set_key('_acquisition_key', 'ACQUISITION_LOCK')

    def acquire_old(self, timeout=0, target=None):
150
        if not self.disabled:
151
            with MemoryReserverAcquisitionLock():
152

153
154
155
156
157
158
159
                if self.usable_memory_gb >= self.mem2lock_gb:
                    for i in range(self.mem2lock_gb):
                        token = super(MemoryReserver, self).acquire(timeout=timeout)
                        self.client.hset(self.grabbed_key_jobID, token, self.current_time)

                    self.client.incr(self.reserved_key, self.mem2lock_gb)
                    self.client.incr(self.reserved_key_jobID, self.mem2lock_gb)
160

161
162
                    self.logger.info('Reserved %s GB of memory.' % self.mem2lock_gb)
                    self._waiting = False
163

164
165
166
167
168
169
170
171
172
                else:
                    if not self._waiting:
                        self.logger.info('Currently usable memory: %s GB. Waiting until at least %s GB are usable.'
                                         % (self.usable_memory_gb, self.mem2lock_gb))
                        self._waiting = True

                    time.sleep(1)
                    self.acquire(timeout=timeout)

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def acquire(self, timeout=0, target=None):
        if not self.disabled:
            token = self.client.getset(self.acquisition_key, self.exists_val)

            if token:
                time.sleep(1)
                self.acquire(timeout=timeout)

            self.client.expire(self.acquisition_key, 10)

            try:
                if self.usable_memory_gb >= self.mem2lock_gb:
                    for i in range(self.mem2lock_gb):
                        token = super(MemoryReserver, self).acquire(timeout=timeout)
                        self.client.hset(self.grabbed_key_jobID, token, self.current_time)

                    self.client.incr(self.reserved_key, self.mem2lock_gb)
                    self.client.incr(self.reserved_key_jobID, self.mem2lock_gb)

                    self.logger.info('Reserved %s GB of memory.' % self.mem2lock_gb)
                    self._waiting = False

                else:
                    if not self._waiting:
                        self.logger.info('Currently usable memory: %s GB. Waiting until at least %s GB are usable.'
                                         % (self.usable_memory_gb, self.mem2lock_gb))
                        self._waiting = True

                    time.sleep(1)
                    self.acquire(timeout=timeout)

            finally:
                self.client.delete(self.acquisition_key)

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    def release(self):
        if not self.disabled:
            for token in self._local_tokens:
                self.signal(token)
            self.client.decr(self.reserved_key, self.mem2lock_gb)
            self.client.decr(self.reserved_key_jobID, self.mem2lock_gb)

            self.logger.info('Released %s GB of reserved memory.' % self.mem2lock_gb)

    def release_all_jobID_tokens(self):
        mem_reserved = int(redis_conn.get(self.reserved_key_jobID) or 0)
        if mem_reserved:
            redis_conn.decr(self.reserved_key, mem_reserved)

        redis_conn.delete(self.reserved_key_jobID)

        for token in self.client.hkeys(self.grabbed_key_jobID):
            self.signal(token)

        self.client.delete(self.grabbed_key_jobID)

    def delete(self):
        if not self.disabled:
            self.client.delete(self.check_exists_key)
            self.client.delete(self.available_key)
            self.client.delete(self.grabbed_key)
233

234
235
            if self.mem_reserved_gb <= 0:
                self.client.delete(self.reserved_key)
236
237


238
239
240
241
242
243
class MemoryReserverAcquisitionLock(Semaphore):
    def __init__(self, **kwargs):
        self.disabled = redis_conn is None or CFG.disable_memory_locks

        if not self.disabled:
            super(MemoryReserverAcquisitionLock, self).__init__(client=redis_conn, count=1,
244
245
                                                                namespace='MemoryReserverAcquisitionLock',
                                                                stale_client_timeout=10, **kwargs)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

    @property
    def grabbed_key_jobID(self):
        return self._get_and_set_key('_grabbed_key_jobID', 'GRABBED_BY_GMSJOB_%s' % CFG.ID)

    def acquire(self, timeout=0, target=None):
        if not self.disabled:
            token = super(MemoryReserverAcquisitionLock, self).acquire(timeout=timeout, target=target)
            self.client.hset(self.grabbed_key_jobID, token, self.current_time)

    def release_all_jobID_tokens(self):
        if not self.disabled:
            for token in self.client.hkeys(self.grabbed_key_jobID):
                self.signal(token)

            self.client.delete(self.grabbed_key_jobID)

    def signal(self, token):
        if token is None:
            return None
        with self.client.pipeline() as pipe:
            pipe.multi()
            pipe.hdel(self.grabbed_key, token)
            pipe.hdel(self.grabbed_key_jobID, token)  # only difference to Semaphore.signal()
            pipe.lpush(self.available_key, token)
            pipe.execute()
            return token

    def delete(self):
        if not self.disabled:
            self.client.delete(self.check_exists_key)
            self.client.delete(self.available_key)
            self.client.delete(self.grabbed_key)
            self.client.delete(self.grabbed_key_jobID)


282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def acquire_process_lock(**processlock_kwargs):
    """Decorator function for ProcessLock.

    :param processlock_kwargs:  Keywourd arguments to be passed to ProcessLock class.
    """

    def decorator(func):
        @functools.wraps(func)  # needed to avoid pickling errors
        def wrapped_func(*args, **kwargs):
            with ProcessLock(**processlock_kwargs):
                result = func(*args, **kwargs)

            return result

        return wrapped_func

    return decorator


def reserve_mem(**memlock_kwargs):
    """Decorator function for MemoryReserver.

    :param memlock_kwargs:  Keyword arguments to be passed to MemoryReserver class.
    """

    def decorator(func):
        @functools.wraps(func)  # needed to avoid pickling errors
        def wrapped_func(*args, **kwargs):
            with MemoryReserver(**memlock_kwargs):
                result = func(*args, **kwargs)

            return result

        return wrapped_func

    return decorator


def release_unclosed_locks():
    if redis_conn:
        for L in [IOLock, ProcessLock]:
            lock = L(allowed_slots=1)
            lock.release_all_jobID_tokens()

            # delete the complete redis namespace if no lock slot is acquired anymore
            if lock.client.hlen(lock.grabbed_key) == 0:
                lock.delete()

330
331
        MR = MemoryReserver(1)
        MR.release_all_jobID_tokens()
332
333

        # delete the complete redis namespace if no lock slot is acquired anymore
334
335
        if MR.client.hlen(MR.grabbed_key) == 0:
            MR.delete()
336

337
338
339
340
341
342
        MRAL = MemoryReserverAcquisitionLock()
        MRAL.release_all_jobID_tokens()

        # delete the complete redis namespace if no lock slot is acquired anymore
        if MRAL.client.hlen(MRAL.grabbed_key) == 0:
            MRAL.delete()