FICO
FICO Xpress Optimization Examples Repository
FICO Optimization Community FICO Xpress Optimization Home
Back to examples browserPrevious exampleNext example

Solution enumeration

Description
Example of enumerating the n-best solutions when solving a MIP.

Further explanation of this example: 'Xpress Python Reference Manual'


Source Files
By clicking on a file name, a preview is opened at the bottom of this page.
solenum.py[download]





solenum.py

# Example of enumerating the n-best solutions when solving a MIP.
#
# The program reads a problem from a file and collects solutions into a pool.
# Depending on the setting of a searchMethod parameter, it will enumerate
# additional solutions. There is an optional parameter to specify the
# maximum number of solutions to collect.
#
# The parameter searchMethod can be 0, 1 or 2. For the value of 0 it just
# collects solutions. For a value of 1 it continues the search until it has
# found the n best solutions that are reachable through the branch-and-bound
# process. The value of 2 ensures the n-best solutions are returned.
#
# The example implements its own solution pool with solutions stored in order
# of objective value, and implements duplication checks. Almost all the
# interesting work happens within the preintsol callback. It collects
# solutions, checks for duplicates and adjusts the cutoff for the remaining
# search. The adjustment of the cutoff is crucial to ensure additional
# solutions are found.
#
# To guarantee that we find the n best solutions, problem.loadBranchDirs is
# used to force the Optimizer into exhaustive branching. This function is used
# to specify the subset of integer variables that should be different in the
# collected solutions. Everything else is treated as a duplicate.
#
# (C) 2025 Fair Isaac Corporation

import xpress as xp
import numpy as np
import argparse
import math


def HashSolution(solution, colIsEnumerated):
    """Calculates a simple hash across the solution values of enumerated columns
    """
    return hash(tuple(round(x) for x, ifenum in zip(solution, colIsEnumerated) if ifenum))


class Solution:
    """Solution class for solutions in a pool
    """
    def __init__(self, colIsEnumerated, x, objval):
        self.x = x
        self.colIsEnumerated = colIsEnumerated
        self.hash = HashSolution(x, colIsEnumerated)
        self.objval = objval

    def __eq__(self, other):
        """Compares solutions for equality based on the integer columns
        """
        for col in range(len(self.x)):
            if self.colIsEnumerated[col] == 1:
                # integral value of should be the same
                if math.fabs(self.x[col] - other.x[col]) > 0.5:
                    return False

        return True

    def __hash__(self):
        return self.hash


class SolutionPool:
    """SolutionPool class for storing a set of distinct best solutions
    """
    def __init__(self, ncols, isMinimization, maxSolutions):
        self.solList = []       # Solutions ordered from best to worst
        self.solDict = dict()   # For identifying duplicate solutions
        self.ncols = ncols
        self.isMinimization = isMinimization
        self.maxSolutions = maxSolutions

    def isSolBetter(self, sol1, sol2):
        if self.isMinimization:
            return sol1.objval < sol2.objval

        return sol1.objval > sol2.objval

    def delSolution(self, sol):
        del self.solDict[sol]
        self.solList.remove(sol)

    def addSolution(self, sol):
        if sol in self.solDict:
            # The solution is already known, keep the duplicate with the best objective
            duplicateSol = self.solDict[sol]

            if not self.isSolBetter(sol, duplicateSol):
                return

            # Previous solution had worse objective value, delete it from the pool
            self.delSolution(duplicateSol)

        # Add solution to the pool
        self.solDict[sol] = sol

        # Insert the solution in the correct position in the list
        idx = 0
        while idx < len(self.solList) and not self.isSolBetter(sol, self.solList[idx]):
            idx += 1

        self.solList.insert(idx, sol)

        # Remove the worst solutions if we have too many
        while len(self.solList) > self.maxSolutions:
            self.delSolution(self.solList[-1])


class CbData:
    """Callback data
    """
    def __init__(self):
        self.pool = None
        self.colIsEnumerated = None
        self.isMinimization = True
        self.searchMethod = 2


def cbPreIntSol(prob, cbData, soltype, cutoff):

    # Collect the new solution.

    # We have to use prob.getCallbackSolution() to retrieve the solution since
    # it has not been installed as the incumbent yet
    x = prob.getCallbackSolution()

    # Get solution objective value
    objval = prob.attributes.lpobjval

    # Add the new solution to our pool.
    sol = Solution(cbData.colIsEnumerated, x, objval)
    cbData.pool.addSolution(sol)

    newcutoff = cutoff
    if cbData.searchMethod == 0:
        # We just collect the solutions and don't adjust the search.
        return (0, newcutoff)

    # Adjust the cutoff so we continue finding solutions.
    if len(cbData.pool.solList) >= cbData.pool.maxSolutions:
        # We already have the required number of solutions, so set the
        # cutoff such that we only search for improving solutions.
        # We will ask for something slightly better than the worst we
        # have collected.
        newcutoff = cbData.pool.worstsol.objval + (-1e-6 if cbData.isMinimization else +1e-6)
    else:
        # We don't have enough solutions yet, so any solution is acceptable.
        newcutoff = +1e+40 if cbData.isMinimization else -1e+40

    return (0, newcutoff)


def select_enumeration_columns(prob, ncols, colIsEnumerated):
    coltype = prob.getColType()

    # Identify the integer restricted columns.
    for col in range(ncols):
        if (coltype[col] == 'B') or (coltype[col] == 'I'):
          colIsEnumerated[col] = 1
        else:
          colIsEnumerated[col] = 0


def enforce_enumeration_columns(prob, ncols, colIsEnumerated):
  # We use the Xpress branch directives to force Xpress to continue branching
  # on any column that we want enumerated.
  colSelect = []
  for col in range(ncols):
    if colIsEnumerated[col] == 1:
      colSelect.append(col)

  prob.loadBranchDirs(colSelect)


def main():
    parser = argparse.ArgumentParser(
        description='Enumerates and collects multiple solutions to a MIP'
    )
    parser.add_argument('filename', help='MPS or LP file to solve')
    parser.add_argument('method', help="""How to collect solutions:
        0 - normal solve;
        1 - extended search;
        2 - n-best search
        """, type=int, choices=[0, 1, 2])
    parser.add_argument('max_sols', help='Maximum number of solutions to collect',
                        type=int, nargs='?', default=10)
    args = parser.parse_args()

    # Create the problem
    prob = xp.problem(f'solenum {args.filename}')

    # Read the file
    prob.readProb(args.filename)

    # Set up the callback data for the preintsol callback
    ncols = prob.attributes.cols
    objectiveSense = prob.attributes.objsense
    cbData = CbData()
    cbData.isMinimization = True if objectiveSense > 0.0 else False
    cbData.pool = SolutionPool(ncols, cbData.isMinimization, max(1, args.max_sols))
    cbData.searchMethod = args.method

    # Set up the enumeration columns
    cbData.colIsEnumerated = np.zeros(ncols, dtype=np.int8)
    select_enumeration_columns(prob, ncols, cbData.colIsEnumerated)
    if args.method >= 2:
        enforce_enumeration_columns(prob, ncols, cbData.colIsEnumerated)

    # Add the callback
    prob.addPreIntsolCallback(cbPreIntSol, cbData, 0)

    # Set important controls
    prob.controls.serializepreintsol = 1
    prob.controls.miprelstop = 0
    if args.method >= 2:
        prob.controls.mipdualreductions = 0

    # Optimize
    prob.optimize()

    # Print out results
    print()
    if len(cbData.pool.solList) == 0:
        print('No solutions collected!')
    else:
        print(f'Collected {len(cbData.pool.solList)} solutions with objective values:')
        for sol in cbData.pool.solList:
            print(f'   {sol.objval}')


main()

Back to examples browserPrevious exampleNext example