'''
A test string at the top of module ``prime``.
'''
import importlib.resources
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as ks
from sklearn.preprocessing import RobustScaler #RobustScaler is used to scale the input/target data but is not called directly below
import joblib
__all__ = ['prime', 'crps_loss', 'mse_metric'] # Here __all__ is defined so that docs tools can read the appropriate docstrings
[docs]
class prime:
'''
This class wraps an instance of PRIME for solar wind prediciton.
When instantiating a ``prime`` object, one can specify a predefined ``model`` to be used instead of the automatically-loaded PRIME model.
In that case, the scaling functions for the input and target datasets (``in_scaler`` and ``tar_scaler``), the input and target features (``in_keys`` and ``tar_keys``), and the output features (``out_keys``) must be specified.
The full list of arguments that can be passed to ``prime`` is given below, but they are not recommended for general use.
'''
def __init__(self, model = None, in_scaler = None, tar_scaler = None, in_keys = None, tar_keys = None, out_keys = None, hps = [60, 15, 5.0/60.0]):
'''
:param model: Keras model for predicitons. If None, PRIME is loaded from the package.
:param in_scaler: Scikitlearn preprocessing scaler for input arrays. If None, pre-fit RobustScaler is loaded from the package.
:param tar_scaler: Scikitlearn preprocessing scaler for output arrays. If None, pre-fit RobustScaler is loaded from the package.
:param in_keys: Features used as inputs. If None, defaults are loaded from the package.
:param tar_keys: Features used as targets. If None, defaults are loaded from the package.
:param out_keys: Features used as outputs. If None, defaults are loaded from the package.
'''
super(prime, self).__init__()
if in_scaler is None:
resource_path = importlib.resources.path('primesw', 'primeinsc_v0.1.0.pkl')
with resource_path as in_scaler_file:
self.in_scaler = joblib.load(in_scaler_file)
else:
self.in_scaler = in_scaler
if tar_scaler is None:
resource_path = importlib.resources.path('primesw', 'primetarsc_v0.1.0.pkl')
with resource_path as tar_scaler_file:
self.tar_scaler = joblib.load(tar_scaler_file)
else:
self.tar_scaler = tar_scaler
if in_keys is None:
self.in_keys = [
'B_xgsm',
'B_ygsm',
'B_zgsm',
'Vi_xgse',
'Vi_ygse',
'Vi_zgse',
'Ni',
'Vth',
'R_xgse',
'R_ygse',
'R_zgse',
'target_R_xgse',
'target_R_ygse',
'target_R_zgse',
] # Wind data keys to include in input dataset
else:
self.in_keys = in_keys
if tar_keys is None:
self.tar_keys = [
'B_xgsm',
'B_ygsm',
'B_zgsm',
'Vi_xgse',
'Vi_ygse',
'Vi_zgse',
'Ne',
] # Targets from MMS dataset to match with input data
else:
self.tar_keys = tar_keys
if out_keys is None:
self.out_keys = [
'B_xgsm',
'B_xgsm_sig',
'B_ygsm',
'B_ygsm_sig',
'B_zgsm',
'B_zgsm_sig',
'Vi_xgse',
'Vi_xgse_sig',
'Vi_ygse',
'Vi_ygse_sig',
'Vi_zgse',
'Vi_zgse_sig',
'Ne',
'Ne_sig',
] # Features in PRIME output (in general, tar_keys with 1sigma uncertainties denoted '_sig')
else:
self.out_keys = out_keys
self.window = hps[0]
"""Length of input timeseries, in 100s units. Specified via `hps` argument."""
self.stride = hps[1]
"""Prediction lead time, in 100s units. Specified via `hps` argument."""
self.fraction = hps[2]
"""Maximum fraction of input timeseries that can be interpolated, in 100s units. Specified via `hps` argument."""
if model is None:
# self.model = self.build_model() # Instantiate model architecture with hyperparameters
resource_path = importlib.resources.path('primesw', 'prime_v0.1.0.keras')
with resource_path as model_weights_file:
self.model = ks.models.load_model(model_weights_file, custom_objects = {'crps_loss' : crps_loss, 'mse_metric' : mse_metric}) # Load the saved model
else:
self.model = model
[docs]
def predict(self, input = None, start = None, stop = None, pos = [13.25, 0, 0]):
"""
Method that produces a dataframe of PRIME solar wind predictions.
"""
if input is None:
if (start is not None)&(stop is not None):
input = self.build_real_input(start = start, stop = stop, pos = pos)
else:
raise RuntimeWarning('Must specify either input or (start and stop).')
return None
if isinstance(input, pd.DataFrame): # If input is a dataframe
input_arr = input[self.in_keys].to_numpy() # Convert input dataframe to array
if isinstance(input, np.ndarray): # If input is an array
input_arr = input # Set input array to input
output_arr = self.predict_raw(input_arr) # Predict with the keras model
#TO DO: Throw a warning or include a flag if any predictions were made with data that was "too interpolated"
output = pd.DataFrame(output_arr, columns = self.out_keys) # Convert output array to dataframe
output_epoch = input['Epoch'].to_numpy()[(self.window-1):] # Stage an epoch column to be added to the output dataframe
output_epoch += pd.Timedelta(seconds = 100*self.stride) # Add lead time to the epoch column
output['Epoch'] = output_epoch # Add the epoch column to the output dataframe
return output
[docs]
def predict_raw(self, input):
'''
Generates PRIME predictions from input dataframe. Assumes that `input` has keys specified by `prime.in_keys`. It is generally recommended to use `prime.predict` instead.
'''
input_scaled = self.in_scaler.transform(input) # Rescale the input data
input_arr = np.zeros((len(input_scaled)-(self.window-1), self.window, len(self.in_keys))) # Reshape input data to be 3D
for i in np.arange(len(input_scaled)-(self.window-1)):
input_arr[i,:,:] = input_scaled[i:(i+self.window)] # Move the 55 unit window through the input data
output_unscaled = self.model.predict(input_arr) # Use stored keras model to make prediction
output = np.zeros((len(output_unscaled),len(self.out_keys))) #Stage output data to be 2x target dimensions (to account for uncertainties)
output[:, ::2] = self.tar_scaler.inverse_transform(output_unscaled[:, ::2]) #Mean values
output[:, 1::2] = np.abs(self.tar_scaler.inverse_transform(output_unscaled[:, ::2] + output_unscaled[:, 1::2]) - self.tar_scaler.inverse_transform(output_unscaled[:, ::2])) #Standard deviations
return output
[docs]
def predict_grid(self, gridsize, x_extent, framenum, bx, by, bz, vx, vy, vz, ni, vt, rx, ry, rz, y_extent=None, z_extent=None, y = 0, z = 0, subtract_ecliptic=False):
"""
Generate predictions from PRIME on a grid of points in GSE coordinates.
Parameters:
-----------
- gridsize (float): Spacing of grid points (RE)
- x_extent (list): Range of x values to calculate on (GSE RE).
- framenum (int): Number of frames to calculate (GSE RE).
- bx (float, array-like): IMF Bx value (nT). If array like, must be of length framenum.
- by (float, array-like): IMF By value (nT). If array like, must be of length framenum.
- bz (float, array-like): IMF Bz value (nT). If array like, must be of length framenum.
- vx (float, array-like): Solar wind Vx value (km/s). If array like, must be of length framenum.
- vy (float, array-like): Solar wind Vy value (km/s). If array like, must be of length framenum.
- vz (float, array-like): Solar wind Vz value (km/s). If array like, must be of length framenum.
- ni (float, array-like): Solar wind ion density value (cm^-3). If array like, must be of length framenum.
- vt (float, array-like): Solar wind ion thermal speed value (km/s). If array like, must be of length framenum.
- rx (float, array-like): Wind spacecraft position x value (GSE RE). If array like, must be of length framenum.
- ry (float, array-like): Wind spacecraft position y value (GSE RE). If array like, must be of length framenum.
- rz (float, array-like): Wind spacecraft position z value (GSE RE). If array like, must be of length framenum.
- y_extent (list): Range of y values to calculate on (GSE RE). If None, z_extent must be specified.
- z_extent (list): Range of z values to calculate on (GSE RE). If None, y_extent must be specified.
- y (float, array-like): Y position (GSE RE) that is held constant if y_extent is not specified. Default 0.
- z (float, array-like): Z position (GSE RE) that is held constant if z_extent is not specified. Default 0.
- subtract_ecliptic (bool): Whether or not to subtract the Earth's motion in the ecliptic from Vy. Default False.
Returns:
--------
- output_grid (ndarray): Array of predicted values on the grid. Shape (framenum, x_extent/gridsize, y_extent/gridsize, 14). Features as in `prime.out_keys`.
"""
x_arr = np.arange(x_extent[0], x_extent[1], gridsize) # Create a grid to calculate the solar wind conditions on
y_arr = np.asarray([y]) # This array is overwritten if y_extent is specified
z_arr = np.asarray([z]) # This array is overwritten if z_extent is specified
if y_extent is None and z_extent is None:
raise ValueError("Must specify y_extent or z_extent")
if y_extent is not None:
y_arr = np.arange(y_extent[0], y_extent[1], gridsize) # Y positions to calculate the solar wind conditions on
if z_extent is not None:
z_arr = np.arange(z_extent[0], z_extent[1], gridsize) # Z positions to calculate the solar wind conditions on
x_grid, y_grid, z_grid = np.meshgrid(x_arr, y_arr, z_arr) # Create a grid to calculate the solar wind conditions on
input_seed = np.zeros((len(x_grid.flatten()) * framenum, len(self.in_keys))) # Initialize array to hold the input data before unfolding it
for idx, element in enumerate([bx, by, bz, vx, vy, vz, ni, vt, rx, ry, rz]): # Loop through the input data and repeat it
try:
iter(element) # Check if the element is iterable
input_seed[:, idx] = np.repeat(element, len(x_grid.flatten())) # If it is, repeat it for each grid point
except TypeError: # This error throws if iter(element) fails (i.e. element is not iterable)
input_seed[:, idx] = np.repeat(element, framenum * len(x_grid.flatten())) # If it isn't, repeat it for each grid point *and frame*
loc_arr = np.zeros((len(x_grid.flatten()) * framenum, 3)) # Initialize array to hold the location data
loc_arr[:, 0] = np.tile(x_grid.flatten(), framenum)
loc_arr[:, 1] = np.tile(y_grid.flatten(), framenum)
loc_arr[:, 2] = np.tile(z_grid.flatten(), framenum)
input_seed_scaled = self.in_scaler.transform(input_seed) # Scale the input data
input_seed_scaled[:, 11:14] = self.loc_scaler.transform(loc_arr) # Scale the location data
input_seed_scaled = np.repeat(input_seed_scaled, self.window, axis=0) # Repeat the input data 55 times to make static timeseries
input_arr = input_seed_scaled.reshape(len(x_grid.flatten()) * framenum, self.window, len(self.in_keys)) # Reshape the input data into the correct shape
output_arr = self.model.predict(input_arr) # Predict the output data
output = np.zeros((len(output_arr), len(self.out_keys))) # Stage output data to be 2x target dimensions
output[:, ::2] = self.tar_scaler.inverse_transform(output_arr[:, ::2]) # Mean values
output[:, 1::2] = np.abs(self.tar_scaler.inverse_transform(output_arr[:, ::2] + output_arr[:, 1::2]) - self.tar_scaler.inverse_transform(output_arr[:, ::2])) # Standard deviations
output_grid = output.reshape(framenum, len(y_arr), len(x_arr), len(z_arr), len(self.out_keys)) # Reshape the output data into the correct shape
output_grid = np.swapaxes(output_grid, 1, 2) # Move the y axis to the second axis (new order is frame, x, y, z, param)
if subtract_ecliptic: # If subtract_ecliptic is true, subtract the Earth's motion in the ecliptic from Vy
output_grid[:, :, :, :, 8] -= 29.8
return output_grid
[docs]
def build_model(self, units = [352, 192, 48, 48], activation = 'elu', dropout = 0.20, lr = 1e-4):
'''
Builds the underlying PRIME model with no weights or biases loaded. Deprecated as of keras introducing the `.keras` model save routine.
Units are the layer size of the GRU layer and three dense layers.
Normalization and dropout applied at each layer.
'''
model = ks.Sequential([ks.layers.GRU(units=units[0]),
ks.layers.Dense(units=units[1], activation=activation),
ks.layers.Dense(units=units[2], activation=activation),
ks.layers.Dense(units=units[3], activation=activation),
ks.layers.LayerNormalization(),
ks.layers.Dropout(dropout),
ks.layers.Dense(len(self.tar_keys),activation='linear')
])
model.compile(optimizer=tf.optimizers.Adamax(learning_rate=lr), loss=crps_loss)
model.build(input_shape = (1, self.window, len(self.in_keys)))
return model
#Custom loss function (Continuous Rank Probability Score) and associated helper functions
def crps_loss(y_true, y_pred):
"""
Tensorflow implementation of the Continuous Rank Probability Score loss function. Assumes seven output features. For a simpler functional version, see `primesw.crps_f`.
Parameters
----------
- y_true (tf.Tensor): Ground truth values of predicted variable.
- y_pred (tf.Tensor): mu and sigma^2 values of predicted distribution.
Returns
-------
- crps (tf.Tensor): Continuous rank probability score.
"""
# Separate the parameters into means and squared standard deviations
mu0, sg0, mu1, sg1, mu2, sg2, mu3, sg3, mu4, sg4, mu5, sg5, mu6, sg6, y_true0, y_true1, y_true2, y_true3, y_true4, y_true5, y_true6 = unstack_helper(y_true, y_pred)
# CRPS (assuming gaussian distribution)
crps0 = tf.math.reduce_mean(crps_f(ep_f(y_true0, mu0), sg0))
crps1 = tf.math.reduce_mean(crps_f(ep_f(y_true1, mu1), sg1))
crps2 = tf.math.reduce_mean(crps_f(ep_f(y_true2, mu2), sg2))
crps3 = tf.math.reduce_mean(crps_f(ep_f(y_true3, mu3), sg3))
crps4 = tf.math.reduce_mean(crps_f(ep_f(y_true4, mu4), sg4))
crps5 = tf.math.reduce_mean(crps_f(ep_f(y_true5, mu5), sg5))
crps6 = tf.math.reduce_mean(crps_f(ep_f(y_true6, mu6), sg6))
# Average the continuous rank probability scores
crps = (crps0 + crps1 + crps2 + crps3 + crps4 + crps5 + crps6) / 7.0
return crps
def mse_metric(y_true, y_pred):
"""
Tensorflow implementation of Mean Squared Error compatible with PRIME's output layer. Assumes seven output features. Not suitable for use as a loss function.
Parameters
----------
- y_true (tf.Tensor): Ground truth values of predicted variable.
- y_pred (tf.Tensor): mu and sigma^2 values of predicted distribution.
Returns
-------
- mse (tf.Tensor): MSE between mu and y_true.
"""
# Separate the parameters into means and squared standard deviations
# mu0, sg0, mu1, sg1, mu2, sg2, mu3, sg3, mu4, sg4, mu5, sg5, mu6, sg6, mu7, sg7, mu8, sg8 = tf.unstack(y_pred, axis=-1)
mu0, sg0, mu1, sg1, mu2, sg2, mu3, sg3, mu4, sg4, mu5, sg5, mu6, sg6 = tf.unstack(y_pred, axis=-1)
#Oh my god mu2, I can't believe it
# Separate the ground truth into each parameter
# y_true0, y_true1, y_true2, y_true3, y_true4, y_true5, y_true6, y_true7, y_true8 = tf.unstack(y_true, axis=-1)
y_true0, y_true1, y_true2, y_true3, y_true4, y_true5, y_true6 = tf.unstack(y_true, axis=-1)
# Add one dimension to make the right shape
mu0 = tf.expand_dims(mu0, -1)
mu1 = tf.expand_dims(mu1, -1)
mu2 = tf.expand_dims(mu2, -1)
mu3 = tf.expand_dims(mu3, -1)
mu4 = tf.expand_dims(mu4, -1)
mu5 = tf.expand_dims(mu5, -1)
mu6 = tf.expand_dims(mu6, -1)
# mu7 = tf.expand_dims(mu7, -1)
# mu8 = tf.expand_dims(mu8, -1)
y_true0 = tf.expand_dims(y_true0, -1)
y_true1 = tf.expand_dims(y_true1, -1)
y_true2 = tf.expand_dims(y_true2, -1)
y_true3 = tf.expand_dims(y_true3, -1)
y_true4 = tf.expand_dims(y_true4, -1)
y_true5 = tf.expand_dims(y_true5, -1)
y_true6 = tf.expand_dims(y_true6, -1)
# y_true7 = tf.expand_dims(y_true7, -1)
# y_true8 = tf.expand_dims(y_true8, -1)
# Calculate the MSE
mse0 = tf.math.reduce_mean(tf.math.square(y_true0 - mu0))
mse1 = tf.math.reduce_mean(tf.math.square(y_true1 - mu1))
mse2 = tf.math.reduce_mean(tf.math.square(y_true2 - mu2))
mse3 = tf.math.reduce_mean(tf.math.square(y_true3 - mu3))
mse4 = tf.math.reduce_mean(tf.math.square(y_true4 - mu4))
mse5 = tf.math.reduce_mean(tf.math.square(y_true5 - mu5))
mse6 = tf.math.reduce_mean(tf.math.square(y_true6 - mu6))
# mse7 = tf.math.reduce_mean(tf.math.square(y_true7 - mu7))
# mse8 = tf.math.reduce_mean(tf.math.square(y_true8 - mu8))
# Average the MSEs
# mse = (mse0 + mse1 + mse2 + mse3 + mse4 + mse5 + mse6 + mse7 + mse8) / 9.0
mse = (mse0 + mse1 + mse2 + mse3 + mse4 + mse5 + mse6) / 9.0
return mse
def crps_f(ep, sg):
'''
Helper function that calculates continuous rank probability score.
'''
crps = sg * ((ep/sg) * tf.math.erf((ep/(np.sqrt(2)*sg))) + tf.math.sqrt(2/np.pi) * tf.math.exp(-ep**2 / (2*sg**2)) - 1/tf.math.sqrt(np.pi))
return crps
def ep_f(y, mu):
'''
Helper function that calculates epsilon (error) for CRPS.
'''
ep = tf.math.abs(y - mu)
return ep
def unstack_helper(y_true, y_pred):
'''
Helper function that unstacks the outputs and targets used in `primesw.crps`.
'''
# Separate the parameters into means and squared standard deviations
mu0, sg0, mu1, sg1, mu2, sg2, mu3, sg3, mu4, sg4, mu5, sg5, mu6, sg6 = tf.unstack(y_pred, axis=-1)
# Separate the ground truth into each parameter
y_true0, y_true1, y_true2, y_true3, y_true4, y_true5, y_true6 = tf.unstack(y_true, axis=-1)
# Add one dimension to make the right shape
mu0 = tf.expand_dims(mu0, -1)
sg0 = tf.expand_dims(sg0, -1)
mu1 = tf.expand_dims(mu1, -1)
sg1 = tf.expand_dims(sg1, -1)
mu2 = tf.expand_dims(mu2, -1)
sg2 = tf.expand_dims(sg2, -1)
mu3 = tf.expand_dims(mu3, -1)
sg3 = tf.expand_dims(sg3, -1)
mu4 = tf.expand_dims(mu4, -1)
sg4 = tf.expand_dims(sg4, -1)
mu5 = tf.expand_dims(mu5, -1)
sg5 = tf.expand_dims(sg5, -1)
mu6 = tf.expand_dims(mu6, -1)
sg6 = tf.expand_dims(sg6, -1)
y_true0 = tf.expand_dims(y_true0, -1)
y_true1 = tf.expand_dims(y_true1, -1)
y_true2 = tf.expand_dims(y_true2, -1)
y_true3 = tf.expand_dims(y_true3, -1)
y_true4 = tf.expand_dims(y_true4, -1)
y_true5 = tf.expand_dims(y_true5, -1)
y_true6 = tf.expand_dims(y_true6, -1)
return mu0, sg0, mu1, sg1, mu2, sg2, mu3, sg3, mu4, sg4, mu5, sg5, mu6, sg6, y_true0, y_true1, y_true2, y_true3, y_true4, y_true5, y_true6
#MMS orbit that ends at bow shock nose stride*100s from the end of the window
#(13.25RE, 0RE, 0RE) (from 2023-01-24 02:46:30+0000 - window - stride to 2023-01-24 02:46:30+0000 - stride)
SYNTH_XPOS = np.array([69215.97057508, 69480.44662705, 69706.40911294, 69969.18467343,
70231.11857452, 70454.91674057, 70715.18415549, 70974.62662114,
71196.30325009, 71454.11198298, 71711.11182052, 71930.70839187,
72186.10608653, 72440.71045452, 72658.26693929, 72911.29961567,
73163.55400328, 73379.10893932, 73629.82119345, 73879.76950962,
74093.36015524, 74341.79489098, 74589.47977345, 74801.14208574,
75047.34088887, 75292.80336019, 75502.57231563, 75746.57534655,
75989.85512801, 76197.76442711, 76439.61054906, 76680.7461448 ,
76886.82833966, 77126.55527427, 77365.58400364, 77569.8706008 ,
77807.51485289, 78044.47276422, 78246.99446109, 78482.59130139,
78717.51337097, 78918.29986461, 79151.88352716, 79384.80363257,
79583.88365476, 79815.48746062, 80046.43841324, 80243.83985851,
80473.4959342 , 80702.50977538])
"""Synthetic MMS-1 X position for prediction at bow shock"""
SYNTH_YPOS = np.array([-6242.531374 , -6141.43603983, -6054.73851884, -5953.54260001,
-5852.30035979, -5765.48374707, -5664.15672177, -5562.79111177,
-5475.87538835, -5374.44022431, -5272.97399371, -5185.97841963,
-5084.45744369, -4982.91267323, -4895.85591308, -4794.27072076,
-4692.66879919, -4605.56897563, -4503.9403974 , -4402.3019605 ,
-4315.17661655, -4213.52495494, -4111.87001052, -4024.73625516,
-3923.08129418, -3821.42952425, -3734.30395957, -3632.66493195,
-3531.03538868, -3443.93418239, -3342.32979659, -3240.74102942,
-3153.67997007, -3052.12846626, -2950.598573 , -2863.59314841,
-2762.11237883, -2660.65906956, -2573.72426935, -2472.33166917,
-2370.97225436, -2284.122821 , -2182.83554019, -2081.58702282,
-1994.83737245, -1893.67203926, -1792.55094402, -1705.91518784,
-1604.88808488, -1503.91065025])
"""Synthetic MMS-1 Y position for prediction at bow shock"""
SYNTH_ZPOS = np.array([1428.22663895, 1404.9257232 , 1384.88999865, 1361.43852111,
1337.90827722, 1317.66995298, 1293.97504102, 1270.1942372 ,
1249.73509015, 1225.77542686, 1201.72276543, 1181.02453785,
1156.77888242, 1132.43315908, 1111.47770227, 1086.92488254,
1062.26489785, 1041.03400953, 1016.15277964, 991.1573987 ,
969.6329908 , 944.40232752, 919.05058092, 897.21474963,
871.61379271, 845.88483511, 823.71973863, 797.7278038 ,
771.60099544, 749.08907437, 722.68572634, 696.14074783,
673.26462524, 646.42970201, 619.4464814 , 596.18909891,
568.90278803, 541.46164749, 517.80622717, 490.04905226,
462.13065323, 438.0606711 , 409.81356785, 381.39897532,
356.89834051, 328.14264026, 299.21335127, 274.26632698,
244.98385268, 215.52178328])
"""Synthetic MMS-1 Z position for prediction at bow shock"""
SYNTH_POS = np.array([SYNTH_XPOS, SYNTH_YPOS, SYNTH_ZPOS]).T
"""Synthetic MMS-1 orbit for prediction at bow shock"""