diff --git a/compiler_opt/tools/combine_tfa_policies_lib_test.py b/compiler_opt/tools/combine_tfa_policies_lib_test.py index 030d213f..9fb8bb4b 100644 --- a/compiler_opt/tools/combine_tfa_policies_lib_test.py +++ b/compiler_opt/tools/combine_tfa_policies_lib_test.py @@ -35,8 +35,7 @@ def __init__(self): act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) - super().__init__( - time_step_spec=time_step_spec, action_spec=act_spec) + super().__init__(time_step_spec=time_step_spec, action_spec=act_spec) def _distribution(self, time_step): pass @@ -59,8 +58,7 @@ def __init__(self): act_spec = tensor_spec.TensorSpec(shape=(1,), dtype=tf.int64) - super().__init__( - time_step_spec=time_step_spec, action_spec=act_spec) + super().__init__(time_step_spec=time_step_spec, action_spec=act_spec) def _distribution(self, time_step): pass