Skip to content

Commit b1d7b74

Browse files
authored
[REF] Use new function to run LDA commands (#587)
* Add new function for running shell commands. * Add test for new utility function.
1 parent fd46f9c commit b1d7b74

File tree

3 files changed

+64
-13
lines changed

3 files changed

+64
-13
lines changed

nimare/annotate/lda.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
"""Topic modeling with latent Dirichlet allocation via MALLET."""
22
import logging
33
import os
4-
import os.path as op
54
import shutil
6-
import subprocess
75

86
import numpy as np
97
import pandas as pd
@@ -12,6 +10,7 @@
1210
from ..base import NiMAREBase
1311
from ..due import due
1412
from ..extract import download_mallet, utils
13+
from ..utils import run_shell_command
1514

1615
LGR = logging.getLogger(__name__)
1716

@@ -73,12 +72,12 @@ def __init__(
7372
self, text_df, text_column="abstract", n_topics=50, n_iters=1000, alpha="auto", beta=0.001
7473
):
7574
mallet_dir = download_mallet()
76-
mallet_bin = op.join(mallet_dir, "bin/mallet")
75+
mallet_bin = os.path.join(mallet_dir, "bin/mallet")
7776

7877
model_dir = utils._get_dataset_dir("mallet_model")
79-
text_dir = op.join(model_dir, "texts")
78+
text_dir = os.path.join(model_dir, "texts")
8079

81-
if not op.isdir(model_dir):
80+
if not os.path.isdir(model_dir):
8281
os.mkdir(model_dir)
8382

8483
if alpha == "auto":
@@ -90,7 +89,7 @@ def __init__(
9089
self.model_dir = model_dir
9190

9291
# Check for presence of text files and convert if necessary
93-
if not op.isdir(text_dir):
92+
if not os.path.isdir(text_dir):
9493
LGR.info("Texts folder not found. Creating text files...")
9594
os.mkdir(text_dir)
9695

@@ -104,11 +103,11 @@ def __init__(
104103

105104
for id_ in text_df["id"].values:
106105
text = text_df.loc[text_df["id"] == id_, text_column].values[0]
107-
with open(op.join(text_dir, str(id_) + ".txt"), "w") as fo:
106+
with open(os.path.join(text_dir, str(id_) + ".txt"), "w") as fo:
108107
fo.write(text)
109108

110109
# Run MALLET topic modeling
111-
LGR.info("Generating topics...")
110+
LGR.info("Compiling MALLET commands...")
112111
import_str = (
113112
f"{mallet_bin} import-dir "
114113
f"--input {text_dir} "
@@ -142,8 +141,9 @@ def fit(self):
142141
p_word_g_topic_ : :obj:`numpy.ndarray`
143142
Probability of each word given a topic
144143
"""
145-
subprocess.call(self.commands_[0], shell=True)
146-
subprocess.call(self.commands_[1], shell=True)
144+
LGR.info("Generating topics...")
145+
run_shell_command(self.commands_[0])
146+
run_shell_command(self.commands_[1])
147147

148148
# Read in and convert doc_topics and topic_keys.
149149
topic_names = [f"topic_{i:03d}" for i in range(self.params["n_topics"])]
@@ -158,7 +158,7 @@ def fit(self):
158158
# on an individual id basis by the weights.
159159
n_cols = (2 * self.params["n_topics"]) + 1
160160
dt_df = pd.read_csv(
161-
op.join(self.model_dir, "doc_topics.txt"),
161+
os.path.join(self.model_dir, "doc_topics.txt"),
162162
delimiter="\t",
163163
skiprows=1,
164164
header=None,
@@ -194,7 +194,7 @@ def fit(self):
194194

195195
# Topic word weights
196196
p_word_g_topic_df = pd.read_csv(
197-
op.join(self.model_dir, "topic_word_weights.txt"),
197+
os.path.join(self.model_dir, "topic_word_weights.txt"),
198198
dtype=str,
199199
keep_default_na=False,
200200
na_values=[],
@@ -213,7 +213,7 @@ def fit(self):
213213
shutil.rmtree(self.model_dir)
214214

215215
def _clean_str(self, string):
216-
return op.basename(op.splitext(string)[0])
216+
return os.path.basename(os.path.splitext(string)[0])
217217

218218
def _get_sort(self, lst):
219219
return [i[0] for i in sorted(enumerate(lst), key=lambda x: x[1])]

nimare/tests/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
import os.path as op
5+
import time
56

67
import nibabel as nib
78
import numpy as np
@@ -164,3 +165,27 @@ def test_mm2vox():
164165
img = utils.get_template(space="mni152_2mm", mask=None)
165166
aff = img.affine
166167
assert np.array_equal(utils.mm2vox(test, aff), true)
168+
169+
170+
def test_run_shell_command(caplog):
171+
"""Test run_shell_command."""
172+
with caplog.at_level(logging.INFO):
173+
utils.run_shell_command("echo 'output'")
174+
assert "output" in caplog.text
175+
176+
# Check that the exception is registered as such
177+
with pytest.raises(Exception) as execinfo:
178+
utils.run_shell_command("echo 'Error!' 1>&2;exit 64")
179+
assert "Error!" in str(execinfo.value)
180+
181+
# Check that the function actually waits until the command completes
182+
dur = 3
183+
start = time.time()
184+
with caplog.at_level(logging.INFO):
185+
utils.run_shell_command(f"echo 'hi';sleep {dur}s;echo 'bye'")
186+
end = time.time()
187+
188+
assert "hi" in caplog.text
189+
assert "bye" in caplog.text
190+
duration = end - start
191+
assert duration >= dur

nimare/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import os.path as op
77
import re
8+
import subprocess
89
from functools import wraps
910
from tempfile import mkstemp
1011

@@ -935,3 +936,28 @@ def boolean_unmask(data_array, bool_array):
935936
unmasked_data[bool_array] = data_array
936937
unmasked_data = unmasked_data.T
937938
return unmasked_data
939+
940+
941+
def run_shell_command(command, env=None):
942+
"""Run a given command with certain environment variables set."""
943+
merged_env = os.environ
944+
if env:
945+
merged_env.update(env)
946+
947+
process = subprocess.Popen(
948+
command,
949+
stdout=subprocess.PIPE,
950+
stderr=subprocess.PIPE,
951+
shell=True,
952+
env=merged_env,
953+
)
954+
while True:
955+
line = process.stdout.readline()
956+
line = str(line, "utf-8")[:-1]
957+
LGR.info(line)
958+
if line == "" and process.poll() is not None:
959+
break
960+
961+
if process.returncode != 0:
962+
stderr_line = str(process.stderr.read(), "utf-8")[:-1]
963+
raise Exception(f"Non zero return code: {process.returncode}\n{command}\n\n{stderr_line}")

0 commit comments

Comments
 (0)