Skip to content

Commit

Permalink
Merge pull request #52 from oksanadanilova/feature/score_out_of_contr…
Browse files Browse the repository at this point in the history
…ol_bugfix

right percentage distance + tests
  • Loading branch information
Evgeny-Egorov-Projects authored Jun 8, 2020
2 parents e3cab00 + 3d03f63 commit 15e6e30
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 50 deletions.
188 changes: 138 additions & 50 deletions topicnet/cooking_machine/cubes/controller_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
We assume that if that metric is 'sort of decreasing', then everything is OK
and we are allowed to change tau coefficient further; otherwise we revert back
to the last "safe" value and stop
'sort of decreasing' performs best with `PerplexityScore`, and all scores which
behave like perplexity (nonnegative, and which should decrease when a model gets better).
If you want to track a different kind of score, it is recommended to use `score_controller` parameter
More formal definition of "sort of decreasing": if we divide a curve into two parts like so:
Expand Down Expand Up @@ -46,12 +50,16 @@
| right part |
then the right part is no higher than 5% of global minimum
(you can change 5% if you like by adjusting `fraction_threshold`
in `is_score_out_of_control` function)
(you can change 5% if you like by adjusting `fraction_threshold` parameter)
If score_to_track is None, then `ControllerAgent` will never stop
If `score_to_track` is None and `score_controller` is None, then `ControllerAgent` will never stop
(useful for e.g. decaying coefficients)
fraction_threshold: float
Threshold to control a score by 'sort of decreasing' metric
score_controller: BaseScoreController
Custom score controller
In case of 'sort of decreasing' is not proper to control score, you are able to create custom Score Controller
inherited from `BaseScoreController`.
tau_converter: str or callable
Notably, def-style functions and lambda functions are allowed
If it is function, then it should accept four arguments:
Expand Down Expand Up @@ -109,58 +117,98 @@
that way agent will continue operating even outside this `RegularizationControllerCube`
""" # noqa: W291

from .base_cube import BaseCube
from ..rel_toolbox_lite import count_vocab_size, handle_regularizer

import numexpr as ne
import warnings
from dill.source import getsource
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional

import numexpr as ne
import numpy as np
from dill.source import getsource

from .base_cube import BaseCube
from ..rel_toolbox_lite import count_vocab_size, handle_regularizer

W_HALT_CONTROL = "Process of dynamically changing tau was stopped at {} iteration"
W_MAX_ITERS = "Maximum number of iterations is exceeded; turning off"


def is_score_out_of_control(model, score_name, fraction_threshold=0.05):
"""
Returns True if score isn't 'sort of decreasing' anymore.
@dataclass
class OutOfControlAnswer:
answer: bool
error_message: Optional[str] = None

See docstring for RegularizationControllerCube for details

Parameters
----------
model : TopicModel
score_name : str or None
fraction_threshold : float
class BaseScoreController:
def __init__(self, score_name):
self.score_name = score_name

Returns
-------
bool
def get_score_values(self, model):
if self.score_name not in model.scores: # case of None is handled here as well
return None

vals = model.scores[self.score_name]
if len(vals) == 0:
return None

return vals

def __call__(self, model):
values = self.get_score_values(model)

if values is None:
return False

try:
out_of_control_result = self.is_out_of_control(values)
except Exception as ex:
message = (f"An error occured while controlling {self.score_name}. Message: {ex}. Score values: {values}")
raise ValueError(message)

if out_of_control_result.error_message is not None:
warnings.warn(out_of_control_result.error_message)

return out_of_control_result.answer

def is_out_of_control(self, values: List[float]) -> OutOfControlAnswer:
raise NotImplementedError


class PerplexityScoreController(BaseScoreController):
"""
Controller is proper to control the Perplexity score. For others, please ensure for yourself.
"""
DEFAULT_FRACTION_THRESHOLD = 0.05

def __init__(self, score_name, fraction_threshold=DEFAULT_FRACTION_THRESHOLD):
super().__init__(score_name)
self.fraction_threshold = fraction_threshold

if score_name not in model.scores: # case of None is handled here as well
return False
def is_out_of_control(self, values: List[float]):
idxmin = np.argmin(values)

vals = model.scores[score_name]
if len(vals) == 0:
return False
if idxmin == len(values): # score is monotonically decreasing
return False

idxmin = np.argmin(vals)
right_maxval = max(values[idxmin:])
minval = values[idxmin]

if idxmin == len(vals): # score is monotonically decreasing
return False
maxval = max(vals[idxmin:])
minval = vals[idxmin]
answer = ((maxval - minval)/abs(minval) - 1.0) > fraction_threshold
if answer:
msg = (f"Score {score_name} is too high: during training the value {maxval}"
f" passed a treshold of {(1 + fraction_threshold) * minval}"
f" (estimate is based on {idxmin} iteration)")
warnings.warn(msg)
return answer
if minval <= 0:
err_message = f"""Score {self.score_name} has min_value = {minval} which is <= 0.
This control scheme is using to control scores acting like Perplexity.
Ensure you control the Perplexity score or write your own controller"""
raise ValueError(err_message)

answer = (right_maxval - minval) / minval > self.fraction_threshold

if answer:
message = (f"Score {self.score_name} is too high! Right max value: {right_maxval}, min value: {minval}")
return OutOfControlAnswer(answer=answer, error_message=message)

return OutOfControlAnswer(answer=answer)


class ControllerAgentException(Exception): pass


class ControllerAgent:
Expand All @@ -172,8 +220,10 @@ class ControllerAgent:
Each agent is described by:
* reg_name: the name of regularizer having `tau` which needs to be changed
* score_to_track: score providing control of the callback execution
* tau_converter: function or string describing how to get new `tau` from old `tau`
* score_to_track: score name providing control of the callback execution
* fraction_threshold: threshold to control score_to_track
* score_controller: custom score controller providing control of the callback execution
* local_dict: dictionary containing values of several variables,
most notably, `user_value`
* is_working:
Expand All @@ -183,31 +233,64 @@ class ControllerAgent:
See top-level docstring for details.
"""
def __init__(self, reg_name, score_to_track, tau_converter, max_iters, local_dict=None):

def __init__(self, reg_name, tau_converter, max_iters, score_to_track=None, fraction_threshold=None,
score_controller=None, local_dict=None):
"""
Parameters
----------
reg_name : str
score_to_track : str, list of str or None
tau_converter : callable or str
local_dict : dict
max_iters : int or float
Agent will stop changing tau after `max_iters` iterations
`max_iters` could be `float("NaN")` and `float("inf")` values:
that way agent will continue operating even outside this `RegularizationControllerCube`
score_to_track : str, list of str or None
Name of score to track.
Please, use this definition to track only scores of type PerplexityScore.
In other cases we recommend you to write you own ScoreController
fraction_threshold : float, list of float of the same length as score_to_track or None
Uses to define threshold to control PerplexityScore
Default value is 0.05
score_controller : BaseScoreController, list of BaseScoreController or None
local_dict : dict
"""
if local_dict is None:
local_dict = dict()

self.reg_name = reg_name
self.tau_converter = tau_converter

self.score_controllers = []
if isinstance(score_to_track, list):
self.score_to_track = score_to_track
if fraction_threshold is None:
controller_params = [(name, PerplexityScoreController.DEFAULT_FRACTION_THRESHOLD) for name in
score_to_track]
elif isinstance(fraction_threshold, list) and len(score_to_track) == len(fraction_threshold):
controller_params = list(zip(score_to_track, fraction_threshold))
else:
err_message = """Length of score_to_track and fraction_threshold must be same.
Otherwise fraction_threshold must be None"""
raise ControllerAgentException(err_message)

self.score_controllers.append(
[PerplexityScoreController(name, threshold) for (name, threshold) in controller_params])

elif isinstance(score_to_track, str):
self.score_to_track = [score_to_track]
else:
self.score_to_track = []
self.score_controllers.append([PerplexityScoreController(
score_to_track,
fraction_threshold or PerplexityScoreController.DEFAULT_FRACTION_THRESHOLD
)])

if isinstance(score_controller, BaseScoreController):
self.score_controllers.append(score_controller)
elif isinstance(score_controller, list):
if not all(isinstance(score, BaseScoreController) for score in score_controller):
err_message = """score_controller must be of type BaseScoreController or list of BaseScoreController"""
raise ControllerAgentException(err_message)

self.score_controllers.extend(score_controller)

self.is_working = True
self.local_dict = local_dict
Expand Down Expand Up @@ -258,7 +341,7 @@ def invoke(self, model, cur_iter):

if self.is_working:
should_stop = any(
is_score_out_of_control(model, score) for score in self.score_to_track
score_controller(model) for score_controller in self.score_controllers
)
if should_stop:
warnings.warn(W_HALT_CONTROL.format(len(self.tau_history)))
Expand All @@ -283,26 +366,31 @@ def __init__(self, num_iter: int, parameters,
regularizers params
each dict should contain the following fields:
("reg_name" or "regularizer"),
"score_to_track" (optional),
"tau_converter",
"score_to_track" (optional),
"fraction_threshold" (optional),
"score_controller" (optional),
"user_value_grid"
See top-level docstring for details.
Examples:
>> {"regularizer": artm.regularizers.<...>,
>> "score_to_track": "PerplexityScore@all",
>> "tau_converter": "prev_tau * user_value",
>> "score_to_track": "PerplexityScore@all",
>> "fraction_threshold": 0.1,
>> "user_value_grid": [0.5, 1, 2]}
-----------
>> {"reg_name": "decorrelator_for_ngramms",
>> "score_to_track": None,
>> "tau_converter": (
>> lambda initial_tau, prev_tau, cur_iter, user_value:
>> initial_tau * (cur_iter % 2) + user_value
>> )
>> "score_to_track": None,
>> "fraction_threshold": None,
>> "score_controller": [PerplexityScoreController("PerplexityScore@all", 0.1)],
>> "user_value_grid": [0, 1]}
reg_search : str
Expand Down
72 changes: 72 additions & 0 deletions topicnet/tests/test_cube_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest

from topicnet.cooking_machine.cubes.controller_cube import PerplexityScoreController, ControllerAgent

DATA_REG_CONTROLLER_SORT_OF_DECREASING = [
([246.77072143554688,
124.72193908691406,
107.95775604248047,
105.27597045898438,
112.46900939941406,
132.88259887695312], 0.1, True),
([246.77072143554688,
124.72193908691406,
107.95775604248047,
105.27597045898438,
112.46900939941406], 0.1, False),
([246.77072143554688,
124.72193908691406,
107.95775604248047,
105.27597045898438,
112.46900939941406], 0.05, True),

]
DATA_AGENT_CONTROLLER_LEN_CHECK = [
({
"reg_name": "decorrelation",
"score_to_track": "PerplexityScore@all",
"tau_converter": "prev_tau * user_value",
"max_iters": float("inf")
}, 1),
({
"reg_name": "decorrelation",
"score_to_track": ["PerplexityScore@all"],
"tau_converter": "prev_tau + user_value",
"max_iters": float("inf")
}, 1),
({
"reg_name": "decorrelation",
"score_to_track": None, # never stop working
"tau_converter": "prev_tau * user_value",
"max_iters": float("inf")
}, 0),
({
"reg_name": "decorrelation",
"score_to_track": None, # never stop working
"score_controller": PerplexityScoreController("PerplexityScore@all", 0.1),
"tau_converter": "prev_tau * user_value",
"max_iters": float("inf")
}, 1),
({
"reg_name": "decorrelation",
"score_to_track": "PerplexityScore@all", # never stop working
"score_controller": PerplexityScoreController("PerplexityScore@all", 0.1),
"tau_converter": "prev_tau * user_value",
"max_iters": float("inf")
}, 2)
]


@pytest.mark.parametrize('values, fraction, answer_true', DATA_REG_CONTROLLER_SORT_OF_DECREASING)
def test_perplexity_controller(values, fraction, answer_true):
score_controller = PerplexityScoreController('test', fraction)
is_out_of_control = score_controller.is_out_of_control(values)

assert is_out_of_control.answer == answer_true


@pytest.mark.parametrize('agent_blueprint, answer_true', DATA_AGENT_CONTROLLER_LEN_CHECK)
def test_controllers_length(agent_blueprint, answer_true):
agent = ControllerAgent(**agent_blueprint)

assert len(agent.score_controllers) == answer_true

0 comments on commit 15e6e30

Please sign in to comment.