Skip to content

Commit

Permalink
Added Nguyen benchmark and a method to list datasets in a benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
smeznar committed Dec 18, 2024
1 parent 48fbd04 commit 4530066
Showing 1 changed file with 24 additions and 30 deletions.
54 changes: 24 additions & 30 deletions SRToolkit/dataset/srbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def create_dataset(self, dataset_name: str):
if dataset_name in self.datasets:
# Check if dataset exists otherwise download it from an url
if os.path.exists(self.datasets[dataset_name]["path"]):
data = np.load(self.datasets[dataset_name]["path"] + ".npy")
data = np.load(self.datasets[dataset_name]["path"])
else:
raise ValueError(f"Could not find dataset {dataset_name} at {self.datasets[dataset_name]['path']}")

Expand All @@ -64,6 +64,23 @@ def create_dataset(self, dataset_name: str):
else:
raise ValueError(f"Dataset {dataset_name} not found")

def list_datasets(self, verbose=True):
datasets = [dataset_name for dataset_name in self.datasets]
sorted(datasets, key= lambda dataset_name: self.datasets[dataset_name]["num_variables"], reverse=True)

if verbose:
for d in datasets:
if self.datasets[d]["num_variables"] == 1:
variable_str = "1 variable"
elif self.datasets[d]["num_variables"] < 1:
variable_str = "Amount of variables unknown"
else:
variable_str = f"{self.datasets[d]['num_variables']} variables"

print(f"{d}:\t{variable_str}, \tExpression: {self.datasets[d]['original_equation']}")
return datasets


@staticmethod
def download_benchmark_data(url, directory_path):
# Check if directory_path exist
Expand Down Expand Up @@ -96,10 +113,10 @@ def nguyen(dataset_directory: str):
sl_1v.add_symbol("sin", symbol_type="fn", precedence=5, np_fn="{} = np.sin({})")
sl_1v.add_symbol("cos", symbol_type="fn", precedence=5, np_fn="{} = np.cos({})")
sl_1v.add_symbol("exp", symbol_type="fn", precedence=5, np_fn="{} = np.exp({})")
sl_1v.add_symbol("ln", symbol_type="fn", precedence=5, np_fn="{} = np.ln({})")
sl_1v.add_symbol("log", symbol_type="fn", precedence=5, np_fn="{} = np.log10({})")
sl_1v.add_symbol("sqrt", symbol_type="fn", precedence=5, np_fn="{} = np.sqrt({})")
sl_1v.add_symbol("^2", symbol_type="fn", precedence=5, np_fn="{} = np.pow({}, 2)")
sl_1v.add_symbol("^3", symbol_type="fn", precedence=5, np_fn="{} = np.pow({}, 3)")
sl_1v.add_symbol("^2", symbol_type="fn", precedence=5, np_fn="{} = np.power({}, 2)")
sl_1v.add_symbol("^3", symbol_type="fn", precedence=5, np_fn="{} = np.power({}, 3)")
sl_1v.add_symbol("X_0", "var", 5, "X[:, 0]")

sl_2v = copy.copy(sl_1v)
Expand Down Expand Up @@ -152,29 +169,6 @@ def nguyen(dataset_directory: str):


if __name__ == '__main__':
# benchmark = SRBenchmark.nguyen("../../data/nguyen")
# a = 0
from SRToolkit.utils.expression_compiler import expr_to_executable_function

equations = [["X_0", "+", "X_0", "^2", "+", "X_0", "^3"],
["X_0", "+", "X_0", "^2", "+", "X_0", "^3", "+", "X_0", "*", "X_0", "^3"],
["X_0", "+", "X_0", "^2", "+", "X_0", "^3", "+", "X_0", "*", "X_0", "^3", "+", "X_0", "^2", "*", "X_0", "^3"],
["X_0", "+", "X_0", "^2", "+", "X_0", "^3", "+", "X_0", "*", "X_0", "^3", "+", "X_0", "^2", "*", "X_0", "^3", "+", "X_0", "^3", "*", "X_0", "^3"],
["sin", "(", "X_0", "^2", ")", "*", "cos", "(", "X_0", ")", "-", "1"],
["sin", "(", "X_0", ")", "+", "sin", "(", "X_0", "+", "X_0", "^2", ")"],
["log", "(", "1", "+", "X_0", ")", "+", "log", "(", "1", "+", "X_0", "^2", ")"],
["sqrt", "(", "X_0", ")"],
["sin", "(", "X_0", ")", "+", "sin", "(", "X_1", "^2", ")"],
["2", "*", "sin", "(", "X_0", ")", "*", "cos", "(", "X_1", ")"]]

bounds = [(-20, 20), (-20, 20), (-20, 20), (-20, 20), (-20, 20), (-20, 20), (1, 100), (0, 100), (-20, 20), (-20, 20)]

for i, eq in enumerate(equations):
exec_fun = expr_to_executable_function(eq)
if i < 8:
x = np.random.random((10000, 1)) * (bounds[i][1] - bounds[i][0]) + bounds[i][0]
else:
x = np.random.random((10000, 2)) * (bounds[i][1] - bounds[i][0]) + bounds[i][0]
y = exec_fun(x, None)

np.save(f"../../data/Nguyen/NG-{i+1}.npy", np.concatenate([x, y[:, np.newaxis]], axis=1))
benchmark = SRBenchmark.nguyen("../../data/nguyen")
benchmark.list_datasets()
a = 0

0 comments on commit 4530066

Please sign in to comment.