diff --git a/src/io/streams.hpp b/src/io/streams.hpp index 643cfd987..c69637570 100644 --- a/src/io/streams.hpp +++ b/src/io/streams.hpp @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -20,7 +21,8 @@ #include #include -namespace srb2::io { +namespace srb2::io +{ using StreamSize = uint64_t; using StreamOffset = int64_t; @@ -62,6 +64,11 @@ inline constexpr const bool IsStreamV = IsStream::value; template inline constexpr const bool IsInputOutputStreamV = IsInputOutputStream::value; +class UnexpectedEof : public std::logic_error +{ + using std::logic_error::logic_error; +}; + template >* = nullptr> void read_exact(I& stream, tcb::span buffer) { std::size_t total = 0; @@ -441,6 +448,28 @@ public: return head_; } + friend void read_exact(SpanStream& stream, tcb::span buffer) + { + const std::size_t remaining = stream.span_.size() - stream.head_; + const std::size_t buffer_size = buffer.size(); + if (buffer_size > remaining) + { + // The span's size will never change, so the generic impl of read_exact will enter an inifinite loop. We can + // throw out early. + throw UnexpectedEof("read buffer size > remaining bytes in span"); + } + if (buffer_size == 0) + { + return; + } + + auto copy_begin = std::next(stream.span_.begin(), stream.head_); + auto copy_end = std::next(stream.span_.begin(), stream.head_ + buffer_size); + stream.head_ += buffer_size; + + std::copy(copy_begin, copy_end, buffer.begin()); + } + private: tcb::span span_; std::size_t head_ {0}; @@ -513,6 +542,28 @@ public: } std::vector& vector() { return vec_; } + + friend void read_exact(VecStream& stream, tcb::span buffer) + { + const std::size_t remaining = stream.vec_.size() - stream.head_; + const std::size_t buffer_size = buffer.size(); + if (buffer_size > remaining) + { + // VecStream is not thread safe, so the generic impl of read_exact would enter an infinite loop under + // correct usage. We know when we've reached the end and can throw out early. + throw UnexpectedEof("read buffer size > remaining bytes in vector"); + } + if (buffer_size == 0) + { + return; + } + + auto copy_begin = std::next(stream.vec_.begin(), stream.head_); + auto copy_end = std::next(stream.vec_.begin(), stream.head_ + buffer_size); + stream.head_ += buffer_size; + + std::copy(copy_begin, copy_end, buffer.begin()); + } }; class ZlibException : public std::exception {