from dataclasses import dataclass
import socket
import time
import dill
from dotmap import DotMap
from verifai.features.features import *
from verifai.samplers.feature_sampler import *
def choose_sampler(sample_space, sampler_type,
sampler_params=None):
if sampler_type == 'random':
return 'random', FeatureSampler.randomSamplerFor(sample_space)
if sampler_type == 'grid':
return 'grid', FeatureSampler.gridSamplerFor(sample_space)
if sampler_type == 'halton':
if sampler_params is None:
halton_params = default_sampler_params('halton')
else:
halton_params = default_sampler_params('halton')
halton_params.update(sampler_params)
sampler = FeatureSampler.haltonSamplerFor(sample_space,
halton_params=halton_params)
return 'halton', sampler
if sampler_type == 'ce':
if sampler_params is None:
ce_params = default_sampler_params('ce')
else:
ce_params = default_sampler_params('ce')
if 'cont' in sampler_params:
if 'buckets' in sampler_params.cont:
ce_params.cont.buckets = sampler_params.cont.buckets
if 'dist' in sampler_params.cont:
ce_params.cont.dist = sampler_params.cont.dist
if 'dist' in sampler_params.disc:
ce_params.disc.dist = sampler_params.disc.dist
if 'alpha' in sampler_params:
ce_params.alpha = sampler_params.alpha
if 'thres' in sampler_params:
ce_params.thres = sampler_params.thres
sampler = FeatureSampler.crossEntropySamplerFor(
sample_space, ce_params=ce_params)
return 'ce', sampler
if sampler_type == 'mab':
if sampler_params is None:
mab_params = default_sampler_params('mab')
else:
mab_params = default_sampler_params('mab')
if 'cont' in sampler_params:
if 'buckets' in sampler_params.cont:
mab_params.cont.buckets = sampler_params.cont.buckets
if 'dist' in sampler_params.cont:
mab_params.cont.dist = sampler_params.cont.dist
if 'dist' in sampler_params.disc:
mab_params.disc.dist = sampler_params.disc.dist
if 'alpha' in sampler_params:
mab_params.alpha = sampler_params.alpha
if 'thres' in sampler_params:
mab_params.thres = sampler_params.thres
if 'priority_graph' in sampler_params:
mab_params.priority_graph = sampler_params.priority_graph
sampler = FeatureSampler.multiArmedBanditSamplerFor(
sample_space, mab_params=mab_params)
return 'mab', sampler
if sampler_type == 'eg':
if sampler_params is None:
eg_params = default_sampler_params('eg')
else:
eg_params = default_sampler_params('eg')
if 'cont' in sampler_params:
if 'buckets' in sampler_params.cont:
eg_params.cont.buckets = sampler_params.cont.buckets
if 'dist' in sampler_params.cont:
eg_params.cont.dist = sampler_params.cont.dist
if 'dist' in sampler_params.disc:
eg_params.disc.dist = sampler_params.disc.dist
if 'alpha' in sampler_params:
eg_params.alpha = sampler_params.alpha
if 'thres' in sampler_params:
eg_params.thres = sampler_params.thres
sampler = FeatureSampler.epsilonGreedySamplerFor(
sample_space, eg_params=eg_params)
return 'eg', sampler
if sampler_type == 'bo':
if sampler_params is None:
bo_params = default_sampler_params('bo')
else:
bo_params = default_sampler_params('bo')
bo_params.update(sampler_params)
sampler = FeatureSampler.bayesianOptimizationSamplerFor(
sample_space, BO_params=bo_params)
return 'bo', sampler
raise ValueError(f'unknown sampler type "{sampler_type}"')
[docs]class Server:
"""Generic server for communicating with an external simulator."""
def __init__(self, sampling_data, monitor, options={}):
defaults = DotMap(port=8888, bufsize=4096, maxreqs=5)
defaults.update(options)
self.monitor = monitor
self.lastValue = None
self.port = defaults.port
self.bufsize = defaults.bufsize
self.maxreqs = defaults.maxreqs
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.host = '127.0.0.1'
self.socket.bind((self.host, self.port))
self.socket.listen(self.maxreqs)
if sampling_data.sampler is not None:
self.sampler_type = ('random' if sampling_data.sampler_type is None
else sampling_data.sampler_type)
self.sampler = sampling_data.sampler
self.sample_space = (self.sampler.space
if sampling_data.sample_space is None
else sampling_data.sample_space)
elif sampling_data.sampler_type is None:
feature_space = {}
for space_name in sampling_data.sample_space:
space = sampling_data.sample_space[space_name]
feature_space[space_name] = Feature(space)
self.sample_space = FeatureSpace(feature_space)
self.sampler_type = 'random'
self.sampler = FeatureSampler.samplerFor(self.sample_space)
self.sample_space = self.sampler.space
else:
feature_space = {}
for space_name in sampling_data.sample_space:
space = sampling_data.sample_space[space_name]
feature_space[space_name] = Feature(space)
self.sample_space = FeatureSpace(feature_space)
params = (None if 'sampler_params' not in sampling_data
else sampling_data.sampler_params)
self.sampler_type, self.sampler = choose_sampler(
sample_space=self.sample_space,
sampler_type=sampling_data.sampler_type,
sampler_params=params
)
def listen(self):
client_socket, addr = self.socket.accept()
self.client_socket = client_socket
def receive(self):
data = []
while True:
msg = self.client_socket.recv(self.bufsize)
if not msg:
break
data.append(msg)
simulation_data = self.decode(b"".join(data))
return simulation_data
def send(self, sample):
msg = self.encode(sample)
self.client_socket.send(msg)
self.client_socket.shutdown(socket.SHUT_WR)
def encode(self, sample):
return dill.dumps(sample)
def decode(self, data):
return dill.loads(data)
def terminate(self):
self.socket.close()
def close_connection(self):
self.client_socket.close()
def get_sample(self, feedback):
return self.sampler.nextSample(feedback)
def flatten_sample(self, sample):
return self.sampler.space.flatten(sample)
def evaluate_sample(self, sample):
self.listen()
self.send(sample)
simulation_data = self.receive()
self.close_connection()
value = (0 if self.monitor is None
else self.monitor.evaluate(simulation_data))
return value
def run_server(self):
start = time.time()
sample = self.get_sample(self.lastValue)
after_sampling = time.time()
self.lastValue = self.evaluate_sample(sample)
after_simulation = time.time()
timings = ServerTimings(sample_time=(after_sampling - start),
simulate_time=(after_simulation - after_sampling))
return sample, self.lastValue, timings
try:
import ray
@ray.remote
class ParallelServer(Server):
pass
except ModuleNotFoundError:
class ParallelServer(Server):
def __init__(*args, **kwargs):
raise RuntimeError('ParallelServer requires ray to be installed')
@dataclass
class ServerTimings:
sample_time: float
simulate_time: float