Spaces:
Running
on
Zero
Running
on
Zero
Update model/segment_anything_2/sam2/modeling/sam2_base.py
Browse files
model/segment_anything_2/sam2/modeling/sam2_base.py
CHANGED
|
@@ -312,6 +312,11 @@ class SAM2Base(torch.nn.Module):
|
|
| 312 |
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
| 313 |
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
| 314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
# b) Handle mask prompts
|
| 316 |
if mask_inputs is not None:
|
| 317 |
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
|
@@ -333,7 +338,7 @@ class SAM2Base(torch.nn.Module):
|
|
| 333 |
sam_mask_prompt = None
|
| 334 |
|
| 335 |
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
| 336 |
-
points=
|
| 337 |
boxes=None,
|
| 338 |
masks=sam_mask_prompt,
|
| 339 |
text_embeds=text_inputs
|
|
|
|
| 312 |
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
| 313 |
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
| 314 |
|
| 315 |
+
sam_point_prompt = (sam_point_coords, sam_point_labels)
|
| 316 |
+
# added by YxZhang to forbid contemporary using text prompt and point prompt
|
| 317 |
+
if text_inputs is not None:
|
| 318 |
+
sam_point_prompt = None
|
| 319 |
+
|
| 320 |
# b) Handle mask prompts
|
| 321 |
if mask_inputs is not None:
|
| 322 |
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
|
|
|
| 338 |
sam_mask_prompt = None
|
| 339 |
|
| 340 |
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
| 341 |
+
points=sam_point_prompt,
|
| 342 |
boxes=None,
|
| 343 |
masks=sam_mask_prompt,
|
| 344 |
text_embeds=text_inputs
|