Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider refactor for model types #6

Open
darth-veitcher opened this issue Apr 26, 2023 · 0 comments
Open

Consider refactor for model types #6

darth-veitcher opened this issue Apr 26, 2023 · 0 comments

Comments

@darth-veitcher
Copy link

At the moment the llms and associated inference / embeddings are in class specific implementations. Not sure if this is necessary or DRY as you want to start catering for different architectures (eg. Dolly v2).

Consider refactoring with interfaces using a config: AutoConfig = AutoConfig.from_pretrained(path_or_repo) type approach. As this might allow for scaling to different model types without the need for heavy configuration on the users side or massive amounts of boilerplate rewriting of specific implementations.

Eg. stub

from transformers import AutoConfig

config: AutoConfig = AutoConfig.from_pretrained(path_or_repo)

if config.model_type == "llama":
    from transformers import LlamaForCausalLM, LlamaTokenizer

    tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(path_or_repo)
    model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(
        path_or_repo, **model_kwargs
    )  # , load_in_8bit=True, device_map="auto")
elif config.model_type == "gpt_neox":
    from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast

    tokenizer: GPTNeoXForCausalLM = GPTNeoXForCausalLM.from_pretrained(path_or_repo)
    model: GPTNeoXTokenizerFast = GPTNeoXTokenizerFast.from_pretrained(
        path_or_repo, **model_kwargs
    )  # , load_in_8bit=True, device_map="auto")
else:
    logger.error(f"Unable to determine model type {config.model_type}. Attempting AutoModel")
    try:
        tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(path_or_repo)
        model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(
            path_or_repo, **model_kwargs
        )  # , load_in_8bit=True, device_map="auto")
    except Exception as e:
        logger.exception(e)
        raise
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant