Fix issue with sizeof() handling for construct Union in erofs_tool.py
This commit is contained in:
parent
777c42fe86
commit
fe4bfba776
|
@ -21,6 +21,8 @@
|
|||
import argparse
|
||||
import mmap
|
||||
import os
|
||||
|
||||
import construct
|
||||
from construct import Struct, Int32ul, Int16ul, Int8ul, Int64ul, Array, Union
|
||||
from enum import Enum
|
||||
from typing import List, Set
|
||||
|
@ -96,6 +98,34 @@ def command_file(args):
|
|||
f.write(data)
|
||||
|
||||
|
||||
def recursive_union_sizeof(struct) -> int:
|
||||
"""
|
||||
Calculates the size of a construct struct, recursing through `subcons`.
|
||||
Will also work for unions with the assumption that all elements of the union have the same size
|
||||
"""
|
||||
if isinstance(struct, Union):
|
||||
union_size = None
|
||||
for item in struct.subcons:
|
||||
item_size = recursive_union_sizeof(item)
|
||||
if union_size is None:
|
||||
union_size = item_size
|
||||
elif union_size != item_size:
|
||||
raise ValueError(f"Inconsistent Union size: {union_size} <=> {item_size}")
|
||||
return union_size
|
||||
elif isinstance(struct, Struct):
|
||||
result = 0
|
||||
for item in struct.subcons:
|
||||
item_size = recursive_union_sizeof(item)
|
||||
if item_size is None:
|
||||
breakpoint()
|
||||
result += recursive_union_sizeof(item)
|
||||
return result
|
||||
elif isinstance(struct, construct.Renamed):
|
||||
return recursive_union_sizeof(struct.subcon)
|
||||
else:
|
||||
return struct.sizeof()
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
struct_erofs_super = Struct(
|
||||
"magic" / Int32ul,
|
||||
|
@ -180,7 +210,7 @@ struct_z_erofs_vle_decompressed_index = Struct(
|
|||
"delta" / Struct("delta0" / Int16ul, "delta1" / Int16ul)
|
||||
)
|
||||
)
|
||||
assert struct_z_erofs_vle_decompressed_index.sizeof() == 8
|
||||
assert recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) == 8
|
||||
|
||||
|
||||
# noinspection PyUnresolvedReferences
|
||||
|
@ -311,7 +341,7 @@ class Inode:
|
|||
prev_blkaddr = 0
|
||||
prev_reserved_blkaddr = 0
|
||||
for di_number in range(num_decompressed_blocks):
|
||||
buf = self.erofs.mmap[decompress_index_header_pos + struct_z_erofs_vle_decompressed_index.sizeof() * di_number: decompress_index_header_pos + struct_z_erofs_vle_decompressed_index.sizeof() * (di_number + 1)]
|
||||
buf = self.erofs.mmap[decompress_index_header_pos + recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) * di_number: decompress_index_header_pos + recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) * (di_number + 1)]
|
||||
# print(" %s" % codecs.encode(buf, 'hex').decode())
|
||||
di = struct_z_erofs_vle_decompressed_index.parse(buf)
|
||||
if debug:
|
||||
|
|
Loading…
Reference in New Issue