1 Star 0 Fork 0

陈狗翔 / adeptRL

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
custom_agent_stub.py 2.79 KB
一键复制 编辑 原始数据 按行查看 历史
Joe Tatusko 提交于 2019-01-23 16:59 . Docs (#40)
"""
Use a custom agent.
"""
from adept.agents import AgentModule, AgentRegistry
from adept.scripts.local import parse_args, main
class MyCustomAgent(AgentModule):
# You will be prompted for these when training script starts
args = {
'example_arg1': True,
'example_arg2': 5
}
def __init__(
self,
network,
device,
reward_normalizer,
gpu_preprocessor,
engine,
action_space,
nb_env,
*args,
**kwargs
):
super(MyCustomAgent, self).__init__(
network,
device,
reward_normalizer,
gpu_preprocessor,
engine,
action_space,
nb_env
)
@classmethod
def from_args(cls, args, network, device, reward_normalizer,
gpu_preprocessor, engine, action_space, **kwargs):
"""
ArgName = str
:param args: Dict[ArgName, Any]
:param network: BaseNetwork
:param device: torch.device
:param reward_normalizer: Callable[[float], float]
:param gpu_preprocessor: ObsPreprocessor
:param engine: env_registry.Engines
:param action_space: Dict[ActionKey, torch.Tensor]
:param kwargs:
:return: MyCustomAgent
"""
pass
@property
def exp_cache(self):
"""
Experience cache, probably a RolloutCache or ExperienceReplay.
:return: BaseExperience
"""
pass
@staticmethod
def output_space(action_space):
"""
Merge action space with any agent-based outputs to get an output_space.
ActionKey = str
Shape = Tuple[*int]
:param action_space: Dict[ActionKey, Shape]
:return:
"""
pass
def compute_loss(self, experience, next_obs):
"""
Compute losses.
ObsKey = str
LossKey = str
:param experience: Tuple[*Any]
:param next_obs: Dict[ObsKey, torch.Tensor]
:return: Dict[LossKey, torch.Tensor (0D)]
"""
pass
def act(self, obs):
"""
Generate an action.
ObsKey = str
ActionKey = str
:param obs: Dict[ObsKey, torch.Tensor]
:return: Dict[ActionKey, np.ndarray]
"""
pass
def act_eval(self, obs):
"""
Generate an action in an evaluation.
ObsKey = str
ActionKey = str
:param obs: Dict[ObsKey, torch.Tensor]
:return: Dict[ActionKey, np.ndarray]
"""
pass
if __name__ == '__main__':
args = parse_args()
agent_reg = AgentRegistry()
agent_reg.register_agent(MyCustomAgent)
main(args, agent_registry=agent_reg)
# Call script like this to train agent:
# python -m custom_agent_stub.py --agent MyCustomAgent
1
https://gitee.com/ChenGouXiang/adeptRL.git
git@gitee.com:ChenGouXiang/adeptRL.git
ChenGouXiang
adeptRL
adeptRL
master

搜索帮助