Skip to content

Commit

Permalink
Issue #239
Browse files Browse the repository at this point in the history
Move ambf_env to tests folder and rename to ambf_gym_env
  • Loading branch information
adnanmunawar committed Nov 12, 2024
1 parent 09a28ff commit 948a4d6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@
from ambf_client import Client
from gym import spaces
import numpy as np
import math
import time
from ambf_world import World
from ambf_object import Object
from numpy import linalg as LA


Expand All @@ -69,11 +66,11 @@ def cur_observation(self):

class AmbfEnv:
def __init__(self):
self.obj_handle = Object
self.world_handle = World
self.obj_handle = None
self.world_handle = None

self.ambf_client = Client()
self.ambf_client.create_objs_from_rostopics()
self.ambf_client.connect()
self.n_skip_steps = 5
self.enable_step_throttling = True
self.action = []
Expand All @@ -82,8 +79,6 @@ def __init__(self):
self.action_lims_high = np.array([30, 30, 30, 2, 2, 2, 1])
self.action_space = spaces.Box(self.action_lims_low, self.action_lims_high)
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(13,))

self.base_handle = self.ambf_client.get_obj_handle('PegBase')
self.prev_sim_step = 0

pass
Expand All @@ -96,15 +91,20 @@ def set_throttling_enable(self, check):
self.enable_step_throttling = check
self.world_handle.enable_throttling(check)

def make(self, a_name):
self.obj_handle = self.ambf_client.get_obj_handle(a_name)
def make(self, a_obj_name):
print("INFO! Making environment with object: ", a_obj_name)
self.obj_handle = self.ambf_client.get_obj_handle(a_obj_name)
self.world_handle = self.ambf_client.get_world_handle()
self.world_handle.enable_throttling(self.enable_step_throttling)
self.world_handle.set_num_step_skips(self.n_skip_steps)
if self.obj_handle is None or self.world_handle is None:
raise Exception
time.sleep(0.2)

def reset(self):
print("INFO! Reset called")
self.world_handle.reset()
time.sleep(0.2)
action = [0.0,
0.0,
0.0,
Expand All @@ -131,7 +131,7 @@ def step(self, action):
return self.obs.cur_observation()

def render(self, mode):
print ' I am a {} POTATO'.format(mode)
print(' I am a {} POTATO'.format(mode))

def _update_observation(self, action):
if self.enable_step_throttling:
Expand All @@ -141,13 +141,13 @@ def _update_observation(self, action):
time.sleep(0.00001)
self.prev_sim_step = self.obj_handle.get_sim_step()
if step_jump > self.n_skip_steps:
print 'WARN: Jumped {} steps, Default skip limit {} Steps'.format(step_jump, self.n_skip_steps)
print('WARN: Jumped {} steps, Default skip limit {} Steps'.format(step_jump, self.n_skip_steps))
else:
cur_sim_step = self.obj_handle.get_sim_step()
step_jump = cur_sim_step - self.prev_sim_step
self.prev_sim_step = cur_sim_step

state = self.obj_handle.get_pose() + self.base_handle.get_pose() + [step_jump]
state = self.obj_handle.get_pose() + [step_jump]
self.obs.state = state
self.obs.reward = self._calculate_reward(state, action)
self.obs.is_done = self._check_if_done()
Expand Down
5 changes: 3 additions & 2 deletions ambf_ros_modules/ambf_client/python/tests/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,19 @@
# \version 0.1
# */
# //==============================================================================
from ambf_comm import AmbfEnv
from ambf_gym_env import AmbfEnv
import time


env = AmbfEnv()
action = env.action_space
env.make('Torus')
env.reset()
env.skip_sim_steps(1)
time.sleep(1)
env.reset()
total = 50000
for i in range(1,total):
state, r, d, dict = env.step(env.action_space.sample())
if i % 50 == 0:
print 'Reward: ', r, 'Steps: ', i, ' \ ', total
print('Reward: ', r, 'Steps: ', i, ' \ ', total)
4 changes: 2 additions & 2 deletions ambf_ros_modules/ambf_client/python/tests/rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
# */
# //==============================================================================
import numpy as np
from ambf_comm import AmbfEnv
from ambf_gym_env import AmbfEnv

from keras.models import Sequential, Model
from keras.layers import Dense, Activation, Flatten, Input, Concatenate
Expand All @@ -56,11 +56,11 @@

ENV_NAME = 'Torus'


# Get the environment and extract the number of actions.
env = AmbfEnv()
env.make(ENV_NAME)
env.reset()
time.sleep(0.5)
assert len(env.action_space.shape) == 1
nb_actions = env.action_space.shape[0]

Expand Down

0 comments on commit 948a4d6

Please sign in to comment.