Skip to content
Snippets Groups Projects
Commit 3e71b84d authored by Anton Kullberg's avatar Anton Kullberg
Browse files

py: updated data generation for ex2

parent 69f069ac
No related branches found
No related tags found
No related merge requests found
import numpy as np
def generate_data(trajectory, sensor_model, clutter_model, rng=None):
def generate_data(trajectories, sensor_model, clutter_model, rng=None):
"""Simulates measurements along a state trajectory according to a sensor and clutter model.
The function assumes Gaussian white noise affecting the measurements.
Parameters
----------
trajectory : numpy.ndarray
A nx by N array where nx is the state dimension and N is the number of time steps
trajectories : dict of numpy.ndarrays
A dict with entries with nx by N arrays where nx is the state dimension and N is the number of time steps. Sample time T=1 is assumed.
sensor_model : dict
A dictionary with the following entries:
h : callable
......@@ -26,23 +26,28 @@ def generate_data(trajectory, sensor_model, clutter_model, rng=None):
rng : Generator
A numpy random number generator. Can be constructed by e.g. np.random.default_rng()
Returns
-------
list of numpy.ndarray
Each list item is a numpy.ndarray with zero or more measurements (ny by x)
"""
N = max([T.shape[1] for key, T in trajectories.items()]) # Maximum length of a trajectory interesting for this purpose
if rng is None:
rng = np.random.default_rng()
measurements = []
ny = sensor_model['h'](trajectory[:, 0]).size
for state in trajectory.T:#.T to aid for-loop
ny = sensor_model['h'](trajectories[next(iter(trajectories))][:, 0]).size # Get the dimensionality of the measurements
for n in range(N):
# Determine amount of clutter this time
nclutter = rng.poisson(lam=clutter_model['lam'])
trajs = [T for key, T in trajectories.items() if n<T.shape[1]] # Figure out what trajectories are active right now
Ntrajs = len(trajs) # Calc. number of trajectories present in the current time step
# Initialize an array w/ the number of measurements this time step
cur_measurements = np.empty((ny, nclutter+1))
cur_measurements = np.empty((ny, nclutter+Ntrajs))
cur_measurements[:, :] = np.NaN
if nclutter != 0:
# Calc. clutter states
......@@ -53,15 +58,16 @@ def generate_data(trajectory, sensor_model, clutter_model, rng=None):
high=clutter_model['volume']['ymax'], size=(nclutter,))
]).reshape(-1, nclutter)
cur_measurements[:, :-1] = (sensor_model['h'](clutter_states)+\
cur_measurements[:, :nclutter] = (sensor_model['h'](clutter_states)+\
rng.multivariate_normal(mean=np.zeros((ny,)), cov=sensor_model['R'], size=(nclutter)).squeeze().T).reshape(-1, nclutter)
# Generate measurement of target (possibly)
if rng.uniform() <= sensor_model['PD']:
y = sensor_model['h'](state)+\
rng.multivariate_normal(mean=np.zeros((ny,)), cov=sensor_model['R'])
cur_measurements[:, -1] = y.flatten() # Add actual observation to array
else:
cur_measurements = cur_measurements[:, :-1]
# Generate measurement of target(s) (possibly)
for nt, traj in enumerate(trajs):
if rng.uniform() <= sensor_model['PD']:
y = sensor_model['h'](traj[:, n])+\
rng.multivariate_normal(mean=np.zeros((ny,)), cov=sensor_model['R'])
cur_measurements[:, nclutter+nt] = y.flatten() # Add actual observation to array
cur_measurements = cur_measurements[~np.isnan(cur_measurements)].reshape(ny, -1) # Remove nan measurements (i.e. targets that did not generate a measurement)
measurements.append(cur_measurements)
return measurements
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment