Skip to content

Commit

Permalink
[feat] feature lazy model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
tamohannes committed Jul 12, 2024
1 parent 338d6bd commit 69d6ddb
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 52 deletions.
2 changes: 1 addition & 1 deletion llm_roleplay/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.3
2.0.4
9 changes: 4 additions & 5 deletions llm_roleplay/common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@ def __init__(self, cfg: List[Dict[str, Any]], role=None):
self.conv_template = cfg.conv_template
self.spec_tokens = None
self.aim_run = None
self.model = None
self.tokenizer = None
self.role = role
self.history = []
self._get_model()
self._model = None

@staticmethod
def get_model(cfg, role):
return hydra.utils.instantiate(cfg.type, cfg, role)

def _get_model(self):
raise NotImplementedError("method '_get_model' is not implemented")
@property
def model(self):
raise NotImplementedError("property 'model' instantiation is not implemented")

def get_prompt(self, turn, response_msg, persona=None, instructions=None):
raise NotImplementedError("method 'get_prompt' is not implemented")
Expand Down
39 changes: 23 additions & 16 deletions llm_roleplay/models/causal_lm_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
from typing import Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -17,21 +16,29 @@ class CausalLMModel(Model):

def __init__(self, cfg, role) -> None:
super().__init__(cfg, role)

def _get_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
self.model = AutoModelForCausalLM.from_pretrained(
self.cfg.name,
cache_dir=self.cfg.cache_dir,
device_map=Device.get_device(),
torch_dtype=eval_dtype(self.cfg.dtype),
token=self.cfg.api_token,
)

for param in self.model.parameters():
param.requires_grad = False

self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.name)
self.model.eval()
self._tokenizer = None

@property
def model(self):
if self._model is None:
self._model = AutoModelForCausalLM.from_pretrained(
self.cfg.name,
cache_dir=self.cfg.cache_dir,
device_map=Device.get_device(),
torch_dtype=eval_dtype(self.cfg.dtype),
token=self.cfg.api_token,
)

for param in self._model.parameters():
param.requires_grad = False
self._model.eval()
return self._model

@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.name)
return self._tokenizer

def get_prompt(self, turn, response_msg=None, persona=None, instructions=None):
if self.role == "model_inquirer":
Expand Down
21 changes: 12 additions & 9 deletions llm_roleplay/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tiktoken
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM

from llm_roleplay.common.model import Model

Expand All @@ -12,14 +12,17 @@ class OpenAIModel(Model):
def __init__(self, cfg, role) -> None:
super().__init__(cfg, role)

def _get_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
self.model = AzureChatOpenAI(
deployment_name=self.cfg.name,
openai_api_type=self.cfg.openai_api_type,
openai_api_version=self.cfg.openai_api_version,
azure_endpoint=self.cfg.azure_openai_endpoint,
openai_api_key=self.cfg.azure_openai_api_key,
)
@property
def model(self) -> AutoModelForCausalLM:
if self._model is None:
self._model = AzureChatOpenAI(
deployment_name=self.cfg.name,
openai_api_type=self.cfg.openai_api_type,
openai_api_version=self.cfg.openai_api_version,
azure_endpoint=self.cfg.azure_openai_endpoint,
openai_api_key=self.cfg.azure_openai_api_key,
)
return self._model

def get_prompt(self, turn, response_msg, persona=None, instructions=None):
if self.role == "model_inquirer":
Expand Down
49 changes: 28 additions & 21 deletions llm_roleplay/models/pipeline_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from urartu.common.device import Device
from urartu.utils.dtype import eval_dtype
Expand All @@ -15,25 +13,34 @@ class PipelineModel(Model):

def __init__(self, cfg, role) -> None:
super().__init__(cfg, role)

def _get_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
model = AutoModelForCausalLM.from_pretrained(
self.cfg.name,
cache_dir=self.cfg.cache_dir,
device_map=Device.get_device(),
torch_dtype=eval_dtype(self.cfg.dtype),
token=self.cfg.api_token,
)
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.name)

self.model = pipeline(
"text-generation",
model=model,
tokenizer=self.tokenizer,
torch_dtype=eval_dtype(self.cfg.dtype),
device_map=Device.get_device(),
eos_token_id=self.tokenizer.eos_token_id,
)
self._tokenizer = None

@property
def model(self):
if self._model is None:
clm_model = AutoModelForCausalLM.from_pretrained(
self.cfg.name,
cache_dir=self.cfg.cache_dir,
device_map=Device.get_device(),
torch_dtype=eval_dtype(self.cfg.dtype),
token=self.cfg.api_token,
)

self._model = pipeline(
"text-generation",
model=clm_model,
tokenizer=self.tokenizer,
torch_dtype=eval_dtype(self.cfg.dtype),
device_map=Device.get_device(),
eos_token_id=self.tokenizer.eos_token_id,
)
return self._model

@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.name)
return self._tokenizer

def get_prompt(self, turn, response_msg=None, persona=None, instructions=None):
if self.role == "model_inquirer":
Expand Down

0 comments on commit 69d6ddb

Please sign in to comment.