libsnapshot: Refactor cow_reader decompression.

Bug: 162274240
Test: cow_api_test
Change-Id: I12c177f3ebb7bb0550669bd5edbdbbde6f572cfd
This commit is contained in:
David Anderson 2020-08-26 17:20:27 -07:00
parent c7c7252289
commit 511c4bc601
5 changed files with 320 additions and 92 deletions

View File

@ -134,6 +134,7 @@ cc_defaults {
],
export_include_dirs: ["include"],
srcs: [
"cow_decompress.cpp",
"cow_reader.cpp",
"cow_writer.cpp",
],

View File

@ -0,0 +1,211 @@
//
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "cow_decompress.h"
#include <utility>
#include <android-base/logging.h>
#include <zlib.h>
namespace android {
namespace snapshot {
class NoDecompressor final : public IDecompressor {
public:
bool Decompress(size_t) override;
};
bool NoDecompressor::Decompress(size_t) {
size_t stream_remaining = stream_->Size();
while (stream_remaining) {
size_t buffer_size = stream_remaining;
uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size));
if (!buffer) {
LOG(ERROR) << "Could not acquire buffer from sink";
return false;
}
// Read until we can fill the buffer.
uint8_t* buffer_pos = buffer;
size_t bytes_to_read = std::min(buffer_size, stream_remaining);
while (bytes_to_read) {
size_t read;
if (!stream_->Read(buffer_pos, bytes_to_read, &read)) {
return false;
}
if (!read) {
LOG(ERROR) << "Stream ended prematurely";
return false;
}
if (!sink_->ReturnData(buffer_pos, read)) {
LOG(ERROR) << "Could not return buffer to sink";
return false;
}
buffer_pos += read;
bytes_to_read -= read;
stream_remaining -= read;
}
}
return true;
}
std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() {
return std::unique_ptr<IDecompressor>(new NoDecompressor());
}
// Read chunks of the COW and incrementally stream them to the decoder.
class StreamDecompressor : public IDecompressor {
public:
bool Decompress(size_t output_bytes) override;
virtual bool Init() = 0;
virtual bool DecompressInput(const uint8_t* data, size_t length) = 0;
virtual bool Done() = 0;
protected:
bool GetFreshBuffer();
size_t output_bytes_;
size_t stream_remaining_;
uint8_t* output_buffer_ = nullptr;
size_t output_buffer_remaining_ = 0;
};
static constexpr size_t kChunkSize = 4096;
bool StreamDecompressor::Decompress(size_t output_bytes) {
if (!Init()) {
return false;
}
stream_remaining_ = stream_->Size();
output_bytes_ = output_bytes;
uint8_t chunk[kChunkSize];
while (stream_remaining_) {
size_t read = std::min(stream_remaining_, sizeof(chunk));
if (!stream_->Read(chunk, read, &read)) {
return false;
}
if (!read) {
LOG(ERROR) << "Stream ended prematurely";
return false;
}
if (!DecompressInput(chunk, read)) {
return false;
}
stream_remaining_ -= read;
if (stream_remaining_ && Done()) {
LOG(ERROR) << "Decompressor terminated early";
return false;
}
}
if (!Done()) {
LOG(ERROR) << "Decompressor expected more bytes";
return false;
}
return true;
}
bool StreamDecompressor::GetFreshBuffer() {
size_t request_size = std::min(output_bytes_, kChunkSize);
output_buffer_ =
reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_));
if (!output_buffer_) {
LOG(ERROR) << "Could not acquire buffer from sink";
return false;
}
return true;
}
class GzDecompressor final : public StreamDecompressor {
public:
~GzDecompressor();
bool Init() override;
bool DecompressInput(const uint8_t* data, size_t length) override;
bool Done() override { return ended_; }
private:
z_stream z_ = {};
bool ended_ = false;
};
bool GzDecompressor::Init() {
if (int rv = inflateInit(&z_); rv != Z_OK) {
LOG(ERROR) << "inflateInit returned error code " << rv;
return false;
}
return true;
}
GzDecompressor::~GzDecompressor() {
inflateEnd(&z_);
}
bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) {
z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
z_.avail_in = length;
while (z_.avail_in) {
// If no more output buffer, grab a new buffer.
if (z_.avail_out == 0) {
if (!GetFreshBuffer()) {
return false;
}
z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
z_.avail_out = output_buffer_remaining_;
}
// Remember the position of the output buffer so we can call ReturnData.
auto avail_out = z_.avail_out;
// Decompress.
int rv = inflate(&z_, Z_NO_FLUSH);
if (rv != Z_OK && rv != Z_STREAM_END) {
LOG(ERROR) << "inflate returned error code " << rv;
return false;
}
size_t returned = avail_out - z_.avail_out;
if (!sink_->ReturnData(output_buffer_, returned)) {
LOG(ERROR) << "Could not return buffer to sink";
return false;
}
output_buffer_ += returned;
output_buffer_remaining_ -= returned;
if (rv == Z_STREAM_END) {
if (z_.avail_in) {
LOG(ERROR) << "Gz stream ended prematurely";
return false;
}
ended_ = true;
return true;
}
}
return true;
}
std::unique_ptr<IDecompressor> IDecompressor::Gz() {
return std::unique_ptr<IDecompressor>(new GzDecompressor());
}
} // namespace snapshot
} // namespace android

View File

@ -0,0 +1,56 @@
//
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <libsnapshot/cow_reader.h>
namespace android {
namespace snapshot {
class IByteStream {
public:
virtual ~IByteStream() {}
// Read up to |length| bytes, storing the number of bytes read in the out-
// parameter. If the end of the stream is reached, 0 is returned.
virtual bool Read(void* buffer, size_t length, size_t* read) = 0;
// Size of the stream.
virtual size_t Size() const = 0;
};
class IDecompressor {
public:
virtual ~IDecompressor() {}
// Factory methods for decompression methods.
static std::unique_ptr<IDecompressor> Uncompressed();
static std::unique_ptr<IDecompressor> Gz();
// |output_bytes| is the expected total number of bytes to sink.
virtual bool Decompress(size_t output_bytes) = 0;
void set_stream(IByteStream* stream) { stream_ = stream; }
void set_sink(IByteSink* sink) { sink_ = sink; }
protected:
IByteStream* stream_ = nullptr;
IByteSink* sink_ = nullptr;
};
} // namespace snapshot
} // namespace android

View File

@ -17,10 +17,13 @@
#include <sys/types.h>
#include <unistd.h>
#include <limits>
#include <android-base/file.h>
#include <android-base/logging.h>
#include <libsnapshot/cow_reader.h>
#include <zlib.h>
#include "cow_decompress.h"
namespace android {
namespace snapshot {
@ -171,7 +174,7 @@ std::unique_ptr<ICowOpIter> CowReader::GetOpIter() {
return std::make_unique<CowOpIter>(std::move(ops_buffer), header_.ops_size);
}
bool CowReader::GetRawBytes(uint64_t offset, void* buffer, size_t len) {
bool CowReader::GetRawBytes(uint64_t offset, void* buffer, size_t len, size_t* read) {
// Validate the offset, taking care to acknowledge possible overflow of offset+len.
if (offset < sizeof(header_) || offset >= header_.ops_offset || len >= fd_size_ ||
offset + len > header_.ops_offset) {
@ -182,104 +185,63 @@ bool CowReader::GetRawBytes(uint64_t offset, void* buffer, size_t len) {
PLOG(ERROR) << "lseek to read raw bytes failed";
return false;
}
if (!android::base::ReadFully(fd_, buffer, len)) {
PLOG(ERROR) << "read raw bytes failed";
ssize_t rv = TEMP_FAILURE_RETRY(::read(fd_.get(), buffer, len));
if (rv < 0) {
PLOG(ERROR) << "read failed";
return false;
}
*read = rv;
return true;
}
class CowDataStream final : public IByteStream {
public:
CowDataStream(CowReader* reader, uint64_t offset, size_t data_length)
: reader_(reader), offset_(offset), data_length_(data_length) {
remaining_ = data_length_;
}
bool Read(void* buffer, size_t length, size_t* read) override {
size_t to_read = std::min(length, remaining_);
if (!to_read) {
*read = 0;
return true;
}
if (!reader_->GetRawBytes(offset_, buffer, to_read, read)) {
return false;
}
offset_ += *read;
remaining_ -= *read;
return true;
}
size_t Size() const override { return data_length_; }
private:
CowReader* reader_;
uint64_t offset_;
size_t data_length_;
size_t remaining_;
};
bool CowReader::ReadData(const CowOperation& op, IByteSink* sink) {
uint64_t offset = op.source;
std::unique_ptr<IDecompressor> decompressor;
switch (op.compression) {
case kCowCompressNone: {
size_t remaining = op.data_length;
while (remaining) {
size_t amount = remaining;
void* buffer = sink->GetBuffer(amount, &amount);
if (!buffer) {
LOG(ERROR) << "Could not acquire buffer from sink";
return false;
}
if (!GetRawBytes(offset, buffer, amount)) {
return false;
}
if (!sink->ReturnData(buffer, amount)) {
LOG(ERROR) << "Could not return buffer to sink";
return false;
}
remaining -= amount;
offset += amount;
}
return true;
}
case kCowCompressGz: {
auto input = std::make_unique<Bytef[]>(op.data_length);
if (!GetRawBytes(offset, input.get(), op.data_length)) {
return false;
}
z_stream z = {};
z.next_in = input.get();
z.avail_in = op.data_length;
if (int rv = inflateInit(&z); rv != Z_OK) {
LOG(ERROR) << "inflateInit returned error code " << rv;
return false;
}
while (z.total_out < header_.block_size) {
// If no more output buffer, grab a new buffer.
if (z.avail_out == 0) {
size_t amount = header_.block_size - z.total_out;
z.next_out = reinterpret_cast<Bytef*>(sink->GetBuffer(amount, &amount));
if (!z.next_out) {
LOG(ERROR) << "Could not acquire buffer from sink";
return false;
}
z.avail_out = amount;
}
// Remember the position of the output buffer so we can call ReturnData.
auto buffer = z.next_out;
auto avail_out = z.avail_out;
// Decompress.
int rv = inflate(&z, Z_NO_FLUSH);
if (rv != Z_OK && rv != Z_STREAM_END) {
LOG(ERROR) << "inflate returned error code " << rv;
return false;
}
// Return the section of the buffer that was updated.
if (z.avail_out < avail_out && !sink->ReturnData(buffer, avail_out - z.avail_out)) {
LOG(ERROR) << "Could not return buffer to sink";
return false;
}
if (rv == Z_STREAM_END) {
// Error if the stream has ended, but we didn't fill the entire block.
if (z.total_out != header_.block_size) {
LOG(ERROR) << "Reached gz stream end but did not read a full block of data";
return false;
}
break;
}
CHECK(rv == Z_OK);
// Error if the stream is expecting more data, but we don't have any to read.
if (z.avail_in == 0) {
LOG(ERROR) << "Gz stream ended prematurely";
return false;
}
}
return true;
}
case kCowCompressNone:
decompressor = IDecompressor::Uncompressed();
break;
case kCowCompressGz:
decompressor = IDecompressor::Gz();
break;
default:
LOG(ERROR) << "Unknown compression type: " << op.compression;
return false;
}
CowDataStream stream(this, op.source, op.data_length);
decompressor->set_stream(&stream);
decompressor->set_sink(sink);
return decompressor->Decompress(header_.block_size);
}
} // namespace snapshot

View File

@ -61,9 +61,6 @@ class ICowReader {
// Return an iterator for retrieving CowOperation entries.
virtual std::unique_ptr<ICowOpIter> GetOpIter() = 0;
// Get raw bytes from the data section.
virtual bool GetRawBytes(uint64_t offset, void* buffer, size_t len) = 0;
// Get decoded bytes from the data section, handling any decompression.
// All retrieved data is passed to the sink.
virtual bool ReadData(const CowOperation& op, IByteSink* sink) = 0;
@ -97,9 +94,10 @@ class CowReader : public ICowReader {
// CowOperation objects. Get() returns a unique CowOperation object
// whose lifeteime depends on the CowOpIter object
std::unique_ptr<ICowOpIter> GetOpIter() override;
bool GetRawBytes(uint64_t offset, void* buffer, size_t len) override;
bool ReadData(const CowOperation& op, IByteSink* sink) override;
bool GetRawBytes(uint64_t offset, void* buffer, size_t len, size_t* read);
private:
android::base::unique_fd owned_fd_;
android::base::borrowed_fd fd_;