locks.py 1.93 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
__author__ = 'Daniel Scheffler'

import time
5
from redis_lock import StrictRedis, Lock
6
7
8
import logging

try:
9
    redis_conn = StrictRedis(host='localhost')
10
    redis_conn.keys()  # may raise ConnectionError, e.g., if redis server is not installed or not running
11
12
13
14
except ConnectionError:
    redis_conn = None


15
class MultiSlotLock(Lock):
16
17
18
19
20
    def __init__(self, name, allowed_threads=1, logger=None, **kwargs):
        self.conn = redis_conn
        self.allowed_threads = allowed_threads
        self.allowed_slot_names = ['%s, slot #%s' % (name, i) for i in range(1, allowed_threads + 1)]

21
        if allowed_threads != 0 and redis_conn:
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
            if allowed_threads > 1:
                while True:
                    name_free_slot = self.get_free_slot_name()
                    if not name_free_slot:
                        time.sleep(0.2)
                    else:
                        break

                name = name_free_slot

            super().__init__(self.conn, name, **kwargs)
        else:
            pass

        self.name = name
        self.logger = logger or logging.getLogger("RedisLock: '%s'" % name)

    def get_existing_locks(self):
        return [i.decode('utf8').split('lock:')[1] for i in self.conn.keys()]

    def get_free_slot_name(self):
        free_slots = [sn for sn in self.allowed_slot_names if sn not in self.get_existing_locks()]
        if free_slots:
            return free_slots[0]

    def __enter__(self):
48
        if self.allowed_threads != 0 and self.conn:
49
50
51
52
53
54
            super().__enter__()
            self.logger.info("Acquired lock '%s'." % self.name)
        else:
            pass

    def __exit__(self, exc_type=None, exc_value=None, traceback=None):
55
        if self.allowed_threads != 0 and self.conn:
56
57
58
59
            super().__exit__(exc_type=exc_type, exc_value=exc_value, traceback=traceback)
            self.logger.info("Released lock '%s'." % self.name)
        else:
            pass