locks.py 5.4 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
__author__ = 'Daniel Scheffler'

import time
5
from redis_lock import StrictRedis, Lock, NotAcquired, AlreadyAcquired
Daniel Scheffler's avatar
Daniel Scheffler committed
6
from redis.exceptions import ConnectionError as RedisConnectionError
7
import logging
8
import functools
9
import re
10
import random
11
12
13
14

from ..misc.logging import GMS_logger
from ..options.config import GMS_config as CFG

15
16

try:
17
    redis_conn = StrictRedis(host='localhost', db=0)
18
    redis_conn.keys()  # may raise ConnectionError, e.g., if redis server is not installed or not running
Daniel Scheffler's avatar
Daniel Scheffler committed
19
except RedisConnectionError:
20
21
22
    redis_conn = None


23
class MultiSlotLock(Lock):
24
25
    def __init__(self, name, allowed_threads=1, logger=None, **kwargs):
        self.conn = redis_conn
26
        self.name = name
27
        self.allowed_threads = allowed_threads or 0
28
29
        self.logger = logger or GMS_logger("RedisLock: '%s'" % self.name)
        self.kwargs = kwargs
30

31
32
        self.allowed_slot_names = ['%s, slot #%s' % (self.name, i) for i in range(1, self.allowed_threads + 1)]
        self.final_name = ''
33
        self._acquired = None
34

35
        if allowed_threads and redis_conn:
36
37
            logged = False
            while True:
38
                time.sleep(random.uniform(0, 1.5))  # avoids race conditions in case multiple tasks are waiting
39
                name_free_slot = self.get_free_slot_name()
40

41
                if not name_free_slot:
42
                    if not logged:
43
                        self.logger.info("Waiting for free '%s' lock." % self.name)
44
45
46
                        logged = True
                else:
                    break
47

48
49
            self.final_name = 'GMS_%s__' % CFG.ID + name_free_slot
            super().__init__(self.conn, self.final_name, **kwargs)
50
51
52
        else:
            pass

53
54
    @property
    def existing_locks(self):
55
56
57
        names = [i.decode('utf8').split('lock:')[1] for i in self.conn.keys() if i.decode('utf8').startswith('lock:')]

        # split 'GMS_<jobid>' and return
58
        return list(set([re.search('GMS_[0-9]*__(.*)', n, re.I).group(1) for n in names if n.startswith('GMS_')]))
59
60

    def get_free_slot_name(self):
61
        free_slots = [sn for sn in self.allowed_slot_names if sn not in self.existing_locks]
62
63
64
        if free_slots:
            return free_slots[0]

65
    def acquire(self, blocking=True, timeout=None):
66
        if self.allowed_threads and self.conn:
67
68
69
70
71
            if self._acquired:
                raise AlreadyAcquired("Already acquired from this Lock instance.")

            while not self._acquired:
                try:
72
                    # print('Trying to acquire %s.' % self.final_name.split('GMS_%s__' % CFG.ID)[1])
73
                    self._acquired = super(MultiSlotLock, self).acquire(blocking=blocking, timeout=timeout)
74
                    # print("Acquired lock '%s'." % self.final_name.split('GMS_%s__' % CFG.ID)[1])
75
76
77
78
79
                except AlreadyAcquired:
                    # this happens in case the lock has already been acquired by another instance of MultiSlotLock due
                    # to a race condition (time gap between finding the free slot and the call of self.acquire())
                    # -> in that case: re-initialize to get a new free slot
                    self.__init__(self.name, allowed_threads=self.allowed_threads, logger=self.logger,
80
81
82
83
                                  **self.kwargs)

                if self._acquired is False:  # and not None
                    self.__init__(self.name, allowed_threads=self.allowed_threads, logger=self.logger,
84
85
                                  **self.kwargs)

86
87
                # print(self.final_name.split('GMS_%s__' % CFG.ID)[1], self._acquired)

88
89
            if self._acquired:
                self.logger.info("Acquired lock '%s'." % self.final_name.split('GMS_%s__' % CFG.ID)[1])
90
        else:
91
            self._acquired = True
92

93
        return self._acquired
94

95
    def release(self):
96
        if self.allowed_threads and self.conn:
97
98
            super(MultiSlotLock, self).release()
            self.logger.info("Released lock '%s'." % self.final_name.split('GMS_%s__' % CFG.ID)[1])
99
100


101
class IOLock(MultiSlotLock):
102
    def __init__(self, processes=1, logger=None, **kwargs):
103
        super(IOLock, self).__init__(name='IOLock', allowed_threads=processes, logger=logger, **kwargs)
104
105


106
107
108
class ProcessLock(MultiSlotLock):
    def __init__(self, processes=1, logger=None, **kwargs):
        super(ProcessLock, self).__init__(name='ProcessLock', allowed_threads=processes, logger=logger, **kwargs)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


def acquire_process_lock(processes=None, logger=None):
    if not logger:
        logger = logging.getLogger('ProcessLock')
        logger.setLevel('INFO')

    def decorator(func):

        @functools.wraps(func)  # needed to avoid pickling errors
        def wrapped_func(*args, **kwargs):
            with ProcessLock(allowed_threads=processes, logger=logger):
                result = func(*args, **kwargs)

            return result

        return wrapped_func

    return decorator


def release_unclosed_locks(logger=None):
    if redis_conn:
        logger = logger or GMS_logger('LockReseter')

        locks2release = [i.split(b'lock:')[1].decode('latin') for i in redis_conn.keys()
                         if i.decode('latin').startswith('lock:GMS_%s__' % CFG.ID)]
        if locks2release:
            logger.info("Releasing unclosed locks of job %s." % CFG.ID)

        for lockN in locks2release:
            lock = Lock(redis_conn, lockN)
            try:
                lock.release()
            except NotAcquired:
                lock.reset()