-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils_nmf.py
130 lines (102 loc) · 3.26 KB
/
utils_nmf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import sys
import time
import pandas as pd
import os.path as op
import argparse
import subprocess
import multiprocessing
from multiprocessing import Pool
import hashlib
from pathlib import Path
dpath = str(Path(op.realpath(__file__)).parent)
DEF_NR_THREADS = multiprocessing.cpu_count()
SEP = ','
#################################
# #
# General methods #
# #
#################################
def eprint(*args, **kargs):
print(*args, file=sys.stderr, **kargs)
def validate_file(fpath):
if not op.isfile(fpath):
eprint('Invalid file', fpath)
exit()
return fpath
def check_executable(cmd):
for p in os.environ['PATH'].split(":"):
if os.access(op.join(p, cmd), os.X_OK):
return
eprint(f'executable {cmd} not found in PATH')
exit(1)
def pat2name(pat):
return op.basename(op.splitext(op.splitext(pat)[0])[0])
def drop_dup_keep_order(lst):
seen = set()
return [x for x in lst if not (x in seen or seen.add(x))]
def remove_files(files):
for f in files:
if op.isfile(f):
os.remove(f)
def mkdir_p(dirpath):
if not op.isdir(dirpath):
os.mkdir(dirpath)
return dirpath
def dump_df(fpath, df, verbose=True):
df.to_csv(fpath, float_format='%.5f')
if verbose:
eprint(f'dumped {fpath}')
#################################
# #
# NMF Specific logic #
# #
#################################
def load_table(table_path, norm_cols=False):
validate_file(table_path)
df = pd.read_csv(table_path, sep=SEP, index_col=None)
if df.shape[1] < 3:
eprint(f'Invalid table: {table_path}. Too few columns ({df.shape[1]})')
exit(1)
df.columns = ['feature'] + list(df.columns)[1:]
df.set_index('feature', inplace=True)
if norm_cols:
df = df.apply(lambda x: x / x.sum())
return df
def parse_cut_str(cstr, maxlen):
"""
Parse user input forr columns choice with unix cut syntax
e.g. 1-4,5,19- will translate to choosings columns 1,2,3,4,5,19,...
"""
if cstr is None:
return []
cstr = ''.join(cstr.split()) # remove whitespace
if cstr.lower() == 'all' or cstr == '-':
return list(range(maxlen))
# validate string:
for c in cstr:
if c not in list(map(str, range(10))) + [',', '-']:
eprint('Invalid input:', cstr)
eprint('Only digits and [,-] characters are allowed')
eprint(f'found: "{c}"')
exit(1)
# parse it
include = []
if cstr.endswith('-'):
cstr += str(maxlen)
elif cstr.startswith('-'):
cstr = '1' + cstr
for stub in cstr.split(','):
if '-' not in stub:
include.append(int(stub))
continue
start, end = map(int, stub.split('-'))
assert end > start, 'range must be increasing!'
include += list(range(start, min(end, maxlen) + 1))
# another validation
if len(set(include)) != len(include):
eprint('Invalid input: duplicated columns', cstr)
exit()
# exclude = [x for x in range(1, maxlen + 1) if x not in include]
# return include, exclude
return sorted([x - 1 for x in include])