Skip to content

Commit

Permalink
sacrifice speed for mem
Browse files Browse the repository at this point in the history
  • Loading branch information
jytime committed Jun 26, 2024
1 parent 89052b2 commit 8c47df8
Showing 1 changed file with 57 additions and 18 deletions.
75 changes: 57 additions & 18 deletions vggsfm/utils/triangulation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,32 @@ def triangulate_multi_view_point_batched(
A = torch.einsum("bnij,bnik->bjk", terms, terms)

# Compute eigenvalues and eigenvectors
try:
# try:
# _, eigenvectors = torch.linalg.eigh(A)
# except:
# print("Meet CUSOLVER_STATUS_INVALID_VALUE ERROR during torch.linalg.eigh()")
# print("SWITCH TO torch.linalg.eig()")
# _, eigenvectors = torch.linalg.eig(A)
# eigenvectors = torch.real(eigenvectors)


# Compute eigenvalues and eigenvectors
num_A_batch = len(A)
MAX_CUSOLVER_STATUS_INVALID_VALUE = 1024000
if num_A_batch>MAX_CUSOLVER_STATUS_INVALID_VALUE:
print("A too big matrix for torch.linalg.eigh(); It will meet CUSOLVER_STATUS_INVALID_VALUE ERROR; Make it happy by spliting the matrix to several ones")
num_runs = math.ceil(num_A_batch/MAX_CUSOLVER_STATUS_INVALID_VALUE)
eigenvectors_list = []
for run_idx in range(num_runs):
start_idx = run_idx * MAX_CUSOLVER_STATUS_INVALID_VALUE
end_idx = (run_idx+1) * MAX_CUSOLVER_STATUS_INVALID_VALUE
_, eigenvectors = torch.linalg.eigh(A[start_idx:end_idx])
eigenvectors_list.append(eigenvectors)
eigenvectors = torch.cat(eigenvectors_list)
else:
_, eigenvectors = torch.linalg.eigh(A)
except:
print("Meet CUSOLVER_STATUS_INVALID_VALUE ERROR during torch.linalg.eigh()")
print("SWITCH TO torch.linalg.eig()")
_, eigenvectors = torch.linalg.eig(A)
eigenvectors = torch.real(eigenvectors)



# Select the first eigenvector
first_eigenvector = eigenvectors[:, :, 0]
Expand Down Expand Up @@ -384,15 +403,15 @@ def generate_combinations(N):
return comb_array


def local_refinement_tri(points1, extrinsics, inlier_mask, sorted_indices, lo_num=50):

def local_refinement_tri(points1, extrinsics, inlier_mask, sorted_indices, lo_num=50, low_mem=True):
"""
Local Refinement for triangulation
"""
B, N, _ = points1.shape
batch_index = torch.arange(B).unsqueeze(-1).expand(-1, lo_num)

points1_expand = points1.unsqueeze(1).expand(-1, lo_num, -1, -1)
extrinsics_expand = extrinsics.unsqueeze(1).expand(-1, lo_num, -1, -1, -1)

# The sets selected for local refinement
lo_indices = sorted_indices[:, :lo_num]
Expand All @@ -402,18 +421,38 @@ def local_refinement_tri(points1, extrinsics, inlier_mask, sorted_indices, lo_nu
lo_points1 = torch.zeros_like(points1_expand)
lo_points1[lo_mask] = points1_expand[lo_mask]

lo_points1 = lo_points1.reshape(B * lo_num, N, -1)
lo_mask = lo_mask.reshape(B * lo_num, N)
lo_extrinsics = extrinsics_expand.reshape(B * lo_num, N, 3, 4)

# triangulate the inliers
triangulated_points, tri_angles, invalid_che_mask = triangulate_multi_view_point_batched(
lo_extrinsics, lo_points1, mask=lo_mask, compute_tri_angle=True, check_cheirality=True
)

triangulated_points = triangulated_points.reshape(B, lo_num, 3)
tri_angles = tri_angles.reshape(B, lo_num, -1)
if low_mem:
all_triangulated_points = []
all_tri_angles = []
all_invalid_che_mask = []

invalid_che_mask = invalid_che_mask.reshape(B, lo_num)
for loidx in range(lo_num):
triangulated_points, tri_angles, invalid_che_mask = triangulate_multi_view_point_batched(
extrinsics, lo_points1[:, loidx], mask=lo_mask[:, loidx], compute_tri_angle=True, check_cheirality=True
)
# Append the outputs to the respective lists
all_triangulated_points.append(triangulated_points[:, None])
all_tri_angles.append(tri_angles[:, None])
all_invalid_che_mask.append(invalid_che_mask[:,None])

triangulated_points = torch.cat(all_triangulated_points, dim=1)
tri_angles = torch.cat(all_tri_angles, dim=1)
invalid_che_mask = torch.cat(all_invalid_che_mask, dim=1)
else:
extrinsics_expand = extrinsics.unsqueeze(1).expand(-1, lo_num, -1, -1, -1)
lo_points1 = lo_points1.reshape(B * lo_num, N, -1)
lo_mask = lo_mask.reshape(B * lo_num, N)
lo_extrinsics = extrinsics_expand.reshape(B * lo_num, N, 3, 4)

# triangulate the inliers
triangulated_points, tri_angles, invalid_che_mask = triangulate_multi_view_point_batched(
lo_extrinsics, lo_points1, mask=lo_mask, compute_tri_angle=True, check_cheirality=True
)

triangulated_points = triangulated_points.reshape(B, lo_num, 3)
tri_angles = tri_angles.reshape(B, lo_num, -1)
invalid_che_mask = invalid_che_mask.reshape(B, lo_num)

return triangulated_points, tri_angles, invalid_che_mask

0 comments on commit 8c47df8

Please sign in to comment.