Update all files for DiffusionSat-Single-256
Browse files- pipeline_diffusionsat.py +8 -0
pipeline_diffusionsat.py
CHANGED
|
@@ -5,6 +5,7 @@ from the checkpoint folder without importing the project package.
|
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
|
|
|
|
| 8 |
from typing import Any, Callable, Dict, List, Optional, Union
|
| 9 |
|
| 10 |
import torch
|
|
@@ -59,6 +60,13 @@ class DiffusionSatPipeline(DiffusionPipeline):
|
|
| 59 |
|
| 60 |
_optional_components = ["safety_checker", "feature_extractor"]
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def __init__(
|
| 63 |
self,
|
| 64 |
vae: AutoencoderKL,
|
|
|
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
+
import inspect
|
| 9 |
from typing import Any, Callable, Dict, List, Optional, Union
|
| 10 |
|
| 11 |
import torch
|
|
|
|
| 60 |
|
| 61 |
_optional_components = ["safety_checker", "feature_extractor"]
|
| 62 |
|
| 63 |
+
@classmethod
|
| 64 |
+
def _get_signature_types(cls) -> Dict[str, tuple]:
|
| 65 |
+
"""Return init param names so diffusers type validation does not KeyError on custom pipeline."""
|
| 66 |
+
sig = inspect.signature(cls.__init__)
|
| 67 |
+
empty = (inspect.Signature.empty,)
|
| 68 |
+
return {name: empty for name in sig.parameters}
|
| 69 |
+
|
| 70 |
def __init__(
|
| 71 |
self,
|
| 72 |
vae: AutoencoderKL,
|