locks.py 5.41 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
__author__ = 'Daniel Scheffler'

import time
5
from redis_lock import StrictRedis, Lock, NotAcquired, AlreadyAcquired
Daniel Scheffler's avatar
Daniel Scheffler committed
6
from redis.exceptions import ConnectionError as RedisConnectionError
7
import logging
8
import functools
9
import re
10
import random
11
12
13
14

from ..misc.logging import GMS_logger
from ..options.config import GMS_config as CFG

15
16

try:
17
    redis_conn = StrictRedis(host='localhost', db=0)
18
    redis_conn.keys()  # may raise ConnectionError, e.g., if redis server is not installed or not running
Daniel Scheffler's avatar
Daniel Scheffler committed
19
except RedisConnectionError:
20
21
22
    redis_conn = None


23
class MultiSlotLock(Lock):
24
    def __init__(self, name, allowed_slots=1, logger=None, **kwargs):
25
        self.conn = redis_conn
26
        self.name = name
27
        self.allowed_threads = allowed_slots or 0
28
29
        self.logger = logger or GMS_logger("RedisLock: '%s'" % self.name)
        self.kwargs = kwargs
30

31
32
        self.allowed_slot_names = ['%s, slot #%s' % (self.name, i) for i in range(1, self.allowed_threads + 1)]
        self.final_name = ''
33
        self._acquired = None
34

35
        if allowed_slots and redis_conn:
36
37
            logged = False
            while True:
38
                time.sleep(random.uniform(0, 1.5))  # avoids race conditions in case multiple tasks are waiting
39
                name_free_slot = self.get_free_slot_name()
40

41
                if not name_free_slot:
42
                    if not logged:
43
                        self.logger.info("Waiting for free '%s' lock." % self.name)
44
45
46
                        logged = True
                else:
                    break
47

48
49
            self.final_name = 'GMS_%s__' % CFG.ID + name_free_slot
            super().__init__(self.conn, self.final_name, **kwargs)
50
51
52
        else:
            pass

53
54
    @property
    def existing_locks(self):
55
56
57
        names = [i.decode('utf8').split('lock:')[1] for i in self.conn.keys() if i.decode('utf8').startswith('lock:')]

        # split 'GMS_<jobid>' and return
58
        return list(set([re.search('GMS_[0-9]*__(.*)', n, re.I).group(1) for n in names if n.startswith('GMS_')]))
59
60

    def get_free_slot_name(self):
61
        free_slots = [sn for sn in self.allowed_slot_names if sn not in self.existing_locks]
62
63
64
        if free_slots:
            return free_slots[0]

65
    def acquire(self, blocking=True, timeout=None):
66
        if self.allowed_threads and self.conn:
67
68
69
70
71
            if self._acquired:
                raise AlreadyAcquired("Already acquired from this Lock instance.")

            while not self._acquired:
                try:
72
                    # print('Trying to acquire %s.' % self.final_name.split('GMS_%s__' % CFG.ID)[1])
73
                    self._acquired = super(MultiSlotLock, self).acquire(blocking=blocking, timeout=timeout)
74
                    # print("Acquired lock '%s'." % self.final_name.split('GMS_%s__' % CFG.ID)[1])
75
76
77
78
                except AlreadyAcquired:
                    # this happens in case the lock has already been acquired by another instance of MultiSlotLock due
                    # to a race condition (time gap between finding the free slot and the call of self.acquire())
                    # -> in that case: re-initialize to get a new free slot
79
                    self.__init__(self.name, allowed_slots=self.allowed_threads, logger=self.logger,
80
81
82
                                  **self.kwargs)

                if self._acquired is False:  # and not None
83
                    self.__init__(self.name, allowed_slots=self.allowed_threads, logger=self.logger,
84
85
                                  **self.kwargs)

86
87
                # print(self.final_name.split('GMS_%s__' % CFG.ID)[1], self._acquired)

88
89
            if self._acquired:
                self.logger.info("Acquired lock '%s'." % self.final_name.split('GMS_%s__' % CFG.ID)[1])
90
        else:
91
            self._acquired = True
92

93
        return self._acquired
94

95
    def release(self):
96
        if self.allowed_threads and self.conn:
97
98
            super(MultiSlotLock, self).release()
            self.logger.info("Released lock '%s'." % self.final_name.split('GMS_%s__' % CFG.ID)[1])
99
100


101
class IOLock(MultiSlotLock):
102
103
    def __init__(self, allowed_slots=1, logger=None, **kwargs):
        super(IOLock, self).__init__(name='IOLock', allowed_slots=allowed_slots, logger=logger, **kwargs)
104
105


106
class ProcessLock(MultiSlotLock):
107
108
    def __init__(self, allowed_slots=1, logger=None, **kwargs):
        super(ProcessLock, self).__init__(name='ProcessLock', allowed_slots=allowed_slots, logger=logger, **kwargs)
109
110


111
def acquire_process_lock(allowed_slots=None, logger=None):
112
113
114
115
116
117
118
119
    if not logger:
        logger = logging.getLogger('ProcessLock')
        logger.setLevel('INFO')

    def decorator(func):

        @functools.wraps(func)  # needed to avoid pickling errors
        def wrapped_func(*args, **kwargs):
120
            with ProcessLock(allowed_threads=allowed_slots, logger=logger):
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
                result = func(*args, **kwargs)

            return result

        return wrapped_func

    return decorator


def release_unclosed_locks(logger=None):
    if redis_conn:
        logger = logger or GMS_logger('LockReseter')

        locks2release = [i.split(b'lock:')[1].decode('latin') for i in redis_conn.keys()
                         if i.decode('latin').startswith('lock:GMS_%s__' % CFG.ID)]
        if locks2release:
            logger.info("Releasing unclosed locks of job %s." % CFG.ID)

        for lockN in locks2release:
            lock = Lock(redis_conn, lockN)
            try:
                lock.release()
            except NotAcquired:
                lock.reset()