Skip to content

Commit

Permalink
Split out storage delay save (home-assistant#16017)
Browse files Browse the repository at this point in the history
* Split out storage delayed write

* Update code using delayed save

* Fix tests

* Fix typing test

* Add callback decorator
  • Loading branch information
balloob authored Aug 17, 2018
1 parent fdbab3e commit 2ad0bd4
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 38 deletions.
49 changes: 27 additions & 22 deletions homeassistant/auth/auth_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, List, Optional # noqa: F401
import hmac

from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import dt as dt_util

from . import models
Expand All @@ -32,15 +32,15 @@ def __init__(self, hass: HomeAssistant) -> None:
async def async_get_users(self) -> List[models.User]:
"""Retrieve all users."""
if self._users is None:
await self.async_load()
await self._async_load()
assert self._users is not None

return list(self._users.values())

async def async_get_user(self, user_id: str) -> Optional[models.User]:
"""Retrieve a user by id."""
if self._users is None:
await self.async_load()
await self._async_load()
assert self._users is not None

return self._users.get(user_id)
Expand All @@ -52,7 +52,7 @@ async def async_create_user(
credentials: Optional[models.Credentials] = None) -> models.User:
"""Create a new user."""
if self._users is None:
await self.async_load()
await self._async_load()
assert self._users is not None

kwargs = {
Expand All @@ -73,7 +73,7 @@ async def async_create_user(
self._users[new_user.id] = new_user

if credentials is None:
await self.async_save()
self._async_schedule_save()
return new_user

# Saving is done inside the link.
Expand All @@ -84,33 +84,33 @@ async def async_link_user(self, user: models.User,
credentials: models.Credentials) -> None:
"""Add credentials to an existing user."""
user.credentials.append(credentials)
await self.async_save()
self._async_schedule_save()
credentials.is_new = False

async def async_remove_user(self, user: models.User) -> None:
"""Remove a user."""
if self._users is None:
await self.async_load()
await self._async_load()
assert self._users is not None

self._users.pop(user.id)
await self.async_save()
self._async_schedule_save()

async def async_activate_user(self, user: models.User) -> None:
"""Activate a user."""
user.is_active = True
await self.async_save()
self._async_schedule_save()

async def async_deactivate_user(self, user: models.User) -> None:
"""Activate a user."""
user.is_active = False
await self.async_save()
self._async_schedule_save()

async def async_remove_credentials(
self, credentials: models.Credentials) -> None:
"""Remove credentials."""
if self._users is None:
await self.async_load()
await self._async_load()
assert self._users is not None

for user in self._users.values():
Expand All @@ -125,22 +125,22 @@ async def async_remove_credentials(
user.credentials.pop(found)
break

await self.async_save()
self._async_schedule_save()

async def async_create_refresh_token(
self, user: models.User, client_id: Optional[str] = None) \
-> models.RefreshToken:
"""Create a new token for a user."""
refresh_token = models.RefreshToken(user=user, client_id=client_id)
user.refresh_tokens[refresh_token.id] = refresh_token
await self.async_save()
self._async_schedule_save()
return refresh_token

async def async_get_refresh_token(
self, token_id: str) -> Optional[models.RefreshToken]:
"""Get refresh token by id."""
if self._users is None:
await self.async_load()
await self._async_load()
assert self._users is not None

for user in self._users.values():
Expand All @@ -154,7 +154,7 @@ async def async_get_refresh_token_by_token(
self, token: str) -> Optional[models.RefreshToken]:
"""Get refresh token by token."""
if self._users is None:
await self.async_load()
await self._async_load()
assert self._users is not None

found = None
Expand All @@ -166,7 +166,7 @@ async def async_get_refresh_token_by_token(

return found

async def async_load(self) -> None:
async def _async_load(self) -> None:
"""Load the users."""
data = await self._store.async_load()

Expand Down Expand Up @@ -218,11 +218,18 @@ async def async_load(self) -> None:

self._users = users

async def async_save(self) -> None:
@callback
def _async_schedule_save(self) -> None:
"""Save users."""
if self._users is None:
await self.async_load()
assert self._users is not None
return

self._store.async_delay_save(self._data_to_save, 1)

@callback
def _data_to_save(self) -> Dict:
"""Return the data to store."""
assert self._users is not None

users = [
{
Expand Down Expand Up @@ -262,10 +269,8 @@ async def async_save(self) -> None:
for refresh_token in user.refresh_tokens.values()
]

data = {
return {
'users': users,
'credentials': credentials,
'refresh_tokens': refresh_tokens,
}

await self._store.async_save(data, delay=1)
14 changes: 9 additions & 5 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ async def async_remove(self, entry_id):
raise UnknownEntry

entry = self._entries.pop(found)
await self._async_schedule_save()
self._async_schedule_save()

unloaded = await entry.async_unload(self.hass)

Expand Down Expand Up @@ -391,7 +391,7 @@ async def _async_finish_flow(self, context, result):
source=context['source'],
)
self._entries.append(entry)
await self._async_schedule_save()
self._async_schedule_save()

# Setup entry
if entry.domain in self.hass.config.components:
Expand Down Expand Up @@ -439,12 +439,16 @@ async def _async_create_flow(self, handler_key, *, context, data):
flow.init_step = source
return flow

async def _async_schedule_save(self):
def _async_schedule_save(self):
"""Save the entity registry to a file."""
data = {
self._store.async_delay_save(self._data_to_save, SAVE_DELAY)

@callback
def _data_to_save(self):
"""Return data to save."""
return {
'entries': [entry.as_dict() for entry in self._entries]
}
await self._store.async_save(data, delay=SAVE_DELAY)


async def _old_conf_migrator(old_config):
Expand Down
32 changes: 25 additions & 7 deletions homeassistant/helpers/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import logging
import os
from typing import Dict, Optional
from typing import Dict, Optional, Callable

from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback
Expand Down Expand Up @@ -76,8 +76,13 @@ async def async_load(self):

async def _async_load(self):
"""Helper to load the data."""
# Check if we have a pending write
if self._data is not None:
data = self._data

# If we didn't generate data yet, do it now.
if 'data_func' in data:
data['data'] = data.pop('data_func')()
else:
data = await self.hass.async_add_executor_job(
json.load_json, self.path)
Expand All @@ -95,20 +100,29 @@ async def _async_load(self):
self._load_task = None
return stored

async def async_save(self, data: Dict, *, delay: Optional[int] = None):
"""Save data with an optional delay."""
async def async_save(self, data):
"""Save data."""
self._data = {
'version': self.version,
'key': self.key,
'data': data,
}

self._async_cleanup_delay_listener()
self._async_cleanup_stop_listener()
await self._async_handle_write_data()

@callback
def async_delay_save(self, data_func: Callable[[], Dict],
delay: Optional[int] = None):
"""Save data with an optional delay."""
self._data = {
'version': self.version,
'key': self.key,
'data_func': data_func,
}

if delay is None:
self._async_cleanup_stop_listener()
await self._async_handle_write_data()
return
self._async_cleanup_delay_listener()

self._unsub_delay_listener = async_call_later(
self.hass, delay, self._async_callback_delayed_write)
Expand Down Expand Up @@ -151,6 +165,10 @@ async def _async_callback_stop_write(self, _event):
async def _async_handle_write_data(self, *_args):
"""Handler to handle writing the config."""
data = self._data

if 'data_func' in data:
data['data'] = data.pop('data_func')()

self._data = None

async with self._write_lock:
Expand Down
8 changes: 4 additions & 4 deletions tests/helpers/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def test_loading_parallel(hass, store, hass_storage, caplog):

async def test_saving_with_delay(hass, store, hass_storage):
"""Test saving data after a delay."""
await store.async_save(MOCK_DATA, delay=1)
store.async_delay_save(lambda: MOCK_DATA, 1)
assert store.key not in hass_storage

async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=1))
Expand All @@ -71,7 +71,7 @@ async def test_saving_with_delay(hass, store, hass_storage):
async def test_saving_on_stop(hass, hass_storage):
"""Test delayed saves trigger when we quit Home Assistant."""
store = storage.Store(hass, MOCK_VERSION, MOCK_KEY)
await store.async_save(MOCK_DATA, delay=1)
store.async_delay_save(lambda: MOCK_DATA, 1)
assert store.key not in hass_storage

hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
Expand All @@ -92,7 +92,7 @@ async def test_loading_while_delay(hass, store, hass_storage):
'data': {'delay': 'no'},
}

await store.async_save({'delay': 'yes'}, delay=1)
store.async_delay_save(lambda: {'delay': 'yes'}, 1)
assert hass_storage[store.key] == {
'version': MOCK_VERSION,
'key': MOCK_KEY,
Expand All @@ -105,7 +105,7 @@ async def test_loading_while_delay(hass, store, hass_storage):

async def test_writing_while_writing_delay(hass, store, hass_storage):
"""Test a write while a write with delay is active."""
await store.async_save({'delay': 'yes'}, delay=1)
store.async_delay_save(lambda: {'delay': 'yes'}, 1)
assert store.key not in hass_storage
await store.async_save({'delay': 'no'})
assert hass_storage[store.key] == {
Expand Down

0 comments on commit 2ad0bd4

Please sign in to comment.