"""
Utility functions.
"""
import os, sys, time, json, pickle, requests, importlib, pathlib
import json
from urllib.parse import urlparse, unquote
from ast import literal_eval
import pickle
import requests
import math, numbers
import numpy as np
import mlflow
import torch
import torchvision
from torchvision import datasets as tv_datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from modlee.client import ModleeClient
[docs]
def safe_mkdir(target_path):
"""
Safely make a directory.
:param target_path: The path to the target directory.
"""
root, ext = os.path.splitext(target_path)
# is a file
if len(ext) > 0:
target_path = os.path.split(root)
else:
target_path = f"{target_path}/"
# if os.path.isfile(target_dir):
# target_dir,_ = os.path.split(target_dir.split('.')[0])
if not os.path.exists(target_path):
os.mkdir(target_path)
[docs]
def get_fashion_mnist(batch_size=64, num_output_channels=1):
"""
Get the Fashion MNIST dataset from torchvision.
:param batch_size: The batch size, defaults to 64.
:param num_output_channels: Passed to torchvision.transforms.Grayscale. 1 = grayscale, 3 = RGB. Defaults to 1.
:return: A tuple of train and test dataloaders.
"""
data_transforms = torchvision.transforms.Compose([
transforms.Grayscale(num_output_channels=num_output_channels),
transforms.ToTensor(),
])
training_loader = DataLoader(
tv_datasets.FashionMNIST(
root="data", train=True, download=True, transform=data_transforms
),
batch_size=batch_size,
shuffle=True,
)
test_loader = DataLoader(
tv_datasets.FashionMNIST(
root="data", train=False, download=True, transform=data_transforms
),
batch_size=batch_size,
shuffle=True,
)
return training_loader, test_loader
[docs]
def uri_to_path(uri):
"""
Convert a URI to a path.
:param uri: The URI to convert.
:return: The converted path.
"""
parsed_uri = urlparse(uri)
path = unquote(parsed_uri.path)
return path
[docs]
def is_cacheable(x):
"""
Check if an object is cacheable / serializable.
:param x: The object to check cacheability, probably a dictionary.
:return: A boolean of whether the object is cacheable or not.
"""
try:
json.dumps(x)
return True
except:
return False
[docs]
def get_model_size(model, as_MB=True):
"""
Get the size of a model, as estimated from the number and size of its parameters.
:param model: The model for which to get the size.
:param as_MB: Whether to return the size in MB, defaults to True.
:return: The model size.
"""
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
model_size = param_size + buffer_size
if as_MB:
model_size /= 1024 ** 2
return model_size
[docs]
def quantize(x):
"""
Quantize an object.
:param x: The object to quantize.
:return: The object, quantized.
"""
if float(x) < 0.1:
ind = 2
while str(x)[ind] == "0":
ind += 1
# print(ind)
c = np.around(float(x), ind - 1)
elif float(x) < 1.0:
c = np.around(float(x), 2)
elif float(x) < 10.0:
c = int(x)
else:
c = int(2 ** np.round(math.log(float(x)) / math.log(2)))
return c
_discretize = quantize
[docs]
def convert_to_scientific(x):
"""
Convert a number to scientific notation.
:param x: The number to convert.
:return: The number in scientific notation as a string.
"""
return f"{float(x):0.0e}"
[docs]
def closest_power_of_2(x):
"""
Round a number to its closest power of 2, i.e. y = 2**floor(log_2(x)).
:param x: The number.
:return: The closest power of 2 of the number.
"""
# Handle negative numbers by taking the absolute value
x = abs(x)
# Find the exponent (log base 2)
exponent = math.log2(x)
# Round the exponent to the nearest integer
rounded_exponent = round(exponent)
# Calculate the closest power of 2
closest_value = 2 ** rounded_exponent
return closest_value
def _is_number(x):
"""
Check if an object is a number.
:param x: The object to check.
:return: Whether the object is a number.
"""
# if isinstance(n,list):
# return all([_is_number(num) for num in n])
try:
float(x) # Type-casting the string to `float`.
# If string is not a valid `float`,
# it'll raise `ValueError` exception
# except ValueError, TypeError:
except:
return False
return True
[docs]
def quantize_dict(base_dict, quantize_fn=quantize):
"""
Quantize a dictionary.
:param base_dict: The dictionary to quantize.
:param quantize_fn: The function to use for quantization, defaults to quantize.
:return: The quantized dictionary.
"""
for k, v in base_dict.items():
if isinstance(v, dict):
base_dict.update({k: quantize_dict(v, quantize_fn)})
elif isinstance(v, (int, float)):
base_dict.update({k: quantize_fn(v)})
elif _is_number(v):
base_dict.update({k: quantize_fn(float(v))})
# elif 'float' in str(type(v)):
# base_dict.update({k:str(v)})
# elif isinstance(v,np.int64):
# base_dict.update({k:int(v)})
return base_dict
[docs]
def typewriter_print(text, sleep_time=0.001, max_line_length=150, max_lines=20):
"""
Print a string letter-by-letter, like a typewriter.
:param text: The text to print.
:param sleep_time: The time to sleep between letters, defaults to 0.001.
:param max_line_length: The maximum line length to truncate to, defaults to 150.
:param max_lines: The maximum number of lines to print, defaults to 20.
"""
if not isinstance(text, str):
text = str(text)
text_lines = text.split("\n")
if len(text_lines) > max_lines:
text_lines = text_lines[:max_lines] + ["...\n"]
def shorten_if_needed(line, max_line_length):
if len(line) > max_line_length:
return line[:max_line_length] + " ...\n"
else:
return line + "\n"
text_lines = [shorten_if_needed(l, max_line_length) for l in text_lines]
for line in text_lines:
for c in line:
print(c, end="")
sys.stdout.flush()
time.sleep(sleep_time)
# ---------------------------------------------
[docs]
def discretize(n: list[float, int]) -> list[float, int]:
"""
Discretize a list of inputs
:param n: The list of inputs to discretize.
:return: The list of discretized inputs.
"""
try:
if type(n) == str:
n = literal_eval(n)
if type(n) == list:
c = [_discretize(_n) for _n in n]
elif type(n) == tuple:
n = list(n)
c = tuple([_discretize(_n) for _n in n])
else:
c = _discretize(n)
except:
c = n
return c
[docs]
def apply_discretize_to_summary(text, info):
"""
Discretize a summary.
:param text: The text to discretize.
:param info: An object that contains different separators.
:return: The discretized summary.
"""
# text_split = [ [ p.split(key_val_seperator) for p in l.split(parameter_seperator)] for l in text.split(layer_seperator)]
# print(text_split)
text_split = [
[
[str(discretize(pp)) for pp in p.split(info.key_val_seperator)]
for p in l.split(info.parameter_seperator)
]
for l in text.split(info.layer_seperator)
]
# print(text_split)
text_join = info.layer_seperator.join(
[
info.parameter_seperator.join([info.key_val_seperator.join(p) for p in l])
for l in text_split
]
)
# print(text_join)
return text_join
[docs]
def save_run(*args, **kwargs):
"""
Save the current run.
:param modlee_client: The client object that is tracking the current run.
"""
api_key = os.environ.get('MODLEE_API_KEY')
ModleeClient(api_key=api_key).post_run(*args, **kwargs)
[docs]
def save_run_as_json(*args, **kwargs):
"""
Save the current run as a JSON.
:param modlee_client: The client object that is tracking the current run.
"""
api_key = os.environ.get('MODLEE_API_KEY')
ModleeClient(api_key=api_key).post_run_as_json(*args, **kwargs)
[docs]
def last_run_path(*args, **kwargs):
"""
Return the path to the last / most recent run path
:return: The path to the last run.
"""
artifact_uri = mlflow.last_active_run().info.artifact_uri
artifact_path = urlparse(artifact_uri).path
return os.path.dirname(artifact_path)