"""
Current Assumptions for Synth and ERA5 forecast
- Both forecasts are downloaded for the same regions (0.25 degree resolution), same altitude band (config)
- Synth Forecast is 1 month @ 12 hour intervals
- Synth Forecast is 6 months @ 3 hour intervals
"""
import xarray as xr
import numpy as np
from env.config.env_config import env_params
from utils import constants
from utils.common import convert_range, quarter
from utils import CoordinateTransformations as transform
from line_profiler import profile
from termcolor import colored
import pandas as pd
[docs]
class Forecast:
"""
Loads a full ERA5 or synthetic forecast into memory.
This class handles large-scale climate forecasts used for simulations, supporting operations
like subsetting, time adjustments, and aligning forecasts for simulating.
Download from https://cds.climate.copernicus.eu/cdsapp#!/dataset/reanalysis-era5-pressure-levels?tab=form
Attributes:
forecast_type (str): Type of forecast ('SYNTH' or 'ERA5').
ds_original (xarray.Dataset): Original dataset loaded from the forecast file.
LAT_MIN, LAT_MAX (float): Latitude range of the forecast.
LON_MIN, LON_MAX (float): Longitude range of the forecast.
LEVEL_MIN, LEVEL_MAX (float): Pressure level range of the forecast.
TIME_MIN, TIME_MAX (numpy.datetime64): Time range of the forecast.
"""
def __init__(self, filename, forecast_type = None, month = None, timewarp=None ):
"""
Initialize the Forecast object and load a dataset.
Args:
filename (str): Path to the forecast file.
forecast_type (str): Type of forecast ('SYNTH' or 'ERA5').
month (int, optional): Month to filter for ERA5 forecasts.
timewarp (int, optional): Simulation Time interval adjustment (e.g., 3, 6, or 12 hours).
Raises:
Exception: If the forecast type is invalid.
"""
self.forecast_type = forecast_type
self.load_forecast(filename, month, timewarp=timewarp)
#check and see if the forecast type is correct
if forecast_type != "SYNTH" and forecast_type != "ERA5":
raise Exception("Invalid forecast type " + str(forecast_type))
[docs]
def check_nan(self, ds):
nans_exist = False
for var in ds.data_vars:
has_nan = ds[var].isnull().any()
if has_nan.item():
print(colored(f"WARNING: {var}: contains NaN? {has_nan.item()}", "yellow"))
nans_exist = True
if nans_exist:
for var in ds.data_vars:
total_vals = ds[var].size
n_nans = ds[var].isnull().sum().item()
pct_nans = (n_nans / total_vals) * 100
print(f"{var}: {n_nans} NaNs out of {total_vals} values ({pct_nans:.2f}%)")
print(colored(f"Linear Filling Nans on time dimension", "yellow"))
ds['u'] = ds['u'].interpolate_na(dim='time', method='linear', use_coordinate=False)
ds['v'] = ds['v'].interpolate_na(dim='time', method='linear', use_coordinate=False)
ds['z'] = ds['z'].interpolate_na(dim='time', method='linear', use_coordinate=False)
print(colored(f"Linear Interpolation done", "yellow"))
return ds
[docs]
def load_forecast(self, filename, month = None, timewarp = None):
"""
Load and preprocess the forecast dataset.
Args:
filename (str): Path to the forecast file.
month (int, optional): Month to filter for ERA5 forecasts.
timewarp (int, optional): Time interval adjustment (e.g., 3, 6, or 12 hours).
"""
self.ds_original = xr.open_dataset(env_params["forecast_directory"] + filename)
#Check if the forecast is corrupt, handle it.
self.ds_original = self.check_nan(self.ds_original)
# Drop temperature variable from forecasts if it exists
if 't' in self.ds_original.data_vars:
self.ds_original = self.ds_original.drop_vars('t')
# Do some reformatting for the new format of ERA5
if 'valid_time' in self.ds_original.coords:
#self.ds_original = self.ds_original.drop_vars('expver')
#self.ds_original = self.ds_original.drop_vars('number')
self.ds_original = self.ds_original.rename({'valid_time': 'time','pressure_level': 'level'})
#self.ds_original['latitude'] = self.ds_original['latitude'].astype('float32')
#self.ds_original['longitude'] = self.ds_original['longitude'].astype('float32')
#reformat the pressure level to match the old format
#self.ds_original['level'] = self.ds_original['level'].astype('int32')
#Reverse the 'level' coordinate
self.ds_original = self.ds_original.reindex(level=self.ds_original.level[::-1])
print("DID WE GO IN HERE")
if self.forecast_type == "ERA5":
memory_size_gb = self.ds_original.nbytes / 1e9
# Print the memory size of the dataset in gigabytes
print(f"Memory size of the dataset: {memory_size_gb:.6f} GB")
# Reverse order of latitude, since era5 comes reversed for some reason (We set up synth to be the same)
self.ds_original = self.ds_original.reindex(latitude=list(reversed(self.ds_original.latitude)))
# Cut off any pressure levels that are not in range of the altitude (only checking top bounds right now)
da_slice = self.ds_original.isel(time=0, latitude=0, longitude=0)
idx = np.argmin(np.abs(da_slice.z.values / 9.81 - env_params["alt_max"]))
max_pres_level = da_slice['level'][idx].level.values
self.ds_original = self.ds_original.sel(level=slice(max_pres_level, None))
#print('Max_pres level', max_pres_level)
# Some forecast formatting helper functions that are performed by default with v1.0
# Only include same ERA5 month as Synth, unless month is not specified
# Need to format ERA5 before timewarping
if self.forecast_type == "ERA5" and month != None:
self.drop_era5_months(month)
# Change the simulation timestamps of forecasts
if timewarp != None:
self.TIMEWARP(timewarp)
print(colored(self.forecast_type, "green"))
print(self.ds_original)
# Set master forecast variables
self.LAT_MIN = self.ds_original.latitude.values[0]
self.LAT_MAX = self.ds_original.latitude.values[-1]
self.LAT_DIM = len(self.ds_original.latitude.values)
self.LON_MIN = self.ds_original.longitude.values[0]
self.LON_MAX = self.ds_original.longitude.values[-1]
self.LON_DIM = len(self.ds_original.longitude.values)
self.LEVEL_MIN = self.ds_original.level.values[0]
self.LEVEL_MAX = self.ds_original.level.values[-1]
self.LEVEL_DIM = len(self.ds_original.level.values)
self.TIME_MIN = self.ds_original.time.values[0]
self.TIME_MAX = self.ds_original.time.values[-1]
self.TIME_DIM = len(self.ds_original.time.values)
#print(f"LAT RANGE: ({self.LAT_DIM }) {self.LAT_MIN}, {self.LAT_MAX}")
#print(f"LON RANGE: ({self.LON_DIM}) {self.LON_MIN}, {self.LON_MAX}")
#print(f"PRES RANGE: ({self.LEVEL_DIM}) {self.LEVEL_MIN}, {self.LEVEL_MAX}")
#print(f"TIME RANGE: ({self.TIME_DIM}) {self.TIME_MIN}, {self.TIME_MAX}")
[docs]
def drop_era5_months(self, month):
"""
Filter ERA5 forecast to match the specified month and reduce time intervals to 12 hours. This is a bit hardcoded rn
By default, ERA5 forecasts are downloaded in 6 month @ 3 hour intervals. Since many simulations require
both Synth and ERA5, we need change ERA5 to 12 hour intervals
Args:
month (int): Month to retain in the dataset.
Raises:
Exception: If the specified month is not within the forecast's range.
"""
print(colored("DROPPING ERA5 Months except (" + str(month) + ") and Times: (12) hour intervals", "yellow"))
# for error printing
start_time = self.ds_original.time.values[0]
end_time = self.ds_original.time.values[-1]
# Reformat the ERA5 forecast to only have times every 12 hours like Synth Forecasts
# Also Only include the same month if ERA5 has more than a month
month_mask = self.ds_original.time.dt.month == month
hour_mask = self.ds_original.time.dt.hour.isin([0, 12])
combined_mask = month_mask & hour_mask
self.ds_original = self.ds_original.sel(time=combined_mask)
if self.ds_original.time.size == 0:
raise Exception(
f"Month {month} is out of range of ERA forecast with time range of {start_time} - {end_time}")
[docs]
def TIMEWARP(self, timewarp):
'''
By Default ERA5 forecasts are downloaded in 3 hour intervals, whereas Synth are in 12 hour intervals
Therefore we perform a "timewarp" to overwrite timestamps in both forecasts, while still matching
up the data from the original timestamps with the new timestamps.
For example:
Synth (original): 2024-01-01 00:00:00, 2024-01-01 12:00:00, 2024-01-02 00:00:00, 2024-01-02 12:00:00
*original timestamps will be overwritten with the timewarp function
Synth (timewarp) : 2024-01-01 00:00:00, 2024-01-01 03:00:00, 2024-01-01 06:00:00, 2024-01-01 09:00:00
*to line up with the ERA5
ERA5: 2024-01-01 00:00:00, 2024-01-01 03:00:00, 2024-01-01 06:00:00, 2024-01-01 09:00:00
Timewarping will typically only be used for Synth Forecast due to their sparse timing
'''
if timewarp != 1 and timewarp != 3 and timewarp != 6 and timewarp != 12:
raise Exception(colored("Timewarp only accepts hour intervals of 1,3,6,12", "yellow"))
print(colored("TIMEWARPING (" + self.forecast_type + ")", "cyan"))
# determine initial time variables (temporary, not assigned)
time_min = self.ds_original.time.values[0]
time_dim = len(self.ds_original.time.values)
# Add timewarp time interval to the forecast
synth_simulated_time = []
for i in range(0, time_dim):
synth_simulated_time.append(time_min + np.timedelta64(i * timewarp, "h"))
self.ds_original['time'] = synth_simulated_time
self.ds_original = self.ds_original.reindex(time=synth_simulated_time)
#print(self.ds_original.time)
#sdfsdf
[docs]
class Forecast_Subset:
"""
Creates a subset of the master forecast for efficient processing and simulation.
Attributes:
Forecast (Forecast): Master forecast object.
lat_central (float): Central latitude of the subset.
lon_central (float): Central longitude of the subset.
start_time (numpy.datetime64): Start time of the subset.
ds (xarray.Dataset): Subset dataset.
"""
# Load from config file for now. Maybe change this later
def __init__(self, Forecast):
"""
Initialize the Forecast_Subset object.
Args:
Forecast (Forecast): Master forecast object.
"""
self.Forecast = Forecast
[docs]
def assign_coord(self, lat, lon, timestamp):
"""
Assign central coordinates and timestamp for the subset.
Args:
lat (float): Central latitude.
lon (float): Central longitude.
timestamp (numpy.datetime64): Start timestamp.
"""
# Round time to nearest hour and quarter
self.start_time = np.array(timestamp, dtype='datetime64[h]')
self.lat_central = quarter(lat)
self.lon_central = quarter(lon)
[docs]
def randomize_coord(self, np_rng):
"""
Generates a random coordinate to centralize the Forecast Subset, and stores the coordinate for look up by other classes.
Altitude Bounds are the same as the PRIMARY FORECAST
Horizontal Bounds are within 2 degrees of the min/max LAT/LON from the PRIMARY FORECAST
Time Bounds are between the start time and up to 24 hours before the final timestamp of the PRIMARY FORECAST
pass np_rng to have forecasts randomize in the same order when manually setting seed
Args:
np_rng (numpy.random.Generator): Random number generator.
"""
#print("RANDOM NUMBER", np_rng.uniform(low=0, high=100))
#print("RANDOM NUMBER 2", np.random.uniform(low=0, high=100))
lat = np_rng.uniform(low=self.Forecast.LAT_MIN+2, high=self.Forecast.LAT_MAX-2)
lon = np_rng.uniform(low=self.Forecast.LON_MIN + 2, high=self.Forecast.LON_MAX - 2)
#Convert time to unix for randomizing.
# Subtract 24 hours from the end for simulating.
time = np_rng.uniform(low=self.get_unixtime(self.Forecast.TIME_MIN), high=self.get_unixtime(self.Forecast.TIME_MAX-np.timedelta64(24, "h")))
# Convert time back to dt64
time = np.datetime64(int(time),'s')
#Round time to nearest hour and quarter
self.start_time = np.array(time, dtype='datetime64[h]')
self.lat_central = quarter(lat)
self.lon_central = quarter(lon)
self.fourecast_error_count = 0
[docs]
def get_alt_from_pressure(self, pressure):
"""Get average altitude from ERA5 for a forecast subset. Average is taken since z is geopotential converted
to altitude
Args:
pressure (float): atmospheric pressure.
Returns:
alt: corresponding altitude (from geopotential) for pressure level
"""
try:
# Use sel() to find the matching index
alt_array = self.ds.sel({"level": pressure}).z.values/9.81
avg_alt = np.mean(alt_array)
return avg_alt
except KeyError:
print(colored(f"Value {pressure} doesn't exist in the levels.","yellow"))
# Return a message if the value is not found
return None
[docs]
def subset_forecast(self, days = 1):
"""
Subsets the Forecast to the central coordinate. This assume a random coordinate or user input coordinate has already been assigned.
Horizontal Bounds are determined by the relative distance (converted to lat/lon degrees)
Altitude is the same
Time is 24 hours
Converts the DataSet to a numpy array for faster processing
Args:
days (int): Number of days to include in the subset.
"""
rel_dist = env_params['rel_dist']
pres_min = env_params['pres_min']
pres_max = env_params['pres_max']
#1. Calculate Lat/Lon Coordinates for subsetting the data to The relative distance area
lat_min, _ = transform.meters_to_latlon_spherical(self.lat_central, self.lon_central, 0, -rel_dist)
_, lon_min = transform.meters_to_latlon_spherical(self.lat_central, self.lon_central, -rel_dist, 0)
lat_max, _ = transform.meters_to_latlon_spherical(self.lat_central, self.lon_central, 0, rel_dist)
_ , lon_max = transform.meters_to_latlon_spherical(self.lat_central, self.lon_central, rel_dist, 0)
#Round to nearest .25 degree resolution since that's the res of ERA5 forecasts
self.lat_min = quarter(lat_min)
self.lon_min = quarter(lon_min)
self.lat_max = quarter(lat_max)
self.lon_max = quarter(lon_max)
#print("SUBSETTING")
#2. Subset the forecast to a smaller array
# This may not be necessary for simulating, but good for forecast visualization)
# No time for now?
self.ds = self.Forecast.ds_original.sel(latitude=slice(self.lat_min, self.lat_max),
longitude=slice(self.lon_min, self.lon_max),
level=slice(pres_min,pres_max),
time=slice(self.start_time, self.start_time + np.timedelta64(days, "D"))) #1 day of time for now
# 3. Determine new min and max values from the subsetted forecast
self.lat_min = self.ds.latitude.values[0]
self.lat_max = self.ds.latitude.values[-1]
self.lon_min = self.ds.longitude.values[0]
self.lon_max = self.ds.longitude.values[-1]
self.start_time = self.ds.time.values[0]
self.end_time = self.ds.time.values[-1]
#Now Calculate new Dimensions
self.lat_dim = len(self.ds.latitude)
self.lon_dim = len(self.ds.longitude)
self.level_dim = len(self.ds.level)
self.time_dim = len(self.ds.time)
# Convert the subset forecast Dataset to a numpy array for faster processing
# Make sure the forecast subset has been changed to the right order:
ordered_vars = ['z', 'u', 'v']# Reorder ds2 to match ds1
self.ds = self.ds[ordered_vars]
self.forecast_np = self.ds.to_array()
self.forecast_np = self.forecast_np.to_numpy()
self.pressure_levels = self.ds.level.values
[docs]
@profile
def xr_lookup(self,lat, lon, timestamp):
# How to look up indicies for xarray
#time_idx = list(self.ds.time.values).index(self.ds.sel(time=timestamp, method='nearest').time)
#lat_idx = list(self.ds.latitude.values).index(self.ds.sel(latitude=lat, method='nearest').latitude)
#lon_idx = list(self.ds.longitude.values).index(self.ds.sel(longitude=lon, method='nearest').longitude)
z = self.ds.sel(latitude=lat, longitude=lon, time=timestamp, method="nearest")['z'].values / constants.GRAVITY
u = self.ds.sel(latitude=lat, longitude=lon, time=timestamp, method="nearest")['u'].values
v = self.ds.sel(latitude=lat, longitude=lon, time=timestamp, method="nearest")['v'].values
return (z,u,v)
[docs]
def get_unixtime(self, dt64):
"""
Convert numpy.datetime64 to Unix time in seconds.
Args:
dt64 (numpy.datetime64): DateTime value.
Returns:
int: Unix timestamp in seconds.
"""
return dt64.astype('datetime64[s]').astype('int')
[docs]
@profile
def np_lookup(self,lat, lon, time):
"""
Perform a fast lookup for wind data using numpy arrays.
Args:
lat (float): Latitude.
lon (float): Longitude.
time (numpy.datetime64): Time.
Returns:
tuple: Altitude, u-component, and v-component of the wind.
"""
#print(lat, self.lat_min, self.lat_max, 0, self.lat_dim)
lat_idx = int(convert_range(lat, self.lat_min, self.lat_max, 0, self.lat_dim))
lon_idx = int(convert_range(lon, self.lon_min, self.lon_max, 0, self.lon_dim))
time_idx = int(convert_range(self.get_unixtime(np.datetime64(time)), self.get_unixtime(self.start_time),
self.get_unixtime(self.end_time), 0, self.time_dim))
# Cast again see if that fixes the problem
# Clip idx's to out of bounds. Should I add a warning here?
lat_idx = int(np.clip(lat_idx, 0, self.lat_dim - 1))
lon_idx = int(np.clip(lon_idx, 0, self.lon_dim - 1))
time_idx = int(np.clip(time_idx, 0, self.time_dim - 1))
z = self.forecast_np[0, time_idx, :, lat_idx, lon_idx] / constants.GRAVITY
u = self.forecast_np[1, time_idx, :, lat_idx, lon_idx]
v = self.forecast_np[2, time_idx, :, lat_idx, lon_idx]
return (z,u,v)
[docs]
def windVectorToBearing(self, u, v):
"""Helper function to convert u-v wind components to bearing and speed.
Not being used right now.
"""
bearing = np.arctan2(v,u)
speed = np.power((np.power(u,2)+np.power(v,2)),.5)
return bearing, speed
[docs]
def interpolate_wind(self, alt, z,u,v):
'''
Interpolates the u and v wind components given a 3D coordinate (lat,lon,alt)
Currently only interpolating in the Z direction. No smoothing for time or horizontal changes.
'''
#need to change to ascending order for interpolating with numpy
u_wind = np.interp(alt, z[::-1], u[::-1])
v_wind = np.interp(alt, z[::-1], v[::-1])
return u_wind, v_wind
[docs]
def getNewCoord(self, Balloon, SimulationState, dt):
"""
Determines new coordinate based on the flow at the current position and integrates forward int time via dt
"""
#get Wind at current lat/lon
z_col, u_col, v_col = self.np_lookup(Balloon.lat, Balloon.lon, SimulationState.timestamp)
x_vel, y_vel = self.interpolate_wind(Balloon.altitude, z_col, u_col, v_col )
#print("alt", Balloon.altitude, z_col, "u_vel:", u_col, "v_vel", v_col)
#Get current relative X,Y Position
relative_x, relative_y = transform.latlon_to_meters_spherical(self.lat_central,
self.lon_central,
Balloon.lat, Balloon.lon)
# If relative distance is nan from the transform step being too close to the central coordinate, revert back to 0.
#print()
#print("alt", Balloon.altitude, "x_vel:", x_vel, "y_vel", y_vel)
#print("relative_x", relative_x, "relative_y" , relative_y )
#Apply the velocity to relative Position
x_new = relative_x + x_vel * dt
y_new = relative_y + y_vel * dt
# Convert New Relative Position back to Lat/Lon
lat_new, lon_new = transform.meters_to_latlon_spherical(self.lat_central,
self.lon_central,
x_new, y_new)
# Similarly, If lat or lon is nan from the transform step being too close to the central coordinate, revert back to central coordinate.
#print("relative distance", relative_x, relative_y, x_vel, y_vel, dt)
#if lat_new == np.nan:
# lat_new = self.lat_central
#if lon_new == np.nan:
# lon_new = self.lon_central
#print("x_new", x_new, "y_new", y_new)
#print("pre_lat", Balloon.lat, "pre_lon", Balloon.lon)
#print("lat_new", lat_new, "lon_new", lon_new)
return [lat_new, lon_new, x_vel, y_vel, x_new, y_new ]
if __name__ == '__main__':
#Can't use utils.initialize_forecast here, because it would be a circular import
FORECAST_SYNTH = Forecast(env_params['synth_netcdf'], forecast_type="SYNTH", timewarp=3)
# Get month associated with Synth
synth_month = pd.to_datetime(FORECAST_SYNTH.TIME_MIN).month
# Then process ERA5 to span the same timespan as a monthly Synthwinds File
FORECAST_ERA5 = Forecast(env_params['era_netcdf'], forecast_type="ERA5", month=synth_month, timewarp=3)
forecast_subset = Forecast_Subset(FORECAST_ERA5) #Choose FORECAST_SYNTH or FORECAST_ERA5 here
forecast_subset.randomize_coord()
print("random_coord", forecast_subset.lat_central, forecast_subset.lon_central, forecast_subset.start_time)
forecast_subset.subset_forecast()
print(forecast_subset.ds)