diff --git a/src/foraging_gui/Dialogs.py b/src/foraging_gui/Dialogs.py index f785937e0..6ed55fdee 100644 --- a/src/foraging_gui/Dialogs.py +++ b/src/foraging_gui/Dialogs.py @@ -15,7 +15,7 @@ from PyQt5.QtWidgets import QApplication, QDialog, QVBoxLayout, QHBoxLayout, QMessageBox, QGridLayout from PyQt5.QtWidgets import QLabel, QDialogButtonBox,QFileDialog,QInputDialog, QLineEdit from PyQt5 import QtWidgets, uic, QtGui -from PyQt5.QtCore import QThreadPool,Qt, QAbstractTableModel, QItemSelectionModel, QObject, QTimer +from PyQt5.QtCore import QThreadPool,Qt, QAbstractTableModel, QItemSelectionModel, QObject, QTimer, pyqtSignal from PyQt5.QtSvg import QSvgWidget from foraging_gui.MyFunctions import Worker @@ -2215,7 +2215,7 @@ def _SelectRigMetadata(self,rig_metadata_file=None): class AutoTrainDialog(QDialog): '''For automatic training''' - + trainingStageChanged = pyqtSignal(str) # signal to indicate training stage has changed def __init__(self, MainWindow, parent=None): super().__init__(parent) uic.loadUi('AutoTrain.ui', self) @@ -2651,12 +2651,12 @@ def _override_curriculum_clicked(self, state): def _update_stage_to_apply(self): if self.checkBox_override_stage.isChecked(): self.stage_in_use = self.comboBox_override_stage.currentText() + logger.info(f"Stage overridden to: {self.stage_in_use}") elif self.last_session is not None: self.stage_in_use = self.last_session['next_stage_suggested'] else: self.stage_in_use = 'unknown training stage' - self.pushButton_apply_auto_train_paras.setText( f"Apply and lock\n" + '\n'.join(get_curriculum_string(self.curriculum_in_use).split('(')).strip(')') @@ -2665,7 +2665,7 @@ def _update_stage_to_apply(self): logger.info(f"Current stage to apply: {self.stage_in_use} @" f"{get_curriculum_string(self.curriculum_in_use)}") - + self.trainingStageChanged.emit(self.stage_in_use) def _apply_curriculum(self): # Check if a curriculum is selected if not hasattr(self, 'selected_curriculum') or self.selected_curriculum is None: @@ -2885,10 +2885,10 @@ def _set_training_parameters(self, paras_dict, if_apply_and_lock=False): # Set warmup to off first so that all AutoTrain parameters # can be correctly registered in WarmupBackup if warmup is turned on later - if paras_dict and paras_dict['warmup'] != self.MainWindow.warmup.currentText(): + if paras_dict and paras_dict['warmup'] != self.MainWindow.behavior_task_logic_model.task_parameters.warmup: widgets_changed.update( {self.MainWindow.warmup: - self.MainWindow.warmup.currentText() + self.MainWindow.behavior_task_logic_model.task_parameters.warmup } ) # Track the changes diff --git a/src/foraging_gui/Foraging.py b/src/foraging_gui/Foraging.py index e30abd628..6d749e4d1 100644 --- a/src/foraging_gui/Foraging.py +++ b/src/foraging_gui/Foraging.py @@ -28,10 +28,11 @@ from pykeepass import PyKeePass from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar from scipy.io import savemat, loadmat -from PyQt5.QtWidgets import QApplication, QMainWindow, QMessageBox, QSizePolicy +from PyQt5.QtWidgets import QApplication, QMainWindow, QMessageBox, QSizePolicy, QLineEdit, QComboBox, QPushButton, QDoubleSpinBox from PyQt5.QtWidgets import QFileDialog,QVBoxLayout, QGridLayout, QLabel from PyQt5 import QtWidgets,QtGui,QtCore, uic from PyQt5.QtCore import QThreadPool,Qt,QThread +from PyQt5.QtGui import QIntValidator, QDoubleValidator from pyOSC3.OSC3 import OSCStreamingClient import webbrowser from pydantic import ValidationError @@ -54,7 +55,9 @@ from aind_data_schema.core.session import Session from aind_data_schema_models.modalities import Modality from aind_behavior_services.session import AindBehaviorSessionModel -from aind_auto_train.schema.task import TrainingStage +from aind_behavior_services.task_logic import AindBehaviorTaskLogicModel +from aind_auto_train.schema.task import TrainingStage, DynamicForagingParas, AdvancedBlockMode +import aind_auto_train logger = logging.getLogger(__name__) logger.root.handlers.clear() # clear handlers so console output can be configured @@ -128,11 +131,17 @@ def __init__(self, parent=None,box_number=1,start_bonsai_ide=True): subject=self.ID.text(), experiment_version=foraging_gui.__version__, notes=self.ShowNotes.toPlainText(), - commit_hash= subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip(), + commit_hash=subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip(), allow_dirty_repo= subprocess.check_output(['git','diff-index','--name-only', 'HEAD']).decode('ascii').strip() != '', skip_hardware_validation=True ) + # create AindBehaviorTaskLogicModel to be used and referenced for task parameter info + self.behavior_task_logic_model = AindBehaviorTaskLogicModel( + name=self.Task.currentText(), + task_parameters=self.initialize_task_parameters().dict(), + version=aind_auto_train.__version__ + ) # add warning_widget to layout and set color self.warning_widget = WarningWidget(log_tag=self.warning_log_tag, @@ -213,6 +222,8 @@ def __init__(self, parent=None,box_number=1,start_bonsai_ide=True): self._ShowRewardPairs() # show reward pairs self._GetTrainingParameters() # get initial training parameters self.connectSignalsSlots() + self.connect_session_model() # connect relevant widgets to update session model + self.connect_task_parameters() # connect relevant widgets to update task parameters self._Task() self.keyPressEvent() self._WaterVolumnManage2() @@ -243,6 +254,233 @@ def __init__(self, parent=None,box_number=1,start_bonsai_ide=True): self._ReconnectBonsai() logging.info('Start up complete') + def initialize_task_parameters(self) -> DynamicForagingParas: + """ + initialize schema of task parameters based on widgets + """ + + return DynamicForagingParas( + training_stage=TrainingStage.STAGE_1, # dummy value + task=self.Task.currentText(), + task_schema_version=aind_auto_train.__version__, + BaseRewardSum=float(self.BaseRewardSum.text()), + RewardFamily=int(self.RewardFamily.text()), + RewardPairsN=int(self.RewardPairsN.text()), + UncoupledReward=self.UncoupledReward.text(), + # Randomness + Randomness=self.Randomness.currentText(), + # Block length + BlockMin=int(self.BlockMin.text()), + BlockMax=int(self.BlockMax.text()), + BlockBeta=int(self.BlockBeta.text()), + BlockMinReward=int(self.BlockMinReward.text()), + # Delay period + DelayMin=float(self.DelayMin.text()), + DelayMax=float(self.DelayMax.text()), + DelayBeta=float(self.DelayBeta.text()), + # Reward delay + RewardDelay=float(self.RewardDelay.text()), + # Auto water + AutoReward=self.AutoReward.isChecked(), + AutoWaterType=self.AutoWaterType.currentText(), + Multiplier=float(self.Multiplier.text()), + Unrewarded=int(self.Unrewarded.text()), + Ignored=int(self.Ignored.text()), + # ITI + ITIMin=float(self.ITIMin.text()), + ITIMax=float(self.ITIMax.text()), + ITIBeta=float(self.ITIBeta.text()), + ITIIncrease=float(self.ITIIncrease.text()), + # Response time + ResponseTime=float(self.ResponseTime.text()), + RewardConsumeTime=float(self.RewardConsumeTime.text()), + StopIgnores=round(int(self.auto_stop_ignore_win.text())*float(self.auto_stop_ignore_ratio_threshold.text())), + # Auto block + AdvancedBlockAuto=self.AdvancedBlockAuto.currentText(), + SwitchThr=float(self.SwitchThr.text()), + PointsInARow=int(self.PointsInARow.text()), + # Auto stop + MaxTrial=int(self.MaxTrial.text()), + MaxTime=int(self.MaxTime.text()), + # Reward size + RightValue_volume=float(self.RightValue_volume.text()), + LeftValue_volume=float(self.LeftValue_volume.text()), + # Warmup + warmup=self.warmup.currentText(), + warm_min_trial=int(self.warm_min_trial.text()), + warm_max_choice_ratio_bias=float(self.warm_max_choice_ratio_bias.text()), + warm_min_finish_ratio=float(self.warm_min_finish_ratio.text()), + warm_windowsize=int(self.warm_windowsize.text()) + ) + + def connect_task_parameters(self) -> None: + """ + Connect relevant widgets to update task parameters in task logic model and add validators + """ + # update parameters in behavior task logic model + # self.AutoTrain_dialog.trainingStageChanged.connect( + # lambda stage: setattr(self.behavior_task_logic_model.task_parameters, 'training_stage', stage)) + self.Task.currentTextChanged.connect( + lambda task: setattr(self.behavior_task_logic_model.task_parameters, 'task', task)) + self.BaseRewardSum.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'BaseRewardSum', float(text))) + self.BaseRewardSum.setValidator(QDoubleValidator()) + + self.RewardFamily.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'RewardFamily', int(text))) + self.RewardFamily.setValidator(QIntValidator()) + + self.RewardPairsN.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'RewardPairsN', int(text))) + self.RewardPairsN.setValidator(QIntValidator()) + + self.UncoupledReward.textChanged.connect( + lambda text: setattr(self.behavior_task_logic_model.task_parameters, 'UncoupledReward', text)) + + self.Randomness.currentIndexChanged.connect( + lambda text: setattr(self.behavior_task_logic_model.task_parameters, 'Randomness', text)) + + self.BlockMin.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'BlockMin', int(text))) + self.BlockMin.setValidator(QIntValidator()) + + self.BlockMax.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'BlockMax', int(text))) + self.BlockMax.setValidator(QIntValidator()) + + self.BlockBeta.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'BlockBeta', int(text))) + self.BlockBeta.setValidator(QIntValidator()) + + self.BlockMinReward.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'BlockMinReward', int(text))) + self.BlockMinReward.setValidator(QIntValidator()) + + self.DelayMin.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'DelayMin', float(text))) + self.DelayMin.setValidator(QDoubleValidator()) + + self.DelayMax.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'DelayMax', float(text))) + self.DelayMax.setValidator(QDoubleValidator()) + + self.DelayBeta.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'DelayBeta', float(text))) + self.DelayBeta.setValidator(QDoubleValidator()) + + self.RewardDelay.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'RewardDelay', float(text))) + self.RewardDelay.setValidator(QDoubleValidator()) + + self.AutoReward.toggled.connect( + lambda checked: setattr(self.behavior_task_logic_model.task_parameters, 'RewardDelay', checked)) + + self.AutoWaterType.currentTextChanged.connect( + lambda water: setattr(self.behavior_task_logic_model.task_parameters, 'AutoWaterType', water)) + + self.Multiplier.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'Multiplier', float(text))) + self.Multiplier.setValidator(QDoubleValidator()) + + self.Unrewarded.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'Unrewarded', int(text))) + self.Unrewarded.setValidator(QIntValidator()) + + self.Ignored.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'Ignored', int(text))) + self.Ignored.setValidator(QIntValidator()) + + self.ITIMin.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'ITIMin', float(text))) + self.ITIMin.setValidator(QDoubleValidator()) + + self.ITIMax.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'ITIMax', float(text))) + self.ITIMax.setValidator(QDoubleValidator()) + + self.ITIBeta.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'ITIBeta', float(text))) + self.ITIBeta.setValidator(QDoubleValidator()) + + self.ITIIncrease.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'ITIIncrease', float(text))) + self.ITIIncrease.setValidator(QDoubleValidator()) + + self.ResponseTime.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'ResponseTime', float(text))) + self.ResponseTime.setValidator(QDoubleValidator()) + + self.RewardConsumeTime.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'RewardConsumeTime', float(text))) + self.RewardConsumeTime.setValidator(QDoubleValidator()) + + self.auto_stop_ignore_win.textChanged.connect( + lambda win: None if win == '' else setattr(self.behavior_task_logic_model.task_parameters, 'StopIgnores', + round(int(win) * float(self.auto_stop_ignore_ratio_threshold.text())))) + self.auto_stop_ignore_win.setValidator(QIntValidator()) + + self.auto_stop_ignore_ratio_threshold.textChanged.connect( + lambda ratio: None if ratio in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'StopIgnores', + round(int(self.auto_stop_ignore_win.text()) * (float(ratio))))) + self.auto_stop_ignore_ratio_threshold.setValidator(QDoubleValidator()) + + self.AdvancedBlockAuto.currentTextChanged.connect( + lambda text: setattr(self.behavior_task_logic_model.task_parameters, 'AdvancedBlockAuto', text)) + + self.SwitchThr.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'SwitchThr', float(text))) + self.SwitchThr.setValidator(QDoubleValidator()) + + self.PointsInARow.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'PointsInARow', int(text))) + self.PointsInARow.setValidator(QIntValidator()) + + self.MaxTrial.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'MaxTrial', int(text))) + self.MaxTrial.setValidator(QIntValidator()) + + self.MaxTime.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'MaxTime', int(text))) + self.MaxTime.setValidator(QIntValidator()) + + self.RightValue_volume.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'RightValue_volume', float(text))) + + self.LeftValue_volume.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'LeftValue_volume', float(text))) + + self.warmup.currentTextChanged.connect( + lambda text: setattr(self.behavior_task_logic_model.task_parameters, 'warmup', text)) + + self.warm_min_trial.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'warm_min_trial', int(text))) + self.warm_min_trial.setValidator(QIntValidator()) + + self.warm_max_choice_ratio_bias.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'warm_max_choice_ratio_bias', + float(text))) + self.warm_max_choice_ratio_bias.setValidator(QDoubleValidator()) + + self.warm_min_finish_ratio.textChanged.connect( + lambda text: None if text in ['', '.'] else setattr(self.behavior_task_logic_model.task_parameters, 'warm_min_finish_ratio', float(text))) + self.warm_min_finish_ratio.setValidator(QDoubleValidator()) + + self.warm_windowsize.textChanged.connect( + lambda text: None if text == '' else setattr(self.behavior_task_logic_model.task_parameters, 'warm_windowsize', int(text))) + self.warm_windowsize.setValidator(QIntValidator()) + + def connect_session_model(self) -> None: + """ + Connect relevant widgets to update session model + """ + + # update parameters in behavior session model if widgets change + self.Task.currentTextChanged.connect(lambda task: setattr(self.behavior_session_model, 'experiment', task)) + self.Experimenter.textChanged.connect(lambda text: setattr(self.behavior_session_model, 'experimenter', [text])) + self.ID.textChanged.connect(lambda subject: setattr(self.behavior_session_model, 'subject', subject)) + self.ShowNotes.textChanged.connect(lambda: setattr(self.behavior_session_model, 'notes', + self.ShowNotes.toPlainText())) + def _load_rig_metadata(self): '''Load the latest rig metadata''' @@ -407,14 +645,6 @@ def connectSignalsSlots(self): self.Opto_dialog.laser_1_calibration_power.textChanged.connect(self._toggle_save_color) self.Opto_dialog.laser_2_calibration_power.textChanged.connect(self._toggle_save_color) - # update parameters in behavior session model if widgets change - self.Task.currentTextChanged.connect(lambda task: setattr(self.behavior_session_model, 'experiment', task)) - self.Experimenter.textChanged.connect(lambda text: setattr(self.behavior_session_model, 'experimenter', [text])) - self.ID.textChanged.connect(lambda subject: setattr(self.behavior_session_model, 'subject', subject)) - self.ShowNotes.textChanged.connect(lambda: setattr(self.behavior_session_model, 'notes', - self.ShowNotes.toPlainText())) - - # Set manual water volume to earned reward and trigger update if changed for side in ['Left', 'Right']: reward_volume_widget = getattr(self, f'{side}Value_volume') @@ -723,7 +953,7 @@ def _warmup(self): to some incorrect parameters when it was turned off. ''' # set warm up parameters - if self.warmup.currentText()=='on': + if self.behavior_task_logic_model.task_parameters.warmup=='on': # get parameters before the warm up is on;WarmupBackup_ stands for Warmup backup, which are parameters before warm-up. self._GetTrainingParameters(prefix='WarmupBackup_') self.warm_min_trial.setEnabled(True) @@ -755,7 +985,7 @@ def _warmup(self): # turn advanced block auto off self.AdvancedBlockAuto.setCurrentIndex(self.AdvancedBlockAuto.findText('off')) self._ShowRewardPairs() - elif self.warmup.currentText()=='off': + elif self.behavior_task_logic_model.task_parameters.warmup=='off': # set parameters back to the previous parameters before warm up self._revert_to_previous_parameters() self.warm_min_trial.setEnabled(False) @@ -1833,7 +2063,7 @@ def _set_parameters(self,key,widget_dict,parameters): def _Randomness(self): '''enable/disable some fields in the Block/Delay Period/ITI''' - if self.Randomness.currentText()=='Exponential': + if self.behavior_task_logic_model.task_parameters.Randomness == 'Exponential': self.label_14.setEnabled(True) self.label_18.setEnabled(True) self.label_39.setEnabled(True) @@ -1842,7 +2072,7 @@ def _Randomness(self): self.ITIBeta.setEnabled(True) # if self.Task.currentText()!='RewardN': # self.BlockBeta.setStyleSheet("color: black;border: 1px solid gray;background-color: white;") - elif self.Randomness.currentText()=='Even': + elif self.behavior_task_logic_model.task_parameters.Randomness == 'Even': self.label_14.setEnabled(False) self.label_18.setEnabled(False) self.label_39.setEnabled(False) @@ -1856,7 +2086,7 @@ def _Randomness(self): def _AdvancedBlockAuto(self): '''enable/disable some fields in the AdvancedBlockAuto''' - if self.AdvancedBlockAuto.currentText()=='off': + if self.behavior_task_logic_model.task_parameters.AdvancedBlockAuto=='off': self.label_54.setEnabled(False) self.label_60.setEnabled(False) self.SwitchThr.setEnabled(False) @@ -2032,17 +2262,16 @@ def _CheckFormat(self,child): '''Check if the input format is correct''' if child.objectName()=='RewardFamily': # When we change the RewardFamily, sometimes the RewardPairsN is larger than available reward pairs in this family. try: - self.RewardFamilies[int(self.RewardFamily.text())-1] - if int(self.RewardPairsN.text())>len(self.RewardFamilies[int(self.RewardFamily.text())-1]): - self.RewardPairsN.setText(str(len(self.RewardFamilies[int(self.RewardFamily.text())-1]))) + if self.behavior_task_logic_model.task_parameters.RewardPairsN>len(self.RewardFamilies[self.behavior_task_logic_model.task_parameters.RewardFamily-1]): + self.RewardPairsN.setText(str(len(self.RewardFamilies[self.behavior_task_logic_model.task_parameters.RewardFamily-1]))) return 1 except Exception as e: logging.error(traceback.format_exc()) return 0 if child.objectName()=='RewardFamily' or child.objectName()=='RewardPairsN' or child.objectName()=='BaseRewardSum': try: - self.RewardPairs=self.RewardFamilies[int(self.RewardFamily.text())-1][:int(self.RewardPairsN.text())] - if int(self.RewardPairsN.text())>len(self.RewardFamilies[int(self.RewardFamily.text())-1]): + self.RewardPairs=self.RewardFamilies[self.behavior_task_logic_model.task_parameters.RewardFamily-1][:self.behavior_task_logic_model.task_parameters.RewardPairsN] + if self.behavior_task_logic_model.task_parameters.RewardPairsN>len(self.RewardFamilies[self.behavior_task_logic_model.task_parameters.RewardFamily-1]): return 0 else: return 1 @@ -2051,7 +2280,7 @@ def _CheckFormat(self,child): return 0 if child.objectName()=='UncoupledReward': try: - input_string=self.UncoupledReward.text() + input_string = self.behavior_task_logic_model.task_parameters.UncoupledReward if input_string=='': # do not permit empty uncoupled reward return 0 # remove any square brackets and spaces from the string @@ -2229,10 +2458,11 @@ def _ShowRewardPairs(self): '''Show reward pairs''' try: if self.behavior_session_model.experiment in ['Coupled Baiting','Coupled Without Baiting','RewardN']: - self.RewardPairs=self.RewardFamilies[int(self.RewardFamily.text())-1][:int(self.RewardPairsN.text())] - self.RewardProb=np.array(self.RewardPairs)/np.expand_dims(np.sum(self.RewardPairs,axis=1),axis=1)*float(self.BaseRewardSum.text()) + self.RewardPairs=self.RewardFamilies[int(self.behavior_task_logic_model.task_parameters.RewardFamily)-1][:int(self.behavior_task_logic_model.task_parameters.RewardPairsN)] + self.RewardProb=np.array(self.RewardPairs)/np.expand_dims(np.sum(self.RewardPairs,axis=1),axis=1)*\ + self.behavior_task_logic_model.task_parameters.BaseRewardSum elif self.behavior_session_model.experiment in ['Uncoupled Baiting','Uncoupled Without Baiting']: - input_string=self.UncoupledReward.text() + input_string = self.behavior_task_logic_model.task_parameters.UncoupledReward # remove any square brackets and spaces from the string input_string = input_string.replace('[','').replace(']','').replace(',', ' ') # split the remaining string into a list of individual numbers @@ -2836,7 +3066,8 @@ def _get_folder_structure_new(self): self.SaveFileMat=os.path.join(self.behavior_session_model.root_path,f'{id_name}.mat') self.SaveFileJson=os.path.join(self.behavior_session_model.root_path,f'{id_name}.json') self.SaveFileParJson=os.path.join(self.behavior_session_model.root_path,f'{id_name}_par.json') - self.behavior_session_modelJson = os.path.join(self.behavior_session_model.root_path,f'behavior_session_model_{id_name}.json') + self.behavior_session_model_json = os.path.join(self.behavior_session_model.root_path, f'behavior_session_model_{id_name}.json') + self.behavior_task_logic_model_json = os.path.join(self.behavior_session_model.root_path, f'behavior_task_logic_model_{id_name}.json') self.HarpFolder=os.path.join(self.behavior_session_model.root_path,'raw.harp') self.VideoFolder=os.path.join(self.SessionFolder,'behavior-videos') self.PhotometryFolder=os.path.join(self.SessionFolder,'fib') @@ -3546,7 +3777,7 @@ def _StopPhotometry(self,closing=False): QMessageBox.Ok) def _AutoReward(self): - if self.AutoReward.isChecked(): + if self.behavior_task_logic_model.task_parameters.AutoReward: self.AutoReward.setStyleSheet("background-color : green;") self.AutoReward.setText('Auto water On') for widget in ['AutoWaterType', 'Multiplier', 'Unrewarded', 'Ignored']: @@ -4045,9 +4276,18 @@ def _Start(self): except ValidationError as e: logging.error(str(e), extra={'tags': [self.warning_log_tag]}) # save behavior session model - with open(self.behavior_session_modelJson, "w") as outfile: + with open(self.behavior_session_model_json, "w") as outfile: outfile.write(self.behavior_session_model.model_dump_json()) + # validate behavior session task logic model and document validation errors if any + try: + AindBehaviorTaskLogicModel(**self.behavior_task_logic_model.model_dump()) + except ValidationError as e: + logging.error(str(e), extra={'tags': [self.warning_log_tag]}) + # save behavior session model + with open(self.behavior_task_logic_model_json, "w") as outfile: + outfile.write(self.behavior_task_logic_model.model_dump_json()) + if (self.StartANewSession == 1) and (self.ANewTrial == 0): # If we are starting a new session, we should wait for the last trial to finish @@ -4976,6 +5216,8 @@ def log_subprocess_output(process, prefix): # Move creating AutoTrain here to catch any AWS errors win.create_auto_train_dialog() - + win.AutoTrain_dialog.trainingStageChanged.connect( + lambda stage: setattr(win.behavior_task_logic_model.task_parameters, 'training_stage', stage)) + # TODO: Feels weird doing it this way? I don't know the AWS errors reasoning though # Run your application's event loop and stop after closing all windows sys.exit(app.exec()) diff --git a/src/foraging_gui/MyFunctions.py b/src/foraging_gui/MyFunctions.py index 266b8dcb1..893c1bda9 100644 --- a/src/foraging_gui/MyFunctions.py +++ b/src/foraging_gui/MyFunctions.py @@ -278,7 +278,7 @@ def _get_uncoupled_reward_prob_pool(self): def _CheckWarmUp(self): '''Check if we should turn on warm up''' - if self.win.warmup.currentText()=='off': + if self.win.behavior_task_logic_model.task_parameters.warmup=='off': return warmup=self._get_warmup_state() if warmup==0 and self.TP_warmup=='on': @@ -1647,10 +1647,11 @@ def _GetAnimalResponse(self,Channel1,Channel3,Channel4): self._SimulateResponse() return # set the valve time of auto water + multiplier = self.win.behavior_task_logic_model.task_parameters.Multiplier if self.CurrentAutoRewardTrial[0]==1: - self._set_valve_time_left(Channel3,float(self.win.LeftValue.text()),float(self.win.Multiplier.text())) + self._set_valve_time_left(Channel3,float(self.win.LeftValue.text()),multiplier) if self.CurrentAutoRewardTrial[1]==1: - self._set_valve_time_right(Channel3,float(self.win.RightValue.text()),float(self.win.Multiplier.text())) + self._set_valve_time_right(Channel3,float(self.win.RightValue.text()),multiplier) if self.CurrentStartType==3: # no delay timestamp ReceiveN=9 @@ -1794,14 +1795,14 @@ def _set_valve_time_right(self,channel3,RightValue=0.01,Multiplier=1): def _GiveLeft(self,channel3): '''manually give left water''' - channel3.LeftValue1(float(self.win.LeftValue.text())*1000*float(self.win.Multiplier.text())) + channel3.LeftValue1(float(self.win.LeftValue.text())*1000*self.win.behavior_task_logic_model.task_parameters.Multiplier) time.sleep(0.01) channel3.ManualWater_Left(int(1)) channel3.LeftValue1(float(self.win.LeftValue.text())*1000) def _GiveRight(self,channel3): '''manually give right water''' - channel3.RightValue1(float(self.win.RightValue.text())*1000*float(self.win.Multiplier.text())) + channel3.RightValue1(float(self.win.RightValue.text())*1000*self.win.behavior_task_logic_model.task_parameters.Multiplier) time.sleep(0.01) channel3.ManualWater_Right(int(1)) channel3.RightValue1(float(self.win.RightValue.text())*1000) diff --git a/src/foraging_gui/Visualization.py b/src/foraging_gui/Visualization.py index d64f256b5..5afbb30d1 100644 --- a/src/foraging_gui/Visualization.py +++ b/src/foraging_gui/Visualization.py @@ -522,23 +522,23 @@ def __init__(self,GeneratedTrials=None,dpi=100,width=5, height=4): FigureCanvas.__init__(self, self.fig) def _Update(self,win): # randomly draw a block length between Min and Max - SampleMethods=win.Randomness.currentText() + SampleMethods=win.behavior_task_logic_model.task_parameters.Randomness # block length - Min=float(win.BlockMin.text()) - Max=float(win.BlockMax.text()) - Beta=float(win.BlockBeta.text()) + Min=win.behavior_task_logic_model.task_parameters.BlockMin + Max=win.behavior_task_logic_model.task_parameters.BlockMax + Beta=win.behavior_task_logic_model.task_parameters.BlockBeta DataType='int' SampledBlockLen=self._Sample(Min=Min,Max=Max,SampleMethods=SampleMethods,Beta=Beta,DataType=DataType) # ITI - Min=float(win.ITIMin.text()) - Max=float(win.ITIMax.text()) - Beta=float(win.ITIBeta.text()) + Min=win.behavior_task_logic_model.task_parameters.ITIMin + Max=win.behavior_task_logic_model.task_parameters.ITIMax + Beta=win.behavior_task_logic_model.task_parameters.ITIBeta DataType='float' SampledITI=self._Sample(Min=Min,Max=Max,SampleMethods=SampleMethods,Beta=Beta,DataType=DataType) # Delay - Min=float(win.DelayMin.text()) - Max=float(win.DelayMax.text()) - Beta=float(win.DelayBeta.text()) + Min=win.behavior_task_logic_model.task_parameters.DelayMin + Max=win.behavior_task_logic_model.task_parameters.DelayMax + Beta=win.behavior_task_logic_model.task_parameters.DelayBeta DataType='float' SampledDelay=self._Sample(Min=Min,Max=Max,SampleMethods=SampleMethods,Beta=Beta,DataType=DataType) self.ax1.cla()