From 11636af323899a8651266f6407a9aa7a00e665e2 Mon Sep 17 00:00:00 2001 From: TennesseeTrash Date: Thu, 18 Sep 2025 23:23:33 +0200 Subject: [PATCH] Initial decoder implementation, including a basic two-way test The implementation is still far from perfect, it's at beast a proof of concept. There are many edge cases that are definitely still not covered, and many rough edges in the quality of the code. I am also not convinced that exceptions are the best error handling method for this, particularly for publicly exposes interfaces that may be much more susceptible to DoS attacks due to malformed input (and the resulting overhead in handling exceptions). --- LibCBOR/CMakeLists.txt | 8 + LibCBOR/Include/CBOR/Core.hpp | 69 +- LibCBOR/Include/CBOR/Decoder.hpp | 203 +++++- LibCBOR/Include/CBOR/Encoder.hpp | 2 + LibCBOR/Source/Core.cpp | 30 + LibCBOR/Source/Decoder.cpp | 1003 ++++++++++++++++++++++++++++++ LibCBOR/Source/Encoder.cpp | 22 +- LibCBOR/Source/Utils.hpp | 43 -- Tests/Main.cpp | 163 ++++- 9 files changed, 1431 insertions(+), 112 deletions(-) create mode 100644 LibCBOR/Source/Core.cpp diff --git a/LibCBOR/CMakeLists.txt b/LibCBOR/CMakeLists.txt index 83061dc..76fda8d 100644 --- a/LibCBOR/CMakeLists.txt +++ b/LibCBOR/CMakeLists.txt @@ -5,6 +5,13 @@ target_compile_features(LibCBOR cxx_std_23 ) +target_compile_options(LibCBOR PUBLIC + $<$:/W4> + $<$>:-Wall -Wextra -Wpedantic -Wold-style-cast -Wcast-align -Wunused + -Woverloaded-virtual -Wconversion -Wsign-conversion -Wformat=2 + -Wnull-dereference -Wdouble-promotion -Wimplicit-fallthrough> +) + target_include_directories(LibCBOR PUBLIC "Include" @@ -15,6 +22,7 @@ target_include_directories(LibCBOR target_sources(LibCBOR PRIVATE + "Source/Core.cpp" "Source/Decoder.cpp" "Source/Encoder.cpp" ) diff --git a/LibCBOR/Include/CBOR/Core.hpp b/LibCBOR/Include/CBOR/Core.hpp index b720f35..ab39203 100644 --- a/LibCBOR/Include/CBOR/Core.hpp +++ b/LibCBOR/Include/CBOR/Core.hpp @@ -49,49 +49,52 @@ namespace CBOR String = 0b0110'0000, Array = 0b1000'0000, Map = 0b1010'0000, - Tagged = 0b1100'0000, + Tag = 0b1100'0000, Other = 0b1110'0000, TypeMask = 0b1110'0000, }; + std::string_view ToString(MajorType type); enum class ArgumentPosition: std::uint8_t { - Direct00 = 0b0000'0000, - Direct01 = 0b0000'0001, - Direct02 = 0b0000'0010, - Direct03 = 0b0000'0011, - Direct04 = 0b0000'0100, - Direct05 = 0b0000'0101, - Direct06 = 0b0000'0110, - Direct07 = 0b0000'0111, - Direct08 = 0b0000'1000, - Direct09 = 0b0000'1001, - Direct10 = 0b0000'1010, - Direct11 = 0b0000'1011, - Direct12 = 0b0000'1100, - Direct13 = 0b0000'1101, - Direct14 = 0b0000'1110, - Direct15 = 0b0000'1111, - Direct16 = 0b0001'0000, - Direct17 = 0b0001'0001, - Direct18 = 0b0001'0010, - Direct19 = 0b0001'0011, - Direct21 = 0b0001'0101, - Direct20 = 0b0001'0100, - Direct23 = 0b0001'0111, - Direct22 = 0b0001'0110, + Direct00 = 0b0000'0000, + Direct01 = 0b0000'0001, + Direct02 = 0b0000'0010, + Direct03 = 0b0000'0011, + Direct04 = 0b0000'0100, + Direct05 = 0b0000'0101, + Direct06 = 0b0000'0110, + Direct07 = 0b0000'0111, + Direct08 = 0b0000'1000, + Direct09 = 0b0000'1001, + Direct10 = 0b0000'1010, + Direct11 = 0b0000'1011, + Direct12 = 0b0000'1100, + Direct13 = 0b0000'1101, + Direct14 = 0b0000'1110, + Direct15 = 0b0000'1111, + Direct16 = 0b0001'0000, + Direct17 = 0b0001'0001, + Direct18 = 0b0001'0010, + Direct19 = 0b0001'0011, + Direct20 = 0b0001'0100, + Direct21 = 0b0001'0101, + Direct22 = 0b0001'0110, + Direct23 = 0b0001'0111, - Next1B = 0b0001'1000, - Next2B = 0b0001'1001, - Next4B = 0b0001'1010, - Next8B = 0b0001'1011, + Next1B = 0b0001'1000, + Next2B = 0b0001'1001, + Next4B = 0b0001'1010, + Next8B = 0b0001'1011, - Reserved28 = 0b0001'1100, - Reserved29 = 0b0001'1101, - Reserved30 = 0b0001'1110, + Reserved28 = 0b0001'1100, + Reserved29 = 0b0001'1101, + Reserved30 = 0b0001'1110, - Indefinite = 0b0001'1111, + Indefinite = 0b0001'1111, + + PositionMask = 0b0001'1111, }; enum class MinorType: std::uint8_t diff --git a/LibCBOR/Include/CBOR/Decoder.hpp b/LibCBOR/Include/CBOR/Decoder.hpp index 9bfd0cf..dd62dca 100644 --- a/LibCBOR/Include/CBOR/Decoder.hpp +++ b/LibCBOR/Include/CBOR/Decoder.hpp @@ -3,6 +3,8 @@ #include "Core.hpp" +#include + namespace CBOR { class DecodeError: public Error @@ -11,16 +13,207 @@ namespace CBOR using Error::Error; }; + // Forward decl + class Decoder; + + class Item + { + private: + Item(Decoder &decoder); + + public: + bool Bool (); + Special Special(); + + std::int8_t Int8 (); + std::int16_t Int16 (); + std::int32_t Int32 (); + std::int64_t Int64 (); + + std::uint8_t Uint8 (); + std::uint16_t Uint16 (); + std::uint32_t Uint32 (); + std::uint64_t Uint64 (); + + // Note(3011): float16_t is currently not supported + float Float (); + double Double (); + + class Binary Binary (); + class String String (); + class Array Array (); + class Map Map (); + + private: + friend class Decoder; + friend class Array; + friend class KeyValue; + + Decoder *mDecoder; + }; + + class KeyValue + { + private: + KeyValue(Decoder &decoder); + + public: + Item Key(); + Item Value(); + private: + friend class Decoder; + friend class Map; + + enum class State + { + Initial, KeyPulled, Done, + }; + + State mState; + Decoder *mDecoder; + }; + + class Binary + { + private: + Binary(Decoder &decoder); + + public: + std::span Get(); + + void AllowIndefinite(); + bool Done(); + std::span Next(); + private: + friend class Decoder; + + bool mHeaderParsed; + bool mIndefiniteAllowed; + bool mDone; + bool mCheckedDone; + Decoder *mDecoder; + }; + + class String + { + private: + String(Decoder &decoder); + + public: + std::string_view Get(); + + void AllowIndefinite(); + bool Done(); + std::string_view Next(); + private: + friend class Decoder; + + bool mHeaderParsed; + bool mIndefiniteAllowed; + bool mDone; + bool mCheckedDone; + Decoder *mDecoder; + }; + + class Array + { + private: + Array(Decoder &decoder); + + public: + bool Done(); + Item Next(); + + private: + static constexpr std::size_t Indefinite = std::numeric_limits::max(); + + friend class Decoder; + + bool mHeaderParsed; + bool mDone; + bool mCheckedDone; + std::size_t mCurrent; + std::size_t mSize; + Decoder *mDecoder; + }; + + class Map + { + private: + Map(Decoder &decoder); + + public: + bool Done(); + KeyValue Next(); + + private: + static constexpr std::size_t Indefinite = std::numeric_limits::max(); + + friend class Decoder; + + bool mHeaderParsed; + bool mDone; + bool mCheckedDone; + std::size_t mCurrent; + std::size_t mSize; + Decoder *mDecoder; + }; + class Decoder { public: - private: - }; + Decoder(std::span buffer); + + bool Bool (); + Special Special(); + + std::int8_t Int8 (); + std::int16_t Int16 (); + std::int32_t Int32 (); + std::int64_t Int64 (); + + std::uint8_t Uint8 (); + std::uint16_t Uint16 (); + std::uint32_t Uint32 (); + std::uint64_t Uint64 (); + + // Note(3011): float16_t is currently not supported + float Float (); + double Double (); + + Binary Binary (); + String String (); + Array Array (); + Map Map (); - class Validator - { - public: private: + friend class Binary; + friend class String; + friend class Array; + friend class Map; + + enum class State: std::uint8_t + { + Initial, + HeaderExtracted, + }; + + struct Header + { + MajorType Type; + MinorType Minor; + ArgumentPosition ArgPosition; + std::uint64_t Argument; + }; + Header PeekHeader(); + Header ExtractHeader(); + + std::span ExtractBinary(std::size_t size); + std::string_view ExtractString(std::size_t size); + + State mState; + std::size_t mCurrent; + std::span mBuffer; }; } diff --git a/LibCBOR/Include/CBOR/Encoder.hpp b/LibCBOR/Include/CBOR/Encoder.hpp index a2c7c86..bcde5ec 100644 --- a/LibCBOR/Include/CBOR/Encoder.hpp +++ b/LibCBOR/Include/CBOR/Encoder.hpp @@ -56,6 +56,8 @@ namespace CBOR void BeginIndefiniteMap(); void End(); + + std::size_t EncodedSize() const; private: std::size_t mCurrent; std::span mBuffer; diff --git a/LibCBOR/Source/Core.cpp b/LibCBOR/Source/Core.cpp new file mode 100644 index 0000000..c9fd871 --- /dev/null +++ b/LibCBOR/Source/Core.cpp @@ -0,0 +1,30 @@ +#include "Core.hpp" + +#include + +namespace CBOR +{ + std::string_view ToString(MajorType type) + { + switch (type) { + case MajorType::Unsigned: + return "unsigned"; + case MajorType::Negative: + return "negative"; + case MajorType::Binary: + return "binary"; + case MajorType::String: + return "string"; + case MajorType::Array: + return "array"; + case MajorType::Map: + return "map"; + case MajorType::Tag: + return "tag"; + case MajorType::Other: + return "other"; + } + + std::unreachable(); + } +} diff --git a/LibCBOR/Source/Decoder.cpp b/LibCBOR/Source/Decoder.cpp index e69de29..0cdaf0f 100644 --- a/LibCBOR/Source/Decoder.cpp +++ b/LibCBOR/Source/Decoder.cpp @@ -0,0 +1,1003 @@ +#include "Decoder.hpp" + +#include "Core.hpp" +#include "Utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace CBOR +{ + class ImplementationError: public DecodeError + { + public: + ImplementationError(std::string_view message) + : DecodeError(std::string("internal implementation error: ").append(message)) + {} + }; + + class InvalidUsageError: public DecodeError + { + public: + InvalidUsageError(std::string_view message) + : DecodeError(std::string("invalid decoder usage: ").append(message)) + {} + }; + + class TypeMismatchError: public DecodeError + { + public: + TypeMismatchError(std::string_view expected, std::string_view got) + : DecodeError(std::format("type mismatch error: expected {}, got {} instead", + expected, got)) + {} + + TypeMismatchError(std::string_view message) + : DecodeError(std::string("type mismatch error: ").append(message)) + {} + }; + + class IndefiniteLengthError: public DecodeError + { + public: + IndefiniteLengthError() + : DecodeError("length error: expected definite length, " + "got an indefinite-length item instead") + {} + }; + + class MalformedDataError: public DecodeError + { + public: + MalformedDataError(std::string_view message) + : DecodeError(std::string("malformed data error: ").append(message)) + {} + }; + + namespace + { + std::size_t SpaceLeft(std::span buffer, std::size_t offset) + { + if (offset >= buffer.size()) { + return 0; + } + return buffer.size() - offset; + } + + void EnsureEnoughSpace(std::span buffer, std::size_t offset, + std::size_t spaceRequired) + { + if (SpaceLeft(buffer, offset) < spaceRequired) { + using namespace std::string_view_literals; + constexpr auto format = "expected at least {}B more to be left in the buffer"sv; + throw MalformedDataError(std::format(format, spaceRequired)); + } + } + + std::uint8_t Read1B(std::span buffer, std::size_t current) + { + EnsureEnoughSpace(buffer, current, 1); + return buffer[current]; + } + + std::uint16_t Read2B(std::span buffer, std::size_t current) + { + EnsureEnoughSpace(buffer, current, 2); + std::uint16_t result = 0; + result |= std::uint16_t(buffer[current ]) ; + result |= std::uint16_t(buffer[current + 1]) << 8; + return NetworkToHost(result); + } + + std::uint32_t Read4B(std::span buffer, std::size_t current) + { + EnsureEnoughSpace(buffer, current, 4); + std::uint32_t result = 0; + result |= std::uint32_t(buffer[current ]) ; + result |= std::uint32_t(buffer[current + 1]) << 8; + result |= std::uint32_t(buffer[current + 2]) << 16; + result |= std::uint32_t(buffer[current + 3]) << 24; + return NetworkToHost(result); + } + + std::uint64_t Read8B(std::span buffer, std::size_t current) + { + EnsureEnoughSpace(buffer, current, 8); + std::uint64_t result = 0; + result |= std::uint64_t(buffer[current ]) ; + result |= std::uint64_t(buffer[current + 1]) << 8; + result |= std::uint64_t(buffer[current + 2]) << 16; + result |= std::uint64_t(buffer[current + 3]) << 24; + result |= std::uint64_t(buffer[current + 4]) << 32; + result |= std::uint64_t(buffer[current + 5]) << 40; + result |= std::uint64_t(buffer[current + 6]) << 48; + result |= std::uint64_t(buffer[current + 7]) << 56; + return NetworkToHost(result); + } + + std::uint8_t Consume1B(std::span buffer, std::size_t ¤t) + { + EnsureEnoughSpace(buffer, current, 1); + return buffer[current++]; + } + + std::uint16_t Consume2B(std::span buffer, std::size_t ¤t) + { + EnsureEnoughSpace(buffer, current, 2); + std::uint16_t result = 0; + result |= std::uint16_t(buffer[current++]) ; + result |= std::uint16_t(buffer[current++]) << 8; + return NetworkToHost(result); + } + + std::uint32_t Consume4B(std::span buffer, std::size_t ¤t) + { + EnsureEnoughSpace(buffer, current, 4); + std::uint32_t result = 0; + result |= std::uint32_t(buffer[current++]) ; + result |= std::uint32_t(buffer[current++]) << 8; + result |= std::uint32_t(buffer[current++]) << 16; + result |= std::uint32_t(buffer[current++]) << 24; + return NetworkToHost(result); + } + + std::uint64_t Consume8B(std::span buffer, std::size_t ¤t) + { + EnsureEnoughSpace(buffer, current, 8); + std::uint64_t result = 0; + result |= std::uint64_t(buffer[current++]) ; + result |= std::uint64_t(buffer[current++]) << 8; + result |= std::uint64_t(buffer[current++]) << 16; + result |= std::uint64_t(buffer[current++]) << 24; + result |= std::uint64_t(buffer[current++]) << 32; + result |= std::uint64_t(buffer[current++]) << 40; + result |= std::uint64_t(buffer[current++]) << 48; + result |= std::uint64_t(buffer[current++]) << 56; + return NetworkToHost(result); + } + + MajorType GetMajorType(std::uint8_t header) + { + return MajorType(header & std::to_underlying(MajorType::TypeMask)); + } + + MinorType GetMinorType(std::uint8_t header) + { + return MinorType(header & std::to_underlying(MinorType::TypeMask)); + } + + ArgumentPosition GetArgumentPosition(std::uint8_t header) + { + return ArgumentPosition(header & std::to_underlying(ArgumentPosition::PositionMask)); + } + + template + T ExtractUnsigned(std::span buffer, std::size_t ¤t) + { + static constexpr std::uint64_t maxValue = std::numeric_limits::max(); + + std::uint8_t header = Consume1B(buffer, current); + MajorType major = GetMajorType(header); + if (major != MajorType::Unsigned) { + throw TypeMismatchError("unsigned", ToString(major)); + } + + ArgumentPosition position = GetArgumentPosition(header); + if (std::to_underlying(position) <= 23) { + return static_cast(std::to_underlying(position)); + } + switch (position) { + case ArgumentPosition::Next1B: { + // Note(3011): Assume we're always dealing with at least 1B. + // In addition, this conversion will always be an identity or a promotion. + return static_cast(Consume1B(buffer, current)); + } + case ArgumentPosition::Next2B: { + std::uint16_t value = Consume2B(buffer, current); + if (value <= maxValue) { + return static_cast(value); + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + case ArgumentPosition::Next4B: { + std::uint32_t value = Consume4B(buffer, current); + if (value <= maxValue) { + return static_cast(value); + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + case ArgumentPosition::Next8B: { + std::uint64_t value = Consume8B(buffer, current); + if (value <= maxValue) { + return static_cast(value); + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + default: + throw MalformedDataError("argument position is reserved for future use, incorrect, " + "or the parser is out of date"); + } + } + + // Note(3011): In this case it includes zero, even though zero is not technically positive. + template + T SignedPositive(std::uint8_t header, std::span buffer, std::size_t ¤t) + { + static constexpr std::uint64_t maxValue = std::numeric_limits::max(); + + ArgumentPosition position = GetArgumentPosition(header); + if (std::to_underlying(position) <= 23) { + return static_cast(std::to_underlying(position)); + } + switch (position) { + case ArgumentPosition::Next1B: { + std::uint8_t value = Consume1B(buffer, current); + if (value <= maxValue) { + return static_cast(value); + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + case ArgumentPosition::Next2B: { + std::uint16_t value = Consume2B(buffer, current); + if (value <= maxValue) { + return static_cast(value); + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + case ArgumentPosition::Next4B: { + std::uint32_t value = Consume4B(buffer, current); + if (value <= maxValue) { + return static_cast(value); + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + case ArgumentPosition::Next8B: { + std::uint64_t value = Consume8B(buffer, current); + if (value <= maxValue) { + return static_cast(value); + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + default: + throw MalformedDataError("argument position is reserved for future use, incorrect, " + "or the parser is out of date"); + } + } + + template + T SignedNegative(std::uint8_t header, std::span buffer, std::size_t ¤t) + { + static constexpr auto actualMin = std::numeric_limits::min(); + static constexpr std::uint64_t minValue = -std::int64_t(actualMin + 1); + + ArgumentPosition position = GetArgumentPosition(header); + if (std::to_underlying(position) <= 23) { + return -static_cast(std::to_underlying(position)) - 1; + } + switch (position) { + case ArgumentPosition::Next1B: { + std::uint8_t value = Consume1B(buffer, current); + if (value <= minValue) { + return -static_cast(std::to_underlying(position)) - 1; + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + case ArgumentPosition::Next2B: { + std::uint16_t value = Consume2B(buffer, current); + if (value <= minValue) { + return -static_cast(std::to_underlying(position)) - 1; + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + case ArgumentPosition::Next4B: { + std::uint32_t value = Consume4B(buffer, current); + if (value <= minValue) { + return -static_cast(std::to_underlying(position)) - 1; + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + case ArgumentPosition::Next8B: { + std::uint64_t value = Consume8B(buffer, current); + if (value <= minValue) { + return -static_cast(std::to_underlying(position)) - 1; + } + else { + throw MalformedDataError("type matches, but the stored value is out of range"); + } + } + default: + throw MalformedDataError("argument position is reserved for future use, incorrect, " + "or the parser is out of date"); + } + } + + template + T ExtractSigned(std::span buffer, std::size_t ¤t) + { + std::uint8_t header = Consume1B(buffer, current); + MajorType major = GetMajorType(header); + if (major != MajorType::Unsigned && major != MajorType::Negative) { + throw TypeMismatchError("integer", ToString(major)); + } + + if (major == MajorType::Unsigned) { + return SignedPositive(header, buffer, current); + } + else { + return SignedNegative(header, buffer, current); + } + } + } + + Item::Item(Decoder &decoder) + : mDecoder(&decoder) + {} + + bool Item::Bool() + { + return mDecoder->Bool(); + } + + Special Item::Special() + { + return mDecoder->Special(); + } + + std::int8_t Item::Int8() + { + return mDecoder->Int8(); + } + + std::int16_t Item::Int16() + { + return mDecoder->Int16(); + } + + std::int32_t Item::Int32() + { + return mDecoder->Int32(); + } + + std::int64_t Item::Int64() + { + return mDecoder->Int64(); + } + + std::uint8_t Item::Uint8() + { + return mDecoder->Uint8(); + } + + std::uint16_t Item::Uint16() + { + return mDecoder->Uint16(); + } + + std::uint32_t Item::Uint32() + { + return mDecoder->Uint32(); + } + + std::uint64_t Item::Uint64() + { + return mDecoder->Uint64(); + } + + float Item::Float() + { + return mDecoder->Float(); + } + + double Item::Double() + { + return mDecoder->Double(); + } + + Binary Item::Binary() + { + return mDecoder->Binary(); + } + + String Item::String() + { + return mDecoder->String(); + } + + Array Item::Array() + { + return mDecoder->Array(); + } + + Map Item::Map() + { + return mDecoder->Map(); + } + + KeyValue::KeyValue(Decoder &decoder) + : mState(State::Initial), mDecoder(&decoder) + {} + + Item KeyValue::Key() + { + if (mState != State::Initial) { + throw InvalidUsageError("the key has already been pulled, or this pair is done"); + } + + mState = State::KeyPulled; + return Item(*mDecoder); + } + + Item KeyValue::Value() + { + if (mState != State::KeyPulled) { + throw InvalidUsageError("the key must be pulled first, or this item is done"); + } + + mState = State::Done; + return Item(*mDecoder); + } + + Binary::Binary(Decoder &decoder) + : mHeaderParsed(false) + , mIndefiniteAllowed(false) + , mDone(false) + , mCheckedDone(false) + , mDecoder(&decoder) + {} + + std::span Binary::Get() + { + if (mDone) { + throw InvalidUsageError("this item has already been fully parsed"); + } + + if (mIndefiniteAllowed) { + throw InvalidUsageError("indefinite-length items have been explicitly enabled, " + "use the funcions for those instead"); + } + + Decoder::Header header = mDecoder->ExtractHeader(); + mHeaderParsed = true; + if (header.Type != MajorType::Binary) { + throw TypeMismatchError("binary", ToString(header.Type)); + } + + if (header.ArgPosition == ArgumentPosition::Indefinite) { + throw IndefiniteLengthError(); + } + + mDone = true; + return mDecoder->ExtractBinary(header.Argument); + } + + void Binary::AllowIndefinite() + { + mIndefiniteAllowed = true; + } + + bool Binary::Done() + { + if (!mHeaderParsed) { + return false; + } + + if (mDone) { + return true; + } + + Decoder::Header header = mDecoder->PeekHeader(); + if (header.Type == MajorType::Other && header.Minor == MinorType::Break) { + mDone = true; + mDecoder->ExtractHeader(); + } + mCheckedDone = true; + return mDone; + } + + std::span Binary::Next() + { + if (!mHeaderParsed) { + Decoder::Header header = mDecoder->ExtractHeader(); + mHeaderParsed = true; + if (header.Type != MajorType::Binary) { + throw TypeMismatchError("binary", ToString(header.Type)); + } + + if (header.ArgPosition != ArgumentPosition::Indefinite) { + mDone = true; + mCheckedDone = true; + return mDecoder->ExtractBinary(header.Argument); + } + + return std::span(); + } + + if (mDone) { + throw InvalidUsageError("this item has already been fully parsed"); + } + + if (!mCheckedDone) { + throw InvalidUsageError("check whether the indefinite binary is done first"); + } + mCheckedDone = false; + + Decoder::Header header = mDecoder->ExtractHeader(); + if (header.Type != MajorType::Binary || header.ArgPosition == ArgumentPosition::Indefinite){ + throw MalformedDataError("an indefinite length binary may only contain " + "definite length binaries"); + } + + return mDecoder->ExtractBinary(header.Argument); + } + + String::String(Decoder &decoder) + : mHeaderParsed(false) + , mIndefiniteAllowed(false) + , mDone(false) + , mCheckedDone(false) + , mDecoder(&decoder) + {} + + std::string_view String::Get() + { + if (mDone) { + throw InvalidUsageError("this item has already been fully parsed"); + } + + if (mIndefiniteAllowed) { + throw InvalidUsageError("indefinite-length items have been explicitly enabled, " + "use the funcions for those instead"); + } + + Decoder::Header header = mDecoder->ExtractHeader(); + mHeaderParsed = true; + if (header.Type != MajorType::String) { + throw TypeMismatchError("string", ToString(header.Type)); + } + + if (header.ArgPosition == ArgumentPosition::Indefinite) { + throw IndefiniteLengthError(); + } + + mDone = true; + return mDecoder->ExtractString(header.Argument); + } + + void String::AllowIndefinite() + { + mIndefiniteAllowed = true; + } + + bool String::Done() + { + if (!mHeaderParsed) { + return false; + } + + if (mDone) { + return true; + } + + std::uint8_t header = Consume1B(mDecoder->mBuffer, mDecoder->mCurrent); + if (GetMajorType(header) == MajorType::Other && GetMinorType(header) == MinorType::Break) { + mDone = true; + mDecoder->ExtractHeader(); + } + mCheckedDone = true; + return mDone; + } + + std::string_view String::Next() + { + if (!mHeaderParsed) { + Decoder::Header header = mDecoder->ExtractHeader(); + mHeaderParsed = true; + if (header.Type != MajorType::String) { + throw TypeMismatchError("string", ToString(header.Type)); + } + + if (header.ArgPosition != ArgumentPosition::Indefinite) { + mDone = true; + mCheckedDone = true; + return mDecoder->ExtractString(header.Argument); + } + + return std::string_view(); + } + + if (mDone) { + throw InvalidUsageError("this item has already been fully parsed"); + } + + if (!mCheckedDone) { + throw InvalidUsageError("check whether the indefinite string is done first"); + } + mCheckedDone = false; + + Decoder::Header header = mDecoder->ExtractHeader(); + if (header.Type != MajorType::Binary || header.ArgPosition == ArgumentPosition::Indefinite){ + throw MalformedDataError("an indefinite length string may only contain " + "definite length strings"); + } + + return mDecoder->ExtractString(header.Argument); + } + + Array::Array(Decoder &decoder) + : mHeaderParsed(false) + , mDone(false) + , mCheckedDone(false) + , mCurrent(0) + , mSize(0) + , mDecoder(&decoder) + {} + + bool Array::Done() + { + if (!mHeaderParsed) { + Decoder::Header header = mDecoder->ExtractHeader(); + mHeaderParsed = true; + if (header.Type != MajorType::Array) { + throw TypeMismatchError("array", ToString(header.Type)); + } + + bool indefinite = header.ArgPosition == ArgumentPosition::Indefinite; + mSize = indefinite ? Indefinite : header.Argument; + + if (!mSize) { + mDone = true; + mCheckedDone = true; + } + } + + if (mDone) { + return true; + } + + if (mSize != Indefinite) { + if (mCurrent >= mSize) { + mDone = true; + } + mCheckedDone = true; + return mDone; + } + + Decoder::Header header = mDecoder->PeekHeader(); + if (header.Type == MajorType::Other && header.Minor == MinorType::Break) { + mDone = true; + mDecoder->ExtractHeader(); + } + mCheckedDone = true; + return mDone; + } + + Item Array::Next() + { + if (mDone) { + throw InvalidUsageError("this item has already been fully parsed"); + } + + if (!mCheckedDone) { + throw InvalidUsageError("check whether the indefinite string is done first"); + } + mCheckedDone = false; + + ++mCurrent; + return Item(*mDecoder); + } + + Map::Map(Decoder &decoder) + : mHeaderParsed(false) + , mDone(false) + , mCheckedDone(false) + , mCurrent(0) + , mSize(0) + , mDecoder(&decoder) + {} + + bool Map::Done() + { + if (!mHeaderParsed) { + Decoder::Header header = mDecoder->ExtractHeader(); + mHeaderParsed = true; + if (header.Type != MajorType::Map) { + throw TypeMismatchError("map", ToString(header.Type)); + } + + bool indefinite = header.ArgPosition == ArgumentPosition::Indefinite; + mSize = indefinite ? Indefinite : header.Argument; + + if (!mSize) { + mDone = true; + mCheckedDone = true; + } + } + + if (mDone) { + return true; + } + + if (mSize != Indefinite) { + if (mCurrent >= mSize) { + mDone = true; + } + mCheckedDone = true; + return mDone; + } + + Decoder::Header header = mDecoder->PeekHeader(); + if (header.Type == MajorType::Other && header.Minor == MinorType::Break) { + mDone = true; + mDecoder->ExtractHeader(); + } + mCheckedDone = true; + return mDone; + } + + KeyValue Map::Next() + { + if (mDone) { + throw InvalidUsageError("this item has already been fully parsed"); + } + + if (!mCheckedDone) { + throw InvalidUsageError("check whether the indefinite string is done first"); + } + mCheckedDone = false; + + ++mCurrent; + return KeyValue(*mDecoder); + } + + Decoder::Decoder(std::span buffer) + : mState(State::Initial), mCurrent(0), mBuffer(buffer) + {} + + bool Decoder::Bool() + { + std::uint8_t header = Consume1B(mBuffer, mCurrent); + if (GetMajorType(header) != MajorType::Other) { + throw TypeMismatchError("bool", ToString(GetMajorType(header))); + } + + if (GetMinorType(header) == MinorType::True) { + return true; + } + if (GetMinorType(header) == MinorType::False) { + return false; + } + + throw TypeMismatchError("expected a simple true/false, got some other value instead"); + } + + Special Decoder::Special() + { + std::uint8_t header = Consume1B(mBuffer, mCurrent); + if (GetMajorType(header) != MajorType::Other) { + throw TypeMismatchError("bool", ToString(GetMajorType(header))); + } + + if (GetMinorType(header) == MinorType::Null) { + return Special::Null; + } + if (GetMinorType(header) == MinorType::Undefined) { + return Special::Undefined; + } + + throw TypeMismatchError("expected a null/undefined, got something else instead"); + } + + std::int8_t Decoder::Int8() + { + return ExtractSigned(mBuffer, mCurrent); + } + + std::int16_t Decoder::Int16() + { + return ExtractSigned(mBuffer, mCurrent); + } + + std::int32_t Decoder::Int32() + { + return ExtractSigned(mBuffer, mCurrent); + } + + std::int64_t Decoder::Int64() + { + return ExtractSigned(mBuffer, mCurrent); + } + + std::uint8_t Decoder::Uint8() + { + return ExtractUnsigned(mBuffer, mCurrent); + } + + std::uint16_t Decoder::Uint16() + { + return ExtractUnsigned(mBuffer, mCurrent); + } + + std::uint32_t Decoder::Uint32() + { + return ExtractUnsigned(mBuffer, mCurrent); + } + + std::uint64_t Decoder::Uint64() + { + return ExtractUnsigned(mBuffer, mCurrent); + } + + float Decoder::Float() + { + std::uint8_t header = Consume1B(mBuffer, mCurrent); + MajorType major = GetMajorType(header); + MinorType minor = GetMinorType(header); + if (major == MajorType::Other && minor == MinorType::Half) { + throw MalformedDataError("half precision floating point numbers are not supported"); + } + if (major == MajorType::Other && minor == MinorType::Float) { + std::uint32_t raw = Consume4B(mBuffer, mCurrent); + return std::bit_cast(raw); + } + if (major == MajorType::Other && minor == MinorType::Double) { + throw MalformedDataError("cannot convert a double to a float"); + } + throw TypeMismatchError("float", ToString(major)); + } + + double Decoder::Double() + { + std::uint8_t header = Consume1B(mBuffer, mCurrent); + MajorType major = GetMajorType(header); + MinorType minor = GetMinorType(header); + if (major == MajorType::Other && minor == MinorType::Half) { + throw MalformedDataError("half precision floating point numbers are not supported"); + } + if (major == MajorType::Other && minor == MinorType::Float) { + throw MalformedDataError("cannot convert a float to a double"); + } + if (major == MajorType::Other && minor == MinorType::Double) { + std::uint64_t raw = Consume8B(mBuffer, mCurrent); + return std::bit_cast(raw); + } + throw TypeMismatchError("double", ToString(major)); + } + + Binary Decoder::Binary() + { + return { *this }; + } + + String Decoder::String() + { + return { *this }; + } + + Array Decoder::Array() + { + return { *this }; + } + + Map Decoder::Map() + { + return { *this }; + } + + Decoder::Header Decoder::PeekHeader() + { + EnsureEnoughSpace(mBuffer, mCurrent, 1); + + std::uint8_t rawHeader = mBuffer[mCurrent]; + Header header { + .Type = MajorType(rawHeader & std::to_underlying(MajorType::TypeMask)), + .Minor = MinorType(rawHeader & std::to_underlying(MinorType::TypeMask)), + .ArgPosition = ArgumentPosition(rawHeader & std::to_underlying(ArgumentPosition::PositionMask)), + .Argument = 0, + }; + + if (std::to_underlying(header.ArgPosition) <= 23) { + header.Argument = std::to_underlying(header.ArgPosition); + } + else if (header.ArgPosition == ArgumentPosition::Next1B) { + header.Argument = Read1B(mBuffer, mCurrent + 1); + } + else if (header.ArgPosition == ArgumentPosition::Next2B) { + header.Argument = Read2B(mBuffer, mCurrent + 1); + } + else if (header.ArgPosition == ArgumentPosition::Next4B) { + header.Argument = Read4B(mBuffer, mCurrent + 1); + } + else if (header.ArgPosition == ArgumentPosition::Next8B) { + header.Argument = Read8B(mBuffer, mCurrent + 1); + } + else if (header.ArgPosition == ArgumentPosition::Indefinite) { + // Nothing more needs to happen + } + else { + // bruh, this happends even with the special ending values ... + //throw MalformedDataError("value reserved for future use in the input buffer " + // "(this version may be too old to parse this data)"); + } + return header; + } + + Decoder::Header Decoder::ExtractHeader() + { + EnsureEnoughSpace(mBuffer, mCurrent, 1); + + std::uint8_t rawHeader = mBuffer[mCurrent++]; + Header header { + .Type = MajorType(rawHeader & std::to_underlying(MajorType::TypeMask)), + .Minor = MinorType(rawHeader & std::to_underlying(MinorType::TypeMask)), + .ArgPosition = ArgumentPosition(rawHeader & std::to_underlying(ArgumentPosition::PositionMask)), + .Argument = 0, + }; + + if (std::to_underlying(header.ArgPosition) <= 23) { + header.Argument = std::to_underlying(header.ArgPosition); + } + else if (header.ArgPosition == ArgumentPosition::Next1B) { + header.Argument = Consume1B(mBuffer, mCurrent); + } + else if (header.ArgPosition == ArgumentPosition::Next2B) { + header.Argument = Consume2B(mBuffer, mCurrent); + } + else if (header.ArgPosition == ArgumentPosition::Next4B) { + header.Argument = Consume4B(mBuffer, mCurrent); + } + else if (header.ArgPosition == ArgumentPosition::Next8B) { + header.Argument = Consume8B(mBuffer, mCurrent); + } + else if (header.ArgPosition == ArgumentPosition::Indefinite) { + // Nothing more needs to happen + } + else { + throw MalformedDataError("value reserved for future use in the input buffer " + "(this version may be too old to parse this data)"); + } + return header; + } + + std::span Decoder::ExtractBinary(std::size_t size) + { + EnsureEnoughSpace(mBuffer, mCurrent, size); + std::span result(mBuffer.data() + mCurrent, size); + mCurrent += size; + return result; + } + + std::string_view Decoder::ExtractString(std::size_t size) + { + EnsureEnoughSpace(mBuffer, mCurrent, size); + std::string_view result(reinterpret_cast(mBuffer.data() + mCurrent), size); + mCurrent += size; + return result; + } +} diff --git a/LibCBOR/Source/Encoder.cpp b/LibCBOR/Source/Encoder.cpp index b4cddf9..0cf23a8 100644 --- a/LibCBOR/Source/Encoder.cpp +++ b/LibCBOR/Source/Encoder.cpp @@ -1,4 +1,5 @@ #include "Encoder.hpp" + #include "Core.hpp" #include "Utils.hpp" @@ -133,26 +134,26 @@ namespace CBOR if (value >= 0) { if (value <= 23) { EnsureEnoughSpace(mBuffer, mCurrent, 1); - mBuffer[mCurrent++] = std::to_underlying(MajorType::Unsigned) | value; + mBuffer[mCurrent++] = static_cast(std::to_underlying(MajorType::Unsigned) | value); } else { EnsureEnoughSpace(mBuffer, mCurrent, 2); mBuffer[mCurrent++] = std::to_underlying(MajorType::Unsigned) | std::to_underlying(ArgumentPosition::Next1B); - mBuffer[mCurrent++] = value; + mBuffer[mCurrent++] = static_cast(value); } } else { - std::int8_t actual = std::abs(value + 1); + std::int8_t actual = -(value + 1); if (actual <= 23) { EnsureEnoughSpace(mBuffer, mCurrent, 1); - mBuffer[mCurrent++] = std::to_underlying(MajorType::Negative) | actual; + mBuffer[mCurrent++] = static_cast(std::to_underlying(MajorType::Negative) | actual); } else { EnsureEnoughSpace(mBuffer, mCurrent, 2); mBuffer[mCurrent++] = std::to_underlying(MajorType::Negative) | std::to_underlying(ArgumentPosition::Next1B); - mBuffer[mCurrent++] = actual; + mBuffer[mCurrent++] = static_cast(actual); } } } @@ -166,7 +167,7 @@ namespace CBOR Encode(static_cast(value)); } else if (value >= -256 && value <= -1) { - std::uint8_t actual = static_cast(std::abs(value + 1)); + std::uint8_t actual = static_cast(-(value + 1)); EnsureEnoughSpace(mBuffer, mCurrent, 2); mBuffer[mCurrent++] = std::to_underlying(MajorType::Negative) | std::to_underlying(ArgumentPosition::Next1B); @@ -180,10 +181,10 @@ namespace CBOR } else { EnsureEnoughSpace(mBuffer, mCurrent, 3); - std::int16_t actual = std::abs(value + 1); + std::int16_t actual = -(value + 1); mBuffer[mCurrent++] = std::to_underlying(MajorType::Negative) | std::to_underlying(ArgumentPosition::Next2B); - Write(mBuffer, mCurrent, static_cast(value)); + Write(mBuffer, mCurrent, static_cast(actual)); } } @@ -475,4 +476,9 @@ namespace CBOR mBuffer[mCurrent++] = std::to_underlying(MajorType::Other) | std::to_underlying(MinorType::Break); } + + std::size_t BasicEncoder::EncodedSize() const + { + return mBuffer.size() - (mBuffer.size() - mCurrent); + } } diff --git a/LibCBOR/Source/Utils.hpp b/LibCBOR/Source/Utils.hpp index e811572..91b7dfb 100644 --- a/LibCBOR/Source/Utils.hpp +++ b/LibCBOR/Source/Utils.hpp @@ -3,52 +3,9 @@ #include #include -#include namespace CBOR { - [[nodiscard, deprecated("Prefer HostToNetwork or NetworkToHost")]] constexpr - std::uint8_t FlipBytes(std::uint8_t value) - { - return value; - } - - [[nodiscard, deprecated("Prefer HostToNetwork or NetworkToHost")]] constexpr - std::uint16_t FlipBytes(std::uint16_t value) - { - static constexpr std::uint16_t upperMask = 0b1111'1111'0000'0000; - static constexpr std::uint16_t lowerMask = 0b0000'0000'1111'1111; - return (value & upperMask) >> 8 | (value & lowerMask) << 8; - } - - [[nodiscard, deprecated("Prefer HostToNetwork or NetworkToHost")]] constexpr - std::uint32_t FlipBytes(std::uint32_t value) - { - static constexpr std::uint32_t firstMask = 0xFF'00'00'00; - static constexpr std::uint32_t secondMask = 0x00'FF'00'00; - static constexpr std::uint32_t thirdMask = 0x00'00'FF'00; - static constexpr std::uint32_t fourthMask = 0x00'00'00'FF; - return (value & firstMask) >> 24 | (value & secondMask) >> 8 | - (value & thirdMask) << 8 | (value & fourthMask) << 24; - } - - [[nodiscard, deprecated("Prefer HostToNetwork or NetworkToHost")]] constexpr - std::uint64_t FlipBytes(std::uint64_t value) - { - static constexpr std::uint64_t firstMask = 0xFF'00'00'00'00'00'00'00; - static constexpr std::uint64_t secondMask = 0x00'FF'00'00'00'00'00'00; - static constexpr std::uint64_t thirdMask = 0x00'00'FF'00'00'00'00'00; - static constexpr std::uint64_t fourthMask = 0x00'00'00'FF'00'00'00'00; - static constexpr std::uint64_t fifthMask = 0x00'00'00'00'FF'00'00'00; - static constexpr std::uint64_t sixthMask = 0x00'00'00'00'00'FF'00'00; - static constexpr std::uint64_t seventhMask = 0x00'00'00'00'00'00'FF'00; - static constexpr std::uint64_t eighthMask = 0x00'00'00'00'00'00'00'FF; - return (value & firstMask ) >> 56 | (value & secondMask) >> 40 | - (value & thirdMask ) >> 24 | (value & fourthMask) >> 8 | - (value & fifthMask ) << 8 | (value & sixthMask ) >> 24 | - (value & seventhMask) << 40 | (value & eighthMask) << 56; - } - template [[nodiscard]] constexpr T HostToNetwork(T value) diff --git a/Tests/Main.cpp b/Tests/Main.cpp index f21218f..4094ac6 100644 --- a/Tests/Main.cpp +++ b/Tests/Main.cpp @@ -1,39 +1,156 @@ -#include "CBOR/Core.hpp" +#include "CBOR/Decoder.hpp" #include "CBOR/Encoder.hpp" #include #include #include +#include +#include #include +struct SomeStruct +{ + std::string name; + double speed; + float fov; + std::int8_t thing; + std::int64_t slots; + std::uint32_t times; + std::vector tools; +}; + +std::size_t Encode(const SomeStruct &value, std::span buffer) +{ + CBOR::BasicEncoder enc(buffer); + enc.BeginIndefiniteMap(); + + enc.Encode("name"); + enc.Encode(value.name); + + enc.Encode("speed"); + enc.Encode(value.speed); + + enc.Encode("fov"); + enc.Encode(value.fov); + + enc.Encode("thing"); + enc.Encode(value.thing); + + enc.Encode("slots"); + enc.Encode(value.slots); + + enc.Encode("times"); + enc.Encode(value.times); + + enc.Encode("tools"); + enc.BeginIndefiniteArray(); + for (std::string_view tool: value.tools) { + enc.Encode(tool); + } + enc.End(); + + enc.End(); + + return enc.EncodedSize(); +} + +SomeStruct Decode(std::span buffer) +{ + SomeStruct result; + + CBOR::Decoder dec(buffer); + CBOR::Map object = dec.Map(); + while (!object.Done()) { + CBOR::KeyValue kv = object.Next(); + std::string_view key = kv.Key().String().Get(); + CBOR::Item value = kv.Value(); + if (key == "name") { + result.name = value.String().Get(); + } + else if (key == "speed") { + result.speed = value.Double(); + } + else if (key == "fov") { + result.fov = value.Float(); + } + else if (key == "thing") { + result.thing = value.Int8(); + } + else if (key == "slots") { + result.slots = value.Int64(); + } + else if (key == "times") { + result.times = value.Uint32(); + } + else if (key == "tools") { + CBOR::Array tools = value.Array(); + while(!tools.Done()) { + result.tools.push_back(std::string(tools.Next().String().Get())); + } + } + } + + return result; +} + +void Compare(const SomeStruct &s1, const SomeStruct &s2) +{ + if (s1.name != s2.name) { + throw std::runtime_error("test error: names are not the same"); + } + if (s1.speed != s2.speed) { + throw std::runtime_error("test error: speed is not the same"); + } + if (s1.fov != s2.fov) { + throw std::runtime_error("test error: fovs are not the same"); + } + if (s1.thing != s2.thing) { + throw std::runtime_error("test error: things are not the same"); + } + if (s1.slots != s2.slots) { + throw std::runtime_error("test error: slots are not the same"); + } + if (s1.times != s2.times) { + throw std::runtime_error("test error: times are not the same"); + } + for (const auto &[t1, t2]: std::ranges::views::zip(s1.tools, s2.tools)) { + if (t1 != t2) { + throw std::runtime_error("test error: some tools are not the same"); + } + } +} + int main() { using namespace std::string_view_literals; - std::array buffer = {0}; + std::array buffer = {0}; - std::vector binData(5, 'g'); + SomeStruct expected { + .name = "Player1", + .speed = 5.0, + .fov = 110.0f, + .thing = -15, + .slots = 40'000'000, + .times = 1234567, + .tools = { + "pickaxe", + "sword", + "axe", + "magical arrow", + "iron ore", + }, + }; - CBOR::BasicEncoder enc(buffer); + try { + std::size_t encodedSize = Encode(expected, buffer); + std::println("Encoded size: {}", encodedSize); - enc.BeginMap(7); - enc.Encode("Hello "); - enc.Encode("World! "); - enc.Encode("Behold by new power! "); - enc.Encode("It truly is a sight to see ..."sv); - enc.Encode(std::int64_t(1212121212121212)); - enc.Encode(binData); - enc.Encode("random double"); - enc.Encode(420.69); - enc.Encode("random float"); - enc.Encode(420.69f); - enc.Encode("undefined?"); - enc.Encode(CBOR::Special::Undefined); - enc.Encode("null?"); - enc.Encode(CBOR::Special::Null); - enc.End(); + SomeStruct result = Decode(std::span(buffer.data(), encodedSize)); - for (const auto &byte: buffer) { - std::print("{:02x} ", byte); + Compare(expected, result); + std::println("The test has been completed successfully."); + } + catch (const std::exception &e) { + std::println("Error: {}", e.what()); } - std::println(""); }