Skip to content

Commit

Permalink
add a separate parameter for safetensors models
Browse files Browse the repository at this point in the history
  • Loading branch information
1b5d committed May 5, 2023
1 parent 1b2037b commit b668b59
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,11 @@ model_params:
wbits: 4
cuda_visible_devices: "0"
device: "cuda:0"
st_device: 0
```

**Note**: `st_device` is only needed in the case of safetensors model, otherwise you can either remove it or set it to `-1`

Example request:

```
Expand Down
7 changes: 5 additions & 2 deletions app/llms/gptq_llama/gptq_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, params: Dict[str, str]) -> None:
wbits = params.get("wbits", 4)
cuda_visible_devices = params.get("cuda_visible_devices", "0")
dev = params.get("device", "cuda:0")
st_device = params.get("st_device", -1)

os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
self.device = torch.device(dev)
Expand All @@ -99,11 +100,13 @@ def __init__(self, params: Dict[str, str]) -> None:
model_path,
wbits,
group_size,
cuda_visible_devices,
st_device,
)

self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model, use_fast=False)
self.tokenizer = AutoTokenizer.from_pretrained(
settings.setup_params["repo_id"], use_fast=False
)

def _load_quant(
self, model, checkpoint, wbits, groupsize, device
Expand Down
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ model_params:
wbits: 4
cuda_visible_devices: "0"
device: "cuda:0"
st_device: 0

0 comments on commit b668b59

Please sign in to comment.