Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow explicit location to module #218

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import importlib
import importlib.util
import logging
import re
from argparse import ArgumentParser, Namespace
import sys
from argparse import ArgumentParser, ArgumentTypeError, Namespace
from pathlib import Path
from types import ModuleType

from pybind11_stubgen.parser.interface import IParser
from pybind11_stubgen.parser.mixins.error_handlers import (
Expand Down Expand Up @@ -77,6 +80,7 @@ class CLIArgs(Namespace):
dry_run: bool
stub_extension: str
module_name: str
location: Path | None


def arg_parser() -> ArgumentParser:
Expand Down Expand Up @@ -215,6 +219,22 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
"Must be 'pyi' (default) or 'py'",
)

def existing_file(path: str) -> Path | None:
if path is None:
return None
try:
return Path(path).resolve(strict=True)
except FileNotFoundError:
raise ArgumentTypeError(f"Path {path!r} does not exist.")

parser.add_argument(
"--location",
type=existing_file,
default=None,
dest="location",
help="Explicit filesytem location for module",
)

parser.add_argument(
"module_name",
metavar="MODULE_NAME",
Expand Down Expand Up @@ -324,6 +344,7 @@ def main():
sub_dir=sub_dir,
dry_run=args.dry_run,
writer=Writer(stub_ext=args.stub_extension),
location=args.location,
)


Expand All @@ -345,6 +366,16 @@ def to_output_and_subdir(
return out_dir.joinpath(*module_path[:-1]), sub_dir


def import_module_from_path(module_name: str, location: Path) -> ModuleType:
spec = importlib.util.spec_from_file_location(module_name, location)
if not (spec and spec.loader):
raise ImportError(f"Can't import {module_name} from {location}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
sys.modules[module_name] = module
return module


def run(
parser: IParser,
printer: Printer,
Expand All @@ -353,10 +384,13 @@ def run(
sub_dir: Path | None,
dry_run: bool,
writer: Writer,
location: Path | None,
):
module = parser.handle_module(
QualifiedName.from_str(module_name), importlib.import_module(module_name)
)
if location:
pymodule = import_module_from_path(module_name, location)
else:
pymodule = importlib.import_module(module_name)
module = parser.handle_module(QualifiedName.from_str(module_name), pymodule)
parser.finalize()

if module is None:
Expand Down