diff --git a/Patch.py b/Patch.py index 113d0658c6..588993cfd4 100644 --- a/Patch.py +++ b/Patch.py @@ -2,13 +2,13 @@ from __future__ import annotations import os import sys -from typing import Tuple, Optional, TypedDict +from typing import Tuple, Optional, TypedDict, Union if __name__ == "__main__": import ModuleUpdate ModuleUpdate.update() -from worlds.Files import AutoPatchRegister, APDeltaPatch +from worlds.Files import AutoPatchRegister, APDeltaPatch, APProcedurePatch class RomMeta(TypedDict): @@ -20,7 +20,7 @@ class RomMeta(TypedDict): def create_rom_file(patch_file: str) -> Tuple[RomMeta, str]: auto_handler = AutoPatchRegister.get_handler(patch_file) if auto_handler: - handler: APDeltaPatch = auto_handler(patch_file) + handler: Union[APDeltaPatch, APProcedurePatch] = auto_handler(patch_file) target = os.path.splitext(patch_file)[0]+handler.result_file_ending handler.patch(target) return {"server": handler.server, diff --git a/worlds/Files.py b/worlds/Files.py index ac1acbf322..da892cbd9e 100644 --- a/worlds/Files.py +++ b/worlds/Files.py @@ -1,9 +1,10 @@ from __future__ import annotations import json +import struct import zipfile -from typing import ClassVar, Dict, Tuple, Any, Optional, Union, BinaryIO +from typing import ClassVar, Dict, Tuple, Any, Optional, Union, BinaryIO, List import bsdiff4 @@ -33,6 +34,24 @@ class AutoPatchRegister(type): current_patch_version: int = 5 +class AutoPatchExtensionRegister(type): + patch_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {} + + def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchExtensionRegister: + # construct class + new_class = super().__new__(mcs, name, bases, dct) + if "game" in dct: + AutoPatchExtensionRegister.patch_types[dct["game"]] = new_class + return new_class + + @staticmethod + def get_handler(game: str) -> Optional[AutoPatchExtensionRegister]: + for patch_type, handler in AutoPatchExtensionRegister.patch_types.items(): + if patch_type == game: + return handler + return None + + class APContainer: """A zipfile containing at least archipelago.json""" version: int = current_patch_version @@ -154,3 +173,102 @@ class APDeltaPatch(APContainer, metaclass=AutoPatchRegister): result = bsdiff4.patch(self.get_source_data_with_cache(), self.delta) with open(target, "wb") as f: f.write(result) + + +class APProcedurePatch(APContainer, metaclass=AutoPatchRegister): + """ + An APContainer that defines a procedure to produce the desired file. + """ + procedure: List[str] + tokens: List[Tuple[int, bytes]] + hash: Optional[str] # base checksum of source file + source_data: bytes + patch_file_ending: str = "" + result_file_ending: str = ".sfc" + token_data: Optional[bytes] = None + + + @classmethod + def get_source_data(cls) -> bytes: + """Get Base data""" + raise NotImplementedError() + + @classmethod + def get_source_data_with_cache(cls) -> bytes: + if not hasattr(cls, "source_data"): + cls.source_data = cls.get_source_data() + return cls.source_data + + def __init__(self, *args: Any, **kwargs: Any): + self.tokens = list() + super(APProcedurePatch, self).__init__(*args, **kwargs) + + def get_manifest(self) -> Dict[str, Any]: + manifest = super(APProcedurePatch, self).get_manifest() + manifest["base_checksum"] = self.hash + manifest["result_file_ending"] = self.result_file_ending + manifest["patch_file_ending"] = self.patch_file_ending + return manifest + + def read_tokens(self) -> None: + if not self.token_data: + self.read() + token_count = struct.unpack("I", self.token_data[0:4])[0] + bpr = 4 + for _ in range(token_count): + offset = struct.unpack("I", self.token_data[bpr:bpr+4])[0] + size = struct.unpack("I", self.token_data[bpr+4:bpr+8])[0] + data = self.token_data[bpr+8:bpr+8+size] + self.tokens.append((offset, data)) + bpr += 8 + size + + def write_token(self, offset: int, data: bytes) -> None: + self.tokens.append((offset,data)) # lazy write these when we go to generate the patch + + def get_token_binary(self) -> bytes: + data = bytearray() + data.extend(struct.pack("I", len(self.tokens))) + for offset, bin_data in self.tokens: + data.extend(struct.pack("I", offset)) + data.extend(struct.pack("I", len(bin_data))) + data.extend(bin_data) + return data + + def process_token_binary(self, data: bytes) -> bytes: + self.read_tokens() + data = bytearray(data) + for offset, token_data in self.tokens: + data[offset:offset+len(token_data)] = token_data + return data + + def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None: + super(APProcedurePatch, self).read_contents(opened_zipfile) + if "token_data.bin" in opened_zipfile.namelist(): + self.token_data = opened_zipfile.read("token_data.bin") + + def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None: + super(APProcedurePatch, self).write_contents(opened_zipfile) + if len(self.tokens) > 0: + opened_zipfile.writestr("token_data.bin", self.get_token_binary()) + + def patch(self, target: str): + base_data = self.get_source_data_with_cache() + patch_extender = AutoPatchExtensionRegister.get_handler(self.game) + for step in self.procedure: + if step == "apply_tokens": + base_data = self.process_token_binary(base_data) + elif patch_extender is not None: + extension = getattr(patch_extender, step, None) + if extension is not None: + base_data = extension(base_data) + else: + raise NotImplementedError(f"Unknown procedure {step} for {self.game}.") + else: + raise NotImplementedError(f"Unknown procedure {step} for {self.game}.") + with open(target, 'wb') as f: + f.write(base_data) + + +class APPatchExtension(metaclass=AutoPatchExtensionRegister): + game: str +