11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166 | class RedisLockManager(LockManager):
"""
A lock manager that uses Redis as a backend.
Attributes:
host: The host of the Redis server
port: The port the Redis server is running on
db: The database to write to and read from
username: The username to use when connecting to the Redis server
password: The password to use when connecting to the Redis server
ssl: Whether to use SSL when connecting to the Redis server
client: The Redis client used to communicate with the Redis server
async_client: The asynchronous Redis client used to communicate with the Redis server
Example:
Use with a cache policy:
```python
from prefect import task
from prefect.cache_policies import TASK_SOURCE, INPUTS
from prefect.isolation_levels import SERIALIZABLE
from prefect_redis import RedisLockManager
cache_policy = (INPUTS + TASK_SOURCE).configure(
isolation_level=SERIALIZABLE,
lock_manager=RedisLockManager(host="my-redis-host"),
)
@task(cache_policy=cache_policy)
def my_cached_task(x: int):
return x + 42
```
Configure with a `RedisDatabase` block:
```python
from prefect_redis import RedisDatabase, RedisLockManager
block = RedisDatabase(host="my-redis-host")
lock_manager = RedisLockManager(**block.as_connection_params())
```
"""
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
username: Optional[str] = None,
password: Optional[str] = None,
ssl: bool = False,
) -> None:
self.host = host
self.port = port
self.db = db
self.username = username
self.password = password
self.ssl = ssl
self.client = Redis(
host=self.host,
port=self.port,
db=self.db,
username=self.username,
password=self.password,
)
self.async_client = AsyncRedis(
host=self.host,
port=self.port,
db=self.db,
username=self.username,
password=self.password,
)
self._locks = {}
@staticmethod
def _lock_name_for_key(key: str) -> str:
return f"lock:{key}"
def acquire_lock(
self,
key: str,
holder: str,
acquire_timeout: Optional[float] = None,
hold_timeout: Optional[float] = None,
) -> bool:
lock_name = self._lock_name_for_key(key)
lock = self._locks.get(lock_name)
if lock is not None and self.is_lock_holder(key, holder):
return True
else:
lock = Lock(
self.client, lock_name, timeout=hold_timeout, thread_local=False
)
lock_acquired = lock.acquire(token=holder, blocking_timeout=acquire_timeout)
if lock_acquired:
self._locks[lock_name] = lock
return lock_acquired
async def aacquire_lock(
self,
key: str,
holder: str,
acquire_timeout: Optional[float] = None,
hold_timeout: Optional[float] = None,
) -> bool:
lock_name = self._lock_name_for_key(key)
lock = self._locks.get(lock_name)
if lock is not None and self.is_lock_holder(key, holder):
return True
else:
lock = AsyncLock(
self.async_client, lock_name, timeout=hold_timeout, thread_local=False
)
lock_acquired = await lock.acquire(
token=holder, blocking_timeout=acquire_timeout
)
if lock_acquired:
self._locks[lock_name] = lock
return lock_acquired
def release_lock(self, key: str, holder: str) -> None:
lock_name = self._lock_name_for_key(key)
lock = self._locks.get(lock_name)
if lock is None or not self.is_lock_holder(key, holder):
raise ValueError(f"No lock held by {holder} for transaction with key {key}")
lock.release()
del self._locks[lock_name]
def wait_for_lock(self, key: str, timeout: Optional[float] = None) -> bool:
lock_name = self._lock_name_for_key(key)
lock = Lock(self.client, lock_name)
lock_freed = lock.acquire(blocking_timeout=timeout)
if lock_freed:
lock.release()
return lock_freed
async def await_for_lock(self, key: str, timeout: Optional[float] = None) -> bool:
lock_name = self._lock_name_for_key(key)
lock = AsyncLock(self.async_client, lock_name)
lock_freed = await lock.acquire(blocking_timeout=timeout)
if lock_freed:
lock.release()
return lock_freed
def is_locked(self, key: str) -> bool:
lock_name = self._lock_name_for_key(key)
lock = Lock(self.client, lock_name)
return lock.locked()
def is_lock_holder(self, key: str, holder: str) -> bool:
lock_name = self._lock_name_for_key(key)
lock = self._locks.get(lock_name)
if lock is None:
return False
if (token := getattr(lock.local, "token", None)) is None:
return False
return token.decode() == holder
|