diff --git a/src/api/src/create_book.py b/src/api/src/create_book.py index 5e42436..5e20a2b 100644 --- a/src/api/src/create_book.py +++ b/src/api/src/create_book.py @@ -190,8 +190,8 @@ async def fetch_part_content(part_id: int, cookies: Optional[dict] = None) -> st async def fetch_cover(url: str, cookies: Optional[dict] = None) -> bytes: """Fetch image bytes.""" async with CachedSession( - headers=headers, cache=None if cookies else cache - ) as session: # Don't cache requests with Cookies. + headers=headers, cache=None + ) as session: # Don't cache cover requests. async with session.get(url) as response: response.raise_for_status() diff --git a/src/api/src/main.py b/src/api/src/main.py index f1129d6..091515b 100644 --- a/src/api/src/main.py +++ b/src/api/src/main.py @@ -1,6 +1,7 @@ """WattpadDownloader API Server.""" from typing import Optional +import asyncio import tempfile from pathlib import Path from io import BytesIO @@ -28,12 +29,49 @@ headers = { } +class RequestCancelledMiddleware: + # Thanks https://github.com/fastapi/fastapi/discussions/11360#discussion-6427734 + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + # Let's make a shared queue for the request messages + queue = asyncio.Queue() + + async def message_poller(sentinel, handler_task): + nonlocal queue + while True: + message = await receive() + if message["type"] == "http.disconnect": + handler_task.cancel() + return sentinel # Break the loop + + # Puts the message in the queue + await queue.put(message) + + sentinel = object() + handler_task = asyncio.create_task(self.app(scope, queue.get, send)) + asyncio.create_task(message_poller(sentinel, handler_task)) + + try: + return await handler_task + except asyncio.CancelledError: + print("Cancelling request due to disconnect") + + class DownloadMode(Enum): story = "story" part = "part" collection = "collection" +app.add_middleware(RequestCancelledMiddleware) + + @app.get("/") def home(): return FileResponse(BUILD_PATH / "index.html")