Skip to content

Commit 738ffab

Browse files
authored
Merge pull request #3 from yihong0618/main
feat: support cpu
2 parents 5c4a63c + ab30fcf commit 738ffab

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

model_torch.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,13 @@ def loadImages(folder):
238238
if __name__ == "__main__":
239239
model = res_skip()
240240
model.load_state_dict(torch.load('erika.pth'))
241-
242-
model.cuda()
241+
is_cuda = torch.cuda.is_available()
242+
if is_cuda:
243+
model.cuda()
244+
else:
245+
model.cpu()
243246
model.eval()
244247

245-
246248
filelists = loadImages(sys.argv[1])
247249

248250
with torch.no_grad():
@@ -255,8 +257,11 @@ def loadImages(folder):
255257
# manually construct a batch. You can change it based on your usecases.
256258
patch = np.ones((1,1,rows,cols),dtype="float32")
257259
patch[0,0,0:src.shape[0],0:src.shape[1]] = src
258-
259-
tensor = torch.from_numpy(patch).cuda()
260+
261+
if is_cuda:
262+
tensor = torch.from_numpy(patch).cuda()
263+
else:
264+
tensor = torch.from_numpy(patch).cpu()
260265
y = model(tensor)
261266
print(imname, torch.max(y), torch.min(y))
262267

@@ -266,8 +271,3 @@ def loadImages(folder):
266271

267272
head, tail = os.path.split(imname)
268273
cv2.imwrite(sys.argv[2]+"/"+tail.replace(".jpg",".png"),yc[0:src.shape[0],0:src.shape[1]])
269-
270-
271-
272-
273-

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch==1.9.1
2+
opencv-python

0 commit comments

Comments
 (0)