// UTF8-CPP   http://utfcpp.sourceforge.net/
//   Slightly simplified (and reformatted) to fit openMSX coding style.

// Copyright 2006 Nemanja Trifunovic

/*
Permission is hereby granted, free of charge, to any person or organization
obtaining a copy of the software and accompanying documentation covered by
this license (the "Software") to use, reproduce, display, distribute,
execute, and transmit the Software, and to prepare derivative works of the
Software, and to permit third-parties to whom the Software is furnished to
do so, all subject to the following:

The copyright notices in the Software and this entire statement, including
the above license grant, this restriction and the following disclaimer,
must be included in all copies of the Software, in whole or in part, and
all derivative works of the Software, unless such copies or derivative
works are solely in the form of machine-executable object code generated by
a source language processor.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
*/

#ifndef UTF8_CHECKED_HH
#define UTF8_CHECKED_HH

#include "utf8_core.hh"

#include <stdexcept>
#include <string_view>

namespace utf8 {

// Exceptions that may be thrown from the library functions.
class invalid_code_point : public std::exception
{
	uint32_t cp;
public:
	explicit invalid_code_point(uint32_t cp_) : cp(cp_) {}
	[[nodiscard]] const char* what() const noexcept override { return "Invalid code point"; }
	[[nodiscard]] uint32_t code_point() const { return cp; }
};

class invalid_utf8 : public std::exception
{
	uint8_t u8;
public:
	explicit invalid_utf8(uint8_t u) : u8(u) {}
	[[nodiscard]] const char* what() const noexcept override { return "Invalid UTF-8"; }
	[[nodiscard]] uint8_t utf8_octet() const { return u8; }
};

class invalid_utf16 : public std::exception
{
	uint16_t u16;
public:
	explicit invalid_utf16(uint16_t u) : u16(u) {}
	[[nodiscard]] const char* what() const noexcept override { return "Invalid UTF-16"; }
	[[nodiscard]] uint16_t utf16_word() const { return u16; }
};

class not_enough_room : public std::exception
{
public:
	[[nodiscard]] const char* what() const noexcept override { return "Not enough space"; }
};

// The library API - functions intended to be called by the users

template<typename octet_iterator, typename output_iterator>
output_iterator replace_invalid(octet_iterator start, octet_iterator end,
                                output_iterator out, uint32_t replacement)
{
	while (start != end) {
		auto sequence_start = start;
		internal::utf_error err_code = internal::validate_next(start, end);
		switch (err_code) {
		using enum internal::utf_error;
		case OK:
			for (auto it = sequence_start; it != start; ++it) {
				*out++ = *it;
			}
			break;
		case NOT_ENOUGH_ROOM:
			throw not_enough_room();
		case INVALID_LEAD:
			append(replacement, out);
			++start;
			break;
		case INCOMPLETE_SEQUENCE:
		case OVERLONG_SEQUENCE:
		case INVALID_CODE_POINT:
			append(replacement, out);
			++start;
			// just one replacement mark for the sequence
			while (internal::is_trail(*start) && start != end) {
				++start;
			}
			break;
		}
	}
	return out;
}

template<typename octet_iterator, typename output_iterator>
inline output_iterator replace_invalid(octet_iterator start, octet_iterator end,
                                       output_iterator out)
{
	return replace_invalid(start, end, out, 0xfffd);
}

template<typename octet_iterator>
octet_iterator append(uint32_t cp, octet_iterator result)
{
	if (!internal::is_code_point_valid(cp)) {
		throw invalid_code_point(cp);
	}
	if (cp < 0x80) {
		// one octet
		*result++ = cp;
	} else if (cp < 0x800) {
		// two octets
		*result++ = ((cp >>  6) & 0x1f) | 0xc0; // 0b110.'....  (5)
		*result++ = ((cp >>  0) & 0x3f) | 0x80; // 0b10..'....  (6)
	} else if (cp < 0x10000) {
		// three octets
		*result++ = ((cp >> 12) & 0x0f) | 0xe0; // 0b1110'....  (4)
		*result++ = ((cp >>  6) & 0x3f) | 0x80; // 0b10..'....  (6)
		*result++ = ((cp >>  0) & 0x3f) | 0x80; // 0b10..'....  (6)
	} else if (cp <= internal::CODE_POINT_MAX) {
		// four octets
		*result++ = ((cp >> 18) & 0x07) | 0xf0; // 0b1111'0...  (3)
		*result++ = ((cp >> 12) & 0x3f) | 0x80; // 0b10..'....  (6)
		*result++ = ((cp >>  6) & 0x3f) | 0x80; // 0b10..'....  (6)
		*result++ = ((cp >>  0) & 0x3f) | 0x80; // 0b10..'....  (6)
	} else {
		throw invalid_code_point(cp);
	}
	return result;
}

template<typename octet_iterator>
uint32_t next(octet_iterator& it, octet_iterator end)
{
	uint32_t cp = 0;
	internal::utf_error err_code = internal::validate_next(it, end, &cp);
	switch (err_code) {
	using enum internal::utf_error;
	case OK:
		break;
	case NOT_ENOUGH_ROOM:
		throw not_enough_room();
	case INVALID_LEAD:
	case INCOMPLETE_SEQUENCE:
	case OVERLONG_SEQUENCE:
		throw invalid_utf8(*it);
	case INVALID_CODE_POINT:
		throw invalid_code_point(cp);
	}
	return cp;
}

template<typename octet_iterator>
[[nodiscard]] uint32_t peek_next(octet_iterator it, octet_iterator end)
{
	return next(it, end);
}

template<typename octet_iterator>
uint32_t prior(octet_iterator& it, octet_iterator start)
{
	auto end = it;
	while (internal::is_trail(*(--it))) {
		if (it < start) {
			// error - no lead byte in the sequence
			throw invalid_utf8(*it);
		}
	}
	auto temp = it;
	return next(temp, end);
}

template<typename octet_iterator, typename distance_type>
void advance(octet_iterator& it, distance_type n, octet_iterator end)
{
	repeat(n, [&] {	next(it, end); });
}

template<typename octet_iterator>
[[nodiscard]] auto distance(octet_iterator first, octet_iterator last)
{
	typename std::iterator_traits<octet_iterator>::difference_type dist = 0;
	while (first < last) {
		++dist;
		next(first, last);
	}
	return dist;
}

template<typename u16bit_iterator, typename octet_iterator>
octet_iterator utf16to8(u16bit_iterator start, u16bit_iterator end,
                        octet_iterator result)
{
	while (start != end) {
		uint32_t cp = *start++;
		// Take care of surrogate pairs first
		if (internal::is_surrogate(cp)) {
			if (start == end) {
				throw invalid_utf16(*start);
			}
			auto trail_surrogate = *start++;
			if (trail_surrogate < internal::TRAIL_SURROGATE_MIN ||
			    trail_surrogate > internal::TRAIL_SURROGATE_MAX) {
				throw invalid_utf16(trail_surrogate);
			}
			cp = (cp << 10) + trail_surrogate + internal::SURROGATE_OFFSET;
		}
		result = append(cp, result);
	}
	return result;
}

template<typename u16bit_iterator, typename octet_iterator>
u16bit_iterator utf8to16(octet_iterator start, octet_iterator end,
                         u16bit_iterator result)
{
	while (start != end) {
		uint32_t cp = next(start, end);
		if (cp > 0xffff) { // make a surrogate pair
			*result++ = (cp >> 10)   + internal::LEAD_OFFSET;
			*result++ = (cp & 0x3ff) + internal::TRAIL_SURROGATE_MIN;
		} else {
			*result++ = cp;
		}
	}
	return result;
}

template<typename octet_iterator, typename u32bit_iterator>
octet_iterator utf32to8(u32bit_iterator start, u32bit_iterator end,
                        octet_iterator result)
{
	while (start != end) {
		result = append(*start++, result);
	}
	return result;
}

template<typename octet_iterator, typename u32bit_iterator>
u32bit_iterator utf8to32(octet_iterator start, octet_iterator end,
                         u32bit_iterator result)
{
	while (start < end) {
		*result++ = next(start, end);
	}
	return result;
}

// The iterator class
template<typename octet_iterator>
class iterator
{
	octet_iterator it;
	octet_iterator range_start;
	octet_iterator range_end;

public:
	using iterator_category = std::bidirectional_iterator_tag;
	using difference_type   = ptrdiff_t;
	using value_type        = uint32_t;
	using pointer           = uint32_t*;
	using reference         = uint32_t&;

	iterator() = default;
	iterator(const octet_iterator& octet_it,
	         const octet_iterator& range_start_,
	         const octet_iterator& range_end_)
		: it(octet_it)
		, range_start(range_start_)
		, range_end(range_end_)
	{
		if (it < range_start || it > range_end) {
			throw std::out_of_range("Invalid utf-8 iterator position");
		}
	}
	// the default "big three" are OK
	[[nodiscard]] octet_iterator base() const { return it; }
	[[nodiscard]] uint32_t operator*() const
	{
		auto temp = it;
		return next(temp, range_end);
	}
	[[nodiscard]] bool operator==(const iterator& rhs) const
	{
		if ((range_start != rhs.range_start) ||
		    (range_end   != rhs.range_end)) {
			throw std::logic_error(
				"Comparing utf-8 iterators defined with different ranges");
		}
		return it == rhs.it;
	}
	iterator& operator++()
	{
		next(it, range_end);
		return *this;
	}
	iterator operator++(int)
	{
		auto temp = *this;
		next(it, range_end);
		return temp;
	}
	iterator& operator--()
	{
		prior(it, range_start);
		return *this;
	}
	iterator operator--(int)
	{
		auto temp = *this;
		prior(it, range_start);
		return temp;
	}
};

#ifdef _WIN32
[[nodiscard]] std::wstring utf8to16  (std::string_view utf8);
[[nodiscard]] std::string  utf16to8  (std::wstring_view utf16);
#endif

} // namespace utf8

#endif
