import random
import time

def print_puzzle(puzzle):
    """Pretty-print a 3x3 puzzle given as a flat tuple/list of length 9."""
    for r in range(3):
        row = puzzle[r * 3:(r + 1) * 3]
        print(" ".join(str(x) for x in row))

def find_zero(puzzle):
    """Return (row, col) of the empty tile (0) in a flat puzzle."""
    idx = puzzle.index(0)
    return divmod(idx, 3)  # (y, x)

def getPossibleMoves(emptyTileLocationY, emptyTileLocationX):
    """
    get possible moves given coordinates (row, col)
    of the zero tile in a 3x3 puzzle (0-based indexing).
    Returns a list of (y, x) positions that can be swapped with zero.
    """
    moves = []

    if emptyTileLocationY > 0:
        moves.append((emptyTileLocationY - 1, emptyTileLocationX))
    if emptyTileLocationY < 2:
        moves.append((emptyTileLocationY + 1, emptyTileLocationX))
    if emptyTileLocationX > 0:
        moves.append((emptyTileLocationY, emptyTileLocationX - 1))
    if emptyTileLocationX < 2:
        moves.append((emptyTileLocationY, emptyTileLocationX + 1))

    return moves

def moveTile(puzzle, yEmptyTile, xEmptyTile, yMovedTile, xMovedTile):
    """
    Moves a tile in a puzzle.

    puzzle: flat tuple/list length 9
    (yEmptyTile, xEmptyTile): coordinates of 0 tile
    (yMovedTile, xMovedTile): coordinates of tile to move into the empty space
    """
    idx_empty = yEmptyTile * 3 + xEmptyTile
    idx_moved = yMovedTile * 3 + xMovedTile

    new_p = list(puzzle)
    # Move tile into the empty position
    new_p[idx_empty] = puzzle[idx_moved]
    # Move the zero tile to the old position
    new_p[idx_moved] = 0

    return tuple(new_p)

def generatePuzzle(numberOfMoves, goalState):
    """
    Generates a puzzle by applying numberOfMoves random valid moves
    starting from the goalState.

    goalState is a flat tuple/list of length 9 (row-major).
    """
    puzzle = list(goalState)

    # remember last position of empty tile to prevent immediate backtracking
    lastEmptyTileLocationY = -1
    lastEmptyTileLocationX = -1

    for _ in range(numberOfMoves):
        # get position of empty tile
        emptyTileLocationY, emptyTileLocationX = find_zero(puzzle)

        # get all possible moves
        possibleMoves = getPossibleMoves(emptyTileLocationY, emptyTileLocationX)

        # choose a random move that does not just revert the last move
        found = False
        while not found:
            y, x = random.choice(possibleMoves)
            if (y != lastEmptyTileLocationY) or (x != lastEmptyTileLocationX):
                found = True

        # perform the move
        puzzle = list(moveTile(puzzle,
                               emptyTileLocationY,
                               emptyTileLocationX,
                               y, x))

        # remember last empty tile position
        lastEmptyTileLocationY = emptyTileLocationY
        lastEmptyTileLocationX = emptyTileLocationX

    return tuple(puzzle)

def isGoalState(currentState, goalState):
    """ checks if currentState and goalState are equal."""
    return currentState == goalState

def numberOfWrongTiles(puzzle, goalState):
    """
    calculates the number of wrong tiles of puzzle,
    given the goalState.
    Both puzzle and goalState are flat sequences of length 9.
    """
    return sum(1 for a, b in zip(puzzle, goalState) if a != b)

def manhattanDist(puzzle, goalState):
    """
    calculates the Manhattan distance between two puzzles.
    puzzle and goalState are flat sequences of length 9.
    """
    n = 0
    for value in range(9):  # tiles 0..8
        posPuzzle = puzzle.index(value)
        yPuzzle, xPuzzle = divmod(posPuzzle, 3)

        posGoal = goalState.index(value)
        yGoal, xGoal = divmod(posGoal, 3)

        n += abs(xPuzzle - xGoal) + abs(yPuzzle - yGoal)
    return n

def runSearch(puzzle, goalState, params):
    """
    runs a search on a given puzzle.

    puzzle, goalState: flat tuples/lists of length 9.
    params: dict with keys:
        'strategy'      : 'dfs' | 'bfs' | 'ids' |
                          'heuristic_wrong_tiles' | 'heuristic_manhattan' |  'dfs_wo_cycles'
        'maxIterations' : int
        'silent'        : bool
    Returns (iteration, maxQueueSize).
    """
    strategy = params['strategy']
    if isGoalState(puzzle, goalState):
        return 0, 0

    initial_puzzle = puzzle
    found = False

    # queue entries are tuples:
    # (puzzle, lastEmptyY, lastEmptyX, pathCost, estimatedTotalCost)
    queue = []

    # if we use IDS, we store the current max depth here
    maxDepthIDS = 1
    
    # visited set for dfs_wo_cycles (to avoid cycles)
    visited = set()
    if params['strategy'] == 'dfs_wo_cycles':
        visited.add(initial_puzzle)
    
    # initialize queue with initial puzzle
    queue.append((initial_puzzle, -1, -1, 0, 0))
    maxQueueSize = len(queue)

    iteration = 0
    for iteration in range(1, params['maxIterations'] + 1):
        # output a . every 5000 iterations to indicate progress
        if (iteration + 1) % 5000 == 0 and not params.get('silent', False):
            print('.', end='', flush=True)

        if not queue:
            # queue empty, nothing left to expand
            break

        # --- Pop next node depending on strategy (stack vs queue) ---
        if strategy in ('dfs', 'dfs_wo_cycles', 'ids'):
            # DFS-like strategies: use stack (LIFO)
            (puzzle,
             lastEmptyTileLocationY,
             lastEmptyTileLocationX,
             currentPuzzleCost,
             currentPuzzleEstimatedCost) = queue.pop()
        else:
            # BFS / A*: use queue (FIFO)
            (puzzle,
             lastEmptyTileLocationY,
             lastEmptyTileLocationX,
             currentPuzzleCost,
             currentPuzzleEstimatedCost) = queue.pop(0)

        # get current position of empty tile
        emptyTileLocationY, emptyTileLocationX = find_zero(puzzle)

        # get all possible moves
        possibleMoves = getPossibleMoves(emptyTileLocationY, emptyTileLocationX)

        # generate children
        children = []
        for y, x in possibleMoves:
            # skip move that simply reverts the previous one
            if (y != lastEmptyTileLocationY) or (x != lastEmptyTileLocationX):
                changedPuzzle = moveTile(puzzle,
                                         emptyTileLocationY,
                                         emptyTileLocationX,
                                         y, x)

                if isGoalState(changedPuzzle, goalState):
                    found = True
                    puzzle = changedPuzzle
                    break

                children.append(changedPuzzle)

        if found:
            break

        # --------- DFS --------- #
        if strategy == 'dfs':
            # Choose a random child; no real queue, just go down one branch
            if children:
                chosenChild = random.choice(children)
                queue = [(chosenChild, -1, -1, 0, 0)]
        
       # --------- DFS without cycles --------- #
        elif strategy == 'dfs_wo_cycles':
            # filter out already visited states
            new_children = []
            for child in children:
                if child not in visited:
                    visited.add(child)
                    new_children.append(child)
            children = new_children
            
            # push all children onto stack (LIFO)
            for child in children:
                queue.append(
                    (child, emptyTileLocationY, emptyTileLocationX, currentPuzzleCost + 1, 0)
                )

        # --------- BFS --------- #
        elif strategy == 'bfs':
            for child in children:
                queue.append((child, emptyTileLocationY, emptyTileLocationX, 0, 0))

        # --------- IDS --------- #
        elif strategy == 'ids':
            # push all children onto the stack (LIFO)
            for child in children:
                queue.append(
                    (child, emptyTileLocationY, emptyTileLocationX, currentPuzzleCost + 1, 0)
                )

            # prune nodes deeper than the current depth limit
            while queue and queue[-1][3] > maxDepthIDS:
                queue.pop()

            # if nothing left, increase depth and restart from root
            if not queue:
                maxDepthIDS += 1
                queue.append((initial_puzzle, -1, -1, 0, 0))

        # --------- A* with different heuristics --------- #
        else:
            # calculates estimated cost to goal
            estimatedCostToGoal = []
            for child in children:
                if strategy == 'heuristic_wrong_tiles':
                    h = numberOfWrongTiles(child, goalState)
                elif strategy == 'heuristic_manhattan':
                    h = manhattanDist(child, goalState)
                else:
                    raise ValueError(f"Search Strategy {strategy} unknown!")
                estimatedCostToGoal.append(h)
            
            # add children to queue, respecting estimated cost to goal
            for child, est in zip(children, estimatedCostToGoal):
                new_cost = currentPuzzleCost + 1
                queue.append(
                    (child, emptyTileLocationY, emptyTileLocationX, new_cost, new_cost + est)
                )

            # sort queue by estimated total cost (A*)
            queue.sort(key=lambda node: node[4])

        maxQueueSize = max(maxQueueSize, len(queue))

    if not params.get('silent', False):
        if found:
            print(f"\n == Solution found after {iteration} iterations! ==")
        else:
            print(f"\n == No solution found after {iteration} iterations! ==")
    else:
        if not found:
            print("WARNING: no solution found (test results invalid)")

    return iteration, maxQueueSize

def testSearch(searchStrategy, numberOfRuns, shuffleSteps):
    """
    tests a search strategy over multiple random puzzles.

    searchStrategy: as in params['strategy']
    numberOfRuns  : number of random puzzles
    shuffleSteps  : number of random moves to generate each puzzle
    """
    params = {
        'strategy': searchStrategy,
        'maxIterations': 2000000,
        'silent': True
    }
    goalState = (1, 2, 3, 8, 0, 4, 7, 6, 5)

    totalIterations = 0
    totalMaxQueueSize = 0
    start = time.perf_counter()

    for _ in range(numberOfRuns):
        puzzle = generatePuzzle(shuffleSteps, goalState)
        usedIterations, maxQueueSize = runSearch(puzzle, goalState, params)
        totalIterations += usedIterations
        totalMaxQueueSize += maxQueueSize

    totalTime = time.perf_counter() - start
    averageIterations = totalIterations / numberOfRuns
    averageMaxQueueSize = totalMaxQueueSize / numberOfRuns
    averageTime = totalTime / numberOfRuns

    print("---- Test Results ---")
    print(f"Strategy      : {params['strategy']}")
    print(f"averaged over : {numberOfRuns} runs")
    print(f"difficulty    : {shuffleSteps} shuffles")
    print(f"Iterations    : {averageIterations:.1f}")
    print(f"MaxQueueSize  : {averageMaxQueueSize:.1f}")
    print(f"Time (ms)     : {averageTime * 1000:.2f}")

def main():
    # Define goal state and difficulty by number of shuffle steps
    goalState = (1, 2, 3,
                 8, 0, 4,
                 7, 6, 5)
    shuffleSteps = 10

    # Define search parameters
    # 'dfs', 'bfs', 'ids', 'heuristic_wrong_tiles', 'heuristic_manhattan'
    # 'dfs' - depth first search
    # 'dfs_wo_cycles' - depth first search
    # 'bfs' - breadh first search
    # 'ids' - iterative deepening search
    # 'heuristic_wrong_tiles' - A*, number of wrong tiles estimate
    # 'heuristic_manhattan' - A*, manhattan distances estimate
    params = {
        'strategy': 'heuristic_manhattan',
        'maxIterations': 2000000, # (to prevent endless loop)
        'silent': False # output with details?
    }

    # generate initial puzzle
    initial_puzzle = generatePuzzle(shuffleSteps, goalState)

    # produce output
    print("\n------- Search ------ ")
    print(f"Strategy      : {params['strategy']}")
    print(f"difficulty    : {shuffleSteps} shuffles")
    print("Initial puzzle:")
    print_puzzle(initial_puzzle)
    input("\nPress ENTER to start the search!\n")

    print("Search started")
    start = time.perf_counter()
    usedIterations, maxQueueSize = runSearch(initial_puzzle, goalState, params)
    elapsed = time.perf_counter() - start

    print()
    print(f"Iterations    : {usedIterations:.1f}")
    print(f"MaxQueueSize  : {maxQueueSize:.1f}")
    print(f"Elapsed time  : {elapsed:.3f} s")
    print("\nTry changing the search strategy or the difficulty!\n")

if __name__ == "__main__":
    #main()
    
    """
    Benchmark all search strategies (BFS, DFS, DFS w/o cycles, IDS, A* variants)
    for 3 puzzle difficulties: 5, 10, and 15 shuffles.
    """
    strategies = [
        #"bfs",
        "dfs",
        #"dfs_wo_cycles",
        #"ids",
        #"heuristic_wrong_tiles",
        #"heuristic_manhattan"
    ]
    shuffle_levels = [5, 10, 15]
    number_of_runs = 500

    print("=== 8-Puzzle Search Benchmark ===")
    print(f"Each configuration averaged over {number_of_runs} random puzzles.\n")

    start_total = time.perf_counter()

    for strategy in strategies:
        print(f"\n===== Strategy: {strategy.upper()} =====")
        for shuffles in shuffle_levels:
            print(f"\n--- Difficulty: {shuffles} shuffles ---")
            testSearch(strategy, number_of_runs, shuffles)

    total_time = time.perf_counter() - start_total
    print(f"\n=== All tests finished in {total_time:.2f} s ===")