import yaml import time import random import numpy as np import torch from torch.utils.tensorboard import SummaryWriter import time from envs.photo_env import PhotoEnhancementEnv from envs.photo_env import PhotoEnhancementEnvTest from sac.sac_algorithm import SAC import multiprocessing as mp import argparse import logging from sac.utils import * from tqdm.auto import tqdm from datetime import datetime import os from pathlib import Path import re def sanitize_filename(name): return re.sub(r'[^\w\-_\. ]', '_', name) def getdatetime(): return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") class Config(object): def __init__(self, dictionary): self.__dict__.update(dictionary) def make_dirs_and_open(file_path, mode): os.makedirs(os.path.dirname(file_path), exist_ok=True) return open(file_path, mode) def main(): current_dir = Path(__file__).parent.absolute() parser = argparse.ArgumentParser() parser.add_argument('experiment_tag', help='experiment tag') parser.add_argument('sac_config', help='YAML sac config file') parser.add_argument('env_config', help='YAML env config file') parser.add_argument('outdir', nargs='?', type=str, help='directory to put experiment results',default=os.path.join(current_dir.parent, 'experiments/runs')) parser.add_argument('save_model', nargs='?',type=bool, default=True) parser.add_argument('--logger_level', type=int, default=logging.INFO) args = parser.parse_args() logger = logging.getLogger(__name__) # Configure logging to console console_handler = logging.StreamHandler() console_handler.setLevel(args.logger_level) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') console_handler.setFormatter(formatter) logger.addHandler(console_handler) logger.setLevel(args.logger_level) with open(args.sac_config) as f: config_dict =yaml.load(f, Loader=yaml.FullLoader) with open(args.env_config) as f: env_config_dict =yaml.load(f, Loader=yaml.FullLoader) sac_config = Config(config_dict) env_config = Config(env_config_dict) exp_name = sanitize_filename(sac_config.exp_name) exp_tag = sanitize_filename(args.experiment_tag) run_name = f"{exp_name}__{exp_tag}__{getdatetime()}" run_name = run_name[:255] # Truncate to 255 characters to avoid potential issues with very long paths run_dir = os.path.join(args.outdir, run_name) with make_dirs_and_open(os.path.join(run_dir, 'configs/sac_config.yaml'), 'w') as f: yaml.dump(config_dict, f, indent=4, default_flow_style=False) with make_dirs_and_open(os.path.join(run_dir, 'configs/env_config.yaml'), 'w') as f: yaml.dump(env_config_dict, f, indent=4, default_flow_style=False) SEED = sac_config.seed random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) torch.backends.cudnn.deterministic = sac_config.torch_deterministic torch.autograd.set_detect_anomaly(True) print() env = PhotoEnhancementEnv( batch_size=env_config.train_batch_size, imsize=env_config.imsize, training_mode=True, done_threshold=env_config.threshold_psnr, edit_sliders=env_config.sliders_to_use, features_size=env_config.features_size, discretize=env_config.discretize, discretize_step= env_config.discretize_step, use_txt_features=env_config.use_txt_features, augment_data=env_config.augment_data, pre_encoding_device=env_config.pre_encoding_device, pre_load_images = env_config.pre_load_images, preprocessor_agent_path=env_config.preprocessor_agent_path, logger=None ) test_env = PhotoEnhancementEnvTest( batch_size=env_config.test_batch_size, imsize=env_config.imsize, training_mode=False, done_threshold=env_config.threshold_psnr, edit_sliders=env_config.sliders_to_use, features_size=env_config.features_size, discretize=env_config.discretize, discretize_step = env_config.discretize_step, use_txt_features=env_config.use_txt_features, augment_data=env_config.augment_data, pre_encoding_device=env_config.pre_encoding_device, pre_load_images = env_config.pre_load_images, preprocessor_agent_path=env_config.preprocessor_agent_path, logger=None ) logger.info(f'Sliders used {env.edit_sliders}') logger.info(f'Number of sliders used { env.num_parameters}') logger.info(f'Sliders used {test_env .edit_sliders}') logger.info(f'Number of sliders used {test_env .num_parameters}') writer = SummaryWriter(run_dir) writer.add_text( "SAC_hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(sac_config).items()])), ) writer.add_text( "env_parameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(env_config).items()])), ) try: agent = SAC(env,sac_config,writer) if env_config.preprocessor_agent_path!=None: #Double agent mode test_env.preprocessor_agent = env.preprocessor_agent # share the same preprocessor agent agent.backbone.model.load_state_dict(env.preprocessor_agent.backbone.model.state_dict()) agent.backbone.eval().requires_grad_(False) agent.start_time = time.time() logger.info(f'Start Training at {getdatetime()}') for i in tqdm(range(sac_config.total_timesteps), position=0, leave=True): episode_count = 0 agent.reset_env() envs_mean_rewards =[] if agent.global_step>env_config.backbone_warmup: agent.backbone.train().requires_grad_(True) while True: episode_count+=1 agent.global_step+=1 rewards,batch_dones = agent.train() envs_mean_rewards.append(rewards.mean().item()) if(batch_dones==True).any(): num_env_done = int(batch_dones.sum().item()) agent.writer.add_scalar("charts/num_env_done", num_env_done , agent.global_step) if agent.global_step % 100 == 0: ens_mean_episodic_return = sum(envs_mean_rewards) agent.writer.add_scalar("charts/mean_episodic_return", ens_mean_episodic_return, agent.global_step) if (batch_dones==True).all()==True or episode_count==sac_config.max_episode_timesteps: episode_count=0 break if agent.global_step%200==0: agent.backbone.eval().requires_grad_(False) agent.actor.eval().requires_grad_(False) agent.qf1.eval().requires_grad_(False) agent.qf2.eval().requires_grad_(False) with torch.no_grad(): n_images = 5 obs = test_env.reset() actions = agent.actor.get_action(**obs.to(sac_config.device)) _,rewards,dones = test_env.step(actions[0]) agent.writer.add_scalar("charts/test_mean_episodic_return", rewards.mean().item(), agent.global_step) if env_config.preprocessor_agent_path!=None: agent.writer.add_images("test_images",test_env.original_image[:n_images],0) agent.writer.add_images("test_images",test_env.state['source_image'][:n_images],1) agent.writer.add_images("test_images",test_env.state['enhanced_image'][:n_images],2) agent.writer.add_images("test_images",test_env.state['target_image'][:n_images],3) else: agent.writer.add_images("test_images",test_env.state['source_image'][:n_images],0) agent.writer.add_images("test_images",test_env.state['enhanced_image'][:n_images],1) agent.writer.add_images("test_images",test_env.state['target_image'][:n_images],2) agent.backbone.train().requires_grad_(True) agent.actor.train().requires_grad_(True) agent.qf1.train().requires_grad_(True) agent.qf2.train().requires_grad_(True) logger.info(f'Ended training at {getdatetime()}') if args.save_model: models_dir = os.path.join(run_dir, 'models') os.makedirs(models_dir, exist_ok=True) logger.info(f"Saving models in {models_dir}") torch.save(agent.backbone.state_dict(), run_dir+'/models/backbone.pth') save_actor_head(agent.actor, run_dir+'/models/actor_head.pth') save_critic_head(agent.qf1, run_dir+'/models/qf1_head.pth') save_critic_head(agent.qf2, run_dir+'/models/qf2_head.pth') writer.close() except Exception as e: logger.exception("An error occurred during training") if agent.global_step>1000: if args.save_model: models_dir = os.path.join(run_dir, 'models') os.makedirs(models_dir, exist_ok=True) logger.info(f"Saving models after exception in {models_dir}") torch.save(agent.backbone.state_dict(), run_dir+'/models/backbone.pth') save_actor_head(agent.actor, run_dir+'/models/actor_head.pth') save_critic_head(agent.qf1, run_dir+'/models/qf1_head.pth') save_critic_head(agent.qf2, run_dir+'/models/qf2_head.pth') writer.close() if __name__=="__main__": main()