-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
187 additions
and
34 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
from specs import Line | ||
|
||
from scanspec.core import ( | ||
Frames, | ||
Path, | ||
) | ||
|
||
|
||
def reduce_frames(stack: list[Frames[str]], max_frames: int) -> Path: | ||
"""Removes frames from a spec so len(path) < max_frames. | ||
Args: | ||
stack: A stack of Frames created by a spec | ||
max_frames: The maximum number of frames the user wishes to be returned | ||
""" | ||
# Calculate the total number of frames | ||
num_frames = 1 | ||
for frames in stack: | ||
num_frames *= len(frames) | ||
|
||
# Need each dim to be this much smaller | ||
ratio = 1 / np.power(max_frames / num_frames, 1 / len(stack)) | ||
|
||
sub_frames = [sub_sample(f, ratio) for f in stack] | ||
return Path(sub_frames) | ||
|
||
|
||
def sub_sample(frames: Frames[str], ratio: float) -> Frames: | ||
"""Provides a sub-sample Frames object whilst preserving its core structure. | ||
Args: | ||
frames: the Frames object to be reduced | ||
ratio: the reduction ratio of the dimension | ||
""" | ||
num_indexes = int(len(frames) / ratio) | ||
indexes = np.linspace(0, len(frames) - 1, num_indexes, dtype=np.int32) | ||
return frames.extract(indexes, calculate_gap=False) | ||
|
||
|
||
def validate_spec(spec: Line) -> Any: | ||
"""A query used to confirm whether or not the Spec will produce a viable scan.""" | ||
# TODO apischema will do all the validation for us | ||
return spec.serialize() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import strawberry | ||
from fastapi import FastAPI | ||
from resolvers import reduce_frames, validate_spec | ||
from specs import Line, PointsResponse | ||
from strawberry.fastapi import GraphQLRouter | ||
|
||
from scanspec.core import Path | ||
|
||
|
||
@strawberry.type | ||
class Query: | ||
@strawberry.field | ||
def validate(self, spec: Line) -> str: | ||
return validate_spec(spec) | ||
|
||
@strawberry.field | ||
def get_points(self, spec: Line, max_frames: int | None = 10000) -> PointsResponse: | ||
"""Calculate the frames present in the scan plus some metadata about the points. | ||
Args: | ||
spec: The specification of the scan | ||
max_frames: The maximum number of frames the user wishes to receive | ||
""" | ||
|
||
dims = spec.calculate() # Grab dimensions from spec | ||
|
||
path = Path(dims) # Convert to a path | ||
|
||
# TOTAL FRAMES | ||
total_frames = len(path) # Capture the total length of the path | ||
|
||
# MAX FRAMES | ||
# Limit the consumed data by the max_frames argument | ||
if max_frames and (max_frames < len(path)): | ||
# Cap the frames by the max limit | ||
path = reduce_frames(dims, max_frames) | ||
# WARNING: path object is consumed after this statement | ||
chunk = path.consume(max_frames) | ||
|
||
return PointsResponse(chunk, total_frames) | ||
|
||
|
||
schema = strawberry.Schema(Query) | ||
|
||
graphql_app = GraphQLRouter(schema, path="/", graphql_ide="apollo-sandbox") | ||
|
||
app = FastAPI() | ||
app.include_router(graphql_app, prefix="/graphql") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from __future__ import annotations | ||
|
||
from _collections_abc import Callable, Mapping | ||
from typing import Any | ||
|
||
import numpy as np | ||
import strawberry | ||
|
||
from scanspec.core import ( | ||
Axis, | ||
Frames, | ||
gap_between_frames, | ||
) | ||
|
||
|
||
@strawberry.type | ||
class PointsResponse: | ||
"""Information about the points provided by a spec.""" | ||
|
||
total_frames: int | ||
returned_frames: int | ||
|
||
def __init__(self, chunk: Frames[str], total_frames: int): | ||
self.total_frames = total_frames | ||
"""The number of frames present across the entire spec""" | ||
self.returned_frames = len(chunk) | ||
"""The number of frames returned by the getPoints query | ||
(controlled by the max_points argument)""" | ||
self._chunk = chunk | ||
|
||
|
||
@strawberry.interface | ||
class SpecInterface: | ||
def serialize(self) -> Mapping[str, Any]: | ||
"""Serialize the spec to a dictionary.""" | ||
return "serialized" | ||
|
||
|
||
def _dimensions_from_indexes( | ||
func: Callable[[np.ndarray], dict[Axis, np.ndarray]], | ||
axes: list, | ||
num: int, | ||
bounds: bool, | ||
) -> list[Frames[Axis]]: | ||
# Calc num midpoints (fences) from 0.5 .. num - 0.5 | ||
midpoints_calc = func(np.linspace(0.5, num - 0.5, num)) | ||
midpoints = {a: midpoints_calc[a] for a in axes} | ||
if bounds: | ||
# Calc num + 1 bounds (posts) from 0 .. num | ||
bounds_calc = func(np.linspace(0, num, num + 1)) | ||
lower = {a: bounds_calc[a][:-1] for a in axes} | ||
upper = {a: bounds_calc[a][1:] for a in axes} | ||
# Points must have no gap as upper[a][i] == lower[a][i+1] | ||
# because we initialized it to be that way | ||
gap = np.zeros(num, dtype=np.bool_) | ||
dimension = Frames(midpoints, lower, upper, gap) | ||
# But calc the first point as difference between first | ||
# and last | ||
gap[0] = gap_between_frames(dimension, dimension) | ||
else: | ||
# Gap can be calculated in Dimension | ||
dimension = Frames(midpoints) | ||
return [dimension] | ||
|
||
|
||
@strawberry.input | ||
class Line(SpecInterface): | ||
axis: str = strawberry.field(description="An identifier for what to move") | ||
start: float = strawberry.field( | ||
description="Midpoint of the first point of the line" | ||
) | ||
stop: float = strawberry.field(description="Midpoint of the last point of the line") | ||
num: int = strawberry.field(description="Number of frames to produce") | ||
|
||
def axes(self) -> list: | ||
return [self.axis] | ||
|
||
def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: | ||
if self.num == 1: | ||
# Only one point, stop-start gives length of one point | ||
step = self.stop - self.start | ||
else: | ||
# Multiple points, stop-start gives length of num-1 points | ||
step = (self.stop - self.start) / (self.num - 1) | ||
# self.start is the first centre point, but we need the lower bound | ||
# of the first point as this is where the index array starts | ||
first = self.start - step / 2 | ||
return {self.axis: indexes * step + first} | ||
|
||
def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: | ||
return _dimensions_from_indexes( | ||
self._line_from_indexes, self.axes(), self.num, bounds | ||
) |