improve baseline

This commit is contained in:
2026-01-24 16:15:34 +01:00
parent 3b1d3c0497
commit 57026695d4
10 changed files with 228 additions and 63 deletions

View File

@@ -81,7 +81,7 @@ def read_compressed_file(file_path: str):
return input_arrays, known_arrays
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20):
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
"""
Here, one might needs to adjust the code based on the used preprocessing
"""
@@ -128,6 +128,11 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
"predictions": predictions
}
# Modify save_path to include RMSE value if provided
if rmse_value is not None:
base_path = save_path.rsplit('.npz', 1)[0]
save_path = f"{base_path}-{rmse_value:.4f}.npz"
np.savez_compressed(save_path, **data)
print(f"Predictions saved at {save_path}")