Skip to content

Commit

Permalink
restart lix reset with new reward.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyotech committed Feb 15, 2023
1 parent bf338bb commit ebfd7a1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
8 changes: 5 additions & 3 deletions learner_lix.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@
velocity_pb_w=0.01,
wall_touch_w=0.5,
tick_skip=Constants_lix.FRAME_SKIP,
curve_wave_zap_dash_w=0.35,
walldash_w=0.35,
flip_reset_w=5,
# curve_wave_zap_dash_w=0.35,
# walldash_w=0.35,
flip_reset_w=0,
# dash_limit_per_ep=1,
lix_reset_w=5,
),
lambda: CoyoteAction(),
save_every=logger.config.save_every * 3,
Expand Down
19 changes: 16 additions & 3 deletions rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,12 @@ def __init__(
boost_remain_touch_object_w=0,
end_touched: dict = None,
punish_backboard_pinch_w=0,
dash_limit_per_ep=100000000,
lix_reset_w=0,
):
self.lix_reset_w = lix_reset_w
self.dash_limit_per_ep = dash_limit_per_ep
self.dash_count = [0] * 6
self.punish_backboard_pinch_w = punish_backboard_pinch_w
self.end_touched = end_touched
self.boost_remain_touch_object_w = boost_remain_touch_object_w
Expand Down Expand Up @@ -607,6 +612,7 @@ def reset(self, initial_state: GameState):
self.exit_vel_save = [None] * 6
self.previous_action = np.asarray([-1] * len(initial_state.players))
self.last_action_change = np.asarray([0] * len(initial_state.players))
self.dash_count = [0] * 6

# if self.walldash_w != 0 or self.wave_zap_dash_w != 0 or self.curvedash_w != 0:
if self.curve_wave_zap_dash_w != 0 or self.walldash_w != 0:
Expand Down Expand Up @@ -745,6 +751,8 @@ def _update_addl_timers(self, player: PlayerData, state: GameState, prev_actions
self.fliptimes[cid] = min(
78, self.fliptimes[cid])

ret = 0

if dash_timer > 0:
dash_rew = (79 - dash_timer) / 40

Expand All @@ -765,7 +773,12 @@ def _update_addl_timers(self, player: PlayerData, state: GameState, prev_actions
speed_rew = max(float(np.dot(norm_pos_diff, norm_vel)), 0.025)

if player.car_data.position[2] > 100: # wall curve is 256, but curvedashes end their torque very close to 0
return dash_rew * self.walldash_w * speed_rew
self.dash_count[self.n] += 1
ret += dash_rew * self.walldash_w * speed_rew if self.dash_count[self.n] <= self.dash_limit_per_ep else 0
elif player.car_data.position[2] <= 100:
return dash_rew * self.curve_wave_zap_dash_w * speed_rew
return 0.0
self.dash_count[self.n] += 1
ret += dash_rew * self.curve_wave_zap_dash_w * speed_rew if self.dash_count[self.n] <= self.dash_limit_per_ep else 0

if not player.on_ground and self.airtimes[self.n] == 0 and not self.is_jumpings[self.n]:
ret += self.lix_reset_w
return ret
8 changes: 5 additions & 3 deletions worker_lix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
velocity_pb_w=0.01,
wall_touch_w=0.5,
tick_skip=Constants_lix.FRAME_SKIP,
curve_wave_zap_dash_w=0.35,
walldash_w=0.35,
flip_reset_w=5,
# curve_wave_zap_dash_w=0.35,
# walldash_w=0.35,
flip_reset_w=0,
# dash_limit_per_ep=1,
lix_reset_w=5,
)

fps = 120 // frame_skip
Expand Down

0 comments on commit ebfd7a1

Please sign in to comment.