From d2c968c7d28f5ec598e6e553859e0f7e6f286c3e Mon Sep 17 00:00:00 2001 From: wanfengcxz <2917021186@qq.com> Date: Mon, 28 Oct 2024 16:11:50 +0800 Subject: [PATCH] fix llava model device bugs --- internlm/model/modeling_llava.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internlm/model/modeling_llava.py b/internlm/model/modeling_llava.py index 4c2bb174..41cb6b89 100644 --- a/internlm/model/modeling_llava.py +++ b/internlm/model/modeling_llava.py @@ -15,6 +15,7 @@ from internlm.model.modules.linear import new_linear from internlm.model.modules.norm import new_layer_norm from internlm.utils.logger import get_logger +from internlm.utils.common import get_current_device logger = get_logger(__file__) @@ -192,7 +193,7 @@ def forward(self, hidden_states=None, images=None, input_ids=None, **kwargs): if hasattr(self, "vit") and hasattr(self, "vision_proj") and hasattr(self, "tok_embeddings"): # vit if len(images) == 1 and len(images[0]) == 0: # make sure grad in Qformer for update - images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).cuda().to(self.dtype)] + images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).to(self.device).to(self.dtype)] pure_text = True for image in images: @@ -201,7 +202,7 @@ def forward(self, hidden_states=None, images=None, input_ids=None, **kwargs): x = [] else: assert not isinstance(image, list), image - x = image.to(torch.cuda.current_device()).to(self.dtype) + x = image.to(get_current_device()).to(self.dtype) x = self.vit(x) x = self.vision_proj(x) xs.append(x)