From fe4bfba776d845b65298afb5f80aa6f5afc1304d Mon Sep 17 00:00:00 2001 From: Jakob Lell Date: Tue, 28 Sep 2021 14:01:47 +0200 Subject: [PATCH] Fix issue with sizeof() handling for construct Union in erofs_tool.py --- erofs_tool.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/erofs_tool.py b/erofs_tool.py index a798569..23837b2 100755 --- a/erofs_tool.py +++ b/erofs_tool.py @@ -21,6 +21,8 @@ import argparse import mmap import os + +import construct from construct import Struct, Int32ul, Int16ul, Int8ul, Int64ul, Array, Union from enum import Enum from typing import List, Set @@ -96,6 +98,34 @@ def command_file(args): f.write(data) +def recursive_union_sizeof(struct) -> int: + """ + Calculates the size of a construct struct, recursing through `subcons`. + Will also work for unions with the assumption that all elements of the union have the same size + """ + if isinstance(struct, Union): + union_size = None + for item in struct.subcons: + item_size = recursive_union_sizeof(item) + if union_size is None: + union_size = item_size + elif union_size != item_size: + raise ValueError(f"Inconsistent Union size: {union_size} <=> {item_size}") + return union_size + elif isinstance(struct, Struct): + result = 0 + for item in struct.subcons: + item_size = recursive_union_sizeof(item) + if item_size is None: + breakpoint() + result += recursive_union_sizeof(item) + return result + elif isinstance(struct, construct.Renamed): + return recursive_union_sizeof(struct.subcon) + else: + return struct.sizeof() + + # noinspection PyUnresolvedReferences struct_erofs_super = Struct( "magic" / Int32ul, @@ -180,7 +210,7 @@ struct_z_erofs_vle_decompressed_index = Struct( "delta" / Struct("delta0" / Int16ul, "delta1" / Int16ul) ) ) -assert struct_z_erofs_vle_decompressed_index.sizeof() == 8 +assert recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) == 8 # noinspection PyUnresolvedReferences @@ -311,7 +341,7 @@ class Inode: prev_blkaddr = 0 prev_reserved_blkaddr = 0 for di_number in range(num_decompressed_blocks): - buf = self.erofs.mmap[decompress_index_header_pos + struct_z_erofs_vle_decompressed_index.sizeof() * di_number: decompress_index_header_pos + struct_z_erofs_vle_decompressed_index.sizeof() * (di_number + 1)] + buf = self.erofs.mmap[decompress_index_header_pos + recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) * di_number: decompress_index_header_pos + recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) * (di_number + 1)] # print(" %s" % codecs.encode(buf, 'hex').decode()) di = struct_z_erofs_vle_decompressed_index.parse(buf) if debug: