feat(api): Cancel requests when client disconnects

This commit is contained in:
TheOnlyWayUp
2024-11-30 19:24:33 +00:00
parent 36c73d01e9
commit 6e222c1f55
2 changed files with 40 additions and 2 deletions
+2 -2
View File
@@ -190,8 +190,8 @@ async def fetch_part_content(part_id: int, cookies: Optional[dict] = None) -> st
async def fetch_cover(url: str, cookies: Optional[dict] = None) -> bytes: async def fetch_cover(url: str, cookies: Optional[dict] = None) -> bytes:
"""Fetch image bytes.""" """Fetch image bytes."""
async with CachedSession( async with CachedSession(
headers=headers, cache=None if cookies else cache headers=headers, cache=None
) as session: # Don't cache requests with Cookies. ) as session: # Don't cache cover requests.
async with session.get(url) as response: async with session.get(url) as response:
response.raise_for_status() response.raise_for_status()
+38
View File
@@ -1,6 +1,7 @@
"""WattpadDownloader API Server.""" """WattpadDownloader API Server."""
from typing import Optional from typing import Optional
import asyncio
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from io import BytesIO from io import BytesIO
@@ -28,12 +29,49 @@ headers = {
} }
class RequestCancelledMiddleware:
# Thanks https://github.com/fastapi/fastapi/discussions/11360#discussion-6427734
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# Let's make a shared queue for the request messages
queue = asyncio.Queue()
async def message_poller(sentinel, handler_task):
nonlocal queue
while True:
message = await receive()
if message["type"] == "http.disconnect":
handler_task.cancel()
return sentinel # Break the loop
# Puts the message in the queue
await queue.put(message)
sentinel = object()
handler_task = asyncio.create_task(self.app(scope, queue.get, send))
asyncio.create_task(message_poller(sentinel, handler_task))
try:
return await handler_task
except asyncio.CancelledError:
print("Cancelling request due to disconnect")
class DownloadMode(Enum): class DownloadMode(Enum):
story = "story" story = "story"
part = "part" part = "part"
collection = "collection" collection = "collection"
app.add_middleware(RequestCancelledMiddleware)
@app.get("/") @app.get("/")
def home(): def home():
return FileResponse(BUILD_PATH / "index.html") return FileResponse(BUILD_PATH / "index.html")