From 9eb676ab4614a8760188001128022ca2e797afb8 Mon Sep 17 00:00:00 2001 From: Hongbo <12580159+ya0guang@users.noreply.github.com> Date: Sun, 31 Mar 2024 20:25:22 -0400 Subject: [PATCH 1/2] Seems like the `tmpLbl` variable is misused as --- backgroundremover/u2net/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backgroundremover/u2net/data_loader.py b/backgroundremover/u2net/data_loader.py index 92991f5..2e3b29a 100644 --- a/backgroundremover/u2net/data_loader.py +++ b/backgroundremover/u2net/data_loader.py @@ -139,7 +139,7 @@ def __call__(self, sample): # change the r,g,b to b,r,g from [0,255] to [0,1] # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) tmpImg = tmpImg.transpose((2, 0, 1)) - tmpLbl = label.transpose((2, 0, 1)) + tmpLbl = tmpLbl.transpose((2, 0, 1)) return { "imidx": torch.from_numpy(imidx), From 3e9804b8ed92576da08ceae19bd476012e99b1a8 Mon Sep 17 00:00:00 2001 From: Ahmad Alobaid Date: Sun, 7 Apr 2024 16:39:13 +0300 Subject: [PATCH 2/2] background remover as a library --- .gitignore | 1 + README.md | 21 +++++++++++++++++ backgroundremover/bg.py | 8 +++---- backgroundremover/github.py | 38 +++++++++++++++++++++++++++++++ backgroundremover/u2net/detect.py | 12 ++++++---- backgroundremover/utilities.py | 35 ---------------------------- 6 files changed, 72 insertions(+), 43 deletions(-) create mode 100644 backgroundremover/github.py diff --git a/.gitignore b/.gitignore index 2a096b1..b788388 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea/ # Created by https://www.toptal.com/developers/gitignore/api/python # Edit at https://www.toptal.com/developers/gitignore?templates=python diff --git a/README.md b/README.md index 118a712..e4bc657 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,27 @@ change the model for different background removal methods between `u2netp`, `u2n backgroundremover -i "/path/to/video.mp4" -m "u2net_human_seg" -fl 150 -tv -o "output.mov" ``` +## As a library +### Remove background image + +``` +from backgroundremover.bg import remove +def remove_bg(src_img_path, out_img_path): + model_choices = ["u2net", "u2net_human_seg", "u2netp"] + f = open(src_img_path, "rb") + data = f.read() + img = remove(data, model_name=model_choices[0], + alpha_matting=True, + alpha_matting_foreground_threshold=240, + alpha_matting_background_threshold=10, + alpha_matting_erode_structure_size=10, + alpha_matting_base_size=1000) + f.close() + f = open(out_img_path, "wb") + f.write(img) + f.close() +``` + ## Todo - convert logic from video to image to utilize more GPU on image removal diff --git a/backgroundremover/bg.py b/backgroundremover/bg.py index 88be680..8abd700 100644 --- a/backgroundremover/bg.py +++ b/backgroundremover/bg.py @@ -13,7 +13,7 @@ import torch.nn.functional from hsh.library.hash import Hasher from .u2net import detect, u2net -from . import utilities +from . import github # closes https://github.com/nadermx/backgroundremover/issues/18 # closes https://github.com/nadermx/backgroundremover/issues/112 @@ -56,7 +56,7 @@ def __init__(self, model_name): if ( not os.path.exists(path) ): - utilities.download_files_from_github( + github.download_files_from_github( path, model_name ) @@ -70,7 +70,7 @@ def __init__(self, model_name): not os.path.exists(path) #or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a" ): - utilities.download_files_from_github( + github.download_files_from_github( path, model_name ) @@ -84,7 +84,7 @@ def __init__(self, model_name): not os.path.exists(path) #or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55" ): - utilities.download_files_from_github( + github.download_files_from_github( path, model_name ) else: diff --git a/backgroundremover/github.py b/backgroundremover/github.py new file mode 100644 index 0000000..738921f --- /dev/null +++ b/backgroundremover/github.py @@ -0,0 +1,38 @@ +import os +import requests + + +def download_files_from_github(path, model_name): + if model_name not in ["u2net", "u2net_human_seg", "u2netp"]: + print("Invalid model name, please use 'u2net' or 'u2net_human_seg' or 'u2netp'") + return + print(f"downloading model [{model_name}] to {path} ...") + urls = [] + if model_name == "u2net": + urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2aa', + 'https://github.com/nadermx/backgroundremover/raw/main/models/u2ab', + 'https://github.com/nadermx/backgroundremover/raw/main/models/u2ac', + 'https://github.com/nadermx/backgroundremover/raw/main/models/u2ad'] + elif model_name == "u2net_human_seg": + urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2haa', + 'https://github.com/nadermx/backgroundremover/raw/main/models/u2hab', + 'https://github.com/nadermx/backgroundremover/raw/main/models/u2hac', + 'https://github.com/nadermx/backgroundremover/raw/main/models/u2had'] + elif model_name == 'u2netp': + urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth'] + try: + os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True) + except Exception as e: + print(f"Error creating directory: {e}") + return + + try: + + with open(path, 'wb') as out_file: + for i, url in enumerate(urls): + print(f'downloading part {i+1} of {model_name}') + part_content = requests.get(url) + out_file.write(part_content.content) + print(f'finished downloading part {i+1} of {model_name}') + except Exception as e: + print(e) diff --git a/backgroundremover/u2net/detect.py b/backgroundremover/u2net/detect.py index d2426e6..cda5184 100644 --- a/backgroundremover/u2net/detect.py +++ b/backgroundremover/u2net/detect.py @@ -8,7 +8,8 @@ from torchvision import transforms from . import data_loader, u2net -from .. import utilities +from .. import github + def load_model(model_name: str = "u2net"): hasher = Hasher() @@ -38,7 +39,7 @@ def load_model(model_name: str = "u2net"): not os.path.exists(path) #or hasher.md5(path) != "e4f636406ca4e2af789941e7f139ee2e" ): - utilities.download_files_from_github( + github.download_files_from_github( path, model_name ) @@ -48,11 +49,14 @@ def load_model(model_name: str = "u2net"): "U2NET_PATH", os.path.expanduser(os.path.join("~", ".u2net", model_name + ".pth")), ) + + print(f"DEBUG: path to be checked: {path}") + if ( not os.path.exists(path) #or hasher.md5(path) != "09fb4e49b7f785c9f855baf94916840a" ): - utilities.download_files_from_github( + github.download_files_from_github( path, model_name ) @@ -66,7 +70,7 @@ def load_model(model_name: str = "u2net"): not os.path.exists(path) #or hasher.md5(path) != "347c3d51b01528e5c6c071e3cff1cb55" ): - utilities.download_files_from_github( + github.download_files_from_github( path, model_name ) diff --git a/backgroundremover/utilities.py b/backgroundremover/utilities.py index 7fba897..52be205 100644 --- a/backgroundremover/utilities.py +++ b/backgroundremover/utilities.py @@ -328,38 +328,3 @@ def transparentvideooverimage(output, overlay, file_path, except PermissionError: pass return - -def download_files_from_github(path, model_name): - if model_name not in ["u2net", "u2net_human_seg", "u2netp"]: - print("Invalid model name, please use 'u2net' or 'u2net_human_seg' or 'u2netp'") - return - print(f"downloading model [{model_name}] to {path} ...") - urls = [] - if model_name == "u2net": - urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2aa', - 'https://github.com/nadermx/backgroundremover/raw/main/models/u2ab', - 'https://github.com/nadermx/backgroundremover/raw/main/models/u2ac', - 'https://github.com/nadermx/backgroundremover/raw/main/models/u2ad'] - elif model_name == "u2net_human_seg": - urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2haa', - 'https://github.com/nadermx/backgroundremover/raw/main/models/u2hab', - 'https://github.com/nadermx/backgroundremover/raw/main/models/u2hac', - 'https://github.com/nadermx/backgroundremover/raw/main/models/u2had'] - elif model_name == 'u2netp': - urls = ['https://github.com/nadermx/backgroundremover/raw/main/models/u2netp.pth'] - try: - os.makedirs(os.path.expanduser("~/.u2net"), exist_ok=True) - except Exception as e: - print(f"Error creating directory: {e}") - return - - try: - - with open(path, 'wb') as out_file: - for i, url in enumerate(urls): - print(f'downloading part {i+1} of {model_name}') - part_content = requests.get(url) - out_file.write(part_content.content) - print(f'finished downloading part {i+1} of {model_name}') - except Exception as e: - print(e)