Skip to content
This repository has been archived by the owner on Apr 29, 2024. It is now read-only.

Commit

Permalink
Added state saving (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
spietras authored Aug 27, 2022
1 parent 85bf851 commit 33c76df
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 66 deletions.
38 changes: 27 additions & 11 deletions kilroy_face_twitter/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion kilroy_face_twitter/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "kilroy-face-twitter"
version = "0.3.0"
version = "0.4.0"
description = "kilroy face for Twitter 🐦"
readme = "README.md"
authors = ["kilroy <kilroymail@pm.me>"]
Expand All @@ -18,6 +18,7 @@ httpx = "^0.23"
aiostream = "^0.4"
PyYAML = "^6.0"
deepmerge = "^1.0"
platformdirs = "^2.5"

# dev

Expand Down
66 changes: 57 additions & 9 deletions kilroy_face_twitter/src/kilroy_face_twitter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@
from asyncio import FIRST_EXCEPTION
from enum import Enum
from logging import Logger
from pathlib import Path
from typing import Dict, Optional

import typer
from kilroy_face_server_py_sdk import FaceServer
from platformdirs import user_cache_dir
from typer import FileText

from kilroy_face_twitter.config import get_config
from kilroy_face_twitter.face import TwitterFace

cli = typer.Typer() # this is actually callable and thus can be an entry point

DEFAULT_STATE_DIRECTORY = (
Path(user_cache_dir("kilroybot")) / "kilroy-face-twitter" / "state"
)


class Verbosity(str, Enum):
DEBUG = "DEBUG"
Expand All @@ -35,15 +41,40 @@ def get_logger(verbosity: Verbosity) -> Logger:
return logger


async def run(config: Dict, logger: Logger) -> None:
face_cls = TwitterFace.for_category(config.get("faceType"))
async def load_or_init(
face: TwitterFace, state_dir: Path, logger: Logger
) -> None:
if not state_dir.exists() or not any(state_dir.iterdir()):
logger.info("Initializing face...")
await face.init()
logger.info("Initialization complete.")
return

try:
logger.info("Loading state...")
await face.load_saved(state_dir)
logger.info("Loading complete.")
except OSError:
logger.warning("State directory is invalid. Will initialize instead.")
logger.info("Initializing face...")
await face.init()
logger.info("Initialization complete.")


async def run(config: Dict, logger: Logger, state_dir: Path) -> None:
face_type = config["faceType"]
face_cls = TwitterFace.for_category(face_type)
face = await face_cls.build(**config.get("faceParams", {}))
server = FaceServer(face, logger)

tasks = (
asyncio.create_task(face.init()),
asyncio.create_task(server.run(**config.get("serverParams", {}))),
state_dir = state_dir / face_type

server_task = asyncio.create_task(
server.run(**config.get("serverParams", {}))
)
init_task = asyncio.create_task(load_or_init(face, state_dir, logger))

tasks = [server_task, init_task]

try:
done, pending = await asyncio.wait(tasks, return_when=FIRST_EXCEPTION)
Expand All @@ -60,24 +91,41 @@ async def run(config: Dict, logger: Logger) -> None:
for task in done:
task.result()

await face.cleanup()
if (
init_task.done()
and not init_task.cancelled()
and init_task.exception() is None
):
logger.info("Saving state...")
await face.save(state_dir)

logger.info("Cleaning up...")
await face.cleanup()


@cli.command()
def main(
config: Optional[FileText] = typer.Option(
default=None, help="Configuration file"
None, "--config", "-c", dir_okay=False, help="Configuration file"
),
verbosity: Verbosity = typer.Option(
default="INFO", help="Verbosity level."
"INFO", "--verbosity", "-v", help="Verbosity level."
),
state_directory: Optional[Path] = typer.Option(
DEFAULT_STATE_DIRECTORY,
"--state-directory",
"-s",
file_okay=False,
writable=True,
help="Path to state directory.",
),
) -> None:
"""Command line interface for kilroy-face-twitter."""

config = get_config(config)
logger = get_logger(verbosity)

asyncio.run(run(config, logger))
asyncio.run(run(config, logger, state_directory))


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 33c76df

Please sign in to comment.