nethack-vae-hmm / training_data.json
CatkinChen's picture
Add training data
6b0ef9c verified
raw
history blame
2.53 kB
{
"train_losses": [
3890.116960449219
],
"test_losses": [
2722.2674072265627
],
"config": {
"hmm_only": false,
"vae_only_with_hmm": false,
"em_rounds": 4,
"m_epochs_per_round": 1,
"hmm_params": {
"alpha": 5.0,
"kappa": 1.0,
"gamma": 5.0,
"K": 40,
"D": 96
},
"niw_params": {
"mu0": 0.0,
"kappa0": 1.0,
"Psi0": 30.0,
"nu0": 106,
"set_mu0_with_global_mean": false,
"set_Psi0_with_global_cov": false
},
"hmm_paths": [
"checkpoints_hmm/hmm_round1.pt",
"checkpoints_hmm/hmm_round2.pt",
"checkpoints_hmm/hmm_round3.pt"
],
"vae_hmm_paths": [
"checkpoints_hmm/vae_with_hmm_round1.pt",
"checkpoints_hmm/vae_with_hmm_round2.pt",
"checkpoints_hmm/vae_with_hmm_round3.pt"
],
"hf_repos": {
"hmm": "CatkinChen/nethack-hmm",
"vae_hmm": "CatkinChen/nethack-vae-hmm"
},
"viz_paths": [
{
"dir": "hmm_analysis/round_01",
"pi_bar": "hmm_analysis/round_01/round01_pi_bar.png",
"A_heatmap": "hmm_analysis/round_01/round01_A_heatmap.png",
"mu_pca": "hmm_analysis/round_01/round01_mu_t-sne.png",
"skill_raster": "hmm_analysis/round_01/round01_skill_raster.png",
"dwell_pmfs": "hmm_analysis/round_01/round01_dwell_pmfs.png",
"diags_json": "hmm_analysis/round_01/round01_diags.json"
},
{
"dir": "hmm_analysis/round_02",
"pi_bar": "hmm_analysis/round_02/round02_pi_bar.png",
"A_heatmap": "hmm_analysis/round_02/round02_A_heatmap.png",
"mu_pca": "hmm_analysis/round_02/round02_mu_t-sne.png",
"skill_raster": "hmm_analysis/round_02/round02_skill_raster.png",
"dwell_pmfs": "hmm_analysis/round_02/round02_dwell_pmfs.png",
"diags_json": "hmm_analysis/round_02/round02_diags.json"
},
{
"dir": "hmm_analysis/round_03",
"pi_bar": "hmm_analysis/round_03/round03_pi_bar.png",
"A_heatmap": "hmm_analysis/round_03/round03_A_heatmap.png",
"mu_pca": "hmm_analysis/round_03/round03_mu_t-sne.png",
"skill_raster": "hmm_analysis/round_03/round03_skill_raster.png",
"dwell_pmfs": "hmm_analysis/round_03/round03_dwell_pmfs.png",
"diags_json": "hmm_analysis/round_03/round03_diags.json"
}
]
},
"final_train_loss": 3890.116960449219,
"final_test_loss": 2722.2674072265627,
"total_epochs": 1,
"best_train_loss": 3890.116960449219,
"best_test_loss": 2722.2674072265627
}