Normalizes EIP-712 fixed-length bytes in Ragger client

Also some type-hinting and simplified the integer encoding
This commit is contained in:
Alexandre Paillier
2024-03-26 18:33:34 +01:00
parent a2107b81c4
commit 02efe1df14

View File

@@ -4,7 +4,8 @@ import re
import signal
import sys
import copy
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import struct
from client import keychain
from client.client import EthAppClient, EIP712FieldType
@@ -118,69 +119,63 @@ def send_struct_def_field(typename, keyname):
return (typename, type_enum, typesize, array_lvls)
def encode_integer(value, typesize):
data = bytearray()
def encode_integer(value: Union[str | int], typesize: int) -> bytes:
# Some are already represented as integers in the JSON, but most as strings
if isinstance(value, str):
base = 10
if value.startswith("0x"):
base = 16
value = int(value, base)
value = int(value, 0)
if value == 0:
data.append(0)
data = b'\x00'
else:
if value < 0: # negative number, send it as unsigned
mask = 0
for i in range(typesize): # make a mask as big as the typesize
mask = (mask << 8) | 0xff
value &= mask
while value > 0:
data.append(value & 0xff)
value >>= 8
data.reverse()
# biggest uint type accepted by struct.pack
uint64_mask = 0xffffffffffffffff
data = struct.pack(">QQQQ",
(value >> 192) & uint64_mask,
(value >> 128) & uint64_mask,
(value >> 64) & uint64_mask,
value & uint64_mask)
data = data[len(data) - typesize:]
data = data.lstrip(b'\x00')
return data
def encode_int(value, typesize):
def encode_int(value: str, typesize: int) -> bytes:
return encode_integer(value, typesize)
def encode_uint(value, typesize):
def encode_uint(value: str, typesize: int) -> bytes:
return encode_integer(value, typesize)
def encode_hex_string(value, size):
data = bytearray()
value = value[2:] # skip 0x
byte_idx = 0
while byte_idx < size:
data.append(int(value[(byte_idx * 2):(byte_idx * 2 + 2)], 16))
byte_idx += 1
return data
def encode_hex_string(value: str, size: int) -> bytes:
assert value.startswith("0x")
value = value[2:]
if len(value) < (size * 2):
value = value.rjust(size * 2, "0")
assert len(value) == (size * 2)
return bytes.fromhex(value)
def encode_address(value, typesize):
def encode_address(value: str, typesize: int) -> bytes:
return encode_hex_string(value, 20)
def encode_bool(value, typesize):
return encode_integer(value, typesize)
def encode_bool(value: str, typesize: int) -> bytes:
return encode_integer(value, 1)
def encode_string(value, typesize):
def encode_string(value: str, typesize: int) -> bytes:
data = bytearray()
for char in value:
data.append(ord(char))
return data
def encode_bytes_fix(value, typesize):
def encode_bytes_fix(value: str, typesize: int) -> bytes:
return encode_hex_string(value, typesize)
def encode_bytes_dyn(value, typesize):
def encode_bytes_dyn(value: str, typesize: int) -> bytes:
# length of the value string
# - the length of 0x (2)
# / by the length of one byte in a hex string (2)