Skip to content

Commit

Permalink
Add T5 LM v1.1 encoder (#550)
Browse files Browse the repository at this point in the history
The encoder shares much of the underlying stack as the decoder. Here
only the encoder is presented as a class.
I have not gone out of my way to strip all decoder related stuff from
the stack. Things like check-pointing and dropout are stripped.

The author attribution is added to the license of the T5 model file as
this seems like a derivative work. They are both Apache 2.0.

There are a few tests of the various components and 2 tests for the
entire encoder for the small and xxl variants. They relay on huggingface
and the models are downloaded no the fly into the cache. The tests
expect the corresponding GGUF files to be already preset and available
on the file system.
  • Loading branch information
sogartar authored Nov 20, 2024
1 parent 19a229e commit 9535984
Show file tree
Hide file tree
Showing 12 changed files with 1,636 additions and 63 deletions.
59 changes: 59 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,62 @@ jobs:
if: ${{ !cancelled() }}
run: |
pytest -n 4 sharktank/
test_with_data:
name: "Data-dependent Tests"
strategy:
matrix:
version: [3.11]
runs-on: [llama-mi300x-3]
fail-fast: false
runs-on: ${{matrix.runs-on}}
defaults:
run:
shell: bash
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
HF_HOME: "/data/huggingface"
SHARK_PLATFORM_REPO_ROOT: ${{ github.workspace }}
steps:
- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}

- name: Install sharktank deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
# Install latest iree-tubrine.
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
# Try with the latest IREE nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler \
iree-base-runtime
- name: Run tests
run: |
pytest \
--with-t5-data \
sharktank/tests/models/t5/t5_test.py
42 changes: 42 additions & 0 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def pytest_addoption(parser):
help="Enable all llama benchmarking tests",
)

parser.addoption(
"--with-t5-data",
action="store_true",
default=False,
help=(
"Enable tests that use T5 data like models that is not a part of the source "
"code. The user is expected to provide the data"
),
)

# TODO: Remove all hardcoded paths in CI tests
parser.addoption(
"--llama3-8b-tokenizer-path",
Expand Down Expand Up @@ -133,6 +143,28 @@ def pytest_addoption(parser):
help="Llama3.1 405b fp8 model path",
)

# To obtain a T5 GGUF file you can use llama.cpp's convert_hf_to_gguf.py.
# https://github.com/ggerganov/llama.cpp/blob/9abe9eeae98b11fa93b82632b264126a010225ff/convert_hf_to_gguf.py
# E.g.
# git lfs install
# git clone https://huggingface.co/google/t5-v1_1-small
# convert_hf_to_gguf.py \
# --outfile t5-v1_1-small.gguf \
# --outtype=f32 \
# t5-v1_1-small
parser.addoption(
"--google-t5-v1-1-small-fp32-model-path",
type=Path,
default="/data/t5/small/google__t5-v1_1-small_fp32.gguf",
help="Google T5 v1.1 small fp32 model path",
)
parser.addoption(
"--google-t5-v1-1-xxl-fp32-model-path",
type=Path,
default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf",
help="Google T5 v1.1 XXL fp32 model path",
)

parser.addoption(
"--baseline-perplexity-scores",
type=Path,
Expand Down Expand Up @@ -256,6 +288,16 @@ def get_model_artifacts(request: FixtureRequest):
model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
)
model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-small-fp32-model-path",
"google__t5_v1_1_small_fp32_model",
)
model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-xxl-fp32-model-path",
"google__t5_v1_1_xxl_fp32_model",
)
return model_path


Expand Down
73 changes: 71 additions & 2 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
(and indeed, can bootstrap these off of GGUF files).
"""

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Optional
import torch

__all__ = ["LlamaHParams", "LlamaModelConfig"]
__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"]


@dataclass
Expand Down Expand Up @@ -179,3 +179,72 @@ class LlamaModelConfig:
# be the difference of many gigabytes of static data being embedded in
# the program and not.
static_tables: bool = True


@dataclass
class T5Config:
return_dict: bool = True
output_hidden_states: bool = False
output_attentions: bool = False
is_encoder_decoder: bool = True
is_decoder: bool = False
vocab_size: int = 32128
d_model: int = 512
d_kv: int = 64
d_ff: int = 2048
num_layers: int = 6
num_decoder_layers: int = 6
num_heads: int = 8
relative_attention_num_buckets: int = 32
relative_attention_max_distance: int = 128
layer_norm_epsilon: float = 1e-6
feed_forward_proj: str = "relu"
is_gated_act: bool = field(init=False)
activation_dtype: torch.dtype = torch.float32
dense_act_fn: str = field(init=False)
use_cache: bool = True
pad_token_id: int = 0
eos_token_id: int = 1
decoder_start_token_id: int = 0

def __post_init__(self):
self.is_gated_act = self.feed_forward_proj.startswith("gated-")
self.dense_act_fn = (
self.feed_forward_proj.split("-")[1]
if "-" in self.feed_forward_proj
else self.feed_forward_proj
)
if self.dense_act_fn == "gelu":
self.dense_act_fn = "gelu_new"

@staticmethod
def from_gguf_properties(properties: dict[str, Any], **kwargs):
assert properties["general.architecture"] == "t5"
assert (
properties["t5.attention.layer_norm_epsilon"]
== properties["t5.attention.layer_norm_rms_epsilon"]
)

gguf_to_config_names_map = {
"t5.embedding_length": ["d_model"],
"t5.feed_forward_length": ["d_ff"],
"t5.block_count": ["num_layers", "num_decoder_layers"],
"t5.attention.head_count": ["num_heads"],
"t5.attention.key_length": ["d_kv"],
"t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"],
"t5.attention.relative_buckets_count": ["relative_attention_num_buckets"],
"t5.decoder_start_token_id": ["decoder_start_token_id"],
"tokenizer.ggml.eos_token_id": ["eos_token_id"],
"tokenizer.ggml.padding_token_id": ["pad_token_id"],
}
all_kwargs = {"vocab_size": None, "feed_forward_proj": None}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_config_names_map.items()
for config_name in config_names
}
)
all_kwargs.update(kwargs)

return T5Config(**all_kwargs)
28 changes: 20 additions & 8 deletions sharktank/sharktank/layers/ffn_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional
from typing import Optional, Callable

import torch
import torch.nn.functional as F
from .. import ops
from ..types import AnyTensor

from .base import Theta, ThetaLayer
from .linear import LinearLayer
Expand All @@ -22,18 +23,29 @@ class FFN(ThetaLayer):
def __init__(
self,
theta: Theta,
is_gated: bool = True,
activation_fn: Callable[[AnyTensor], AnyTensor] = F.silu,
):
super().__init__(theta)

self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.is_gated = is_gated
self.activation_fn = activation_fn
if self.is_gated:
self.add_module("ffn_gate", LinearLayer(theta("ffn_gate")))
self.add_module("ffn_up", LinearLayer(theta("ffn_up")))
self.add_module("ffn_down", LinearLayer(theta("ffn_down")))

def forward(
self,
h: torch.Tensor,
):
ffn_gate = ops.elementwise(F.silu, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
h: AnyTensor,
) -> AnyTensor:
if self.is_gated:
ffn_gate = ops.elementwise(self.activation_fn, self.ffn_gate(h))
ffn_up = self.ffn_up(h)
ffn_down = self.ffn_down(ffn_gate * ffn_up)
return ffn_down
else:
h = self.ffn_up(h)
h = ops.elementwise(self.activation_fn, h)
h = self.ffn_down(h)
return h
Loading

0 comments on commit 9535984

Please sign in to comment.