This commit is contained in:
CookieCat
2023-10-18 23:33:05 -04:00
parent 8b98cd7a01
commit d043c0d0f6
9 changed files with 747 additions and 65 deletions

View File

@@ -3,7 +3,7 @@ import Utils
import websockets
import functools
from copy import deepcopy
from typing import List, Any, Iterable
from typing import List, Any, Iterable, Dict
from NetUtils import decode, encode, JSONtoTextParser, JSONMessagePart, NetworkItem
from MultiServer import Endpoint
from CommonClient import CommonContext, gui_enabled, ClientCommandProcessor, logger, \
@@ -70,12 +70,18 @@ class AHITContext(CommonContext):
await self.endpoint.socket.send(msgs)
return True
async def disconnect(self, allow_autoreconnect: bool = False):
await super().disconnect(allow_autoreconnect)
async def disconnect_proxy(self):
if self.endpoint and not self.endpoint.socket.closed:
await self.endpoint.socket.close()
if self.proxy_task is not None:
await self.proxy_task
def is_connected(self) -> bool:
return self.server and self.server.socket.open
def is_proxy_connected(self) -> bool:
return self.endpoint and self.endpoint.socket.open
@@ -91,6 +97,10 @@ class AHITContext(CommonContext):
logger.info(text)
def update_items(self):
# just to be safe - we might still have an inventory from a different room
if not self.is_connected():
return
self.server_msgs.append(encode([{"cmd": "ReceivedItems", "index": 0, "items": self.full_inventory}]))
def on_package(self, cmd: str, args: dict):
@@ -118,6 +128,10 @@ class AHITContext(CommonContext):
if cmd != "PrintJSON":
self.server_msgs.append(encode([args]))
# def on_deathlink(self, data: Dict[str, Any]):
# self.server_msgs.append(encode([data]))
# super().on_deathlink(data)
def run_gui(self):
from kvui import GameManager
@@ -147,13 +161,17 @@ async def proxy(websocket, path: str = "/", ctx: AHITContext = None):
break
if ctx.seed_name:
seed = msg.get("seed", "")
if seed != "" and seed != ctx.seed_name:
seed_name = msg.get("seed_name", "")
if seed_name != "" and seed_name != ctx.seed_name:
logger.info("Aborting proxy connection: seed mismatch from save file")
logger.info(f"Expected: {ctx.seed_name}, got: {seed_name}")
text = encode([{"cmd": "PrintJSON",
"data": [{"text": "Connection aborted - save file to seed mismatch"}]}])
await ctx.send_msgs_proxy(text)
await ctx.disconnect_proxy()
break
if ctx.connected_msg:
if ctx.connected_msg and ctx.is_connected():
await ctx.send_msgs_proxy(ctx.connected_msg)
ctx.update_items()
continue
@@ -174,7 +192,7 @@ async def proxy(websocket, path: str = "/", ctx: AHITContext = None):
async def on_client_connected(ctx: AHITContext):
if ctx.room_info:
if ctx.room_info and ctx.is_connected():
await ctx.send_msgs_proxy(ctx.room_info)
else:
ctx.awaiting_info = True
@@ -186,7 +204,8 @@ async def main():
ctx = AHITContext(args.connect, args.password)
logger.info("Starting A Hat in Time proxy server")
ctx.proxy = websockets.serve(functools.partial(proxy, ctx=ctx), host="localhost", port=11311)
ctx.proxy = websockets.serve(functools.partial(proxy, ctx=ctx),
host="localhost", port=11311, ping_timeout=999999, ping_interval=999999)
ctx.proxy_task = asyncio.create_task(proxy_loop(ctx), name="ProxyLoop")
if gui_enabled: