Fix issue with sizeof() handling for construct Union in erofs_tool.py

This commit is contained in:
Jakob Lell 2021-09-28 14:01:47 +02:00
parent 777c42fe86
commit fe4bfba776
1 changed files with 32 additions and 2 deletions

View File

@ -21,6 +21,8 @@
import argparse import argparse
import mmap import mmap
import os import os
import construct
from construct import Struct, Int32ul, Int16ul, Int8ul, Int64ul, Array, Union from construct import Struct, Int32ul, Int16ul, Int8ul, Int64ul, Array, Union
from enum import Enum from enum import Enum
from typing import List, Set from typing import List, Set
@ -96,6 +98,34 @@ def command_file(args):
f.write(data) f.write(data)
def recursive_union_sizeof(struct) -> int:
"""
Calculates the size of a construct struct, recursing through `subcons`.
Will also work for unions with the assumption that all elements of the union have the same size
"""
if isinstance(struct, Union):
union_size = None
for item in struct.subcons:
item_size = recursive_union_sizeof(item)
if union_size is None:
union_size = item_size
elif union_size != item_size:
raise ValueError(f"Inconsistent Union size: {union_size} <=> {item_size}")
return union_size
elif isinstance(struct, Struct):
result = 0
for item in struct.subcons:
item_size = recursive_union_sizeof(item)
if item_size is None:
breakpoint()
result += recursive_union_sizeof(item)
return result
elif isinstance(struct, construct.Renamed):
return recursive_union_sizeof(struct.subcon)
else:
return struct.sizeof()
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
struct_erofs_super = Struct( struct_erofs_super = Struct(
"magic" / Int32ul, "magic" / Int32ul,
@ -180,7 +210,7 @@ struct_z_erofs_vle_decompressed_index = Struct(
"delta" / Struct("delta0" / Int16ul, "delta1" / Int16ul) "delta" / Struct("delta0" / Int16ul, "delta1" / Int16ul)
) )
) )
assert struct_z_erofs_vle_decompressed_index.sizeof() == 8 assert recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) == 8
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
@ -311,7 +341,7 @@ class Inode:
prev_blkaddr = 0 prev_blkaddr = 0
prev_reserved_blkaddr = 0 prev_reserved_blkaddr = 0
for di_number in range(num_decompressed_blocks): for di_number in range(num_decompressed_blocks):
buf = self.erofs.mmap[decompress_index_header_pos + struct_z_erofs_vle_decompressed_index.sizeof() * di_number: decompress_index_header_pos + struct_z_erofs_vle_decompressed_index.sizeof() * (di_number + 1)] buf = self.erofs.mmap[decompress_index_header_pos + recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) * di_number: decompress_index_header_pos + recursive_union_sizeof(struct_z_erofs_vle_decompressed_index) * (di_number + 1)]
# print(" %s" % codecs.encode(buf, 'hex').decode()) # print(" %s" % codecs.encode(buf, 'hex').decode())
di = struct_z_erofs_vle_decompressed_index.parse(buf) di = struct_z_erofs_vle_decompressed_index.parse(buf)
if debug: if debug: