305 lines
11 KiB
Python
305 lines
11 KiB
Python
# Copyright 2016 MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Test Motor's AIOHTTPGridFSHandler."""
|
|
|
|
import asyncio
|
|
import datetime
|
|
import email
|
|
import logging
|
|
import sys
|
|
import test
|
|
import time
|
|
from test.asyncio_tests import AsyncIOTestCase, asyncio_test
|
|
|
|
import pytest
|
|
|
|
# MOTOR-1477 - after libzstd is supported on build hosts
|
|
# we can remove this guard.
|
|
if sys.version_info >= (3, 14):
|
|
try:
|
|
import compression.zstd # noqa: F401
|
|
except ModuleNotFoundError:
|
|
pytest.skip(allow_module_level=True)
|
|
|
|
|
|
import aiohttp
|
|
import aiohttp.web
|
|
import gridfs
|
|
|
|
from motor.aiohttp import AIOHTTPGridFS
|
|
from motor.motor_gridfs import _hash_gridout
|
|
|
|
|
|
def format_date(d):
|
|
return time.strftime("%a, %d %b %Y %H:%M:%S GMT", d.utctimetuple())
|
|
|
|
|
|
def parse_date(d):
|
|
date_tuple = email.utils.parsedate(d)
|
|
return datetime.datetime.fromtimestamp(time.mktime(date_tuple))
|
|
|
|
|
|
def expires(response):
|
|
return parse_date(response.headers["Expires"])
|
|
|
|
|
|
class AIOHTTPGridFSHandlerTestBase(AsyncIOTestCase):
|
|
fs = None
|
|
file_id = None
|
|
|
|
def tearDown(self):
|
|
self.loop.run_until_complete(self.stop())
|
|
super().tearDown()
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
logging.getLogger("aiohttp.web").setLevel(logging.CRITICAL)
|
|
|
|
cls.fs = gridfs.GridFS(test.env.sync_cx.motor_test)
|
|
|
|
# Make a 500k file in GridFS with filename 'foo'
|
|
cls.contents = b"Jesse" * 100 * 1024
|
|
|
|
# Record when we created the file, to check the Last-Modified header
|
|
cls.put_start = datetime.datetime.now(datetime.timezone.utc).replace(
|
|
microsecond=0, tzinfo=None
|
|
)
|
|
file_id = "id"
|
|
cls.file_id = file_id
|
|
cls.fs.delete(cls.file_id)
|
|
cls.fs.put(cls.contents, _id=file_id, filename="foo", content_type="my type")
|
|
item = cls.fs.get(file_id)
|
|
cls.contents_hash = _hash_gridout(item)
|
|
cls.put_end = datetime.datetime.now(datetime.timezone.utc).replace(
|
|
microsecond=0, tzinfo=None
|
|
)
|
|
cls.app = cls.srv = cls.app_handler = None
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.fs.delete(cls.file_id)
|
|
super().tearDownClass()
|
|
|
|
async def start_app(self, http_gridfs=None, extra_routes=None):
|
|
self.app = aiohttp.web.Application()
|
|
resource = self.app.router.add_resource("/fs/{filename}")
|
|
handler = http_gridfs or AIOHTTPGridFS(self.db)
|
|
resource.add_route("GET", handler)
|
|
resource.add_route("HEAD", handler)
|
|
|
|
if extra_routes:
|
|
for route, handler in extra_routes.items():
|
|
resource = self.app.router.add_resource(route)
|
|
resource.add_route("GET", handler)
|
|
|
|
self.app_handler = self.app.make_handler()
|
|
server = self.loop.create_server(self.app_handler, host="localhost", port=8088)
|
|
|
|
self.srv, _ = await asyncio.gather(server, self.app.startup())
|
|
|
|
async def request(self, method, path, if_modified_since=None, headers=None):
|
|
headers = headers or {}
|
|
if if_modified_since:
|
|
headers["If-Modified-Since"] = format_date(if_modified_since)
|
|
|
|
session = aiohttp.ClientSession()
|
|
|
|
try:
|
|
method = getattr(session, method)
|
|
|
|
resp = await method("http://localhost:8088%s" % path, headers=headers)
|
|
await resp.read()
|
|
return resp
|
|
finally:
|
|
await session.close()
|
|
|
|
def get(self, path, **kwargs):
|
|
return self.request("get", path, **kwargs)
|
|
|
|
def head(self, path, **kwargs):
|
|
return self.request("head", path, **kwargs)
|
|
|
|
async def stop(self):
|
|
# aiohttp.rtfd.io/en/stable/web.html#aiohttp-web-graceful-shutdown
|
|
if self.srv is not None:
|
|
self.srv.close()
|
|
await self.srv.wait_closed()
|
|
if self.app is not None:
|
|
await self.app.shutdown()
|
|
await self.app_handler.shutdown(timeout=1)
|
|
await self.app.cleanup()
|
|
|
|
|
|
class AIOHTTPGridFSHandlerTest(AIOHTTPGridFSHandlerTestBase):
|
|
@asyncio_test
|
|
async def test_basic(self):
|
|
await self.start_app()
|
|
# First request
|
|
response = await self.get("/fs/foo")
|
|
|
|
self.assertEqual(200, response.status)
|
|
self.assertEqual(self.contents, (await response.read()))
|
|
self.assertEqual(len(self.contents), int(response.headers["Content-Length"]))
|
|
self.assertEqual("my type", response.headers["Content-Type"])
|
|
self.assertEqual("public", response.headers["Cache-Control"])
|
|
self.assertTrue("Expires" not in response.headers)
|
|
|
|
etag = response.headers["Etag"]
|
|
last_mod_dt = parse_date(response.headers["Last-Modified"])
|
|
self.assertEqual(self.contents_hash, etag.strip('"'))
|
|
self.assertTrue(self.put_start <= last_mod_dt <= self.put_end)
|
|
|
|
# Now check we get 304 NOT MODIFIED responses as appropriate
|
|
for ims_value in (last_mod_dt, last_mod_dt + datetime.timedelta(seconds=1)):
|
|
response = await self.get("/fs/foo", if_modified_since=ims_value)
|
|
self.assertEqual(304, response.status)
|
|
self.assertEqual(b"", (await response.read()))
|
|
|
|
# If-Modified-Since in the past, get whole response back
|
|
response = await self.get(
|
|
"/fs/foo", if_modified_since=last_mod_dt - datetime.timedelta(seconds=1)
|
|
)
|
|
self.assertEqual(200, response.status)
|
|
self.assertEqual(self.contents, (await response.read()))
|
|
|
|
# Matching Etag
|
|
response = await self.get("/fs/foo", headers={"If-None-Match": etag})
|
|
self.assertEqual(304, response.status)
|
|
self.assertEqual(b"", (await response.read()))
|
|
|
|
# Mismatched Etag
|
|
response = await self.get("/fs/foo", headers={"If-None-Match": etag + "a"})
|
|
self.assertEqual(200, response.status)
|
|
self.assertEqual(self.contents, (await response.read()))
|
|
|
|
@asyncio_test
|
|
async def test_404(self):
|
|
await self.start_app()
|
|
response = await self.get("/fs/bar")
|
|
self.assertEqual(404, response.status)
|
|
|
|
@asyncio_test
|
|
async def test_head(self):
|
|
await self.start_app()
|
|
response = await self.head("/fs/foo")
|
|
|
|
etag = response.headers["Etag"]
|
|
last_mod_dt = parse_date(response.headers["Last-Modified"])
|
|
|
|
self.assertEqual(200, response.status)
|
|
# Empty body for HEAD request.
|
|
self.assertEqual(b"", (await response.read()))
|
|
self.assertEqual(len(self.contents), int(response.headers["Content-Length"]))
|
|
self.assertEqual("my type", response.headers["Content-Type"])
|
|
self.assertEqual(self.contents_hash, etag.strip('"'))
|
|
self.assertTrue(self.put_start <= last_mod_dt <= self.put_end)
|
|
self.assertEqual("public", response.headers["Cache-Control"])
|
|
|
|
@asyncio_test
|
|
async def test_bad_route(self):
|
|
handler = AIOHTTPGridFS(self.db)
|
|
await self.start_app(extra_routes={"/x/{wrongname}": handler})
|
|
response = await self.get("/x/foo")
|
|
self.assertEqual(500, response.status)
|
|
msg = 'Bad AIOHTTPGridFS route "/x/{wrongname}"'
|
|
self.assertIn(msg, (await response.text()))
|
|
|
|
@asyncio_test
|
|
async def test_content_type(self):
|
|
await self.start_app()
|
|
# Check that GridFSHandler uses file extension to guess Content-Type
|
|
# if not provided
|
|
for filename, expected_type in [
|
|
("bar", "octet-stream"),
|
|
("bar.png", "png"),
|
|
("ht.html", "html"),
|
|
]:
|
|
# 'fs' is PyMongo's blocking GridFS
|
|
_id = self.fs.put(b"", filename=filename)
|
|
self.addCleanup(self.fs.delete, _id)
|
|
|
|
for method in self.get, self.head:
|
|
response = await method("/fs/" + filename)
|
|
self.assertEqual(200, response.status)
|
|
# mimetypes are platform-defined, be fuzzy
|
|
self.assertIn(expected_type, response.content_type)
|
|
|
|
@asyncio_test
|
|
async def test_post(self):
|
|
# Only allow GET and HEAD, even if a POST route is added.
|
|
await self.start_app()
|
|
result = await self.request("post", "/fs/foo")
|
|
self.assertEqual(405, result.status)
|
|
|
|
|
|
class AIOHTTPTZAwareGridFSHandlerTest(AIOHTTPGridFSHandlerTestBase):
|
|
@asyncio_test
|
|
async def test_tz_aware(self):
|
|
client = self.asyncio_client(tz_aware=True)
|
|
await self.start_app(AIOHTTPGridFS(client.motor_test))
|
|
now = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
ago = now - datetime.timedelta(minutes=10)
|
|
hence = now + datetime.timedelta(minutes=10)
|
|
|
|
response = await self.get("/fs/foo", if_modified_since=ago)
|
|
self.assertEqual(200, response.status)
|
|
|
|
response = await self.get("/fs/foo", if_modified_since=hence)
|
|
self.assertEqual(304, response.status)
|
|
|
|
|
|
class AIOHTTPCustomHTTPGridFSTest(AIOHTTPGridFSHandlerTestBase):
|
|
@asyncio_test
|
|
async def test_get_gridfs_file(self):
|
|
def getter(bucket, filename, request):
|
|
# Test overriding the get_gridfs_file() method, path is
|
|
# interpreted as file_id instead of filename.
|
|
return bucket.open_download_stream(file_id=filename)
|
|
|
|
def cache_time(path, modified, mime_type):
|
|
return 10
|
|
|
|
def extras(response, gridout):
|
|
response.headers["quux"] = "fizzledy"
|
|
|
|
await self.start_app(
|
|
AIOHTTPGridFS(
|
|
self.db, get_gridfs_file=getter, get_cache_time=cache_time, set_extra_headers=extras
|
|
)
|
|
)
|
|
|
|
# We overrode get_gridfs_file so we expect getting by filename *not* to
|
|
# work now; we'll get a 404. We have to get by file_id now.
|
|
response = await self.get("/fs/foo")
|
|
self.assertEqual(404, response.status)
|
|
|
|
response = await self.get("/fs/" + str(self.file_id))
|
|
self.assertEqual(200, response.status)
|
|
|
|
self.assertEqual(self.contents, (await response.read()))
|
|
cache_control = response.headers["Cache-Control"]
|
|
self.assertRegex(cache_control, r"max-age=\d+")
|
|
self.assertEqual(10, int(cache_control.split("=")[1]))
|
|
expiration = parse_date(response.headers["Expires"])
|
|
now = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
|
|
# It should expire about 10 seconds from now
|
|
self.assertTrue(
|
|
datetime.timedelta(seconds=8) < expiration - now < datetime.timedelta(seconds=12)
|
|
)
|
|
|
|
self.assertEqual("fizzledy", response.headers["quux"])
|