Skip to content
Snippets Groups Projects
Commit 145415f7 authored by Joaquin Rives Gambin's avatar Joaquin Rives Gambin
Browse files

reshape_segmented_arrays function added to preprocess_and_segmentaion.py

parent 9d10ad10
No related branches found
No related tags found
No related merge requests found
......@@ -5,8 +5,6 @@ import re
import matplotlib.pyplot as plt
from preprocessor import preprocess_input_data
data_dir = 'sample_of_data/Training_WFDB'
fs = 500 # Hz
lead_labels = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
......@@ -84,12 +82,13 @@ def window_stack(a, win_width, overlap):
def segmenting_data(dict_of_data, seg_width, overlap_perc):
segmented_signals = {}
ignore_key = ['info']
for key in dict_of_data.keys() - ignore_key:
segmented_signals[key] = window_stack(dict_of_data[key], seg_width, overlap=overlap_perc)
# add the label to the dict
segmented_signals['info'] = np.repeat(dict_of_data['info'], len(segmented_signals[lead_labels[0]]))
# add the info/label back to the dict
segmented_signals['info'] = np.repeat(dict_of_data['info'], len(segmented_signals[lead_labels[1]]))
return segmented_signals
......@@ -108,8 +107,167 @@ def segment_all_dict_data(data_dict, seg_width, overlap_perc):
return segmented_dict_of_data
# def mad(my_segment, theta=10):
# my_median = np.median(my_segment)
#
# ########################### find the outliers
#
# MedianAD = theta * (np.median(np.abs(my_segment - my_median)))
#
# MedianAD_flag = np.abs(my_segment - my_median)
#
# outliers = np.where(MedianAD_flag > MedianAD, 1, 0)
#
# ########################## get sign of the data
#
# sign_of_my_segment = np.where(my_segment > 0, 1, -1)
#
# ########################## replace the ones positive with my_median and the negative ones with -my_median
#
# cleaned_segment = my_segment.copy()
#
# outliers_to_replace = outliers * sign_of_my_segment
#
# cleaned_segment[np.where(outliers_to_replace == +1)] = abs(MedianAD)
#
# cleaned_segment[np.where(outliers_to_replace == -1)] = -abs(MedianAD)
#
# #########################
#
# return cleaned_segment
def reshape_segmented_arrays(input_dict, shuffle_IDs=True, shuffle_segments=True, # outlier_rejection_flag=True,
segment_standardization_flag=True):
from random import shuffle
#########################################
list_of_swapped_stack = []
list_of_ID_arrays = []
list_of_label_arrays = []
#########################################
for key in input_dict.keys():
print(key)
##################################### list of the matrices of segmented data in 6 channel
dict_data = input_dict[key]
ID = key
data_list = [v for k, v in dict_data.items() if k != 'info']
##################################### stacking all the data into one array
data_stacked_array = np.stack(data_list, axis=0)
##################################### outlier rejection by 5 and 95th percentile at each segment
# if outlier_rejection_flag:
# # data_stacked_array = outlier_rejection(data_stacked_array)
#
# data_stacked_array = np.apply_along_axis(mad, 2, data_stacked_array)
#
# ##################################### shuffle the segments in the data_stacked_array cleaned
if shuffle_segments:
random_indices = np.random.randint(0, data_stacked_array.shape[1], data_stacked_array.shape[1])
data_stacked_array = data_stacked_array[:, random_indices, :]
##################################### swap the axes
swaped_stack = np.swapaxes(np.swapaxes(data_stacked_array, 0, 2), 0, 1)
#####################################
ID_for_segments = np.repeat(ID, swaped_stack.shape[0])
label_for_segments = dict_data['info']
#################################### append to their corresponding lists
list_of_swapped_stack.append(swaped_stack)
list_of_ID_arrays.append(ID_for_segments)
list_of_label_arrays.append(label_for_segments)
# print(swaped_stack.shape)
################################### shuffle the order of subjects in every list
if shuffle_IDs:
######################## generate random indices
perm = list(range(len(list_of_ID_arrays)))
shuffle(perm)
# print(perm)
######################## rearrange the lists
list_of_swapped_stack = [list_of_swapped_stack[index] for index in perm]
list_of_ID_arrays = [list_of_ID_arrays[index] for index in perm]
list_of_label_arrays = [list_of_label_arrays[index] for index in perm]
################################### transform the lists into numpy arrays by stacking along first axis
array_of_segments = np.concatenate(list_of_swapped_stack, axis=0)
array_of_IDs = np.concatenate(list_of_ID_arrays, axis=0)[:, np.newaxis]
array_of_labels = np.concatenate(list_of_label_arrays, axis=0)[:, np.newaxis]
################################# normalize every segemnt
if segment_standardization_flag:
def segment_standardization(my_segment):
from sklearn.preprocessing import StandardScaler
#################
s = StandardScaler()
################# fit on training data
normalized_segment = s.fit_transform(my_segment[:, np.newaxis])
#############
return (normalized_segment.ravel())
############################
print("Standardizing...")
array_of_segments = np.apply_along_axis(segment_standardization, 1, array_of_segments)
################################# print the shapes
print('shape of the array of segments is :', array_of_segments.shape)
print('shape of the array of IDs is :', array_of_IDs.shape)
print('shape of the array of labels is :', array_of_labels.shape)
##################################
return (array_of_segments, array_of_labels, array_of_IDs)
def plot_segment(inputarray, seg_indx, axis1, axis2):
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(inputarray[seg_indx, :, axis1:axis2])
plt.show()
return (fig)
if __name__ == '__main__':
data_dir = 'data/Training_WFDB'
# load data
data = load_data(data_dir)
......@@ -119,8 +277,23 @@ if __name__ == '__main__':
# segment signal
data = segment_all_dict_data(data, 500, 0.5)
subj2_leadII = data['A0002']['II']
# reshape to array
arr_of_segments, arr_of_labels, arr_of_IDs = reshape_segmented_arrays(data,
shuffle_IDs=True,
shuffle_segments=True,
# outlier_rejection_flag=False,
segment_standardization_flag=True
)
# Plot segment examples
# plot_segment(arr_of_segments, 0, 0, 1)
# Check labels
from collections import Counter
labels = [i[0]['Dx'] for i in arr_of_labels]
label_count = Counter(labels)
print(label_count)
plt.plot(subj2_leadII[1])
plt.plot(subj2_leadII[2])
plt.show()
\ No newline at end of file
# plt.bar(label_count.keys(), label_count.values())
# plt.show()
......@@ -3,7 +3,7 @@ import copy
FS = 500
lf_filter = 0.5 # Hz
hf_filter = 20 # Hz
hf_filter = 30 # Hz
order_filter = 4
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment