Source code for rendering.renderertriple

import matplotlib.pyplot as plt
import numpy as np

#from Plotting.era5_forecast_visualizations import skip_array
#dfdg
from env.config.env_config import env_params
from env.rendering.Trajectory3DPlotter import Trajectory3DPlotter
import matplotlib as mpl
from termcolor import colored
import os

[docs] class MatplotlibRendererTriple(): """ Handles the visualization of ERA5 and Synthetic forecasts using 3D quiver plots and trajectory visualization. Attributes: coordinate_system (str): Coordinate system used for visualization ('geographic' or 'cartesian'). Forecast_visualizer (object): Visualizer for ERA5 forecasts. Forecast_visualizer_synth (object): Visualizer for synthetic forecasts. render_count (int): Frequency of rendering updates. quiver_skip (int): Frequency of skipping quivers in plots for performance. render_mode (str): Rendering mode ('human', etc.). render_timestamp (numpy.datetime64): Timestamp for ERA5 rendering. render_timestamp_synth (numpy.datetime64): Timestamp for synthetic rendering. dt (float): Time step of the simulation. episode_length (int): Total steps in an episode. goal (dict): Goal position (x, y). radius (float): Station-keeping radius. radius_inner (float): Inner radius for goal visualization. radius_outer (float): Outer radius for goal visualization. trajectory_plotter (Trajectory3DPlotter): Handles trajectory visualization. """ def __init__(self, Forecast_visualizer_ERA5, Forecast_visualizer_SYNTH, render_mode, radius, coordinate_system = "geographic"): """ Initialize the renderer. Args: Forecast_visualizer_ERA5 (object): Visualizer for ERA5 forecast. Forecast_visualizer_SYNTH (object): Visualizer for synthetic forecast. render_mode (str): Rendering mode ('human', etc.). radius (float): Station-keeping radius in meters. coordinate_system (str, optional): Coordinate system ('geographic' or 'cartesian'). """ self.coordinate_system = coordinate_system self.Forecast_visualizer = Forecast_visualizer_ERA5 self.Forecast_visualizer_synth = Forecast_visualizer_SYNTH self.render_count = env_params['render_count'] self.quiver_skip = env_params['quiver_skip'] self.render_mode = render_mode self.render_timestamp = self.Forecast_visualizer.forecast_subset.start_time self.render_timestamp_synth = self.Forecast_visualizer_synth.forecast_subset.start_time self.dt = env_params['dt'] self.episode_length = env_params['episode_length'] self.goal = {"x": 0, "y": 0} #relative self.radius = radius # m self.radius_inner = radius * .5 # m self.radius_outer = radius * 1.5 # m self.save_figure = env_params['save_figure'] if self.save_figure: self.frame_save_dir = env_params['save_dir'] os.makedirs(self.frame_save_dir, exist_ok=True) self.frame_count = 0 #try: if self.coordinate_system == "geographic": self.init_plot_geographic() if self.coordinate_system == "cartesian": # this does not exist right now self.init_plot() #except: #print(colored("Not a Valid Coordinate System. Can either be geographic or cartesian", "red"))
[docs] def init_plot_geographic(self): """ Initialize the plot layout for geographic visualization. """ self.fig = plt.figure(figsize=(18, 10)) self.gs = self.fig.add_gridspec(nrows=2, ncols=4, height_ratios=[1, 4], width_ratios=[.9, .9, .9, 0.05]) # Top Altitude Plot self.ax3 = self.fig.add_subplot(self.gs[0, :]) # 3D plots self.ax = self.fig.add_subplot(self.gs[1, 0], projection='3d') self.ax2 = self.fig.add_subplot(self.gs[1, 1], projection='custom3dquiver') # ERA5 self.ax4 = self.fig.add_subplot(self.gs[1, 2], projection='custom3dquiver') # Synth # Adjust the space between rows (hspace = 0 for no space) self.gs.update(hspace=0) #box1 = self.ax.get_position() # Get the original position of the top plot box2 = self.ax3.get_position() # Get the original position of the bottom plot # Adjust the position of the bottom plot to overlap the top one slightly self.ax2.set_position([box2.x0, box2.y0 - 0.05, box2.width, box2.height]) #These 3 lines make it so the first ERA5 forecast visualization is not glitchy during rendering. Not sure why it is in the first place. self.ax2.clear() self.ax2.remove() self.ax2 = self.fig.add_subplot(self.gs[1, 1], projection='custom3dquiver') # Initialize the 3D trajectory plotter here self.trajectory_plotter = Trajectory3DPlotter(self.ax, self.radius, self.goal, self.dt, self.episode_length) # Initialize 2 Forecast visualizers, 1 for ERA5, one for Synth self.Forecast_visualizer.visualize_3d_planar_flow(self.ax2, quiver_skip=self.quiver_skip) self.Forecast_visualizer_synth.visualize_3d_planar_flow(self.ax4, quiver_skip=self.quiver_skip) # Altitude Profile Plot Setup self.altitude_line, = self.ax3.plot([], [], 'b-') self.ax3.set_xlabel('Number of Steps (dt=' + str(self.dt) + 's)') self.ax3.set_ylabel('Altitude (km)') self.ax3.set_xlim(0, self.episode_length) self.ax3.set_ylim(env_params['alt_min']/1000.,env_params['alt_max']/1000.) # Manual COlor Bar stuff: plt.title('') # Manual colorbar creation from -180° to 180° norm = mpl.colors.Normalize(vmin=-180, vmax=180) # Normalize the range from -180 to 180 sm = mpl.cm.ScalarMappable(cmap='hsv', norm=norm) # ScalarMappable ties the colormap and normalization together sm.set_array([]) # Required for ScalarMappable (no actual data, so empty array) # Create a new axis for the colorbar with manual size # [left, bottom, width, height] where values are fractions of figure size cbar_ax = self.fig.add_axes([0.89, 0.25, 0.01, 0.35]) # Half-height colorbar # Create colorbar in this new axis cbar = self.fig.colorbar(sm, cax=cbar_ax, orientation='vertical') cbar.set_label('Wind Direction (degrees)') # Remove any title from the colorbar # print("title",cbar_ax.title()) #cbar_ax.title.set_text('First Plot') # sdfsdf # self.cbar_ax.set_title('') # Ensures no title appears # Set custom ticks and labels (with degree symbols) cbar.set_ticks([-180, -90, 0, 90, 180]) cbar.set_ticklabels([r'$-180^\circ$', r'$-90^\circ$', r'$0^\circ$', r'$90^\circ$', r'$180^\circ$'])
#self.reset(self.Balloon, self.goal, self.SimulatorState)
[docs] def reset_forecast_visualization(self): """ Reset the forecast visualization for ERA5 and Synthetic forecasts. """ self.ax2.clear() self.ax2.remove() self.ax2 = self.fig.add_subplot(self.gs[1, 1], projection='custom3dquiver') self.Forecast_visualizer.generate_flow_array(timestamp=self.SimulatorState.timestamp) self.Forecast_visualizer.visualize_3d_planar_flow(self.ax2, quiver_skip=self.quiver_skip, altitude_quiver_skip=1) self.ax4.clear() self.ax4.remove() self.ax4 = self.fig.add_subplot(self.gs[1, 2], projection='custom3dquiver') self.Forecast_visualizer_synth.generate_flow_array(timestamp=self.SimulatorState.timestamp) self.Forecast_visualizer_synth.visualize_3d_planar_flow(self.ax4, quiver_skip=self.quiver_skip)
[docs] def reset(self, goal, Balloon, SimulatorState): """ Reset the renderer. Args: goal (dict): Goal position (x, y). Balloon (object): Balloon state object. SimulatorState (object): Simulator state object. """ #close figures if already open, otherwise make them if hasattr(self, 'fig'): plt.close('all') delattr(self, 'fig') delattr(self, 'ax') #delattr(self, 'ax2') delattr(self, 'ax3') #delattr(self, 'ax4') delattr(self, 'goal') #delattr(self, 'scatter') #delattr(self, 'canvas') self.Balloon = Balloon self.SimulatorState = SimulatorState self.goal = goal self.render_step = 1 #for initial reset self.hour_count = 0 #self.render_timestamp = self.Forecast_visualizer.forecast_subset.start_time #self.render_timestamp_synth = self.Forecast_visualizer_synth.forecast_subset.start_time self.render_timestamp = self.SimulatorState.timestamp #self.Forecast_visualizer.forecast_subset.start_time if not hasattr(self, 'fig'): if self.coordinate_system == "geographic": self.init_plot_geographic() if self.coordinate_system == "cartesian": self.init_plot() self.reset_forecast_visualization()
[docs] def is_timestamp_in_interval(self, timestamp, interval): """ Check if a timestamp falls within a specified interval. Args: timestamp (datetime.datetime): Timestamp to check. interval (int): Interval in hours. Returns: bool: True if timestamp falls within the interval, False otherwise. """ # Convert the timestamp string to a datetime object #dt = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S") # Get the hour from the datetime object hour = timestamp.hour minute = timestamp.minute # Check if the minutes are exactly 00 if minute != 0: return False # Check if the hour is divisible by the interval if interval == 1: return hour % 1 == 0 if interval == 3: return hour % 3 == 0 elif interval == 6: return hour % 6 == 0 elif interval == 12: return hour % 12 == 0 else: raise ValueError("Interval must be 3, 6, or 12 hours.")
[docs] def render(self, mode='human'): """ Render the simulation environment. Args: mode (str, optional): Rendering mode ('human', etc.). """ if self.render_step == self.render_count: path = np.array(self.SimulatorState.trajectory) # Update the 3D trajectory plot self.trajectory_plotter.update(self.Balloon, path, self.goal) self.trajectory_plotter.update_altitude(path) self.trajectory_plotter.draw() # Update Altitude Profile self.altitude_line.set_data(range(len(path)), path[:, 2]/1000.) #Doing this in km self.ax3.set_title("Timestamp: " + str(self.SimulatorState.timestamp) + "\nTime Elapsed: " + str( (self.SimulatorState.timestamp - self.render_timestamp))) if env_params['timewarp'] == None: forecast_visualizer_update = 12 else: forecast_visualizer_update = env_params['timewarp'] if self.is_timestamp_in_interval(self.SimulatorState.timestamp, forecast_visualizer_update): self.reset_forecast_visualization() if mode == 'human': plt.pause(0.001) #print(self.SimulatorState.timestamp, self.Balloon) # Save frame as image if self.save_figure: frame_filename = os.path.join(self.frame_save_dir, f"frame_{self.frame_count:05d}.png") self.fig.savefig(frame_filename) self.frame_count += 1 self.render_step = 1 else: self.render_step += 1