Skip to content

Commit

Permalink
Add Line as SpecInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
paula-mg committed Sep 6, 2024
1 parent 18d4d6b commit b239efe
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 34 deletions.
34 changes: 0 additions & 34 deletions src/scanspec/schema.py

This file was deleted.

Empty file added src/scanspec/schema/__init__.py
Empty file.
46 changes: 46 additions & 0 deletions src/scanspec/schema/resolvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any

import numpy as np
from specs import Line

from scanspec.core import (
Frames,
Path,
)


def reduce_frames(stack: list[Frames[str]], max_frames: int) -> Path:
"""Removes frames from a spec so len(path) < max_frames.
Args:
stack: A stack of Frames created by a spec
max_frames: The maximum number of frames the user wishes to be returned
"""
# Calculate the total number of frames
num_frames = 1
for frames in stack:
num_frames *= len(frames)

# Need each dim to be this much smaller
ratio = 1 / np.power(max_frames / num_frames, 1 / len(stack))

sub_frames = [sub_sample(f, ratio) for f in stack]
return Path(sub_frames)


def sub_sample(frames: Frames[str], ratio: float) -> Frames:
"""Provides a sub-sample Frames object whilst preserving its core structure.
Args:
frames: the Frames object to be reduced
ratio: the reduction ratio of the dimension
"""
num_indexes = int(len(frames) / ratio)
indexes = np.linspace(0, len(frames) - 1, num_indexes, dtype=np.int32)
return frames.extract(indexes, calculate_gap=False)


def validate_spec(spec: Line) -> Any:
"""A query used to confirm whether or not the Spec will produce a viable scan."""
# TODO apischema will do all the validation for us
return spec.serialize()
48 changes: 48 additions & 0 deletions src/scanspec/schema/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import strawberry
from fastapi import FastAPI
from resolvers import reduce_frames, validate_spec
from specs import Line, PointsResponse
from strawberry.fastapi import GraphQLRouter

from scanspec.core import Path


@strawberry.type
class Query:
@strawberry.field
def validate(self, spec: Line) -> str:
return validate_spec(spec)

@strawberry.field
def get_points(self, spec: Line, max_frames: int | None = 10000) -> PointsResponse:
"""Calculate the frames present in the scan plus some metadata about the points.
Args:
spec: The specification of the scan
max_frames: The maximum number of frames the user wishes to receive
"""

dims = spec.calculate() # Grab dimensions from spec

path = Path(dims) # Convert to a path

# TOTAL FRAMES
total_frames = len(path) # Capture the total length of the path

# MAX FRAMES
# Limit the consumed data by the max_frames argument
if max_frames and (max_frames < len(path)):
# Cap the frames by the max limit
path = reduce_frames(dims, max_frames)
# WARNING: path object is consumed after this statement
chunk = path.consume(max_frames)

return PointsResponse(chunk, total_frames)


schema = strawberry.Schema(Query)

graphql_app = GraphQLRouter(schema, path="/", graphql_ide="apollo-sandbox")

app = FastAPI()
app.include_router(graphql_app, prefix="/graphql")
93 changes: 93 additions & 0 deletions src/scanspec/schema/specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

from _collections_abc import Callable, Mapping
from typing import Any

import numpy as np
import strawberry

from scanspec.core import (
Axis,
Frames,
gap_between_frames,
)


@strawberry.type
class PointsResponse:
"""Information about the points provided by a spec."""

total_frames: int
returned_frames: int

def __init__(self, chunk: Frames[str], total_frames: int):
self.total_frames = total_frames
"""The number of frames present across the entire spec"""
self.returned_frames = len(chunk)
"""The number of frames returned by the getPoints query
(controlled by the max_points argument)"""
self._chunk = chunk


@strawberry.interface
class SpecInterface:
def serialize(self) -> Mapping[str, Any]:
"""Serialize the spec to a dictionary."""
return "serialized"


def _dimensions_from_indexes(
func: Callable[[np.ndarray], dict[Axis, np.ndarray]],
axes: list,
num: int,
bounds: bool,
) -> list[Frames[Axis]]:
# Calc num midpoints (fences) from 0.5 .. num - 0.5
midpoints_calc = func(np.linspace(0.5, num - 0.5, num))
midpoints = {a: midpoints_calc[a] for a in axes}
if bounds:
# Calc num + 1 bounds (posts) from 0 .. num
bounds_calc = func(np.linspace(0, num, num + 1))
lower = {a: bounds_calc[a][:-1] for a in axes}
upper = {a: bounds_calc[a][1:] for a in axes}
# Points must have no gap as upper[a][i] == lower[a][i+1]
# because we initialized it to be that way
gap = np.zeros(num, dtype=np.bool_)
dimension = Frames(midpoints, lower, upper, gap)
# But calc the first point as difference between first
# and last
gap[0] = gap_between_frames(dimension, dimension)
else:
# Gap can be calculated in Dimension
dimension = Frames(midpoints)
return [dimension]


@strawberry.input
class Line(SpecInterface):
axis: str = strawberry.field(description="An identifier for what to move")
start: float = strawberry.field(
description="Midpoint of the first point of the line"
)
stop: float = strawberry.field(description="Midpoint of the last point of the line")
num: int = strawberry.field(description="Number of frames to produce")

def axes(self) -> list:
return [self.axis]

def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]:
if self.num == 1:
# Only one point, stop-start gives length of one point
step = self.stop - self.start
else:
# Multiple points, stop-start gives length of num-1 points
step = (self.stop - self.start) / (self.num - 1)
# self.start is the first centre point, but we need the lower bound
# of the first point as this is where the index array starts
first = self.start - step / 2
return {self.axis: indexes * step + first}

def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]:
return _dimensions_from_indexes(
self._line_from_indexes, self.axes(), self.num, bounds
)

0 comments on commit b239efe

Please sign in to comment.