当前位置: 首页 > news >正文

强化学习DeepQLearning求最优策略的代码实现

理论基础:

可运行代码:

import numpy as np import torch from torch import nn from torch.utils import data from env import GridWorldEnv from utils import drow_policy class DeepQLearning(object): def __init__(self, env: GridWorldEnv, gamma=0.9): self.env = env self.action_space_size = self.env.num_actions self.state_space_size = self.env.num_states self.gamma = gamma self.policy = np.ones((self.state_space_size, self.action_space_size)) / self.action_space_size self.net = nn.Sequential( nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, self.action_space_size) ) self.target_net = nn.Sequential( nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, self.action_space_size) ) self.target_net.load_state_dict(self.net.state_dict()) self.optimizer = torch.optim.SGD(self.net.parameters(), lr=0.01, momentum=0.9) self.loss = nn.MSELoss() def data_iter(self, episode, batch_size=32, is_train=True): reward = [] state = [] action = [] next_state = [] for s, a, r, next_s in episode: reward.append(r) state.append((s // self.env.size, s % self.env.size)) # 有空间位置 next_state.append((next_s // self.env.size, next_s % self.env.size)) action.append(a) reward = torch.tensor(reward, dtype=torch.float32) state = torch.tensor(state, dtype=torch.float32) next_state = torch.tensor(next_state, dtype=torch.float32) action = torch.tensor(action, dtype=torch.long) data_arrays = (state, action, reward, next_state) dataset = data.TensorDataset(*data_arrays) return data.DataLoader(dataset, batch_size=batch_size, shuffle=is_train, drop_last=False) def solve(self, epochs, update_frep): s = self.env.reset() a = np.random.choice(self.action_space_size, p=self.policy[s]) episodes = [] for _ in range(self.action_space_size * self.state_space_size + 1): episode = self.env.generate_episodes(self.policy, s, a, max_steps=1000) episodes.extend(episode) dataloader = self.data_iter(episodes) step = 0 for epoch in range(epochs): for state, action, reward, next_state in dataloader: step += 1 with torch.no_grad(): q_value = self.target_net(next_state) # [B,N] max_q = q_value.max(dim=1).values y_target = reward + self.gamma * max_q # [B,] y = self.net(state) y_ = y.gather(1, action.unsqueeze(1)).squeeze(1) l = self.loss(y_target, y_) self.optimizer.zero_grad() l.backward() self.optimizer.step() if step % update_frep == 0: self.target_net.load_state_dict(self.net.state_dict()) def get_policy(self): for s in range(self.state_space_size): if s in self.env.terminal: self.policy[s, 4] = 1 break s_t = torch.tensor((s // self.env.size, s % self.env.size), dtype=torch.float32) q_value = self.net(s_t) a = q_value.argmax(dim=0).item() self.policy[s] = 0 self.policy[s, a] = 1 return self.policy if __name__ == '__main__': env = GridWorldEnv( size=5, forbidden=[(1, 2), (3, 3)], terminal=[(4, 4)], r_boundary=-1, r_other=0, r_terminal=1, r_forbidden=-1, r_stay=-0.1 ) # 注意samples要大一点,否则每个state被访问到的概率很小 vi = DeepQLearning(env=env, gamma=0.8) vi.solve(epochs=50, update_frep=20) policy = vi.get_policy() print(policy) drow_policy(policy, env)
http://www.cnnetsun.cn/news/90366.html

相关文章:

  • 加密PDF处理新进展(Dify进度跟踪深度剖析)
  • 从零构建智能Agent文档系统:Dify配置与最佳实践全揭秘
  • 高负载环境下Docker Offload调度失控?优先级设置不当是元凶!
  • 还在手动校验语音数据?Dify 1.7.0自动检测功能已上线(限时体验)
  • 专家警告:不掌握量子计算镜像缓存技术,你的研发效率已落后同行三年
  • 对标行业高标准,全星研发项目管理系统赋能汽车芯片研发升级:PLM系统更专业化
  • LC.669 | 修剪二叉搜索树 | 树 | 递归与重连
  • DAY29 pipeline管道
  • A29语音模组:100dB消回音黑科技,超大音量下也能清晰通话
  • 1688 拍立淘接口(item_search_img)技术全景解析
  • Dify如何逆向解析加密PDF?,深入剖析现代文档安全的攻防博弈
  • 测试工程师必备:利用Apipost AI编写脚本,快速实现多接口串联流程
  • IP 扫盲:不要再迷信家宽
  • 基于协同过滤算法的动漫推荐系统源码 Java+SpringBoot+Vue3
  • 高效量子电路设计秘籍(R驱动的3种前沿优化策略)
  • 分分钟带你杀入Kaggle Top 1%,大牛带队,100%拿牌!
  • IP6808至为芯支持PD快充输入的15W无线充电方案SOC芯片
  • 笔记本重装系统超详细指南(附系统备份还原技巧,告别电脑店花费)
  • 大型地源热泵机组多高
  • 别墅供暖地源热泵
  • Traefik:为云原生而生的自动化反向代理
  • P1043 [NOIP 2003 普及组] 数字游戏
  • Web安全攻防学习图谱:90天从网安小白到漏洞猎人(超详细),看这一篇就够了!
  • 【Docker镜像优化黄金法则】:让边缘Agent更小更快更安全
  • 前端vue3 web端中实现拖拽功能实现列表排序
  • 【音视频开发必看】Dify 1.7.0音频转换避坑指南:5大常见错误及修复方案
  • VSCode+PlatfoemIO+ESP32-Cam + MB烧录器 入门测试
  • 【加密PDF解析避坑指南】:Dify错误处理的5大核心策略与实战技巧
  • 性能测试入门:使用 Playwright 测量关键 Web 性能指标
  • 从入门到精通:R语言极值分布拟合在气象数据中的4个关键步骤