Skip to content

Commit

Permalink
Additional functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
KernelA committed Jun 7, 2022
1 parent 22fba4f commit 1f90b25
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 48 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Check connection to the database.

Download data from the Sketchfab or any PLY file with color.

Modify config: [insert_data.yaml](configs/insert_data.yaml).
Modify config: [insert_data.yaml](configs/insert_data.yaml). By default, `drop_tables` is true.

Insert data:
```
Expand Down
2 changes: 1 addition & 1 deletion configs/insert_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ data:
loader:
_target_: loader.ply_loader.PLYColorLoader

drop_db: false
drop_db: true

data_splitting:
voxel_size: 2
Expand Down
55 changes: 38 additions & 17 deletions dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from sqlalchemy import sql
from sqlalchemy.orm import sessionmaker
from geoalchemy2 import func as geosql
from geoalchemy2 import shape as geoshape
from shapely import geometry
import geopandas as geo
import numpy as np
import dash_vtk
Expand All @@ -28,14 +30,12 @@
BUTTON_ID = "draw-button-id"
TABLE_ID = "datatable-id"
ERROR_ID = "error-id"
RADIUS_SLIDER_ID = "slider-id"

CACHE_DIR = "./dashboard-cache"

cache = diskcache.Cache(CACHE_DIR)
long_callback_manager = DiskcacheLongCallbackManager(cache)

app = Dash(__name__, external_stylesheets=[
dbc.themes.BOOTSTRAP], long_callback_manager=long_callback_manager)
dbc.themes.BOOTSTRAP], long_callback_manager=DiskcacheLongCallbackManager(diskcache.Cache(CACHE_DIR)))


app.layout = dbc.Container([
Expand All @@ -44,6 +44,9 @@
dcc.Dropdown(id=FILE_DROPDOWN),
html.Label("Chunk id:"),
dcc.Dropdown(id=CHUNK_DROP_DOWN, multi=True),
html.Label("Radius around mean point [units of the data]:"),
dcc.Slider(0, 4, value=2, step=0.1, id=RADIUS_SLIDER_ID, tooltip={
"placement": "bottom", "always_visible": True}),
html.P(id=ERROR_ID, style={"color": "red"}),
html.Div(children=[html.Button("Query points and draw", id=BUTTON_ID, style={"width": "25%"})]),
html.Progress(id=PROGRESS_BAR_ID, max=str(100),
Expand All @@ -55,7 +58,7 @@
dbc.Col(id=PLOT_3D_ID)
], class_name="h-75"
)
], fluid=True, style={"height": "90vh"}
], fluid=True, style={"height": "80vh"}
)


Expand Down Expand Up @@ -112,7 +115,8 @@ def select_chunk_ids(file):
@ app.long_callback(
output=Output(PLOT_3D_ID, "children"),
inputs=[Input(BUTTON_ID, "n_clicks")],
state=[State(FILE_DROPDOWN, "value"), State(CHUNK_DROP_DOWN, "value")],
state=[State(FILE_DROPDOWN, "value"), State(
CHUNK_DROP_DOWN, "value"), State(RADIUS_SLIDER_ID, "value")],
running=[
(Output(FILE_DROPDOWN, "disabled"), True, False),
(Output(CHUNK_DROP_DOWN, "disabled"), True, False),
Expand All @@ -123,11 +127,11 @@ def select_chunk_ids(file):
{"visibility": "hidden"},
)
],
interval=8000,
interval=4000,
progress=[Output(PROGRESS_BAR_ID, "value")],
prevent_initial_call=True
)
def select_chunk(set_progress, n_click, file_path, chunk_ids):
def select_chunk(set_progress, n_click, file_path, chunk_ids, radius_around_centroid):
if isinstance(chunk_ids, numbers.Number):
chunk_ids = [chunk_ids]

Expand All @@ -145,17 +149,34 @@ def select_chunk(set_progress, n_click, file_path, chunk_ids):

Session = sessionmaker(engine)

geom_col_name = "geom"
geom_col_name = "point"
color_col_name = "color"

with Session.begin() as session:
query_points = sql.select(LazPoints.points.label(
geom_col_name), LazPoints.colors.label(color_col_name)) \
subq = sql.select(geosql.ST_DumpPoints(LazPoints.points).geom.label("point"))\
.filter(LazPoints.file == file_path) \
.filter(LazPoints.chunk_id.in_(chunk_ids))
.filter(LazPoints.chunk_id.in_(chunk_ids)).subquery()

mean_point = session.execute(sql.select(
sql.func.avg(geosql.ST_X(subq.c.point)).label("x"),
sql.func.avg(geosql.ST_Y(subq.c.point)).label("y"),
sql.func.avg(geosql.ST_Z(subq.c.point)).label("z"))).one()

mean_point = geoshape.from_shape(geometry.Point(mean_point.x, mean_point.y, mean_point.z))

set_progress([str(25)])

del subq

first_filtered_points = sql.select(geosql.ST_DumpPoints(LazPoints.points).label("point_info"), LazPoints.colors)\
.where(geosql.ST_3DDWithin(LazPoints.points, mean_point, radius_around_centroid)).subquery()

filtered_points = sql.select(first_filtered_points.c.point_info.geom.label("point"),
first_filtered_points.c.colors[first_filtered_points.c.point_info.path[1]: first_filtered_points.c.point_info.path[1]].label("color")) \
.where(geosql.ST_3DDWithin(first_filtered_points.c.point_info.geom, mean_point, radius_around_centroid))

points_data = geo.read_postgis(
query_points, session.connection(), geom_col=geom_col_name)
filtered_points, session.connection(), geom_col=geom_col_name)

engine.dispose()

Expand All @@ -165,12 +186,12 @@ def select_chunk(set_progress, n_click, file_path, chunk_ids):
colors = []

for row in points_data.itertuples(index=False):
for point in getattr(row, geom_col_name).geoms:
for coord in point.coords:
xyz.extend(coord)

coord = getattr(row, geom_col_name)
xyz.extend(coord.coords)
colors.extend(np.array(getattr(row, color_col_name)).reshape(-1))

set_progress([str(75)])

xyz = np.array(xyz)
xyz -= xyz.mean(axis=0)
scale = 1 / max(np.linalg.norm(xyz.max(axis=0) - xyz.min(axis=0)), 1e-4)
Expand Down
10 changes: 5 additions & 5 deletions db/laz_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@


class LazPoints(Base):
__tablename__ = "LidarPointsPly"
__tablename__ = "LidarPoints"

id = Column(BigInteger, primary_key=True, autoincrement=True)
chunk_id = Column(Integer, nullable=False)
file = Column(String, nullable=False)
chunk_id = Column(Integer, nullable=False, index=True)
file = Column(String, nullable=False, index=True)
points = Column(Geometry("MULTIPOINTZ", dimension=3,
spatial_index=False, nullable=False, use_N_D_index=False))
colors = Column(ARRAY(Integer, zero_indexes=True, dimensions=2))
spatial_index=True, use_N_D_index=True, nullable=False))
colors = Column(ARRAY(Integer, zero_indexes=True, dimensions=2), nullable=False)
8 changes: 2 additions & 6 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@ dependencies:
- numpy
- conda-forge::psycopg2==2.9
- tqdm==4.63
- ipython
- autopep8
- plotly::plotly==5.7.0
- ipykernel
- conda-forge::sqlalchemy==1.4
- conda-forge::sqlalchemy>=1.4.37
- conda-forge::geopandas==0.10
- h5py=3.6
- conda-forge::shapely==1.8
- pip:
- geoalchemy2~=0.11.0
- sqlalchemy-utils~=0.38.0
- geoalchemy2~=0.12.0
- diskcache~=5.4.0
- psutil~=5.9.0
- dash[diskcache]~=2.4.0
Expand Down
19 changes: 7 additions & 12 deletions insert_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,11 @@ def insert_points(session: Session, file_path: str, file: h5py.File, chunk_id: i


def generate_bounds(min_value: float, max_value: float, step: float):
values = []
value = min_value
if max_value < min_value + step:
max_value = min_value + step

while value < max_value:
values.append(value)
value += step

values.append(max_value)

return np.array(values)
values = np.arange(min_value, max_value, step=step, dtype=float)
return np.append(values, max_value)


def get_chunk_indices(points: np.ndarray, x_bounds: np.ndarray, y_bounds: np.ndarray, z_bounds: np.ndarray):
Expand Down Expand Up @@ -77,8 +72,7 @@ def insert_data(session_factory, path_to_file: str, loader_config, chunk_size: i

with tempfile.TemporaryDirectory() as tmp_dir:
for i, chunk_data in enumerate(loader.iter_chunks(chunk_size), 1):
chunk_xyz = chunk_data["xyz"]
colors = chunk_data["color"]
chunk_xyz = chunk_data[COORDS_KEY]

chunk_index_per_point = get_chunk_indices(
chunk_xyz, x_intervals, y_intervals, z_intervals)
Expand All @@ -91,7 +85,7 @@ def insert_data(session_factory, path_to_file: str, loader_config, chunk_size: i
with h5py.File(file_path, "a") as hdf_file:
indices = np.nonzero(chunk_index_per_point == chunk_index)
append_points(
hdf_file, {COORDS_KEY: chunk_xyz[indices], COLOR_KEY: colors[indices]})
hdf_file, {COORDS_KEY: chunk_xyz[indices], **{key: data[indices] for key, data in chunk_data.items() if key != COORDS_KEY}})

del file_path

Expand All @@ -112,6 +106,7 @@ def main(config):
Base.metadata.drop_all(engine)

Base.metadata.create_all(engine)

CustomSession = sessionmaker(engine)
insert_data(CustomSession, config.data.filepath, config.loader,
config.data_splitting.chunk_size, config.data_splitting.voxel_size)
Expand Down
4 changes: 2 additions & 2 deletions loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base_loader import BasePointCloudLoader
from .ply_loader import PLYColorLoader, COLOR_KEY, COORDS_KEY
from .base_loader import BasePointCloudLoader, COLOR_KEY, COORDS_KEY
from .ply_loader import PLYColorLoader
3 changes: 3 additions & 0 deletions loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import numpy as np

COORDS_KEY = "xyz"
COLOR_KEY = "color"


class BasePointCloudLoader(ABC):

Expand Down
5 changes: 1 addition & 4 deletions loader/ply_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from plyfile import PlyData
import numpy as np

from .base_loader import BasePointCloudLoader

COORDS_KEY = "xyz"
COLOR_KEY = "color"
from .base_loader import BasePointCloudLoader, COORDS_KEY, COLOR_KEY


class PLYColorLoader(BasePointCloudLoader):
Expand Down

0 comments on commit 1f90b25

Please sign in to comment.