当前位置: 首页 > news >正文

强化学习Q-learning求最优策略

理论基础:

on policy:behavior policy=target policy

off policy:behavior policy!=target policy

注意:

behavior policy的初始化最好具有较强的随机性,就能尽可能遍历到所有的(s, a)pair。

强化学习的数据基础这种书中有不同的behavior policy导致的不同的探索路径的图:

代码可运行:

import numpy as np from env import GridWorldEnv from utils import drow_policy class Q_Learning(object): def __init__(self, env: GridWorldEnv, gamma=0.9, alpha=0.001, epsilon=0.1, samples=1, start_state=(0, 0),mode="on policy"): ''' :param env: 定义了网格的基础配置 :param gamma: discount rate :param alpha: learning rate :param samples: 从起点到终点采样的路径数 :param start_state: 起点 :param mode: 模式 ''' self.env = env self.action_space_size = self.env.num_actions # 上下左右原地 self.state_space_size = self.env.num_states self.reward_list = self.env.reward_list self.gamma = gamma self.samples = samples self.alpha = alpha self.epsilon = epsilon self.mode=mode self.start_state = self.env.state_id(start_state[0], start_state[1]) self.behavior_policy = np.ones( (self.state_space_size, self.action_space_size)) / self.action_space_size # 探索性很强 self.target_policy = np.zeros((self.state_space_size, self.action_space_size)) self.qvalues = np.zeros((self.state_space_size, self.action_space_size)) def update_qvalues(self,s_t,a_t,s_next,r_next): max_q_next = np.max(self.qvalues[s_next]) td_target = r_next + self.gamma * max_q_next td_error = td_target - self.qvalues[s_t][a_t] # 负号提出去 self.qvalues[s_t][a_t] += self.alpha * td_error def solve(self): if self.mode=="off policy": for _ in range(self.samples): s = self.start_state a = np.random.choice(self.action_space_size, p=self.behavior_policy[s]) episode = self.env.generate_episodes(self.behavior_policy, s, a) for i in range(len(episode)): s_t, a_t, r_next_t, s_next_t= episode[i] self.update_qvalues(s_t,a_t,s_next_t,r_next_t) # greedy best_a = np.argmax(self.qvalues[s_t]) self.target_policy[s_t] = np.eye(self.action_space_size)[best_a] elif self.mode=="on policy": # target_policy=behavior_policy for _ in range(self.samples): s = self.start_state while s not in self.env.terminal: a = np.random.choice(self.action_space_size, p=self.behavior_policy[s]) # generate at following πt(st) next_s, next_r, _ = self.env.step(s, a) # generate rt+1, st+1 by interacting with the environment # updata q-value for (s_t,a_t) # qt+1(st, at) = qt(st, at) − αt(st, at) [ qt(st, at) − (rt+1 + γ max(qt(st+1, a)))] self.update_qvalues(s,a,next_s,next_r) # update policy for s_t: epsilon greedy 因为要用policy生成数据,因此需要策略具有一定的探索性,因此使用epsilon greedy best_a = np.argmax(self.qvalues[s]) self.behavior_policy[s] = self.epsilon / self.action_space_size self.behavior_policy[s, best_a] += 1 - self.epsilon self.target_policy=self.behavior_policy s = next_s else: raise Exception("Invalid mode") if __name__ == '__main__': env = GridWorldEnv( size=5, forbidden=[(1, 2), (3, 3)], terminal=[(4, 4)], r_boundary=-1, r_other=-0.04, r_terminal=1, r_forbidden=-1, r_stay=-0.1 ) # 注意samples要大一点,否则每个state被访问到的概率很小 vi = Q_Learning(env=env, gamma=0.8, alpha=0.01, samples=1000, start_state=(0, 0),mode="off policy") vi.solve() print("\n state value: ") print(vi.qvalues) drow_policy(vi.target_policy, env)
http://www.cnnetsun.cn/news/53528.html

相关文章:

  • 你对电脑上的【Fn】熟悉多少
  • 计及N-k安全约束的含光热电站电力系统优化调度模型【IEEE14节点、118节点】附Matlab代码
  • 计及需求响应的粒子群算法求解风能、光伏、柴油机、储能容量优化配置附Matlab代码
  • conda使用详细指南
  • 豆包与DeepSeek底层大模型的深度解析:技术架构、设计理念与生态分野
  • Linux系统中的socket激活:先创建监听端口,后启动程序
  • 从零解决pyproject.toml构建失败的实战指南
  • Redis Lua脚本入门:从零写出你的第一个原子操作
  • 旧机转手不再慌!电子产品信息清除新国标落地,核心技术逻辑全解析
  • 安全体验馆好用供应商
  • 第二章——数据分析场景之Python数据可视化:用Matplotlib与Seaborn绘制洞察之图
  • 【Java毕设全套源码+文档】基于springboot的高校毕业生离校管理系统小程序设计与实现(丰富项目+远程调试+讲解+定制)
  • 如何用AI工具jstat优化Java应用性能分析
  • 【Java毕设全套源码+文档】基于springboot的高校毕业生信息管理系统的设计与实现(丰富项目+远程调试+讲解+定制)
  • Day 38 GPU训练及类的call方法
  • 【Python实战】火爆全网的“隔空手势画板”是如何实现的?教你用OpenCV+MediaPipe复刻钢铁侠黑科技!
  • 【学习笔记】如果打造可复现、可评测、可迭代的AI技术体系
  • 【论文自动阅读】See Once, Then Act: Vision-Language-Action Model with Task Learning from One-Shot Video Demo
  • 利用齐次坐标系证明各种几何定理【射影几何】
  • 小程序基于springboot的乡镇普法知识科普宣传系统 律师预约系统设计与实现_qf4cwws6(java毕业设计项目源码)
  • 面向对象编程三大特性:封装、继承、多态的核心要义
  • leetcode 2147. 分隔长廊的方案数 困难
  • 学生党必备!这款桌面课表工具太省心了
  • 深度学习实验14代码
  • 优化及性能-–-behaviac
  • 练题100天——DAY26:汇总区间+丢失的数字+数组交集
  • 当AI芯片不再性感:博通的高增长,为何成了催命符?
  • Vibe Coding:AI驱动的编程新范式
  • AI 数字孪生工厂:西门子与中信特钢的实践,如何降本 11%?
  • Spring IoC的实现机制是什么?