diff --git a/src/utility.py b/src/utility.py index 28f3b50c733ff193f93a13958f6ef4da506236aa..dae556d95075e0380d7315b687aef75c79e1588c 100644 --- a/src/utility.py +++ b/src/utility.py @@ -1,5 +1,6 @@ +import copy import numpy as np - +import murty as murty_ def match_tracks_to_ground_truth(tracks, ground_truth): """Matches tracks to ground truth naively. @@ -41,3 +42,92 @@ def match_tracks_to_ground_truth(tracks, ground_truth): matches[track['identity']] = key ormse = rmse return matches + +def recreate_trajectories(hypothesis, marginalize=True): + confirmed_tracks = dict() + tracks = dict() + for t, hyp_t in hypothesis.items(): + t_tracks = dict() + t_prob = [] + for hyp in hyp_t: + t_prob.append(hyp['probability']) + for track in hyp['tracks']: + # Restructure the tracks to easily marginalize + if track['identity'] not in t_tracks.keys(): + t_tracks[track['identity']] = [track] + else:# If a track exists in more than one hypothesis it exists in all of them + t_tracks[track['identity']].append(track) + t_prob = np.array(t_prob) + # Marginalization over the hypothesis + for track_list in t_tracks.values(): + # Identify the track in the previous track list + track_identity = track_list[0]['identity'] + associations = [association for track in track_list for association in track['associations']] # Get associations in the different hypothesis + mult_hyp = len(track_list) == t_prob.size + if mult_hyp: + # Compute the track score + Lt = np.array([track['Lt'] for track in track_list]) + x = np.vstack([track['x'][0] for track in track_list]) + P = [track['P'][0] for track in track_list] + if marginalize: + xhat = t_prob@x + Lt = Lt@t_prob + # Compute the state error covariance + err = (x-xhat[None, :]).T + Pk = np.stack([col[:,None]@col[None,:] for col in err.T]) + Phat = np.tensordot(np.stack(P)+Pk, t_prob, (0, 0)) # Dot product over axis 0 + else: + most_prob = np.argmax(t_prob) + xhat = x[most_prob, :] + Phat = P[most_prob] + Lt = Lt[most_prob] + else: + xhat = track_list[0]['x'][-1] + Phat = track_list[0]['P'][-1] + Lt = track_list[0]['Lt'] + + if track_identity in tracks.keys(): + tracks[track_identity]['x'].append(xhat) + tracks[track_identity]['P'].append(Phat) + tracks[track_identity]['Lt'] = Lt + if t in associations: + tracks[track_identity]['associations'].append(t) + if mult_hyp: + stages = [track['stage'] for track in track_list] + # The stage is assumed to be the most probable hypothesis + tracks[track_identity]['stage'] = stages[np.argmax(t_prob)] + else: + tracks[track_identity]['stage'] = track_list[0]['stage'] + else: + track = copy.deepcopy(track_list[0]) + track['x'] = [xhat] + track['P'] = [Phat] + if t in associations and t not in track['associations']: + track['associations'].append(t) + track['Lt'] = Lt + tracks[track_identity] = track + if tracks[track_identity]['stage'] == 'confirmed': + if track_identity not in confirmed_tracks.keys(): + confirmed_tracks[track_identity] = tracks[track_identity] + return tracks.values(), confirmed_tracks.values() + +def murty(C): + """Algorithm due to Murty.""" + mgen = murty_.Murty(C) + while True: + ok, cost, sol = mgen.draw() + if not ok: + return None + yield cost, sol + +# Save result +def save_result(filename, result): + for track in result['tracks']: + track.pop('filt', None) # Pop the filter object so it is only data that is saved. + np.save(filename, result) + +def load_result(filename): + if '.npy' not in filename: + filename += '.npy' + result = np.load(filename, allow_pickle=True) + return result.item()