Refactor type hash function to use less RAM

This commit is contained in:
Alexandre Paillier
2022-05-02 15:30:41 +02:00
parent 854791324a
commit 83dda443f4
5 changed files with 301 additions and 348 deletions

View File

@@ -1,303 +0,0 @@
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <stdio.h>
#include "mem.h"
#include "mem_utils.h"
#include "eip712.h"
#include "encode_type.h"
/**
*
* @param[in] lvl_ptr pointer to the first array level of a struct field
* @param[in] lvls_count the number of array levels the struct field contains
* @return \ref true it finished correctly, \ref false if it didn't (memory allocation)
*/
static bool format_field_type_array_levels_string(const void *lvl_ptr, uint8_t lvls_count)
{
uint8_t array_size;
while (lvls_count-- > 0)
{
if (mem_alloc_and_copy_char('[') == NULL)
{
return false;
}
switch (struct_field_array_depth(lvl_ptr, &array_size))
{
case ARRAY_DYNAMIC:
break;
case ARRAY_FIXED_SIZE:
mem_alloc_and_format_uint(array_size, NULL);
break;
default:
// should not be in here :^)
break;
}
if (mem_alloc_and_copy_char(']') == NULL)
{
return false;
}
lvl_ptr = get_next_struct_field_array_lvl(lvl_ptr);
}
return true;
}
/**
*
* @param[in] field_ptr pointer to the struct field
* @return \ref true it finished correctly, \ref false if it didn't (memory allocation)
*/
static bool format_field_string(const void *field_ptr)
{
const char *name;
uint8_t length;
uint16_t field_size;
uint8_t lvls_count;
const uint8_t *lvl_ptr;
// field type name
name = get_struct_field_typename(field_ptr, &length);
if (mem_alloc_and_copy(name, length) == NULL)
{
return false;
}
// field type size
if (struct_field_has_typesize(field_ptr))
{
field_size = get_struct_field_typesize(field_ptr);
switch (struct_field_type(field_ptr))
{
case TYPE_SOL_INT:
case TYPE_SOL_UINT:
field_size *= 8; // bytes -> bits
break;
case TYPE_SOL_BYTES_FIX:
break;
default:
// should not be in here :^)
break;
}
mem_alloc_and_format_uint(field_size, NULL);
}
// field type array levels
if (struct_field_is_array(field_ptr))
{
lvl_ptr = get_struct_field_array_lvls_array(field_ptr, &lvls_count);
format_field_type_array_levels_string(lvl_ptr, lvls_count);
}
// space between field type name and field name
if (mem_alloc_and_copy_char(' ') == NULL)
{
return false;
}
// field name
name = get_struct_field_keyname(field_ptr, &length);
if (mem_alloc_and_copy(name, length) == NULL)
{
return false;
}
return true;
}
/**
*
* @param[in] struct_ptr pointer to the structure we want the typestring of
* @param[in] str_length length of the formatted string in memory
* @return pointer of the string in memory, \ref NULL in case of an error
*/
static const char *format_struct_string(const uint8_t *const struct_ptr, uint16_t *const str_length)
{
const char *str_start;
const char *struct_name;
uint8_t struct_name_length;
const uint8_t *field_ptr;
uint8_t fields_count;
// struct name
struct_name = get_struct_name(struct_ptr, &struct_name_length);
if ((str_start = mem_alloc_and_copy(struct_name, struct_name_length)) == NULL)
{
return NULL;
}
// opening struct parenthese
if (mem_alloc_and_copy_char('(') == NULL)
{
return NULL;
}
field_ptr = get_struct_fields_array(struct_ptr, &fields_count);
for (uint8_t idx = 0; idx < fields_count; ++idx)
{
// comma separating struct fields
if (idx > 0)
{
if (mem_alloc_and_copy_char(',') == NULL)
{
return NULL;
}
}
if (format_field_string(field_ptr) == false)
{
return NULL;
}
field_ptr = get_next_struct_field(field_ptr);
}
// closing struct parenthese
if (mem_alloc_and_copy_char(')') == NULL)
{
return NULL;
}
// compute the length
*str_length = ((char*)mem_alloc(0) - str_start);
return str_start;
}
/**
*
*
* @param[in] structs_array pointer to structs array
* @param[in] deps_count count of how many struct dependencies pointers
* @param[in,out] deps pointer to the first dependency pointer
*/
static void sort_dependencies(uint8_t deps_count,
void **deps)
{
bool changed;
void *tmp_ptr;
const char *name1, *name2;
uint8_t namelen1, namelen2;
int str_cmp_result;
do
{
changed = false;
for (size_t idx = 0; (idx + 1) < deps_count; ++idx)
{
name1 = get_struct_name(*(deps + idx), &namelen1);
name2 = get_struct_name(*(deps + idx + 1), &namelen2);
str_cmp_result = strncmp(name1, name2, MIN(namelen1, namelen2));
if ((str_cmp_result > 0) || ((str_cmp_result == 0) && (namelen1 > namelen2)))
{
tmp_ptr = *(deps + idx);
*(deps + idx) = *(deps + idx + 1);
*(deps + idx + 1) = tmp_ptr;
changed = true;
}
}
}
while (changed);
}
/**
*
*
* @param[in] structs_array pointer to structs array
* @param[out] deps_count count of how many struct dependencie pointers
* @param[in] deps pointer to the first dependency pointer
* @param[in] struct_ptr pointer to the struct we are getting the dependencies of
* @return \ref false in case of a memory allocation error, \ref true otherwise
*/
static bool get_struct_dependencies(const void *const structs_array,
uint8_t *const deps_count,
void *const *const deps,
const void *const struct_ptr)
{
uint8_t fields_count;
const void *field_ptr;
const char *arg_structname;
uint8_t arg_structname_length;
const void *arg_struct_ptr;
size_t dep_idx;
const void **new_dep;
field_ptr = get_struct_fields_array(struct_ptr, &fields_count);
for (uint8_t idx = 0; idx < fields_count; ++idx)
{
if (struct_field_type(field_ptr) == TYPE_CUSTOM)
{
// get struct name
arg_structname = get_struct_field_typename(field_ptr, &arg_structname_length);
// from its name, get the pointer to its definition
arg_struct_ptr = get_structn(structs_array, arg_structname, arg_structname_length);
// check if it is not already present in the dependencies array
for (dep_idx = 0; dep_idx < *deps_count; ++dep_idx)
{
// it's a match!
if (*(deps + dep_idx) == arg_struct_ptr)
{
break;
}
}
// if it's not present in the array, add it and recurse into it
if (dep_idx == *deps_count)
{
if ((new_dep = mem_alloc(sizeof(void*))) == NULL)
{
return false;
}
*new_dep = arg_struct_ptr;
*deps_count += 1;
get_struct_dependencies(structs_array, deps_count, deps, arg_struct_ptr);
}
}
field_ptr = get_next_struct_field(field_ptr);
}
return true;
}
/**
*
*
* @param[in] structs_array pointer to structs array
* @param[in] struct_name name of the given struct
* @param[in] struct_name_length length of the name of the given struct
* @param[out] encoded_length length of the returned string
* @return pointer to encoded string or \ref NULL in case of a memory allocation error
*/
const char *encode_type(const void *const structs_array,
const char *const struct_name,
const uint8_t struct_name_length,
uint16_t *const encoded_length)
{
const void *const struct_ptr = get_structn(structs_array,
struct_name,
struct_name_length);
uint8_t deps_count;
void **deps;
uint16_t length;
const char *typestr;
*encoded_length = 0;
deps_count = 0;
// get list of structs (own + dependencies), properly ordered
deps = mem_alloc(0); // get where the first elem will be
if (get_struct_dependencies(structs_array, &deps_count, deps, struct_ptr) == false)
{
return NULL;
}
sort_dependencies(deps_count, deps);
typestr = format_struct_string(struct_ptr, &length);
*encoded_length += length;
// loop over each struct and generate string
for (int idx = 0; idx < deps_count; ++idx)
{
format_struct_string(*deps, &length);
*encoded_length += length;
deps += 1;
}
return typestr;
}

View File

@@ -1,12 +0,0 @@
#ifndef ENCODE_TYPE_H_
#define ENCODE_TYPE_H_
#include <stdint.h>
#include <stdbool.h>
const char *encode_type(const void *const structs_array,
const char *const struct_name,
const uint8_t struct_name_length,
uint16_t *const encoded_length);
#endif // ENCODE_TYPE_H_

View File

@@ -381,7 +381,7 @@ bool handle_apdu(const uint8_t *const data)
switch (data[OFFSET_P2])
{
case P2_NAME:
type_hash(structs_array, (char*)&data[OFFSET_DATA], data[OFFSET_LC]);
type_hash(structs_array, (char*)&data[OFFSET_DATA], data[OFFSET_LC], true);
// set root type
path_set_root((char*)&data[OFFSET_DATA], data[OFFSET_LC]);
break;

View File

@@ -1,63 +1,330 @@
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>
#include <stdio.h>
#include "eip712.h"
#include "mem.h"
#include "encode_type.h"
#include "mem_utils.h"
#include "eip712.h"
#include "type_hash.h"
#include "shared_context.h"
const uint8_t *type_hash(const void *const structs_array,
const char *const struct_name,
const uint8_t struct_name_length)
static inline void hash_nbytes(const uint8_t *b, uint8_t n)
{
const void *const mem_loc_bak = mem_alloc(0); // backup the memory location
const char *typestr;
uint16_t length;
uint8_t *hash_ptr;
typestr = encode_type(structs_array, struct_name, struct_name_length, &length);
if (typestr == NULL)
#ifdef DEBUG
for (int i = 0; i < n; ++i)
{
return NULL;
printf("%c", b[i]);
}
cx_keccak_init((cx_hash_t*)&global_sha3, 256);
#endif
cx_hash((cx_hash_t*)&global_sha3,
0,
(uint8_t*)typestr,
length,
b,
n,
NULL,
0);
}
#ifdef DEBUG
// Print type string
fwrite(typestr, sizeof(char), length, stdout);
printf("\n");
#endif
static inline void hash_byte(uint8_t b)
{
hash_nbytes(&b, 1);
}
// restore the memory location
mem_dealloc(mem_alloc(0) - mem_loc_bak);
/**
*
* @param[in] lvl_ptr pointer to the first array level of a struct field
* @param[in] lvls_count the number of array levels the struct field contains
* @return \ref true it finished correctly, \ref false if it didn't (memory allocation)
*/
static bool format_field_type_array_levels_string(const void *lvl_ptr, uint8_t lvls_count)
{
uint8_t array_size;
char *uint_str_ptr;
uint8_t uint_str_len;
if ((hash_ptr = mem_alloc(KECCAK256_HASH_BYTESIZE + 1)) == NULL)
while (lvls_count-- > 0)
{
hash_byte('[');
switch (struct_field_array_depth(lvl_ptr, &array_size))
{
case ARRAY_DYNAMIC:
break;
case ARRAY_FIXED_SIZE:
uint_str_ptr = mem_alloc_and_format_uint(array_size, &uint_str_len);
hash_nbytes((uint8_t*)uint_str_ptr, uint_str_len);
mem_dealloc(uint_str_len);
break;
default:
// should not be in here :^)
break;
}
hash_byte(']');
lvl_ptr = get_next_struct_field_array_lvl(lvl_ptr);
}
return true;
}
/**
*
* @param[in] field_ptr pointer to the struct field
* @return \ref true it finished correctly, \ref false if it didn't (memory allocation)
*/
static bool encode_and_hash_field(const void *field_ptr)
{
const char *name;
uint8_t length;
uint16_t field_size;
uint8_t lvls_count;
const uint8_t *lvl_ptr;
char *uint_str_ptr;
uint8_t uint_str_len;
// field type name
name = get_struct_field_typename(field_ptr, &length);
hash_nbytes((uint8_t*)name, length);
// field type size
if (struct_field_has_typesize(field_ptr))
{
field_size = get_struct_field_typesize(field_ptr);
switch (struct_field_type(field_ptr))
{
case TYPE_SOL_INT:
case TYPE_SOL_UINT:
field_size *= 8; // bytes -> bits
break;
case TYPE_SOL_BYTES_FIX:
break;
default:
// should not be in here :^)
break;
}
uint_str_ptr = mem_alloc_and_format_uint(field_size, &uint_str_len);
hash_nbytes((uint8_t*)uint_str_ptr, uint_str_len);
mem_dealloc(uint_str_len);
}
// field type array levels
if (struct_field_is_array(field_ptr))
{
lvl_ptr = get_struct_field_array_lvls_array(field_ptr, &lvls_count);
format_field_type_array_levels_string(lvl_ptr, lvls_count);
}
// space between field type name and field name
hash_byte(' ');
// field name
name = get_struct_field_keyname(field_ptr, &length);
hash_nbytes((uint8_t*)name, length);
return true;
}
/**
*
* @param[in] struct_ptr pointer to the structure we want the typestring of
* @param[in] str_length length of the formatted string in memory
* @return pointer of the string in memory, \ref NULL in case of an error
*/
static bool encode_and_hash_type(const uint8_t *const struct_ptr)
{
const char *struct_name;
uint8_t struct_name_length;
const uint8_t *field_ptr;
uint8_t fields_count;
// struct name
struct_name = get_struct_name(struct_ptr, &struct_name_length);
hash_nbytes((uint8_t*)struct_name, struct_name_length);
// opening struct parenthese
hash_byte('(');
field_ptr = get_struct_fields_array(struct_ptr, &fields_count);
for (uint8_t idx = 0; idx < fields_count; ++idx)
{
// comma separating struct fields
if (idx > 0)
{
hash_byte(',');
}
if (encode_and_hash_field(field_ptr) == false)
{
return NULL;
}
field_ptr = get_next_struct_field(field_ptr);
}
// closing struct parenthese
hash_byte(')');
return true;
}
/**
*
*
* @param[in] structs_array pointer to structs array
* @param[in] deps_count count of how many struct dependencies pointers
* @param[in,out] deps pointer to the first dependency pointer
*/
static void sort_dependencies(uint8_t deps_count,
void **deps)
{
bool changed;
void *tmp_ptr;
const char *name1, *name2;
uint8_t namelen1, namelen2;
int str_cmp_result;
do
{
changed = false;
for (size_t idx = 0; (idx + 1) < deps_count; ++idx)
{
name1 = get_struct_name(*(deps + idx), &namelen1);
name2 = get_struct_name(*(deps + idx + 1), &namelen2);
str_cmp_result = strncmp(name1, name2, MIN(namelen1, namelen2));
if ((str_cmp_result > 0) || ((str_cmp_result == 0) && (namelen1 > namelen2)))
{
tmp_ptr = *(deps + idx);
*(deps + idx) = *(deps + idx + 1);
*(deps + idx + 1) = tmp_ptr;
changed = true;
}
}
}
while (changed);
}
/**
*
*
* @param[in] structs_array pointer to structs array
* @param[out] deps_count count of how many struct dependencie pointers
* @param[in] deps pointer to the first dependency pointer
* @param[in] struct_ptr pointer to the struct we are getting the dependencies of
* @return \ref false in case of a memory allocation error, \ref true otherwise
*/
static bool get_struct_dependencies(const void *const structs_array,
uint8_t *const deps_count,
void *const *const deps,
const void *const struct_ptr)
{
uint8_t fields_count;
const void *field_ptr;
const char *arg_structname;
uint8_t arg_structname_length;
const void *arg_struct_ptr;
size_t dep_idx;
const void **new_dep;
field_ptr = get_struct_fields_array(struct_ptr, &fields_count);
for (uint8_t idx = 0; idx < fields_count; ++idx)
{
if (struct_field_type(field_ptr) == TYPE_CUSTOM)
{
// get struct name
arg_structname = get_struct_field_typename(field_ptr, &arg_structname_length);
// from its name, get the pointer to its definition
arg_struct_ptr = get_structn(structs_array, arg_structname, arg_structname_length);
// check if it is not already present in the dependencies array
for (dep_idx = 0; dep_idx < *deps_count; ++dep_idx)
{
// it's a match!
if (*(deps + dep_idx) == arg_struct_ptr)
{
break;
}
}
// if it's not present in the array, add it and recurse into it
if (dep_idx == *deps_count)
{
if ((new_dep = mem_alloc(sizeof(void*))) == NULL)
{
return false;
}
*new_dep = arg_struct_ptr;
*deps_count += 1;
get_struct_dependencies(structs_array, deps_count, deps, arg_struct_ptr);
}
}
field_ptr = get_next_struct_field(field_ptr);
}
return true;
}
/**
*
*
* @param[in] structs_array pointer to structs array
* @param[in] struct_name name of the given struct
* @param[in] struct_name_length length of the name of the given struct
* @param[in] with_deps if hashed typestring should include struct dependencies
* @return pointer to encoded string or \ref NULL in case of a memory allocation error
*/
const uint8_t *type_hash(const void *const structs_array,
const char *const struct_name,
const uint8_t struct_name_length,
bool with_deps)
{
const void *const struct_ptr = get_structn(structs_array,
struct_name,
struct_name_length);
uint8_t deps_count;
void **deps;
uint8_t *hash_ptr;
cx_keccak_init((cx_hash_t*)&global_sha3, 256); // init hash
if (with_deps)
{
deps_count = 0;
// get list of structs (own + dependencies), properly ordered
deps = mem_alloc(0); // get where the first elem will be
if (get_struct_dependencies(structs_array, &deps_count, deps, struct_ptr) == false)
{
return NULL;
}
sort_dependencies(deps_count, deps);
}
if (encode_and_hash_type(struct_ptr) == false)
{
return NULL;
}
if (with_deps)
{
// loop over each struct and generate string
for (int idx = 0; idx < deps_count; ++idx)
{
encode_and_hash_type(*deps);
deps += 1;
}
mem_dealloc(sizeof(void*) * deps_count);
}
#ifdef DEBUG
printf("\n");
#endif
// End progressive hashing
if ((hash_ptr = mem_alloc(KECCAK256_HASH_BYTESIZE)) == NULL)
{
return NULL;
}
// set TypeHash marker
*hash_ptr = EIP712_TYPE_HASH;
// copy hash into memory
cx_hash((cx_hash_t*)&global_sha3,
CX_LAST,
NULL,
0,
hash_ptr + 1,
hash_ptr,
KECCAK256_HASH_BYTESIZE);
#ifdef DEBUG
// print computed hash
printf("-> 0x");
printf("new -> 0x");
for (int idx = 0; idx < KECCAK256_HASH_BYTESIZE; ++idx)
{
printf("%.02x", (hash_ptr + 1)[idx]);
printf("%.02x", hash_ptr[idx]);
}
printf("\n");
#endif

View File

@@ -5,6 +5,7 @@
const uint8_t *type_hash(const void *const structs_array,
const char *const struct_name,
const uint8_t struct_name_length);
const uint8_t struct_name_length,
bool with_deps);
#endif // TYPE_HASH_H_