Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [sc-22270] Separate random generator for each thread #302

Merged
merged 9 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions generator/generation/categories_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import csv
import random

import itertools
import collections
import logging
import random

from operator import itemgetter
from typing import List, Dict, Tuple
Expand All @@ -15,6 +15,8 @@
from .combination_limiter import CombinationLimiter, prod
from ..input_name import InputName, Interpretation
from ..utils import Singleton
from generator.thread_utils import get_random_rng


logger = logging.getLogger('generator')

Expand All @@ -27,6 +29,7 @@ def __init__(self, config: DictConfig) -> None:
for token in tokens:
self.inverted_categories[token].append(category)

random.seed(0)
for tokens in self.categories.values():
random.shuffle(tokens)

Expand Down Expand Up @@ -65,10 +68,12 @@ def generate(self, tokens: Tuple[str, ...]) -> List[Tuple[str, ...]]:

token = ''.join(tokens)

rng = get_random_rng()

iterators = []
for category in self.categories.get_categories(token):
names = self.categories.get_names(category)
start_index = random.randint(0, len(names))
start_index = rng.randint(0, len(names))
iterators.append(
itertools.chain(itertools.islice(names, start_index, None), itertools.islice(names, 0, start_index)))

Expand Down
4 changes: 2 additions & 2 deletions generator/generation/easteregg_generator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Tuple, Iterator
from random import shuffle

from generator.generation import NameGenerator
from ..input_name import InputName, Interpretation
from generator.thread_utils import get_random_rng



Expand All @@ -27,7 +27,7 @@ def generate(self, tokens: Tuple[str, ...]) -> Iterator[Tuple[str, ...]]:
return []

name = ''.join(tokens)
shuffle(self.messages)
get_random_rng().shuffle(self.messages)
return ((m.format(name=name),) for m in self.messages)

def generate2(self, name: InputName, interpretation: Interpretation) -> Iterator[Tuple[str, ...]]:
Expand Down
3 changes: 2 additions & 1 deletion generator/generation/person_name_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .name_generator import NameGenerator
from ..input_name import InputName, Interpretation
from generator.thread_utils import get_numpy_rng


def standardize(a):
Expand Down Expand Up @@ -59,7 +60,7 @@ def generate(self, tokens: Tuple[str, ...], gender: str = None) -> List[Tuple[st
else:
data = self.both

order = np.random.choice(len(data[0]), size=len(data[0]), replace=False, p=data[1])
order = get_numpy_rng().choice(len(data[0]), size=len(data[0]), replace=False, p=data[1])
return (tokens + (data[0][index][0],) if data[0][index][1] == 'suffix' else (data[0][index][0],) + tokens for
index in order)

Expand Down
6 changes: 4 additions & 2 deletions generator/generation/random_available_name_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import random, logging
import logging
from itertools import accumulate
from typing import List, Tuple, Any

Expand All @@ -8,6 +8,8 @@
from .name_generator import NameGenerator
from ..domains import Domains
from ..input_name import InputName, Interpretation
from generator.thread_utils import get_random_rng


logger = logging.getLogger('generator')

Expand Down Expand Up @@ -43,7 +45,7 @@ def generate(self, limit=None) -> List[Tuple[str, ...]]:
limit = self.limit
limit = min(limit * 2, self.limit)
if len(self.domains.only_available) >= limit:
result = random.choices(self.names, cum_weights=self.accumulated_probabilities, k=limit)
result = get_random_rng().choices(self.names, cum_weights=self.accumulated_probabilities, k=limit)
else:
result = self.names
return ((x,) for x in result)
Expand Down
8 changes: 5 additions & 3 deletions generator/generation/rhymes_generator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import re
import unicodedata
import random
from typing import Tuple, Iterator

from generator.generation import NameGenerator
from ..input_name import InputName, Interpretation
from generator.thread_utils import get_random_rng


class RhymesGenerator(NameGenerator):
Expand Down Expand Up @@ -35,6 +35,8 @@ def generate(self, tokens: Tuple[str, ...]) -> Iterator[Tuple[str, ...]]:

rhyme_suffix = RhymesGenerator.get_rhyme_suffix(name_vmetaphone_repr)

rng = get_random_rng()

for suffix_len in range(len(rhyme_suffix), 2, -1):
suffix = rhyme_suffix[-suffix_len:]
rhymes = self.suffix2rhymes.get(suffix, None)
Expand All @@ -54,8 +56,8 @@ def generate(self, tokens: Tuple[str, ...]) -> Iterator[Tuple[str, ...]]:

rhymes_top = rhymes[:shuffle_threshold]
rhymes_bottom = rhymes[shuffle_threshold:]
random.shuffle(rhymes_top)
random.shuffle(rhymes_bottom)
rng.shuffle(rhymes_top)
rng.shuffle(rhymes_bottom)
rhymes = rhymes_top + rhymes_bottom

try:
Expand Down
11 changes: 8 additions & 3 deletions generator/meta_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import random
from typing import Type, Callable
from ens_normalize import is_ens_normalized

Expand All @@ -10,6 +9,8 @@
from generator.sampling.round_robin_sampler import RoundRobinSampler
from generator.sampling.sampler import Sampler
from generator.input_name import InputName
from generator.thread_utils import init_seed_for_thread, get_random_rng


logger = logging.getLogger('generator')

Expand Down Expand Up @@ -91,6 +92,8 @@ def sample(
) -> list[GeneratedName]:
min_available_required = int(min_suggestions * min_available_fraction)

init_seed_for_thread(seed_label=name.input_name) # init random generators for a thread

mode = name.params.get('mode', 'full')

types_lang_weights = {}
Expand Down Expand Up @@ -122,17 +125,19 @@ def sample(
all_suggestions_str = set()
joined_input_name = name.input_name.replace(' ', '')

rng = get_random_rng()

while True:
if len(all_suggestions) >= max_suggestions or not types_lang_weights:
break

# sample interpretation
sampled_type_lang = random.choices(
sampled_type_lang = rng.choices(
list(types_lang_weights.keys()),
weights=list(types_lang_weights.values())
)[0]

sampled_interpretation = random.choices(
sampled_interpretation = rng.choices(
list(interpretation_weights[sampled_type_lang].keys()),
weights=list(interpretation_weights[sampled_type_lang].values())
)[0]
Expand Down
17 changes: 8 additions & 9 deletions generator/sampling/weighted_sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import random

import numpy.random
from omegaconf import DictConfig
import numpy as np

from generator.pipeline import Pipeline
from generator.sampling.sampler import Sampler
from generator.thread_utils import get_random_rng, get_numpy_rng


class WeightedSorter(Sampler):
Expand All @@ -22,7 +21,7 @@ def __init__(self, config: DictConfig, pipelines: list[Pipeline], weights: dict[

def __next__(self):
if self.weights:
pipeline = random.choices(
pipeline = get_random_rng().choices(
list(self.weights.keys()),
weights=list(self.weights.values())
)[0] # TODO: optimize?
Expand All @@ -48,18 +47,18 @@ def __init__(self, config: DictConfig, pipelines: list[Pipeline], weights: dict[
del self.weights[pipeline]

if self.weights:
normalized_weights = numpy.array(list(self.weights.values()))
normalized_weights = normalized_weights / numpy.sum(normalized_weights)
self.first_pass = numpy.random.choice(list(self.weights.keys()), len(self.weights),
p=normalized_weights, replace=False).tolist()
normalized_weights = np.array(list(self.weights.values()))
normalized_weights = normalized_weights / np.sum(normalized_weights)
self.first_pass = get_numpy_rng().choice(list(self.weights.keys()), len(self.weights),
p=normalized_weights, replace=False).tolist()
else:
self.first_pass = []

def __next__(self):
if self.first_pass:
return self.first_pass.pop(0)
if self.weights:
pipeline = random.choices(list(self.weights.keys()), weights=list(self.weights.values()))[
pipeline = get_random_rng().choices(list(self.weights.keys()), weights=list(self.weights.values()))[
0] # TODO: optimize?
return pipeline
raise StopIteration
Expand Down
2 changes: 2 additions & 0 deletions generator/thread_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .thread_locals import thread_locals, init_seed_for_thread
from .thread_random import get_random_rng, get_numpy_rng
20 changes: 20 additions & 0 deletions generator/thread_utils/thread_locals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import threading
import hashlib
import logging
import random
import numpy.random as np_random


logger = logging.getLogger('generator')

thread_locals = threading.local()


def init_seed_for_thread(seed_label: str):
global thread_locals

hashed = hashlib.md5(seed_label.encode('utf-8')).digest()
seed = int.from_bytes(hashed, 'big') & 0xff_ff_ff_ff
logger.info(f"Setting seed for a thread: {seed}")
thread_locals.random_rng = random.Random(seed)
thread_locals.numpy_rng = np_random.default_rng(seed)
37 changes: 37 additions & 0 deletions generator/thread_utils/thread_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import logging
import random

import numpy.random as np_random

from .thread_locals import thread_locals


# todo: is "sth in globals()" and "globals()['thread_locals']" thread-safe?

logger = logging.getLogger('generator')

def get_random_rng():
"""Returns a `random.Random` object for a thread or a `random` module."""
if 'thread_locals' in globals():
rng_per_thread = getattr(globals()['thread_locals'], 'random_rng', None)
if rng_per_thread is None:
logger.warning(f'Using random module instead of a thread-specific rng!')
rng = rng_per_thread if rng_per_thread is not None else random
else:
logger.warning(f'Using random module instead of a thread-specific rng!')
rng = random
return rng


def get_numpy_rng():
"""Returns a `np.random.default_rng` object for a thread or a `np.random` module."""
if 'thread_locals' in globals():
rng_per_thread = getattr(globals()['thread_locals'], 'numpy_rng', None)
if rng_per_thread is None:
logger.warning(f'Using numpy.random module instead of a thread-specific rng!')
rng = rng_per_thread if rng_per_thread is not None else np_random
else:
logger.warning(f'Using numpy.random module instead of a thread-specific rng!')
rng = np_random

return rng
1 change: 0 additions & 1 deletion generator/xgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import concurrent.futures
import threading
import logging
import random
import time
from itertools import islice, cycle
from typing import List, Any
Expand Down
31 changes: 31 additions & 0 deletions tests/test_web_api_prod.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
from time import time as get_time
from time import sleep

import pytest
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -748,6 +749,36 @@ def test_prod_grouped_by_category_emojis(self, prod_test_client, label):
elif category['type'] == 'expand':
assert len(category['suggestions']) > 0

# reason is described here - https://app.shortcut.com/ps-web3/story/22270/nondeterministic-behavior-for-vitalik
@pytest.mark.xfail
@pytest.mark.integration_test
@pytest.mark.parametrize("label", ["zeus", "dog", "dogs", "superman"])
def test_prod_deterministic_behavior(self, prod_test_client, label):
client = prod_test_client

request_data = {
"label": label,
"params": {
"user_info": {
"user_wallet_addr": "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa",
"user_ip_addr": "192.168.0.1",
"session_id": "d6374908-94c3-420f-b2aa-6dd41989baef",
"user_ip_country": "us"
},
"mode": "full",
"metadata": True
}
}
response1 = client.post("/suggestions_by_category", json=request_data)
assert response1.status_code == 200

sleep(0.1)

response2 = client.post("/suggestions_by_category", json=request_data)
assert response2.status_code == 200

assert response1.json() == response2.json()


@pytest.mark.integration_test
@pytest.mark.parametrize(
Expand Down