more flexibility

load default procedure for version 5 patches
add args for procedure
add default extension for tokens and bsdiff
allow specifying additional required extensions for generation
This commit is contained in:
Silvris
2023-07-16 21:10:37 -05:00
parent 3fbbb4f361
commit e0bc4cfa20

View File

@@ -31,25 +31,35 @@ class AutoPatchRegister(type):
return None
current_patch_version: int = 5
current_patch_version: int = 6
class AutoPatchExtensionRegister(type):
patch_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {}
extension_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {}
required_extensions: List[str]
def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchExtensionRegister:
# construct class
new_class = super().__new__(mcs, name, bases, dct)
if "game" in dct:
AutoPatchExtensionRegister.patch_types[dct["game"]] = new_class
AutoPatchExtensionRegister.extension_types[dct["game"]] = new_class
return new_class
@staticmethod
def get_handler(game: str) -> Optional[AutoPatchExtensionRegister]:
for patch_type, handler in AutoPatchExtensionRegister.patch_types.items():
if patch_type == game:
return handler
return None
def get_handler(game: str) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]:
for extension_type, handler in AutoPatchExtensionRegister.extension_types.items():
if extension_type == game:
if len(handler.required_extensions) > 0:
handlers = [handler]
for required in handler.required_extensions:
if required in AutoPatchExtensionRegister.extension_types:
handlers.append(AutoPatchExtensionRegister.extension_types[required])
else:
raise NotImplementedError(f"No handler for {required}.")
return handlers
else:
return handler
return APPatchExtension
class APContainer:
@@ -179,14 +189,13 @@ class APProcedurePatch(APContainer, metaclass=AutoPatchRegister):
"""
An APContainer that defines a procedure to produce the desired file.
"""
procedure: List[str]
procedure: List[Tuple[str, List[Any]]]
tokens: List[Tuple[int, bytes]]
hash: Optional[str] # base checksum of source file
source_data: bytes
patch_file_ending: str = ""
result_file_ending: str = ".sfc"
token_data: Optional[bytes] = None
files: Dict[str, bytes] = dict()
@classmethod
def get_source_data(cls) -> bytes:
@@ -199,31 +208,44 @@ class APProcedurePatch(APContainer, metaclass=AutoPatchRegister):
cls.source_data = cls.get_source_data()
return cls.source_data
def __init__(self, *args: Any, **kwargs: Any):
self.tokens = list()
def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any):
super(APProcedurePatch, self).__init__(*args, **kwargs)
self.tokens = list()
def get_manifest(self) -> Dict[str, Any]:
manifest = super(APProcedurePatch, self).get_manifest()
manifest["compatible_version"] = 6
manifest["base_checksum"] = self.hash
manifest["result_file_ending"] = self.result_file_ending
manifest["patch_file_ending"] = self.patch_file_ending
manifest["procedure"] = self.procedure
return manifest
def read_tokens(self) -> None:
if not self.token_data:
self.read()
token_count = struct.unpack("I", self.token_data[0:4])[0]
bpr = 4
for _ in range(token_count):
offset = struct.unpack("I", self.token_data[bpr:bpr+4])[0]
size = struct.unpack("I", self.token_data[bpr+4:bpr+8])[0]
data = self.token_data[bpr+8:bpr+8+size]
self.tokens.append((offset, data))
bpr += 8 + size
def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
super(APProcedurePatch, self).read_contents(opened_zipfile)
with opened_zipfile.open("archipelago.json", "r") as f:
manifest = json.load(f)
if manifest["version"] < 6:
# support patching files made before moving to procedures
self.procedure = [("apply_bsdiff4", ["delta.bsdiff4"])]
else:
self.procedure = manifest["procedure"]
for file in opened_zipfile.namelist():
if file not in ["archipelago.json"]:
self.files[file] = opened_zipfile.read(file)
def write_token(self, offset: int, data: bytes) -> None:
self.tokens.append((offset,data)) # lazy write these when we go to generate the patch
def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
super(APProcedurePatch, self).write_contents(opened_zipfile)
for file in self.files:
opened_zipfile.writestr(file, self.files[file])
def get_file(self, file: str) -> bytes:
if file not in self.files:
self.read()
return self.files[file]
def write_file(self, file_name: str, file: bytes) -> None:
self.files[file_name] = file
def get_token_binary(self) -> bytes:
data = bytearray()
@@ -234,35 +256,21 @@ class APProcedurePatch(APContainer, metaclass=AutoPatchRegister):
data.extend(bin_data)
return data
def process_token_binary(self, data: bytes) -> bytes:
self.read_tokens()
data = bytearray(data)
for offset, token_data in self.tokens:
data[offset:offset+len(token_data)] = token_data
return data
def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
super(APProcedurePatch, self).read_contents(opened_zipfile)
if "token_data.bin" in opened_zipfile.namelist():
self.token_data = opened_zipfile.read("token_data.bin")
def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
super(APProcedurePatch, self).write_contents(opened_zipfile)
if len(self.tokens) > 0:
opened_zipfile.writestr("token_data.bin", self.get_token_binary())
def write_token(self, offset, data):
self.tokens.append((offset, data))
def patch(self, target: str):
self.read()
base_data = self.get_source_data_with_cache()
patch_extender = AutoPatchExtensionRegister.get_handler(self.game)
for step in self.procedure:
if step == "apply_tokens":
base_data = self.process_token_binary(base_data)
elif patch_extender is not None:
for step, args in self.procedure:
if isinstance(patch_extender, List):
extension = next((item for item in [getattr(extender, step, None) for extender in patch_extender]
if item is not None), None)
else:
extension = getattr(patch_extender, step, None)
if extension is not None:
base_data = extension(base_data)
else:
raise NotImplementedError(f"Unknown procedure {step} for {self.game}.")
if extension is not None:
base_data = extension(self, base_data, *args)
else:
raise NotImplementedError(f"Unknown procedure {step} for {self.game}.")
with open(target, 'wb') as f:
@@ -271,4 +279,22 @@ class APProcedurePatch(APContainer, metaclass=AutoPatchRegister):
class APPatchExtension(metaclass=AutoPatchExtensionRegister):
game: str
required_extensions: List[str] = list()
@staticmethod
def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str):
return bsdiff4.patch(rom, caller.get_file(patch))
@staticmethod
def apply_tokens(caller: APProcedurePatch, rom: bytes, token_file: str) -> bytes:
token_data = caller.get_file(token_file)
rom_data = bytearray(rom)
token_count = struct.unpack("I", token_data[0:4])[0]
bpr = 4
for _ in range(token_count):
offset = struct.unpack("I", token_data[bpr:bpr + 4])[0]
size = struct.unpack("I", token_data[bpr + 4:bpr + 8])[0]
data = token_data[bpr + 8:bpr + 8 + size]
rom_data[offset:offset + len(data)] = data
bpr += 8 + size
return rom_data