diff --git a/src/utils/rate_limiter.py b/src/utils/rate_limiter.py index 2b3a647..431af60 100644 --- a/src/utils/rate_limiter.py +++ b/src/utils/rate_limiter.py @@ -21,7 +21,7 @@ from collections import deque from contextlib import AbstractContextManager from threading import BoundedSemaphore, Lock, Thread from time import sleep, time -from typing import Any, Optional, Sized +from typing import Any, Sized class PickHistory(Sized): @@ -79,16 +79,23 @@ class PickHistory(Sized): # pylint: disable=too-many-instance-attributes class RateLimiter(AbstractContextManager): - """Rate limiter implementing the token bucket algorithm""" + """ + Base rate limiter implementing the token bucket algorithm. + + Do not use directly, create a child class to tailor the rate limiting to the + underlying service's limits. + + Subclasses must provide values to the following attributes: + * refill_period_seconds - Period in which we have a max amount of tokens + * refill_period_tokens - Number of tokens allowed in this period + * burst_tokens - Max number of tokens that can be consumed instantly + """ - # Period in which we have a max amount of tokens refill_period_seconds: int - # Number of tokens allowed in this period refill_period_tokens: int - # Max number of tokens that can be consumed instantly burst_tokens: int - pick_history: Optional[PickHistory] = None # TODO: Geoff: make this required + pick_history: PickHistory bucket: BoundedSemaphore queue: deque[Lock] queue_lock: Lock @@ -107,23 +114,20 @@ class RateLimiter(AbstractContextManager): with self.__n_tokens_lock: self.__n_tokens = value - def __init__( - self, - refill_period_seconds: Optional[int] = None, - refill_period_tokens: Optional[int] = None, - burst_tokens: Optional[int] = None, - ) -> None: + def _init_pick_history(self) -> None: + """ + Initialize the tocken pick history + (only for use in this class and its children) + + By default, creates an empty pick history. + Should be overriden or extended by subclasses. + """ + self.pick_history = PickHistory(self.refill_period_seconds) + + def __init__(self) -> None: """Initialize the limiter""" - # Initialize default values - if refill_period_seconds is not None: - self.refill_period_seconds = refill_period_seconds - if refill_period_tokens is not None: - self.refill_period_tokens = refill_period_tokens - if burst_tokens is not None: - self.burst_tokens = burst_tokens - if self.pick_history is None: - self.pick_history = PickHistory(self.refill_period_seconds) + self._init_pick_history() # Create synchronization data self.__n_tokens_lock = Lock() diff --git a/src/utils/steam.py b/src/utils/steam.py index 456869c..bdc8d84 100644 --- a/src/utils/steam.py +++ b/src/utils/steam.py @@ -28,7 +28,7 @@ import requests from requests.exceptions import HTTPError from src import shared -from src.utils.rate_limiter import PickHistory, RateLimiter +from src.utils.rate_limiter import RateLimiter class SteamError(Exception): @@ -72,14 +72,16 @@ class SteamRateLimiter(RateLimiter): refill_period_tokens = 200 burst_tokens = 100 - def __init__(self) -> None: - # Load pick history from schema - # (Remember API limits through restarts of Cartridges) + def _init_pick_history(self) -> None: + """ + Load the pick history from schema. + + Allows remembering API limits through restarts of Cartridges. + """ + super()._init_pick_history() timestamps_str = shared.state_schema.get_string("steam-limiter-tokens-history") - self.pick_history = PickHistory(self.refill_period_seconds) self.pick_history.add(*json.loads(timestamps_str)) self.pick_history.remove_old_entries() - super().__init__() def acquire(self) -> None: """Get a token from the bucket and store the pick history in the schema""" @@ -91,9 +93,7 @@ class SteamRateLimiter(RateLimiter): class SteamFileHelper: """Helper for steam file formats""" - def get_manifest_data( - self, manifest_path: Path - ) -> SteamManifestData: # TODO: Geoff: fix typing issue + def get_manifest_data(self, manifest_path: Path) -> SteamManifestData: """Get local data for a game from its manifest""" with open(manifest_path, "r", encoding="utf-8") as file: @@ -107,7 +107,11 @@ class SteamFileHelper: raise SteamInvalidManifestError() data[key] = match.group(1) - return SteamManifestData(**data) + return SteamManifestData( + name=data["name"], + appid=data["appid"], + stateflags=data["stateflags"], + ) class SteamAPIHelper: