From 63e4efda03cc10834d276960d1161ca81690bb8d Mon Sep 17 00:00:00 2001 From: hjorvardr Date: Wed, 28 Nov 2018 15:19:00 +0100 Subject: [PATCH] Fixed sarsa, update networks --- Ensembling/ensembling_mountaincar_mixed.py | 28 ++++++++++-------- FrozenLake/SavedNetworks/sarsa_4x4_policy.npy | Bin 256 -> 256 bytes FrozenLake/SavedNetworks/sarsa_8x8_policy.npy | Bin 640 -> 640 bytes .../SavedNetworks/sarsa_8x8d_policy.npy | Bin 640 -> 640 bytes FrozenLake/ql_4x4.py | 4 +-- FrozenLake/ql_4x4_deterministic.py | 4 +-- FrozenLake/ql_8x8.py | 4 +-- FrozenLake/ql_8x8_deterministic.py | 4 +-- FrozenLake/sarsa_4x4.py | 6 ++-- FrozenLake/sarsa_4x4_deterministic.py | 6 ++-- FrozenLake/sarsa_8x8.py | 6 ++-- FrozenLake/sarsa_8x8_deterministic.py | 6 ++-- MountainCar/SavedNetworks/ql_policy.npy | Bin 180128 -> 180128 bytes MountainCar/dqn_mountain_car.py | 16 +++++----- MountainCar/ql_mountain_car.py | 4 +-- MountainCar/sarsa_mountain_car.py | 4 +-- ReinforcementLearningLib/sarsa_lib.py | 24 ++++++++++----- Taxi/SavedNetworks/sarsa_policy.npy | Bin 4128 -> 4128 bytes Taxi/ql_taxi.py | 4 +-- Taxi/sarsa_taxi.py | 8 ++--- 20 files changed, 69 insertions(+), 59 deletions(-) diff --git a/Ensembling/ensembling_mountaincar_mixed.py b/Ensembling/ensembling_mountaincar_mixed.py index 7eab2ce8..a90e7f4b 100644 --- a/Ensembling/ensembling_mountaincar_mixed.py +++ b/Ensembling/ensembling_mountaincar_mixed.py @@ -48,20 +48,20 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): layer1 = Dense(15, input_dim=input_dim, activation='relu') layer2 = Dense(output_dim) - agent1 = DQNAgent(output_dim, [layer1, layer2], use_ddqn=True, learn_thresh=1000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.01, optimizer=keras.optimizers.RMSprop(0.001), tb_dir=None) - agent2 = QLAgent([n_states, n_states, env.action_space.n], epsilon_decay_function=lambda e: e * 0.6, epsilon_lower_bound=0.1) - agent3 = SARSAAgent([n_states, n_states, env.action_space.n], epsilon_decay_function=lambda e: e * 0.6, epsilon_lower_bound=0.1) + agent1 = DQNAgent(output_dim, [layer1, layer2], use_ddqn=True, learn_thresh=1000, update_rate=300, epsilon_decay_function=lambda e: e - 0.001, epsilon_lower_bound=0.01, optimizer=keras.optimizers.RMSprop(0.001), tb_dir=None) + #agent2 = QLAgent([n_states, n_states, env.action_space.n], epsilon_decay_function=lambda e: e - 0.001, epsilon_lower_bound=0.01) + #agent3 = SARSAAgent([n_states, n_states, env.action_space.n], epsilon_decay_function=lambda e: e - 0.001, epsilon_lower_bound=0.01) + agent4 = DQNAgent(output_dim, [layer1, layer2], use_ddqn=False, learn_thresh=1000, update_rate=300, epsilon_decay_function=lambda e: e - 0.001, epsilon_lower_bound=0.01, optimizer=keras.optimizers.RMSprop(0.001), tb_dir=None) - agents = [agent1, agent2, agent3] - agentE = EnsemblerAgent(env.action_space.n, agents, EnsemblerType.MAJOR_VOTING_BASED) + agents = [agent1, agent4] + agentE = EnsemblerAgent(env.action_space.n, agents, EnsemblerType.TRUST_BASED) evaluate = False for i_episode in tqdm(range(n_episodes + 1), desc="Episode"): state = env.reset() - if (i_episode % 100) == 0: - agent3.extract_policy() + # agent3.extract_policy() discretized_state = obs_to_state(env, state, n_states) cumulative_reward = 0 @@ -88,13 +88,13 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): r2 = reward + np.sin(3 * original_state[0]) r3 = reward + (original_state[1] * original_state[1]) r4 = abs(new_state[0] - (-0.5)) # r in [0, 1] - reward = r4 new_state = np.reshape(new_state, [1, 2]) - agent1.memoise((state, next_action, reward, new_state, end)) - agent2.update_q((discretized_state[0], discretized_state[1]), (new_discretized_state[0], new_discretized_state[1]), next_action, reward) - agent3.update_q((discretized_state[0], discretized_state[1]), (new_discretized_state[0], new_discretized_state[1]), next_action, reward) + agent1.memoise((state, next_action, r4, new_state, end)) + #agent2.update_q((discretized_state[0], discretized_state[1]), (new_discretized_state[0], new_discretized_state[1]), next_action, reward) + #agent3.update_q((discretized_state[0], discretized_state[1]), (new_discretized_state[0], new_discretized_state[1]), next_action, reward) + agent4.memoise((state, next_action, r4, new_state, end)) if end: @@ -112,6 +112,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): cumulative_reward += reward agent1.learn() + agent4.learn() cumulative_reward += reward scores.append(cumulative_reward) @@ -136,6 +137,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): new_state, reward, end, _ = env.step(next_action) new_discretized_state = obs_to_state(env, new_state, n_states) original_state = new_state + new_state = np.reshape(new_state, [1, 2]) if end: @@ -164,11 +166,11 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): # Training -train_res = experiment(505) +train_res = experiment(200) training_mean_steps = train_res["steps"].mean() training_mean_score = train_res["scores"].mean() -np.savetxt("results/ens_mixed_major.csv", train_res["steps"], delimiter=',') +np.savetxt("results/ens_mixed_trust_cont.csv", train_res["steps"], delimiter=',') print("Training episodes:", len(train_res["steps"]), "Training mean score:", training_mean_score, \ "Training mean steps", training_mean_steps) diff --git a/FrozenLake/SavedNetworks/sarsa_4x4_policy.npy b/FrozenLake/SavedNetworks/sarsa_4x4_policy.npy index fc6f19728899f26e86f39c2f95677acd2177854a..a4f8cefcc254b641a4dd35f303e9cd5082a45aa5 100644 GIT binary patch delta 20 ZcmZo*YG9f$fsuLQMEQviI6&xt0{}|s2kig= delta 16 WcmZo*YG9f$fstwAM0p^w(H{UP{sjF1 diff --git a/FrozenLake/SavedNetworks/sarsa_8x8_policy.npy b/FrozenLake/SavedNetworks/sarsa_8x8_policy.npy index a944cfe6c1ae05d1ce5c5301b3d1bd8491e736b0..4be6acf4c61f8f672386ed51f636a3636b60bd66 100644 GIT binary patch delta 119 zcmZo*ZD5_yG4Y1LwQsVuT3-MJzy~lM{djL1ZVs5C8y`x+Nt5 delta 129 zcmZo*ZD5_yG1-GrV&Vdh$txHoCL1tvfN=w(z{Ck0lM9#xCNXkMV$_(pL15wuiOB*$ xX#pmV$pK6PlNEsM1B@IKJ0QX!E=bn`j)@z9ihxEUQy^^?P%Q!zX9z$fIRIwzBs>5B diff --git a/FrozenLake/SavedNetworks/sarsa_8x8d_policy.npy b/FrozenLake/SavedNetworks/sarsa_8x8d_policy.npy index 4738ebd8d55189b865249c8a9b08cc47c80c6189..a53d03bd81b413074c0c96eee1e66b360798e404 100644 GIT binary patch delta 70 zcmZo*ZD5^nV6p?F#>5E*lbARrFJKgyY`|zRc><6w0OA9T9Ft!F=?g$w0Z4CP6qv-M Q093{SR-(YjF^Q=G02E^tBme*a delta 68 zcmZo*ZD5^nVB!Rci3d0)K5&@4fKgzw0i(gh2ON_RFmg;TU=)~mKwxqLP*eekHvn-1 PPu>+(@BYaD$v^)W|LiwE z|M6e{^fy2M@1GC-^N)YJ`Jcav{D1WScYO)uB41!_?)#sV;{e zzW(e}hv`c_xjdM4ykYj?tC;&HU(K0(nTxk~J^FKjBNy-;>h+z`6Zz=L>2g{UPqX-& zFY8n8R5z#P(xE(<{d`cp`PElZt~@+ZJ<%L~^`@(*pFZKyzSUfNKlof==K{XV?D5^v z(}ySOgR6CO;MC98u+2?BAKl!>=}M)t|nr=T3QN_s#ux`R;sv?|SfSUm$-^-;Y-%rzi3y%BSPmzsmK) z)@PnNKcr9C{Hf#SCQkMAH=a4kTi;cBo9B*a=d}EmyX%)@-v__O1^mvooqk9DMEy_> zU+OsIU*(yr52`1s!<h$#C(_B6|{1fu)hx%~7NmpOR%++UZ<4tdQ{0&#< zWxqa{^`kuVyz9ZwxPaehcN^bB_9xF=eK6}2mk*o&?z*|&(J=SmqvM&E_2jLO-mvBH zPjx=X59x4K&-d#;isr)Iob0pTgP(B$zo*}+ZN;a~Plu`FkWWrSe)Z(5diEukGoe1{ z9m!X9^YlagkPhiEbLfeD^vqMIL;eZVpM2`yt>?bG{YU-gKKR-T_`OW^eL*^G{?zfz zq04La^zk>`^=D4*k#*dhqu6qCzI=YzeCfv<%4zk~&8KI7az5DVa$8-0L;dQjIqAcp zya~;J@U<84J0<28>U6&+o|yh!KILV7l}~%O9DUw_LwZ9#`8Yov()kkkR`vAZ)0|ab z+eiPaFyDj2tM~Bk2VZjmzoTh>PrPCFrB4nYp1#KQspE3+tmBE3KkNGGFz3^udZN1A ztf!9a<4->6S=TQg()l4B@~hv4=1uvlzA2}@Ysy#G|KMvb;O;P|Hl5G+z?aB3p`7NE z+v->8_L$H-KHQ#A9nv8k(&Z&KADtghJX_B^`>yKgOD+$#`l_D3{{p3M;c>2{RU0u%9 zrw;W&dFqf3)1SINsGeL7fqw`k#-0{J^3*>!qDSq$N)f?{mTh6QX z+&kZ&b(|08-hA}rbp7gye02WQlczuXaJbqp7wX3$9m>aH^Q-S-%WL{>{#m~}Jb3N` ze$Pxv9f!@|>V9v1&6oA;Q{TnRNp7E}w|UuzLw#~nS5Lk9)2E-m;pCI6ABR)FzC``0 zPkic|_R$Zgd0EfC~WU$)`V2o_;)iS?5a~Pvrk9X#cjadfw^bzg}RvPpiB4)ps-x%1=(e3iCaU5}%H_ZO*OD-q%8(Pybzfb&pk_ z@7TVruj%?4Wm&-42{`0EAy4%0pSF3q=^ z)A}<%dG@KptS7f0KTKb8eRO@0KH=nFt*0+}+rvAWKXp9&Qdc(@Pi%dBSx-LY%YphK z9cCUM<|xPk-a| zMEQy4L-oex@-@_-etPRWs$Z3RRnGJdcllA#ecXHZ4*GZduJYSEaW`Lo_BB0y$&dDB zPVy-?>o}C>cZ2H5Q_nh{m~&drs!x9U>8qG~Bwx*0_1%^44qqMJarGVUnf}yq$e)}J zbD#7l*Vi!f_@*57)-PulKu8Wg6@;w!FS>9 zeNFf7sm}M+u)VA4tMAMH{%T(H`5hko^A~Wp zKVN;ny1HNU(c#e-9qn&@P0#)OJ<^lY<*6sOdeil1J$dudSJ69~Kj$XbH|6ksI(o-< zMZaIZ&+n4@)x6E8!=rPL_P_i6`QGUrZC`r!(-X~29hZZrKY9AIuhseBs*k_nRXHvH z)q1|i{eN_H7k5Zb-^HtPcIQt1tM;05rh4wl@9#n#&pI9QL;gg5sIHHmoKHPDJ^PxT zzU2B~&P&dhxa!Nk#;?w6bDHi>pN8%N?dKhM)^V7AeUJ~o8_&*b{nI4~epHkZDdlk>kS_gBHrukDfky^Q*(wul_E)YOl;m z-rnbauSWYFoqM)_ntONtH2rHN=2K7PZ^*Bn z^RvzekJhI?`H4sSGAH?JE}uKS3his&)q6~Nt)BWWKl+}WBj>B2{o(5G%eRW|!Jj`z$VDs(fww$K#=HKma zd(z$OZhZB;yWZpXI+}mBe|2t~Lx-z*>C^9bO-w(|hs(j$Vd}~GR?$5Ei8mh|+AH(a z8}9mB&Z<7mQ-`a$eA$*`6!v`~mkDk8dd=qA0@~Qu-y1B6B zPIqX1-%WRitI_>;bC34z&N=FT)qdu{qx-(9&)uY{_KMJ~!d)$rgnse*{v%bsE=A7-{-S?~e^LI_2_f&7l-|DZ{{XVO6&+ec7$yfLI zZa%-mgMaw~zL&r3>ebc$RbQX~itl3%^QJaZuZ z-RMqtVZQgSIZy7d7jW;R_c+?G59a*T+jp;zzxl4#bKb-UGGM} z!(DUSSAOzISKr0E-?h8XyZm`?cY*53?MrW%{mu6(J@-j|)gE#augd?_9PiEjd@om{ zd6{c|t2ck@pTfPzKCAEIe-*xZFL#07i6`1G^-t%yckNv*>&J zuIR4tUGHr3oBpmnUp4pace~o}zN`N9o$Z(Bu9=sdKT&_`tNq#cRk=HSHS{~3-TTvW zyyvRD@AFUtCVd{-%e)hLM`d3Bo z^!~4o`JUugy-#l9Rrz=2*w^=Z^&I&vH|zFDT|b?Vj>GiPp*~2T`uG~^hkQ7sC+0p` z*UyK;=1+Z>xBT6l=0BV79-o5Vb9A4x{YU41YTu{cqwPh9c~?4A&%Y0KnDxf_Gnd}_ zj_Puu96a|>hkSU$RbT6Wm7aU%`|-Qc9(T`q^7k&#ewWmfo6_{G^FiM~UvfG>YM5< z_bxs6N#5qBFMnrsIOXuc=BM*P^KnR@kiX@tx16c2Z`C*XGbg$JcVXMdeP504@annm zn&D?!{?+eG4nN*7eR!gNd#Tg$Rcw8k-||}jRsNQ9weG#| z!o1_V_W0Df?z6k!(Z1a|{AcI+uH`kH=EzAO4*44L*@KS`Q%|10tmBYBQ61)9bf~{^ zdPDuN`ChGi_q))0SLeIms(;sacfNgIHOIW&JL_A;d{6Qzr}>+{nwR~_-G!dm>gKCM z{f*PlqB(GNpR;rC?)RQH$2-3(djIPCCjV4lo%3nFPkmSJmAu`N&v%bQdZIi&$cICI zbzD7lTphMPe#nRG$0uC%uliQ!U6tqkuSWa7dT!2h@6?a-yWZ7un{NNJbvdubX>a?i z`uL&WT|W-_)RXf~s2@&#b^b(qK>l5n+xn(@>zj1BaPnu}yFV4ZYuazs*K)EShx(G! zkK&X!)$f`scXfXHCN4iQ`_xlU&bNyC{XH~Jhw}A7^@(TyZl8YG=FuDG+~!-=UzO+G z??&%^_a3jFpYKlo>V15l_JCLKadb}h*^7=(*xr@8TskhFj;qrl9rCN=Fnx4>$agoI z2d|nZ54K!&clgn8_5G{9c0c{F`BLZi`>E5Dr>?$>Q%Z^Ong{ys&uiceyUc>a6 zmpZPVdFpgN`h>gwDetPf_gs}H$9q2wz4z1h+TG{scU+yf`<}_)zN@SEp7QL`d~`Uy zOPxQ_yj_$F<;uZf>iXytX3peSZ+%VIw~DLp;QJKx{2>e+`w{rrv7Vfy94^zl#l>b|S@@&2!d?*6L1Up>#e z-Zj_!hG+N6e&016=KQ9o4^MycoZI^7>6Zu9)1N&3`sk3JIMvhtDt^~q?)O#E{d0f! zO8r&*>iJjaO?NO4^0#}dpT(TFnwNda+un4TeYkq=lXbq-ahN%{I;6`@PG3d2iL3pe z;&X>jL-%pt)&2CX`aj*b`~LQxvpG|L)|2n<)%x$&bI<8}vJW2)>2e$Ahw0b1i}H5+ znm_ZBx4x#g{?y&$Q_;S;pLeIe%3JPg|EKu!9m%iWGxL+L&O6#C|LpHCCsD5bVEXV@ zPd#&!H$R_wkiLuN?Dl<%-}}D`x|@C7AD`-c-v#siyYF(ZmecgrIoWUCYJc-(PI7%Y zKlQBRu;t20&qnPtrPVUisXZ5y^`QHV-^X~UtJ?CuR)IZgC?{zj$ zKWy`-`fi`UqwJ%b2h*3jK0fu-v(5+A6SFRN7v<_QUI)C2ftMb+NxC@;6 z_@=w8`ZDKeU*??U(;jE*xt~0k_2hgo>l1H%O`r07H~R6`$2ZmGG{1U7eXV|%-sU#_ z>>llokN(*g@IAm+ZR>lxdS1RG=Vaadrt5EY{V@H>`4Z(PHeb{AtDFC$p*wgt?k-;) z+q>?1=UusX{SIy4UHw)2wY{eK*_T{@>bUO+^52E#$~Sk?`Cf(gdDXl&$DMw3^n18N za{5=p-FNWAd~ehFR_l{b-h^@>e{$cy`h>gwDbGHz`OW=m=v`;`o^smVzN_w@uX8-!%)BW0<)LVY~a^7lR`f!+YlIL7J`|0|V^F#F?1>3vb z>qkL%nBF`2)M3k;>aDNoQ{Ghfy|?_PXWuj@b0+^(&%ESQf9}89FE=s!C$4|;soPf{ zKH;u^%KMbMyZq?r_jL!q4}KTs+-c6PKl9$*pZhePdnRvvuhMtFkJbI+92eY1h)t7y^|8K}i&|%f3(Hcl{1uHP`Q% z`!_v(yZmTw>)X|L_r1#B_Pi?ps+_ZX>fgmr`TxUJdwBoRysLiCC+F7-_&uh3_`N3I zR9~Gp<+S>4{#Ea4`@CAu_piP$`|PbBrjA27{8!`M`B(4jE)U*w0l!yE_q}Aj>FL9p zzv=p3g;(#9dCBu#XX{hml&`<}=)RY)hW3Zk{;NJY{Oi6qF20JZbDFQ|ZT{r(bFQlAz7sz>H*;H#-<2Qot3x^;9X4J4tKjsGtS5J;2hUx=T{FdZlzQVWXIGcs z<~Lm*Y;)B=72ErczNh82_nQOLN0+;bNBdUuKi!vieE7W=@V%{0+4WDkSLaRhcI!v? z%f96P4%V$XPsv~1!<`$*-uYja=ye>U-l)Ra_pUbt9;7Ox_J-2=K_9L zzl$yCc-HCf#@s7;+i%s^`tv@`mwL`kp8kpJw@>oySBLrsEWbNH@A&Z7 zT)=nX_noHh`ZF*4b1n|Ef8s~`^{M0P^i{m7Z#8Gtx4QpP zpSwNy-V691&Tb(G%1cha3U~8;uQ@;UDYw<>i8*hT=U)2cCvQ5RJU%=z>%0A#x65;W z_P2V|^{Yergx>$)doSR3%PpFIwa*-w`LE{F9=>Zn98UZ4>E};9>p0}spL%kBdZPNP z(H^UF`Qd8cQJ?!h_*ob5d)mtPfS<*dca?s0zc&A9UiRf(vu-|2KV6?Xz47#?FZ);P ze7o4@WzN;U%-`ki_~2(>!0&o?OW)s=v+9%6{HZ6`pP0VJQ_mbc?}ukU9cCSG^_<5y zp}x$Y`mU;*e^s7$T%9ZL!Oy&a@272--^EeiRXOIjywsEHZ|ENS60=X8t{;b!E+>7d ztMBUi5;G?`f74q%->o0&YnZ;qbI!!uJbe#-jSKi4k8bOG(m!!N$PeiguKM|4%gsAX z{+6Q;W^Uv3yD;~;YtGr+yZ6le_B%fMXJ5eY2y@$}rw>nma=r$^@r z9jfEz;p%i8()l0!^)gr8;p~0=P7i+X3;5oqSNfj#^wDv3sE#+>^)){o=DgH#n11hq ze97rkzrM3*9z2_u{deWMPtH5b<6>H%cq>I<1ptnJ$=ciK6iQW$GU*;B)^Z{E9loJ7l(BI0zY4FOe>88mZ`JR2dho}-fbXSUK^^iXrziS*O4K6=CS+pqao^~_!EOW%}}_2iki%3DtA$>qD_gU<#2bOGNNjVb^eyC zo_%<$Pr7-0ke>PU^fjKkez}m|u;r}!TYuAA?xe5gr7wA#Gu=!5!RG=y7x4Y<^!zz* zez!dG6V)5aYxUIm8}jEKc=k0teP{Wr=H#At&0Wp)dp!7D;LlyacLaUEymWjQck}pH z=kn=;bV#><>dE<=uD*(^IedvZM}1XKU*j#8-~Djd`si@-%YX2>z{&;Mch!3NVe>Uz zA8fu!&whUwP=E6^T|eZ*ug2B<=A$R(p6aPjJo}RG_NU+P^x$)WoeTJm^x(cv`h@9E zp8e*~(~tAP=BKBx>3q{1zQoMU`ls;rzO(zuh4PZq6W#B@=K^Ogkl&T>PiEHX>7zq^ zctbf*AFhuM)syqXtY@G4gj3&CpYo@A=I8seKmA!xo_-vrf8zS=`QUSbD=y%>Yu}f; zS;?~xhx~ZUp~LjkTVDEcs2^`R^oH5re0S@gwny8?-5-1|aODMjZ@ZcLVD{7f{boHm zek2>%$wa`da_0Zcgs0Up=w;^{t}Z$=`fU&pi7+_*~$Q3-}IYd=u*9 z%RKr~%>2g9J&W>Nj(4>BRr>CnY45A{l4HLIp9{R>0{Ok$fez_#`hKUnoR+ViIsEj* zsh<8-p8M!aeRck-&zxPqzJ|GX`m&Bc_*~$X7x3NZ%%=N}TfOBpKV7c=re{65oa|4| z*N`7hzE+=f`|~BI!`!F&`4Z)19}f8%%H@CXxxhOwFny0%PoBH4*7*`w|8CM}4|!Ri zc`WkpIEw0`I&)e!sa()|2zE>i#bH@kF_hAMz)s@8YNA z?C!x2-RmgI$@?A6d-|RWyy61B+nkj;?)%1}es#PdpMID+E(eF%N8d#`Fz3-*?rNP6 z%FB6lK6Sid_TdxCfAG1$Ctkq!>$_jwQ6JPN2Tr;=l$WS37pAZAtNbm;eqUxo6a zy>Mt=a(W`)gU#YP&}chHG)zBkAG~4a^3m~z>BC|6(P7&&>wITXZtit9 z_vwEw@JSc&ef!?+mU`A3&%TzIK0N*O3H3q#VD@YbL8T`o8K==n}Q`h(8}o(r76K>oe8M1Ozj$Foj{ zQ$2nB_EDcO{c`l-d=EYscrNhuF5vGOw!drjhF-2N2Umx5KKg^t1)d8$7kDo4T;RFD zbAjgq&jp?fJQsK_@Lb@zz;l7;0?!4W3p^KiF7RC7xxjOQ=K{|Ko(nt|crNf<;JLta Yf#(9x1)d8$7kDo4T;RFDKfJ*I1x_ed%m4rY literal 180128 zcmeIyzpiyzk`(YBOrD~=2B{4&kPyPgV}yf&g<4D_kPtKKM&lr!f)}Ps%}>)4>Zn|6 z|2g-Z`|UC+Dk38zbLH8${_kJ^$AAAH{?lLm>hFK`zkmHVzx%^~`Tei|%}>AnxBvZL z|N1vS{rbQD?eG8B?|=J0|MTDe{%?Nwv;OaX`#=8hyFcoG_;0`cyWjnh|FeJjul~hv ze){9T{`qfy`aeG%`sbhi>ZX7G>i_5dlY;7PyuPf{8)h!fFE4dmJyAd8Pp&WPe2M%p z>zTtxZzzYq@znWY=F|1TrmL@_dGa27F7S#AXLM zAw5wY%1NGmjki9zP(QAZ-jEOS;h9IDFmswe^YyLjd{Do6^a=H)Pfn}zPj$XTedav) zT;R_y;CuOV?dSid=l6rd)~7!){kXhUl&c@=#}l)T%Y&)Q$vU4po~W)L=fl+-^0)d~ zUH-z@Xxw{-#@qZ z`>I2}4*47EYbZw@hpFQe>WBQv+unTYIHXV5{PxkGoDNrW(w}`3*WWPv zn{QR0^4#UY?{$IvK70p>>B}pqx4!h_>BqC)IK5%Zv#-ANHC-R%!;hl**-t+^Cv$f5 z(l4*=n|k9dr}=Xaz9~oFgYSKT{2guUcTVJ=P@lZ?tJ9O~n^6DcQ%^q*`4ZJ(){}b= z)Q7L4K4?zs=WF$^)~Eflp1kePH{GZ8r!MEg_qu@JdHTK7ZQuInd=2G6K3on|#~bn` zX0AFNKZ^E|2UC}?4pVP^spC*yV%C$V|0r*{O+PxP_2nI#Z&jD~;MZQj?`4bjJ+*r3 zIFtv?!}%s$_4C2ZH9vKJnDvRv)sH92NgdCA{?wbl>H6jIC8zV_iR!IC{mJ#GzRHiz zZTYMFnET*sFW~oCojUn_2lA3@Ovbx!>r@B#~~eN4z8Zqa?+PvA5>3HUq$^; zpLfuqIu7|KFXTrOpqtjzhjx)CX7ln$J5X%z3yzJm+K`pM3JtkEJ9a!pRP}x zeiV0e%wL@+|H0Q>!0*zg`TZbYayhBuIm!7N<~#I5eSDA(`C;lfl*6B> zo_>16mc!rbbiTxC5B<%jp1Egv=HhLSv-7gwoCmMDK>mJHDnIo5=%eEeSA9qOr~Fn= z-TR<=X6Q#sGk1R^`(x(^wD9PqYmY)`t>Esf%5S5<9ztlDF4B= z3%GxC`#segm!sa0PhO(Be)Z(^v$&dTuf*(Ew^!EbkUxENK6OZkbeMYPsnel;yrF(L z`Brn&m%Po{)mL-*rg!enPrv*Juf2dfpH1~0z$x$ST>aC$$(Q~7xj&yeOkZ++iRqho z_T%^|tS7pSwM{cLDd+g+sp!&PT_yPMr-F)&vNg-dam3b4c+nVyW8*Da*xhk?c;-O z5Bjbz{oXGRrf<^K6XnCKH{Rx_KYiw;o;-8#=99~ZHbN=`hxGI( z=ZET$o~Uk4)|2z6j;llZD!#gJbuah%QPA%&z4xq7KQu2nz2W4a>Z`fFm(0Uq%R8#y zmD}Dk>G_VypLOq4{T4{q~^m_Un5U?(Wz8tGfGqH}w9i-*s26J$+|a&5;AmIg9dO z<~2QiM|sPqyJKRTmp(lG^s|_InsZg|Re5Lk(Vy>~boB|n^Hb1!cK3C^>}&c_U-lj4 zt9!}khpFRtW6r;No_y$gz|EsisIU35-tyD8nwP$oLxzq`QIeKIe3=E(o< z=nkvzjy*>rqWp&^5yT@w(sz2vU`MdrpPrcbN>oPgK{Zo}8XI`Q+wP z!pzS+JTd!{pPk$Kr#bd){qL%K*SnxQPW!kgKg_cy4j?Q!cuR?oYHQ!w_ zujzbKoo^R+bDO{E_L=H@t2pIN^_;`+_fAyTr;e*9^3l_;-t<=IS5Gt#s!x4fm^p30NW^VE}Z}PdH{>jgmn0?99|L#2Z$K4B0RG;40 ze3^&K(+8(|`f#Wpngi2^Lwz`Gy864|SM8Z|-TU3q-e>pBoaFMGuAaE-&;EDktMALd zuk2H&Eko;oh4>FNzzPV<|S_0*H6f0bXAvpTQ&-05A=UD`gY z^~sm{&ENFwU*%Wlt;Pv@sYI;87^>Ug3$Y&rU}t`G9z{K@&$ z@uMg|`>)E++{W7;&A+O9$E(mg%u7y(-ZA<3SMg}x(Z1ZPc{zE(^u=wm$}LH*$>ha^{0=XetfF)?c!Crt2wK_)pxk#yP&;~?&ChI{rqsXkMC%o z{#WC)?^Mry`B#58emVLkU7cTl>gtdWU&YMDp`3{~-=wR<)j9m%75$!fqkEf^oPHFq z&dWVoJ{@M?yYbyUcm3Y$?Jvu9NjDP+~d_~kDS}|^fi9AzvZ6Qb1%Q2+^pk}?`k}nuMh6-<^9=@t0zwN^e4Ax z>Jy*ls>94n{pcS0GB4-t@~iXO{Ihes|8BJZ>KyM{^*7&DIW2$Ezk2?(U)I~cP0!zz zkB%p*o3BoX>iDa%?PZUv_B%T#`;*)ED85^syWfr8y*uY<-|n3Bx4BKv{&(kBzoY#< zwdZXcMwySY7bPVSSOzv(SkUt-Hi-)?U9K{(?J?E8yZM@K-c;xNG+cd``@IV7 zwL9-ZuC%Tut)syq7--S8%?3|-{-V5!KoPKnU{;Tlr zJu=@tz6*NaRqwevZ`$kA=e55B?^gHD#_6Ae(|fY+U3nM!Su_ux&3l)A_j>S`FW`6o zW!L%Y?i_xY8_;3acexzc`hDM7r@t$D*R)^ylI!<=e8S9W{pp)>ud1)kUG?qm4FR;i|9of3=?bC$}FR(i7b`_2g3yzr0nyzE#}K z-Syl5YV^LVb9eLHW4EvQug+_8zDl3oH`S-T_4_^&r+WJD=2!3IKK9Q1CZ{K^`m*mX{?+@e&R_LScgmdPQ@`(P@?|crKXIxz|88IA zBwx)-9}eZZ6Asfyhf|$z70>3K?Vt9P?_OVp(;id(tM*=<>vvf7^G)xb>T=#4+kNO) zp*via*K+(Wt)BX{*OZrad2~GI;(YkKqxam6?$YK=_3Yz^S#NyT-*UdI?tAsV>0asA zm)LSr$K}2&=AJm5_MUv|P+oHSyP-XO2UqVW_b$x+&1t$iT=gCG>w6dU`?hyC{p|j` zxx4$gX%U!}X(Re5ske^Q~>jz2BT41$|d=_1(Mvt8;UH^3{2h z&v(Libk1sD&fA@fY{f7urL9@<~^}3fugn z`DgQ{{;9sZ_mnr)ub!8A=KFifdU8IfjzhVv-gNzmnK$+GXU@c1?$Nwy-c|Lp`?Y@e zcon*beRk(geN%mR-sFGvylJ1WuID|H+pp>BFnzusbx3c1_4Lyxlsoy<)1RELp*>;q z(O2`Y_L=`GbdS5xeKOzsR&~CXr=FPpiR(`v{atW%zsaW#uYOcCue|Sf=kD*F$>|Mu{VnHH z^xgOJ!}iXjb>FYt%t_7<)u9~qL_T`*)YawUd<~l~^@(SG`m^46_IXe9ze=~qtI<7P zHFtO3uK(_LyxTtQy?5=IxwyZR+((_>xc&+Cr;l!qdehYt`RU2&iF`RfeK4J0^Wfxfb^XxqotXVuZ~gkzp?-B7 zrjOoGPGa-lt?%xi`*=s{6W8}@^j>(?{M9^n$+@fjtG>)nepOE9Ud_L1&#UJ8y?4?4 zhVHD6n?uJNW?%EkJ$T}j zllAsHXD%PioaFlXlhd!l%x^sPtLDpb2mC0m?rGk;p!crsL zujcYid#OWxXYs1M={`AU;__hno8Iox`sL`uVbj$ipE>mJir)J!=Wmz*E!OH@zflZ!Xhub!L^vu;je`t+a0yK-0e z@=kvHd==UU<{nK?AD;e+>zhyxOn=kUcb30u&e=WB_Iqby%jKKud=005b^X)d5nrMl z{g4lb{8^{-HO#)|dw1P?e>C*FgskXy25pepmiox%O_DbMZv|`TOwE`Jg(U zIMw;(^EaIQa$w6*hjP=e{%&Z`w%1hGmp*s+RP?T+`@6@{{?@meyXrG%*FW_g&6C^k zXzpqsAGAmE)U%Gm)$fTPnxj9F5AOP>ysS^$9+U6g>hAU~n0K@n9lr`QH~Cd_R&&nI z-|gS^ufAjQtBD-IsIL+?c4Nxhrfr^ui|&@(e_@Q*LK-E~*)`UclHci*w=_xF+8?&{{k%%dmr<-Fv4`qlA< z`k=lEGjG+OJ~@-#eA$Ojes$>mpN97LZu`?4^6jF$-M*{+uihi?o4oCldh*@7=!d)g ztNwgn_L&dW^+7u1&wA>~_3MM_qhEzBzv)Nkw!Egh$B&NgJ>9__Vft5dvhOP1_R9R` zYkKx4*9WtneASnIt9;5=hkiFmH!nH8Ve`{-4qaa3`V#q5PrmBQzOUx)^6BXQ`HrTi z?=J55m7n`0=ery4o}2IXp1bF^d^)rbOg*{xKz)$E)%jrar=DCsT^=2(<7ctuHT~?o z*6+Qaj_&3i$?0&a^R439JipJY=1p_*J;_^td&g>j%T3*!%x${?y7If`Ly~~@6O+^<C5^2 zkUxENKA3v;@g?fx|IyGryc_@S=yyN+&Z~0lAumxqkuUXE^W0lfPW;>a_n< z&*aJ9>TUn@CpUk!pD%OhiF~Pl3SZsd9e;FmuZHfUj$e(t`RPw?zox4v?)vrXOH^M) z{i}Yy>`TrEr#fHa(Z0;_{}XvQY`Xd`UY)l)zxj9f$~*pO+`ZoI-EFT)SBFR6b+&(X z?x)V#-Mjfy_jjH5qkk9l?$!6?J$`iVZg~fHO`TsIrf=d$`=;Ep^{e)pa?aL|zSCTI zG-v9|{QPe4M7gQI3qQJ_d$f17T>69OF5v!|;`eCv^dIFdzv($=)t7yDb9=!yZ>swp z^rt_0`t`x|(NB#QFqCKHJ9MWOa)f?)YexKGq=`A2se4uf2fZ zC-K@T|B&|TIsY*Em;cLNT*)V0{V2A1Q{L3)??_Ji)nTjC`L4puO@7s!SLL|#gRi-O z@1;#UThH7n&v&M8^7DNc%=hQJ+CB8C!|cbip8fdMXl~1$>aB0mS94c==DFvC_grB0 z&VIMtqUq_w`5`@Vs;58sYX7QF{;ppiocgD_oJ4&v>v*gC4zf-^i*nEA+})phPCWN< z=Lhe-fZt_x%2B^QnDd&RzKQ#eCf`)QYMpdjY?f-w8ju)o%Z;f7WxqNQ$2Gh|Ji!J51-K7^ff(w@5b#7 zkM5^Gf5(Tv*9Ff0PQH)!Ds^*U^S64_Th6QWX)kqXFLz1K*N|Twmv=RqvzvRgZ+FhF zf7;7EAAHXXWw)IrkH~G|kcX;-vu6`8TJo;|#w4Yp<{?yfB?!*5sXwSCKRBwH| z`ZPx!W`5rJ;osu|e($5N$$sBW%i~WUz4bL+A50%T=cO;XzC^yvOJ8#SSE2c5=Veav zv-#P7Hs9~_;AdaJcQL(U)hEaIg~O(+H~gyk)4uBV!xO!WFF9Qwo|t{<&$|Bf(Hk~@ z(_7A2J>Qe_)l<*8d=Gxs1^kXzZS6ae*ZAt(>~FdBqnLA&x7_BVL+`x{?YBB_)i>>% zImzvnIqBQ2%bjvE7uWybXJ5eYKfPl0cjuD}GoLSYT)p-49Ywi`nWx_N$+~_zq|2F5 z-{gCj`t-it_e_3u*zRDj2fyb9{O;{_zK`Ad)HmfF?U&o$l{#)7q|0l%dZK=)AJP-m zVbVi4{z+Gtr$2Q(QD5qf@A_L# z>bbZ2*?F(-HxGIT^v(yL3;cWm-^uPZR{QvTe~IR%u1_5{T|H4Q9hcLPPaou0zY68T zt8(wkxoXcT$K4-%F7WdOd@s{AR(*2#{k_m(*72<4u=(hT=0iDL}x)!bdb z+}780eG~59Fa7v~&jo(IfbY$BxqB7;t9>~qdCO~h&Xt$QCr3SX^+f*E8{hS}oTl&Q zU+tHdXy4S6^R1%a>A~j$e|~}duKv9C^M9+a>gmJfz*U_O`rdJQ4V$07i{|RXA^mRj zUUQPu;n}*JyU`sVd@gXs1$?i*e<&MIRNqB8txrAo;?H{OIFtjMuKuaG`mVfRd+(}m znydccbAg=;wC``q$iB(n>iKt(K6B`JqMY>6`EfoRrY||)r{U3e=!4!palVG`|KM|h zD=y%B%q%)=b$avVcg#<3oGuq;AAJ|)Jou+ePj~QpOunq^d+@oy9T)JuP1&oy{QrnP zXbxZU^d+DAvaTPhH%?E~2ebaw+#TB<>bv`|`sF|PT;R+F^83v!-(%{yejMt9bl7xt z*nH-~^y`~YpPW^c%bz@Tb*O$7PPx19l-Dr#=4-n74?Y*T^8&sWP`N zI*%_=zTB+Sp*lX{ewcNAiL1WsYrN$(T~5{~t}oI1AABzG$_w~jS7$bV)2Cd0tC)HE z^6x46t}p$`GyiU$^OH~W-7S4`pnCf82cHXk#RYtyM`!!4;grJ%)1Uh!=hv?e>51wv z>#O{1|1@VcKkv|TkLFE%=0Esc;GHkvJGNtTy1ySB@+Yc8^~N(F*Oy!l{oTM7~7xp*}pZ)#dUhr^BhvcQu*|t;S-QR=yQOw+y!)I>dR8Rli+&k&(WNx@43J$ zFL2fO-*#^KsW+Z~SG!#9YTx9~`oyQX*`NNbyL0OLc6EJnAABzGi5JN4zP+LO==%I! zKsuzq3)+8opVdC^p8Q$IVeYr8o39^#@VUSzUm(AKyuE3vw?4YRA32bo*y{3N_9bsQ z=}Vq{IWOzkKXLtX@4{EjxoiI?_qo8&7x4Ggt}%V*tv~zHr%uNc)uDRhbm(39QEYQl z&wTTnkIr`$wmq)Oe{!A+e8mNP*Q>MrePthAPS$bAj~~UBlR6H~OP;#Acin~NLUZv% zd(d%xFm?RF=K|mH0>1nFyV!lJ{=0B>AMb(FKIu!YKXv=4H{^TpxxjOQ`!C?{(BJDc zkuUr3vzYnvnr`2$ 0 and (i_episode % 100) == 0 and not default_policy: - agent.save_model("43-tmp_model") - evaluation_result = experiment(500, default_policy=True, policy="43-tmp_model") + agent.save_model("tmp_model") + evaluation_result = experiment(500, default_policy=True, policy="tmp_model") acc = accuracy(evaluation_result["results"]) if acc == 100: break @@ -66,7 +66,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False, agen next_action = agent.act(state) new_state, reward, end, _ = env.step(next_action) - # reward = abs(new_state[0] - (-0.5)) # r in [0, 1] + reward = abs(new_state[0] - (-0.5)) # r in [0, 1] new_state = np.reshape(new_state, [1, 2]) agent.memoise((state, next_action, reward, new_state, end)) @@ -133,7 +133,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False, agen # experiments.append(("model20", 10000, DQNAgent(output_dim, layers, use_ddqn=True, learn_thresh=1000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.1, optimizer=keras.optimizers.RMSprop(0.001), tb_dir=None))) # experiments.append(("model21", 10000, DQNAgent(output_dim, layers, use_ddqn=True, learn_thresh=2000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.1, optimizer=keras.optimizers.RMSprop(0.001), tb_dir=None))) # experiments.append(("model22", 10000, DQNAgent(output_dim, layers, use_ddqn=True, learn_thresh=3000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.1, optimizer=keras.optimizers.RMSprop(0.001), tb_dir=None))) -experiments.append(("43-model23", 25000, DQNAgent(output_dim, layers, use_ddqn=True, learn_thresh=1000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.01, optimizer=keras.optimizers.Adam(0.001), tb_dir=None))) +experiments.append(("model23", 25000, DQNAgent(output_dim, layers, use_ddqn=True, learn_thresh=1000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.01, optimizer=keras.optimizers.Adam(0.001), tb_dir=None))) # experiments.append(("model24", 10000, DQNAgent(output_dim, layers, use_ddqn=True, learn_thresh=2000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.1, optimizer=keras.optimizers.Adam(0.001), tb_dir=None))) # experiments.append(("model25", 10000, DQNAgent(output_dim, layers, use_ddqn=True, learn_thresh=3000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.1, optimizer=keras.optimizers.Adam(0.001), tb_dir=None))) # experiments.append(("model26", 10000, DQNAgent(output_dim, layers, use_ddqn=True, learn_thresh=4000, update_rate=300, epsilon_decay_function=lambda e: e * 0.995, epsilon_lower_bound=0.1, optimizer=keras.optimizers.Adam(0.001), tb_dir=None))) @@ -196,7 +196,7 @@ def train_and_test(experiments): df.loc[len(df)] = [model_name, len(train_res["steps"]), training_mean_score, training_mean_steps, testing_accuracy, testing_mean_score, testing_mean_steps] - df.to_csv('43-experiments.csv') + df.to_csv('experiments.csv') def main(): train_and_test(experiments) diff --git a/MountainCar/ql_mountain_car.py b/MountainCar/ql_mountain_car.py index 2f959d4b..bb909888 100644 --- a/MountainCar/ql_mountain_car.py +++ b/MountainCar/ql_mountain_car.py @@ -36,7 +36,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): else: agent = QLAgent([n_states, n_states, env.action_space.n], epsilon_decay_function=lambda e: e * 0.6, epsilon_lower_bound=0.1) - for i_episode in tqdm(range(n_episodes), desc="Episode"): + for _ in tqdm(range(n_episodes), desc="Episode"): state = env.reset() state = obs_to_state(env, state, n_states) cumulative_reward = 0 @@ -45,7 +45,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): if (render): env.render() - next_action = agent.act((state[0], state[1]), i_episode) + next_action = agent.act((state[0], state[1])) new_state, reward, end, _ = env.step(next_action) new_state = obs_to_state(env, new_state, n_states) if policy is None: diff --git a/MountainCar/sarsa_mountain_car.py b/MountainCar/sarsa_mountain_car.py index d7dc06d1..8099a427 100644 --- a/MountainCar/sarsa_mountain_car.py +++ b/MountainCar/sarsa_mountain_car.py @@ -36,7 +36,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): else: agent = SARSAAgent([n_states, n_states, env.action_space.n], epsilon_decay_function=lambda e: e * 0.6, epsilon_lower_bound=0.1) - for i_episode in tqdm(range(n_episodes), desc="Episode"): + for _ in tqdm(range(n_episodes), desc="Episode"): state = env.reset() state = obs_to_state(env, state, n_states) cumulative_reward = 0 @@ -47,7 +47,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): if (render): env.render() - next_action = agent.act((state[0], state[1]), i_episode) + next_action = agent.act((state[0], state[1])) new_state, reward, end, _ = env.step(next_action) new_state = obs_to_state(env, new_state, n_states) if policy is None: diff --git a/ReinforcementLearningLib/sarsa_lib.py b/ReinforcementLearningLib/sarsa_lib.py index 4ce1dbe9..126bc4a9 100644 --- a/ReinforcementLearningLib/sarsa_lib.py +++ b/ReinforcementLearningLib/sarsa_lib.py @@ -5,20 +5,26 @@ class SARSAAgent(QLAgent): def __init__(self, shape, alpha=0.8, gamma=0.95, policy=None, epsilon=1, - epsilon_lower_bound=0.01, epsilon_decay_function=lambda e: e * 0.6): + epsilon_lower_bound=0.01, epsilon_decay_function=lambda e: e * 0.6, update_rate=100): super().__init__(shape, alpha, gamma, policy, epsilon, epsilon_lower_bound, epsilon_decay_function) self.current_policy = None if policy is not None: self.current_policy = policy self.shape = shape + self.update_rate = update_rate + self.Q_target = None + self.total_episodes = 0 def extract_policy(self): - policy_shape = self.shape - policy_shape = policy_shape[:-1] - self.current_policy = np.zeros(policy_shape, dtype=int) - for idx, _ in np.ndenumerate(self.current_policy): - self.current_policy[idx] = self.next_action(self.Q[idx]) + if (self.total_episodes % self.update_rate) == 0: + policy_shape = self.shape + policy_shape = policy_shape[:-1] + self.current_policy = np.zeros(policy_shape, dtype=int) + for idx, _ in np.ndenumerate(self.current_policy): + self.current_policy[idx] = self.next_action(self.Q[idx]) + self.Q_target = self.Q + self.total_episodes += 1 def update_q(self, state, new_state, action, reward): """ @@ -34,7 +40,7 @@ def update_q(self, state, new_state, action, reward): next_action = self.current_policy[new_state] self.Q[state][action] = (1 - self.alpha) * self.Q[state][action] + self.alpha * (reward + self.gamma * self.Q[new_state][next_action]) - def act(self, state, episode_number=None): # TODO: controllare episode_number + def act(self, state, return_prob_dist=False): # TODO: controllare episode_number if (self.policy is not None): next_action = self.policy[state] else: @@ -46,4 +52,6 @@ def act(self, state, episode_number=None): # TODO: controllare episode_number else: next_action = np.argmax(np.random.uniform(0, 1, size=self.actions)) - return next_action \ No newline at end of file + if not return_prob_dist: + return next_action + return next_action, self.Q_target[state] \ No newline at end of file diff --git a/Taxi/SavedNetworks/sarsa_policy.npy b/Taxi/SavedNetworks/sarsa_policy.npy index 7fc7a6c78e2f0d9b9a5bb1ae038084e82cf0debc..b5c2331b3df90f9ca36f939d3bcde0132719378b 100644 GIT binary patch literal 4128 zcmb`}KWmd=6o&ECKZ~Ct+pCbFiy-3crgZ7hNrH_ih?PWK#82TDZiW0hTySVDX1Q{o zbDi@%Z;~N@-k-layEs|>T>V-ucZZMr<@I`b^X1jDSudZi_CNO9Zy&Dqm%G39x7)9W z-A(=QdHa2Llb^plJ>9(6tbeZ`{Cc)JzMc%!$4AFE*!q;G`qS*|UUQGVIkGwJ+jy7v z(0=Epd$%s*?M3hTW|~tkyVkvt&y!J|-`slV_)}efqTi*yj6I*%dUHL$!`9`K;~Vrm z$dggMv`3wf>b%yQd7e|QCnMi^r9K~d?W<3ozB%_)H>b`+UbDJ+nWx@dPhY)sukP!d z%$K`Qz56DQkM{FrK zw3la3sgIqv2kY_Cex5mXp8DkRkvGljqq$Nar@neU**#GGzBK0??7aF!e*5YZ>&tu@ zTW{uJzhCu<_L)PzKC0JK=S`n;J-NPYA8#TLd2;tn_wSrM(SCh-BCqGy=b`({e7Sv@ zCztwYA5Omh#A%=Zfyk5pb)Wx%=IoOv{x8q>XrJs{Ss#6_yk_3i=ilX-=k;?|FU_HT gpGUu(Jm;78caQ9Q=6QbolGl23J$p-@x;?V~HSi=I-~a#s literal 4128 zcmb`}y^51@5JvH=AB|U$E($5O5k#!FQ*11(B)AcU#g$~S5wF4xb*Z<^E6oVs~}i z-`0*4KI+T&W#^lx zeXbvQQ*ND)OMM>JZ|0#ss-wF5NBxrDx;cG$^4j;i~|=4?(9@R_k(;H z)hF^we$P|q%RJXF*OQTt_8|||(Omn@yt1C``s?>x*&M2yS8uMTkLGx?zIrncd8N7D zZ#{kN9A6*R`SQeRP93{XeWLR`4^*GXEBVgTZ@rm^<|og6)>F5qw2x=LnWvBT$&;sV zFHg3oHLOIlhebdFoSceQBTmMElxTpIBe!%h-Ce_s2)x z)W54|-ad2ca-YBY#O~wSgG-({9;(|Tx3Av&nLIw$lP8)h>$l&`SMPnw^<}AmAuyd4^-!QpQw-j%d>xJpFX+{^nB%i_3T0WuyyCR&ntQEU%#2B?>(3G U_~q~V;{X5v diff --git a/Taxi/ql_taxi.py b/Taxi/ql_taxi.py index 634b8a17..0cf22d46 100644 --- a/Taxi/ql_taxi.py +++ b/Taxi/ql_taxi.py @@ -25,7 +25,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): else: agent = QLAgent([env.observation_space.n, env.action_space.n]) - for i_episode in tqdm(range(n_episodes)): + for _ in tqdm(range(n_episodes)): state = env.reset() cumulative_reward = 0 @@ -34,7 +34,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): env.render() time.sleep(1) - next_action = agent.act(state, i_episode) + next_action = agent.act(state) new_state, reward, end, _ = env.step(next_action) if policy is None: agent.update_q(state, new_state, next_action, reward) diff --git a/Taxi/sarsa_taxi.py b/Taxi/sarsa_taxi.py index cec7296d..4f8e65df 100644 --- a/Taxi/sarsa_taxi.py +++ b/Taxi/sarsa_taxi.py @@ -23,9 +23,9 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): if (default_policy): agent = SARSAAgent([env.observation_space.n, env.action_space.n], policy=policy) else: - agent = SARSAAgent([env.observation_space.n, env.action_space.n], epsilon_decay_function=lambda e: e - 0.000016) + agent = SARSAAgent([env.observation_space.n, env.action_space.n], epsilon_decay_function=lambda e: e - 0.000016, update_rate=10) - for i_episode in tqdm(range(n_episodes)): + for _ in tqdm(range(n_episodes)): state = env.reset() cumulative_reward = 0 if not default_policy: @@ -36,7 +36,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): env.render() time.sleep(1) - next_action = agent.act(state, i_episode) + next_action = agent.act(state) new_state, reward, end, _ = env.step(next_action) if policy is None: agent.update_q(state, new_state, next_action, reward) @@ -61,7 +61,7 @@ def experiment(n_episodes, default_policy=False, policy=None, render=False): # Training -train_res = experiment(30000) +train_res = experiment(10000) learnt_policy = np.argmax(train_res["Q"], axis=1) # print("Policy learnt: ", learnt_policy) training_mean_steps = train_res["steps"].mean()