Skip to content

Commit

Permalink
fix llava model device bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
wanfengcxz committed Oct 28, 2024
1 parent 2bae28f commit d2c968c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions internlm/model/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from internlm.model.modules.linear import new_linear
from internlm.model.modules.norm import new_layer_norm
from internlm.utils.logger import get_logger
from internlm.utils.common import get_current_device

logger = get_logger(__file__)

Expand Down Expand Up @@ -192,7 +193,7 @@ def forward(self, hidden_states=None, images=None, input_ids=None, **kwargs):
if hasattr(self, "vit") and hasattr(self, "vision_proj") and hasattr(self, "tok_embeddings"):
# vit
if len(images) == 1 and len(images[0]) == 0: # make sure grad in Qformer for update
images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).cuda().to(self.dtype)]
images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).to(self.device).to(self.dtype)]
pure_text = True

for image in images:
Expand All @@ -201,7 +202,7 @@ def forward(self, hidden_states=None, images=None, input_ids=None, **kwargs):
x = []
else:
assert not isinstance(image, list), image
x = image.to(torch.cuda.current_device()).to(self.dtype)
x = image.to(get_current_device()).to(self.dtype)
x = self.vit(x)
x = self.vision_proj(x)
xs.append(x)
Expand Down

0 comments on commit d2c968c

Please sign in to comment.