diff --git a/README.md b/README.md index 69b3e2f..056080c 100644 --- a/README.md +++ b/README.md @@ -196,8 +196,11 @@ model_params: wbits: 4 cuda_visible_devices: "0" device: "cuda:0" + st_device: 0 ``` +**Note**: `st_device` is only needed in the case of safetensors model, otherwise you can either remove it or set it to `-1` + Example request: ``` diff --git a/app/llms/gptq_llama/gptq_llama.py b/app/llms/gptq_llama/gptq_llama.py index 51c2097..025e81f 100644 --- a/app/llms/gptq_llama/gptq_llama.py +++ b/app/llms/gptq_llama/gptq_llama.py @@ -91,6 +91,7 @@ def __init__(self, params: Dict[str, str]) -> None: wbits = params.get("wbits", 4) cuda_visible_devices = params.get("cuda_visible_devices", "0") dev = params.get("device", "cuda:0") + st_device = params.get("st_device", -1) os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices self.device = torch.device(dev) @@ -99,11 +100,13 @@ def __init__(self, params: Dict[str, str]) -> None: model_path, wbits, group_size, - cuda_visible_devices, + st_device, ) self.model.to(self.device) - self.tokenizer = AutoTokenizer.from_pretrained(self.model, use_fast=False) + self.tokenizer = AutoTokenizer.from_pretrained( + settings.setup_params["repo_id"], use_fast=False + ) def _load_quant( self, model, checkpoint, wbits, groupsize, device diff --git a/config.yaml b/config.yaml index 69ed68c..b246ead 100644 --- a/config.yaml +++ b/config.yaml @@ -8,3 +8,4 @@ model_params: wbits: 4 cuda_visible_devices: "0" device: "cuda:0" + st_device: 0