clear plot path every start
This commit is contained in:
@@ -9,6 +9,7 @@ from utils import create_predictions
|
|||||||
|
|
||||||
|
|
||||||
from train import train
|
from train import train
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -49,6 +50,13 @@ if __name__ == '__main__':
|
|||||||
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
|
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
|
||||||
save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
|
save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
|
||||||
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
|
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
|
||||||
|
os.makedirs(plot_path, exist_ok=True)
|
||||||
|
for name in os.listdir(plot_path):
|
||||||
|
p = os.path.join(plot_path, name)
|
||||||
|
if os.path.isfile(p) or os.path.islink(p):
|
||||||
|
os.unlink(p)
|
||||||
|
elif os.path.isdir(p):
|
||||||
|
shutil.rmtree(p)
|
||||||
|
|
||||||
# Comment out, if predictions are required
|
# Comment out, if predictions are required
|
||||||
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20, rmse_value=rmse_value)
|
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20, rmse_value=rmse_value)
|
||||||
|
|||||||
Reference in New Issue
Block a user