Skip to content

Commit

Permalink
Merge pull request #125 from sovrasov/vs/upd_setup
Browse files Browse the repository at this point in the history
Update setup script and license headers
  • Loading branch information
sovrasov authored Nov 27, 2023
2 parents 316cda9 + 4245bae commit 1fc2ed5
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 23 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@ jobs:
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f test_requirements.txt ]; then pip install -r test_requirements.txt; fi
- name: Install ptflops
run: |
python setup.py develop
pip install .[dev]
- name: Testing with pytest
run: |
python -m pytest . -s
- name: Linting with flake8
run: |
python -m flake8 .
python -m isort -rc --check-only --diff .
python -m isort -rc --check-only --diff ./ptflops ./tests
22 changes: 13 additions & 9 deletions ptflops/flops_counter.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
'''
Copyright (C) 2019-2021 Sovrasov V. - All Rights Reserved
Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
* this file. If not visit https://opensource.org/licenses/MIT
'''

import sys
from typing import Any, Callable, Dict, TextIO, Tuple, Union

import torch.nn as nn

from .pytorch_engine import get_flops_pytorch
from .utils import flops_to_string, params_to_string


def get_model_complexity_info(model, input_res,
print_per_layer_stat=True,
as_strings=True,
input_constructor=None, ost=sys.stdout,
verbose=False, ignore_modules=[],
custom_modules_hooks={}, backend='pytorch',
flops_units=None, param_units=None,
output_precision=2):
def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...],
print_per_layer_stat: bool = True,
as_strings: bool = True,
input_constructor: Union[Callable, None] = None,
ost: TextIO = sys.stdout,
verbose: bool = False, ignore_modules=[],
custom_modules_hooks: Dict[nn.Module, Any] = {},
backend: str = 'pytorch',
flops_units: Union[str, None] = None,
param_units: Union[str, None] = None,
output_precision: int = 2):
assert type(input_res) is tuple
assert len(input_res) >= 1
assert isinstance(model, nn.Module)
Expand Down
2 changes: 1 addition & 1 deletion ptflops/pytorch_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
Copyright (C) 2021 Sovrasov V. - All Rights Reserved
Copyright (C) 2021-2023 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
Expand Down
2 changes: 1 addition & 1 deletion ptflops/pytorch_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
Copyright (C) 2021 Sovrasov V. - All Rights Reserved
Copyright (C) 2021-2023 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
Expand Down
32 changes: 23 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,49 @@
from setuptools import find_packages, setup
'''
Copyright (C) 2018-2023 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
* this file. If not visit https://opensource.org/licenses/MIT
'''

from pathlib import Path

readme = open('README.md').read()
from setuptools import find_packages, setup

VERSION = '0.7.1.2'

requirements = [
'torch',
]

SETUP_DIR = Path(__file__).resolve().parent

TEST_BASE_EXTRAS = (SETUP_DIR / 'test_requirements.txt').read_text()
EXTRAS_REQUIRE = {
'dev': TEST_BASE_EXTRAS,
}

setup(
# Metadata
name='ptflops',
version=VERSION,
author='Vladislav Sovrasov',
author_email='sovrasov.vlad@gmail.com',
url='https://github.com/sovrasov/flops-counter.pytorch',
description='Flops counter for convolutional networks in'
'pytorch framework',
long_description=readme,
long_description=(SETUP_DIR / 'README.md').read_text(),
long_description_content_type='text/markdown',
license='MIT',

# Package info
packages=find_packages(exclude=('*test*',)),
packages=find_packages(SETUP_DIR, exclude=('*test*',)),
package_dir={'ptflops': str(SETUP_DIR / 'ptflops')},

#
zip_safe=True,
install_requires=requirements,
extras_require=EXTRAS_REQUIRE,
python_requires='>=3.7',

# Classifiers
classifiers=[
'Programming Language :: Python :: 3',
'MIT Software License :: Programming Language :: Python :: 3',
],
)

0 comments on commit 1fc2ed5

Please sign in to comment.