Skip to content
Snippets Groups Projects
Commit ce1f3544 authored by Matteo Rossi's avatar Matteo Rossi
Browse files

Full obs fidelity and piecewise pot

parent 7f8c7ed1
Branches
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ import gym_stirap
import numpy as np
import matplotlib.pyplot as plt
env = gym_stirap.StirapEnv(initial_noise_std=.05)
env = gym_stirap.StirapEnv()
env.seed(11)
obs = []
observation = env.reset()
......@@ -11,24 +11,23 @@ obs.append(observation)
#actions = [3] * 200 + [2] * 90 + [4] * 4 + [6] * 120 + [4] * 300
actions = [4] * 21 + [2] * 90 + [4] * 4 + [6] * 120 + [4] * 300
actions_left = [0.] * 21 + [.5] * 90 + [0.] * 4 + [-.5] * 160 + [0.] * 300
actions_right = [0.] * 21 + [-.5] * 90 + [0.] * 4 + [+.5] * 160 + [0.] * 300
actions_left = [0.] * 21 + [.4] * 90 + [0.] * 4 + [-.4] * 160 + [0.] * 300
actions_right = [0.] * 21 + [-.4] * 90 + [0.] * 4 + [+.4] * 160 + [0.] * 300
actions = np.array([actions_left, actions_right]).T
#actions = [4] * env.timesteps
reward = np.zeros(env.timesteps)
for t in range(env.timesteps):
env.render()
plt.pause(0.0000001)
action = actions[t]
observation, reward[t], done, info = env.step(action)
observation, reward[t], done, _ = env.step(action)
obs.append(observation)
if done:
print("Episode finished after {} timesteps".format(t+1))
break
print(info)
print("Score: ", np.sum(reward))
print(obs[-1])
\ No newline at end of file
......@@ -15,15 +15,3 @@ register(
id='stirap-v0',
entry_point='gym_stirap.envs:StirapEnv',
)
register(
id='stirap-fullobs-v0',
entry_point='gym_stirap.envs:StirapEnv',
kwargs={'full_observation': 'True'}
)
register(
id='stirap-finalreward-v0',
entry_point='gym_stirap.envs:StirapEnv',
kwargs={'final_reward': 'True'}
)
\ No newline at end of file
......@@ -41,10 +41,10 @@ class StirapEnv(gym.Env):
3 Right well position -2 +2
If full_observation=True:
Type: Box(2 * n + 2)
Type: Box(3 * n, 3)
where n is the number of space points.
(re(psi), im(psi), left_well_pos, right_well_pos)
(re(psi), im(psi), potential)
Actions:
Type: Box(2)
......@@ -53,8 +53,8 @@ class StirapEnv(gym.Env):
1 Move right well of amount -1 +1
Reward:
Reward at each time step is proportional to the population in the right well times t
Reward -10 is given if the episode terminates before hand
Reward at each time step is proportional to the fidelity to the target state
Reward -2 is given if the episode terminates before hand
Starting State:
The system is initially in the ground state of the left well.
......@@ -69,11 +69,7 @@ class StirapEnv(gym.Env):
metadata = {'render.modes': ['human', 'rgb_array']}
def __init__(self,
full_observation=False,
final_reward=False,
initial_noise_std=None,
non_linear=0.0):
def __init__(self, non_linearity=0.0, initial_noise_std=0.1):
# Simulation parameters
self.n = 512 # Number of points in space
self.xlim = 2
......@@ -87,7 +83,7 @@ class StirapEnv(gym.Env):
self.dt = self.time[1] - self.time[0]
self.g = non_linear
self.g = non_linearity
self.mass = 2.
self.hbar = 1.
......@@ -112,6 +108,8 @@ class StirapEnv(gym.Env):
# Strength of the potential
self.trap_strength = 2.e3
self.omega = np.sqrt(self.trap_strength / self.mass)
self.solver = None
# Definition of the wells
......@@ -119,12 +117,11 @@ class StirapEnv(gym.Env):
self.left_well = np.where(self.x < self.left_center + 2 * self.well_size / 3, 1, 0)
self.right_well = np.where(self.x > self.right_center - 2 * self.well_size /3, 1, 0)
# Initial state of the system
# Initial state is the ground state of the left trap
self.omega = np.sqrt(self.trap_strength/self.mass)
self.initial_noise_std = initial_noise_std
# Target state is the ground state of the right well
self.target_psi = gauss_x(self.x, np.sqrt(self.hbar / (self.mass * self.omega)), self.right_center, 0)
self.it = None
self.il = None
self.ir = None
......@@ -132,17 +129,10 @@ class StirapEnv(gym.Env):
# RL parameters
self.Delta = 0.02 # How much to change the well positions at each step
self.full_observation = full_observation
self.final_reward = final_reward
self.action_space = spaces.Box(low=-1., high=1., shape=(2,), dtype=np.float32)
if self.full_observation:
self.observation_space = spaces.Box(low=-np.inf,
high=np.inf,shape=(2*self.n + 2,), dtype=np.float32)
else:
self.observation_space = spaces.Box(low=np.array([0, 0, - self.xlim, - self.xlim, 0]),
high=np.array([1, 1, self.xlim, self.xlim, self.totaltime]), dtype=np.float32)
high=np.inf, shape=(1, self.n, 3), dtype=np.float32)
# For the rendering
self.viewer = None
......@@ -154,16 +144,14 @@ class StirapEnv(gym.Env):
def reset(self):
"""Reinitialize the system to the initial conditions"""
if self.initial_noise_std is not None:
xn = self.initial_noise_std * self.np_random.randn()
print(xn)
self.psi0 = gauss_x(self.x, np.sqrt(self.hbar / (self.mass * self.omega)), self.left_center + xn, 0)
else:
self.psi0 = gauss_x(self.x, np.sqrt(self.hbar / (self.mass * self.omega)), self.left_center, 0)
xn = self.initial_noise_std * self.np_random.randn(2)
self.il = self.left_center + xn[0]
self.ir = self.right_center + xn[1]
self.psi0 = gauss_x(self.x, np.sqrt(self.hbar / (self.mass * self.omega)), self.il, 0)
self.psi = self.psi0.copy() # We need a deep copy
self.il = self.left_center
self.ir = self.right_center
self.it = 0
......@@ -173,6 +161,11 @@ class StirapEnv(gym.Env):
self.solver = Schrodinger(self.x, self.psi0, self.potential(),
m=self.mass, hbar=self.hbar, g=self.g)
# Reset the viewer
if self.viewer is not None:
plt.close(self.viewer)
self.viewer = None
return self.state
def seed(self, seed=None):
......@@ -220,23 +213,21 @@ class StirapEnv(gym.Env):
# We interrupt the simulation if the traps are moved too far away We punish the agent for moving the traps too far away, or for crossing them
if self.il < (- self.xlim + 1 * self.well_size ) or self.il > 0:
reward -= 10
reward -= 5
done = True
if self.ir < 0 or (self.ir > self.xlim - 1 * self.well_size):
reward -= 10
reward -= 5
done = True
return self.state, reward, done, self.psi
return self.state, reward, done, {}
def update_state(self):
""" Prepares the state to be passed to the agent"""
if self.full_observation:
# If full_observation, pass the real and imaginary part of the wave function, the positions of the wells
# and the time
self.state = np.append(np.real(self.psi), np.append(np.imag(self.psi), (self.il, self.ir)))
else:
# This is the state that the agent sees (left/right population, position of the wells)
self.state = np.array(self.evaluate_populations() + (self.il, self.ir, self.dt * self.it))
pot = self.potential() * np.where(np.abs(self.x) < 1.9, 1, 0)
self.state = np.vstack([np.real(self.psi), np.imag(self.psi), pot])
self.state = np.expand_dims(self.state.T, axis=0) # Add singleton dimension
return self.state
def render(self, mode='human', timeevolution=True):
......@@ -283,9 +274,10 @@ class StirapEnv(gym.Env):
self.axins.patch.set_alpha(0.5)
self.axins.set_xlabel('t')
self.leftpop_line, = self.axins.plot(self.dt * self.it, self.evaluate_populations()[0])
self.rightpop_line, = self.axins.plot(self.dt * self.it, self.evaluate_populations()[1])
self.centerpop_line, = self.axins.plot(self.dt * self.it, 1 - np.sum(self.evaluate_populations()))
self.reward_line, = self.axins.plot(self.dt * self.it, self.reward())
#self.leftpop_line, = self.axins.plot(self.dt * self.it, self.evaluate_populations()[0])
#self.rightpop_line, = self.axins.plot(self.dt * self.it, self.evaluate_populations()[1])
#self.centerpop_line, = self.axins.plot(self.dt * self.it, 1 - np.sum(self.evaluate_populations()))
self.rightwell, = self.axins.plot(self.dt * self.it, self.ir, 'k--')
self.leftwell, = self.axins.plot(self.dt * self.it, self.il, 'k--')
......@@ -300,18 +292,21 @@ class StirapEnv(gym.Env):
self.probability_line.set_ydata(ys)
self.potential_line.set_ydata(0.5*self.dx * self.potential())
if timeevolution:
self.rightpop_line.set_xdata(np.append(self.rightpop_line.get_xdata(),self.dt * self.it))
self.rightpop_line.set_ydata(np.append(self.rightpop_line.get_ydata(),self.evaluate_populations()[1]))
#self.rightpop_line.set_xdata(np.append(self.rightpop_line.get_xdata(), self.dt * self.it))
#self.rightpop_line.set_ydata(np.append(self.rightpop_line.get_ydata(), self.evaluate_populations()[1]))
self.leftpop_line.set_xdata(np.append(self.leftpop_line.get_xdata(),self.dt * self.it))
self.leftpop_line.set_ydata(np.append(self.leftpop_line.get_ydata(),self.evaluate_populations()[0]))
#self.leftpop_line.set_xdata(np.append(self.leftpop_line.get_xdata(), self.dt * self.it))
#self.leftpop_line.set_ydata(np.append(self.leftpop_line.get_ydata(), self.evaluate_populations()[0]))
self.centerpop_line.set_xdata(np.append(self.centerpop_line.get_xdata(),self.dt * self.it))
self.centerpop_line.set_ydata(np.append(self.centerpop_line.get_ydata(),1 - np.sum(self.evaluate_populations())))
#self.centerpop_line.set_xdata(np.append(self.centerpop_line.get_xdata(), self.dt * self.it))
#self.centerpop_line.set_ydata(np.append(self.centerpop_line.get_ydata(), 1 - np.sum(self.evaluate_populations())))
self.rightwell.set_xdata(np.append(self.rightwell.get_xdata(),self.dt * self.it))
self.rightwell.set_ydata(np.append(self.rightwell.get_ydata(),self.ir))
self.reward_line.set_xdata(np.append(self.reward_line.get_xdata(),self.dt * self.it))
self.reward_line.set_ydata(np.append(self.reward_line.get_ydata(), self.reward()))
self.leftwell.set_xdata(np.append(self.leftwell.get_xdata(),self.dt * self.it))
self.leftwell.set_ydata(np.append(self.leftwell.get_ydata(),self.il))
......@@ -328,18 +323,11 @@ class StirapEnv(gym.Env):
return image
def reward(self):
""" Define the reward """
reward = 0.
# We reward for the population in the last 90 % of time
if self.final_reward:
if self.it==self.timesteps:
(pop_left, pop_right) = self.evaluate_populations()
reward = - np.log10(1 - pop_right) #* (self.it > .9*self.timesteps)
else:
(pop_left, pop_right) = self.evaluate_populations()
#reward = 2 * (pop_right * self.it / self.timesteps) ** 2
reward = - np.log10(1 - pop_right)
return reward
""" Returns the overlap with the target wavefunction """
left, right = self.evaluate_populations()
fidelity = np.abs(np.trapz(self.target_psi * self.psi, self.x))**2
return right - left + 2 * fidelity
# These are auxiliary functions for the physical system
......@@ -349,12 +337,29 @@ class StirapEnv(gym.Env):
def evaluate_populations(self):
""" Returns the integral of |ψ|^2 in the left and right well"""
left = np.sum(np.abs(self.psi * self.left_well)**2) * self.dx
right = np.sum(np.abs(self.psi * self.right_well)**2) * self.dx
return (left, right)
def potential(self):
""" Returns the trapping potential """
v = 0.5 * self.trap_strength * ((self.x-self.il)**2) * (self.x ** 2) * ((self.x-self.ir)**2)
# A piecewise potential
v = (0.5 * self.trap_strength *
np.piecewise(self.x,
[(self.x <= self.il/2) & (self.x >= -1.9),
(self.x < self.ir/2) & (self.x > self.il/2),
(self.x >= self.ir/2 ) & (self.x <= 1.9),
],
[lambda x: (x-self.il)**2,
lambda x: (x)**2,
lambda x: (x-self.ir)**2,
lambda x: x**6]
))
# The last term makes sure that the derivative of the potential
# at the borders is high enough that the wave function is
# near zero at the border.
return v
......@@ -5,9 +5,8 @@ import warnings
import pytest
@pytest.mark.parametrize("name",
["stirap-v0",
"stirap-fullobs-v0",
"stirap-finalreward-v0"])
["stirap-v0"])
def test_make(name):
env = gym.make(name)
assert env is not None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment