Skip to content

Commit

Permalink
Allow model selection (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-luke authored Apr 28, 2023
1 parent d62c1c7 commit f848751
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ['3.10']
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

steps:
- uses: actions/checkout@v3
Expand All @@ -42,4 +44,7 @@ jobs:
run: hatch run dev:lint-types

- name: Run format checking
run: hatch run dev:lint-format
run: hatch run dev:lint-format

- name: Run self test
run: hatch run docstring-auditor
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,4 @@ cython_debug/
package-lock.json
package.json
***.draft

***.DS_Store
33 changes: 23 additions & 10 deletions docstring_auditor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def extract_functions(file_path: str) -> List[Optional[str]]:
return functions


def ask_for_critique(function: str) -> Dict[str, str]:
def ask_for_critique(function: str, model: str) -> Dict[str, str]:
"""
Query OpenAI for a critique of the docstring for a function.
Expand All @@ -63,6 +63,8 @@ def ask_for_critique(function: str) -> Dict[str, str]:
function : str
A string containing the code and the docstring for the Python function.
The input should be formatted as a single string, with the code and docstring combined.
model : str
The name of the OpenAI model to use for the query.
Returns
-------
Expand Down Expand Up @@ -110,7 +112,7 @@ def ask_for_critique(function: str) -> Dict[str, str]:
{"role": "user", "content": function},
]
response = openai.ChatCompletion.create(
model="gpt-4", temperature=0.1, messages=messages
model=model, temperature=0.0, messages=messages
)

response_str = response["choices"][0]["message"]["content"]
Expand Down Expand Up @@ -165,7 +167,7 @@ def report_concerns(response_dict: Dict[str, str]) -> Tuple[int, int]:
return error_count, warning_count


def process_file(file_path: str) -> Tuple[int, int]:
def process_file(file_path: str, model: str) -> Tuple[int, int]:
"""
Process a single Python file and analyze its functions' docstrings.
Expand All @@ -178,6 +180,8 @@ def process_file(file_path: str) -> Tuple[int, int]:
----------
file_path : str
The path to the .py file to analyze the functions' docstrings.
model : str
The name of the OpenAI model to use for the analysis.
Returns
-------
Expand All @@ -194,7 +198,7 @@ def process_file(file_path: str) -> Tuple[int, int]:
f"Processing function {idx + 1} of {len(functions)} in file {file_path}..."
)
assert isinstance(function, str)
critique = ask_for_critique(function)
critique = ask_for_critique(function, model)
errors, warnings = report_concerns(critique)
error_count += errors
warning_count += warnings
Expand All @@ -203,7 +207,7 @@ def process_file(file_path: str) -> Tuple[int, int]:


def process_directory(
directory_path: str, ignore_dirs: Optional[List[str]] = None
directory_path: str, model: str, ignore_dirs: Optional[List[str]] = None
) -> Tuple[int, int]:
"""
Recursively process all .py files in a directory and its subdirectories, ignoring specified directories.
Expand All @@ -212,7 +216,8 @@ def process_directory(
----------
directory_path : str
The path to the directory containing .py files to analyze the functions' docstrings.
model : str
The name of the OpenAI model to use for the docstring analysis.
ignore_dirs : Optional[List[str]]
A list of directory names to ignore while processing .py files. By default, it ignores the "tests" directory.
Expand All @@ -233,7 +238,7 @@ def process_directory(
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
errors, warnings = process_file(file_path)
errors, warnings = process_file(file_path, model)
error_count += errors
warning_count += warnings

Expand All @@ -255,7 +260,13 @@ def process_directory(
default=False,
help="If true, warnings will be treated as errors and included in the exit code count.",
)
def docstring_auditor(path: str, ignore_dirs: List[str], error_on_warnings: bool):
@click.option(
"--model",
type=click.STRING,
default="gpt-4",
help="The OpenAI model to use for docstring analysis. Default is 'gpt-4'.",
)
def docstring_auditor(path: str, ignore_dirs: List[str], error_on_warnings: bool, model: str):
"""
Analyze Python functions' docstrings in a given file or directory and provide critiques and suggestions for improvement.
Expand All @@ -271,16 +282,18 @@ def docstring_auditor(path: str, ignore_dirs: List[str], error_on_warnings: bool
A list of directory names to ignore while processing .py files.
error_on_warnings : bool
If true, warnings will be treated as errors and included in the exit code count.
model : str
The OpenAI model to use for docstring analysis. Default is 'gpt-4'.
Returns
-------
None
The function does not return any value. It prints the critiques and suggestions for the docstrings in the given file or directory.
"""
if os.path.isfile(path):
error_count, warning_count = process_file(path)
error_count, warning_count = process_file(path, model)
elif os.path.isdir(path):
error_count, warning_count = process_directory(path, ignore_dirs)
error_count, warning_count = process_directory(path, model, ignore_dirs)
else:
error_text = "Invalid path. Please provide a valid file or directory path."
click.secho(error_text, fg="red")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "docstring-auditor"
version = "0.1.7"
version = "0.1.8"
authors = [{name = "Rob Luke", email = "rob.luke@ae.studio"}]
description = "A tool to analyze Python functions' docstrings and provide critiques and suggestions for improvement"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_critique.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_ask_for_critique(openai_mock):
" return 'successful'\n"
)

response_dict = ask_for_critique(function)
response_dict = ask_for_critique(function, model="gpt-4")
assert response_dict["function"].startswith("test_function_")

current_test_id = (
Expand Down

0 comments on commit f848751

Please sign in to comment.