diff --git a/icechunk-python/python/icechunk/_icechunk_python.pyi b/icechunk-python/python/icechunk/_icechunk_python.pyi index c98fed8d..9c13fb6a 100644 --- a/icechunk-python/python/icechunk/_icechunk_python.pyi +++ b/icechunk-python/python/icechunk/_icechunk_python.pyi @@ -983,6 +983,7 @@ class PyRepository: *, tag: str | None = None, snapshot_id: str | None = None, + as_of: datetime.datetime | None = None, ) -> PySession: ... def writable_session(self, branch: str) -> PySession: ... def expire_snapshots( diff --git a/icechunk-python/python/icechunk/repository.py b/icechunk-python/python/icechunk/repository.py index 3d26dfb1..ceecb7d7 100644 --- a/icechunk-python/python/icechunk/repository.py +++ b/icechunk-python/python/icechunk/repository.py @@ -437,6 +437,7 @@ def readonly_session( *, tag: str | None = None, snapshot_id: str | None = None, + as_of: datetime.datetime | None = None, ) -> Session: """ Create a read-only session. @@ -453,6 +454,9 @@ def readonly_session( If provided, the tag to create the session on. snapshot_id : str, optional If provided, the snapshot ID to create the session on. + as_of: datetime.datetime, optional + When combined with the branch argument, it will open the session at the last + snapshot that is at or before this datetime Returns ------- @@ -465,7 +469,7 @@ def readonly_session( """ return Session( self._repository.readonly_session( - branch=branch, tag=tag, snapshot_id=snapshot_id + branch=branch, tag=tag, snapshot_id=snapshot_id, as_of=as_of ) ) diff --git a/icechunk-python/src/repository.rs b/icechunk-python/src/repository.rs index ed11d0c2..d1275a47 100644 --- a/icechunk-python/src/repository.rs +++ b/icechunk-python/src/repository.rs @@ -517,7 +517,7 @@ impl PyRepository { let repo = Arc::clone(&self.0); // This function calls block_on, so we need to allow other thread python to make progress py.allow_threads(move || { - let version = args_to_version_info(branch, tag, snapshot_id)?; + let version = args_to_version_info(branch, tag, snapshot_id, None)?; let ancestry = pyo3_async_runtimes::tokio::get_runtime() .block_on(async move { repo.ancestry_arc(&version).await }) .map_err(PyIcechunkStoreError::RepositoryError)? @@ -701,8 +701,8 @@ impl PyRepository { to_tag: Option, to_snapshot_id: Option, ) -> PyResult { - let from = args_to_version_info(from_branch, from_tag, from_snapshot_id)?; - let to = args_to_version_info(to_branch, to_tag, to_snapshot_id)?; + let from = args_to_version_info(from_branch, from_tag, from_snapshot_id, None)?; + let to = args_to_version_info(to_branch, to_tag, to_snapshot_id, None)?; // This function calls block_on, so we need to allow other thread python to make progress py.allow_threads(move || { @@ -717,17 +717,18 @@ impl PyRepository { }) } - #[pyo3(signature = (*, branch = None, tag = None, snapshot_id = None))] + #[pyo3(signature = (*, branch = None, tag = None, snapshot_id = None, as_of = None))] pub fn readonly_session( &self, py: Python<'_>, branch: Option, tag: Option, snapshot_id: Option, + as_of: Option>, ) -> PyResult { // This function calls block_on, so we need to allow other thread python to make progress py.allow_threads(move || { - let version = args_to_version_info(branch, tag, snapshot_id)?; + let version = args_to_version_info(branch, tag, snapshot_id, as_of)?; let session = pyo3_async_runtimes::tokio::get_runtime().block_on(async move { self.0 @@ -841,6 +842,7 @@ fn args_to_version_info( branch: Option, tag: Option, snapshot: Option, + as_of: Option>, ) -> PyResult { let n = [&branch, &tag, &snapshot].iter().filter(|r| !r.is_none()).count(); if n > 1 { @@ -849,8 +851,18 @@ fn args_to_version_info( )); } - if let Some(branch_name) = branch { - Ok(VersionInfo::BranchTipRef(branch_name)) + if as_of.is_some() && branch.is_none() { + return Err(PyValueError::new_err( + "as_of argument must be provided together with a branch name", + )); + } + + if let Some(branch) = branch { + if let Some(at) = as_of { + Ok(VersionInfo::AsOf { branch, at }) + } else { + Ok(VersionInfo::BranchTipRef(branch)) + } } else if let Some(tag_name) = tag { Ok(VersionInfo::TagRef(tag_name)) } else if let Some(snapshot_id) = snapshot { diff --git a/icechunk-python/tests/test_timetravel.py b/icechunk-python/tests/test_timetravel.py index 79303a74..20074be5 100644 --- a/icechunk-python/tests/test_timetravel.py +++ b/icechunk-python/tests/test_timetravel.py @@ -258,3 +258,38 @@ async def test_tag_delete() -> None: with pytest.raises(ValueError): repo.create_tag("tag", snap) + + +async def test_session_with_as_of() -> None: + repo = ic.Repository.create( + storage=ic.in_memory_storage(), + ) + + session = repo.writable_session("main") + store = session.store + + times = [] + group = zarr.group(store=store, overwrite=True) + sid = session.commit("root") + times.append(next(repo.ancestry(snapshot_id=sid)).written_at) + + for i in range(5): + session = repo.writable_session("main") + store = session.store + group = zarr.open_group(store=store) + group.create_group(f"child {i}") + sid = session.commit(f"child {i}") + times.append(next(repo.ancestry(snapshot_id=sid)).written_at) + + ancestry = list(p for p in repo.ancestry(branch="main")) + assert len(ancestry) == 7 # initial + root + 5 children + + store = repo.readonly_session("main", as_of=times[-1]).store + group = zarr.open_group(store=store, mode="r") + + for i, time in enumerate(times): + store = repo.readonly_session("main", as_of=time).store + group = zarr.open_group(store=store, mode="r") + expected_children = {f"child {j}" for j in range(i)} + actual_children = {g[0] for g in group.members()} + assert expected_children == actual_children diff --git a/icechunk/src/repository.rs b/icechunk/src/repository.rs index 44ab628e..29c83344 100644 --- a/icechunk/src/repository.rs +++ b/icechunk/src/repository.rs @@ -5,7 +5,9 @@ use std::{ sync::Arc, }; +use async_recursion::async_recursion; use bytes::Bytes; +use chrono::{DateTime, Utc}; use err_into::ErrorInto as _; use futures::{ stream::{FuturesOrdered, FuturesUnordered}, @@ -37,15 +39,13 @@ use crate::{ Storage, StorageError, }; -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] pub enum VersionInfo { - #[serde(rename = "snapshot_id")] SnapshotId(SnapshotId), - #[serde(rename = "tag")] TagRef(String), - #[serde(rename = "branch")] BranchTipRef(String), + AsOf { branch: String, at: DateTime }, } #[derive(Debug, Error)] @@ -60,6 +60,8 @@ pub enum RepositoryErrorKind { #[error("snapshot not found: `{id}`")] SnapshotNotFound { id: SnapshotId }, + #[error("branch {branch} does not have a snapshots before or at {at}")] + InvalidAsOfSpec { branch: String, at: DateTime }, #[error("invalid snapshot id: `{0}`")] InvalidSnapshotId(String), #[error("tag error: `{0}`")] @@ -404,11 +406,12 @@ impl Repository { } /// Returns the sequence of parents of the snapshot pointed by the given version + #[async_recursion(?Send)] #[instrument(skip(self))] - pub async fn ancestry( - &self, + pub async fn ancestry<'a>( + &'a self, version: &VersionInfo, - ) -> RepositoryResult> + '_> { + ) -> RepositoryResult> + 'a> { let snapshot_id = self.resolve_version(version).await?; self.snapshot_ancestry(&snapshot_id).await } @@ -572,6 +575,24 @@ impl Repository { .await?; Ok(ref_data.snapshot) } + VersionInfo::AsOf { branch, at } => { + let tip = VersionInfo::BranchTipRef(branch.clone()); + let snap = self + .ancestry(&tip) + .await? + .try_skip_while(|parent| ready(Ok(&parent.flushed_at > at))) + .take(1) + .try_collect::>() + .await?; + match snap.into_iter().next() { + Some(snap) => Ok(snap.id), + None => Err(RepositoryErrorKind::InvalidAsOfSpec { + branch: branch.clone(), + at: *at, + } + .into()), + } + } } }