Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced gemma prediction with new flawless logit #51

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a Python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
coverage.xml
*.cover
*.py,cover
.cache
nosetests.xml
coverage/
*.cover
.hypothesis/

# Pytest cache
.pytest_cache/
.cache/

# MyPy cache
.mypy_cache/

# Profiling data
*.lprof
.prof

# Virtual environment directories
venv/
ENV/
env/
.venv/
env.bak/
venv.bak/

# Jupyter Notebook checkpoints
.ipynb_checkpoints

# pyenv
.python-version

# Editor directories and files
.vscode/
.idea/
*.swp
*.swo
*~

# macOS system files
.DS_Store

# Temporary files
*.tmp
*.log
*.bak
*.orig

# Local development overrides
.local/
.env

# Docker-related files
docker-compose.override.yml

# Poetry-specific files
poetry.lock
174 changes: 174 additions & 0 deletions examples/mmlu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
r"""An example showing how to load a checkpoint and sample from it.

Getting Started with Gemma Sampling:

Prerequisites:

1. Download your Gemma checkpoint: Choose the desired checkpoint and download it.
2. Get the Gemma tokenizer: Download the tokenizer file required for your model.
3. Install Gemma: Follow the straightforward instructions in the README to install the Gemma repository.

Ready to Sample!

Here's how to run the sampling.py script:

python mmlu.py --path_checkpoint=${PATH_TO_THE_GEMMA_CHECKPOINT} \
--path_tokenizer=${PATH_TO_THE_GEMMA_TOKENIZER}
"""

import os
import sys
import re
from absl import flags
from absl import app
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib

import sentencepiece as spm
import datasets

# Define flags
FLAGS = flags.FLAGS

_PATH_CHECKPOINT = flags.DEFINE_string(
"path_checkpoint", None, required=True, help="Path to checkpoint."
)
_PATH_TOKENIZER = flags.DEFINE_string(
"path_tokenizer", None, required=True, help="Path to tokenizer."
)
_TOTAL_GENERATION_STEPS = flags.DEFINE_integer(
"total_generation_steps", 1024, help="Maximum number of steps to run when decoding."
)
_PREAMBLE = flags.DEFINE_string(
"preamble",
"The following question is related to machine learning. Please provide a step by step solution to the following question.",
help="Preamble for the prompt.",
)
_PROMPT = flags.DEFINE_string(
"prompt",
"""Q: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.
Subject: abstract_algebra
Choices: [ "0", "1", "2", "3" ]
A: 1

Q: Statement 1 | If aH is an element of a factor group, then |aH| divides |a|. Statement 2 | If H and K are subgroups of G then HK is a subgroup of G.
Subject: abstract_algebra
Choices: [ "True, True", "False, False", "True, False", "False, True" ]
A: 1

Q: Statement 1 | Every element of a group generates a cyclic subgroup of the group. Statement 2 | The symmetric group S_10 has 10 elements.
Subject: abstract_algebra
Choices: [ "True, True", "False, False", "True, False", "False, True" ]
A: 2

Q: Statement 1| Every function from a finite set onto itself must be one to one. Statement 2 | Every subgroup of an abelian group is abelian.
Subject: abstract_algebra
Choices: [ "True, True", "False, False", "True, False", "False, True" ]
A: 0

Q: Find the characteristic of the ring 2Z.
Subject: abstract_algebra
Choices: [ "0", "3", "12", "30" ]
A: 0""",
help="Prompt for the model.",
)

_CACHE_SIZE = 1024

# Load MMLU dataset
mmlu = datasets.load_dataset("cais/mmlu", "machine_learning", cache_dir='/dc/cais_cache')
mmlu_test = mmlu['test']

def _load_and_infer(
*,
path_checkpoint: str,
path_tokenizer: str,
preamble: str,
prompt: str,
total_generation_steps: int,
cache_size: int,
) -> None:
"""Loads and infers a string from a checkpoint."""
print(f"Loading the parameters from {path_checkpoint}")
parameters = params_lib.load_and_format_params(path_checkpoint)
print("Parameters loaded.")

# Create a sampler with the right param shapes.
vocab = spm.SentencePieceProcessor()
vocab.Load(path_tokenizer)
transformer_config = transformer_lib.TransformerConfig.from_params(
parameters,
cache_size=cache_size
)
transformer = transformer_lib.Transformer(transformer_config)
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=parameters["transformer"],
)

TEMPLATE = """
Q: {question}
Subject: {subject}
Choices: {choices}
A:"""

all_correct = 0
all_responses = {}
short_responses = {}
idx = 0
correct = 0

for task_id, problem in enumerate(mmlu_test):

if task_id in all_responses:
continue

# Print Task ID
print(f"task_id {task_id}")

# Formulate and print the full prompt
full_prompt = (preamble + '\n\n' + prompt + '\n' +
TEMPLATE.format(question=problem['question'],
subject=problem['subject'],
choices=problem['choices']))
short_prompt = preamble + '\n' + TEMPLATE.format(question=problem['question'],
subject=problem['subject'],
choices=problem['choices'])

input_batch = [full_prompt]
response = sampler(input_strings=input_batch, total_generation_steps=total_generation_steps)
print(response.text)

all_responses[task_id] = response.text[0].split('\nQ:')[0]
short_responses[task_id] = all_responses[task_id].strip()
print(f"Short answer: {short_responses[task_id]}")

try:
correct += int(problem['answer']) == int(short_responses[task_id])
except ValueError:
correct += problem['answer'] == short_responses[task_id]

print('-'*40)
print(f"Ground truth answer {problem['answer']}")
print(f"Short ground truth answer {problem['answer']}")
print(f"Correct: {correct} out of {idx+1}")
print("="*40)
idx += 1

def main(argv):
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

_load_and_infer(
path_checkpoint=_PATH_CHECKPOINT.value,
path_tokenizer=_PATH_TOKENIZER.value,
preamble=_PREAMBLE.value,
prompt=_PROMPT.value,
total_generation_steps=_TOTAL_GENERATION_STEPS.value,
cache_size=_CACHE_SIZE,
)

if __name__ == "__main__":
app.run(main)
21 changes: 17 additions & 4 deletions gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,13 +313,26 @@ def __call__(
cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch

x = self.final_norm(x)
logits = self.embedder.decode(x)
baseline_logits = self.embedder.decode(x)

# Compute logits for each token in isolation
single_token_logits = []
for token in last_tokens:
single_token_x = self.embedder.encode(token.reshape(1, -1))
single_token_x = self.final_norm(single_token_x)
single_token_logits.append(self.embedder.decode(single_token_x))

# Normalize and adjust
normalized_single_token_logits = jax.nn.softmax(jnp.stack(single_token_logits), axis=-1)
normalized_sum = jnp.sum(normalized_single_token_logits, axis=0)
adjusted_logits = baseline_logits - normalized_sum

# Apply softcap if configured
if self.config.final_logit_softcap is not None:
logits /= self.config.final_logit_softcap
logits = jnp.tanh(logits) * self.config.final_logit_softcap
adjusted_logits /= self.config.final_logit_softcap
adjusted_logits = np.tanh(adjusted_logits) * self.config.final_logit_softcap

return logits, cache # pytype: disable=bad-return-type
return adjusted_logits, cache # pytype: disable=bad-return-type


def make_causal_attn_mask(
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ pytest = {version = "^8.0.0", optional = true}

[tool.poetry.extras]
test = ["pytest"]
scikit-learn = ["scikit-learn"]
huggingface_hub = ["huggingface_hub"]
datasets = ["datasets"]
jupyterlab = ["jupyterlab"]

[build-system]
requires = ["poetry-core"]
Expand Down