diff --git a/autokoopman/core/scoring.py b/autokoopman/core/scoring.py index 894a6d1..dace94e 100644 --- a/autokoopman/core/scoring.py +++ b/autokoopman/core/scoring.py @@ -30,7 +30,8 @@ def weighted_score( absdiff = (prediction_data - true_data).abs() end_errors = np.array( - [norm(weights_f[n] * s.states, axis=1) for n, s in absdiff._trajs.items()] + [norm(weights_f[n] * s.states, axis=1) for n, s in absdiff._trajs.items()], + dtype=object, ) return np.sum(np.concatenate(end_errors, axis=0)) @@ -43,7 +44,7 @@ def end_point_score(true_data: TrajectoriesData, prediction_data: TrajectoriesDa @staticmethod def total_score(true_data: TrajectoriesData, prediction_data: TrajectoriesData): errors = (prediction_data - true_data).norm() - end_errors = np.array([s.states.flatten() for s in errors]) + end_errors = np.array([s.states.flatten() for s in errors], dtype=object) return np.mean(np.concatenate(end_errors, axis=0)) @staticmethod