Skip to content

Commit

Permalink
combine cpu and gpu wheel
Browse files Browse the repository at this point in the history
  • Loading branch information
jq authored and rhdong committed Sep 2, 2024
1 parent 0dd0fd8 commit 9384a0a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 18 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ jobs:
py-version: '3.8'
- os: 'Linux'
cpu: 'arm64'
- os: 'Linux'
tf-need-cuda: '0'
- py-version: '3.7'
cpu: 'arm64'
- py-version: '3.8'
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ run the following:
```
pip install tensorflow-recommenders-addons
```

By default, CPU version will be installed. To install GPU version, run the following:
Before version 0.8, to install GPU version, run the following:
```
pip install tensorflow-recommenders-addons-gpu
```
Expand Down
4 changes: 0 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,9 @@ def get_project_name_version():

project_name = "tensorflow-recommenders-addons"
version["tf_project_name"] = "tensorflow"
if os.getenv("TF_NEED_CUDA", "0") == "1":
project_name = project_name + "-gpu"

if "--nightly" in sys.argv:
project_name = "tfra-nightly"
if os.getenv("TF_NEED_CUDA", "0") == "1":
project_name = project_name + "-gpu"
version["__version__"] += get_last_commit_time()
sys.argv.remove("--nightly")

Expand Down
21 changes: 9 additions & 12 deletions tensorflow_recommenders_addons/utils/resource_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,15 @@
def get_required_tf_version():
try:
pkg = pkg_resources.get_distribution("tensorflow-recommenders-addons")
except:
try:
pkg = pkg_resources.get_distribution("tensorflow-recommenders-addons-gpu")
except pkg_resources.DistributionNotFound:
# Force return for 'Test with bazel' on CI.
warnings.warn(
"Fail to get TFRA package information, if you are running on "
"bazel test mode, please ignore this warning, \nor you should check "
"TFRA installation.",
UserWarning,
)
return tf.__version__, tf.__version__
except pkg_resources.DistributionNotFound:
# Force return for 'Test with bazel' on CI.
warnings.warn(
"Fail to get TFRA package information, if you are running on "
"bazel test mode, please ignore this warning, \nor you should check "
"TFRA installation.",
UserWarning,
)
return tf.__version__, tf.__version__

pkg_info = pkg.requires()
low_version, high_version = None, None
Expand Down

0 comments on commit 9384a0a

Please sign in to comment.