feat(api): Cancel requests when client disconnects
This commit is contained in:
@@ -190,8 +190,8 @@ async def fetch_part_content(part_id: int, cookies: Optional[dict] = None) -> st
|
|||||||
async def fetch_cover(url: str, cookies: Optional[dict] = None) -> bytes:
|
async def fetch_cover(url: str, cookies: Optional[dict] = None) -> bytes:
|
||||||
"""Fetch image bytes."""
|
"""Fetch image bytes."""
|
||||||
async with CachedSession(
|
async with CachedSession(
|
||||||
headers=headers, cache=None if cookies else cache
|
headers=headers, cache=None
|
||||||
) as session: # Don't cache requests with Cookies.
|
) as session: # Don't cache cover requests.
|
||||||
async with session.get(url) as response:
|
async with session.get(url) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""WattpadDownloader API Server."""
|
"""WattpadDownloader API Server."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import asyncio
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -28,12 +29,49 @@ headers = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RequestCancelledMiddleware:
|
||||||
|
# Thanks https://github.com/fastapi/fastapi/discussions/11360#discussion-6427734
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Let's make a shared queue for the request messages
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def message_poller(sentinel, handler_task):
|
||||||
|
nonlocal queue
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "http.disconnect":
|
||||||
|
handler_task.cancel()
|
||||||
|
return sentinel # Break the loop
|
||||||
|
|
||||||
|
# Puts the message in the queue
|
||||||
|
await queue.put(message)
|
||||||
|
|
||||||
|
sentinel = object()
|
||||||
|
handler_task = asyncio.create_task(self.app(scope, queue.get, send))
|
||||||
|
asyncio.create_task(message_poller(sentinel, handler_task))
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await handler_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
print("Cancelling request due to disconnect")
|
||||||
|
|
||||||
|
|
||||||
class DownloadMode(Enum):
|
class DownloadMode(Enum):
|
||||||
story = "story"
|
story = "story"
|
||||||
part = "part"
|
part = "part"
|
||||||
collection = "collection"
|
collection = "collection"
|
||||||
|
|
||||||
|
|
||||||
|
app.add_middleware(RequestCancelledMiddleware)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def home():
|
def home():
|
||||||
return FileResponse(BUILD_PATH / "index.html")
|
return FileResponse(BUILD_PATH / "index.html")
|
||||||
|
|||||||
Reference in New Issue
Block a user