feat(api): Cancel requests when client disconnects
This commit is contained in:
@@ -190,8 +190,8 @@ async def fetch_part_content(part_id: int, cookies: Optional[dict] = None) -> st
|
||||
async def fetch_cover(url: str, cookies: Optional[dict] = None) -> bytes:
|
||||
"""Fetch image bytes."""
|
||||
async with CachedSession(
|
||||
headers=headers, cache=None if cookies else cache
|
||||
) as session: # Don't cache requests with Cookies.
|
||||
headers=headers, cache=None
|
||||
) as session: # Don't cache cover requests.
|
||||
async with session.get(url) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""WattpadDownloader API Server."""
|
||||
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
@@ -28,12 +29,49 @@ headers = {
|
||||
}
|
||||
|
||||
|
||||
class RequestCancelledMiddleware:
|
||||
# Thanks https://github.com/fastapi/fastapi/discussions/11360#discussion-6427734
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Let's make a shared queue for the request messages
|
||||
queue = asyncio.Queue()
|
||||
|
||||
async def message_poller(sentinel, handler_task):
|
||||
nonlocal queue
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
handler_task.cancel()
|
||||
return sentinel # Break the loop
|
||||
|
||||
# Puts the message in the queue
|
||||
await queue.put(message)
|
||||
|
||||
sentinel = object()
|
||||
handler_task = asyncio.create_task(self.app(scope, queue.get, send))
|
||||
asyncio.create_task(message_poller(sentinel, handler_task))
|
||||
|
||||
try:
|
||||
return await handler_task
|
||||
except asyncio.CancelledError:
|
||||
print("Cancelling request due to disconnect")
|
||||
|
||||
|
||||
class DownloadMode(Enum):
|
||||
story = "story"
|
||||
part = "part"
|
||||
collection = "collection"
|
||||
|
||||
|
||||
app.add_middleware(RequestCancelledMiddleware)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def home():
|
||||
return FileResponse(BUILD_PATH / "index.html")
|
||||
|
||||
Reference in New Issue
Block a user