| | import importlib |
| |
|
| | |
| | import logging |
| |
|
| | logger = logging.getLogger("pytorch_lightning") |
| |
|
| | from pytorch_lightning.utilities.rank_zero import ( |
| | rank_zero_debug, |
| | rank_zero_info, |
| | rank_zero_only, |
| | ) |
| |
|
| |
|
| | def find(cls_string): |
| | module_string = ".".join(cls_string.split(".")[:-1]) |
| | cls_name = cls_string.split(".")[-1] |
| | module = importlib.import_module(module_string, package=None) |
| | cls = getattr(module, cls_name) |
| | return cls |
| |
|
| |
|
| | debug = rank_zero_debug |
| | info = rank_zero_info |
| |
|
| |
|
| | @rank_zero_only |
| | def warn(*args, **kwargs): |
| | logger.warn(*args, **kwargs) |
| |
|