From 075111b127a315f22e7b1acd0484180e76867d62 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Sun, 23 Jul 2023 21:45:02 -0400 Subject: [PATCH 01/86] add bg ensemble --- lab4d/dataloader/data_utils.py | 12 ++- lab4d/engine/model.py | 2 +- lab4d/nnutils/base.py | 159 +++++++++++++++++++++++++++ lab4d/nnutils/bgnerf.py | 146 +++++++++++++++++++++++++ lab4d/nnutils/deformable.py | 7 +- lab4d/nnutils/multifields.py | 9 +- lab4d/nnutils/nerf.py | 87 ++++++++++----- lab4d/nnutils/transformer.py | 191 +++++++++++++++++++++++++++++++++ lab4d/utils/geom_utils.py | 19 ++-- 9 files changed, 583 insertions(+), 49 deletions(-) create mode 100644 lab4d/nnutils/bgnerf.py create mode 100644 lab4d/nnutils/transformer.py diff --git a/lab4d/dataloader/data_utils.py b/lab4d/dataloader/data_utils.py index 2bd134f..976ecaa 100644 --- a/lab4d/dataloader/data_utils.py +++ b/lab4d/dataloader/data_utils.py @@ -315,11 +315,13 @@ def load_small_files(data_path_dict): data_info["rtmat"] = np.stack([rtmat_bg, rtmat_fg], 0) # path to centered mesh files - camera_prefix = data_path_dict["cambg"][0].rsplit("/", 1)[0] - data_info["geom_path"] = [ - "%s/mesh-00-centered.obj" % camera_prefix, - "%s/mesh-01-centered.obj" % camera_prefix, - ] + geom_path_bg = [] + geom_path_fg = [] + for path in data_path_dict["cambg"]: + camera_prefix = path.rsplit("/", 1)[0] + geom_path_bg.append("%s/mesh-00-centered.obj" % camera_prefix) + geom_path_fg.append("%s/mesh-01-centered.obj" % camera_prefix) + data_info["geom_path"] = [geom_path_bg, geom_path_fg] return data_info diff --git a/lab4d/engine/model.py b/lab4d/engine/model.py index 88886c2..69dc77d 100644 --- a/lab4d/engine/model.py +++ b/lab4d/engine/model.py @@ -107,7 +107,7 @@ def set_progress(self, current_steps): alpha = None self.fields.set_alpha(alpha) - # beta_prob: steps(0->2k, 1->0.2), range (0.2,1) + # anneal geometry/appearance code for foreground: steps(0->2k, 1->0.2), range (0.2,1) anchor_x = (0, 2000) anchor_y = (1.0, 0.2) type = "linear" diff --git a/lab4d/nnutils/base.py b/lab4d/nnutils/base.py index 8115002..daad8b6 100644 --- a/lab4d/nnutils/base.py +++ b/lab4d/nnutils/base.py @@ -1,8 +1,11 @@ # Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. import torch import torch.nn as nn +import torch.nn.functional as F +from functorch import vmap, combine_state_for_ensemble from lab4d.nnutils.embedding import InstEmbedding +from lab4d.nnutils.transformer import Transformer class ScaleLayer(nn.Module): @@ -155,3 +158,159 @@ def get_dim_inst(num_inst, inst_channels): return inst_channels else: return 0 + + +class CondTransformerMLP(BaseMLP): + """A MLP that accepts both input `x` and condition `c` + + Args: + num_inst (int): Number of distinct object instances. If --nosingle_inst + is passed, this is equal to the number of videos, as we assume each + video captures a different instance. Otherwise, we assume all videos + capture the same instance and set this to 1. + D (int): Number of linear layers for density (sigma) encoder + W (int): Number of hidden units in each MLP layer + in_channels (int): Number of channels in input `x` + inst_channels (int): Number of channels in condition `c` + out_channels (int): Number of output channels + skips (List(int)): List of layers to add skip connections at + activation (Function): Activation function to use (e.g. nn.ReLU()) + final_act (bool): If True, apply the activation function to the output + """ + + def __init__( + self, + num_inst, + D=8, + W=256, + in_channels=63, + inst_channels=32, + out_channels=3, + skips=[4], + activation=nn.ReLU(True), + final_act=False, + ): + inst_channels = 768 + super().__init__( + D=D, + W=W, + in_channels=in_channels, + out_channels=out_channels, + skips=skips, + activation=activation, + final_act=final_act, + ) + self.inst_embedding = InstEmbedding(num_inst, inst_channels) + + self.transformer = Transformer( + in_channels, + depth=1, + heads=12, + dim_head=inst_channels // 12, + mlp_dim=inst_channels * 2, + selfatt=False, + kv_dim=inst_channels, + ) + + def forward(self, feat, inst_id): + """ + Args: + feat: (M, ..., self.in_channels) + inst_id: (M,) Instance id, or None to use the average instance + Returns: + out: (M, ..., self.out_channels) + """ + if inst_id is None: + if self.inst_embedding.out_channels > 0: + inst_code = self.inst_embedding.get_mean_embedding() + inst_code = inst_code.expand(feat.shape[:-1] + (-1,)) + # print("inst_embedding exists but inst_id is None, using mean inst_code") + else: + # empty, falls back to single-instance NeRF + inst_code = torch.zeros(feat.shape[:-1] + (0,), device=feat.device) + else: + inst_code = self.inst_embedding(inst_id) + inst_code = inst_code.view( + inst_code.shape[:1] + (1,) * (feat.ndim - 2) + (-1,) + ) + inst_code = inst_code.expand(feat.shape[:-1] + (-1,)) + + # feat = torch.cat([feat, inst_code], -1) + # if both input feature and inst_code are empty, return zeros + if feat.shape[-1] == 0 and inst_code.shape[-1] == 0: + return feat + feat = self.transformer(feat, inst_code) + return super().forward(feat) + + @staticmethod + def get_dim_inst(num_inst, inst_channels): + if num_inst > 1: + return inst_channels + else: + return 0 + + +class MultiMLP(nn.Module): + """Independent MLP for each instance""" + + def __init__(self, num_inst, inst_channels=32, **kwargs): + super(MultiMLP, self).__init__() + self.out_channels = kwargs["out_channels"] + self.num_inst = num_inst + self.nets = [] + for i in range(num_inst): + self.nets.append(BaseMLP(**kwargs)) + self.nets = nn.ModuleList(self.nets) + + def forward(self, feat, inst_id): + """ + Args: + feat: (M, ..., self.in_channels) + inst_id: (M,) Instance id, or None to use the average instance + Returns: + out: (M, ..., self.out_channels) + """ + # rearrange the batch dimension + shape = feat.shape[:-1] + out = torch.zeros(shape + (self.out_channels,), device=feat.device).view( + -1, self.out_channels + ) + if inst_id is None: + return out.view(shape + (self.out_channels,)) + + feat = feat.view(-1, feat.shape[-1]) + inst_id = inst_id.view((-1,) + (1,) * (len(shape) - 1)) + inst_id = inst_id.expand(shape) + inst_id = inst_id.reshape(-1) + + # do the real work + + empty_input = torch.zeros_like(feat[:1]) + for it in range(self.num_inst): + id_sel = inst_id == it + if id_sel.sum() == 0: + # to avoid error in ddp + # Expected to have finished reduction in the prior iteration before starting a new one. + out = out + self.nets[it](empty_input).mean() * 0 + continue + x_sel = feat[id_sel] + out[id_sel] = self.nets[it](x_sel) + out = out.view(shape + (self.out_channels,)) + return out + + +class MixMLP(nn.Module): + """Mixing CondMLP and MultiMLP""" + + def __init__(self, num_inst, inst_channels=32, **kwargs): + super(MixMLP, self).__init__() + self.multimlp = MultiMLP(num_inst, inst_channels=inst_channels, **kwargs) + kwargs["D"] *= 5 # 5 + kwargs["W"] *= 2 # 128 + self.condmlp = CondMLP(num_inst, inst_channels=inst_channels, **kwargs) + + def forward(self, feat, inst_id): + out1 = self.condmlp(feat, inst_id) + out2 = self.multimlp(feat, inst_id) + out = out1 + out2 + return out diff --git a/lab4d/nnutils/bgnerf.py b/lab4d/nnutils/bgnerf.py new file mode 100644 index 0000000..4e72564 --- /dev/null +++ b/lab4d/nnutils/bgnerf.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +from lab4d.nnutils.nerf import NeRF + +import trimesh +from pysdf import SDF + +from lab4d.utils.quat_transform import quaternion_translation_to_se3 +from lab4d.utils.geom_utils import get_near_far +from lab4d.nnutils.base import MixMLP, MultiMLP, CondMLP, CondTransformerMLP +from lab4d.nnutils.visibility import VisField + + +class BGNeRF(NeRF): + """A static neural radiance field with an MLP backbone. Specialized to background.""" + + # def __init__(self, data_info, field_arch=CondTransformerMLP, D=5, W=128, **kwargs): + # def __init__(self, data_info, field_arch=MixMLP, D=1, W=64, **kwargs): + def __init__(self, data_info, field_arch=CondMLP, D=5, W=128, **kwargs): + super(BGNeRF, self).__init__( + data_info, field_arch=field_arch, D=D, W=W, **kwargs + ) + # TODO: update beta + # TODO: update scale + + def init_proxy(self, geom_paths, init_scale): + """Initialize the geometry from a mesh + + Args: + geom_path (Listy(str)): Initial shape mesh + init_scale (float): Geometry scale factor + """ + meshes = [] + for geom_path in geom_paths: + mesh = trimesh.creation.uv_sphere(radius=0.12, count=[4, 4]) + # mesh = trimesh.load(geom_path) + # mesh.vertices = mesh.vertices * init_scale + meshes.append(mesh) + self.proxy_geometry = meshes + + def get_proxy_geometry(self): + """Get proxy geometry + + Returns: + proxy_geometry (Trimesh): Proxy geometry + """ + return self.proxy_geometry[0] + + def init_aabb(self): + """Initialize axis-aligned bounding box""" + self.register_buffer("aabb", torch.zeros(len(self.proxy_geometry), 2, 3)) + self.update_aabb(beta=0) + + def get_init_sdf_fn(self): + """Initialize signed distance function from mesh geometry + + Returns: + sdf_fn_torch (Function): Signed distance function + """ + + def sdf_fn_torch_sphere(pts): + radius = 0.1 + # l2 distance to a unit sphere + dis = (pts).pow(2).sum(-1, keepdim=True) + sdf = torch.sqrt(dis) - radius # negative inside, postive outside + return sdf + + return sdf_fn_torch_sphere + + def update_proxy(self): + """Extract proxy geometry using marching cubes""" + for inst_id in range(self.num_inst): + mesh = self.extract_canonical_mesh(level=0.005, inst_id=inst_id) + if len(mesh.vertices) > 0: + self.proxy_geometry[inst_id] = mesh + + def get_aabb(self, inst_id=None): + """Get axis-aligned bounding box + Args: + inst_id: (N,) Instance id + Returns: + aabb: (2,3) Axis-aligned bounding box if inst_id is None, (N,2,3) otherwise + """ + if inst_id is None: + return self.aabb.mean(0) + return self.aabb[inst_id] + + def update_aabb(self, beta=0.9): + """Update axis-aligned bounding box by interpolating with the current + proxy geometry's bounds + + Args: + beta (float): Interpolation factor between previous/current values + """ + device = self.aabb.device + for inst_id in range(self.num_inst): + bounds = self.proxy_geometry[inst_id].bounds + if bounds is not None: + aabb = torch.tensor(bounds, dtype=torch.float32, device=device) + self.aabb[inst_id] = self.aabb[inst_id] * beta + aabb * (1 - beta) + + def update_near_far(self, beta=0.9): + """Update near-far bounds by interpolating with the current near-far bounds + + Args: + beta (float): Interpolation factor between previous/current values + """ + device = next(self.parameters()).device + with torch.no_grad(): + quat, trans = self.camera_mlp.get_vals() # (B, 4, 4) + rtmat = quaternion_translation_to_se3(quat, trans) + + frame_id_all = list(range(self.num_frames)) + frame_offset = self.frame_offset + near_far_all = [] + for inst_id in range(self.num_inst): + verts = self.proxy_geometry[inst_id].vertices + frame_id = frame_id_all[frame_offset[inst_id] : frame_offset[inst_id + 1]] + proxy_pts = torch.tensor(verts, dtype=torch.float32, device=device) + near_far = get_near_far(proxy_pts, rtmat[frame_id]).to(device) + near_far_all.append( + self.near_far[frame_id].data * beta + near_far * (1 - beta) + ) + self.near_far.data = torch.cat(near_far_all, 0) + + def get_near_far(self, frame_id, field2cam): + device = next(self.parameters()).device + frame_id_all = list(range(self.num_frames)) + frame_offset = self.frame_offset + field2cam_mat = quaternion_translation_to_se3(field2cam[0], field2cam[1]) + + near_far_all = [] + for inst_id in range(self.num_inst): + frame_id_sel = frame_id_all[ + frame_offset[inst_id] : frame_offset[inst_id + 1] + ] + # find the overlap of frame_id and frame_id_sel + id_sel = [i for i, x in enumerate(frame_id) if x in frame_id_sel] + if len(id_sel) == 0: + continue + corners = trimesh.bounds.corners(self.proxy_geometry[inst_id].bounds) + corners = torch.tensor(corners, dtype=torch.float32, device=device) + near_far = get_near_far(corners, field2cam_mat[id_sel], tol_fac=1.5) + near_far_all.append(near_far) + near_far = torch.cat(near_far_all, 0) + return near_far diff --git a/lab4d/nnutils/deformable.py b/lab4d/nnutils/deformable.py index 6410c13..0fe1b24 100644 --- a/lab4d/nnutils/deformable.py +++ b/lab4d/nnutils/deformable.py @@ -205,7 +205,7 @@ def gauss_skin_consistency_loss(self, nsample=2048): Returns: loss: (0,) Skinning consistency loss """ - pts = self.sample_points_aabb(nsample, extend_factor=0.25) + pts, _, _ = self.sample_points_aabb(nsample, extend_factor=0.25) # match the gauss density to the reconstructed density density_gauss = self.warp.get_gauss_density(pts) # (N,1) @@ -244,10 +244,7 @@ def soft_deform_loss(self, nsample=1024): Returns: loss: (0,) Soft deformation loss """ - device = next(self.parameters()).device - pts = self.sample_points_aabb(nsample, extend_factor=1.0) - frame_id = torch.randint(0, self.num_frames, (nsample,), device=device) - inst_id = torch.randint(0, self.num_inst, (nsample,), device=device) + pts, frame_id, inst_id = self.sample_points_aabb(nsample, extend_factor=1.0) dist2 = self.warp.compute_post_warp_dist2(pts[:, None, None], frame_id, inst_id) return dist2.mean() diff --git a/lab4d/nnutils/multifields.py b/lab4d/nnutils/multifields.py index 08c689a..e2447fa 100644 --- a/lab4d/nnutils/multifields.py +++ b/lab4d/nnutils/multifields.py @@ -8,6 +8,7 @@ from lab4d.nnutils.deformable import Deformable from lab4d.nnutils.nerf import NeRF +from lab4d.nnutils.bgnerf import BGNeRF from lab4d.nnutils.pose import ArticulationSkelMLP from lab4d.nnutils.warping import ComposedWarp, SkinningWarp from lab4d.utils.quat_transform import quaternion_translation_to_se3 @@ -84,11 +85,13 @@ def define_field(self, category, data_info, tracklet_id): ) # no directional encoding elif category == "bg": - nerf = NeRF( + # nerf = NeRF( + nerf = BGNeRF( data_info, num_freq_xyz=6, num_freq_dir=0, appr_channels=0, + num_inst=self.num_inst, init_scale=0.1, ) else: # exit with an error @@ -155,7 +158,7 @@ def extract_canonical_meshes( grid_size (int): Marching cubes resolution level (float): Contour value to search for isosurfaces on the signed distance function - inst_id: (M,) Instance id. If None, extract for the average instance + inst_id: (int) Instance id. If None, extract for the average instance use_visibility (bool): If True, use visibility mlp to mask out invisible region. use_extend_aabb (bool): If True, extend aabb by 50% to get a loose proxy. @@ -184,7 +187,7 @@ def export_geometry_aux(self, path): """ for category, field in self.field_params.items(): # print(field.near_far) - mesh_geo = field.proxy_geometry + mesh_geo = field.get_proxy_geometry() quat, trans = field.camera_mlp.get_vals() rtmat = quaternion_translation_to_se3(quat, trans).cpu() # evenly pick max 200 cameras diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index aeaba50..ae1f3ac 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -76,6 +76,7 @@ def __init__( init_beta=0.1, init_scale=0.1, color_act=True, + field_arch=CondMLP, ): rtmat = data_info["rtmat"] frame_info = data_info["frame_info"] @@ -95,7 +96,7 @@ def __init__( # xyz encoding layers # TODO: add option to replace with instNGP - self.basefield = CondMLP( + self.basefield = field_arch( num_inst=self.num_inst, D=D, W=W, @@ -109,7 +110,7 @@ def __init__( # color self.pos_embedding_color = PosEmbedding(3, num_freq_xyz + 2) - self.colorfield = CondMLP( + self.colorfield = field_arch( num_inst=self.num_inst, D=2, W=W, @@ -153,10 +154,9 @@ def __init__( # visibility mlp self.vis_mlp = VisField(self.num_inst) - # load initial mesh + # load initial mesh, define aabb self.init_proxy(geom_path, init_scale) - self.register_buffer("aabb", torch.zeros(2, 3)) - self.update_aabb(beta=0) + self.init_aabb() # non-parameters are not synchronized self.register_buffer("near_far", torch.zeros(len(rtmat), 2), persistent=False) @@ -235,16 +235,29 @@ def mlp_init(self): self.geometry_init(sdf_fn_torch) def init_proxy(self, geom_path, init_scale): - """Initialize the geometry from a mesh + """Initialize proxy geometry as a sphere Args: - geom_path (str): Initial shape mesh - init_scale (float): Geometry scale factor + geom_path (str): Unused + init_scale (float): Unused """ - mesh = trimesh.load(geom_path) + mesh = trimesh.load(geom_path[0]) mesh.vertices = mesh.vertices * init_scale self.proxy_geometry = mesh + def get_proxy_geometry(self): + """Get proxy geometry + + Returns: + proxy_geometry (Trimesh): Proxy geometry + """ + return self.proxy_geometry + + def init_aabb(self): + """Initialize axis-aligned bounding box""" + self.register_buffer("aabb", torch.zeros(2, 3)) + self.update_aabb(beta=0) + def geometry_init(self, sdf_fn, nsample=256): """Initialize SDF using tsdf-fused geometry if radius is not given. Otherwise, initialize sdf using a unit sphere @@ -265,7 +278,7 @@ def geometry_init(self, sdf_fn, nsample=256): inst_id = torch.randint(0, self.num_inst, (nsample,), device=device) # sample points - pts = self.sample_points_aabb(nsample, extend_factor=0.25) + pts, _, _ = self.sample_points_aabb(nsample, extend_factor=0.25) # get sdf from proxy geometry sdf_gt = sdf_fn(pts) @@ -312,7 +325,7 @@ def extract_canonical_mesh( grid_size (int): Marching cubes resolution level (float): Contour value to search for isosurfaces on the signed distance function - inst_id: (M,) Instance id. If None, extract for the average instance + inst_id: (int) Instance id. If None, extract for the average instance use_visibility (bool): If True, use visibility mlp to mask out invisible region. use_extend_aabb (bool): If True, extend aabb by 50% to get a loose proxy. @@ -322,12 +335,13 @@ def extract_canonical_mesh( """ if inst_id is not None: inst_id = torch.tensor([inst_id], device=next(self.parameters()).device) + aabb = self.get_aabb(inst_id=inst_id)[0] # 2,3 + else: + aabb = self.get_aabb() sdf_func = lambda xyz: self.forward(xyz, inst_id=inst_id, get_density=False) vis_func = lambda xyz: self.vis_mlp(xyz, inst_id=inst_id) > 0 if use_extend_aabb: - aabb = extend_aabb(self.aabb, factor=0.5) - else: - aabb = self.aabb + aabb = extend_aabb(aabb, factor=0.5) mesh = marching_cubes( sdf_func, aabb, @@ -338,6 +352,18 @@ def extract_canonical_mesh( ) return mesh + def get_aabb(self, inst_id=None): + """Get axis-aligned bounding box + Args: + inst_id: (N,) Instance id + Returns: + aabb: (2,3) Axis-aligned bounding box if inst_id is None, (N,2,3) otherwise + """ + if inst_id is None: + return self.aabb + else: + return self.aabb[None].repeat(len(inst_id), 1, 1) + def update_aabb(self, beta=0.9): """Update axis-aligned bounding box by interpolating with the current proxy geometry's bounds @@ -380,13 +406,16 @@ def sample_points_aabb(self, nsample, extend_factor=1.0): pts: (nsample, 3) Sampled points """ device = next(self.parameters()).device - aabb = extend_aabb(self.aabb, factor=extend_factor) + frame_id = torch.randint(0, self.num_frames, (nsample,), device=device) + inst_id = torch.randint(0, self.num_inst, (nsample,), device=device) + aabb = self.get_aabb(inst_id=inst_id) + aabb = extend_aabb(aabb, factor=extend_factor) pts = ( torch.rand(nsample, 3, dtype=torch.float32, device=device) - * (aabb[1:] - aabb[:1]) - + aabb[:1] + * (aabb[..., 1, :] - aabb[..., 0, :]) + + aabb[..., 0, :] ) - return pts + return pts, frame_id, inst_id def visibility_decay_loss(self, nsample=512): """Encourage visibility to be low at random points within the aabb. The @@ -398,9 +427,7 @@ def visibility_decay_loss(self, nsample=512): loss: (0,) Visibility decay loss """ # sample random points - device = next(self.parameters()).device - pts = self.sample_points_aabb(nsample) - inst_id = torch.randint(0, self.num_inst, (nsample,), device=device) + pts, _, inst_id = self.sample_points_aabb(nsample) # evaluate loss vis = self.vis_mlp(pts, inst_id=inst_id) @@ -491,7 +518,8 @@ def get_valid_idx(self, xyz, xyz_t=None, vis_score=None, samples_dict={}): valid_idx: (M,N,D) Visibility mask, bool """ # check whether the point is inside the aabb - aabb = extend_aabb(self.aabb) + aabb = self.get_aabb(samples_dict["inst_id"]) + aabb = extend_aabb(aabb) # (M,N,D), whether the point is inside the aabb inside_aabb = check_inside_aabb(xyz, aabb) @@ -505,7 +533,7 @@ def get_valid_idx(self, xyz, xyz_t=None, vis_score=None, samples_dict={}): )[1][0] t_aabb = torch.stack([t_bones.min(0)[0], t_bones.max(0)[0]], 0) t_aabb = extend_aabb(t_aabb, factor=1.0) - inside_aabb = check_inside_aabb(xyz_t, t_aabb) + inside_aabb = check_inside_aabb(xyz_t, t_aabb[None]) valid_idx = valid_idx & inside_aabb # temporally disable visibility mask @@ -546,10 +574,7 @@ def get_samples(self, Kinv, batch): near_far = self.near_far.to(device) near_far = near_far[batch["frameid"]] else: - corners = trimesh.bounds.corners(self.proxy_geometry.bounds) - corners = torch.tensor(corners, dtype=torch.float32, device=device) - field2cam_mat = quaternion_translation_to_se3(field2cam[0], field2cam[1]) - near_far = get_near_far(corners, field2cam_mat, tol_fac=1.5) + near_far = self.get_near_far(frame_id, field2cam) # auxiliary outputs samples_dict = {} @@ -564,6 +589,14 @@ def get_samples(self, Kinv, batch): samples_dict["feature"] = batch["feature"] return samples_dict + def get_near_far(self, frame_id, field2cam): + device = next(self.parameters()).device + corners = trimesh.bounds.corners(self.proxy_geometry.bounds) + corners = torch.tensor(corners, dtype=torch.float32, device=device) + field2cam_mat = quaternion_translation_to_se3(field2cam[0], field2cam[1]) + near_far = get_near_far(corners, field2cam_mat, tol_fac=1.5) + return near_far + def query_field(self, samples_dict, flow_thresh=None): """Render outputs from a neural radiance field. diff --git a/lab4d/nnutils/transformer.py b/lab4d/nnutils/transformer.py new file mode 100644 index 0000000..c010f51 --- /dev/null +++ b/lab4d/nnutils/transformer.py @@ -0,0 +1,191 @@ +# Taken from https://github.com/stelzner/srt/tree/main +import torch +import torch.nn as nn +import numpy as np + +import math +from einops import rearrange + + +class PositionalEncoding(nn.Module): + def __init__(self, num_octaves=8, start_octave=0): + super().__init__() + self.num_octaves = num_octaves + self.start_octave = start_octave + + def forward(self, coords, rays=None): + embed_fns = [] + batch_size, num_points, dim = coords.shape + + octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves) + octaves = octaves.float().to(coords) + multipliers = 2**octaves * math.pi + coords = coords.unsqueeze(-1) + while len(multipliers.shape) < len(coords.shape): + multipliers = multipliers.unsqueeze(0) + + scaled_coords = coords * multipliers + + sines = torch.sin(scaled_coords).reshape( + batch_size, num_points, dim * self.num_octaves + ) + cosines = torch.cos(scaled_coords).reshape( + batch_size, num_points, dim * self.num_octaves + ) + + result = torch.cat((sines, cosines), -1) + return result + + +class RayEncoder(nn.Module): + def __init__( + self, pos_octaves=8, pos_start_octave=0, ray_octaves=4, ray_start_octave=0 + ): + super().__init__() + self.pos_encoding = PositionalEncoding( + num_octaves=pos_octaves, start_octave=pos_start_octave + ) + self.ray_encoding = PositionalEncoding( + num_octaves=ray_octaves, start_octave=ray_start_octave + ) + + def forward(self, pos, rays): + if len(rays.shape) == 4: + batchsize, height, width, dims = rays.shape + pos_enc = self.pos_encoding(pos.unsqueeze(1)) + pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1) + pos_enc = pos_enc.repeat(1, 1, height, width) + rays = rays.flatten(1, 2) + + ray_enc = self.ray_encoding(rays) + ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1]) + ray_enc = ray_enc.permute((0, 3, 1, 2)) + x = torch.cat((pos_enc, ray_enc), 1) + else: + pos_enc = self.pos_encoding(pos) + ray_enc = self.ray_encoding(rays) + x = torch.cat((pos_enc, ray_enc), -1) + + return x + + +# Transformer implementation based on ViT +# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__( + self, dim, heads=8, dim_head=64, dropout=0.0, selfatt=True, kv_dim=None + ): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + + self.attend = nn.Softmax(dim=-1) + if selfatt: + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + else: + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out + else nn.Identity() + ) + + def forward(self, x, z=None): + if z is None: + qkv = self.to_qkv(x).chunk(3, dim=-1) + else: + q = self.to_q(x) + k, v = self.to_kv(z).chunk(2, dim=-1) + qkv = (q, k, v) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class Transformer(nn.Module): + def __init__( + self, + dim, + depth, + heads, + dim_head, + mlp_dim, + dropout=0.0, + selfatt=True, + kv_dim=None, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PreNorm( + dim, + Attention( + dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + selfatt=selfatt, + kv_dim=kv_dim, + ), + ), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), + ] + ) + ) + + def forward(self, x, z=None): + shape = x.shape[:-1] + if len(shape) > 2: + x = rearrange(x, "... c d -> (...) c d") + z = rearrange(z, "... c d -> (...) c d") + elif len(shape) == 1: + x = x.unsqueeze(1) + z = z.unsqueeze(1) + for attn, ff in self.layers: + x = attn(x, z=z) + x + x = ff(x) + x + + x = x.view(shape + (x.shape[-1],)) + return x diff --git a/lab4d/utils/geom_utils.py b/lab4d/utils/geom_utils.py index d79f924..0219936 100644 --- a/lab4d/utils/geom_utils.py +++ b/lab4d/utils/geom_utils.py @@ -402,14 +402,15 @@ def extend_aabb(aabb, factor=0.1): If aabb = [-1,1] and factor = 1, the extended aabb will be [-3,3] Args: - aabb: Axis-aligned bounding box, (2,3) + aabb: Axis-aligned bounding box, (...,2,3) factor (float): Amount to extend on each side Returns: - aabb_new: Extended aabb, (2,3) + aabb_new: Extended aabb, (...,2,3) """ aabb_new = aabb.clone() - aabb_new[0] = aabb[0] - (aabb[1] - aabb[0]) * factor - aabb_new[1] = aabb[1] + (aabb[1] - aabb[0]) * factor + size = (aabb[..., 1, :] - aabb[..., 0, :]) * factor + aabb_new[..., 0, :] = aabb[..., 0, :] - size + aabb_new[..., 1, :] = aabb[..., 1, :] + size return aabb_new @@ -498,11 +499,13 @@ def check_inside_aabb(xyz, aabb): """Return a mask of whether the input poins are inside the aabb Args: - xyz: (N,3) Points in object canonical space to query - aabb: (2,3) axis-aligned bounding box + xyz: (N,...,3) Points in object canonical space to query + aabb: (N,2,3) axis-aligned bounding box Returns: - inside_aabb: (N) Inside mask, bool + inside_aabb: (N,...) Inside mask, bool """ # check whether the point is inside the aabb - inside_aabb = ((xyz > aabb[:1]) & (xyz < aabb[1:])).all(-1) + shape = xyz.shape[:-1] + aabb = aabb.view((aabb.shape[0], 2) + (1,) * (len(shape) - 1) + (3,)) + inside_aabb = ((xyz > aabb[:, 0]) & (xyz < aabb[:, 1])).all(-1) return inside_aabb From c226478ce022bdc99544c7b5a0263b81ddf7b770 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Mon, 24 Jul 2023 14:53:34 -0400 Subject: [PATCH 02/86] slightly speed up of mmlp --- lab4d/nnutils/base.py | 59 ++++++++++++++++++++++++++++------------- lab4d/nnutils/bgnerf.py | 4 +-- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/lab4d/nnutils/base.py b/lab4d/nnutils/base.py index daad8b6..eb354af 100644 --- a/lab4d/nnutils/base.py +++ b/lab4d/nnutils/base.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from functorch import vmap, combine_state_for_ensemble from lab4d.nnutils.embedding import InstEmbedding from lab4d.nnutils.transformer import Transformer @@ -255,13 +254,18 @@ class MultiMLP(nn.Module): def __init__(self, num_inst, inst_channels=32, **kwargs): super(MultiMLP, self).__init__() + self.in_channels = kwargs["in_channels"] self.out_channels = kwargs["out_channels"] self.num_inst = num_inst + # ensemble version self.nets = [] for i in range(num_inst): self.nets.append(BaseMLP(**kwargs)) self.nets = nn.ModuleList(self.nets) + # # linear version + # self.linear = nn.Linear(kwargs["in_channels"], num_inst*kwargs["out_channels"]) + def forward(self, feat, inst_id): """ Args: @@ -272,32 +276,51 @@ def forward(self, feat, inst_id): """ # rearrange the batch dimension shape = feat.shape[:-1] - out = torch.zeros(shape + (self.out_channels,), device=feat.device).view( - -1, self.out_channels - ) - if inst_id is None: - return out.view(shape + (self.out_channels,)) - - feat = feat.view(-1, feat.shape[-1]) inst_id = inst_id.view((-1,) + (1,) * (len(shape) - 1)) inst_id = inst_id.expand(shape) - inst_id = inst_id.reshape(-1) - # do the real work - - empty_input = torch.zeros_like(feat[:1]) - for it in range(self.num_inst): + # # linear version: with duplicate computation + # out_stacked = self.linear(feat).view(shape + (self.out_channels, self.num_inst)) # (M, ..., self.out_channels, self.num_inst) + # # Construct an index tensor + # index = inst_id.unsqueeze(-1).expand(shape + (self.out_channels,)).unsqueeze(-1) + # # Gather elements from out_stacked using the index tensor + # out = torch.gather(out_stacked, -1, index).squeeze(-1) + # return out + + # # sequential version: with duplicate computation + # out = torch.zeros(shape + (self.out_channels,), device=feat.device) + # for it, net in enumerate(self.nets): + # id_sel = inst_id == it + # out[id_sel] = net(feat)[id_sel] + # return out + + # sequential version: avoid duplicate computation + out = torch.zeros(shape + (self.out_channels,), device=feat.device) + empty_input = torch.zeros(1,1,self.in_channels, device=feat.device) + for it, net in enumerate(self.nets): id_sel = inst_id == it if id_sel.sum() == 0: - # to avoid error in ddp - # Expected to have finished reduction in the prior iteration before starting a new one. out = out + self.nets[it](empty_input).mean() * 0 continue - x_sel = feat[id_sel] - out[id_sel] = self.nets[it](x_sel) - out = out.view(shape + (self.out_channels,)) + out[id_sel] = net(feat[id_sel]) return out + # # parallel version with issue with ddp; slow + # def wrapper(params, buffers, data): + # return torch.func.functional_call(self.nets[0], (params, buffers), (data,)) + + # params, buffers = torch.func.stack_module_state(self.nets) + # out_stacked = torch.vmap(wrapper, (0, 0, None))(params, buffers, feat) + + # # Construct an index tensor + # index = inst_id.unsqueeze(-1).expand(shape + (self.out_channels,)).unsqueeze(0) + # # Gather elements from out_stacked using the index tensor + # out = torch.gather(out_stacked, 0, index).squeeze(0) + + # empty_input = torch.zeros(1,1,self.in_channels, device=feat.device) + # for net in self.nets: + # out += net(empty_input).mean()*0 + # return out class MixMLP(nn.Module): """Mixing CondMLP and MultiMLP""" diff --git a/lab4d/nnutils/bgnerf.py b/lab4d/nnutils/bgnerf.py index 4e72564..0ae26c5 100644 --- a/lab4d/nnutils/bgnerf.py +++ b/lab4d/nnutils/bgnerf.py @@ -15,8 +15,8 @@ class BGNeRF(NeRF): """A static neural radiance field with an MLP backbone. Specialized to background.""" # def __init__(self, data_info, field_arch=CondTransformerMLP, D=5, W=128, **kwargs): - # def __init__(self, data_info, field_arch=MixMLP, D=1, W=64, **kwargs): - def __init__(self, data_info, field_arch=CondMLP, D=5, W=128, **kwargs): + # def __init__(self, data_info, field_arch=CondMLP, D=5, W=128, **kwargs): + def __init__(self, data_info, field_arch=MixMLP, D=1, W=64, **kwargs): super(BGNeRF, self).__init__( data_info, field_arch=field_arch, D=D, W=W, **kwargs ) From acb391db96807f6a311adc5ed3330475625d84cb Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Mon, 24 Jul 2023 15:08:43 -0400 Subject: [PATCH 03/86] remove unused files --- lab4d/nnutils/base.py | 94 +---------------- lab4d/nnutils/bgnerf.py | 12 +-- lab4d/nnutils/transformer.py | 191 ----------------------------------- 3 files changed, 7 insertions(+), 290 deletions(-) delete mode 100644 lab4d/nnutils/transformer.py diff --git a/lab4d/nnutils/base.py b/lab4d/nnutils/base.py index eb354af..0d2a3e1 100644 --- a/lab4d/nnutils/base.py +++ b/lab4d/nnutils/base.py @@ -4,7 +4,6 @@ import torch.nn.functional as F from lab4d.nnutils.embedding import InstEmbedding -from lab4d.nnutils.transformer import Transformer class ScaleLayer(nn.Module): @@ -159,96 +158,6 @@ def get_dim_inst(num_inst, inst_channels): return 0 -class CondTransformerMLP(BaseMLP): - """A MLP that accepts both input `x` and condition `c` - - Args: - num_inst (int): Number of distinct object instances. If --nosingle_inst - is passed, this is equal to the number of videos, as we assume each - video captures a different instance. Otherwise, we assume all videos - capture the same instance and set this to 1. - D (int): Number of linear layers for density (sigma) encoder - W (int): Number of hidden units in each MLP layer - in_channels (int): Number of channels in input `x` - inst_channels (int): Number of channels in condition `c` - out_channels (int): Number of output channels - skips (List(int)): List of layers to add skip connections at - activation (Function): Activation function to use (e.g. nn.ReLU()) - final_act (bool): If True, apply the activation function to the output - """ - - def __init__( - self, - num_inst, - D=8, - W=256, - in_channels=63, - inst_channels=32, - out_channels=3, - skips=[4], - activation=nn.ReLU(True), - final_act=False, - ): - inst_channels = 768 - super().__init__( - D=D, - W=W, - in_channels=in_channels, - out_channels=out_channels, - skips=skips, - activation=activation, - final_act=final_act, - ) - self.inst_embedding = InstEmbedding(num_inst, inst_channels) - - self.transformer = Transformer( - in_channels, - depth=1, - heads=12, - dim_head=inst_channels // 12, - mlp_dim=inst_channels * 2, - selfatt=False, - kv_dim=inst_channels, - ) - - def forward(self, feat, inst_id): - """ - Args: - feat: (M, ..., self.in_channels) - inst_id: (M,) Instance id, or None to use the average instance - Returns: - out: (M, ..., self.out_channels) - """ - if inst_id is None: - if self.inst_embedding.out_channels > 0: - inst_code = self.inst_embedding.get_mean_embedding() - inst_code = inst_code.expand(feat.shape[:-1] + (-1,)) - # print("inst_embedding exists but inst_id is None, using mean inst_code") - else: - # empty, falls back to single-instance NeRF - inst_code = torch.zeros(feat.shape[:-1] + (0,), device=feat.device) - else: - inst_code = self.inst_embedding(inst_id) - inst_code = inst_code.view( - inst_code.shape[:1] + (1,) * (feat.ndim - 2) + (-1,) - ) - inst_code = inst_code.expand(feat.shape[:-1] + (-1,)) - - # feat = torch.cat([feat, inst_code], -1) - # if both input feature and inst_code are empty, return zeros - if feat.shape[-1] == 0 and inst_code.shape[-1] == 0: - return feat - feat = self.transformer(feat, inst_code) - return super().forward(feat) - - @staticmethod - def get_dim_inst(num_inst, inst_channels): - if num_inst > 1: - return inst_channels - else: - return 0 - - class MultiMLP(nn.Module): """Independent MLP for each instance""" @@ -296,7 +205,7 @@ def forward(self, feat, inst_id): # sequential version: avoid duplicate computation out = torch.zeros(shape + (self.out_channels,), device=feat.device) - empty_input = torch.zeros(1,1,self.in_channels, device=feat.device) + empty_input = torch.zeros(1, 1, self.in_channels, device=feat.device) for it, net in enumerate(self.nets): id_sel = inst_id == it if id_sel.sum() == 0: @@ -322,6 +231,7 @@ def forward(self, feat, inst_id): # out += net(empty_input).mean()*0 # return out + class MixMLP(nn.Module): """Mixing CondMLP and MultiMLP""" diff --git a/lab4d/nnutils/bgnerf.py b/lab4d/nnutils/bgnerf.py index 0ae26c5..3c40f96 100644 --- a/lab4d/nnutils/bgnerf.py +++ b/lab4d/nnutils/bgnerf.py @@ -7,21 +7,20 @@ from lab4d.utils.quat_transform import quaternion_translation_to_se3 from lab4d.utils.geom_utils import get_near_far -from lab4d.nnutils.base import MixMLP, MultiMLP, CondMLP, CondTransformerMLP +from lab4d.nnutils.base import MixMLP, MultiMLP, CondMLP from lab4d.nnutils.visibility import VisField class BGNeRF(NeRF): """A static neural radiance field with an MLP backbone. Specialized to background.""" - # def __init__(self, data_info, field_arch=CondTransformerMLP, D=5, W=128, **kwargs): # def __init__(self, data_info, field_arch=CondMLP, D=5, W=128, **kwargs): def __init__(self, data_info, field_arch=MixMLP, D=1, W=64, **kwargs): super(BGNeRF, self).__init__( data_info, field_arch=field_arch, D=D, W=W, **kwargs ) - # TODO: update beta - # TODO: update scale + # TODO: update per-scene beta + # TODO: update per-scene scale def init_proxy(self, geom_paths, init_scale): """Initialize the geometry from a mesh @@ -32,9 +31,8 @@ def init_proxy(self, geom_paths, init_scale): """ meshes = [] for geom_path in geom_paths: - mesh = trimesh.creation.uv_sphere(radius=0.12, count=[4, 4]) - # mesh = trimesh.load(geom_path) - # mesh.vertices = mesh.vertices * init_scale + mesh = trimesh.load(geom_path) + mesh.vertices = mesh.vertices * init_scale meshes.append(mesh) self.proxy_geometry = meshes diff --git a/lab4d/nnutils/transformer.py b/lab4d/nnutils/transformer.py deleted file mode 100644 index c010f51..0000000 --- a/lab4d/nnutils/transformer.py +++ /dev/null @@ -1,191 +0,0 @@ -# Taken from https://github.com/stelzner/srt/tree/main -import torch -import torch.nn as nn -import numpy as np - -import math -from einops import rearrange - - -class PositionalEncoding(nn.Module): - def __init__(self, num_octaves=8, start_octave=0): - super().__init__() - self.num_octaves = num_octaves - self.start_octave = start_octave - - def forward(self, coords, rays=None): - embed_fns = [] - batch_size, num_points, dim = coords.shape - - octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves) - octaves = octaves.float().to(coords) - multipliers = 2**octaves * math.pi - coords = coords.unsqueeze(-1) - while len(multipliers.shape) < len(coords.shape): - multipliers = multipliers.unsqueeze(0) - - scaled_coords = coords * multipliers - - sines = torch.sin(scaled_coords).reshape( - batch_size, num_points, dim * self.num_octaves - ) - cosines = torch.cos(scaled_coords).reshape( - batch_size, num_points, dim * self.num_octaves - ) - - result = torch.cat((sines, cosines), -1) - return result - - -class RayEncoder(nn.Module): - def __init__( - self, pos_octaves=8, pos_start_octave=0, ray_octaves=4, ray_start_octave=0 - ): - super().__init__() - self.pos_encoding = PositionalEncoding( - num_octaves=pos_octaves, start_octave=pos_start_octave - ) - self.ray_encoding = PositionalEncoding( - num_octaves=ray_octaves, start_octave=ray_start_octave - ) - - def forward(self, pos, rays): - if len(rays.shape) == 4: - batchsize, height, width, dims = rays.shape - pos_enc = self.pos_encoding(pos.unsqueeze(1)) - pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1) - pos_enc = pos_enc.repeat(1, 1, height, width) - rays = rays.flatten(1, 2) - - ray_enc = self.ray_encoding(rays) - ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1]) - ray_enc = ray_enc.permute((0, 3, 1, 2)) - x = torch.cat((pos_enc, ray_enc), 1) - else: - pos_enc = self.pos_encoding(pos) - ray_enc = self.ray_encoding(rays) - x = torch.cat((pos_enc, ray_enc), -1) - - return x - - -# Transformer implementation based on ViT -# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py - - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - - -class FeedForward(nn.Module): - def __init__(self, dim, hidden_dim, dropout=0.0): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim, dim), - nn.Dropout(dropout), - ) - - def forward(self, x): - return self.net(x) - - -class Attention(nn.Module): - def __init__( - self, dim, heads=8, dim_head=64, dropout=0.0, selfatt=True, kv_dim=None - ): - super().__init__() - inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == dim) - - self.heads = heads - self.scale = dim_head**-0.5 - - self.attend = nn.Softmax(dim=-1) - if selfatt: - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) - else: - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(kv_dim, inner_dim * 2, bias=False) - - self.to_out = ( - nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) - if project_out - else nn.Identity() - ) - - def forward(self, x, z=None): - if z is None: - qkv = self.to_qkv(x).chunk(3, dim=-1) - else: - q = self.to_q(x) - k, v = self.to_kv(z).chunk(2, dim=-1) - qkv = (q, k, v) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) - - dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - - attn = self.attend(dots) - - out = torch.matmul(attn, v) - out = rearrange(out, "b h n d -> b n (h d)") - return self.to_out(out) - - -class Transformer(nn.Module): - def __init__( - self, - dim, - depth, - heads, - dim_head, - mlp_dim, - dropout=0.0, - selfatt=True, - kv_dim=None, - ): - super().__init__() - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - PreNorm( - dim, - Attention( - dim, - heads=heads, - dim_head=dim_head, - dropout=dropout, - selfatt=selfatt, - kv_dim=kv_dim, - ), - ), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), - ] - ) - ) - - def forward(self, x, z=None): - shape = x.shape[:-1] - if len(shape) > 2: - x = rearrange(x, "... c d -> (...) c d") - z = rearrange(z, "... c d -> (...) c d") - elif len(shape) == 1: - x = x.unsqueeze(1) - z = z.unsqueeze(1) - for attn, ff in self.layers: - x = attn(x, z=z) + x - x = ff(x) + x - - x = x.view(shape + (x.shape[-1],)) - return x From 4b58b711e36efe4851b093723a9c46b9293ad446 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Mon, 24 Jul 2023 21:23:44 -0400 Subject: [PATCH 04/86] fix export --- lab4d/export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lab4d/export.py b/lab4d/export.py index 2490afc..304aac5 100644 --- a/lab4d/export.py +++ b/lab4d/export.py @@ -68,7 +68,7 @@ def extract_deformation(field, mesh_rest, inst_id, render_length): field2cam = field.camera_mlp.get_vals(frame_id_torch) samples_dict = {} - if isinstance(field.warp, SkinningWarp): + if hasattr(field, "warp") and isinstance(field.warp, SkinningWarp): ( samples_dict["t_articulation"], samples_dict["rest_articulation"], @@ -115,7 +115,7 @@ def extract_deformation(field, mesh_rest, inst_id, render_length): ) motion_tuples[frame_id] = motion_expl - if isinstance(field.warp, SkinningWarp): + if hasattr(field, "warp") and isinstance(field.warp, SkinningWarp): # modify rest mesh based on instance morphological changes on bones # idendity transformation of cameras field2cam_rot_idn = torch.zeros_like(field2cam[0]) From d6868ebd6eff6eac131a06352818d41dfeae3411 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Wed, 26 Jul 2023 18:03:11 -0400 Subject: [PATCH 05/86] update ensemble --- lab4d/nnutils/base.py | 46 +++++++++++++++++++++++++++++++++++++++++ lab4d/nnutils/bgnerf.py | 4 +++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/lab4d/nnutils/base.py b/lab4d/nnutils/base.py index 0d2a3e1..534d4c6 100644 --- a/lab4d/nnutils/base.py +++ b/lab4d/nnutils/base.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from lab4d.nnutils.embedding import InstEmbedding +from functorch import vmap, combine_state_for_ensemble class ScaleLayer(nn.Module): @@ -158,6 +159,18 @@ def get_dim_inst(num_inst, inst_channels): return 0 +class Ensemble(nn.Module): + def __init__(self, modules, **kwargs): + super().__init__() + fmodel, self.params, self.buffers = combine_state_for_ensemble(modules) + self.vmap_model = vmap(fmodel, **kwargs) + self.params = nn.ParameterList([nn.Parameter(p) for p in self.params]) + + def forward(self, *args, **kwargs): + params = [i for i in self.params] + return self.vmap_model(params, self.buffers, *args, **kwargs) + + class MultiMLP(nn.Module): """Independent MLP for each instance""" @@ -170,6 +183,7 @@ def __init__(self, num_inst, inst_channels=32, **kwargs): self.nets = [] for i in range(num_inst): self.nets.append(BaseMLP(**kwargs)) + # self.nets = Ensemble(self.nets) self.nets = nn.ModuleList(self.nets) # # linear version @@ -185,6 +199,7 @@ def forward(self, feat, inst_id): """ # rearrange the batch dimension shape = feat.shape[:-1] + device = feat.device inst_id = inst_id.view((-1,) + (1,) * (len(shape) - 1)) inst_id = inst_id.expand(shape) @@ -231,6 +246,37 @@ def forward(self, feat, inst_id): # out += net(empty_input).mean()*0 # return out + # # parallel version, slow due to variable size index + # num_elem = inst_id.numel() + # feat = feat.view(num_elem, -1) + # inst_id = inst_id.reshape(-1) + + # # in: feat: K, in_channels + # # in: inst_id: K + # # Get the counts for each instance + # counts = torch.bincount(inst_id) + # max_bs = counts.max() + + # feat_padded = torch.zeros((self.num_inst, max_bs, feat.shape[1]), device=device) + # id_sels = torch.stack([(inst_id == it) for it in range(self.num_inst)], 0) + # # feat_paded: N, M, in_channels + # # id_sels, N, K, 1 + # # feat: K, in_channels + # for it in range(self.num_inst): + # feal_sel = feat[id_sels[it]] + # feat_padded[it, : feal_sel.shape[0]] = feal_sel + + # # run network + # out_padded = self.nets(feat_padded) + + # # out_padded: N,K, ..., out_channels + # out = torch.zeros((num_elem, self.out_channels), device=feat.device) + # for it in range(self.num_inst): + # id_sel = id_sels[it] + # out[id_sel] = out_padded[it][: id_sel.sum()] + # out = out.reshape(shape + (self.out_channels,)) + # return out + class MixMLP(nn.Module): """Mixing CondMLP and MultiMLP""" diff --git a/lab4d/nnutils/bgnerf.py b/lab4d/nnutils/bgnerf.py index 3c40f96..0ba2120 100644 --- a/lab4d/nnutils/bgnerf.py +++ b/lab4d/nnutils/bgnerf.py @@ -14,7 +14,9 @@ class BGNeRF(NeRF): """A static neural radiance field with an MLP backbone. Specialized to background.""" - # def __init__(self, data_info, field_arch=CondMLP, D=5, W=128, **kwargs): + # def __init__( + # self, data_info, field_arch=CondMLP, D=8, W=256, inst_channels=256, **kwargs + # ): def __init__(self, data_info, field_arch=MixMLP, D=1, W=64, **kwargs): super(BGNeRF, self).__init__( data_info, field_arch=field_arch, D=D, W=W, **kwargs From 8d28d061f8e19ab2df93212e83ae0b20ed1b9821 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Sun, 30 Jul 2023 17:09:56 -0400 Subject: [PATCH 06/86] clean up --- lab4d/nnutils/base.py | 79 ------------------------------------ lab4d/nnutils/multifields.py | 2 +- lab4d/nnutils/nerf.py | 8 ++-- lab4d/utils/geom_utils.py | 6 +-- 4 files changed, 8 insertions(+), 87 deletions(-) diff --git a/lab4d/nnutils/base.py b/lab4d/nnutils/base.py index 534d4c6..67f7ef0 100644 --- a/lab4d/nnutils/base.py +++ b/lab4d/nnutils/base.py @@ -159,18 +159,6 @@ def get_dim_inst(num_inst, inst_channels): return 0 -class Ensemble(nn.Module): - def __init__(self, modules, **kwargs): - super().__init__() - fmodel, self.params, self.buffers = combine_state_for_ensemble(modules) - self.vmap_model = vmap(fmodel, **kwargs) - self.params = nn.ParameterList([nn.Parameter(p) for p in self.params]) - - def forward(self, *args, **kwargs): - params = [i for i in self.params] - return self.vmap_model(params, self.buffers, *args, **kwargs) - - class MultiMLP(nn.Module): """Independent MLP for each instance""" @@ -183,12 +171,8 @@ def __init__(self, num_inst, inst_channels=32, **kwargs): self.nets = [] for i in range(num_inst): self.nets.append(BaseMLP(**kwargs)) - # self.nets = Ensemble(self.nets) self.nets = nn.ModuleList(self.nets) - # # linear version - # self.linear = nn.Linear(kwargs["in_channels"], num_inst*kwargs["out_channels"]) - def forward(self, feat, inst_id): """ Args: @@ -203,21 +187,6 @@ def forward(self, feat, inst_id): inst_id = inst_id.view((-1,) + (1,) * (len(shape) - 1)) inst_id = inst_id.expand(shape) - # # linear version: with duplicate computation - # out_stacked = self.linear(feat).view(shape + (self.out_channels, self.num_inst)) # (M, ..., self.out_channels, self.num_inst) - # # Construct an index tensor - # index = inst_id.unsqueeze(-1).expand(shape + (self.out_channels,)).unsqueeze(-1) - # # Gather elements from out_stacked using the index tensor - # out = torch.gather(out_stacked, -1, index).squeeze(-1) - # return out - - # # sequential version: with duplicate computation - # out = torch.zeros(shape + (self.out_channels,), device=feat.device) - # for it, net in enumerate(self.nets): - # id_sel = inst_id == it - # out[id_sel] = net(feat)[id_sel] - # return out - # sequential version: avoid duplicate computation out = torch.zeros(shape + (self.out_channels,), device=feat.device) empty_input = torch.zeros(1, 1, self.in_channels, device=feat.device) @@ -229,54 +198,6 @@ def forward(self, feat, inst_id): out[id_sel] = net(feat[id_sel]) return out - # # parallel version with issue with ddp; slow - # def wrapper(params, buffers, data): - # return torch.func.functional_call(self.nets[0], (params, buffers), (data,)) - - # params, buffers = torch.func.stack_module_state(self.nets) - # out_stacked = torch.vmap(wrapper, (0, 0, None))(params, buffers, feat) - - # # Construct an index tensor - # index = inst_id.unsqueeze(-1).expand(shape + (self.out_channels,)).unsqueeze(0) - # # Gather elements from out_stacked using the index tensor - # out = torch.gather(out_stacked, 0, index).squeeze(0) - - # empty_input = torch.zeros(1,1,self.in_channels, device=feat.device) - # for net in self.nets: - # out += net(empty_input).mean()*0 - # return out - - # # parallel version, slow due to variable size index - # num_elem = inst_id.numel() - # feat = feat.view(num_elem, -1) - # inst_id = inst_id.reshape(-1) - - # # in: feat: K, in_channels - # # in: inst_id: K - # # Get the counts for each instance - # counts = torch.bincount(inst_id) - # max_bs = counts.max() - - # feat_padded = torch.zeros((self.num_inst, max_bs, feat.shape[1]), device=device) - # id_sels = torch.stack([(inst_id == it) for it in range(self.num_inst)], 0) - # # feat_paded: N, M, in_channels - # # id_sels, N, K, 1 - # # feat: K, in_channels - # for it in range(self.num_inst): - # feal_sel = feat[id_sels[it]] - # feat_padded[it, : feal_sel.shape[0]] = feal_sel - - # # run network - # out_padded = self.nets(feat_padded) - - # # out_padded: N,K, ..., out_channels - # out = torch.zeros((num_elem, self.out_channels), device=feat.device) - # for it in range(self.num_inst): - # id_sel = id_sels[it] - # out[id_sel] = out_padded[it][: id_sel.sum()] - # out = out.reshape(shape + (self.out_channels,)) - # return out - class MixMLP(nn.Module): """Mixing CondMLP and MultiMLP""" diff --git a/lab4d/nnutils/multifields.py b/lab4d/nnutils/multifields.py index e2447fa..8763a35 100644 --- a/lab4d/nnutils/multifields.py +++ b/lab4d/nnutils/multifields.py @@ -426,5 +426,5 @@ def get_aabb(self): """ aabb = {} for cate, field in self.field_params.items(): - aabb[cate] = field.aabb / field.logscale.exp() + aabb[cate] = field.get_aabb() / field.logscale.exp() return aabb diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index ae1f3ac..3eb99c9 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -235,11 +235,11 @@ def mlp_init(self): self.geometry_init(sdf_fn_torch) def init_proxy(self, geom_path, init_scale): - """Initialize proxy geometry as a sphere - + """Initialize the geometry from a mesh + Args: - geom_path (str): Unused - init_scale (float): Unused + geom_path (List(str)): paths to initial shape mesh + init_scale (float): Geometry scale factor """ mesh = trimesh.load(geom_path[0]) mesh.vertices = mesh.vertices * init_scale diff --git a/lab4d/utils/geom_utils.py b/lab4d/utils/geom_utils.py index 0219936..d0543dc 100644 --- a/lab4d/utils/geom_utils.py +++ b/lab4d/utils/geom_utils.py @@ -402,10 +402,10 @@ def extend_aabb(aabb, factor=0.1): If aabb = [-1,1] and factor = 1, the extended aabb will be [-3,3] Args: - aabb: Axis-aligned bounding box, (...,2,3) + aabb: Axis-aligned bounding box, ((N,)2,3) factor (float): Amount to extend on each side Returns: - aabb_new: Extended aabb, (...,2,3) + aabb_new: Extended aabb, ((N,)2,3) """ aabb_new = aabb.clone() size = (aabb[..., 1, :] - aabb[..., 0, :]) * factor @@ -502,7 +502,7 @@ def check_inside_aabb(xyz, aabb): xyz: (N,...,3) Points in object canonical space to query aabb: (N,2,3) axis-aligned bounding box Returns: - inside_aabb: (N,...) Inside mask, bool + inside_aabb: (N,...,) Inside mask, bool """ # check whether the point is inside the aabb shape = xyz.shape[:-1] From dfe367e98d64400bc91fac4fd0a47f397dc12283 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Mon, 31 Jul 2023 21:29:23 -0400 Subject: [PATCH 07/86] dict_mlp; independent vis_field --- lab4d/export.py | 3 +- lab4d/nnutils/base.py | 124 ++++++++++++++++++++++++++++++++++++ lab4d/nnutils/bgnerf.py | 11 ++-- lab4d/nnutils/nerf.py | 4 +- lab4d/nnutils/visibility.py | 3 +- 5 files changed, 136 insertions(+), 9 deletions(-) diff --git a/lab4d/export.py b/lab4d/export.py index 304aac5..d9e0d3f 100644 --- a/lab4d/export.py +++ b/lab4d/export.py @@ -38,6 +38,7 @@ class ExportMeshFlags: flags.DEFINE_float( "level", 0.0, "contour value of marching cubes use to search for isosurfaces" ) + flags.DEFINE_boolean("use_visibility", False, "use visibility to remove extra pts") class MotionParamsExpl(NamedTuple): @@ -179,7 +180,7 @@ def extract_motion_params(model, opts, data_info): grid_size=opts["grid_size"], level=opts["level"], inst_id=opts["inst_id"], - use_visibility=False, + use_visibility=opts["use_visibility"], use_extend_aabb=False, ) diff --git a/lab4d/nnutils/base.py b/lab4d/nnutils/base.py index 67f7ef0..28b11b6 100644 --- a/lab4d/nnutils/base.py +++ b/lab4d/nnutils/base.py @@ -159,6 +159,105 @@ def get_dim_inst(num_inst, inst_channels): return 0 +# class PosEncArch(nn.Module): +# def __init__(self, in_channels, N_freqs) -> None: +# super().__init__() +# self.pos_embedding = PosEmbedding(in_channels, N_freqs) + + +class DictMLP(BaseMLP): + """A MLP that accepts both input `x` and condition `c` + + Args: + num_inst (int): Number of distinct object instances. If --nosingle_inst + is passed, this is equal to the number of videos, as we assume each + video captures a different instance. Otherwise, we assume all videos + capture the same instance and set this to 1. + D (int): Number of linear layers for density (sigma) encoder + W (int): Number of hidden units in each MLP layer + in_channels (int): Number of channels in input `x` + inst_channels (int): Number of channels in condition `c` + out_channels (int): Number of output channels + skips (List(int)): List of layers to add skip connections at + activation (Function): Activation function to use (e.g. nn.ReLU()) + final_act (bool): If True, apply the activation function to the output + """ + + def __init__( + self, + num_inst, + D=8, + W=256, + in_channels=63, + inst_channels=32, + out_channels=3, + skips=[4], + activation=nn.ReLU(True), + final_act=False, + ): + super().__init__( + D=D, + W=W, + in_channels=in_channels + inst_channels, + out_channels=out_channels, + skips=skips, + activation=activation, + final_act=False, + ) + + self.basis = BaseMLP( + D=D, + W=W, + in_channels=in_channels, + out_channels=out_channels, + skips=skips, + activation=activation, + final_act=final_act, + ) + + self.inst_embedding = InstEmbedding(num_inst, inst_channels) + + def forward(self, feat, inst_id): + """ + Args: + feat: (M, ..., self.in_channels) + inst_id: (M,) Instance id, or None to use the average instance + Returns: + out: (M, ..., self.out_channels) + """ + if inst_id is None: + if self.inst_embedding.out_channels > 0: + inst_code = self.inst_embedding.get_mean_embedding() + inst_code = inst_code.expand(feat.shape[:-1] + (-1,)) + # print("inst_embedding exists but inst_id is None, using mean inst_code") + else: + # empty, falls back to single-instance NeRF + inst_code = torch.zeros(feat.shape[:-1] + (0,), device=feat.device) + else: + inst_code = self.inst_embedding(inst_id) + inst_code = inst_code.view( + inst_code.shape[:1] + (1,) * (feat.ndim - 2) + (-1,) + ) + inst_code = inst_code.expand(feat.shape[:-1] + (-1,)) + + out = torch.cat([feat, inst_code], -1) + # if both input feature and inst_code are empty, return zeros + if out.shape[-1] == 0: + return out + coeff = super().forward(out) + coeff = F.normalize(coeff, dim=-1) + basis = self.basis(feat) + out = coeff * basis + return out + + @staticmethod + def get_dim_inst(num_inst, inst_channels): + if num_inst > 1: + return inst_channels + else: + return 0 + + class MultiMLP(nn.Module): """Independent MLP for each instance""" @@ -214,3 +313,28 @@ def forward(self, feat, inst_id): out2 = self.multimlp(feat, inst_id) out = out1 + out2 return out + + +# class Triplane(nn.Module): +# """Triplane""" + +# def __init__(self, num_inst, inst_channels=32, **kwargs) -> None: +# super(Triplane, self).__init__() +# init_scale = 0.1 +# resolution = 128 +# num_components = 24 +# self.plane = nn.Parameter( +# init_scale * torch.randn((3 * resolution * resolution, num_components)) +# ) + +# def forward(self, feat, inst_id): +# """ +# Args: +# feat: (M, ..., self.in_channels) +# inst_id: (M,) Instance id, or None to use the average instance +# Returns: +# out: (M, ..., self.out_channels) +# """ +# # rearrange the batch dimension +# shape = feat.shape[:-1] +# return out diff --git a/lab4d/nnutils/bgnerf.py b/lab4d/nnutils/bgnerf.py index 0ba2120..63203db 100644 --- a/lab4d/nnutils/bgnerf.py +++ b/lab4d/nnutils/bgnerf.py @@ -7,20 +7,21 @@ from lab4d.utils.quat_transform import quaternion_translation_to_se3 from lab4d.utils.geom_utils import get_near_far -from lab4d.nnutils.base import MixMLP, MultiMLP, CondMLP +from lab4d.nnutils.base import MixMLP, MultiMLP, CondMLP, DictMLP from lab4d.nnutils.visibility import VisField class BGNeRF(NeRF): """A static neural radiance field with an MLP backbone. Specialized to background.""" - # def __init__( - # self, data_info, field_arch=CondMLP, D=8, W=256, inst_channels=256, **kwargs - # ): - def __init__(self, data_info, field_arch=MixMLP, D=1, W=64, **kwargs): + # def __init__(self, data_info, field_arch=CondMLP, D=5, W=128, **kwargs): + # def __init__(self, data_info, field_arch=MixMLP, D=1, W=64, **kwargs): + def __init__(self, data_info, field_arch=DictMLP, D=8, W=256, **kwargs): super(BGNeRF, self).__init__( data_info, field_arch=field_arch, D=D, W=W, **kwargs ) + # self.vis_mlp = VisField(self.num_inst, D=D, W=W, field_arch=field_arch) + self.vis_mlp = VisField(self.num_inst, D=1, W=64, field_arch=MixMLP) # TODO: update per-scene beta # TODO: update per-scene scale diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index 3eb99c9..1a39291 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -152,7 +152,7 @@ def __init__( self.camera_mlp = CameraMLP(rtmat, frame_info=frame_info) # visibility mlp - self.vis_mlp = VisField(self.num_inst) + self.vis_mlp = VisField(self.num_inst, field_arch=field_arch) # load initial mesh, define aabb self.init_proxy(geom_path, init_scale) @@ -236,7 +236,7 @@ def mlp_init(self): def init_proxy(self, geom_path, init_scale): """Initialize the geometry from a mesh - + Args: geom_path (List(str)): paths to initial shape mesh init_scale (float): Geometry scale factor diff --git a/lab4d/nnutils/visibility.py b/lab4d/nnutils/visibility.py index 437046c..f70e477 100644 --- a/lab4d/nnutils/visibility.py +++ b/lab4d/nnutils/visibility.py @@ -31,6 +31,7 @@ def __init__( inst_channels=32, skips=[4], activation=nn.ReLU(True), + field_arch=CondMLP, ): super().__init__() @@ -38,7 +39,7 @@ def __init__( self.pos_embedding = PosEmbedding(3, num_freq_xyz) # xyz encoding layers - self.basefield = CondMLP( + self.basefield = field_arch( num_inst=num_inst, D=D, W=W, From 0d6572e60b9f61745f54b034b2d9479bdbd3cfe3 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Tue, 1 Aug 2023 17:13:59 -0400 Subject: [PATCH 08/86] update project ppr --- .gitmodules | 3 +++ projects/ppr/ppr-diffphys | 1 + 2 files changed, 4 insertions(+) create mode 160000 projects/ppr/ppr-diffphys diff --git a/.gitmodules b/.gitmodules index 0c28cdf..dee0880 100644 --- a/.gitmodules +++ b/.gitmodules @@ -12,3 +12,6 @@ [submodule "docs/pytorch_sphinx_theme"] path = docs/pytorch_sphinx_theme url = https://github.com/gengshan-y/pytorch_sphinx_theme +[submodule "projects/ppr/ppr-diffphys"] + path = projects/ppr/ppr-diffphys + url = git@github.com:gengshan-y/ppr-diffphys.git diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys new file mode 160000 index 0000000..b77648d --- /dev/null +++ b/projects/ppr/ppr-diffphys @@ -0,0 +1 @@ +Subproject commit b77648dde75d8e9219324e8cffec3d701a809333 From 9700a999b36c42af6f8ac3d48c3f021246a5a1b7 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Tue, 1 Aug 2023 17:14:47 -0400 Subject: [PATCH 09/86] update ppr --- .gitignore | 1 - lab4d/engine/trainer.py | 1 - projects/ppr/config.py | 19 +++++++++ projects/ppr/ppr-diffphys | 2 +- projects/ppr/train.py | 19 +++++++++ projects/ppr/trainer.py | 90 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 projects/ppr/config.py create mode 100644 projects/ppr/train.py create mode 100644 projects/ppr/trainer.py diff --git a/.gitignore b/.gitignore index 984795d..f263e3b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -projects viewer run.sh run-long.sh diff --git a/lab4d/engine/trainer.py b/lab4d/engine/trainer.py index f1cc640..0e08aea 100644 --- a/lab4d/engine/trainer.py +++ b/lab4d/engine/trainer.py @@ -88,7 +88,6 @@ def define_dataset(self): def init_model(self): """Initialize camera transforms, geometry, articulations, and camera intrinsics from external priors, if this is the first run""" - opts = self.opts # init mlp if get_local_rank() == 0: self.model.mlp_init() diff --git a/projects/ppr/config.py b/projects/ppr/config.py new file mode 100644 index 0000000..9574390 --- /dev/null +++ b/projects/ppr/config.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import os + +from absl import flags + +opts = flags.FLAGS + + +class PPRConfig: + # configs related to ppr + flags.DEFINE_string( + "urdf_template", "wolf_mod", "whether to use predefined skeleton" + ) + flags.DEFINE_float("ratio_phys_cycle", 0.2, "number of iterations per round") + flags.DEFINE_integer("phys_wdw_len", 24, "length of the physics opt window") + flags.DEFINE_integer("phys_batch", 20, "number of parallel physics sim") + flags.DEFINE_string( + "phys_vid", "0", "whether to optimize selected videos, e.g., 0,1,2" + ) diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index b77648d..ec7d7b3 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit b77648dde75d8e9219324e8cffec3d701a809333 +Subproject commit ec7d7b34a962a81b4faee27ee5ae251a28f3d17f diff --git a/projects/ppr/train.py b/projects/ppr/train.py new file mode 100644 index 0000000..6c48182 --- /dev/null +++ b/projects/ppr/train.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import os +import sys +import pdb +from absl import app + +sys.path.insert(0, "%s/../../" % os.path.join(os.path.dirname(__file__))) +from lab4d.train import train_ddp + +sys.path.insert(0, "%s/../" % os.path.join(os.path.dirname(__file__))) +from ppr.trainer import PPRTrainer + + +def main(_): + train_ddp(PPRTrainer) + + +if __name__ == "__main__": + app.run(main) diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py new file mode 100644 index 0000000..e840b05 --- /dev/null +++ b/projects/ppr/trainer.py @@ -0,0 +1,90 @@ +# Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import os, sys +import pdb +import torch +import numpy as np + +from lab4d.engine.trainer import Trainer +from ppr import config + +sys.path.insert(0, "%s/ppr-diffphys" % os.path.join(os.path.dirname(__file__))) +from diffphys.warp_env import phys_model + + +class PPRTrainer(Trainer): + def define_model(self): + super().define_model() + + # define physics model + opts = self.opts + opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] + model_dict = {} + model_dict["bg_rts"] = self.model.fields.field_params["bg"].camera_mlp + model_dict["nerf_root_rts"] = self.model.fields.field_params["fg"].camera_mlp + model_dict["nerf_body_rts"] = self.model.fields.field_params[ + "fg" + ].warp.articulation + model_dict["ks_params"] = self.model.intrinsics + self.phys_model = phys_model(opts, model_dict, use_dr=True) + + def init_model(self): + """Initialize camera transforms, geometry, articulations, and camera + intrinsics from external priors, if this is the first run""" + super().init_model() + + def get_lr_dict(self): + """Return the learning rate for each category of trainable parameters + + Returns: + param_lr_startwith (Dict(str, float)): Learning rate for base model + param_lr_with (Dict(str, float)): Learning rate for explicit params + """ + return super().get_lr_dict() + # opts = self.opts + # lr_base = opts["learning_rate"] + # lr_explicit = lr_base * 10 + + # # only update the following parameters + # param_lr_startwith = { + # "module.intrinsics": lr_base, + # "module.fields.field_params.fg.camera_mlp.base_quat": lr_explicit, + # } + # param_lr_with = { + # "inst_embedding": lr_explicit, + # "time_embedding.mapping1": lr_base, + # ".base_logfocal": lr_explicit, + # ".base_ppoint": lr_explicit, + # } + # return param_lr_startwith, param_lr_with + + def run_one_round(self, round_count): + # super().run_one_round(round_count) + + # transfer pharameters + self.run_phys_cycle() + # transfer pharameters + + def run_phys_cycle(self): + opts = self.opts + torch.cuda.empty_cache() + + # eval + self.phys_model.eval() + self.phys_model.reinit_envs(1, wdw_length=30, is_eval=True) + + # train + self.phys_model.train() + self.phys_model.reinit_envs( + opts["phys_batch"], wdw_length=opts["phys_wdw_len"], is_eval=False + ) + + iters_per_phys_cycle = int(opts["ratio_phys_cycle"] * opts["iters_per_round"]) + for i in range(iters_per_phys_cycle): + self.run_phys_iter() + + def run_phys_iter(self): + """Run physics optimization""" + phys_loss, phys_aux = self.phys_model() + self.phys_model.backward(phys_loss) + grad_list = self.phys_model.update() + phys_aux.update(grad_list) From bf1fc31f2fc3266c9669c862f0cf3731fa94a92a Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Tue, 1 Aug 2023 18:36:06 -0400 Subject: [PATCH 10/86] working forward backward --- lab4d/engine/trainer.py | 16 +++++++++------- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 27 ++++++++++++++++++++++----- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/lab4d/engine/trainer.py b/lab4d/engine/trainer.py index 0e08aea..2938813 100644 --- a/lab4d/engine/trainer.py +++ b/lab4d/engine/trainer.py @@ -38,6 +38,15 @@ def __init__(self, opts): self.define_dataset() self.trainer_init() self.define_model() + + # move model to ddp + self.model = DataParallelPassthrough( + self.model, + device_ids=[get_local_rank()], + output_device=get_local_rank(), + find_unused_parameters=False, + ) + self.optimizer_init(is_resumed=is_resumed) # load model @@ -106,13 +115,6 @@ def define_model(self): self.init_model() - self.model = DataParallelPassthrough( - self.model, - device_ids=[get_local_rank()], - output_device=get_local_rank(), - find_unused_parameters=False, - ) - # cache queue of length 2 self.model_cache = [None, None] self.optimizer_cache = [None, None] diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index ec7d7b3..a94e2d6 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit ec7d7b34a962a81b4faee27ee5ae251a28f3d17f +Subproject commit a94e2d6fb5efe6dc39029a6064d7fd4b1f7dcce8 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index e840b05..f6d625d 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -5,6 +5,7 @@ import numpy as np from lab4d.engine.trainer import Trainer +from lab4d.engine.trainer import get_local_rank from ppr import config sys.path.insert(0, "%s/ppr-diffphys" % os.path.join(os.path.dirname(__file__))) @@ -27,11 +28,25 @@ def define_model(self): model_dict["ks_params"] = self.model.intrinsics self.phys_model = phys_model(opts, model_dict, use_dr=True) + # move model to device + self.device = torch.device("cuda:{}".format(get_local_rank())) + self.phys_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.phys_model) + self.phys_model = self.phys_model.to(self.device) + def init_model(self): """Initialize camera transforms, geometry, articulations, and camera intrinsics from external priors, if this is the first run""" super().init_model() + def trainer_init(self): + super().trainer_init() + + opts = self.opts + self.current_steps_phys = 0 # 0-total_steps + self.iters_per_phys_cycle = int( + opts["ratio_phys_cycle"] * opts["iters_per_round"] + ) + def get_lr_dict(self): """Return the learning rate for each category of trainable parameters @@ -58,7 +73,7 @@ def get_lr_dict(self): # return param_lr_startwith, param_lr_with def run_one_round(self, round_count): - # super().run_one_round(round_count) + super().run_one_round(round_count) # transfer pharameters self.run_phys_cycle() @@ -78,13 +93,15 @@ def run_phys_cycle(self): opts["phys_batch"], wdw_length=opts["phys_wdw_len"], is_eval=False ) - iters_per_phys_cycle = int(opts["ratio_phys_cycle"] * opts["iters_per_round"]) - for i in range(iters_per_phys_cycle): + for i in range(self.iters_per_phys_cycle): + self.phys_model.set_progress(self.current_steps_phys) self.run_phys_iter() + self.current_steps_phys += 1 + print(self.current_steps_phys) def run_phys_iter(self): """Run physics optimization""" - phys_loss, phys_aux = self.phys_model() - self.phys_model.backward(phys_loss) + phys_aux = self.phys_model() + self.phys_model.backward(phys_aux["total_loss"]) grad_list = self.phys_model.update() phys_aux.update(grad_list) From 8be5f97242a20efc58ba1ef942dfaf7f80e4229a Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Tue, 1 Aug 2023 22:09:07 -0400 Subject: [PATCH 11/86] add vis --- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index a94e2d6..32eef20 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit a94e2d6fb5efe6dc39029a6064d7fd4b1f7dcce8 +Subproject commit 32eef20f1f58317d868d73470a94f96d73ea3a51 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index f6d625d..9d061a3 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -10,6 +10,7 @@ sys.path.insert(0, "%s/ppr-diffphys" % os.path.join(os.path.dirname(__file__))) from diffphys.warp_env import phys_model +from diffphys.vis import Logger class PPRTrainer(Trainer): @@ -27,6 +28,7 @@ def define_model(self): ].warp.articulation model_dict["ks_params"] = self.model.intrinsics self.phys_model = phys_model(opts, model_dict, use_dr=True) + self.phys_visualizer = Logger(opts) # move model to device self.device = torch.device("cuda:{}".format(get_local_rank())) @@ -46,6 +48,7 @@ def trainer_init(self): self.iters_per_phys_cycle = int( opts["ratio_phys_cycle"] * opts["iters_per_round"] ) + print("# iterations per phys cycle: ", self.iters_per_phys_cycle) def get_lr_dict(self): """Return the learning rate for each category of trainable parameters @@ -86,6 +89,15 @@ def run_phys_cycle(self): # eval self.phys_model.eval() self.phys_model.reinit_envs(1, wdw_length=30, is_eval=True) + for vidid in opts["phys_vid"]: + frame_start = torch.zeros(1) + self.phys_model.data_offset[vidid] + _ = self.phys_model(frame_start=frame_start.to(self.device)) + img_size = tuple(self.data_info["raw_size"][vidid][::-1]) + img_size = img_size + (0.5,) # scale + data = self.phys_model.query(img_size=img_size) + self.phys_visualizer.show( + "%02d-%05d" % (vidid, self.current_steps_phys), data + ) # train self.phys_model.train() From 46f248bc43cab5552bec645f7c4bc357a4f674e9 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Wed, 2 Aug 2023 13:51:16 -0400 Subject: [PATCH 12/86] update dp_interface --- lab4d/nnutils/multifields.py | 4 +--- lab4d/nnutils/nerf.py | 11 +++++++++++ projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 38 +++++------------------------------- 4 files changed, 18 insertions(+), 37 deletions(-) diff --git a/lab4d/nnutils/multifields.py b/lab4d/nnutils/multifields.py index 08c689a..635db40 100644 --- a/lab4d/nnutils/multifields.py +++ b/lab4d/nnutils/multifields.py @@ -410,9 +410,7 @@ def get_cameras(self, inst_id=None): else: frame_to_vid = field.camera_mlp.time_embedding.frame_to_vid frame_id = (frame_to_vid == inst_id).nonzero() - quat, trans = field.camera_mlp.get_vals(frame_id=frame_id) - trans = trans / field.logscale.exp() - field2cam[cate] = quaternion_translation_to_se3(quat, trans) + field2cam[cate] = field.get_camera(frame_id=frame_id) return field2cam def get_aabb(self): diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index aeaba50..2623fb0 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -982,3 +982,14 @@ def cam_prior_loss(self): """ loss = self.camera_mlp.compute_distance_to_prior() return loss + + def get_camera(self, frame_id=None): + """Compute camera matrices in world units + + Returns: + field2cam (Dict): Maps field names ("fg" or "bg") to (M,4,4) cameras + """ + quat, trans = self.camera_mlp.get_vals(frame_id=frame_id) + trans = trans / self.logscale.exp() + field2cam = quaternion_translation_to_se3(quat, trans) + return field2cam diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index 32eef20..bb3b086 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit 32eef20f1f58317d868d73470a94f96d73ea3a51 +Subproject commit bb3b0866b9809b71863be8640df3166890aa3871 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 9d061a3..468b171 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -9,7 +9,7 @@ from ppr import config sys.path.insert(0, "%s/ppr-diffphys" % os.path.join(os.path.dirname(__file__))) -from diffphys.warp_env import phys_model +from diffphys.dp_interface import phys_interface from diffphys.vis import Logger @@ -21,13 +21,10 @@ def define_model(self): opts = self.opts opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] model_dict = {} - model_dict["bg_rts"] = self.model.fields.field_params["bg"].camera_mlp - model_dict["nerf_root_rts"] = self.model.fields.field_params["fg"].camera_mlp - model_dict["nerf_body_rts"] = self.model.fields.field_params[ - "fg" - ].warp.articulation - model_dict["ks_params"] = self.model.intrinsics - self.phys_model = phys_model(opts, model_dict, use_dr=True) + model_dict["bg_field"] = self.model.fields.field_params["bg"] + model_dict["obj_field"] = self.model.fields.field_params["fg"] + model_dict["intrinsics"] = self.model.intrinsics + self.phys_model = phys_interface(opts, model_dict) self.phys_visualizer = Logger(opts) # move model to device @@ -50,31 +47,6 @@ def trainer_init(self): ) print("# iterations per phys cycle: ", self.iters_per_phys_cycle) - def get_lr_dict(self): - """Return the learning rate for each category of trainable parameters - - Returns: - param_lr_startwith (Dict(str, float)): Learning rate for base model - param_lr_with (Dict(str, float)): Learning rate for explicit params - """ - return super().get_lr_dict() - # opts = self.opts - # lr_base = opts["learning_rate"] - # lr_explicit = lr_base * 10 - - # # only update the following parameters - # param_lr_startwith = { - # "module.intrinsics": lr_base, - # "module.fields.field_params.fg.camera_mlp.base_quat": lr_explicit, - # } - # param_lr_with = { - # "inst_embedding": lr_explicit, - # "time_embedding.mapping1": lr_base, - # ".base_logfocal": lr_explicit, - # ".base_ppoint": lr_explicit, - # } - # return param_lr_startwith, param_lr_with - def run_one_round(self, round_count): super().run_one_round(round_count) From abfaa30cbd5bb08dfdba927f4a7bc61ffde12de4 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Thu, 3 Aug 2023 00:19:21 -0400 Subject: [PATCH 13/86] update scene rectification --- lab4d/nnutils/nerf.py | 28 ++++++++++++++++++++++++++++ lab4d/utils/geom_utils.py | 32 ++++++++++++++++++++++++++++++++ projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 3 +++ 4 files changed, 64 insertions(+), 1 deletion(-) diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index 2623fb0..c65aedc 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -20,6 +20,7 @@ marching_cubes, pinhole_projection, check_inside_aabb, + compute_rectification_se3, ) from lab4d.utils.loss_utils import align_vectors from lab4d.utils.quat_transform import ( @@ -161,6 +162,9 @@ def __init__( # non-parameters are not synchronized self.register_buffer("near_far", torch.zeros(len(rtmat), 2), persistent=False) + field2world = torch.eye(4)[None].expand(self.num_inst, -1, -1) + self.register_buffer("field2world", field2world, persistent=False) + def forward(self, xyz, dir=None, frame_id=None, inst_id=None, get_density=True): """ Args: @@ -993,3 +997,27 @@ def get_camera(self, frame_id=None): trans = trans / self.logscale.exp() field2cam = quaternion_translation_to_se3(quat, trans) return field2cam + + def compute_field2world(self): + """Compute SE(3) to transform points in the scene space to world space + For background, this is computed by detecting planes with ransac. + + Returns: + rect_se3: (4,4) SE(3) transform + """ + for inst_id in range(self.num_inst): + # TODO: move this to background nerf, and use each proxy geometry + self.field2world[inst_id] = compute_rectification_se3(self.proxy_geometry) + + def get_field2world(self, inst_id=None): + """Compute SE(3) to transform points in the scene space to world space + For background, this is computed by detecting planes with ransac. + + Returns: + rect_se3: (4,4) SE(3) transform + """ + if inst_id is None: + field2world = self.field2world + else: + field2world = self.field2world[inst_id] + return field2world diff --git a/lab4d/utils/geom_utils.py b/lab4d/utils/geom_utils.py index d79f924..8800a9c 100644 --- a/lab4d/utils/geom_utils.py +++ b/lab4d/utils/geom_utils.py @@ -1,9 +1,11 @@ # Copyright (c) 2023 Gengshan Yang, Carnegie Mellon University. +import cv2 import numpy as np import torch import trimesh from scipy.spatial.transform import Rotation as R from skimage import measure +import open3d as o3d from lab4d.utils.quat_transform import ( dual_quaternion_apply, @@ -506,3 +508,33 @@ def check_inside_aabb(xyz, aabb): # check whether the point is inside the aabb inside_aabb = ((xyz > aabb[:1]) & (xyz < aabb[1:])).all(-1) return inside_aabb + + +def compute_rectification_se3(mesh, threshold=0.01, init_n=3, iter=1000): + # run ransac to get plane + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(mesh.vertices) + best_eq, index = pcd.segment_plane(threshold, init_n, iter) + segmented_points = pcd.select_by_index(index) + trimesh.Trimesh(segmented_points.points).export("tmp/0.obj") + + # point upside + if best_eq[1] < 0: + best_eq = -1 * best_eq + + # get se3 + plane_n = np.asarray(best_eq[:3]) + center = np.asarray(segmented_points.points).mean(0) + dist = (center * plane_n).sum() + best_eq[3] + plane_o = center - plane_n * dist + plane = np.concatenate([plane_o, plane_n]) + bg2xy = trimesh.geometry.plane_transform(origin=plane[:3], normal=plane[3:6]) + # to xz + xy2xz = np.eye(4) + xy2xz[:3, :3] = cv2.Rodrigues(np.asarray([-np.pi / 2, 0, 0]))[0] + xy2xz[:3, :3] = cv2.Rodrigues(np.asarray([0, -np.pi / 2, 0]))[0] @ xy2xz[:3, :3] + bg2world = xy2xz @ bg2xy # coplanar with xy->xz plane + + # mesh.apply_transform(bg2world) # DEBUG only + bg2world = torch.Tensor(bg2world) + return bg2world diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index bb3b086..e6b05da 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit bb3b0866b9809b71863be8640df3166890aa3871 +Subproject commit e6b05da8aec799186f75b32e1b9419748e70feff diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 468b171..359dec0 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -58,6 +58,9 @@ def run_phys_cycle(self): opts = self.opts torch.cuda.empty_cache() + # re-initialize field2world transforms + self.model.fields.field_params["bg"].compute_field2world() + # eval self.phys_model.eval() self.phys_model.reinit_envs(1, wdw_length=30, is_eval=True) From d240be0f0456dba9eaa297b66ffefb60f7e54ae5 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Thu, 3 Aug 2023 15:25:51 -0400 Subject: [PATCH 14/86] fix urdf --- lab4d/nnutils/deformable.py | 4 +- lab4d/nnutils/pose.py | 75 ++++++++++++++++++++++++++++++++++++- lab4d/nnutils/warping.py | 28 +++++++++++--- lab4d/utils/skel_utils.py | 31 ++++++++++++--- projects/ppr/config.py | 4 +- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 1 + 7 files changed, 127 insertions(+), 18 deletions(-) diff --git a/lab4d/nnutils/deformable.py b/lab4d/nnutils/deformable.py index 6410c13..2f40f68 100644 --- a/lab4d/nnutils/deformable.py +++ b/lab4d/nnutils/deformable.py @@ -111,7 +111,7 @@ def sdf_fn_torch_skel(pts): sdf = self.warp.get_gauss_sdf(pts) return sdf - if "skel-" in self.fg_motion: + if "skel-" in self.fg_motion or "urdf-" in self.fg_motion: return sdf_fn_torch_skel else: return sdf_fn_torch_sphere @@ -293,7 +293,7 @@ def mlp_init(self): from an external skeleton """ super().mlp_init() - if self.fg_motion.startswith("skel"): + if "skel-" in self.fg_motion or "urdf-" in self.fg_motion: if hasattr(self.warp.articulation, "init_vals"): self.warp.articulation.mlp_init() diff --git a/lab4d/nnutils/pose.py b/lab4d/nnutils/pose.py index c957525..b504164 100644 --- a/lab4d/nnutils/pose.py +++ b/lab4d/nnutils/pose.py @@ -459,7 +459,7 @@ def forward( local_rest_joints = override_local_rest_joints # run forward kinematics - out = fk_se3(local_rest_joints, so3, self.edges) + out = self.fk_se3(local_rest_joints, so3, self.edges) out = shift_joints_to_bones_dq(out, self.edges, shift=self.shift) return out @@ -475,7 +475,7 @@ def compute_rel_rest_joints(self, inst_id=None, override_log_bone_len=None): rel_rest_joints: Translations from parent to child joints """ # get relative joints - rel_rest_joints = rest_joints_to_local(self.rest_joints, self.edges) + rel_rest_joints = self.rest_joints_to_local(self.rest_joints, self.edges) # match the shape rel_rest_joints = rel_rest_joints[None] @@ -493,6 +493,14 @@ def compute_rel_rest_joints(self, inst_id=None, override_log_bone_len=None): rel_rest_joints = rel_rest_joints * bone_length[..., None] return rel_rest_joints + def fk_se3(self, local_rest_joints, so3, edges): + """Forward kinematics for a skeleton""" + return fk_se3(local_rest_joints, so3, edges) + + def rest_joints_to_local(self, rest_joints, edges): + """Convert rest joints to local coordinates""" + return rest_joints_to_local(rest_joints, edges) + def get_vals(self, frame_id=None, return_so3=False, override_so3=None): """Compute articulation parameters at the given frames. @@ -598,3 +606,66 @@ def skel_prior_loss(self): # loss = (bones_gt - bones_pred).norm(2, -1).mean() # loss = loss * 0.2 return loss + + +class ArticulationURDFMLP(ArticulationSkelMLP): + """Encode a skeleton over time using an MLP + + Args: + frame_info (FrameInfo): Metadata about the frames in a dataset + skel_type (str): Skeleton type ("human" or "quad") + joint_angles: (B, 3) If provided, initial joint angles + num_se3 (int): Number of bones + D (int): Number of linear layers + W (int): Number of hidden units in each MLP layer + num_freq_t (int): Number of frequencies in time Fourier embedding + skips (List(int)): List of layers to add skip connections at + activation (Function): Activation function to use (e.g. nn.ReLU()) + """ + + def __init__( + self, + frame_info, + skel_type, + joint_angles, + D=5, + W=256, + num_freq_t=6, + skips=[], + activation=nn.ReLU(True), + ): + super().__init__( + frame_info, + skel_type, + joint_angles, + D=D, + W=W, + num_freq_t=num_freq_t, + skips=skips, + activation=activation, + ) + + self.urdf = self.get_urdf(skel_type) + + # get local rest rotation matrices, pick the first coordinate in rpy of ball joints + local_rest_rmat = np.stack([i.origin[:3, :3] for i in self.urdf.joints], 0) + local_rest_rmat = torch.tensor(local_rest_rmat[::3], dtype=torch.float32) + self.register_buffer("local_rest_rmat", local_rest_rmat, persistent=False) + + def get_urdf(self, urdf_name): + """Load the URDF file for the skeleton""" + from urdfpy import URDF + + urdf_path = f"projects/ppr/ppr-diffphys/data/urdf_templates/{urdf_name}.urdf" + urdf = URDF.load(urdf_path) + return urdf + + def fk_se3(self, local_rest_joints, so3, edges): + return fk_se3( + local_rest_joints, so3, edges, local_rest_rmat=self.local_rest_rmat + ) + + def rest_joints_to_local(self, rest_joints, edges): + return rest_joints_to_local( + rest_joints, edges, local_rest_rmat=self.local_rest_rmat + ) diff --git a/lab4d/nnutils/warping.py b/lab4d/nnutils/warping.py index 0361c63..5842c85 100644 --- a/lab4d/nnutils/warping.py +++ b/lab4d/nnutils/warping.py @@ -6,7 +6,11 @@ from lab4d.nnutils.base import CondMLP from lab4d.nnutils.embedding import PosEmbedding, TimeEmbedding -from lab4d.nnutils.pose import ArticulationFlatMLP, ArticulationSkelMLP +from lab4d.nnutils.pose import ( + ArticulationFlatMLP, + ArticulationSkelMLP, + ArticulationURDFMLP, +) from lab4d.nnutils.skinning import SkinningField from lab4d.third_party.nvp import NVP from lab4d.utils.geom_utils import dual_quaternion_skinning, marching_cubes, extend_aabb @@ -41,7 +45,13 @@ def create_warp(fg_motion, data_info): elif fg_motion.startswith("skel-"): warp = SkinningWarp( frame_info, - skel_type=fg_motion.split("-")[1], + skel_type=fg_motion, + joint_angles=joint_angles, + ) + elif fg_motion.startswith("urdf-"): + warp = SkinningWarp( + frame_info, + skel_type=fg_motion, joint_angles=joint_angles, ) elif fg_motion.startswith("comp"): @@ -257,10 +267,18 @@ def __init__( if skel_type == "flat": self.articulation = ArticulationFlatMLP(frame_info, num_se3) symm_idx = None - else: + elif skel_type.startswith("skel-"): + skel_type = skel_type.split("-")[1] self.articulation = ArticulationSkelMLP(frame_info, skel_type, joint_angles) num_se3 = self.articulation.num_se3 symm_idx = self.articulation.symm_idx + elif skel_type.startswith("urdf-"): + skel_type = skel_type.split("-")[1] + self.articulation = ArticulationURDFMLP(frame_info, skel_type, joint_angles) + num_se3 = self.articulation.num_se3 + symm_idx = self.articulation.symm_idx + else: + raise NotImplementedError self.skinning_model = SkinningField( num_se3, @@ -427,14 +445,14 @@ def __init__( # e.g., comp_skel-human_dense, limited to skel+another type of field type_list = warp_type.split("_")[1:] assert len(type_list) == 2 - assert type_list[0] in ["skel-human", "skel-quad"] + assert type_list[0] in ["skel-human", "skel-quad", "urdf-human", "urdf-quad"] assert type_list[1] in ["bob", "dense"] if type_list[1] == "bob": raise NotImplementedError super().__init__( frame_info, - skel_type=type_list[0].split("-")[1], + skel_type=type_list[0], joint_angles=joint_angles, ) # self.post_warp = DenseWarp(frame_info, D=2, W=64) diff --git a/lab4d/utils/skel_utils.py b/lab4d/utils/skel_utils.py index 4936073..6110788 100644 --- a/lab4d/utils/skel_utils.py +++ b/lab4d/utils/skel_utils.py @@ -32,31 +32,45 @@ def get_valid_edges(edges): return idx, parent_idx -def rest_joints_to_local(rest_joints, edges): +def rest_joints_to_local(rest_joints, edges, local_rest_rmat=None): """Convert rest joints to local coordinates, where local = current - parent + If local_rest_rmat is given, local = parent^-1 * current Args: rest_joints: (B, 3) Joint locations edges (Dict(int, int)): Maps each joint to its parent joint + local_rest_rmat: (B, 3, 3) Local rotations Returns: local_rest_joints: (B, 3) Translations from parent to child joints """ idx, parent_idx = get_valid_edges(edges) local_rest_joints = rest_joints.clone() local_rest_joints[idx] = rest_joints[idx] - rest_joints[parent_idx] + if local_rest_rmat is not None: + # T_rel = R_parent^T(T-T_parent) + to_global_rmat = local_rest_rmat.clone() + for idx, parent_idx in edges.items(): + if parent_idx > 0: + local_rest_joints[idx - 1] = ( + to_global_rmat[parent_idx - 1].T @ local_rest_joints[idx - 1] + ) + to_global_rmat[idx - 1] = ( + to_global_rmat[parent_idx - 1] @ local_rest_rmat[idx - 1] + ) return local_rest_joints -def fk_se3(local_rest_joints, so3, edges, to_dq=True): - """Compute forward kinematics given joint angles on a skeleton +def fk_se3(local_rest_joints, so3, edges, to_dq=True, local_rest_rmat=None): + """Compute forward kinematics given joint angles on a skeleton. + If local_rest_rmat is None, assuming identity rotation in zero configuration. Args: local_rest_joints: (B, 3) Translations from parent to current joints, - assuming identity rotation in zero configuration so3: (..., B, 3) Axis-angles at each joint edges (Dict(int, int)): Maps each joint to its parent joint to_dq (bool): If True, output link rigid transforms as dual quaternions, otherwise output SE(3) + local_rest_rot: (B, 3, 3) Local rotations Returns: out: Location of each joint. This is written as dual quaternions ((..., B, 4), (..., B, 4)) if to_dq=True, otherwise it is written @@ -73,10 +87,17 @@ def fk_se3(local_rest_joints, so3, edges, to_dq=True): local_to_parent = identity_rt.clone() global_rt = identity_rt.clone() + if local_rest_rmat is None: + local_rest_rmat = so3_to_exp_map(so3) + else: + local_rest_rmat = local_rest_rmat.view((1,) * (len(shape) - 2) + (-1, 3, 3)) + local_rest_rmat = so3_to_exp_map(so3) @ local_rest_rmat + # get local rt transformation: (..., k, 4, 4) + # parent ... child # first rotate around joint i # then translate wrt the relative position of the parent to i - local_to_parent[..., :3, :3] = so3_to_exp_map(so3) + local_to_parent[..., :3, :3] = local_rest_rmat local_to_parent[..., :3, 3] = local_rest_joints for idx, parent_idx in edges.items(): diff --git a/projects/ppr/config.py b/projects/ppr/config.py index 9574390..e7396ab 100644 --- a/projects/ppr/config.py +++ b/projects/ppr/config.py @@ -8,9 +8,7 @@ class PPRConfig: # configs related to ppr - flags.DEFINE_string( - "urdf_template", "wolf_mod", "whether to use predefined skeleton" - ) + flags.DEFINE_string("urdf_template", "", "whether to use predefined skeleton") flags.DEFINE_float("ratio_phys_cycle", 0.2, "number of iterations per round") flags.DEFINE_integer("phys_wdw_len", 24, "length of the physics opt window") flags.DEFINE_integer("phys_batch", 20, "number of parallel physics sim") diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index e6b05da..d5d2711 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit e6b05da8aec799186f75b32e1b9419748e70feff +Subproject commit d5d27115ef41410f8c0d8bec626a5e0b77867d1d diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 359dec0..39d665d 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -20,6 +20,7 @@ def define_model(self): # define physics model opts = self.opts opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] + opts["urdf_template"] = opts["fg_motion"].split("-")[1].split("_")[0] model_dict = {} model_dict["bg_field"] = self.model.fields.field_params["bg"] model_dict["obj_field"] = self.model.fields.field_params["fg"] From 3faa742a5e24e074e8894181981ec0691f8672aa Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Fri, 4 Aug 2023 01:12:20 -0400 Subject: [PATCH 15/86] fix urdf --- lab4d/engine/trainer.py | 1 + lab4d/nnutils/nerf.py | 2 +- lab4d/nnutils/pose.py | 46 +++++++++++++++++++++++++++--------- lab4d/utils/geom_utils.py | 17 +++++++++++++- lab4d/utils/skel_utils.py | 49 ++++++++++++++++++--------------------- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 15 +++++++----- 7 files changed, 85 insertions(+), 47 deletions(-) diff --git a/lab4d/engine/trainer.py b/lab4d/engine/trainer.py index 2938813..06f1e3a 100644 --- a/lab4d/engine/trainer.py +++ b/lab4d/engine/trainer.py @@ -145,6 +145,7 @@ def get_lr_dict(self): ".base_logfocal": lr_explicit, ".base_ppoint": lr_explicit, ".shift": lr_explicit, + ".orient": lr_explicit, } return param_lr_startwith, param_lr_with diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index c65aedc..b107559 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -163,7 +163,7 @@ def __init__( self.register_buffer("near_far", torch.zeros(len(rtmat), 2), persistent=False) field2world = torch.eye(4)[None].expand(self.num_inst, -1, -1) - self.register_buffer("field2world", field2world, persistent=False) + self.register_buffer("field2world", field2world, persistent=True) def forward(self, xyz, dir=None, frame_id=None, inst_id=None, get_density=True): """ diff --git a/lab4d/nnutils/pose.py b/lab4d/nnutils/pose.py index b504164..3f8c6ef 100644 --- a/lab4d/nnutils/pose.py +++ b/lab4d/nnutils/pose.py @@ -377,6 +377,8 @@ def __init__( self.logscale = nn.Parameter(torch.zeros(1)) self.shift = nn.Parameter(torch.zeros(3)) + self.orient = nn.Parameter(torch.tensor([1.0, 0.0, 0.0, 0.0])) + # instance bone length num_inst = len(frame_info["frame_offset"]) - 1 self.log_bone_len = CondMLP( @@ -460,7 +462,9 @@ def forward( # run forward kinematics out = self.fk_se3(local_rest_joints, so3, self.edges) - out = shift_joints_to_bones_dq(out, self.edges, shift=self.shift) + out = shift_joints_to_bones_dq( + out, self.edges, shift=self.shift, orient=self.orient + ) return out def compute_rel_rest_joints(self, inst_id=None, override_log_bone_len=None): @@ -645,27 +649,47 @@ def __init__( activation=activation, ) - self.urdf = self.get_urdf(skel_type) + local_rest_coord, scale_factor, orient, offset = self.parse_urdf(skel_type) + self.logscale.data = torch.log(scale_factor) + self.shift.data = offset + self.orient.data = orient # get local rest rotation matrices, pick the first coordinate in rpy of ball joints - local_rest_rmat = np.stack([i.origin[:3, :3] for i in self.urdf.joints], 0) - local_rest_rmat = torch.tensor(local_rest_rmat[::3], dtype=torch.float32) - self.register_buffer("local_rest_rmat", local_rest_rmat, persistent=False) + # by default: transform points from child to parent + local_rest_coord = torch.tensor(local_rest_coord, dtype=torch.float32) + self.register_buffer("local_rest_coord", local_rest_coord, persistent=False) + self.rest_joints = None - def get_urdf(self, urdf_name): + def parse_urdf(self, urdf_name): """Load the URDF file for the skeleton""" from urdfpy import URDF urdf_path = f"projects/ppr/ppr-diffphys/data/urdf_templates/{urdf_name}.urdf" urdf = URDF.load(urdf_path) - return urdf + + local_rest_coord = np.stack([i.origin for i in urdf.joints], 0)[::3] + + if urdf_name == "human": + offset = torch.tensor([0.0, 0.0, 0.0]) + orient = torch.tensor([1.0, 0.0, 0.0, 0.0]) # wxyz + scale_factor = torch.tensor([0.08]) + elif urdf_name == "quad": + offset = torch.tensor([0.0, -0.02, 0.02]) + orient = torch.tensor([1.0, -0.8, 0.0, 0.0]) + scale_factor = torch.tensor([0.05]) + else: + raise NotImplementedError + + orient = F.normalize(orient, dim=-1) + return local_rest_coord, scale_factor, orient, offset def fk_se3(self, local_rest_joints, so3, edges): return fk_se3( - local_rest_joints, so3, edges, local_rest_rmat=self.local_rest_rmat + local_rest_joints, + so3, + edges, + local_rest_coord=self.local_rest_coord.clone(), ) def rest_joints_to_local(self, rest_joints, edges): - return rest_joints_to_local( - rest_joints, edges, local_rest_rmat=self.local_rest_rmat - ) + return self.local_rest_coord[:, :3, 3].clone() diff --git a/lab4d/utils/geom_utils.py b/lab4d/utils/geom_utils.py index 8800a9c..77d1f62 100644 --- a/lab4d/utils/geom_utils.py +++ b/lab4d/utils/geom_utils.py @@ -516,7 +516,6 @@ def compute_rectification_se3(mesh, threshold=0.01, init_n=3, iter=1000): pcd.points = o3d.utility.Vector3dVector(mesh.vertices) best_eq, index = pcd.segment_plane(threshold, init_n, iter) segmented_points = pcd.select_by_index(index) - trimesh.Trimesh(segmented_points.points).export("tmp/0.obj") # point upside if best_eq[1] < 0: @@ -538,3 +537,19 @@ def compute_rectification_se3(mesh, threshold=0.01, init_n=3, iter=1000): # mesh.apply_transform(bg2world) # DEBUG only bg2world = torch.Tensor(bg2world) return bg2world + + +def se3_inv(rtmat): + """Invert an SE(3) matrix + + Args: + rtmat: (..., 4, 4) SE(3) matrix + Returns: + rtmat_inv: (..., 4, 4) Inverse SE(3) matrix + """ + rmat, tmat = se3_mat2rt(rtmat) + rmat = rmat.transpose(-1, -2) + tmat = -rmat @ tmat[..., None] + rtmat[..., :3, :3] = rmat + rtmat[..., :3, 3] = tmat[..., 0] + return rtmat diff --git a/lab4d/utils/skel_utils.py b/lab4d/utils/skel_utils.py index 6110788..3aaa227 100644 --- a/lab4d/utils/skel_utils.py +++ b/lab4d/utils/skel_utils.py @@ -9,7 +9,7 @@ from lab4d.utils.quat_transform import ( axis_angle_to_quaternion, matrix_to_quaternion, - quaternion_translation_mul, + dual_quaternion_mul, quaternion_translation_to_dual_quaternion, dual_quaternion_to_quaternion_translation, ) @@ -32,35 +32,22 @@ def get_valid_edges(edges): return idx, parent_idx -def rest_joints_to_local(rest_joints, edges, local_rest_rmat=None): +def rest_joints_to_local(rest_joints, edges): """Convert rest joints to local coordinates, where local = current - parent - If local_rest_rmat is given, local = parent^-1 * current Args: rest_joints: (B, 3) Joint locations edges (Dict(int, int)): Maps each joint to its parent joint - local_rest_rmat: (B, 3, 3) Local rotations Returns: local_rest_joints: (B, 3) Translations from parent to child joints """ idx, parent_idx = get_valid_edges(edges) local_rest_joints = rest_joints.clone() local_rest_joints[idx] = rest_joints[idx] - rest_joints[parent_idx] - if local_rest_rmat is not None: - # T_rel = R_parent^T(T-T_parent) - to_global_rmat = local_rest_rmat.clone() - for idx, parent_idx in edges.items(): - if parent_idx > 0: - local_rest_joints[idx - 1] = ( - to_global_rmat[parent_idx - 1].T @ local_rest_joints[idx - 1] - ) - to_global_rmat[idx - 1] = ( - to_global_rmat[parent_idx - 1] @ local_rest_rmat[idx - 1] - ) return local_rest_joints -def fk_se3(local_rest_joints, so3, edges, to_dq=True, local_rest_rmat=None): +def fk_se3(local_rest_joints, so3, edges, to_dq=True, local_rest_coord=None): """Compute forward kinematics given joint angles on a skeleton. If local_rest_rmat is None, assuming identity rotation in zero configuration. @@ -75,6 +62,7 @@ def fk_se3(local_rest_joints, so3, edges, to_dq=True, local_rest_rmat=None): out: Location of each joint. This is written as dual quaternions ((..., B, 4), (..., B, 4)) if to_dq=True, otherwise it is written as (..., B, 4, 4) SE(3) matrices. + link to global transforms X_global = T_1...T_k x X_k """ assert local_rest_joints.shape == so3.shape shape = so3.shape @@ -84,22 +72,22 @@ def fk_se3(local_rest_joints, so3, edges, to_dq=True, local_rest_rmat=None): identity_rt = identity_rt.view((1,) * (len(shape) - 2) + (-1, 4, 4)) identity_rt = identity_rt.expand(*shape[:-1], -1, -1).clone() identity_rt_slice = identity_rt[..., 0, :, :].clone() - local_to_parent = identity_rt.clone() global_rt = identity_rt.clone() - if local_rest_rmat is None: - local_rest_rmat = so3_to_exp_map(so3) + if local_rest_coord is None: + local_rmat = so3_to_exp_map(so3) else: - local_rest_rmat = local_rest_rmat.view((1,) * (len(shape) - 2) + (-1, 3, 3)) - local_rest_rmat = so3_to_exp_map(so3) @ local_rest_rmat + local_rmat = local_rest_coord[:, :3, :3] + local_rmat = local_rmat.view((1,) * (len(shape) - 2) + (-1, 3, 3)) + local_rmat = local_rmat @ so3_to_exp_map(so3) + + local_to_parent = torch.cat([local_rmat, local_rest_joints[..., None]], -1) + local_to_parent = torch.cat([local_to_parent, identity_rt[..., -1:, :]], -2) # get local rt transformation: (..., k, 4, 4) # parent ... child # first rotate around joint i # then translate wrt the relative position of the parent to i - local_to_parent[..., :3, :3] = local_rest_rmat - local_to_parent[..., :3, 3] = local_rest_joints - for idx, parent_idx in edges.items(): if parent_idx > 0: parent_to_global = global_rt[..., parent_idx - 1, :, :].clone() @@ -119,7 +107,7 @@ def fk_se3(local_rest_joints, so3, edges, to_dq=True, local_rest_rmat=None): return global_rt -def shift_joints_to_bones_dq(dq, edges, shift=None): +def shift_joints_to_bones_dq(dq, edges, shift=None, orient=None): """Compute bone centers and orientations from joint locations Args: @@ -131,10 +119,17 @@ def shift_joints_to_bones_dq(dq, edges, shift=None): written as dual quaternions """ quat, joints = dual_quaternion_to_quaternion_translation(dq) - if shift is not None: - joints += shift.reshape((1,) * (joints[0].ndim - 1) + (3,)) joints = shift_joints_to_bones(joints, edges) dq = quaternion_translation_to_dual_quaternion(quat, joints) + if shift is not None and orient is not None: + shift = shift.reshape((1,) * (joints[0].ndim - 1) + (3,)) + shift = shift.expand(*joints.shape[:-1], -1) + orient = orient.reshape((1,) * (joints[0].ndim - 1) + (4,)) + orient = orient.expand(*joints.shape[:-1], -1) + offset_dq = quaternion_translation_to_dual_quaternion(orient, shift) + dq = dual_quaternion_mul(offset_dq, dq) + else: + raise ValueError("shift and orient cannot be both None") return dq diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index d5d2711..f35895f 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit d5d27115ef41410f8c0d8bec626a5e0b77867d1d +Subproject commit f35895f8eb5f5c3c3d567a7eabcedcf981297f72 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 39d665d..c8a05b4 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -22,7 +22,7 @@ def define_model(self): opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] opts["urdf_template"] = opts["fg_motion"].split("-")[1].split("_")[0] model_dict = {} - model_dict["bg_field"] = self.model.fields.field_params["bg"] + model_dict["scene_field"] = self.model.fields.field_params["bg"] model_dict["obj_field"] = self.model.fields.field_params["fg"] model_dict["intrinsics"] = self.model.intrinsics self.phys_model = phys_interface(opts, model_dict) @@ -46,22 +46,26 @@ def trainer_init(self): self.iters_per_phys_cycle = int( opts["ratio_phys_cycle"] * opts["iters_per_round"] ) - print("# iterations per phys cycle: ", self.iters_per_phys_cycle) + print("# iterations per phys cycle:", self.iters_per_phys_cycle) def run_one_round(self, round_count): + # run dr cycle super().run_one_round(round_count) + # re-initialize field2world transforms + self.model.fields.field_params["bg"].compute_field2world() + # transfer pharameters + self.phys_model.override_states() + # run physics cycle self.run_phys_cycle() # transfer pharameters + self.phys_model.override_states_inv() def run_phys_cycle(self): opts = self.opts torch.cuda.empty_cache() - # re-initialize field2world transforms - self.model.fields.field_params["bg"].compute_field2world() - # eval self.phys_model.eval() self.phys_model.reinit_envs(1, wdw_length=30, is_eval=True) @@ -85,7 +89,6 @@ def run_phys_cycle(self): self.phys_model.set_progress(self.current_steps_phys) self.run_phys_iter() self.current_steps_phys += 1 - print(self.current_steps_phys) def run_phys_iter(self): """Run physics optimization""" From d19d7ca003e64be522f119242fcf715cff05dee0 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Fri, 4 Aug 2023 12:12:13 -0400 Subject: [PATCH 16/86] init plane once --- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index f35895f..8de925e 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit f35895f8eb5f5c3c3d567a7eabcedcf981297f72 +Subproject commit 8de925e01f9def9b8e153fa81a39634bae451d45 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index c8a05b4..29b706f 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -17,14 +17,21 @@ class PPRTrainer(Trainer): def define_model(self): super().define_model() - # define physics model + # opts opts = self.opts opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] opts["urdf_template"] = opts["fg_motion"].split("-")[1].split("_")[0] + + # re-initialize field2world transforms + self.model.fields.field_params["bg"].compute_field2world() + + # model model_dict = {} model_dict["scene_field"] = self.model.fields.field_params["bg"] model_dict["obj_field"] = self.model.fields.field_params["fg"] model_dict["intrinsics"] = self.model.intrinsics + + # define phys model self.phys_model = phys_interface(opts, model_dict) self.phys_visualizer = Logger(opts) @@ -52,9 +59,6 @@ def run_one_round(self, round_count): # run dr cycle super().run_one_round(round_count) - # re-initialize field2world transforms - self.model.fields.field_params["bg"].compute_field2world() - # transfer pharameters self.phys_model.override_states() # run physics cycle From 9f242fead1995b3f49305e83703f3f23e72a4117 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Fri, 4 Aug 2023 18:02:25 -0400 Subject: [PATCH 17/86] update --- projects/ppr/ppr-diffphys | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index 8de925e..866a44d 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit 8de925e01f9def9b8e153fa81a39634bae451d45 +Subproject commit 866a44d8a1e45ad29fb18e121a9565596e572393 From 3f017dbfd8b11c20db8ea7b6ec6267e88ef4704a Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Sun, 6 Aug 2023 01:29:17 -0400 Subject: [PATCH 18/86] readibility --- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 35 ++++++++++++++++++++--------------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index 866a44d..22644f7 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit 866a44d8a1e45ad29fb18e121a9565596e572393 +Subproject commit 22644f763361f4d8d9617e64d1d7157d76fff160 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 29b706f..0f67788 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -28,7 +28,7 @@ def define_model(self): # model model_dict = {} model_dict["scene_field"] = self.model.fields.field_params["bg"] - model_dict["obj_field"] = self.model.fields.field_params["fg"] + model_dict["object_field"] = self.model.fields.field_params["fg"] model_dict["intrinsics"] = self.model.intrinsics # define phys model @@ -58,7 +58,6 @@ def trainer_init(self): def run_one_round(self, round_count): # run dr cycle super().run_one_round(round_count) - # transfer pharameters self.phys_model.override_states() # run physics cycle @@ -72,31 +71,37 @@ def run_phys_cycle(self): # eval self.phys_model.eval() - self.phys_model.reinit_envs(1, wdw_length=30, is_eval=True) - for vidid in opts["phys_vid"]: - frame_start = torch.zeros(1) + self.phys_model.data_offset[vidid] - _ = self.phys_model(frame_start=frame_start.to(self.device)) - img_size = tuple(self.data_info["raw_size"][vidid][::-1]) - img_size = img_size + (0.5,) # scale - data = self.phys_model.query(img_size=img_size) - self.phys_visualizer.show( - "%02d-%05d" % (vidid, self.current_steps_phys), data - ) + self.run_phys_visualization(tag="kinematics") # train self.phys_model.train() self.phys_model.reinit_envs( opts["phys_batch"], wdw_length=opts["phys_wdw_len"], is_eval=False ) - for i in range(self.iters_per_phys_cycle): self.phys_model.set_progress(self.current_steps_phys) self.run_phys_iter() self.current_steps_phys += 1 + # eval again + self.phys_model.eval() + self.run_phys_visualization(tag="phys") + def run_phys_iter(self): """Run physics optimization""" phys_aux = self.phys_model() self.phys_model.backward(phys_aux["total_loss"]) - grad_list = self.phys_model.update() - phys_aux.update(grad_list) + self.phys_model.update() + + def run_phys_visualization(self, tag=""): + opts = self.opts + self.phys_model.reinit_envs(1, wdw_length=30, is_eval=True) + for vidid in opts["phys_vid"]: + frame_start = torch.zeros(1) + self.phys_model.data_offset[vidid] + _ = self.phys_model(frame_start=frame_start.to(self.device)) + img_size = tuple(self.data_info["raw_size"][vidid][::-1]) + img_size = img_size + (0.5,) # scale + data = self.phys_model.query(img_size=img_size) + self.phys_visualizer.show( + "%s-%02d-%05d" % (tag, vidid, self.current_steps_phys), data + ) From e5ad970113a89f9b425334818085a90fc22d2dfa Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Fri, 11 Aug 2023 14:32:08 -0400 Subject: [PATCH 19/86] allow decimal frame id --- lab4d/nnutils/embedding.py | 11 ++++++----- lab4d/nnutils/pose.py | 6 +++--- projects/ppr/config.py | 11 +++++++++++ projects/ppr/ppr-diffphys | 2 +- 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/lab4d/nnutils/embedding.py b/lab4d/nnutils/embedding.py index 57f3bb7..92c33b4 100644 --- a/lab4d/nnutils/embedding.py +++ b/lab4d/nnutils/embedding.py @@ -201,7 +201,7 @@ def forward(self, frame_id=None): if frame_id is None: inst_id, t_sample = self.frame_to_vid, self.frame_to_tid(self.frame_mapping) else: - inst_id = self.raw_fid_to_vid[frame_id] + inst_id = self.raw_fid_to_vid[frame_id.long()] t_sample = self.frame_to_tid(frame_id) if inst_id.ndim == 1: @@ -257,10 +257,11 @@ def forward(self, inst_id): return torch.zeros(inst_id.shape + (0,), device=inst_id.device) else: if self.num_inst == 1: - return self.mapping(torch.zeros_like(inst_id)) - if self.training and self.beta_prob > 0: - inst_id = self.randomize_instance(inst_id) - inst_code = self.mapping(inst_id) + inst_code = self.mapping(torch.zeros_like(inst_id)) + else: + if self.training and self.beta_prob > 0: + inst_id = self.randomize_instance(inst_id) + inst_code = self.mapping(inst_id) return inst_code def randomize_instance(self, inst_id): diff --git a/lab4d/nnutils/pose.py b/lab4d/nnutils/pose.py index 11b79a5..e64bcfd 100644 --- a/lab4d/nnutils/pose.py +++ b/lab4d/nnutils/pose.py @@ -7,7 +7,7 @@ from lab4d.nnutils.base import CondMLP, BaseMLP, ScaleLayer from lab4d.nnutils.time import TimeMLP -from lab4d.utils.geom_utils import so3_to_exp_map +from lab4d.utils.geom_utils import so3_to_exp_map, rot_angle from lab4d.utils.quat_transform import ( axis_angle_to_quaternion, matrix_to_quaternion, @@ -141,7 +141,7 @@ def get_vals(self, frame_id=None): if frame_id is None: inst_id = self.time_embedding.frame_to_vid else: - inst_id = self.time_embedding.raw_fid_to_vid[frame_id] + inst_id = self.time_embedding.raw_fid_to_vid[frame_id.long()] # multiply with per-instance base rotation base_quat = self.base_quat[inst_id] @@ -520,7 +520,7 @@ def get_vals(self, frame_id=None, return_so3=False, override_so3=None): if frame_id is None: inst_id = self.time_embedding.frame_to_vid else: - inst_id = self.time_embedding.raw_fid_to_vid[frame_id] + inst_id = self.time_embedding.raw_fid_to_vid[frame_id.long()] t_embed = self.time_embedding(frame_id) pred = self.forward( t_embed, inst_id, return_so3=return_so3, override_so3=override_so3 diff --git a/projects/ppr/config.py b/projects/ppr/config.py index e7396ab..6a28dbe 100644 --- a/projects/ppr/config.py +++ b/projects/ppr/config.py @@ -15,3 +15,14 @@ class PPRConfig: flags.DEFINE_string( "phys_vid", "0", "whether to optimize selected videos, e.g., 0,1,2" ) + + # weights + flags.DEFINE_float("traj_wt", 0.1, "weight for traj matching loss") + flags.DEFINE_float("pos_state_wt", 0.1, "weight for position matching reg") + flags.DEFINE_float("vel_state_wt", 0.0, "weight for velocity matching reg") + + # regs + flags.DEFINE_float("reg_torque_wt", 0.0, "weight for torque regularization") + flags.DEFINE_float("reg_res_f_wt", 0.0, "weight for residual force regularization") + flags.DEFINE_float("reg_foot_wt", 0.0, "weight for foot contact regularization") + flags.DEFINE_float("reg_root_wt", 0.0, "weight for root pose regularization") diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index 22644f7..7d6e2b6 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit 22644f763361f4d8d9617e64d1d7157d76fff160 +Subproject commit 7d6e2b6bdbe8ee176e4890d5dc64c404e4f176c9 From 6cae7680dd54dc70bea64c90daccd88f9b1b7781 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Tue, 15 Aug 2023 16:14:42 -0400 Subject: [PATCH 20/86] frz gauss rel scale; fix dq norm bug; urdf bones --- lab4d/config.py | 4 +-- lab4d/engine/model.py | 17 +++++------ lab4d/engine/trainer.py | 5 +--- lab4d/nnutils/embedding.py | 7 ++++- lab4d/nnutils/multifields.py | 4 ++- lab4d/nnutils/nerf.py | 1 + lab4d/nnutils/pose.py | 56 ++++++++++++++++++++++++++++++------ lab4d/nnutils/skinning.py | 20 +++++++++---- lab4d/nnutils/warping.py | 27 ++++++++--------- lab4d/utils/skel_utils.py | 39 ++++++++++++++++++------- lab4d/utils/transforms.py | 11 +++++-- projects/ppr/config.py | 2 +- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 42 +++++++++++++++++++++++---- 14 files changed, 170 insertions(+), 67 deletions(-) diff --git a/lab4d/config.py b/lab4d/config.py index c6fcaa8..181e244 100644 --- a/lab4d/config.py +++ b/lab4d/config.py @@ -65,9 +65,7 @@ class TrainOptConfig: flags.DEFINE_integer("pixels_per_image", 16, "pixel samples per image") # flags.DEFINE_integer("imgs_per_gpu", 1, "size of minibatches per iter") # flags.DEFINE_integer("pixels_per_image", 4096, "number of pixel samples per image") - flags.DEFINE_boolean( - "freeze_bone_len", False, "do not change bone length of skeleton" - ) + flags.DEFINE_boolean("use_freq_anneal", True, "whether to use frequency annealing") flags.DEFINE_boolean( "reset_steps", True, diff --git a/lab4d/engine/model.py b/lab4d/engine/model.py index b7a53fb..8a07108 100644 --- a/lab4d/engine/model.py +++ b/lab4d/engine/model.py @@ -98,14 +98,15 @@ def set_progress(self, current_steps): Args: current_steps (int): Number of optimization steps so far """ - # positional encoding annealing - anchor_x = (0, 4000) - anchor_y = (0.6, 1) - type = "linear" - alpha = interp_wt(anchor_x, anchor_y, current_steps, type=type) - if alpha >= 1: - alpha = None - self.fields.set_alpha(alpha) + if self.config["use_freq_anneal"]: + # positional encoding annealing + anchor_x = (0, 4000) + anchor_y = (0.6, 1) + type = "linear" + alpha = interp_wt(anchor_x, anchor_y, current_steps, type=type) + if alpha >= 1: + alpha = None + self.fields.set_alpha(alpha) # beta_prob: steps(0->2k, 1->0.2), range (0.2,1) anchor_x = (0, 2000) diff --git a/lab4d/engine/trainer.py b/lab4d/engine/trainer.py index 06f1e3a..1b37805 100644 --- a/lab4d/engine/trainer.py +++ b/lab4d/engine/trainer.py @@ -140,7 +140,7 @@ def get_lr_dict(self): ".logibeta": lr_explicit, ".logsigma": lr_explicit, ".logscale": lr_explicit, - ".log_gauss": lr_explicit, + ".log_gauss": 0.0, ".base_quat": lr_explicit, ".base_logfocal": lr_explicit, ".base_ppoint": lr_explicit, @@ -160,9 +160,6 @@ def optimizer_init(self, is_resumed=False): param_lr_startwith, param_lr_with = self.get_lr_dict() - if opts["freeze_bone_len"]: - param_lr_with[".log_bone_len"] = 0 - params_list = [] lr_list = [] for name, p in self.model.named_parameters(): diff --git a/lab4d/nnutils/embedding.py b/lab4d/nnutils/embedding.py index 92c33b4..0263d67 100644 --- a/lab4d/nnutils/embedding.py +++ b/lab4d/nnutils/embedding.py @@ -150,6 +150,7 @@ def __init__(self, num_freq_t, frame_info, out_channels=128, time_scale=1.0): self.out_channels = out_channels self.frame_offset = frame_info["frame_offset"] + self.frame_offset_raw = frame_info["frame_offset_raw"] self.num_frames = self.frame_offset[-1] self.num_vids = len(self.frame_offset) - 1 @@ -201,7 +202,11 @@ def forward(self, frame_id=None): if frame_id is None: inst_id, t_sample = self.frame_to_vid, self.frame_to_tid(self.frame_mapping) else: - inst_id = self.raw_fid_to_vid[frame_id.long()] + if torch.is_tensor(frame_id): + frame_id = frame_id.long() + else: + frame_id = torch.tensor(frame_id, dtype=torch.long) + inst_id = self.raw_fid_to_vid[frame_id] t_sample = self.frame_to_tid(frame_id) if inst_id.ndim == 1: diff --git a/lab4d/nnutils/multifields.py b/lab4d/nnutils/multifields.py index 1271a0e..bfd76e8 100644 --- a/lab4d/nnutils/multifields.py +++ b/lab4d/nnutils/multifields.py @@ -12,6 +12,7 @@ from lab4d.nnutils.warping import ComposedWarp, SkinningWarp from lab4d.utils.quat_transform import quaternion_translation_to_se3 from lab4d.utils.vis_utils import draw_cams, mesh_cat +from lab4d.utils.geom_utils import extend_aabb class MultiFields(nn.Module): @@ -194,7 +195,8 @@ def export_geometry_aux(self, path): mesh_cam = draw_cams(rtmat) mesh = mesh_cat(mesh_geo, mesh_cam) if category == "fg": - mesh_gauss, mesh_sdf = field.warp.get_template_vis(aabb=field.aabb) + aabb = extend_aabb(field.aabb, factor=0.5) + mesh_gauss, mesh_sdf = field.warp.get_template_vis(aabb=aabb) mesh_gauss.export("%s-%s-gauss.obj" % (path, category)) mesh_sdf.export("%s-%s-sdf.obj" % (path, category)) mesh.export("%s-%s-proxy.obj" % (path, category)) diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index 6a3735c..0e9bc78 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -1026,4 +1026,5 @@ def get_field2world(self, inst_id=None): field2world = self.field2world else: field2world = self.field2world[inst_id] + field2world[..., :3, 3] /= self.logscale.exp() return field2world diff --git a/lab4d/nnutils/pose.py b/lab4d/nnutils/pose.py index e64bcfd..bb3cc72 100644 --- a/lab4d/nnutils/pose.py +++ b/lab4d/nnutils/pose.py @@ -13,7 +13,7 @@ matrix_to_quaternion, quaternion_mul, quaternion_translation_to_dual_quaternion, - dual_quaternion_to_quaternion_translation, + dual_quaternion_mul, quaternion_translation_to_se3, ) from lab4d.utils.skel_utils import ( @@ -21,7 +21,7 @@ get_predefined_skeleton, rest_joints_to_local, shift_joints_to_bones_dq, - shift_joints_to_bones, + apply_root_offset, ) from lab4d.utils.vis_utils import draw_cams @@ -462,11 +462,13 @@ def forward( # run forward kinematics out = self.fk_se3(local_rest_joints, so3, self.edges) - out = shift_joints_to_bones_dq( - out, self.edges, shift=self.shift, orient=self.orient - ) + out = self.shift_joints_to_bones(out) + out = apply_root_offset(out, self.shift, self.orient) return out + def shift_joints_to_bones(self, se3): + return shift_joints_to_bones_dq(se3, self.edges) + def compute_rel_rest_joints(self, inst_id=None, override_log_bone_len=None): """Compute relative position difference from parent to child bone coordinate frames, without scale @@ -649,10 +651,19 @@ def __init__( activation=activation, ) - local_rest_coord, scale_factor, orient, offset = self.parse_urdf(skel_type) + ( + local_rest_coord, + scale_factor, + orient, + offset, + bone_centers, + bone_sizes, + ) = self.parse_urdf(skel_type) self.logscale.data = torch.log(scale_factor) - self.shift.data = offset + self.shift.data = offset # same scale as object field self.orient.data = orient + self.register_buffer("bone_centers", bone_centers, persistent=False) + self.register_buffer("bone_sizes", bone_sizes, persistent=False) # get local rest rotation matrices, pick the first coordinate in rpy of ball joints # by default: transform points from child to parent @@ -679,9 +690,25 @@ def parse_urdf(self, urdf_name): scale_factor = torch.tensor([0.05]) else: raise NotImplementedError - orient = F.normalize(orient, dim=-1) - return local_rest_coord, scale_factor, orient, offset + + # get center/size of each link + bone_centers = [] + bone_sizes = [] + for link in urdf._reverse_topo: + if len(link.visuals) == 0: + continue + bone_bounds = link.collision_mesh.bounds + center = (bone_bounds[1] + bone_bounds[0]) / 2 + size = (bone_bounds[1] - bone_bounds[0]) / 2 + center = torch.tensor(center, dtype=torch.float) + size = torch.tensor(size, dtype=torch.float) + bone_centers.append(center) + bone_sizes.append(size) + + bone_centers = torch.stack(bone_centers, dim=0)[1:] # skip root + bone_sizes = torch.stack(bone_sizes, dim=0)[1:] # skip root + return local_rest_coord, scale_factor, orient, offset, bone_centers, bone_sizes def fk_se3(self, local_rest_joints, so3, edges): return fk_se3( @@ -693,3 +720,14 @@ def fk_se3(self, local_rest_joints, so3, edges): def rest_joints_to_local(self, rest_joints, edges): return self.local_rest_coord[:, :3, 3].clone() + + def shift_joints_to_bones(self, bone_to_obj): + idn_quat = torch.zeros_like(bone_to_obj[0]) + idn_quat[..., 0] = 1.0 + bone_centers = self.bone_centers.expand_as(idn_quat[..., :3]) + bone_centers = bone_centers * self.logscale.exp().clone() + link_transform = quaternion_translation_to_dual_quaternion( + idn_quat, bone_centers + ) + bone_to_obj = dual_quaternion_mul(bone_to_obj, link_transform) + return bone_to_obj diff --git a/lab4d/nnutils/skinning.py b/lab4d/nnutils/skinning.py index 6227c3d..4e45c7d 100644 --- a/lab4d/nnutils/skinning.py +++ b/lab4d/nnutils/skinning.py @@ -58,11 +58,16 @@ def __init__( ): super().__init__() - # 3D gaussians - gaussians = init_scale * torch.ones( - num_coords, 3 - ) # scale of bone skinning field + # 3D gaussians: scale of bone skinning field + if torch.is_tensor(init_scale): + gaussians = init_scale + else: + gaussians = init_scale * torch.ones(num_coords, 3) + # clip minimum radius to 0.01 + gaussians = torch.clamp(gaussians, min=0.01) self.log_gauss = nn.Parameter(torch.log(gaussians)) + # self.register_buffer("log_gauss", torch.log(gaussians), persistent=False) + self.logscale = nn.Parameter(torch.zeros(1)) self.num_coords = num_coords if delta_skin: @@ -115,8 +120,10 @@ def forward(self, xyz, bone2obj, frame_id, inst_id): t_embed = t_embed.expand(xyz.shape[:-1] + (-1,)) xyzt_embed = torch.cat([xyz_embed, t_embed], dim=-1) delta = self.delta_field(xyzt_embed, inst_id) - delta = F.relu(delta) * 0.1 - skin = -(dist2 + delta) + # delta = F.relu(delta) * 0.1 + # skin = -(dist2 + delta) + dist2 = dist2 * (0.1 * delta).exp() + skin = -dist2 else: skin = -dist2 delta = None @@ -150,6 +157,7 @@ def get_gauss(self): log_gauss = self.log_gauss if self.symm_idx is not None: log_gauss = (log_gauss[self.symm_idx] + log_gauss) / 2 + log_gauss = log_gauss + self.logscale return log_gauss.exp() def draw_gaussian(self, articulation, edges): diff --git a/lab4d/nnutils/warping.py b/lab4d/nnutils/warping.py index 5842c85..beb42be 100644 --- a/lab4d/nnutils/warping.py +++ b/lab4d/nnutils/warping.py @@ -277,6 +277,9 @@ def __init__( self.articulation = ArticulationURDFMLP(frame_info, skel_type, joint_angles) num_se3 = self.articulation.num_se3 symm_idx = self.articulation.symm_idx + init_gauss_scale = ( + self.articulation.bone_sizes * self.articulation.logscale.exp() + ) else: raise NotImplementedError @@ -383,19 +386,17 @@ def get_gauss_density(self, xyz, bone2obj=None): if bone2obj is None: bone2obj = self.articulation.get_mean_vals() # 1,K,4,4 - dist2 = get_xyz_bone_distance(xyz, bone2obj) # N,K - dist2 = dist2 / (0.01) ** 2 # assuming spheres of radius 0.01 - - # # gauss bones - # xyz = xyz[:, None, None] # (N,1,1,3) - # bone2obj = ( - # bone2obj[0][None, None].repeat(xyz.shape[0], 1, 1, 1, 1), - # bone2obj[1][None, None].repeat(xyz.shape[0], 1, 1, 1, 1), - # ) # (N,1,1,K,4) - # dist2 = -self.skinning_model.forward( - # xyz, bone2obj, None, None, normalize=False - # )[0][:, 0, 0] - + if isinstance(self.articulation, ArticulationURDFMLP): + # gauss bones + skinning + xyz = xyz[:, None, None] # (N,1,1,3) + bone2obj = ( + bone2obj[0][None, None].repeat(xyz.shape[0], 1, 1, 1, 1), + bone2obj[1][None, None].repeat(xyz.shape[0], 1, 1, 1, 1), + ) # (N,1,1,K,4) + dist2 = -self.skinning_model.forward(xyz, bone2obj, None, None)[0][:, 0, 0] + else: + dist2 = get_xyz_bone_distance(xyz, bone2obj) # N,K + dist2 = dist2 / (0.01) ** 2 # assuming spheres of radius 0.01 score = (-0.5 * dist2).exp() # (N,K) # hard selection diff --git a/lab4d/utils/skel_utils.py b/lab4d/utils/skel_utils.py index 3aaa227..779955b 100644 --- a/lab4d/utils/skel_utils.py +++ b/lab4d/utils/skel_utils.py @@ -4,6 +4,7 @@ import cv2 import numpy as np import torch +import torch.nn.functional as F from lab4d.utils.geom_utils import so3_to_exp_map from lab4d.utils.quat_transform import ( @@ -107,7 +108,7 @@ def fk_se3(local_rest_joints, so3, edges, to_dq=True, local_rest_coord=None): return global_rt -def shift_joints_to_bones_dq(dq, edges, shift=None, orient=None): +def shift_joints_to_bones_dq(dq, edges): """Compute bone centers and orientations from joint locations Args: @@ -121,15 +122,6 @@ def shift_joints_to_bones_dq(dq, edges, shift=None, orient=None): quat, joints = dual_quaternion_to_quaternion_translation(dq) joints = shift_joints_to_bones(joints, edges) dq = quaternion_translation_to_dual_quaternion(quat, joints) - if shift is not None and orient is not None: - shift = shift.reshape((1,) * (joints[0].ndim - 1) + (3,)) - shift = shift.expand(*joints.shape[:-1], -1) - orient = orient.reshape((1,) * (joints[0].ndim - 1) + (4,)) - orient = orient.expand(*joints.shape[:-1], -1) - offset_dq = quaternion_translation_to_dual_quaternion(orient, shift) - dq = dual_quaternion_mul(offset_dq, dq) - else: - raise ValueError("shift and orient cannot be both None") return dq @@ -153,6 +145,31 @@ def shift_joints_to_bones(joints, edges): return joints +def apply_root_offset(dq, shift, orient): + """Compute bone centers and orientations from joint locations + + Args: + dq: ((..., B, 4), (..., B, 4)) Location of each joint, written as dual + quaternions + edges (Dict(int, int)): Maps each joint to its parent joint + Returns: + dq: ((..., B, 4), (..., B, 4)) Bone-to-object SE(3) transforms, + written as dual quaternions + """ + # normliaze the quaternion + orient = F.normalize(orient, 2, dim=-1) + ndim = dq[0].ndim + shape = dq[0].shape + shift = shift.reshape((1,) * (ndim - 1) + (3,)) + shift = shift.expand(*shape[:-1], -1) + orient = orient.reshape((1,) * (ndim - 1) + (4,)) + orient = orient.expand(*shape[:-1], -1) + offset_dq = quaternion_translation_to_dual_quaternion(orient, shift) + dq = dual_quaternion_mul(offset_dq, dq) + + return dq + + def get_predefined_skeleton(skel_type): """Compute pre-defined skeletons @@ -246,7 +263,7 @@ def get_predefined_skeleton(skel_type): 22: 0, # right hip 2: 1, # spine 2 3: 2, # spine 3 - 4: 3, # spine 4 + 4: 3, # head 5: 3, # left shoulder 9: 3, # right shoulder 6: 5, # left elbow diff --git a/lab4d/utils/transforms.py b/lab4d/utils/transforms.py index c25e23d..2a6c7cb 100644 --- a/lab4d/utils/transforms.py +++ b/lab4d/utils/transforms.py @@ -21,6 +21,11 @@ def get_bone_coords(xyz, bone2obj): # reshape xyz = xyz[..., None, :].expand(xyz.shape[:-1] + (bone2obj[0].shape[-2], 3)).clone() + expand_shape = xyz.shape[:-2] + (-1, -1) + obj2bone = ( + obj2bone[0].expand(expand_shape).clone(), + obj2bone[1].expand(expand_shape).clone(), + ) xyz_bone = dual_quaternion_apply(obj2bone, xyz) return xyz_bone @@ -29,11 +34,11 @@ def get_xyz_bone_distance(xyz, bone2obj): """Compute squared distances from points to bone centers Argss: - xyz: (..., 3) Points in object canonical space - bone2obj: ((..., B, 4), (..., B, 4)) Bone-to-object SE(3) transforms, written as dual quaternions + xyz: (M, 3) Points in object canonical space + bone2obj: ((M, B, 4), (M, B, 4)) Bone-to-object SE(3) transforms, written as dual quaternions Returns: - dist2: (..., B) Squared distance to each bone center + dist2: (M, B) Squared distance to each bone center """ _, center = dual_quaternion_to_quaternion_translation(bone2obj) dist2 = (xyz[..., None, :] - center).pow(2).sum(-1) # M, K diff --git a/projects/ppr/config.py b/projects/ppr/config.py index 6a28dbe..b9aa5c1 100644 --- a/projects/ppr/config.py +++ b/projects/ppr/config.py @@ -10,7 +10,7 @@ class PPRConfig: # configs related to ppr flags.DEFINE_string("urdf_template", "", "whether to use predefined skeleton") flags.DEFINE_float("ratio_phys_cycle", 0.2, "number of iterations per round") - flags.DEFINE_integer("phys_wdw_len", 24, "length of the physics opt window") + flags.DEFINE_integer("phys_wdw_len", 72, "length of the physics opt window") flags.DEFINE_integer("phys_batch", 20, "number of parallel physics sim") flags.DEFINE_string( "phys_vid", "0", "whether to optimize selected videos, e.g., 0,1,2" diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index 7d6e2b6..196628e 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit 7d6e2b6bdbe8ee176e4890d5dc64c404e4f176c9 +Subproject commit 196628e3bb9f039545e6006101fc09cef63977da diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 0f67788..bc76f38 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -3,6 +3,7 @@ import pdb import torch import numpy as np +import tqdm from lab4d.engine.trainer import Trainer from lab4d.engine.trainer import get_local_rank @@ -40,6 +41,27 @@ def define_model(self): self.phys_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.phys_model) self.phys_model = self.phys_model.to(self.device) + def get_lr_dict(self): + """Return the learning rate for each category of trainable parameters + + Returns: + param_lr_startwith (Dict(str, float)): Learning rate for base model + param_lr_with (Dict(str, float)): Learning rate for explicit params + """ + # define a dict for (tensor_name, learning) pair + param_lr_startwith, param_lr_with = super().get_lr_dict() + opts = self.opts + + param_lr_with.update( + { + "module.fields.field_params.fg.basefield.": 0.0, + "module.fields.field_params.fg.colorfield.": 0.0, + "module.fields.field_params.fg.sdf.": 0.0, + "module.fields.field_params.fg.rgb.": 0.0, + } + ) + return param_lr_startwith, param_lr_with + def init_model(self): """Initialize camera transforms, geometry, articulations, and camera intrinsics from external priors, if this is the first run""" @@ -56,14 +78,14 @@ def trainer_init(self): print("# iterations per phys cycle:", self.iters_per_phys_cycle) def run_one_round(self, round_count): - # run dr cycle - super().run_one_round(round_count) # transfer pharameters self.phys_model.override_states() # run physics cycle self.run_phys_cycle() # transfer pharameters self.phys_model.override_states_inv() + # run dr cycle + super().run_one_round(round_count) def run_phys_cycle(self): opts = self.opts @@ -71,6 +93,7 @@ def run_phys_cycle(self): # eval self.phys_model.eval() + self.phys_model.correct_foot_position() self.run_phys_visualization(tag="kinematics") # train @@ -78,7 +101,7 @@ def run_phys_cycle(self): self.phys_model.reinit_envs( opts["phys_batch"], wdw_length=opts["phys_wdw_len"], is_eval=False ) - for i in range(self.iters_per_phys_cycle): + for i in tqdm.tqdm(range(self.iters_per_phys_cycle)): self.phys_model.set_progress(self.current_steps_phys) self.run_phys_iter() self.current_steps_phys += 1 @@ -92,16 +115,23 @@ def run_phys_iter(self): phys_aux = self.phys_model() self.phys_model.backward(phys_aux["total_loss"]) self.phys_model.update() + if get_local_rank() == 0: + del phys_aux["total_loss"] + self.add_scalar(self.log, phys_aux, self.current_steps_phys) def run_phys_visualization(self, tag=""): opts = self.opts - self.phys_model.reinit_envs(1, wdw_length=30, is_eval=True) + frame_offset_raw = self.phys_model.frame_offset_raw + vid_frame_max = max(frame_offset_raw[1:] - frame_offset_raw[:-1]) + self.phys_model.reinit_envs(1, wdw_length=vid_frame_max, is_eval=True) for vidid in opts["phys_vid"]: - frame_start = torch.zeros(1) + self.phys_model.data_offset[vidid] + frame_start = torch.zeros(1) + frame_offset_raw[vidid] _ = self.phys_model(frame_start=frame_start.to(self.device)) img_size = tuple(self.data_info["raw_size"][vidid][::-1]) img_size = img_size + (0.5,) # scale data = self.phys_model.query(img_size=img_size) self.phys_visualizer.show( - "%s-%02d-%05d" % (tag, vidid, self.current_steps_phys), data + "%s-%02d-%05d" % (tag, vidid, self.current_steps_phys), + data, + fps=1.0 / self.phys_model.frame_interval, ) From 4c130eb4ca281a633a78775f4d9d7a90452072ce Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Tue, 15 Aug 2023 20:22:18 -0400 Subject: [PATCH 21/86] improve bg fitting --- lab4d/nnutils/embedding.py | 3 +- lab4d/nnutils/nerf.py | 3 +- lab4d/utils/geom_utils.py | 89 ++++++++++++++++++++++++++++++++++---- projects/ppr/trainer.py | 12 +++-- 4 files changed, 93 insertions(+), 14 deletions(-) diff --git a/lab4d/nnutils/embedding.py b/lab4d/nnutils/embedding.py index 0263d67..78978e9 100644 --- a/lab4d/nnutils/embedding.py +++ b/lab4d/nnutils/embedding.py @@ -199,13 +199,14 @@ def forward(self, frame_id=None): Returns: t_embed (..., self.W): Output time embeddings """ + device = self.parameters().__next__().device if frame_id is None: inst_id, t_sample = self.frame_to_vid, self.frame_to_tid(self.frame_mapping) else: if torch.is_tensor(frame_id): frame_id = frame_id.long() else: - frame_id = torch.tensor(frame_id, dtype=torch.long) + frame_id = torch.tensor(frame_id, dtype=torch.long, device=device) inst_id = self.raw_fid_to_vid[frame_id] t_sample = self.frame_to_tid(frame_id) diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index 0e9bc78..ea3311c 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -1013,7 +1013,8 @@ def compute_field2world(self): """ for inst_id in range(self.num_inst): # TODO: move this to background nerf, and use each proxy geometry - self.field2world[inst_id] = compute_rectification_se3(self.proxy_geometry) + mesh = self.extract_canonical_mesh(level=0.005) + self.field2world[inst_id] = compute_rectification_se3(mesh) def get_field2world(self, inst_id=None): """Compute SE(3) to transform points in the scene space to world space diff --git a/lab4d/utils/geom_utils.py b/lab4d/utils/geom_utils.py index 77d1f62..79cf9d0 100644 --- a/lab4d/utils/geom_utils.py +++ b/lab4d/utils/geom_utils.py @@ -527,18 +527,91 @@ def compute_rectification_se3(mesh, threshold=0.01, init_n=3, iter=1000): dist = (center * plane_n).sum() + best_eq[3] plane_o = center - plane_n * dist plane = np.concatenate([plane_o, plane_n]) - bg2xy = trimesh.geometry.plane_transform(origin=plane[:3], normal=plane[3:6]) - # to xz - xy2xz = np.eye(4) - xy2xz[:3, :3] = cv2.Rodrigues(np.asarray([-np.pi / 2, 0, 0]))[0] - xy2xz[:3, :3] = cv2.Rodrigues(np.asarray([0, -np.pi / 2, 0]))[0] @ xy2xz[:3, :3] - bg2world = xy2xz @ bg2xy # coplanar with xy->xz plane - - # mesh.apply_transform(bg2world) # DEBUG only + + # xz plane + bg2world = plane_transform(origin=plane[:3], normal=plane[3:6], axis=[0, 1, 0]) + + # mesh.export("tmp/raw.obj") + # mesh.apply_transform(bg2world) # DEBUG only + # mesh.export("tmp/rect.obj") + bg2world = torch.Tensor(bg2world) return bg2world +def plane_transform(origin, normal, axis=[0, 1, 0]): + """ + # modified from https://github.com/mikedh/trimesh/blob/main/trimesh/geometry.py#L14 + Given the origin and normal of a plane find the transform + that will move that plane to be coplanar with the XZ plane. + Parameters + ---------- + origin : (3,) float + Point that lies on the plane + normal : (3,) float + Vector that points along normal of plane + Returns + --------- + transform: (4,4) float + Transformation matrix to move points onto XZ plane + """ + transform = align_vectors(normal, axis) + if origin is not None: + transform[:3, 3] = -np.dot(transform, np.append(origin, 1))[:3] + return transform + + +def align_vectors(a, b, return_angle=False): + """ + # modified from https://github.com/mikedh/trimesh/blob/main/trimesh/geometry.py#L38 + Find the rotation matrix that transforms one 3D vector + to another. + Parameters + ------------ + a : (3,) float + Unit vector + b : (3,) float + Unit vector + return_angle : bool + Return the angle between vectors or not + Returns + ------------- + matrix : (4, 4) float + Homogeneous transform to rotate from `a` to `b` + angle : float + If `return_angle` angle in radians between `a` and `b` + """ + a = np.array(a, dtype=np.float64) + b = np.array(b, dtype=np.float64) + if a.shape != (3,) or b.shape != (3,): + raise ValueError("vectors must be (3,)!") + + # find the SVD of the two vectors + au = np.linalg.svd(a.reshape((-1, 1)))[0] + bu = np.linalg.svd(b.reshape((-1, 1)))[0] + + if np.linalg.det(au) < 0: + au[:, -1] *= -1.0 + if np.linalg.det(bu) < 0: + bu[:, -1] *= -1.0 + + # put rotation into homogeneous transformation + matrix = np.eye(4) + matrix[:3, :3] = bu.dot(au.T) + + if return_angle: + # projection of a onto b + # first row of SVD result is normalized source vector + dot = np.dot(au[0], bu[0]) + # clip to avoid floating point error + angle = np.arccos(np.clip(dot, -1.0, 1.0)) + if dot < -1e-5: + angle += np.pi + return matrix, angle + + return matrix + + def se3_inv(rtmat): """Invert an SE(3) matrix diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index bc76f38..45a1b57 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -23,9 +23,6 @@ def define_model(self): opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] opts["urdf_template"] = opts["fg_motion"].split("-")[1].split("_")[0] - # re-initialize field2world transforms - self.model.fields.field_params["bg"].compute_field2world() - # model model_dict = {} model_dict["scene_field"] = self.model.fields.field_params["bg"] @@ -58,6 +55,10 @@ def get_lr_dict(self): "module.fields.field_params.fg.colorfield.": 0.0, "module.fields.field_params.fg.sdf.": 0.0, "module.fields.field_params.fg.rgb.": 0.0, + "module.fields.field_params.bg.basefield.": 0.0, + "module.fields.field_params.bg.colorfield.": 0.0, + "module.fields.field_params.bg.sdf.": 0.0, + "module.fields.field_params.bg.rgb.": 0.0, } ) return param_lr_startwith, param_lr_with @@ -65,7 +66,8 @@ def get_lr_dict(self): def init_model(self): """Initialize camera transforms, geometry, articulations, and camera intrinsics from external priors, if this is the first run""" - super().init_model() + # super().init_model() + return def trainer_init(self): super().trainer_init() @@ -78,6 +80,8 @@ def trainer_init(self): print("# iterations per phys cycle:", self.iters_per_phys_cycle) def run_one_round(self, round_count): + # re-initialize field2world transforms + self.model.fields.field_params["bg"].compute_field2world() # transfer pharameters self.phys_model.override_states() # run physics cycle From d2eee3e5979db36c3844100af01f200b92b34850 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Wed, 16 Aug 2023 10:03:46 -0400 Subject: [PATCH 22/86] remove frame filtering --- lab4d/utils/geom_utils.py | 6 ++++-- preprocess/scripts/extract_frames.py | 12 +++++++++--- projects/ppr/config.py | 2 +- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 8 ++++++-- scripts/run_preprocess.py | 2 +- 6 files changed, 22 insertions(+), 10 deletions(-) diff --git a/lab4d/utils/geom_utils.py b/lab4d/utils/geom_utils.py index 79cf9d0..c7b0b4d 100644 --- a/lab4d/utils/geom_utils.py +++ b/lab4d/utils/geom_utils.py @@ -510,12 +510,13 @@ def check_inside_aabb(xyz, aabb): return inside_aabb -def compute_rectification_se3(mesh, threshold=0.01, init_n=3, iter=1000): +def compute_rectification_se3(mesh, threshold=0.01, init_n=3, iter=2000): # run ransac to get plane pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(mesh.vertices) best_eq, index = pcd.segment_plane(threshold, init_n, iter) segmented_points = pcd.select_by_index(index) + print("segmented floor points: ", len(segmented_points.points) / len(mesh.vertices)) # point upside if best_eq[1] < 0: @@ -531,8 +532,9 @@ def compute_rectification_se3(mesh, threshold=0.01, init_n=3, iter=1000): # xz plane bg2world = plane_transform(origin=plane[:3], normal=plane[3:6], axis=[0, 1, 0]) + # # DEBUG only # mesh.export("tmp/raw.obj") - # mesh.apply_transform(bg2world) # DEBUG only + # mesh.apply_transform(bg2world) # mesh.export("tmp/rect.obj") bg2world = torch.Tensor(bg2world) diff --git a/preprocess/scripts/extract_frames.py b/preprocess/scripts/extract_frames.py index 612cfb1..7943f53 100644 --- a/preprocess/scripts/extract_frames.py +++ b/preprocess/scripts/extract_frames.py @@ -6,10 +6,16 @@ import numpy as np -def extract_frames(in_path, out_path): +def extract_frames(in_path, out_path, desired_fps=10): print("extracting frames: ", in_path) # Open the video file reader = imageio.get_reader(in_path) + original_fps = reader.get_meta_data()["fps"] + + # If a desired frame rate is given, calculate the frame skip rate + skip_rate = 1 + if desired_fps: + skip_rate = int(original_fps / desired_fps) # Find the first non-black frame for i, im in enumerate(reader): @@ -17,10 +23,10 @@ def extract_frames(in_path, out_path): start_frame = i break - # Write the video starting from the first non-black frame + # Write the video starting from the first non-black frame, considering the desired frame rate count = 0 for i, im in enumerate(reader): - if i >= start_frame: + if i >= start_frame and i % skip_rate == 0: imageio.imsave("%s/%05d.jpg" % (out_path, count), im) count += 1 diff --git a/projects/ppr/config.py b/projects/ppr/config.py index b9aa5c1..29c3508 100644 --- a/projects/ppr/config.py +++ b/projects/ppr/config.py @@ -10,7 +10,7 @@ class PPRConfig: # configs related to ppr flags.DEFINE_string("urdf_template", "", "whether to use predefined skeleton") flags.DEFINE_float("ratio_phys_cycle", 0.2, "number of iterations per round") - flags.DEFINE_integer("phys_wdw_len", 72, "length of the physics opt window") + flags.DEFINE_float("phys_wdw_len", 2.4, "length of the physics opt window in secs") flags.DEFINE_integer("phys_batch", 20, "number of parallel physics sim") flags.DEFINE_string( "phys_vid", "0", "whether to optimize selected videos, e.g., 0,1,2" diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index 196628e..abdcf18 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit 196628e3bb9f039545e6006101fc09cef63977da +Subproject commit abdcf1838010f9d0d22f8de1c6e012d897dd61e7 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 45a1b57..44efde0 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -103,7 +103,10 @@ def run_phys_cycle(self): # train self.phys_model.train() self.phys_model.reinit_envs( - opts["phys_batch"], wdw_length=opts["phys_wdw_len"], is_eval=False + opts["phys_batch"], + frames_per_wdw=int(opts["phys_wdw_len"] / self.phys_model.frame_interval) + + 1, + is_eval=False, ) for i in tqdm.tqdm(range(self.iters_per_phys_cycle)): self.phys_model.set_progress(self.current_steps_phys) @@ -113,6 +116,7 @@ def run_phys_cycle(self): # eval again self.phys_model.eval() self.run_phys_visualization(tag="phys") + torch.cuda.empty_cache() def run_phys_iter(self): """Run physics optimization""" @@ -127,7 +131,7 @@ def run_phys_visualization(self, tag=""): opts = self.opts frame_offset_raw = self.phys_model.frame_offset_raw vid_frame_max = max(frame_offset_raw[1:] - frame_offset_raw[:-1]) - self.phys_model.reinit_envs(1, wdw_length=vid_frame_max, is_eval=True) + self.phys_model.reinit_envs(1, frames_per_wdw=vid_frame_max, is_eval=True) for vidid in opts["phys_vid"]: frame_start = torch.zeros(1) + frame_offset_raw[vidid] _ = self.phys_model(frame_start=frame_start.to(self.device)) diff --git a/scripts/run_preprocess.py b/scripts/run_preprocess.py index 3ee8318..ead99e2 100644 --- a/scripts/run_preprocess.py +++ b/scripts/run_preprocess.py @@ -93,7 +93,7 @@ def run_extract_priors(seqname, outdir, obj_class): # True: manually annotate camera for key frames use_manual_cameras = True if obj_class == "other" else False # True: filter frame based on motion magnitude | False: use all frames - use_filter_frames = True + use_filter_frames = False outdir = "database/processed/" viddir = "database/raw/%s" % vidname From 67769b78b16d4d9d51817e9044b18fafe0fec0cb Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Thu, 17 Aug 2023 14:19:39 -0400 Subject: [PATCH 23/86] update bg nerf; load fgbg --- lab4d/engine/trainer.py | 3 +- lab4d/nnutils/embedding.py | 8 ++--- lab4d/nnutils/multifields.py | 2 +- lab4d/nnutils/nerf.py | 2 +- lab4d/nnutils/pose.py | 2 +- projects/ppr/config.py | 12 ++++--- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 70 ++++++++++++++++++++++++------------ 8 files changed, 64 insertions(+), 37 deletions(-) diff --git a/lab4d/engine/trainer.py b/lab4d/engine/trainer.py index 1b37805..c9d3beb 100644 --- a/lab4d/engine/trainer.py +++ b/lab4d/engine/trainer.py @@ -594,10 +594,11 @@ def check_grad(self, thresh=5.0): if grad_norm > thresh: # clear gradients self.optimizer.zero_grad() + print("large grad: %.2f, clear gradients" % grad_norm) # load cached model from two rounds ago if self.model_cache[0] is not None: if get_local_rank() == 0: - print("large grad: %.2f, resume from cached weights" % grad_norm) + print("fallback to cached model") self.model.load_state_dict(self.model_cache[0]) self.optimizer.load_state_dict(self.optimizer_cache[0]) self.scheduler.load_state_dict(self.scheduler_cache[0]) diff --git a/lab4d/nnutils/embedding.py b/lab4d/nnutils/embedding.py index 78978e9..e3adcfe 100644 --- a/lab4d/nnutils/embedding.py +++ b/lab4d/nnutils/embedding.py @@ -203,11 +203,9 @@ def forward(self, frame_id=None): if frame_id is None: inst_id, t_sample = self.frame_to_vid, self.frame_to_tid(self.frame_mapping) else: - if torch.is_tensor(frame_id): - frame_id = frame_id.long() - else: - frame_id = torch.tensor(frame_id, dtype=torch.long, device=device) - inst_id = self.raw_fid_to_vid[frame_id] + if not torch.is_tensor(frame_id): + frame_id = torch.tensor(frame_id, device=device) + inst_id = self.raw_fid_to_vid[frame_id.long()] t_sample = self.frame_to_tid(frame_id) if inst_id.ndim == 1: diff --git a/lab4d/nnutils/multifields.py b/lab4d/nnutils/multifields.py index bfd76e8..60642d7 100644 --- a/lab4d/nnutils/multifields.py +++ b/lab4d/nnutils/multifields.py @@ -90,7 +90,7 @@ def define_field(self, category, data_info, tracklet_id): num_freq_xyz=6, num_freq_dir=0, appr_channels=0, - init_scale=0.1, + init_scale=0.2, ) else: # exit with an error raise ValueError("Invalid category") diff --git a/lab4d/nnutils/nerf.py b/lab4d/nnutils/nerf.py index ea3311c..fa3be81 100644 --- a/lab4d/nnutils/nerf.py +++ b/lab4d/nnutils/nerf.py @@ -301,7 +301,7 @@ def geometry_init(self, sdf_fn, nsample=256): def update_proxy(self): """Extract proxy geometry using marching cubes""" mesh = self.extract_canonical_mesh(level=0.005) - if mesh is not None: + if len(mesh.vertices) > 3: self.proxy_geometry = mesh @torch.no_grad() diff --git a/lab4d/nnutils/pose.py b/lab4d/nnutils/pose.py index bb3cc72..5e02771 100644 --- a/lab4d/nnutils/pose.py +++ b/lab4d/nnutils/pose.py @@ -83,6 +83,7 @@ def __init__( self.register_buffer( "init_vals", torch.tensor(rtmat, dtype=torch.float32), persistent=False ) + self.base_init() # override the loss function def loss_fn(gt): @@ -103,7 +104,6 @@ def base_init(self): def mlp_init(self): """Initialize camera SE(3) transforms from external priors""" - self.base_init() super().mlp_init() # with torch.no_grad(): diff --git a/projects/ppr/config.py b/projects/ppr/config.py index 29c3508..28935d5 100644 --- a/projects/ppr/config.py +++ b/projects/ppr/config.py @@ -10,19 +10,21 @@ class PPRConfig: # configs related to ppr flags.DEFINE_string("urdf_template", "", "whether to use predefined skeleton") flags.DEFINE_float("ratio_phys_cycle", 0.2, "number of iterations per round") - flags.DEFINE_float("phys_wdw_len", 2.4, "length of the physics opt window in secs") - flags.DEFINE_integer("phys_batch", 20, "number of parallel physics sim") + flags.DEFINE_float("secs_per_wdw", 2.4, "length of the physics opt window in secs") flags.DEFINE_string( "phys_vid", "0", "whether to optimize selected videos, e.g., 0,1,2" ) # weights - flags.DEFINE_float("traj_wt", 0.1, "weight for traj matching loss") - flags.DEFINE_float("pos_state_wt", 0.1, "weight for position matching reg") + flags.DEFINE_float("traj_wt", 0.01, "weight for traj matching loss") + flags.DEFINE_float("pos_state_wt", 0.01, "weight for position matching reg") flags.DEFINE_float("vel_state_wt", 0.0, "weight for velocity matching reg") # regs flags.DEFINE_float("reg_torque_wt", 0.0, "weight for torque regularization") - flags.DEFINE_float("reg_res_f_wt", 0.0, "weight for residual force regularization") + flags.DEFINE_float("reg_res_f_wt", 2e-5, "weight for residual force regularization") flags.DEFINE_float("reg_foot_wt", 0.0, "weight for foot contact regularization") flags.DEFINE_float("reg_root_wt", 0.0, "weight for root pose regularization") + + # io-related + flags.DEFINE_string("load_path_bg", "", "path to load pretrained model") diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index abdcf18..f318da4 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit abdcf1838010f9d0d22f8de1c6e012d897dd61e7 +Subproject commit f318da4aae9fdcf9f72bfb72afbc99d301cdd82a diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 44efde0..0b6ca94 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -15,6 +15,31 @@ class PPRTrainer(Trainer): + def __init__(self, opts): + """Train and evaluate a Lab4D model. + + Args: + opts (Dict): Command-line args from absl (defined in lab4d/config.py) + """ + super().__init__(opts) + self.model.fields.field_params["bg"].compute_field2world() + + def trainer_init(self): + super().trainer_init() + + opts = self.opts + self.current_steps_phys = 0 # 0-total_steps + self.iters_per_phys_cycle = int( + opts["ratio_phys_cycle"] * opts["iters_per_round"] + ) + print("# iterations per phys cycle:", self.iters_per_phys_cycle) + + def init_model(self): + """Initialize camera transforms, geometry, articulations, and camera + intrinsics from external priors, if this is the first run""" + # super().init_model() + return + def define_model(self): super().define_model() @@ -38,6 +63,18 @@ def define_model(self): self.phys_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.phys_model) self.phys_model = self.phys_model.to(self.device) + def load_checkpoint_train(self): + """Load a checkpoint at training time and update the current step count + and round count + """ + super().load_checkpoint_train() + # training time + if self.opts["load_path_bg"] != "": + _ = self.load_checkpoint(self.opts["load_path_bg"], self.model) + + # reset near_far + self.model.fields.reset_geometry_aux() + def get_lr_dict(self): """Return the learning rate for each category of trainable parameters @@ -63,25 +100,7 @@ def get_lr_dict(self): ) return param_lr_startwith, param_lr_with - def init_model(self): - """Initialize camera transforms, geometry, articulations, and camera - intrinsics from external priors, if this is the first run""" - # super().init_model() - return - - def trainer_init(self): - super().trainer_init() - - opts = self.opts - self.current_steps_phys = 0 # 0-total_steps - self.iters_per_phys_cycle = int( - opts["ratio_phys_cycle"] * opts["iters_per_round"] - ) - print("# iterations per phys cycle:", self.iters_per_phys_cycle) - def run_one_round(self, round_count): - # re-initialize field2world transforms - self.model.fields.field_params["bg"].compute_field2world() # transfer pharameters self.phys_model.override_states() # run physics cycle @@ -102,10 +121,14 @@ def run_phys_cycle(self): # train self.phys_model.train() + # to use the same amount memory: batch * time_per_wdw = 2.4*20 = 48 + num_envs = int(48 / opts["secs_per_wdw"]) + frames_per_wdw = int(opts["secs_per_wdw"] / self.phys_model.frame_interval) + 1 + print("num_envs:", num_envs) + print("frames_per_wdw:", frames_per_wdw) self.phys_model.reinit_envs( - opts["phys_batch"], - frames_per_wdw=int(opts["phys_wdw_len"] / self.phys_model.frame_interval) - + 1, + num_envs, + frames_per_wdw=frames_per_wdw, is_eval=False, ) for i in tqdm.tqdm(range(self.iters_per_phys_cycle)): @@ -116,7 +139,6 @@ def run_phys_cycle(self): # eval again self.phys_model.eval() self.run_phys_visualization(tag="phys") - torch.cuda.empty_cache() def run_phys_iter(self): """Run physics optimization""" @@ -143,3 +165,7 @@ def run_phys_visualization(self, tag=""): data, fps=1.0 / self.phys_model.frame_interval, ) + + def save_checkpoint(self, round_count): + super().save_checkpoint(round_count) + self.phys_model.save_checkpoint(round_count) From 85a1b0adaf4186c5c70d02001fb206f40287efdf Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Thu, 17 Aug 2023 18:13:51 -0400 Subject: [PATCH 24/86] glue in phys reg; remove param transfer --- lab4d/engine/trainer.py | 4 +-- projects/ppr/config.py | 5 +-- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 69 ++++++++++++++++++++++++++++++++------- 4 files changed, 63 insertions(+), 17 deletions(-) diff --git a/lab4d/engine/trainer.py b/lab4d/engine/trainer.py index c9d3beb..8eddd8a 100644 --- a/lab4d/engine/trainer.py +++ b/lab4d/engine/trainer.py @@ -101,13 +101,13 @@ def init_model(self): if get_local_rank() == 0: self.model.mlp_init() - def define_model(self): + def define_model(self, model=dvr_model): """Define a Lab4D model and wrap it with DistributedDataParallel""" opts = self.opts data_info = self.data_info self.device = torch.device("cuda:{}".format(get_local_rank())) - self.model = dvr_model(opts, data_info) + self.model = model(opts, data_info) # ddp self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) diff --git a/projects/ppr/config.py b/projects/ppr/config.py index 28935d5..c2230e2 100644 --- a/projects/ppr/config.py +++ b/projects/ppr/config.py @@ -16,8 +16,8 @@ class PPRConfig: ) # weights - flags.DEFINE_float("traj_wt", 0.01, "weight for traj matching loss") - flags.DEFINE_float("pos_state_wt", 0.01, "weight for position matching reg") + flags.DEFINE_float("traj_wt", 2e-3, "weight for traj matching loss") + flags.DEFINE_float("pos_state_wt", 0.0, "weight for position matching reg") flags.DEFINE_float("vel_state_wt", 0.0, "weight for velocity matching reg") # regs @@ -25,6 +25,7 @@ class PPRConfig: flags.DEFINE_float("reg_res_f_wt", 2e-5, "weight for residual force regularization") flags.DEFINE_float("reg_foot_wt", 0.0, "weight for foot contact regularization") flags.DEFINE_float("reg_root_wt", 0.0, "weight for root pose regularization") + flags.DEFINE_float("reg_phys_wt", 1e-2, "weight for soft physics regularization") # io-related flags.DEFINE_string("load_path_bg", "", "path to load pretrained model") diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index f318da4..a8884b1 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit f318da4aae9fdcf9f72bfb72afbc99d301cdd82a +Subproject commit a8884b15acdcb3c89ad5b31a928256f936469896 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index 0b6ca94..adfb7c8 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -7,11 +7,53 @@ from lab4d.engine.trainer import Trainer from lab4d.engine.trainer import get_local_rank +from lab4d.engine.model import dvr_model from ppr import config sys.path.insert(0, "%s/ppr-diffphys" % os.path.join(os.path.dirname(__file__))) -from diffphys.dp_interface import phys_interface +from diffphys.dp_interface import phys_interface, query_q from diffphys.vis import Logger +from diffphys.dp_utils import se3_loss + + +class dvr_phys_reg(dvr_model): + """A model that contains a collection of static/deformable neural fields + + Args: + config (Dict): Command-line args + data_info (Dict): Dataset metadata from get_data_info() + """ + + @torch.no_grad() + def copy_phys_traj(self, phys_model): + self.steps_fr = torch.arange(phys_model.total_frames, device=self.device) + self.phys_q = phys_model.root_pose_mlp(self.steps_fr) # N, 7 + self.phys_ja = phys_model.joint_angle_mlp(self.steps_fr) + + def forward(self, batch): + loss_dict = super().forward(batch) + reg_phys = self.compute_kinemaics_phys_diff() + reg_phys = self.config["reg_phys_wt"] * reg_phys + loss_dict["phys_reg"] = reg_phys + return loss_dict + + def compute_kinemaics_phys_diff(self): + """ + compute the difference between the target kinematics and kinematics estimated by physics proxy + """ + object_field = self.fields.field_params["fg"] + scene_field = self.fields.field_params["bg"] + kinematics_q = query_q(self.steps_fr, object_field, scene_field)[0] + kinematics_ja = object_field.warp.articulation.get_vals( + self.steps_fr, return_so3=True + ) + + loss_q = se3_loss(self.phys_q, kinematics_q).mean() + loss_ja = (self.phys_ja - kinematics_ja).pow(2).mean() + # print("loss_q:", loss_q) + # print("loss_ja:", loss_ja) + loss = 1e-2 * loss_q + loss_ja + return loss class PPRTrainer(Trainer): @@ -21,6 +63,9 @@ def __init__(self, opts): Args: opts (Dict): Command-line args from absl (defined in lab4d/config.py) """ + opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] + opts["urdf_template"] = opts["fg_motion"].split("-")[1].split("_")[0] + super().__init__(opts) self.model.fields.field_params["bg"].compute_field2world() @@ -40,14 +85,10 @@ def init_model(self): # super().init_model() return - def define_model(self): - super().define_model() - + def define_model(self, model=dvr_phys_reg): + super().define_model(model=model) # opts opts = self.opts - opts["phys_vid"] = [int(i) for i in opts["phys_vid"].split(",")] - opts["urdf_template"] = opts["fg_motion"].split("-")[1].split("_")[0] - # model model_dict = {} model_dict["scene_field"] = self.model.fields.field_params["bg"] @@ -101,12 +142,15 @@ def get_lr_dict(self): return param_lr_startwith, param_lr_with def run_one_round(self, round_count): - # transfer pharameters - self.phys_model.override_states() + if round_count == 0: + # initialize control input of phys model to kinematics + self.phys_model.override_states() # run physics cycle self.run_phys_cycle() - # transfer pharameters - self.phys_model.override_states_inv() + # # transfer phys-optimized kinematics to dvr + # self.phys_model.override_states_inv() + # transfer hys-optimized kinematics to dvr as soft constriaints + self.model.copy_phys_traj(self.phys_model) # run dr cycle super().run_one_round(round_count) @@ -117,7 +161,8 @@ def run_phys_cycle(self): # eval self.phys_model.eval() self.phys_model.correct_foot_position() - self.run_phys_visualization(tag="kinematics") + if self.current_round == 0: + self.run_phys_visualization(tag="kinematics") # train self.phys_model.train() From 54df5bc5bba7cf1635a23cd9bbe85d7cf0f5767f Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Thu, 17 Aug 2023 22:40:28 -0400 Subject: [PATCH 25/86] update vis freq --- lab4d/engine/trainer.py | 8 ++++++-- projects/ppr/config.py | 6 +++--- projects/ppr/trainer.py | 34 ++++++++++++++++++++++++---------- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/lab4d/engine/trainer.py b/lab4d/engine/trainer.py index 8eddd8a..7d75b0d 100644 --- a/lab4d/engine/trainer.py +++ b/lab4d/engine/trainer.py @@ -238,9 +238,8 @@ def run_one_round(self, round_count): Args: round_count (int): Current round index """ - self.model.eval() if get_local_rank() == 0: - with torch.no_grad(): + if round_count == 0: self.model_eval() self.model.update_geometry_aux() @@ -251,6 +250,9 @@ def run_one_round(self, round_count): self.current_round += 1 self.save_checkpoint(round_count=self.current_round) + if get_local_rank() == 0: + self.model_eval() + def save_checkpoint(self, round_count): """Save model checkpoint to disk @@ -396,8 +398,10 @@ def print_sum_params(self): sum += p.abs().sum() print(f"{sum:.16f}") + @torch.no_grad() def model_eval(self): """Evaluate the current model""" + self.model.eval() torch.cuda.empty_cache() ref_dict, batch = self.load_batch(self.evalloader.dataset, self.eval_fid) self.construct_eval_batch(batch) diff --git a/projects/ppr/config.py b/projects/ppr/config.py index c2230e2..ec0910c 100644 --- a/projects/ppr/config.py +++ b/projects/ppr/config.py @@ -16,8 +16,8 @@ class PPRConfig: ) # weights - flags.DEFINE_float("traj_wt", 2e-3, "weight for traj matching loss") - flags.DEFINE_float("pos_state_wt", 0.0, "weight for position matching reg") + flags.DEFINE_float("traj_wt", 5e-3, "weight for traj matching loss") + flags.DEFINE_float("pos_state_wt", 2e-4, "weight for position matching reg") flags.DEFINE_float("vel_state_wt", 0.0, "weight for velocity matching reg") # regs @@ -25,7 +25,7 @@ class PPRConfig: flags.DEFINE_float("reg_res_f_wt", 2e-5, "weight for residual force regularization") flags.DEFINE_float("reg_foot_wt", 0.0, "weight for foot contact regularization") flags.DEFINE_float("reg_root_wt", 0.0, "weight for root pose regularization") - flags.DEFINE_float("reg_phys_wt", 1e-2, "weight for soft physics regularization") + flags.DEFINE_float("reg_phys_wt", 5e-2, "weight for soft physics regularization") # io-related flags.DEFINE_string("load_path_bg", "", "path to load pretrained model") diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index adfb7c8..ce3a0ce 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -26,9 +26,13 @@ class dvr_phys_reg(dvr_model): @torch.no_grad() def copy_phys_traj(self, phys_model): - self.steps_fr = torch.arange(phys_model.total_frames, device=self.device) - self.phys_q = phys_model.root_pose_mlp(self.steps_fr) # N, 7 - self.phys_ja = phys_model.joint_angle_mlp(self.steps_fr) + phys_traj = {} + phys_traj["steps_fr"] = torch.arange( + phys_model.total_frames, device=self.device + ) + phys_traj["phys_q"] = phys_model.root_pose_mlp(phys_traj["steps_fr"]) # N, 7 + phys_traj["phys_ja"] = phys_model.joint_angle_mlp(phys_traj["steps_fr"]) + self.phys_traj = phys_traj def forward(self, batch): loss_dict = super().forward(batch) @@ -41,15 +45,21 @@ def compute_kinemaics_phys_diff(self): """ compute the difference between the target kinematics and kinematics estimated by physics proxy """ + if not hasattr(self, "phys_traj"): + return torch.zeros(1).to(self.device).mean() + steps_fr = self.phys_traj["steps_fr"] + phys_q = self.phys_traj["phys_q"] + phys_ja = self.phys_traj["phys_ja"] + object_field = self.fields.field_params["fg"] scene_field = self.fields.field_params["bg"] - kinematics_q = query_q(self.steps_fr, object_field, scene_field)[0] + kinematics_q = query_q(steps_fr, object_field, scene_field)[0] kinematics_ja = object_field.warp.articulation.get_vals( - self.steps_fr, return_so3=True + steps_fr, return_so3=True ) - loss_q = se3_loss(self.phys_q, kinematics_q).mean() - loss_ja = (self.phys_ja - kinematics_ja).pow(2).mean() + loss_q = se3_loss(phys_q, kinematics_q).mean() + loss_ja = (phys_ja - kinematics_ja).pow(2).mean() # print("loss_q:", loss_q) # print("loss_ja:", loss_ja) loss = 1e-2 * loss_q + loss_ja @@ -74,6 +84,7 @@ def trainer_init(self): opts = self.opts self.current_steps_phys = 0 # 0-total_steps + self.current_round_phys = 0 # 0-total_rounds self.iters_per_phys_cycle = int( opts["ratio_phys_cycle"] * opts["iters_per_round"] ) @@ -143,6 +154,8 @@ def get_lr_dict(self): def run_one_round(self, round_count): if round_count == 0: + # run dr cycle + super().run_one_round(round_count) # initialize control input of phys model to kinematics self.phys_model.override_states() # run physics cycle @@ -153,15 +166,15 @@ def run_one_round(self, round_count): self.model.copy_phys_traj(self.phys_model) # run dr cycle super().run_one_round(round_count) + self.current_round_phys += 1 def run_phys_cycle(self): opts = self.opts torch.cuda.empty_cache() # eval - self.phys_model.eval() self.phys_model.correct_foot_position() - if self.current_round == 0: + if self.current_round_phys == 0: self.run_phys_visualization(tag="kinematics") # train @@ -182,7 +195,6 @@ def run_phys_cycle(self): self.current_steps_phys += 1 # eval again - self.phys_model.eval() self.run_phys_visualization(tag="phys") def run_phys_iter(self): @@ -194,7 +206,9 @@ def run_phys_iter(self): del phys_aux["total_loss"] self.add_scalar(self.log, phys_aux, self.current_steps_phys) + @torch.no_grad() def run_phys_visualization(self, tag=""): + self.phys_model.eval() opts = self.opts frame_offset_raw = self.phys_model.frame_offset_raw vid_frame_max = max(frame_offset_raw[1:] - frame_offset_raw[:-1]) From 3cc8af94d5b91473901b4aad94d46f7df9c8ba8e Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Fri, 18 Aug 2023 10:17:40 -0400 Subject: [PATCH 26/86] modify ckpt order; large step; mod loss --- projects/ppr/config.py | 5 +++-- projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 23 +++++++++++------------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/projects/ppr/config.py b/projects/ppr/config.py index ec0910c..0ead84e 100644 --- a/projects/ppr/config.py +++ b/projects/ppr/config.py @@ -16,7 +16,7 @@ class PPRConfig: ) # weights - flags.DEFINE_float("traj_wt", 5e-3, "weight for traj matching loss") + flags.DEFINE_float("traj_wt", 0.01, "weight for traj matching loss") flags.DEFINE_float("pos_state_wt", 2e-4, "weight for position matching reg") flags.DEFINE_float("vel_state_wt", 0.0, "weight for velocity matching reg") @@ -25,7 +25,8 @@ class PPRConfig: flags.DEFINE_float("reg_res_f_wt", 2e-5, "weight for residual force regularization") flags.DEFINE_float("reg_foot_wt", 0.0, "weight for foot contact regularization") flags.DEFINE_float("reg_root_wt", 0.0, "weight for root pose regularization") - flags.DEFINE_float("reg_phys_wt", 5e-2, "weight for soft physics regularization") + flags.DEFINE_float("reg_phys_q_wt", 1e-3, "weight for soft physics regularization") + flags.DEFINE_float("reg_phys_ja_wt", 0.1, "weight for soft physics regularization") # io-related flags.DEFINE_string("load_path_bg", "", "path to load pretrained model") diff --git a/projects/ppr/ppr-diffphys b/projects/ppr/ppr-diffphys index a8884b1..7abcc8a 160000 --- a/projects/ppr/ppr-diffphys +++ b/projects/ppr/ppr-diffphys @@ -1 +1 @@ -Subproject commit a8884b15acdcb3c89ad5b31a928256f936469896 +Subproject commit 7abcc8a4f3a22b27c8b493ad0d7c51387a89ad89 diff --git a/projects/ppr/trainer.py b/projects/ppr/trainer.py index ce3a0ce..c0c6c0c 100644 --- a/projects/ppr/trainer.py +++ b/projects/ppr/trainer.py @@ -36,9 +36,9 @@ def copy_phys_traj(self, phys_model): def forward(self, batch): loss_dict = super().forward(batch) - reg_phys = self.compute_kinemaics_phys_diff() - reg_phys = self.config["reg_phys_wt"] * reg_phys - loss_dict["phys_reg"] = reg_phys + reg_phys_q, reg_phys_ja = self.compute_kinemaics_phys_diff() + loss_dict["phys_q_reg"] = self.config["reg_phys_q_wt"] * reg_phys_q + loss_dict["phys_ja_reg"] = self.config["reg_phys_ja_wt"] * reg_phys_ja return loss_dict def compute_kinemaics_phys_diff(self): @@ -46,7 +46,10 @@ def compute_kinemaics_phys_diff(self): compute the difference between the target kinematics and kinematics estimated by physics proxy """ if not hasattr(self, "phys_traj"): - return torch.zeros(1).to(self.device).mean() + return ( + torch.zeros(1).to(self.device).mean(), + torch.zeros(1).to(self.device).mean(), + ) steps_fr = self.phys_traj["steps_fr"] phys_q = self.phys_traj["phys_q"] phys_ja = self.phys_traj["phys_ja"] @@ -62,8 +65,7 @@ def compute_kinemaics_phys_diff(self): loss_ja = (phys_ja - kinematics_ja).pow(2).mean() # print("loss_q:", loss_q) # print("loss_ja:", loss_ja) - loss = 1e-2 * loss_q + loss_ja - return loss + return loss_q, loss_ja class PPRTrainer(Trainer): @@ -107,7 +109,8 @@ def define_model(self, model=dvr_phys_reg): model_dict["intrinsics"] = self.model.intrinsics # define phys model - self.phys_model = phys_interface(opts, model_dict) + self.phys_model = phys_interface(opts, model_dict, dt=1e-3) + # self.phys_model = phys_interface(opts, model_dict) self.phys_visualizer = Logger(opts) # move model to device @@ -119,13 +122,9 @@ def load_checkpoint_train(self): """Load a checkpoint at training time and update the current step count and round count """ - super().load_checkpoint_train() - # training time if self.opts["load_path_bg"] != "": _ = self.load_checkpoint(self.opts["load_path_bg"], self.model) - - # reset near_far - self.model.fields.reset_geometry_aux() + super().load_checkpoint_train() def get_lr_dict(self): """Return the learning rate for each category of trainable parameters From b3c0e5d09947b1395fc51a61a0b6fb96ab17ce99 Mon Sep 17 00:00:00 2001 From: Gengshan Yang Date: Sat, 19 Aug 2023 00:56:30 -0400 Subject: [PATCH 27/86] partially finished camera mlp; --- lab4d/nnutils/embedding.py | 3 + lab4d/nnutils/nerf.py | 3 +- lab4d/nnutils/pose.py | 163 +++++++++++++++++++++++++++++++++++++ projects/ppr/ppr-diffphys | 2 +- projects/ppr/trainer.py | 13 +-- 5 files changed, 173 insertions(+), 11 deletions(-) diff --git a/lab4d/nnutils/embedding.py b/lab4d/nnutils/embedding.py index e3adcfe..6182619 100644 --- a/lab4d/nnutils/embedding.py +++ b/lab4d/nnutils/embedding.py @@ -171,6 +171,9 @@ def __init__(self, num_freq_t, frame_info, out_channels=128, time_scale=1.0): ) # M, in range [0,N-1], M