Spaces:
Runtime error
Runtime error
potato
commited on
Commit
·
a256f81
1
Parent(s):
d7a6e77
Deleted description
Browse files- app.py +0 -10
- src/core.py +3 -2
app.py
CHANGED
|
@@ -44,16 +44,6 @@ if 'reuse_image' not in st.session_state:
|
|
| 44 |
def set_image(img):
|
| 45 |
st.session_state.reuse_image = img
|
| 46 |
|
| 47 |
-
st.title("AI Photo Object Removal")
|
| 48 |
-
|
| 49 |
-
st.image(open("assets/demo.png", "rb").read())
|
| 50 |
-
|
| 51 |
-
st.markdown(
|
| 52 |
-
"""
|
| 53 |
-
So you want to remove an object in your photo? You don't need to learn photo editing skills.
|
| 54 |
-
**Just draw over the parts of the image you want to remove, then our AI will remove them.**
|
| 55 |
-
"""
|
| 56 |
-
)
|
| 57 |
uploaded_file = st.file_uploader("Choose image", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
|
| 58 |
|
| 59 |
if uploaded_file is not None:
|
|
|
|
| 44 |
def set_image(img):
|
| 45 |
st.session_state.reuse_image = img
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
uploaded_file = st.file_uploader("Choose image", accept_multiple_files=False, type=["png", "jpg", "jpeg"])
|
| 48 |
|
| 49 |
if uploaded_file is not None:
|
src/core.py
CHANGED
|
@@ -64,9 +64,10 @@ ENERGY_MASK_CONST = 100000.0 # large energy value for protective ma
|
|
| 64 |
MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
|
| 65 |
USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
|
| 66 |
|
| 67 |
-
device = torch.device("cpu")
|
|
|
|
| 68 |
model_path = "./assets/big-lama.pt"
|
| 69 |
-
model = torch.jit.load(model_path, map_location=
|
| 70 |
model = model.to(device)
|
| 71 |
model.eval()
|
| 72 |
|
|
|
|
| 64 |
MASK_THRESHOLD = 10 # minimum pixel intensity for binary mask
|
| 65 |
USE_FORWARD_ENERGY = True # if True, use forward energy algorithm
|
| 66 |
|
| 67 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 68 |
+
print(device)
|
| 69 |
model_path = "./assets/big-lama.pt"
|
| 70 |
+
model = torch.jit.load(model_path, map_location=device)
|
| 71 |
model = model.to(device)
|
| 72 |
model.eval()
|
| 73 |
|