diff --git a/edm/LICENSE.txt b/edm/LICENSE.txt new file mode 100644 index 0000000..2a272a4 --- /dev/null +++ b/edm/LICENSE.txt @@ -0,0 +1,439 @@ +Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the "Licensor." The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/edm/dnnlib/__init__.py b/edm/dnnlib/__init__.py new file mode 100644 index 0000000..a25aca4 --- /dev/null +++ b/edm/dnnlib/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +from .util import EasyDict, make_cache_dir_path diff --git a/edm/dnnlib/util.py b/edm/dnnlib/util.py new file mode 100644 index 0000000..b7f9c01 --- /dev/null +++ b/edm/dnnlib/util.py @@ -0,0 +1,491 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Miscellaneous utility classes and functions.""" + +import ctypes +import fnmatch +import importlib +import inspect +import numpy as np +import os +import shutil +import sys +import types +import io +import pickle +import re +import requests +import html +import hashlib +import glob +import tempfile +import urllib +import urllib.request +import uuid + +from distutils.util import strtobool +from typing import Any, List, Tuple, Union, Optional + + +# Util classes +# ------------------------------------------------------------------------------------------ + + +class EasyDict(dict): + """Convenience class that behaves like a dict but allows access with the attribute syntax.""" + + def __getattr__(self, name: str) -> Any: + try: + return self[name] + except KeyError: + raise AttributeError(name) + + def __setattr__(self, name: str, value: Any) -> None: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + +class Logger(object): + """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" + + def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): + self.file = None + + if file_name is not None: + self.file = open(file_name, file_mode) + + self.should_flush = should_flush + self.stdout = sys.stdout + self.stderr = sys.stderr + + sys.stdout = self + sys.stderr = self + + def __enter__(self) -> "Logger": + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def write(self, text: Union[str, bytes]) -> None: + """Write text to stdout (and a file) and optionally flush.""" + if isinstance(text, bytes): + text = text.decode() + if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash + return + + if self.file is not None: + self.file.write(text) + + self.stdout.write(text) + + if self.should_flush: + self.flush() + + def flush(self) -> None: + """Flush written text to both stdout and a file, if open.""" + if self.file is not None: + self.file.flush() + + self.stdout.flush() + + def close(self) -> None: + """Flush, close possible files, and remove stdout/stderr mirroring.""" + self.flush() + + # if using multiple loggers, prevent closing in wrong order + if sys.stdout is self: + sys.stdout = self.stdout + if sys.stderr is self: + sys.stderr = self.stderr + + if self.file is not None: + self.file.close() + self.file = None + + +# Cache directories +# ------------------------------------------------------------------------------------------ + +_dnnlib_cache_dir = None + +def set_cache_dir(path: str) -> None: + global _dnnlib_cache_dir + _dnnlib_cache_dir = path + +def make_cache_dir_path(*paths: str) -> str: + if _dnnlib_cache_dir is not None: + return os.path.join(_dnnlib_cache_dir, *paths) + if 'DNNLIB_CACHE_DIR' in os.environ: + return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) + if 'HOME' in os.environ: + return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) + if 'USERPROFILE' in os.environ: + return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) + return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) + +# Small util functions +# ------------------------------------------------------------------------------------------ + + +def format_time(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) + else: + return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) + + +def format_time_brief(seconds: Union[int, float]) -> str: + """Convert the seconds to human readable string with days, hours, minutes and seconds.""" + s = int(np.rint(seconds)) + + if s < 60: + return "{0}s".format(s) + elif s < 60 * 60: + return "{0}m {1:02}s".format(s // 60, s % 60) + elif s < 24 * 60 * 60: + return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) + else: + return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) + + +def ask_yes_no(question: str) -> bool: + """Ask the user the question until the user inputs a valid answer.""" + while True: + try: + print("{0} [y/n]".format(question)) + return strtobool(input().lower()) + except ValueError: + pass + + +def tuple_product(t: Tuple) -> Any: + """Calculate the product of the tuple elements.""" + result = 1 + + for v in t: + result *= v + + return result + + +_str_to_ctype = { + "uint8": ctypes.c_ubyte, + "uint16": ctypes.c_uint16, + "uint32": ctypes.c_uint32, + "uint64": ctypes.c_uint64, + "int8": ctypes.c_byte, + "int16": ctypes.c_int16, + "int32": ctypes.c_int32, + "int64": ctypes.c_int64, + "float32": ctypes.c_float, + "float64": ctypes.c_double +} + + +def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: + """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" + type_str = None + + if isinstance(type_obj, str): + type_str = type_obj + elif hasattr(type_obj, "__name__"): + type_str = type_obj.__name__ + elif hasattr(type_obj, "name"): + type_str = type_obj.name + else: + raise RuntimeError("Cannot infer type name from input") + + assert type_str in _str_to_ctype.keys() + + my_dtype = np.dtype(type_str) + my_ctype = _str_to_ctype[type_str] + + assert my_dtype.itemsize == ctypes.sizeof(my_ctype) + + return my_dtype, my_ctype + + +def is_pickleable(obj: Any) -> bool: + try: + with io.BytesIO() as stream: + pickle.dump(obj, stream) + return True + except: + return False + + +# Functionality to import modules/objects by name, and call functions by name +# ------------------------------------------------------------------------------------------ + +def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: + """Searches for the underlying module behind the name to some python object. + Returns the module and the object name (original name with module part removed).""" + + # allow convenience shorthands, substitute them by full names + obj_name = re.sub("^np.", "numpy.", obj_name) + obj_name = re.sub("^tf.", "tensorflow.", obj_name) + + # list alternatives for (module_name, local_obj_name) + parts = obj_name.split(".") + name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] + + # try each alternative in turn + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + return module, local_obj_name + except: + pass + + # maybe some of the modules themselves contain errors? + for module_name, _local_obj_name in name_pairs: + try: + importlib.import_module(module_name) # may raise ImportError + except ImportError: + if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): + raise + + # maybe the requested attribute is missing? + for module_name, local_obj_name in name_pairs: + try: + module = importlib.import_module(module_name) # may raise ImportError + get_obj_from_module(module, local_obj_name) # may raise AttributeError + except ImportError: + pass + + # we are out of luck, but we have no idea why + raise ImportError(obj_name) + + +def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: + """Traverses the object name and returns the last (rightmost) python object.""" + if obj_name == '': + return module + obj = module + for part in obj_name.split("."): + obj = getattr(obj, part) + return obj + + +def get_obj_by_name(name: str) -> Any: + """Finds the python object with the given name.""" + module, obj_name = get_module_from_obj_name(name) + return get_obj_from_module(module, obj_name) + + +def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: + """Finds the python object with the given name and calls it as a function.""" + assert func_name is not None + func_obj = get_obj_by_name(func_name) + assert callable(func_obj) + return func_obj(*args, **kwargs) + + +def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: + """Finds the python class with the given name and constructs it with the given arguments.""" + return call_func_by_name(*args, func_name=class_name, **kwargs) + + +def get_module_dir_by_obj_name(obj_name: str) -> str: + """Get the directory path of the module containing the given object name.""" + module, _ = get_module_from_obj_name(obj_name) + return os.path.dirname(inspect.getfile(module)) + + +def is_top_level_function(obj: Any) -> bool: + """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" + return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ + + +def get_top_level_function_name(obj: Any) -> str: + """Return the fully-qualified name of a top-level function.""" + assert is_top_level_function(obj) + module = obj.__module__ + if module == '__main__': + module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] + return module + "." + obj.__name__ + + +# File system helpers +# ------------------------------------------------------------------------------------------ + +def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: + """List all files recursively in a given directory while ignoring given file and directory names. + Returns list of tuples containing both absolute and relative paths.""" + assert os.path.isdir(dir_path) + base_name = os.path.basename(os.path.normpath(dir_path)) + + if ignores is None: + ignores = [] + + result = [] + + for root, dirs, files in os.walk(dir_path, topdown=True): + for ignore_ in ignores: + dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] + + # dirs need to be edited in-place + for d in dirs_to_remove: + dirs.remove(d) + + files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] + + absolute_paths = [os.path.join(root, f) for f in files] + relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] + + if add_base_to_relative: + relative_paths = [os.path.join(base_name, p) for p in relative_paths] + + assert len(absolute_paths) == len(relative_paths) + result += zip(absolute_paths, relative_paths) + + return result + + +def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: + """Takes in a list of tuples of (src, dst) paths and copies files. + Will create all necessary directories.""" + for file in files: + target_dir_name = os.path.dirname(file[1]) + + # will create all intermediate-level directories + if not os.path.exists(target_dir_name): + os.makedirs(target_dir_name) + + shutil.copyfile(file[0], file[1]) + + +# URL helpers +# ------------------------------------------------------------------------------------------ + +def is_url(obj: Any, allow_file_urls: bool = False) -> bool: + """Determine whether the given object is a valid URL string.""" + if not isinstance(obj, str) or not "://" in obj: + return False + if allow_file_urls and obj.startswith('file://'): + return True + try: + res = requests.compat.urlparse(obj) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) + if not res.scheme or not res.netloc or not "." in res.netloc: + return False + except: + return False + return True + + +def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + assert not (return_filename and (not cache)) + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + assert is_url(url) + + # Lookup from cache. + if cache_dir is None: + cache_dir = make_cache_dir_path('downloads') + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + if cache: + cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) + if len(cache_files) == 1: + filename = cache_files[0] + return filename if return_filename else open(filename, "rb") + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Save to cache. + if cache: + safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) + safe_name = safe_name[:min(len(safe_name), 128)] + cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) + temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) + os.makedirs(cache_dir, exist_ok=True) + with open(temp_file, "wb") as f: + f.write(url_data) + os.replace(temp_file, cache_file) # atomic + if return_filename: + return cache_file + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) diff --git a/edm/generate.py b/edm/generate.py new file mode 100644 index 0000000..45c1241 --- /dev/null +++ b/edm/generate.py @@ -0,0 +1,316 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Generate random images using the techniques described in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models".""" + +import os +import re +import click +import tqdm +import pickle +import numpy as np +import torch +import PIL.Image +import dnnlib +from torch_utils import distributed as dist + +#---------------------------------------------------------------------------- +# Proposed EDM sampler (Algorithm 2). + +def edm_sampler( + net, latents, class_labels=None, randn_like=torch.randn_like, + num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, +): + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Time step discretization. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + x_next = latents.to(torch.float64) * t_steps[0] + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 + t_hat = net.round_sigma(t_cur + gamma * t_cur) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) + + # Euler step. + denoised = net(x_hat, t_hat, class_labels).to(torch.float64) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = net(x_next, t_next, class_labels).to(torch.float64) + d_prime = (x_next - denoised) / t_next + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + + return x_next + +#---------------------------------------------------------------------------- +# Generalized ablation sampler, representing the superset of all sampling +# methods discussed in the paper. + +def ablation_sampler( + net, latents, class_labels=None, randn_like=torch.randn_like, + num_steps=18, sigma_min=None, sigma_max=None, rho=7, + solver='heun', discretization='edm', schedule='linear', scaling='none', + epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1, + S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, +): + assert solver in ['euler', 'heun'] + assert discretization in ['vp', 've', 'iddpm', 'edm'] + assert schedule in ['vp', 've', 'linear'] + assert scaling in ['vp', 'none'] + + # Helper functions for VP & VE noise level schedules. + vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5 + vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t)) + vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d + ve_sigma = lambda t: t.sqrt() + ve_sigma_deriv = lambda t: 0.5 / t.sqrt() + ve_sigma_inv = lambda sigma: sigma ** 2 + + # Select default noise level range based on the specified time step discretization. + if sigma_min is None: + vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=epsilon_s) + sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization] + if sigma_max is None: + vp_def = vp_sigma(beta_d=19.9, beta_min=0.1)(t=1) + sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization] + + # Adjust noise levels based on what's supported by the network. + sigma_min = max(sigma_min, net.sigma_min) + sigma_max = min(sigma_max, net.sigma_max) + + # Compute corresponding betas for VP. + vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1) + vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d + + # Define time steps in terms of noise level. + step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) + if discretization == 'vp': + orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) + sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps) + elif discretization == 've': + orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1))) + sigma_steps = ve_sigma(orig_t_steps) + elif discretization == 'iddpm': + u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device) + alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 + for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1 + u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() + u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)] + sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] + else: + assert discretization == 'edm' + sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho + + # Define noise level schedule. + if schedule == 'vp': + sigma = vp_sigma(vp_beta_d, vp_beta_min) + sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min) + sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min) + elif schedule == 've': + sigma = ve_sigma + sigma_deriv = ve_sigma_deriv + sigma_inv = ve_sigma_inv + else: + assert schedule == 'linear' + sigma = lambda t: t + sigma_deriv = lambda t: 1 + sigma_inv = lambda sigma: sigma + + # Define scaling schedule. + if scaling == 'vp': + s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt() + s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3) + else: + assert scaling == 'none' + s = lambda t: 1 + s_deriv = lambda t: 0 + + # Compute final time steps based on the corresponding noise levels. + t_steps = sigma_inv(net.round_sigma(sigma_steps)) + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + t_next = t_steps[0] + x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next)) + for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 + t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) + x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur) + + # Euler step. + h = t_next - t_hat + denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64) + d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised + x_prime = x_hat + alpha * h * d_cur + t_prime = t_hat + alpha * h + + # Apply 2nd order correction. + if solver == 'euler' or i == num_steps - 1: + x_next = x_hat + h * d_cur + else: + assert solver == 'heun' + denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64) + d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised + x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) + + return x_next + +#---------------------------------------------------------------------------- +# Wrapper for torch.Generator that allows specifying a different random seed +# for each sample in a minibatch. + +class StackedRandomGenerator: + def __init__(self, device, seeds): + super().__init__() + self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] + + def randn(self, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) + + def randn_like(self, input): + return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) + + def randint(self, *args, size, **kwargs): + assert size[0] == len(self.generators) + return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) + +#---------------------------------------------------------------------------- +# Parse a comma separated list of numbers or ranges and return a list of ints. +# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + +def parse_int_list(s): + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + m = range_re.match(p) + if m: + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +@click.command() +@click.option('--network', 'network_pkl', help='Network pickle filename', metavar='PATH|URL', type=str, required=True) +@click.option('--outdir', help='Where to save the output images', metavar='DIR', type=str, required=True) +@click.option('--seeds', help='Random seeds (e.g. 1,2,5-10)', metavar='LIST', type=parse_int_list, default='0-63', show_default=True) +@click.option('--subdirs', help='Create subdirectory for every 1000 seeds', is_flag=True) +@click.option('--class', 'class_idx', help='Class label [default: random]', metavar='INT', type=click.IntRange(min=0), default=None) +@click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) + +@click.option('--steps', 'num_steps', help='Number of sampling steps', metavar='INT', type=click.IntRange(min=1), default=18, show_default=True) +@click.option('--sigma_min', help='Lowest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) +@click.option('--sigma_max', help='Highest noise level [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True)) +@click.option('--rho', help='Time step exponent', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=7, show_default=True) +@click.option('--S_churn', 'S_churn', help='Stochasticity strength', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) +@click.option('--S_min', 'S_min', help='Stoch. min noise level', metavar='FLOAT', type=click.FloatRange(min=0), default=0, show_default=True) +@click.option('--S_max', 'S_max', help='Stoch. max noise level', metavar='FLOAT', type=click.FloatRange(min=0), default='inf', show_default=True) +@click.option('--S_noise', 'S_noise', help='Stoch. noise inflation', metavar='FLOAT', type=float, default=1, show_default=True) + +@click.option('--solver', help='Ablate ODE solver', metavar='euler|heun', type=click.Choice(['euler', 'heun'])) +@click.option('--disc', 'discretization', help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm'])) +@click.option('--schedule', help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear', type=click.Choice(['vp', 've', 'linear'])) +@click.option('--scaling', help='Ablate signal scaling s(t)', metavar='vp|none', type=click.Choice(['vp', 'none'])) + +def main(network_pkl, outdir, subdirs, seeds, class_idx, max_batch_size, device=torch.device('cuda'), **sampler_kwargs): + """Generate random images using the techniques described in the paper + "Elucidating the Design Space of Diffusion-Based Generative Models". + + Examples: + + \b + # Generate 64 images and save them as out/*.png + python generate.py --outdir=out --seeds=0-63 --batch=64 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl + + \b + # Generate 1024 images using 2 GPUs + torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \\ + --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl + """ + dist.init() + num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() + all_batches = torch.as_tensor(seeds).tensor_split(num_batches) + rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] + + # Rank 0 goes first. + if dist.get_rank() != 0: + torch.distributed.barrier() + + # Load network. + dist.print0(f'Loading network from "{network_pkl}"...') + with dnnlib.util.open_url(network_pkl, verbose=(dist.get_rank() == 0)) as f: + net = pickle.load(f)['ema'].to(device) + + # Other ranks follow. + if dist.get_rank() == 0: + torch.distributed.barrier() + + # Loop over batches. + dist.print0(f'Generating {len(seeds)} images to "{outdir}"...') + for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)): + torch.distributed.barrier() + batch_size = len(batch_seeds) + if batch_size == 0: + continue + + # Pick latents and labels. + rnd = StackedRandomGenerator(device, batch_seeds) + latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) + class_labels = None + if net.label_dim: + class_labels = torch.eye(net.label_dim, device=device)[rnd.randint(net.label_dim, size=[batch_size], device=device)] + if class_idx is not None: + class_labels[:, :] = 0 + class_labels[:, class_idx] = 1 + + # Generate images. + sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None} + have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling']) + sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler + images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like, **sampler_kwargs) + + # Save images. + images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy() + for seed, image_np in zip(batch_seeds, images_np): + image_dir = os.path.join(outdir, f'{seed-seed%1000:06d}') if subdirs else outdir + os.makedirs(image_dir, exist_ok=True) + image_path = os.path.join(image_dir, f'{seed:06d}.png') + if image_np.shape[2] == 1: + PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path) + else: + PIL.Image.fromarray(image_np, 'RGB').save(image_path) + + # Done. + torch.distributed.barrier() + dist.print0('Done.') + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +#---------------------------------------------------------------------------- diff --git a/edm/torch_utils/__init__.py b/edm/torch_utils/__init__.py new file mode 100644 index 0000000..d76c2b6 --- /dev/null +++ b/edm/torch_utils/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +# empty diff --git a/edm/torch_utils/distributed.py b/edm/torch_utils/distributed.py new file mode 100644 index 0000000..92466ef --- /dev/null +++ b/edm/torch_utils/distributed.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +import os +import torch +from . import training_stats + +#---------------------------------------------------------------------------- + +def init(): + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = 'localhost' + if 'MASTER_PORT' not in os.environ: + os.environ['MASTER_PORT'] = '29500' + if 'RANK' not in os.environ: + os.environ['RANK'] = '0' + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = '0' + if 'WORLD_SIZE' not in os.environ: + os.environ['WORLD_SIZE'] = '1' + + backend = 'gloo' if os.name == 'nt' else 'nccl' + torch.distributed.init_process_group(backend=backend, init_method='env://') + torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) + + sync_device = torch.device('cuda') if get_world_size() > 1 else None + training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) + +#---------------------------------------------------------------------------- + +def get_rank(): + return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + +#---------------------------------------------------------------------------- + +def get_world_size(): + return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + +#---------------------------------------------------------------------------- + +def should_stop(): + return False + +#---------------------------------------------------------------------------- + +def update_progress(cur, total): + _ = cur, total + +#---------------------------------------------------------------------------- + +def print0(*args, **kwargs): + if get_rank() == 0: + print(*args, **kwargs) + +#---------------------------------------------------------------------------- diff --git a/edm/torch_utils/misc.py b/edm/torch_utils/misc.py new file mode 100644 index 0000000..f0d3184 --- /dev/null +++ b/edm/torch_utils/misc.py @@ -0,0 +1,266 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +import re +import contextlib +import numpy as np +import torch +import warnings +import dnnlib + +#---------------------------------------------------------------------------- +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor + +#---------------------------------------------------------------------------- +# Replace NaN/Inf with specified numerical values. + +try: + nan_to_num = torch.nan_to_num # 1.8.0a0 +except AttributeError: + def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin + assert isinstance(input, torch.Tensor) + if posinf is None: + posinf = torch.finfo(input.dtype).max + if neginf is None: + neginf = torch.finfo(input.dtype).min + assert nan == 0 + return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) + +#---------------------------------------------------------------------------- +# Symbolic assert. + +try: + symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access +except AttributeError: + symbolic_assert = torch.Assert # 1.7.0 + +#---------------------------------------------------------------------------- +# Context manager to temporarily suppress known warnings in torch.jit.trace(). +# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 + +@contextlib.contextmanager +def suppress_tracer_warnings(): + flt = ('ignore', None, torch.jit.TracerWarning, None, 0) + warnings.filters.insert(0, flt) + yield + warnings.filters.remove(flt) + +#---------------------------------------------------------------------------- +# Assert that the shape of a tensor matches the given list of integers. +# None indicates that the size of a dimension is allowed to vary. +# Performs symbolic assertion when used in torch.jit.trace(). + +def assert_shape(tensor, ref_shape): + if tensor.ndim != len(ref_shape): + raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') + for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): + if ref_size is None: + pass + elif isinstance(ref_size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') + elif isinstance(size, torch.Tensor): + with suppress_tracer_warnings(): # as_tensor results are registered as constants + symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') + elif size != ref_size: + raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') + +#---------------------------------------------------------------------------- +# Function decorator that calls torch.autograd.profiler.record_function(). + +def profiled_function(fn): + def decorator(*args, **kwargs): + with torch.autograd.profiler.record_function(fn.__name__): + return fn(*args, **kwargs) + decorator.__name__ = fn.__name__ + return decorator + +#---------------------------------------------------------------------------- +# Sampler for torch.utils.data.DataLoader that loops over the dataset +# indefinitely, shuffling items as it goes. + +class InfiniteSampler(torch.utils.data.Sampler): + def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): + assert len(dataset) > 0 + assert num_replicas > 0 + assert 0 <= rank < num_replicas + assert 0 <= window_size <= 1 + super().__init__(dataset) + self.dataset = dataset + self.rank = rank + self.num_replicas = num_replicas + self.shuffle = shuffle + self.seed = seed + self.window_size = window_size + + def __iter__(self): + order = np.arange(len(self.dataset)) + rnd = None + window = 0 + if self.shuffle: + rnd = np.random.RandomState(self.seed) + rnd.shuffle(order) + window = int(np.rint(order.size * self.window_size)) + + idx = 0 + while True: + i = idx % order.size + if idx % self.num_replicas == self.rank: + yield order[i] + if window >= 2: + j = (i - rnd.randint(window)) % order.size + order[i], order[j] = order[j], order[i] + idx += 1 + +#---------------------------------------------------------------------------- +# Utilities for operating with torch.nn.Module parameters and buffers. + +def params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.parameters()) + list(module.buffers()) + +def named_params_and_buffers(module): + assert isinstance(module, torch.nn.Module) + return list(module.named_parameters()) + list(module.named_buffers()) + +@torch.no_grad() +def copy_params_and_buffers(src_module, dst_module, require_all=False): + assert isinstance(src_module, torch.nn.Module) + assert isinstance(dst_module, torch.nn.Module) + src_tensors = dict(named_params_and_buffers(src_module)) + for name, tensor in named_params_and_buffers(dst_module): + assert (name in src_tensors) or (not require_all) + if name in src_tensors: + tensor.copy_(src_tensors[name]) + +#---------------------------------------------------------------------------- +# Context manager for easily enabling/disabling DistributedDataParallel +# synchronization. + +@contextlib.contextmanager +def ddp_sync(module, sync): + assert isinstance(module, torch.nn.Module) + if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): + yield + else: + with module.no_sync(): + yield + +#---------------------------------------------------------------------------- +# Check DistributedDataParallel consistency across processes. + +def check_ddp_consistency(module, ignore_regex=None): + assert isinstance(module, torch.nn.Module) + for name, tensor in named_params_and_buffers(module): + fullname = type(module).__name__ + '.' + name + if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): + continue + tensor = tensor.detach() + if tensor.is_floating_point(): + tensor = nan_to_num(tensor) + other = tensor.clone() + torch.distributed.broadcast(tensor=other, src=0) + assert (tensor == other).all(), fullname + +#---------------------------------------------------------------------------- +# Print summary table of module hierarchy. + +def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): + assert isinstance(module, torch.nn.Module) + assert not isinstance(module, torch.jit.ScriptModule) + assert isinstance(inputs, (tuple, list)) + + # Register hooks. + entries = [] + nesting = [0] + def pre_hook(_mod, _inputs): + nesting[0] += 1 + def post_hook(mod, _inputs, outputs): + nesting[0] -= 1 + if nesting[0] <= max_nesting: + outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] + outputs = [t for t in outputs if isinstance(t, torch.Tensor)] + entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) + hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] + hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] + + # Run module. + outputs = module(*inputs) + for hook in hooks: + hook.remove() + + # Identify unique outputs, parameters, and buffers. + tensors_seen = set() + for e in entries: + e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] + e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] + e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] + tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} + + # Filter out redundant entries. + if skip_redundant: + entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] + + # Construct table. + rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] + rows += [['---'] * len(rows[0])] + param_total = 0 + buffer_total = 0 + submodule_names = {mod: name for name, mod in module.named_modules()} + for e in entries: + name = '' if e.mod is module else submodule_names[e.mod] + param_size = sum(t.numel() for t in e.unique_params) + buffer_size = sum(t.numel() for t in e.unique_buffers) + output_shapes = [str(list(t.shape)) for t in e.outputs] + output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] + rows += [[ + name + (':0' if len(e.outputs) >= 2 else ''), + str(param_size) if param_size else '-', + str(buffer_size) if buffer_size else '-', + (output_shapes + ['-'])[0], + (output_dtypes + ['-'])[0], + ]] + for idx in range(1, len(e.outputs)): + rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] + param_total += param_size + buffer_total += buffer_size + rows += [['---'] * len(rows[0])] + rows += [['Total', str(param_total), str(buffer_total), '-', '-']] + + # Print table. + widths = [max(len(cell) for cell in column) for column in zip(*rows)] + print() + for row in rows: + print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) + print() + return outputs + +#---------------------------------------------------------------------------- diff --git a/edm/torch_utils/persistence.py b/edm/torch_utils/persistence.py new file mode 100644 index 0000000..fbecbe2 --- /dev/null +++ b/edm/torch_utils/persistence.py @@ -0,0 +1,257 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Facilities for pickling Python code alongside other data. + +The pickled code is automatically imported into a separate Python module +during unpickling. This way, any previously exported pickles will remain +usable even if the original code is no longer available, or if the current +version of the code is not consistent with what was originally pickled.""" + +import sys +import pickle +import io +import inspect +import copy +import uuid +import types +import dnnlib + +#---------------------------------------------------------------------------- + +_version = 6 # internal version number +_decorators = set() # {decorator_class, ...} +_import_hooks = [] # [hook_function, ...] +_module_to_src_dict = dict() # {module: src, ...} +_src_to_module_dict = dict() # {src: module, ...} + +#---------------------------------------------------------------------------- + +def persistent_class(orig_class): + r"""Class decorator that extends a given class to save its source code + when pickled. + + Example: + + from torch_utils import persistence + + @persistence.persistent_class + class MyNetwork(torch.nn.Module): + def __init__(self, num_inputs, num_outputs): + super().__init__() + self.fc = MyLayer(num_inputs, num_outputs) + ... + + @persistence.persistent_class + class MyLayer(torch.nn.Module): + ... + + When pickled, any instance of `MyNetwork` and `MyLayer` will save its + source code alongside other internal state (e.g., parameters, buffers, + and submodules). This way, any previously exported pickle will remain + usable even if the class definitions have been modified or are no + longer available. + + The decorator saves the source code of the entire Python module + containing the decorated class. It does *not* save the source code of + any imported modules. Thus, the imported modules must be available + during unpickling, also including `torch_utils.persistence` itself. + + It is ok to call functions defined in the same module from the + decorated class. However, if the decorated class depends on other + classes defined in the same module, they must be decorated as well. + This is illustrated in the above example in the case of `MyLayer`. + + It is also possible to employ the decorator just-in-time before + calling the constructor. For example: + + cls = MyLayer + if want_to_make_it_persistent: + cls = persistence.persistent_class(cls) + layer = cls(num_inputs, num_outputs) + + As an additional feature, the decorator also keeps track of the + arguments that were used to construct each instance of the decorated + class. The arguments can be queried via `obj.init_args` and + `obj.init_kwargs`, and they are automatically pickled alongside other + object state. This feature can be disabled on a per-instance basis + by setting `self._record_init_args = False` in the constructor. + + A typical use case is to first unpickle a previous instance of a + persistent class, and then upgrade it to use the latest version of + the source code: + + with open('old_pickle.pkl', 'rb') as f: + old_net = pickle.load(f) + new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) + misc.copy_params_and_buffers(old_net, new_net, require_all=True) + """ + assert isinstance(orig_class, type) + if is_persistent(orig_class): + return orig_class + + assert orig_class.__module__ in sys.modules + orig_module = sys.modules[orig_class.__module__] + orig_module_src = _module_to_src(orig_module) + + class Decorator(orig_class): + _orig_module_src = orig_module_src + _orig_class_name = orig_class.__name__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + record_init_args = getattr(self, '_record_init_args', True) + self._init_args = copy.deepcopy(args) if record_init_args else None + self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None + assert orig_class.__name__ in orig_module.__dict__ + _check_pickleable(self.__reduce__()) + + @property + def init_args(self): + assert self._init_args is not None + return copy.deepcopy(self._init_args) + + @property + def init_kwargs(self): + assert self._init_kwargs is not None + return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) + + def __reduce__(self): + fields = list(super().__reduce__()) + fields += [None] * max(3 - len(fields), 0) + if fields[0] is not _reconstruct_persistent_obj: + meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) + fields[0] = _reconstruct_persistent_obj # reconstruct func + fields[1] = (meta,) # reconstruct args + fields[2] = None # state dict + return tuple(fields) + + Decorator.__name__ = orig_class.__name__ + Decorator.__module__ = orig_class.__module__ + _decorators.add(Decorator) + return Decorator + +#---------------------------------------------------------------------------- + +def is_persistent(obj): + r"""Test whether the given object or class is persistent, i.e., + whether it will save its source code when pickled. + """ + try: + if obj in _decorators: + return True + except TypeError: + pass + return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck + +#---------------------------------------------------------------------------- + +def import_hook(hook): + r"""Register an import hook that is called whenever a persistent object + is being unpickled. A typical use case is to patch the pickled source + code to avoid errors and inconsistencies when the API of some imported + module has changed. + + The hook should have the following signature: + + hook(meta) -> modified meta + + `meta` is an instance of `dnnlib.EasyDict` with the following fields: + + type: Type of the persistent object, e.g. `'class'`. + version: Internal version number of `torch_utils.persistence`. + module_src Original source code of the Python module. + class_name: Class name in the original Python module. + state: Internal state of the object. + + Example: + + @persistence.import_hook + def wreck_my_network(meta): + if meta.class_name == 'MyNetwork': + print('MyNetwork is being imported. I will wreck it!') + meta.module_src = meta.module_src.replace("True", "False") + return meta + """ + assert callable(hook) + _import_hooks.append(hook) + +#---------------------------------------------------------------------------- + +def _reconstruct_persistent_obj(meta): + r"""Hook that is called internally by the `pickle` module to unpickle + a persistent object. + """ + meta = dnnlib.EasyDict(meta) + meta.state = dnnlib.EasyDict(meta.state) + for hook in _import_hooks: + meta = hook(meta) + assert meta is not None + + assert meta.version == _version + module = _src_to_module(meta.module_src) + + assert meta.type == 'class' + orig_class = module.__dict__[meta.class_name] + decorator_class = persistent_class(orig_class) + obj = decorator_class.__new__(decorator_class) + + setstate = getattr(obj, '__setstate__', None) + if callable(setstate): + setstate(meta.state) # pylint: disable=not-callable + else: + obj.__dict__.update(meta.state) + return obj + +#---------------------------------------------------------------------------- + +def _module_to_src(module): + r"""Query the source code of a given Python module. + """ + src = _module_to_src_dict.get(module, None) + if src is None: + src = inspect.getsource(module) + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + return src + +def _src_to_module(src): + r"""Get or create a Python module for the given source code. + """ + module = _src_to_module_dict.get(src, None) + if module is None: + module_name = "_imported_module_" + uuid.uuid4().hex + module = types.ModuleType(module_name) + sys.modules[module_name] = module + _module_to_src_dict[module] = src + _src_to_module_dict[src] = module + exec(src, module.__dict__) # pylint: disable=exec-used + return module + +#---------------------------------------------------------------------------- + +def _check_pickleable(obj): + r"""Check that the given object is pickleable, raising an exception if + it is not. This function is expected to be considerably more efficient + than actually pickling the object. + """ + def recurse(obj): + if isinstance(obj, (list, tuple, set)): + return [recurse(x) for x in obj] + if isinstance(obj, dict): + return [[recurse(x), recurse(y)] for x, y in obj.items()] + if isinstance(obj, (str, int, float, bool, bytes, bytearray)): + return None # Python primitive types are pickleable. + if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: + return None # NumPy arrays and PyTorch tensors are pickleable. + if is_persistent(obj): + return None # Persistent objects are pickleable, by virtue of the constructor check. + return obj + with io.BytesIO() as f: + pickle.dump(recurse(obj), f) + +#---------------------------------------------------------------------------- diff --git a/edm/torch_utils/training_stats.py b/edm/torch_utils/training_stats.py new file mode 100644 index 0000000..727c4e8 --- /dev/null +++ b/edm/torch_utils/training_stats.py @@ -0,0 +1,272 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Facilities for reporting and collecting training statistics across +multiple processes and devices. The interface is designed to minimize +synchronization overhead as well as the amount of boilerplate in user +code.""" + +import re +import numpy as np +import torch +import dnnlib + +from . import misc + +#---------------------------------------------------------------------------- + +_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] +_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. +_counter_dtype = torch.float64 # Data type to use for the internal counters. +_rank = 0 # Rank of the current process. +_sync_device = None # Device to use for multiprocess communication. None = single-process. +_sync_called = False # Has _sync() been called yet? +_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor +_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor + +#---------------------------------------------------------------------------- + +def init_multiprocessing(rank, sync_device): + r"""Initializes `torch_utils.training_stats` for collecting statistics + across multiple processes. + + This function must be called after + `torch.distributed.init_process_group()` and before `Collector.update()`. + The call is not necessary if multi-process collection is not needed. + + Args: + rank: Rank of the current process. + sync_device: PyTorch device to use for inter-process + communication, or None to disable multi-process + collection. Typically `torch.device('cuda', rank)`. + """ + global _rank, _sync_device + assert not _sync_called + _rank = rank + _sync_device = sync_device + +#---------------------------------------------------------------------------- + +@misc.profiled_function +def report(name, value): + r"""Broadcasts the given set of scalars to all interested instances of + `Collector`, across device and process boundaries. + + This function is expected to be extremely cheap and can be safely + called from anywhere in the training loop, loss function, or inside a + `torch.nn.Module`. + + Warning: The current implementation expects the set of unique names to + be consistent across processes. Please make sure that `report()` is + called at least once for each unique name by each process, and in the + same order. If a given process has no scalars to broadcast, it can do + `report(name, [])` (empty list). + + Args: + name: Arbitrary string specifying the name of the statistic. + Averages are accumulated separately for each unique name. + value: Arbitrary set of scalars. Can be a list, tuple, + NumPy array, PyTorch tensor, or Python scalar. + + Returns: + The same `value` that was passed in. + """ + if name not in _counters: + _counters[name] = dict() + + elems = torch.as_tensor(value) + if elems.numel() == 0: + return value + + elems = elems.detach().flatten().to(_reduce_dtype) + moments = torch.stack([ + torch.ones_like(elems).sum(), + elems.sum(), + elems.square().sum(), + ]) + assert moments.ndim == 1 and moments.shape[0] == _num_moments + moments = moments.to(_counter_dtype) + + device = moments.device + if device not in _counters[name]: + _counters[name][device] = torch.zeros_like(moments) + _counters[name][device].add_(moments) + return value + +#---------------------------------------------------------------------------- + +def report0(name, value): + r"""Broadcasts the given set of scalars by the first process (`rank = 0`), + but ignores any scalars provided by the other processes. + See `report()` for further details. + """ + report(name, value if _rank == 0 else []) + return value + +#---------------------------------------------------------------------------- + +class Collector: + r"""Collects the scalars broadcasted by `report()` and `report0()` and + computes their long-term averages (mean and standard deviation) over + user-defined periods of time. + + The averages are first collected into internal counters that are not + directly visible to the user. They are then copied to the user-visible + state as a result of calling `update()` and can then be queried using + `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the + internal counters for the next round, so that the user-visible state + effectively reflects averages collected between the last two calls to + `update()`. + + Args: + regex: Regular expression defining which statistics to + collect. The default is to collect everything. + keep_previous: Whether to retain the previous averages if no + scalars were collected on a given round + (default: True). + """ + def __init__(self, regex='.*', keep_previous=True): + self._regex = re.compile(regex) + self._keep_previous = keep_previous + self._cumulative = dict() + self._moments = dict() + self.update() + self._moments.clear() + + def names(self): + r"""Returns the names of all statistics broadcasted so far that + match the regular expression specified at construction time. + """ + return [name for name in _counters if self._regex.fullmatch(name)] + + def update(self): + r"""Copies current values of the internal counters to the + user-visible state and resets them for the next round. + + If `keep_previous=True` was specified at construction time, the + operation is skipped for statistics that have received no scalars + since the last update, retaining their previous averages. + + This method performs a number of GPU-to-CPU transfers and one + `torch.distributed.all_reduce()`. It is intended to be called + periodically in the main training loop, typically once every + N training steps. + """ + if not self._keep_previous: + self._moments.clear() + for name, cumulative in _sync(self.names()): + if name not in self._cumulative: + self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + delta = cumulative - self._cumulative[name] + self._cumulative[name].copy_(cumulative) + if float(delta[0]) != 0: + self._moments[name] = delta + + def _get_delta(self, name): + r"""Returns the raw moments that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + assert self._regex.fullmatch(name) + if name not in self._moments: + self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + return self._moments[name] + + def num(self, name): + r"""Returns the number of scalars that were accumulated for the given + statistic between the last two calls to `update()`, or zero if + no scalars were collected. + """ + delta = self._get_delta(name) + return int(delta[0]) + + def mean(self, name): + r"""Returns the mean of the scalars that were accumulated for the + given statistic between the last two calls to `update()`, or NaN if + no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0: + return float('nan') + return float(delta[1] / delta[0]) + + def std(self, name): + r"""Returns the standard deviation of the scalars that were + accumulated for the given statistic between the last two calls to + `update()`, or NaN if no scalars were collected. + """ + delta = self._get_delta(name) + if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): + return float('nan') + if int(delta[0]) == 1: + return float(0) + mean = float(delta[1] / delta[0]) + raw_var = float(delta[2] / delta[0]) + return np.sqrt(max(raw_var - np.square(mean), 0)) + + def as_dict(self): + r"""Returns the averages accumulated between the last two calls to + `update()` as an `dnnlib.EasyDict`. The contents are as follows: + + dnnlib.EasyDict( + NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), + ... + ) + """ + stats = dnnlib.EasyDict() + for name in self.names(): + stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) + return stats + + def __getitem__(self, name): + r"""Convenience getter. + `collector[name]` is a synonym for `collector.mean(name)`. + """ + return self.mean(name) + +#---------------------------------------------------------------------------- + +def _sync(names): + r"""Synchronize the global cumulative counters across devices and + processes. Called internally by `Collector.update()`. + """ + if len(names) == 0: + return [] + global _sync_called + _sync_called = True + + # Collect deltas within current rank. + deltas = [] + device = _sync_device if _sync_device is not None else torch.device('cpu') + for name in names: + delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) + for counter in _counters[name].values(): + delta.add_(counter.to(device)) + counter.copy_(torch.zeros_like(counter)) + deltas.append(delta) + deltas = torch.stack(deltas) + + # Sum deltas across ranks. + if _sync_device is not None: + torch.distributed.all_reduce(deltas) + + # Update cumulative values. + deltas = deltas.cpu() + for idx, name in enumerate(names): + if name not in _cumulative: + _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) + _cumulative[name].add_(deltas[idx]) + + # Return name-value pairs. + return [(name, _cumulative[name]) for name in names] + +#---------------------------------------------------------------------------- +# Convenience. + +default_collector = Collector() + +#---------------------------------------------------------------------------- diff --git a/edm/train.py b/edm/train.py new file mode 100644 index 0000000..6851604 --- /dev/null +++ b/edm/train.py @@ -0,0 +1,236 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Train diffusion-based generative model using the techniques described in the +paper "Elucidating the Design Space of Diffusion-Based Generative Models".""" + +import os +import re +import json +import click +import torch +import dnnlib +from torch_utils import distributed as dist +from training import training_loop + +import warnings +warnings.filterwarnings('ignore', 'Grad strides do not match bucket view strides') # False warning printed by PyTorch 1.12. + +#---------------------------------------------------------------------------- +# Parse a comma separated list of numbers or ranges and return a list of ints. +# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] + +def parse_int_list(s): + if isinstance(s, list): return s + ranges = [] + range_re = re.compile(r'^(\d+)-(\d+)$') + for p in s.split(','): + m = range_re.match(p) + if m: + ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) + else: + ranges.append(int(p)) + return ranges + +#---------------------------------------------------------------------------- + +@click.command() + +# Main options. +@click.option('--outdir', help='Where to save the results', metavar='DIR', type=str, required=True) +@click.option('--data', help='Path to the dataset', metavar='ZIP|DIR', type=str, required=True) +@click.option('--cond', help='Train class-conditional model', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--arch', help='Network architecture', metavar='ddpmpp|ncsnpp|adm', type=click.Choice(['ddpmpp', 'ncsnpp', 'adm']), default='ddpmpp', show_default=True) +@click.option('--precond', help='Preconditioning & loss function', metavar='vp|ve|edm', type=click.Choice(['vp', 've', 'edm']), default='edm', show_default=True) + +# Hyperparameters. +@click.option('--duration', help='Training duration', metavar='MIMG', type=click.FloatRange(min=0, min_open=True), default=200, show_default=True) +@click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True) +@click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1)) +@click.option('--cbase', help='Channel multiplier [default: varies]', metavar='INT', type=int) +@click.option('--cres', help='Channels per resolution [default: varies]', metavar='LIST', type=parse_int_list) +@click.option('--lr', help='Learning rate', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=10e-4, show_default=True) +@click.option('--ema', help='EMA half-life', metavar='MIMG', type=click.FloatRange(min=0), default=0.5, show_default=True) +@click.option('--dropout', help='Dropout probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.13, show_default=True) +@click.option('--augment', help='Augment probability', metavar='FLOAT', type=click.FloatRange(min=0, max=1), default=0.12, show_default=True) +@click.option('--xflip', help='Enable dataset x-flips', metavar='BOOL', type=bool, default=False, show_default=True) + +# Performance-related. +@click.option('--fp16', help='Enable mixed-precision training', metavar='BOOL', type=bool, default=False, show_default=True) +@click.option('--ls', help='Loss scaling', metavar='FLOAT', type=click.FloatRange(min=0, min_open=True), default=1, show_default=True) +@click.option('--bench', help='Enable cuDNN benchmarking', metavar='BOOL', type=bool, default=True, show_default=True) +@click.option('--cache', help='Cache dataset in CPU memory', metavar='BOOL', type=bool, default=True, show_default=True) +@click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=1), default=1, show_default=True) + +# I/O-related. +@click.option('--desc', help='String to include in result dir name', metavar='STR', type=str) +@click.option('--nosubdir', help='Do not create a subdirectory for results', is_flag=True) +@click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=50, show_default=True) +@click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True) +@click.option('--dump', help='How often to dump state', metavar='TICKS', type=click.IntRange(min=1), default=500, show_default=True) +@click.option('--seed', help='Random seed [default: random]', metavar='INT', type=int) +@click.option('--transfer', help='Transfer learning from network pickle', metavar='PKL|URL', type=str) +@click.option('--resume', help='Resume from previous training state', metavar='PT', type=str) +@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True) + +def main(**kwargs): + """Train diffusion-based generative model using the techniques described in the + paper "Elucidating the Design Space of Diffusion-Based Generative Models". + + Examples: + + \b + # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs + torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \\ + --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp + """ + opts = dnnlib.EasyDict(kwargs) + torch.multiprocessing.set_start_method('spawn') + dist.init() + + # Initialize config dict. + c = dnnlib.EasyDict() + c.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=opts.data, use_labels=opts.cond, xflip=opts.xflip, cache=opts.cache) + c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=opts.workers, prefetch_factor=2) + c.network_kwargs = dnnlib.EasyDict() + c.loss_kwargs = dnnlib.EasyDict() + c.optimizer_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=opts.lr, betas=[0.9,0.999], eps=1e-8) + + # Validate dataset options. + try: + dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs) + dataset_name = dataset_obj.name + c.dataset_kwargs.resolution = dataset_obj.resolution # be explicit about dataset resolution + c.dataset_kwargs.max_size = len(dataset_obj) # be explicit about dataset size + if opts.cond and not dataset_obj.has_labels: + raise click.ClickException('--cond=True requires labels specified in dataset.json') + del dataset_obj # conserve memory + except IOError as err: + raise click.ClickException(f'--data: {err}') + + # Network architecture. + if opts.arch == 'ddpmpp': + c.network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard') + c.network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2]) + elif opts.arch == 'ncsnpp': + c.network_kwargs.update(model_type='SongUNet', embedding_type='fourier', encoder_type='residual', decoder_type='standard') + c.network_kwargs.update(channel_mult_noise=2, resample_filter=[1,3,3,1], model_channels=128, channel_mult=[2,2,2]) + else: + assert opts.arch == 'adm' + c.network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4]) + + # Preconditioning & loss function. + if opts.precond == 'vp': + c.network_kwargs.class_name = 'training.networks.VPPrecond' + c.loss_kwargs.class_name = 'training.loss.VPLoss' + elif opts.precond == 've': + c.network_kwargs.class_name = 'training.networks.VEPrecond' + c.loss_kwargs.class_name = 'training.loss.VELoss' + else: + assert opts.precond == 'edm' + c.network_kwargs.class_name = 'training.networks.EDMPrecond' + c.loss_kwargs.class_name = 'training.loss.EDMLoss' + + # Network options. + if opts.cbase is not None: + c.network_kwargs.model_channels = opts.cbase + if opts.cres is not None: + c.network_kwargs.channel_mult = opts.cres + if opts.augment: + c.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', p=opts.augment) + c.augment_kwargs.update(xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1) + c.network_kwargs.augment_dim = 9 + c.network_kwargs.update(dropout=opts.dropout, use_fp16=opts.fp16) + + # Training options. + c.total_kimg = max(int(opts.duration * 1000), 1) + c.ema_halflife_kimg = int(opts.ema * 1000) + c.update(batch_size=opts.batch, batch_gpu=opts.batch_gpu) + c.update(loss_scaling=opts.ls, cudnn_benchmark=opts.bench) + c.update(kimg_per_tick=opts.tick, snapshot_ticks=opts.snap, state_dump_ticks=opts.dump) + + # Random seed. + if opts.seed is not None: + c.seed = opts.seed + else: + seed = torch.randint(1 << 31, size=[], device=torch.device('cuda')) + torch.distributed.broadcast(seed, src=0) + c.seed = int(seed) + + # Transfer learning and resume. + if opts.transfer is not None: + if opts.resume is not None: + raise click.ClickException('--transfer and --resume cannot be specified at the same time') + c.resume_pkl = opts.transfer + c.ema_rampup_ratio = None + elif opts.resume is not None: + match = re.fullmatch(r'training-state-(\d+).pt', os.path.basename(opts.resume)) + if not match or not os.path.isfile(opts.resume): + raise click.ClickException('--resume must point to training-state-*.pt from a previous training run') + c.resume_pkl = os.path.join(os.path.dirname(opts.resume), f'network-snapshot-{match.group(1)}.pkl') + c.resume_kimg = int(match.group(1)) + c.resume_state_dump = opts.resume + + # Description string. + cond_str = 'cond' if c.dataset_kwargs.use_labels else 'uncond' + dtype_str = 'fp16' if c.network_kwargs.use_fp16 else 'fp32' + desc = f'{dataset_name:s}-{cond_str:s}-{opts.arch:s}-{opts.precond:s}-gpus{dist.get_world_size():d}-batch{c.batch_size:d}-{dtype_str:s}' + if opts.desc is not None: + desc += f'-{opts.desc}' + + # Pick output directory. + if dist.get_rank() != 0: + c.run_dir = None + elif opts.nosubdir: + c.run_dir = opts.outdir + else: + prev_run_dirs = [] + if os.path.isdir(opts.outdir): + prev_run_dirs = [x for x in os.listdir(opts.outdir) if os.path.isdir(os.path.join(opts.outdir, x))] + prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] + prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] + cur_run_id = max(prev_run_ids, default=-1) + 1 + c.run_dir = os.path.join(opts.outdir, f'{cur_run_id:05d}-{desc}') + assert not os.path.exists(c.run_dir) + + # Print options. + dist.print0() + dist.print0('Training options:') + dist.print0(json.dumps(c, indent=2)) + dist.print0() + dist.print0(f'Output directory: {c.run_dir}') + dist.print0(f'Dataset path: {c.dataset_kwargs.path}') + dist.print0(f'Class-conditional: {c.dataset_kwargs.use_labels}') + dist.print0(f'Network architecture: {opts.arch}') + dist.print0(f'Preconditioning & loss: {opts.precond}') + dist.print0(f'Number of GPUs: {dist.get_world_size()}') + dist.print0(f'Batch size: {c.batch_size}') + dist.print0(f'Mixed-precision: {c.network_kwargs.use_fp16}') + dist.print0() + + # Dry run? + if opts.dry_run: + dist.print0('Dry run; exiting.') + return + + # Create output directory. + dist.print0('Creating output directory...') + if dist.get_rank() == 0: + os.makedirs(c.run_dir, exist_ok=True) + with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f: + json.dump(c, f, indent=2) + dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True) + + # Train. + training_loop.training_loop(**c) + +#---------------------------------------------------------------------------- + +if __name__ == "__main__": + main() + +#---------------------------------------------------------------------------- diff --git a/edm/training/__init__.py b/edm/training/__init__.py new file mode 100644 index 0000000..d76c2b6 --- /dev/null +++ b/edm/training/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +# empty diff --git a/edm/training/augment.py b/edm/training/augment.py new file mode 100644 index 0000000..a8d474d --- /dev/null +++ b/edm/training/augment.py @@ -0,0 +1,330 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Augmentation pipeline used in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models". +Built around the same concepts that were originally proposed in the paper +"Training Generative Adversarial Networks with Limited Data".""" + +import numpy as np +import torch +from torch_utils import persistence +from torch_utils import misc + +#---------------------------------------------------------------------------- +# Coefficients of various wavelet decomposition low-pass filters. + +wavelets = { + 'haar': [0.7071067811865476, 0.7071067811865476], + 'db1': [0.7071067811865476, 0.7071067811865476], + 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523], + 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125], + 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017], + 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236], + 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161], + 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025], + 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569], + 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427], + 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728], + 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148], + 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255], + 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609], +} + +#---------------------------------------------------------------------------- +# Helpers for constructing transformation matrices. + +def matrix(*rows, device=None): + assert all(len(row) == len(rows[0]) for row in rows) + elems = [x for row in rows for x in row] + ref = [x for x in elems if isinstance(x, torch.Tensor)] + if len(ref) == 0: + return misc.constant(np.asarray(rows), device=device) + assert device is None or device == ref[0].device + elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] + return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) + +def translate2d(tx, ty, **kwargs): + return matrix( + [1, 0, tx], + [0, 1, ty], + [0, 0, 1], + **kwargs) + +def translate3d(tx, ty, tz, **kwargs): + return matrix( + [1, 0, 0, tx], + [0, 1, 0, ty], + [0, 0, 1, tz], + [0, 0, 0, 1], + **kwargs) + +def scale2d(sx, sy, **kwargs): + return matrix( + [sx, 0, 0], + [0, sy, 0], + [0, 0, 1], + **kwargs) + +def scale3d(sx, sy, sz, **kwargs): + return matrix( + [sx, 0, 0, 0], + [0, sy, 0, 0], + [0, 0, sz, 0], + [0, 0, 0, 1], + **kwargs) + +def rotate2d(theta, **kwargs): + return matrix( + [torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [0, 0, 1], + **kwargs) + +def rotate3d(v, theta, **kwargs): + vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] + s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c + return matrix( + [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], + [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], + [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], + [0, 0, 0, 1], + **kwargs) + +def translate2d_inv(tx, ty, **kwargs): + return translate2d(-tx, -ty, **kwargs) + +def scale2d_inv(sx, sy, **kwargs): + return scale2d(1 / sx, 1 / sy, **kwargs) + +def rotate2d_inv(theta, **kwargs): + return rotate2d(-theta, **kwargs) + +#---------------------------------------------------------------------------- +# Augmentation pipeline main class. +# All augmentations are disabled by default; individual augmentations can +# be enabled by setting their probability multipliers to 1. + +@persistence.persistent_class +class AugmentPipe: + def __init__(self, p=1, + xflip=0, yflip=0, rotate_int=0, translate_int=0, translate_int_max=0.125, + scale=0, rotate_frac=0, aniso=0, translate_frac=0, scale_std=0.2, rotate_frac_max=1, aniso_std=0.2, aniso_rotate_prob=0.5, translate_frac_std=0.125, + brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1, + ): + super().__init__() + self.p = float(p) # Overall multiplier for augmentation probability. + + # Pixel blitting. + self.xflip = float(xflip) # Probability multiplier for x-flip. + self.yflip = float(yflip) # Probability multiplier for y-flip. + self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation. + self.translate_int = float(translate_int) # Probability multiplier for integer translation. + self.translate_int_max = float(translate_int_max) # Range of integer translation, relative to image dimensions. + + # Geometric transformations. + self.scale = float(scale) # Probability multiplier for isotropic scaling. + self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation. + self.aniso = float(aniso) # Probability multiplier for anisotropic scaling. + self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation. + self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling. + self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle. + self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling. + self.aniso_rotate_prob = float(aniso_rotate_prob) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame. + self.translate_frac_std = float(translate_frac_std) # Standard deviation of frational translation, relative to image dimensions. + + # Color transformations. + self.brightness = float(brightness) # Probability multiplier for brightness. + self.contrast = float(contrast) # Probability multiplier for contrast. + self.lumaflip = float(lumaflip) # Probability multiplier for luma flip. + self.hue = float(hue) # Probability multiplier for hue rotation. + self.saturation = float(saturation) # Probability multiplier for saturation. + self.brightness_std = float(brightness_std) # Standard deviation of brightness. + self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast. + self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle. + self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation. + + def __call__(self, images): + N, C, H, W = images.shape + device = images.device + labels = [torch.zeros([images.shape[0], 0], device=device)] + + # --------------- + # Pixel blitting. + # --------------- + + if self.xflip > 0: + w = torch.randint(2, [N, 1, 1, 1], device=device) + w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w)) + images = torch.where(w == 1, images.flip(3), images) + labels += [w] + + if self.yflip > 0: + w = torch.randint(2, [N, 1, 1, 1], device=device) + w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w)) + images = torch.where(w == 1, images.flip(2), images) + labels += [w] + + if self.rotate_int > 0: + w = torch.randint(4, [N, 1, 1, 1], device=device) + w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w)) + images = torch.where((w == 1) | (w == 2), images.flip(3), images) + images = torch.where((w == 2) | (w == 3), images.flip(2), images) + images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images) + labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)] + + if self.translate_int > 0: + w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1 + w = torch.where(torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w)) + tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64) + ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64) + b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing='ij') + x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs() + y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs() + images = images.flatten()[(((b * C) + c) * H + y) * W + x] + labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)] + + # ------------------------------------------------ + # Select parameters for geometric transformations. + # ------------------------------------------------ + + I_3 = torch.eye(3, device=device) + G_inv = I_3 + + if self.scale > 0: + w = torch.randn([N], device=device) + w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w)) + s = w.mul(self.scale_std).exp2() + G_inv = G_inv @ scale2d_inv(s, s) + labels += [w] + + if self.rotate_frac > 0: + w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max) + w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w)) + G_inv = G_inv @ rotate2d_inv(-w) + labels += [w.cos() - 1, w.sin()] + + if self.aniso > 0: + w = torch.randn([N], device=device) + r = (torch.rand([N], device=device) * 2 - 1) * np.pi + w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w)) + r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r)) + s = w.mul(self.aniso_std).exp2() + G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r) + labels += [w * r.cos(), w * r.sin()] + + if self.translate_frac > 0: + w = torch.randn([2, N], device=device) + w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w)) + G_inv = G_inv @ translate2d_inv(w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std)) + labels += [w[0], w[1]] + + # ---------------------------------- + # Execute geometric transformations. + # ---------------------------------- + + if G_inv is not I_3: + cx = (W - 1) / 2 + cy = (H - 1) / 2 + cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz] + cp = G_inv @ cp.t() # [batch, xyz, idx] + Hz = np.asarray(wavelets['sym6'], dtype=np.float32) + Hz_pad = len(Hz) // 4 + margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx] + margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1] + margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) + margin = margin.max(misc.constant([0, 0] * 2, device=device)) + margin = margin.min(misc.constant([W - 1, H - 1] * 2, device=device)) + mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) + + # Pad image and adjust origin. + images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect') + G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv + + # Upsample. + conv_weight = misc.constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1]) + conv_pad = (len(Hz) + 1) // 2 + images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(N, C, images.shape[2], -1)[:, :, :, :-1] + images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0,conv_pad]) + images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(N, C, -1, images.shape[3])[:, :, :-1, :] + images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad,0]) + G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device) + G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device) + + # Execute transformation. + shape = [N, C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2] + G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False) + images = torch.nn.functional.grid_sample(images, grid, mode='bilinear', padding_mode='zeros', align_corners=False) + + # Downsample and crop. + conv_weight = misc.constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile([images.shape[1], 1, 1]) + conv_pad = (len(Hz) - 1) // 2 + images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1,2], padding=[0,conv_pad])[:, :, :, Hz_pad : -Hz_pad] + images = torch.nn.functional.conv2d(images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2,1], padding=[conv_pad,0])[:, :, Hz_pad : -Hz_pad, :] + + # -------------------------------------------- + # Select parameters for color transformations. + # -------------------------------------------- + + I_4 = torch.eye(4, device=device) + M = I_4 + luma_axis = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) + + if self.brightness > 0: + w = torch.randn([N], device=device) + w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w)) + b = w * self.brightness_std + M = translate3d(b, b, b) @ M + labels += [w] + + if self.contrast > 0: + w = torch.randn([N], device=device) + w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w)) + c = w.mul(self.contrast_std).exp2() + M = scale3d(c, c, c) @ M + labels += [w] + + if self.lumaflip > 0: + w = torch.randint(2, [N, 1, 1], device=device) + w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w)) + M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M + labels += [w] + + if self.hue > 0: + w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max) + w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w)) + M = rotate3d(luma_axis, w) @ M + labels += [w.cos() - 1, w.sin()] + + if self.saturation > 0: + w = torch.randn([N, 1, 1], device=device) + w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w)) + M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M + labels += [w] + + # ------------------------------ + # Execute color transformations. + # ------------------------------ + + if M is not I_4: + images = images.reshape([N, C, H * W]) + if C == 3: + images = M[:, :3, :3] @ images + M[:, :3, 3:] + elif C == 1: + M = M[:, :3, :].mean(dim=1, keepdims=True) + images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:] + else: + raise ValueError('Image must be RGB (3 channels) or L (1 channel)') + images = images.reshape([N, C, H, W]) + + labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1) + return images, labels + +#---------------------------------------------------------------------------- diff --git a/edm/training/dataset.py b/edm/training/dataset.py new file mode 100644 index 0000000..ef4bd02 --- /dev/null +++ b/edm/training/dataset.py @@ -0,0 +1,250 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Streaming images and labels from datasets created with dataset_tool.py.""" + +import os +import numpy as np +import zipfile +import PIL.Image +import json +import torch +import dnnlib + +try: + import pyspng +except ImportError: + pyspng = None + +#---------------------------------------------------------------------------- +# Abstract base class for datasets. + +class Dataset(torch.utils.data.Dataset): + def __init__(self, + name, # Name of the dataset. + raw_shape, # Shape of the raw image data (NCHW). + max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. + use_labels = False, # Enable conditioning labels? False = label dimension is zero. + xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. + random_seed = 0, # Random seed to use when applying max_size. + cache = False, # Cache images in CPU memory? + ): + self._name = name + self._raw_shape = list(raw_shape) + self._use_labels = use_labels + self._cache = cache + self._cached_images = dict() # {raw_idx: np.ndarray, ...} + self._raw_labels = None + self._label_shape = None + + # Apply max_size. + self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) + if (max_size is not None) and (self._raw_idx.size > max_size): + np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) + self._raw_idx = np.sort(self._raw_idx[:max_size]) + + # Apply xflip. + self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) + if xflip: + self._raw_idx = np.tile(self._raw_idx, 2) + self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) + + def _get_raw_labels(self): + if self._raw_labels is None: + self._raw_labels = self._load_raw_labels() if self._use_labels else None + if self._raw_labels is None: + self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) + assert isinstance(self._raw_labels, np.ndarray) + assert self._raw_labels.shape[0] == self._raw_shape[0] + assert self._raw_labels.dtype in [np.float32, np.int64] + if self._raw_labels.dtype == np.int64: + assert self._raw_labels.ndim == 1 + assert np.all(self._raw_labels >= 0) + return self._raw_labels + + def close(self): # to be overridden by subclass + pass + + def _load_raw_image(self, raw_idx): # to be overridden by subclass + raise NotImplementedError + + def _load_raw_labels(self): # to be overridden by subclass + raise NotImplementedError + + def __getstate__(self): + return dict(self.__dict__, _raw_labels=None) + + def __del__(self): + try: + self.close() + except: + pass + + def __len__(self): + return self._raw_idx.size + + def __getitem__(self, idx): + raw_idx = self._raw_idx[idx] + image = self._cached_images.get(raw_idx, None) + if image is None: + image = self._load_raw_image(raw_idx) + if self._cache: + self._cached_images[raw_idx] = image + assert isinstance(image, np.ndarray) + assert list(image.shape) == self.image_shape + assert image.dtype == np.uint8 + if self._xflip[idx]: + assert image.ndim == 3 # CHW + image = image[:, :, ::-1] + return image.copy(), self.get_label(idx) + + def get_label(self, idx): + label = self._get_raw_labels()[self._raw_idx[idx]] + if label.dtype == np.int64: + onehot = np.zeros(self.label_shape, dtype=np.float32) + onehot[label] = 1 + label = onehot + return label.copy() + + def get_details(self, idx): + d = dnnlib.EasyDict() + d.raw_idx = int(self._raw_idx[idx]) + d.xflip = (int(self._xflip[idx]) != 0) + d.raw_label = self._get_raw_labels()[d.raw_idx].copy() + return d + + @property + def name(self): + return self._name + + @property + def image_shape(self): + return list(self._raw_shape[1:]) + + @property + def num_channels(self): + assert len(self.image_shape) == 3 # CHW + return self.image_shape[0] + + @property + def resolution(self): + assert len(self.image_shape) == 3 # CHW + assert self.image_shape[1] == self.image_shape[2] + return self.image_shape[1] + + @property + def label_shape(self): + if self._label_shape is None: + raw_labels = self._get_raw_labels() + if raw_labels.dtype == np.int64: + self._label_shape = [int(np.max(raw_labels)) + 1] + else: + self._label_shape = raw_labels.shape[1:] + return list(self._label_shape) + + @property + def label_dim(self): + assert len(self.label_shape) == 1 + return self.label_shape[0] + + @property + def has_labels(self): + return any(x != 0 for x in self.label_shape) + + @property + def has_onehot_labels(self): + return self._get_raw_labels().dtype == np.int64 + +#---------------------------------------------------------------------------- +# Dataset subclass that loads images recursively from the specified directory +# or ZIP file. + +class ImageFolderDataset(Dataset): + def __init__(self, + path, # Path to directory or zip. + resolution = None, # Ensure specific resolution, None = highest available. + use_pyspng = True, # Use pyspng if available? + **super_kwargs, # Additional arguments for the Dataset base class. + ): + self._path = path + self._use_pyspng = use_pyspng + self._zipfile = None + + if os.path.isdir(self._path): + self._type = 'dir' + self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} + elif self._file_ext(self._path) == '.zip': + self._type = 'zip' + self._all_fnames = set(self._get_zipfile().namelist()) + else: + raise IOError('Path must point to a directory or zip') + + PIL.Image.init() + self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) + if len(self._image_fnames) == 0: + raise IOError('No image files found in the specified path') + + name = os.path.splitext(os.path.basename(self._path))[0] + raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) + if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): + raise IOError('Image files do not match the specified resolution') + super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) + + @staticmethod + def _file_ext(fname): + return os.path.splitext(fname)[1].lower() + + def _get_zipfile(self): + assert self._type == 'zip' + if self._zipfile is None: + self._zipfile = zipfile.ZipFile(self._path) + return self._zipfile + + def _open_file(self, fname): + if self._type == 'dir': + return open(os.path.join(self._path, fname), 'rb') + if self._type == 'zip': + return self._get_zipfile().open(fname, 'r') + return None + + def close(self): + try: + if self._zipfile is not None: + self._zipfile.close() + finally: + self._zipfile = None + + def __getstate__(self): + return dict(super().__getstate__(), _zipfile=None) + + def _load_raw_image(self, raw_idx): + fname = self._image_fnames[raw_idx] + with self._open_file(fname) as f: + if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': + image = pyspng.load(f.read()) + else: + image = np.array(PIL.Image.open(f)) + if image.ndim == 2: + image = image[:, :, np.newaxis] # HW => HWC + image = image.transpose(2, 0, 1) # HWC => CHW + return image + + def _load_raw_labels(self): + fname = 'dataset.json' + if fname not in self._all_fnames: + return None + with self._open_file(fname) as f: + labels = json.load(f)['labels'] + if labels is None: + return None + labels = dict(labels) + labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] + labels = np.array(labels) + labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) + return labels + +#---------------------------------------------------------------------------- diff --git a/edm/training/loss.py b/edm/training/loss.py new file mode 100644 index 0000000..ff045c5 --- /dev/null +++ b/edm/training/loss.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Loss functions used in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models".""" + +import torch +from torch_utils import persistence + +#---------------------------------------------------------------------------- +# Loss function corresponding to the variance preserving (VP) formulation +# from the paper "Score-Based Generative Modeling through Stochastic +# Differential Equations". + +@persistence.persistent_class +class VPLoss: + def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): + self.beta_d = beta_d + self.beta_min = beta_min + self.epsilon_t = epsilon_t + + def __call__(self, net, images, labels, augment_pipe=None): + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) + weight = 1 / sigma ** 2 + y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + + def sigma(self, t): + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() + +#---------------------------------------------------------------------------- +# Loss function corresponding to the variance exploding (VE) formulation +# from the paper "Score-Based Generative Modeling through Stochastic +# Differential Equations". + +@persistence.persistent_class +class VELoss: + def __init__(self, sigma_min=0.02, sigma_max=100): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def __call__(self, net, images, labels, augment_pipe=None): + rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) + sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) + weight = 1 / sigma ** 2 + y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + +#---------------------------------------------------------------------------- +# Improved loss function proposed in the paper "Elucidating the Design Space +# of Diffusion-Based Generative Models" (EDM). + +@persistence.persistent_class +class EDMLoss: + def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): + self.P_mean = P_mean + self.P_std = P_std + self.sigma_data = sigma_data + + def __call__(self, net, images, labels=None, augment_pipe=None): + rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) + n = torch.randn_like(y) * sigma + D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + loss = weight * ((D_yn - y) ** 2) + return loss + +#---------------------------------------------------------------------------- diff --git a/edm/training/networks.py b/edm/training/networks.py new file mode 100644 index 0000000..d2326c7 --- /dev/null +++ b/edm/training/networks.py @@ -0,0 +1,673 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Model architectures and preconditioning schemes used in the paper +"Elucidating the Design Space of Diffusion-Based Generative Models".""" + +import numpy as np +import torch +from torch_utils import persistence +from torch.nn.functional import silu + +#---------------------------------------------------------------------------- +# Unified routine for initializing weights and biases. + +def weight_init(shape, mode, fan_in, fan_out): + if mode == 'xavier_uniform': return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) + if mode == 'xavier_normal': return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) + if mode == 'kaiming_uniform': return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) + if mode == 'kaiming_normal': return np.sqrt(1 / fan_in) * torch.randn(*shape) + raise ValueError(f'Invalid init mode "{mode}"') + +#---------------------------------------------------------------------------- +# Fully-connected layer. + +@persistence.persistent_class +class Linear(torch.nn.Module): + def __init__(self, in_features, out_features, bias=True, init_mode='kaiming_normal', init_weight=1, init_bias=0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) + self.weight = torch.nn.Parameter(weight_init([out_features, in_features], **init_kwargs) * init_weight) + self.bias = torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) if bias else None + + def forward(self, x): + x = x @ self.weight.to(x.dtype).t() + if self.bias is not None: + x = x.add_(self.bias.to(x.dtype)) + return x + +#---------------------------------------------------------------------------- +# Convolutional layer with optional up/downsampling. + +@persistence.persistent_class +class Conv2d(torch.nn.Module): + def __init__(self, + in_channels, out_channels, kernel, bias=True, up=False, down=False, + resample_filter=[1,1], fused_resample=False, init_mode='kaiming_normal', init_weight=1, init_bias=0, + ): + assert not (up and down) + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.up = up + self.down = down + self.fused_resample = fused_resample + init_kwargs = dict(mode=init_mode, fan_in=in_channels*kernel*kernel, fan_out=out_channels*kernel*kernel) + self.weight = torch.nn.Parameter(weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) * init_weight) if kernel else None + self.bias = torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) if kernel and bias else None + f = torch.as_tensor(resample_filter, dtype=torch.float32) + f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() + self.register_buffer('resample_filter', f if up or down else None) + + def forward(self, x): + w = self.weight.to(x.dtype) if self.weight is not None else None + b = self.bias.to(x.dtype) if self.bias is not None else None + f = self.resample_filter.to(x.dtype) if self.resample_filter is not None else None + w_pad = w.shape[-1] // 2 if w is not None else 0 + f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 + + if self.fused_resample and self.up and w is not None: + x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=max(f_pad - w_pad, 0)) + x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) + elif self.fused_resample and self.down and w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad+f_pad) + x = torch.nn.functional.conv2d(x, f.tile([self.out_channels, 1, 1, 1]), groups=self.out_channels, stride=2) + else: + if self.up: + x = torch.nn.functional.conv_transpose2d(x, f.mul(4).tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad) + if self.down: + x = torch.nn.functional.conv2d(x, f.tile([self.in_channels, 1, 1, 1]), groups=self.in_channels, stride=2, padding=f_pad) + if w is not None: + x = torch.nn.functional.conv2d(x, w, padding=w_pad) + if b is not None: + x = x.add_(b.reshape(1, -1, 1, 1)) + return x + +#---------------------------------------------------------------------------- +# Group normalization. + +@persistence.persistent_class +class GroupNorm(torch.nn.Module): + def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5): + super().__init__() + self.num_groups = min(num_groups, num_channels // min_channels_per_group) + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(num_channels)) + self.bias = torch.nn.Parameter(torch.zeros(num_channels)) + + def forward(self, x): + x = torch.nn.functional.group_norm(x, num_groups=self.num_groups, weight=self.weight.to(x.dtype), bias=self.bias.to(x.dtype), eps=self.eps) + return x + +#---------------------------------------------------------------------------- +# Attention weight computation, i.e., softmax(Q^T * K). +# Performs all computation using FP32, but uses the original datatype for +# inputs/outputs/gradients to conserve memory. + +class AttentionOp(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k): + w = torch.einsum('ncq,nck->nqk', q.to(torch.float32), (k / np.sqrt(k.shape[1])).to(torch.float32)).softmax(dim=2).to(q.dtype) + ctx.save_for_backward(q, k, w) + return w + + @staticmethod + def backward(ctx, dw): + q, k, w = ctx.saved_tensors + db = torch._softmax_backward_data(grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32) + dq = torch.einsum('nck,nqk->ncq', k.to(torch.float32), db).to(q.dtype) / np.sqrt(k.shape[1]) + dk = torch.einsum('ncq,nqk->nck', q.to(torch.float32), db).to(k.dtype) / np.sqrt(k.shape[1]) + return dq, dk + +#---------------------------------------------------------------------------- +# Unified U-Net block with optional up/downsampling and self-attention. +# Represents the union of all features employed by the DDPM++, NCSN++, and +# ADM architectures. + +@persistence.persistent_class +class UNetBlock(torch.nn.Module): + def __init__(self, + in_channels, out_channels, emb_channels, up=False, down=False, attention=False, + num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5, + resample_filter=[1,1], resample_proj=False, adaptive_scale=True, + init=dict(), init_zero=dict(init_weight=0), init_attn=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.emb_channels = emb_channels + self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head + self.dropout = dropout + self.skip_scale = skip_scale + self.adaptive_scale = adaptive_scale + + self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) + self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init) + self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init) + self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) + self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero) + + self.skip = None + if out_channels != in_channels or up or down: + kernel = 1 if resample_proj or out_channels!= in_channels else 0 + self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init) + + if self.num_heads: + self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) + self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init)) + self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero) + + def forward(self, x, emb): + orig = x + x = self.conv0(silu(self.norm0(x))) + + params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) + if self.adaptive_scale: + scale, shift = params.chunk(chunks=2, dim=1) + x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) + else: + x = silu(self.norm1(x.add_(params))) + + x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training)) + x = x.add_(self.skip(orig) if self.skip is not None else orig) + x = x * self.skip_scale + + if self.num_heads: + q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2) + w = AttentionOp.apply(q, k) + a = torch.einsum('nqk,nck->ncq', w, v) + x = self.proj(a.reshape(*x.shape)).add_(x) + x = x * self.skip_scale + return x + +#---------------------------------------------------------------------------- +# Timestep embedding used in the DDPM++ and ADM architectures. + +@persistence.persistent_class +class PositionalEmbedding(torch.nn.Module): + def __init__(self, num_channels, max_positions=10000, endpoint=False): + super().__init__() + self.num_channels = num_channels + self.max_positions = max_positions + self.endpoint = endpoint + + def forward(self, x): + freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + +#---------------------------------------------------------------------------- +# Timestep embedding used in the NCSN++ architecture. + +@persistence.persistent_class +class FourierEmbedding(torch.nn.Module): + def __init__(self, num_channels, scale=16): + super().__init__() + self.register_buffer('freqs', torch.randn(num_channels // 2) * scale) + + def forward(self, x): + x = x.ger((2 * np.pi * self.freqs).to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + return x + +#---------------------------------------------------------------------------- +# Reimplementation of the DDPM++ and NCSN++ architectures from the paper +# "Score-Based Generative Modeling through Stochastic Differential +# Equations". Equivalent to the original implementation by Song et al., +# available at https://github.com/yang-song/score_sde_pytorch + +@persistence.persistent_class +class SongUNet(torch.nn.Module): + def __init__(self, + img_resolution, # Image resolution at input/output. + in_channels, # Number of color channels at input. + out_channels, # Number of color channels at output. + label_dim = 0, # Number of class labels, 0 = unconditional. + augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation. + + model_channels = 128, # Base multiplier for the number of channels. + channel_mult = [1,2,2,2], # Per-resolution multipliers for the number of channels. + channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector. + num_blocks = 4, # Number of residual blocks per resolution. + attn_resolutions = [16], # List of resolutions with self-attention. + dropout = 0.10, # Dropout probability of intermediate activations. + label_dropout = 0, # Dropout probability of class labels for classifier-free guidance. + + embedding_type = 'positional', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. + channel_mult_noise = 1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++. + encoder_type = 'standard', # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. + decoder_type = 'standard', # Decoder architecture: 'standard' for both DDPM++ and NCSN++. + resample_filter = [1,1], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. + ): + assert embedding_type in ['fourier', 'positional'] + assert encoder_type in ['standard', 'skip', 'residual'] + assert decoder_type in ['standard', 'skip'] + + super().__init__() + self.label_dropout = label_dropout + emb_channels = model_channels * channel_mult_emb + noise_channels = model_channels * channel_mult_noise + init = dict(init_mode='xavier_uniform') + init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5) + init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2)) + block_kwargs = dict( + emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6, + resample_filter=resample_filter, resample_proj=True, adaptive_scale=False, + init=init, init_zero=init_zero, init_attn=init_attn, + ) + + # Mapping. + self.map_noise = PositionalEmbedding(num_channels=noise_channels, endpoint=True) if embedding_type == 'positional' else FourierEmbedding(num_channels=noise_channels) + self.map_label = Linear(in_features=label_dim, out_features=noise_channels, **init) if label_dim else None + self.map_augment = Linear(in_features=augment_dim, out_features=noise_channels, bias=False, **init) if augment_dim else None + self.map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init) + self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init) + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + caux = in_channels + for level, mult in enumerate(channel_mult): + res = img_resolution >> level + if level == 0: + cin = cout + cout = model_channels + self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init) + else: + self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs) + if encoder_type == 'skip': + self.enc[f'{res}x{res}_aux_down'] = Conv2d(in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter) + self.enc[f'{res}x{res}_aux_skip'] = Conv2d(in_channels=caux, out_channels=cout, kernel=1, **init) + if encoder_type == 'residual': + self.enc[f'{res}x{res}_aux_residual'] = Conv2d(in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, **init) + caux = cout + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + attn = (res in attn_resolutions) + self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs) + skips = [block.out_channels for name, block in self.enc.items() if 'aux' not in name] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = img_resolution >> level + if level == len(channel_mult) - 1: + self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs) + self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs) + else: + self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + attn = (idx == num_blocks and res in attn_resolutions) + self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs) + if decoder_type == 'skip' or level == 0: + if decoder_type == 'skip' and level < len(channel_mult) - 1: + self.dec[f'{res}x{res}_aux_up'] = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter) + self.dec[f'{res}x{res}_aux_norm'] = GroupNorm(num_channels=cout, eps=1e-6) + self.dec[f'{res}x{res}_aux_conv'] = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero) + + def forward(self, x, noise_labels, class_labels, augment_labels=None): + # Mapping. + emb = self.map_noise(noise_labels) + emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype) + emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = silu(self.map_layer1(emb)) + + # Encoder. + skips = [] + aux = x + for name, block in self.enc.items(): + if 'aux_down' in name: + aux = block(aux) + elif 'aux_skip' in name: + x = skips[-1] = x + block(aux) + elif 'aux_residual' in name: + x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) + else: + x = block(x, emb) if isinstance(block, UNetBlock) else block(x) + skips.append(x) + + # Decoder. + aux = None + tmp = None + for name, block in self.dec.items(): + if 'aux_up' in name: + aux = block(aux) + elif 'aux_norm' in name: + tmp = block(x) + elif 'aux_conv' in name: + tmp = block(silu(tmp)) + aux = tmp if aux is None else tmp + aux + else: + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + x = block(x, emb) + return aux + +#---------------------------------------------------------------------------- +# Reimplementation of the ADM architecture from the paper +# "Diffusion Models Beat GANS on Image Synthesis". Equivalent to the +# original implementation by Dhariwal and Nichol, available at +# https://github.com/openai/guided-diffusion + +@persistence.persistent_class +class DhariwalUNet(torch.nn.Module): + def __init__(self, + img_resolution, # Image resolution at input/output. + in_channels, # Number of color channels at input. + out_channels, # Number of color channels at output. + label_dim = 0, # Number of class labels, 0 = unconditional. + augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation. + + model_channels = 192, # Base multiplier for the number of channels. + channel_mult = [1,2,3,4], # Per-resolution multipliers for the number of channels. + channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector. + num_blocks = 3, # Number of residual blocks per resolution. + attn_resolutions = [32,16,8], # List of resolutions with self-attention. + dropout = 0.10, # List of resolutions with self-attention. + label_dropout = 0, # Dropout probability of class labels for classifier-free guidance. + ): + super().__init__() + self.label_dropout = label_dropout + emb_channels = model_channels * channel_mult_emb + init = dict(init_mode='kaiming_uniform', init_weight=np.sqrt(1/3), init_bias=np.sqrt(1/3)) + init_zero = dict(init_mode='kaiming_uniform', init_weight=0, init_bias=0) + block_kwargs = dict(emb_channels=emb_channels, channels_per_head=64, dropout=dropout, init=init, init_zero=init_zero) + + # Mapping. + self.map_noise = PositionalEmbedding(num_channels=model_channels) + self.map_augment = Linear(in_features=augment_dim, out_features=model_channels, bias=False, **init_zero) if augment_dim else None + self.map_layer0 = Linear(in_features=model_channels, out_features=emb_channels, **init) + self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init) + self.map_label = Linear(in_features=label_dim, out_features=emb_channels, bias=False, init_mode='kaiming_normal', init_weight=np.sqrt(label_dim)) if label_dim else None + + # Encoder. + self.enc = torch.nn.ModuleDict() + cout = in_channels + for level, mult in enumerate(channel_mult): + res = img_resolution >> level + if level == 0: + cin = cout + cout = model_channels * mult + self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init) + else: + self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs) + for idx in range(num_blocks): + cin = cout + cout = model_channels * mult + self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs) + skips = [block.out_channels for block in self.enc.values()] + + # Decoder. + self.dec = torch.nn.ModuleDict() + for level, mult in reversed(list(enumerate(channel_mult))): + res = img_resolution >> level + if level == len(channel_mult) - 1: + self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs) + self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs) + else: + self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs) + for idx in range(num_blocks + 1): + cin = cout + skips.pop() + cout = model_channels * mult + self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=(res in attn_resolutions), **block_kwargs) + self.out_norm = GroupNorm(num_channels=cout) + self.out_conv = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero) + + def forward(self, x, noise_labels, class_labels, augment_labels=None): + # Mapping. + emb = self.map_noise(noise_labels) + if self.map_augment is not None and augment_labels is not None: + emb = emb + self.map_augment(augment_labels) + emb = silu(self.map_layer0(emb)) + emb = self.map_layer1(emb) + if self.map_label is not None: + tmp = class_labels + if self.training and self.label_dropout: + tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype) + emb = emb + self.map_label(tmp) + emb = silu(emb) + + # Encoder. + skips = [] + for block in self.enc.values(): + x = block(x, emb) if isinstance(block, UNetBlock) else block(x) + skips.append(x) + + # Decoder. + for block in self.dec.values(): + if x.shape[1] != block.in_channels: + x = torch.cat([x, skips.pop()], dim=1) + x = block(x, emb) + x = self.out_conv(silu(self.out_norm(x))) + return x + +#---------------------------------------------------------------------------- +# Preconditioning corresponding to the variance preserving (VP) formulation +# from the paper "Score-Based Generative Modeling through Stochastic +# Differential Equations". + +@persistence.persistent_class +class VPPrecond(torch.nn.Module): + def __init__(self, + img_resolution, # Image resolution. + img_channels, # Number of color channels. + label_dim = 0, # Number of class labels, 0 = unconditional. + use_fp16 = False, # Execute the underlying model at FP16 precision? + beta_d = 19.9, # Extent of the noise level schedule. + beta_min = 0.1, # Initial slope of the noise level schedule. + M = 1000, # Original number of timesteps in the DDPM formulation. + epsilon_t = 1e-5, # Minimum t-value used during training. + model_type = 'SongUNet', # Class name of the underlying model. + **model_kwargs, # Keyword arguments for the underlying model. + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.beta_d = beta_d + self.beta_min = beta_min + self.M = M + self.epsilon_t = epsilon_t + self.sigma_min = float(self.sigma(epsilon_t)) + self.sigma_max = float(self.sigma(1)) + self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1).sqrt() + c_noise = (self.M - 1) * self.sigma_inv(sigma) + + F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) + assert F_x.dtype == dtype + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def sigma(self, t): + t = torch.as_tensor(t) + return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() + + def sigma_inv(self, sigma): + sigma = torch.as_tensor(sigma) + return ((self.beta_min ** 2 + 2 * self.beta_d * (1 + sigma ** 2).log()).sqrt() - self.beta_min) / self.beta_d + + def round_sigma(self, sigma): + return torch.as_tensor(sigma) + +#---------------------------------------------------------------------------- +# Preconditioning corresponding to the variance exploding (VE) formulation +# from the paper "Score-Based Generative Modeling through Stochastic +# Differential Equations". + +@persistence.persistent_class +class VEPrecond(torch.nn.Module): + def __init__(self, + img_resolution, # Image resolution. + img_channels, # Number of color channels. + label_dim = 0, # Number of class labels, 0 = unconditional. + use_fp16 = False, # Execute the underlying model at FP16 precision? + sigma_min = 0.02, # Minimum supported noise level. + sigma_max = 100, # Maximum supported noise level. + model_type = 'SongUNet', # Class name of the underlying model. + **model_kwargs, # Keyword arguments for the underlying model. + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + + F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) + assert F_x.dtype == dtype + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma): + return torch.as_tensor(sigma) + +#---------------------------------------------------------------------------- +# Preconditioning corresponding to improved DDPM (iDDPM) formulation from +# the paper "Improved Denoising Diffusion Probabilistic Models". + +@persistence.persistent_class +class iDDPMPrecond(torch.nn.Module): + def __init__(self, + img_resolution, # Image resolution. + img_channels, # Number of color channels. + label_dim = 0, # Number of class labels, 0 = unconditional. + use_fp16 = False, # Execute the underlying model at FP16 precision? + C_1 = 0.001, # Timestep adjustment at low noise levels. + C_2 = 0.008, # Timestep adjustment at high noise levels. + M = 1000, # Original number of timesteps in the DDPM formulation. + model_type = 'DhariwalUNet', # Class name of the underlying model. + **model_kwargs, # Keyword arguments for the underlying model. + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.C_1 = C_1 + self.C_2 = C_2 + self.M = M + self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels*2, label_dim=label_dim, **model_kwargs) + + u = torch.zeros(M + 1) + for j in range(M, 0, -1): # M, ..., 1 + u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt() + self.register_buffer('u', u) + self.sigma_min = float(u[M - 1]) + self.sigma_max = float(u[0]) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1).sqrt() + c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + + F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) + assert F_x.dtype == dtype + D_x = c_skip * x + c_out * F_x[:, :self.img_channels].to(torch.float32) + return D_x + + def alpha_bar(self, j): + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + def round_sigma(self, sigma, return_index=False): + sigma = torch.as_tensor(sigma) + index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape).to(sigma.device) + +#---------------------------------------------------------------------------- +# Improved preconditioning proposed in the paper "Elucidating the Design +# Space of Diffusion-Based Generative Models" (EDM). + +@persistence.persistent_class +class EDMPrecond(torch.nn.Module): + def __init__(self, + img_resolution, # Image resolution. + img_channels, # Number of color channels. + label_dim = 0, # Number of class labels, 0 = unconditional. + use_fp16 = False, # Execute the underlying model at FP16 precision? + sigma_min = 0, # Minimum supported noise level. + sigma_max = float('inf'), # Maximum supported noise level. + sigma_data = 0.5, # Expected standard deviation of the training data. + model_type = 'DhariwalUNet', # Class name of the underlying model. + **model_kwargs, # Keyword arguments for the underlying model. + ): + super().__init__() + self.img_resolution = img_resolution + self.img_channels = img_channels + self.label_dim = label_dim + self.use_fp16 = use_fp16 + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.sigma_data = sigma_data + self.model = globals()[model_type](img_resolution=img_resolution, in_channels=img_channels, out_channels=img_channels, label_dim=label_dim, **model_kwargs) + + def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): + x = x.to(torch.float32) + sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) + class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) + dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 + + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() + c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() + c_noise = sigma.log() / 4 + + F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) + assert F_x.dtype == dtype + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x + + def round_sigma(self, sigma): + return torch.as_tensor(sigma) + +#---------------------------------------------------------------------------- diff --git a/edm/training/training_loop.py b/edm/training/training_loop.py new file mode 100644 index 0000000..109d7d2 --- /dev/null +++ b/edm/training/training_loop.py @@ -0,0 +1,216 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# This work is licensed under a Creative Commons +# Attribution-NonCommercial-ShareAlike 4.0 International License. +# You should have received a copy of the license along with this +# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ + +"""Main training loop.""" + +import os +import time +import copy +import json +import pickle +import psutil +import numpy as np +import torch +import dnnlib +from torch_utils import distributed as dist +from torch_utils import training_stats +from torch_utils import misc + +#---------------------------------------------------------------------------- + +def training_loop( + run_dir = '.', # Output directory. + dataset_kwargs = {}, # Options for training set. + data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. + network_kwargs = {}, # Options for model and preconditioning. + loss_kwargs = {}, # Options for loss function. + optimizer_kwargs = {}, # Options for optimizer. + augment_kwargs = None, # Options for augmentation pipeline, None = disable. + seed = 0, # Global random seed. + batch_size = 512, # Total batch size for one training iteration. + batch_gpu = None, # Limit batch size per GPU, None = no limit. + total_kimg = 200000, # Training duration, measured in thousands of training images. + ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights. + ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup. + lr_rampup_kimg = 10000, # Learning rate ramp-up duration. + loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows. + kimg_per_tick = 50, # Interval of progress prints. + snapshot_ticks = 50, # How often to save network snapshots, None = disable. + state_dump_ticks = 500, # How often to dump training state, None = disable. + resume_pkl = None, # Start from the given network snapshot, None = random initialization. + resume_state_dump = None, # Start from the given training state, None = reset training state. + resume_kimg = 0, # Start from the given training progress. + cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? + device = torch.device('cuda'), +): + # Initialize. + start_time = time.time() + np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31)) + torch.manual_seed(np.random.randint(1 << 31)) + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + # Select batch size per GPU. + batch_gpu_total = batch_size // dist.get_world_size() + if batch_gpu is None or batch_gpu > batch_gpu_total: + batch_gpu = batch_gpu_total + num_accumulation_rounds = batch_gpu_total // batch_gpu + assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size() + + # Load dataset. + dist.print0('Loading dataset...') + dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset + dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed) + dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs)) + + # Construct network. + dist.print0('Constructing network...') + interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim) + net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module + net.train().requires_grad_(True).to(device) + if dist.get_rank() == 0: + with torch.no_grad(): + images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device) + sigma = torch.ones([batch_gpu], device=device) + labels = torch.zeros([batch_gpu, net.label_dim], device=device) + misc.print_module_summary(net, [images, sigma, labels], max_nesting=2) + + # Setup optimizer. + dist.print0('Setting up optimizer...') + loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss + optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer + augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe + ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False) + ema = copy.deepcopy(net).eval().requires_grad_(False) + + # Resume training from previous snapshot. + if resume_pkl is not None: + dist.print0(f'Loading network weights from "{resume_pkl}"...') + if dist.get_rank() != 0: + torch.distributed.barrier() # rank 0 goes first + with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f: + data = pickle.load(f) + if dist.get_rank() == 0: + torch.distributed.barrier() # other ranks follow + misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False) + misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False) + del data # conserve memory + if resume_state_dump: + dist.print0(f'Loading training state from "{resume_state_dump}"...') + data = torch.load(resume_state_dump, map_location=torch.device('cpu')) + misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True) + optimizer.load_state_dict(data['optimizer_state']) + del data # conserve memory + + # Train. + dist.print0(f'Training for {total_kimg} kimg...') + dist.print0() + cur_nimg = resume_kimg * 1000 + cur_tick = 0 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - start_time + dist.update_progress(cur_nimg // 1000, total_kimg) + stats_jsonl = None + while True: + + # Accumulate gradients. + optimizer.zero_grad(set_to_none=True) + for round_idx in range(num_accumulation_rounds): + with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)): + images, labels = next(dataset_iterator) + images = images.to(device).to(torch.float32) / 127.5 - 1 + labels = labels.to(device) + loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe) + training_stats.report('Loss/loss', loss) + loss.sum().mul(loss_scaling / batch_gpu_total).backward() + + # Update weights. + for g in optimizer.param_groups: + g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1) + for param in net.parameters(): + if param.grad is not None: + torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) + optimizer.step() + + # Update EMA. + ema_halflife_nimg = ema_halflife_kimg * 1000 + if ema_rampup_ratio is not None: + ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio) + ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8)) + for p_ema, p_net in zip(ema.parameters(), net.parameters()): + p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta)) + + # Perform maintenance tasks once per tick. + cur_nimg += batch_size + done = (cur_nimg >= total_kimg * 1000) + if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): + continue + + # Print status line, accumulating the same information in training_stats. + tick_end_time = time.time() + fields = [] + fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] + fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"] + fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] + fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] + fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] + fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] + fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] + fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] + fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"] + torch.cuda.reset_peak_memory_stats() + dist.print0(' '.join(fields)) + + # Check for abort. + if (not done) and dist.should_stop(): + done = True + dist.print0() + dist.print0('Aborting...') + + # Save network snapshot. + if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0): + data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs)) + for key, value in data.items(): + if isinstance(value, torch.nn.Module): + value = copy.deepcopy(value).eval().requires_grad_(False) + misc.check_ddp_consistency(value) + data[key] = value.cpu() + del value # conserve memory + if dist.get_rank() == 0: + with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f: + pickle.dump(data, f) + del data # conserve memory + + # Save full dump of the training state. + if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0: + torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt')) + + # Update logs. + training_stats.default_collector.update() + if dist.get_rank() == 0: + if stats_jsonl is None: + stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at') + stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n') + stats_jsonl.flush() + dist.update_progress(cur_nimg // 1000, total_kimg) + + # Update state. + cur_tick += 1 + tick_start_nimg = cur_nimg + tick_start_time = time.time() + maintenance_time = tick_start_time - tick_end_time + if done: + break + + # Done. + dist.print0() + dist.print0('Exiting...') + +#----------------------------------------------------------------------------