Skip to content

Commit

Permalink
Updates env, tests and utils
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Apr 15, 2024
1 parent a862b4b commit bf6ed01
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
39 changes: 25 additions & 14 deletions rl_core/environments/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,23 @@ def __init__(self, primitives: list[str] = None, max_number_of_nodes: int = 10,

def _get_obs(self) -> np.ndarray:
""" Returns current environment's observation """
# graph_structure = self._get_graph_structure()
graph_structure = self._get_graph_structure()

node_structure = self._apply_one_hot_encoding(self._nodes_structure, self.number_of_primitives + 1)
edge_structure = self._edges_structure
# node_structure = self._apply_one_hot_encoding(self._nodes_structure, self.number_of_primitives + 1)
# edge_structure = self._edges_structure

# if self._meta_data is not None:
# obs = np.concatenate((graph_structure, self._meta_data))
# else:
# obs = graph_structure
#
if self._meta_data is not None:
obs = np.concatenate((graph_structure, self._meta_data))
else:
obs = graph_structure

obs = {
'meta': self._meta_data,
'nodes': node_structure,
'edges': edge_structure
}

# For sb3 models
# obs = {
# 'meta': self._meta_data,
# 'nodes': node_structure,
# 'edges': edge_structure
# }

return obs

Expand Down Expand Up @@ -247,8 +248,18 @@ def step(self, action: int, mode: str = 'train') -> (np.ndarray, int, bool, bool

assert action in self.action_space

# Checks if action is not valid
if not action in self._get_available_actions().keys():
terminated = False
truncated = False
reward = -10

self.env_step += 1
observation = self._get_obs()
info = self._get_info()

# Checks if action is from special actions (e.g. eof - End of Pipeline)
if action in self._special_action.keys():
elif action in self._special_action.keys():
terminated, truncated, reward = self._apply_eop_action()

observation = self._get_obs()
Expand Down
10 changes: 5 additions & 5 deletions rl_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.model_selection import train_test_split

from meta_automl.utils import project_root
from rl_core.dataloader import DataLoader_TS
from rl_core.dataloader import TimeSeriesDataLoader

OFFLINE_TRAJECTORIES = [
[2, 0],
Expand All @@ -31,8 +31,8 @@ def define_data_for_experiment(test_size: int = 3):
dataset_names = [name for name in os.listdir(data_folder_path)]

temp = pd.read_csv('pipeline_validation_results.csv', index_col=0)
train = temp[temp['Topo Pipeline'].isna() == False]['Dataset'].to_list()
test = temp[temp['Topo Pipeline'].isna() == True]['Dataset'].to_list()
train = temp[temp['Topo Pipeline'].isna() == True]['Dataset'].to_list()
test = temp[temp['Topo Pipeline'].isna() == False]['Dataset'].to_list()

# if test_size:
# train, test = train_test_split(dataset_names, test_size=3)
Expand All @@ -50,7 +50,7 @@ def define_data_for_experiment(test_size: int = 3):
path_to_meta_data = os.path.join(str(project_root()),
'MetaFEDOT\\data\\knowledge_base_time_series_0\\meta_features_ts.csv')

dataloader_train = DataLoader_TS(train_datasets, path_to_meta_data=path_to_meta_data)
dataloader_test = DataLoader_TS(test_datasets, path_to_meta_data=path_to_meta_data)
dataloader_train = TimeSeriesDataLoader(train_datasets, path_to_meta_data=path_to_meta_data)
dataloader_test = TimeSeriesDataLoader(test_datasets, path_to_meta_data=path_to_meta_data)

return dataloader_train, dataloader_test, train, test,
2 changes: 1 addition & 1 deletion tests/unit/rl_test/test_ts_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_correct_pipelines(trajectory):
def test_max_number_of_actions_in_pipelines(max_number_of_nodes):
train_data, test_data = get_time_series()

env = TimeSeriesPipelineEnvironment(max_number_of_nodes=max_number_of_nodes, metadata_dim=0)
env = TimeSeriesPipelineEnvironment(max_number_of_nodes=max_number_of_nodes, metadata_dim=0, using_number_of_nodes=max_number_of_nodes)
env.load_data(train_data, test_data, meta=None)
env.reset()

Expand Down

0 comments on commit bf6ed01

Please sign in to comment.