improve baseline
This commit is contained in:
@@ -81,7 +81,7 @@ def read_compressed_file(file_path: str):
|
||||
return input_arrays, known_arrays
|
||||
|
||||
|
||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20):
|
||||
def create_predictions(model_config, state_dict_path, testset_path, device, save_path, plot_path, plot_at=20, rmse_value=None):
|
||||
"""
|
||||
Here, one might needs to adjust the code based on the used preprocessing
|
||||
"""
|
||||
@@ -128,6 +128,11 @@ def create_predictions(model_config, state_dict_path, testset_path, device, save
|
||||
"predictions": predictions
|
||||
}
|
||||
|
||||
# Modify save_path to include RMSE value if provided
|
||||
if rmse_value is not None:
|
||||
base_path = save_path.rsplit('.npz', 1)[0]
|
||||
save_path = f"{base_path}-{rmse_value:.4f}.npz"
|
||||
|
||||
np.savez_compressed(save_path, **data)
|
||||
|
||||
print(f"Predictions saved at {save_path}")
|
||||
|
||||
Reference in New Issue
Block a user