locks.py 8.69 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
__author__ = 'Daniel Scheffler'

import time
5
from redis import StrictRedis
6
from redis_semaphore import Semaphore
7
from redis_lock import Lock
Daniel Scheffler's avatar
Daniel Scheffler committed
8
from redis.exceptions import ConnectionError as RedisConnectionError
9
import functools
10
from psutil import virtual_memory
11
12
13
14

from ..misc.logging import GMS_logger
from ..options.config import GMS_config as CFG

15
try:
16
    redis_conn = StrictRedis(host='localhost', db=0)
17
    redis_conn.keys()  # may raise ConnectionError, e.g., if redis server is not installed or not running
Daniel Scheffler's avatar
Daniel Scheffler committed
18
except RedisConnectionError:
19
20
21
    redis_conn = None


22
23
class MultiSlotLock(Semaphore):
    def __init__(self, name='MultiSlotLock', allowed_slots=1, logger=None, **kwargs):
24
        self.disabled = redis_conn is None or allowed_slots in [None, False]
25
        self.namespace = name
26
27
28
29
        self.allowed_slots = allowed_slots
        self.logger = logger or GMS_logger("RedisLock: '%s'" % name)

        if not self.disabled:
30
            super(MultiSlotLock, self).__init__(client=redis_conn, count=allowed_slots, namespace=name, **kwargs)
31
32
33
34

    def acquire(self, timeout=0, target=None):
        if not self.disabled:
            if self.available_count == 0:
35
                self.logger.info("Waiting for free lock '%s'." % self.namespace)
36

37
            token = super(MultiSlotLock, self).acquire(timeout=timeout, target=target)
38

39
            self.logger.info("Acquired lock '%s'" % self.namespace +
40
41
42
43
44
45
                             ('.' if self.allowed_slots == 1 else ', slot #%s.' % (int(token) + 1)))

            return token

    def release(self):
        if not self.disabled:
46
            token = super(MultiSlotLock, self).release()
47
            if token:
48
                self.logger.info("Released lock '%s'" % self.namespace +
49
50
51
52
                                 ('.' if self.allowed_slots == 1 else ', slot #%s.' % (int(token) + 1)))

    def delete(self):
        if not self.disabled:
53
54
55
            self.client.delete(self.check_exists_key)
            self.client.delete(self.available_key)
            self.client.delete(self.grabbed_key)
56
57


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class SharedResourceLock(MultiSlotLock):
    def acquire(self, timeout=0, target=None):
        if not self.disabled:
            token = super(SharedResourceLock, self).acquire(timeout=timeout, target=target)
            self.client.hset(self.grabbed_key_jobID, token, self.current_time)

    def release_all_jobID_tokens(self):
        if not self.disabled:
            for token in self.client.hkeys(self.grabbed_key_jobID):
                self.signal(token)

            self.client.delete(self.grabbed_key_jobID)

    @property
    def grabbed_key_jobID(self):
        return self._get_and_set_key('_grabbed_key_jobID', 'GRABBED_BY_GMSJOB_%s' % CFG.ID)

    def signal(self, token):
        if token is None:
            return None
        with self.client.pipeline() as pipe:
            pipe.multi()
            pipe.hdel(self.grabbed_key, token)
            pipe.hdel(self.grabbed_key_jobID, token)  # only difference to Semaphore.signal()
            pipe.lpush(self.available_key, token)
            pipe.execute()
            return token

    def delete(self):
        if not self.disabled:
            super(SharedResourceLock, self).delete()
            self.client.delete(self.grabbed_key_jobID)


class IOLock(SharedResourceLock):
93
94
    def __init__(self, allowed_slots=1, logger=None, **kwargs):
        super(IOLock, self).__init__(name='IOLock', allowed_slots=allowed_slots, logger=logger, **kwargs)
95
96


97
class ProcessLock(SharedResourceLock):
98
99
    def __init__(self, allowed_slots=1, logger=None, **kwargs):
        super(ProcessLock, self).__init__(name='ProcessLock', allowed_slots=allowed_slots, logger=logger, **kwargs)
100
101


102
103
class MemoryReserver(Semaphore):
    def __init__(self, mem2lock_gb, max_usage=90, logger=None, **kwargs):
104
105
106
107
        """

        :param reserved_mem:    Amount of memory to be reserved during the lock is acquired (gigabytes).
        """
108
109
110
111
112
113
114
115
116
117
118
119
120
        self.disabled = redis_conn is None
        self.mem2lock_gb = mem2lock_gb
        self.max_usage = max_usage
        self._waiting = False

        if not self.disabled:
            mem_limit = int(virtual_memory().total * max_usage / 100 / 1024**3)
            super(MemoryReserver, self).__init__(client=redis_conn, count=mem_limit, namespace='MemoryReserver',
                                                 **kwargs)
            self.logger = logger or GMS_logger("RedisLock: 'MemoryReserver'")

    @property
    def mem_reserved_gb(self):
121
        return int(redis_conn.get('GMS_mem_reserved') or 0)
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    @property
    def usable_memory_gb(self):
        return int((virtual_memory().total * self.max_usage / 100 - virtual_memory().used) / 1024**3) \
               - int(self.mem_reserved_gb)

    @property
    def grabbed_key_jobID(self):
        return self._get_and_set_key('_grabbed_key_jobID', 'GRABBED_BY_GMSJOB_%s' % CFG.ID)

    @property
    def reserved_key(self):
        return self._get_and_set_key('_reserved_key', 'MEM_RESERVED')

    @property
    def reserved_key_jobID(self):
        return self._get_and_set_key('_reserved_key_jobID', 'MEM_RESERVED_BY_GMSJOB_%s' % CFG.ID)

    def acquire(self, timeout=0, target=None):
        if not self.disabled:
            with Lock(self.client, 'GMS_mem_acquire_lock'):
                if self.usable_memory_gb >= self.mem2lock_gb:
                    for i in range(self.mem2lock_gb):
                        token = super(MemoryReserver, self).acquire(timeout=timeout)
                        self.client.hset(self.grabbed_key_jobID, token, self.current_time)

                    self.client.incr(self.reserved_key, self.mem2lock_gb)
                    self.client.incr(self.reserved_key_jobID, self.mem2lock_gb)
150

151
152
                    self.logger.info('Reserved %s GB of memory.' % self.mem2lock_gb)
                    self._waiting = False
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
                else:
                    if not self._waiting:
                        self.logger.info('Currently usable memory: %s GB. Waiting until at least %s GB are usable.'
                                         % (self.usable_memory_gb, self.mem2lock_gb))
                        self._waiting = True

                    time.sleep(1)
                    self.acquire(timeout=timeout)

    def release(self):
        if not self.disabled:
            for token in self._local_tokens:
                self.signal(token)
            self.client.decr(self.reserved_key, self.mem2lock_gb)
            self.client.decr(self.reserved_key_jobID, self.mem2lock_gb)

            self.logger.info('Released %s GB of reserved memory.' % self.mem2lock_gb)

    def release_all_jobID_tokens(self):
        mem_reserved = int(redis_conn.get(self.reserved_key_jobID) or 0)
        if mem_reserved:
            redis_conn.decr(self.reserved_key, mem_reserved)

        redis_conn.delete(self.reserved_key_jobID)

        for token in self.client.hkeys(self.grabbed_key_jobID):
            self.signal(token)

        self.client.delete(self.grabbed_key_jobID)

    def delete(self):
        if not self.disabled:
            self.client.delete(self.check_exists_key)
            self.client.delete(self.available_key)
            self.client.delete(self.grabbed_key)
189

190
191
            if self.mem_reserved_gb <= 0:
                self.client.delete(self.reserved_key)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247


def acquire_process_lock(**processlock_kwargs):
    """Decorator function for ProcessLock.

    :param processlock_kwargs:  Keywourd arguments to be passed to ProcessLock class.
    """

    def decorator(func):
        @functools.wraps(func)  # needed to avoid pickling errors
        def wrapped_func(*args, **kwargs):
            with ProcessLock(**processlock_kwargs):
                result = func(*args, **kwargs)

            return result

        return wrapped_func

    return decorator


def reserve_mem(**memlock_kwargs):
    """Decorator function for MemoryReserver.

    :param memlock_kwargs:  Keyword arguments to be passed to MemoryReserver class.
    """

    def decorator(func):
        @functools.wraps(func)  # needed to avoid pickling errors
        def wrapped_func(*args, **kwargs):
            with MemoryReserver(**memlock_kwargs):
                result = func(*args, **kwargs)

            return result

        return wrapped_func

    return decorator


def release_unclosed_locks():
    if redis_conn:
        for L in [IOLock, ProcessLock]:
            lock = L(allowed_slots=1)
            lock.release_all_jobID_tokens()

            # delete the complete redis namespace if no lock slot is acquired anymore
            if lock.client.hlen(lock.grabbed_key) == 0:
                lock.delete()

        ML = MemoryReserver(1)
        ML.release_all_jobID_tokens()

        # delete the complete redis namespace if no lock slot is acquired anymore
        if ML.client.hlen(ML.grabbed_key) == 0:
            ML.delete()