From d4fa4c82dfb976182d1e62df601cf879b262c646 Mon Sep 17 00:00:00 2001 From: taylor howell Date: Mon, 12 Feb 2024 19:54:12 -0700 Subject: [PATCH] set simulation state --- mjpc/grpc/agent.proto | 1 + mjpc/grpc/agent_service.cc | 11 +++++++++++ python/mujoco_mpc/agent.py | 6 +++++- 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/mjpc/grpc/agent.proto b/mjpc/grpc/agent.proto index 79d4b7742..8427a9a3a 100644 --- a/mjpc/grpc/agent.proto +++ b/mjpc/grpc/agent.proto @@ -89,6 +89,7 @@ message GetStateResponse { message SetStateRequest { State state = 1; + bool set_simulation = 2; } message SetStateResponse {} diff --git a/mjpc/grpc/agent_service.cc b/mjpc/grpc/agent_service.cc index 254927a35..a1a2b314a 100644 --- a/mjpc/grpc/agent_service.cc +++ b/mjpc/grpc/agent_service.cc @@ -170,6 +170,17 @@ grpc::Status AgentService::SetState(grpc::ServerContext* context, task->Transition(model, data_); agent_.SetState(data_); + // Set simulation state + if (request->set_simulation()) { + Agent::StepJob job = [&agent_data = data_]( + Agent* agent, const mjModel* model, mjData* data) { + mju_copy(data->qpos, agent_data->qpos, model->nq); + mju_copy(data->qvel, agent_data->qvel, model->nv); + data->time = agent_data->time; + }; + agent_.RunBeforeStep(std::move(job)); + } + return grpc::Status::OK; } diff --git a/python/mujoco_mpc/agent.py b/python/mujoco_mpc/agent.py index 1198fc685..e388c1ce1 100644 --- a/python/mujoco_mpc/agent.py +++ b/python/mujoco_mpc/agent.py @@ -189,6 +189,7 @@ def set_state( mocap_pos: Optional[npt.ArrayLike] = None, mocap_quat: Optional[npt.ArrayLike] = None, userdata: Optional[npt.ArrayLike] = None, + set_simulation: Optional[bool] = False, ): """Set `Agent`'s MuJoCo `data` state. @@ -200,6 +201,7 @@ def set_state( mocap_pos: `data.mocap_pos`. mocap_quat: `data.mocap_quat`. userdata: `data.userdata`. + set_simulation: bool, set the simulation state. """ # if mocap_pos is an ndarray rather than a list, flatten it if hasattr(mocap_pos, "flatten"): @@ -217,7 +219,9 @@ def set_state( userdata=userdata if userdata is not None else [], ) - set_state_request = agent_pb2.SetStateRequest(state=state) + set_state_request = agent_pb2.SetStateRequest( + state=state, set_simulation=set_simulation + ) self.stub.SetState(set_state_request) def get_state(self) -> agent_pb2.State: