Skip to content

Commit

Permalink
Add color info
Browse files Browse the repository at this point in the history
  • Loading branch information
KernelA committed May 31, 2022
1 parent b857487 commit 5197a22
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 43 deletions.
4 changes: 3 additions & 1 deletion configs/insert_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ data:
filepath: ./data/Carola_PointCloud.ply

loader:
_target_: loader.ply_loader.PLYLoader
_target_: loader.ply_loader.PLYColorLoader

drop_db: false

data_splitting:
voxel_size: 2.5
Expand Down
31 changes: 21 additions & 10 deletions dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from sqlalchemy import sql
from sqlalchemy.orm import sessionmaker
from geoalchemy2 import func as geosql
from geoalchemy2 import shape
import geopandas as geo
import numpy as np
from shapely import wkb
import dash_vtk
from dash import Dash, html, dcc, Input, Output, State
from dash.long_callback import DiskcacheLongCallbackManager
Expand All @@ -20,11 +22,7 @@

FILE_DROPDOWN = "file-selection-id"
CHUNK_DROP_DOWN = "chunk-dropdown-id"
SIGN_INFO_TABLE_ID = "sign-info-table"
PLOT_3D_ID = "3d-scatter-id"
EXTERNAL_DROPDOWN = "external-id-field"
PROGRESS_ID = "progress-id"
RADIUS_SLIDER_ID = "selection-radius-range"
PROGRESS_BAR_ID = "progress-bar-id"
BUTTON_ID = "draw-button-id"

Expand Down Expand Up @@ -119,27 +117,40 @@ def select_chunk(set_progress, file_path, chunk_ids, n_click):
Session = sessionmaker(engine)

geom_col_name = "geom"
color_col_name = "color"

with Session.begin() as session:
query_points = sql.select(geosql.ST_DumpPoints(LazPoints.points).geom.label(
geom_col_name)) \
query_points = sql.select(LazPoints.points.label(
geom_col_name), LazPoints.colors.label(color_col_name)) \
.filter(LazPoints.file == file_path) \
.filter(LazPoints.chunk_id.in_(chunk_ids))

points = geo.read_postgis(
points_data = geo.read_postgis(
query_points, session.connection(), geom_col=geom_col_name)

engine.dispose()

set_progress([str(50)])

coords = np.vstack(
(points[geom_col_name].x, points[geom_col_name].y, points[geom_col_name].z)).T.reshape(-1)
xyz = []
colors = []

for row in points_data.itertuples(index=False):
for point in getattr(row, geom_col_name).geoms:
for coord in point.coords:
xyz.extend(coord)

colors.extend(np.array(getattr(row, color_col_name)).reshape(-1))

xyz = np.array(xyz)
xyz -= xyz.mean(axis=0)
xyz = xyz.reshape(-1).tolist()

vtk_view = dash_vtk.View(
[
dash_vtk.PointCloudRepresentation(
xyz=coords,
xyz=xyz,
rgb=colors,
property={"pointSize": 2}
)
],
Expand Down
3 changes: 2 additions & 1 deletion db/laz_points.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import Column, BigInteger, String, Integer
from sqlalchemy import Column, BigInteger, String, Integer, ARRAY
from geoalchemy2 import Geometry

from .base import Base
Expand All @@ -12,3 +12,4 @@ class LazPoints(Base):
file = Column(String, nullable=False)
points = Column(Geometry("MULTIPOINTZ", dimension=3,
spatial_index=False, nullable=False, use_N_D_index=False))
colors = Column(ARRAY(Integer, zero_indexes=True, dimensions=2))
24 changes: 16 additions & 8 deletions insert_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from db import LazPoints
from db.base import Base

from loader import BasePointCloudLoader
from utils import append_points, DATASET_PATH
from loader import BasePointCloudLoader, COLOR_KEY, COORDS_KEY
from utils import append_points


def insert_points(session: Session, file_path: str, file: h5py.File, chunk_id: int):
Expand All @@ -27,10 +27,12 @@ def insert_points(session: Session, file_path: str, file: h5py.File, chunk_id: i
if res is not None:
return

points = file.get(DATASET_PATH)[:]
points: np.ndarray = file.get(COORDS_KEY)[:]
colors: np.ndarray = file.get(COLOR_KEY)[:]

multi_point = MultiPoint(points)
session.add(LazPoints(file=file_path, chunk_id=chunk_id, points=from_shape(multi_point)))

session.add(LazPoints(file=file_path, chunk_id=chunk_id, points=from_shape(multi_point), colors=colors.tolist()))


def generate_bounds(min_value: float, max_value: float, step: float):
Expand Down Expand Up @@ -73,18 +75,21 @@ def insert_data(session_factory, path_to_file: str, loader_config, chunk_size: i
chunk_ids = []

with tempfile.TemporaryDirectory() as tmp_dir:
for chunk_xyz in loader.iter_chunks(chunk_size):
for i, chunk_data in enumerate(loader.iter_chunks(chunk_size), 1):
chunk_xyz = chunk_data["xyz"]
colors = chunk_data["color"]

chunk_index_per_point = get_chunk_indices(
chunk_xyz, x_intervals, y_intervals, z_intervals)

for chunk_index in tqdm(set(chunk_index_per_point), desc="Save to hdf"):
for chunk_index in tqdm(set(chunk_index_per_point), desc=f"Save chunk {i} to hdf"):
file_path = os.path.join(tmp_dir, f"chunk_{chunk_index}.h5")
files.append(file_path)
chunk_ids.append(int(chunk_index))

with h5py.File(file_path, "a") as hdf_file:
local_chunk = chunk_xyz[chunk_index_per_point == chunk_index]
append_points(hdf_file, local_chunk)
indices = np.nonzero(chunk_index_per_point == chunk_index)
append_points(hdf_file, {COORDS_KEY: chunk_xyz[indices], COLOR_KEY: colors[indices]})

del file_path

Expand All @@ -101,6 +106,9 @@ def main(config):
engine = create_engine(config.db.url)

try:
if config.drop_db:
Base.metadata.drop_all(engine)

Base.metadata.create_all(engine)
CustomSession = sessionmaker(engine)
insert_data(CustomSession, config.data.filepath, config.loader,
Expand Down
2 changes: 1 addition & 1 deletion loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base_loader import BasePointCloudLoader
from .ply_loader import PLYLoader
from .ply_loader import PLYColorLoader, COLOR_KEY, COORDS_KEY
4 changes: 2 additions & 2 deletions loader/base_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Iterable, Tuple
from typing import Iterable, Tuple, Dict

import numpy as np

Expand All @@ -10,7 +10,7 @@ def __init__(self, path_to_file: str):
self.path_to_file = path_to_file

@abstractmethod
def iter_chunks(self, chunk_size: int) -> Iterable[np.ndarray]:
def iter_chunks(self, chunk_size: int) -> Iterable[Dict[str, np.ndarray]]:
pass

@abstractmethod
Expand Down
25 changes: 17 additions & 8 deletions loader/ply_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from typing import Iterable, Tuple
from typing import Iterable, Tuple, Dict

from plyfile import PlyData
import numpy as np

from .base_loader import BasePointCloudLoader

COORDS_KEY = "xyz"
COLOR_KEY = "color"

class PLYLoader(BasePointCloudLoader):

class PLYColorLoader(BasePointCloudLoader):
def __init__(self, path_to_file: str):
super().__init__(path_to_file)

def iter_chunks(self, chunk_size: int) -> Iterable[np.ndarray]:
def iter_chunks(self, chunk_size: int) -> Iterable[Dict[str, np.ndarray]]:
assert chunk_size > 0
with open(self.path_to_file, "rb") as file:
data = PlyData.read(file)
Expand All @@ -22,19 +25,25 @@ def iter_chunks(self, chunk_size: int) -> Iterable[np.ndarray]:
x = data["vertex"]["x"][i:end]
y = data["vertex"]["y"][i:end]
z = data["vertex"]["z"][i:end]
red = data["vertex"]["red"][i:end]
green = data["vertex"]["green"][i:end]
blue = data["vertex"]["blue"][i:end]

colors = np.vstack((red, green, blue)).T

yield np.vstack((x, y, z)).T
yield {COORDS_KEY: np.vstack((x, y, z)).T, COLOR_KEY: colors}

def get_bounds(self) -> Tuple[np.ndarray, np.ndarray]:
min_bounds = None
max_bounds = None

for chunk_data in self.iter_chunks(10_000):
coords = chunk_data[COORDS_KEY]
if min_bounds is None:
min_bounds = chunk_data.min(axis=0)
max_bounds = chunk_data.max(axis=0)
min_bounds = coords.min(axis=0)
max_bounds = coords.max(axis=0)
else:
min_bounds = np.minimum(min_bounds, chunk_data.min(axis=0))
max_bounds = np.maximum(max_bounds, chunk_data.max(axis=0))
min_bounds = np.minimum(min_bounds, coords.min(axis=0))
max_bounds = np.maximum(max_bounds, coords.max(axis=0))

return min_bounds, max_bounds
4 changes: 0 additions & 4 deletions main.py

This file was deleted.

19 changes: 11 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Dict

import h5py
import numpy as np

DATASET_PATH = "/points"

def append_points(file: h5py.File, points_data: Dict[str, np.ndarray]) -> None:
for dataset_name in points_data:
dataset = file.get(dataset_name)

def append_points(file: h5py.File, points: np.ndarray) -> None:
dataset = file.get(DATASET_PATH)
points = points_data[dataset_name]

if dataset is None:
file.create_dataset(DATASET_PATH, maxshape=(None, points.shape[1]), data=points)
else:
last_index = dataset.shape[0]
dataset[last_index:] = points
if dataset is None:
file.create_dataset(dataset_name, maxshape=(None, points.shape[1]), data=points)
else:
last_index = dataset.shape[0]
dataset[last_index:] = points

0 comments on commit 5197a22

Please sign in to comment.