Wendy-Fly commited on
Commit
d81790e
·
verified ·
1 Parent(s): d27a9f2

Upload infer_1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. infer_1.py +1 -1
infer_1.py CHANGED
@@ -104,7 +104,7 @@ for batch_idx in tqdm(range(begin, end, batch_size)):
104
  return_tensors="pt",
105
  **video_kwargs,
106
  )
107
- inputs = inputs.to("cuda:1")
108
 
109
  # Inference
110
  generated_ids = model.generate(**inputs, max_new_tokens=128)
 
104
  return_tensors="pt",
105
  **video_kwargs,
106
  )
107
+ inputs = inputs.to(model.device)
108
 
109
  # Inference
110
  generated_ids = model.generate(**inputs, max_new_tokens=128)