Q-Intro: Optimizing Functions

Exploring Multivariable, Convexity, and Derivatives for Gradient Descent and Beyond

Anonymousket
9 min readApr 21, 2024

In this article, we discuss about in detail about functions, multivariable functions, convex and concave functions, derivatives, partial derivatives, gradient, maxima and minima.

Function:

Functions in mathematics are rules that establish a relationship between input values (arguments) and output values. They are often denoted as f(x) or y = f(x), where x represents the input and f(x) denotes the output.

Functions can take various forms and shapes, such as linear, quadratic, exponential, trigonometric, etc. They can be continuous or discrete, and their behavior can be described by their domain, range, and various properties like symmetry, periodicity, and extrema.

Below we see a sample code in python along with a graph for quadratic function x² -4x + 4

import required modules

import numpy as np
import matplotlib.pyplot as plt

quadratic function

# Define the quadratic function
def quadratic_function(x):
return x**2 - 4*x + 4

Generate x and y values

# Generating x values from -2 to 6
x_values = np.linspace(-2, 6, 100)

# Calculate corresponding y values using the function
y_values = quadratic_function(x_values)

plot the graph

# Plotting the graph
plt.figure(figsize=(8, 6))
plt.plot(x_values, y_values, label='y = x^2 - 4x + 4', color='blue')
plt.title('Graph of a Quadratic Function')
plt.xlabel('x-axis')
plt.ylabel('y-axis')
plt.axhline(0, color='black',linewidth=0.5)
plt.axvline(0, color='black',linewidth=0.5)
plt.grid(color = 'gray', linestyle = '--', linewidth = 0.5)
plt.legend()
plt.show()

Multivariable Function:

A multivariable function involves multiple variables that contribute to the function’s output. In mathematics, a multivariable function f(x,y) or f(x1​,x2​,…,xn​) takes more than one input and produces a single output.

Below we see a sample code in python along with a graph for multivariable function f(x, y) = +

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

defining the function

# Define the function
def multivariable_function(x, y):
return x**2 + y**2

generate x, y and z values

# Generate data
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
x, y = np.meshgrid(x, y)
z = multivariable_function(x, y)

plot the graph

# Plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(x, y, z, cmap='viridis')

# Labeling axes
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.set_zlabel('Z-axis (f(x, y))')
ax.set_title('Graph of f(x, y) = x^2 + y^2')
plt.show()

Convex and Concave Functions:

Convex Functions:

A function f(x) defined on an interval is considered convex if, for any two points x1​ and x2​ within the interval and for any λ in the range [0, 1]

In simpler terms, a function is convex if the line segment between any two points on the graph of the function lies above the graph itself.

below is the python code and plotted graph for a sample convex function

import required modules

import numpy as np
import matplotlib.pyplot as plt

Define a convex function

# Define a convex function: f(x) = x^2
def convex_function(x):
return x**2
# Generate x values
x_values = np.linspace(-5, 5, 100)
# Calculate corresponding y values for convex functions
convex_y_values = convex_function(x_values)
# Plotting convex function
plt.figure(figsize=(6, 4))
plt.plot(x_values, convex_y_values, label='Convex: f(x) = x^2')
plt.title('Convex Function')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.show()

Concave Function:

Concave functions are the opposite of convex functions. A function f(x) is considered concave if, for any two points x1​ and x2​ within the interval and for any λ in the range [0, 1]

In simpler terms, a function is concave if the line segment between any two points on the graph of the function lies below the graph itself.

below is the python code and plotted graph for a sample concave function

import required modules

import numpy as np
import matplotlib.pyplot as plt

Define concave function

# Define a concave function: f(x) = -x^2
def concave_function(x):
return -x**2
# Generate x values
x_values = np.linspace(-5, 5, 100)
# Calculate corresponding y values for concave functions
concave_y_values = concave_function(x_values)
# Plotting concave function
plt.figure(figsize=(6, 4))
plt.plot(x_values, concave_y_values, label='Concave: f(x) = -x^2')
plt.title('Concave Function')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend()
plt.show()

Derivatives:

Derivatives in calculus represent the rate of change of a function concerning its variables. They indicate how a function’s output changes concerning its input. A derivative at a specific point on a curve gives the slope of the tangent line to the curve at that point. It helps in optimizing functions by finding critical points where the function reaches minimum or maximum values.

where the slope or tangent the line at the point says, a small change in y direction to the small change in the x direction.

Below we see a sample code in python along with a graph for derivative using sympy, numpy

import required modules

import sympy as sp
import numpy as np
import matplotlib.pyplot as plt

define the variable

# Define the variable
x = sp.symbols('x')

define the function, f(x) = x³ — 2x² + x + 3

# Define a function
f = x**3 - 2*x**2 + x + 3

calculating the derivative of the function

# Calculate the derivative of f(x)
f_derivative = sp.diff(f, x)

convert the sympy data compatible to numpy, scipy etc

# Convert the SymPy expression to a callable function
f_derivative_func = sp.lambdify(x, f_derivative, 'numpy')

The primary purpose of this function is to provide a bridge from SymPy expressions to numerical libraries such as NumPy, SciPy. In general, SymPy functions do not work with objects from other libraries, such as NumPy arrays, and functions from numeric libraries like NumPy do not work on SymPy expressions. lambdify bridges the two by converting a SymPy expression to an equivalent numeric function.

generate values x, y and y derivatives

# Generate x values
x_vals = np.linspace(-2, 3, 400)

# Calculate y values for both the original function and its derivative
y = x_vals**3 - 2*x_vals**2 + x_vals + 3
y_derivative = f_derivative_func(x_vals)

plot the graph for function and derivative

# Plotting the function and its derivative
plt.figure(figsize=(10, 6))

# Plot the original function
plt.plot(x_vals, y, label='f(x) = x^3 - 2x^2 + x + 3', color='blue')

# Plot the derivative function
plt.plot(x_vals, y_derivative, label="f'(x)", color='red')

plt.title("Function and its Derivative")
plt.xlabel("x")
plt.ylabel("y")
plt.axhline(0, color='black',linewidth=0.5)
plt.axvline(0, color='black',linewidth=0.5)
plt.legend()
plt.grid(True)
plt.show()

Partial Derivatives:

Partial derivatives are derivatives of multivariable functions concerning a single variable while holding other variables constant. They represent the rate of change of a function concerning one variable at a specific point while keeping other variables fixed. This concept is crucial in calculus, especially in fields like machine learning and optimization, where functions involve multiple variables.

Below we see a sample code in python along with a graph for partial derivative, f(x, y) = x² + y²

import required modules

import numpy as np
import matplotlib.pyplot as plt

define the function

# Define the function
def f(x, y):
return x**2 + y**2

partial derivative of the function with respective to x and y

# Define the partial derivatives of f with respect to x and y
def partial_derivative_x(x, y):
return 2 * x

def partial_derivative_y(x, y):
return 2 * y

Generate the data x, y and z values along with partial x and partial y values at (2, 3)

# Generate data for the partial derivatives
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)

# Compute partial derivatives at a specific point (e.g., x=2, y=3)
point_x, point_y = 2, 3
partial_x = partial_derivative_x(point_x, point_y)
partial_y = partial_derivative_y(point_x, point_y)

plot the graph

# Display the function surface
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.set_zlabel('Z-axis')

# Plot the point and its partial derivatives
ax.scatter(
point_x,
point_y,
f(point_x, point_y),
color='red',
s=100,
label='Point (2, 3)'
)
ax.text(
point_x,
point_y,
f(point_x, point_y),
"(2, 3)",
color='black',
fontsize=12
)
ax.quiver(
point_x,
point_y,
f(point_x, point_y),
partial_x,
partial_y,
0.1,
color='green',
label='Partial Derivatives'
)
ax.legend()
plt.title('Partial Derivatives of f(x, y) = x^2 + y^2')
plt.show()

Gradient:

The gradient represents the slope or rate of change of a function at a particular point. For a multivariable function, the gradient is a vector of partial derivatives with respect to each variable. The gradient points in the direction of the steepest increase of the function.

suppose consider a function f at a point (x, y) then gradient is represented as a vector of the partial derivatives as shown below

where ∇ is called Nabla which acts as an operator and represents the vector of partial derivatives for the given function w.r.t to each variable or dimension.

below is the python code and graph for the gradient of a function

import required modules

import numpy as np
import matplotlib.pyplot as plt

define the function, f(x) = x² and finds its derivative

# Define the function: 
def f(x):
return x**2

# Define the derivative of the function: f'(x) = 2x
def df(x):
return 2*x

generate the x values and gradient

# Generate x values from -5 to 5
x = np.linspace(-5, 5, 100)

# Calculate function values and gradient values
y = f(x)
gradient = df(x)

plot the graph

# Plot the function and its gradient
plt.figure(figsize=(6, 4))
plt.plot(x, y, label='f(x) = x^2')
plt.plot(x, gradient, label="Gradient of f(x)")
plt.xlabel('x')
plt.ylabel('y / Gradient')
plt.title('Function and its Gradient')
plt.axhline(0, color='black',linewidth=0.5)
plt.axvline(0, color='black',linewidth=0.5)
plt.legend()
plt.grid(True)
plt.show()

The graph generated by the code illustrates the function f(x) = and its gradient f’(x) = 2x.

The x-axis represents the input values, the y-axis represents the function values for f(x) = , and another red line represents the gradient of the function f’(x)=2x.

This visual representation helps to understand how the gradient changes concerning the function itself.

Maxima and Minima:

Maxima and Minima refer to the highest and lowest points, respectively, in a function or dataset. In mathematical terms, they are critical points where the derivative of a function equals zero or where it changes from positive to negative (for maxima) or negative to positive (for minima).

below is the sample python code which plot the graph for minima, maxima and critical points

import required modules

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize_scalar

define the function, x⁵–8.x³ + 18.x² + 12.x + 7

# Define the function
def f(x):
return x**5 - 8*x**3 + 18*x**2 + 12*x + 7

# Derivative of the function
def derivative_f(x):
return 5*x**4 - 24*x**2 + 36*x + 12

generate x and y values

# Create an array of x values
x = np.linspace(-2, 5, 1000)

# Calculate corresponding y values
y = f(x)

plot the function

# Plot the function
plt.figure(figsize=(6, 4))
plt.plot(x, y, label='f(x) = x^5 - 8x^3 + 18x^2 + 12x + 7')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title('Plot of the Function')
plt.axhline(0, color='black',linewidth=0.5)
plt.axvline(0, color='black',linewidth=0.5)
plt.show()

Finding critical points (where derivative = 0)

# Finding critical points
critical_points = [minimize_scalar(derivative_f, bounds=(-2, 5)).x]

plot the critical values

# Plotting critical points
plt.scatter(
critical_points,
f(np.array(critical_points)),
color='red',
label='Critical Points'
)

# Show maxima and minima
maximum = max(f(np.array(critical_points)))
minimum = min(f(np.array(critical_points)))
plt.scatter(
critical_points[np.argmax(f(np.array(critical_points)))],
maximum,
color='green',
label=f'Maximum: {round(maximum, 2)}'
)
plt.scatter(
critical_points[np.argmin(f(np.array(critical_points)))],
minimum,
color='blue',
label=f'Minimum: {round(minimum, 2)}'
)
# Highlight maxima and minima points
plt.legend()
plt.grid(True)
plt.show()

updated: 22/04/2024 — code style updated to python.

References: Wikipedia, stack overflow, stack exchange, Open source materials.

--

--