Files
dockipelago/worlds/AutoSNIClient.py
Doug Hoskisson 775065715d SNIClient: new SnesReader interface (#5155)
* SNIClient: new SnesReader interface

* fix Python 3.8 compatibility
`bisect_right`

* move to worlds
because we don't have good separation importable modules and entry points

* `read` gives object that contains data

* remove python 3.10 implementation and update typing

* remove obsolete comment

* freeze _MemRead and assert type of get parameter

* some optimization in `SnesData.get`

* pass context to `read` so that we can have a static instance of `SnesReader`

* add docstring to `SnesReader`

* remove unused import

* break big reads into chunks

* some minor improvements

- `dataclass` instead of `NamedTuple` for `Read`
- comprehension in `SnesData.__init__`
- `slots` for dataclasses

* update chunk size to 2048
2025-11-30 01:22:35 +01:00

222 lines
7.8 KiB
Python

from __future__ import annotations
import abc
from bisect import bisect_right
from dataclasses import dataclass
import enum
import logging
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterable,
Optional, Sequence, Tuple, TypeGuard, TypeVar, Union)
from worlds.LauncherComponents import Component, SuffixIdentifier, Type, components
if TYPE_CHECKING:
from SNIClient import SNIContext
SNES_READ_CHUNK_SIZE = 2048
"""
note: SNI v0.0.101 currently has a bug where reads from
RetroArch >2048 bytes will only return the last ~2048 bytes read.
https://github.com/alttpo/sni/issues/51
"""
component = Component('SNI Client', 'SNIClient', component_type=Type.CLIENT, file_identifier=SuffixIdentifier(".apsoe"),
description="A client for connecting to SNES consoles via Super Nintendo Interface.")
components.append(component)
def valid_patch_suffix(obj: object) -> TypeGuard[Union[str, Iterable[str]]]:
""" make sure this is a valid value for the class variable `patch_suffix` """
def valid_individual(one: object) -> TypeGuard[str]:
""" check an individual suffix """
# TODO: decide: len(one) > 3 and one.startswith(".ap") ?
# or keep it more general?
return isinstance(one, str) and len(one) > 1 and one.startswith(".")
if isinstance(obj, str):
return valid_individual(obj)
if not isinstance(obj, Iterable):
return False
obj_it: Iterable[object] = obj
return all(valid_individual(each) for each in obj_it)
class AutoSNIClientRegister(abc.ABCMeta):
game_handlers: ClassVar[Dict[str, SNIClient]] = {}
def __new__(cls, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoSNIClientRegister:
# construct class
new_class = super().__new__(cls, name, bases, dct)
if "game" in dct:
AutoSNIClientRegister.game_handlers[dct["game"]] = new_class()
if "patch_suffix" in dct:
patch_suffix = dct["patch_suffix"]
assert valid_patch_suffix(patch_suffix), f"class {name} defining invalid {patch_suffix=}"
existing_identifier = component.file_identifier
assert isinstance(existing_identifier, SuffixIdentifier), f"{existing_identifier=}"
new_suffixes = [*existing_identifier.suffixes]
if isinstance(patch_suffix, str):
new_suffixes.append(patch_suffix)
else:
new_suffixes.extend(patch_suffix)
component.file_identifier = SuffixIdentifier(*new_suffixes)
return new_class
@staticmethod
async def get_handler(ctx: SNIContext) -> Optional[SNIClient]:
for _game, handler in AutoSNIClientRegister.game_handlers.items():
try:
if await handler.validate_rom(ctx):
return handler
except Exception as e:
text_file_logger = logging.getLogger()
text_file_logger.exception(e)
return None
class SNIClient(abc.ABC, metaclass=AutoSNIClientRegister):
patch_suffix: ClassVar[Union[str, Iterable[str]]] = ()
"""The file extension(s) this client is meant to open and patch (e.g. ".aplttp")"""
@abc.abstractmethod
async def validate_rom(self, ctx: SNIContext) -> bool:
""" TODO: interface documentation here """
...
@abc.abstractmethod
async def game_watcher(self, ctx: SNIContext) -> None:
""" TODO: interface documentation here """
...
async def deathlink_kill_player(self, ctx: SNIContext) -> None:
""" override this with implementation to kill player """
pass
def on_package(self, ctx: SNIContext, cmd: str, args: Dict[str, Any]) -> None:
""" override this with code to handle packages from the server """
pass
@dataclass(frozen=True, slots=True, order=True)
class Read:
""" snes memory read - address and size in bytes """
address: int
size: int
@dataclass(frozen=True, slots=True)
class _MemRead:
location: Read
data: bytes
_T_Enum = TypeVar("_T_Enum", bound=enum.Enum)
class SnesData(Generic[_T_Enum]):
_ranges: Sequence[_MemRead]
""" sorted by address """
def __init__(self, ranges: Sequence[tuple[Read, bytes]]) -> None:
self._ranges = [_MemRead(r, d) for r, d in ranges]
def get(self, read: _T_Enum) -> bytes:
assert isinstance(read.value, Read), read.value
address = read.value.address
index = bisect_right(self._ranges, address, key=lambda r: r.location.address) - 1
assert index >= 0, (self._ranges, read.value)
mem_read = self._ranges[index]
sub_index = address - mem_read.location.address
return mem_read.data[sub_index:sub_index + read.value.size]
class SnesReader(Generic[_T_Enum]):
"""
how to use:
```
from enum import Enum
from worlds.AutoSNIClient import Read, SNIClient, SnesReader
class MyGameMemory(Enum):
game_mode = Read(WRAM_START + 0x0998, 1)
send_queue = Read(SEND_QUEUE_START, 8 * 127)
...
snes_reader = SnesReader(MyGameMemory)
snes_data = await snes_reader.read(ctx)
if snes_data is None:
snes_logger.info("error reading from snes")
return
game_mode = snes_data.get(MyGameMemory.game_mode)
```
"""
_ranges: Sequence[Read]
""" sorted by address """
def __init__(self, reads: type[_T_Enum]) -> None:
self._ranges = self._make_ranges(reads)
@staticmethod
def _make_ranges(reads: type[enum.Enum]) -> Sequence[Read]:
unprocessed_reads: list[Read] = []
for e in reads:
assert isinstance(e.value, Read), (reads.__name__, e, e.value)
unprocessed_reads.append(e.value)
unprocessed_reads.sort()
ranges: list[Read] = []
for read in unprocessed_reads:
# v end of the previous range
if len(ranges) == 0 or read.address - (ranges[-1].address + ranges[-1].size) > 255:
ranges.append(read)
else: # combine with previous range
chunk_address = ranges[-1].address
assert read.address >= chunk_address, "sort() didn't work? or something"
original_chunk_size = ranges[-1].size
new_size = max((read.address + read.size) - chunk_address,
original_chunk_size)
ranges[-1] = Read(chunk_address, new_size)
logging.debug(f"{len(ranges)=} {max(r.size for r in ranges)=}")
return ranges
async def read(self, ctx: "SNIContext") -> SnesData[_T_Enum] | None:
"""
returns `None` if reading fails,
otherwise returns the data for the registered `Enum`
"""
from SNIClient import snes_read
reads: list[tuple[Read, bytes]] = []
for r in self._ranges:
if r.size < SNES_READ_CHUNK_SIZE: # most common
response = await snes_read(ctx, r.address, r.size)
if response is None:
return None
reads.append((r, response))
else: # big read
# Problems were reported with big reads,
# so we chunk it into smaller pieces.
read_so_far = 0
collection: list[bytes] = []
while read_so_far < r.size:
remaining_size = r.size - read_so_far
chunk_size = min(SNES_READ_CHUNK_SIZE, remaining_size)
response = await snes_read(ctx, r.address + read_so_far, chunk_size)
if response is None:
return None
collection.append(response)
read_so_far += chunk_size
reads.append((r, b"".join(collection)))
return SnesData(reads)