Spaces:
Runtime error
Runtime error
update
Browse files- infer_api.py +3 -7
- refine/mesh_refine.py +5 -0
infer_api.py
CHANGED
|
@@ -619,30 +619,26 @@ def infer_refine(meshes, imgs):
|
|
| 619 |
|
| 620 |
# my mesh flow weight by nearest vertexs
|
| 621 |
if fixed_v is not None and fixed_f is not None and level == 1:
|
| 622 |
-
t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
|
| 623 |
-
|
| 624 |
fixed_v_cpu = fixed_v.cpu().numpy()
|
| 625 |
kdtree_anchor = KDTree(fixed_v_cpu)
|
| 626 |
kdtree_mesh_v = KDTree(mesh_v)
|
| 627 |
_, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
|
| 628 |
_, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
|
| 629 |
idx_anchor = idx_anchor.squeeze()
|
| 630 |
-
neighbors = torch.tensor(mesh_v)
|
| 631 |
# calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
|
| 632 |
-
neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v)
|
| 633 |
neighbor_dists[neighbor_dists > 0.06] = 114514.
|
| 634 |
neighbor_weights = torch.exp(-neighbor_dists * 1.)
|
| 635 |
neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
|
| 636 |
anchors = fixed_v[idx_anchor] # V, 3
|
| 637 |
anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
|
| 638 |
-
dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v)
|
| 639 |
vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
|
| 640 |
vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
|
| 641 |
weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
|
| 642 |
mesh_v += weighted_vec_anchor.cpu().numpy()
|
| 643 |
|
| 644 |
-
t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
|
| 645 |
-
|
| 646 |
mesh_v = torch.tensor(mesh_v, dtype=torch.float32)
|
| 647 |
mesh_f = torch.tensor(mesh_f)
|
| 648 |
|
|
|
|
| 619 |
|
| 620 |
# my mesh flow weight by nearest vertexs
|
| 621 |
if fixed_v is not None and fixed_f is not None and level == 1:
|
|
|
|
|
|
|
| 622 |
fixed_v_cpu = fixed_v.cpu().numpy()
|
| 623 |
kdtree_anchor = KDTree(fixed_v_cpu)
|
| 624 |
kdtree_mesh_v = KDTree(mesh_v)
|
| 625 |
_, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
|
| 626 |
_, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
|
| 627 |
idx_anchor = idx_anchor.squeeze()
|
| 628 |
+
neighbors = torch.tensor(mesh_v)[idx_mesh_v] # V, 25, 3
|
| 629 |
# calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
|
| 630 |
+
neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v)[:, None], dim=-1)
|
| 631 |
neighbor_dists[neighbor_dists > 0.06] = 114514.
|
| 632 |
neighbor_weights = torch.exp(-neighbor_dists * 1.)
|
| 633 |
neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
|
| 634 |
anchors = fixed_v[idx_anchor] # V, 3
|
| 635 |
anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
|
| 636 |
+
dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v)) * anchor_normals).sum(-1), min=0) + 0.01
|
| 637 |
vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
|
| 638 |
vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
|
| 639 |
weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
|
| 640 |
mesh_v += weighted_vec_anchor.cpu().numpy()
|
| 641 |
|
|
|
|
|
|
|
| 642 |
mesh_v = torch.tensor(mesh_v, dtype=torch.float32)
|
| 643 |
mesh_f = torch.tensor(mesh_f)
|
| 644 |
|
refine/mesh_refine.py
CHANGED
|
@@ -268,6 +268,11 @@ def run_mesh_refine(vertices, faces, pils: List[Image.Image], fixed_v=None, fixe
|
|
| 268 |
|
| 269 |
def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None,
|
| 270 |
distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
vertices, faces = geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=expansion_weight, fixed_v=fixed_v, fixed_f=fixed_f,
|
| 272 |
distract_mask=distract_mask, distract_bbox=distract_bbox, thres=thres, no_decompose=no_decompose)
|
| 273 |
vertices, faces = geo_refine_2(vertices, faces, fixed_v=fixed_v)
|
|
|
|
| 268 |
|
| 269 |
def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None,
|
| 270 |
distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False):
|
| 271 |
+
print(mesh_v.device, mesh_f.device)
|
| 272 |
+
if fixed_v is not None:
|
| 273 |
+
print('fixed_v', fixed_v.shape, fixed_v.device)
|
| 274 |
+
if fixed_f is not None:
|
| 275 |
+
print('fixed_f', fixed_f.shape, fixed_f.device)
|
| 276 |
vertices, faces = geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=expansion_weight, fixed_v=fixed_v, fixed_f=fixed_f,
|
| 277 |
distract_mask=distract_mask, distract_bbox=distract_bbox, thres=thres, no_decompose=no_decompose)
|
| 278 |
vertices, faces = geo_refine_2(vertices, faces, fixed_v=fixed_v)
|