Update FileStream API (#3670)

This commit is contained in:
Kim Christie 2025-09-18 13:15:06 +01:00 committed by GitHub
parent f9d5e12049
commit 4acf5c2c37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 39 additions and 61 deletions

View File

@ -178,7 +178,8 @@ class File(Content):
return os.path.getsize(self._path)
def encode(self) -> Stream:
return FileStream(self._path)
fin = open(self._path, 'rb')
return FileStream(self._path, fin)
def content_type(self) -> str:
_, ext = os.path.splitext(self._path)

View File

@ -1,4 +1,5 @@
import io
import typing
import types
import os
@ -75,32 +76,19 @@ class DuplexStream(Stream):
class FileStream(Stream):
def __init__(self, path):
def __init__(self, path: str, fin: typing.Any) -> None:
self._path = path
self._fileobj = None
self._size = None
self._fin = fin
async def read(self, size: int=-1) -> bytes:
if self._fileobj is None:
raise ValueError('I/O operation on unopened file')
return self._fileobj.read(size)
async def open(self):
self._fileobj = open(self._path, 'rb')
self._size = os.path.getsize(self._path)
return self
return self._fin.read(size)
async def close(self) -> None:
if self._fileobj is not None:
self._fileobj.close()
self._fin.close()
@property
def size(self) -> int | None:
return self._size
async def __aenter__(self):
await self.open()
return self
return os.path.getsize(self._path)
class HTTPStream(Stream):
@ -152,7 +140,7 @@ class MultiPartStream(Stream):
# Mutable state...
self._form_progress = list(self._form)
self._files_progress = list(self._files)
self._filestream: FileStream | None = None
self._fin: typing.Any = None
self._complete = False
self._buffer = io.BytesIO()
@ -196,10 +184,10 @@ class MultiPartStream(Stream):
f"\r\n"
f"{value}\r\n"
).encode("utf-8")
elif self._files_progress and self._filestream is None:
elif self._files_progress and self._fin is None:
# return start of a file item
key, value = self._files_progress.pop(0)
self._filestream = await FileStream(value).open()
self._fin = open(value, 'rb')
name = key.translate({10: "%0A", 13: "%0D", 34: "%22"})
filename = os.path.basename(value)
return (
@ -207,15 +195,15 @@ class MultiPartStream(Stream):
f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n'
f"\r\n"
).encode("utf-8")
elif self._filestream is not None:
chunk = await self._filestream.read(64*1024)
elif self._fin is not None:
chunk = await self._fin.read(64*1024)
if chunk != b'':
# return some bytes from file
return chunk
else:
# return end of file item
await self._filestream.close()
self._filestream = None
await self._fin.close()
self._fin = None
return b"\r\n"
elif not self._complete:
# return final section of multipart
@ -225,9 +213,9 @@ class MultiPartStream(Stream):
return b""
async def close(self) -> None:
if self._filestream is not None:
await self._filestream.close()
self._filestream = None
if self._fin is not None:
await self._fin.close()
self._fin = None
self._buffer.close()
@property

View File

@ -178,7 +178,8 @@ class File(Content):
return os.path.getsize(self._path)
def encode(self) -> Stream:
return FileStream(self._path)
fin = open(self._path, 'rb')
return FileStream(self._path, fin)
def content_type(self) -> str:
_, ext = os.path.splitext(self._path)

View File

@ -1,4 +1,5 @@
import io
import typing
import types
import os
@ -75,32 +76,19 @@ class DuplexStream(Stream):
class FileStream(Stream):
def __init__(self, path):
def __init__(self, path: str, fin: typing.Any) -> None:
self._path = path
self._fileobj = None
self._size = None
self._fin = fin
def read(self, size: int=-1) -> bytes:
if self._fileobj is None:
raise ValueError('I/O operation on unopened file')
return self._fileobj.read(size)
def open(self):
self._fileobj = open(self._path, 'rb')
self._size = os.path.getsize(self._path)
return self
return self._fin.read(size)
def close(self) -> None:
if self._fileobj is not None:
self._fileobj.close()
self._fin.close()
@property
def size(self) -> int | None:
return self._size
def __enter__(self):
self.open()
return self
return os.path.getsize(self._path)
class HTTPStream(Stream):
@ -152,7 +140,7 @@ class MultiPartStream(Stream):
# Mutable state...
self._form_progress = list(self._form)
self._files_progress = list(self._files)
self._filestream: FileStream | None = None
self._fin: typing.Any = None
self._complete = False
self._buffer = io.BytesIO()
@ -196,10 +184,10 @@ class MultiPartStream(Stream):
f"\r\n"
f"{value}\r\n"
).encode("utf-8")
elif self._files_progress and self._filestream is None:
elif self._files_progress and self._fin is None:
# return start of a file item
key, value = self._files_progress.pop(0)
self._filestream = FileStream(value).open()
self._fin = open(value, 'rb')
name = key.translate({10: "%0A", 13: "%0D", 34: "%22"})
filename = os.path.basename(value)
return (
@ -207,15 +195,15 @@ class MultiPartStream(Stream):
f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n'
f"\r\n"
).encode("utf-8")
elif self._filestream is not None:
chunk = self._filestream.read(64*1024)
elif self._fin is not None:
chunk = self._fin.read(64*1024)
if chunk != b'':
# return some bytes from file
return chunk
else:
# return end of file item
self._filestream.close()
self._filestream = None
self._fin.close()
self._fin = None
return b"\r\n"
elif not self._complete:
# return final section of multipart
@ -225,9 +213,9 @@ class MultiPartStream(Stream):
return b""
def close(self) -> None:
if self._filestream is not None:
self._filestream.close()
self._filestream = None
if self._fin is not None:
self._fin.close()
self._fin = None
self._buffer.close()
@property

View File

@ -30,17 +30,17 @@ def test_filestream(tmp_path):
path = tmp_path / "example.txt"
path.write_bytes(b"hello world")
with httpx.FileStream(path) as s:
with httpx.File(path).encode() as s:
assert s.size == 11
assert s.read() == b'hello world'
with httpx.FileStream(path) as s:
with httpx.File(path).encode() as s:
assert s.read(5) == b'hello'
assert s.read(5) == b' worl'
assert s.read(5) == b'd'
assert s.read(5) == b''
with httpx.FileStream(path) as s:
with httpx.File(path).encode() as s:
assert s.read(5) == b'hello'