diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..064817f --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,21 @@ +## Changes + +> What changes does this PR introduce? Link to the issue if applicable. + +## Blocked by + +> Which PRs need to be merged before this PR can be merged? + +## Testing + +> How was this PR tested? Make sure to mark the checkboxes and add more based on your testing. It may be helpful to write a testing plan prior to finishing the PR. + +- [ ] + +## Reviewing + +> How should the reviewer review and test this PR? + +## Documentation + +> How have any major architectural decisions and changes in this PR been documented? diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..da8db79 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,47 @@ +name: Cartesia Python CI + +on: + push: + branches: [ main ] + paths: + # all python files + - '**/*.py' + # This file + - '**/lint.yml' + pull_request: + branches: [ main ] + paths: + - '**/*.py' + # This file + - '**/lint.yml' + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + + Linting: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip + + - name: Install Dependencies + run: | + make install-dev + + - name: Lint with ruff + run: | + make lint diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e0ad882 --- /dev/null +++ b/.gitignore @@ -0,0 +1,184 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +*.dylib +*.metallib +.Python +cartesia-mlx/build/ +cartesia-metal/build +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Logs +logs/ +speed.py + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Ruff cache +.ruff_cache/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# JetBrains IDEs +.idea/ +.vscode/ + +# dependencies +node_modules +.pnp +.pnp.js +.yarn + +# testing +coverage + +# next.js +.next/ +out/ +.swc/ + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# local env files +.env +.env*.local + +# turbo +.turbo + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + +# secrets +*.secrets + +# Local .terraform directories +**/.terraform/* + +# .tfstate files +*.tfstate +*.tfstate.* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f0f1b03 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +# NOTE: These versions should match those in requirements-dev.txt +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.4.10 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format +- repo: https://github.com/timothycrosley/isort + rev: 5.13.2 + hooks: + - id: isort diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..1460ebb --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +@ad12 @nimz @tobiaskatsch diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..c99c96f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,69 @@ +# Contributing + +Much of this guide was adopted from [AXLearn](https://github.com/apple/axlearn/blob/main/CONTRIBUTING.md). + +## General Guidelines +1. Please do not commit broken code to any branch. +1. Only commit stable config files. Config files that are in development should not be committed. +1. Use pull requests (PR) to the merge in code. Do not develop on the main branch. + +## Coding Style + +We follow [the Google Python Style Guide](https://google.github.io/styleguide/pyguide.html). + +To avoid confusion around naming, respect code consistency in the repo, including but not limited to naming conventions and code structure. + +If you have not already, we recommend [setting up `pre-commit`](docs/01-start.md#optional-additional-setup-for-developers), which runs some of the linters/formatters prior to each commit. The same checks will be required to pass in CI, so this will help make the development process smoother. + +### Type Annotations + +Functions and methods must be annotated with types. We use [pytype](https://google.github.io/pytype/user_guide.html) for type checking. + +## Code Review Process + +### Author + +All PRs must be made with the default template provided in the repository. + +Before embarking on a major PR, send out a sketch PR (including the high level design notes in the PR description) to solicit feedback first. + +It is more manageable for both the author and reviewers to have small PRs that can be quickly reviewed and merged. + +When selecting multiple reviewers, use "Assignees" to indicate that approvals from specific +reviewers are required before merging. + +The PR authors are expected to reply to each comment they have addressed, e.g., with "Done". +However, they should refrain from resolving the comments -- instead, let the reviewers do so. + +When addressing a comment, pay attention to other places where the same comment may apply, so that +the reviewer does not have to repeat the comment. + +When a PR is ready for another round of review, click [**re-request review**](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/requesting-a-pull-request-review) to notify the reviewer. + +Although not strictly enforced, a general etiquette is to wait for all reviewers who have left comments to approve the PR prior to merging, even if the PR has already received an approval. + +In some cases, the comments may be "nits" and appropriate for addressing in follow-up PRs. We encourage authors to give a heads up to reviewers prior to merging a PR with unresolved comments by leaving PR comments (e.g. "Will address in a follow-up") as well as [leaving a TODO](#leaving-todos) where applicable. + +### Reviewer + +People on the team should feel free to add themselves as reviewers and/or to assignees. + +Consider prioritizing reviews over writing one's own code. +This will make the entire team more productive. + +Code review does not end with merge of the PR. +Reviewers should feel free to add comments after the merge, which can be addressed in follow-up PRs. + +## Attributions + +Code that refers to (or is adapted from) other sources must explicitly reference the original source by providing a [link in the docstring](https://github.com/apple/axlearn/blob/669f0cae6249e165caa1a94cf64b12e77bf4cfdf/axlearn/common/attention.py#L360-L365) of the corresponding function, class, etc. + +Code that is adapted from papers should have clear `Reference:` section in the docstring that provides a complete reference (with link) to the original paper. Links to ArXiv papers are preferred to avoid paywalls. + +## Leaving TODOs + +In some cases it's useful to track future work ("TODOs"). +TODOs should have the format `TODO(username1,username2)` indicating the contributor(s) responsible for addressing the TODO. +Please use your actual GitHub username as opposed to an alias to avoid ambiguity. + +For larger work items, consider creating a GitHub issue to track progress. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..9d5c427 --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +install-dev: + pip install -r requirements-dev.txt + +# Pytype can be quite expensive to run. +# TODO: Determine if we want to use a lazier type checker. +# cd cartesia-mlx && pytype --disable=import-error . && cd .. +# cd cartesia-metal && pytype --disable=import-error . && cd .. +# cd cartesia-pytorch && pytype --disable=import-error . && cd .. +lint: + isort -c . + ruff check . + ruff format --check . + + +autoformat: + isort --atomic . + ruff format . diff --git a/README.md b/README.md new file mode 100644 index 0000000..8042189 --- /dev/null +++ b/README.md @@ -0,0 +1,30 @@ +# 📏 Edge + +Edge is an open-source library to support the research and release of efficient state space models (SSMs) for on-device applications across multiple accelerators and environments. + +Current state-of-the-art models are often too large and slow to run on-device, requiring cloud-based solutions that are expensive and slow. Edge aims to change this by providing a library of efficient SSM models that can run on-device in real-time. + +State-Space Models (SSMs) offer a fundamentally more computationally efficient, yet high quality, approach to building AI models compared to Transformer models. They provide an elegant foundation for training efficient real-time models that can natively stream information with constant tokens per second (tok/s) and memory consumption, making them a game changer for on-device applications. + +With Edge, you have: + +📏 **Custom inference kernels:** Edge supports custom hardware-specialized inference kernels for SSM architectures, like Mamba + +📏 **Local SSM models:** Edge brings together open-weights SSM models and optimizes them for multiple hardware platforms. + +## What's New +- **[08/27/2024]**: We are excited to release Edge and our first openly available small language model, [Rene-v0.1](https://huggingface.co/cartesia-ai/Rene-v0.1-1.3b-pytorch). Rene is available with both PyTorch and MLX with quantization support and custom hardware kernels. + +## Available Packages +Edge currently hosts packages for multiple backends: +- [`cartesia-pytorch`](cartesia-pytorch/): This package contains the PyTorch implementations of our models and can be used by any hardware where PyTorch is supported (e.g. CPU, CUDA GPU). +- [`cartesia-metal`](cartesia-metal/): This package contains custom metal kernels for fast SSM inference on Apple silicon on laptop and mobile. +- [`cartesia-mlx`](cartesia-mlx/): A Python package containing models to run on [MLX](https://github.com/ml-explore/mlx) + +## Run Rene on MLX +For a full list of the available models in MLX and an example of how to run these models, see the [cartesia-mlx README](cartesia-mlx/README.md). + +![Language Model](cartesia-mlx/assets/lm-demo.gif) + +## Looking for custom support from Cartesia? +We got your back! Contact us [here](bit.ly/cartesiaondevice). diff --git a/cartesia-metal/CMakeLists.txt b/cartesia-metal/CMakeLists.txt new file mode 100644 index 0000000..8079b4b --- /dev/null +++ b/cartesia-metal/CMakeLists.txt @@ -0,0 +1,82 @@ +cmake_minimum_required(VERSION 3.27) + +project(_ext LANGUAGES CXX) + +# ----------------------------- Setup ----------------------------- +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) + +# ----------------------------- Dependencies ----------------------------- +find_package(MLX CONFIG REQUIRED) +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE NB_DIR) +list(APPEND CMAKE_PREFIX_PATH "${NB_DIR}") +find_package(nanobind CONFIG REQUIRED) + +# ----------------------------- Extensions ----------------------------- + +# Add library +add_library(mlx_ext) + +# Add sources +target_sources( + mlx_ext + PUBLIC + ${CMAKE_CURRENT_LIST_DIR}/src/conv1d_update.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/conv1d_forward.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/ssm_update.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/ssd_update.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/ssd_update_no_z.cpp +) + +# Add include headers +target_include_directories( + mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR} +) + +# Link to mlx +target_link_libraries(mlx_ext PUBLIC mlx) + +# ----------------------------- Metal ----------------------------- + +# Build metallib +if(MLX_BUILD_METAL) + mlx_build_metallib( + TARGET mlx_ext_metallib + TITLE mlx_ext + SOURCES + ${CMAKE_CURRENT_LIST_DIR}/src/conv1d_update.metal + ${CMAKE_CURRENT_LIST_DIR}/src/conv1d_forward.metal + ${CMAKE_CURRENT_LIST_DIR}/src/ssm_update.metal + ${CMAKE_CURRENT_LIST_DIR}/src/ssd_update.metal + ${CMAKE_CURRENT_LIST_DIR}/src/ssd_update_no_z.metal + INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} + OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY} + ) + + + add_dependencies( + mlx_ext + mlx_ext_metallib + ) + +endif() + +# ----------------------------- Python Bindings ----------------------------- +nanobind_add_module( + _ext + NB_STATIC STABLE_ABI LTO NOMINSIZE + NB_DOMAIN mlx + ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp +) +target_link_libraries(_ext PRIVATE mlx_ext) + +if(BUILD_SHARED_LIBS) + target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path) +endif() + diff --git a/cartesia-metal/MANIFEST.in b/cartesia-metal/MANIFEST.in new file mode 100644 index 0000000..f4e63d8 --- /dev/null +++ b/cartesia-metal/MANIFEST.in @@ -0,0 +1,2 @@ +include README.md +recursive-include cartesia_metal * diff --git a/cartesia-metal/README.md b/cartesia-metal/README.md new file mode 100644 index 0000000..6be7a8e --- /dev/null +++ b/cartesia-metal/README.md @@ -0,0 +1,18 @@ +# Cartesia Metal + +This package contains Metal kernels for fast on-device SSM inference on Apple silicon. + +## Installation +To install this package, first install Xcode, which can be downloaded from https://developer.apple.com/xcode/. +Accept the license agreement with: +```shell +sudo xcodebuild -license +``` + +Install the dependency `nanobind` (exact version required), followed by `cartesia-metal`: +```shell +pip install nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 +pip install git+https://github.com/cartesia-ai/edge.git#subdirectory=cartesia-metal +``` + +Note: This package has been tested on macOS Sonoma 14.1 with the M3 chip. diff --git a/cartesia-metal/bindings.cpp b/cartesia-metal/bindings.cpp new file mode 100644 index 0000000..4d15c91 --- /dev/null +++ b/cartesia-metal/bindings.cpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "src/conv1d_update.h" +#include "src/conv1d_forward.h" +#include "src/ssm_update.h" +#include "src/ssd_update.h" +#include "src/ssd_update_no_z.h" + +namespace nb = nanobind; +using namespace nb::literals; + +using namespace mlx::core; + +NB_MODULE(_ext, m) { + + m.def( + "conv1d_update", + [](const array& x, const array& w, const array& b, const array& state, StreamOrDevice s) { + return conv1d_update(x, w, b, state, s); + }, + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none() + ); + + m.def( + "conv1d_forward_", + [](const array& x, const array& w, const array& b, StreamOrDevice s) { + return conv1d_forward(x, w, b, s); + }, + nb::arg(), + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none() + ); + + m.def( + "ssm_update", + [](const array& x, const array& dt, const array& A, const array& B, const array& C, const array& D, const array& z, const array& state, StreamOrDevice s) { + return ssm_update(x, dt, A, B, C, D, z, state, s); + }, + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none() + ); + + m.def( + "ssd_update_", + [](const array& x, const array& dt, const array& decay, const array& B, const array& C, const array& D, const array& z, const array& state, StreamOrDevice s) { + return ssd_update(x, dt, decay, B, C, D, z, state, s); + }, + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none() + ); + + m.def( + "ssd_update_no_z_", + [](const array& x, const array& dt, const array& decay, const array& B, const array& C, const array& D, const array& state, StreamOrDevice s) { + return ssd_update_no_z(x, dt, decay, B, C, D, state, s); + }, + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::arg(), + nb::kw_only(), + "stream"_a = nb::none() + ); +} diff --git a/cartesia-metal/cartesia_metal/__init__.py b/cartesia-metal/cartesia_metal/__init__.py new file mode 100644 index 0000000..eef1805 --- /dev/null +++ b/cartesia-metal/cartesia_metal/__init__.py @@ -0,0 +1,7 @@ +import mlx.core # noqa: F401 + +from cartesia_metal.interface import conv1d_forward, ssd_update, ssd_update_no_z + +from ._ext import conv1d_update + +__all__ = ["conv1d_forward", "conv1d_update", "ssd_update", "ssd_update_no_z"] diff --git a/cartesia-metal/cartesia_metal/interface.py b/cartesia-metal/cartesia_metal/interface.py new file mode 100644 index 0000000..922dae3 --- /dev/null +++ b/cartesia-metal/cartesia_metal/interface.py @@ -0,0 +1,125 @@ +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn + +from ._ext import conv1d_forward_, ssd_update_, ssd_update_no_z_ + + +@mx.compile +def conv1d_forward( + x: mx.array, + w: mx.array, + b: mx.array, + state: mx.array, +) -> Tuple[mx.array, mx.array]: + """Performs a 1D depthwise convolution forward pass. + + Args: + x: Input tensor of shape (b, l, d). + w: Convolution weights of shape (d, k). + b: Convolution bias of shape (d,). + state: Optional tensor of shape (b, d, k-1) to maintain state across inputs. + + Returns: + Output tensor of shape (b, l, d) and the updated state tensor. + """ + kernel_size = w.shape[1] + x = x.swapaxes(1, 2) + x = mx.concatenate([state, x], axis=-1) + next_state = x[:, :, -kernel_size + 1 :] + next_state = mx.as_strided(next_state) + y = conv1d_forward_(x, w, b)[0] + y = y[..., : -kernel_size + 1] + y = y.swapaxes(1, 2) + return y, next_state + + +@mx.compile +def ssd_update( + x: mx.array, + dt: mx.array, + A: mx.array, + B: mx.array, + C: mx.array, + D: mx.array, + z: mx.array, + dt_bias: mx.array, + dt_min: float, + dt_max: float, + state: mx.array, +) -> Tuple[mx.array, mx.array]: + """Updates the state space model with the given inputs and parameters. + + Args: + x: Input tensor of shape (b, h, dh). + dt: Time deltas tensor of shape (b, h). + A: State transition tensor of shape (h). + B: Input mixing tensor of shape (b, g, n). + C: Output mixing tensor of shape (b, g, n). + D: Residual connection tensor of shape (h). + z: Gating tensor of shape (b, h*dh). + dt_bias: Bias for the time deltas of shape (h). + dt_min: Minimum value for time deltas after clipping. + dt_max: Maximum value for time deltas after clipping. + state: State tensor of shape (b, h, dh, n). + + Returns: + tuple: Output tensor reshaped to (b, h*dh) and the updated state tensor. + """ + b, h, dh = x.shape + + if dt_bias is not None: + dt = dt + dt_bias.reshape(1, -1) + + dt = nn.softplus(dt) + dt = mx.clip(dt, a_min=dt_min, a_max=dt_max).astype(x.dtype) + decay = mx.exp(dt * A.reshape(1, -1)) # (b, h) + + x, state = ssd_update_(x, dt, decay, B, C, D, z, state) + x = x.reshape(b, h * dh) + return x, state + + +@mx.compile +def ssd_update_no_z( + x: mx.array, + dt: mx.array, + A: mx.array, + B: mx.array, + C: mx.array, + D: mx.array, + dt_bias: mx.array, + dt_min: float, + dt_max: float, + state: mx.array, +) -> Tuple[mx.array, mx.array]: + """Updates the state space model without the gating. + + Args: + x: Input tensor of shape (b, h, dh). + dt: Time deltas tensor of shape (b, h). + A: State transition tensor of shape (h). + B: Input mixing tensor of shape (b, g, n). + C: Output mixing tensor of shape (b, g, n). + D: Residual connection tensor of shape (h). + dt_bias: Bias for the time deltas of shape (h). + dt_min: Minimum value for time deltas after clipping. + dt_max: Maximum value for time deltas after clipping. + state: State tensor of shape (b, h, dh, n). + + Returns: + tuple: Output tensor reshaped to (b, h*dh) and the updated state tensor. + """ + b, h, dh = x.shape + + if dt_bias is not None: + dt = dt + dt_bias.reshape(1, -1) + + dt = nn.softplus(dt) + dt = mx.clip(dt, a_min=dt_min, a_max=dt_max).astype(x.dtype) + decay = mx.exp(dt * A.reshape(1, -1)) # (b, h) + + x, state = ssd_update_no_z_(x, dt, decay, B, C, D, state) + x = x.reshape(b, h * dh) + return x, state diff --git a/cartesia-metal/cartesia_metal/version.py b/cartesia-metal/cartesia_metal/version.py new file mode 100644 index 0000000..5f66355 --- /dev/null +++ b/cartesia-metal/cartesia_metal/version.py @@ -0,0 +1 @@ +__version__ = "0.0.1rc2" diff --git a/cartesia-metal/requirements.txt b/cartesia-metal/requirements.txt new file mode 100644 index 0000000..15000c0 --- /dev/null +++ b/cartesia-metal/requirements.txt @@ -0,0 +1,2 @@ +mlx==0.16.1 # TODO: upgrade to latest mlx version once nanobind for mlx-extensions is upgraded: https://github.com/ml-explore/mlx/issues/1342 +cmake>=3.24 diff --git a/cartesia-metal/setup.py b/cartesia-metal/setup.py new file mode 100644 index 0000000..e349fc3 --- /dev/null +++ b/cartesia-metal/setup.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import io +import os + +# Note: To use the 'upload' functionality of this file, you must: +# $ pipenv install twine --dev +import subprocess +import sys +from distutils.util import convert_path +from shutil import rmtree + +from setuptools import Command, find_packages, setup + +PACKAGE_DIR = "cartesia_metal" +main_ns = {} +ver_path = convert_path(os.path.join(PACKAGE_DIR, "version.py")) +with open(ver_path) as ver_file: + exec(ver_file.read(), main_ns) + +# Package meta-data. +NAME = "cartesia-metal" +DESCRIPTION = "The official Cartesia metal bindings." +URL = "" +EMAIL = "support@cartesia.ai" +AUTHOR = "Cartesia, Inc." +REQUIRES_PYTHON = ">=3.9.0" +VERSION = main_ns["__version__"] + + +# What packages are required for this module to be executed? +def get_requirements(path): + with open(path, "r") as f: + out = f.read().splitlines() + + out = [line.strip() for line in out] + return out + + +REQUIRED = get_requirements("requirements.txt") +EXTRAS = {} + +# Ensure Xcode and mlx is installed +here = os.path.abspath(os.path.dirname(__file__)) +subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "-r", + os.path.join(here, "requirements.txt"), + ] +) + + +# Import the README and use it as the long-description. +# Note: this will only work if 'README.md' is present in your MANIFEST.in file! +try: + with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: + long_description = "\n" + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + +# Load the package's __version__.py module as a dictionary. +about = {} +if not VERSION: + project_slug = NAME.lower().replace("-", "_").replace(" ", "_") + with open(os.path.join(here, project_slug, "__version__.py")) as f: + exec(f.read(), about) +else: + about["__version__"] = VERSION + + +class UploadCommand(Command): + """Support setup.py upload.""" + + description = "Build and publish the package." + user_options = [("skip-upload", "u", "skip git tagging and pypi upload")] + boolean_options = ["skip-upload"] + + @staticmethod + def status(s): + """Prints things in bold.""" + print("\033[1m{0}\033[0m".format(s)) + + def initialize_options(self): + self.skip_upload = False + + def finalize_options(self): + self.skip_upload = bool(self.skip_upload) + + def run(self): + try: + self.status("Removing previous builds…") + rmtree(os.path.join(here, "dist")) + rmtree(os.path.join(here, "build")) + except OSError: + pass + + self.status("Building Source and Wheel (universal) distribution…") + os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) + + if self.skip_upload: + self.status("Skipping git tagging and pypi upload") + sys.exit() + + self.status("Uploading the package to PyPI via Twine…") + os.system("twine upload dist/*") + + self.status("Pushing git tags…") + os.system("git tag v{0}".format(about["__version__"])) + os.system("git push --tags") + + sys.exit() + + +from mlx import extension + +if __name__ == "__main__": + setup( + name=NAME, + version=about["__version__"], + description=DESCRIPTION, + long_description=long_description, + long_description_content_type="text/markdown", + author=AUTHOR, + author_email=EMAIL, + python_requires=REQUIRES_PYTHON, + url=URL, + packages=find_packages( + exclude=[ + "scratch", + "tests", + "*.tests", + "*.tests.*", + "tests.*", + ] + ), + install_requires=REQUIRED, + extras_require=EXTRAS, + include_package_data=True, + classifiers=[ + # Trove classifiers + # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + # $ setup.py publish support. + cmdclass={"upload": UploadCommand, "build_ext": extension.CMakeBuild}, + ext_modules=[extension.CMakeExtension("cartesia_metal._ext")], + package_data={"cartesia_metal": ["*.so", "*.dylib", "*.metallib"]}, + zip_safe=False, + ) diff --git a/cartesia-metal/src/conv1d_forward.cpp b/cartesia-metal/src/conv1d_forward.cpp new file mode 100644 index 0000000..9275243 --- /dev/null +++ b/cartesia-metal/src/conv1d_forward.cpp @@ -0,0 +1,99 @@ +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include "src/conv1d_forward.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#endif + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" + +namespace mlx::core { + +std::vector conv1d_forward( + const array& x, + const array& w, + const array& b, + StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +) { + auto y_dtype = x.dtype(); + // TODO: Also make sure state is of the same dtype + + auto y_shape = x.shape(); + + return array::make_arrays( + {y_shape}, + {y_dtype}, + std::make_shared(to_stream(s)), + {x, w, b}); +} + +void Conv1dForward::eval(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval not implemented!"); +} + +#ifdef ACCELERATE_NEW_LAPACK +void Conv1dForward::eval_cpu(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval_cpu not implemented!"); +} +#endif + +void Conv1dForward::eval_gpu(const std::vector& inputs, std::vector& outputs) { + + assert(inputs.size() == 3); + assert(outputs.size() == 1); + + auto x = inputs[0]; // (b, d, l) + auto w = inputs[1]; // (d, k) + auto b = inputs[2]; // (d) + + auto y = outputs[0]; + + auto& s = stream(); + auto& d = metal::device(s.device); + + y.set_data( + allocator::malloc_or_wait(x.data_size() * y.itemsize()), + x.data_size(), + x.strides(), + x.flags() + ); + + std::ostringstream kname; + kname << "conv1d_forward_kernel_"; + kname << type_to_name(x); + + d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + auto kernel_size = w.shape(1); + auto batch_size = x.shape(0); + auto n_channels = x.shape(1); + auto seq_len = x.shape(2); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(b, 2); + compute_encoder.set_output_array(y, 3); + compute_encoder->setBytes(x.strides().data(), 3 * sizeof(size_t), 4); + compute_encoder->setBytes(&kernel_size, kernel_size * sizeof(int), 5); + + + // https://developer.apple.com/documentation/metal/compute_passes/calculating_threadgroup_and_grid_sizes + MTL::Size grid_dims = MTL::Size(batch_size, n_channels, seq_len); + size_t width = kernel->threadExecutionWidth(); + size_t height = kernel->maxTotalThreadsPerThreadgroup() / width; + MTL::Size group_dims = MTL::Size(width, height, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/conv1d_forward.h b/cartesia-metal/src/conv1d_forward.h new file mode 100644 index 0000000..a1c566c --- /dev/null +++ b/cartesia-metal/src/conv1d_forward.h @@ -0,0 +1,30 @@ +#pragma once + +#include "mlx/ops.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +std::vector conv1d_forward( + const array& x, + const array& w, + const array& b, + StreamOrDevice s = {} // Stream on which to schedule the operation +); + +class Conv1dForward : public Primitive { + public: + explicit Conv1dForward(Stream stream) + : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + + void print(std::ostream& os) override { + os << "Conv1dForward"; + } + + void eval(const std::vector& inputs, std::vector& outputs); +}; + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/conv1d_forward.metal b/cartesia-metal/src/conv1d_forward.metal new file mode 100644 index 0000000..77f894c --- /dev/null +++ b/cartesia-metal/src/conv1d_forward.metal @@ -0,0 +1,64 @@ +#include +#include +#include +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +struct SILU { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? (1 - y) * x : y * x; + } +}; + +template +[[kernel]] void conv1d_forward_kernel( + device const T* x [[buffer(0)]], // (b, d, l) + device const T* w [[buffer(1)]], // (d, k) + device const T* b [[buffer(2)]], // (d) + device T* y [[buffer(3)]], + constant const size_t* x_strides [[buffer(4)]], + constant const int& kernel_size [[buffer(5)]], + uint3 grid_idx [[thread_position_in_grid]], + uint3 grid_size [[threads_per_grid]]) { + + const int length_size = grid_size.z; + const int batch_idx = grid_idx.x; + const int channel_idx = grid_idx.y; + const int length_idx = grid_idx.z; + + const int y_idx = batch_idx * x_strides[0] + channel_idx * x_strides[1] + length_idx; + const int w_start_idx = channel_idx * kernel_size; + + T acc = 0; + + if (length_idx + kernel_size - 1 < length_size) { // prevents out of bounds access and assumes last k-1 elements at end are dropped + #pragma unroll + for (int i = 0; i < kernel_size; ++i) { + acc = acc + w[w_start_idx + i] * x[y_idx + i]; + } + acc = acc + b[channel_idx]; + acc = SILU()(acc); + } + + y[y_idx] = acc; +} + + +#define instantiate_conv1d_forward_kernel(type_name, type) \ + template [[host_name("conv1d_forward_kernel_" #type_name)]] \ + [[kernel]] void conv1d_forward_kernel( \ + device const type* x [[buffer(0)]], \ + device const type* w [[buffer(1)]], \ + device const type* b [[buffer(2)]], \ + device type* y [[buffer(3)]], \ + constant const size_t* x_strides [[buffer(4)]], \ + constant const int& kernel_size [[buffer(5)]], \ + uint3 grid_idx [[thread_position_in_grid]], \ + uint3 grid_size [[threads_per_grid]]); + +instantiate_conv1d_forward_kernel(float32, float); +instantiate_conv1d_forward_kernel(float16, half); +//instantiate_conv1d_forward_kernel(bfloat16, bfloat16_t); +//instantiate_conv1d_forward_kernel(complex64, complex64_t); \ No newline at end of file diff --git a/cartesia-metal/src/conv1d_update.cpp b/cartesia-metal/src/conv1d_update.cpp new file mode 100644 index 0000000..7383c1e --- /dev/null +++ b/cartesia-metal/src/conv1d_update.cpp @@ -0,0 +1,118 @@ +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include "src/conv1d_update.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#endif + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" + +namespace mlx::core { + +std::vector conv1d_update( + const array& x, + const array& w, + const array& b, + const array& state, + StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +) { + auto y_dtype = x.dtype(); + // TODO: Also make sure state is of the same dtype + + auto y_shape = x.shape(); + auto state_shape = state.shape(); + + return array::make_arrays( + {y_shape, state_shape}, + {y_dtype, y_dtype}, + std::make_shared(to_stream(s)), + {x, w, b, state}); +} + + +void Conv1dUpdate::eval(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval not implemented!"); +} + + +#ifdef ACCELERATE_NEW_LAPACK + +void Conv1dUpdate::eval_cpu(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval_cpu not implemented!"); +} + +#endif + + +void Conv1dUpdate::eval_gpu(const std::vector& inputs, std::vector& outputs) { + + assert(inputs.size() == 4); + assert(outputs.size() == 2); + + auto x = inputs[0]; + auto w = inputs[1]; + auto b = inputs[2]; + auto state = inputs[3]; + + auto y = outputs[0]; + auto next_state = outputs[1]; + + auto& s = stream(); + auto& d = metal::device(s.device); + + y.set_data( + allocator::malloc_or_wait(x.data_size() * y.itemsize()), + x.data_size(), + x.strides(), + x.flags() + ); + + next_state.set_data( + allocator::malloc_or_wait(state.data_size() * state.itemsize()), + state.data_size(), + state.strides(), + state.flags() + ); + + std::ostringstream kname; + kname << "conv1d_update_kernel_"; + kname << type_to_name(x); + + d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + auto kernel_size = w.shape(1); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(w, 1); + compute_encoder.set_input_array(b, 2); + compute_encoder.set_input_array(state, 3); + compute_encoder.set_output_array(y, 4); + compute_encoder.set_output_array(next_state, 5); + compute_encoder->setBytes(&kernel_size, kernel_size * sizeof(int), 6); + compute_encoder->setBytes(x.strides().data(), 2 * sizeof(size_t), 7); + compute_encoder->setBytes(state.strides().data(), 3 * sizeof(size_t), 8); + + auto batch_size = x.shape(0); + auto n_channels = x.shape(1); + + // https://developer.apple.com/documentation/metal/compute_passes/calculating_threadgroup_and_grid_sizes + MTL::Size grid_dims = MTL::Size(batch_size, n_channels, 1); + size_t width = kernel->threadExecutionWidth(); + size_t height = kernel->maxTotalThreadsPerThreadgroup() / width; + MTL::Size group_dims = MTL::Size(width, height, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/conv1d_update.h b/cartesia-metal/src/conv1d_update.h new file mode 100644 index 0000000..ce8956b --- /dev/null +++ b/cartesia-metal/src/conv1d_update.h @@ -0,0 +1,31 @@ +#pragma once + +#include "mlx/ops.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +std::vector conv1d_update( + const array& x, + const array& w, + const array& b, + const array& state, + StreamOrDevice s = {} // Stream on which to schedule the operation +); + +class Conv1dUpdate : public Primitive { + public: + explicit Conv1dUpdate(Stream stream) + : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + + void print(std::ostream& os) override { + os << "Conv1dUpdate"; + } + + void eval(const std::vector& inputs, std::vector& outputs); +}; + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/conv1d_update.metal b/cartesia-metal/src/conv1d_update.metal new file mode 100644 index 0000000..b5c87d2 --- /dev/null +++ b/cartesia-metal/src/conv1d_update.metal @@ -0,0 +1,72 @@ +#include +#include +#include +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +struct SILU { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? (1 - y) * x : y * x; + } +}; + +template +[[kernel]] void conv1d_update_kernel( + device const T* x [[buffer(0)]], // b, d + device const T* w [[buffer(1)]], // d, k + device const T* b [[buffer(2)]], // d + device const T* state [[buffer(3)]], // b, d, k-1 + device T* y [[buffer(4)]], + device T* next_state [[buffer(5)]], + constant const int& kernel_size [[buffer(6)]], + constant const size_t* x_strides [[buffer(7)]], + constant const size_t* state_strides [[buffer(8)]], + uint3 grid_idx [[thread_position_in_grid]], + uint3 grid_size [[threads_per_grid]]) { + + const int batch_idx = grid_idx.x; + const int channel_idx = grid_idx.y; + const int x_idx = batch_idx * x_strides[0] + channel_idx; + const int w_start_idx = channel_idx * kernel_size; + const int state_start_idx = batch_idx * state_strides[0] + channel_idx * state_strides[1]; + + T temp = 0; + #pragma unroll + for (int i = 0; i < kernel_size - 1; ++i) { + temp = temp + w[w_start_idx + i] * state[state_start_idx + i]; + } + + temp = temp + w[w_start_idx + kernel_size - 1] * x[x_idx]; + temp = temp + b[channel_idx]; + temp = SILU()(temp); + y[x_idx] = temp; + + #pragma unroll + for (int i = 0; i < kernel_size - 2; ++i) { + next_state[state_start_idx + i] = state[state_start_idx + i + 1]; + } + next_state[state_start_idx + kernel_size - 2] = x[x_idx]; +} + + +#define instantiate_conv1d_update_kernel(type_name, type) \ + template [[host_name("conv1d_update_kernel_" #type_name)]] \ + [[kernel]] void conv1d_update_kernel( \ + device const type* x [[buffer(0)]], \ + device const type* w [[buffer(1)]], \ + device const type* b [[buffer(2)]], \ + device const type* state [[buffer(3)]], \ + device type* y [[buffer(4)]], \ + device type* next_state [[buffer(5)]], \ + constant const int& kernel_size [[buffer(6)]], \ + constant const size_t* x_strides [[buffer(7)]], \ + constant const size_t* state_strides [[buffer(8)]], \ + uint3 grid_idx [[thread_position_in_grid]], \ + uint3 grid_size [[threads_per_grid]]); + +instantiate_conv1d_update_kernel(float32, float); +instantiate_conv1d_update_kernel(float16, half); +//instantiate_conv1d_update_kernel(bfloat16, bfloat16_t); +//instantiate_conv1d_update_kernel(complex64, complex64_t); \ No newline at end of file diff --git a/cartesia-metal/src/ssd_update.cpp b/cartesia-metal/src/ssd_update.cpp new file mode 100644 index 0000000..2caf6fe --- /dev/null +++ b/cartesia-metal/src/ssd_update.cpp @@ -0,0 +1,137 @@ +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include "src/ssd_update.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#endif + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" + +namespace mlx::core { + +std::vector ssd_update( + const array& x, + const array& dt, + const array& decay, + const array& B, + const array& C, + const array& D, + const array& z, + const array& state, + StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +) { + auto y_dtype = x.dtype(); + // TODO: Also make sure state is of the same dtype + + auto y_shape = x.shape(); + auto state_shape = state.shape(); + + return array::make_arrays( + {y_shape, state_shape}, + {y_dtype, y_dtype}, + std::make_shared(to_stream(s)), + {x, dt, decay, B, C, D, z, state}); +} + +void SSDUpdate::eval(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval not implemented!"); +} + + +#ifdef ACCELERATE_NEW_LAPACK + +void SSDUpdate::eval_cpu(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval_cpu not implemented!"); +} + +#endif + + +void SSDUpdate::eval_gpu(const std::vector& inputs, std::vector& outputs) { + + assert(inputs.size() == 8); + assert(outputs.size() == 2); + + auto x = inputs[0]; + auto dt = inputs[1]; + auto decay = inputs[2]; + auto B = inputs[3]; + auto C = inputs[4]; + auto D = inputs[5]; + auto z = inputs[6]; + auto state = inputs[7]; + + auto y = outputs[0]; + auto next_state = outputs[1]; + + auto& s = stream(); + auto& d = metal::device(s.device); + + y.set_data( + allocator::malloc_or_wait(x.data_size() * y.itemsize()), + x.data_size(), + x.strides(), + x.flags() + ); + + next_state.set_data( + allocator::malloc_or_wait(state.data_size() * state.itemsize()), + state.data_size(), + state.strides(), + state.flags() + ); + + std::ostringstream kname; + kname << "ssd_update_kernel_"; + kname << type_to_name(x); + + d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(dt, 1); + compute_encoder.set_input_array(decay, 2); + compute_encoder.set_input_array(B, 3); + compute_encoder.set_input_array(C, 4); + compute_encoder.set_input_array(D, 5); + compute_encoder.set_input_array(z, 6); + compute_encoder.set_input_array(state, 7); + compute_encoder.set_output_array(y, 8); + compute_encoder.set_output_array(next_state, 9); + + + auto b = x.shape(0); + auto h = x.shape(1); + auto dh = x.shape(2); + auto state_size = state.shape(3); + auto g = B.shape(1); + auto group_size = h / g; + + compute_encoder->setBytes(&group_size, sizeof(size_t), 10); + compute_encoder->setBytes(&state_size, sizeof(size_t), 11); + + compute_encoder->setBytes(x.strides().data(), 3 * sizeof(size_t), 12); + compute_encoder->setBytes(dt.strides().data(), 2 * sizeof(size_t), 13); + compute_encoder->setBytes(B.strides().data(), 3 * sizeof(size_t), 14); + compute_encoder->setBytes(state.strides().data(), 4 * sizeof(size_t), 15); + + // https://developer.apple.com/documentation/metal/compute_passes/calculating_threadgroup_and_grid_sizes + MTL::Size grid_dims = MTL::Size(b, h, dh); + // size_t width = kernel->threadExecutionWidth(); + // size_t height = kernel->maxTotalThreadsPerThreadgroup() / width; + MTL::Size group_dims = MTL::Size(32, 32, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/ssd_update.h b/cartesia-metal/src/ssd_update.h new file mode 100644 index 0000000..fac6d40 --- /dev/null +++ b/cartesia-metal/src/ssd_update.h @@ -0,0 +1,36 @@ +#pragma once + +#include "mlx/ops.h" +#include "mlx/primitives.h" + +namespace mlx::core { + + +std::vector ssd_update( + const array& x, + const array& dt, + const array& A, + const array& B, + const array& C, + const array& D, + const array& z, + const array& state, + StreamOrDevice s = {} // Stream on which to schedule the operation +); + +class SSDUpdate : public Primitive { + public: + explicit SSDUpdate(Stream stream) + : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + + void print(std::ostream& os) override { + os << "SSDUpdate"; + } + + void eval(const std::vector& inputs, std::vector& outputs); +}; + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/ssd_update.metal b/cartesia-metal/src/ssd_update.metal new file mode 100644 index 0000000..0ec9f25 --- /dev/null +++ b/cartesia-metal/src/ssd_update.metal @@ -0,0 +1,99 @@ +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +// #include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/utils.h" + + + +struct SILU { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? (1 - y) * x : y * x; + } +}; + +template +[[kernel]] void ssd_update_kernel( + // Input + device const T* x [[buffer(0)]], // (b, h, dh) + device const T* dt [[buffer(1)]], // (b, h) + device const T* decay [[buffer(2)]], // (b, h) + device const T* B [[buffer(3)]], // (b, g, n) + device const T* C [[buffer(4)]], // (b, g, n) + device const T* D [[buffer(5)]], // (h) + device const T* z [[buffer(6)]], // (b, d) + device const T* state [[buffer(7)]], // (b, h, dh, n) + device T* y [[buffer(8)]], + device T* next_state [[buffer(9)]], + // Parameters + constant const int& group_size [[buffer(10)]], // h = group_size * g + constant const int& state_size [[buffer(11)]], + // Strides + constant const size_t* x_strides [[buffer(12)]], + constant const size_t* dt_strides [[buffer(13)]], + constant const size_t* CB_strides [[buffer(14)]], + constant const size_t* state_strides [[buffer(15)]], + // Grid + uint3 grid_idx [[thread_position_in_grid]], + uint3 grid_size [[threads_per_grid]]) { + + const int b_idx = grid_idx.x; + const int h_idx = grid_idx.y; + const int dh_idx = grid_idx.z; + + const int CB_start_idx = b_idx * CB_strides[0] + (h_idx / group_size) * CB_strides[1]; + const int x_idx = b_idx * x_strides[0] + h_idx * x_strides[1] + dh_idx * x_strides[2]; + const int dt_idx = b_idx * dt_strides[0] + h_idx * dt_strides[1]; + const int state_start_idx = b_idx * state_strides[0] + h_idx * state_strides[1] + dh_idx * state_strides[2]; + + // load data + T this_x = x[x_idx]; // Assumed to be a already SILU activated by conv kernel + T this_dt = dt[dt_idx]; + T this_decay = decay[dt_idx]; + T this_D = D[h_idx]; + T this_z = z[x_idx]; + this_z = SILU()(this_z); // SILU activation + + T temp = 0; + #pragma unroll + for (int i = 0; i < state_size; ++i) { + int CB_idx = CB_start_idx + i; + int state_idx = state_start_idx + i; + T this_new_state = state[state_idx] * this_decay + B[CB_idx] * this_dt * this_x; + next_state[state_idx] = this_new_state; + temp = temp + this_new_state * C[CB_idx]; + } + temp = temp + this_D * this_x; // Skip connection + temp = temp * this_z; // Out gate with z + y[x_idx] = temp; +} + +#define instantiate_ssd_update_kernel(type_name, type) \ + template [[host_name("ssd_update_kernel_" #type_name)]] \ + [[kernel]] void ssd_update_kernel( \ + device const type* x [[buffer(0)]], \ + device const type* dt [[buffer(1)]], \ + device const type* decay [[buffer(2)]], \ + device const type* B [[buffer(3)]], \ + device const type* C [[buffer(4)]], \ + device const type* D [[buffer(5)]], \ + device const type* z [[buffer(6)]], \ + device const type* state [[buffer(7)]], \ + device type* y [[buffer(8)]], \ + device type* next_state [[buffer(9)]], \ + constant const int& group_size [[buffer(10)]], \ + constant const int& state_size [[buffer(11)]], \ + constant const size_t* x_strides [[buffer(12)]], \ + constant const size_t* dt_strides [[buffer(13)]], \ + constant const size_t* CB_strides [[buffer(14)]], \ + constant const size_t* state_strides [[buffer(15)]], \ + uint3 grid_idx [[thread_position_in_grid]], \ + uint3 grid_size [[threads_per_grid]]); + +instantiate_ssd_update_kernel(float32, float); +instantiate_ssd_update_kernel(float16, half); +//instantiate_ssd_update_kernel(bfloat16, bfloat16_t); +//instantiate_ssd_update_kernel(complex64, complex64_t); diff --git a/cartesia-metal/src/ssd_update_no_z.cpp b/cartesia-metal/src/ssd_update_no_z.cpp new file mode 100644 index 0000000..0e5580d --- /dev/null +++ b/cartesia-metal/src/ssd_update_no_z.cpp @@ -0,0 +1,134 @@ +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include "src/ssd_update_no_z.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#endif + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" + +namespace mlx::core { + +std::vector ssd_update_no_z( + const array& x, + const array& dt, + const array& decay, + const array& B, + const array& C, + const array& D, + const array& state, + StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +) { + auto y_dtype = x.dtype(); + // TODO: Also make sure state is of the same dtype + + auto y_shape = x.shape(); + auto state_shape = state.shape(); + + return array::make_arrays( + {y_shape, state_shape}, + {y_dtype, y_dtype}, + std::make_shared(to_stream(s)), + {x, dt, decay, B, C, D, state}); +} + +void SSDUpdateNoZ::eval(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval not implemented!"); +} + + +#ifdef ACCELERATE_NEW_LAPACK + +void SSDUpdateNoZ::eval_cpu(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval_cpu not implemented!"); +} + +#endif + + +void SSDUpdateNoZ::eval_gpu(const std::vector& inputs, std::vector& outputs) { + + assert(inputs.size() == 7); + assert(outputs.size() == 2); + + auto x = inputs[0]; + auto dt = inputs[1]; + auto decay = inputs[2]; + auto B = inputs[3]; + auto C = inputs[4]; + auto D = inputs[5]; + auto state = inputs[6]; + + auto y = outputs[0]; + auto next_state = outputs[1]; + + auto& s = stream(); + auto& d = metal::device(s.device); + + y.set_data( + allocator::malloc_or_wait(x.data_size() * y.itemsize()), + x.data_size(), + x.strides(), + x.flags() + ); + + next_state.set_data( + allocator::malloc_or_wait(state.data_size() * state.itemsize()), + state.data_size(), + state.strides(), + state.flags() + ); + + std::ostringstream kname; + kname << "ssd_update_no_z_kernel_"; + kname << type_to_name(x); + + d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(dt, 1); + compute_encoder.set_input_array(decay, 2); + compute_encoder.set_input_array(B, 3); + compute_encoder.set_input_array(C, 4); + compute_encoder.set_input_array(D, 5); + compute_encoder.set_input_array(state, 6); + compute_encoder.set_output_array(y, 7); + compute_encoder.set_output_array(next_state, 8); + + + auto b = x.shape(0); + auto h = x.shape(1); + auto dh = x.shape(2); + auto state_size = state.shape(3); + auto g = B.shape(1); + auto group_size = h / g; + + compute_encoder->setBytes(&group_size, sizeof(size_t), 9); + compute_encoder->setBytes(&state_size, sizeof(size_t), 10); + + compute_encoder->setBytes(x.strides().data(), 3 * sizeof(size_t), 11); + compute_encoder->setBytes(dt.strides().data(), 2 * sizeof(size_t), 12); + compute_encoder->setBytes(B.strides().data(), 3 * sizeof(size_t), 13); + compute_encoder->setBytes(state.strides().data(), 4 * sizeof(size_t), 14); + + // https://developer.apple.com/documentation/metal/compute_passes/calculating_threadgroup_and_grid_sizes + MTL::Size grid_dims = MTL::Size(b, h, dh); + // size_t width = kernel->threadExecutionWidth(); + // size_t height = kernel->maxTotalThreadsPerThreadgroup() / width; + MTL::Size group_dims = MTL::Size(32, 32, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/ssd_update_no_z.h b/cartesia-metal/src/ssd_update_no_z.h new file mode 100644 index 0000000..b8563a7 --- /dev/null +++ b/cartesia-metal/src/ssd_update_no_z.h @@ -0,0 +1,35 @@ +#pragma once + +#include "mlx/ops.h" +#include "mlx/primitives.h" + +namespace mlx::core { + + +std::vector ssd_update_no_z( + const array& x, + const array& dt, + const array& A, + const array& B, + const array& C, + const array& D, + const array& state, + StreamOrDevice s = {} // Stream on which to schedule the operation +); + +class SSDUpdateNoZ : public Primitive { + public: + explicit SSDUpdateNoZ(Stream stream) + : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + + void print(std::ostream& os) override { + os << "SSDUpdateNoZ"; + } + + void eval(const std::vector& inputs, std::vector& outputs); +}; + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/ssd_update_no_z.metal b/cartesia-metal/src/ssd_update_no_z.metal new file mode 100644 index 0000000..2c56f75 --- /dev/null +++ b/cartesia-metal/src/ssd_update_no_z.metal @@ -0,0 +1,94 @@ +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +// #include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/utils.h" + + + +struct SILU { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? (1 - y) * x : y * x; + } +}; + +template +[[kernel]] void ssd_update_no_z_kernel( + // Input + device const T* x [[buffer(0)]], // (b, h, dh) + device const T* dt [[buffer(1)]], // (b, h) + device const T* decay [[buffer(2)]], // (b, h) + device const T* B [[buffer(3)]], // (b, g, n) + device const T* C [[buffer(4)]], // (b, g, n) + device const T* D [[buffer(5)]], // (h) + device const T* state [[buffer(6)]], // (b, h, dh, n) + device T* y [[buffer(7)]], + device T* next_state [[buffer(8)]], + // Parameters + constant const int& group_size [[buffer(9)]], // h = group_size * g + constant const int& state_size [[buffer(10)]], + // Strides + constant const size_t* x_strides [[buffer(11)]], + constant const size_t* dt_strides [[buffer(12)]], + constant const size_t* CB_strides [[buffer(13)]], + constant const size_t* state_strides [[buffer(14)]], + // Grid + uint3 grid_idx [[thread_position_in_grid]], + uint3 grid_size [[threads_per_grid]]) { + + const int b_idx = grid_idx.x; + const int h_idx = grid_idx.y; + const int dh_idx = grid_idx.z; + + const int CB_start_idx = b_idx * CB_strides[0] + (h_idx / group_size) * CB_strides[1]; + const int x_idx = b_idx * x_strides[0] + h_idx * x_strides[1] + dh_idx * x_strides[2]; + const int dt_idx = b_idx * dt_strides[0] + h_idx * dt_strides[1]; + const int state_start_idx = b_idx * state_strides[0] + h_idx * state_strides[1] + dh_idx * state_strides[2]; + + // load data + T this_x = x[x_idx]; // Assumed to be a already SILU activated by conv kernel + T this_dt = dt[dt_idx]; + T this_decay = decay[dt_idx]; + T this_D = D[h_idx]; + + T temp = 0; + #pragma unroll + for (int i = 0; i < state_size; ++i) { + int CB_idx = CB_start_idx + i; + int state_idx = state_start_idx + i; + T this_new_state = state[state_idx] * this_decay + B[CB_idx] * this_dt * this_x; + next_state[state_idx] = this_new_state; + temp = temp + this_new_state * C[CB_idx]; + } + temp = temp + this_D * this_x; // Skip connection + y[x_idx] = temp; +} + +#define instantiate_ssd_update_no_z_kernel(type_name, type) \ + template [[host_name("ssd_update_no_z_kernel_" #type_name)]] \ + [[kernel]] void ssd_update_no_z_kernel( \ + device const type* x [[buffer(0)]], \ + device const type* dt [[buffer(1)]], \ + device const type* decay [[buffer(2)]], \ + device const type* B [[buffer(3)]], \ + device const type* C [[buffer(4)]], \ + device const type* D [[buffer(5)]], \ + device const type* state [[buffer(6)]], \ + device type* y [[buffer(7)]], \ + device type* next_state [[buffer(8)]], \ + constant const int& group_size [[buffer(9)]], \ + constant const int& state_size [[buffer(10)]], \ + constant const size_t* x_strides [[buffer(11)]], \ + constant const size_t* dt_strides [[buffer(12)]], \ + constant const size_t* CB_strides [[buffer(13)]], \ + constant const size_t* state_strides [[buffer(14)]], \ + uint3 grid_idx [[thread_position_in_grid]], \ + uint3 grid_size [[threads_per_grid]]); + +instantiate_ssd_update_no_z_kernel(float32, float); +instantiate_ssd_update_no_z_kernel(float16, half); +//instantiate_ssd_update_kernel(bfloat16, bfloat16_t); +//instantiate_ssd_update_kernel(complex64, complex64_t); diff --git a/cartesia-metal/src/ssm_update.cpp b/cartesia-metal/src/ssm_update.cpp new file mode 100644 index 0000000..792cb9b --- /dev/null +++ b/cartesia-metal/src/ssm_update.cpp @@ -0,0 +1,126 @@ +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include "src/ssm_update.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#endif + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" + +namespace mlx::core { + +std::vector ssm_update( + const array& x, + const array& dt, + const array& A, + const array& B, + const array& C, + const array& D, + const array& z, + const array& state, + StreamOrDevice s /* = {} */ // Stream on which to schedule the operation +) { + auto y_dtype = x.dtype(); + // TODO: Also make sure state is of the same dtype + + auto y_shape = x.shape(); + auto state_shape = state.shape(); + + return array::make_arrays( + {y_shape, state_shape}, + {y_dtype, y_dtype}, + std::make_shared(to_stream(s)), + {x, dt, A, B, C, D, z, state}); +} + +void SSMUpdate::eval(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval not implemented!"); +} + + +#ifdef ACCELERATE_NEW_LAPACK + +void SSMUpdate::eval_cpu(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("eval_cpu not implemented!"); +} + +#endif + + +void SSMUpdate::eval_gpu(const std::vector& inputs, std::vector& outputs) { + + assert(inputs.size() == 8); + assert(outputs.size() == 2); + + auto x = inputs[0]; + auto dt = inputs[1]; + auto A = inputs[2]; + auto B = inputs[3]; + auto C = inputs[4]; + auto D = inputs[5]; + auto z = inputs[6]; + auto state = inputs[7]; + + auto y = outputs[0]; + auto next_state = outputs[1]; + + auto& s = stream(); + auto& d = metal::device(s.device); + + y.set_data( + allocator::malloc_or_wait(x.data_size() * y.itemsize()), + x.data_size(), + x.strides(), + x.flags() + ); + + next_state.set_data( + allocator::malloc_or_wait(state.data_size() * state.itemsize()), + state.data_size(), + state.strides(), + state.flags() + ); + + std::ostringstream kname; + kname << "ssm_update_kernel_"; + kname << type_to_name(x); + + d.register_library("mlx_ext", metal::get_colocated_mtllib_path); + auto kernel = d.get_kernel(kname.str(), "mlx_ext"); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + compute_encoder.set_input_array(x, 0); + compute_encoder.set_input_array(dt, 1); + compute_encoder.set_input_array(A, 2); + compute_encoder.set_input_array(B, 3); + compute_encoder.set_input_array(C, 4); + compute_encoder.set_input_array(D, 5); + compute_encoder.set_input_array(z, 6); + compute_encoder.set_input_array(state, 7); + + compute_encoder.set_output_array(y, 8); + compute_encoder.set_output_array(next_state, 9); + + auto batch_size = x.shape(0); + auto n_channels = x.shape(1); + //auto n_state = A.shape(0); + + // https://developer.apple.com/documentation/metal/compute_passes/calculating_threadgroup_and_grid_sizes + MTL::Size grid_dims = MTL::Size(batch_size, n_channels, 1); + // size_t width = kernel->threadExecutionWidth(); + // size_t height = kernel->maxTotalThreadsPerThreadgroup() / width; + MTL::Size group_dims = MTL::Size(32, 32, 1); + + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/ssm_update.h b/cartesia-metal/src/ssm_update.h new file mode 100644 index 0000000..8239fae --- /dev/null +++ b/cartesia-metal/src/ssm_update.h @@ -0,0 +1,36 @@ +#pragma once + +#include "mlx/ops.h" +#include "mlx/primitives.h" + +namespace mlx::core { + + +std::vector ssm_update( + const array& x, + const array& dt, + const array& A, + const array& B, + const array& C, + const array& D, + const array& z, + const array& state, + StreamOrDevice s = {} // Stream on which to schedule the operation +); + +class SSMUpdate : public Primitive { + public: + explicit SSMUpdate(Stream stream) + : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + + void print(std::ostream& os) override { + os << "SSMUpdate"; + } + + void eval(const std::vector& inputs, std::vector& outputs); +}; + +} // namespace mlx::core \ No newline at end of file diff --git a/cartesia-metal/src/ssm_update.metal b/cartesia-metal/src/ssm_update.metal new file mode 100644 index 0000000..dc2826a --- /dev/null +++ b/cartesia-metal/src/ssm_update.metal @@ -0,0 +1,102 @@ +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/utils.h" + + +struct SILU { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(-metal::abs(x))); + return (x < 0) ? (1 - y) * x : y * x; + } +}; + +// Stable softplus https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html +struct Softplus { + template + T operator()(T x) { + auto y = log1p(metal::fast::exp(x)); + return (x > 20) ? x : y; + }; +}; + +// Could also be included from #include "mlx/backend/metal/kernels/unary.h" but somehow this doesnt work for multiple files +struct Exp { + template + T operator()(T x) { + return metal::fast::exp(x); + }; +}; + + +template +[[kernel]] void ssm_update_kernel( + device const T* x [[buffer(0)]], + device const T* dt [[buffer(1)]], + device const T* A [[buffer(2)]], + device const T* B [[buffer(3)]], + device const T* C [[buffer(4)]], + device const T* D [[buffer(5)]], + device const T* z [[buffer(6)]], + device const T* state [[buffer(7)]], + device T* y [[buffer(8)]], + device T* next_state [[buffer(9)]], + uint3 grid_idx [[thread_position_in_grid]], + uint3 grid_size [[threads_per_grid]]) { + + + const int state_size = 16; + //const int batch_size = grid_size.x; + const int channel_size = grid_size.y; + const int batch_idx = grid_idx.x; + const int channel_idx = grid_idx.y; + const int cb_start_idx = batch_idx * state_size; // CB are data controlled + + const int x_idx = batch_idx * channel_size + channel_idx; + const int state_start_idx = x_idx * state_size; + + T this_x = x[x_idx]; + this_x = SILU()(this_x); // SILU activation + + T this_z = z[x_idx]; + this_z = SILU()(this_z); // SILU activation + + T delta = Softplus()(dt[x_idx]); // Softplus log(1 + exp(dt)) + + T temp = 0; + #pragma unroll + for (int i = 0; i < state_size; ++i) { + int cb_idx = cb_start_idx + i; + int state_idx = state_start_idx + i; + T this_new_state = state[state_idx] * Exp()(A[i] * delta) + B[cb_idx] * delta * this_x; + next_state[state_idx] = this_new_state; + temp = temp + this_new_state * C[cb_idx]; + } + temp = temp + D[channel_idx] * this_x; // Skip connection + temp = temp * this_z; // Out gate with z + y[x_idx] = temp; +} + +#define instantiate_ssm_update_kernel(type_name, type) \ + template [[host_name("ssm_update_kernel_" #type_name)]] \ + [[kernel]] void ssm_update_kernel( \ + device const type* x [[buffer(0)]], \ + device const type* dt [[buffer(1)]], \ + device const type* A [[buffer(2)]], \ + device const type* B [[buffer(3)]], \ + device const type* C [[buffer(4)]], \ + device const type* D [[buffer(5)]], \ + device const type* z [[buffer(6)]], \ + device const type* state [[buffer(7)]], \ + device type* y [[buffer(8)]], \ + device type* next_state [[buffer(9)]], \ + uint3 grid_idx [[thread_position_in_grid]], \ + uint3 grid_size [[threads_per_grid]]); + +instantiate_ssm_update_kernel(float32, float); +instantiate_ssm_update_kernel(float16, half); +//instantiate_ssm_update_kernel(bfloat16, bfloat16_t); +//instantiate_ssm_update_kernel(complex64, complex64_t); diff --git a/cartesia-mlx/MANIFEST.in b/cartesia-mlx/MANIFEST.in new file mode 100644 index 0000000..bc6dcd3 --- /dev/null +++ b/cartesia-mlx/MANIFEST.in @@ -0,0 +1,2 @@ +include README.md +recursive-include cartesia_mlx * diff --git a/cartesia-mlx/README.md b/cartesia-mlx/README.md new file mode 100644 index 0000000..1270034 --- /dev/null +++ b/cartesia-mlx/README.md @@ -0,0 +1,42 @@ +# Cartesia MLX + +This package contains implementations for fast on-device SSM inference on Apple silicon. + +## Installation +To install this package, first install Xcode, which can be downloaded from https://developer.apple.com/xcode/. +Accept the license agreement with: +```shell +sudo xcodebuild -license +``` + +Install the required dependencies: the exact version of `nanobind`, followed by `cartesia-metal`, and finally `cartesia-mlx`, with the following commands: +```shell +pip install nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4 +pip install git+https://github.com/cartesia-ai/edge.git#subdirectory=cartesia-metal +pip install cartesia-mlx +``` + +Note: This package has been tested on macOS Sonoma 14.1 with the M3 chip. + +## Models + +### Language Models +- `cartesia-ai/Rene-v0.1-1.3b-4bit-mlx` +- `cartesia-ai/mamba2-130m-8bit-mlx` +- `cartesia-ai/mamba2-130m-mlx` +- `cartesia-ai/mamba2-370m-8bit-mlx` +- `cartesia-ai/mamba2-780m-8bit-mlx` +- `cartesia-ai/mamba2-1.3b-4bit-mlx` +- `cartesia-ai/mamba2-2.7b-4bit-mlx` + +## Usage +A simple example script for generation can be found in `cartesia-mlx/example.py`. +Usage example (clone this repo and run the below from within the `cartesia-mlx` directory): +```shell +python example.py --model cartesia-ai/Rene-v0.1-1.3b-4bit-mlx --prompt "Rene Descartes was" +``` + +You can pass any of the models listed above to the `--model` argument; for a full list of command-line options, pass `--help`. + +## Rene in MLX +![Language Model](assets/lm-demo.gif) diff --git a/cartesia-mlx/assets/lm-demo.gif b/cartesia-mlx/assets/lm-demo.gif new file mode 100644 index 0000000..14e4e7e Binary files /dev/null and b/cartesia-mlx/assets/lm-demo.gif differ diff --git a/cartesia-mlx/cartesia_mlx/__init__.py b/cartesia-mlx/cartesia_mlx/__init__.py new file mode 100644 index 0000000..1f48d34 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/__init__.py @@ -0,0 +1,6 @@ +from cartesia_mlx.models.lm import LM +from cartesia_mlx.models.mamba2 import Mamba2 +from cartesia_mlx.models.rene import Rene +from cartesia_mlx.utils.load import from_pretrained + +__all__ = ["LM", "Rene", "Mamba2", "from_pretrained"] diff --git a/cartesia-mlx/cartesia_mlx/layers/__init__.py b/cartesia-mlx/cartesia_mlx/layers/__init__.py new file mode 100644 index 0000000..d33a241 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/__init__.py @@ -0,0 +1,18 @@ +from cartesia_mlx.layers.embedding import Embedding +from cartesia_mlx.layers.ffn import FFN, SwiGLU +from cartesia_mlx.layers.linear import Linear +from cartesia_mlx.layers.residual_block import ResidualBlock +from cartesia_mlx.layers.sa import SelfAttention +from cartesia_mlx.layers.sequence_model import SequenceModel +from cartesia_mlx.layers.ssd.ssd import SSD + +__all__ = [ + "Embedding", + "FFN", + "Linear", + "ResidualBlock", + "SelfAttention", + "SequenceModel", + "SSD", + "SwiGLU", +] diff --git a/cartesia-mlx/cartesia_mlx/layers/embedding.py b/cartesia-mlx/cartesia_mlx/layers/embedding.py new file mode 100644 index 0000000..d25d452 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/embedding.py @@ -0,0 +1,34 @@ +import mlx.core as mx +import mlx.nn as nn + +from cartesia_mlx.utils.configure import Inherit, set_cfg + + +class Embedding(nn.Module): + """An embedding lookup layer.""" + + base_cfg = dict( + _class_="layers.embedding.Embedding", + d_model=Inherit(default=1024), + n_tokens=50288, + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent) + self.encoder = nn.Embedding(self.n_tokens, self.d_model) + + def encode(self, x: mx.array) -> mx.array: + """Encode the input tensor. + + Args: + x: The input tensor. Shape (...). + + Returns: + The encoded tensor. Shape (..., d_model). + """ + return self.encoder(x) + + def __call__(self, x: mx.array) -> mx.array: + """See :meth:`encode`.""" + return self.encode(x) diff --git a/cartesia-mlx/cartesia_mlx/layers/ffn.py b/cartesia-mlx/cartesia_mlx/layers/ffn.py new file mode 100644 index 0000000..4deabce --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/ffn.py @@ -0,0 +1,83 @@ +from functools import partial + +import mlx.core as mx +import mlx.nn as nn + +from cartesia_mlx.utils.configure import Inherit, set_cfg + + +class FFN(nn.Module): + """A base class for feed-forward networks.""" + + base_cfg = dict( + _class_="layers.ffn.FFN", + quantization_kwargs=Inherit(default=None), + d_model=Inherit(default=1024), + expand=4, + activation="swish", # or 'gelu' + glu=False, + bias=False, + norm=False, + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent) + self.d_inner = int(round(self.expand * self.d_model)) + Linear = ( + partial(nn.QuantizedLinear, **self.quantization_kwargs) + if self.quantization_kwargs + else nn.Linear + ) + self.in_proj = Linear(self.d_model, self.d_inner, bias=self.bias) + self.out_proj = Linear(self.d_inner, self.d_model, bias=self.bias) + if self.glu: + self.gate_proj = nn.Linear(self.d_model, self.d_inner, bias=self.bias) + if self.norm: + self.norm = nn.RMSNorm(self.d_inner) + assert self.activation in ["swish", "gelu"] + self.activation = nn.silu if self.activation == "swish" else nn.gelu + + def __call__(self, x: mx.array, *args, **kwargs) -> mx.array: + """Forward pass. + + Args: + x: The input tensor. + + Returns: + The output of the feed-forward network. + Shape (batch_size, seq_len, d_model). + """ + y = self.in_proj(x) + y = self.activation(y) + if self.glu: + g = self.gate_proj(x) + y = y * g + if self.norm: + y = self.norm(y) + x = self.out_proj(y) + return x + + def step(self, x: mx.array, *args, **kwargs) -> mx.array: + """See :meth:`__call__`.""" + x = self(x, *args, **kwargs) + return x + + +class SwiGLU(FFN): + """A feed-forward network with Swish-GLU. + + Reference: + Shazeer et al. GLU Variants Improve Transformer. ArXiv 2020. + `https://arxiv.org/pdf/2002.05202v1`. + """ + + base_cfg = dict( + _class_="layers.ffn.SwiGLU", + quantization_kwargs=Inherit(default=None), + d_model=Inherit(default=1024), + expand=2, + glu=True, + bias=False, + norm=True, + ) diff --git a/cartesia-mlx/cartesia_mlx/layers/linear.py b/cartesia-mlx/cartesia_mlx/layers/linear.py new file mode 100644 index 0000000..78cdcca --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/linear.py @@ -0,0 +1,39 @@ +from functools import partial + +import mlx.core as mx +import mlx.nn as nn + +from cartesia_mlx.utils.configure import Inherit, set_cfg + + +class Linear(nn.Module): + """A linear layer.""" + + base_cfg = dict( + _class_="layers.linear.Linear", + quantization_kwargs=Inherit(default=None), + input_dims=None, + output_dims=None, + bias=True, + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent) + Linear = ( + partial(nn.QuantizedLinear, **self.quantization_kwargs) + if self.quantization_kwargs + else nn.Linear + ) + self.linear = Linear(self.input_dims, self.output_dims, bias=self.bias) + + def __call__(self, x: mx.array) -> mx.array: + """Forward pass. + + Args: + x: Input tensor. + + Returns: + The linear transformation of the input tensor. + """ + return self.linear(x) diff --git a/cartesia-mlx/cartesia_mlx/layers/residual_block.py b/cartesia-mlx/cartesia_mlx/layers/residual_block.py new file mode 100644 index 0000000..911e052 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/residual_block.py @@ -0,0 +1,104 @@ +from typing import Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from cartesia_mlx.layers.sa import SelfAttention +from cartesia_mlx.utils.configure import Inherit, instantiate, set_cfg + + +class ResidualBlock(nn.Module): + """A residual block composed of a single layer. + + A residual block is a layer that applies a layer to the input tensor + and adds the input tensor to the output tensor. The layer does not + need to have a state, but it can have a state if needed. + """ + + base_cfg = dict( + _class_="layers.residual_block.ResidualBlock", + quantization_kwargs=Inherit(default=None), + d_model=Inherit(default=1024), + layer=SelfAttention.base_cfg, + norm_point="pre", # None, pre, pre_resid, post + stateful=True, + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent) + if self.norm_point: + self.norm = nn.RMSNorm(self.d_model) + self.layer = instantiate(self.layer, parent=self) + + def __call__( + self, x: mx.array, *args, state: Optional[Union[mx.array, Tuple[mx.array]]] = None, **kwargs + ): + """Forward pass for the residual block. + + Args: + x: The input tensor. + *args: Additional arguments to pass to the layers. + state: The state of the layer. + **kwargs: Additional keyword arguments. + + Returns: + The output tensor and the next state if the layer is stateful. + Otherwise, just the output tensor. + """ + r = x + if self.norm_point == "pre": + if len(x.shape) == 3 and x.shape[1] == 1: + b, l, d = x.shape + x = self.norm(x.reshape(b, d)).reshape(b, l, d) + else: + x = self.norm(x) + z = self.layer(x, *args, state=state, **kwargs) + if self.stateful: + x, state = z + else: + x, state = z, None + if self.norm_point == "pre_resid": + if len(x.shape) == 3 and x.shape[1] == 1: + b, l, d = x.shape + x = self.norm(x.reshape(b, d)).reshape(b, l, d) + else: + x = self.norm(x) + x = x + r + if self.norm_point == "post": + if len(x.shape) == 3 and x.shape[1] == 1: + b, l, d = x.shape + x = self.norm(x.reshape(b, d)).reshape(b, l, d) + else: + x = self.norm(x) + if self.stateful: + return x, state + return x + + def step( + self, x: mx.array, *args, state: Optional[Union[mx.array, Tuple[mx.array]]] = None, **kwargs + ): + """Step function for the residual block. + + Args: + x: The input tensor. + *args: Additional arguments to pass to the layers. + state: The state of the layer. + **kwargs: Additional keyword arguments. + """ + r = x + if self.norm_point == "pre": + x = self.norm(x) + z = self.layer.step(x, *args, state=state, **kwargs) + if self.stateful: + x, state = z + else: + x, state = z, None + if self.norm_point == "pre_resid": + x = self.norm(x) + x = x + r + if self.norm_point == "post": + x = self.norm(x) + if self.stateful: + return x, state + return x diff --git a/cartesia-mlx/cartesia_mlx/layers/sa.py b/cartesia-mlx/cartesia_mlx/layers/sa.py new file mode 100644 index 0000000..cbaee5c --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/sa.py @@ -0,0 +1,168 @@ +from functools import partial +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from cartesia_mlx.utils.configure import Inherit, set_cfg + +SelfAttentionLayerState = Tuple[mx.array, mx.array, mx.array] # (keys, values, mask) + + +class SelfAttention(nn.Module): + """Implements Self-Attention https://arxiv.org/abs/1706.03762 and Multi-Query-Attention https://arxiv.org/abs/1911.02150.""" + + base_cfg = dict( + _class_="layers.sa.SelfAttention", + quantization_kwargs=Inherit(default=None), + d_model=Inherit(default=1024), + d_head=64, + kv_heads=1, + expand=None, + bias=False, + causal=True, + max_context_len=4096, + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent) + if self.expand is None: + self.d_inner = self.d_model + else: + self.d_inner = int(round(self.expand * self.d_model)) + self.n_heads = self.d_inner // self.d_head + self.kv_heads = self.n_heads if self.kv_heads is None else self.kv_heads + self.softmax_scale = 1 / (self.d_head**0.5) + self.d_proj = (self.n_heads + 2 * self.kv_heads) * self.d_head + Linear = ( + partial(nn.QuantizedLinear, **self.quantization_kwargs) + if self.quantization_kwargs + else nn.Linear + ) + self.qkv_proj = Linear(self.d_model, self.d_proj, bias=self.bias) + self.out_proj = Linear(self.n_heads * self.d_head, self.d_model, bias=self.bias) + + def __call__( + self, + x: mx.array, + *args, + state: Optional[SelfAttentionLayerState] = None, + mask: Optional[mx.array] = None, + **kwargs, + ) -> Tuple[mx.array, SelfAttentionLayerState]: + """ + Args: + x (mx.array): Input tensor of shape (batch_size, seq_len, d_model). + state: Tuple containing previous keys, values, and mask state (optional). + mask: Attention mask (optional). + + Returns: + tuple: Output tensor and updated state. + """ + b, l, _ = x.shape + qkv = self.qkv_proj(x) + queries, keys, values = mx.split( + qkv, + indices_or_sections=[ + self.n_heads * self.d_head, + (self.n_heads + self.kv_heads) * self.d_head, + ], + axis=-1, + ) + + queries = queries.reshape(b, l, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(b, l, self.kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(b, l, self.kv_heads, -1).transpose(0, 2, 1, 3) + + if state is not None: + key_state, value_state, mask_state = state + keys = mx.concatenate([key_state, keys], axis=2) + values = mx.concatenate([value_state, values], axis=2) + if mask is None and mask_state is not None: + mask_state = mx.concatenate( + [mask_state, mx.ones([b, l], dtype=x.dtype)], axis=1 + ) # (b, l+l_state) + if mask is not None and mask_state is None: + l_state = key_state.shape[2] + mask_state = mx.concatenate([mx.ones([b, l_state], dtype=x.dtype), mask], axis=1) + elif mask is not None and mask_state is not None: + mask_state = mx.concatenate(mask_state, mask, axis=1) + else: + mask_state = mask + + l_q, l_kv = l, keys.shape[2] + mask_lq_x_lkv = _construct_mask( + b, mask, mask_state, l_q, l_kv, x.dtype, self.causal, self.max_context_len + ) + + dtype = queries.dtype + + x = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.softmax_scale, mask=mask_lq_x_lkv + ) + + if x.dtype != dtype: + x = x.astype(dtype) + + if self.kv_heads == self.n_heads: + x = x[0] # scaled_dot_product_attention outputs different shapes for mha / mqa + x = x.transpose(0, 2, 1, 3).reshape(b, l, -1) + + x = self.out_proj(x) + state = (keys, values, mask_state) + return x, state + + def step( + self, x: mx.array, *args, state: Optional[SelfAttentionLayerState] = None, **kwargs + ) -> Tuple[mx.array, SelfAttentionLayerState]: + """ + Args: + x (mx.array): Input tensor of shape (batch_size, d_model). + state: Tuple containing previous keys, values, and mask state (optional). + + Returns: + tuple: Output tensor and updated state. + """ + b, d = x.shape + x = x.reshape(b, 1, d) + + # Truncate state at max_context_len + if self.max_context_len is not None and state[1].shape[2] > self.max_context_len: + keys = state[0][:, :, -self.max_context_len - 1 :] + values = state[1][:, :, -self.max_context_len - 1 :] + mask_state = None if state[2] is None else state[2][: -self.max_context_len - 1 :] + state = (keys, values, mask_state) + + x, state = self.__call__(x, *args, state=state, **kwargs) + return x[:, 0, :], state + + +def _construct_mask( + b, mask_q, mask_kv, l_q, l_kv, dtype, causal, max_context_len, max_context_len_forces_caual=True +): + if mask_q is None and mask_kv is None and not causal and not max_context_len: + return None + if mask_q is None: + mask_q = mx.ones([b, l_q], dtype=dtype) + if mask_kv is None: + mask_kv = mx.ones([b, l_kv], dtype=dtype) + mask_lq_x_lkv = mask_q.reshape(b, l_q, 1).astype(dtype) @ mask_kv.reshape(b, 1, l_kv).astype( + dtype + ) # (b, l_q, l_kv) + mask_lq_x_lkv = mx.where(mask_lq_x_lkv == 0, float("-inf"), 0) + diag_offset = l_kv - l_q + if causal or max_context_len and max_context_len_forces_caual: + causal_mask = mx.tri(l_q, l_kv, k=diag_offset, dtype=dtype) + causal_mask = mx.where(causal_mask == 0, float("-inf"), 0) + mask_lq_x_lkv += causal_mask.reshape(1, l_q, l_kv) + if max_context_len: + context_mask = mx.tri(l_q, l_kv, k=-max_context_len + diag_offset, dtype=dtype) + context_mask = mx.where(context_mask == 1, float("-inf"), 0) + if not causal: + context_mask_upper = mx.tri(l_q, l_kv, k=max_context_len + diag_offset - 1) + context_mask_upper = mx.where(context_mask_upper == 0, float("-inf"), 0) + context_mask += context_mask_upper + mask_lq_x_lkv += context_mask.reshape(1, l_q, l_kv) + mask_lq_x_lkv = mask_lq_x_lkv.reshape(b, 1, 1, l_q, l_kv) + return mask_lq_x_lkv diff --git a/cartesia-mlx/cartesia_mlx/layers/sequence_model.py b/cartesia-mlx/cartesia_mlx/layers/sequence_model.py new file mode 100644 index 0000000..95af111 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/sequence_model.py @@ -0,0 +1,120 @@ +import copy +from functools import partial +from typing import Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from cartesia_mlx.layers.residual_block import ResidualBlock +from cartesia_mlx.layers.sa import SelfAttention +from cartesia_mlx.layers.ssd.ssd import SSD +from cartesia_mlx.utils.configure import Inherit, instantiate, set_cfg, sub_cfg + +State = Union[mx.array, Tuple[mx.array]] + + +class SequenceModel(nn.Module): + """A base class for a sequence model. + + A sequence model is composed of a sequence of stateful layers (e.g. ssd, attention, etc.). + These layers can be configured in the config displayed below. + """ + + base_cfg = dict( + _class_="layers.sequence_model.SequenceModel", + quantization_kwargs=Inherit(default=None), + d_model=Inherit(default=512), + n_layer_repeats=1, + unique_layers=[ + sub_cfg(ResidualBlock.base_cfg, layer=SelfAttention.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SSD.base_cfg), + ], + post_norm=True, + final_proj=False, + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent) + if self.post_norm: + self.norm = nn.RMSNorm(self.d_model) + layers = [self.unique_layers] * self.n_layer_repeats + layers = flatten([copy.deepcopy(layer) for layer in layers]) + self.layers = [instantiate(layer, parent=self) for layer in layers] + + if self.final_proj: + Linear = ( + partial(nn.QuantizedLinear, **self.quantization_kwargs) + if self.quantization_kwargs + else nn.Linear + ) + self.out_proj = Linear(self.d_model, self.d_model) + + def __call__( + self, x: mx.array, *args, state: Optional[State] = None, **kwargs + ) -> Tuple[mx.array, State]: + """Forward pass on the sequence model. + + Args: + x: The input tensor. Shape (batch_size, seq_len, ...). + *args: Additional arguments to pass to the layers. + state: The state of the model. + **kwargs: Additional keyword arguments to pass to the layers. + + Returns: + The output tensor and the next state. + """ + next_state = [] + for i, layer in enumerate(self.layers): + state_i = state[i] if state else None + z = layer(x, *args, state=state_i, **kwargs) + if layer.stateful is True: + x, state_i = z + else: + x, state_i = z, None + next_state.append(state_i) + if self.post_norm: + x = self.norm(x) + if self.final_proj: + x = self.out_proj(x) + return x, next_state + + def step(self, x: mx.array, *args, state=None, **kwargs) -> Tuple[mx.array, State]: + """A single step on the sequence model. + + Args: + x: Input tensor. + state: State of the model. + + Returns: + The output and the next state. + """ + next_state = [] + for i, layer in enumerate(self.layers): + state_i = state[i] if state else None + z = layer.step(x, *args, state=state_i, **kwargs) + if layer.stateful is True: + x, state_i = z + else: + x, state_i = z, None + next_state.append(state_i) + if self.post_norm: + x = self.norm(x) + if self.final_proj: + x = self.out_proj(x) + return x, next_state + + +def flatten(x) -> list: + """Flattens a list of list-likes (e.g. list, array, tensor). + + Args: + x: A list-like object that can be selected like, x[0], x[1], ... + + Returns: + List: A flattened list of all elements in x. + """ + y = [] + for xi in x: + y.extend(xi) + return y diff --git a/cartesia-mlx/cartesia_mlx/layers/ssd/__init__.py b/cartesia-mlx/cartesia_mlx/layers/ssd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cartesia-mlx/cartesia_mlx/layers/ssd/ops.py b/cartesia-mlx/cartesia_mlx/layers/ssd/ops.py new file mode 100644 index 0000000..1608979 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/ssd/ops.py @@ -0,0 +1,91 @@ +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn + + +@mx.compile +def ssd_forward_attn( + x: mx.array, + dt: mx.array, + A: mx.array, + B: mx.array, + C: mx.array, + D: mx.array, + dt_bias: mx.array, + dt_min: float, + dt_max: float, +) -> Tuple[mx.array, mx.array]: + """SSD-SSM forward pass. + + Args: + x: Input of shape (batch_size, num_heads, head_dim). + dt: Time deltas of shape (num_heads,). + A: State transition of shape (num_heads,). + B: Input mixing of shape (batch_size, num_groups, n). + C: Output mixing of shape (batch_size, num_groups, n). + D: Residual connection. + dt_bias: Bias for time deltas of shape (num_heads,). + dt_min: Minimum value for time deltas after clipping. + dt_max: Maximum value for time deltas after clipping. + state: Previous state for recurrent connections. + Shape (batch_size, n_heads, d_head, d_state). + + Returns: + Output and next state. + """ + b, l, h, dh = x.shape + _, _, g, _ = B.shape + + if dt_bias is not None: + dt = dt + dt_bias.reshape(1, 1, -1) + + dt = nn.softplus(dt) + dt = mx.clip(dt, a_min=dt_min, a_max=dt_max).astype(x.dtype) + + B = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2) + C = mx.swapaxes(C, 1, 2) + + CB = C @ B + CB = mx.repeat(CB, repeats=h // g, axis=1) + + dtA = dt * A.reshape(1, 1, -1) + dtA = mx.swapaxes(dtA, 1, 2) + + decay = mx.exp(segsum(dtA)) + + surrogate_attention_matrix = mx.tril(CB * decay, 0) + + dtx = dt.reshape(b, l, h, 1) * x + y = surrogate_attention_matrix @ dtx.swapaxes(1, 2) + y = mx.swapaxes(y, 1, 2) + + decay = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1) + B = mx.repeat(B, h // g, axis=1).swapaxes(2, 3) + dtxdecay = dtx * decay + dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3) + next_state = dtxdecay @ B + + if D is not None: + y += x * D.reshape(1, 1, h, 1) + + y = y.reshape(b, l, h * dh) + + return y, next_state + + +@mx.compile +def segsum(x): + """Compute the segmented cumulative sum of a tensor. + + Args: + x: Input tensor. + + Returns: + Segmented cumulative sum. + """ + l = x.shape[-1] + x = mx.repeat(x[..., None], l, axis=-1) + x = mx.tril(x, -1) + x_segsum = mx.cumsum(x, axis=-2) + return x_segsum diff --git a/cartesia-mlx/cartesia_mlx/layers/ssd/ssd.py b/cartesia-mlx/cartesia_mlx/layers/ssd/ssd.py new file mode 100644 index 0000000..c35bef1 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/layers/ssd/ssd.py @@ -0,0 +1,237 @@ +from functools import partial +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from cartesia_metal import conv1d_forward, conv1d_update, ssd_update, ssd_update_no_z +from cartesia_mlx.layers.ssd.ops import ssd_forward_attn +from cartesia_mlx.utils.configure import Inherit, set_cfg + +uniform_init = nn.init.uniform() + + +SSDLayerState = Tuple[mx.array, mx.array] # (conv_state, ssm_state) + + +class SSD(nn.Module): + """State Space Duality (Mamba 2) layer. + + Reference: + Dao* & Gu* Transformers are SSMs: Generalized Models and Efficient Algorithms + Through Structured State Space Duality. ArXiv 2024. + https://arxiv.org/abs/2405.21060. + """ + + base_cfg = dict( + _class_="layers.ssd.ssd.SSD", + quantization_kwargs=Inherit(default=None), + d_model=Inherit(default=1024), + expand=2, + kernel_size=4, + d_state=64, + d_head=64, + n_groups=1, + A_init_range=(1, 16), + dt_init_range=(0.0001, 0.1), + dt_limit=(0.0, float("inf")), + norm_before_gate=False, + bias=False, + conv_bias=False, + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent) + + self.d_inner = self.d_model * self.expand + assert self.d_inner % self.d_head == 0 + self.n_heads = self.d_inner // self.d_head + + in_proj_dim = 2 * self.d_inner + 2 * self.d_state * self.n_groups + self.n_heads + Linear = ( + partial(nn.QuantizedLinear, **self.quantization_kwargs) + if self.quantization_kwargs + else nn.Linear + ) + self.in_proj = Linear(self.d_model, in_proj_dim, bias=self.bias) + self.out_proj = Linear(self.d_inner, self.d_model, bias=self.bias) + + self.conv_dim = self.d_inner + 2 * self.d_state * self.n_groups + self.conv_weight = mx.zeros([self.conv_dim, self.kernel_size]) + self.conv_bias = mx.zeros([self.conv_dim]) + + self.A = ( + -uniform_init(mx.zeros([self.n_heads])) * (self.A_init_range[1] - self.A_init_range[0]) + + self.A_init_range[0] + ) + + self.dt_bias = ( + uniform_init(mx.zeros([self.n_heads])) * (self.dt_init_range[1] - self.dt_init_range[0]) + + self.dt_init_range[0] + ) + self.dt_bias = softplus_inverse(self.dt_bias) + + self.D = mx.ones([self.n_heads]) + + self.rms_norm = nn.RMSNorm(self.d_inner) + + def __call__( + self, x: mx.array, state: Optional[SSDLayerState] = None, **kwargs + ) -> Tuple[mx.array, SSDLayerState]: + """ + Args: + x: Input tensor of shape (batch_size, seq_len, d_model). + state: Tuple containing previous convolution and SSD states. + See :meth:`get_default_state` for the shape information. + + Returns: + A tuple of the output and updated state. + """ + b, l, _ = x.shape + + conv_state, ssm_state = ( + state if state is not None else self.get_default_state(batch_size=b, dtype=x.dtype) + ) + assert ssm_state is None, "SSD state passing is not supported." + + zxBCdt = self.in_proj(x) + + z, xBC, dt = mx.split( + zxBCdt, [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], axis=-1 + ) + + xBC, conv_state = conv1d_forward( + xBC, self.conv_weight, self.conv_bias, conv_state + ) # Fused silu + + x, B, C = mx.split( + xBC, [self.d_inner, self.d_inner + self.d_state * self.n_groups], axis=-1 + ) + + x, B, C = mx.split( + xBC, [self.d_inner, self.d_inner + self.d_state * self.n_groups], axis=-1 + ) + + x, ssm_state = ssd_forward_attn( + x.reshape(b, l, self.n_heads, self.d_head), + dt, + self.A, + B.reshape(b, l, self.n_groups, -1), + C.reshape(b, l, self.n_groups, -1), + self.D, + self.dt_bias, + dt_min=self.dt_limit[0], + dt_max=self.dt_limit[1], + ) + + if self.norm_before_gate is True: + x = self.rms_norm(x) + x = x * nn.silu(z) + else: + x = x * nn.silu(z) + x = self.rms_norm(x) + + x = self.out_proj(x) + + return x, (conv_state, ssm_state) + + def step( + self, x: mx.array, state: SSDLayerState, *args, **kwargs + ) -> Tuple[mx.array, SSDLayerState]: + """ + Args: + x: An input tensor of shape (batch_size, d_model). + state: A tuple containing previous convolution and SSD states. + See :meth:`get_default_state` for the shape information. + + Returns: + A tuple of the the following: + - output: The output tensor of shape (batch_size, seq_len, d_model). + - updated_state: The updated state. + See :meth:`get_default_state` for the shape information. + """ + b, _ = x.shape + conv_state, ssm_state = state + + zxBCdt = self.in_proj(x) + + z, xBC, dt = mx.split( + zxBCdt, [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], axis=-1 + ) + + xBC, conv_state = conv1d_update( + xBC, self.conv_weight, self.conv_bias, conv_state + ) # Fused silu + + x, B, C = mx.split( + xBC, [self.d_inner, self.d_inner + self.d_state * self.n_groups], axis=-1 + ) + + B = B.reshape(b, self.n_groups, -1) # (b, g, n) + C = C.reshape(b, self.n_groups, -1) # (b, g, n) + x = x.reshape(b, self.n_heads, self.d_head) # (b, h, d_head) + + if self.norm_before_gate is True: + x, ssm_state = ssd_update_no_z( + x, + dt, + self.A, + B, + C, + self.D, + self.dt_bias, + dt_min=self.dt_limit[0], + dt_max=self.dt_limit[1], + state=ssm_state, + ) + x = self.rms_norm(x) + x = x * nn.silu(z) + else: + x, ssm_state = ssd_update( + x, + dt, + self.A, + B, + C, + self.D, + z, + self.dt_bias, + dt_min=self.dt_limit[0], + dt_max=self.dt_limit[1], + state=ssm_state, + ) + x = self.rms_norm(x) + + x = self.out_proj(x) + + return x, (conv_state, ssm_state) + + def get_default_state(self, batch_size=1, dtype=mx.float16): + """Get the default state for the layer. + + Args: + batch_size: The batch size. + dtype: The data type. + + Returns: + A tuple representing the state with components: + - conv_state: The convolution state. Shape (batch_size, conv_dim, kernel_size - 1). + - ssm_state: The SSD state. Shape (batch_size, n_heads, d_head, d_state). + """ + conv_state = mx.zeros([batch_size, self.conv_dim, self.kernel_size - 1], dtype=dtype) + ssm_state = None + state = (conv_state, ssm_state) + return state + + +@mx.compile +def softplus_inverse(x: mx.array): + """Computes the inverse softplus. + + :math:`x = softplus_inverse(softplus(x))` + + Args: + x: The input tensor. + """ + return x + mx.log(-mx.exp(-x) + 1) diff --git a/cartesia-mlx/cartesia_mlx/models/__init__.py b/cartesia-mlx/cartesia_mlx/models/__init__.py new file mode 100644 index 0000000..817b780 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/models/__init__.py @@ -0,0 +1,5 @@ +from cartesia_mlx.models.lm import LM +from cartesia_mlx.models.mamba2 import Mamba2 +from cartesia_mlx.models.rene import Rene + +__all__ = ["LM", "Rene", "Mamba2"] diff --git a/cartesia-mlx/cartesia_mlx/models/lm.py b/cartesia-mlx/cartesia_mlx/models/lm.py new file mode 100644 index 0000000..1a4a292 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/models/lm.py @@ -0,0 +1,205 @@ +import time + +import mlx.core as mx +import mlx.nn as nn + +from cartesia_mlx.layers.embedding import Embedding +from cartesia_mlx.layers.ffn import SwiGLU +from cartesia_mlx.layers.linear import Linear +from cartesia_mlx.layers.residual_block import ResidualBlock +from cartesia_mlx.layers.sa import SelfAttention +from cartesia_mlx.layers.sequence_model import SequenceModel +from cartesia_mlx.layers.ssd.ssd import SSD +from cartesia_mlx.utils.configure import Inherit, set_cfg, sub_cfg +from cartesia_mlx.utils.sample_utils import categorical_sampling, min_p_sampling, top_p_sampling +from cartesia_mlx.utils.tokenizer import Tokenizer + + +class LM(nn.Module): + """Generic language model backbone. + + Example: + import mlx.core as mx + import cartesia_mlx as cmx + + model = cmx.from_pretrained('cartesia-ai/mamba2-130m-8bit-mlx') + model.set_dtype(mx.float32) + + prompt = "Rene Descartes was" + + print(prompt, end="", flush=True) + for text in model.generate( + prompt, + max_tokens=500, + eval_every_n=5, + verbose=True, + top_p=0.99, + temperature=0.85, + ): + print(text, end="", flush=True) + """ + + base_cfg = dict( + _class_="models.lm.LM", + quantization_kwargs=None, + d_model=2048, + n_tokens=50288, + eos=50279, + tokenizer=sub_cfg(Tokenizer.base_cfg), + embedding=sub_cfg(Embedding.base_cfg, n_tokens=Inherit(), d_model=Inherit()), + model=sub_cfg( + SequenceModel.base_cfg, + n_layer_repeats=4, + unique_layers=[ + sub_cfg(ResidualBlock.base_cfg, layer=SSD.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SSD.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SwiGLU.base_cfg, stateful=False), + sub_cfg(ResidualBlock.base_cfg, layer=SSD.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SSD.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SwiGLU.base_cfg, stateful=False), + sub_cfg(ResidualBlock.base_cfg, layer=SelfAttention.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SSD.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SwiGLU.base_cfg, stateful=False), + sub_cfg(ResidualBlock.base_cfg, layer=SSD.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SSD.base_cfg), + sub_cfg(ResidualBlock.base_cfg, layer=SwiGLU.base_cfg, stateful=False), + ], + ), + head=sub_cfg( + Linear.base_cfg, + output_dim=Inherit(from_key="n_tokens"), + input_dim=Inherit(from_key="d_model"), + ), + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent, instantiate_children=True) + + def generate( + self, + prompt, + max_tokens=500, + eval_every_n=1, + verbose=False, + **sampling_kwargs, + ): + """Generates text given an initial prompt. + + Args: + prompt: The initial text prompt to start the generation. + max_tokens: The maximum number of tokens to generate. Default is 500. + eval_every_n: Number of tokens between yielding intermediate generated text. Default is 1. + verbose: If True, prints the tokens per second during generation. + **sampling_kwargs: Additional keyword arguments to control the sampling strategy. + - temperature: Sampling temperature to control the randomness of predictions. Default is 1.0. + - top_p: Nucleus sampling probability. If specified, the sampling will be restricted to the top p cumulative probability mass. + - min_p: Minimum probability sampling. If specified, tokens with a probability lower than this threshold will be discarded. + - min_tokens_to_keep: The minimum number of tokens to keep during `min_p` sampling. Only valid when `min_p` is specified. + + Yields: + str: Generated text at intervals specified by eval_every_n. + """ + prompt_ids = mx.array(self.tokenizer.tokenize([prompt])[0]) + tokens = [] + for n_tokens, token in enumerate( + self.generate_tokens(prompt_ids, max_tokens, verbose, **sampling_kwargs) + ): + tokens.append(token) + if (n_tokens + 1) % eval_every_n == 0: + yield "".join(self.tokenizer.detokenize(tokens)) + tokens = [] + if tokens: + yield "".join(self.tokenizer.detokenize(tokens)) + + def generate_tokens( + self, + prompt_ids, + max_tokens=500, + verbose=False, + **sampling_kwargs, + ): + """Generates tokens given the tokenized prompt. + + Args: + prompt_ids: Array of tokenized prompt IDs. + max_tokens: The maximum number of tokens to generate. Default is 1000. + verbose: If True, prints the tokens per second during generation. + **sampling_kwargs: Additional keyword arguments to control the sampling strategy. + - temperature: Sampling temperature to control the randomness of predictions. Default is 1.0. + - top_p: Nucleus sampling probability. If specified, the sampling will be restricted to the top p cumulative probability mass. + - min_p: Minimum probability sampling. If specified, tokens with a probability lower than this threshold will be discarded. + - min_tokens_to_keep: The minimum number of tokens to keep during `min_p` sampling. Only valid when `min_p` is specified. + + Yields: + int: Generated token IDs. + """ + logits, state = self.prefill(prompt_ids.reshape(1, -1)) + y = sample(logits, **sampling_kwargs) + + mx.async_eval(y) + + def _step(y): + nonlocal state + logits, state = self.step(y, state) + return sample(logits, **sampling_kwargs) + + next_y = _step(y) + mx.async_eval(next_y) + + yield y.item() + + if verbose: + start_time = time.time() + + y = next_y + + for n_tokens in range(max_tokens - 1): + y = next_y + next_y = _step(y) + mx.async_eval(next_y) + if y.item() == self.eos: + break + yield y.item() + + if verbose: + elapsed_time = time.time() - start_time + print("\n" + "-" * 50) + print(f"Tok/s: {n_tokens / elapsed_time:.2f}") + + def prefill(self, x): + """Prefills the model with given prompt.""" + x = self.embedding.encode(x) + x, state = self.model(x) + x_last = x[:, -1, :] + x_last = self.head(x_last) + return x_last, state + + def step(self, x, state=None): + """Autoregressive step for generating model response.""" + x = self.embedding.encode(x) + x, state = self.model.step(x, state=state) + x = self.head(x) + return x, state + + +def sample(logits, **sampling_kwargs): + """Sample from the logits using the specified sampling strategy.""" + if "min_p" in sampling_kwargs: + allowed_keys = {"min_p", "min_tokens_to_keep", "temperature"} + assert set(sampling_kwargs.keys()).issubset( + allowed_keys + ), f"keys {sampling_kwargs.keys()} not in {allowed_keys}" + return min_p_sampling(logits, **sampling_kwargs) + elif "top_p" in sampling_kwargs: + allowed_keys = {"top_p", "temperature"} + assert set(sampling_kwargs.keys()).issubset( + allowed_keys + ), f"keys {sampling_kwargs.keys()} not in {allowed_keys}" + return top_p_sampling(logits, **sampling_kwargs) + else: + allowed_keys = {"temperature"} + assert set(sampling_kwargs.keys()).issubset( + allowed_keys + ), f"keys {sampling_kwargs.keys()} not in {allowed_keys}" + return categorical_sampling(logits, **sampling_kwargs) diff --git a/cartesia-mlx/cartesia_mlx/models/mamba2.py b/cartesia-mlx/cartesia_mlx/models/mamba2.py new file mode 100644 index 0000000..57292f4 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/models/mamba2.py @@ -0,0 +1,29 @@ +from cartesia_mlx.models.lm import LM +from cartesia_mlx.utils.configure import sub_cfg + + +class Mamba2(LM): + """Mamba2 language model. + + Example: + import mlx.core as mx + import cartesia_mlx as cmx + + model = cmx.from_pretrained('cartesia-ai/mamba2-130m-8bit-mlx') + model.set_dtype(mx.float32) + + prompt = "Rene Descartes was" + + print(prompt, end="", flush=True) + for text in model.generate( + prompt, + max_tokens=500, + eval_every_n=5, + verbose=True, + top_p=0.99, + temperature=0.85, + ): + print(text, end="", flush=True) + """ + + base_cfg = sub_cfg(LM.base_cfg, _class_="models.mamba2.Mamba2") diff --git a/cartesia-mlx/cartesia_mlx/models/rene.py b/cartesia-mlx/cartesia_mlx/models/rene.py new file mode 100644 index 0000000..6ea7ada --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/models/rene.py @@ -0,0 +1,29 @@ +from cartesia_mlx.models.lm import LM +from cartesia_mlx.utils.configure import sub_cfg + + +class Rene(LM): + """Rene Language Model. + + Example: + import mlx.core as mx + import cartesia_mlx as cmx + + model = cmx.from_pretrained('cartesia-ai/Rene-v0.1-1.3b-4bit-mlx') + model.set_dtype(mx.float32) + + prompt = "Rene Descartes was" + + print(prompt, end="", flush=True) + for text in model.generate( + prompt, + max_tokens=500, + eval_every_n=5, + verbose=True, + top_p=0.99, + temperature=0.85, + ): + print(text, end="", flush=True) + """ + + base_cfg = sub_cfg(LM.base_cfg, _class_="models.rene.Rene") diff --git a/cartesia-mlx/cartesia_mlx/utils/configure.py b/cartesia-mlx/cartesia_mlx/utils/configure.py new file mode 100644 index 0000000..c36efe7 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/utils/configure.py @@ -0,0 +1,160 @@ +import collections +import copy +import importlib +from typing import Any, Dict, Optional + +from mlx.utils import tree_flatten, tree_unflatten + +from cartesia_mlx.utils.filter import filter_startswith_strip + + +def set_cfg(obj, cfg, parent=None, instantiate_children=False): + """Sets the configuration and optionally instantiates child objects. + + Args: + cfg: Configuration dictionary to update the base configuration with. + parent: Parent object to inherit configuration values from (optional). + instantiate_children: Whether to instantiate child objects specified in the configuration (default is False). + + Returns: + dict: The updated configuration dictionary. + """ + cfg = _resolve_inheritance(copy.deepcopy(cfg)) + this_cfg = _resolve_inheritance(copy.deepcopy(obj.base_cfg)) + this_cfg = _recursive_update(this_cfg, cfg) + this_cfg = _parent_update(this_cfg, parent) + obj.cfg = this_cfg + for k, v in obj.cfg.items(): + if isinstance(v, dict) and instantiate_children is True and "_class_" in v: + setattr(obj, k, instantiate(v, parent=obj)) + else: + setattr(obj, k, v) + return cfg + + +class Inherit: + """Inherit class for inheriting values from parent config.""" + + def __init__(self, default: Any = None, from_key: Optional[str] = None): + self.default = default + self.from_key = from_key + + def __repr__(self): + return f"Inherit(default={self.default}, from_key={self.from_key})" + + +def sub_cfg(base_cfg: Dict[str, Any], **kwargs): + """Overwrites the base_cfg with the kwargs provided.""" + cfg = copy.deepcopy(base_cfg) + cfg.update(kwargs) + return cfg + + +def class_from_path(path: str, package_name="cartesia_mlx"): + """Retrives class from path to the class.""" + path = package_name + "." + path + module_path, class_name = path.rsplit(".", 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + return cls + + +def instantiate(cfg: Dict[str, Any], parent: Optional[Any] = None): + """Instantiates a class from the config.""" + this_cfg = copy.deepcopy(cfg) + assert "_class_" in this_cfg, f"Class not found in config: {this_cfg}" + class_path = this_cfg.pop("_class_") + cls = class_from_path(class_path) + instance = cls(cfg=this_cfg, parent=parent) + return instance + + +def _resolve_inheritance(cfg): + """Removes all Inheritance objects from the configuration and by replacing them with inherited values.""" + + def resolve_value(flat_cfg, key): + value = flat_cfg[key] + if isinstance(value, Inherit): + # Replace the key with the key to inherit from + inherit_key = key.split(".")[-1] if value.from_key is None else value.from_key + + full_inherit_key = key.split(".") + + # No parent key to inherit from -> return default value + if len(full_inherit_key) < 2: + return value.default + + # If the second last key is an integer, remove it (this skips the lists in inheritance) + if full_inherit_key and _can_convert_to_int(full_inherit_key[-2]): + full_inherit_key.pop(-2) + + # Build the full key to inherit from + full_inherit_key = full_inherit_key[:-2] + + # Build final key to inherit from + full_inherit_key.append(inherit_key) + full_inherit_key = ".".join(full_inherit_key) + + # If the key to inherit from exists, return its value + if full_inherit_key in flat_cfg: + return flat_cfg[full_inherit_key] + + # Else, check if inheriting a whole subtree is possible + elif sub_tree := filter_startswith_strip(flat_cfg, full_inherit_key + "."): + return sub_tree + + # Otherwise, return the default value + return value.default + return value + + # Flatten the configuration dictionary + flat_cfg = dict(tree_flatten(copy.deepcopy(cfg))) + + # Sort the keys by the number of nodes (levels) in the key + sorted_keys = sorted(flat_cfg.keys(), key=lambda k: len(k.split("."))) + + # Resolve inheritance in the sorted order + for key in sorted_keys: + flat_cfg[key] = resolve_value(flat_cfg, key) + + # Unflatten the dictionary back to the original nested structure + + resolved_cfg = tree_unflatten(list(flat_cfg.items())) + return resolved_cfg + + +def _recursive_update(d: Dict[str, Any], u: Dict[str, Any]): + if u is None: + return d + for k, v in u.items(): + t = d.get(k, {}) + if isinstance(v, collections.abc.Mapping) and t is not None: + d[k] = _recursive_update(t, v) + else: + d[k] = v + return d + + +def _parent_update(this_cfg: Dict[str, Any], parent: Optional[Any]): + for k, v in this_cfg.items(): + if isinstance(v, Inherit): + kp = v.from_key if v.from_key is not None else k + if parent is not None and hasattr(parent, kp): + this_cfg[k] = copy.deepcopy(getattr(parent, kp)) + elif parent is None and v.default is not None: + this_cfg[k] = copy.deepcopy(v.default) + elif parent is not None and not hasattr(parent, kp): + this_cfg[k] = copy.deepcopy(v.default) + else: + raise ValueError( + f"Inhertitance error: Parent attribute {k} not found and no default provided." + ) + return this_cfg + + +def _can_convert_to_int(s: str) -> bool: + try: + int(s) + return True + except ValueError: + return False diff --git a/cartesia-mlx/cartesia_mlx/utils/filter.py b/cartesia-mlx/cartesia_mlx/utils/filter.py new file mode 100644 index 0000000..46c3c48 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/utils/filter.py @@ -0,0 +1,26 @@ +import copy +import re + + +def filter_startswith(dictionary: dict, prefix: str, invert: bool = False): + """Returns the subtree of a flattened dictionary that starts with a given prefix.""" + filtered_dict = copy.deepcopy(dictionary) + pattern = re.compile(re.escape(prefix) + ".*") + + def matches_pattern(key): + if invert: + return not pattern.match(key) + else: + return pattern.match(key) + + filtered_dict = {key: value for key, value in filtered_dict.items() if matches_pattern(key)} + return filtered_dict + + +def filter_startswith_strip(dictionary: dict, prefix: str): + """Returns the subtree of a flattened dictionary that starts with a given prefix and strips the prefix from the keys.""" + prefix_len = len(prefix) + filtered_dict = copy.deepcopy(dictionary) + filtered_dict = filter_startswith(filtered_dict, prefix, invert=False) + filtered_dict = {k[prefix_len:]: v for k, v in filtered_dict.items()} + return filtered_dict diff --git a/cartesia-mlx/cartesia_mlx/utils/load.py b/cartesia-mlx/cartesia_mlx/utils/load.py new file mode 100644 index 0000000..57a51c0 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/utils/load.py @@ -0,0 +1,84 @@ +import json +import logging +import os +import pathlib +from pathlib import Path +from typing import Optional, Union + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download +from mlx.utils import tree_unflatten + +from cartesia_mlx.utils.configure import instantiate + + +class ModelNotFoundError(Exception): + """Exception raised when the model is not found.""" + + def __init__(self, message): + """ + Args: + message: The error message. + """ + self.message = message + super().__init__(self.message) + + +def get_model_path( + path_or_hf_repo: Union[str, os.PathLike], + revision: Optional[str] = None, +) -> pathlib.Path: + """Get the local path of the model.""" + model_path = Path(path_or_hf_repo) + + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + revision=revision, + allow_patterns=["*.safetensors"], + ) + ) + model_path = next(model_path.glob("*.safetensors"), None) + if model_path is None: + raise ValueError(f"No model path found in '{path_or_hf_repo}'.") + return model_path + + +def instantiate_from_file(file_path) -> nn.Module: + """Instantiates a model from the safetensors checkpoint file. + + Args: + file_path: The path to the safetensors checkpoint file. + + Returns: + The mlx module. + """ + state_dict, cfg = mx.load(str(file_path), return_metadata=True) + cfg = [(k, json.loads(v)) for k, v in cfg.items()] + cfg = tree_unflatten(cfg) + model = instantiate(cfg) + state_dict_nested = tree_unflatten(list(state_dict.items())) + model.update(state_dict_nested) + mx.eval(model.parameters()) + model.eval() + return model + + +def from_pretrained( + path_or_hf_repo: Union[str, os.PathLike], + revision: Optional[str] = None, +): + """Load a model from a local path or huggingface repository. + + Args: + path_or_hf_repo: The path to the model or the huggingface repository. + revision: The revision of the model. + """ + model_path = get_model_path(path_or_hf_repo, revision) + if not model_path: + logging.error(f"No safetensors file found in {model_path}") + raise FileNotFoundError(f"No safetensors file found in {model_path}") + model = instantiate_from_file(model_path) + return model diff --git a/cartesia-mlx/cartesia_mlx/utils/sample_utils.py b/cartesia-mlx/cartesia_mlx/utils/sample_utils.py new file mode 100644 index 0000000..cf2d91e --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/utils/sample_utils.py @@ -0,0 +1,103 @@ +# Adapted from: https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/sample_utils.py +# Copyright © 2023-2024 Apple Inc. + +from functools import partial + +import mlx.core as mx + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def min_p_sampling( + logits: mx.array, + min_p: float, + min_tokens_to_keep: int = 1, + temperature=1.0, +) -> mx.array: + """ + Apply min-p sampling to the logits. + + Min-p keeps all tokens that are above a minimum probability, scaled by the + probability of the most likely token. As a result, the filter is more + aggressive given a very high-probability token. + + Args: + logits: The logits from the model's output. + min_p (float): Minimum token probability. Typical values are in the + 0.01-0.2 range, comparably selective as setting `top_p` in the + 0.99-0.8 range. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered. Default: ``1``. + + """ + if not (0 <= min_p <= 1.0): + raise ValueError(f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}") + if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): + raise ValueError( + f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}" + ) + # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 + + # Softmax probabilities + probs = mx.softmax(logits * (1 / temperature), axis=-1) + + # Indices sorted in decreasing order + sorted_indices = mx.argsort(-logits).squeeze(0) + sorted_probs = probs[..., sorted_indices] + + # Top probability + top_probs = probs[..., sorted_indices[0]] + + # Calculate the min_p threshold + scaled_min_p = min_p * top_probs + + # Mask tokens that have a probability less than the scaled min_p + tokens_to_remove = sorted_probs < scaled_min_p + tokens_to_remove[..., :min_tokens_to_keep] = False + + # Create pool of tokens with probability less than scaled min_p + selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) + + # Return sampled token + sorted_token = mx.random.categorical(mx.log(selected_probs)) + return sorted_indices[sorted_token] + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: + """ + Apply top-p (nucleus) sampling to logits. + + Args: + logits: The logits from the model's output. + top_p: The cumulative probability threshold for top-p filtering. + temperature: Temperature parameter for softmax distribution reshaping. + + Returns: + token selected based on the top-p criterion. + """ + # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 + probs = mx.softmax(logits * (1 / temperature), axis=-1) + + # sort probs in ascending order + sorted_indices = mx.argsort(probs, axis=-1) + sorted_probs = probs[..., sorted_indices.squeeze(0)] + + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + + # select tokens with cumulative probs below threshold + top_probs = mx.where( + cumulative_probs > 1 - top_p, + sorted_probs, + 0, + ) + + sorted_token = mx.random.categorical(mx.log(top_probs)) + token = sorted_indices.squeeze(0)[sorted_token] + + return token + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def categorical_sampling(logits, temperature=1.0): + """Standard categorical_sampling.""" + return mx.random.categorical(logits * (1 / temperature)) diff --git a/cartesia-mlx/cartesia_mlx/utils/tokenizer.py b/cartesia-mlx/cartesia_mlx/utils/tokenizer.py new file mode 100644 index 0000000..1d99549 --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/utils/tokenizer.py @@ -0,0 +1,34 @@ +import warnings + +from transformers import AutoTokenizer + +from cartesia_mlx.utils.configure import set_cfg + +warnings.filterwarnings( + "ignore", category=FutureWarning, module="transformers.tokenization_utils_base" +) + + +class Tokenizer: + """Tokenizer Wrapper for Huggingface AutoTokenizer.""" + + base_cfg = dict( + _class_="utils.tokenizer.Tokenizer", + name="allenai/OLMo-1B-hf", + ) + + def __init__(self, cfg=None, parent=None): + super().__init__() + set_cfg(self, cfg, parent) + self.name: str + self.tokenizer = AutoTokenizer.from_pretrained(self.name) + self.tokenizer.decode([0]) # warm up + + def tokenize(self, x: list) -> list: + """Tokenize list of strings.""" + tokens = self.tokenizer(x).input_ids + return tokens + + def detokenize(self, tokens: list) -> list: + """Detokenize list of tokens.""" + return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) diff --git a/cartesia-mlx/cartesia_mlx/version.py b/cartesia-mlx/cartesia_mlx/version.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/cartesia-mlx/cartesia_mlx/version.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/cartesia-mlx/example.py b/cartesia-mlx/example.py new file mode 100644 index 0000000..52abf8b --- /dev/null +++ b/cartesia-mlx/example.py @@ -0,0 +1,42 @@ +import argparse + +import mlx.core as mx + +import cartesia_mlx as cmx + +parser = argparse.ArgumentParser( + description="Simple example script to run a Cartesia MLX language model." +) +parser.add_argument( + "--model", default="cartesia-ai/Rene-v0.1-1.3b-4bit-mlx", help="The model name." +) +parser.add_argument("--prompt", default="Rene Descartes was") +parser.add_argument( + "--max-tokens", type=int, default=500, help="Maximum number of tokens to generate." +) +parser.add_argument( + "--top-p", + type=float, + default=0.99, + help="Top-p sampling (a value of 1 is equivalent to standard sampling).", +) +parser.add_argument( + "--temperature", type=float, default=0.85, help="Temperature scaling parameter." +) +args = parser.parse_args() + +model = cmx.from_pretrained(args.model) +model.set_dtype(mx.float32) + +prompt = args.prompt + +print(prompt, end="", flush=True) +for text in model.generate( + prompt, + max_tokens=args.max_tokens, + eval_every_n=5, + verbose=True, + top_p=args.top_p, + temperature=args.temperature, +): + print(text, end="", flush=True) diff --git a/cartesia-mlx/requirements.txt b/cartesia-mlx/requirements.txt new file mode 100644 index 0000000..c13df16 --- /dev/null +++ b/cartesia-mlx/requirements.txt @@ -0,0 +1,9 @@ +mlx==0.16.1 +transformers>=4.40.0 +setuptools>=42 +datasets +torch +numpy<2 # some packages require numpy<2 +requests +# cartesia-metal +huggingface_hub diff --git a/cartesia-mlx/setup.py b/cartesia-mlx/setup.py new file mode 100644 index 0000000..3c89a36 --- /dev/null +++ b/cartesia-mlx/setup.py @@ -0,0 +1,144 @@ +import io +import os + +# Note: To use the 'upload' functionality of this file, you must: +# $ pipenv install twine --dev +import sys +from distutils.util import convert_path +from shutil import rmtree + +from setuptools import Command, find_packages, setup + +PACKAGE_DIR = "cartesia_mlx" +main_ns = {} +ver_path = convert_path(os.path.join(PACKAGE_DIR, "version.py")) +with open(ver_path) as ver_file: + exec(ver_file.read(), main_ns) + + +# Package metadata. +NAME = "cartesia-mlx" +DESCRIPTION = "The official Cartesia MLX library." +URL = "https://github.com/cartesia-ai/edge" +EMAIL = "support@cartesia.ai" +AUTHOR = "Cartesia, Inc." +REQUIRES_PYTHON = ">=3.9.0" +VERSION = main_ns["__version__"] + + +# What packages are required for this module to be executed? +def get_requirements(path): + with open(path, "r") as f: + out = f.read().splitlines() + + out = [line.strip() for line in out] + return out + + +REQUIRED = get_requirements("requirements.txt") +# REQUIRED_DEV = get_requirements("requirements-dev.txt") +# What packages are optional? +EXTRAS = {} + + +# The rest you shouldn't have to touch too much :) +# ------------------------------------------------ +# Except, perhaps the License and Trove Classifiers! +# If you do change the License, remember to change the Trove Classifier for that! + +here = os.path.abspath(os.path.dirname(__file__)) + +# Import the README and use it as the long-description. +# Note: this will only work if 'README.md' is present in your MANIFEST.in file! +try: + with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: + long_description = "\n" + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + +# Load the package's __version__.py module as a dictionary. +about = {} +if not VERSION: + project_slug = NAME.lower().replace("-", "_").replace(" ", "_") + with open(os.path.join(here, project_slug, "__version__.py")) as f: + exec(f.read(), about) +else: + about["__version__"] = VERSION + + +class UploadCommand(Command): + """Support setup.py upload.""" + + description = "Build and publish the package." + user_options = [("skip-upload", "u", "skip git tagging and pypi upload")] + boolean_options = ["skip-upload"] + + @staticmethod + def status(s): + """Prints things in bold.""" + print("\033[1m{0}\033[0m".format(s)) + + def initialize_options(self): + self.skip_upload = False + + def finalize_options(self): + self.skip_upload = bool(self.skip_upload) + + def run(self): + try: + self.status("Removing previous builds…") + rmtree(os.path.join(here, "dist")) + rmtree(os.path.join(here, "build")) + except OSError: + pass + + self.status("Building Source and Wheel (universal) distribution…") + os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) + + if self.skip_upload: + self.status("Skipping git tagging and pypi upload") + sys.exit() + + self.status("Uploading the package to PyPI via Twine…") + os.system("twine upload dist/*") + + self.status("Pushing git tags…") + os.system("git tag v{0}".format(about["__version__"])) + os.system("git push --tags") + + sys.exit() + + +# Where the magic happens: +setup( + name=NAME, + version=about["__version__"], + description=DESCRIPTION, + long_description=long_description, + long_description_content_type="text/markdown", + author=AUTHOR, + author_email=EMAIL, + python_requires=REQUIRES_PYTHON, + url=URL, + packages=find_packages( + exclude=[ + "scratch", + "tests", + "*.tests", + "*.tests.*", + "tests.*", + ] + ), + install_requires=REQUIRED, + extras_require=EXTRAS, + include_package_data=True, + classifiers=[ + # Trove classifiers + # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + # $ setup.py publish support. + cmdclass={"upload": UploadCommand}, +) diff --git a/cartesia-pytorch/README.md b/cartesia-pytorch/README.md new file mode 100644 index 0000000..75f5b8f --- /dev/null +++ b/cartesia-pytorch/README.md @@ -0,0 +1,68 @@ +--- +license: apache-2.0 +language: +- en +datasets: +- allenai/dolma +tags: +- rene +- mamba +- cartesia +--- + +# Model Card for Rene + +Rene is a 1.3 billion-parameter language model trained by [Cartesia](https://cartesia.ai). +Rene has a hybrid architecture based on [Mamba-2](https://arxiv.org/abs/2405.21060), with feedforward and sliding window attention layers interspersed. +It uses the [allenai/OLMo-1B-hf](https://huggingface.co/allenai/OLMo-1B-hf) tokenizer. +Rene was pretrained on 1.5 trillion tokens of the [Dolma-1.7](https://huggingface.co/datasets/allenai/dolma) dataset. +For more details, see our [blog post](https://cartesia.ai/blog/on-device). + +## Usage +### Installation +The Rene model depends on the `cartesia-pytorch` package, which can be installed with `pip` as follows: +```shell +pip install --no-binary :all: cartesia-pytorch +``` + +### Generation example +```python +from cartesia_pytorch import ReneLMHeadModel +from transformers import AutoTokenizer + +model = ReneLMHeadModel.from_pretrained("cartesia-ai/Rene-v0.1-1.3b-pytorch").half().cuda() +tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf") +in_message = ["Rene Descartes was"] +inputs = tokenizer(in_message, return_tensors="pt") +outputs = model.generate(inputs.input_ids.cuda(), max_length=50, top_k=100, top_p=0.99) +out_message = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] +print(out_message) +# Example output: "Rene Descartes was a French mathematician, philosopher, and scientist. Descartes is famously credited for creating the Cartesian coordinate system: a 3 dimensional representation of points, vectors, and directions. This work is, for the most part" ... +``` + +### Evaluation example +You can use our `cartesia_lm_eval` wrapper around the [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main) to evaluate our model on standard text benchmarks. Example command (clone this repo and run the below from within the `cartesia-pytorch` directory): +```shell +python -m evals.cartesia_lm_eval --model rene_ssm --model_args pretrained=cartesia-ai/Rene-v0.1-1.3b-pytorch,trust_remote_code=True --trust_remote_code --tasks copa,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --cache_requests true --batch_size auto:4 --output_path outputs/rene_evals/ +``` +## Results on common benchmarks +| Model | Params (B) | Train Tokens | COPA | HellaSwag | MMLU (5-shot) | PIQA | ARC-e | ARC-c | WinoGrande | OpenBookQA | Average | +|------------------------------------------------|------------|--------------|------|-----------|---------------|------|-------|-------|------------|------------|---------| +| allenai/OLMo-1B-hf | 1.2 | 3.0 | 82.0 | 62.9 | 26.2 | 75.1 | 57.4 | 31.1 | 60.0 | 36.2 | 53.9 | +| apple/OpenELM-1\_1B | 1.1 | 1.5 | 81.0 | 64.8 | 27.1 | 75.6 | 55.4 | 32.3 | 61.9 | 36.2 | 54.3 | +| state-spaces/mamba2-1.3b | 1.3 | 0.3 | 82.0 | 60.0 | 25.8 | 73.7 | 64.2 | 33.3 | 61.0 | 37.8 | 54.7 | +| microsoft/phi-1\_5 | 1.4 | 0.15 | 79.0 | 62.6 | 42.5 | 75.5 | 73.2 | 48.0 | 72.8 | 48.0 | 62.7 | +| Qwen/Qwen2-1.5B | 1.5 | 7.0 | 80.0 | 65.4 | 56.0 | 75.5 | 60.4 | 35.0 | 65.8 | 36.4 | 59.3 | +| RWKV/rwkv-6-world-1b6 | 1.6 | 1.1 | 84.0 | 58.3 | 25.9 | 73.5 | 56.7 | 34.1 | 60.0 | 37.4 | 53.7 | +| stabilityai/stablelm-2-1\_6b | 1.6 | 4.0 | 86.0 | 69.0 | 38.1 | 76.7 | 68.1 | 38.9 | 63.6 | 38.8 | 59.9 | +| HuggingFaceTB/SmolLM-1.7B | 1.7 | 1.0 | 76.0 | 65.8 | 29.9 | 76.1 | 73.5 | 46.4 | 60.9 | 42.0 | 58.8 | +| h2oai/h2o-danube2-1.8b-base | 1.8 | 3.0 | 82.0 | 72.4 | 39.9 | 77.3 | 69.0 | 39.9 | 63.9 | 41.4 | 60.7 | +| google/recurrentgemma-2b | 2.7 | 2.0 | 62.0 | 61.8 | 32.3 | 68.8 | 46.4 | 29.9 | 57.1 | 29.0 | 48.4 | +| cognitivecomputations/TinyDolphin-2.8.1-1.1b | 1.1 | | 71.0 | 59.9 | 25.7 | 73.1 | 55.8 | 33.0 | 59.7 | 36.6 | 51.9 | +| cartesia-ai/Rene-v0.1-1.3b-pytorch (OUR MODEL) | 1.3 | 1.5 | 82.0 | 69.4 | 32.6 | 77.5 | 61.7 | 34.4 | 62.9 | 39.2 | 57.5 | + +## Bias, Risks, and Limitations +Rene is a pretrained base model which has not undergone any alignment or instruction tuning, and therefore does not have any moderation or safety guarantees. Users should implement appropriate guardrails and moderation mechanisms based on their particular needs in order to ensure responsible and ethical usage. + +## About Cartesia +At [Cartesia](https://cartesia.ai/), we're building real-time multimodal intelligence for every device. diff --git a/cartesia-pytorch/cartesia_pytorch/__init__.py b/cartesia-pytorch/cartesia_pytorch/__init__.py new file mode 100644 index 0000000..48cd8cb --- /dev/null +++ b/cartesia-pytorch/cartesia_pytorch/__init__.py @@ -0,0 +1,4 @@ +from cartesia_pytorch.configuration_rene import ReneConfig +from cartesia_pytorch.rene import ReneLMHeadModel + +__all__ = ["ReneConfig", "ReneLMHeadModel"] diff --git a/cartesia-pytorch/cartesia_pytorch/configuration_rene.py b/cartesia-pytorch/cartesia_pytorch/configuration_rene.py new file mode 100644 index 0000000..5bb803a --- /dev/null +++ b/cartesia-pytorch/cartesia_pytorch/configuration_rene.py @@ -0,0 +1,103 @@ +from typing import Dict, List, Optional + +from transformers.configuration_utils import PretrainedConfig + + +class ReneConfig(PretrainedConfig): + r"""Configuration class for the Rene model. + + This is the configuration class to store the configuration of a [`ReneLMHeadModel`]. + It is used to instantiate a Rene model according to the specified arguments, + defining the model architecture. Instantiating a configuration with the defaults will yield + a similar configuration to that of the Rene-v0.1-1.3b-pytorch model. + [cartesia-ai/Rene-v0.1-1.3b-pytorch](https://huggingface.co/cartesia-ai/Rene-v0.1-1.3b-pytorch) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + d_model (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + n_layer (`int`, *optional*, defaults to 48): + Number of architecture blocks. + vocab_size (`int`, *optional*, defaults to 50280): + Vocabulary size of the Rene model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ReneModel`]. + ssm_cfg (`dict`, *optional*): + Configuration parameters for the SSM layers. + attn_layer_idx (`List[int]`, *optional*): + Indices of the architecture blocks that should have attention layers. + attn_cfg (`dict`, *optional*): + Configuration parameters for the attention layers. + mlp_layer_idx (`List[int]`, *optional*): + Indices of the architecture blocks that should have MLP layers. + mlp_cfg (`dict`, *optional*): + Configuration parameters for the MLP layers. + rms_norm (`bool`, *optional*, defaults to `True`): + Whether to use RMSNorm (instead of LayerNorm). + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether to keep residual values in fp32. + pad_vocab_size_multiple (`int`, *optional*, defaults to 16): + Pad the vocabulary size up to the next multiple of this value. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token. + bos_token_id (`int`, *optional*): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 50279): + The id of the "end-of-sequence" token. + """ + + model_type = "rene" + + def __init__( + self, + d_model: int = 2048, + n_layer: int = 48, + vocab_size: int = 50280, + ssm_cfg: Optional[Dict] = None, + attn_layer_idx: Optional[List] = None, + attn_cfg: Optional[Dict] = None, + mlp_layer_idx: Optional[List] = None, + mlp_cfg: Optional[Dict] = None, + rms_norm: bool = True, + residual_in_fp32: bool = True, + pad_vocab_size_multiple: int = 16, + tie_word_embeddings: bool = True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + **kwargs, + ): + if ssm_cfg is None: + ssm_cfg = {} + if attn_layer_idx is None: + attn_layer_idx = [] + if attn_cfg is None: + attn_cfg = {} + if mlp_layer_idx is None: + mlp_layer_idx = [] + if mlp_cfg is None: + mlp_cfg = {} + + self.d_model = d_model + self.n_layer = n_layer + self.vocab_size = vocab_size + self.ssm_cfg = ssm_cfg + self.attn_layer_idx = attn_layer_idx + self.attn_cfg = attn_cfg + self.mlp_layer_idx = mlp_layer_idx + self.mlp_cfg = mlp_cfg + self.rms_norm = rms_norm + self.residual_in_fp32 = residual_in_fp32 + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.tie_word_embeddings = tie_word_embeddings + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/cartesia-pytorch/cartesia_pytorch/rene.py b/cartesia-pytorch/cartesia_pytorch/rene.py new file mode 100644 index 0000000..90ace8b --- /dev/null +++ b/cartesia-pytorch/cartesia_pytorch/rene.py @@ -0,0 +1,435 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from flash_attn import flash_attn_with_kvcache +from mamba_ssm.models.mixer_seq_simple import _init_weights +from mamba_ssm.modules.mamba2 import Mamba2 +from mamba_ssm.modules.mha import _update_kv_cache +from mamba_ssm.utils.generation import GenerationMixin as MambaGenerationMixin +from transformers.modeling_outputs import CausalLMOutput +from transformers.modeling_utils import PreTrainedModel + +from .configuration_rene import ReneConfig + + +class ReneMLP(nn.Module): + """One-hidden-layer network with GELU activation. + + Args: + d_input: Block input dimension. + d_output: Block output dimension. + expand: Block expansion factor. + bias: Use biases in linear layers. + """ + + def __init__(self, d_input, d_output=None, expand=3, bias=True, device=None, dtype=None): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.d_input = d_input + self.d_output = d_input if d_output is None else d_output + self.d_inner = int(round(expand * d_input)) + self.in_proj = nn.Linear(self.d_input, self.d_inner, bias=bias, **factory_kwargs) + self.activation = nn.GELU() + self.out_proj = nn.Linear(self.d_inner, self.d_input, bias=bias, **factory_kwargs) + + def forward(self, x, inference_params=None): + """Forward pass through the MLP module.""" + y = self.in_proj(x) + y = self.activation(y) + y = self.out_proj(y) + return y + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for ReneMLP. (There is nothing to cache for this module).""" + return None + + +class ReneMHA(nn.Module): + """Multi-head self-attention. Adapted from mamba_ssm MHA class.""" + + def __init__( + self, + embed_dim, + num_heads, + num_heads_kv=None, + head_dim=None, # If None, use embed_dim // num_heads + qkv_proj_bias=True, + out_proj_bias=True, + softmax_scale=None, + causal=True, + sliding_window_length=None, # If None, infinite context + layer_idx=None, + device=None, + dtype=None, + ) -> None: + """ + num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.embed_dim = embed_dim + self.layer_idx = layer_idx + self.softmax_scale = softmax_scale + self.causal = causal + assert self.causal, "Rene does not yet support non-causal modeling" + + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert ( + self.num_heads % self.num_heads_kv == 0 + ), "num_heads must be divisible by num_heads_kv" + if head_dim is None: + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + out_dim = self.head_dim * self.num_heads + + self.sliding_window_length = sliding_window_length + if self.sliding_window_length is None: + self.window_size = (-1, -1) + else: + self.window_size = (self.sliding_window_length - 1, 0) # for flash_attn + + self.in_proj = nn.Linear(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) + self.out_proj = nn.Linear(out_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + """Allocate inference cache for the multi-head self-attention module.""" + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + kv_cache = torch.empty( + batch_size, + max_seqlen, + 2, + self.num_heads_kv, + self.head_dim, + dtype=dtype, + device=device, + ) + return kv_cache, None + + def _pytorch_attn(self, q, kv): + k, v = kv.unbind(dim=-3) + k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv) + v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv) + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + L, S = q.size(-2), k.size(-2) + if S > self.sliding_window_length: + attn_mask = ( + torch.ones(L, S, dtype=torch.bool) + .tril(diagonal=0) + .triu(-self.window_size[0]) + .to(device=q.device) + ) + # Since we pass in an attn_mask explicitly, we need to pass is_causal=False to + # `scaled_dot_product_attention` (even though the attn_mask itself is in fact causal). + is_causal_arg = False + else: + # The previous branch would also handle this case correctly, but it is more efficient + # to omit the attn_mask when we don't need it. + attn_mask = None + is_causal_arg = True + return F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=is_causal_arg, scale=self.softmax_scale + ).transpose(1, 2) + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim).""" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then compute attention.""" + if inference_params.seqlen_offset == 0 or flash_attn_with_kvcache is None: + # TODO: this only uses seqlen_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self._pytorch_attn(q, kv) + else: + batch = q.shape[0] + kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx] + kv_cache = kv_cache[:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + return flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.softmax_scale, + causal=self.causal, + window_size=self.window_size, + ) + + def forward(self, x, inference_params=None): + """Forward pass through the multi-head self-attention module.""" + if ( + inference_params is not None + and self.layer_idx not in inference_params.key_value_memory_dict + ): + inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache( + x.shape[0], inference_params.max_seqlen, dtype=x.dtype + ) + qkv = self.in_proj(x) + q, kv = qkv.split( + [self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1 + ) + q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) + kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) + if inference_params is None: + context = self._pytorch_attn(q, kv) + else: + context = self._update_kvcache_attention(q, kv, inference_params) + context = rearrange(context, "... h d -> ... (h d)") + out = self.out_proj(context) + return out + + +class Block(nn.Module): + """Simple residual block with normalization that wraps an inner "mixer" module.""" + + def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, residual_in_fp32=False): + """ + dim: The dimension of the input data. + mixer_cls: The class of the mixer module. + norm_cls: The class of the normalization module. + residual_in_fp32: Whether to keep residuals in fp32. + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.norm = norm_cls(dim) + self.mixer = mixer_cls(dim) + + def forward(self, x, inference_params=None, **mixer_kwargs): + """Forward pass through the block.""" + y = self.norm(x.to(dtype=self.norm.weight.dtype)) + y = self.mixer(y, inference_params=inference_params, **mixer_kwargs) + + residual = x + if self.residual_in_fp32: + residual = residual.to(torch.float32) + y = y + residual + y = y.to(dtype=x.dtype) + + return y + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for the mixer module.""" + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + +def _create_block( + d_model, + norm_cls, + ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, + mlp_layer_idx=None, + mlp_cfg=None, + residual_in_fp32=False, + layer_idx=None, + device=None, + dtype=None, +): + factory_kwargs = {"device": device, "dtype": dtype} + if ssm_cfg is None: + ssm_cfg = {} + if attn_layer_idx is None: + attn_layer_idx = [] + if attn_cfg is None: + attn_cfg = {} + if mlp_layer_idx is None: + mlp_layer_idx = [] + if mlp_cfg is None: + mlp_cfg = {} + if layer_idx in attn_layer_idx: + mixer_cls = partial(ReneMHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) + elif layer_idx in mlp_layer_idx: + mixer_cls = partial(ReneMLP, **mlp_cfg, **factory_kwargs) + else: + mixer_cls = partial(Mamba2, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + return Block(d_model, mixer_cls, norm_cls=norm_cls, residual_in_fp32=residual_in_fp32) + + +class MixerModel(nn.Module): + """Adapted from mamba_ssm.models.mixer_seq_simple.MixerModel.""" + + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, + mlp_layer_idx=None, + mlp_cfg=None, + norm_epsilon: float = 1e-5, + rms_norm: bool = False, + initializer_cfg=None, + residual_in_fp32=False, + device=None, + dtype=None, + ) -> None: + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.residual_in_fp32 = residual_in_fp32 + + if rms_norm: + from mamba_ssm.ops.triton.layer_norm import RMSNorm as norm_cls_base + else: + norm_cls_base = nn.LayerNorm + norm_cls = partial(norm_cls_base, eps=norm_epsilon, **factory_kwargs) + + self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) + + self.layers = nn.ModuleList( + [ + _create_block( + d_model, + norm_cls=norm_cls, + ssm_cfg=ssm_cfg, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, + mlp_layer_idx=mlp_layer_idx, + mlp_cfg=mlp_cfg, + residual_in_fp32=residual_in_fp32, + layer_idx=i, + **factory_kwargs, + ) + for i in range(n_layer) + ] + ) + + self.norm_f = norm_cls(d_model) + + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + n_residuals_per_layer=1, + ) + ) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache for all layers.""" + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def forward(self, input_ids, inference_params=None, **mixer_kwargs): + """Forward pass through the model.""" + hidden_states = self.embedding(input_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, inference_params=inference_params, **mixer_kwargs) + hidden_states = self.norm_f(hidden_states.to(dtype=self.norm_f.weight.dtype)) + return hidden_states + + +class ReneLMHeadModel(PreTrainedModel, MambaGenerationMixin): + """ + Rene language model architecture. + Based on mamba_ssm.models.mixer_seq_simple.MambaLMHeadModel, with several adaptations. + """ + + config_class = ReneConfig + base_model_prefix = "backbone" + _no_split_modules = ["Block", "Mamba2"] + supports_gradient_checkpointing = True + _is_stateful = True + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config: ReneConfig, + initializer_cfg=None, + device=None, + dtype=None, + ) -> None: + super().__init__(config) + d_model = config.d_model + n_layer = config.n_layer + vocab_size = config.vocab_size + ssm_cfg = config.ssm_cfg + attn_layer_idx = config.attn_layer_idx + attn_cfg = config.attn_cfg + mlp_layer_idx = config.mlp_layer_idx + mlp_cfg = config.mlp_cfg + rms_norm = config.rms_norm + residual_in_fp32 = config.residual_in_fp32 + pad_vocab_size_multiple = config.pad_vocab_size_multiple + factory_kwargs = {"device": device, "dtype": dtype} + + if set(attn_layer_idx).intersection(mlp_layer_idx): + raise ValueError(f"Conflicting {attn_layer_idx=} and {mlp_layer_idx=}") + + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + + self.backbone = MixerModel( + d_model=d_model, + n_layer=n_layer, + vocab_size=vocab_size, + ssm_cfg=ssm_cfg, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, + mlp_layer_idx=mlp_layer_idx, + mlp_cfg=mlp_cfg, + rms_norm=rms_norm, + initializer_cfg=initializer_cfg, + residual_in_fp32=residual_in_fp32, + **factory_kwargs, + ) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + + # Initialize weights + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + self.tie_weights() + + def tie_weights(self): + """Tie embeddings and softmax layer weights if specified by config.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.backbone.embedding.weight + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + """Allocate inference cache.""" + return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward( + self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs + ): + """ + "position_ids" is just to be compatible with Transformer generation. We don't use it. + num_last_tokens: if > 0, only return the logits for the last n tokens. + """ + hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + lm_logits = self.lm_head(hidden_states) + + return CausalLMOutput(logits=lm_logits) + + def generate(self, *args, **kwargs): + """ + Calls the custom `generate` method from `mamba_ssm.utils.generation.GenerationMixin`. + Refer to that method for argument names and defaults. + """ + return MambaGenerationMixin.generate(self, *args, **kwargs) diff --git a/cartesia-pytorch/cartesia_pytorch/version.py b/cartesia-pytorch/cartesia_pytorch/version.py new file mode 100644 index 0000000..4abfd8a --- /dev/null +++ b/cartesia-pytorch/cartesia_pytorch/version.py @@ -0,0 +1 @@ +__version__ = "0.0.1rc1" diff --git a/cartesia-pytorch/evals/cartesia_lm_eval.py b/cartesia-pytorch/evals/cartesia_lm_eval.py new file mode 100644 index 0000000..1a18b87 --- /dev/null +++ b/cartesia-pytorch/evals/cartesia_lm_eval.py @@ -0,0 +1,72 @@ +""" +This is a wrapper for the lm-evaluation-harness that enables evaluating the Rene model. +Other standard models can still be evaluated with this script. +The command-line interface is the same as that of the standard lm-evaluation-harness. +To evaluate a Rene-class model, pass `rene_ssm` to the `--model` argument. +""" + +from typing import Optional, Union + +import lm_eval.models.utils +import torch +from lm_eval.__main__ import cli_evaluate +from lm_eval.api.registry import register_model +from lm_eval.models.huggingface import HFLM + + +@register_model("rene_ssm") +class ReneLMWrapper(HFLM): + """Wrapper for Rene model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + # rene currently only supports causal models + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + tokenizer=kwargs.pop("tokenizer", "allenai/OLMo-1B-hf"), + max_length=kwargs.pop("max_length", 4096), + **kwargs, + ) + + def _get_config(self, pretrained: str, **kwargs) -> None: + from cartesia_pytorch.configuration_rene import ReneConfig + + self._config = ReneConfig.from_pretrained(pretrained) + + def _create_model( + self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs + ) -> None: + from cartesia_pytorch.rene import ReneLMHeadModel + + self._model = ReneLMHeadModel.from_pretrained( + pretrained, + device=self._device, + dtype=torch.float16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + for key in ("do_sample", "attention_mask"): + if key in generation_kwargs: + generation_kwargs.pop(key) + + # The custom GenerationMixin imported from mamba_ssm currently does not support + # passing stopping criteria. + # For the time being, we simply generate to max length, then truncate (equivalent result). + # This should be revisited to speed up generation + # stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, context.shape[0]) + + return self.model.generate( + input_ids=context, + max_length=max_length, + # stopping_criteria=stopping_criteria, + # pad_token_id=self.tokenizer.pad_token_id, + # use_cache=True, + **generation_kwargs, + ) + + +if __name__ == "__main__": + cli_evaluate() diff --git a/cartesia-pytorch/setup.py b/cartesia-pytorch/setup.py new file mode 100644 index 0000000..8433898 --- /dev/null +++ b/cartesia-pytorch/setup.py @@ -0,0 +1,154 @@ +import io +import os +import subprocess + +# Note: To use the 'upload' functionality of this file, you must: +# $ pipenv install twine --dev +import sys +from distutils.util import convert_path +from shutil import rmtree + +from setuptools import Command, find_packages, setup + +PACKAGE_DIR = "cartesia_pytorch" +main_ns = {} +ver_path = convert_path(os.path.join(PACKAGE_DIR, "version.py")) +with open(ver_path) as ver_file: + exec(ver_file.read(), main_ns) + + +# Package metadata. +NAME = "cartesia-pytorch" +DESCRIPTION = "The official Cartesia PyTorch library." +URL = "https://github.com/cartesia-ai/edge" +EMAIL = "support@cartesia.ai" +AUTHOR = "Cartesia, Inc." +REQUIRES_PYTHON = ">=3.9.0" +VERSION = main_ns["__version__"] + +# What packages are required for this module to be executed? +def get_requirements(path): + with open(path, "r") as f: + out = f.read().splitlines() + + out = [line.strip() for line in out] + return out + + + +# The rest you shouldn't have to touch too much :) +# ------------------------------------------------ +# Except, perhaps the License and Trove Classifiers! +# If you do change the License, remember to change the Trove Classifier for that! + +here = os.path.abspath(os.path.dirname(__file__)) + +# Import the README and use it as the long-description. +# Note: this will only work if 'README.md' is present in your MANIFEST.in file! +try: + with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: + long_description = "\n" + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + +# Load the package's __version__.py module as a dictionary. +about = {} +if not VERSION: + project_slug = NAME.lower().replace("-", "_").replace(" ", "_") + with open(os.path.join(here, project_slug, "__version__.py")) as f: + exec(f.read(), about) +else: + about["__version__"] = VERSION + + +# handle dependencies manually to circumvent potential package conflicts +dependency_commands = [ + "pip install --upgrade torch==2.4.0", + "pip install causal-conv1d==1.4.0", + "pip install lm-eval==0.4.3", + "pip install --upgrade mamba-ssm==2.2.2", + "pip install --upgrade flash-attn==2.6.3", + "pip install --upgrade torchaudio==2.4.0", + "pip install --upgrade torchvision==0.19.0", + "pip uninstall -y transformer_engine", +] +for dependency_command in dependency_commands: + subprocess.check_call(f"{sys.executable} -m {dependency_command}", shell=True) + + +class UploadCommand(Command): + """Support setup.py upload.""" + + description = "Build and publish the package." + user_options = [("skip-upload", "u", "skip git tagging and pypi upload")] + boolean_options = ["skip-upload"] + + @staticmethod + def status(s): + """Prints things in bold.""" + print("\033[1m{0}\033[0m".format(s)) + + def initialize_options(self): + self.skip_upload = False + + def finalize_options(self): + self.skip_upload = bool(self.skip_upload) + + def run(self): + try: + self.status("Removing previous builds…") + rmtree(os.path.join(here, "dist")) + rmtree(os.path.join(here, "build")) + except OSError: + pass + + self.status("Building Source and Wheel (universal) distribution…") + os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) + + if self.skip_upload: + self.status("Skipping git tagging and pypi upload") + sys.exit() + + self.status("Uploading the package to PyPI via Twine…") + os.system("twine upload dist/*") + + self.status("Pushing git tags…") + os.system("git tag v{0}".format(about["__version__"])) + os.system("git push --tags") + + sys.exit() + + + +# Where the magic happens: +setup( + name=NAME, + version=about["__version__"], + description=DESCRIPTION, + long_description=long_description, + long_description_content_type="text/markdown", + author=AUTHOR, + author_email=EMAIL, + python_requires=REQUIRES_PYTHON, + url=URL, + packages=find_packages( + exclude=[ + "scratch", + "tests", + "*.tests", + "*.tests.*", + "tests.*", + ] + ), + include_package_data=True, + zip_safe=False, + classifiers=[ + # Trove classifiers + # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + # $ setup.py publish support. + cmdclass={"upload": UploadCommand}, +) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..fa61416 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,63 @@ +[tool.ruff] +# Add more rule codes as needed +lint.extend-select = [ + "D", # pydocstyle - to replace docformatter +] + +# Ignore specific rules +lint.ignore = [ + "E731", # Do not assign a lambda expression, use a def + "E203", # Whitespace before ':' + "E741", # Ambiguous variable name: `l` + "W605", # Invalid escape sequence + # "F401", # Ignore unused imports in __init__.py + + # Docstring Rules to Ignore + "D100", # Missing docstring in public module + # "D101", # Missing docstring in public class + # "D102", # Missing docstring in public method + # "D103", # Missing docstring in public function + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic method + # "D106", # Missing docstring in public nested class + "D107", # Missing docstring in __init__ + "D205", # 1 Blank line required between summary line and description + "D212", # Multi-line docstring summary should start at the first line (to allow google style docstrings) + "D405", # Capitalize section name + "D417", # Missing argument description +] + +# Exclude specific files and patterns +exclude = [ + "setup.py", + "^.*https?://.*$", # Long URLs in comments + "^.*figure.*$", +] + +# Set the maximum line length +line-length = 100 + +# Enable the count of violations +output-format = "full" + +# [tool.ruff.lint.isort] +# force-wrap-aliases = true +# combine-as-imports = true +# force-sort-within-sections = true +# known-first-party = ["cartesia_mlx", "cartesia_metal"] +# known-third-party = [] +# known-local-folder = [] +# lines-after-imports = 2 + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.isort] +profile = "black" +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true +line_length = 100 +known_first_party = ["cartesia_metal", "cartesia_mlx", "cartesia_pytorch"] diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..aa5c7b1 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,5 @@ +pytype==2024.4.11 +isort==5.13.2 +ruff==0.4.10 +pytest +pytest-cov