Update FileStream API (#3670)
This commit is contained in:
parent
f9d5e12049
commit
4acf5c2c37
@ -178,7 +178,8 @@ class File(Content):
|
||||
return os.path.getsize(self._path)
|
||||
|
||||
def encode(self) -> Stream:
|
||||
return FileStream(self._path)
|
||||
fin = open(self._path, 'rb')
|
||||
return FileStream(self._path, fin)
|
||||
|
||||
def content_type(self) -> str:
|
||||
_, ext = os.path.splitext(self._path)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import io
|
||||
import typing
|
||||
import types
|
||||
import os
|
||||
|
||||
@ -75,32 +76,19 @@ class DuplexStream(Stream):
|
||||
|
||||
|
||||
class FileStream(Stream):
|
||||
def __init__(self, path):
|
||||
def __init__(self, path: str, fin: typing.Any) -> None:
|
||||
self._path = path
|
||||
self._fileobj = None
|
||||
self._size = None
|
||||
self._fin = fin
|
||||
|
||||
async def read(self, size: int=-1) -> bytes:
|
||||
if self._fileobj is None:
|
||||
raise ValueError('I/O operation on unopened file')
|
||||
return self._fileobj.read(size)
|
||||
|
||||
async def open(self):
|
||||
self._fileobj = open(self._path, 'rb')
|
||||
self._size = os.path.getsize(self._path)
|
||||
return self
|
||||
return self._fin.read(size)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._fileobj is not None:
|
||||
self._fileobj.close()
|
||||
self._fin.close()
|
||||
|
||||
@property
|
||||
def size(self) -> int | None:
|
||||
return self._size
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.open()
|
||||
return self
|
||||
return os.path.getsize(self._path)
|
||||
|
||||
|
||||
class HTTPStream(Stream):
|
||||
@ -152,7 +140,7 @@ class MultiPartStream(Stream):
|
||||
# Mutable state...
|
||||
self._form_progress = list(self._form)
|
||||
self._files_progress = list(self._files)
|
||||
self._filestream: FileStream | None = None
|
||||
self._fin: typing.Any = None
|
||||
self._complete = False
|
||||
self._buffer = io.BytesIO()
|
||||
|
||||
@ -196,10 +184,10 @@ class MultiPartStream(Stream):
|
||||
f"\r\n"
|
||||
f"{value}\r\n"
|
||||
).encode("utf-8")
|
||||
elif self._files_progress and self._filestream is None:
|
||||
elif self._files_progress and self._fin is None:
|
||||
# return start of a file item
|
||||
key, value = self._files_progress.pop(0)
|
||||
self._filestream = await FileStream(value).open()
|
||||
self._fin = open(value, 'rb')
|
||||
name = key.translate({10: "%0A", 13: "%0D", 34: "%22"})
|
||||
filename = os.path.basename(value)
|
||||
return (
|
||||
@ -207,15 +195,15 @@ class MultiPartStream(Stream):
|
||||
f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n'
|
||||
f"\r\n"
|
||||
).encode("utf-8")
|
||||
elif self._filestream is not None:
|
||||
chunk = await self._filestream.read(64*1024)
|
||||
elif self._fin is not None:
|
||||
chunk = await self._fin.read(64*1024)
|
||||
if chunk != b'':
|
||||
# return some bytes from file
|
||||
return chunk
|
||||
else:
|
||||
# return end of file item
|
||||
await self._filestream.close()
|
||||
self._filestream = None
|
||||
await self._fin.close()
|
||||
self._fin = None
|
||||
return b"\r\n"
|
||||
elif not self._complete:
|
||||
# return final section of multipart
|
||||
@ -225,9 +213,9 @@ class MultiPartStream(Stream):
|
||||
return b""
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._filestream is not None:
|
||||
await self._filestream.close()
|
||||
self._filestream = None
|
||||
if self._fin is not None:
|
||||
await self._fin.close()
|
||||
self._fin = None
|
||||
self._buffer.close()
|
||||
|
||||
@property
|
||||
|
||||
@ -178,7 +178,8 @@ class File(Content):
|
||||
return os.path.getsize(self._path)
|
||||
|
||||
def encode(self) -> Stream:
|
||||
return FileStream(self._path)
|
||||
fin = open(self._path, 'rb')
|
||||
return FileStream(self._path, fin)
|
||||
|
||||
def content_type(self) -> str:
|
||||
_, ext = os.path.splitext(self._path)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import io
|
||||
import typing
|
||||
import types
|
||||
import os
|
||||
|
||||
@ -75,32 +76,19 @@ class DuplexStream(Stream):
|
||||
|
||||
|
||||
class FileStream(Stream):
|
||||
def __init__(self, path):
|
||||
def __init__(self, path: str, fin: typing.Any) -> None:
|
||||
self._path = path
|
||||
self._fileobj = None
|
||||
self._size = None
|
||||
self._fin = fin
|
||||
|
||||
def read(self, size: int=-1) -> bytes:
|
||||
if self._fileobj is None:
|
||||
raise ValueError('I/O operation on unopened file')
|
||||
return self._fileobj.read(size)
|
||||
|
||||
def open(self):
|
||||
self._fileobj = open(self._path, 'rb')
|
||||
self._size = os.path.getsize(self._path)
|
||||
return self
|
||||
return self._fin.read(size)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._fileobj is not None:
|
||||
self._fileobj.close()
|
||||
self._fin.close()
|
||||
|
||||
@property
|
||||
def size(self) -> int | None:
|
||||
return self._size
|
||||
|
||||
def __enter__(self):
|
||||
self.open()
|
||||
return self
|
||||
return os.path.getsize(self._path)
|
||||
|
||||
|
||||
class HTTPStream(Stream):
|
||||
@ -152,7 +140,7 @@ class MultiPartStream(Stream):
|
||||
# Mutable state...
|
||||
self._form_progress = list(self._form)
|
||||
self._files_progress = list(self._files)
|
||||
self._filestream: FileStream | None = None
|
||||
self._fin: typing.Any = None
|
||||
self._complete = False
|
||||
self._buffer = io.BytesIO()
|
||||
|
||||
@ -196,10 +184,10 @@ class MultiPartStream(Stream):
|
||||
f"\r\n"
|
||||
f"{value}\r\n"
|
||||
).encode("utf-8")
|
||||
elif self._files_progress and self._filestream is None:
|
||||
elif self._files_progress and self._fin is None:
|
||||
# return start of a file item
|
||||
key, value = self._files_progress.pop(0)
|
||||
self._filestream = FileStream(value).open()
|
||||
self._fin = open(value, 'rb')
|
||||
name = key.translate({10: "%0A", 13: "%0D", 34: "%22"})
|
||||
filename = os.path.basename(value)
|
||||
return (
|
||||
@ -207,15 +195,15 @@ class MultiPartStream(Stream):
|
||||
f'Content-Disposition: form-data; name="{name}"; filename="{filename}"\r\n'
|
||||
f"\r\n"
|
||||
).encode("utf-8")
|
||||
elif self._filestream is not None:
|
||||
chunk = self._filestream.read(64*1024)
|
||||
elif self._fin is not None:
|
||||
chunk = self._fin.read(64*1024)
|
||||
if chunk != b'':
|
||||
# return some bytes from file
|
||||
return chunk
|
||||
else:
|
||||
# return end of file item
|
||||
self._filestream.close()
|
||||
self._filestream = None
|
||||
self._fin.close()
|
||||
self._fin = None
|
||||
return b"\r\n"
|
||||
elif not self._complete:
|
||||
# return final section of multipart
|
||||
@ -225,9 +213,9 @@ class MultiPartStream(Stream):
|
||||
return b""
|
||||
|
||||
def close(self) -> None:
|
||||
if self._filestream is not None:
|
||||
self._filestream.close()
|
||||
self._filestream = None
|
||||
if self._fin is not None:
|
||||
self._fin.close()
|
||||
self._fin = None
|
||||
self._buffer.close()
|
||||
|
||||
@property
|
||||
|
||||
@ -30,17 +30,17 @@ def test_filestream(tmp_path):
|
||||
path = tmp_path / "example.txt"
|
||||
path.write_bytes(b"hello world")
|
||||
|
||||
with httpx.FileStream(path) as s:
|
||||
with httpx.File(path).encode() as s:
|
||||
assert s.size == 11
|
||||
assert s.read() == b'hello world'
|
||||
|
||||
with httpx.FileStream(path) as s:
|
||||
with httpx.File(path).encode() as s:
|
||||
assert s.read(5) == b'hello'
|
||||
assert s.read(5) == b' worl'
|
||||
assert s.read(5) == b'd'
|
||||
assert s.read(5) == b''
|
||||
|
||||
with httpx.FileStream(path) as s:
|
||||
with httpx.File(path).encode() as s:
|
||||
assert s.read(5) == b'hello'
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user