diff --git a/framework_models/ml_function_classifier/CellClassifier.py b/framework_models/ml_function_classifier/CellClassifier.py new file mode 100644 index 00000000..280894c8 --- /dev/null +++ b/framework_models/ml_function_classifier/CellClassifier.py @@ -0,0 +1,29 @@ +import os +from download_from_ghub import download_file_from_github_release +import pickle + + +class CellClassifier: + def __init__(self): + # if the folder does not exist then create the folder + if not os.path.exists('cell_classifier'): + os.mkdir('cell_classifier') + + # if the file does not exist then download the model from a github release + if not os.path.exists('cell_classifier/code_classifier.pkl'): + # download the model from github release + download_file_from_github_release() + + # load the model from the pickle file + with open('cell_classifier/code_classifier.pkl', 'rb') as f: + self.classifier = pickle.load(f) + + # classes of the model + self.classes = ['helper_functions','load_data','data_preprocessing', + 'data_exploration','modelling','evaluation', + 'prediction','result_visualization','save_results', + 'comment_only'] + + # function to predict the workflow step of a code cell, given the code cell as input, returns the predicted workflow step, e.g., 'load_data' + def predict_workflow_step(self, codestring): + return self.classes[self.classifier.predict(codestring)] \ No newline at end of file diff --git a/framework_models/ml_function_classifier/download_from_ghub.py b/framework_models/ml_function_classifier/download_from_ghub.py new file mode 100644 index 00000000..b74c5f13 --- /dev/null +++ b/framework_models/ml_function_classifier/download_from_ghub.py @@ -0,0 +1,53 @@ +import requests +import os + +def download_file_from_github_release(repo_owner="suvanshchawla", + repo_name="cell-classifer", + release_tag="rel-cell-classifier-ver-0.0.1", + asset_name="code_classifier.pkl", + download_path="./cell_classifier/"): + """Downloads a specific file from a GitHub release. + + Args: + repo_owner (str): The owner of the GitHub repository. + repo_name (str): The name of the GitHub repository. + release_tag (str): The tag name of the release (e.g., 'v1.0.0'). + asset_name (str): The name of the asset file to download. + download_path (str): The local path where the file should be saved. + """ + + # API endpoint to get release info + url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}" + + response = requests.get(url) + response.raise_for_status() # Raise an exception for bad response status codes + + release_data = response.json() + + print("Assets in this release:") # Add this line + for asset in release_data['assets']: + print(asset['name']) # Add this line + + + # Find the download URL of the asset + asset_url = None + for asset in release_data['assets']: + if asset['name'] == asset_name: + asset_url = asset['browser_download_url'] + break + + if not asset_url: + raise ValueError(f"Asset '{asset_name}' not found in the release.") + + # Download the file + response = requests.get(asset_url, stream=True) + response.raise_for_status() + + file_path = os.path.join(download_path, asset_name) + with open(file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + + print(f"File downloaded to: {file_path}") + diff --git a/headergen/headergen.py b/headergen/headergen.py index 2e4eb699..34fd8a16 100644 --- a/headergen/headergen.py +++ b/headergen/headergen.py @@ -21,6 +21,7 @@ from framework_models import PHASES as PIPELINE_PHASES from framework_models import get_high_level_phase, lookup_pipeline_tag from headergen.node_visitor import HeaderGenVisitor +from framework_models.ml_function_classifier.CellClassifier import CellClassifier # %% @@ -664,13 +665,30 @@ def get_analysis_output(nb_path, out_path="."): return analysis_info -def start_headergen(nb_path, out_path=".", debug_mode=False, create_linted_file=True): +def start_headergen(nb_path, out_path=".", debug_mode=False, create_linted_file=True, start_cellclassifer=False): # Read ipynb and convert to python for analysis + + + # implement the cellclassifier logic here + # Do classification or something here + + if start_cellclassifer: + cell_classifier = CellClassifier() + # Ashwin I need your help here: + # The usage of the cell_classifier is as follows: + # cell_classifier.predict_workflow_step(cell_code) + # It will first check if there a folder named cell_classifier exists, + # if not it will create one and then download the pkl file to it + # and then load it in the classifer attribute of the object + # The predict_workflow_step method will take a cell code as input and return the predicted workflow step - The names can be found in /framework_models/ml_function_classifier/CellClassifier.py + # I have added the code for the classifier in headergen, but I require your help to integrate it with the headergen code, where the predictions can be used to classify the cells + file_name = Path(nb_path).stem if Path(nb_path).suffix == ".py": py_ntbk_path = nb_path py_ntbk = open(py_ntbk_path).read() elif Path(nb_path).suffix == ".ipynb": + ntbk = jupytext.read(nb_path) py_ntbk = jupytext.writes(ntbk, fmt="py:percent") py_ntbk_path = "{}/{}-hg-analysis.py".format(Path(nb_path).parent, file_name)