Skip to content

Commit

Permalink
_actually_ fix bug in prediction this time
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonBurns committed Feb 8, 2024
1 parent 0a90dd4 commit 6c75b56
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 2 additions & 3 deletions fastprop/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def predict_fastprop(checkpoints_dir, smiles, input_file, output=None):
output (str or None): Either save to a file or just print result.
"""
if input_file:
raise NotImplementedError("TODO")
raise NotImplementedError("TODO: please pass as command line options, loading from file is a WIP")
if type(smiles) is str:
smiles = [smiles]
checkpoint_dir_contents = os.listdir(checkpoints_dir)
Expand All @@ -37,7 +37,7 @@ def predict_fastprop(checkpoints_dir, smiles, input_file, output=None):
with open(os.path.join(checkpoints_dir, "fastprop_config.yml")) as file:
config_dict = yaml.safe_load(file)
except FileNotFoundError:
logger.error("checkpoints directory is missing 'preprocess_config.yml'. Re-execute training.")
logger.error("checkpoints directory is missing 'fastprop_config.yml'. Re-execute training.")

descs = calculate_mordred_desciptors(
mordred_descriptors_from_strings(config_dict["descriptors"]),
Expand All @@ -46,7 +46,6 @@ def predict_fastprop(checkpoints_dir, smiles, input_file, output=None):
strategy="low-memory",
)
descs = pd.DataFrame(data=descs, columns=config_dict["descriptors"])
descs = descs.dropna(axis=1, how="all")

for pickled_scaler in config_dict["feature_scalers"]:
scaler = pickle.loads(pickled_scaler)
Expand Down
5 changes: 4 additions & 1 deletion fastprop/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ def preprocess(
target_scaler = OneHotEncoder(sparse_output=False)
y = target_scaler.fit_transform(targets)

# drop missing features
# drop missing features - if this is direct output from mordred, missing values are
# strings describing why they are missing. If this is output from fastprop.utils.load_daved_desc
# the missing descriptors are nan. To deal with the former case, force all str->nan
descriptors: pd.DataFrame
descriptors = descriptors.apply(pd.to_numeric, errors="coerce")
descriptors = descriptors.dropna(axis=1, how="all")

if zero_variance_drop:
Expand Down

0 comments on commit 6c75b56

Please sign in to comment.