Skip to content

Commit

Permalink
Avoid crash when no keypoints (#92)
Browse files Browse the repository at this point in the history
Occurs if no keypoints are detected in either of the input images or if
they get all pruned.
  • Loading branch information
sarlinpe authored Jan 24, 2024
1 parent be49528 commit 4959b59
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions lightglue/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, allow_flash: bool) -> None:
torch.backends.cuda.enable_flash_sdp(allow_flash)

def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if q.shape[-2] == 0 or k.shape[-2] == 0:
return q.new_zeros((*q.shape[:-1], v.shape[-1]))
if self.enable_flash and q.device.type == "cuda":
# use torch 2.0 scaled_dot_product_attention with flash
if self.has_sdp:
Expand Down Expand Up @@ -523,6 +525,8 @@ def _forward(self, data: dict) -> dict:
prune1 = torch.ones_like(ind1)
token0, token1 = None, None
for i in range(self.conf.n_layers):
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
break
desc0, desc1 = self.transformers[i](
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
)
Expand All @@ -531,7 +535,7 @@ def _forward(self, data: dict) -> dict:

if do_early_stop:
token0, token1 = self.token_confidence[i](desc0, desc1)
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
break
if do_point_pruning and desc0.shape[-2] > pruning_th:
scores0 = self.log_assignment[i].get_matchability(desc0)
Expand All @@ -550,7 +554,29 @@ def _forward(self, data: dict) -> dict:
encoding1 = encoding1.index_select(-2, keep1)
prune1[:, ind1] += 1

desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
m0 = desc0.new_full((b, m), -1, dtype=torch.long)
m1 = desc1.new_full((b, n), -1, dtype=torch.long)
mscores0 = desc0.new_zeros((b, m))
mscores1 = desc1.new_zeros((b, n))
matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
mscores = desc0.new_empty((b, 0))
if not do_point_pruning:
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
"matching_scores1": mscores1,
"stop": i + 1,
"matches": matches,
"scores": mscores,
"prune0": prune0,
"prune1": prune1,
}

desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
scores, _ = self.log_assignment[i](desc0, desc1)
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
matches, mscores = [], []
Expand Down Expand Up @@ -579,7 +605,7 @@ def _forward(self, data: dict) -> dict:
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
prune1 = torch.ones_like(mscores1) * self.conf.n_layers

pred = {
return {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
Expand All @@ -591,8 +617,6 @@ def _forward(self, data: dict) -> dict:
"prune1": prune1,
}

return pred

def confidence_threshold(self, layer_index: int) -> float:
"""scaled confidence threshold"""
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
Expand Down

0 comments on commit 4959b59

Please sign in to comment.