diff --git a/ols/src/quota/quota_limiter.py b/ols/src/quota/quota_limiter.py index 69747752..3f3213e7 100644 --- a/ols/src/quota/quota_limiter.py +++ b/ols/src/quota/quota_limiter.py @@ -7,15 +7,17 @@ class QuotaLimiter(ABC): """Abstract class that is parent for all quota limiter implementations.""" @abstractmethod - def available_quota(self, user_id: str) -> int: + def available_quota(self) -> int: """Retrieve available quota for given user.""" @abstractmethod - def revoke_quota(self, user_id: str) -> None: + def revoke_quota(self) -> None: """Revoke quota for given user.""" @abstractmethod - def consume_tokens( - self, user_id: str, input_tokens: int, output_tokens: int - ) -> None: + def increase_quota(self) -> None: + """Increase quota for given user.""" + + @abstractmethod + def consume_tokens(self, input_tokens: int, output_tokens: int) -> None: """Consume tokens by given user.""" diff --git a/ols/src/quota/user_quota_limiter.py b/ols/src/quota/user_quota_limiter.py index 470d7e19..71b11f2c 100644 --- a/ols/src/quota/user_quota_limiter.py +++ b/ols/src/quota/user_quota_limiter.py @@ -37,15 +37,24 @@ class UserQuotaLimiter(QuotaLimiter): WHERE user_id=%s LIMIT 1 """ - UPDATE_AVAILABLE_QUOTA_FOR_USER = """ + SET_AVAILABLE_QUOTA_FOR_USER = """ UPDATE user_quota_limiter SET available=%s, updated_at=%s WHERE user_id=%s """ - def __init__(self, config: PostgresConfig, initial_quota: int) -> None: + UPDATE_AVAILABLE_QUOTA_FOR_USER = """ + UPDATE user_quota_limiter + SET available=available+%s, updated_at=%s + WHERE user_id=%s + """ + + def __init__( + self, config: PostgresConfig, initial_quota: int = 0, increase_by: int = 0 + ) -> None: """Initialize quota limiter storage.""" self.initial_quota = initial_quota + self.increase_by = increase_by # initialize connection to DB self.connection = psycopg2.connect( @@ -66,7 +75,7 @@ def __init__(self, config: PostgresConfig, initial_quota: int) -> None: logger.exception("Error initializing Postgres database:\n%s", e) raise - def available_quota(self, user_id: str) -> int: + def available_quota(self, user_id: str = "") -> int: """Retrieve available quota for given user.""" with self.connection.cursor() as cursor: cursor.execute( @@ -79,20 +88,32 @@ def available_quota(self, user_id: str) -> int: return self.initial_quota return value[0] - def revoke_quota(self, user_id: str) -> None: + def revoke_quota(self, user_id: str = "") -> None: """Revoke quota for given user.""" # timestamp to be used updated_at = datetime.now() with self.connection.cursor() as cursor: cursor.execute( - UserQuotaLimiter.UPDATE_AVAILABLE_QUOTA_FOR_USER, + UserQuotaLimiter.SET_AVAILABLE_QUOTA_FOR_USER, (self.initial_quota, updated_at, user_id), ) self.connection.commit() + def increase_quota(self, user_id: str = "") -> None: + """Increase quota for given user.""" + # timestamp to be used + updated_at = datetime.now() + + with self.connection.cursor() as cursor: + cursor.execute( + UserQuotaLimiter.UPDATE_AVAILABLE_QUOTA_FOR_USER, + (self.increase_by, updated_at, user_id), + ) + self.connection.commit() + def consume_tokens( - self, user_id: str, input_tokens: int, output_tokens: int + self, input_tokens: int = 0, output_tokens: int = 0, user_id: str = "" ) -> None: """Consume tokens by given user.""" to_be_consumed = input_tokens + output_tokens @@ -104,16 +125,13 @@ def consume_tokens( logger.exception("Quota exceed: %s", e) raise e - # update available tokens for user - available -= to_be_consumed - # timestamp to be used updated_at = datetime.now() with self.connection.cursor() as cursor: cursor.execute( UserQuotaLimiter.UPDATE_AVAILABLE_QUOTA_FOR_USER, - (available, updated_at, user_id), + (-to_be_consumed, updated_at, user_id), ) self.connection.commit()