From d6352a32472e28eb008f3134daaf1f83ee120605 Mon Sep 17 00:00:00 2001 From: James <49045138+ghidracadabra@users.noreply.github.com> Date: Thu, 1 Dec 2022 18:02:28 +0000 Subject: [PATCH] GP-2906_James_exhaust_function_interiors --- .../RandomForestTrainingTask.java | 9 ++++-- .../RandomForestTrainingTaskTest.java | 29 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/Ghidra/Extensions/MachineLearning/src/main/java/ghidra/machinelearning/functionfinding/RandomForestTrainingTask.java b/Ghidra/Extensions/MachineLearning/src/main/java/ghidra/machinelearning/functionfinding/RandomForestTrainingTask.java index a6c539b82f..7313c57f88 100644 --- a/Ghidra/Extensions/MachineLearning/src/main/java/ghidra/machinelearning/functionfinding/RandomForestTrainingTask.java +++ b/Ghidra/Extensions/MachineLearning/src/main/java/ghidra/machinelearning/functionfinding/RandomForestTrainingTask.java @@ -277,15 +277,18 @@ public class RandomForestTrainingTask extends Task { monitor.setMessage( "Selecting " + numEntries * factor + " random addresses within function interiors"); start = System.nanoTime(); - AddressSetView randomFuncInteriors = - RandomSubsetUtils.randomSubset(selectableInteriors, numEntries * factor, monitor); + long numInteriors = numEntries * factor; + + AddressSetView randomFuncInteriors = numInteriors < selectableInteriors.getNumAddresses() + ? RandomSubsetUtils.randomSubset(selectableInteriors, numInteriors, monitor) + : selectableInteriors; end = System.nanoTime(); Msg.info(this, String.format("factor: %d elapsed selecting random interiors: %g seconds", factor, (end - start) / NANOSECONDS_PER_SECOND)); trainingNegative = trainingNegative.union(randomFuncInteriors); if (trainingNegative.isEmpty()) { Msg.showError(this, null, "Data Gathering Error", - "No function interiors in training set"); + "No non-starts in training set for sampling factor " + factor); return null; } if (trainingPositive.intersects(trainingNegative)) { diff --git a/Ghidra/Extensions/MachineLearning/src/test.slow/java/ghidra/machinelearning/functionfinding/RandomForestTrainingTaskTest.java b/Ghidra/Extensions/MachineLearning/src/test.slow/java/ghidra/machinelearning/functionfinding/RandomForestTrainingTaskTest.java index 97d0b94a10..277faf5298 100644 --- a/Ghidra/Extensions/MachineLearning/src/test.slow/java/ghidra/machinelearning/functionfinding/RandomForestTrainingTaskTest.java +++ b/Ghidra/Extensions/MachineLearning/src/test.slow/java/ghidra/machinelearning/functionfinding/RandomForestTrainingTaskTest.java @@ -358,4 +358,33 @@ public class RandomForestTrainingTaskTest extends AbstractProgramBasedTest { assertTrue(data.getTestNegative().contains(definedData)); } + @Test + public void testExhaustingFunctionInteriors() throws CancelledException { + params = new FunctionStartRFParams(program); + params.setMaxStarts(5); + int tooBig = 10; + Address begin = program.getSymbolTable().getSymbols("entry").next().getAddress(); + AddressSet entries = new AddressSet(); + for (int i = 0; i < 10; ++i) { + entries.add(begin.add(i)); + } + AddressSet interiors = new AddressSet(); + for (int i = 10; i < 25; ++i) { + interiors.add(begin.add(i)); + } + AddressSet definedData = new AddressSet(); + for (int i = 25; i < 30; ++i) { + definedData.add(begin.add(i)); + } + RandomForestTrainingTask task = new RandomForestTrainingTask(program, params, null, + RandomForestFunctionFinderPlugin.TEST_SET_MAX_SIZE_DEFAULT); + TrainingAndTestData data = + task.getTrainingAndTestData(entries, interiors, definedData, tooBig, TaskMonitor.DUMMY); + assertTrue(data.getTrainingPositive().getNumAddresses() == 5); + assertTrue(data.getTestPositive().getNumAddresses() == 5); + assertTrue(data.getTestPositive().union(data.getTrainingPositive()).equals(entries)); + assertTrue(data.getTrainingNegative().equals(interiors)); + assertTrue(data.getTestNegative().equals(definedData)); + } + }