diff --git a/test.py b/test.py index 34ab746df1e2cb4c9666289bd5502d6a93397001..0ab2cbe9b07df0092f710f93345dc617d2fe4201 100644 --- a/test.py +++ b/test.py @@ -101,13 +101,12 @@ for s, stain in enumerate(stainsToValidate): for i in range(numberClassesToEvaluate): classEvaluators[s][i].add_example(classInstancePredictionList[i], classInstanceGTList[i]) - prediction = torch.softmax(prediction, 1) STAIN_to_PAS_img = STAIN_to_PAS_img.flip(2) - prediction += torch.softmax(segModel(STAIN_to_PAS_img), 1).flip(2) + prediction += segModel(STAIN_to_PAS_img).flip(2) STAIN_to_PAS_img = STAIN_to_PAS_img.flip(3) - prediction += torch.softmax(segModel(STAIN_to_PAS_img), 1).flip(3).flip(2) + prediction += segModel(STAIN_to_PAS_img).flip(3).flip(2) STAIN_to_PAS_img = STAIN_to_PAS_img.flip(2) - prediction += torch.softmax(segModel(STAIN_to_PAS_img), 1).flip(3) + prediction += segModel(STAIN_to_PAS_img).flip(3) STAIN_to_PAS_img = STAIN_to_PAS_img.flip(3) prediction /= 4.