wallet/mpir_and_base58.h

184 lines
6.2 KiB
C
Raw Normal View History

#pragma once
/* This header handles the stuff related to arbitrary precision arithmetic values*/
namespace mpir {
class mpz {
mpz_t mpz_m;
public:
mpz() {
mpz_init2(mpz_m, 624);
}
mpz(const char* str, int base) {
mpz_init(mpz_m);
if (mpz_set_str(mpz_m, str, base))throw BadDataException();
}
~mpz() {
mpz_clear(mpz_m);
}
operator mpz_ptr() {
return mpz_m;
}
};
extern mpz ristretto25519_curve_order;
}
// This global thread local object is explicitly constructed
// on the heap in the code on need once the thread starts running
// and destroyed when the thread exits.
// This explicit construction and destruction is a workaround because C++ lacks
// support for concurrent processes
// and is therefore unable to correctly handle non pod thread local objects. It will
// not correctly construct and
// destruct thread_local objects by itself, because the model machine does not have a model
// for threading. It has a pile of matchsticks and a tub of glue with which you can build
// your own model. It supports all the stuff you need for threads, but has no idea how all
// these moving parts fit together.
// So you have to construct and destruct your non pod thread local objects in code.
class thread_local__ {
public:
// These exist to avoid the high cost of repeatedly creating and destroying temporary mpz objects.
// They are expensive to create and destroy, because they use heap allocation.
// These are used all over the place as temporaries.
mpir::mpz q, r, n;
};
extern thread_local thread_local__* thl;
//extern thread_local thread_local__* thl; //Constructor in app.obj.
// Destructor does not seem to get called, hence using a pointer
// and destroying explicitly, rather than an std::unique_ptr
// needs testing to figure out what is going on
namespace ro {
using ristretto255::scalar, ristretto255::point;
auto fasthash(uint64_t, std::span<const uint64_t>)->uint32_t;
void right_justify_string(char*, char*);
bool is_alphanumeric_fixed_length(unsigned int, const char*);
template <class T> typename std::enable_if<
ro::is_blob_field_type<T>::value,
decltype(T::type_indentifier, uint32_t())
>::type fasthash(const T& p_blob) {
static_assert(sizeof(T) % 8 == 0, "fasthash assumes a multiple of 8 bytes");
return fasthash(
T::type_indentifier,
std::span< const uint64_t >(
reinterpret_cast<const uint64_t*>(&p_blob.blob[0]),
p_blob.blob.size() / 8
)
);
}
void map_base_from_mpir58(char*);
void map_base_to_mpir58(const char*, char*, size_t);
template <typename T> class base58 : public CompileSizedString<
((sizeof(T) * 8 + 4) * 4943ui64 + 28955ui64) / 28956ui64 - (std::is_same_v<std::remove_cvref_t<T>, scalar> ? 1 : 0)> {
public:
// The rational number 4943 / 28956 is minisculy larger than log(2)/log(58)
// hence rounding up the nearest integer guarantees it will always be big enough.
base58() = default;
~base58() = default;
base58(const T&);
base58(const char* p);
static const char* bin(
typename const decltype(T::type_indentifier, char())* p,
T& sc
);
static void bin(const base58<T>& str, T& sc);
static T bin(const char* str) {
T sc;
bin(str, sc);
return sc;
};
T bin() const {
T sc;
bin(*this, sc);
return sc;
};
operator T() const {
return bin();
}
};
template <class T> typename const decltype(base58<T>::length, T::type_indentifier, uint32_t())
// cannot be consteval or constexpr, because has to be called after the mpir temp values are constructed
check_range() {
if (thl == nullptr)thl = new thread_local__();
mpz_ui_pow_ui(thl->n, 58, base58<T>::length);
if constexpr (std::is_same_v<std::remove_cvref_t<T>, scalar>) {
mpz_fdiv_q(thl->q, thl->n, mpir::ristretto25519_curve_order);
}
else {
mpz_fdiv_q_2exp(thl->q, thl->n, sizeof(T) * 8);
}
assert(mpz_cmp_ui(thl->q, UINT32_MAX) <= 0);
return static_cast<uint32_t>(mpz_get_ui(thl->q));
}
template <class T> const uint32_t check_range_m{ check_range<T>() };
template <class T> const char* base58<T>::bin(
typename const decltype(T::type_indentifier, char())* p,
T& sc
) {
const uint32_t range = check_range_m<T>;
base58 strsc;
char* ps = strsc;
map_base_to_mpir58(p, ps, strsc.length);
if (p[base58::length] > ' ')throw OversizeBase58String();
if (mpz_set_str(thl->n, ps, 58))throw BadStringRepresentationOfCryptoIdException();
if constexpr (std::is_same_v<std::remove_cvref_t<T>, scalar>) {
mpz_fdiv_qr(thl->q, thl->r, thl->n, mpir::ristretto25519_curve_order);
}
else {
mpz_fdiv_q_2exp(thl->q, thl->n, sizeof(sc.blob) * 8);
mpz_fdiv_r_2exp(thl->r, thl->n, sizeof(sc.blob) * 8);
}
size_t count;
mpz_export(&(sc.blob[0]), &count, -1, 1, -1, 0, thl->r);
if (count < sizeof(sc.blob))memset(&sc.blob[count], 0, sizeof(sc.blob) - count);
mpir_ui ck{ (static_cast<uint64_t>(fasthash(sc)) * static_cast<uint64_t>(range)) >> 32 };
if (ck != mpz_get_ui(thl->q)) throw BadStringRepresentationOfCryptoIdException();
return p + base58<T>::length;
}
template <class T> char* to_base58(
// does no string memory allocation, p has to point into a buffer,
// return value points to next position in the buffer, which is now null
typename decltype(check_range_m<T>, char())* p,
const T& sc
) {
mpir_ui ck{ (static_cast<uint64_t>(fasthash(sc)) * static_cast<uint64_t>(check_range_m<T>)) >> 32 };
mpz_import(thl->n, sizeof(sc.blob), -1, 1, -1, 0, &sc.blob[0]);
if constexpr (std::is_same_v<std::remove_cvref_t<T>, scalar>) {
mpz_addmul_ui(thl->n, mpir::ristretto25519_curve_order, ck);
}
else {
mpz_set_ui(thl->r, ck);
mpz_mul_2exp(thl->r, thl->r, sizeof(sc.blob) * 8);
mpz_add(thl->n, thl->n, thl->r);
}
mpz_get_str(p, 58, thl->n);
char* terminal_null{ p + sizeof(base58<T>) - 1 };
right_justify_string(p, terminal_null);
map_base_from_mpir58(p);
return terminal_null;
// return value points to trailing null
}
template <class T> base58<T>::base58(const T& el) {
to_base58<T>(static_cast<char*>(*this), el);
}
template <class T> base58<T>::base58(const char* p) {
memmove(this, p, (this->length));
std::array<char, (this->length)> test;
map_base_to_mpir58(this, test, this->length); //Force an exception for bad char
if (p[this->length] > ' ')throw OversizeBase58String();
this->operator char* [this->length] = '\0';
}
}