#ifndef CMCPP_STRING_HPP #define CMCPP_STRING_HPP #include "context.hpp" #include "integer.hpp" #include "util.hpp" namespace cmcpp { namespace string { const uint32_t MAX_STRING_BYTE_LENGTH = (1U << 28) - 1; inline std::pair store_string_copy(LiftLowerContext &cx, const void *src, uint32_t src_code_units, uint32_t dst_code_unit_size, uint32_t dst_alignment, Encoding dst_encoding) { uint32_t dst_byte_length = dst_code_unit_size * src_code_units; trap_if(cx, dst_byte_length > MAX_STRING_BYTE_LENGTH); if (dst_byte_length > 0) { uint32_t ptr = cx.opts.realloc(0, 0, dst_alignment, dst_byte_length); trap_if(cx, ptr != align_to(ptr, dst_alignment)); trap_if(cx, ptr + dst_byte_length > cx.opts.memory.size()); std::memcpy(&cx.opts.memory[ptr], src, dst_byte_length); return std::make_pair(ptr, src_code_units); } return std::make_pair(0, 0); } inline std::pair store_string_to_utf8(LiftLowerContext &cx, Encoding src_encoding, const void *src, uint32_t src_byte_len, uint32_t worst_case_size) { assert(worst_case_size <= MAX_STRING_BYTE_LENGTH); uint32_t ptr = cx.opts.realloc(0, 0, 1, worst_case_size); trap_if(cx, ptr + src_byte_len > cx.opts.memory.size()); auto encoded = cx.convert(&cx.opts.memory[ptr], worst_case_size, src, src_byte_len, src_encoding, Encoding::Utf8); if (worst_case_size > encoded.second) { ptr = cx.opts.realloc(ptr, worst_case_size, 1, checked_uint32(cx, encoded.second)); assert(ptr + encoded.second <= cx.opts.memory.size()); } return std::make_pair(ptr, checked_uint32(cx, encoded.second)); } inline std::pair store_utf16_to_utf8(LiftLowerContext &cx, const void *src, uint32_t src_code_units) { uint32_t worst_case_size = src_code_units * 3; return store_string_to_utf8(cx, Encoding::Utf16, src, src_code_units * 2, worst_case_size); } inline std::pair store_latin1_to_utf8(LiftLowerContext &cx, const void *src, uint32_t src_code_units) { uint32_t worst_case_size = src_code_units * 2; return store_string_to_utf8(cx, Encoding::Latin1, src, src_code_units, worst_case_size); } inline std::pair store_utf8_to_utf16(LiftLowerContext &cx, const void *src, uint32_t src_code_units) { uint32_t worst_case_size = 2 * src_code_units; trap_if(cx, worst_case_size > MAX_STRING_BYTE_LENGTH); uint32_t ptr = cx.opts.realloc(0, 0, 2, worst_case_size); trap_if(cx, ptr != align_to(ptr, 2)); trap_if(cx, ptr + worst_case_size > cx.opts.memory.size()); auto encoded = cx.convert(&cx.opts.memory[ptr], worst_case_size, src, src_code_units, Encoding::Utf8, Encoding::Utf16); if (encoded.second < worst_case_size) { ptr = cx.opts.realloc(ptr, worst_case_size, 2, checked_uint32(cx, encoded.second)); assert(ptr == align_to(ptr, 2)); assert(ptr + encoded.second <= cx.opts.memory.size()); } uint32_t code_units = checked_uint32(cx, encoded.second / 2); return std::make_pair(ptr, code_units); } inline std::pair store_probably_utf16_to_latin1_or_utf16(LiftLowerContext &cx, const void *src, uint32_t src_code_units) { uint32_t src_byte_length = 2 * src_code_units; trap_if(cx, src_byte_length > MAX_STRING_BYTE_LENGTH); uint32_t ptr = cx.opts.realloc(0, 0, 2, src_byte_length); trap_if(cx, ptr != align_to(ptr, 2)); trap_if(cx, ptr + src_byte_length > cx.opts.memory.size()); auto encoded = cx.convert(&cx.opts.memory[ptr], src_byte_length, src, src_code_units, Encoding::Utf16, Encoding::Utf16); const uint8_t *enc_src_ptr = &cx.opts.memory[ptr]; if (std::any_of(enc_src_ptr, enc_src_ptr + encoded.second, [](unsigned c) { return c >= (1 << 8); })) { uint32_t tagged_code_units = checked_uint32(cx, encoded.second / 2) | UTF16_TAG; return std::make_pair(ptr, tagged_code_units); } uint32_t latin1_size = checked_uint32(cx, encoded.second / 2); for (uint32_t i = 0; i < latin1_size; ++i) cx.opts.memory[ptr + i] = cx.opts.memory[ptr + 2 * i]; ptr = cx.opts.realloc(ptr, src_byte_length, 1, latin1_size); trap_if(cx, ptr + latin1_size > cx.opts.memory.size()); return std::make_pair(ptr, latin1_size); } template std::pair store_string_to_latin1_or_utf16(LiftLowerContext &cx, const T &v) { Encoding src_encoding = ValTrait::encoding; const auto *src = v.data(); const size_t src_code_units = v.size(); const size_t src_byte_length = src_code_units * ValTrait::char_size; assert(src_code_units <= MAX_STRING_BYTE_LENGTH); uint32_t ptr = cx.opts.realloc(0, 0, 2, checked_uint32(cx, src_byte_length)); trap_if(cx, ptr != align_to(ptr, 2)); trap_if(cx, ptr + src_code_units > cx.opts.memory.size()); uint32_t dst_byte_length = 0; for (unsigned usv : v) { // Optimistically assume the character will fit in a single byte (Latin1) if (usv < (1 << 8)) { cx.opts.memory[ptr + dst_byte_length] = static_cast(usv); dst_byte_length += 1; } else { // If it doesn't, convert it to a UTF-16 sequence uint32_t worst_case_size = checked_uint32(cx, 2 * src_code_units); trap_if(cx, worst_case_size > MAX_STRING_BYTE_LENGTH, "Worst case size exceeds maximum string byte length"); ptr = cx.opts.realloc(ptr, checked_uint32(cx, src_byte_length), 2, worst_case_size); trap_if(cx, ptr != align_to(ptr, 2), "Pointer misaligned"); trap_if(cx, ptr + worst_case_size > cx.opts.memory.size(), "Out of bounds access"); #ifdef SIMPLE_UTF16_CONVERSION // Convert entire string to UTF-16 in one go, ignoring the previously computed data --- auto encoded = cx.convert(&cx.opts.memory[ptr], worst_case_size, src, src_code_units * ValTrait::char_size, src_encoding, Encoding::Utf16); if (encoded.second < worst_case_size) { ptr = cx.opts.realloc(ptr, worst_case_size, 2, encoded.second * 2); trap_if(cx, ptr != align_to(ptr, 2), "Pointer misaligned"); trap_if(cx, ptr + encoded.second > cx.opts.memory.size(), "Out of bounds access"); } uint32_t tagged_code_units = checked_uint32(cx, encoded.second / 2) | UTF16_TAG; return std::make_pair(ptr, tagged_code_units); #else // Pad out existing non unicode characters --- for (signed j = dst_byte_length - 1; j >= 0; --j) { cx.opts.memory[ptr + 2 * j] = cx.opts.memory[ptr + j]; cx.opts.memory[ptr + 2 * j + 1] = 0; } // Convert the remaining portion --- uint32_t destPtr = ptr + (2 * dst_byte_length); uint32_t destLen = worst_case_size - (2 * dst_byte_length); void *srcPtr = (char *)src + dst_byte_length * ValTrait::char_size; uint32_t srcLen = checked_uint32(cx, (src_code_units - dst_byte_length) * ValTrait::char_size); auto encoded = cx.convert(&cx.opts.memory[destPtr], destLen, srcPtr, srcLen, src_encoding, Encoding::Utf16); // Add special tag to indicate the string is a UTF-16 string --- uint32_t tagged_code_units = checked_uint32(cx, dst_byte_length + encoded.second / 2) | UTF16_TAG; return std::make_pair(ptr, tagged_code_units); #endif } } if (dst_byte_length < src_code_units) { ptr = cx.opts.realloc(ptr, checked_uint32(cx, src_code_units), 2, dst_byte_length); trap_if(cx, ptr != align_to(ptr, 2), "Pointer misaligned"); trap_if(cx, ptr + dst_byte_length > cx.opts.memory.size(), "Out of bounds access"); } return std::make_pair(ptr, dst_byte_length); } template std::pair store_into_range(LiftLowerContext &cx, const T &v) { Encoding src_encoding = ValTrait::encoding; auto *src = v.data(); const size_t src_tagged_code_units = v.size(); Encoding src_simple_encoding; uint32_t src_code_units; if (src_encoding == Encoding::Latin1_Utf16) { if (src_tagged_code_units & UTF16_TAG) { src_simple_encoding = Encoding::Utf16; src_code_units = checked_uint32(cx, src_tagged_code_units ^ UTF16_TAG); } else { src_simple_encoding = Encoding::Latin1; src_code_units = checked_uint32(cx, src_tagged_code_units); } } else { src_simple_encoding = src_encoding; src_code_units = checked_uint32(cx, src_tagged_code_units); } switch (cx.opts.string_encoding) { case Encoding::Latin1: cx.trap("Invalid guest encoding, must be UTF8, UTF16 or Latin1/UTF16"); break; case Encoding::Utf8: switch (src_simple_encoding) { case Encoding::Utf8: return store_string_copy(cx, src, src_code_units, 1, 1, Encoding::Utf8); case Encoding::Utf16: return store_utf16_to_utf8(cx, src, src_code_units); case Encoding::Latin1: return store_latin1_to_utf8(cx, src, src_code_units); } break; case Encoding::Utf16: switch (src_simple_encoding) { case Encoding::Utf8: return store_utf8_to_utf16(cx, src, src_code_units); case Encoding::Utf16: return store_string_copy(cx, src, src_code_units, 2, 2, Encoding::Utf16); case Encoding::Latin1: return store_string_copy(cx, src, src_code_units, 2, 2, Encoding::Utf16); } break; case Encoding::Latin1_Utf16: switch (src_encoding) { case Encoding::Utf8: return store_string_to_latin1_or_utf16(cx, v); case Encoding::Utf16: return store_string_to_latin1_or_utf16(cx, v); case Encoding::Latin1_Utf16: switch (src_simple_encoding) { case Encoding::Latin1: return store_string_copy(cx, src, src_code_units, 1, 2, Encoding::Latin1); case Encoding::Utf16: return store_probably_utf16_to_latin1_or_utf16(cx, src, src_code_units); } } } assert(false); return std::make_pair(0, 0); } template inline void store(LiftLowerContext &cx, const T &v, uint32_t ptr) { auto [begin, tagged_code_units] = store_into_range(cx, v); integer::store(cx, begin, ptr); integer::store(cx, tagged_code_units, ptr + 4); } template inline WasmValVector lower_flat(LiftLowerContext &cx, const T &v) { auto [ptr, packed_length] = store_into_range(cx, v); return {(int32_t)ptr, (int32_t)packed_length}; } template T load_from_range(const LiftLowerContext &cx, uint32_t ptr, uint32_t tagged_code_units) { uint32_t alignment = 0; uint64_t byte_length = 0; Encoding encoding = Encoding::Utf8; switch (cx.opts.string_encoding) { case Encoding::Utf8: alignment = 1; byte_length = tagged_code_units; encoding = Encoding::Utf8; break; case Encoding::Utf16: alignment = 2; byte_length = 2ull * tagged_code_units; encoding = Encoding::Utf16; break; case Encoding::Latin1_Utf16: alignment = 2; if (tagged_code_units & UTF16_TAG) { byte_length = 2ull * (tagged_code_units ^ UTF16_TAG); encoding = Encoding::Utf16; } else { byte_length = tagged_code_units; encoding = Encoding::Latin1; } break; default: trap_if(cx, false); } trap_if(cx, byte_length > MAX_STRING_BYTE_LENGTH, "string byte length exceeds limit"); trap_if(cx, ptr != align_to(ptr, alignment)); trap_if(cx, static_cast(ptr) + byte_length > cx.opts.memory.size()); size_t char_size = ValTrait::char_size; size_t host_byte_length = static_cast(byte_length * 2); T retVal; if constexpr (std::is_same::value) { retVal.encoding = encoding; } retVal.resize(host_byte_length); auto decoded = cx.convert(retVal.data(), host_byte_length, (void *)&cx.opts.memory[ptr], checked_uint32(cx, byte_length), encoding, ValTrait::encoding == Encoding::Latin1_Utf16 ? encoding : ValTrait::encoding); if ((decoded.second / char_size) < host_byte_length) { retVal.resize(decoded.second / char_size); } return retVal; } template T load(const LiftLowerContext &cx, offset offset) { auto begin = integer::load(cx, offset); auto tagged_code_units = integer::load(cx, offset + 4); return load_from_range(cx, begin, tagged_code_units); } template T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi) { auto ptr = vi.next(); auto packed_length = vi.next(); return load_from_range(cx, ptr, packed_length); } } template inline void store(LiftLowerContext &cx, const T &v, uint32_t ptr) { string::store(cx, v, ptr); } template inline WasmValVector lower_flat(LiftLowerContext &cx, const T &v) { return string::lower_flat(cx, v); } template inline T load(const LiftLowerContext &cx, uint32_t ptr) { return string::load(cx, ptr); } template inline T lift_flat(const LiftLowerContext &cx, const CoreValueIter &vi) { return string::lift_flat(cx, vi); } } #endif