clear plot path every start

This commit is contained in:
2026-01-24 17:29:58 +01:00
parent 16cdc2ed4e
commit 8f0fb11926

View File

@@ -9,6 +9,7 @@ from utils import create_predictions
from train import train
import shutil
if __name__ == '__main__':
@@ -49,6 +50,13 @@ if __name__ == '__main__':
state_dict_path = os.path.join(config_dict['results_path'], "best_model.pt")
save_path = os.path.join(config_dict['results_path'], "testset", "tikaiz")
plot_path = os.path.join(config_dict['results_path'], "testset", "plots")
os.makedirs(plot_path, exist_ok=True)
for name in os.listdir(plot_path):
p = os.path.join(plot_path, name)
if os.path.isfile(p) or os.path.islink(p):
os.unlink(p)
elif os.path.isdir(p):
shutil.rmtree(p)
# Comment out, if predictions are required
create_predictions(config_dict['network_config'], state_dict_path, testset_path, None, save_path, plot_path, plot_at=20, rmse_value=rmse_value)