From 948a4d66512ba0ddc4872968635f3acf0c4883fc Mon Sep 17 00:00:00 2001 From: Adnan Munawar Date: Tue, 12 Nov 2024 12:23:51 -0500 Subject: [PATCH] Issue #239 Move ambf_env to tests folder and rename to ambf_gym_env --- .../{ambf_env.py => tests/ambf_gym_env.py} | 26 +++++++++---------- .../ambf_client/python/tests/env_test.py | 5 ++-- .../ambf_client/python/tests/rl_test.py | 4 +-- 3 files changed, 18 insertions(+), 17 deletions(-) rename ambf_ros_modules/ambf_client/python/{ambf_env.py => tests/ambf_gym_env.py} (90%) diff --git a/ambf_ros_modules/ambf_client/python/ambf_env.py b/ambf_ros_modules/ambf_client/python/tests/ambf_gym_env.py similarity index 90% rename from ambf_ros_modules/ambf_client/python/ambf_env.py rename to ambf_ros_modules/ambf_client/python/tests/ambf_gym_env.py index 955c84be6..a2f5405aa 100755 --- a/ambf_ros_modules/ambf_client/python/ambf_env.py +++ b/ambf_ros_modules/ambf_client/python/tests/ambf_gym_env.py @@ -45,10 +45,7 @@ from ambf_client import Client from gym import spaces import numpy as np -import math import time -from ambf_world import World -from ambf_object import Object from numpy import linalg as LA @@ -69,11 +66,11 @@ def cur_observation(self): class AmbfEnv: def __init__(self): - self.obj_handle = Object - self.world_handle = World + self.obj_handle = None + self.world_handle = None self.ambf_client = Client() - self.ambf_client.create_objs_from_rostopics() + self.ambf_client.connect() self.n_skip_steps = 5 self.enable_step_throttling = True self.action = [] @@ -82,8 +79,6 @@ def __init__(self): self.action_lims_high = np.array([30, 30, 30, 2, 2, 2, 1]) self.action_space = spaces.Box(self.action_lims_low, self.action_lims_high) self.observation_space = spaces.Box(-np.inf, np.inf, shape=(13,)) - - self.base_handle = self.ambf_client.get_obj_handle('PegBase') self.prev_sim_step = 0 pass @@ -96,15 +91,20 @@ def set_throttling_enable(self, check): self.enable_step_throttling = check self.world_handle.enable_throttling(check) - def make(self, a_name): - self.obj_handle = self.ambf_client.get_obj_handle(a_name) + def make(self, a_obj_name): + print("INFO! Making environment with object: ", a_obj_name) + self.obj_handle = self.ambf_client.get_obj_handle(a_obj_name) self.world_handle = self.ambf_client.get_world_handle() self.world_handle.enable_throttling(self.enable_step_throttling) self.world_handle.set_num_step_skips(self.n_skip_steps) if self.obj_handle is None or self.world_handle is None: raise Exception + time.sleep(0.2) def reset(self): + print("INFO! Reset called") + self.world_handle.reset() + time.sleep(0.2) action = [0.0, 0.0, 0.0, @@ -131,7 +131,7 @@ def step(self, action): return self.obs.cur_observation() def render(self, mode): - print ' I am a {} POTATO'.format(mode) + print(' I am a {} POTATO'.format(mode)) def _update_observation(self, action): if self.enable_step_throttling: @@ -141,13 +141,13 @@ def _update_observation(self, action): time.sleep(0.00001) self.prev_sim_step = self.obj_handle.get_sim_step() if step_jump > self.n_skip_steps: - print 'WARN: Jumped {} steps, Default skip limit {} Steps'.format(step_jump, self.n_skip_steps) + print('WARN: Jumped {} steps, Default skip limit {} Steps'.format(step_jump, self.n_skip_steps)) else: cur_sim_step = self.obj_handle.get_sim_step() step_jump = cur_sim_step - self.prev_sim_step self.prev_sim_step = cur_sim_step - state = self.obj_handle.get_pose() + self.base_handle.get_pose() + [step_jump] + state = self.obj_handle.get_pose() + [step_jump] self.obs.state = state self.obs.reward = self._calculate_reward(state, action) self.obs.is_done = self._check_if_done() diff --git a/ambf_ros_modules/ambf_client/python/tests/env_test.py b/ambf_ros_modules/ambf_client/python/tests/env_test.py index 4ca3d4feb..1d35cef02 100755 --- a/ambf_ros_modules/ambf_client/python/tests/env_test.py +++ b/ambf_ros_modules/ambf_client/python/tests/env_test.py @@ -43,13 +43,14 @@ # \version 0.1 # */ # //============================================================================== -from ambf_comm import AmbfEnv +from ambf_gym_env import AmbfEnv import time env = AmbfEnv() action = env.action_space env.make('Torus') +env.reset() env.skip_sim_steps(1) time.sleep(1) env.reset() @@ -57,4 +58,4 @@ for i in range(1,total): state, r, d, dict = env.step(env.action_space.sample()) if i % 50 == 0: - print 'Reward: ', r, 'Steps: ', i, ' \ ', total + print('Reward: ', r, 'Steps: ', i, ' \ ', total) diff --git a/ambf_ros_modules/ambf_client/python/tests/rl_test.py b/ambf_ros_modules/ambf_client/python/tests/rl_test.py index d531f2a6a..3f910fdb4 100644 --- a/ambf_ros_modules/ambf_client/python/tests/rl_test.py +++ b/ambf_ros_modules/ambf_client/python/tests/rl_test.py @@ -43,7 +43,7 @@ # */ # //============================================================================== import numpy as np -from ambf_comm import AmbfEnv +from ambf_gym_env import AmbfEnv from keras.models import Sequential, Model from keras.layers import Dense, Activation, Flatten, Input, Concatenate @@ -56,11 +56,11 @@ ENV_NAME = 'Torus' - # Get the environment and extract the number of actions. env = AmbfEnv() env.make(ENV_NAME) env.reset() +time.sleep(0.5) assert len(env.action_space.shape) == 1 nb_actions = env.action_space.shape[0]