Skip to content
Snippets Groups Projects
Commit 803804a6 authored by Anton Kullberg's avatar Anton Kullberg
Browse files

py: added saving/loading & MHT trajectory recreation

parent 6df4fad5
No related branches found
No related tags found
No related merge requests found
import copy
import numpy as np
import murty as murty_
def match_tracks_to_ground_truth(tracks, ground_truth):
"""Matches tracks to ground truth naively.
......@@ -41,3 +42,92 @@ def match_tracks_to_ground_truth(tracks, ground_truth):
matches[track['identity']] = key
ormse = rmse
return matches
def recreate_trajectories(hypothesis, marginalize=True):
confirmed_tracks = dict()
tracks = dict()
for t, hyp_t in hypothesis.items():
t_tracks = dict()
t_prob = []
for hyp in hyp_t:
t_prob.append(hyp['probability'])
for track in hyp['tracks']:
# Restructure the tracks to easily marginalize
if track['identity'] not in t_tracks.keys():
t_tracks[track['identity']] = [track]
else:# If a track exists in more than one hypothesis it exists in all of them
t_tracks[track['identity']].append(track)
t_prob = np.array(t_prob)
# Marginalization over the hypothesis
for track_list in t_tracks.values():
# Identify the track in the previous track list
track_identity = track_list[0]['identity']
associations = [association for track in track_list for association in track['associations']] # Get associations in the different hypothesis
mult_hyp = len(track_list) == t_prob.size
if mult_hyp:
# Compute the track score
Lt = np.array([track['Lt'] for track in track_list])
x = np.vstack([track['x'][0] for track in track_list])
P = [track['P'][0] for track in track_list]
if marginalize:
xhat = t_prob@x
Lt = Lt@t_prob
# Compute the state error covariance
err = (x-xhat[None, :]).T
Pk = np.stack([col[:,None]@col[None,:] for col in err.T])
Phat = np.tensordot(np.stack(P)+Pk, t_prob, (0, 0)) # Dot product over axis 0
else:
most_prob = np.argmax(t_prob)
xhat = x[most_prob, :]
Phat = P[most_prob]
Lt = Lt[most_prob]
else:
xhat = track_list[0]['x'][-1]
Phat = track_list[0]['P'][-1]
Lt = track_list[0]['Lt']
if track_identity in tracks.keys():
tracks[track_identity]['x'].append(xhat)
tracks[track_identity]['P'].append(Phat)
tracks[track_identity]['Lt'] = Lt
if t in associations:
tracks[track_identity]['associations'].append(t)
if mult_hyp:
stages = [track['stage'] for track in track_list]
# The stage is assumed to be the most probable hypothesis
tracks[track_identity]['stage'] = stages[np.argmax(t_prob)]
else:
tracks[track_identity]['stage'] = track_list[0]['stage']
else:
track = copy.deepcopy(track_list[0])
track['x'] = [xhat]
track['P'] = [Phat]
if t in associations and t not in track['associations']:
track['associations'].append(t)
track['Lt'] = Lt
tracks[track_identity] = track
if tracks[track_identity]['stage'] == 'confirmed':
if track_identity not in confirmed_tracks.keys():
confirmed_tracks[track_identity] = tracks[track_identity]
return tracks.values(), confirmed_tracks.values()
def murty(C):
"""Algorithm due to Murty."""
mgen = murty_.Murty(C)
while True:
ok, cost, sol = mgen.draw()
if not ok:
return None
yield cost, sol
# Save result
def save_result(filename, result):
for track in result['tracks']:
track.pop('filt', None) # Pop the filter object so it is only data that is saved.
np.save(filename, result)
def load_result(filename):
if '.npy' not in filename:
filename += '.npy'
result = np.load(filename, allow_pickle=True)
return result.item()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment