Source code for redis_allocator.lock

"""Redis-based locking mechanisms for distributed coordination.

This module provides Redis-based locking classes that enable distributed
coordination and ensure data consistency across distributed processes.
"""
import time
import logging
import threading
from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
from enum import IntEnum
from collections import defaultdict
from functools import cached_property
from datetime import timedelta, datetime
from typing import Iterable, Optional, Sequence, Tuple, Union, Any
from redis import StrictRedis as Redis

logger = logging.getLogger(__name__)
Timeout = Union[float, timedelta, None]


[docs] class LockStatus(IntEnum): """Enumeration representing the status of a Redis lock. The LockStatus enum defines the possible states of a Redis lock: - FREE: The lock is not being used. - UNAVAILABLE: The lock is being used by another program, or it has been marked as unavailable for a certain period of time. - LOCKED: The lock is being used by the current program. - ERROR: The lock is being used permanently, indicating a potential issue with the program. """ FREE = 0x00 UNAVAILABLE = 0x01 LOCKED = 0x02 ERROR = 0x04
[docs] class BaseLock(ABC): """Abstract base class defining the interface for lock implementations. Attributes: eps: Epsilon value for floating point comparison. """ eps: float
[docs] def __init__(self, eps: float = 1e-6): """Initialize a BaseLock instance. Args: eps: Epsilon value for floating point comparison. """ self.eps = eps
[docs] @abstractmethod def key_status(self, key: str, timeout: int = 120) -> LockStatus: """Get the status of a key. Args: key: The key to check the status of. timeout: The lock timeout in seconds. Returns: The current status of the key. """
[docs] @abstractmethod def update(self, key: str, value='1', timeout: Timeout = 120): """Lock a key for a specified duration without checking if the key is already locked. Args: key: The key to lock. value: The value to set for the key. timeout: The lock timeout in seconds. """
[docs] @abstractmethod def lock(self, key: str, value: str = '1', timeout: Timeout = 120) -> bool: """Try to lock a key for a specified duration. Args: key: The key to lock. value: The value to set for the key. timeout: The lock timeout in seconds. Returns: True if the ownership of the key is successfully acquired, False otherwise. """
[docs] @abstractmethod def is_locked(self, key: str) -> bool: """Check if a key is locked. Args: key: The key to check. Returns: True if the key is locked, False otherwise. """
[docs] @abstractmethod def lock_value(self, key: str) -> Optional[str]: """Get the value of a locked key. Args: key: The key to get the value of. Returns: The value of the key if the key is locked, None otherwise. """
[docs] @abstractmethod def rlock(self, key: str, value: str = '1', timeout=120) -> bool: """Try to lock a key for a specified duration. When the value is the same as the current value, the function will return True. Args: key: The key to lock. value: The value to set for the key. timeout: The lock timeout in seconds. Returns: True if the ownership of the key is successfully acquired, False otherwise. """
[docs] @abstractmethod def unlock(self, key: str) -> bool: """Forcefully release a key without checking if the key is locked. Args: key: The key to release. Returns: True if the key is successfully released, False if the key is not locked. """
@abstractmethod def _conditional_setdel(self, op: str, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None, isdel: bool = False) -> bool: """Conditionally set or del a key's value based on comparison with current value. Args: op: Comparison operator ('>', '<', '>=', '<=', '==', '!='). key: The key to set or delete. value: The value to compare with. set_value: The value to set, if None, will use value instead. ex: Optional expiration time in seconds. isdel: Whether to delete the key or set the value if the condition is met. Returns: Whether the operation was successful. """
[docs] def setgt(self, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None) -> bool: """Sets a new value when the comparison value is greater than the current value.""" return self._conditional_setdel('>', key, value, set_value, ex, False)
[docs] def setlt(self, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None) -> bool: """Sets a new value when the comparison value is less than the current value.""" return self._conditional_setdel('<', key, value, set_value, ex, False)
[docs] def setge(self, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None) -> bool: """Sets a new value when the comparison value is greater than or equal to the current value.""" return self._conditional_setdel('>=', key, value, set_value, ex, False)
[docs] def setle(self, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None) -> bool: """Sets a new value when the comparison value is less than or equal to the current value.""" return self._conditional_setdel('<=', key, value, set_value, ex, False)
[docs] def seteq(self, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None) -> bool: """Sets a new value when the comparison value is equal to the current value.""" return self._conditional_setdel('==', key, value, set_value, ex, False)
[docs] def setne(self, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None) -> bool: """Sets a new value when the comparison value is not equal to the current value.""" return self._conditional_setdel('!=', key, value, set_value, ex, False)
[docs] def delgt(self, key: str, value: float): """Deletes a key when the comparison value is greater than the current value.""" return self._conditional_setdel('>', key, value, None, None, True)
[docs] def dellt(self, key: str, value: float): """Deletes a key when the comparison value is less than the current value.""" return self._conditional_setdel('<', key, value, None, None, True)
[docs] def delge(self, key: str, value: float): """Deletes a key when the comparison value is greater than or equal to the current value.""" return self._conditional_setdel('>=', key, value, None, None, True)
[docs] def delle(self, key: str, value: float): """Deletes a key when the comparison value is less than or equal to the current value.""" return self._conditional_setdel('<=', key, value, None, None, True)
[docs] def deleq(self, key: str, value: float): """Deletes a key when the comparison value is equal to the current value.""" return self._conditional_setdel('==', key, value, None, None, True)
[docs] def delne(self, key: str, value: float): """Deletes a key when the comparison value is not equal to the current value.""" return self._conditional_setdel('!=', key, value, None, None, True)
def _to_seconds(self, timeout: Timeout) -> float: """Convert a timeout to seconds.""" if timeout is None: timeout = datetime(2099, 1, 1).timestamp() elif isinstance(timeout, timedelta): timeout = timeout.total_seconds() return timeout
[docs] class BaseLockPool(BaseLock, metaclass=ABCMeta): """Abstract base class defining the interface for lock pool implementations. A lock pool manages a collection of lock keys as a group, providing methods to track, add, remove, and check lock status of multiple keys. Attributes: eps: Epsilon value for floating point comparison. """
[docs] @abstractmethod def extend(self, keys: Optional[Sequence[str]] = None): """Extend the pool with the specified keys."""
[docs] @abstractmethod def shrink(self, keys: Sequence[str]): """Shrink the pool by removing the specified keys."""
[docs] @abstractmethod def assign(self, keys: Optional[Sequence[str]] = None): """Assign keys to the pool, replacing any existing keys."""
[docs] @abstractmethod def clear(self): """Empty the pool."""
[docs] @abstractmethod def keys(self) -> Iterable[str]: """Get the keys in the pool."""
@abstractmethod def _get_key_lock_status(self, keys: Iterable[str]) -> Iterable[bool]: """Get the lock status of the specified keys."""
[docs] def values_lock_status(self) -> Iterable[bool]: """Get the lock status of all keys in the pool.""" return self._get_key_lock_status(self.keys())
[docs] def items_locked_status(self) -> Iterable[Tuple[str, bool]]: """Get (key, lock_status) pairs for all keys in the pool.""" all_keys = list(self.keys()) return zip(all_keys, self._get_key_lock_status(all_keys))
[docs] def health_check(self) -> Tuple[int, int]: """Check the health status of the keys in the pool. Returns: A tuple of (locked_count, free_count) """ items = list(self.values_lock_status()) locked = sum(1 for item in items if item) free = len(items) - locked return locked, free
[docs] def __len__(self): """Get the number of keys in the pool.""" return len(list(self.keys()))
[docs] def __iter__(self): """Iterate over the keys in the pool.""" return iter(self.keys())
[docs] class RedisLock(BaseLock): """Redis-based lock implementation. Uses standard Redis commands (SET with NX, EX options) for basic locking and Lua scripts for conditional operations (set/del based on value comparison). Attributes: redis: StrictRedis client instance (must decode responses). prefix: Prefix for all Redis keys managed by this lock instance. suffix: Suffix for Redis keys to distinguish lock types (e.g., 'lock'). eps: Epsilon for float comparisons in conditional Lua scripts. """ redis: Redis prefix: str suffix: str
[docs] def __init__(self, redis: Redis, prefix: str, suffix="lock", eps: float = 1e-6): """Initialize a RedisLock instance. Args: redis: Redis client instance. prefix: Prefix for Redis keys. suffix: Suffix for Redis keys. eps: Epsilon value for floating point comparison. """ assert "'" not in prefix and "'" not in suffix, "Prefix and suffix cannot contain single quotes" assert redis.get_encoder().decode_responses, "Redis must be configured to decode responses" super().__init__(eps=eps) self.redis = redis self.prefix = prefix self.suffix = suffix
@property def _lua_required_string(self): """Base Lua script providing the key_str function. - key_str(key: str): Constructs the full Redis key using prefix and suffix. """ return f''' local function key_str(key) return '{self.prefix}|{self.suffix}:' .. key end ''' def _key_str(self, key: str): return f'{self.prefix}|{self.suffix}:{key}'
[docs] def key_status(self, key: str, timeout: int = 120) -> LockStatus: ttl = self.redis.ttl(self._key_str(key)) if ttl > timeout: # If TTL is greater than the required expiration time, it means the usage is incorrect return LockStatus.UNAVAILABLE elif ttl >= 0: return LockStatus.LOCKED elif ttl == -1: return LockStatus.ERROR # Permanent lock return LockStatus.FREE
[docs] def update(self, key: str, value='1', timeout: Timeout = 120): self.redis.set(self._key_str(key), value, ex=timeout)
[docs] def lock(self, key: str, value: str = '1', timeout: Timeout = 120) -> bool: key_str = self._key_str(key) return self.redis.set(key_str, value, ex=timeout, nx=True)
[docs] def is_locked(self, key: str) -> bool: return self.redis.exists(self._key_str(key))
[docs] def lock_value(self, key: str) -> Optional[str]: return self.redis.get(self._key_str(key))
[docs] def rlock(self, key: str, value: str = '1', timeout=120) -> bool: key_str = self._key_str(key) old_value = self.redis.set(key_str, value, ex=timeout, nx=True, get=True) return old_value is None or old_value == value
[docs] def unlock(self, key: str) -> bool: return bool(self.redis.delete(self._key_str(key)))
def _conditional_setdel(self, op: str, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None, isdel: bool = False) -> bool: """Executes the conditional set/delete Lua script. Passes necessary arguments (key, compare_value, set_value, expiry, isdel flag) to the cached Lua script corresponding to the comparison operator (`op`). """ # Convert None to a valid value for Redis (using -1 to indicate no expiration) key_value = self._key_str(key) ex_value = -1 if ex is None else ex isdel_value = '1' if isdel else '0' if set_value is None: set_value = value return self._conditional_setdel_script[op](keys=[key_value], args=[value, set_value, ex_value, isdel_value]) @cached_property def _conditional_setdel_script(self): return {op: self.redis.register_script(self._conditional_setdel_lua_script(op, self.eps)) for op in ('>', '<', '>=', '<=', '==', '!=')} def _conditional_setdel_lua_script(self, op: str, eps: float = 1e-6) -> str: """Generates the Lua script for conditional set/delete operations. Args: op: The comparison operator ('>', '<', '>=', '<=', '==', '!='). eps: Epsilon for floating-point comparisons ('==', '!=', '>=', '<='). Returns: A Lua script string that: 1. Gets the current numeric value of the target key (KEYS[1]). 2. Compares it with the provided compare_value (ARGV[1]) using the specified `op`. 3. If the key doesn't exist or the condition is true: - If `isdel` (ARGV[4]) is true, deletes the key. - Otherwise, sets the key to `new_value` (ARGV[2]) with optional expiry `ex` (ARGV[3]). 4. Returns true if the operation was performed, false otherwise. """ match op: case '>': condition = 'compare_value > current_value' case '<': condition = 'compare_value < current_value' case '>=': condition = f'compare_value >= current_value - {eps}' case '<=': condition = f'compare_value <= current_value + {eps}' case '==': condition = f'abs(compare_value - current_value) < {eps}' case '!=': condition = f'abs(compare_value - current_value) > {eps}' case _: raise ValueError(f"Invalid operator: {op}") return f''' {self._lua_required_string} local abs = math.abs local current_key = KEYS[1] local current_value = tonumber(redis.call('GET', current_key)) local compare_value = tonumber(ARGV[1]) local new_value = tonumber(ARGV[2]) local ex = tonumber(ARGV[3]) local isdel = ARGV[4] ~= '0' if current_value == nil or {condition} then if isdel then redis.call('DEL', current_key) else if ex ~= nil and ex > 0 then redis.call('SET', current_key, new_value, 'EX', ex) else redis.call('SET', current_key, new_value) end end return true end return false ''' def __eq__(self, value: Any) -> bool: if isinstance(value, RedisLock): return self.prefix == value.prefix and self.suffix == value.suffix return False def __hash__(self) -> int: return hash((self.prefix, self.suffix))
[docs] class RedisLockPool(RedisLock, BaseLockPool): """Manages a collection of RedisLock keys as a logical pool. Uses a Redis Set (`<prefix>|<suffix>|pool`) to store the identifiers (keys) belonging to the pool. Inherits locking logic from RedisLock. Provides methods to add (`extend`), remove (`shrink`), replace (`assign`), and query (`keys`, `__contains__`) the members of the pool. Also offers methods to check the lock status of pool members (`_get_key_lock_status`). """
[docs] def __init__(self, redis: Redis, prefix: str, suffix='lock-pool', eps: float = 1e-6): """Initialize a RedisLockPool instance. Args: redis: Redis client instance. prefix: Prefix for Redis keys. suffix: Suffix for Redis keys. eps: Epsilon value for floating point comparison. """ super().__init__(redis, prefix, suffix=suffix, eps=eps) assert redis.get_encoder().decode_responses, "Redis must be configured to decode responses"
@property def _lua_required_string(self): """Base Lua script providing key_str and pool_str functions. - key_str(key: str): Inherited from RedisLock. - pool_str(): Returns the Redis key for the pool Set. """ return f''' {super()._lua_required_string} local function pool_str() return '{self._pool_str()}' end ''' def _pool_str(self): """Returns the Redis key for the pool.""" return f'{self.prefix}|{self.suffix}|pool'
[docs] def extend(self, keys: Optional[Sequence[str]] = None): """Extend the pool with the specified keys.""" if keys is not None and len(keys) > 0: self.redis.sadd(self._pool_str(), *keys)
[docs] def shrink(self, keys: Sequence[str]): """Shrink the pool by removing the specified keys.""" if keys is not None and len(keys) > 0: self.redis.srem(self._pool_str(), *keys)
@property def _assign_lua_string(self): """Lua script to atomically replace the contents of the pool Set. 1. Deletes the existing pool Set key (KEYS[1]). 2. Adds all provided keys (ARGV) to the (now empty) pool Set using SADD. """ return f''' {self._lua_required_string} local _pool_str = KEYS[1] redis.call('DEL', _pool_str) redis.call('SADD', _pool_str, unpack(ARGV)) ''' @cached_property def _assign_lua_script(self): return self.redis.register_script(self._assign_lua_string)
[docs] def assign(self, keys: Optional[Sequence[str]] = None): """Assign keys to the pool, replacing any existing keys.""" if keys is not None and len(keys) > 0: self._assign_lua_script(args=keys, keys=[self._pool_str()]) else: self.clear()
[docs] def clear(self): """Empty the pool.""" self.redis.delete(self._pool_str())
[docs] def keys(self) -> Iterable[str]: """Get the keys in the pool.""" return self.redis.smembers(self._pool_str())
[docs] def __contains__(self, key): """Check if a key is in the pool.""" return self.redis.sismember(self._pool_str(), key)
def _get_key_lock_status(self, keys: Iterable[str]) -> Iterable[bool]: """Get the lock status of the specified keys.""" return map(lambda x: x is not None, self.redis.mget(map(self._key_str, keys)))
@dataclass class LockData: """Data structure to store lock information. Attributes: value: The lock value. expiry: The expiration timestamp. """ value: str expiry: float
[docs] class ThreadLock(BaseLock): """In-memory, thread-safe lock implementation conforming to BaseLock. Simulates Redis lock behavior using Python's `threading.RLock` for concurrency control and a `defaultdict` to store lock data (value and expiry timestamp). Suitable for single-process scenarios or testing. Attributes: eps: Epsilon for float comparisons. _locks: defaultdict storing LockData(value, expiry) for each key. _lock: threading.RLock protecting access to _locks. """
[docs] def __init__(self, eps: float = 1e-6): """Initialize a ThreadLock instance. Args: eps: Epsilon value for floating point comparison. """ super().__init__(eps=eps) self._locks = defaultdict(lambda: LockData(value='1', expiry=0)) self._lock = threading.RLock() # Thread lock to protect access to _locks
def _is_expired(self, key: str) -> bool: """Check if a lock has expired.""" return self._get_ttl(key) <= 0 def _get_ttl(self, key: str): """Get the TTL of a lock in seconds.""" return self._locks[key].expiry - time.time()
[docs] def key_status(self, key: str, timeout: int = 120) -> LockStatus: """Get the status of a key.""" ttl = self._get_ttl(key) if ttl <= 0: return LockStatus.FREE elif ttl > timeout: return LockStatus.UNAVAILABLE return LockStatus.LOCKED
[docs] def update(self, key: str, value='1', timeout: Timeout = 120): """Lock a key for a specified duration without checking if already locked.""" expiry = time.time() + self._to_seconds(timeout) self._locks[key] = LockData(value=value, expiry=expiry)
[docs] def lock(self, key: str, value: str = '1', timeout: Timeout = 120) -> bool: """Try to lock a key for a specified duration.""" with self._lock: if not self._is_expired(key): return False expiry = time.time() + self._to_seconds(timeout) self._locks[key] = LockData(value=value, expiry=expiry) return True
[docs] def is_locked(self, key: str) -> bool: """Check if a key is locked.""" return not self._is_expired(key)
[docs] def lock_value(self, key: str) -> Optional[str]: """Get the value of a locked key.""" data = self._locks[key] if data.expiry <= time.time(): return None return str(data.value)
[docs] def rlock(self, key: str, value: str = '1', timeout=120) -> bool: """Try to relock a key for a specified duration.""" with self._lock: data = self._locks[key] if data.expiry > time.time() and data.value != value: return False expiry = time.time() + self._to_seconds(timeout) self._locks[key] = LockData(value=value, expiry=expiry) return True
[docs] def unlock(self, key: str) -> bool: """Forcefully release a key.""" return self._locks.pop(key, None) is not None
def _compare_values(self, op: str, compare_value: float, current_value: float) -> bool: """Compare two values using the specified operator.""" match op: case '>': return compare_value > current_value case '<': return compare_value < current_value case '>=': return compare_value >= current_value - self.eps case '<=': return compare_value <= current_value + self.eps case '==': return abs(compare_value - current_value) < self.eps case '!=': return abs(compare_value - current_value) > self.eps case _: raise ValueError(f"Invalid operator: {op}") def _conditional_setdel(self, op: str, key: str, value: float, set_value: Optional[float] = None, ex: Optional[int] = None, isdel: bool = False) -> bool: """Conditionally set or delete a key's value based on comparison.""" compare_value = float(value) if set_value is None: set_value = value with self._lock: # Get current value if key exists and is not expired current_data = self._locks[key] if current_data.expiry <= time.time(): current_value = None else: current_value = float(current_data.value) # Condition check if current_value is None or self._compare_values(op, compare_value, current_value): if isdel: # Delete the key self._locks.pop(key, None) else: # Set the key expiry = time.time() + self._to_seconds(ex) self._locks[key] = LockData(value=set_value, expiry=expiry) return True return False
[docs] class ThreadLockPool(ThreadLock, BaseLockPool): """In-memory, thread-safe lock pool implementation. Manages a collection of lock keys using a Python `set` for the pool members and inherits the locking logic from `ThreadLock`. Attributes: _pool: Set containing the keys belonging to this pool. _lock: threading.RLock protecting access to _locks and _pool. """
[docs] def __init__(self, eps: float = 1e-6): """Initialize a ThreadLockPool instance.""" super().__init__(eps=eps) self._locks = defaultdict(lambda: LockData(value='1', expiry=0)) self._lock = threading.RLock() self._pool = set()
[docs] def extend(self, keys: Optional[Sequence[str]] = None): """Extend the pool with the specified keys.""" with self._lock: if keys is not None: self._pool.update(keys)
[docs] def shrink(self, keys: Sequence[str]): """Shrink the pool by removing the specified keys.""" with self._lock: self._pool.difference_update(keys)
[docs] def assign(self, keys: Optional[Sequence[str]] = None): """Assign keys to the pool, replacing any existing keys.""" with self._lock: self.clear() self.extend(keys=keys)
[docs] def clear(self): """Empty the pool.""" self._pool.clear()
[docs] def keys(self) -> Iterable[str]: """Get the keys in the pool.""" return self._pool
[docs] def __contains__(self, key): """Check if a key is in the pool.""" return key in self._pool
def _get_key_lock_status(self, keys: Iterable[str]) -> Iterable[bool]: """Get the lock status of the specified keys.""" return [self.is_locked(key) for key in keys]