import numpy as np
from typing import Callable, Tuple
import torch
from jaxtyping import Array, Float, Int  # type: ignore


def create_zero_vector(size: int = 10) -> Float[Array, " size"]:
    """
    Exercise 1: Create a zero vector of specified size.

    Args:
        size: Size of the vector (default: 10)

    Returns:
        Zero vector of specified size
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def create_diagonal_matrix(
    size: int = 10, diag_value: int = -1
) -> Int[Array, "size size"]:
    """
    Exercise 2: Create an int64 matrix with diagonal values set to specified value.

    Args:
        size: Size of the square matrix (default: 10)
        diag_value: Value to set on diagonal (default: -1)

    Returns:
        int64 matrix with specified diagonal values
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def create_checkerboard(size: int = 10) -> Float[Array, "size size"]:
    """
    Exercise 3: Create a checkerboard pattern matrix.

    Args:
        size: Size of the square matrix (default: 10)

    Returns:
        Matrix with checkerboard pattern (0s and 1s)
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def place_random_ones(
    matrix_shape: Tuple[int, int] = (8, 8), num_ones: int = 5
) -> Float[Array, "height width"]:
    """
    Exercise 4: Randomly place specified number of 1's in a zero matrix.

    Args:
        matrix_shape: Shape of the matrix (default: (8, 8))
        num_ones: Number of 1's to place (default: 5)

    Returns:
        Matrix with randomly placed 1's
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def channel_last_to_first(
    image: Float[Array, "height width channels"],
) -> Float[Array, "channels height width"]:
    """
    Exercise 5: Convert channel-last image to channel-first.

    Args:
        image: Image tensor in channel-last format (H, W, C)

    Returns:
        Image tensor in channel-first format (C, H, W)
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def rgb_to_bgr(image: Float[Array, "3 height width"]) -> Float[Array, "3 height width"]:
    """
    Exercise 6: Convert RGB image to BGR by swapping color channels.

    Args:
        image: RGB image in channel-first format (3, H, W)

    Returns:
        BGR image in channel-first format (3, H, W)
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def negate_range_inplace(
    array: Float[Array, " n"], min_val: float = 3, max_val: float = 8
) -> Float[Array, " n"]:
    """
    Exercise 7: Negate all elements between min_val and max_val, in place.

    Args:
        array: 1D array to modify
        min_val: Minimum value of range (inclusive, default: 3)
        max_val: Maximum value of range (inclusive, default: 8)

    Returns:
        Modified array (same object, modified in place)
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def convert_dtype(array: Array, target_dtype: np.dtype) -> Array:
    """
    Exercise 8: Convert array to specified dtype.

    Args:
        array: Input array
        target_dtype: Target data type

    Returns:
        Array converted to target dtype
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def subtract_row_means(matrix: Float[Array, "rows cols"]) -> Float[Array, "rows cols"]:
    """
    Exercise 9: Subtract the mean of each row from the matrix.

    Args:
        matrix: Input matrix

    Returns:
        Matrix with row means subtracted
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def sort_by_column(
    matrix: Float[Array, "rows cols"], column_idx: int = 1
) -> Float[Array, "rows cols"]:
    """
    Exercise 10: Sort matrix by specified column.

    Args:
        matrix: Input matrix
        column_idx: Column index to sort by (default: 1 for second column)

    Returns:
        Matrix sorted by specified column
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def one_hot_encode(
    indices: Int[Array, " n"], num_classes: int = 10
) -> Int[Array, "n num_classes"]:
    """
    Exercise 11: Convert integer array to one-hot encoding.

    Args:
        indices: Array of integer indices
        num_classes: Number of classes (default: 10)

    Returns:
        One-hot encoded matrix of shape (len(indices), num_classes)
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def broadcast_multiply_rows(
    matrix: Float[Array, "rows cols"], vector: Float[Array, " rows"]
) -> Float[Array, "rows cols"]:
    """
    Exercise 12: Multiply nth row of matrix with nth element of vector using broadcasting.

    Args:
        matrix: Input matrix
        vector: Vector for multiplication

    Returns:
        Matrix with rows multiplied by corresponding vector elements
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def pad_with_zeros(
    array: Float[Array, "height width"], pad_width: int = 1
) -> Float[Array, "padded_height padded_width"]:
    """
    Exercise 13: Pad array with zeros without using np.pad.

    Args:
        array: Input array
        pad_width: Width of padding (default: 1)

    Returns:
        Padded array
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def numerical_gradient(
    func: Callable[[Float[Array, " n"]], float], x: Float[Array, " n"], h: float = 1e-5
) -> Float[Array, " n"]:
    """
    Exercise 13.a: Implement numerical gradient computation using central difference.

    Args:
        func: Function that takes numpy array and returns scalar
        x: Point at which to evaluate gradient
        h: Step size for finite difference (default: 1e-5)

    Returns:
        Approximate gradient at x
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def analytical_gradient_f1(x: Float[Array, " n"]) -> Float[Array, " n"]:
    """
    Exercise 13.b: Analytical gradient of f(x) = x^T x.

    Args:
        x: Input vector

    Returns:
        Analytical gradient: 2x
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def analytical_gradient_f2(x: Float[Array, "3"]) -> Float[Array, "3"]:
    """
    Exercise 13.b: Analytical gradient of f(x) = sin(x_1) + cos(x_2) + cos(x_3).

    Args:
        x: Input vector

    Returns:
        Analytical gradient: [cos(x_1), -sin(x_2), -sin(x_3)]
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def analytical_gradient_f3(
    x: Float[Array, " n"], A: Float[Array, "n n"]
) -> Float[Array, " n"]:
    """
    Exercise 13.b: Analytical gradient of f(x) = x^T A x.

    Args:
        x: Input vector
        A: Matrix A

    Returns:
        Analytical gradient: (A + A^T) x
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def complex_function(
    x: Float[Array, " n"], A: Float[Array, "n n"], b: Float[Array, " n"]
) -> float:
    """
    Exercise 13.c: Complex vector function for numerical differentiation.

    f(x) = (x^T A x) * sin(x^T b) + exp(-x^T x)

    Args:
        x: Input vector
        A: Matrix A
        b: Vector b

    Returns:
        Function value
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def complex_function_torch(
    x: Float[torch.Tensor, " n"],
    A: Float[torch.Tensor, "n n"],
    b: Float[torch.Tensor, " n"],
) -> Float[torch.Tensor, ""]:
    """
    Exercise 13.d: PyTorch version of complex function for automatic differentiation.

    Args:
        x: Input tensor
        A: Matrix tensor
        b: Vector tensor

    Returns:
        Function value as tensor
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


def compute_analytical_gradient_torch(
    x_np: Float[Array, " n"], A_np: Float[Array, "n n"], b_np: Float[Array, " n"]
) -> Float[Array, " n"]:
    """
    Exercise 13.d: Compute analytical gradient using PyTorch automatic differentiation.

    Args:
        x_np: Input vector as numpy array
        A_np: Matrix as numpy array
        b_np: Vector as numpy array

    Returns:
        Analytical gradient computed by PyTorch
    """
    ### BEGIN SOLUTION
    pass
    ### END SOLUTION


# Test functions for exercises 13.b
def f1(x: Float[Array, " n"]) -> float:
    """Test function 1: f(x) = x^T x"""
    return np.dot(x, x)


def f2(x: Float[Array, "3"]) -> float:
    """Test function 2: f(x) = sin(x_1) + cos(x_2) + cos(x_3)"""
    return np.sin(x[0]) + np.cos(x[1]) + np.cos(x[2])


def f3(x: Float[Array, " n"], A: Float[Array, "n n"]) -> float:
    """Test function 3: f(x) = x^T A x"""
    return np.dot(x, np.dot(A, x))
