Skip to content

Commit

Permalink
Fixes memory errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Jan 25, 2024
1 parent e8f915c commit 77be6b2
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 27 deletions.
22 changes: 13 additions & 9 deletions rl_core/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
if __name__ == '__main__':
task_type = 'classification'
pipeline_len = 5
n_episodes = 15000
n_episodes = 1000

path_to_agent = os.path.join(str(project_root()), 'MetaFEDOT/rl_core/agent/pretrained/ensemble_5_a2c_128_1000')

data_folder_path = os.path.join(str(project_root()), 'MetaFEDOT/rl_core/data/')

Expand All @@ -19,20 +21,22 @@
'kc1': os.path.join(data_folder_path, 'kc1_train.csv'),
}

primitives = ['scaling', 'simple_imputation', 'normalization', 'dt', 'logit', 'rf']
primitives = ['scaling', 'simple_imputation', 'normalization', 'dt', 'logit', 'rf', 'knn']

gen = Generator(task_type, state_dim=pipeline_len, n_episodes=n_episodes) \
.set_environment(env_name='ensemble', primitives=primitives) \
.set_dataloader(train_datasets) \
.set_agent(
eval_schedule=15,
critic_updates_per_actor=10,
eval_schedule=25,
critic_updates_per_actor=25,
) \
.set_writer()

gen.fit()

gen.save_agent()
if path_to_agent:
gen.load_agent(path=path_to_agent)
else:
gen.fit()
gen.save_agent()

test_datasets = {
'amazon': os.path.join(data_folder_path, 'amazon_test.csv'),
Expand All @@ -43,14 +47,14 @@
'kc1': os.path.join(data_folder_path, 'kc1_test.csv'),
}

for name, dataset in test_datasets:
for name, dataset in test_datasets.items():
valid, not_valid = 0, 0

for _ in range(25):
pipeline, metric_value = gen.generate(path_to_dataset=dataset)

if pipeline:
pipeline.show()
pipeline.build().show()
print('Test metric:', metric_value)
valid += 1
else:
Expand Down
2 changes: 1 addition & 1 deletion rl_core/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _inference_step(self, action):
raise NotImplementedError()

@staticmethod
def _pipeline_constuction_validate(pipeline):
def _pipeline_construction_validate(pipeline):
try:
if pipeline.build():
return True
Expand Down
28 changes: 17 additions & 11 deletions rl_core/environments/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fedot.core.pipelines.pipeline_builder import PipelineBuilder
from gym import spaces

from rl_core.agent.agent import to_tensor
from rl_core.environments.base import PipelineGenerationEnvironment


Expand All @@ -26,8 +27,10 @@ def __init__(self, state_dim: int, primitives: list):
self.meta_model = None
self.branch_idx = 0

self.encoded_pipeline = []

def init_state(self):
self.state = torch.tensor(np.zeros(self.state_dim))
self.state = to_tensor(np.zeros(self.state_dim))
return self

def update_state(self, action):
Expand All @@ -38,6 +41,7 @@ def update_state(self, action):

def reset(self, **kwargs):
self.pipeline = PipelineBuilder()
self.encoded_pipeline = []
self.is_valid = False
self.time_step = 0
self.metric_value = 0
Expand All @@ -50,17 +54,17 @@ def reset(self, **kwargs):

return self.state

def _train_step(self, action):
def _train_step(self, action, return_pipeline=False):
terminated, truncated = False, False
self.last_action = action

if self.primitives[action] == 'eop' or self.position == self.state_dim:
self.time_step += 1
done = True
terminated = True

self.pipeline.join_branches(self.meta_model)

if self._pipeline_constuction_validate(self.pipeline):
if self._pipeline_construction_validate(self.pipeline):
reward = self.pipeline_fitting_and_evaluating()
else:
reward = -0.999
Expand All @@ -73,26 +77,28 @@ def _train_step(self, action):
self.time_step += 1
reward = -0.001

self.update_state(action)

primitive = self.primitives[action]

self.encoded_pipeline.append(primitive)

if self.position == 0:
self.meta_model = action
self.meta_model = primitive
else:
self.pipeline.add_branch(primitive, branch_idx=self.branch_idx)
self.branch_idx += 1

reward, info = self._environment_response(reward)
self.update_state(action)

reward, info = self._environment_response(reward, return_pipeline)

return deepcopy(self.state), reward, terminated, truncated, info

def _inference_step(self, action):
raise NotImplementedError()
return self._train_step(action, return_pipeline=True)

def _environment_response(self, reward: float) -> (int, bool, dict):
def _environment_response(self, reward: float, return_pipeline: bool) -> (int, bool, dict):
info = {
'pipeline': self.pipeline.build(),
'pipeline': self.pipeline if return_pipeline else None,
'time_step': self.time_step,
'metric_value': self.metric_value,
'is_valid': self.is_valid
Expand Down
10 changes: 6 additions & 4 deletions rl_core/environments/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@ def reset(self, **kwargs):
return deepcopy(self.state)

def _train_step(self, action):
return_pipeline = False
terminated, truncated = False, False
self.last_action = action

if self.primitives[action] == 'eop' or self.position == self.state_dim:
self.time_step += 1
terminated = True

if self._pipeline_constuction_validate(self.pipeline):
if self._pipeline_construction_validate(self.pipeline):
reward = self.pipeline_fitting_and_evaluating()
return_pipeline = True
else:
reward = -0.999

Expand All @@ -72,16 +74,16 @@ def _train_step(self, action):
primitive = self.primitives[action]
self.pipeline.add_node(primitive)

reward, info = self._environment_response(reward)
reward, info = self._environment_response(reward, return_pipeline)

return deepcopy(self.state), reward, terminated, truncated, info

def _inference_step(self, action):
raise NotImplementedError()

def _environment_response(self, reward: float) -> (int, bool, dict):
def _environment_response(self, reward: float, return_pipeline: bool) -> (int, bool, dict):
info = {
'pipeline': self.pipeline.build(),
'pipeline': self.pipeline if return_pipeline else None,
'time_step': self.time_step,
'metric_value': self.metric_value,
'is_valid': self.is_valid
Expand Down
6 changes: 4 additions & 2 deletions rl_core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rl_core.agent.agent import ActorCriticAgent
from rl_core.dataloader import DataLoader
from rl_core.environments.embedding import EmbeddingPipelineGenerationEnvironment
from rl_core.environments.ensemble import EnsemblePipelineGenerationEnvironment
from rl_core.environments.linear import LinearPipelineGenerationEnvironment

SUCCESS_RET = 0.75
Expand All @@ -18,6 +19,7 @@
class Generator:
__env_dict = {
'linear': LinearPipelineGenerationEnvironment,
'ensemble': EnsemblePipelineGenerationEnvironment,
'embedding': EmbeddingPipelineGenerationEnvironment,
}

Expand Down Expand Up @@ -76,7 +78,7 @@ def set_environment(self, env_name: str, primitives: list[str] = None):
if not primitives:
primitives = OperationTypesRepository('all').suitable_operation(task_type=self.task_type)

for d_primitves in ['lgbm', 'knn']:
for d_primitves in ['lgbm']:
primitives.remove(d_primitves)

self.env = env(state_dim=self.state_dim, primitives=primitives)
Expand Down Expand Up @@ -185,7 +187,7 @@ def generate(self, path_to_dataset):

while not done:
a = self.agent.act(s)
s_next, r, terminated, truncated, info = self.env.step(a)
s_next, r, terminated, truncated, info = self.env.step(a, mode='inference')
done = terminated or truncated

s = s_next
Expand Down

0 comments on commit 77be6b2

Please sign in to comment.