-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcond_consistency_across_rats_all.py
185 lines (151 loc) · 7.74 KB
/
cond_consistency_across_rats_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import sys
import os
import numpy as np
import pandas as pd
import torch
import random
from datetime import datetime
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import r2_score
from scipy import stats
import gc
import argparse
import joblib as jl
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from cebra import CEBRA
from hold_out import hold_out
from CSUS_score import CSUS_score
from consistency import consistency
import glob
# Adding library paths
sys.path.extend([
'/home/hsw967/Programming/Hannahs-CEBRAs',
'/home/hsw967/Programming/Hannahs-CEBRAs/scripts',
'/Users/Hannah/Programming/Hannahs-CEBRAs',
'/Users/Hannah/anaconda3/envs/CEBRA/lib/python3.8/site-packages/cebra'
])
#ex
# This function measures consistency across environments for the same rat
# Global rat IDs
#rat_ids = ['0222', '0313', '314', '0816']
rat_ids = ['0222', '0307', '0313', '314', '0816']
def save_results(results, base_filename):
""" Save results to a CSV file. """
scores_runs, pairs_runs, ids_runs = results # Unpack the three arrays
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
filename = f"{base_filename}_{current_time}.csv"
with open(filename, 'w') as f:
for score, pair in zip(scores_runs, pairs_runs):
pair_str = '-'.join(map(str, pair)) # Convert tuple to string
f.write(f"{pair_str},{score}\n")
print(f"Results saved to {filename}")
def load_files(model_pattern, dimension, base_dir, divisor):
""" Load model files based on pattern, dimension, and specific rat IDs. """
files = []
for rat_id in rat_ids:
path_pattern = f"{base_dir}/rat{rat_id}/cebra_variables/models/{model_pattern}{dimension}_*_{divisor}.pt"
matched_files = glob.glob(path_pattern)
files.extend(matched_files)
return files
def calculate_all_models_consistency(model_data_pairs):
""" Calculate consistency for all model-data pairs. """
transformations = []
for filename, data in model_data_pairs:
print(f"Loading model from: {filename}")
model = CEBRA.load(filename)
transformations.append(model.transform(data))
if transformations:
scores, pairs, ids = consistency(transformations)
return scores, pairs, ids
else:
print("No transformations to process.")
return None
def main(trace_data_A1, trace_data_An, trace_data_B1, trace_data_B2):
base_dir = os.getcwd()
print(f"Using base directory: {base_dir}")
model_patterns = ["modelA1_dim", "modelAn_dim", "modelB1_dim", "modelB2_dim", "modelA1_shuffled_dim", "modelAn_shuffled_dim", "modelB1_shuffled_dim", "modelB2_shuffled_dim"]
dimensions = ["2", "3", "5", "7", "10"]
divisor = "div2"
for dimension in dimensions:
all_model_data_pairs = [] # Initialize list to collect all pairs for the current dimension
for model_pattern in model_patterns:
files = load_files(model_pattern, dimension, base_dir, divisor)
if files:
for file in files:
# Extract rat_id from file path
rat_id = file.split('/rat')[1].split('/')[0]
index = rat_ids.index(rat_id)
# Select data based on model pattern
if "A1" in model_pattern:
data_list = [trace_data_A1[index]] # Select corresponding A data
elif "An" in model_pattern:
data_list = [trace_data_An[index]] # Select corresponding A data
elif "B1" in model_pattern:
data_list = [trace_data_B1[index]] # Select corresponding B data
elif "B2" in model_pattern:
data_list = [trace_data_B2[index]] # Select corresponding B data
else:
continue # Adjust as necessary for other patterns
# Append to all_model_data_pairs for the current dimension
all_model_data_pairs.extend([(file, data) for data in data_list])
else:
print(f"No files loaded for pattern {model_pattern} and dimension {dimension}.")
# After collecting all model-data pairs for the current dimension, process them
if all_model_data_pairs:
results = calculate_all_models_consistency(all_model_data_pairs)
if results:
save_results(results, f"{dimension}_{divisor}")
else:
print("No results to save for dimension {}.".format(dimension))
else:
print(f"No model-data pairs to process for dimension {dimension}.")
model_patterns = ["modelA1_dim", "modelAn_dim", "modelB1_dim", "modelB2_dim", "modelA1_shuffled_dim", "modelAn_shuffled_dim", "modelB1_shuffled_dim", "modelB2_shuffled_dim"]
dimensions = ["2", "3", "5", "7", "10"]
divisor = "div5"
for dimension in dimensions:
all_model_data_pairs = [] # Initialize list to collect all pairs for the current dimension
for model_pattern in model_patterns:
files = load_files(model_pattern, dimension, base_dir, divisor)
if files:
for file in files:
# Extract rat_id from file path
rat_id = file.split('/rat')[1].split('/')[0]
index = rat_ids.index(rat_id)
# Select data based on model pattern
if "A1" in model_pattern:
data_list = [trace_data_A1[index]] # Select corresponding A data
elif "An" in model_pattern:
data_list = [trace_data_An[index]] # Select corresponding A data
elif "B1" in model_pattern:
data_list = [trace_data_B1[index]] # Select corresponding B data
elif "B2" in model_pattern:
data_list = [trace_data_B2[index]] # Select corresponding B data
else:
continue # Adjust as necessary for other patterns
# Append to all_model_data_pairs for the current dimension
all_model_data_pairs.extend([(file, data) for data in data_list])
else:
print(f"No files loaded for pattern {model_pattern} and dimension {dimension}.")
# After collecting all model-data pairs for the current dimension, process them
if all_model_data_pairs:
results = calculate_all_models_consistency(all_model_data_pairs)
if results:
save_results(results, f"{dimension}_{divisor}")
else:
print("No results to save for dimension {}.".format(dimension))
else:
print(f"No model-data pairs to process for dimension {dimension}.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the CEBRA model evaluation.")
for i in range(1, 6):
parser.add_argument(f"--traceR{i}A1", required=True, help=f"File path for traceR{i}A data.")
parser.add_argument(f"--traceR{i}An", required=True, help=f"File path for traceR{i}A data.")
parser.add_argument(f"--traceR{i}B1", required=True, help=f"File path for traceR{i}B data.")
parser.add_argument(f"--traceR{i}B2", required=True, help=f"File path for traceR{i}B data.")
args = parser.parse_args()
trace_data_A1 = [load_data(args.__dict__[f'traceR{i}A1']) for i in range(1, 6)]
trace_data_An = [load_data(args.__dict__[f'traceR{i}An']) for i in range(1, 6)]
trace_data_B1 = [load_data(args.__dict__[f'traceR{i}B1']) for i in range(1, 6)]
trace_data_B2 = [load_data(args.__dict__[f'traceR{i}B2']) for i in range(1, 6)]
main(trace_data_A1, trace_data_An, trace_data_B1, trace_data_B2)