SAC Code Help

I’ve been trying to implement SAC for use in openAI’s gymnasium. I have a DDPG algorithm that works for the pendulum environment however my SAC implementation never learns. I suspect the problem is my logic somewhere. From what I could tell from printing most of the steps, my q’s end of being full of nan but I cannot figure out why. If someone could help that would be appreciated.
class ReplayBuffer:

def __init__(self,memory_capacity=1000000,batch_size=64,num_actions=1,num_states=3):

self.memory_capacity=memory_capacity

self.num_states=num_states

self.num_actions=num_actions

self.batch_size=batch_size

self.buffer_counter=0

self.state_buffer=np.zeros((self.memory_capacity,self.num_states))

self.action_buffer=np.zeros((self.memory_capacity,self.num_actions))

self.reward_buffer=np.zeros(self.memory_capacity)

self.next_state_buffer=np.zeros((self.memory_capacity,self.num_states))

self.done_buffer=np.zeros(self.memory_capacity)

def record(self,observation,action,reward,next_observation,done):

index = self.buffer_counter % self.memory_capacity

self.state_buffer[index] = observation

self.action_buffer[index] = action

self.reward_buffer[index] = reward

self.next_state_buffer[index] = next_observation

self.done_buffer[index] = done

self.buffer_counter += 1

def sample(self):

range1 = min(self.buffer_counter, self.memory_capacity)

indices = np.random.randint(0, range1, size=self.batch_size)

states = torch.tensor(self.state_buffer[indices], dtype=torch.float32)

actions = torch.tensor(self.action_buffer[indices], dtype=torch.float32)

rewards = torch.tensor(self.reward_buffer[indices], dtype=torch.float32)

next_states = torch.tensor(self.next_state_buffer[indices], dtype=torch.float32)

dones = torch.tensor(self.done_buffer[indices], dtype=torch.float32)

return states,actions,rewards,next_states,dones

class Critic(nn.Module):

def __init__(self,num_states,num_actions,action_bound,learning_rate):

super(Critic,self).__init__()

self.num_actions=num_actions

self.num_states=num_states

self.action_bound=action_bound

self.lC=learning_rate

self.fc1=nn.Linear(num_states,128)

self.fc2=nn.Linear(num_actions,128)

self.combinedfc1=nn.Linear(256,256)

self.combinedfc2=nn.Linear(256,1)

def forward(self,s,a):

state_out=F.relu(self.fc1(s))

action_out=F.relu(self.fc2(a))

combined=torch.cat([state_out,action_out],dim=-1)

combined=F.relu(self.combinedfc1(combined))

x=self.combinedfc2(combined)

return (x)
class Actor(nn.Module):

def __init__(self,num_states,num_actions,learning_rate,action_bound):

super(Actor,self).__init__()

self.num_states=num_states

self.num_actions=num_actions

self.lA=learning_rate

self.action_bound=action_bound

self.fc1=nn.Linear(num_states,256)

self.fc2=nn.Linear(256,256)

self.mu_head=nn.Linear(256,num_actions)

self.log_std_head=nn.Linear(256,num_actions)

self.min_log_std=-20

self.max_log_std=2

def forward(self,state):

state=torch.tensor(state,dtype=torch.float32).clone().detach()

x=F.relu(self.fc1(state))

x=F.relu(self.fc2(x))

mu=self.mu_head(x)

log_std_head=F.relu(self.log_std_head(x))

log_std_head=torch.clamp(log_std_head,self.min_log_std,self.max_log_std)

return mu,log_std_head

class Agent:

def __init__(self,env):

self.env=env

self.state_dimension=self.env.observation_space.shape[0]

self.action_dimension=self.env.action_space.shape[0]

self.action_bound=(self.env.action_space.high[0])

self.buffer=ReplayBuffer()

self.learning_rate1=.0001

self.learning_rate2=.001

self.tau=.01

self.gamma=.9

self.alpha=.2

self.actor=Actor(self.state_dimension,self.action_dimension,self.learning_rate1,self.action_bound)

self.critic=Critic(self.state_dimension,self.action_dimension,self.action_bound,self.learning_rate2)

self.target_critic=Critic(self.state_dimension,self.action_dimension,self.action_bound,self.learning_rate2)

self.target_critic.load_state_dict(self.critic.state_dict())

self.actor_optimizer=optim.Adam(self.actor.parameters(),lr=self.learning_rate1)

self.critic_optimizer=optim.Adam(self.critic.parameters(),lr=self.learning_rate2)

self.critic2=Critic(self.state_dimension,self.action_dimension,self.action_bound,self.learning_rate2)

self.target_critic2=Critic(self.state_dimension,self.action_dimension,self.action_bound,self.learning_rate2)

self.target_critic2.load_state_dict(self.critic2.state_dict())

self.critic2_optimizer=optim.Adam(self.critic.parameters(),lr=self.learning_rate2)

self.value_crit=nn.MSELoss()

self.q_crit=nn.MSELoss()

def select_action(self,state):

mu,log_std=self.actor(state)

std=torch.exp(log_std)

action=torch.normal(mu,std)

return action

def log_probs(self,state):

state=torch.tensor(state,dtype=torch.float32)

#action=torch.tensor(action,dtype=torch.float32)

mu,log_std=self.actor(state)

std=torch.exp(log_std)

normal=torch.distributions.Normal(mu,std)

action=normal.sample()

log_probs=normal.log_prob(action).sum(axis=-1,keepdim=True)

log_probs-=torch.log(1-action.pow(2)+1e-6)

return log_probs

def soft_update(self):

for target_param,param in zip(self.target_critic.parameters(),self.critic.parameters()):

target_param.data.copy_(self.tau*param.data+(1-self.tau)*target_param.data)

for target_param, param in zip(self.target_critic2.parameters(),self.critic2.parameters()):

target_param.data.copy_(self.tau*param.data+(1-self.tau)*target_param.data)

def train(self,max_step,max_episode):

theta_values=[]

time_values=[]

for episode in range(max_episode):

state,_=self.env.reset()

for step in range(max_step):

action=self.select_action(state).detach().numpy()

action=np.clip(action,-self.action_bound,self.action_bound)

next_state,reward,done,trunc,info=self.env.step(action)

self.buffer.record(state,action,reward,next_state,done)

states,actions,rewards,next_states,dones=self.buffer.sample()

states=torch.FloatTensor(states)

actions=torch.FloatTensor(actions)

rewards=torch.FloatTensor(rewards)

next_states=torch.FloatTensor(next_states)

dones=torch.FloatTensor(dones)

log_probs=self.log_probs(state)

q1=self.critic(states,actions)

q2=self.critic2(states,actions)

with torch.no_grad():

next_action=self.select_action(next_states)

q1_next_target=self.target_critic(next_states,next_action)

q2_next_target=self.target_critic2(next_states,next_action)

q_next_target=torch.min(q1_next_target,q2_next_target)

next_log_probs=self.log_probs(next_states)

value_target=rewards+(1-dones)*self.gamma*(q_next_target-self.alpha*next_log_probs)

q1_loss=((q1-value_target)**2).mean()

q2_loss=((q2-value_target)**2).mean()

loss_q=q2_loss+q1_loss

self.critic_optimizer.zero_grad()

self.critic2_optimizer.zero_grad()

loss_q.backward()

self.critic_optimizer.step()

self.critic2_optimizer.step()

self.actor_optimizer.zero_grad()

actions_pred=self.select_action(states)

log_pred=self.log_probs(states)

q1_pred=self.critic(states,actions_pred)

q2_pred=self.critic2(states,actions_pred)

q_pred=torch.min(q1_pred,q2_pred)

actor_loss=(self.alpha*log_pred-q_pred).mean()

actor_loss.backward()

self.actor_optimizer.step()

self.soft_update()

if done:

break

state=next_state

submitted by /u/Spiritual_Basket8332
[link] [comments]

Leave a Reply

The Future Is A.I. !
To top
en_USEnglish