Fix issue with sizeof() handling for construct Union in erofs_tool.py

This commit is contained in:
Jakob Lell 2021-09-28 14:01:47 +02:00
parent 777c42fe86
commit fe4bfba776
1 changed files with 32 additions and 2 deletions

View File

@ -21,6 +21,8 @@
import argparse
import mmap
import os
import construct
from construct import Struct, Int32ul, Int16ul, Int8ul, Int64ul, Array, Union
from enum import Enum
from typing import List, Set
@ -96,6 +98,34 @@ def command_file(args):
f.write(data)
def recursive_union_sizeof(struct) -> int:
"""
Calculates the size of a construct struct, recursing through `subcons`.
Will also work for unions with the assumption that all elements of the union have the same size
"""
if isinstance(struct, Union):
union_size = None
for item in struct.subcons:
item_size = recursive_union_sizeof(item)
if union_size is None:
union_size = item_size
elif union_size != item_size:
raise ValueError(f"Inconsistent Union size: {union_size} <=> {item_size}")
return union_size
elif isinstance(struct, Struct):
result = 0
for item in struct.subcons:
item_size = recursive_union_sizeof(item)
if item_size is None:
breakpoint()
result += recursive_union_sizeof(item)
return result
elif isinstance(struct, construct.Renamed):
return recursive_union_sizeof(struct.subcon)
else:
return struct.sizeof()
# noinspection PyUnresolvedReferences
struct_erofs_super = Struct(
"magic" / Int32ul,
@ -180,7 +210,7 @@ struct_z_erofs_vle_decompressed_index = Struct(
"delta" / Struct("delta0" / Int16ul, "delta1" / Int16ul)
)
)
assert struct_z_erofs_vle_decompressed_index.sizeof() == 8
assert recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) == 8
# noinspection PyUnresolvedReferences
@ -311,7 +341,7 @@ class Inode:
prev_blkaddr = 0
prev_reserved_blkaddr = 0
for di_number in range(num_decompressed_blocks):
buf = self.erofs.mmap[decompress_index_header_pos + struct_z_erofs_vle_decompressed_index.sizeof() * di_number: decompress_index_header_pos + struct_z_erofs_vle_decompressed_index.sizeof() * (di_number + 1)]
buf = self.erofs.mmap[decompress_index_header_pos + recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) * di_number: decompress_index_header_pos + recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) * (di_number + 1)]
# print(" %s" % codecs.encode(buf, 'hex').decode())
di = struct_z_erofs_vle_decompressed_index.parse(buf)
if debug: