From 6679ada88803faa0c198cb88918432a4fb032fe7 Mon Sep 17 00:00:00 2001 From: Marc Lanctot Date: Mon, 8 Jan 2024 17:33:26 +0000 Subject: [PATCH] Simplify and make Kemeny voting implementation faster. PiperOrigin-RevId: 596618517 Change-Id: I9156c7a7cda417337b616d9ec9d84ca4e597af12 --- open_spiel/python/voting/kemeny_young.py | 41 +++++++++++++----------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/open_spiel/python/voting/kemeny_young.py b/open_spiel/python/voting/kemeny_young.py index 10f9419778..add159dad2 100644 --- a/open_spiel/python/voting/kemeny_young.py +++ b/open_spiel/python/voting/kemeny_young.py @@ -33,38 +33,43 @@ def name(self) -> str: def _score( self, - alternatives: List[base.AlternativeId], pref_mat: np.ndarray, perm: Tuple[int, ...], - ) -> Tuple[List[base.AlternativeId], int, np.ndarray]: + ) -> np.ndarray: # The score of alternative a_i in a ranking R is defined to be: # KemenyScore(a_i) = sum_{a_j s.t. R(a_i) >= R(a_j)} N(a_i, a_j) # The score of ranking R is then sum_i KemenyScore(a_i). num_alts = len(perm) scores = np.zeros(num_alts, dtype=np.int32) - ranking = [] for i in range(num_alts): - alt_idx_i = perm[i] for j in range(i+1, num_alts): - alt_idx_j = perm[j] - value = pref_mat[alt_idx_i, alt_idx_j] - scores[i] += value - ranking.append(alternatives[alt_idx_i]) - return (ranking, scores.sum(), scores) + scores[i] += pref_mat[perm[i], perm[j]] + return scores + + def _permutation_to_ranking( + self, + alternatives: List[base.AlternativeId], + permutation: Tuple[base.AlternativeId, ...]) -> List[base.AlternativeId]: + assert len(permutation) == len(alternatives) + return [alternatives[permutation[i]] for i in range(len(alternatives))] def run_election(self, profile: base.PreferenceProfile) -> base.RankOutcome: assert self.is_valid_profile(profile) pref_mat = profile.pref_matrix() alternatives = profile.alternatives m = profile.num_alternatives() - # ranking info is tuples of (ranking, total_score, scores list) - best_ranking_info = (None, 0, []) - for perm in itertools.permutations(range(m)): - # perm is a permutation of alternative indices - ranking_info = self._score(alternatives, pref_mat, perm) - if ranking_info[1] > best_ranking_info[1]: - best_ranking_info = ranking_info - outcome = base.RankOutcome(rankings=best_ranking_info[0], - scores=list(best_ranking_info[2])) + best_permutation = None + best_score = -1 + best_score_array = None + for permutation in itertools.permutations(range(m)): + scores = self._score(pref_mat, permutation) + total_score = scores.sum() + if total_score > best_score: + best_score = total_score + best_score_array = scores + best_permutation = permutation + best_ranking = self._permutation_to_ranking(alternatives, best_permutation) + outcome = base.RankOutcome(rankings=best_ranking, + scores=list(best_score_array)) return outcome