58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182 | class CloudClient:
def __init__(
self,
host: str,
api_key: str,
httpx_settings: Optional[Dict[str, Any]] = None,
) -> None:
httpx_settings = httpx_settings or dict()
httpx_settings.setdefault("headers", dict())
httpx_settings["headers"].setdefault("Authorization", f"Bearer {api_key}")
httpx_settings.setdefault("base_url", host)
if not PREFECT_UNIT_TEST_MODE.value():
httpx_settings.setdefault("follow_redirects", True)
self._client = PrefectHttpxAsyncClient(
**httpx_settings, enable_csrf_support=False
)
api_url = prefect.settings.PREFECT_API_URL.value() or ""
if match := (
re.search(PARSE_API_URL_REGEX, host)
or re.search(PARSE_API_URL_REGEX, api_url)
):
self.account_id, self.workspace_id = match.groups()
@property
def account_base_url(self) -> str:
if not self.account_id:
raise ValueError("Account ID not set")
return f"accounts/{self.account_id}"
@property
def workspace_base_url(self) -> str:
if not self.workspace_id:
raise ValueError("Workspace ID not set")
return f"{self.account_base_url}/workspaces/{self.workspace_id}"
async def api_healthcheck(self):
"""
Attempts to connect to the Cloud API and raises the encountered exception if not
successful.
If successful, returns `None`.
"""
with anyio.fail_after(10):
await self.read_workspaces()
async def read_workspaces(self) -> List[Workspace]:
workspaces = pydantic.TypeAdapter(List[Workspace]).validate_python(
await self.get("/me/workspaces")
)
return workspaces
async def read_worker_metadata(self) -> Dict[str, Any]:
response = await self.get(
f"{self.workspace_base_url}/collections/work_pool_types"
)
return cast(Dict[str, Any], response)
async def read_account_settings(self) -> Dict[str, Any]:
response = await self.get(f"{self.account_base_url}/settings")
return cast(Dict[str, Any], response)
async def update_account_settings(self, settings: Dict[str, Any]):
await self.request(
"PATCH",
f"{self.account_base_url}/settings",
json=settings,
)
async def read_account_ip_allowlist(self) -> IPAllowlist:
response = await self.get(f"{self.account_base_url}/ip_allowlist")
return IPAllowlist.model_validate(response)
async def update_account_ip_allowlist(self, updated_allowlist: IPAllowlist):
await self.request(
"PUT",
f"{self.account_base_url}/ip_allowlist",
json=updated_allowlist.model_dump(mode="json"),
)
async def check_ip_allowlist_access(self) -> IPAllowlistMyAccessResponse:
response = await self.get(f"{self.account_base_url}/ip_allowlist/my_access")
return IPAllowlistMyAccessResponse.model_validate(response)
async def __aenter__(self):
await self._client.__aenter__()
return self
async def __aexit__(self, *exc_info):
return await self._client.__aexit__(*exc_info)
def __enter__(self):
raise RuntimeError(
"The `CloudClient` must be entered with an async context. Use 'async "
"with CloudClient(...)' not 'with CloudClient(...)'"
)
def __exit__(self, *_):
assert False, "This should never be called but must be defined for __enter__"
async def get(self, route, **kwargs):
return await self.request("GET", route, **kwargs)
async def request(self, method, route, **kwargs):
try:
res = await self._client.request(method, route, **kwargs)
res.raise_for_status()
except httpx.HTTPStatusError as exc:
if exc.response.status_code in (
status.HTTP_401_UNAUTHORIZED,
status.HTTP_403_FORBIDDEN,
):
raise CloudUnauthorizedError(str(exc)) from exc
elif exc.response.status_code == status.HTTP_404_NOT_FOUND:
raise ObjectNotFound(http_exc=exc) from exc
else:
raise
if res.status_code == status.HTTP_204_NO_CONTENT:
return
return res.json()
|