Skip to content

Commit b896bd3

Browse files
committed
Mypy compatibility
1 parent c822df8 commit b896bd3

File tree

6 files changed

+56
-19
lines changed

6 files changed

+56
-19
lines changed

.github/workflows/CI.yml

+26
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,29 @@ jobs:
143143
run: |
144144
pip install coveralls
145145
coveralls --finish
146+
147+
types:
148+
name: Check types
149+
runs-on: ubuntu-latest
150+
defaults:
151+
run:
152+
shell: bash -l {0}
153+
strategy:
154+
matrix:
155+
python-version: ['3.10']
156+
157+
steps:
158+
- uses: actions/checkout@v3
159+
- name: "Set up Python"
160+
uses: actions/setup-python@v4
161+
with:
162+
python-version: ${{ matrix.python-version }}
163+
cache: pip
164+
- name: "Install PySR and all dependencies"
165+
run: |
166+
python -m pip install --upgrade pip
167+
pip install -r requirements.txt
168+
pip install mypy jax jaxlib torch
169+
python setup.py install
170+
- name: "Run mypy"
171+
run: mypy --install-types --non-interactive pysr

mypy.ini

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[mypy]
2+
warn_return_any = True
3+
4+
[mypy-sklearn.*]
5+
ignore_missing_imports = True
6+
7+
[mypy-julia.*]
8+
ignore_missing_imports = True

pysr/export_latex.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Functions to help export PySR equations to LaTeX."""
2-
from typing import List
2+
from typing import List, Optional, Tuple
33

44
import pandas as pd
55
import sympy
@@ -19,14 +19,16 @@ def _print_Float(self, expr):
1919
return super()._print_Float(reduced_float)
2020

2121

22-
def sympy2latex(expr, prec=3, full_prec=True, **settings):
22+
def sympy2latex(expr, prec=3, full_prec=True, **settings) -> str:
2323
"""Convert sympy expression to LaTeX with custom precision."""
2424
settings["full_prec"] = full_prec
2525
printer = PreciseLatexPrinter(settings=settings, prec=prec)
2626
return printer.doprint(expr)
2727

2828

29-
def generate_table_environment(columns=["equation", "complexity", "loss"]):
29+
def generate_table_environment(
30+
columns: List[str] = ["equation", "complexity", "loss"]
31+
) -> Tuple[str, str]:
3032
margins = "c" * len(columns)
3133
column_map = {
3234
"complexity": "Complexity",
@@ -58,20 +60,20 @@ def generate_table_environment(columns=["equation", "complexity", "loss"]):
5860

5961
def sympy2latextable(
6062
equations: pd.DataFrame,
61-
indices: List[int] = None,
63+
indices: Optional[List[int]] = None,
6264
precision: int = 3,
63-
columns=["equation", "complexity", "loss", "score"],
65+
columns: List[str] = ["equation", "complexity", "loss", "score"],
6466
max_equation_length: int = 50,
6567
output_variable_name: str = "y",
66-
):
68+
) -> str:
6769
"""Generate a booktabs-style LaTeX table for a single set of equations."""
6870
assert isinstance(equations, pd.DataFrame)
6971

7072
latex_top, latex_bottom = generate_table_environment(columns)
7173
latex_table_content = []
7274

7375
if indices is None:
74-
indices = range(len(equations))
76+
indices = list(equations.index)
7577

7678
for i in indices:
7779
latex_equation = sympy2latex(
@@ -126,11 +128,11 @@ def sympy2latextable(
126128

127129
def sympy2multilatextable(
128130
equations: List[pd.DataFrame],
129-
indices: List[List[int]] = None,
131+
indices: Optional[List[List[int]]] = None,
130132
precision: int = 3,
131-
columns=["equation", "complexity", "loss", "score"],
132-
output_variable_names: str = None,
133-
):
133+
columns: List[str] = ["equation", "complexity", "loss", "score"],
134+
output_variable_names: Optional[List[str]] = None,
135+
) -> str:
134136
"""Generate multiple latex tables for a list of equation sets."""
135137
# TODO: Let user specify custom output variable
136138

pysr/export_sympy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@
5151

5252

5353
def create_sympy_symbols(
54-
feature_names_in: Optional[List[str]] = None,
54+
feature_names_in: List[str],
5555
) -> List[sympy.Symbol]:
5656
return [sympy.Symbol(variable) for variable in feature_names_in]
5757

5858

5959
def pysr2sympy(
6060
equation: str, *, extra_sympy_mappings: Optional[Dict[str, Callable]] = None
61-
) -> sympy.Expr:
61+
):
6262
local_sympy_mappings = {
6363
**(extra_sympy_mappings if extra_sympy_mappings else {}),
6464
**sympy_mappings,

pysr/feature_selection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33

44

5-
def run_feature_selection(X, y, select_k_features, random_state=None) -> np.ndarray:
5+
def run_feature_selection(X, y, select_k_features, random_state=None):
66
"""
77
Find most important features.
88

pysr/sr.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from io import StringIO
1212
from multiprocessing import cpu_count
1313
from pathlib import Path
14+
from typing import List, Optional
1415

1516
import numpy as np
1617
import pandas as pd
@@ -1781,10 +1782,10 @@ def fit(
17811782
y,
17821783
Xresampled=None,
17831784
weights=None,
1784-
variable_names=None,
1785-
X_units=None,
1786-
y_units=None,
1787-
):
1785+
variable_names: Optional[List[str]] = None,
1786+
X_units: Optional[List[str]] = None,
1787+
y_units: Optional[List[str]] = None,
1788+
) -> "PySRRegressor":
17881789
"""
17891790
Search for equations to fit the dataset and store them in `self.equations_`.
17901791
@@ -2371,7 +2372,7 @@ def latex_table(
23712372
return "\n".join(preamble_string + [table_string])
23722373

23732374

2374-
def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
2375+
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
23752376
"""Select an expression and return its index."""
23762377
if model_selection == "accuracy":
23772378
chosen_idx = equations["loss"].idxmin()

0 commit comments

Comments
 (0)