From be9945e8f36554f6210def772afe4ca16a2ad090 Mon Sep 17 00:00:00 2001 From: Carlo Fisicaro Date: Wed, 18 Sep 2024 14:18:21 +0200 Subject: [PATCH 01/17] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d7d19bd..9ae7054 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Gemma +# Gemma - CUSTOM [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini From ce1afcc6601fb0f1148477c783503eccdd9a6f36 Mon Sep 17 00:00:00 2001 From: Carlo Fisicaro Date: Thu, 19 Sep 2024 03:11:35 +0200 Subject: [PATCH 02/17] Create .gitignore --- .gitignore | 102 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d13be0f --- /dev/null +++ b/.gitignore @@ -0,0 +1,102 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a Python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +coverage.xml +*.cover +*.py,cover +.cache +nosetests.xml +coverage/ +*.cover +.hypothesis/ + +# Pytest cache +.pytest_cache/ +.cache/ + +# MyPy cache +.mypy_cache/ + +# Profiling data +*.lprof +.prof + +# Virtual environment directories +venv/ +ENV/ +env/ +.venv/ +env.bak/ +venv.bak/ + +# Jupyter Notebook checkpoints +.ipynb_checkpoints + +# pyenv +.python-version + +# Editor directories and files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# macOS system files +.DS_Store + +# Temporary files +*.tmp +*.log +*.bak +*.orig + +# Local development overrides +.local/ +.env + +# Docker-related files +docker-compose.override.yml + +# Poetry-specific files +poetry.lock From 0251732d93acdebc3cb303e770d4f26c954f0b77 Mon Sep 17 00:00:00 2001 From: carlofisicaro Date: Thu, 19 Sep 2024 16:27:21 +0000 Subject: [PATCH 03/17] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9ae7054..21a8d75 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Gemma - CUSTOM +# Gemma - CUSTOM MODEL1 [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini From f0f147ed303cf77711012472915c822b032ecc44 Mon Sep 17 00:00:00 2001 From: carlofisicaro Date: Thu, 19 Sep 2024 18:15:46 +0000 Subject: [PATCH 04/17] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 21a8d75..e6cd5f8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Gemma - CUSTOM MODEL1 +# Gemma - CUSTOM MODEL2 [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini From 9063a3d678861763033ef86345b9c052edbf905f Mon Sep 17 00:00:00 2001 From: carlofisicaro Date: Thu, 19 Sep 2024 18:49:26 +0000 Subject: [PATCH 05/17] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e6cd5f8..e924811 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Gemma - CUSTOM MODEL2 +# Gemma - CUSTOM MODEL [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini From e87d5cb22a4f7b6377120018f3b266085659d082 Mon Sep 17 00:00:00 2001 From: carlofisicaro Date: Thu, 19 Sep 2024 18:53:37 +0000 Subject: [PATCH 06/17] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e924811..7a9b2ce 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Gemma - CUSTOM MODEL +# Gemma - CUSTOM MODEL 3 [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini From 008ff127ec485d66457856725a4116a34d0dd728 Mon Sep 17 00:00:00 2001 From: carlofisicaro Date: Thu, 19 Sep 2024 18:54:04 +0000 Subject: [PATCH 07/17] update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7a9b2ce..d7d19bd 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Gemma - CUSTOM MODEL 3 +# Gemma [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini From 220b25a01e89e21d7420c04fa4b76e3ff7ba786c Mon Sep 17 00:00:00 2001 From: carlofisicaro Date: Sat, 21 Sep 2024 00:28:11 +0000 Subject: [PATCH 08/17] feat: add MMLU benchmarking --- examples/mmlu.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 examples/mmlu.py diff --git a/examples/mmlu.py b/examples/mmlu.py new file mode 100644 index 0000000..e5458d7 --- /dev/null +++ b/examples/mmlu.py @@ -0,0 +1,69 @@ +import os +import sentencepiece as spm +from datasets import load_dataset +from sklearn.metrics import accuracy_score +from gemma import params as params_lib +from gemma import sampler as sampler_lib +from gemma import transformer as transformer_lib +from tqdm import tqdm + + +# Paths to the model and tokenizer +variant = '7b' +weights_dir = "/dc/gemma_models_7b/" +checkpoint_path = os.path.join(weights_dir, variant) +tokenizer_path = os.path.join(weights_dir, 'tokenizer.model') + +# Load the parameters +parameters = params_lib.load_and_format_params(checkpoint_path) + +# Load the tokenizer +vocab = spm.SentencePieceProcessor() +vocab.Load(tokenizer_path) + +# Create the transformer configuration and model +transformer_config = transformer_lib.TransformerConfig.from_params(parameters) +transformer = transformer_lib.Transformer(transformer_config) + +# Create the sampler +sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=vocab, + params=parameters["transformer"], +) + +# List of available configurations +configs = [ + 'machine_learning', +] + +# Evaluate the model on a specific configuration of the MMLU dataset +def evaluate_model_on_config(sampler, config): + dataset = load_dataset("cais/mmlu", config) + predictions = [] + references = [] + + for example in tqdm(dataset['test'], desc=f"Processing {config}"): + question = example['question'] + choices = example['choices'] + reference = example['answer'] + + # Sample the output + sampled_str = sampler( + input_strings=[question], + total_generation_steps=100 # Adjust as needed + ).text[0] + + # Find the choice that matches the sampled output + best_choice = max(choices, key=lambda choice: sampled_str in choice) + predictions.append(best_choice) + references.append(reference) + + # Calculate accuracy + accuracy = accuracy_score(references, predictions) + return accuracy + +# Evaluate the model on all configurations +for config in configs: + accuracy = evaluate_model_on_config(sampler, config) + print(f"Model accuracy on {config} configuration: {accuracy:.2f}") \ No newline at end of file From c5f05fed5e8517f61491d8ed9be7bda991621d13 Mon Sep 17 00:00:00 2001 From: Carlo Fisicaro Date: Sat, 21 Sep 2024 05:06:01 +0200 Subject: [PATCH 09/17] Update mmlu.py --- examples/mmlu.py | 219 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 162 insertions(+), 57 deletions(-) diff --git a/examples/mmlu.py b/examples/mmlu.py index e5458d7..84e2ee8 100644 --- a/examples/mmlu.py +++ b/examples/mmlu.py @@ -1,69 +1,174 @@ +r"""An example showing how to load a checkpoint and sample from it. + +Getting Started with Gemma Sampling: + +Prerequisites: + +1. Download your Gemma checkpoint: Choose the desired checkpoint and download it. +2. Get the Gemma tokenizer: Download the tokenizer file required for your model. +3. Install Gemma: Follow the straightforward instructions in the README to install the Gemma repository. + +Ready to Sample! + +Here's how to run the sampling.py script: + +python mmlu.py --path_checkpoint=${PATH_TO_THE_GEMMA_CHECKPOINT} \ + --path_tokenizer=${PATH_TO_THE_GEMMA_TOKENIZER} +""" + import os -import sentencepiece as spm -from datasets import load_dataset -from sklearn.metrics import accuracy_score +import sys +import re +from absl import flags +from absl import app from gemma import params as params_lib from gemma import sampler as sampler_lib from gemma import transformer as transformer_lib -from tqdm import tqdm +import sentencepiece as spm +import datasets -# Paths to the model and tokenizer -variant = '7b' -weights_dir = "/dc/gemma_models_7b/" -checkpoint_path = os.path.join(weights_dir, variant) -tokenizer_path = os.path.join(weights_dir, 'tokenizer.model') +# Define flags +FLAGS = flags.FLAGS -# Load the parameters -parameters = params_lib.load_and_format_params(checkpoint_path) +_PATH_CHECKPOINT = flags.DEFINE_string( + "path_checkpoint", None, required=True, help="Path to checkpoint." +) +_PATH_TOKENIZER = flags.DEFINE_string( + "path_tokenizer", None, required=True, help="Path to tokenizer." +) +_TOTAL_GENERATION_STEPS = flags.DEFINE_integer( + "total_generation_steps", 1024, help="Maximum number of steps to run when decoding." +) +_PREAMBLE = flags.DEFINE_string( + "preamble", + "The following question is related to machine learning. Please provide a step by step solution to the following question.", + help="Preamble for the prompt.", +) +_PROMPT = flags.DEFINE_string( + "prompt", + """Q: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field. +Subject: abstract_algebra +Choices: [ "0", "1", "2", "3" ] +A: 1 + +Q: Statement 1 | If aH is an element of a factor group, then |aH| divides |a|. Statement 2 | If H and K are subgroups of G then HK is a subgroup of G. +Subject: abstract_algebra +Choices: [ "True, True", "False, False", "True, False", "False, True" ] +A: 1 -# Load the tokenizer -vocab = spm.SentencePieceProcessor() -vocab.Load(tokenizer_path) +Q: Statement 1 | Every element of a group generates a cyclic subgroup of the group. Statement 2 | The symmetric group S_10 has 10 elements. +Subject: abstract_algebra +Choices: [ "True, True", "False, False", "True, False", "False, True" ] +A: 2 -# Create the transformer configuration and model -transformer_config = transformer_lib.TransformerConfig.from_params(parameters) -transformer = transformer_lib.Transformer(transformer_config) +Q: Statement 1| Every function from a finite set onto itself must be one to one. Statement 2 | Every subgroup of an abelian group is abelian. +Subject: abstract_algebra +Choices: [ "True, True", "False, False", "True, False", "False, True" ] +A: 0 -# Create the sampler -sampler = sampler_lib.Sampler( - transformer=transformer, - vocab=vocab, - params=parameters["transformer"], +Q: Find the characteristic of the ring 2Z. +Subject: abstract_algebra +Choices: [ "0", "3", "12", "30" ] +A: 0""", + help="Prompt for the model.", ) -# List of available configurations -configs = [ - 'machine_learning', -] - -# Evaluate the model on a specific configuration of the MMLU dataset -def evaluate_model_on_config(sampler, config): - dataset = load_dataset("cais/mmlu", config) - predictions = [] - references = [] - - for example in tqdm(dataset['test'], desc=f"Processing {config}"): - question = example['question'] - choices = example['choices'] - reference = example['answer'] - - # Sample the output - sampled_str = sampler( - input_strings=[question], - total_generation_steps=100 # Adjust as needed - ).text[0] - - # Find the choice that matches the sampled output - best_choice = max(choices, key=lambda choice: sampled_str in choice) - predictions.append(best_choice) - references.append(reference) - - # Calculate accuracy - accuracy = accuracy_score(references, predictions) - return accuracy - -# Evaluate the model on all configurations -for config in configs: - accuracy = evaluate_model_on_config(sampler, config) - print(f"Model accuracy on {config} configuration: {accuracy:.2f}") \ No newline at end of file +_CACHE_SIZE = 1024 + +# Load MMLU dataset +mmlu = datasets.load_dataset("cais/mmlu", "machine_learning", cache_dir='/dc/cais_cache') +mmlu_test = mmlu['test'] + +def _load_and_infer( + *, + path_checkpoint: str, + path_tokenizer: str, + preamble: str, + prompt: str, + total_generation_steps: int, + cache_size: int, +) -> None: + """Loads and infers a string from a checkpoint.""" + print(f"Loading the parameters from {path_checkpoint}") + parameters = params_lib.load_and_format_params(path_checkpoint) + print("Parameters loaded.") + + # Create a sampler with the right param shapes. + vocab = spm.SentencePieceProcessor() + vocab.Load(path_tokenizer) + transformer_config = transformer_lib.TransformerConfig.from_params( + parameters, + cache_size=cache_size + ) + transformer = transformer_lib.Transformer(transformer_config) + sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=vocab, + params=parameters["transformer"], + ) + + TEMPLATE = """ + Q: {question} + Subject: {subject} + Choices: {choices} + A:""" + + all_correct = 0 + all_responses = {} + short_responses = {} + idx = 0 + correct = 0 + + for task_id, problem in enumerate(mmlu_test): + + if task_id in all_responses: + continue + + # Print Task ID + print(f"task_id {task_id}") + + # Formulate and print the full prompt + full_prompt = (preamble + '\n\n' + prompt + '\n' + + TEMPLATE.format(question=problem['question'], + subject=problem['subject'], + choices=problem['choices'])) + short_prompt = preamble + '\n' + TEMPLATE.format(question=problem['question'], + subject=problem['subject'], + choices=problem['choices']) + + input_batch = [full_prompt] + response = sampler(input_strings=input_batch, total_generation_steps=total_generation_steps) + print(response.text) + + all_responses[task_id] = response.text[0].split('\nQ:')[0] + short_responses[task_id] = all_responses[task_id].strip() + print(f"Short answer: {short_responses[task_id]}") + + try: + correct += int(problem['answer']) == int(short_responses[task_id]) + except ValueError: + correct += problem['answer'] == short_responses[task_id] + + print('-'*40) + print(f"Ground truth answer {problem['answer']}") + print(f"Short ground truth answer {problem['answer']}") + print(f"Correct: {correct} out of {idx+1}") + print("="*40) + idx += 1 + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + _load_and_infer( + path_checkpoint=_PATH_CHECKPOINT.value, + path_tokenizer=_PATH_TOKENIZER.value, + preamble=_PREAMBLE.value, + prompt=_PROMPT.value, + total_generation_steps=_TOTAL_GENERATION_STEPS.value, + cache_size=_CACHE_SIZE, + ) + +if __name__ == "__main__": + app.run(main) From 2ab0824e8c3ee9c92c8f95501c2f41418d60788c Mon Sep 17 00:00:00 2001 From: Gemma Team Date: Thu, 15 Aug 2024 05:55:37 -0700 Subject: [PATCH 10/17] Fixed tokens to string method in tokenizer. PiperOrigin-RevId: 663277444 Change-Id: I8d7030ce586577a433c48f32df7efa7c141b171a --- colabs/fine_tuning_tutorial.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colabs/fine_tuning_tutorial.ipynb b/colabs/fine_tuning_tutorial.ipynb index 9bfc9b4..7543c74 100644 --- a/colabs/fine_tuning_tutorial.ipynb +++ b/colabs/fine_tuning_tutorial.ipynb @@ -299,7 +299,7 @@ "\n", " def to_string(self, tokens: jax.Array) -\u003e str:\n", " \"\"\"Convert an array of tokens to a string.\"\"\"\n", - " return self._spm_processor.EncodeIds(tokens.tolist())" + " return self._spm_processor.DecodeIds(tokens.tolist())" ] }, { From c706509dc3c87d891a6be721d9f0eabce7cccd82 Mon Sep 17 00:00:00 2001 From: Andreas Steiner Date: Fri, 16 Aug 2024 05:05:39 -0700 Subject: [PATCH 11/17] Aligns meaning of `_compute_attention_masks(input_mask)` with `transformer_lib.make_causal_attn_mask(input_mask)` PiperOrigin-RevId: 663692225 Change-Id: Ie2cb6229302087ea1ce5b5c7f442a088207ead07 --- gemma/sampler.py | 12 ++++++------ gemma/sampler_test.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/gemma/sampler.py b/gemma/sampler.py index 4f111c9..a511cdd 100644 --- a/gemma/sampler.py +++ b/gemma/sampler.py @@ -35,7 +35,7 @@ def _compute_attention_masks( """Computes causal attention mask.""" bsz = input_mask.shape[0] batch_time_step = jnp.full((bsz, 1), time_step, dtype=jnp.uint32) - causal_padding = jnp.greater( + causal_mask = jnp.less_equal( jnp.expand_dims(jnp.arange(seq_len), 0), batch_time_step ) max_seq_len = min(input_mask.shape[-1], seq_len) @@ -45,15 +45,15 @@ def _compute_attention_masks( (bsz, max_seq_len), ) input_mask = ( - jnp.zeros((bsz, seq_len), dtype=jnp.bool_) + jnp.ones((bsz, seq_len), dtype=jnp.bool_) .at[:, :max_seq_len] .set(input_mask) ) - causal_padding = jnp.logical_or(causal_padding, input_mask) - attention_mask = causal_padding[:, jnp.newaxis, :].astype(jnp.bool_) + causal_mask = jnp.logical_and(causal_mask, input_mask) + attention_mask = causal_mask[:, jnp.newaxis, :].astype(jnp.bool_) - return ~attention_mask + return attention_mask @chex.dataclass @@ -133,7 +133,7 @@ def _sample_step( batch_size = sampler_state.token_buffer.shape[0] decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32) last_token = sampler_state.token_buffer[:, decoding_step] - input_mask = sampler_state.token_buffer == self.vocab.pad_id() + input_mask = sampler_state.token_buffer != self.vocab.pad_id() attention_mask = _compute_attention_masks( decoding_step, self.transformer.config.max_cache_length, input_mask ) diff --git a/gemma/sampler_test.py b/gemma/sampler_test.py index ab8b1a0..1b0ddd8 100644 --- a/gemma/sampler_test.py +++ b/gemma/sampler_test.py @@ -320,7 +320,7 @@ def test_sampler_mask_tokens_after_eos_ids(self): def test_compute_attention_mask(self): # Check that the input mask is correctly applied when total sampling steps # is lower than the max cache length. - input_mask = jnp.array([[1, 1, 0, 0, 0], [1, 1, 0, 1, 0]], dtype=jnp.bool_) + input_mask = jnp.array([[0, 0, 1, 1, 1], [0, 0, 1, 0, 1]], dtype=jnp.bool_) seq_len = 8 time_step = jnp.asarray(4, dtype=jnp.int32) attn_mask = sampler_lib._compute_attention_masks( From fe438606090d61d39c6373722f454c4235be8d42 Mon Sep 17 00:00:00 2001 From: Gemma Team Date: Tue, 20 Aug 2024 10:12:05 -0700 Subject: [PATCH 12/17] Fix test_sampler_mask_tokens_after_eos_ids test. PiperOrigin-RevId: 665414923 Change-Id: I42bc41074518e3065f85c7f1a3014fdd09cffe4c --- gemma/sampler_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemma/sampler_test.py b/gemma/sampler_test.py index 1b0ddd8..9270086 100644 --- a/gemma/sampler_test.py +++ b/gemma/sampler_test.py @@ -315,7 +315,7 @@ def test_sampler_mask_tokens_after_eos_ids(self): ) self.assertListEqual(list(masked_token_buffer[0]), [1, 5, 6, 2, 0, 0]) - self.assertListEqual(list(masked_token_buffer[0]), [1, 5, 6, 2, 0, 0]) + self.assertListEqual(list(masked_token_buffer[1]), [1, 3, 4, 2, 0, 0]) def test_compute_attention_mask(self): # Check that the input mask is correctly applied when total sampling steps From 5b504a7d7fa1ffebb45962d63ede4a06cb3d524a Mon Sep 17 00:00:00 2001 From: Gemma Team Date: Fri, 13 Sep 2024 08:24:05 -0700 Subject: [PATCH 13/17] Fix Feedforward init to enable learning when training from scratch. Currently all weights in FeedForward layers are initialized to zero. This doesn't cause any issues when loading the module with pretrained weights, but if training from scratch it will result in all gradients being zero throughout training so no learning can occur. Changing w_gating be be initialized from a normal distribution fixes this. PiperOrigin-RevId: 674306730 Change-Id: I90800dbe605cdf88f341d103f102357ff278a393 --- gemma/modules.py | 4 ++-- gemma/modules_test.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/gemma/modules.py b/gemma/modules.py index cd77889..d720fbf 100644 --- a/gemma/modules.py +++ b/gemma/modules.py @@ -218,14 +218,14 @@ def __call__(self, x): if self.transpose_gating_einsum: w_gating = self.param( 'gating_einsum', - nn.initializers.zeros_init(), + nn.initializers.normal(), ((2, self.hidden_dim, self.features)), ) w_gating = w_gating.transpose((0, 2, 1)) else: w_gating = self.param( 'gating_einsum', - nn.initializers.zeros_init(), + nn.initializers.normal(), ((2, self.features, self.hidden_dim)), ) ff_gate = jnp.dot(x, w_gating[0]) diff --git a/gemma/modules_test.py b/gemma/modules_test.py index 0ad408a..257cd72 100644 --- a/gemma/modules_test.py +++ b/gemma/modules_test.py @@ -284,6 +284,40 @@ def test_ffw(self, transpose_gating_einsum: bool): np.testing.assert_array_almost_equal(outputs[:, 0, 0], expected_val) self.assertEqual(outputs.shape, expected_shape) + @parameterized.parameters( + dict( + transpose_gating_einsum=False, + expected_grad=[-1.916515e-04, -5.391428e-05, -2.923766e-04], + ), + dict( + transpose_gating_einsum=True, + expected_grad=[1.574128e-05, -1.301362e-04, -1.037612e-04], + ), + ) + def test_ffw_grad(self, transpose_gating_einsum: bool, + expected_grad: list[float]): + features = 2 + hidden_dim = 3 + batch_size = 2 + inputs = jnp.arange(1, batch_size + 1)[:, None, None] + inputs = jnp.repeat(inputs, features, axis=-1) + ffw = modules.FeedForward( + features=features, + hidden_dim=hidden_dim, + transpose_gating_einsum=transpose_gating_einsum, + ) + loss = lambda params, inputs: jnp.square( + ffw.apply(params, inputs) - jnp.ones((batch_size, 1, features)) + ).mean() + + params = ffw.init(jax.random.PRNGKey(0), inputs) + + grad_loss = jax.grad(loss) + grad = grad_loss(params, inputs) + np.testing.assert_array_almost_equal( + grad['params']['linear'][:, 0], expected_grad + ) + class BlockTest(absltest.TestCase): From c13384a0a812980a5d9b56801d54da419955168d Mon Sep 17 00:00:00 2001 From: Gemma Team Date: Fri, 13 Sep 2024 12:31:59 -0700 Subject: [PATCH 14/17] Fix a bug in sliding window attention. PiperOrigin-RevId: 674394389 Change-Id: I25ba5ad4769c3101c2bf572e33723d4a241e3895 --- gemma/modules.py | 45 +++++++++++++++++--- gemma/modules_test.py | 99 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 6 deletions(-) diff --git a/gemma/modules.py b/gemma/modules.py index d720fbf..28d1a83 100644 --- a/gemma/modules.py +++ b/gemma/modules.py @@ -25,6 +25,37 @@ LayerCache = dict[str, jax.Array] +def _create_sliding_mask( + segment_pos: jnp.ndarray, + end_index: int, + cache_len: int, + sliding_window_size: int, +): + """Creates mask for sliding window attention.""" + total_tokens = end_index + segment_pos.shape[1] # cached + processing tokens + + def _reconstruct_rotated_cache_positions(): + cache_positions = jnp.arange(cache_len) + total_tokens - cache_len + cache_positions = ( + jnp.zeros_like(cache_positions) + # kv were placed at index (position_id % cache_len) in the cache. + .at[cache_positions % cache_len].set(cache_positions) + ) + return cache_positions + + # Reconstruct position_ids for cached kv. + cache_positions = jax.lax.cond( + total_tokens <= cache_len, + lambda: jnp.arange(cache_len), + _reconstruct_rotated_cache_positions, + ) + + segment_pos = segment_pos[:, :, None] + sliding_mask = (cache_positions > segment_pos - sliding_window_size) + sliding_mask *= (cache_positions < segment_pos + sliding_window_size) + return sliding_mask + + class AttentionType(enum.Enum): GLOBAL = 1 LOCAL_SLIDING = 2 @@ -150,12 +181,14 @@ def __call__( raise ValueError( 'Sliding_window_size must be set if Local Sliding attention type' ) - - all_ones = jnp.ones_like(attn_mask) - sliding_mask = jnp.triu( - all_ones, -1 * self.sliding_window_size + 1 - ) * jnp.tril(all_ones, self.sliding_window_size - 1) - attn_mask = sliding_mask * attn_mask + sliding_mask = _create_sliding_mask( + segment_pos, + end_index=cache['end_index'][0] if cache is not None else 0, + # Derive cache length from attn_mask shape in case cache is None + cache_len=attn_mask.shape[-1], + sliding_window_size=self.sliding_window_size, + ) + attn_mask *= sliding_mask padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK) probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype) diff --git a/gemma/modules_test.py b/gemma/modules_test.py index 257cd72..75bae30 100644 --- a/gemma/modules_test.py +++ b/gemma/modules_test.py @@ -54,6 +54,105 @@ def test_decodes(self): np.testing.assert_array_equal(output, jnp.array(expected)) +class SlidingWindowTest(absltest.TestCase): + + def test_create_sliding_mask_decode_none_rotated_cache_pos(self): + cache_len = 4 + end_index = 1 + segment_pos = jnp.array([[1]]) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=1 + ) + np.testing.assert_array_equal( + sliding_mask, + [[[False, True, False, False]]], + ) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=2 + ) + np.testing.assert_array_equal( + sliding_mask, + [[[True, True, True, False]]], + ) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=3 + ) + np.testing.assert_array_equal( + sliding_mask, + [[[True, True, True, True]]], + ) + + def test_create_sliding_mask_decode_rotated_cache_pos(self): + cache_len = 4 + end_index = 5 + segment_pos = jnp.array([[5]]) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=1 + ) + np.testing.assert_array_equal( + sliding_mask, + # cache_positions = [ + # 4, 5, 2, 3, + # ] + [[[False, True, False, False]]], + ) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=2 + ) + np.testing.assert_array_equal( + sliding_mask, + [[[True, True, False, False]]], + ) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=3 + ) + np.testing.assert_array_equal( + sliding_mask, + [[[True, True, False, True]]], + ) + + def test_create_sliding_mask_prefill_rotated_cache_pos(self): + cache_len = 4 + end_index = 5 + segment_pos = jnp.array([[5, 6]]) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=1 + ) + np.testing.assert_array_equal( + sliding_mask, + # cache_positions = [ + # 4, 5, 6, 3, + # ] + [[[False, True, False, False], + [False, False, True, False],]], + ) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=2 + ) + np.testing.assert_array_equal( + sliding_mask, + [[[True, True, True, False], + [False, True, True, False],]], + ) + + sliding_mask = modules._create_sliding_mask( + segment_pos, end_index, cache_len, sliding_window_size=3 + ) + np.testing.assert_array_equal( + sliding_mask, + [[[True, True, True, True], + [True, True, True, False],]], + ) + + class AttentionTest(absltest.TestCase): def _get_attn_output( From 2e62333713e826d12063ad4facab891df0adb9cd Mon Sep 17 00:00:00 2001 From: Gemma Team Date: Mon, 16 Sep 2024 09:29:07 -0700 Subject: [PATCH 15/17] Explicitly promote rank when creating sliding mask, as some tests raise errors for implicit rank promotion. PiperOrigin-RevId: 675179053 Change-Id: I55459c1aa99c7d33ae3f03712eaed01ccc5fc9f2 --- gemma/modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gemma/modules.py b/gemma/modules.py index 28d1a83..b901386 100644 --- a/gemma/modules.py +++ b/gemma/modules.py @@ -50,7 +50,8 @@ def _reconstruct_rotated_cache_positions(): _reconstruct_rotated_cache_positions, ) - segment_pos = segment_pos[:, :, None] + cache_positions = cache_positions[None, None, :] # [1, 1, cache_len] + segment_pos = segment_pos[:, :, None] # [B, seq_len, 1] sliding_mask = (cache_positions > segment_pos - sliding_window_size) sliding_mask *= (cache_positions < segment_pos + sliding_window_size) return sliding_mask From b58703c3e19f0f4dca45a74da5e12133117d222e Mon Sep 17 00:00:00 2001 From: carlofisicaro Date: Sat, 21 Sep 2024 19:44:02 +0000 Subject: [PATCH 16/17] build: add dependeces --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 0998a95..dbba979 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,10 @@ pytest = {version = "^8.0.0", optional = true} [tool.poetry.extras] test = ["pytest"] +scikit-learn = ["scikit-learn"] +huggingface_hub = ["huggingface_hub"] +datasets = ["datasets"] +jupyterlab = ["jupyterlab"] [build-system] requires = ["poetry-core"] From dcc5ee4e580e6b39598e246063aa6c881218231e Mon Sep 17 00:00:00 2001 From: carlofisicaro Date: Sat, 21 Sep 2024 19:51:25 +0000 Subject: [PATCH 17/17] feat: add flawless logit --- gemma/transformer.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/gemma/transformer.py b/gemma/transformer.py index c988624..5bd3969 100644 --- a/gemma/transformer.py +++ b/gemma/transformer.py @@ -313,13 +313,26 @@ def __call__( cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch x = self.final_norm(x) - logits = self.embedder.decode(x) + baseline_logits = self.embedder.decode(x) + # Compute logits for each token in isolation + single_token_logits = [] + for token in last_tokens: + single_token_x = self.embedder.encode(token.reshape(1, -1)) + single_token_x = self.final_norm(single_token_x) + single_token_logits.append(self.embedder.decode(single_token_x)) + + # Normalize and adjust + normalized_single_token_logits = jax.nn.softmax(jnp.stack(single_token_logits), axis=-1) + normalized_sum = jnp.sum(normalized_single_token_logits, axis=0) + adjusted_logits = baseline_logits - normalized_sum + + # Apply softcap if configured if self.config.final_logit_softcap is not None: - logits /= self.config.final_logit_softcap - logits = jnp.tanh(logits) * self.config.final_logit_softcap + adjusted_logits /= self.config.final_logit_softcap + adjusted_logits = np.tanh(adjusted_logits) * self.config.final_logit_softcap - return logits, cache # pytype: disable=bad-return-type + return adjusted_logits, cache # pytype: disable=bad-return-type def make_causal_attn_mask(