From daed085515825d92872849aaf7b1cf9d04674feb Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sun, 13 Oct 2024 10:14:31 -0400 Subject: [PATCH 01/32] getting ademamix docs and notebook running --- docs/api/contrib.rst | 7 + examples/contrib/rosenbrock_ademamix.ipynb | 247 ++++++++++ optax/contrib/__init__.py | 3 + optax/contrib/_ademamix.py | 173 +++++++ optax/contrib/_common_test.py | 497 ++++++++++----------- 5 files changed, 673 insertions(+), 254 deletions(-) create mode 100644 examples/contrib/rosenbrock_ademamix.ipynb create mode 100644 optax/contrib/_ademamix.py diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index fdee8656f..232104c76 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -8,6 +8,7 @@ Experimental features and algorithms that don't meet the .. autosummary:: acprop + ademamix cocob COCOBState dadapt_adamw @@ -82,6 +83,12 @@ Momo .. autofunction:: momo_adam .. autoclass:: MomoAdamState +Multiple EMA AdEMAMix +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: ademamix +.. autofunction:: scale_by_ademamix +.. autoclass:: ScaleByAdemamixState + Prodigy ~~~~~~~ .. autofunction:: prodigy diff --git a/examples/contrib/rosenbrock_ademamix.ipynb b/examples/contrib/rosenbrock_ademamix.ipynb new file mode 100644 index 000000000..756a38eee --- /dev/null +++ b/examples/contrib/rosenbrock_ademamix.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "2dae2f1f-939a-4bd6-86ee-ec2dec9e6192", + "metadata": {}, + "outputs": [], + "source": [ + "# Recreate AdeMAMix Rosenbrock Plot from Paper\n", + "This notebook attempts to recreate the Figures 2(b) and 2(c) from " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "55182561-ad63-4fb1-ba21-116ca65c21b1", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import optax\n", + "import jax\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "15cd3560-d41c-4a97-83c5-b28df4d5d077", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "\n", + "def rosenbrock(x):\n", + " return jnp.square(1 - x[0]) + 100. * jnp.square(x[1] - jnp.square(x[0]))\n", + "\n", + "# Create a grid of x and y values\n", + "#X, Y = np.meshgrid(np.linspace(-1.3, 1.3, 31), np.linspace(-0.9, 1.7, 31))\n", + "x = jnp.linspace(-5, 10, 1000)\n", + "y = jnp.linspace(-5, 10, 1000)\n", + "X, Y = jnp.meshgrid(x, y)\n", + "\n", + "# Compute the Rosenbrock function values for each point on the grid\n", + "Z = rosenbrock([X, Y])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92b6987c-8ba1-43bc-8083-4c2b6324cb28", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Objective function: 1616.0\n" + ] + } + ], + "source": [ + "all_b1_params = []\n", + "for b1 in [0.9,0.99,0.999,0.9999]:\n", + " solver = optax.adam(\n", + " learning_rate=0.003,\n", + " b1=b1,\n", + " b2=0.9999,\n", + " )\n", + " params = jnp.array([-3.,5.])\n", + " print(\"Objective function: \", rosenbrock(params))\n", + " all_params=[params]\n", + " opt_state = solver.init(params)\n", + " for i in range(100000):\n", + " grad = jax.grad(rosenbrock)(params)\n", + " updates, opt_state = solver.update(grad, opt_state, params)\n", + " params = optax.apply_updates(params, updates)\n", + " all_params.append(params)\n", + " # if i%1000 == 0:\n", + " # print(f\"Objective function at iteration {i} = {rosenbrock(params)}\")\n", + " all_b1_params.append(all_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11a4561a-1d92-44ce-bab0-5af22f028167", + "metadata": {}, + "outputs": [], + "source": [ + "all_ademamix_params = []\n", + "for b3 in [0.999,0.9999]:\n", + " solver = optax.ademamix(\n", + " learning_rate=0.003,\n", + " b1=.99,\n", + " b2=0.999,\n", + " b3=0.9999,\n", + " )\n", + " params = jnp.array([-3.,5.])\n", + " print(\"Objective function: \", rosenbrock(params))\n", + " all_params=[params]\n", + " opt_state = solver.init(params)\n", + " for i in range(100000):\n", + " grad = jax.grad(rosenbrock)(params)\n", + " updates, opt_state = solver.update(grad, opt_state, params)\n", + " params = optax.apply_updates(params, updates)\n", + " all_params.append(params)\n", + " # if i%1000 == 0:\n", + " # print(f\"Objective function at iteration {i} = {rosenbrock(params)}\")\n", + " all_ademamix_params.append(all_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d19aeff7-c908-4c7c-a428-a7846fb3e62a", + "metadata": {}, + "outputs": [], + "source": [ + "all_b1_params_array = jnp.array(all_b1_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60ca6921-fa7c-413f-972f-6b1668c72159", + "metadata": {}, + "outputs": [], + "source": [ + "all_b1_params_array.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78364e82-f634-4d08-b395-18d34bc64329", + "metadata": {}, + "outputs": [], + "source": [ + "all_ademamix_params_array = jnp.array(all_ademamix_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fdfb913-96c7-4b71-9dcb-177317006049", + "metadata": {}, + "outputs": [], + "source": [ + "all_ademamix_params_array.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b0a0467-02e7-4fd3-84e2-9b70df6ed665", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rc('figure', figsize=(20, 10))\n", + "plt.rc('font', size=14)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69d8642f-dfcc-4fac-8f85-3ee1fbfa135f", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a 3D plot\n", + "fig = plt.figure()\n", + "ax = fig.subplots(1,2)\n", + "ax[0].set_xlabel('x')\n", + "ax[0].set_ylabel('y')\n", + "ax[0].set_title('Rosenbrock Function - Adam Trajectories')\n", + "# Show the plot\n", + "ax[0].plot([1], [1], 'x', mew=1, markersize=10, color='cyan')\n", + "ax[0].contourf(X, Y, Z, np.logspace(-1, 3, 100), cmap='jet')\n", + "for i, b1 in enumerate([0.9,0.99,0.999,0.9999]):\n", + " ax[0].plot(all_b1_params_array[i,::100,0], all_b1_params_array[i,::100,1],label=f'Adam b1 = {b1}')\n", + "ax[0].set_xlim(-4,4)\n", + "ax[0].set_ylim(-3.5,7.5)\n", + "ax[0].legend()\n", + "\n", + "ax[1].set_xlabel('x')\n", + "ax[1].set_ylabel('y')\n", + "ax[1].set_title('Rosenbrock Function - Adam Trajectories')\n", + "# Show the plot\n", + "ax[1].plot([1], [1], 'x', mew=1, markersize=10, color='cyan')\n", + "ax[1].contourf(X, Y, Z, np.logspace(-1, 3, 100), cmap='jet')\n", + "for i, b3 in enumerate([0.999,0.9999]):\n", + " ax[1].plot(all_ademamix_params_array[i,::100,0], all_ademamix_params_array[i,::100,1],label=f'AdEMAMix b3 = {b3}')\n", + "ax[1].set_xlim(-4,4)\n", + "ax[1].set_ylim(-3.5,7.5)\n", + "ax[1].legend()\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cf96de0-cb01-4338-87b4-dd80f0498ebd", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " all_ademamix_params_array[0,-1,0],all_ademamix_params_array[0,-1,1],\n", + " all_ademamix_params_array[1,-1,0],all_ademamix_params_array[1,-1,1]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66647d3e-81e2-4987-b5ef-81e08ac048dc", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/optax/contrib/__init__.py b/optax/contrib/__init__.py index a310cc23b..2456f8af7 100644 --- a/optax/contrib/__init__.py +++ b/optax/contrib/__init__.py @@ -18,6 +18,9 @@ from optax.contrib._acprop import acprop from optax.contrib._acprop import scale_by_acprop +from optax.contrib._ademamix import ScaleByAdemamixState +from optax.contrib._ademamix import scale_by_ademamix +from optax.contrib._ademamix import ademamix from optax.contrib._cocob import cocob from optax.contrib._cocob import COCOBState from optax.contrib._cocob import scale_by_cocob diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py new file mode 100644 index 000000000..3e461fb20 --- /dev/null +++ b/optax/contrib/_ademamix.py @@ -0,0 +1,173 @@ +from typing import Any, Callable, NamedTuple, Optional, Union +import chex +import jax.numpy as jnp +import jax.tree_util as jtu +from optax._src import base +from optax._src import combine +from optax._src import transform + + +class ScaleByAdemamixState(NamedTuple): + """State for the Ademamix algorithm.""" + + count: chex.Array + count_m2: chex.Array + m1: base.Updates + m2: base.Updates + nu: base.Updates + + +def scale_by_ademamix( + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.9999, + alpha: float = 5.0, + b3_scheduler: Optional[base.ScalarOrSchedule] = None, + alpha_scheduler: Optional[base.ScalarOrSchedule] = None, + eps: float = 1e-8, + weight_decay: float = 0.0, +) -> base.GradientTransformation: + """Rescale updates according to the Ademamix algorithm. + + References: + [Pagliardini et al, 2024](https://arxiv.org/pdf/2409.03137) + + Args: + b1: Exponential decay rate to track the first moment of past gradients for + the first Exponential Moving Average (EMA) - same as AdamW + b2: Exponential decay rate to track the second moment of past gradients for + the first Exponential Moving Average (EMA) - same as AdamW + b3: Exponential decay rate to track the first moment of past gradients + for the second EMA. + alpha: the coefficient that "blends" the two EMAs. paper states values in + :math:`[4,10]` work well in practice. + b3_scheduler: The schedule for the b3 parameter + alpha_scheduler: The schedule for the alpha parameter + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + weight_decay: Strength of the weight decay regularization. + + Returns: + A `GradientTransformation` object. + + Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, + the method is best suited to settings where the number of iterations is + important. The paper reports on this effect in App. C.1.5, showing how + smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations + scenarios. Moreover, retaining gradient information over many thousands + steps can pose a problem in domains requiring fast adaptation to a sudden + distribution shift, or general cases in which the distribution is non-stationary. + """ + + def init_fn(params): + m1 = otu.tree_zeros_like(params) # fast EMA + m2 = otu.tree_zeros_like(params) # slow EMA + nu = otu.tree_zeros_like(params) # second moment estimate + return ScaleByAdemamixState( + count=jnp.zeros([], jnp.int32), + count_m2=jnp.zeros([], jnp.int32), + m1=m1, + m2=m2, + nu=nu, + ) + + def update_fn(updates, state, params=None): + del params + c_b3 = b3_scheduler(state.count_m2) if b3_scheduler is not None else b3 + c_alpha = ( + alpha_scheduler(state.count_m2) if alpha_scheduler is not None else alpha + ) + m1 = otu.tree_update_moment( + updates, state.m1, b1, 1 + ) # m1 = b1 * m1 + (1-b1) * updates + m2 = otu.tree_update_moment(updates, state.m2, c_b3, 1) + nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + count_m2_inc = numerics.safe_int32_increment(state.count_m2) + m1_hat = otu.tree_bias_correction(m1, b1, count_inc) + nu_hat = otu.tree_bias_correction(nu, b2, count_inc) + updates = jtu.tree_map( + lambda m1_, m2_, v_: (m1_ + c_alpha * m2_) / (jnp.sqrt(v_) + eps), + m1_hat, + m2, + nu_hat, + ) + return updates, ScaleByAdemamixState( + count=count_inc, count_m2=count_m2_inc, m1=m1, m2=m2, nu=nu + ) + + return base.GradientTransformation(init_fn, update_fn) + + +def ademamix( + learning_rate: base.ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + b3: float = 0.9999, + alpha: float = 5.0, + b3_scheduler: Optional[base.ScalarOrSchedule] = None, + alpha_scheduler: Optional[base.ScalarOrSchedule] = None, + eps: float = 1e-8, + weight_decay: float = 0.0, +) -> base.GradientTransformation: + """The Ademamix optimiser. + + Description + + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function + >>> solver = optax.ademamix(learning_rate=0.003) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 1.40E+01 + Objective function: 1.39E+01 + Objective function: 1.39E+01 + Objective function: 1.39E+01 + Objective function: 1.38E+01 + + References: + Pagliardini et al, 2024: https://arxiv.org/pdf/2409.03137 + + Args: + b1: Exponential decay rate to track the first moment of past gradients for + the first Exponential Moving Average (EMA) - same as AdamW + b2: Exponential decay rate to track the second moment of past gradients for + the first Exponential Moving Average (EMA) - same as AdamW + b3: Exponential decay rate to track the first moment of past gradients + for the second EMA. + alpha: the coefficient that "blends" the two EMAs. paper states values in + :math:`[4,10]` work well in practice. + b3_scheduler: The schedule for the b3 parameter + alpha_scheduler: The schedule for the alpha parameter + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + weight_decay: Strength of the weight decay regularization. + + Returns: + A `GradientTransformation` object. + + Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, + the method is best suited to settings where the number of iterations is + important. The paper reports on this effect in App. C.1.5, showing how + smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations + scenarios. Moreover, retaining gradient information over many thousands + steps can pose a problem in domains requiring fast adaptation to a sudden + distribution shift, or general cases in which the distribution is non-stationary. + """ + return combine.chain( + transform.scale_by_ademamix( + b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps + ), + transform.add_decayed_weights(weight_decay), + transform.scale_by_learning_rate(learning_rate), + ) diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index f044d7a11..d440fb786 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -36,81 +36,82 @@ # Testing contributions coded as GradientTransformations _MAIN_OPTIMIZERS_UNDER_TEST = [ - dict(opt_name='acprop', opt_kwargs=dict(learning_rate=1e-3)), - dict(opt_name='cocob', opt_kwargs={}), - dict(opt_name='cocob', opt_kwargs=dict(weight_decay=1e-2)), - dict(opt_name='dadapt_adamw', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='dog', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='dowg', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name="acprop", opt_kwargs=dict(learning_rate=1e-3)), + dict(opt_name="ademamix", opt_kwargs=dict(learning_rate=1e-3)), + dict(opt_name="cocob", opt_kwargs={}), + dict(opt_name="cocob", opt_kwargs=dict(weight_decay=1e-2)), + dict(opt_name="dadapt_adamw", opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name="dog", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="dowg", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="momo", opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name="momo_adam", opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name="prodigy", opt_kwargs=dict(learning_rate=1e-1)), dict( - opt_name='schedule_free_sgd', + opt_name="schedule_free_sgd", opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), ), dict( - opt_name='schedule_free_adamw', + opt_name="schedule_free_adamw", opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), ), ] for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST: - optimizer['wrapper_name'] = None - optimizer['wrapper_kwargs'] = None + optimizer["wrapper_name"] = None + optimizer["wrapper_kwargs"] = None # Testing contributions coded as wrappers # (just with sgd as we just want the behavior of the wrapper) _MAIN_OPTIMIZERS_UNDER_TEST += [ dict( - opt_name='sgd', + opt_name="sgd", opt_kwargs=dict(learning_rate=1e-1), - wrapper_name='mechanize', + wrapper_name="mechanize", wrapper_kwargs=dict(weight_decay=0.0), ), dict( - opt_name='sgd', + opt_name="sgd", opt_kwargs=dict(learning_rate=1e-2), - wrapper_name='schedule_free', + wrapper_name="schedule_free", wrapper_kwargs=dict(learning_rate=1e-2), ), dict( - opt_name='sgd', + opt_name="sgd", opt_kwargs=dict(learning_rate=1e-3), - wrapper_name='reduce_on_plateau', + wrapper_name="reduce_on_plateau", wrapper_kwargs={}, ), ] # Adding here instantiations of wrappers with any base optimizer _BASE_OPTIMIZERS = [ - dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name='adam', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='lion', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), - dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), - dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="sgd", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="sgd", opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), + dict(opt_name="adam", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="adamw", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="adamax", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="adamaxw", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="amsgrad", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="lamb", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="lion", opt_kwargs=dict(learning_rate=1.0, b1=0.99)), + dict(opt_name="noisy_sgd", opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), + dict(opt_name="novograd", opt_kwargs=dict(learning_rate=1.0)), dict( - opt_name='optimistic_gradient_descent', + opt_name="optimistic_gradient_descent", opt_kwargs=dict(learning_rate=1.0, alpha=0.7, beta=0.1), ), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), + dict(opt_name="rmsprop", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="rmsprop", opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), + dict(opt_name="adabelief", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="radam", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="sm3", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name="yogi", opt_kwargs=dict(learning_rate=1.0, b1=0.99)), ] # TODO(harshm): make LARS and Fromage work with mechanic. _OTHER_OPTIMIZERS_UNDER_TEST = [ dict( - opt_name=base_opt['opt_name'], - opt_kwargs=base_opt['opt_kwargs'], - wrapper_name='mechanize', + opt_name=base_opt["opt_name"], + opt_kwargs=base_opt["opt_kwargs"], + wrapper_name="mechanize", wrapper_kwargs=dict(weight_decay=0.0), ) for base_opt in _BASE_OPTIMIZERS @@ -123,235 +124,223 @@ def _get_opt_factory(opt_name): - """Get optimizer factory.""" - if hasattr(contrib, opt_name): - return getattr(contrib, opt_name) - if hasattr(alias, opt_name): - return getattr(alias, opt_name) - raise ValueError(f'Unknown optimizer: {opt_name}') + """Get optimizer factory.""" + if hasattr(contrib, opt_name): + return getattr(contrib, opt_name) + if hasattr(alias, opt_name): + return getattr(alias, opt_name) + raise ValueError(f"Unknown optimizer: {opt_name}") def _wrap_opt(opt, wrapper_name, wrapper_kwargs): - if wrapper_name == 'reduce_on_plateau': - return combine.chain(opt, contrib.reduce_on_plateau(**wrapper_kwargs)) - else: - return getattr(contrib, wrapper_name)(opt, **wrapper_kwargs) + if wrapper_name == "reduce_on_plateau": + return combine.chain(opt, contrib.reduce_on_plateau(**wrapper_kwargs)) + else: + return getattr(contrib, wrapper_name)(opt, **wrapper_kwargs) def _setup_parabola(dtype): - """Quadratic function as an optimization target.""" - initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) - final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) + """Quadratic function as an optimization target.""" + initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) + final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) - @jax.value_and_grad - def get_updates(params): - return jnp.sum(numerics.abs_sq(params - final_params)) + @jax.value_and_grad + def get_updates(params): + return jnp.sum(numerics.abs_sq(params - final_params)) - return initial_params, final_params, get_updates + return initial_params, final_params, get_updates def _setup_rosenbrock(dtype): - """Rosenbrock function as an optimization target.""" - a = 1.0 - b = 100.0 + """Rosenbrock function as an optimization target.""" + a = 1.0 + b = 100.0 - initial_params = jnp.array([0.0, 0.0], dtype=dtype) - final_params = jnp.array([a, a**2], dtype=dtype) + initial_params = jnp.array([0.0, 0.0], dtype=dtype) + final_params = jnp.array([a, a**2], dtype=dtype) - @jax.value_and_grad - def get_updates(params): - return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq( - params[1] - params[0] ** 2 - ) + @jax.value_and_grad + def get_updates(params): + return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq( + params[1] - params[0] ** 2 + ) - return initial_params, final_params, get_updates + return initial_params, final_params, get_updates class ContribTest(chex.TestCase): - - @parameterized.product( - _ALL_OPTIMIZERS_UNDER_TEST, - target=(_setup_parabola, _setup_rosenbrock), - dtype=('float32',), - ) - def test_optimizers( - self, - opt_name, - opt_kwargs, - wrapper_name, - wrapper_kwargs, - target, - dtype, - ): - dtype = jnp.dtype(dtype) - opt = _get_opt_factory(opt_name)(**opt_kwargs) - if wrapper_name is not None: - opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - initial_params, final_params, get_updates = target(dtype) - - @jax.jit - def step(params, state): - value, updates = get_updates(params) - if ( - opt_name in ['momo', 'momo_adam'] - or wrapper_name == 'reduce_on_plateau' - ): - update_kwargs = {'value': value} - else: - update_kwargs = {} - updates, state = opt.update(updates, state, params, **update_kwargs) - params = update.apply_updates(params, updates) - return params, state - - params = initial_params - state = opt.init(params) - with self.subTest('Test that tree_map_params works'): - # A no-op change, to verify that tree map works. - state = _state_utils.tree_map_params(opt, lambda v: v, state) - - with self.subTest('Test that optimization works'): - - def f(params_state, _): - return step(*params_state), None - - (params, state), _ = jax.lax.scan(f, (params, state), length=30_000) - - if ( - opt_name in ['schedule_free_sgd', 'schedule_free_adamw'] - or wrapper_name == 'schedule_free' - ): - params = contrib.schedule_free_eval_params(state, params) - chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) - - @chex.all_variants - @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST) - def test_optimizers_can_be_wrapped_in_inject_hyperparams( - self, opt_name, opt_kwargs, wrapper_name=None, wrapper_kwargs=None - ): - """Checks that optimizers can be wrapped in inject_hyperparams.""" - # See also https://github.com/deepmind/optax/issues/412. - # When debugging this, make sure that options like weight decay or not - # are checked by asserting wehter such a value is None or not (see e.g. the - # logic in schedule_free_adamw). Some hyperparameters may not be supported - # by inject_hyperparams (e.g. warmup_steps). In that case (if you're sure - # you can ignore such hyperparameter), add the exception below. - if wrapper_name == 'reduce_on_plateau': - # TODO(vroulet): discuss adding support for reduce_on_plateau - # so removing all assertions in its definition - self.skipTest('reduce_on_plateau is not supported by inject_hyperparams.') - if wrapper_name is None: - factory = _get_opt_factory(opt_name) - hparams = opt_kwargs - else: - base_opt = _get_opt_factory(opt_name)(**opt_kwargs) - factory = getattr(contrib, wrapper_name) - factory = functools.partial(factory, base_opt) - hparams = wrapper_kwargs - opt = factory(**hparams) - - # Add here the hyperparameters that cannot be injected with - # inject_hyperparams. - static_args = [] - for uninjectable_hparam in ['warmup_steps', 'num_betas']: - if uninjectable_hparam in inspect.signature(factory).parameters.keys(): - static_args.append(uninjectable_hparam) - static_args = tuple(static_args) - opt_inject = _inject.inject_hyperparams(factory, static_args)(**hparams) - - params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))] - grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))] - - if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': - update_kwargs = {'value': jnp.array(1.0)} - else: - update_kwargs = {} - - state = self.variant(opt.init)(params) - updates, new_state = self.variant(opt.update)( - grads, state, params, **update_kwargs + @parameterized.product( + _ALL_OPTIMIZERS_UNDER_TEST, + target=(_setup_parabola, _setup_rosenbrock), + dtype=("float32",), ) - - state_inject = self.variant(opt_inject.init)(params) - updates_inject, new_state_inject = self.variant(opt_inject.update)( - grads, state_inject, params, **update_kwargs - ) - - with self.subTest('Equality of updates.'): - chex.assert_trees_all_close(updates_inject, updates, rtol=1e-5) - with self.subTest('Equality of new optimizer states.'): - chex.assert_trees_all_close( - new_state_inject.inner_state, new_state, rtol=1e-5, atol=1e-5 - ) - - # Not testing with `without_device=True` because without_device set the - # variables to the host which appears to convert then the dtype, so we - # lose control of the dtype and the test fails. - @chex.variants( - with_jit=True, without_jit=True, with_device=True, with_pmap=True - ) - @parameterized.product( - _MAIN_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32') - ) - def test_preserve_dtype( - self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None - ): - """Test that the optimizers return updates of same dtype as params.""" - # When debugging this test, note that operations like - # x = 0.5**jnp.asarray(1, dtype=jnp.int32) - # (appearing in e.g. optax.tree_utils.tree_bias_correction) - # are promoted (strictly) to float32 when jitted - # see https://github.com/google/jax/issues/23337 - # This may end up letting updates have a dtype different from params. - # The solution is to fix the dtype of the result to the desired dtype - # (just as done in optax.tree_utils.tree_bias_correction). - # Otherwise, just make sure that all variables defined in the optimizer have - # the same dtype as the parameters. - dtype = jnp.dtype(dtype) - opt = _get_opt_factory(opt_name)(**opt_kwargs) - if wrapper_name is not None: - opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - fun = lambda x: jnp.sum(x**2) - - params = jnp.array([1.0, 2.0], dtype=dtype) - value, grads = jax.value_and_grad(fun)(params) - state = self.variant(opt.init)(params) - if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': - update_kwargs = {'value': value} - else: - update_kwargs = {} - updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) - self.assertEqual(updates.dtype, params.dtype) - - @chex.variants( - with_jit=True, without_jit=True, with_device=True, with_pmap=True - ) - @parameterized.product( - _MAIN_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32') - ) - def test_gradient_accumulation( - self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None - ): - """Test that the optimizers can safely be used with optax.MultiSteps.""" - # Checks for issues like https://github.com/google-deepmind/optax/issues/377 - # Should pass as long as test_preserve_dtype passes. - dtype = jnp.dtype(dtype) - opt = _get_opt_factory(opt_name)(**opt_kwargs) - if wrapper_name is not None: - opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - opt = _accumulation.MultiSteps(opt, every_k_schedule=4) - - fun = lambda x: jnp.sum(x**2) - - params = jnp.array([1.0, 2.0], dtype=dtype) - value, grads = jax.value_and_grad(fun)(params) - state = self.variant(opt.init)(params) - if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': - update_kwargs = {'value': value} - else: - update_kwargs = {} - updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) - chex.assert_trees_all_equal(updates, jnp.zeros_like(grads)) - - -if __name__ == '__main__': - absltest.main() + def test_optimizers( + self, + opt_name, + opt_kwargs, + wrapper_name, + wrapper_kwargs, + target, + dtype, + ): + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) + initial_params, final_params, get_updates = target(dtype) + + @jax.jit + def step(params, state): + value, updates = get_updates(params) + if opt_name in ["momo", "momo_adam"] or wrapper_name == "reduce_on_plateau": + update_kwargs = {"value": value} + else: + update_kwargs = {} + updates, state = opt.update(updates, state, params, **update_kwargs) + params = update.apply_updates(params, updates) + return params, state + + params = initial_params + state = opt.init(params) + with self.subTest("Test that tree_map_params works"): + # A no-op change, to verify that tree map works. + state = _state_utils.tree_map_params(opt, lambda v: v, state) + + with self.subTest("Test that optimization works"): + + def f(params_state, _): + return step(*params_state), None + + (params, state), _ = jax.lax.scan(f, (params, state), length=30_000) + + if ( + opt_name in ["schedule_free_sgd", "schedule_free_adamw"] + or wrapper_name == "schedule_free" + ): + params = contrib.schedule_free_eval_params(state, params) + chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) + + @chex.all_variants + @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST) + def test_optimizers_can_be_wrapped_in_inject_hyperparams( + self, opt_name, opt_kwargs, wrapper_name=None, wrapper_kwargs=None + ): + """Checks that optimizers can be wrapped in inject_hyperparams.""" + # See also https://github.com/deepmind/optax/issues/412. + # When debugging this, make sure that options like weight decay or not + # are checked by asserting wehter such a value is None or not (see e.g. the + # logic in schedule_free_adamw). Some hyperparameters may not be supported + # by inject_hyperparams (e.g. warmup_steps). In that case (if you're sure + # you can ignore such hyperparameter), add the exception below. + if wrapper_name == "reduce_on_plateau": + # TODO(vroulet): discuss adding support for reduce_on_plateau + # so removing all assertions in its definition + self.skipTest("reduce_on_plateau is not supported by inject_hyperparams.") + if wrapper_name is None: + factory = _get_opt_factory(opt_name) + hparams = opt_kwargs + else: + base_opt = _get_opt_factory(opt_name)(**opt_kwargs) + factory = getattr(contrib, wrapper_name) + factory = functools.partial(factory, base_opt) + hparams = wrapper_kwargs + opt = factory(**hparams) + + # Add here the hyperparameters that cannot be injected with + # inject_hyperparams. + static_args = [] + for uninjectable_hparam in ["warmup_steps", "num_betas"]: + if uninjectable_hparam in inspect.signature(factory).parameters.keys(): + static_args.append(uninjectable_hparam) + static_args = tuple(static_args) + opt_inject = _inject.inject_hyperparams(factory, static_args)(**hparams) + + params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))] + grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))] + + if opt_name in ["momo", "momo_adam"] or wrapper_name == "reduce_on_plateau": + update_kwargs = {"value": jnp.array(1.0)} + else: + update_kwargs = {} + + state = self.variant(opt.init)(params) + updates, new_state = self.variant(opt.update)( + grads, state, params, **update_kwargs + ) + + state_inject = self.variant(opt_inject.init)(params) + updates_inject, new_state_inject = self.variant(opt_inject.update)( + grads, state_inject, params, **update_kwargs + ) + + with self.subTest("Equality of updates."): + chex.assert_trees_all_close(updates_inject, updates, rtol=1e-5) + with self.subTest("Equality of new optimizer states."): + chex.assert_trees_all_close( + new_state_inject.inner_state, new_state, rtol=1e-5, atol=1e-5 + ) + + # Not testing with `without_device=True` because without_device set the + # variables to the host which appears to convert then the dtype, so we + # lose control of the dtype and the test fails. + @chex.variants(with_jit=True, without_jit=True, with_device=True, with_pmap=True) + @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST, dtype=("bfloat16", "float32")) + def test_preserve_dtype( + self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None + ): + """Test that the optimizers return updates of same dtype as params.""" + # When debugging this test, note that operations like + # x = 0.5**jnp.asarray(1, dtype=jnp.int32) + # (appearing in e.g. optax.tree_utils.tree_bias_correction) + # are promoted (strictly) to float32 when jitted + # see https://github.com/google/jax/issues/23337 + # This may end up letting updates have a dtype different from params. + # The solution is to fix the dtype of the result to the desired dtype + # (just as done in optax.tree_utils.tree_bias_correction). + # Otherwise, just make sure that all variables defined in the optimizer have + # the same dtype as the parameters. + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) + fun = lambda x: jnp.sum(x**2) + + params = jnp.array([1.0, 2.0], dtype=dtype) + value, grads = jax.value_and_grad(fun)(params) + state = self.variant(opt.init)(params) + if opt_name in ["momo", "momo_adam"] or wrapper_name == "reduce_on_plateau": + update_kwargs = {"value": value} + else: + update_kwargs = {} + updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + self.assertEqual(updates.dtype, params.dtype) + + @chex.variants(with_jit=True, without_jit=True, with_device=True, with_pmap=True) + @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST, dtype=("bfloat16", "float32")) + def test_gradient_accumulation( + self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None + ): + """Test that the optimizers can safely be used with optax.MultiSteps.""" + # Checks for issues like https://github.com/google-deepmind/optax/issues/377 + # Should pass as long as test_preserve_dtype passes. + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) + opt = _accumulation.MultiSteps(opt, every_k_schedule=4) + + fun = lambda x: jnp.sum(x**2) + + params = jnp.array([1.0, 2.0], dtype=dtype) + value, grads = jax.value_and_grad(fun)(params) + state = self.variant(opt.init)(params) + if opt_name in ["momo", "momo_adam"] or wrapper_name == "reduce_on_plateau": + update_kwargs = {"value": value} + else: + update_kwargs = {} + updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + chex.assert_trees_all_equal(updates, jnp.zeros_like(grads)) + + +if __name__ == "__main__": + absltest.main() From c47261c528329625200b12a68cc9af512563b780 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Sun, 13 Oct 2024 15:28:31 -0400 Subject: [PATCH 02/32] fixed imports --- examples/contrib/rosenbrock_ademamix.ipynb | 12 +++++------- optax/contrib/_ademamix.py | 3 ++- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/contrib/rosenbrock_ademamix.ipynb b/examples/contrib/rosenbrock_ademamix.ipynb index 756a38eee..1e50abaf9 100644 --- a/examples/contrib/rosenbrock_ademamix.ipynb +++ b/examples/contrib/rosenbrock_ademamix.ipynb @@ -1,14 +1,12 @@ { "cells": [ { - "cell_type": "code", - "execution_count": null, - "id": "2dae2f1f-939a-4bd6-86ee-ec2dec9e6192", + "cell_type": "markdown", + "id": "b1d10c78-11a1-4998-8d50-b48a0e8fb3ae", "metadata": {}, - "outputs": [], "source": [ "# Recreate AdeMAMix Rosenbrock Plot from Paper\n", - "This notebook attempts to recreate the Figures 2(b) and 2(c) from " + "This notebook attempts to recreate the Figures 2(b) and 2(c) from the [AdeMAMix paper](https://arxiv.org/pdf/2409.03137)" ] }, { @@ -93,11 +91,11 @@ "source": [ "all_ademamix_params = []\n", "for b3 in [0.999,0.9999]:\n", - " solver = optax.ademamix(\n", + " solver = optax.contrib.ademamix(\n", " learning_rate=0.003,\n", " b1=.99,\n", " b2=0.999,\n", - " b3=0.9999,\n", + " b3=b3,\n", " )\n", " params = jnp.array([-3.,5.])\n", " print(\"Objective function: \", rosenbrock(params))\n", diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 3e461fb20..138094a8d 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -1,3 +1,4 @@ +import optax.tree_utils as otu from typing import Any, Callable, NamedTuple, Optional, Union import chex import jax.numpy as jnp @@ -165,7 +166,7 @@ def ademamix( distribution shift, or general cases in which the distribution is non-stationary. """ return combine.chain( - transform.scale_by_ademamix( + scale_by_ademamix( b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps ), transform.add_decayed_weights(weight_decay), From bf2d4a88663fd41c0f95b3fb90e9d3d859c7f56a Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 14 Oct 2024 10:44:34 -0400 Subject: [PATCH 03/32] fixed linting errors --- examples/contrib/rosenbrock_ademamix.ipynb | 192 +++++--- optax/contrib/_ademamix.py | 134 +++--- optax/contrib/_common_test.py | 498 +++++++++++---------- 3 files changed, 452 insertions(+), 372 deletions(-) diff --git a/examples/contrib/rosenbrock_ademamix.ipynb b/examples/contrib/rosenbrock_ademamix.ipynb index 1e50abaf9..f89435321 100644 --- a/examples/contrib/rosenbrock_ademamix.ipynb +++ b/examples/contrib/rosenbrock_ademamix.ipynb @@ -9,6 +9,14 @@ "This notebook attempts to recreate the Figures 2(b) and 2(c) from the [AdeMAMix paper](https://arxiv.org/pdf/2409.03137)" ] }, + { + "cell_type": "markdown", + "id": "c53b3ca1-0372-4671-90a9-e65446695a85", + "metadata": {}, + "source": [ + "## Imports" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -19,7 +27,21 @@ "import matplotlib.pyplot as plt\n", "import optax\n", "import jax\n", - "import jax.numpy as jnp" + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "\n", + "plt.rc('figure', figsize=(20, 10))\n", + "plt.rc('font', size=14)" + ] + }, + { + "cell_type": "markdown", + "id": "ec581f6c-c3e5-4924-bf78-17f57c60cbcd", + "metadata": {}, + "source": [ + "## Functions" ] }, { @@ -29,10 +51,6 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from mpl_toolkits.mplot3d import Axes3D\n", - "\n", "def rosenbrock(x):\n", " return jnp.square(1 - x[0]) + 100. * jnp.square(x[1] - jnp.square(x[0]))\n", "\n", @@ -46,9 +64,17 @@ "Z = rosenbrock([X, Y])" ] }, + { + "cell_type": "markdown", + "id": "152e443e-5697-4eea-97f5-269cd12a2cfd", + "metadata": {}, + "source": [ + "## Generate Adam Trajectories (Baseline)" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "92b6987c-8ba1-43bc-8083-4c2b6324cb28", "metadata": {}, "outputs": [ @@ -56,6 +82,9 @@ "name": "stdout", "output_type": "stream", "text": [ + "Objective function: 1616.0\n", + "Objective function: 1616.0\n", + "Objective function: 1616.0\n", "Objective function: 1616.0\n" ] } @@ -79,15 +108,33 @@ " all_params.append(params)\n", " # if i%1000 == 0:\n", " # print(f\"Objective function at iteration {i} = {rosenbrock(params)}\")\n", - " all_b1_params.append(all_params)" + " all_b1_params.append(all_params)\n", + "all_b1_params_array = jnp.array(all_b1_params)" + ] + }, + { + "cell_type": "markdown", + "id": "75bcfd99-6db6-4f54-ba0c-240863c6162d", + "metadata": {}, + "source": [ + "## Generate AdeMAMix Trajectories" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "11a4561a-1d92-44ce-bab0-5af22f028167", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Objective function: 1616.0\n", + "Objective function: 1616.0\n" + ] + } + ], "source": [ "all_ademamix_params = []\n", "for b3 in [0.999,0.9999]:\n", @@ -108,68 +155,36 @@ " all_params.append(params)\n", " # if i%1000 == 0:\n", " # print(f\"Objective function at iteration {i} = {rosenbrock(params)}\")\n", - " all_ademamix_params.append(all_params)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d19aeff7-c908-4c7c-a428-a7846fb3e62a", - "metadata": {}, - "outputs": [], - "source": [ - "all_b1_params_array = jnp.array(all_b1_params)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60ca6921-fa7c-413f-972f-6b1668c72159", - "metadata": {}, - "outputs": [], - "source": [ - "all_b1_params_array.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78364e82-f634-4d08-b395-18d34bc64329", - "metadata": {}, - "outputs": [], - "source": [ + " all_ademamix_params.append(all_params)\n", "all_ademamix_params_array = jnp.array(all_ademamix_params)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "1fdfb913-96c7-4b71-9dcb-177317006049", - "metadata": {}, - "outputs": [], - "source": [ - "all_ademamix_params_array.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4b0a0467-02e7-4fd3-84e2-9b70df6ed665", + "cell_type": "markdown", + "id": "6c55be58-0157-4909-bf01-c8aac0af7044", "metadata": {}, - "outputs": [], "source": [ - "plt.rc('figure', figsize=(20, 10))\n", - "plt.rc('font', size=14)" + "## Plot the Figure" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "69d8642f-dfcc-4fac-8f85-3ee1fbfa135f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABmcAAANlCAYAAACJ1C0sAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3gUVRfG301200khgUAgJIHQlE6oAQIiQWpo0qUkKEpXUVGRrqBSFSxACGAQEOSjiaJUkSZVUekk9BpIAVI22fv9MTuzM7Mz21KB83uefXZ35k7Z2d2588459z0axhgDQRAEQRAEQRAEQRAEQRAEQRAEUSQ4FfcOEARBEARBEARBEARBEARBEARBPEtQcIYgCIIgCIIgCIIgCIIgCIIgCKIIoeAMQRAEQRAEQRAEQRAEQRAEQRBEEULBGYIgCIIgCIIgCIIgCIIgCIIgiCKEgjMEQRAEQRAEQRAEQRAEQRAEQRBFCAVnCIIgCIIgCIIgCIIgCIIgCIIgihAKzhAEQRAEQRAEQRAEQRAEQRAEQRQhFJwhCIIgCIIgCIIgCIIgCIIgCIIoQig4QxAEQRAEQRAEQRAEQRAEQRAEUYRQcIYgCIHQ0FBoNBrs2bOnuHcl3wwZMgQajQZTpkwp7l15amjdujU0Gg2WL19e3LvyxKPRaKDRaJCcnFzcu/LUkJycLBzXZx06FgRBEMSzBGkYwhKkYQoO0jD5Y/ny5dBoNGjdunVx70qRM2XKFGg0GgwZMqS4d6XYoWNByKHgDPHMwV+ciR/Ozs7w9fVF48aNMXXqVDx48KC4d5MogfBiydrj5MmTxb2rdjNlyhRMmTIFqampxb0rJZLBgwcL329iYmJx706xwwsLRx5POydPnsSUKVPoBgBBEARRoJCGIRyFNMyzC2mYZxOl/sKWx7MQLFi+fDmmTJnyRJ7viKcXbXHvAEEUF8HBwahUqRIAQK/X48qVKzhy5AiOHDmCxYsXY+/evQgPDy/mvSRKImXLlkXVqlVV53t5eRXh3hQMU6dOBcCJN19fX8U2lSpVQvXq1eHj41OEe1b8ZGRkYP369cL7ZcuWYeDAgcW4R8VPYGAgIiMjzabfuXMH58+fBwDF+YWJTqdD9erVi3SbSpw8eRJTp05FVFRUsQmcknIsCIIgiIKHNAzhKKRhSMM86xqmuPHx8UH16tWFc3hhUbt2beTm5ppNP3XqFNLT01XPBdWqVSu0fQoICED16tVRvnz5QtuGLSxfvhx79+5FaGgo6tWrVyz7UFKOBVFyoOAM8cwSGxtrNlz8t99+Q58+fXDjxg0MHz4cO3fuLJ6dI0o0HTp0eCaz4leuXFncu1AsrFmzBo8fP4avry9SU1OxZ88eJCUlISwsrLh3rdjo0KEDOnToYDZ9+fLlGDp0KADgjz/+KNJ9qlChAs6cOVOk2yyp0LEgCIJ4eiENQzgKaZhnC9IwJY/u3buje/fuhb6dL7/8UnF669atsXfv3mI5F4waNQqjRo0q0m2WVOhYEHLI1owgRLRr1w4zZswAAOzevRt37twp5j0iCKK4WbZsGQDuIqphw4ZgjCEhIaGY94ogCIIgCIKDNAxBEHJIwxAEQTwZUHCGIGQ0b94cAMAYQ1JSktl8xhhWr16Ndu3awd/fHy4uLqhYsSL69++P48ePq653165d6N69O4KCgqDT6eDj44MqVaqge/fuwoWTnMzMTMyfPx+RkZHw8/ODq6srwsLCMHz4cMV9A6QFD+/fv49x48YhNDQUrq6uqFChAl599VXcunXL6nH4559/0Lt3b5QrVw5ubm6oUaMGpk+fjqysLMX24uKAf/75J3r16oVy5crB2dlZkt2n1+vx9ddfo0WLFvDz84ObmxsqV66M1157DRcuXLC4TxcuXMDo0aNRs2ZNeHl5oVSpUqhRowbi4uLw+++/W/1MPOnp6XjxxReh0WjQsGFD3L592+Zl7cFa8UlLRbvz+z0yxrBx40bExMQgKCgIrq6uCAwMRNOmTTF9+nTcvHkTgKkYHU9YWJjEd1b83Vn7PKdPn0ZsbKywn35+foiKisLSpUuRl5enuIz4d3Pq1Cn06dMHgYGBcHV1RfXq1TFt2jTk5OSofs7C5vTp0zh06BAAYNCgQRg8eDAAYMWKFTAYDBaX/fHHH9GiRQt4eXnB19cXrVq1wqZNmywuc+/ePSxZsgTdunVDtWrV4OnpCU9PT9SqVQvvvvuu6s0WcXFJxhgWLlyIevXqwdPTE+XKlcMrr7yCq1evCu137tyJ9u3bw9/fH56enmjZsiX27t1rz6Gxmz179kCj0SA0NBQAsHr1akRFRaF06dKSIr4PHz5EYmIi+vXrh5o1a8LHxwfu7u6oWrUqRowYoXrus/R/4jl+/DgGDx6M0NBQuLm5Cd/L8uXLLX6fqampmDFjBpo0aSKct8LCwhATEyPJxgwNDRVGDu3du9fMx1lePDW//xm1c21hHYvjx49jwIABCAkJgaurK7y8vBAaGoqXXnoJc+bMAWNMdXsEQRBE4UIahoM0TP4hDUMaxl4Nw2Pv9aVcH3z//fdo1qwZvL29ERAQgG7duuG///4T2h87dgw9evRAYGAg3N3d0bBhQ2zYsEFxXxzVFIW5T2LNJubdd98VtqlUP+nWrVsIDAyERqPBZ599prju/MLXppoyZQrS0tLw3nvvoXr16nB3dxeOBQCcO3cOs2bNQps2bRASEiJ8z82bN8cXX3yh+rvn/7Nqts95eXlISEhA27ZtERAQABcXF1SoUAEDBgzAX3/9ZXHfT5w4gdjYWFSpUgXu7u7w9fVF7dq1MWbMGKG2DP+98pp36NChknOG/DtxpM8Uf78GgwFfffUVGjduDB8fH4kWLIxjYTAYEB8fL+hrnU6HgIAAPPfcc4iNjcXu3bstHkOimGEE8YwRFRXFALDJkycrzt+/fz8DwACwU6dOSebp9XrWq1cvYX7FihVZREQE8/HxYQCYs7Mz+/bbb83WuWTJEmEZX19fVrduXVanTh3m5+fHALAKFSqYLXP58mX23HPPMQDMycmJVapUidWtW5d5eHgwAKxUqVJs9+7dqp9v+vTpLDg4mGm1Wla7dm1WrVo15uTkxACwypUrs7S0NLNlQ0JCGAA2a9Ys5uHhwVxdXVmDBg1YeHi4sP/NmjVjDx8+NFuWnz979mym1WqZl5cXa9iwIatWrRqbMmUKY4yx9PR01rJlS6Ft5cqVWcOGDYXP5O7uzjZv3qz4vSQkJDAXFxcGQPhMdevWZd7e3gwAi4qKkrQfPHiw4vd848YNVq9ePQaAtWvXjmVkZChuTwl+nYMHD7apPf9dJCQkKM5PSkoSjoXaso58j48fP2YxMTHCuv39/VlERASrUqUK0+l0kn2Kj49nkZGRQtuIiAgWGRkpPOLj4236PGvXrhW+H09PT9awYUMWFhYmrLddu3bs8ePHZsvx87/99lvm5uYm/G6CgoKEeT179rTpeBcG48ePZwBY8+bNGWOM3bt3T/ic27dvV13uo48+Eva/TJkyLCIigvn7+zMAbP78+cK8pKQkyXJffvklA8BcXFxYpUqVWEREBKtatarwvQUFBbFLly6ZbS8hIUH4H/Tr148BYOHh4ax27drCsqGhoezevXts0aJFTKPRsMDAQNagQQPm5eUlbPOPP/7I1/Hi90PpN717924GgIWEhLBx48YxACwwMJA1atSIBQUFCeezLVu2CP/zChUqsIYNG7IaNWowd3d3BoD5+Piww4cPm63f0v+JMcY+++wzptFohPNn3bp1WcWKFYVlunXrxnJzc82WO3r0KCtfvrzQLjw8nEVERLCyZcuaba9Xr16satWqDADz9vaW/JciIyPZzZs3hbb5/c9YOtcWxrH4+eefhd+Sl5cXq1WrFqtfvz4rU6aMsJxer1fcHkEQBJF/SMOQhiENQxrGVopawzDm2PWlWB+8//77DACrVKkSq1evHnNzc2MAWOnSpdm5c+fYxo0bmaurK/Pz82MNGzYUzkMajYatXbvWbH8c1RSFuU9izSYmJyeHNW7cWPF3k5eXx1588UUGgEVHRzODwaD6/VmD/y8onQv488SoUaNYeHg402g0rGbNmqxBgwasZs2aQruePXsKeiA8PJw1atRIOAfzny07O9ts/ZMnT1bd9v379yXn2KCgIFa/fn1WqlQpBoDpdDq2evVqxc80Y8YM4Xfn5ubG6tWrx55//nnm6ekp2d7x48dZZGSkcO6tWrWq5JwxatQoYZ2O9pn899uqVSvWo0cPBoAFBwezRo0asYCAAOF/UxjHYuDAgZJlIiIiWLVq1YTjMGDAAMXjR5QMKDhDPHNYEzb8hYy3t7fZhdiUKVMYAObh4cE2bNggTM/KymJvvfWWcKI+dOiQMC83N1e4oPniiy/Mbl6dPn2aLViwQDItOzub1a1blwFgMTExLDk5WbKt9957jwFgAQEBLCUlRfHz6XQ6Fh0dzW7cuCHMO378OAsMDGQA2KRJk8w+O9+p6nQ61qlTJ8m69+3bxwICAhgANmLECLNl+Y7A2dmZvf322ywzM1OYxx/HIUOGCBd6+/btE+anpaUJN5S9vLzMLvZ27NghXMyPHDmS3b9/XzL/0KFDbNGiRZJpSsLmzJkzLDQ0VOiccnJyzD6HJYpD2DjyPfL76evry9atW8fy8vKEeY8fP2YrV66UHH/GTN+f0oW2tc9z+vRp4UJ12LBhEuH722+/CRcxI0eONFsnv12dTsfeffddye9m1apVwoXWrl27VPersNDr9cJxFl98de/enQFgffr0UVzut99+Ez7Xp59+Khx/vV7PJk+eLIhLpeN9+PBhtm3bNpaVlSWZfvfuXfbqq68yAKx9+/Zm2+QvBHU6HStXrhw7cOCAMO/ixYvC775r167M3d2dLVmyRLiwf/jwIYuOjmYAWGRkpEPHSr4fSr9pXug4OzszV1dXlpiYKOyDwWAQPvOZM2fY+vXrzW46pKenC4KxZs2aZsLE0v9pzZo1wn9ixYoVkv/En3/+Kdy8mTZtmmS5W7duCb+B1q1bs3PnzknmJycns48++kjxGMhFl5iC+M9YOtcWxrHg+6R3333XrG+8fPmy5LdOEARBFDykYUjDkIYhDWMLxaFhHL2+5PUBHxgVBznv3LnDGjRowACwtm3bMl9fXzZt2jThXKTX64XfTHBwsNl1qKOaojD3yZJOuHjxohA4+Oqrr4TpH3/8MQO4pLZbt26ZLWcPtgRnnJ2dWd26dSW6R9ynbNy4kR0+fNjsuJ0+fZo1bdqUAWAzZ840W7+lgESHDh0YANaiRQtJckFeXh6bN28ec3JyYm5ubuzs2bOS5fjj6eTkxKZNm8YePXokzDMYDOy3335jK1euVDwGauc3xhzrM8X74+zszPz8/CSBUL1eL/xOCvpYnDhxQuj/5ckPBoOB7d27VzFYSJQcKDhDPHMoCRu9Xs8uXLjAPvzwQ+bs7MwALttHzMOHD4XO8rPPPlNcNx/h7ty5szDt5s2bwoWKrfBZahEREaoX3126dBEunJQ+X0BAAHvw4IHZcnPmzGEAWP369c3m8cKmdOnSitlYq1atEi5C5RcG/IVadHS04v4mJSUJ4uSHH34wm6/X64UsJXHWAmNMuAAaNGiQ4rqVkAubgwcPCgJz/PjxDmWc8Ou09Jg3b57QviCEjb3f419//SWsc+fOnTZ/tvwIm9jYWAaA1apVS/G4Ll26VPjdiAWaeLtt27ZV3Cb/O3/zzTdt/iwFxcaNGxnAZeCIv4NNmzYxAMzV1dVMZDPG2AsvvMAALjtMiXbt2tl0vJWoUKEC02g0khEYjEmDIkr/r0WLFgnzlW5M/PPPP8J8pd+brdgSnAHAPv74Y4e3wWdJyi+G1f5Per1eOLeJL67FHD16lGk0Gubr6yvJ9OIvvqtXr66YNamELcGZgvjPqJ1rGSucY+Hq6soAsNTUVNXtEgRBEIUHaRjSMKRhzCENY05Ra5j8XF+K9YHS+emnn34S5nfs2NFs/r1794Rr1JMnT6odEkXUNEVh7pM1nbB69Wrhuzt16hQ7cOAA02q1TKPRWBzxZCu2BGdcXFzs1qg858+fZwBYjRo1zOapBST4oGClSpVUdejo0aMZAPbGG28I07KzswV3A6VgrxrWzm+O9pmMSbXwqlWrVPehoI8F/7tR++8SJR+qOUM8s0ydOlXwl9TpdAgPD8fHH38MPz8/fPbZZ5g4caKk/b59+5Ceng43Nze88cYbiuscP348AGDHjh2C12bZsmXh7u6OtLQ0bNu2zaZ9W7t2LQAgLi4OOp1OsU3Pnj0BcLUjlOjfvz98fX3Npjdr1gwALHojx8XFwcvLy2x6nz59UK5cOej1evz666+qyyrxyy+/wGAwoFKlSsK+i9FqtRg3bhwA4KeffhKmJycnC56eH374oeo+W2Lr1q1o27Yt7t+/j3nz5uHzzz+3WIvBGmXLlkVkZKTio0KFCg6vVwl7v0fe37ZZs2Z44YUXCnRf1OB/1+PGjVM8roMGDULZsmWh1+vx22+/Ka5j5MiRitNt+b0WFnzBzJiYGMl30KFDB5QpUwbZ2dn4/vvvJcs8evRI8LEdM2aM4nr537kaWVlZ+P777zF8+HC89NJLaNmyJVq0aIEWLVogIyMDjDHBO1eOn58fXn75ZbPpDRs2FF6/9tprZvOff/55uLm5AQAuXrxocf8KgmHDhlmcn5eXh02bNmH06NHo1KkTWrVqJRyD8+fPA+C8hW3h8OHDuHz5MsqVK4fu3bsrtmnYsCFCQkKQmpqKY8eOCdN//PFHAMBbb70Fd3d3m7ZnCwXxn1E711oiP8ciJCQEAMx+8wRBEETRQhqGNIyjkIaRQhrGRH40TH6uL8UoaRRrGsbf3x9hYWEAlI91fjVFYeyTJfr27Yu4uDhkZWWhd+/e6NevH3Jzc/HOO+8gOjrarnU5Stu2bSU1ZpS4c+cOvvjiCwwcOBDt2rUT9CpfQ+Xs2bPIzMy0aXt8v9GvXz/Fcwag3G8cOHAAN2/ehKurK95++22btmULjvaZYkqVKoXevXvbvW1HjwWv0w4dOoRLly7ZvV2i+NEW9w4QRHERHByMSpUqAeCKK164cAGZmZnw9fVFmzZtzNqfPXsWAFfwWemiHwBq164NgLu5mpycjGrVqsHJyQnjx4/H9OnT0alTJ9SuXRtt27ZFs2bN0KpVK5QrV85sPXyRr6+++gqJiYmK2+ILxYmLfIupVq2a4vTAwEAAQEZGhuJ8AKhVq5bidGdnZ9SoUQO3bt3C6dOnFds8//zzitP54/fcc8/ByUk5Lswfv6SkJOTk5MDFxQWnTp0CwF3kqH0mS2zevBkzZsyAs7MzVq9ejT59+ti9DjkdOnRQLShZ0Nj7PfLHiy8KW9ikpaUJRT3Vfjc6nQ41atTAnTt3cObMGcU2+fm9FgZ37twRBDZfQJNHp9NhwIABmD9/PpYtWyYRZRcuXBAKh6r9F9SmA1zxzk6dOqkWqORJSUlRnF6lShXF6WXLlhVeh4eHq7a5cuUKHj58aHHb+SUgIECyP3Ju3ryJTp06WQ2+qB0DOfz5NDMzEy1atLC6vqtXr6JZs2bIyMjA5cuXARTs/6mg/jOWfkdqOHosAOC9995DXFwcRowYgTlz5qBdu3Zo1qwZoqKiBEFAEARBFD6kYUjDOAppGBOkYQpOw+Tn+pInICAAPj4+ZsvYqmHOnDljpmHyqykKY59s4YsvvsCBAweEc1Xjxo0xY8YMu9fjKNY0xvr16zF06FCLn40xhvv379sU9OV/Pxs2bMAff/yh2CYrKwuAtN/gzxm1atWCt7e31e3YiqN9ppjq1atDq7X/drujx6Jp06aIiorC3r17Ua1aNbRs2RKtWrVCs2bN0KJFC9XPQZQcKDhDPLPExsZiypQpwvvU1FS8/fbbWLZsGaKjo3Hy5ElB+ACmCyslIcJTvnx5s/YAl+EWEhKCL7/8En/99RdOnTqF+fPnQ6PRoG3btpg9ezbq1q0rtH/w4AEAU4djicePHytO9/T0VJyuJirE8BeTluapXWiqbdeR4+fv74/09HQAUM0csMalS5eQl5cHX19f1KxZ06F1FCf2fo/5PV72Iv4d2PLd2vu74T8nY8zmfRo9erTiRfiHH36IDh062LSOlStXIjc3F+XKlVPMUho8eDDmz5+P48eP4++//0adOnUAmD6fk5MTypQpo7hutf+XwWBAz549kZSUhPr162Pq1Klo2LAhAgIC4OLiAgBo1aoV9u3bB71er7gOteMozga01saeY+0IatvnGTp0KE6cOIHKlSvj448/RvPmzREYGAhXV1cAXBbjd999p3oM5PDn07S0NOzfv99qe/6cyv+XgIL9PxX2f8YSjh4LgOsz/fz8MHv2bBw6dAjffPMNvvnmGwBAkyZNMGvWLLRu3drufSIIgiDsgzSMOqRhSg6kYZ4dDZOf60uewtAw+dUUxaWrPDw8EBkZKQRnhg4dqjoSsTCwpDGSk5MxcOBAZGdno3fv3hgzZgxq1KgBHx8faLVaGAwGODs7A4DdWu38+fPCaCY1xKNxCuuckZ8+k8cRnQY4fiw0Gg22bt2KWbNmYcWKFdizZw/27NkDAHB3d0e/fv3w6aefIiAgwKH9IgofsjUjCCO+vr5YsmQJmjdvjgcPHmDEiBGS+aVKlQIAIcNGiZs3b5q1B7iTZVxcHE6ePIk7d+7gf//7H8aNG4dy5cphx44deOGFF3D9+nWhPR/Z3rVrFxhXG0r1kZycXBAfX8Lt27etzhN/Pltw9PjxWRB8lp29jBkzBoMHD0ZKSgpeeOEFm62QCgJrF2WPHj0q8G3m93jZi/h3YMt3a+/vxhFOnTqF/fv3mz0s/a7l8HYAt27dglarFexD+Ef9+vWFtvHx8cJr/vMZDAbcvXtXcd1q+/Hnn3/i9OnTcHd3x6+//oouXbogKChICMwAto8WeVK5desWtm/fDoDLGO3bty8qVaokiCjA/mPAn09btWpl9XzKGBOG44szsAry/1Sc/xlHjwVP9+7dsX//fty/fx/btm3DhAkTUKVKFRw+fBjt27cXsr0IgiCIooM0jAnSMAUDaRgTpGFMqO1Hfq8vC4PC0BRFxaZNm7B06VIhwPf+++/jypUrxbxXHGvWrEF2djYaN26M1atXIzIyEv7+/sIoEUeOKf/7WbZsmU2/H57COmfkp8/ML44eC37ZGTNm4OrVqzh//jyWL1+OgQMHQqPRYNmyZYiJiRFGxxElDwrOEIQIJycnzJ8/HwDnGcxHmwGgRo0aALhsAbUhnHyWmJubm6pPZ5kyZdCtWzfMmzcPZ8+eRVhYGO7fv481a9YIbfhhkn///Xc+P5Fj/Pvvv4rT8/LyhGGe9mZw8cfvv//+g8FgUGzDH7/KlSsLN6T5bJ6UlBScO3fOrm0C3He6bNkyDBs2DCkpKWjbti2OHDli93ocgc+YULuQdeTzWIM/XgcOHCjwdSvh4+MjZJX8888/im1yc3MFK4CiyPzbs2dPvoTAoUOH8N9//wHgMsTUHn5+fgCAVatWCV6z4eHhQrYQvw45av8v3sqsZs2ailktDx48KJTfTEmCPwalS5dWHFKfm5uLo0eP2rVO/nz677//qp57lChVqpRwHrfn/2TNC744/zOOHgs5Pj4+6NChA2bOnIkzZ86gadOmyMnJwdKlSwtqVwmCIAg7IA3DQRqmYCANw0EaRora/6ugri8LksLQFEXBtWvXEBsbCwCYN28eunfvjtTUVPTv379E3Fjnj2uLFi0UR8EdOnTI7nU62m/w54x//vnHLvtAa1qtoPpMRyioPjQ8PByDBw/Gd999h0OHDkGj0eDAgQOqdWuJ4oeCMwQho1GjRujcuTMAYPLkycL0Fi1awNvbG1lZWfj6668Vl50zZw4AoF27dpJsdzVKlSoldCrirDO+eNiiRYtUh/wXJkuXLlXMiPrhhx9w8+ZN6HQ6tGvXzq51vvTSS3BycsKVK1eEItticnNzsWDBAgBAp06dhOkhISGIiIgAAMycOdOubfI4OTlh8eLFGDFiBB48eIB27do5dOFgL1WrVgUAHDx4UHG+2u8oP/Ts2RMajQYHDx6UCHNreHh4AFC3mLAE/33Nnz9fMcPuu+++w507dxz63RQHy5YtA8Bd8N26dUv1ce7cOeh0OqSkpGDTpk0AODHbqlUrAMCXX36puH7+dy6H/w5u376teBznzZuH3NzcfH++kgx/DNLT0xV/iytXrsSdO3fsWmeLFi0QFBSElJQUSYagLfTq1QsAd+x5f19r2PJfKq7/TH6OhRparRZNmjQBIO3HCIIgiKKFNAxpmIKCNAwHaRgpahqmMK4v80thaIrCJi8vDwMGDMD9+/fRpUsXjBkzBkuXLkVwcDD2798vsbMsLvjjKh45wsMYw+zZs+1eJ99vrFy50q5RYs2bN0dQUBCys7Mxd+5cm5ezds4ojD7TVhw9FpaoXbu2UDuJtFrJhYIzBKEAL2h+//137Nq1CwB3wfLWW28BAKZMmYKNGzcK7XNycvDuu+/i999/h7OzMz788ENh3n///Ye4uDj88ccfZpkkv/32G3bu3AmAE1Q8r776KmrXro3z588jOjpaMXL+77//4qOPPsKWLVsK5kOLyMjIQP/+/QXPS4DLYho3bhwAIC4uzqIHpxIhISEYNGgQAGDUqFGSAmcZGRkYOnQoLl26BC8vL+E483z22WdwcnLC8uXLMXbsWLOhq3/++Se++uori9vXaDRYtGgRxo0bh7S0NERHR9vkiZsfunbtCgDYsmWLJKswKysLH3zwgV3Cw1Zq1aolZFf17NkT//vf/yRiIysrC4mJiWYF5vhihjt27LB7m+PHj4ebmxv++ecfDB8+XCKKd+3ahbfffhsAMHz4cLt/N0XN48ePsXbtWgCcv68lAgIC0KVLFwAmMQQAEyZMAMAV8pszZ47wv8/Ly8P06dOxe/duxfU1a9YMOp0O169fx6RJk4TsKIPBgEWLFuGTTz6Bm5tb/j5gCef5559HQEAAcnNzMWrUKElAZP369Rg9erTdx8DFxQWff/45AM7Le/78+RKPXgB4+PAhfvzxRwwbNkwy/Z133kFgYCDOnDmDTp064cKFC5L5ly9fltwAA0z/pX///Vd1OHxx/WccPRbp6el4+eWXsX37diHDkufYsWPCf0bcjxEEQRBFD2kY0jAFAWkY0jD2aJj8XGsXFoWhKQqb6dOn4/fff0dQUJBgT1e6dGmsWrUKzs7O+OSTTwrlv2cPUVFRAIB169bhp59+EqZnZGRg2LBh+PPPP+1eZ+fOnREdHY379++jTZs2Zv9xgKvB9dlnn0lG6et0OiHwPW3aNHzyySeS3x1jDDt37kRiYqJkXfw5Y/fu3YojvRztMwsCR49FYmIiJk2aJIz049Hr9fj888+RmpoKZ2dnia0hUcJgBPGMERUVxQCwyZMnW2zXqVMnBoC1bNlSmKbX61nPnj0ZAAaABQcHs0aNGjEfHx8GgDk5ObFvv/1Wsp4TJ04I7T08PFidOnVYo0aNWFBQkDA9JiaG5eXlSZa7cuUKq1+/vmRbTZo0YfXq1WO+vr7C9ISEBMXPJ5/Ok5SUJCwrJyQkhAFgs2bNYh4eHszNzY01bNiQVa1aVVimSZMmLD093WxZfn5SUpLqMU1PT2ctWrQQ2oaHh7OIiAjm4eHBADB3d3e2efNmxWWXLVvGdDodA8B0Oh2rU6cOq1u3rnDso6KiJO0HDx6s+j2/++67DADz8vJie/bsUd1fOfw6Bw8ebPMyPXr0ED5vhQoVWEREBCtVqhRzc3NjixcvVv0u8vM9Pn78mHXt2lWY7+/vzxo1asTCw8OFYyhf7+zZs4X2NWrUYK1atWJRUVGSdpb2ae3atczFxUU4rhEREaxy5crCOtu1a8ceP35stpy1301CQoLi91tYrFixggFgLi4u7O7du1bbb926VfjvX7t2TZj+/vvvC5+tbNmyrFGjRiwgIIABYPPnz1f93B999JEwr0yZMiwiIoKVKVOGAWDDhg1T/Q6sHSdLvxce/v+/e/duq59bDX4/lLaze/duBoCFhIRYXEd8fLywDh8fH9awYUNWoUIFBoC1b9+eDRw4UPG/be0zzp8/n2m1WgaAubm5sbp167ImTZqwKlWqMCcnJ9V9O3LkCCtXrpyw7qpVq7KIiAgWGBiouD2DwcBq164tnPMjIiJYVFQUi4qKYjdv3hTaFdZ/pjCOxYMHD4T1ubi4sOeee441btyYhYaGSvqGR48eqe4TQRAEkT9Iw5CGIQ1DGkaN4tYwjDl2rW2LPrB2rNW+X0c1RWHuk9rvYu/evczZ2Zk5OTmxXbt2ma1v8uTJwn/x3r17qvtlDX6/lM4Fls49PHl5eax169bC5w8LC2MNGzZkHh4ezMnJia1cuVL12PCfQWnbDx48YC+++KLZb69BgwaCFlbbt+nTpzONRiOci+vXr89q1arFPD09Fbd3+PBh4fdYoUIFFhkZyaKiotjYsWOFNo70mYzZ/r8v6GMxb948yXmzQYMGrH79+pI+d/bs2Rb3iSheKDhDPHPYKmyOHDkinMh27NghTDcYDGzVqlXshRdeYH5+fkyn07GgoCDWr18/dvToUbP1PHr0iMXHx7P+/fuzGjVqMD8/P6bVallAQAB78cUX2cqVK81EDU92djaLj49n0dHRrEyZMkyr1TJPT09Wo0YNFhsbyzZt2sQyMzMVP19+hM3u3bvZqVOnWK9evVjZsmWZi4sLq1atGpsyZYrixSljtgkbxhjLyclhixYtYs2bN2fe3t7MxcWFhYSEsGHDhrFz585ZXPbMmTNs+PDhrEqVKszNzY15e3uzmjVrsmHDhrF9+/ZJ2lq7uJg4caIgNsXfryUcETbZ2dls+vTprFq1aszFxYUFBASwHj16sL/++svid5Gf75Ex7nf6ww8/sA4dOrCyZcsynU7HAgMDWdOmTdmMGTMkN4kZ4y60Zs+ezerWrSsITfnxs7ZP//77LxsyZAirVKkSc3FxYT4+Pqxly5ZsyZIlLDc3V3GZkiZs+IvNnj172tQ+NzdXuEkxY8YMybwffviBNW/enHl4eDBvb2/WsmVLtnHjRsaY5c+9dOlSVq9ePebq6sq8vb1Z06ZN2dKlSxlj9l/o8zxJwRnGGPvf//7HmjVrxtzd3ZmnpyerW7cu+/zzz5ler1f9b9vyGU+fPs1GjhzJatasyTw9PZlWq2WBgYGsdevW7NNPP1U9B6WkpLDJkyez+vXrMy8vL+bm5sbCwsJYt27dWGJioln7q1evssGDB7Pg4GDhZoLS910Y/5nCOBa5ubls1apVLC4ujtWqVYv5+/szZ2dn5ufnx1q2bMm+/PJLlp2drbotgiAIIv+QhiENwxhpGNIwypQEDcOY/dfahRkIYcwxTVHUwZmUlBRWsWJFBoB98MEHiuvLzc1lLVu2ZABY165dVffLGvkNzjDGBVEnTJjAwsLCmE6nY2XKlGGdOnUSAsaOBGcY4/7L69atYzExMax8+fJMp9Mxd3d3VqVKFdavXz+2evVqlpaWprjsn3/+yQYOHCj8h/38/Fjt2rXZ2LFj2V9//WXWfvPmzax169bM19dXCNTI/6v29pmMFUxwxpFjceXKFTZ79mzWqVMnFhYWxjw9PZmLiwsLDg5mffr0Yb///rvF/SGKHw1jCuaaBEEQBEEQTxCXLl1ClSpV4Ozs/NTX5iEIgiAIgiAIgnhSmDRpEqZPn464uDiJJRdBEFRzhiAIgiCIpwDeX7506dLFvCcEQRAEQRAEQRAED2k1glCHgjMEQRAEQTzxrFu3DgDQoEGDYt4TgiAIgiAIgiAIAgDS0tLw66+/AiCtRhBKaIt7BwiCIAiCIBxl9OjR+Omnn5CUlAQAGDlyZDHvEUEQBEEQBEEQxLPNrVu3EBMTg7NnzyItLQ2VKlVC165di3u3CKLEQSNnCIIgCIJ4Yjl16hRu3LiB+vXrY/Xq1ejSpUtx7xJBEARBEARBEMQzTVZWFv788084OTmhR48e2LVrFzw8PIp7twiixKFhjLHi3gmCIAiCIAiCIAiCIAiCIAiCIIhnBRo5QxAEQRAEQRAEQRAEQRAEQRAEUYRQzRkrGAwG3LhxA6VKlYJGoynu3SEIgiAIgiCIQocxhoyMDAQFBcHJifK5CMuQZiIIgiAIgiCeNQpCM1Fwxgo3btxAcHBwce8GQRAEQRAEQRQ5V69eRcWKFYt7N4gSDmkmgiAIgiAI4lklP5qJgjNWKFWqFADuIHt7exfz3hQfP/v4CK87VDa+eNH43ApgkdzLv0tXAwD8i+dxAVUBAJcQhiRwC11DBdxLqgjc0ABXjMvfAnAHQIrokQbgEYDb/FYzAWQAeABAb3zOBZAF4LHxmW+Xa3zNPwOmn7oWgM74rAXgbnxoAZQC4Gac72d8784t5gbAB4Cn8dnf+IDxuSyAcsb3lYzPQQwBYdcAABVxHQAQhksAgMpIAgCE4zwA4Hn8CwCoc/8cAECzX7Trv4te7+Cefr4EMzqkpZlPJAiCIEok4n4VEPWtgKR/5ZH3swDX1wKQ9LcAkITKuIYKAMD1uYDlfheQ9r1pMHWruAOub80C1w/zrzNh6nP5Z/6hNy7L98N8Hyzuf/k+2A2Ah/G5lOjZHVznClMfHART/+tvmo1y4PpeY78r7nPF/a24rxX6Wb6PVehfn/V+NT09HcHBwcK1MEFYgjQThzXNBEjP5+LzuPwcbnb+tqSZFM/bGTCdu/lztfh8rYe5bhLfHuDP0/w528P47GacXkr22t10vgbMdROvnXjdJNJMAMzO34B1zQTA/HwOCOd0wFw3PevndoIgiCcJmzQTYNbHAtL7k4CyZgJg3ucC0n4XsKHvteWepVw3AdK+GDDdswS4TpV/z/e3vH7yg9m9S7X7luJ7lkbNBJj6XXGfK+5vze5P/g7STAoUhGai4IwV+GH53t7ez7TQ8BC99nY2vnARzTwBsCjAyzjTHTq4whUAoIM7nOEJAHBCKaCUN+CpAWoCSAZ3AnE1rk8HwBlcNSQncOedTMB0ctICSDc2yjA20hkXlt8MgmiaeHkYV6yTPTsB8ILp5hEAGL9zjWifnAGkGpsFGPfdzTitPGD8qEApBidv7s/Jf36dMdjDHxt34355gTtu3vyudwA0e42v+eMM47YB9KkKbDkPCc/y75MgCOJJYotGI+lXAVHf2l40UdSIGU/xfD/7N2rDHcA5VIcrgAuoAh2Aiwg3dqPGi8NS3sA1Ddc3uQO4CfN+9y6kfa/gSJQJ7mJfC4ABMBgffCONsQ0gvdHnLJoGmAdn+GdeYHgYnz3B9bvuAAIhJEhojLvhDFOXz/e95QGEGldv7HfFfa4rXFENZwHoJH2thj+2fB/rbHbIqV81QhZVhC2QZuLowxi2GI/F3ktAl6oAdoM7txtPMOLzOa8FXOEq6ARneMIJpVC2bhruXArmdBN/L4Y/d5cHdz7k5ZDkb6oDl+HG6yYncOdqjXEef2NIC+nNIX5ZHvn52gPSc7crTMF047k7y7gZX5h0k/i8DZjuL90Fd/4uxd0kcvIuhRuogWBctaqZLnnXRL37pwGIzucvAdhlfM13Q4B5f/sM/z4JgiCeNFQ1EyC9VybrY0+Wrgkv4yxxXwtA0t9eRTBXjL2UcUFPDXefUtzv3oPy/UqzvjcT3D1FPYAc43ve5orXTe4waSa9cYW2aCb+viX/4O9dloLQB8vvW8p1Ey+3SjGUrXwVAKebdHBHOC4CcIU7dKiDUwCchfuTGg9w/asLSDNZID+aiQykiQKFv0gGYLwZwlEFF6QNKzI71+yuMl0eX9QpTBOTK3uvlz1nyuZnmp5SwUXHUxVWe1thGsAJKjs4WbqmXe0JgiCIJ58uVS3PZ1GOrVe1D1Lqs1Jh6uPkXaEwQR6I0cvey/tYS+hEz3xWtgLu4G7yWUN2XWF23WFEfJ1CEARRpOwyn8TdAOHgboxYIRBcchiPL7gbLb7yhu6yZ0AaeFFDrzJdaVQkoNBhqHPP9qYXEQ6ASzwAuEQEgEtMsAlRooO8j91CAWeCIIgnApvP1y+oz1LrN/h+xmas9mHpotfivtFejaQGH6zhgzRQeBZRBtz1QgC4aweiREPBGcJu5CM27IWL0DoCH42Vj4KRv1ZCLjTEAkM+T02UqCA/SSebN7kKy0Eai0LDQkcjhoQGQRDEU4TCud/WAL5in5Mse2/zTTL5jTdrfaY1SzNxMEZ+o9CCwChMtnNP4uubLszeJBKCIIiCIxgKeqm8lYUsnjp1stdKCW62wJ/bxX2DlQDNXZj3OTdFr6/lT8M4msBAEARBPFlIgu3tzedb6g/4ID8f9FdE3B/dVJh/F9Jk7UyzF5BqIXFQpiACNGL4fls2akUxYQOK1xCK1xoyBEcfolCh4AxhE4o3Kbart3c4CywAXITXFyZ7MAmOZoGpoXaytDELzE6hoZYFpoRix2IhC4wgCIIo+RREIP3pyQADTBlgPI4HZuTJH+LrD/F1iYBCFjtBEERho3SjQ+w4oEioI1uS6yaloLg9CW7yYLy4X0iHmX4qDscBG5PaCIIgCILHLrcBQL1/M3MbAKy7DVjTVDa6DQDW3QZCzSeJ3QbE1yLkNlC0UHCGyD923NywOwvMF8WcBZau0NbIXRtXTRAEQRAqlLgMMAG1DDAx+Q3QiPtseWcve+8Lm4bn25IBRhAEURJQu/FRsHbQSo4D8nlqKNXxtPTaCG8HrURhOw6IUehTechxgCAI4inhiXcb4Mmn24Ct+W12X1OYQ24DBQ8FZ4gCw6HhbqGObCm/WWBiSzOgUOrOJJteFmUWGAkNgiAIAshHBphqvRkepXoz4mdbULshqDA83xcqI2nBJXeEWt6SUjY6Dc8nCKIwEd+oEG5gWHAcsISqHbTcccAMteK81upzAsp20PJ5DtadEVPYjgMiyHGAIAjiycLivS0LwXcxfDBfntCm6DZgKaHtWXcb4HHwWoawjrUrM4Kwm3r3TwsBhmo4K5wIq+CC9CRYkdnpMewO00lPB+6kp4W5gNApTFNC7ocvRg+7Tnb3IC3OqcBVBCMYV3ER4ZKMuHOojmo4i79R2/KJkKc9hJNil6r5rwFEEATBo9frkZeXV9y78dSyu0YNaENCJNOygkRvfIzPTUyTWA73fNq3CpDFvXZBaQCAF7wAAKXhAgB4CA1uoRwqIhfuecbGGg1wHYALAA9w9+seAPAzbsDJOD3TOC+L33I2gIcADACY8b0BpoQIfro7uD7TA1y/6mJ8djWuh8/4cjYu52x872Zs42pcBqL1ZwPQcE28jZsoBaC0cb+9jZtz4ZqhHAPyAP+sXARCg0q4DMAFXvASjhVDedRMvYgshEDDd/v88TZ+B9oc03HPysrC04azszN0uvxYwRIEkW92wSzpqg5OCTeQwnHR8mhI3nFAHmjng9iSOIm7cQL/rIP0JpGSjrJErmgdvBbj9RLfiajAOw5Y0UsEQTx7kP4glJBrpjahIpniI27IPbFGAIzX8iyL6ywtaSYAYMY+0T0vi9NM3EKczgC4bo3/ad43TteA0yO8fEEWuL40z7QDMMCkeVyM792MbbTGZ3mCm1b0LNZNvF7SghNu7uA0WJ5xuxpuH3jdVBYmzeQHwB+mz6UBkMfgn5WLcrgFQIPSRs3EHy9mvNDIynkMAJxuOgzumJNmKjQoOEM4xJbz+ctAKlv5qnpWr/ii/ZF8pjeUs7P4gIwlgaGHcrYunwWshXkWmA0BGnlg5iZMwumaRnXY4AVUsVqPh0VRhi9BEIVPeno67t27h+zs7OLelaeast98I3nvoQWS+Dfi7sZN9DqFe3J+wPVfOdAhHEAudAgFoIcWkcb3AJBr7DgNuclcl+YFoAq46/dQ47PB+IDoPRM9CzOY6FHKuDKm8IB4QdFrcQKGRvSs9HAybjMVXCKGxjSZf3Y2PvOvec1j1AROSbnQwgNaYyKIDqHQoil3/KBHssF4jcAHxDqA6+pbA49zOS0DAB4hIUhKEr6ZpwpXV1cEBATA29vCTVSCIEoEwbgqtVwJhaINmAQ+DiOZwCedZRin8QEWXjfxGkgtwY3XSfxrsZ7KFL3nk+hEHRqv5Xxlq7wNkz1lMoRRkHcuBauPFlLgZOma5vZwL8Am2+0tGg3ZsRBEMUL6g7CEXDclKbkhK2imHCcdnB9wr8ON/VOo8Vlv7MvEusmQqwWQzMmcXHC6KRSmoAyvm1Q1E4NUXLnAFJCxppnE723RTU6i14/BJbQZ+3ZnWNdMAJBl0kwAoEW4UVUCWjSFs/E6INmg5zRWEDjNBACtuaeyotumpJkKBgrOEDbThTHzoYXboTqk0K4sMMB0gS4eMlgkWWA8allggGqQ5i44OwErFKXQIAiCsIf09HRcv34dXl5eCAgIgE6ng4YsEguF9EfSjANvF9EbL9FrY9/HRHZeWVqucZZRheQYR6ZkG9O2cozPeqPYyMvRAXrj95gDU7eYB1OyVi5MOoIXHUz+RqxE5OJDPB0wFxs8coEhjrjIoy8aCDf6NJCKC16A8SJDB0776BicXbgPqIMeLsiBqzFzzQXZcDNGb9xyuWmaRzDdMHzIPaWLMsC8w8JUPseTC2MMer0eaWlpuH79OgBQgIYgihHNXnMrLrHjQP4ROw6IkeskSwEZXmOJAzNqr2VaiZdQvgqrdsBxgE9qI8cBgng6IP1BWCL9n3/gHyDtKATdpKCZAJNukmsmQFk3STQTwOkmXg/wA2H4AI1YN0k0E4wT+JlircQUnu3VTGLtxD8AkxgyCiW5ZhIbF8g0EwA4u+gFzQQArsiBC7ggqRuypJoJeKZ0U3FpJgrOEAWDwhB9NcyywABupIlSoWIeh7LAIJqenyww/r27+aAdX9l7lSwwHl5oqOGI0JBDWWAEQdjDvXv34OXlhYoVK5IoKmTkg77dxIeb745KmSYxo6bI1LoJJmF5cAYAGGSZX07GFhrokJvtwo1+12i4hCq+OxPfY5PrBUkGGK86+AUMCg9+QUAqNMT9j0bhWS4wxGldTsaddTI1F4sKPutLXGrOBYALg8bVCS7QA3CFMxi0xn3TIReuxvW5GfM4NHoAfJKmcdfE342bmzgN7+nB3d0dpUqVwrVr13Dv3j0KzhBECUBsBy0m/3bQgMlxQMkO2hFLM3lgpugcB2yBHAcI4smB9AdhCblm8nUVvRHfxTYGbJhYP2m5635eM2XDzTjQ3hXOxvdcahifDOYC5Gi4dTGYbg3yP8tccBqElz+8nBG6K72oAR+1UdJMSsEZ+agZawlt4mQ2PvriZB7HEY+W4XWTUTMBgMbVCU5wgrNx+1oYoDP26a5wkmom7iAKu5WabXJ0842IwNNIcWgmJ+tNCMJ2lC6IlYrxSgi1ZwtqF/3yOKO1YpdKhS0tvbYBq0XCTFCBS4IgSgJ6vR7Z2dnw8fEhYVTIpB49qj7TR32WmMdWbnzlKFp3yrC5JqVB9FrpZpmauJAj/10pZYSp4GxxB01lbVRnm3sgazIUGj5DaDQa+Pj4IDs7G3q9I6OMCYKwl/yO1lAdfR9gfJQBlzDmqdyMQ9x/5MdHXe6RL85cU7KeFnHX8mweVetrFZSCW7YmDVosOE0QRKFA+oMoKeRmu1hvBJhG0ojtzQAo6yU12zJHkY+iAUy38u2/pa91zZG8d4PJVtDDUj+eZvemnmiKWjNRcIYoFMwsuYxUwQXphHxkR5mEhTgIY20wmPxPpRSkkWeB2YA8MCMeBWR3lhtBEETRwBffpCLhRY+vlcCCOANMTrZxmH6WMERfYaRHjqjvsdT1KSIuyioeHSMfGWMJpaCMWkBGQWDYeoXqIt0fpYAMALjnPn3FKh2F/79T8V2CKDwUR7Hzo94VbIrFI+fFNSkVR9yXh2mkvhxfyHLZ5EF9ney1UoKbEpYS2MR6KR1m+ukRuHJicm6LXiebzzZzWpDB22dbRWTBTUltBFH8kP4gLGExoU2MglbK1OZj9Lu49JHNyWw8am4C1pLYHIHXUbIsNn7EDN+18w8FXOxJRH+GE9uKUjNRcIZwmOLLAhMPKSuqLDAlz2YjlAVGEMQTDmWtPYPwGWBm2JIBVhDIgzSOX5LKM8BsxpgBlioSY0/r8Hwx9H8niCeYUEcWcpc9i70hIZqmhoOOA5lQDsoAxeY4QBBEyYGuRwhbkCS02ek2oJbQVnhuA4D6yBlrrgOF5Dagg1W3ASWedbcBoGjPURScIezCYhaYAmpZYIoUahaYGrZmgYne80KDssAIgiAIGyjKDDDJ8HxbMsDMhueLZ4hhsH/0jBoaldeA2aWpPAPMBqwOzyexQRBECcCaHXT+HQeUktiUTqS2jJYByHGAIAiCKGlYchuwCWtuAxYHmSi5DQBPmtuA+DW5DRQPFJwhCg6FIfpqmA3RD3Vkg/ZkgYmnK4mJQq47oyA0KAuMIAji2eTJygBTqzdjqailGmo33SwMz5ej1NXDzuH5BEEQJQg1O+iCQa2IrT3JbHKewLoz7c1n85DjAEEQRMmg2BLaJDMUpjlcb6YgKDi3gfwidhsgChYKzhAFjrUsMDOKPAtMThFkgTkIWZsRBEE8exR6BphFCrLejHw4vvi1BaHBeyZbwsrwfKXaMzQ8nyCI4qJQ7KADoWwH7StvKE9os8fSTIwDdWcKyHGAT2rj4ZPabHYcEEGOAwRBEE8G+UloU0MxoS1fQQdrbgP5CdjY4TYAWHUbkFtBW3Ub4Ekzn/QsWEEXJRScIQoNtSwwsyH6DlEcWWAqdWdSYZ4FdluhHWzPAiNrM4IgiCeP5ORkaDQaDBkyRJhWIjPA9EDXdqHo2j6Ue19kGWBKSQP5uBS1MDxfDA3PJwiiOLBoB63gOKBmB23mOABwdtCWMLs3JQ/QAJxuUkpmU5pWvI4DPLzjgCXIcYAgiGcJJf3xrMAntA2NnQJPTW1cTr4uzFNzG7AJUZe3eNEUNKqjwbEjexQaWrMys7XejBwH3AbkiW0F4TZAiW1FBgVniHxRKFlgQAnNAhO9F2eBWSPZfJI8C4yszQiCIIqf2NhYaDQa+Pv7Izu7cMZtl7gMsDylieYZYFlZj5CYuAgTJw5Hr16RaNy4PBo1Ko8bN1T6cVWURs4o+CfbiDwDDAAunbuEuN6jUT2gEfzdI1CvQX98/c16MPGNUmMGmHh4vjwD7Nq1axg+fDgqVaoEFxcXBAUFYejQobh61d7PTBAEUYCEFsRK5DpJLcGtGOrOOAg5DhAE8SRSWPqjdevW0Gg0Fh979uwR2k+ZMkWYPn78eNX1vvfee0K7KVOmqLb7/fffhXbr1q1Tbbd8+XKh3eD33lNt9/V366EJaQRNSCMMGam+XbsopHoz33zzCUaM6IlOneqjRYtQvPjicxg0qD2+/34xsrIeK6yreNwGsrOzMW/aPDSu2hal3RqiStALeO31j3H30n0rK5ViMBiwcOFCNGjQAB4eHvD29karVq2wefNm1WUOHz6MmJgYBAQEwNXVFVWrVsWkSZOQmal8HfHgwQOMHz8e4eHhcHV1RZkyZdCrVy/8+++/du1rSYWCM4TdWMwCU0AtC0yR8uCG6KvhUBZYMdadKeYClyQ0CIIgbCMjIwM//PADNBoN7t+/j40bNxbLfliyNCuoDDDLKNebuX//HhYsmIzt2zcgJycL3t6+sjb5zQBTwBnS4fk25Fxc+e8UOjfujF827UB0h0i8MaY/8gx5GDn6U4yZMNvmDLCLFy+iYcOGWLx4MWrWrImxY8eicePGWLFiBSIiInDxopXrGYIgCBuwZgdt5jhQ4HbQJaDuDDkOEATxjFIU+uPtt9/G5MmTFR+hoaFm7bVaLRITE5Gbay4ecnNzsXLlSmi11vuO+Ph4AIBGo8GyZcusttc6O+OXffuQkppqPtMHiF+7GVqtciTio0/fxvHTmxBUoazV7QB21psxiJ5tdBtYty4ejx8/RNOmUejb91VER3fjAiHzpiA2tiuysvh+0h59VLBuAwaDAbExsZgzeQ78A/wwctxANG9aC/HLNqFZ+1jcvffAplUzxtC7d2+MHj0a6enpiIuLQ9++fXH27FnExMRg4cKFZsts2LABLVq0wPbt29G+fXuMGjUK/v7+mD59Otq1a2cWpExJSUGTJk0wZ84clC1bFqNGjUK7du2wZcsWNG7cGIcPH3b82JQQKDhDFCwKQ/TVMBuiH1oQO6CUBWYLhZQFpkBhFrgkoUEQBOEYa9euxaNHj/Dmm2/CyclJEBT5IT+WZg5T4Blg3BB9X18/LFy4Djt2nMHmzUdRs2ZdKzuilAEmfi8XHc7Sl2pXqFaG57//xvtIT0vHio3fIP67mZjx6Vs49mciWraoj4VLfsDBY39b2W+OsWPH4s6dO1iwYAG2b9+Ozz//HBs3bsTatWtx584djBw50qb1EARB2IKaHbQYVccBQOo4YIbYDlopWGMPhVB3Rkyy+SRyHCAI4mmlMPSHnPHjx2PKlCmKD6XgTIcOHXD79m1s3brVbN62bdtw69YtdOzY0eI209PTsX79etSpUwcvvvgifv31V6sjz19s3hw5ej1+2LYNgNRt4O/T53Hs1Gl0bBOpuGz58mVQvUZl6HVcfydPYOPf2+Q2YHNCG6BWb2bbtlNYvvwXfPTRfIwa9SHeffcTrF27Bx069MT58/9h8+a1Nqy7cN0GNqxYg73b96JHvy7YdmAdps96E+t/+AyLvnwPl5KvY+LHX0vap6oM6vrxxx/x448/IjIyEqdOncKXX36JxYsX499//0VISAjGjx+P5ORkoX1mZiZef/11aDQa7N+/H6tWrcKcOXNw8OBBjBw5Evv378e8efMk25g8eTLOnz+Pt956CwcOHMCcOXPw/fffY8+ePcjOzkZsbCwMBvl38WRBwRmiUHh6s8AcrDuTbL4ILzTUcKTAJUEQBOEY8fHx0Gq1ePfdd9GmTRvs3LkTly9fVmybl5eHTz/9FOHh4XBzc0N4eDhmzpypelG47+hRjJo2DU169YRXk1bwatIKEQMGYfH3GxTbO+kaIerFN3Dj+m0M6f8uqgc0QmipuhjUaRAuX+L26cLpcxjdrS+ala6EyFIBeLvHAKTcVkg/VhIYeRCSuzLSUvHJJ8PRvn05REaWwoABTbB9Oy8YTH2zh4cXmjRpDR8fP8V9NmEpA0ytkKWT+SRLqAzPTzp3EYd/P4wWbZrixQ6mO3MuLjpMf284AGDJ6o1WV5+VlYXt27cjMDAQo0ePlsx7+eWXUa9ePWzfvh2XLl2yYWcJgiA47LWDtuo4AHCOAwGyaRbtoHl0ste2aKeSUXfGbsjajCCIEkph6g9H6dGjB3x9fRVHuyxbtgx+fn7o3r27xXWsXr0ajx8/xqBBgzBo0CAYDAYsX77c4jKN69RBtdBQrFIICi37YTOcnZ0xuGcnYZrYbeC1IR/CU1MbV5KvcfMYwysdX0FFTUX8tPZ/knUxxjCma0c0cnPCrz8aNY8lB08RmzbFo2/f2oiM9EDHjqGYO/ddPHqUAXm9GVdXN9E008iatm07AwCuXUuyeCyUKVi3gXVLVgIAJs4cD0+NqT7n8Nd6oHJoBaxa/wsys7IEK2gxYivoTZs2AQA++OADuLubrjUCAgLw5ptvIjs7GwkJCcL0AwcO4O7du+jWrRsaNmxo+nQaDWbMmAEA+OabbyR21Js2bYKTkxOmTp0q2Y9mzZqhS5cu+O+//7B3r8JN6CcICs4QhYotWWAWKXFZYKL3lurO2DCChoeywAiCKIkwxvA4J7dEP5iSzaYD/Pfffzh06BCio6MRGBgoiAjxhaSY1157DRMmTIDBYMDIkSPRvn17zJ07F2PHjlVsv2DFChw4cQKNaj2HUf1exsDOHXDvQSqGvz8Tb083ZQaJRcaDB+l4scVgJCXdRJ/B3dGsdTPs2rYL/dv1x7l/TqN38054/PARYmIH47mIBtj1vx8xcVB/bmEbM8D0+hyMHPkijh/fi44dB6JLlyG4ffsaJk4cirVrv+H3CubiIj9Y8E22B4Xh+Yf37AcAtI5uIUx3z+XERoum9eDp6Y69h49bXXVKSgpyc3MREhICjcLNurCwMADA7t27Hdt3giCeGSzaQSs4DojtoMWYOQ4AnB20JazaQStVC7ZkB81TdHVnCsRxQAw5DhBEiYb0R8HpD0dxc3NDv3798PPPP+O2KPHr9u3b+Omnn9CvXz+4uVm2Vo6Pj4ezszMGDBiAHj16wMvLCwkJCVaPXf8uXfDv+fM4edp0HzFHr8eqjb+gfaumCAorY7ZMptZ8XzQaDeYmzIV/2TKYOHw8bly+IsxbNf9LHPz1F3QeMATRPfuo74ys61u1ai5mzx6D556LQN++oxEQUA6rVy/E6NFdkJvLN7asl/bv3wkAqFKlhnhvFZ7V6s0ABeE2kJ2VhROHT6BK9SoIDqkgaaPRaNCudRM8epSJo39bv59769YtACZ9JIaftmvXLpva+/r6ws/PD5cvX5Ykwd26dQsBAQHw8vKyaRtPIo4OLyCIAqNs5avmF9583RnxRbun8Vlyve8um6CDSRzwZyO9bL6SwNCK2qq9tlyQWZWbMImnaxoHRglxQsMs0PUCbLKR26LRKAtDgiAIC2Tq8/DcJAsFxUoA/01rDw+X/F/K8BYCr7zyCgAuY2zEiBFISEjApEmT4ORkuurds2cPli1bhrp162L//v3w9OQ6pw8++AD16tVTXP+cCRMQUqGCZHh+rmcuOg4ZhwUJazF2TD9UqlhOssw/f5/DqDdfweS5kwFww/E/GPEBVn69En1bdsGIKR/glbEjkAMdGGMY3bEn9v+yDWdOHkeNmg1sygC7d+8mgoOrIj7+AHQ6LQA9hg59GwMHNscXX0xEmzadUbZsOdlSSqLD1j6GExfHju3DsWP7IRUdzqYm/MMJgBYoXzEUXfoOMVubfHh+8nnuIr5y1VCzts7OzgirFIT/ziYhNyUXWq1WdXi+n58fnJ2dcfnyZTDGzAI0SUlcttu5c+ds+MwEQRCFQCgUR+Yr4w7T6H8dOO3EnfOlKOkkS+TCpL349fJ6KRPSRDoRqcZn8T222zDpv2SY2V1fRbBigOocqqMazuJv1DYLbLEoZTcHgiBKLqQ/Ck5/8MyePVvxprabmxsmTJiguExcXBy+/vprrFy5Eu+88w4AYOXKlcjNzUVcXJzFa+BTp07hyJEjaN++PcqVKyd8tpUrV2LXrl1o27at6rJ9O3bEjK++wrqfNqN1PS7YvunwXty7n4q4Pl0tfk45ZQLL4LMVX2JYx354p38slv6+CxdO/YUvJkxEpfCqeOfzL6UJbVbcBg4d3I4VK46gatVaAPLAWB4++mgwtm//AWvWfI2BA0eZLb5y5UJkZWUiIyMNf/11BKdP/4WmTaPQqVMvKI+EkQZjMjJSsXr1N5C6DThxbeSayfjo9+o4lPL35ZqruA1cvpgMg8GAKlUrme9BBlC1Mnd/9nzSVbSsXl95JUYCArjhu0lJSahZU5ogoaSZxO3lpKWl4cGDB8IyVapUEZa5c+cOHj58aPZbflp0GQVnCIfowpgw9HvLeWPm0XZIspHE1MEpwaYrHBeFUSLBuGpu71UeiplTAvJ4DNxhCp5kwJQFJj67amF+ttVDGk5WCtLIs8BsCNDcg7m9gIw7l4It+0YbURIairSHkIXXpar9tgkEQRDPKnq9Ht999x28vb3RrVs3AICXlxe6d++OxMRE7NixA9HR0UL7lSu5IeCTJk0ShBEAVKhQAWPHjsVHH30EQFpvJqSCNCMJ4Iptvj6gB37bdxi79x3F4H6dJfO9vDwweYbUUiumXwxWfr0Svv5+GDjmDWG6RqNB9Mt9sP+XbTj3z19ccEbxw5pPGjnyE+h0LsLMwMCK6Nv3DXzzzXT8+ut6DBw4EkqFLu1DKj6OHfsDS5bMsnnpBk2j0GXgEKvtHqdxF/PePqXgoZC57e3hCYPBgIzHj+HnLb1pKB6e7+HhgVatWmH37t346quvJPVlNmzYgJMnTwIAUpUKlhIEQdiJZq/56PhqOCuMpq+CC8JIewBcopfdtl/eUB7Rwgdk7AnKiBPY5GTCpK/4gJC7dNOecJiLCEcVXMAFVLHN9s1OKKmNIIiioLD0hxJz5sxRnO7j46ManGnYsCHq1KmDhIQEITiTkJCAunXrokGDBhZvhPNBp0GDBgnTBg0ahJUrVyI+Pt5icCYwIADtIiOx+udfMWf8OLi5umLZ2s0o4++HLi+2wrGLxoRlC3ey+foyWXBF1EttMXDsCHw3fxG+mDAR+7ZuA2MMM1Z+Dw8vL7vqzXTsOAhVq9YBX2dGo2EYOXIqduz4EVu3fi/STKZEtpUrv0Ja2n1hHR069MSECTOh1Vpz++GiLhkZ6XZpJgDo3GeIKTgDKLoNPEzj+mdvH5NtA+82AADepbjfWFrGQ6vb69ChA9asWYNZs2bhhRdeEEZVpaSkYP78+QCkmikyMhLe3t7YuHEjTpw4gfr1TcGfSZMmCa/Fy3To0AEJCQmYOnUqPv/8c2H64cOHhdpIT7ouo+AMUfDsgs3evmaE4gnIAgMUgzSpxmcHssDkQoPPAlOCssAIgigK3HXO+G+aSsS9hOCuc7beyAqbNm3C3bt3ERcXJxmiP2jQICQmJiI+Pl4ijv766y8AQMuWLc3WpTQNADIePcLS1YnYuHsvLl69hkeZ0htkN27Ji5YBVaqGAB6lJdPKli8LAKhR53loNBqhqGVutgv8y3FDNO9dvWFaQC0DzGhN7eysRe3azSAtZslQr15zAMDZszYkBygiH5Yvfq3Ba699gNde+xDcaBktJGPydcaXcu9kfp4sA8zF3voGNjJv3jy0aNECo0aNwpYtW1CnTh1cuHABmzZtQp06dfD3339LMhoJgiDyS737p63acik6DgDKiWGp8gl8hhv/rJTMpnZO5ZPalJLZdMiX44A4sa0oHAcoqY0gSiykPwpOf/DcvHlTGMFiD7GxsRg3bhwOHjwIADh9+jQWLFhgcZns7GwkJiaiVKlSkro0bdq0QXBwMP73v//hwYMH8PMz1bB8LBtBMbBrV/z8++/43849aNWwPn7ddxhjh/aFrrQWjsTk35o1FUf27MPK2ZyV9OiPZ6FmLVNSlq31ZurXFx9nTjuVL18JgYEVcenSaej1OUYnAoAP0OzY8S8A4N692zh69A98+eXHGDKkM778chUCA82T9+QEBYXgyJEMmMSRE4QECPEkJd0kQu42YC9qbgMA0L9/fyxfvhy7d+9G7dq18dJLL0Gv12Pjxo0IDORuhoo1k5eXF+bOnYthw4ahWbNm6NWrF8qVK4cDBw7g2LFjqFGjBs6cOSNZZtq0afjll18we/ZsHDx4EE2bNsXNmzexfv16PPfcc0+FLnuy954o0SgFEMQBhyq4IJ3pwMW36nB5mwtbiim+ujMOQwUuCYIoJDQaDTxctCX6oVQPxF6UsrsAoG3btqhQoQI2bdqE+/dNGU9paWlwcnIShmSL4S9AxeTo9ejy+uuY9u1SODs54ZXOHfHhqFhMHvcqBvfiilpmZ+sl9WYAwNvblBXHZ4DlaT0AAO7ePmbb0Wq5Ps/keWzEQgaYr2+A7EKWExr+/lyWwcOHabCv3oza96E0XeES1NarUoUMMAAoZcz+Sk/LkG7d+Db94SNoNBqU8vCwuom6deviyJEj6N27N44fP44FCxbg7Nmz+PbbbwX7ibJly9q4wwRBEBz2BgKsjg4pD1MimBxf2FB3BuB0kziLV6nujJxCqjujgK11Z3iXBoIgnlxIfxSM/igIBg4cCBcXFyxbtgzLli2Di4sLBgwYYHGZjRs3IiUlBb169ZIUh3dycsKAAQOQlZWF77//3uI6er8QiUD/0li2cTOWb98Kg8GA2D5dFNsq1ZuR4+LqipYd2gEAXN3cEDN0mNVllPIUSpcOBKeV8kRTDShdugwYY3j0KANqeikgoCxeeqkHPvtsKZKSzmH+/GkqG7ampWSBQaXAjA23Qf19uOOWnpah6DaQnsHd4PTRmtvhid0GAE6D/vzzz5gyZQqcnJywePFibNiwATExMVi/fj0Ac80UFxeHbdu2oVmzZti0aRO++uor6HQ67Ny5E+Hh4WbLVKxYEUeOHEFcXBySkpLwxRdf4NChQ5g2bRo++OADxW08adDIGaLQKd4sMKW6M0oUcN0ZubVZYWWBiaEsMIIgCLu4evUqfv31VwBAVFSUarvExESMGTMGAGcBYDAYcO/ePZQpIy1KyRfNzLlnutu0be9e/HXmDOK6x2Dp1IncRGNsZc2vv2LF+p8k67BFZCiSK7qYtzEDLDX1HgwGA7j4jGn0TErKHQCAl5c8ASK/9Wa4x7Fjf+DYsT9E01TqzRgf5UNC0aXfEKtbCKvKFYS8dD4ZQJRkeH5eXh6Srt5AWHCQEMiyRo0aNbB27Vqz6UOGcPsSIRMnBEEQSojtoAV4O2gFxwGxHbQYMzvoUNjhOKCGfLQMOQ6QtRlBEIVJYemPgsbf3x8xMTHCtXC3bt3g7+9vcRk+6JSQkICEhATVNmLLYDlarRaDunTCnJWr8O+lS2hc73nUqh6u2h4Aco337HIUiqz8ffgIEj5fAF9/f6SmpGDWmBGYuWytbfVmDBCkzv378uPMjNPvQKPRwNPTPJAh5/nn68Hb2xfHjx80TrHsNpCRkYbVq7+GVBwZ2znBTDPBCeg3fBxKBfiabVvsNlCpciU4OTnh0vnLivt5/ixXgqFqJduSI1xdXTF58mRMnjxZMn3Pnj0AlDVThw4d0KFDB7Ppr7zyCpycnNCggdSmu0KFCli6dKlZ+ylTpqhu40mCgjNEsSCuO6NKILiLdDm+xmfFhCy+7gygXndGLDoKoe7MXUiFhgp83Rm1Apc8VOCSIAiicFi+fDkMBgNatGiB6tWrm83Pzc3FihUrEB8fL4ijunXr4vjx49i3bx969Oghab9v3z6zdSRduwYAiGnTipsgGvSy79CJfH+G3GwX2xoq3GvLy8vFqVP7UbduU9FUA06ePAAAqF69Ngq63gwAHDu2D0uWzLR5DQ2aRpkFZ5SG5zeLagYA2PPrH3h/wmDJvD8OncSjx5mIasJd6Fsanm+JjIwMbNmyBf7+/mjXrp1jKyEIgihyrNlBW7I0U8LWujP8eyt1Z2yo2WkPVq3NCIIgiomi0B8FRWxsLNatWye8tsTly5exc+dOBAYGonPnzoptdu3ahRMnTgi1RsQ1OiXb7d4Vny//Djfv3MPkca/avd9ZxiDN/Qw93u0fB2etFkv2/IqvJ32CHet/QLM27dG1t+zzWHAbAIATJ/ahU6eBxndcUtvNm1dw+/Z1VK5cAzqdDpJojgKPHz/Cw4fpCAhQG+0k1U0ZGWl2aSYA6NxvCBecMY9TCbi7u6NB4zo4eugkrly+gUohQdzWMwDGGH7bdxieHu6IeN5ykr01Vq1aBQDo27evTe3379+P5ORkdOzYET4+5k4RcvLy8rBmzRpotVr07NkzX/ta3FBwhnAYcRbYlvPcaA0hC0wBm7PAAG6UyU179sYdUkFRjFlgvrJpVrLAeOzJAnMEygIjCIIwwRhDQkICNBoNVqxYgcqVKyu2O3fuHA4ePIijR48iIiICr7zyChISEjBt2jS0b99eKMp5/fp1RR/m4PLcsMk/TvyFLq1bCdP3HjqGJSs3cvuicPGcZ/T44i3N5ORYtJuB5Qww0Wj8RYsmYtGibUaPZIbbt69jzZqv4eLiiujoHgorsYblDDBTzZmPYPJNtlJvhi9pIEOcAeaGbFSpXgXNWjXCH7sPYfvP+9CtXSMAQE6OHh9N/xYAMKxLjGQdKamp0FesiNx79yRWEZmZmdDpdJJRNtnZ2YiLi8P9+/exYMECiUc4QRBEftDs5ZKvxFTDWZwDd+OuCi7gIkSZwxUZNxpfiTJQcBoAODtosY7hX4t1ki1BmgKuO1MIjgNWIccBgiCKgaLSHwVFdHQ0Nm7cCABWk5ISEhJgMBgwfPhwTJ06VbHN4sWLMXz4cMTHx2PhwoVm832NmqhGWCh+XrEAWdk5eLFFY6CUWVOb3Aamj3gLVy8lYcLC+Qiv9Tw+/Gox/j3yJ2a/OwZ1G7ZASKVq0gUsBGi2bVuJPn1GompVLmDBGMOiRVOQl5eHzp37C+2uXUtCqVLe8PHxla46Nwdz506GwWBA8+aW6hOYNFNQUCiOHHkEu+rNiO/yq1hBA8Arr/XF0UMnMen9+VizYipgvLf77fINuHTlOl7r3x3uIq2jz81F0rVr8PfzQ5Uq0kT79PR0eHtLHRfWr1+PZcuWoVGjRmYBRaX2N27cwLBhw6DVajF9+nTJPL1ej9zcXIlVnsFgwPjx43H27Fm8+eabCAoKwpMMBWeIwkFhiL7NhMKOIfqFmQUmH1VTgFlgVOCSIAiiWNm1axeSkpIQFRWlKowAYOjQoTh48CDi4+MRERGBNm3aYOjQoUhISEDt2rXRvXt3ZGdnY+3atWjatCm2bt0qWf6lli0RGhSEzxJW4p8LF1Hr+So4e+kytu78A907tcb6zTtt2t8sS+lPaljJAAsIKI+srEfo168BWrbsiMzMR9ix40ekpd3H+PGfomzZ8hBngM2fPwWpqZz/9cWLZwAACxZMg7u7JwANunXrh3r1msq2YqnejJP5JEvYcAg++2oqOkf2Qd9uY9Hn5XYo7x+An379A/+euYRRg3ujeb26kvZLfvgBny5ZgsmTJwvD4gHg2LFj6NGjB9q1a4fg4GCkp6fjp59+wpUrV/Dqq69i9OjRNuwwQRCEfThsBy1Owr0HabKYxG3AXTZBnIyWC8u6qeQ6DvBJbeQ4QBBESaao9IeY2bNnw8tL2XLrpZdeQtOm8mt3E05OToiJiVGdz2MwGISgE2//q0SfPn0wbtw4rFq1CrNnz7a4zpdaN7e6XQB4rNLPbExchy2Ja9C6Swf0Gfk6crNd4O3ngmlLEvFGpzb4aHh/LNt8EFqNLANMD2lZGSNNm0YjNjYS0dEvw9c3AEeO7Mbp08dRu3Yj9OnzKni9dOLEIcya9S7q1m2MChUqwcfHDykpd/Dnn/tw585NhIVVxYgRExT22IFaRvLAjAJKbgMA0HdwD2xc+xPWrf4ZVy9dQ6tW9XHxzDVs2LobYcFBmDH+DQAmt4Gbd+6gycsvIyQkBMnJyZJ1NWnSBMHBwahZsybc3Nzw559/Ys+ePahcuTLWrVsHZ2dprZwvvvgCiYmJaNGiBcqWLYurV69i06ZNePz4MeLj480szW7fvo3nn38e0dHRCAsLQ05ODrZv344zZ86gU6dOmDnTvtFFJREKzhCFSoFmgaliSxaYJXihUYR1Z4zwQsMaSkKDIAiCcBzeE9mSgAA4ETF27FisXr0ac+fOhbu7O5YsWYJq1aphyZIlWLhwISpWrIi33noLvXv3NhNHFf08sGvpV3hn7hf4/cQJ7Dl6DM9Xq4xV305DYFl/SXDG4XozOaJ+Uy4mLARodDoXLFy4HQsXvodt277Hw4dpCAmphnfe+Rzt2/eAuA4NwLBr11bcvHlNso5du7YJrxs2bKYQnAGkI2d4bInGKGAhA8wVWajxfDXsObwK0yZ+iZ9+3o9HjzJRrUolLJr+Lt54pZcpn8MKlSpVQuvWrbFv3z7cvn0bHh4eaNCgAebOnfvED5snCKJ4ERwHbMSqHbQ1xwF5PEZwHODtoJUcBwBzO2g1+HY8Re84YDM2WpuR4wBBEIVBUekPMXPmzFGd5+vrazE4Yys7duzAlStXEBUVhbCwMNV2Pj4+6NGjB1atWoXvZs/Gyy+9ZNd2mIf1NllwxdWky5gycgLKlC+HScuWSOY3iGyFIWPfx7J5H2PRzA8w9oPPlVfEuw0Yu4IBA95Eq1adsXr1F7h27SK8vf3Qt+8beP31D6DTuYDXTXXrNkb37gNx4sRhnDv3Dx4+TIeHhxdCQ8PRt28cXn55CNzcPGCL24BpGiDU6ORfKkkpG90GAC7w9uOm+ZgzKx5rvtuM+QtWo7SfN+IGdsWMsW+gjNZP+bgo0KdPH2zYsAGHDh2CXq9HWFgYJk6ciHfeecdshAwANG/eHHv37sWWLVvw4MED+Pv7o2PHjnjvvfdQv359s/Y+Pj6IiYnB/v37sXXrVuh0OtSqVQtLlixBbGwsnJwc1JUlCA1jdNVhifT0dPj4+CAtLU3xR/WsIy5uKYgM3tbMOHKGD86cLF1TsDU7h+qCyLiIcMHW7M6lYC44kwxOZPA1Z+4ZH3fBXcg/gkiEZIK723IfnHi4D04kZBrfi5/5u1RykcGfwfizmRaciNCJnkuJXpcGFxQyCg13cOLCEyaRUQZccIYP0ATCFJwJhTByhg/O8FlgVXABAAShwVub8cEZ8cgZIQtMLDK2m14qjZwhoUEQBE9WVhaSkpIQFhZGFk0FhNw72Vc82kNsnSsans+Mr/ngjDgDjLc140fO8O9zoDPVm+GDM9nguje+q8uFKRGaD9rwQsMgfyNWIAbRg4me5fVnxP2JXEDwosJJ5WHn8HwduJEzLkzIAHOBXgjOuCFbeO2BTLjncq81fBk6/jmNexLXnPF9wgtI2out/3u6BibsgX4v1rGmm8SaCeCSs/iEtguoIiS0XUWwSTMBUt2kpJlSIQrO3Da+4fVTpugh1kvizkQpcMMnrsn1kjtMOorXSuLXMNdNvGaC8VmumQCgIrNbMwEm3SQZOWOjbiLNRBCFA+kPAjDXTICKbrKgmQCTbrKkmQBjnU6xZgJMXZ1YO/G6SaKZIJqQC04Dqekm8QMKz5Y0k1w7ORunaSEIJVs1EyDoJj4444osITjjiix4GC8OJLqJNJNAUWqmJz+8RDyxWM1yKg/pEH05ZglY7rJn+cAwnWiaFa9+ANIzNCBNOeMFjQxeABUiilYHYgs5Uc0fe7LyCIIgiPyhVtTSEkzBQ5nHmsiQNbaOmciQ7InsuSBQyggTT5cOcbfqm6yCm00f3kSqfc0JgiCKBLuttxyqxaJ208AWbSRGPCxTD1MARy+bbiP3ZO8VRgKZWbmpoFTjlCAIgijZqCa0FQBCQpsYeRdlxQ6aQy6ilPphpjDd1v666NwGzLacYTaJKEIoOEPkC3FGkZBttF25LQBVay4z7+BQe/ZCHKWRZ3PJX9uCJVGRqfw+E+pBGbHYEAsNu+3bCIIgiCcJW0bN8BSIpZkYmwQGj5LQkGd9WUItA0wtKAOYXYLK4jQSrAzPF+Nh1k/DLANMzLOWAUYQxJMFPxoEMI0UsUoAuNEovjCvhwnAPKFNfJIVJ7MpYelull5hmsI5Wc5d2fvbotfJ5s15xwV+JBHvxsCPMlJCbrNtC+KRTgRBEETBUVAJbfJRM3KsJrTJ9ZJKvRlThpv4vbBnUE9yy0/Sm7wPKpzb9/yoGaJ4oeAMUXjY4OerSLFkgSmJCbUgTT6ywBSwlgXGCw2lLDASGgRBEE8XakUt5RROBpha/2trkMYWFC49bbkadbXeRA5lgBEE8SQitjBWQ7FmZSCk9S55fGGD4wBgrpsccRzgIccBgiAIIp9YcBhQQu42YBNKeqnQ3Qbk9+QKx22At4J2FHIbKDooOEMUOkpD9EteFpgcW7LArGSCFVAWmM28oDyZhAZBEEThU1iWZnIKPwNMzScZsveODM/nBUbRDM8nCIIoqSjVhbSE2A7azHEAMNVocRglO2hbsJQVQI4DBEEQhHXssTQrELcBhxLaAKle4p8dcRsQv7fmNiAjH24DYitoRbcBHnIbKHIoOEMUGU92FphS3RnAVEjTSCqUs8DURtA4KDQUs8DEtLc8myAIgihcfNVGejy1GWBy1Po3Kx7KVnIpbMkAo+H5BEGURBQLzPN20AqOA2p20GaE2rMXBWkHXch1ZxQgxwGCIIgnm4K2NLOGotuAEha7KyURVdD1ZuQ4yZ5RaG4DAuQ6UGxQcIYoVsRZYIqUiCwwa3VnLE82o4ALXJLQIAiCKOEURwaYzRRWvRnxa3k2GCBJ+3KGaXg+j8rwfDlusG+8PQ3PJwiiJKPkOCDGzHEg33bQYscBsU56MuvO2AxZmxEEQZRsbExo4xPY5AltdrsNACqJbPIJhVVvpvjcBsgKuvih4AyRb8RZYMIQ/e3KbQH1LDCzIfqh9uyFpSwwWy3Niq/uDI8jBS4lqFibEQRBEIVHicwAy4UNXZWj9WYcGbYPOJQBpoNZBpiLygdTHJ7Piw0ank8QxBOK2A7aJgJgbgftK2+k1t/IgzTi6daw5DjA152RnadLguMAQRAEUWwUuaWZGDU7M3IbIIoYCs4QhYvCEH2bKNAsMDFFWHcmFXZngdmDPdZmlAVGEARRAihRGWD5qTdjK0qjZorm0pMywAiCeJLJlx20GlbtoJUcB+xJcFPCguNAqg2rJscBgiAIQoatCW2KFHu9mXy4DQDkNvCUQsEZokhQGqIvzgIzG6KvRInMAgMkWWC80HiksJoCygJTEhqOQEKDIAiicKAMMKV9k0+z8xLUzuH5BEEQJR3BccACYscBsR20meMAYIcddEHUnVFzHCj8ujM8heU4QEltBEEQhUNBuQ3wqNXk5BPaVN0G5HqpWOvN2OE24KzQlEfh1ia5DTw5UHCGKFKe3CwwG+vOqCEXGg5kgVkSGo5kgREEQRD5p0RYmhVbBhiPtQwweTuVDDC+e1bpoml4PkEQTzpiO2gB3g7aUccBwE47aEDqOCDG1vqccp4exwGCIAiiaFBNaFPQSpYS2uRuA4pYGwnCx2AksZjiqDfD48DtelfrTYiSCQVniGLnycgCs/baAoVY4FJRaFCBS4IgiJKHg5ZmPIqWZmLynQGmhLUMMFs8kvn3ChlggHoGmBZW7xGKh+crZoDxGDPAaHg+QRBPAnY7DqjZQcsdB8yQJ7SJT7q2JrPJKfmOA2RtRhAE8QwhdhsQ6yObk9l4nk63AbKCLhlQcIYoEMRZYMIQ/e3KbQHpEH2LhNq7JyUwC6wEQ0KDIAiiYCloS7OiywArrHozFrDlKlSWAaY2PF8RC2KDhucTBPE0oOo4EKDQ2Bc2Og6o1ey0pKeeHMcBCWRtRhAEUSSUCLcBs0a27kkJcRvgJ5HbwFMHBWeIwsfGIfo2Z4EBpiwwRZ6QLDCx0CimLDCCIAjCcRwRGQVCic4AkyMWGyqjZ6zhor5/lAFGEMTTii120GbY7DhgK7YEZJ5gxwExZG1GEARRZNhjaaaE3G2Af6/oNiBOaFNyG8hT20pB1ZspILcBJQllp9uALZDbQNFDwRmiyLA2RF+MxSwwcSaYJ2zMAgMKLguMf85nFpgCfBZYgRa4JGszgiCIIuHKjRvQ1GmEIROnqLaxlAFmcZQMCjYDrGvXaujatSYKLgNM7b216fajFJAhCIJ4EhEcBywgdhywaAcdas+WLdlB5zeZDSDHAYIgiKIhOTkZGo0GQ4YMKe5dKTSsuQ28O2QUnteUwvXky7atUEkvqbgNLF48HY0aeeDYMfENzSfXbUDRCppPbEszn0VuA0UDBWeIIqd4ssC0suf8ZIHJ29oQpCnOApdWIKFBEATBERsbC41GA39/f2Rn258y5C2OnTiQAcZTEjPAsrIeIzHxa0ycOBK9erVC48YV0ahREG7c4G8Oaiw8xMiG5zvBfHi+wj1Be4bnn7twGb2Hvo+Aei/CvVoL1O3VH1+vXQ+mVIjbAteuXcPw4cNRqVIluLi4ICgoCEOHDsXVqwoJJAAMBgMWLlyIBg0awMPDA97e3mjVqhU2b95s13YJgni66aJ0LuLtoG10HDDDkuOAKmI7aCXHAaAkOw5YszbjIccBgiBKMvnVH2q0bt0aGo3G4mPPnj1C+ylTpgjTx48fr7re9957T2g3ZcoU1Xa///670G75rFmq7ZZv2gJNSCNoQhqh1+vvCdPlCW1Lvl4LT01tlNGEY9SQd61+fgDqbgMOY1tf+803n2HEiD7o1CkCLVpUwYsvPo9Bg17C999/i6ysxypLFa7bgBI593Iw7bMlqBrVA27VIhHUtgNem/ox7t6/b9d6HNFAhw8fRkxMDAICAuDq6oqqVati0qRJyMxUvr/64MEDjB8/HuHh4XB1dUWZMmXQq1cv/Pvvv6rb+P777xEZGQkvLy94enqiUaNGWL58uV2fraig4AxRbBRPFpgjWMsCA1QDNI9gngVWwgpcEgRBEEBGRgZ++OEHaDQa3L9/Hxs3brTYPr+WZrbUm1EkR6WvsCMDTBnLGWD376dgwYLp2L59I3JysuHt7WvrHkNRYPBBGTUUumxxBph4eL44A+y//y6h8YtDsOnnvejQujnG9O+DvLw8jPj4U7w3e7bNe3zx4kU0bNgQixcvRs2aNTF27Fg0btwYK1asQEREBC5evChpzxhD7969MXr0aKSnpyMuLg59+/bF2bNnERMTg4ULF9q8bYIgCGuOA2Z20ErwjgO8HTTvOCDBXeW1DtITcTHXnbEAOQ4QBPGkYq/+cIS3334bkydPVnyEhoaatddqtUhMTERurrm4yM3NxcqVK6HVWg/cx8fHAwA0Gg0SZTfpJZZm7vx2nbFl5z7cS0lVXN+K+A3Cdg1GEcEntI2fORFbTh+Fb4UQbj8LzG3AAHN9ZN1tYN265Xj8+BGaNo1C377DEB0dg+zsbMybNwWxsV2NAZqidRuQOw8YDAbEDHgbk2ctRoCfL8bF9kWzOrWxdMMmRMfG4t6DBzZtwxENtGHDBrRo0QLbt29H+/btMWrUKPj7+2P69Olo166dWZAyJSUFTZo0wZw5c1C2bFmMGjUK7dq1w5YtW9C4cWMcPnzYbBtvv/02BgwYgEuXLmHAgAEYOnQoUlJSMHToUIvBx+LCkVQYglCkC2PCKIwt540XtNtRDP693jAJAHfjax1MZ19eVNgbOs8VrUdnXK9YxKSbtslv3tfKKm/CbFTQnUvByrZuRs6huqodnIQXYMq+aw8hI69LVdssFAiCIJ4l1q5di0ePHuGtt97C/PnzER8fjz59+hToNgq8qGURZoD5+pbGwoWrUaNGbfj4+GH06AE4dGiPSmulUTMO5AO5Wm8iZ8SoWUhLf4hty+ejQ5tIIA2YPup1tB42Ekt++AG92rdHdGys1fWMHTsWd+7cwYIFCzBmzBhh+rp169C7d2+MHDkSv/zyizD9xx9/xI8//ojIyEj89ttvcHfnvtdPPvkEERERGD9+PDp37qwoggmCINSod/+01VHyZStfNR9FEgjpSH0xvDySTBB3KFrZex0suwjw+kgrmqb22oY+7y64gBLPbXCfB+AcB0Ktr0KNk6VrOubiYGSLRqM88okgCMIBikJ/jB8/HuXKlbO5fYcOHbBlyxZs3boV3bp1k8zbtm0bbt26ha5du1ocFZGeno7169ejTp06CAwMxM6dO3Ht1i1UtLAfHVo3x5Yd+5D4wzaMe6O/MD1T64ZTf5/FiWP/4aWubfHL5p1my/qUD4VP+YJ0GzCozLCt3sy2bcfg6son4mmE50mTRuHnn3/E5s1r0bv3MBSn20DCyq3YvusQ+vVsj1VzpkOj0QBpwDc//Ig3ZszCjK+/xvwPPrC6Lns1UGZmJl5//XVoNBrs378fDRs2BMAFeUaPHo1FixZh3rx5mDBhgrCNyZMn4/z583jrrbcwZ84cYfrBgwfRsmVLxMbG4tSpU3By4vTm0aNHMXfuXISHh+Pw4cMoXbo0AODRo0do06YN5syZg549e6JZs2ZWP19RQSNniKLBGCSwOwtMaYh+vrPAxJSQujPJ5pMKtMClFcjajCCIZ534+HhotVq8++67aNOmDXbu3InLl5V9i/Py8jB/xQo06N4d5SIj0aB7d3z9XQIMzDgyRRxHKQXs3ncUsaOmocbzPVHKtxVK+bZCi4g+WLZ4ndBMXNQyXFMG/Vp3x+3rN/B2/zi0CaiAqABvjOvWGdeSLgEAks6dxvih3dD2udKIqlkK743ohZRbanfixBggHkGTkZGKTz4Zi/btayAysgIGDGiD7ds3mC3l4eGJJk1awcfHz8K6lfoS+bT8Dc+3lAF27txl/L7vBNq0jOACM/wqdDp8OHw4AGClDRmJWVlZ2L59OwIDAzF69GjJvJdffhn16tXD9u3bcenSJWH6pk2bAAAffPCBIEoAICAgAG+++Says7ORkJBgddsEQRDWEDsOKGKzHbQlxwH5a1vJR90ZchwgCOIZw1798emnnyI8PBxubm4IDw/HzJkzYTBYGhlvPz169ICvry+WLVtmNm/ZsmXw8/ND9+7dLa5j9erVePz4MQYNGoSeLVrAYDBg9datFpdp3rAOalQNRcL3W80S2lYu2whnZ2f0GdxDcdkPhgzH85pSuJGcDIC70T+2Syc08tbg1/+tlbRleoYxQzugUTUNfv3ZOM/MbUCZTZtWom/f5oiMDELHjrUxd+5EPHr0ULx2ABAFZqS0bdsFAHDtWpLKForObWBp/EYAwMxJI7nAjJE+XXsgtEIFrP/lF2RmWa/zaa8GOnDgAO7evYtu3boJgRmAG2E1Y8YMAMA333wjsaPetGkTnJycMHXqVMm2mzVrhi5duuC///7D3r17Je0B4M033xQCMwDg6emJDz/8UNhGSYKCM0SJRXH0SKD5JAGzRCz5BLnAKEF1ZxyEhAZBEIUGY0DOo5L9KKDs1f/++w+HDh1CdHQ0AgMDMWjQIBgMBtWb6UN69MDUhQthMBgwrFcvvNC0KeZ+9z3Gzpqj2P7TBSvw+8ETaNTwOYwc8TL6DuyMlHupGD18Gia9/YniMmkPUvFKi2hcT7qMzoMHokGr1tj/yzaMimmHCyf/QVzH5nj86CG69I5FzToR2PXLj5j4Vj8b6s2Y0OtzMHJkDI4fP4COHXujS5f+uH37BiZOfANr1y61cMTkN8XE2V5KGWD8s4KykGeAKWBrBtievccAANFtmpjNb1qvHjzd3bH/+HGr60pJSUFubi5CQkIkYoUnLCwMALB7925h2q1btyTzlNrv2uVoMQmCIJ5mbBnRLraDFpM/O2hAue4MoGxpZg8O1J2xxk3zSdbqzpC1GUE8YZD+UNUfr732GiZMmACDwYCRI0eiffv2mDt3LsaOHVsg+8Pj5uaGfv364eeff8bt26abVrdv38ZPP/2Efv36wc3Nsj1zfHw8nJ2dMWDAAHR54QV4eXhg1ZYtYIxJLc3EuAJDB3TB3/+ex7FjphGOOTl6/LDqJ7Rp3xLlgspa3f/cbBdoNBpM+joBpcuUxczxw3EzyRTwWr1sPg7u/QWdewxBdAdro5RM0ZpVqxZi9uwJeO65+ujbdzgCAgKxevVijB7dB7m5ttkZ7N+/AwBQpUoN45TicRvIysrG4T//RfWqIQjxNWZ0pBn3SKNB6yZN8CgzExds+K3bq4Estff19YWfnx8uX74sSYK7desWAgIC4OXlle9tlFRdRrZmRLGgNEQ/HBctjg5BeShelJvjDpPFGG8/Jh6eLx+qbwtKQ/XV6s7IgkJ8gUtf0bR74Eb/AFJrs2samwp6krUZQRCFjv4x8ElQce+FZT64Abh45ns1vCfyK6+8AoDLGBsxYgQSEhIwadIkYYg0AOzZsweJmzejVtWq+CU+Hp7GDKGpw4eiXu8Biuv/es4EhNaqILzP1LohNzcXMR1HYfGCFXht7BCUqVRZssy5v//BoDdH4s25XMAnN9sFs0aOxI/xX+O1Li3x6jtT0C92LJDLZYC9GdcZ+/dsw5n/jqNG9QY21Zu5d+8WgoMrIz7+F+h0WgAMQ4eOwcCBbfHFF9PRpk1HlC3L2w/YIkSVMpnNM8COHduDYyf2cM2dFB4AF7RxBuDMUDG8IroOGSQs7ybxJzBx/gJ3k7Jq5Upm85ydnVEpKAhnk5KQm5tr0Svbz88Pzs7OuHz5MhhjZgGapCQu2+3cuXPCtICAAGFezZo1rbYnCOLZRmwHLcDbQe+CNFhgKxWZ+ogSsU1Yqnym3OdMbOOcC/t0k9zGTBzkyYRdNUDFeklMMsyCUFcRjGBcxUWEowou4AKqqI4wImszgijBkP5Q1R/Lli1D3bp1sX//fnh6ctv/4IMPUK9ePYvbmT17tuJNbTc3N4l1lJi4uDh8/fXXWLlyJd555x0AwMqVK5Gbm4u4uDiL17SnTp3CkSNH0L59e5QrVw6p166hc5s2WPPTT/j9yBHEtGxsauwjXXZQn474cMZXWLZ8Mxo2rIlMrRu2rtuOe/ceYEDcy5K2fL0Z3n1Ajn/ZQEz+cgXG9euIiSP6Y/G633HhzCks/HwCKoVWxTsTvlT9DBxS7XTo0E6sWLETVas+B8AAxj7ERx+9ge3bN2DNmngMHPi62RpWrvwKWVmZyMhIx19/HcHp03+hadPW6NRJKShk7jaQkZGK1Wvmm+I3SrpJpJkA4JXxI+Dv62nRbeDixWswGAyoWlk5waFKMDf9/PnzaNmypWIbHns1kLi9nLS0NDww1ro5d+4cqlSpIixz584dPHz40Oy3bO82+GnXrl3D48eP4eHhYfHzFRUUnCGKlTo4pTj6IxhXBVsvANwFeLI9a1aqOwNIPZMdDdJYqjsD0zS1ujNqQkMEX3eGhAZBEEThotfr8d1338Hb21vwVfby8kL37t2RmJiIHTt2IDo6Wmi/cuVKAMC7w4YJgRlfV8A3sCzGDuiLjxaKhkgbh+SHhVQwC21otVoMfr0/9vy2H3t2H8PLg6XBGQ8vL4yZMQmAqd5M+5f74cf4r+Hj54++Q011UDQaDaI79sX+Pdtw7uxfXHBGFXFhS2DkyEnQ6VzAC5DAwCD07fsqvvnmU/z66yYMHDjcwrqUUPNNNnHs+B4sWTxVdb6chlEt0WtIP8V54uH5aWmcrYCPt1EwGzPAUo2xnFKenjAYDMjIyICfn7o9m4eHB1q1aoXdu3fjq6++wsiRI4V5GzZswMmTJ7n1pqYK0zt06IA1a9Zg1qxZeOGFF4SMwpSUFMyfP9+sPUEQhDU0e81HxFfDWWE0SBVcEOyPVeE1h9gezNf4rFh3xh1ABpTrzgAmDSS3fban7gy/PQtYqjvjIH+jttkIJBalbLtNEARRmDiqPyZNmiQEZgCgQoUKGDt2LD766CPVbYlrdIjx8fFRDc40bNgQderUQUJCghCcSUhIQN26ddGgQQOLwRk+6DRo0CCkHj0KAOjbsSPW/PQTvtu8WRqckVEuMAAd20VizdpfMefzcYCXG1Yu+x8CypRG+y4v4K9j/6guq0Tzti+h72tjsfrb+Vg4awL+2LkVjDHMmLsaHp5eKlZm8noznG7q2LEvqlatJSyg0WgwcuT72LFjE7ZuXSsKzph01sqVXyEt7YHwvkOHXpgw4TNotTrY4jaQkZFql2YCgG7D+sPf13Lw0KSZzIN2AKeZuHZpVrdnrwaKjIyEt7c3Nm7ciBMnTqB+/frCvEmTJgmv5TorISEBU6dOxeeffy5MP3z4MLYa7fLk7WfNmoX58+ejf//+8PX1BQA8fvwYM2fOFB2HNArOEE8n4iywLeeNQ8GLIgvskdLMws4C46epjJhRohALXBIEQRQoOg8uM6wko8v/xdSmTZtw9+5dxMXFSYboDxo0CImJiYiPj5eIo+MHDwIAmokuJHlaNqinuI2MjEf4fG4iNm3ai4uXruHRI6m9y+0bnGVAlmhMekjVKnD2kKaTBQRywyzDn6/DjeYQDeD0L8PNu3fH0ncmVR/OzlrUrt0IpkKWnJioV4+zBTt7VtlGRxlL9Wako2deGz4Fr42Yol7YUgfT8HwXZrQ1s6PPtrMcnBLz5s1DixYtMGrUKGzZsgV16tTBhQsXsGnTJtSpUwd///23JKOxf//+WL58OXbv3o3atWvjpZdegl6vx8aNGxEYyHX04vYEQRC2ouQ4IKds5atSiy9+VL5dNspPpuMAn9SmBjkOEMQTBOkPRf3x119/AYDiKAZrIxtu3ryJcuXKWWyjRGxsLMaNG4eDRu1z+vRpLFiwwOIy2dnZSExMRKlSpdC9e3dk//svt48REagQGIif9uzBg/R0+Hl7q65j6LCu2Pzz7/jfxj1o3KYpdv56EK+NHQydzo6RlzkmXTJqwiwc+2MPEr+dDQAY/d6nqFmrobRrU603Ywq01K/fDGK9BDCULx+MwMAgXLp0Fnp9tjHhzcSOHdznv3fvLo4e/QNffvkxhgzpgC+//AGBgRUgxdxtICgoFEdOMnXNBJh0k6CZAEv9tnuu9Toy9mCvBvLy8sLcuXMxbNgwNGvWDL169UK5cuVw4MABHDt2DDVq1MCZM2cky0ybNg2//PILZs+ejYMHD6Jp06a4efMm1q9fj+eee85Ml7Vq1QqvvPIKvvvuOzz33HPo2rUrdDodtm3bhtzcXPj4+CAtLa1EabOSsyfEM4NSdpL4grkKLlhfSQCko088wV3Eq9ad4Z+V6s6Iz2pibKk7Iy9waYFUy7MFCrrApUpAjDyUCYKwiEbDDdkvyQ+FeiD2Is7uEtO2bVtUqFABmzZtwv3794Xp6Q8fwsnJCf7GDBwxgf7+ZtNycvRo3e11TJ+xFM7OTuj3She8++Gr+GDyG0JRy+xsc5suL+9SZtOcjVZcnl7mgoa36crN0duQAcY18PX1V7goZfD35zrYhw8zzLajjEb2Wv4esFzJUgUX6XgjS8PzAcDXncv+SktXzNhAxqNH0Gg0KFXK/NjKqVu3Lo4cOYLevXvj+PHjWLBgAc6ePYtvv/1WsJ8oW9bke63VavHzzz9jypQpcHJywuLFi7FhwwbExMRg/fr1Zu0JgiDyg3gkff7qzrhDWndGjB03whSxVndGNE2t7sw9hWkq8K4L/GgiS3bZ1oJd1jCzpCMIomAg/aGoP/gbybxdkxj+BnhBM3DgQLi4uGDZsmVYtmwZXFxcMGCAsoUzz8aNG5GSkoJevXpJisM7OTnh5ZdeQlZ2Nr7f9gs30Ud5HZ06RiIwsDQSlm9G4vJNMBgM6B/LWZrlgAt+5BlvZfOWZgbje95tQIyLqyuav9ABAODq6oaYnsNsPAJSQVW6dBnFVqVLlwFjDI8ePVRdU0BAWbz0Ug989lk8kpLOYv78SbDFbcDqHXuFbtpFdM9SbAUtdhvw8eE1k3SfebeBjEePjO1UviQRjmiguLg4bNu2Dc2aNcOmTZvw1VdfQafTYefOnQgPDzdbpmLFijhy5Aji4uKQlJSEL774AocOHcK0adPwwQcfKG5j+fLlWLBgAcqUKYPly5cjMTERjRo1wr59+5CXlwetVovSpUtb/XxFBY2cIUo0JT8LDJBam4lFh3FaqvGtfGShWhaYEbUsMN7azOYsMDGiLDAlyNqMIIhniatXr+LXX38FAERFRam2S0xMxJgxnI2Yt5cXDAYDUlJTEeDnJylqeTsrxfTGeP9/07a9OH7iDGKHxmDp4onI1HIi4jHc8b81W7F2xQaL+yiIDFEGmMQjTZ6MzGNDBlhqagoMhjyY4jPcvJSUuwAAL69SZsuoCwi1oIw5x47vwbFje2yqN+OkzUNQaIiqrRlgygDjfZPPJ18xa5OXl4crN24gLCzMYr0ZMTVq1MDatWvNpg8ZMgQAEBERIZnu6uqKyZMnY/LkyZLpe/bsUWxPEATBIzgOWEDNDrpg4B0H+GfeaQAw3f2xVTtZcxywMehTRI4DZG1GEERR4oj+8PHxgcFgwL1791CmjDRIcPu2XTfHbMbf3x8xMTHCtXC3bt3gr5CIJoYPOiUkJCAhIUG5zYbNGNm3t+o6tFotXhnYCXPnrcI//yUhonFt1KxVzf4PYAw0/HP8ML77+nP4+Pkj7UEKZk1+AzNnm1/fSxGLKE4H3b9/F3K3AX66RqOBp6eyRZiY55+vD29vXxw/fkA2R9ltICMjFatXz7debwYQdNOQca/DWyGJUEzlyhXg5OQk1OvkraB5Ll411vGsaltGtyMaqEOHDujQoYPZ9FdeeQVOTk5o0EBq012hQgUsXbrUrP2UKVMUt+Hk5IQxY8YI/yGe5ORkPHz4EA0aNLBvNFYhQ8EZothQGqIfjotClpNZ3Rm7cAeQrjKvoOvOAKpCQ6kkDWBT3Rkeed0ZS+S37gxBEMSzxPLly2EwGNCiRQtUr17dbH5ubi5WrFiB+Ph4jBkzBqlHj6JW1ar468wZHDxxAl1ekA5N3PfnSbN1XEy+BgCI6dpKCMzwHNx3RHgtL2ppsJRFZbajtjaURmvy8nJx6tSfqFtX6v188uRhAED16rVs3wcA5gEa8+H5cAKOHd2DJd86VnNGnAEmJyqSu4j/9ffDmDBiiGTeoZMn8Sgz06IItoWMjAxs2bIF/v7+aNeunU3LrFq1CgDQt2/ffG2bIIinC7EdtICddtBmdWfU7KCVdIfqoH++7gwgDdIAJosye+vOQNaW344KqTCv26mE0drMVpTqzkggazOCIAoZe/UHwI3oPn78OPbt24cePXpI2u/bt6/Q9jU2Nhbr1q0TXlvi8uXL2LlzJwIDA9G5c2fk3JMOfXRxAnb9eQQnzpzFidNnUb+p+Wdnxryw2CFdMXvOd7h18y7GTx5j1s4iooS2Rw8zMHFkf2idtfjm+z1YsmAKdvz8A5o1b4+uMbEmcwFBIim7DQDAiRMH0akTH1TiAjQ3b17B7ds3ULlydaOlmVJCm+n58ePHePgwHQEBgbDFbSAjI9UuzQQAPYf0Qxlfk+ZUchvwyHVD4wbP49DRU7h87SZCSpkyxRlj2HP4MDw9PfOdWGavBtq/fz+Sk5PRsWNHm0bt5OXlYc2aNdBqtejZs2eh7FNRQcEZotixOQssFFyWlM3wQ/QtZYEVVN0Zflrx1J2xWuBSLDREkNAgCOJZhTGGhIQEaDQarFixApUrV1Zsd+7cORw8eBBHjx5FOIA+HTti1ZYt+GzpUrzQrBl8Xbnz/vXbd7AgYY3Z8iHB3MXuH/v/wovdTN7R+/ceRuKSH+zb6RzRa3lAxmqAxjwDDAAWLZqBRYvWCZlDt2/fwJo1S+Hi4oro6Bj79k9ALDbMx+O/9sYUvDZyis31ZlxU+moP2Z3F6lVD0ap5few+cBQ/b92PDi0jkZoN5Oj1+PjbbwEAw4ZJrQzu3buHe/fuISAgQGIVkZmZCZ1OJxllk52djbi4ONy/fx8LFiyQeIQDQHp6OrxlHtrr16/HsmXL0KhRIzMxTRAEYQ3NXpldMTg76HMwv6ll5jgAmDSGkj2YvDwn3CHVRnwSG/9sj3ayp+4Mv23k23GAT2rjseQ4kN+kNnIcIAjCXhzRHxEREXjllVeQkJCAadOmoX379vA0Fmy/fv261Tow+SE6OhobN24EAKtJSQkJCTAYDBg+fDimTp2K1KNHJfN9XYHF6zdg+LSZiP9pExY2fZeb4Wa+rho1QvG/n79GdlY2mrzYBoApgc0ePh0/AtcvX8I70xcivHotfDhzCf7960/M/mQM6tZtgZBKlkbkSM/v27atQZ8+caha9TluLmNYtGgm8vLy0LmzaSTQtWuXUaqUN3x8pJZZubl6zJ07EQaDAc2bvyiao56QFxQciiN/M7vqzXC6SbmujLjezGuDu+HQ0VN4/9NFWDV9OlfPFEDChg1Ivn4dr732msSaTq/X4+LFi9DpdKhSRWobaq8GUmp/48YNDBs2DFqtFtOnT5fM0+v1yM3NleyPwWDA+PHjcfbsWbz55psICgqyuo19+/Zh5syZCAkJweuvv654jIoLCs4QBY44C0wYom9nFpgZBZoFxt/9kQ/VtzcLTEloqA2VMZIK9QKXYmzMAiNrM4IgCMfYtWsXkpKSEBUVpSqMAGDo0KE4ePAg4uPjMTMuDi0jIjCgSxes2rIFLfv1Rfe2rZGdo8faX39D0/q1sHXnH5KBlJ17tUTox0H4fPZK/P1fEp6rFY4zZ6/i16270bF7O2xZ/4v1nc1RuWhXCsgw2JgBxhAQUA5ZWY/Qr18UWraMRmbmI+zYsRlpaQ8wfvwMlC1bHmJhMn/+NKSmPgAAXLx4BgCwYMFUuBtrvXTrNgj16jUXbUtpFI2NuDjWD3019T1E9hyGbuPeQZ/27eBXOgC//vEHzly6hFGjRqF58+aS9gsXLsTUqVMxefJkYVg8ABw7dgw9evRAu3btEBwcjPT0dPz000+4cuUKXn31VYwePdps202aNEFwcDBq1qwJNzc3/Pnnn9izZw8qV66MdevWwdnZgbo7BEEQUHYcsEp5cIEMq4gdB+R20PYms8mROw5Y0Er2Og4kQzWpzRbHAYIgiKLGEf0RERGBNm3aYOjQoUhISEDt2rXRvXt3ZGdnY+3atWjatCm2bt2quq7Zs2fDy0vZcuull15C06ZNVZd1cnJCTIz1ZC2DwSAEnYYMGWIWmOHp0z4a4z6bi1Ubf8HsD8fCzc1VsV2m1g3RL7UAwFlBK8EHa3KUHGyygW3rEvHzj4lo2a4Leg8eCeQC3p5+mDY7EW+80gYfTeiPZSsOQuukZG1l7g3dtOkLiI3tgOjobvD1LY0jR/bh9Om/ULt2Q/TpYxpVdOLEIcya9T7q1m2EChVC4OPjh5SUu/jzz324c+cGwsKqYcSIibDFbcAiVhy5LLkNAMDgfp2xdv1vWL15O5KSbyAqoj5OJ1/Dlt27ERIUhBkzZkjaX79+HTVr1kRISAiSk5Ml8+zVQF988QUSExPRokULlC1bFlevXsWmTZvw+PFjxMfHm1ma3b59G88//zyio6MRFhaGnJwcbN++HWfOnEGnTp0wc+ZMs8/Xq1cvZGZmok6dOvD29sapU6fw888/o3Tp0ti4caNNNUiLEgrOEMWCtSww8RD9gs8Cyw98kEZsaQZYrDsjfivPAhNjQxYYLzT4LDAlyNqMIAjCOrwnMl8/RI0+ffpg7Nix+D4xEZMGDIC7mxsWfPghqlSqhMRNG7Fw9TpUDCyLt4b1R+/O7bjgjAgvLw/s/PUrvDvhC+zddxL79hxF9eer4utVc+EbWB5b1v+CXJXLMaWilvlDGvDQ6XRYuPBHLFw4Ddu2rcPDh+kICQnHO+98jPbtu5ktvWvXT7h585rZNJ6GDVuiXr1ImGeAyYISvEeyOAPMCuIh+UrD8zVGB57nq1XB4Y0JmDjzG/y0bz8eZWaiSqVK+Pzdd/H2rFnWN2SkUqVKaN26Nfbt24fbt2/Dw8MDDRo0wNy5c1WHzffp0wcbNmzAoUOHoNfrERYWhokTJ+Kdd94xy9wiCILILxbtoEPhgOMAL1qKou4Mvx0rWHIccBC7HAfI2owgiALEXv2xevVqzJ07F+7u7liyZAmqVauGJUuWYOHChahYsSLeeust9O7d22JwZs6cOarzfH19LQZnbGXHjh24cuUKoqKiEBYWhtSUFMl8vkanTykv9GjbBqt++gUbftmN/t1eEtowhTiNWmDGGtcvJ+Gz90YiILA8Jn2+TDKvQb1WGBL3PpYt+RiLvvwAY8d+LporL9pp0k4DBryBVq2isXr1Yly7lgRvb1/07TsMr7/+ntHSjKNu3Ubo3n0ATpw4jHPn/sXDh+nw8PBCaGhV9O07DC+/HAc3Nz5YZtltwCqyY2ar2wDABd42LZmDWV+vwHfrt2Hed6vh5+2NV7p2xYdvvGFW28gS9mqg5s2bY+/evdiyZQsePHgAf39/dOzYEe+99x7q169v1t7HxwcxMTHYv38/tm7dCp1Oh1q1amHJkiWIjY2Fk5P5sevWrRuWL1+OVatWITMzE8HBwRg9ejTef/99BAbm80KiENAwRinylkhPT4ePjw/S0tJIWNuB2D9ZKG7Z3vhsHDnDB2dOlq4p2JqdQ3VBZFxEuCAy7lwK5kaTJBvXcRPcxfk94+MuuFEpj4zPwrnntvFNJriMMP51LoDHotd6mMSH/ITGixF+xA3/7C561hqfvY3P4tfGJ19wwRlfmERGAEyZYIEwBWdCIYycKVuZG5rPD9Hns8D44Ix45AwvNMTBGUFoiG3NZCNn5EKDRs4QxLNBVlYWkpKSEBYWZmbRREhRygDzFV8Qi21xRYk4vHeyuN4MLzT4jC95vRk+A0wIzuRohKKWku4qF9Kui9cTiiNncsEJDIPowRSe+fO//BlQ8k42VaiUV6p0Nk7TGl8bL5p1UA7OWBmezwdk3JAtvPZApjA8nw/OCGUSjIUtU0VJY7759E1+WrD1f0/XwIQ90O/FMeQ1ZyS6SUEzAVxwgU9ou4AqQkLbVQSbEtp43STXTIBJN12HSDPxWuk+uA7lPrh+I9P4XvwMmPSTHLFu4m+quQPwgElDlYJJR5WGqmYCTLqJ10sBMNdMgEQ3qWkmwKSbLGomQFU3KQVnSDcRhH2Q/nj6UbI0E7CimQCTbpJrJsCKbuLdBrJh6qLEeonvuixqJoPotVwriTWU+Nwvfw9IAy8a2YPXSWLdxL829qPiyZasoI1uA2IraFdkCSNnXJElBGckuok0k1WKUjM5EJojiMJFfAEt9gwG4EA9FvEfgxcI4hEv/BnOVuSWZ2rTrHDXhjZKNm4KKHlOKyK2k2uv2gqAuVAkCIIgpPi6qsxQERk8jmaAWYQXGYqoZ4CZhERB3FgSCw4e2WWmA8PzxRlg1obnEwRBPGlYvLGvUC+SR5ycZZN9lzghjMcXCoNX3GXPcp0k11GWEHtvijVSpux1umlaJkwJd5awYNfGJ/fxQSs+8U8Ju63iCIIgCIuoWZrZijihTY7VhDZLWK3PKUYpWU2tv1aarnY/TSlgI0bBAlkcmLEBW90GiJIFBWeIEoF8aLkilmqwlIE0y0qCu8pruaDQqkwXh9zl05Tm8aNzFOBH9ogRW7MpCA0+A46EBkEQRPFgVWT4WJ6thLWilpJRMzzie1tKAkMxA0zeAFAWEUqiwxaxoSQ+5NPsuNxUC3rJUBqeL6CQAUYQBPGkIhnRYQV+xL0Em9w7lDSS+E5QftzQc2XPgIUiocqI9dJt0etkh3ZIcG0QI7HcVklqE0Y2EQRBEDahmtAmQimhzWHE1/9yvWRTLrVcPxX26EiFejOAxHhAQMUOmncbsAQ/aoYomVBwhigUxFlgwvBvfki4hSwwMYWTBcZjKQvMGkoCAzAXGQpZYHKUauYUMPLaPkqQ0CAIgig8HMkAs4kCywBzNCijNlRf3F4hA8waLtLtK2V9KWIhE4yG5xME8aSjVFNSrQalQHnLs6UoOQ4AUp30jDgO2AE5DhAEQdiIiqWZEna7DeTYeC6WmwqoziC3AaLooOAMUWwoZYGJh+iLKZlZYDwOZoHJhUYhZYFJIGszgiCIfONoBpjDlmYlOgNMqa+wkgEm901WQC0DjIbnEwTxrKPmOJB/O2jAPKFNJ3q2JUhDjgMEQRDPKoVhaSav0clj0dLMmtsAYIfbgC1BGnIbIPIHBWeIYkcpC8wqT3oWWKrsvdoIGmMWmFDcUwYvNJSywEhoEARB5B+7LM1sHJJvzdJM4InOAJNh6YqTL2ppJzQ8nyCIpw0zxwFHUbODDoDUDtpXqVFh1J0pRMeBZCu7oAJZmxEEQRQuhWFpZlFHqQUZ8uU2IH+t1FaMmtuA0kPcntwGnmUoOEOUSMRD9EtmFlgubMsCS5euJhUFWuDSVsjajCAIongosKKWBZoBJsaeDDA5ljLAHLjElAk4teH5FjPACIIgnjDEdtBmWLCDFjsOWLWDDoS5FTSPxUGdT4DjQAmAHAcIgnjWKciENofdBsTY7TZggLlesiVAkx/IbYDgoOAMUSQIWWAWUBuiL6HEZIHxyLPBlERGpvosoEgKXEogazOCIAiHKdailmIKrN6MWlsx8swutXoz8vaA3ZeaLvkUPjQ8nyCIpxAlO2g1zOygbXYcUHMYkL8nxwGCIPIPsxSYJp45LFmaqSEktIndBiwltPGJbHniifIJSnbQBfVbJbeBJ4miPEdRcIYoNBSzwPgh+sWWBVZYdWfE5CMLTIl8FLgkoUEQBOE4xZoBpmRpVuAZYPnFmpeybHi+PAPMBsRZXxYzwGh4PkEQzwBKdtAF6zigZAdtbwKbo3VnSqDjAFmbEUSB4+zMXR/q9TYVTiSeUGxJaLMVNbcBm1BLaFO0hBZTVPVmCs9tgMgf/DmKP2cVJhScIYqVos8Ck2NPFphYTNiSBWYhSFNEBS7FkLUZQRBE0WJPUUseuy3NbMoAk1OY9WYUhufzQRk5KkP01YbnEwRBEBw2OQ5YQuw4YIa7wmt5uq44yl5QdWdE00qK44AdkOMAQdiGTqeDq6sr0tLSaPTMU4KjCW1KbgMOW5qJYxJ2J7QBnH4qrHoz4tcF5DZgAxatoMltwCKMMaSlpcHV1RU6nQPDleykoIcMEIRD1Lt/2myURzguCoGHYFwVghIAuCywZHu24A3TVb678bUO3Fmbf7Y3c4M/42uNy2phWXS4m976yprdg/IIoGQ4VGPnb9S2LNpegGn0UntYLDq6RaOx7IVNEATxjFCYRS0LNwNMPHrG0QwwOQWYAWbloytlgNHwfIIgnla2nDcmTG2HVQtiMVVwQTpKpCIzH4HP6w1xkMPX+Cy5h+MOZW2kVZluD7z+4jdqY/93F1xQyRLXNOo22CLOobrErQHgHAeURiYRBFHwBAQE4Pr167h27Rp8fHyg0+mgoQDnE4tSWlWW+FQs1i6ixkx0OZ+t5fSK3phglgNXYcE8o7YwGLUG47VNjgHQG3834pxpPkctD6ZENjU5BINxO3xwRvxsgFQviRe21NfIAzLi5DUn0Xu+uIzWtNNMtphG9Hn4TWoA6BiQDTCmhwF6uCAHeQBykQsXZEMPINv4oTXG46/JgWlEqnFd4u/Ou1YtZGU92xqLMQa9Xo+0tDQ8fPgQFSpUKJLtUnCGKFHUwan8ZTCJL9hT5TP5oAz/OgPmAoMXB7xosEV8iAM8OuM28lHA7DY4uzYx+RAaBEEQhP2UiKKWJSoDTKnejNp7HjuCNBaG54uxmAFGEATxhNKFMfWRF7sAvMA5DshHwlfDWUVr47KVr5rXXgmEdLSJIu4w2YvxuobXS7ZqI4ja86+1sunyafy2FeBvJPmKpokT227CzFHhzqVglK3MJfcF4youIhxVcAEXUEViA6cGi1JxeBAltXWpaltdVYIgzPH25iwU7927h+vXrxfz3hD55fE9aTEwDy2QIp4gth92M3+d42QK0vMJa7nGZ72xr8gVpnPD8Q25xj5EPDhTHJThgzEG0Xtx7AWQTZC/VgvKwMJrJU2kkT3EgRr+vfG1RjbZWTSbdyFwhuA+4KTlPrwWedAa+1IdcoXXvJ5yMeghuEPzz0ZJ9VikMT2SkkBwuLq6okKFCsK5qrCh4AxRZDiaBWZGicwCkwsNcRDIAqnGZ2tZYEZ4ocFji9CwNwuMhAZBEET+caSopYBSvRklLDmXSYbRiAM0BVV7xlJQBlD2MbOCi4P7xAs+Gp5PEMQzgJLjgFXKw2J9FimWHAcAadDFGrmQugs84Y4DViDHAYKwHW9vb3h7e0Ov1yMvz+JFLVGC2V2jhtm0NqGiNy1Er5uYXrJGptenfTnHnLOoJkxLRhgA4IrRQecKQgAAt1AOAJBy1RiVv6UB+PgeX9M5BcAD4wMA7oPTC3x5M2FwSBaAh+Ay4vjnbOP0LOPrXOPrPEiH5eTCvB/Tip75Qpuuxveu4PpUV3Ceom7GZ1cAXtx7N3BdsDu45L/SAPyMD3+Y7htWAFCOwT+Y69jL4RYq4TIqGevOhYILslTHOdRM5e4Tao4Ylz1sfP6De9qdbNr7NmfOgOBqzBSFlZkYCs4QhYotWWBKiLPAxEP0S04WmJLAUMLCKJpHMPd5tpIFxsNngVlCSWhIssDI2owgCMJmbLE0S0q5gcqlYzD4lU74auUs1XZqRS3trjcDSLPBJBPE7wF5hlfXrg0BMGzebGWEkFWURsrIRszwWV+8RlGpNSPH1aScJK+FLWeYTTLDNyLCeiOCIIgnnIK1gwaUHQcAqQ4CTEGWInIckFubFbDjAFmbEUTRo9PpivxGKFFw5F6+bDbNTSxp0sSNTS+ZqI3GzQMAkGO8GcbdC3wIALhvNN66bdQy14wi4o6zmylpm/fmegzTvcEH4O6v8QGbVJhqPwu5AAxcACYDJk2TAxgNwrjXenDBGb4Ugrgkgrw/zAbXtxkgtewUayNem7nAJI58uUnuMNWD8xU1dTbOywHXpzMAzgyZbloE4ypuA/BCDkobj1kO7gv3At1cuO+HtzYTvo8bxk8g+vrc3GxMJCQKnIKvOEQQdqI4ZNxWVIIXyoiHo/EiQK2wpa2IT8biE7Q484sPz4sKXKYqrKqAC1wSBEEQ9hEbGwuNRoPSvr7IzrFQlL6oLc3EqNWbUUSploy1+jLKN7Sysh4jMfEbTJz4Bnr1ikTjxmXRqFEAbty4AotWZpauNOX1pQFoXS0cdyPW6s1cuHwZQ99/H1VefBHu7u6oW7cuvv76a7uLzl67dg3Dhw9HpUqV4OLigqCgIAwdOhRXryonRxgMBixcuBANGjSAh4cHvL290apVK2zevFl1G4cPH0ZMTAwCAgLg6uqKqlWrYtKkScjMJPs2giAsY3G0B49akCIAXJDDF6abQGbdlXyCXCfl92Yqr50yYVNgJ9W+tfMJfXygik/244NY1pBbyAmIHCC6VDWfrZqYSBAE8QygdF60hKWRoGrna7OEbUA6OvSebF4qTBaZEkT36QDRaz4Awwdj8oO875S/d8A2y4YEBIsYE7PJMafkQMEZosSglKUktuoyGykS6shW3FVeqwVpbBUdubJnAPb44t+13sTMyk0G33Ep+U7ba31gb4dKEATxpJORkYEffvgBGo0GD9LS8NOePQ6vK9fJ3NIrX5ZmDteb4bHlAt5yUcv791OwYMFUbN/+P+TkZMPb29esjZSCqTfjphqlUua/i5fQdsgQbNu7Fy82b44xY8YgLy8PI0aMwJgxY2xez8WLF9GwYUMsXrwYNWvWxNixY9G4cWOsWLECERERuHhRaiXKGEPv3r0xevRopKenIy4uDn379sXZs2cRExODhQsXmm1jw4YNaNGiBbZv34727dtj1KhR8Pf3x/Tp09GuXTtkZ5M/G0E8qwg3TPiR7TbaalXBBcsNAqFsCaYKr4XE2sjWZDZx4pq4UrNaJ2ZBOyndVBPffLPZsk0ZqzVPVdweCIIgnmWsBqPF5QxE51Gl4Lel8zAfXJeMCrWGPEADyEbNyPsc/r1SH1UQARrxvUV32bMRX5gSJsqA669t6LPFfb9SDWohId7CtQQ55RQvFJwhShzFnwVmD3qF10rTLJAqe29FaKhlgSmh1MFJOkKx0LBSB4iywAiCeJpZu3YtHj16hDfffBNOTk5IFI12sMXSjFkZQcNTYJZmEiszSzPk9WaURs+oYTrv+/qWxsKFa7Fjxxls3nwcNWvWV2jHPztQb8YGPGxIenhjxiykP3yIxM8/x7fTpuHTTz/F8ePH0bJlSyxcuBAHDx60aVtjx47FnTt3sGDBAmzfvh2ff/45Nm7ciLVr1+LOnTsYOXKkpP2PP/6IH3/8EZGRkTh16hS+/PJLLF68GP/++y9CQkIwfvx4JCcnC+0zMzPx+uuvQ6PRYP/+/Vi1ahXmzJmDgwcPYuTIkdi/fz/mzZtn1/EhCOLJxpYbI0qOA0o3YgBI6lQCsMNxwB1F6zggpmgdBwoiqY0gCIIoRpJFr62WO1BD3ifpZa/F09T6LyXkyd46qAZmLJktBMKsDzfr44knGgrOEEWKWRaYjZTcLDD5NCjMy4RiJlgqzLPAlKL7AFmbEQRBFDLx8fHQarUYHh2Nlg0bYu+RI7hyUyFC7gPk5eXh069XILx1d7iVj0R4w+6Y+WkCDAZptIS3NPtj90G8Hfs2WlVvhTpeIajjFYLeEa3ww+JlivvSyFuD4TGtcefmdUwc0R8v1g1AVO1SGDe0E65duQQASLp0GuPHdUPbVqUR1bIU3nuvF1JS5PurVG8GkmkZGWn45JN30L59XURGhmHAgGhs377RrLWHhyeaNImCj4+fyhFUqTfDT7JWb8ZFuo9KNWYU4csgpAHnki/j92Mn0DIiAu0iI02rdnHB9OnTAQBLliyxusqsrCxs374dgYGBGD16tGTeyy+/jHr16mH79u24dOmSMH3Tpk0AgA8++ADu7iZ1FRAQgDfffBPZ2dlISEgQph84cAB3795Ft27d0LBhQ2G6RqPBjBkzAADffPON3VZsBEE8G9hdFyXUka2o3SmS32xyNEjDP6toJTXscBxQtL4BWZsRBEEUBgVhacYHza2ep604ywDg+otUa43k/Y8jo2Tk9xXl9xctOfLYYYEdKn0rdhcSuw7xUA21JwcKzhCFjsUsMOOwuqcvCwwwP8nLssCsoRT1t6UDAlmbEQRB2MN///2HQ4cOITo6GmX9/dG3UycYDAZ8v2WLYvvXJnyCCbMWwmAwYGRcL7R/oSnmzf8e496ao9h+/qdLcej3Q6jbqC4GjopDzMCX8eBeCqYOH4s5b7+nuExG2gO82q0FblxJQqdeg9GgSWvs37sNo4a2w4Xz/yDuleZ4nPkQXbrGombNCOza9SMmThwgWoOlejPcs16vx8iRvXH8+EF07NgTXbr0we3bNzBx4kisXascODLHQr/EB2WUUOhq1erNKAVqNBnm7fYcPQYAaNOkCQDANyJCmNeiRQt4enpi717rhe5SUlKQm5uLkJAQaBRusIWFhQEAdu/eLUy7deuWZJ5S+127dtnU3tfXF35+frh8+bIkAEQQBGEJi3bQasgdB8yQZ/haqjtjzQ666B0HeBx1HJBA1mYEQRACBWlpZgm5pZnd9WYALik6VTxBNkpTmMZTkPVmxInh8vIJDtSbsQGb3IiIEkV+/JwIosCpd/+0fUGEUDgwqsQdytERHaQnX63xvQ62GfzzbflnGLdjY92au+AEkg3cuRSMspWv4iqCEYyruIhwVMEFXEAVxYi5HBYlCoi9AJP3ZHtYHNW0RaMhL0qCeEZgjCEzt2QXJXfXuivePP8/e+cd50Sd/vF3stm+LAu7LB2WKoiCBVRURFBQUey9nA3LKdZDTz0Vy/3sFeup6Hn2s5yIiKgUxUZHsQCCLL3tQnaXrckmvz+SSSaTmcmk7WaX5/16+ZrJzGQSuGO+eb7P5/v5RMvUqVMBON2/2uKkUaOY9PDDvDV9Og9cOwG1lmXeD0t49b+fMGTffnw3ayq5udl428Dtd13KgUMv0Ls9D77wID169QhYmgFUux38ddwZvPP0s5xz9U106tEjJG/mj19/5vwrb+KmO5/wHXDDQ3dew4dvv8CVF43gir/ew3nn3wCN4G30ctNNJ/Hdd5+xcuUyBgxQTzAZP7PLyrbTvXsvpk79hPT0DMDLpZdex4UXHseUKf9k1KhxFBd38V+ttS5ToxyzpvlZ8sM8liyYF2zcpAFpXuyORv9LD2n+8diBm14lnTjvkjMAyHYbr6j5Y71vMrJPjx5h59LS0ujVqxe//fYbbrcbh8P4J3C7du1IS0tj/fr1eL3esP+PrVu3DoDVq1cHjhUVFQXODRw4MKrrtVRUVLB79+7Ae/r0sabwFgRh72MwKyI3Fbp5wwVeiuOAdhJLt0xK9x90EKyJ1PtmuAjWRw7VMfUzWC+zUzNh5fRvc1XHytB3TiglqpVCq9nHUBRolfH9JFhZEAQhViKOY0aUGhw3cqQxLG2TmTejxcDWDMLzZvTQRDxEdBlSUOb8/PN9MmalFtKcEVKevqwNLGnszkZrIWDaJkfYQzgb34M3G58nirbAsNqQgdACQylAtMeUz9RBsTYrUB1TFxtbiWJ1UDg/s790zgVBiJpady2Hvn1oc38NUxacv4Cc9Jy47uFyuXjjjTfIz8/nxJE+KVdeTg4nHn00/505k69+XMjYww/zXdwW/vPhDADu/vsEcnODz/WuXYu5/rpzuXvyi0DQ0gygR6/QZkE9WTgccPrVV/L9l3NY/PVcTrro4pBrcnLz+OutPnsrpSY47sTz+PDtF2hbUMi5510fiJex2WyMHXs23333GatXL1c1ZyLnzVx77e3+xoyPjh27cO65l/Pii4/yxRfTuPDCv6qutqn+Ux9Dc8y8SbPkh3m8/NS9pteoOXzkIYHmjBkVe/YAkJ+bq3s+Pz8fj8dDVVUV7doZ2bNBTk4ORx11FHPnzuX5558PyZf56KOPWL58OQBOpzNw/IQTTuDdd9/loYceYvTo0WRl+TKFysvLeeqpp8KuP+KII8jPz+fjjz9m2bJlHHhgMMPn7rvvDuyr3yMIwt7F9D/8K9ln4RNQzcFw9UZ/VgVWzvdhTUBpXNx7Y7jKuCMWfPmz8amK9VDXSVabNApGTZpaQmsl5bOzw0/psR3fn0vNJptxTimYitqWtx8YsIMJEbWpEVGbIAhCgERamkVEz1FGO65ZscAEkpc3o6CsnjERbWejv3q1iPCxDeO8GT2hge74pUHGquZHbM2ElMRSM0Hvx3YR+h1m3R/02twZ7b4ZZrkzrSfgUjyUBUFo7UybNo2dO3dy1llnkZUZXNly7oknAjD1f9NCrv/pd5/MaMTwA9Ey4sgDdD9jT9UeHpv8GCcNOZrBeT0ZZGvDIFsbJp1xLgA7t24JXlzv23Tv1Y+sHFXjyQ1Fxb5Ofd/+g8NWcxQW+s6VlSnr+iPnzaSlOdh//6FhZw84wNeUW7XqF90/Tzjqn5M6PmZK3oyfK2++h0VbvCza6WVRpZdFdR4W1XlY5q3jV28Vv3qrWOPdySbvJjZ5NzFt3tvkRJNHkACefPJJ8vLymDhxIscffzy33norp59+OmeddRaDBw8GwG4P/rnPP/98Ro0axfz589l///257rrruPrqqxk0aBD5+flh1+fl5fHEE0/gcrkYPnw4F154IZMmTeLwww/nxRdfZMCAAWHvEQSh9WNlgsTKRIshUQm+zOygtRNN8ebOQEJyZ3SszZTmlFibCYIgJIZkW5op4uy4Lc2chOc8B1DnnSnbeFbJaPNm9JoxJqtmzIhBrC15My0LWTkjNDmprwJTHqKxqMDUlmZgTeqlwoq1mV8FplibKTSVtZkgCHsH2Y5sFpy/oLm/hinZjih/2OqgWJqddsghIcdHDhtG1+Jips39hl0VFbRv2xaAiqo92O12igoLwu7Vtkv4L+eqBjtnHX0WK5auYN8D9+fUi84mr7ADDoeDjaUbmf76m7jq60MszQBy2+SH1QdpXt/PttzccH9ixabL7XahzZbRy5sBKChoj91uCzteWOgbiPbsMRovtatnFOyhu2mhh0Kslk1QZ8zo5c0EUHJnKvzXZucBUFmtX4VVVlZis9lo06ZNxO8wZMgQFi1axOTJk5k7dy5z586lb9++/Otf/8LpdHLLLbdQXFwcuN7hcDBz5kweeugh3n77bV566SXatm3LaaedxqRJk+jfv3/I9QCXX345Xbp04ZFHHmHatGk0NjYybNgwZs+ezcMPP8zKlSvD3iMIgqCgZwdt6jhQQpx20FYcByK5DyTZcaAZEWszQRCE6InZ0swIPUszJ6r+vyKaVr9WE23ejFFxo14xo5c3A1HNFZaEvlRny1mZ/xNSG2nOCE3CeK83Ynfd9nX04WABOmMaAhlKPqFFhpILo86MUXAQW/dcmzujfJYJTmKyNlNyZ8wQazNBEKLFZrPFbRmW6mzcuJEvvvgCgJOuusrwujc/ncn11/hWubRtk4fH46Gs3EmHonZ4VfP8O7aXA+BW/bz6YtoXrFi6grMuv4AHX3mKenx2Vw2k8/m7/2X662+af0m9IchUWK3XmNHH6dyFx+MJW51RXu6TJOflKU2gxOXNACxZ6M+cgYh5Mw7cdC/pxuWXjAt+YhW69OnuV0Vv2EDB0NAVQY2Njaxbt45evXqZ5s2oGTBgAO+9917Y8UsuuQSAoZrPyMzMZPLkyUyePDnk+Lx583SvB58d2gknnBB2/KKLLsJut3PQQQdZ+q6CIOy9xJw7A8ZNjZC5KsUOGhKfO6Ns9RwHVJ+vvCzQXKaul9TWZqXEnTsj1maCIAjWaFJLM4VS1X5EkbYRLs2+dkxzGxyLFfWcYLjYLpA3Y4SJVaeCzPu1TKQ5I6QcrVMFBoaraJz+bRMFXKoLDSvoqcCk0BAEoTXw73//G4/Hw5FHHkmvwsKw83avm9c/mcHU/30SaM4MGdiPpb+sZP4Pyzh9fOiyz+/nLw27R+naUgCOPSV8An7Z/O9CD9QbfFG9GsCfN+PDo36hg37eTGOjmxUrFjNkSOiqoeXLFwKwzz7qyT51g0Y7wWc9bwZgyffzePmJ6DJnLr9kHNluk1U0wBH+RsbcBQvQ3v3bb7+lurqakSNjVYH4qKqqYvr06RQWFjJmzBhL73nrrbcAOPfccy1d/91331FaWsq4ceNo61+xJQiCEA1qxwFdlCaGUWgyEOo4kE5owyQRuTNN7zigiNrMHAciitrUjgOCIAh7Gc1uaaYnNNCyE/0IgRDUY5pbZ9+lORbNOKdglDeTHbpbQOh8oFFcg4Y+rAns6+XNBFDGLL+YQFZ5ph7SnBFSlr1GBdbEAZdqxNpMEIS9Fa/Xy2uvvYbNZuOZv/2Nkm7dQs4X+ONnVq/fwA8/rWDxz78xdPC+XHT6OF57fzr3PfoKx40eTk4b30N87fYKnn/6rZB71JNFt56++y75dgHHjA9WK4u//oaPXn7V96JRZxzTPuKjqgfUeTPaG4W+fu65h3juuXdJT/cVDdu3b+Hdd18hIyOTsWNP0bw3UjGkkzejw5WT7uHK2+8BJeInw4sjs4EM/x8ykzqy/J2qTOos5830Kynh8AMPZP7ixcycOTOwIqWhoYG77roLgAkTJoS8p6ysjLKyMoqKiigqCv54qK2tJT09PWSVTX19PZdffjm7du3i6aefJisrK+RelZWVgXwZhQ8++IBXX32VYcOGcfrpp0e8fsuWLUyYMAGHw8H9999v6c8tCELrJVY7aDVhdtAJcxyAyAI2K2gdB0xCkxWcxOQ4kGzE2kwQBME6cVuameXNgM8K06k9qMqBBhKfNwMJy5vpSNh4po42iISVjDoRXacG0pwRWhQtSwWmbJX7Wyg0wFgFplNoGKnA9BBrM0EQhCBz5sxh3bp1jBw5Mqwxo+bSU8bzw08rmPreJwwdvC+jjhvKpeeP57W3p7P/Uedy6ilHU1/v4r/vf8WwwwYz89PQX8Fjxo+hW0kPXnrkGVb+soq+++3Ln6vWMv/Tzzj6lNOY/dEH1r90YzR/Qq/Ofui2qKgjdXU1nHfeMYwYMYba2hq++mo6FRW7mTTpfoqLQwedp56ajNO5C7Cxdu1vADz99J1kZ+cBNk49dQIHHHCU72IlbyaNcKtlNRlxFgT+vBmnf9XRY3//O8dPmMCpp57KOeecQ+fOnZkxYwa//vorEydO5PDDDw95+7PPPsu9997L5MmTueeeewLHlyxZwumnn86YMWPo3r07lZWVzJgxgw0bNnDFFVdw3XXXhX2VQw89lO7duzNw4ECysrJYuHAh8+bNo3fv3rz//vukpYU2r6ZMmcKbb77JkUceSXFxMRs3bmTatGnU1NQwdepUsTQThL2UpNpBl5BAxwFtkwaC9U+Scmec/m0zOA6ItZkgCHs7emNTs1uaqTGaB6wN29G8dqm28ebNKEWPOm9GITF5M2r0RNnROOUIqYF1g3BBSCABRZHyQ9ZkabjR8rywjnFUCim1UlV5KKq73OqHZjw9TK2HpQl6XX3TJpM19AY6vQHRDL0BN+JyVkEQhBRm6tSpQDA/RI2yagbgnOPHkJ2VyTufzKK2zmer9fLT/+DBu6/FZrPx3PPv8/ms75l481945Km/h90rNy+XN+Z8xHFnnMQvi5byzrMvsXPLVv7vrX9z1lXXhH8x7Qp6PcLmeEI8zrCSNwOQnp7Os8++y0EHHcZnn33I9OnvUVzcmX/+8znOOWeC5mobc+ZMZ8aMd5kx4x127vRJ1ebM+ZgZM95kxow32LjRLw4w+3WpM6Q6MhsC+5nU6e6HYZA7M7BPHxYuXszJJ5/MjBkzePrpp7Hb7Tz33HNMmTLF5IuF0qNHD44++mjmz5/Pk08+yTvvvEPfvn354IMPeOmll7DpjIHnnHMO27Zt47XXXmPKlCls376dO++8k2XLltGzZ8+w6w8//HC6d+/O9OnTeeyxx/jqq68YN24cixYt4uKLL7b8XQVB2HvRm4BRT9REyqUEfI2NDgRtVQrQmTMymkSyWicpg5tb55jeOQiGNtcGX0ZCL3fA77IQsnqIoGWOYqETFQarlwRBEPZqEmRpZoiea472ub/T2meFz89p5+7cRNeoMUM7VurkzUSB0dguYuyWi6ycEZqMlqMCiyZ3Rm1tBuEqMAh9mGtszdSHCjS3jhRwKdZmgiAIMfP222/z9ttv41y82PS6/K551Kz6NuRYWloat914CX+/6xIAah1Be6udXl+Dop7gse69evLcB68FjjX4xwd3fQaL6jzQYAvJm1m0wf9sdxMyhHTpWsKi5V6dvBk4+OCRLFpUg++k2s5MP2/mk08WBY7dcccj3HHHI/4zenkyNv97luHrvNj9x5SlMcoxiytELV6mh82gKaNmn3324f3337d0v3vuuSdkxYxCjx49+O9//xvVdzO6lxGjR49m9GiZ4RMEIX5itoMuIgbHAaVeSpbjQIy5MxatzRTHATPicRyQvE5BEITImI1ZhnkzaswszZxmn1xL5LwZLUYiAjO0eTMGtmYFBIURHTDOmzGZ+xNaPrJyRkhJmkcF5jDYjwZ1UJi6QNGTepnIvyx3+4MDlTJwmdm+xePpGe1yVUEQhFZHm+Cut43xZWrq/MEq6mYN+BozgK8xEwnL811m+TL6TZpIq2uC6DVt1FjLmwmQGfoyQ/WHzFJ3qvxku01W0QiCIAhhqK2OdT3qtZmWhpg5Dmj340EZB2qJyXFAj9LovkEkxwFDEeFxBscFQRBaCS3O0kwZJ0IszSo1F6nzZmJpwEB4zkwUeTNmWgSdMVo9lqvHeCO3ISAovvYLryUXLTWR5oyQ0lhSLOl1kPU6zSGon4Lah2ciLM3cmi1YW4uPubWZ5fDOUGK2NotQaIi1mSAIrY2CzMjXGFFjQe3bYDaJZWZpZpo340Hf1ky7j4XjZgRX0fiwa7aaU+q8GT0s5M3kWB07BUEQWilGdtB6+SemEzRqorKDhtDayaiOUj/w9cY6I7tno0aMom6uDD3k1LlUXS+ZWJsZIdZmgiAIcZIKlmZRoc2b0e5HQk/crR4DE5A3E/VYHUQ3I02DrOhMHaQ5I7Q4Wo4KTG8FjQvTJk215rWRzUCptW9gtdAIGTCl0BAEYS8hkqUZbSPfQ21ppqBdJROR8IUiPrQNGg86lmbaro3Wzkx9zCqRVsmomzSan5JKU0aLko1pgmnGjB4Vvo3T6O9PEAShhRLNhEnUjgMlBjfSOg6EoZ1M0k5GWUUtYtPLnWmdjgMiahMEQfCRUEszNTuxtqoybJWMS2dfmzdjdUWNdkyMM2+mJPSlekzXizHQ+00gpD7SnBGaDSMVmB5NqwJTvCEVrKrAjCzNlHNqdAIunRG+qoWASyk0BEEQkkQclmaG6FmaaYeLaARcga6NlZUzepg1YWyEr5xRY/EnpWYYdWQ26F6m16gJ5M2Y5M4UDB1q7XsIgiDsjRg5Dui5DugKe5WHuFltFC0ty3FArM0EQdjbSBlLMz30BM1a0XMAbd4MRG66RFWM+dHOKRrYmkF43oweFvJmYs1JE1IDac4ITYoVFZiy/K75VGAK8VqaQeTcGRPUKrAEFBp6JMLaTBAEoSWit2omWZZmykoaxdIskDejJmZLMwXtKhntyplk5M0o2/jyZtRI3owgCEJs6E3MqB0HdLHkOJCNNceBWGsniNpxwEmzOQ6EII4DgiAIybc0UyhV7avFy3rPfycW82bAN95oV8pYIVLeTDqmeTMFOrcsIuLYHHFs1yJ5MymPNGeE1kFSVWBNnDvThAGXasTaTBAEQUVTWZpZxWN2QnsyUtMlyXkzEFW/JiMmRZogCMLeQ5jjgAlGjgNhdtBxOQ6oaQbHgUiI44AgCELSiHbVjJqYLM0i5IYBUVhcqhv/ejWIlQZNtHkzamLLm9GNdEB/zA/kzZi4E0neTGohzRkh5YlpeV5CVWDRohcopqcC01QWcQZcar04lYFNT4UghYYgCEIUJNPSTL1QJJKtseW8GWVfb7VMJKzkzRig5M0ovy6VukSvNsmI/L1yol1xKgiCsLfgn3CxEvirS0msH6xVAMeaO6OQmo4DUVmbqUVt4jggCEIrI+Jcj4XnniXHFquon+vaFTPOSG9Wjy1GeTPaYxB/3owyrxhf3kwkJG+m5SLNGSE1iFIFpl7Gl1gVmHo/VhWY9pjeOQWVCkwPo+6/SaGhqAysItZmgiDsbaScpZkRlheUeNHPm8HkmBYreTPaa6P4Gakzb6fOm1FnzOjlzQRQ8mYqfBtnuAuaIAhCqyAaVWvUdtB6FBFuB53quTMJdhyIydpMEARB8BGFpVnUeTOlBse1DRplXAizNFOPKcp+M+fNQDBvxgLqsVw9xgstn1bdnPnf//7HmDFjKCwsJCsri169enHeeeexcaOFH6dCkxDmedjsKjCFRKnAtFu9ADIDnJrXRj7KUSDWZoIgCBZoakuzuPJm1JZmRnkz6q322khomzLavBmw/HMyjvk7W1XkawqGDo39AwRhL0fqptaBJccBrR10R4xDiANEa2lmleZ1HBBrM0EQBGP0nleJtjRTmuJRWZrpPd8to52PizdvxuprFUreTK7qmFEsg16Eg4aY3IaElKJVNme8Xi9XXXUVp59+OuvWrePcc8/lxhtvZMSIEXz//fesX7++ub/iXs3epQIzIkIH3izgUj0Qlfq3Bh6cYm0mCIIQJ8mwNFMwWvWhrQ0UB7MwSzOjxowZsebNqFH/fIwiXAbQ/tWo82aydP5Cst0mq2gEQYgbqZtaP4lzHNCzg1ZomY4DZoi1mSAIQgRSxdJsJ9ZWUYY1/yNlnumhzZvRjnlRrJrRQyeiQT12q8d0o4w5IJg343cpChPGCylFrMsCUpopU6bw0ksvcc011zBlyhTS0kInDtzuaDuiQnMzmBWRmwndvOFNiiIirDjJxvdDX/26Ft8DVfn/ifJwjWZJo4vgPy+X/x7qYxBeXGQHDxVE8VEqdvzZneLeG9lId2uNKz/L2w+M7E95HJbs5wRBEFKZlLA0a9BpZmt/mkS1il5raZaMvBmttRmEaXyUvJk0zAXUkjcjCCmD1E0th+l/+AVSs/D9Lp+D4Ur3/qyK3jLGEsqY5/LvV+F72Mdi/aLg9t/DTbAGSydYGFkQzDkJrZ/KCFcgl2LJZWENfRJqFzO+X/ik2HSbTcKYBUFoXTSHpZkWrcg5gHYFpl6eTCLzZtShm3rNmghEHdUQxIoLkYw/qUerWzlTW1vLvffeS+/evXn66afDCgwAh6NV9qT2OkxVYKDbcdbHLJQrGSqw1Au4VCPWZoIg7NWoLc0MVs2kjqUZRLYrizdvxuhazU9Io1+U6trEANOMGUEQkobUTamPlQkUq3bQYcKtEoMLLTkOKKhro2T9fyVBjgMKGmszhWRZmwmCILREmsPSLIxIlmZ6QmwnOnkzatR5M2bjSyLzZsB83lGHktCXkfJmIoqthZSm1TVnvvjiC3bv3s2pp55KY2MjH330EQ899BAvvvgia9asiXwDockJKImUVRlzjK6MsGxPTdSd5mzClxwmK3cGmiPgcldpO3p+thGbK2iDI9ZmgiDsTSR61YwecVuaafGYndCzNVPvJyJvRvsfJCJvxpHZoHuZaaNGyZ2p8G2cVv8OBUHQReqm1kUkO+gQ9DzsjfzuA6gnmYxWs1ht0uhlzaj3tcplRfWssjbTy50xsjYzQck30EOszQRBEAxIgqVZWN6MFaJ67ivzcNpxJ9l5MxqVQwFBAUQHjMdeyZvZa2h1UqglS5YAkJaWxuDBg1m9enXgnN1u56abbuKxxx4zfH99fT319cFqv7JS22UVmgrb15GXQ4Kvgxzyo7oE/YZFB51jYT0SZQm9sjxfsSOLFz1rM+W+OnI0J+HWZurl+dsJXxm0yQbdvGHWZq6VmZz30n/Z9/OV2D1eVl7Ul8p/5IR9pFibCYIgRIfW0mxj6UaG9xrO6Refw/3/fgWIwdJMrzbQzZtRo7UzMyJ47uSTDwXgk08WmFyvR+LyZqxiq4p8TcHQobHdXBD2YuKpm6RmSl0s2UFr6YjFgOV8ggWU1g5abQsNoVZlRvWUYmkGwTpJXS/VYtkGxomxtdlW4rKJiRexNhMEoVWTTEszBaO8GbAgaq4l3OnGpbMfqVGTgLwZoyGtiIjuP2r3IKF10epWzuzYsQOAJ554grZt27Jw4UKqqqr45ptv6N+/P48//jgvvPCC4fsffPBB2rZtG/ive/coOraCZaL5Idq0KjA1RpZmZqgf+FpLM1Tn1OiowLTEEHC577LfuPnqKfzfqfey32e/Y/f4/m4GvLGG3YvbG78R69ZmsnpGEITWxGX334et5zAKhxxLfUZwZUeTWJpZOa6LNm9G2U9U3oxvxUxdXS1vvvkMd955CWeeeQCHHJLNsGFpbNlS6rvMat6MigzVHzRLZylRtjs2u7PVq1dz9tlnU1RURHZ2NkOGDOGFF17AG+Uk2KZNm7jqqqvo0aMHGRkZdOnShUsvvZSNG/Uz3TweD88++ywHHXQQOTk55Ofnc9RRR/HJJ58YfsaCBQs45ZRTKCoqIjMzk379+nH33XdTW6u/wnb37t1MmjSJvn37kpmZSYcOHTjzzDP59ddfo/qzCQLEVzdJzZT6qB0HTO2gY3Ic0O5rH/oWcmJCSIDjgBVK/Vs9qxyC1jqK1Y6aqBtesnpGEIQWSlNamhmiPKdLVcfMLM0U55mQIaNSc0BtaWaGXmSBGcnLm9GNcCCCq5DiRuQXVmvFAULq0eqaMx6Pb5IiIyODjz/+mGHDhpGXl8eIESN4//33sdvtPP7444bvv/3226moqAj8Z1SAC01PTMv1Ysqd0T40Y2nSKOjlz0BUuTNOzWs9X83S4O7wZd/z5F9u4fnzbuLAeT/jsdtYMW4gL3x8KWuO7AVA3w+DbxAPZUEQ9gYiWZpVVVfz3xlfYbPZ2OWs4OMZ85L3ZfTsuPTyZgwtzfCf1LMv0+5HakhYy5vZtWsnTz/9D2bNep+Ghjry89sFT5rlzWjJiNwgyYk2m03Fb7/9xiGHHMK0adM44YQTuP7662lsbOSaa67h+uuvt3yftWvXcvDBB/PSSy8xcOBAbrjhBg455BBef/11hg4dytq1oSIRr9fL2WefzXXXXUdlZSWXX3455557LqtWreKUU07h2WefDfuMjz76iCOPPJJZs2Zx3HHHMXHiRAoLC7n//vsZM2ZMyKoEgPLycg499FAef/xxiouLmThxImPGjGH69OkccsghLFgQ7SooYW8nnrpJaqbmIRo7aEuUGBzX5s6Eoa6XtJNS8aIncHOhWz/pWZtB5NwZP4p1TlKszQRBEForCbQ0U5riibc0044Z2tcuYrM0M0I7/iUub8YqVjLoZMVmatLqmjNt2/qSfIcOHUqXLl1Czu2333707t2btWvX4nQ6dd+fmZlJfn5+yH9CatG8KjA10SjCjFRgiQu4zKmr5oH7b2fqPyZw4MKfcaXZmHHqSG797J888sTfqC7Mpedi39/RhrFddT/OcAAVD2VBEFo57339JdU1tdx0+XnY7Xamvmm82gHCLc0A6vFZlzUa2X0lzNJML2sm0g9tq9cphGbNFBQU8uyz0/jqq0188slqBg60aCOmzcUkNG9GnTFjmjdjkb/+9a9UVFTw8ccf88Ybb/Dwww+zdOlSRowYwbPPPssPP/xg6T433HADO3bs4Omnn2bWrFk8+uijfPzxx7z33nvs2LGDa6+9NuT6Dz/8kA8//JAjjjiCFStW8Mwzz/DSSy/x66+/0rNnTyZNmkRpaWng+traWq6++mpsNhvfffcdb731Fo8//jg//PAD1157Ld999x1PPvlkyGdMnjyZP/74g5tvvpnvv/+exx9/nLfffpt58+ZRX1/PZZddFphsFwQrxFM3Sc3UdFiZSFEmZCI5DqgDhXUxchzIDtvB2GM/ntyZKB0H9FBP1qlrJxPHgaZAHAcEQWgJRL1qJuUtzcA45ywWzPJltNZmEHPejAr12K3nIhQxnkBIeVpdc2affXz/2AsKCnTPK8eN7CKE5qHZVGBhc2t6KrBoV8uosVJ0xBlwuRUGblvCR/cfz6k/TMNjg5kH27j+ajvPXD6U7SW+5UNHvvwj6XVudg5pz+ajO0ccGK3k/YAUGoIgtA6mvvcJDkcat179F0aNOJjZ3yxi/catupZmjY2NTHn4XwzrO5puWftyRN8jePbBZ8MmxpW8mR9nfcd9V17GGQftw1Gd8ziqZx5/OXYoH735UvgXccGwQTauuuxodmzfzJ13nM+xxxYxcmQbbrzxJDZt+hOAdetWMmnS2RxzTA9GjuzK3/9+MeXl24nW0qyqqoIHHriV444bzBFHlHDBBccya9ZHYdfl5LTh0ENH07atuS1mognkzSjbCt/GqbP6aPXq1XzzzTeMGjWKE044IXA8IyOD+++/H4CXX3454mfW1dUxa9YsOnbsyHXXXRdy7qyzzuKAAw5g1qxZ/Pnnn4Hj06ZNA+COO+4gOzv4W6KoqIibbrqJ+vp6XnvttcDx77//np07d3Lqqady8MEHB/+8Nhv//Oc/AXjxxRdDrNimTZuG3W7n3nvvDflOw4cPZ/z48fz22298/bUFyZwg+JG6qfViyXFAawcdl+NAumarYKWG0hOxJchxQI9SwOMBg6aXWJsJgiAkjpgszRRKLX6IobWlMtemkKy8GYg7b0aLXmSDhpjchYSUpNU1Z0aNGgXA77+Hdw5dLhdr1qwhNzeXDh300uGFpiQlVGABzFRg6uPR5s5oj0XrXenHqXmtqAW8Hi5fcyvvvn4hPXaWUdYG7rkgjdfGplGeb8NVvC8A7bbtYui7ywCYecNY0DRPpNAQBKE1E8nS7Le1f/LjshWMHXEYHTsU8pdzTsTj8fDaW9N17zfxynu5/7ZH8Xq8XHztJRx93NG89MRL3H/DP3Svf/3xR1j27Xz2PWgYZ105kRPOuhBneRkP3noVT97/N90hoapyN1dceiRbtqzjxBMv5qCDjua772YyceI41qz5lcsvH01NTTXjx1/EwIEHMGfOJ9x55xVR/b24XC6uvfYcli79gXHjzmT8+HPZvn0Ld975V95772VCV89A8Gejzs9HJW/GiMzQl4nMmykY6lvFM2/ePADGjh0bds2RRx5Jbm6upeZFeXk5brebnj17YtMRG/Tq5bMInTt3buDYtm3bQs7pXT9nzhxL1xcUFNCuXTvWr18f0gDatm0bRUVF5OXlWfoMQYiE1E17F82bO5NkxwEn1hwHvB7yy9bS/YtZtPviN8CatZkeibI2E1GbIAgtCvVckMHzLm5LM71cMLO8GdDkzSjiZzXJzJtJV+3HmDejI5BQj9XqMdw0b0ZokSTCFDal6NOnD2PHjuWLL77glVdeYcKECYFzDz30EE6nkwsvvBCHo9X90QWFbt7Qh3lHTL2Gg+QTfGBn+/fT8T2UlS0EH7LK8UjLIt34/qmp75Ou+qwIxUo1vpU+CmVAERS4N/Po/As5Yq1vgmdBv1zuP/9inHnt6bLrIeqzetCY35WN2LjkxTdwNDRSOqw7fw7vSV/+1PskwDeQKs0w70iVb+Vo4l/RJAhCi8Hr9eJNcbW0LTtbd/I8WqbO8K16uOj0cQCcftIorrnlYV57ezp33T8Buz3YjPhm3iL+8+r/GDRkIDO+ew9Hrm8lyRV3/I3xB4zSvf9tU56na9fevhf+PoS71s2NF47jvVef5ry/3ECnjj1C3vPH6p85/4KbuOnmJwIuZg899Fc+/PBFrrzyGK644g7OO++vgAev18NNN53Nd999ycqVPzFgwP7or6AJfV1Wtp3u3Xsxdep00tMzABuXXnojF154LFOm3MuoUeMpLu7mv1rboFG9VDdlFB2DSd7M4nlfs8zfSEnDjcM/vjpwk67se3xbWz2U9OjMJePH69wwlD/+8C3D7dcvfElnWloavXr14rfffsPtdpv+DmzXrh1paWmsX78er9cb9v+xdevWAb6VOgpFRUWBcwMHDozqei0VFRXs3r078J4+ffoE3rNjxw727NkT1qDR+wxBiITUTa2f/qyK3UZGIReDFSnZBOsgh8F+NLgIDh4ugnWW+v9/2t8l2b5D2nmvnfgcE1Q4GvdQtPonMhudAKQ1WP+Oq9knIZNh4/tJMLMgCKlL1JZmKprc0sw0b0aLMnaoRdOJzJsxIwrb16iFEkEC83bKnJ3flUjGnJZBq/yl/fzzz3P44YdzxRVX8PHHHzNgwACWLVvGnDlz6NmzJ48++mhzf0UhRgazImyVRx/WBDruxb03hgaIdSZKf2GlKaPsVxF7gWGGtvhQPk+DE90QzsMaP+DhH+6hQ1UjrjR4bvQoXjp4CnRJo8cOX+BxVfujwGaj369/MPKDbwGYc8NRIatmklloTLfZJGxMEFow3tpaVh10cOQLm5F9li7BlpNjeo3eqhk1LpebN/43k/w2uZw6diS0gTxyOO3Eo3nzvzP5avZCxo45LGBp9vZ/fFk0k+6eSG5uTmDNR6eunbnkhit48q6HgKClmbs+g669ekFD6Oc6HA5Ov+BqFnzzJYt/nMtJp1wccj4nJ4+//vWfIXkzxx13Nh9++CJt27bn3HOvQTlps8HYsafz3Xdfsnr1r/7mjBrjZ/G1197mb8z46NixC+eeewUvvvgQX3zxPy688HrV1TaDfQMMfmUunvcN/7r3wcjv9zPyiIMsNWcqKnyeZ0qOhpb8/Hw8Hg9VVVW0a9fO8D45OTkcddRRzJ07l+effz4kX+ajjz5i+fLlACE5HCeccALvvvsuDz30EKNHjyYry/f/l/Lycp566qmw64844gjy8/P5+OOPWbZsGQceeGDg3N133x3Y137Ga6+9xr333hvyW3bBggV8+umnYdcLghWkbmp5TP/DP1k2C5+CeQ4w2jcxY8WWuDsbQ1eJlBBuHaN2HHASrEfCNBvpmoNq4ZqVGkppwihiNuWYQ7WNwXGgQPW6zEub4nUUVK7Crsps29M98kqZNfQJ8/f/mf2js5I5jqB1tyAIQksmSgeVpFuaOTGxNAP9VZeRss3MMMqbUVbPKK8NbM0gNG9Gj5LQl4nMm5H5udSl1dmagU8FtnjxYi655BKWLFnClClT+OOPP7j22mtZuHAhnTp1au6vKMRJ3A2FAoIPxTD0cme0+5EwChxT70cIuNQRrNt3uLgp7UpeXnoXHaoa2dzOwSWnPMBL3Z4Hu4N2e/5Hfs08vDYHzo6nkL2nhsk3P4DD3ciiMQexYaivCBEPZUEQ9mbUlmbT5n7NzvLdnDXuWLKygif+cs6JALz62rSQ9674yTf+HDZiKPVkhZwbOuIw3c+rrqriX/83mfMPH8JRPfMY1sHGsC42/n7VGQDs3L7Fd6FqiOjeox9Z2aGNp6Ii3++Xvn33w2ZTflz7toWFxQCUlW0z+6OHkJbmYP/9h4YdP+AA359j1Spl8knbiDHzL9OgqWEcmQ1cfc9d/OqtYo13J2u8O9nk3cRO7xp2etdQ7V1BtXcFHtciPK5FeHctYt7b//K92SRvJtE8+eST5OXlMXHiRI4//nhuvfVWTj/9dM466ywGDx4MELKi6vzzz2fUqFHMnz+f/fffn+uuu46rr76aQYMGBYLS1dfn5eXxxBNP4HK5GD58OBdeeCGTJk3i8MMP58UXX2TAgAFh77nvvvvo3Lkzjz32GEceeSSTJk3iggsu4KijjmLfffcNu14QrCB1U8sgmgmVSHbQIeh52kcMJ84mcu6M+uEfTQ2ltTRTttrcAAM0k3SOymo6Fv5Ae8fv2PHg8Y9froxc6goLA24LWmszRfgXiViszSSvUxCEVkHKWppByLxaALfOvnoljd55hUh5M2pM8mYKNJcaxS9I3sxeR6tcOQPQvXv3kOBVIfVpUhWYegmkerFMACMVmFJoWF1Jo7Y0g1AVGOivw9enY/0qnmx3CQf84QTgqz5duKP/G+xp2wWATO+fdC7zKZG3F1xHXW5/br73Orqt38L2zh2Yev9fqCYvxKvSiFiszWT1jCC0LmzZ2eyzdElzfw1TbNnWnp9mTJ3ua7785YxxIcePGTmMrl2LmfbJN+zaVUF2sa8R46yoxm63U1jUPrAYps4fqFLUsTjs/q6GBq4eO4qVy5ayz/4HMu7si2jbtpC0NAdbNpQy44PXcTWoug2Nvk1ubvgS+LQ0h/9cG80Zb8B2yO12YbZSRk1BQXvsdm2jxRZo9OzZU4mlvJlIZEa+RA9bVXTXKytmlBU0WiorK7HZbLRpo/37C2fIkCEsWrSIyZMnM3fuXObOnUvfvn3517/+hdPp5JZbbqG4OPi/t8PhYObMmTz00EO8/fbbvPTSS7Rt25bTTjuNSZMm0b9//5DrAS6//HK6dOnCI488wrRp02hsbGTYsGHMnj2bhx9+mJUrV4a8p1u3boHvNHPmTBYuXEj37t257777KCkp4dxzzw37DEGwgtRNrRM9x4GIWLaDBmuOA+qVNFbsoPWw4DigfI0C5YCXNoWlFHReid3uweNNY3fjvrTJXE9GYyVV7XqEZXBGIlGOA4IgCKlIyluaqTG1NNNr5KvzZszGoWjHKPWKGfX0evx5M2qszOEJLZtW25wRWgbjvV7LSqEDdv0e1oXvy1r9ZZHa3BnwNWb0uuwBlIemldwZiGmZfVjujPbzdHDCsd1e5n7XU7Td7qEuHR7rfxpv5/6fr6jYDraO9XTjFuzUsSfrMMraXsIps6cxZvocGtPs3P/47VQXhAcIgxQagiDoY7PZIlqGpTp6lmbqVTMbt23ji/kLABh59lWG93nt3a+45voLAMhv2waPx0N52S7adOgact3W7b6mgMffwHDXZ/D19PdZuWwpp/zlcu58/BXfhX5x1hefvMuMD15XWZdp8Kh3Ggm/0Guwb+U1OJ278Hg8/tUWNpRGTHn5DgDy8pQGkUHejHJIyZ0xE5FlBD9/+bw5LJo3PyxvBiAdd0jeDEBJcWcuOUvf1qxgaHDlj5I1o2TPqGlsbGTdunX06tXLcn7GgAEDeO+998KOX3LJJQAMHRq66igzM5PJkyczefLkkOPz/Pk62uvBZ1V2wgknhB2/6KKLsNvtHHTQQSHHu3btyiuvvBJ2/T333GP4GYIg7N2oc2cSawcNobkzahKZOwP6jgPq76D65JwaCg/6iax2u3xX1hdS7hxMWvsGMhor8WKnum03n4ivJPI3EmszQRAEmt/STDuX57RyIxfheTPKfjLyZrQ1Rux5M8W9N+peZmnuTsabFoc0Z4QWR8tXgVkPuEyvr+W2Iy7jvB3LAfizMJObs55gde1ony2bn05tniCb1bhpz6YOD9BrWyl3PXc/AK9ddxG/HDQoxKsSpNAQBEH497RP8Xg8HDnsAPbp3TNE5ORNh0a3m9ffmMHrUz8KNGcGDRnIz0t/Zf78nxh3emhzZsn878M+Y9OfvufsUSee4jugGi6WLZwffBH1HJba1sxspYzxucZGNytWLGbIkENUR20sX+5rWO2zz+Bov5QPB6a/MBfNm8/z0WTOHHaQYXMm5LqRPtneF198wW233RZy7ttvv6W6ujpwTaxUVVUxffp0CgsLGTNmjKX3vPXWWwCce+65lq7/7rvvKC0tZdy4cYb5OWoaGxt59913cTgcnHHGGZY+QxCEVoLGcSBmStCfDNPzxDd0HFDqpWTkzigfbKZA9pLXbwPthv6OPb0RT2Mau7cNYI+tJ2CjwO5r3FdndMbjCGatsckG3bzs+LM7xb19Tgzd2cha+orjgCAIewVRr5pRjTfqVTNNbmlWjUmTRj1YqW3M9PatkIC8mUiURP8WUI07mjFHO84IqYsYUwutAvUP57AOs6YDbY1E5s6ol05qt8aDQUnRUv57yIhAY+bjHn05o3weqytVI2EZtMmfS2Hu2wBs4p/YS/N5/OWbyamv5cchh/L2FeeE3Fc8lAVBEMDr9fLaJ9Ox2Wy8/sRkXnnkTl6ZovrvpTt57dV7GH7Y/vzy82qWLv6VGrI5+yJfk+Wp+56ipromYGm2bfNW3nz6+bDP6dyjJwA/ffttyPElP37Nx++8HHpxo+Vvr9nGznPPPYjL1RB4vX37Ft599yUyMjIZO/YMgqtmlG0aln8+6uTNAFx7zx2W82a8uxYx771/Wfq4ffbZh6OOOoq5c+cyc+bMwPGGhgbuuusuACZMmBDynrKyMlauXElZWWi1V1tbi9sdOkbX19dz+eWXs2vXLu6++26yskIzhyorK8O+0wcffMCrr77KsGHDOP300yNev2XLFiZMmIDD4eD+++8POedyuaitDZ0V9Xg8TJo0iVWrVnHdddfRpUuXsHsKgtD6MJpwCUzQ6KAWZYWItoxyZ0yzZ7STTk2ROwP64c6Q1qaW4lMXUjj8F+zpjdTtbM/WBSPYU14CZTbsNhc5dl++257MHlF8l1BituYByesUBGGvJK7nph56lma16h1t3oza0kyPJsybUXKvO2A8xqrGZPVYrZcdp5cxZ4Q0/1MbWTkjpCYWVGDqJfqGlGCcO6MlabkzCmo1mN4HBwuYYwe/wAO8Qt5uL1WZNu7Pv5RPf7klZLUMZeDoup2u3X0TTmXVf2FP3gju/Oo+BmxaRXmb9tx6yyOUre8QogLTQ6zNBEFobUSyNJuzYBHrNm5h5GEH0btHN1DFkHhV+xdcdjo//LiC16d+xINDh3LkqOGcd+mZvPPaBxy7/7Ece9qJNNTXM+O9aQw+bBhff/o54LM0Axgx5mS69CzhP888wtrff6FP//1Yv2YV387+lKPHnsbsmR+EfkkjizM86NuYaVfOWP/hXVTUkbq6Gs47bzQjRoyltraWr776hIqKXUya9DDFxcrKIF8z5qmnbsfpLAdsrF3rW2X59JOTyM7JAxuces4EDjjsyNAPiTFvJoCSO+OPkXHWG14JwPPPP88RRxzBqaeeyjnnnEPnzp2ZMWMGv/76KxMnTuTwww8Puf7ZZ5/l3nvvZfLkyQFrMIAlS5Zw+umnM2bMGLp3705lZSUzZsxgw4YNXHHFFVx33XVhn33ooYfSvXt3Bg4cSFZWFgsXLmTevHn07t2b999/n7S00HyfKVOm8Oabb3LkkUdSXFzMxo0bmTZtGjU1NUydOjXM0mz79u0MGjSIsWPH0qtXLxoaGpg1axYrV67kxBNP5MEHra9GEgSh5RGvHbRlLDsOKHYtzek44CV3/y20P2YN9kw3Hrcd57IBVG0ogQKb77IOkJuzCbvNQ4OnDfVp7Xz2bVEI+GJ2HDBZPaOHrJ4RBCElaUpLMwW1zaZpPIEZ2kk+F8YrZpKRNxOBIiLmzegRlduN0CKQlTNCStEkKjBTsgn1hdRTgamJpb+ppwIDcIFtD9ePvImnGl4mr97LyvY5nLn9DT5dckvYXWz2errvNwmHo4LamoFsr7qBsT/M4vyl7wBw+6UPUdZez4sgiN4AGbVlnIWBWlbPCIKQakz99BMALjnzJNPrzjjneLKzs3j/nZnU1tYB8NDLj3Pbg7dhs9l449mpfD1zDhffPJHbnno47P05eXk8P30Oo086g9+WLeK/rz3Lzu1buO/ptzjromt9F+k1ZAJDmJI3o4fRBE6kvBnfMzk9PZ1nn32Pgw46nM8++4Dp09+huLgz//znS5xzzpUh14KdOXM+YsaM/zBjxuvs3OlTIM/56kNmfPI6M6a9zsZNke1fMlSFTxbhnZZst+/v2FYVdiqMAp18lUGDBrFgwQJOPvlkZsyYwdNPP43dbue5555jypQpkW/qp0ePHhx99NHMnz+fJ598knfeeYe+ffvywQcf8NJLL2HTGdfOOecctm3bxmuvvcaUKVPYvn07d955J8uWLaNnz55h1x9++OF0796d6dOn89hjj/HVV18xbtw4Fi1axMUXXxx2fdu2bTnllFNYunQpzzzzDK+++irt2rXj5Zdf5pNPPiEzM95OmCAIrQ29CZzkOg6oicVxQL1v7jiQlueiw5lrKBq3Enumm7rNBWx9ewRVv/cCbCqrGy9tctYDUOXp6cvsVCj1b/1WOoq1jmK1E7fjgAlWg7YFQRCSSXNbmgVQLM1KLdzEic/SzBS9cSWZ6OXNWLQ2M8mbsWKvKbR8bF6vSDPMqKyspG3btlRUVJCfH0WYkxAV6gEhMBAok/7+h7/y4F/efmCggaCsnFlDn5CHfIhnZSnBrvt2fF33MnzLIZ0EfSrDlkLuIhggVqXaV5bTuwn3rNSi9qEE38NZ6azn+LfZQDY5ubU8dsgdHL1lNwAz2vfjjoXv4GrMDV8GiZdu42+joOdnNLrasHbNO/Ty1vLWjAvIcdcw9bDLePyyWwINKuXhrm5eKQ95pbmlXjmjLuTUSyVDmmRqFZgmd0avySYqMEFIPerq6gIh6VqLppaM3qoZCF05gzrKo03odeqVM7UO399LjerHdT2+Y3WqZSHKsQb/c19ZOUODjUAPQj2/pO7Tuwj2X5ReTKBhoz6g7CurZTyq/7ya/9DZKmitymz+/+yqrXY/LXSrPaQMbcpwp4jJMoEM3+c7MhsCzZlM6gLNmUx8DZkcasObMyYrZ/SaM0JkrP67l9/AQjTI/1+ahmhrJiCkblKEWWvpG+r1r54U20pozQQGdZPaPka77wZqCBZYyjEwniTT+vcrdVI6oTVUFrmD7LQ7tpy0LA9etx3n/P5ULuoNWTZ/rUSgbsrtt4miHj/h8aSxyX0MXtJ9SmVlMqzEvzWom7Q1EwTrpqhrJgipm6RmEgShubHUnFELci00Z7TzdUDI+AMmeTOl/jcoYxFYGI8UthM6Z2c0l6eMSXrzeXq2ZtrxKZvgGJWjet3ef749Ic0ZI1szZeVMZ0LGIvU4pB6DtGOPMu7YviY41vjHGO34ImNL8kjEb2BZOSO0SJpeBZaI3BntseC5kq5L+e/g6zl6y27cdngkaxy3/PCurzGjxQkdhr1IQc/P8HocbFjxBG2qcnl+1jXkuGv4vmQ4Tx19k+9aAxWYHuKhLAhCayakMWOCXmMmEqaNGS1mlsaGlmbKSa19WbJ/ZGubOGDpp6POMKnkzUCwGaMmJ9xbVBAEQTAhMPGiTPb7J2bMHAfUGNkdB2jW3BmFUMcBe46X0ef/xGHjl5GW5aF+Sw5b/z2MyoV9wOu3MXMG352eW0X7rr8AULmnt68xo6C2zLFAJMcBQ9W45HUKgpCiWHrepJKlmV7eTAC9WiJS3oyCWWNGvZ/kvBkVVvNmrI75QmojzRmhRWMpJ6XE4HgHQh+SuisOtTZm6mIjEZZmbo4e8jnvFk2h9y4Xzmw7V9XexL9/upXQyTAC6oC2g2bQ8VBf6PSWJf/AVX4gU5ZcR+e6razLL+Hmw5+k0W79u8VsbSaFhiAILZW2kS8xQmnEJA3DVTNqzLJnrDZsjJ7JNtV/EPpTMS38ciMs2i3rNWoEQRCE+NELCtab4AHC7aAte+BrbVuUfW0tkm6wb0S4Hc3A/hvYcvMlzO53K2+nP8Dur9ux7c0SXOWN6E3K2dLcFO27FHtaI7VVhVTs6aefW1Dq324KHRejtTZTY9XaTBAEIdVoCkuzMKK1NIuIsjpG2Ve2Rnkz0WKUNxMFeg0avWiGBCCrZlIfac4IKYeRCswqTaMCUxPjw9jWyF+Pmsoz9TPIr4M/2mZx5rqH+KF0BKHWAMFNTs9ldB1/FwA7l17K7rVncO+6uznQuZwKRz7XHPs8lZn+WcdS849PpoeyIAhCc6FnaWa6aqaN/mH1qpkane59XbRJ9+p5Jr3FlJYxsiqzep22KaNtxhhdq/nJqLU000PzV5Rh8Q9rxdJMEARhbySWCZaog4ObPXcmdJDMdtTz2Dnv8Nt5N9ExzQnAnqpsKn9s51sto4uXwsEryMjfg7suk7INB0KZ6trtBm9DZbFjQrIdB0TUJghCShCHQ4rZczLM0kwP9XNa21hXLM10UTfrjew0E9GgUaOMdUpRZDFvJoIQQi9vJuoxXWgRSHNGaDEoy/VangosXPl12n4LeGXMA1y3fRVpXviifSfOXvoEWyrCA4MDdy/aSI8Lb8DucFG5cjTbv7+Ryyqmckr5J7hJ4+YDn2B9217hbxRrM0EQBEt4DRo1CnqrZiJamllpKDRGviSYMaPsK9tE2ZvZMG7U2E1fAqF5M2oy9L9bls5fjJI3YwXJmxEEQYgdUztoIwqw6DignaRSH4sGN93b7uLH6x/lbwNmhJw58IXnMVM35HXbQG7vLXg9Nsp+OQhPo4GgIonWZiGitigdBwRBEJKNpawZNSbPMYWEW5qp0bM0C8uNrtU5qXWvUa+kMfObVjBy01GvngFDW7NIaAQRRmOyJdcgocUizRkhJWjdKjAFN9mOBr6/7l6ubv8Fh6+vo9EGT+UexI0/3EG921ghZc+ppOc11+LI203tpn3ZOO1BRrnncfOuJwB4qOft/NDhCFMVmBlibSYIQmtAb9VMGHFYmsVNpLwZ3SaNh1BbM6MVMYlerp64vJlISN6MIAhC8onZDroInx20LsnLnRnRZx1LJz7B4DYbAscaGtM4+KXHqaxXcjmVgOfgpFxGWyftR/4GgPOXfaivaB86qWdmbabBzNosLlGbIAhCCyahlmYKpTrHtM9rp6WPItzq0mz1vtG5BOTNQHjejB4loS8jugH5CeTNKG5DfvehgBuR0GKQ5ozQqohaBRZV7oxWBRadv+SgDtuZfd0j2L7Novt2G3uy4dqG8by09FyCk1+1hPpjAmkuelzzN7I6r8O1qyPr//MMfV2beKT2Fux4ebfNObxdfH7owKWoDUrNv5NYmwmC0NqxammmXjXTrJZmYXkzarSNGaOt9notetZmytaoGRN73owjsyGwLxkzgiAIicPIDtosIFjtOBAyAaTndW8htNjnOKAmfseBK0cs4asLXqDIURVydtIXF7J0a4nhN7FnuuhwylJsaR5qNnSk8o/expN5eqI2jeNA0tA4DoioTRCElKOpLM026TzrIlmamRLuXBPcb+K8Gb35RaOYBYO8GT2XID03ISMkb6ZlIM0ZocUTlwrMkMTmzpy6zwoeOuUt8j7Pp20NbO/g5dwNf+GbPw/2X6E3Q1cLVNDlovvI2+9HGuuyWf/Cs+TvSeN5+1/JpYYf0w7jgaJ/QKQf7ylkbSaFhiCkHl750RaR5rU0g9DVM2C+UiaapoxR1kyEFTNK3owRFvpXpo0aTd6MkDjk37sgtGyimWiJZgInDMt20OCrleJ3HEizeXjunM/51+j3ybA1srShD3tcvrH1g9+G8szC4/1XKitm1HgpPHE5joJaXBXZlH03BKo1Y5zeqpkUsDYTBEFoSlq8pZkTA0szo7wZvf1EoWflGX/ejB6SN9N6keaMkJKkjgos9tyZrJIc0vLT+PvIufxl0Dz6zM/C4YF1fTyct/xK/tyl+Kypwy9Di4yiE9+h/dEf4/XY2fTCo3g29eLp7OvpatvCem8Pbsp+ErctPbTQSJK1mXgoC0LrIy3NN7vuckWdSp9y6Fmaha2aSWVLM0M8hDZbtPvNkDejbsooorEIc28ZKhGCXt5M4JtU6R93Wml0CZZQ/r0r//4FQWj96E3omDoOGNlBFxB0HNAl9tyZdlm1fHXte1wz4FsAntl1Ag5nDXnpDSzb2o2LP55AuKIh6DiQf8gGcvrtwOu2U/a/g/Hu1IjpUs3aTERtgiC0QJrN0iwqtHkzWpKVNwPhq0lNMMmbUY/RQutHmjNCypCaKjCIWgVmh/bHd6THOR2ZesU0Rtn+YMCvaXiAtQc0csX8S9i2R3lga8PJQBlICo78ik5nvQrA1rdvZc9PI3io8DYOdiyl0tuGazwvUGEr0A9GA8sqMKvWZpaJY/mrIAhNS3p6OpmZmVRUVOx9avrmsDQzO96IjqWZ0WoZr8F+okhQ3kyG9e+V7bZud1YwdKjla4VQvF4vFRUVZGZmkp5uZCskCEJrJaG5M9lhO5rX1nNnBnXeyZIbX+PowtVUezP5y+9X0mNnKYOLt7JtTxtOfvdqalzKOBQ+sGZ2K6NgpE+ot+urQTRs16gxnDp/JiNitDZLRl6nIAhCorHU9E1FSzMnFi3NalX7ytZoxUwi82a0+8SUN6OH6dityZsRWiaxpJoLQsowmBVhP4T7sCbQcCjuvTH0h3Vn9JsWBZrXYdnE6f6DDnwP8HTVFpSHui3dRodTutKtTyN3NNxLl6/rKai2UZ0Je4Y1MPGDi9lSpf0wVPfw/ZNsM2QBXS99BoCdn57Lrq9O59YOj3J87ue4vOlcX/MM67J6h9+ijPBVQaX4HvibbNDNy44/u1PceyMb6W4YNLaafawVbxYY3y88kGy6zSbel4KQIhQVFbF582Y2bdpE27ZtSU9Px9bClJqVv/yie7xO/ZjJI/Q3eTAGBa+qL1DvCDZEXH6FbgOZgTc3+psUDWQA9bhIJ+Bf1uABly30/spv/kaCgl+l76J8VFh/RWnOeAmunNHbonpttLpGi02zbye4Ysau2tr9f2Y7vrGpMfS7Kl+hUXVLL8Hh0QZ4vaRluPDWgQcXGTTQCLj9f5d2GsmijnrA5v/fxqb8vblV9yTkfy7q6iS3Jlq8Xi8ul4uKigr27NlD165dm/srCYKQaGbhm1CbQ1QWWt3ZaGp5TBEWFMz56BRQBGsnhXTVa9/+KQes4Y3xH9LGXsdGTwfOnHcxp2Us4JQjf6fO7eCUdyewqVJRUSj1kuI4kI09103RKVuw2WHPLx3Z85Pmz1KNb3JMQV0vbScma5k19NHNAVBY3n5gQEzoHWnu/iAIgtDcWLU0M1o1kxBLM6s40RluzCzNEolZ3ky26UvAOG9GhXqOzmycMRpXtHNvQstAmjNCq6A/qyIvKy8hfLmk8mDUXX2Sjc+7Uv26Ft+D2K3a+rDnZlB8Zgf277yNq8qfpt93dtIbYUehl3YHVXPl25exvqKd/+pgIya0yeMmu+8aul/zDLY0D7u/OZbtH1zOhQVvc0n7fwNwR/kDLMw41He5k2BjKUmFxs/sH7BCMCw0RhPs2AuC0KLIz/et5CsrK2Pz5s3N/G1io6YsfNYoxwHl6gNqu6wsYLfmtZ8Gu+953KBS+bpV+y7/s1s55vZ7fHncDuXi0K26KaNuyCgNGq/qXKCnoteMMfoPna12H8KbMtqttkGjfm0PXqvt4aSpLknz/6cMbw6wO9z+3UYc/gIpHXdgX7E7y/D4iyel76Js/XVWjaqxlrNuHUJsZGZm0rVr18C/e0EQWj7T/9CfWLN97fu9fsCu38OsZ/qyVn/irJs3VMnckSgsk5WZKJd/Xz3waps0AF7+ceIC7jt4Fnabl+9dAznr3TMZmfMrt50xH4DLpp3Jws3d/Ner6y//uGyrocP4PTjy3DTszGLXFwMIjFfKPF2B6iN3ElQvq2unrfhEfKWYqpjX0jfMaiYuUdtxRFQ7i6hNEIQmJcFOKJYtzdQianVpZ+QUAwTzZrTH1GjzZqxamlkhQXkzelEMBkTjIiRjR8tBmjNC6tPsKjAIPuCVQiO0wHC0z6LjWe0Z224xJ615j4ErfJN1G3o3MuKgnRz16uX8sUtZXu8mtDHjCGwzu2yj543PYM9soHL5QWz+90TG5H3DbcWPA/D4zpuZUXsS1PjfrlaB6aEUGhHQKzTiQgoNQWhR5Ofnk5+fj8vlorHRckJ9yjD3hBPCjo0q0Rw4UrV/aHDXOyy4/3tBcKJqDf0D+6X0AmCDakzZQE8AttEJgPKN/oftNv+k0GaCxYTSJdpNsCm0C99wotQUIYtB6vxvqsPXnKnHJ/11+4+5/ceULfjGkUb07TIhpGNCGr4iIg1fZ8oBZKr2c/37uf7jeb7XWQTj2NoA7QFFc9AOKMQ36aUsyujkpbD7VjqxDYAerPdvN1KCr8GyD6sBKHGuxbbI/74F/q0vdoC5paF/klErVyJET1pamliZCUIrYbzXG3UeScIcByJi5DhA4FiWw8XrF0/n7G5LAXitchR/fXUknW27ePHsTwH4v/kjeeeXAww+w1c7FRxZQ1bPejwNNnZO64bXVY1vbNOZHHNikpWjIQrHATVqUZshJqI2PccBQRCERKE3biTSUjFuSzMjnFYv1K6eMVs5E2lVTbrBtnnyZiKOLUKLRpozQsqSOiow8D1w9R7eDjK7eulwRj5XZL3PsIWLKNnsa8xsHlrPmD7l2G2wslLPYFJp0vjUX+ntyyj520s48mqoWdObjc9P5MDMn3ik8z+x27y8vfs8pu6aEPw6aoxUYAqlJN/aTAoNQWjxpKent7jJW6PJqawM1YvjgArVa1Xfwqu6zpaVE9hv8D9IfUXGHgB2+c21fBNZ3hABwI40//Ibry2oAFMa6ZUExQBlBJs2Tnw9FyeqWqIWX8NFWZJi8x9TmjSKlYviqaxVgCljlVFzRlmGr4xBNv+90/A1d5S/kAyCTZwC36FsQr2TlVVARf4/YzY+/zEvAQVYbZYDm38VTx4N9GUte4AGdgFgY6tPAZYRtDYL/G+1xf8nWR/6J8nKykIQBEGIjZgdB0DfM78WrDsOuOjatoJPJ7zFAXkbcHnTuH3DOTzxZn9sjY385+KPyc9s4NsNPbh77jH+e+k7DmT1rqft4b6Bo/zzDrjLTTLhWri1mYjaBEFoFpra0swobwaCNZMuSm2koN2PZnWMNm9Ge0yNSd6MFUoiX5KouAEhtbGQ7ioITUcsPzr1OsjqTrO6Aw2YryRRJpx0UXfKfa9z+qfR7dw87rI/yTFzFlOy2UZtBtydOZi1FT2x++cMRx9g3glKy62k59/+TXqhk7rNnVj/5PX09JbzXNcHybS7mFN1BA/suJlQWxo/ToObRtV8CqI3kKoHXPVArB6gBUEQUgWrvslGWAr1VWE51FKN0+rdtUvzrSjAjAoQbeNNqwBTiEEBFmFiK1EKMJmcEgRBSDwRV4VY8Mk3Hi+Ck1qH9y1l6cTnOCBvA7u8eZyy8Foe/3c/vG4vEw9ZyIieGwBYWVZImwxFkKAQDHdOy/dSdJLvXNXSLGp+zzP+Ws5I39uPskqo1PwyPYueiM0uMzQWQolUsQuCIChYWm3ZEizNQgRtlUS2MdMej5Uo82YKCM4vdsDCGGphLPYTaPYr4mi/a40Iolsu0pwRWg2WOsolOseK8D0sC1THssN2NK/TaXMw9D3Vxb1Vkzn0y220r4KtbW2cu/0kPvj5UK7+dDxv7hgOwNgBij++euLMt2/LqKHHje+R1XUHrvK2rH/8StrV7eGlbg9QkLaHn2r3YdLWu/D4cw1CqNa81rNps2hHEHHgjBYLhUa0dgyCIAhq4i0yjBRgekStANN79qaMAkyvqIDw5k2UCjCT5flqRAEmCIKQWAITMoqtsH/Cxmy1huGqD633fVSrSoxVxJeNWMSc85+h2FHJ743dOfzj65n5eXHg/D6FwaS4CQctY8b5b6jerbLstEOHUzykZUP9Vju75uQSXE2qTNLVhmxCUE/2mVlc+wUXigDD1C5bhZHAI0TUFoVVtyAIQrKIV9CmpvktzSD40NfWSS7VthnzZoqwnDejN0ZHkzcjtCykOSO0aiypwCKiVYFlUTDKzSHHbOAfpfdx0NcNZLhhecdMTv3pUv7Y0QPlAf3lmt4AjCo2aGHbG+h+zafk9ttEY3UWpY9fQvruHJ7v9jzdMnayvqET12y+hzqvxsKllvBBKlKhUerfRlFoqAdYS0pyKTQEQWhmku2brDSyDZ+dpTrHkqIAi7ZBo4eeAixbs1W9LCCyAqwkzq8kCIIgJJyYJ3QsZFeGk416ZWaarZGnz/0vU0e/RqbNzWe1BzP85atZ9XNOyLtu++pYHv1+OCvLCgHYt8MO3bu3G2Ujs4uNxlrY+XEmNGon+bTh0Cqcqn312JxAxwE1kYQfASyo1UXUJghCrMT7/EgJSzNTlAa9sq9sjeolK3kzDs1rSFbejBpxG9g7keaMkNKklgosG9K8FJ1cxRlDZ3Pz0mfY/yff2Q+KO3HBN9dQVR+6pH7Wki4ADMrYSOfiOs09vXS9ZDb5B/yJp8HB+icvwL2liMe7TGW/rA3scrfhqk23s7sxE+MJOz9Og6/dBIWGqbWZFBqCIKQSBr7JahJqaWaEM5pPMFKAKaSAAkyLarxViyTMvPlleb4gCEJ0NIkdtB6K44DSrDccJrIpyK7i84mPcP0+MwF4dMepjH/qIiq2p6FdwVnVkMmtX47m+pknALC5qo3qrG8iLWeAnfyhvmy08hk2Gis9GNZHWiJO7hGX40DU1mYmojaxNhMEIdmEPWfisDQzE7RFxOi5qxW0OdF53NdqDmozOKNF6yxg5DYAcefNqBC3AUGaM0KLpilVYPYsNx3P2ckNvV/g4q8/pddGG3XpcG/2odw9fwJeMsLes31XDivqfRN3xx28KeRcx7N+pN1Rv+FttLHx+VOoWdOJyR0/YmTe79R60rlm841scHWK/MWSbG0Wl4eyBik0BEFIFIn0TU64pZlCyinA9LbJz5sxI5pxXBRggiAIiSEmO+iIuTPZqMeNfTqXsuiGWzi2/U/UeDO5+JeJ3PriWDwNNkLHnVBLzS5tqgDYEmjO+MY8R3svhcf7hHAVP7ipXduoepfWClRzyqk5lgTHATXJyusUUZsgCE1CAjM6DS3NSnUu1j6PnUZ3VcTL2mNqtG4DsTZsICl5MyVxfB2h1SHNGSHlSD0VWDZp+fX0uGA197f5Byd+9TtFFbAz38bFuy/iveUnq24S/rCeu70/AMf2/tN/xEXRCT/R4cRlAGz+9xiqlpdwfdEczixYRKPXxi1bL+Hnui5EVII5Na+bsNBQIx7KgiCkAsn2TbZsaWY11NKJgaVZshRgDp1jCsnLm0nU8nxBEATBAhrHATPUKxpD7KDjdBw48cBvWDDhKvpmbmWTp4hjvryH/3x4MHiNFMgKDirrMwHokFMTOGpz2OhwajvsmXbqNrhxzlcaM9qxURlDU8dxwDIW8joFQRCiJaUzOhXiErRBeO1kVDepz4nbgJA6SHNGaFUkQwWWXlzBkL/M58HqWzhyXiVZLvitYyan/DqZFZsPNPkg34N61u++DxzVbhXgpd1Rq+l0zgIAtr47HOf8gZxfsISrC+cDcO/205mzZ7DqPuqAS1WhYXEFf6wky0NZr9AQFZggCNHQlL7JZjSdpZlCshVgUeTN6FES+VNleb4gCEJyMJqYUSZymiZI2MvfT3yOaeP/Tlt7DQsa9uGw/zzCjz/0InQSy7hJ8+Mmn3PAkE7b6NLGp4xud0w7Mjqk07jHQ9kn1eBVj39RTLClmuOAiNoEQWgGmiqjMyIWn7fmK2jU44F2P5oGjLrhonYWSGDejLgNCAZIc0ZolSRKBZZVspOTz3+Ge1f/kyHLPAB81qUr58x/FmdNoepK4wJjzrJO1HnT6ZK2i/1O+o0ul3wPwM4ZQyj//ECOy1vFHcVfATClbBQfVBxKTBNuTtW+XsClFBqCILRiEumbrKZlKMD09qMlSgVYRHub+BVggiAIgjlN4jigZwet4ziQmVbL2xMu46GhT5Nm8/JmxbGMfOZZNq8vJHSSy7xJs3VPG37c1BW7DUb3Wk12vxzaHJCH1+ulbHoljdUN/ivV46J21amGlmJtJnmdgiAkkJieFwnM6IzZ0kzrNgAGbgNqtCI27fFYSELejLgNCDpIc0ZIeQIqMM0S/bhVYBFyZ3IHbuT6Uydx/Q9v0Wc91DvgsfwxTJr7DI1etdJYXWyEP7zr6h38WN2XBVmZ2E9bhM3uZdfX/dn+/iEckr2Rhzt/ht0Gb+8+kBfLjyJ0maWF5TFOQif69AYyNaX+rRQagiDsLTSXpZkaS6GWEJsCzErejJkCTCEBCjCtCMKESOO3LM8XBEFIHjE5DoBuY75L2018d9PRnNd1Jm6vnb+v+ysXTbmP+j1Zqqv0gpb19mHB5q4AHNx9O4Un+JZrVi6oom6DMnC6NVsFA8cBLU79w81ubaZBrM0EQUg0sQramtzSzBnNTbRjg4I6tzPWJk3z5M2I28DehTRnhJSkeXNnvLQb+jsPHnIh58xeSAcnlLexcVX17by66Gb/hcpDWDvRpeyHNm7er+zK9R074EnzUrG4J1v+PZx9Mst5pusMMuweZlX144EdxwJKg0I9cEQIuDRDTwUWBVJoCIKQaiTLN1lPAWb2DLRkaRZ1qGWiFWB6aq+WmTcjy/MFQRCST4jjgB46k0yHdv+exZeP4ODclez25nHK98/xyH9uAo8y1aAdUyLnzizfVgzAsIG7SMtOo35bPc75Ts11MU66tQJrMxG1CYKQFBIoaIuIledsNRGaNEpDXtlXtkbCtVjzZvQyPJs+byaAuA20OqQ5I7Q64lKB2Rrps99cprQ5h2Pn7iC7AVZ3zOC0lTNYWHqW5g3KA9i8wMjo7OSbA0upsdsZVlPPjlePpKtjDy91+5Q2aQ0srOnK37eOw4Od0EEizoBLPWszi0ihIQhCSyLlGr4JsTRragWY5M0IgiC0GgwcB/QwnAiyYAd98X6vMu/4k+ns2MUqdzeOeH86n311mv+sejyx5jigsGSLL3dm/5wN4HJTNn0HeNQWn9otWHYcUJNAazM9QUciHAdS7jeOIAgtgqYUtOmRfEszrQU0hI4JicybUfabN2/GaCwXt4GWjzRnhL0C09wZgCKw2+sZ1XsqT2y7loOX+B7qs3t05cz5Cyjb08t/YTbGXfHwAiO9/R56TZoFOQ3sU+fmmR07GDdwMy93m0kHRw2r6ttz3eaTaAhRBCco4FIPRZ1Q6t82o7WZFBqCIERLU/gmJ8XSTIvT7KSepZmiADOzNEvdvBlBEAShaYg0QWPVDjpS7oydRp4YeiP/HnITWTYXX9QdwmH/nsPvvx+kuVM65o4DoFdDrbZ1Z483i3xbLR0X/Yx7l159pGdtpjgO6DRqtIec4ZcAMVubxYXkdQqC0AQkaw4mIZZmapzR3MSq80C06NVHkOp5M+I20DKR5ozQskiECkxLR0inigkFt3D3kifptw4a0uCFbsdy3bdf4fZm6bxJ7devfh1UgaW1qaXklhmkF9ZQt6UtJ67sSnYD3OL9jpKMSja78rhy43iqPGmq+xqpwBIccBkFybI2EwRBSATN4ZsctaVZwhRgsRApbyaBCjCDvBmz8TgwjmuW54sCTBAEwRrJsIMOoST0Zb5tNzMPOJGbOr4GwJO7LuCEtz7DuUsZFLIxHkci587YHDban9Sd5R7fOHxAjZWGkt5YGcFxQE0M1maKYCOS44CRylzyOgVBSBbxCtqskHBLs5jdBhTUqyvVxGoPDeEratSI24CQOKQ5I7QIkqkCy2rYyf11Z3PlnC/puAt258FNXe7gmfXP+B6qhs1wYxWYPdNLz5s+J7NzBQ3luax/bCzfr+zDpu/aUVBdy253JldsPIGdjbmq++mrwGzpDRQc+S3dr32Uzhc9Q1bPFaSytVlIoSHWZoIgJIimKDL0sFxkKCQl1DIWBZjVvBkt8SnArGJ13BYEQRCSh94EkNnKx/7231nUcwRjM36gzpvOZRsf4OaZz+Fp1CqLwVxdbDwetRvVgYyiLBbX+8bf4d02aa7QW2EKlsdHJ6GTfjFam+kRSdQWSRhihDgOCIIQL9EI2lLG0syJzrSXImgzGgsiuQ0YoTeOQVx5M+I2IFhEmjNCytIUKrDC8pU8v/o4TppTSk4D/NnZwdldPmSu8yKdN+ipwMJzZ2xpjXSfOIOc3jtxV2ZR+uiJuHflsv/KjVRvy8KW5uGOilGUugpM/iQuHAWVdDxrGvs8eRfdJvybtsMWUXjMF3S5+Dnjt6WAtZkpYm0mCEIyibPIMGtAJ8TSLKpQS/UxveviUYDFmDdjAVmeLwiC0PKwkjtzfMN0FuQeQ3/7RrZ62nNM6Xu89vO1gJl4IhsrjgOBq/vk0uagdgDM/b4AgCO6b0TfwjNGx4EmQvI6BUFoMcQhaGs6SzNlNaQWIxFbLII2pU5Suw6AuA0IyUaaM0KrJKIKzOvlgKXTeXn26Ry2qBY78N3AIk7r9iOb3ftGuLs2d0b1MLel0fXyWbTZfwOeegfrnzyehm0F/K3DQsZkrsNrg25H7KbkwDr/G8K7/fbMaopP/Zr+Dz9OhxPn4ciroWFn+8BVDdu76H8tp+Z1qlibSaEhCEKSaA7f5KaxNIPkKsDURKEA64C++qsk8qfI8nxBEISmITBRY2AHHfXKRa+XSdse4lP3RRTYqlni2YdhW77i+53H+M4XEGzix+A4oEyE2XPTKDyhEwCVi3bxzZJCPF7oV7iL4tw9mvvpOQ5ox01lbI3DcSAOazM1ktcpCEJTYWkuxaKgzYjUsTTTug0oqJv38eTOqJsykMi8GauI28DegTRnhFaNXufZ5nJz9rf/x+Mf3MqANV7cdnh9xHCu6DUfly1X5y5GaCe90ul49mwKDv8dr9vOhmePp3ZdMRe3W85l7X8G4Ne+XcjrUs/Yvn9q3usGu4d2I5fT75GpFJ86H3umi5o/erD+6csofWQijbWZAFQtP5CwQiNJAZdN4aEshYYgCJFIed9khYRZmiVDAQYJy5spwlQBphZDWMp/0yjABEEQhOQTyXGguPdGMly1/OfHi3i05kHSbF7e8YzliD1z2NzQxzcWGHnqR5U746NoXAfSch007Khj99dlOOuy+WVHMQBH9thg8ieJwdoMLE7++Sn1bxNgbZZoRNQmCEIkYp1ziWRppve8S66lGRhbmrmILGiLJW9Guw8x582oELcBQY00Z4QWQ0JUYBUubv/Xhdzy/lt0LoOKXLj93Ot5uO+r4dcWEJUKrPC4b+lwwo8AbH51LHtW9OGkNmv4e/H3ADy28xD+vftAAEYVrgSCD87cfTfR977/0vXS2aS3raF+ewEbnj2HP//vKqqW96fr5W+Tll1P9ap9qFg43PzPqC00zAIuS/1bC9ZmyfJQNkIKDUEQzGgO32RDrFqaRUQbaqkmXgWY3goZiEsBFiPKeB1Ynq9BlucLgiBERzwTMkYrGzvs2M7894/hotrpNHpt3J5xI+dnv0c9eRE99INYy51pc1Ae2b1z8Lg8lE3fCo1ewME3633j71E912M8GQdRW5s5NccS4DgQa15nCFphiclvG0EQBC3NldEZNTFbmmmf7WpBWzw0Ud5MSeRvIm4DeyfSnBFaHUYqsMYNtdx3wwWcP/NncutgQyc7F932H2aU/DX04qhVYNm0PWwJnc/7DIBt743G+f0gDs8p5f86fwXA67sG8+quwcxZWkydN53OaU4G995NeuEeul/7Jb1u/Yysbrtw78lk61ujWHPHFVQu7gfYKBzzHbkD1uCpz2DzK1eB1+SfrVPzOoHWZnHRHAO+IAitgqYuMqKyNCvVuahFKMDCff51i4sCzPNmNMvzRQEmCILQOhj181d88cI5HOL+jQpvLqf1eJmHut0Ldnv4ysmImOfOpBdl0m6UL2fGOW8XrrKGwDvnb+gBwIiQlTN6lmbqrULTWptFIiZrMw3iOCAIQrQkWtCWGpZmEG5pZuQ8YEbz5s2I24CgIM0ZIaVJJFlcdgABAABJREFUlAqs8YdK/jHxRo76sQK7FxYPbsOp//ieNR2Hhb4pKhWY78GcO+hXuk54B4CyWUdQNnM4gzK3MqXrJ6TbPHxa2Y9Hdh4B2KirdfB9TX/qbdD37IX0e+Aj2g5bh7fRRtkX+7H61gso//IgvI0eAHL6/0Gns2cAsO29U2jYmY+hKiAFrM3EQ1kQhKYg1iLDiIRbmqlxGr0hxRVg2l6N5M0IgiC0PAwcB/RQJoaOnjGH1z78O1285ayhK4eP+ozp/c7yXWTkmV9AbLkzaVA0Ph+bw0bN2lqqllaiHpPm+1fODOm4nfzMOozRW0FjgVS0NpO8TkEQYqCpnwNJtzQzRGm8a4/pXReLPTSEW0BDc+bNiNtA60eaM0KrxuuF/Dd+4+q/P8+g1Y247fC/s/rzlwd+oK6PptMdtQoMsnpupMfE57E7GnH+OIRt755Az/RdvNjtA3LsLr6r7s4/to7FizJQenmzvgOndO3C6gFbsWc2Ur2yE2smn8G2tw/HU5OFMmikty+nx8T/YnN4cC4Ywq45I1SfHEEF1kzWZqZIoSEIQjPTZJZmSVOAWZl0SrICTK9BY6AAEwRBEJqHSBM2enbQXpeHE1/8jMcWPUa2rYHZWUM5/oZ3+a3HEOMbxeA44MM3VrUbmUVGsYPG6kbKPysPu8PWPW1Yu6uANLuXw7trxxftmKheUWpia6bg1LxuRmszEbUJgtAkaAVtCcjoTBq16h1l3susKRPJbSBWJG9GSD7SnBFaJhZUYI2NNnLv/YkrnviErjugKhum3n0sT9//jG9JvpoYVGDp7XfS8+YHScuuZ8+v/dn8yoUU2at5qdubFDpq+KWuEzdsPhUXaQBkdKym581fsujgP9mc7qDY3cjWf41g3UPHU7+pfci9bRkuelz/EY78amrXd2bz1LOARmt/N07N6yRZm0mhIQhCstk7LM0gVAEWSf0bqwIMEqYAsyhmMFueHxi3NcvzRQEmCIIQG7FMzCgTQLYyFyc/9hV/2fY5AK90PpULJ71MVbuC4MUlmjfHlDsTJKuXjfxhWQCUfbYHT41H97pv1qutzfQsPdXjpfq41jJUOYa540AzWJslAhG1CYKgJZo5llgzOpWmtGVBm9ptQP28dUb8iipqMbYxi6U+0iPOvBktJZG/ibgN7L1Ic0ZoUVhVgVVXOvBc9SenvreaNrWwsZONf//nTL4991wgtFOtSwQVmC09k+4THyS9bQW1G3qw4ZnryPG4+Fe3V+mesZsNDe24etOZ1HgzsGe66HjmQvr+3we0GbwZj9vOubvrmb5pC8Pc2YCN0MKhga6XfkF2yXbclTlsmHIu3oYM/zkLAZdWSIC1mRopNARBaCpSyjdZIS5LM/WDW88vP1EKML0CQ6Hp8mb01NqCIAhC8+H+2cbVz/2Xw+p/p8abyT8PmcCLV11JV4eqI9FN0/hJQO5M/lDf65rVLur+bNB/Gw7mb/BN+B3Vc73et9dsQb8ho4y7OrQEazMRtQmCYEKrEbRVY6FJE6l2UraRmjTiNiCkDtKcEVKeaFVg634tpOMFyzni2yrsXlgx2MHr0ydSvr/BZFuJ5nVEFZiXLn95mJzef+CuasOGKTeTVpfOlC5TGZi1mTJ3HldsuohdjTm0PXQN/R58mw4nLcfu8FD1U1fW/ONU+qzpRY7Xy/GDSjX3dlN82iIKhq/E22hj4/On4CovIOEBlwotwNpMEIS9k6YuMvRoHkszaHYFWALzZuJBlucLgiAkl22fFnPeh7Pp5i1jo60DT559ASvGDTZ/k5l3fhS5MzWrfM/4rF4OHG3tBCfGQBmb0osyWJLnE1Qc0n0LJecXk90nx+DeRitoTHBqXjeBtVlMeZ0WEVGbIAgKKZnRaZVa7Y4y96W1MVO2kQRtzZQ3I24DgkWkOSO0Kj6ZMYSCq8rY9w8PbjssOi2H5e+egrtNbuCakA51DCqw9se8Q7sRn+L12Nn4/N9xlXVmcqfXGZ67impPJldtuowdHV30uu19uv91Buntq2nYkc/6p45j/ZNjaNjelllrfQ2O0Z1Cly0WHLmK4lMWA7Dl9TFUr+xJQgIum9hD2XKhYVEFJoWGIAgQXZERLUm3NDOlBSvAVOgpwGR5viAIQvMQmLjRsYN2u+3UPpvHCYuXkWNrYHFWf16aeCY1g9qE3Sei4wD4mvgFeieMc2f2/JxG3XoP9nQb7Y9rAzbI6JhOm6H5dDitI92u60OXy3tRccT+bPYWkmFr5Iiemyg+sxOdL+tG7r5ZYNPWSep6KUrHAafBdQm2NrOMiNoEQbBAoudKkprRGZOlmd6qR7WlWTKJMm9GD3EbECyg/X+aILRIPF4brz81jP2nVZJbDxW58NukDqSd1xkbdvqyNvKKj87o/+guCO7m9F1M5/MeBWDbf6+g+vchXNX+P5zedj5ur51bqs6i8sIf6XvEMmx2L556Bzs/HUbZ5/vjDYwbLmYu6QqHwv6ZG+lUWMu28mxy991E10u+AWDHJwez+xv1QOjC989VCbg06dYrhYb/e7f3lLOv5zf23f0bebVVVKXls7umgLn1oynvEdkseiPdLS25/Jn9rQWXjSbY9RcEQdAhEUVGylmaGYZatnAFmFbk4CcWBZggCIKQfHbV5rP55Z4c7vI1z2d3GczGy7rgcWQErunDGv3xsIRQgYJSSkQUf2UT2g1xAC7KZ0GXy7xk98qg+42dsGeEjv+eBg/1W2r5Or8353cp5+Cti/mmYF8yOmRQNL6Y/EMb2PqfcmhU6qR01VZBb/JOM+5V45tcUyjDWJBQiu/vYZMNunnZ8Wd3wwbWGvqYjodavCP1s1QBnyBFNV6O7yfKaUEQ9IkoaGtOSzMjLFuaKWOJ1tbSipBNQb1KUy1kU9tvputcq2BQO8XhNhCPoE3cBlo+snJGaHFoVWAVs7N5c9oQDvmvrzGzsTPUvw5p54WvuVd3pqPNnXHkb6PHlX/D5nDj/GEc5Z9fxvj8OdzQ4R0ApnQdxI77PqTdiKXY7F4qFuzLH3dcys7ph+B1hfZBt+/MZkW9L9zyhKGbyOy2ix4Tv8Tm8OL8oS87PjrEf2WkgMtwa7M03Bya+SOTs+5htn0031YfyUu1V3Lj7qeYsHUqN216kvt+mcyn35zEiWs/jWhtpke01maJWD0jCIJgSoKLDEUBFvYsLPVvE2ppBimjAIOgAswIEwWYGZEUYDLJJAiCEB+RJmjW7OpJ3ar27O9aT403k1mHHkD7K8uxOfQnhkwFWnHlzvhw7wbnt768GXuGDU+9h9q1teyeV87WN9az8ek/2PHeJmYt7ATA8LTf2PTCn1QscAKQUZyBLU3d0Ilmtakfp+Z1M1qbJQJxHBCEvYcWJWhTMHMbAIMFj2aCNi2x2kPrrV2QvBmhaZCVM0KL5jc6s2pdO4atrQNg5RA44dHV/N5jHwazwvIyTF0VmObHeMexz+Bos4va9QPY/No9DM1ewj87PQnA/w7K5MfjfscOVK/sxbb/Hk/tn4UER5bwQWP29v7s32MDwweu5rux1aTluKhe1YnNU0cCNnyDiFoFptxH75+tl4OzlzCuzWzGtPmSIkd54IzHa6M0rYTf7IMoyy0kP72K/apX0L/2Dx795hY25PdgRecI3tJ+1tI3bOnlavYJFHPq1TPL2w9M6FLM6TabKAIEYS/AqMhoCt9kUzaZFD9xW5olQgGmoPbsV16DZQWYuvYoIGoFmNny/GiQ570gCEJiWZQ1mP02bCHb1sBmWyEN4xooPiR8ZszQcaCb13wsVFA393Un2ZRVNL5VLpULoGHnHjzVLhp21OGzGwgd++au6wnAsC5baEMNrp2+hk791nq8DQ3+eyp1krpesuA4oD7tRN+aZju+ZtRWTDN39BwHol09E4LWcUBWzwiCEC/NldEZtaWZUb6yVtAWyW0gGlpG3ozQ+pDmjNAiGO/1hk3YfTZgENlLGxlQ7suXWTfazan3hU8K9WdV2CRcdzYGBxFtodGRMJuatNwy2g6ZCcCWf99Nv7Y/81zRtaR73PwwwMa7Y93Ube7C9vfHU7V8AFAH1OB7qCuDgtJo8fH5b724sqeNzwZuJSMzjfqtbdkwZSxed5rO34CL4AARvEeevZqT87/kvILP6JO5PnB8t7uAL+vG8KVtLMuyDqQm318lFQJF4PC4eGrDDYzeMZcT1n3GisHmzRmr1maWkUJDEIREYlJkROObnDBLMy1hk1ORLM20WFGAaX/SaZs06uMgeTOCIAh7B267nUWegxieuxqAn+0l9Oy/lvy+1Sy38vzHt0IyZCWpnh204jjg1L5bGW9qNceqUJopdX+CmRhhY2Vb/ihvR7/C3YzosYEFvXsBULeuxn+FImpD9VoZ74xyZyJMsiXI2kyNWtRmhKm1mUVE1CYIrZ+mErSpSaqlmSWMGjXafSukRzif3LwZM8RtYO9DbM2EFocXeGPoQXRc30incnDmwp4LKjj5mPDGjBpLiiUDJVT74R9gd7horG1D37Ev8lKvS2jjaWBVV3jy8GI2/fsq1tz1MFXL98e36gXCA5hD9+f+1IGbO3RgTWYaVGdQ+sQYGquVxoyb8Im64ITdPpl/MLnjC8zrM4E7O75An8z11Hiy+ajiZCZsfIWRa7/hnl338V3jkdSQG1oklYHbns6nXccDcOjWhXFbm1lRoEc78Oshy/QFYe8kmiIjWpJiaeY0+jRFAaalhSrAYsibEQRBEJqOD8vzWXnqIIZ38DVmFtT3Y79BK2ib6ZsJUyaA9FY46q2EBMJXThYRsYHvI5vQMUdvFaeRdtTBnFLfB4/utY6skhwAateZTdRFMWY6CZ0cTJK1mRrL1mYJVrkLgiDo0eyWZk6Me+kBS38FbZ0UjaWlQ7PfPHkziXIbEFoH0pwRWhR1GZnsOmswQ5fWkFMP6ztD90M3MHybb6YsWpWR1dyZzPZ/ApDtqOSfq+ZQXAFbcrO4wXkvv9w+ld1fHwueNIIPaodmq8VL4VmL+S43i0yPhxHLe+La2UbnOrXFjZfhOb8ztdsz/K/k/zinYB459nrW1Hfn/u3XMHLtB9y57R6+rzkCt3rwMFAh2Co9AOxJNwsWCEcKDUEQkkVz+iabEoulmaVQS0i8pZkZic+bsYoyAWi0PF8UYIIgCImlZt8eDL+mHftl+fJlllV349A9S7DbjVdURL3iMercGTAeq8ybNHPWlQBwTL/1pOU48NR7qN9Sp3MfbWan4WyfjwinAyirZfUEGir0LH6aI69TRG2CsPcRUdCWpIxOQxJiaaZ9SKsbMZEEbSmSN6Mi0W4DskqydSDNGaHFsLt7MRmje3Pozz6P4d8GwOitq+iWVhN2bcJUYH62fnMHW7/4G1e83pV+W8HpbcPlKz5iw9Lx4MkjWhVY4fG/UHjMSvDCQzvLOSs//Fd+VskOOl84jy7nzea4ghW83/M5pnZ/meG5q3F77cysPIS/bJjMyaVP8Y5zPNUendk0p+a1auKwc8M2ALalWZtpS0ShYYoUGoIg6JDMIkNN01maKWgVYMoxveuaSAEWQ96MWuQgCjBBEITmZ7zXS93xgznlzGq6pO1mk7uInfPSObDut5juFzKRZLBiMgzF5sVwcaZyQjteGe/PK+0BwJCirbSjkrrSavBAJMcBH4qVqDLRp5dj4Mep2jdbNVPq31pwHFCjFohYFrUJgiCoSLSgzQirgramszTTug0kErXbgFIviduA0DRIc0ZoEXzxwYv06Nqe/uu8uNJgcd9GzvhtJRkeawVCvCqwxroCTl7kYeTODbi86Vy38XnWu3pp3qRXSGitZBzkD/uTzucu9r38cj+OralleO4fZGf7BpqM4gp63PA5A+98n7MLl/HGusU82XEag7K2UuNJ543dIzj+z8n8beulLK7tRdBGTYO23nCq9stg8J6fAFif6wvYjMXaTI2VQiMR1maCILROmqrIUGPJ0kxbZMRtaaZVgKknlaAlKsD0EAWYIAhC0+NyNfDD81dy1qGlZNlcLKoqoc2n2+m5dUtU94nZDroDBt772YSPN1YdB3zsqM7jlx2+YJtD7SupXaed0dOuQNWOrxZoYmszy2iFKCJqEwRBg9FzIEBzZXQqxGxpBvqWZsl0G4DQMSn+vBmrRHIbEFon0pwRUhqv18vb91xM0X1P03E37G4DP+dVUTAjfv+TaFRgQ90Lucn+BAAP1N3OEob6z2i752bd9HSy+26l25VzASj/aiC/vH0wWxvbkWVzMeqQzXQ8czH73/1fLqwv5dnnG7nqcw+dd4PTm8FzZSM59s9beHDHeLa426vua0EFpqldcjzVHOX8BoA5HaOXnUuhIQhCShBDkdEkvskxhVrGUlhEomnyZvSW5wuCIAhNS/mOzax+9BiG73gPgBnr+rHlyRV885N/UFImdvwTPcrEj1nwsHpFZIgddInOxU2QOzN3g++DD7f/qtOcUaO1NotALdasSCNYmynCjqRbmwmCsNdieS7EJKMzKYI2LQmzNNOK2JSttvmuraX0nv3RuA2gOQfx5M0kym1ArKBbJ9KcEVKW6ppq3rnkcA58dyHZDVDazcbmLZtou3BzzPeMRQXWwb2DJxr+hoNGpnlO5j3XuZo35KM/2RWqAssodtLz+pnY0xupXNaDrW8dAtiZXd6fT3NzaDhlHtfnLefFl1xcONdDu2ooawPP9+3IsX9cy3PlI3A2KtZlUQZcqtkJo2rmkuWtZ11WCSvbDLBszyOFhiAIycCoyIhoaZZgIvomK1h5ZjqJM9QyWsWvlbwZEwUYJCRvxmycFQWYIAhCcvhj+Xxczx/FoIafqfZm8f433fD8Zwn2KFYixmVDGfUKy8g2ZnpNmvlVPoHFcH6hsVJv5aneFvQn+ZTjOjhV+1FYm+mhJwCJydosRlGbIAiCGQnP6CyN4j2WBW2gb2lm5jZgFXEbEJofac4IKcmfK5cx54xDOHCBE4BfDm7D6E9+IGvLnvCLk6gCc7R38cSumylqLGOVvT/3eu/B0EbMRAVmy/DS44ZPceTXUbuuiI0vHANeO1k9y3ivsJK1S9vx2CsNnPaDl9x62JDn4LkT7Vx1Wnuen30WNd4M//3iDLh0+janVv0PgFntjwP1pGgTWJuFEWHhjhQagiBYQd34NXrmqJ9RZo3lprM0A3MFGDSZAsxq3oyKSAowBbPxGEQBJgiCEA+LPn6e7v87jU6UsdHWhbLzPuOsOb/GfD+9iSJTxwHtykoFy7kzCpEbNgvTBuHx2ujn2ErnvCrVNXqWZuqtGrXjgA4JsjazLPgwIRGiNnEcEITWQ8yCtubK6FSIZGlmirp+SoSlmR5meTPaFZ8RsOg2IHkzghppzggpx5z3n2PTJefTd52HBgesPPsAznprIZk5bUOuszqZE7MKrCPcvPJxDt6zhCpbHjdmP01dxIeyXiHhoOsls8jqWo5rdy7rnzoRe6abo878imcO+JiHPqhg7DIvGY3wu6eQuwZ045aJXub0y+TPF8biqVP+maon8NQDkMWAS3+hMbDxV46o/R43aXzY4czm91DWYlEZL4WGILQOkl1kNIlvsprWrgArsf4t4lJhC4IgCKa4Gur58bkJDFt+O1k2F8uzD6PtDfPpOeDghNw/FXNn6joXscLry/08tvefEb6cUe1kgtPCNRatzfRQC0OMVOkiahMEIZmkpKVZrXYnkqWZlmiaNNHkzehcm8S8GQVxG9j7kOaMkFK8d8+FtLvvWTo4YVc+VN1zNafd907gfDzL9qJVgR237nMuKX0dgDuKH2S9vST0IRyY3zLLnXHQ7ujFFBz+K95GG5tfPpajD1rAf4a/zosr1nHESi92L9R2tNNzdBmvn5TPqtO24fHa2fjiMdRvUrJl9NRfsVmbTWh4BYCZeePYnNkteCLJ1mbqQkOszQRBSCSJLDIiovesjCnUsjkVYNr9CESZNyPL8wVBEJJL+fZNrH7sGA7b+T4AP3a7nMGTPiO/IMKSxxaRO6Met1Sk2cjsnM23nv0AGNtHr3mkl0GgoLd6VXNajVO1nwRrMzWRhCSGiKhNEPYakiFoS1hGZ0ItzfRWNRpZQGtRnvlW3QbUDgN6bgNqkp83I24Dey/SnBFSgtqaat79y2EMfncJWS5Y191O9/+8yZFn3pDwz7KiAmuXv4t7vp8MwCudL2d27rG+h64h+rkz9sx6Op31FTavl4OWZvJG4cdM2fILB6730GiDmd6unFZ6OrO6D+LHkjTWDfY9tLe9M4w9P/Uw+KzYrc16UspY1xe+P1fBBP1CIwprs2g8lE2J0UNZCg1BaJ1EVH4mqMjQI/mWZs2tAIPmyJsRBEEQEsPqpd/gfuEoBjWsYI83m2WHP8dhE57AnpZm+J6kOw5AgnJntMeCjZvMzlnYHHbmVQ8CfCtnbHjRn4wzW5Fq3XEASLi1meR1CoLQXCTieRK1ZWNclmYQ+oxOlIhNjbYpo3Ub0Fv1qUMT580IrQtpzgjNTunKpcw+7RCGLKwAYMWwfI6Z9iM9olmSn2AV2HXfTKFtQyW/5w/k6SE3RvGgDVWBFR+yhONX1fDkvxq5/as97LOjkfo0+CCzO8evvYC/rR7HqvpCpm0r4rYOhXhtsHtuf8q/3BdoUN3LLOASrFqbXd34Amk2D3PTjuaPjP7BEzFam8WDFBqCIFhusFpUhkZLalmaNZUCzGLeTASizZuR5fmCIAjxs/B/z9Bz2ul0pJwN9q6UnzeTA8demPDPaZ7cGXNLs8xuvmu/X9+Z6oZ0OuVVs39HveWsiXMcMMTI2qw58jpF1CYIey3JFrRZyujUErOlmfagS3PSaPVMtPbPRhjlzURBlG4D8SBuA60Lac4Izcq8959h48UX0Ge9h3oH/HbewZz9xgIyc9pYen8yVGD9N63irOU+i4AH9r2dRrt5oRBOOu3TKplYOJ2PdnzKFbM8dNkNezJsvJ7fi2PXXsHdP5/GZldbIB1HQQ3OE5dQa7czvLaW/qu7U3zaMgY+/zZ97vkIbEqTxkrApTH7Vv/KKRmfAPB85rXhqgWIaG2mDMCRVGBSaAiCECvR+qWnpKWZKWaWZtr9eDBTgEFEBVgRTZI3I8vzBUEQIuNqqGfBs5dxyE93kmlzsTxnOAXXz6fngAMN35OIiZumzZ1R0K+9MrtkAbBnayPzSn0OA2Ms585AtI4DQPTWZibEk9cpojZB2LtJaUFbwizNTFYzhojZ9ARtkWoptX2mnnhNK3hD5xoSljcjbgOCFmnOCM3Goi/eoeDe5ymqgLK2UHXPNZwx+c2kfmZEFVhXD7e99yBpXg+flxzPkoHDwm9SgGHuTJ+M9dzT8QVm9/4b1xR9RgF1bHK34+G6kRz9yyQeXnQ25fXBgsSW4abH9bNJb19DYV0aR9bUseeSeRSf8hNpOQ1kl5SRlttAOFEGXNZ6ubXdwwBM94zn16r9gueisDbTI55CQxCEvZeYi4yWYGkWEmoZjaWZmc1ZNGgLEAhVgEnejCAIQkvip2fP59CyDwH4oceVDP7bDPILCuO7qcZxwCohjgN6JDx3xrd17fKNjW0OaM9XG3xCi2DujNlYGpvjQHNbm0UlatMgojZBaP2ktKBNISGWZlq3ASOirZ/07Mz0BG16r/0Y5c2oELcBwQrSnBGajYNGn8XafTJZ29NOz/+8xYgzr7P0vmSqwI75YTaHrVpAnSOTx8ZNCr9Ax+ol3dbAuDYzeL37zUzvdTVnF3xJpt3Fz7W9uWnzZZyw9g5eX38iNV7FzN9faNi8dLtiLjm9fSOW0w6PFrajNrMxcG/nD31p3JPlf2Ul4FJ9LFhojMqdyyFZi6jzZvJU9Y26f/bmKDTCfiBIoSEIQhTEoyRNDUszBSMFmJmlmYJBcHIsCjAjRAEmCILQrLQdfSPltGX5ES8w/LJHTfNlImIw0aNMDOmtfDScUCrRvE5K7oyPiu8rcO1uwNE2nQWdDgVgRI8NZKbpKab1tkm2Niv1bxNkbRYVSVLLC4LQ/LRIQVtclmZgbGmmbtQY10e92zk5uPM2w/NBlAaNUX5njHkzJRY+2o+4DQggzRmhGUlzOBjzyhcc+7+F9NjnoMTdOA4V2DE/zAbg3VHnsaWga+gFmgduD9t6/lbwKHN7H81jXSYxLGcpbq+dL6sO58IN93LuhruYtWcYjaiLp+BDv/DY32g7bF3gdWNGI1keT+B1/da2bHn9UP8rqwGXEGqZA+k0MKn4UQD+U3kxW71dgpfGYW0WCSk0BEEwwqjICGvAJrjIMFOARXy2JcTSDBIXaqltyOgt049SAVZA1HkzVhAFmCAIQnz0O2AEObf8ygFjzo/5HtFO7CQvd8ZsPDK2k/Y2eCn7dCtej5ctfQewpTaf7HQ3R/bYYPgeH01kbWaBprQ2E1GbIAgKzSJoMyKipZn2IWzmLBBeSznsjay9/lUWX/k27bLU9zJqwATe6d82Td6MuA0IaqQ5IzQr+e2Kycoxk+tGQQJUYJ3KfN3137trlpH7GzM9XaVcVv8Kb9vP5fM2x3N5/qu0d+xmq6sTz5RdzbF/vscNW+5iae3+gPLDV18F1m7Ur4FX3kYb5V8N5IAqXyMnzWVn/ZTReOoyNN/YKODSeHLvkvb/pldGKWXuQl6uvCJ4wqm6qBmtzaTQEAShqYnKNzkuSzP1vtVQy3hJgAIsirwZPTW1Mt4q468RogATBEGwTnautUxONamZOwO+cUg98WWUOxParGnYUkfFj+WAjW/TDwBah7VZMvI6BUFo+SRb0GZEQjI647Y0g3BLM6N6Keg28NehPwWO1ri0DgNqIVu6zjHJmxGaB2nOCC2eRKrAOuzyjRqD1v9CV+cmDt62mJP2TOeGVU8xbcV4Zm48gUkNj3OA7Sc8XhvfuEZw7c7nGPvnl7xQfgM73GrJsdpLObzAqPixP67yXJw/9GXNnadSW1rIj219BZTH4aHP3Z9SdILZEke9IiO0qOiavpm/Fj4PwKM7b6Xam+c7YWS/I4WGIAhJxnKRoUXzbGh232Q1hpZmVkMtteesNmmaWAGmItEKMEEQBKHlkJjcGYicO6M9BxXflrPnZyffeny1xfH7r9fc08zaLAqcFq6JwdpMRG2CIKQKkSzN9J5Xic/o1EM9txWdpdmUE+YF9usbjVZj6h1Pbt6MGZI3I0hzRmiRJEsF9ntvX9PgL7Pf4Mvnx/DGzIt45JtbuWrtv+hXuwYXDr7PHs59mXcx2jOXq2tfYm7taBoDD3etCgz0u+8Odn4ylFV/u4BN/xoNaR66Tfg2cNZrg7RsF3n7bSY0c0BPBaanvq4Fariz+P/IstfzY/WhTK8c7zvt1FxuZm2mHWCl0BAEoamIojGbMN9kIxJqaaYXammuAAvHigJMe22MeTNqtBY2fhKhAJPl+YIgCE2IMuGjsYM2W/GYCrkzkA5eKJ+5nc8W+CTKg/M30+coYzs0H5FU1xqSaG2mJtq8zjBE1CYIrZbmErQlFUuWZmarINE5p8/89V1Nz+u7DWhXdCYub8bMbSAS4jbQ+pHmjCCoePyFidz2twfZVNgVV1o6Gwq6s6DTIfyv62n8vfdDjDjoWyZ0fpV3M85nh5mk2FSdHK4CKzh0ne6V5V/uq3PUyNoslGPzvmdk3o80eNK5b/tkwBZ7oVFqck4HKTQEQdAj5iJDQ9J8k5OmAIsv1NIaehk0Ccib0SzPjztvRhAEQWg2rE7wmE0YNV3ujHnDZe2cWn6u9E2+HTdYb2zSW5mqoDcJmBxrM0UAEmm1bkLyOgVB2PtIsKCtZVmauUP29+0Q/OAbZx2tOmfVbUAhdfNmhNaJNGeE1kWcKjCv3c4nx5zC2Ae+5IDnlnP8NV9w6Qmv848hDzB9wClUOtpG8WXSCT7UzVVgFUv6UL2qI7u/7Utmo29yMN1tp2ZtcYTP0A+4zLXXcEfxCwC8susSSl2drH/tBFibqZFCQxCEqEkl32SFqCzNEhdqGR0JUIAlO29GlucLgiC0GKKaQEpo7oy+pZmWL1b1BODo/N/9cZ96q04jrVTVG6v9OA0/OkgEazM9Euo4EEHUJo4DgtDySFVBW/NbmukRfMY/MfbrwP7SrR3RdxvQZsy0rLwZcRtonUhzRmgVRKsCM8ud8b/Aa1f989BTgRVgQQWmPmasAqsr7cC6B8dR80cx9Wm+h63L4aHkls/9V0QXcHlLh9folF7O+oYuvLTrTJ3vhLEKTKG5PZSl0BCEVkVTFBkJtTRTnn0JszSDWEItjWliBVgEolWAyfJ8QRCEpiFZdtDJzZ1RMGvSOPj89xIARjh+wVGg18QxchxoHmuzpOR1CoKw95KKgraosGJphuaY8fP7uL7aDDKr6LkN6EUW+Elg3oyC7evI1witF2nOCEIkjFRghljJnUlHr0njaBtaBaSnezRXmAVc+gavw3J+4+yC2QDcte0GGryZobeoJbS4UO8nqdBQI4WGIAimRFFkREtUlmZ6NJkCzAwreTNRKsCsEGF5fjyIAkwQBKEZSZncGWuOA+r9bzd0p9aTTifbbgb32h3hM8IFbeGDd4zWZnpCDhURs+2iRPI6BaH10uIEbQqRLM2qiUHQBr5nt7YpE96kSbc3BvZv++rICPdXuw0o9ZLRSk4TkuE2oFhBi9vAXoU0Z4QWS8tQgVkJuAxOtO2cMZj1Tx4b/H5b21n4nGChkWOr5r6OUwF4a/cJLK4dbPXLhhKlh3IkrA74UmgIQuskUUWGlqT7JivEZWkWKdRSKS7MQy3N0VMLJz5vRo9ELM8XBEEQmo7AysUIEz9Wg4oTlzujPmYtd6a+0cGPe3zCizF9//QfVa8+1XMc0IojlFy4OKzNFJJobSZ5nYIghJAqgraEWpppL4hUHwWf99cesjxw9KkfDyJUpGY+loSeT628GXEb2DuQ5ozQekkJFRhYf7inQ6OdxpqMwJFNB5SS2dmpc62+tdlNHT6hW0Y5m12FPLnzXIKFRpQqMIUm8FCWQkMQhAARiowm901uckszt8G+FawowOLLm1GLFwzHTYLjbWB5vijABEEQWiwR7aDVxJ07A1YdBxTm7PQJMo4tNprwMrI2s0AKWZtpEVGbILQ+Wq2gTRc9QZv6oat9dptbmj15XNAXrL7RSLhmlDejvkZDgvJmEoG4DbRepDkjtD6aUgVWgIEKTDsBZqQC08mf2dCe9PWFALjtkNNHmQnUZhCEbg/O/oML2n0DwN3bLqPGm6Xzhf00g7WZ1UIjWqTQEITUJln/FmP1TY6ZZgi1DCcaBZia5ObNKFgdX0UBJgiC0LQky3EgIklzHAhOrn1ZNgCAIwrWkOWwqrIG31hsNIA3jbVZPHmdYYioTRD2HppL0GZEJEGbrtuAEdpVj9rnunmTZl5ptwj31zZlmiZvRtwGBDOkOSO0GqKd7EmICsyUbPQDLvWKjeAkm6c+nZ//eQIZHl8Rta8mMkaPHNseHuj0HgDvO4fzQ82gyG+KRBzWZvEUGmE/JKTQEIRWSXP5Jje/pRmaY7Fmz4C5Aqzp8maiXZ4vCIIgtFzUKydDJqBKNBcm1XEgyO/OYrZ425Od5uLoklL/UTNLUb3sGbW1WWX4hzij+EJROA6osSJqizevU0RtgpC6tAhBWzyWZk7isDTTq5eC4uX9ioOdoJtmHa26RqcOCjnXPHkzCuI2IChIc0Zo0aR+7gxYU4GF7nsdXhrsvoGv3ZifDe4bHKz+XjyL7hm72OxqxyM7TyN8aWjTWpupMRropdAQhNZPVP8Go/BNjvaZkTqWZhCLAiwycSjACmjWvBlZni8IgtCMKBNABnbQlh0HtESVOxO74wCA1+VlXuMQAI7vq1Xr6TsORJX11kTWZrEgojZBaP20eEGbLom1NHtibNDSbPm2YkLHC3WdpBWxaffFbUBoHqQ5I+wV2L6OfE3T5c5EDrj01GbQ82ffYFjaqRLSPP4zWmszODr3d84qWIrHa+P2rWdT7YliQHEa7Cep0Ij0A0BBCg1BaN3E65usJqm+yQoJsTSDaBRgxiRJAZbsvBlBEAQhZUi444DaDjrq3BmwljujPefD67LxtecAAE7oa0U00AzWZn4hSCIcB0TUJgitjxYlaFOwKmiLytIMjFc5Gh2DMX02WLivdv4tCXkz4jYgxIg0Z4TWiUYFphC1CiwpuTNgVmAoLP9kv8B+cbsa3WvapVVyX6eZALy++wgW1/bBsgI74pJSmtTaTAoNQWhdJLLIiMc3WY+IvskKCbM0g2gVYOG0cAWYf1wWBZggCELLJTVyZ0L33VUuvvMMwuVNo3/hLnq326Vz72a2NtMhUXmdEX8jiahNEFosKSloKzX5kLgtzdR1krLV1kvukP3MtODrW74cEeEzlBpKO74kLm9GD3EbECIhzRmhVdFycmcUjJs0Wze3x+HPnRl27hKde3u5p+OXFDlqWF3fgafLtL/MU9PaTAoNQdi7ibfIUDdyjZ4h6meOaaO4ySzNIinAYrU0S5ICTI0owARBEFoFiZzYabrcmciOA3WlNTir0lni7Q/A8X2V76ZdjZpka7M4RG3RYipqM1HRK4ioTRBSh1YtaNNF+0CNZGmmR/DZfcNhSwP7UxYcqLrGzG0Awt0GtHN3JojbgJBgpDkjCCRbBQaxqMA8tRnsX1oIgGugekLMNxCdkr+SMW3W4PLauW3riTR404lpks9psN/MHsphSKEhCC2G5igyEu6brBCzpZn2AjMFmPp8tI2aBCjAtEjejCAIwt5HU+TOFBCD44AanSaNF6p/cQZyZ04Iy53RI5LNqELTiNrUQhIreZ1RI6I2QWhxJFLQZkTSBW2GlmZqtwE1eoI28zmuh4/9NrDf0OhA321A7Tqg3YfIjRwdxG1ASCDSnBFaPCmnAjNstkenAgMoXegbIN0ZddjtDYHjXRxV/KN4HgDPlh3OyvqOhAdOJ9jaLE4PZSuFRtTWZlJoCEKLoimKDDWmvslGKM86dVMmYpERbaildt8q0SrAtPsGJChvRhAEQWg5BCZ+ZpleFiAhuTOmGDkORM6d2fNzFV97fM2Z0b1KQ2xujFFP/hk5DmhwGtwq1fM6dRBRmyA0P00laDOyNNMj4YI2J9bmnSJampnPcc38o8TKh6jQE7SpXovbgNCESHNGaP00tQosgKICM5okUxca+k2aBT92A2Cnw87wEZsAsOHl/zrPJi/NxdLazkzdNUz1Dj3/5ARZmynE6KGsRgoNQWi9NHeRYdQIDsNIAWYFZxTXGirAtMfMsJI3o6cAa9q8GcPl+RYn/gRBEITUJxUdB9xOF0vXtWebtx056W5G9FzvP69elao39kYhlDCzNlNoQXmdgiCkLi1C0KaQEEsz5Xg0lmZuDu26JXD0pllHq67RE7Rp3QYgvFbSruJUoZc3I24DQoKQ5ozQ6kg9FRgY585oi4tQhViDsy19KtPx2mz0GPULAH9pt4JDc7ZQ43Fw+9YxeGj0vyeSZY4JToP9KFVgUmgIgmBEIgMto6X5LM2iD7W0jl5T30gBpjpcQLMpwGR5viAIQvOSrAme5s6dAahcWM7XfmuzcQOMBhy9ib9ItVMLzevUitqicBwQUZsgJJ8WL2hLAUuzZ06YG9hfVd4eY0GbFu140rR5MwoBQZuw1yPNGUHw0zQqsGzVvkV+6wlAbccy+mbs5MaixQA8vOMINrry/BcZFRoWUI+JRgOo1tpMCg1BEHSIq8jQoP23r27cGj0zIvkmBxrIRs+sJrM002LleR2PAswkb6aZFGCCIAhCCmLgOKCHJZvLhOXOGFuaKdStq2b27gEAjBu4LvJ3w0VLsDYzI15RW3MKZwRB0KdVCdqcRGlppqCtjfSbNMO6+oq3Oneayb1TL28mbGwVt4G9HmnOCK2ClMqdKSBC0105qVWBGVibfe2bYNzq8PBEz7lk2huZt6c771fsa3B/7VJ9PVWCgbWZGjNrMwtIoSEIgoKlf5MWbAv1iMs3udS/bXJLs1hXyDSvAiyRyPJ8QRCEFCLChJAyodR8uTMQKkYIdxwAB59+2xm3184+eTvo2dbpPx7JVrSJrc1iyOsUUZsgtC6aW9CmJiZBm4LR885Q0KZ3TH1cXSfpPZuDlmZtMuoDRyd+FqmQ1NZQqZ03I24Dex/SnBH2DqJQgVnCTAUWgl7uDEQTcLnt904UuuDk76CvvYLd7kzu3nYUoB4otYWGngpb6+epwWmwX6bZRiCaQkONFBqC0LJJZJGhpVl8k5NuaRaDBWWA5lGAmamkJW9GEASh5RHrBFCqOg5s/aWBpY2+3wAnDvlTc1Y74RfnmByrtZkO8Yja4kVEbYKQOjSVoE1vTiYqQVskSzMwELRp3QYU9J7B5s/lfxy1ILD/2vJBqjNmY4a6sa/Nhk7NvBlh70CaM0LrJmVUYOB74EcfcAk2Dvgpj1N+9H2HydsPp6wxh9CAS6NtnNZmcXooRyo0Iine40EKDUFIDWIpMlqOb3IiLc2ibdSkhgJMGUcjIQowQRCE1k3MuTMJdBzAA7OdPneBcQOsTIppayUjazONKsNpcLsUyOsUUZsgpDatRtCmEJegDcIb5XqWZuH8/YjFgX2P146x24BWxKbdt+g0AAnNm7GKuA3sHUhzRmiVpK4KDKINuMy1N3Da3FrsXvh+P5jv6mbhMxJkbaYQo4eyGik0BKH1kkpFhpqE+CarcUZzk0RYmimIAkwQBEFIDFYneqwEFSckdyZA/I4DAJ+X+mqOkcVrSLc36ny4nuNAkqzN9IQfkNS8zlgQUZsgND8pK2jTopfRqSYuSzMF7TPZrdkPjmNvrzAT/GrHCG1N1Xx5M+I2IKiR5owgRCD+3Bm9TrxRdz68wLi9+AeK3fXsbAv/GpvGyOONOk9WFAYxWpspxOGhrEYKDUHYO0hmkaFG/UxJiKWZGqfmdVQKMCOrSasTQamjAEskogATBEFIQZSJoTmhh5WJJLOVkqnjOOBg8ar27PS2JS+tgSN7rPMfj+Q4oFyTYGszhSRbm0UtZBFRmyA0Cy1S0Fbq35pldMZlaQbhz2QjmzMfx/UpDezf+uVRJl8MgjWUdhxR10sqYYDkzQjNgDRnhFZDtCows9yZuFVgYUSjAguqoI/JW8fpbVfj8cL/jsqmNtNG16Hqp7WeHY5WcdCE1mY6SKEhCK2b5vq3FMkWUffZkxRLM2Vf2/w2mwBSiMXSrHkUYDHlzQiCIAgpT7QTQdFONIWQZMcBgLotdXztGQLAuAHa3Bk9UtvaLBZRW0THAR1E1CYIzUdTCdrUpI6lWaTc5HCeHRcsNjZXtSF0bs1sjDByGzBB3AaEJkCaM0Lrx0AFppAQFVgRCVSB+W+ZVsO9Hb8FYOquIfyy2/fDvKx4N8YNF6O8GcXaTEvyrM1aQqEhDRpBSB66//a0DdQIRYa6QdssRYYap9Eb9J6hRsVFoizNrCjAVK+bK29GlucLgiC0WvQmnprbcQDAU9PI7Epf7szx/Y2aM81gbZaAvE4zRNQmCKlNKgjaEpbRqa6XLAna9I7pWZrpPYu1lmbQt30FANv25Oh9IMEmjdptQM95QNwGhNRAmjNCqyVWFVjUHe6oFWAQWQXm5f5Oc2jvqGNlXXueLTuYBZ/si93rZV2mnf0OMltTqkZPCWZibVaLdWszrdK81L+16KFslWQWGoIgxEeLLDKMSLilGUQTahmOkWWZ0XWpmzcjy/MFQRBSi0RO+DSv44D6XPD1l+v70Oi1sV+77XTL3+U/btXaLEqsWptFiZW8ThG1CULLJ9GCtmiJpymsi1PvYDSWZnoE39Mhpyawf+1nFh5sIRhZY/pRBG1GiNuAkCSkOSMIFjBVgemhqJNDVGDaiTJ10RFqaXZ2218YmbeBek8at24dhYs0qnbm0rPGV3wMOUZtyaYNm9Yei8LaTE20HsomqAd8KTQEofWT0kVGqX+bMEsz0G96RxNqaYXWoQATBEEQWgD+CSOtHXTLyJ2B7RvsLPf6ao7j+1oRFug5DkRpbabet+o4EGdepxYRtQlCatKUcw7quRQjQZup24BVQVvSLM3MVzL+c/R3gf2PflcXnVbcBhQM8mYUlPk8vbFKPaaJ24CQIKQ5I7QqUkYFZkg2+qtmQBkweqbv4tbi+QA8UXYYaxo6Bq51rO8EgLvHNown+4yUB1FYmzkNvn4CrM3UWC00okYKDUFIOi2qyLCKUZHhxAC9VYjqgsJaqGV8NK8CTCGgAIuALM8XBEFIYSxOEKV67kz95hq+bvTnzvQ3as7oWZtFgZXJx0jWZjokUsUuojZBSG32DkGb3jErlmb6TZorD1Y3OmyEuw0Y5W/GkDcDMTnliNuAEC3SnBH2LpKpArOcOwNGk2lpeHi486fk2N38UN2dN3cPDnnXL18NAGB1XiOFHXRHPw16ljoWrM3UOHVu28weylJoCELqYqnIiEDCiwytb7KCnm+yEU1qaabQRAowNREUYArKuKmMowGU5fmiABMEQWgxxDohlJDcmTgdB3yENmtcO+qY5RwEwJg+68hIU4slwFzk1jzWZsnI67SEiNoEIam0ekGbGqfewURZmrmxEaxTnl04xOQ9kdwG0hG3ASGVkOaMsHeQbBVYzLkzyoDgGySuLPyewdlbqWjM5I5tx+IldCD/Y2lnihu8uGw2jjjpN9UZK9ZmMSCFhiAIOsRdZERQgKmx8iywYpcYgpkdY5NamhlhpADTNmQSpACzkDejEO04KQowQRCE1k9zOQ4Y7ztYsLyAbd525DkaGNnzT4PPMMqbSaC1WSQBiIjaBGGvpLkFbbo0WUYnxGJpduHgoCjs7rmHR7i/dhWNsm9QKyluA7pZaCQ+byYC4jawdyHNGaFV02QqMD0MVWD6AZcDMndydeE3ANy/fSzb3W0IX7Jvo8OOAgDa7rue6KzNlIGuiazNIg3sUSKFhiCkPskuMowUYHqkjqWZHkbKXTOMGjZNowCLdXm+IAiCkJpEnPhJcceBSOz5pYK5jQcAMH4/q2NYgqzN9ERtUVibqYk3r9MSImoThKSQyoK25rU0065mtGZp9uJJXwX2d9dlE1oHGaFeZamtlfTm50hI3oyC4ZgpbgOCCmnOCIIJManAOhBBBQbaASAdGw90eo90m4dZVQP4rGogRiqwzYt9E5Ib2lVjtzdG/n7NZW2mg/oHQHMVGmYNGkEQIpNqRYYaS5Zm2iIjaZZm6v0YVy+GkVoKsGgRBZggCEILoIU4DkTKn2msdPH5jn0BOKn/GgjY4agFEkaOA5GaNDo/DPQmJ5spr1NEbYLQ/Jj9O2k1gjY1Tr2DibM0A8hJ922XbCnWuVapk/REbNp97cpMA+LImxG3ASEapDkjtDqaXQVmmaDK+crCLxiQtZVd7hzu334SYDyQ//hZf7I9XsoddoaPWWdwVepbm6lpykLDDCk0BMGcZBcZZiTVN1khoZZm1hRg4agVYEaBxwlUgKmJMW8msDxfFGCCIAgtltR0HNCiVUmn65wL8sWKbtR70+mVu4uBRdsMPtwob8bM2kyD02BfQStqUyj1bxNgbSarZwShlRGFoC1aEiZoc2peJ9HSbGBReWD/2s+imfjRrqqxvgITkLwZoUmQ5oyw95CSKjAYkLmDKwt9X+6fO05jV2Oe/4z+pJy7IZ1eFZkAlAxfTfiS0CRZm6n3I3kol/q3KVpoyOoZQWgmIhQZ8fgmmxYZRug9y5wG+yEYrDpMiqVZkhVgkjcjCIIgJICmyZ0B/caMvuNA2aoGfvD4V88MTAFrMwWT7LtY8zq1RPxNJatnBCFpNLWgTT1nkjqCNqPjepZmavSfwU8fPzewv2BzF9UZvWaLUjtpxwvteJId3BQgbgNCsyHNGaHVk5oqMN8BB24e6PSK385sMJ9XDfZfpx5EtAoxB7WrfINoRSera+X1Brgorc3M0Fqb6ZBKhYYZUmgIgj6pUmSosWKPCMTnm2zZ0ix6BVjsNK0CTPJmBEEQWidWJ4C0jgNmJC93Rj2pZm5ppuDe3cAXVb7fE+P3+1N1RjtWa481g7VZAkRtZlhS3YuoTRBSgyQK2nSxmhcclaBNa2mmfmaaCYq1r337Y/ps0FxjlMepRc9twGh1JknJmxG3ASES0pwRhAgkM3fm4vYzGZC1kd3uXO7ffi4+O7PIk2wLPt0Xm9dLaZaNXgONfvFrC404Jgn1rM3iKDQSSTIKDWnQCEISSFCR0WS+yWEoViZKYZEMSzMFUYAJgiAIzYgyYTRH/7Qy4dTUjgOhY565pZnCpyt9vxuGd1hPu6wqzVm9CcImsjazkNepxoqoTes4IKI2QWh6mlPQZoT6+WHa9C31b5OW0QnWBW3B87npDYH962aOMrm31m1Az3nAgtMAxDFOiduAED3SnBFaJbGqwJoyd6Zr+jauKfwvAI/sPINdjW2wGnC5Y1MBfWp9/3wPPv43YrM20yMKazOFGAoN9Q8CI+V7cxYagiCEkopFhhpLvskKVn2TnTrnAyTL0qyJFWBqEqUAEwRBEFosqew4EERbH5k7DvyxOpOVnu44bB6O6/sn1ojD2qwJ8jpjRURtgtACiEPQZmRppkdLsjS7a+SPgf0XFg2J8AW1NZTRvgXEbUBoIqQ5I+xdRFCBKSRcBVaAptDwcmfx82TbG1hQsy/TKtXdf2sBl1kbiwGw9d5s8ctpBzo9FZjOJXokqdCI9APCCCk0BCHFSUKREbVvsontYmpbmiVRAZaMvBlZni8IgrDXkkzHAf0JtsgTbfWbapjtPgCAkwepJ86SZG2mRyTVeal/2xx5nSJqE4SE0aoEbQoxZXQm1tLs70csDhxp9NoJbcQbrZzUjg/qeknV+Fe7DRTo3EbcBoQmQJozwl5Bk6jAjAIudSxljstbyMi8RTR4HNy77Qp8dmbqQSVywOUvc3yTmWtzXbRpb1QYJMnaTKEJPZSl0BCE5qFFFxmRMHqGObUH4rU0iwYzSzO9a1JTASbL8wVBEFoGzZ47ExVmuTP6E3Rel5dZ23x103G915Bma9BcYWWiMMnWZjokK69TRG2CkMKkkqAtUkanmiRamkFwDHlt2SCD+2ob9tpxQes2YEJT5c0IggppzgiCBWLqhBtYyGTbarit+DEAXtl1NqWunqqz1ifZflnQlU4NHhrsNkaMbznWZlJoCMJeRhMVGbqU+rdWfJOdmteGRUYslmbhCrBwtJZmemjVYS1DASYIgiC0YJojdyY7bEfvJFYdB75bVcxubx7tM2oZ3n2ThS/notmszUTUJggtklQXtOkSj6CtiSzNzhkUHGP+/tUIo2+IseWlepVNCuXNiNuAoEKaM0KrJVVVYBe3/zcd03eyqaETL+06x380m+BAYU0FBnY67GwHQMF+6yJ9dT9xWpslwEM5HqTQEISmpamLjHhInm+yGdFYmsWygsbIxgxakgJMlucLgiC0QKKcOIo7d6YAA1GBIj7QrprRa8zoOw5U/VHDPI8vp2D8ACPxQYKszZwGl2mtgaLI61QjojZBSD2irpn0SLKgzbTJW+rfRitoc+qcT5ilmY+Xxn8V2N9Zk6M6Y+Y2oB0fDGolsxJK3AaEJkSaM8LeR4TcmaSpwAqgMLeMy9tPBeDJshto8GZoLtKGWJoHXG5e4htsN7arxm43UmMn0NpMDz0PUjWl/m0EFZgUGoKQWiSjyNCi/TesbsCa/dtXSJhvshqn9oDa0syomZ0oS7NIiAJMEARBSC7xThTFnDtjCSOXASMxmw/Xznq+qPT9xjhp37VE5zigvLZobabGaX46EjELT0iQql4QhOTQmgRtUVuaaY8ZCdp8x/MzfVaUy7YqA4XWbSDSuKCtlfJDjxUgbgNCsyPNGUGIQCJzZ65t+yy59hp+rt2fmVVjCVeBQTQBl9/P6E9eo4fdDjvDj4u20DBq0mizFfw4Dfa1xOihbJXmWD0jDRpBsICFIsNSA1WHpPkmKzg1ry0XGZAYSzMFUYAJgiAIrYPk5s5oHQcweB1k5m8luL129m27g5ICp4XPSYK1WYLyOo3sXeMStenVTCJqE4SINMeqmYQK2iJh9bll2LDWWppZf7YO6hD88KtnHGtypdZtQM95QG/uTYO4DQjNhDRnhL2GJlGBaVE93LvbNnBm3gcAPLrzFkA9GGrVz9aszdz1DkoqfEs7S4avtviltANhDT3T1zGuzTwmFr7Kk13+jze6T+TDnucxveREXu12KZPz72FE1jfY8DSLh3Jzr54RhL2Nplg1oyXaxmtcRYZCyliaiQJMEARBaD4iTgj5J5S0dtDJdBwIzZ3J17koOseBbb+7Wez11RQn9m9CazMn4cSZ16lGLWDR0hyrZ6RBIwgRSOKqmZgEbaX+bUIszSCypZnea33h8HPjZgf2F25Wd0mMBG0YXGM925kiYnYcELcBIVakOSMIKpKZO3Nhxps4bI3Mrz6SJbXDDO4efcBl1W89ASjvuNvkWyv4Couu6ds5r+AbnuwylW/63MbM3rfxWJcpXFP0Pse1+ZaDc35hYNYq+mT+yWG5P3JOwXv8q/gqPs49hRK7Tr5NDB7Kej8MjKzNtMjqGUFIMWJYNROvb7Ippf6tUZGhLjacBvtA9JZm6ByPldRSgCmEKcAiIAowQRCEFkwyc2ciOA6EE5/jQP3mWr6sHQzAyYP/JGnWZkarb1uCqC0Bq2cEYW8iIfmccdhAWyGpgraYLM3MnARCn8cjSzZrzmsFbVrUbgMK2nEjO7gpMPkqakqCu+I2ICQDac4IrZpoVWAKiVaB5VHF6RkfAvB69SX+o4oKTDtYqAeZyAGX30/flzSvl02ZNvY9aIvqmuDAZqOBA7M3cUuHOUwveYYvez/OXR0/5rg2yylyVFLvcbC8th/vO4/l4R1XcMPmf3DFxke4dOOL3Lb1Qd7cfSFVnjz6pa3h9Zy/0JVNcXsoq5HVM4KQOjRFkaElHmVnQnyT1cRlaRY51NIYM0szvWuiVIDFiDIOKuNiGMryfFGACYIgtBqa23HAnNgcBwBmrPb9kBnZeR256Q0WPisGazM1eqtyI1kElVq7tVVRW9SIqE0QmpVobKATJmjTEnNGp/aY2tIMQp+p5lnIBVl1gf0Jn4wxuEpZKakVsem5DRiswhS3ASEFkOaMsHfSxCqwM3I+JNdWw5rGPnxfd7jOJ2QT66Tb7u159K1OA2D/Mb+jDHA2vByQtZW/d/iW2b1f5a0eb3Fp+0X0ySzH7bWxqKYXT+88gYs23MCha57g/A3/YPL2K3h99wl8uedIvqs5hAU1h/BJ5ak8sOMfHP/nLFY27EMHexnXND4f/AJWPZRL/dsIKo3mLjTER1nYm0nW/8cTvWom4b7JupZmZl2aWCzNjIqPSAowSIgCTPJmBEEQhCYmebkzEIvjwE8rslnvKSbT7ubY3jpuAECzW5vpkMi8ThG1CUJiaFWCtoRldKpXF6rRPkMjW5r93+jvAvuvLtvP5ItpMXIbiEAHwseiRObNREDcBvZupDkj7FUkRQXWLcJDtAMc75oJwFsNFxCaNaMlNhWYbV1XABq6b2Fw1g5u7fADX/V+h7d7/o+L2/9Mp/Rqqhoz+KRiEDdvOZkj1tzKxRsn8K9do1hS25cGr1kzyDfq7m5szz277gXgxPRPaUNlqyw0BEHQp8UUGaX+bUJ8k5vD0kwUYIIgCELTk1K5M8o4FpY7E5/jQN36Wma7DgDg5APW0dKtzYxEbVYCwk0RUZsgmLJXCdriyuiEWC3Nrhn2U+CMNyyvWYtSO2kb9VoRtAWSlTcjbgOCCdKcEYQYMOyU66jA8jxV7Of5BYBv3P7RNqTQ0JJNdAGXaeyZ2ZmLZjdy4+tu3u05g0va/0Ln9Gr2NKbzSUV/rtk0jiPXTuC2bSfxedVAqjzqAUpRKmiriPBC4+eKIWxzdyTD5qJEvQZWCg1BaPGkWpFhRGr4JkNiLc0i0bwKMAVRgAmCIOyFNEfujCmxOw7g8fLZxgEAnNhrNTasjFNJtDaLM6/TKs0lapO6Sdgb2fsEbYm3NEuzeQL7T/xwkH9P6zZg9OzXE7RBSINfcRswzDhTURLcFbcBIVlIc0Zo9TSZCqxE89o/ITasdhFpeCj19mRrThedgkMZJKwHXNrwMiRrG5M6LODL3u/yROP3jF/opUMl1KfZ+bSyNxM3j+HItRdy27ZjmVddgsvrIHxw1KJWgemztdE3s9e5XmdJjJE3qRQagtCiaY4io0l9k8Noakuz1FaARcybEQRBEFodqZ07A7E6Dsz5qRNV3mw6ZlZxSNctBlfFYm2mrLZV4TTYj4YIojY1qSZqE4TWSNR2ZhZJeUGbmiRbmv1VtWpm8jy9WAAFPbcB7b7eXJuKAsLHHnEbEJoYac4Iey+JVoGpUT3MCxt9s4F/Onpb+BTjgEsHjQzP2chdxXOZ2+d13un5EZe1/5mu6Xuo8Tj4pUcWj55h5/FLM7l160jm7CmhwWtcmPgwVywECY6+afhUDG7luzlN3mbmWaqDUaGhRgoNQUg8LaXIUJNQ32Sn3gWxWprFSgtWgMnyfEEQBMGAuHNnEuo44KBiTR3zGgcDcOrgNSTO2kyD+pB69Uy0eZ0RsCpgEVGbIMRPTP9/bkmCNgXLGZ2RsGppFnrNMyfMDRzZ05ChOm8maNM7H4WgTdwGhGZEmjPCXkdT58405PsGk0zqo/gUX/GRbavn2LyV/8/eeYfLcZTp/p2Tj+IoS7bCUbYsWZaRs5wDTjjiCBiT810yCxuAheUa7rK7LOwCS3QAg42xjXPO2ZYsy7IkKx4FK4dRPHnm/jFTMzU1VdVV3dVh5ny/5/EzPd09M32Odabrre+t98MPx96F56b9N3474Q5cN+wtjG44hAN9jXhg31T83bvnYMHq6/Bv2ePx2ow6rBzRhZYB3dx7ic5u9ihzLPAqQp6hPCy1GwCwOze8/JJNo83aC48erg1VtJkICQ2CCEY1iYzqiDQzLXqrEB1gqkiz5DnAaHk+QRBEdZPIvjNl2CcOiGQ7+/DgjvwY5NJZpjeugNFmjIxkX4j9OsnURhDRUROGNt+RZrJ5JFVqiy5toJyHV08qbKlWxfDI0gYsobQBIiaoOEMQPjHqOzMS2Ng8EQBwXN9rmMLcyGlIG1w2ogfzWt7Bp4ffi9+N/xlenPbP+Onht+KSoUswtL4TO3sH4I7MUfj0pkuwYM3H8PUt5+HxA5PRlWvB609OxojeLA7V1eGUS1ZA3f/ANNqskpZUB8Y2bgUAbM2O9ZehLCERQkMy2ew1wKICDVHrRCUybPGVm8yLjYxiG0B0kWYMLweYeE48DjCCIAiiHxNl35lRMOg7A+gSB8ofK7n/rcnoydVj1uBtmDpst+KsEKLNGBH16xTx6llBpjaCUEOGtgJKqeQVaaYyCJdz9uT1xe0vPKj6ZbGVkaKJzTJtIC15a/6e1FbapLQBIkyoOEP0C2LrOwPgjUHH4IXWk9GEHtxZ/358d9C38aGmW/C+Affh2vRt+PTwX+D/jv027pr0cbw+44O4bdK38cVRd+HEgSvRXNeLjd3DcdPuBfjQhg/jjDWfx3e3vRfPHWxDT66+7HNyuTocvnMwAGD0PJObhFkzthIdmN2yFI2pXmzvHYWte8aWDmUMXi4jSUJDATnBiFomLpEhwhdcIxMZPBUiI4pIMxsHGCM+B1jF8nwPBxgtzycIgqh+Iu874+uexd8bZStoKu+xW1f04pXsEQCAy2avQuTRZowI+nXqTG1GxhkytRGEJzVraNPiFTvgL9LsN5c8VtyzZs8w04uBr7QBwLvXmSQlh9IGCNdQcYbo30TRdyaVwj+P+le83jIfLejC1XV/wT+03Ij/N/Ib+PaY7+OLo36Ky4bejyNa1qAx1Ys9vYPxyP4T8C/brsWFa/8J5637R/y/HRdiUccUZFEHnQts++L8pObmkfuAQm+YPLJoM0AdbSYKjXy02TGtiwEAizvmASgMsG0ylKtQaHhBQoOoVcIUGV4FVWvaPY47y00Gwos0s3GAlVZcFp+mQQ4wgiAIIpEY952RkYY0caAcXh95rTxtQO+ebjyy/2gAwKVHmd73AkSbZRTbIrpeeWRqI4jY6HeGNkZGsV1EZWjzH2nWls7PPa3PDC7sEXuH8TDtJCvMO0ob4LDpN0MQNlBxhuiXhN53Rvgy3zp2HD487lZ8uvV/8evsJ/BIz3l4ufNEPLr/XPwlcyV+tvMz+PymH+CcNTdhwZqb8OXNX8HtmfPR3nM48kUQswaXz987C63ZLHY21OHYs9fDf7SZnDMG5RuzvXrohMqDGckLxGgzGSQ0CCIWki4yePjvhPBzk1XoIs1ETCPNvFA5wDzwcoBJIAcYQRAEYYPzxAFxciyNvOlASSvUE3FqMxvjbytmAABOHtWOEa2HFGc5ijaTYduvMyTI1EYQerz+3RrPFUj+jhJraHPWo1MXaVbJlGGZ4vYn7ztXcoYY/SweA+SrZQRDmwzxHtSmOE9A1W+G0gYIG6g4QxABMMrpZ1/yqRSeazgN/5n7Kr7c8RN8rPv3+NKen+I7276PX+z6Ip46eAE297YBGCC8gXmDy85DTZi6txkAMP2UFYY/BWDi9h5VvxbvaX0DAPD4gXPyOzOSEyMUGrR6hiD84UtkGP5tuBIZfGFWSqIizWyLMHxxXZxAisABRv1mCIIgCAWmcdBe+EocYFiZDGR9Z8TnlcWad95qwLLsJNSncrho5mokJtrMsl+nialN1ExxmtpINxG1hPLfu485BS+qx9AGmOmkyv0/Pf+p4vZjaydp3l+WNiBuA6XVlgImaQMclDZAhA0VZwhCIIy+M2Wk4eECAyon4MwbXHaunAgA2DdOvMPKnN6qpmyVQuPcwc8CABZ3HI3tvWMSITRESGgQRHCsVoo5XDVjgrPc5ERGmvGPDMcOMEOo3wxBEARRgc+JJdvEAf99Z8RUAZ3JoQHdmzvwSMc8AMAV89Yafo5NtJkweZlRbNvgN5rIgyhNbQRRTYSZNFC7hja+0G0aaZbff9GMddwx/rp1kWay441wkjZA/WaIiKDiDNFvUE4WMaGRSBeYfYPLF/82G/W5HNa3ANPnboeraLPLhz4MAHh4//nlBzKKbRuqRGhQvBlRCyRVZMSfm1zqr1W5P4pIM9H1JXODAYEcYG2lTXKAEQRBECa4mkiyShwQSUPRd0acfDNPHACA+9bkTSLnjF+FlgaVHrKNNpOMJ1SmtpD7depWz3hBpjaiv+MszswB1WtoA0wjzQY1dRe3/+6hMwpbKn3EI37vGxRleKjfDJEAqDhDED5x5gKrEBoidg0ut28eimkH6wEAx5y/VPGhMtSu71nNqzG7ZSW6s424d997vd8qRqER5uoZL0hoEEmnZkWGCl+5yaIDjMd1pJnsmO57nhxgBEEQRHJITuKALNrMO3HgpTeGYnNuOAbWd+OsKe1wE20WAMf9OnWImolMbQThgH5jaFMhFq7NI83+5YwXi9u/eP1ozWfITGx82gDPEFSkDcjuKdRvhoiZflOc+dGPfoRUKoVUKoWXX3457sshEkDsLjCVy7noAvPf4LJu3WEAgO5Jm4Uj/qLNrkrfDwB4/MApyPQNK53qFW3GiFBoiEQtNKhAQ9QcSRMZMkxykwPBf0+6jDTzcoAx4nOAEQTRvyDNRADV2HcG8JM40LmhA491HwMAuOKYmKLNAvTrNDG18YS2eoZ6dhI1RpiGNluzaCSGNobTHp12kWZfOWlR8Uhvtp47z2slpCxtQLaykiMt2Uf9ZogY6RfFmaVLl+I73/kOBg70tN0QBAB/LrCyL+027oDsSz5teiX+Gly+/tAcAMCaAVmMmZhBkGizIXUHcMmQ/C/kL3svLux1mKHcXrmrmoUGQSSVahYZZbBCbrvh+UxkZLh9GfEkFkEiKg8+NxkIN9IsmQ4w6jdDEP0H0kyEJ3H3nXGcOIBsDvetz5tJLpq8Eimo7mEhRpsxfPTr5HHRr9O3qU0BmdqIasS3ZjI0tIkkwtAWc6RZfSpb3P7Jy8cUtsRIMwjPG6HvLebBKOgNbZQ2QERIzRdnenp6cMMNN2DevHm4/PLL474cImbCdIEZ48sFZtfgct2yUZjSAeRSKZx0afkSSzmqVTTAVeknMKCuE+90teGVQ0fAd4YyCQ2CiI1qFxnagq1fkaFckR9mpBmDHGAEQSQL0kyEjlpOHHj8jbHYn2vF2OZ9OP7wzYg02ixjfmoZAfp1kqmNIOIjFkObLRnFNgB1j04gSKTZZ497s7j97adOVrxO1VNM1EuWqMwAHJQ2QIRNzRdnfvCDH+Dtt9/G7373O9TX13u/gOifxOkCS0PR4FLErsHlwHdHAwAapm8Sjqic35UOh0bsx4fSDwIAbtp9MQDDG3zG7LQK4hYaFG9GECUMRXZtiAwdLiPNyAFGEEQyIc1E+EWVOGBMG7cdQ+LAvjWdeLp3LgDg8mNMjQqm0WaSVbkZyWmO+nWSqY0gglGThrb2wqOXoS1Qj05ZpJnaACwWv392wVPFI/u7m2UfWEDUUryhjb8HCPNqurQBkbbSpjKGE5Q2QLilposzixYtwg9+8AN85zvfwZFHHhn35RAJJBEuMCWiCwywaXD59hOzAQBrBvdgyIiD8BNtduGQ1zCmMYPtvcPw4L5ThNdbCA0VSRMaBFGDOI0zC1Fk8BiLDC985SaD27aJNPNDI+R9Z/hIM3KAEQQRLqSZCBlBEwfYxJVV3xmRCBIHcj05PLAlr5sumakTh7poM1FniZOV3G5GCP06TSFTG0FU4vvfYa0a2nyhmnMC1PqpdK+5950phS3xe5xH1bOTne+hndKovLco0gZkUNoAEQY1W5zp6urChz/8YcybNw/f+MY3rF63b9++sv+I/kmofWdGwdIFZt/gcslLh2N8dxY9qRROuextg88pdzbUoxOfGfEQAODWPRegp/h5DjOUZZDQIAhnOHV/SXApMvjCqxTZd0Moucm6XjK6pfp++83wyHrPMNw4wHR4OsAIgqg5SDMR1iQycUCcjLNLHLhvcRt6cvWYNWgrpg7bDftoM6DcSe5BRrFtQnvlLjK1EUT4JNHQxuPc0JYRd5hGmulSBSr3X35E6d7w+QfPlrxGVYwRv+ctDW2UNkAkiJotznz729/GqlWr8Pvf/95qaf6NN96IoUOHFv+bMMGuMTmRfGLvO+PpaFbdVCwbXCKFEVtHAAAGzW4X3ss72uziIQsxqWkHdvcOwp/2nAajDOWMYtsE8RKRMKFBBRqiv1NtIoORUWwDiCfSzMQBJm67c4AF6jfD7o+KiTlank8Q1QdpJsKUWBMH0jAwI4j3S/PEga3Le/FK9ggAwOXzQo42c9mvk0xtBOGEsOPMwjK08XMkRVwa2nxFmomPekPbLZc/XLr0fYNlH1hAZmLj0wZ4JIY2GZQ2QCSEmizOvPTSS/jxj3+Mf/qnf8KcOXOsXvutb30Le/fuLf63cSP9IfYbonKByUhDuJ/wLjD/DS7XvZgXGevSHWhqkd0Y5ZONDejEZ0Y8BgD43e5zcSjXovgEg4aXCRMa4uSxsdAgiCojaSJDh6fIkOEsN5mPNONFhp9IM5tCjc4Bxki2A4wgiOqGNBPhCj99Z6wSB6zQJQ6wOFF+H5Dt7MNDO/OGksvnBIk2001aasiYnVakvXKX1NAiQKY2gqgkjn93tWdoA7wNbXIGNeWPvbJpbGGPl6ENwnF+nyYSeiDkRRr+ntNW2gyz3wxBiNRccaa3txc33HAD5s6di29+85vWr29ubsaQIUPK/iNqm+S7wAC/DS5fengKRvVk0VFXh1MvX2bwOfkb4yVD38DEpl3Y2TsIf8qIM7BMaFhGm5nQXrnLtdCQYTzJTEKDqBKSKDKqMzcZsIs0s6W6HWC0PJ8gqhfSTIQJYfadMcIocUD2b0+VOCAjf+5dS/JGkROHtWP0wAMwizazuf877NfpgSpxQCS01TMBIN1EJI2qWjUjQzS0MXwb2iCcJEaaqag0sS2Y8G5x++P3vlfyGpmhrRHy6EqzGEv7gn8JV/1mKG2AEKm54syBAwewatUqLF68GE1NTUilUsX/br75ZgDASSedhFQqhXvuuSfeiyUST+h9Z6zw0eASdTh8Z16wjDpGvJHIneCN6MJnhj8NAPjt7rPQkatH4AzlhAkNIycYCQ2iSjH5txb3qhm+oKoSGVoHmNd3RuDcZCDaSDMIxwM6wHjaSpuBHGAEQdQUpJmIQCQqcQBwkTiw6q16LM5OQV0qh8uPsllZajIeiKZfZxSrZ8I2tRFElITdn1OG61UzZZgY2mRRZlbwqwN57CPNbr6sFGn29g5dNV5WpOENbZQ2QFQ33qOUKqO5uRkf//jHpceeffZZrFq1CpdccglGjRqFtra2aC+OSAwX53L6G/GTMJqcn4Y19jdMIH8zkN0UW1G4x/E3F/Gm14j8zYw96tm8cDrwvjewccR+pOr7kOsDyv/0ewvP84+XDX0T45v2YEfvINyeOdnj3ZnQaC1dKrv0g6hcFbQTpZ99JPJCQ9EXAZtS0hujitWYqnQyLMFRnvEKudMlk5BnodIReB6kgvTi6d4u8vtSKXJJELETh8gIbdVMaLnJfiLNTInAAab6XpVg3W9GAX23EUR1QZqJ8MN9q9xMrk/ARm/Xt6iX0oVHzxX6RUGF/D2X3bvZPbUHJe1Tom9/Lx7IHIN5w9fi/fNW439fnqd4f/Z69sjvB+SDDcXEYQaln4vfNqEd5UZAgY2YUFwpuwbTlJOLKzFTazxcPHxW0bChRKaZNHjpJtJMRBQ4N0/GFAPtxNCW4fZlxJNc9OiUM3X4XgDA+ozYa0angWTJAwy2mlJIG5Cl1VDaAJEgaq4409rait/85jfSYx/5yEewatUqfOtb38KJJ54Y8ZURVcEj8DVBORWrK1dyjM+VJhXHoXRjlBVm2M0iI3t3Npjnb4q8IOAfK3nu3pk47oKF2Ftfh5MuWI0X71dPkjamevHpEc8DAH6963R05uok78s/N3AoZBBYaGxfO6Ho+I5UaKgIUKAhiDCpBpEReNWMKRnFtjGqJfp6B5g5XpFmCXWA+XRLEwSRLEgzEWExb/dyLB4+C3PxFpbgKMzAO0qTxugpG0v3/DaU7vVjUL5aZBS8V+KXwQo0Pcjfa/n7OP+cbeeLNX99ewb+8VTgjNGrMLS5E3u7WlDSWcwY1yA8svcxgXeyKdiB/M+rMrVtQeUKoyoztZlABRoiTJwnDRhSvYY22x6dIpU6adbIXcXtG+45v7DFT1HL4vxVJrZGeH63piX7wug3QxA+qLlYM4LwQ6h9Z9o8XpQ2fXdxss6swWVPdyMm7xkAAJh4klisKHeEXzX0TRzWuA/begbjjr3HCefqmlpaZCirlunLXB2WfSZC6T1D8WZEldBvRIaYmyx+p6gmbzrEJ6pIMzFDPmikGcNr4oYcYARBEEQVUpiQ95qYct53ppXfGAL5xJxsVaqat95swTvZ8WhMZXHxkS6jzeLp1xlpJLQKijcjqhDrf5/92tAG+Ik0+/XFjxWPPLNetZJS9t0t7rMwtEXRb4YVqanfDGEBFWcIwoPAfWd4AvedAfw0uMwsmQIA2DY6AyCLyhtlLwbWdeGzI14AAPxy9ynozgFmET4kNHhMBnJUoCHiwLowUwsiw/M7SOYA4xEjzWwLMoCZA0wVaUYOMIIgCCJelBNJUfSdkRVmtP3WWN8Z/v4prk4VH0v07urGgwfeAwC48j264oyq+TUbK4jjCw0ZybbKaFJjpjYv3USaiQiDQP+uAvTnrBpDWxkuenTKWTBxMwCgs7fe4xq80gZEWtWHAOo3QySSflWcuemmm5DL5Wh5PmGGYW6utoKua3ApExtpKFxgJg0u1SmFz95zJFqzWexoTGHe6Ruk53x8+CKMaOjAuu7h+GtmNneEuchlQkMjNjKSba8BQD8RGgThmrjizEIVGTJMc5N5MuIOr9xkQB1pJuIn0kzl4uW/wxPqAFNADjCCqC1IMxFREG7iAOCdOCAey9+H73onP3Y597CVGNDYjdI9Xrzn84+mPelYNFABL1ObLoKI0V65q1pMbSZQgYZwSVKSBryIzNDGcNqjU6RSJx02eH9x+/q7LyhsVabAlOALNLK0ATZnNqT8ZWnIC/qUNkAkjH5VnCEInkhcYCrEm0Ea8oiaCkQXmExgVBZp9u9rxbR9zQCAGWesKOwtCY3RDQdww7A3AAD/seMU9KIedkIDUE54mgiNbZJzJLgQGuLkMQkNotqpdpHBUyYyIslNFokq0ky3mkYszEgizUSo3wxBEAQRESYTTdEnDqhMDTaJA3lefWMINmRHYUBdD86bvs7oNXl04wMPUxtPxvDjvAwrAXFuaqPUAaIKiCrOLDGGtgy3LyOeZPid5SPS7KfnP1U8cucy2S9d7C/D0BXYPRDvIZQ2QCQIKs4QRIHI+s4EjjZrhVpoqF1gnSsmAgD2HlZpLf/CiNfQWteLhYfG4YkD4qRpwAxlRkbzFjxsINFueL4E3eoZE0hoENVCoMKMiohFBl9oldJu9DYlMoptKaqVgD2ILtIM0DvAJJADjCAIgkgSrvvOWCcOAPrEAe9IM0b3lk482DkfAHDVsabRZjITB+80F3HQr1MGZ2wxMbWJmsnZ6hlKHSASiPM4M0Nq09AGBIk0e39ZTy/d/xevSDOLtAFAb2iTwAxtlDZAhA0VZwjCAKd9Z3icNLg04/m/zUFDLod3m4Hp87cW909r2o3Lh+ZX0/x4xymoXLoP7nmADGVGBEJDxM/qmajizahAQ4SJ9t9gTCLDiFBzk70wXSUTJNJMto89ipNKGsgBRhAEQURA5H1neAIlDnhFm8ljdO5ZnR/nXDRhBRrr+uAdbcYf8yKafp0u8L16RgWlDhAxEUrSQC0Z2jwxjTQTvwcrdVK6pbO4/YUHzyxs6SLNeGSGNqBi3kyVNgCo5+C4e1BgQxulDRCWUHGGIHRY9p2xanApI8QGlzu2DMH0A/lma8e8d2lhby++OupV1KdyeGT/VLzZOaq4v/QYUoayCe0+Xxc1GqFBBRoiLEJxf0UgMuLPTWbfVTKRIYs04/G7gobHtqmlBwYrZnjIAUYQBEEkhfATBwC7aLP8uc8uGoHtuTSGNHTirCntFp/Fxgkyk5vDfp3MqOLRr1M6toLd6hkytRHVTlRJAzKqxtCWEXcE6dGpjzT7v2c/Xzzyy9ePlryvGGnGnqsizTT6STu/VqDN43gAKG2AMIWKM0S/xrULjGFUaZe5wIzx1+Ayu3o8AKBzYv6OfcKAzTh90Eb05FL4yY4TNJ8XIEM5I9muRaERwAlGELb0e5EhYpSbbIrrSDNVU0sIxw0dYDK3MH8/8eg34xtygBEEQRAFIus7w2OVOCBO1un7c4p0bOzAw13vAeA32gwoTxyQfop8t0m/ThntmmMwcN1HBZnaiAgJ/O/FQuPXrqFNxLRHp5zPHrukuN2X001Jy/rOiBHQhujSBjhk9yJmaKO0ASJMqDhDEFHRxm07c4HJCjJqXrpvLlK5HNa2AlNmbsfXRr0OALg9cyTW96S5M2UOCPbcMkNZRy0JDQ0kNAiXBC7M1ILI8JWbLOmPVYYqM151LjTHRWTiwqcDDPDlAKN+MwRBEESouOg74yxxAFAnDrBjmv4zWeCe9XMAAJdMXoG6VBbm0WZeqQMh9evk0fWgKFDNpjbSTYRLwoozi9XQpsKXoU3VoxOwjTRrbSid852nTypsmUSayUxsPIKhTYau34zE0CZNxdFBaQNEAKg4QxAcukmnSPvOpKFxgancA94NLjevHYYph/J/9h875mXMbtmFA32N+MWuY6AWGrbRZhyyaLOM4VvVmNCgAg3hgtDcX9UsMhgZ0w9RRZoxVEUZ15Fm/D5GNA4wBjnACIIgCFPC6jsTT+KA7Hn5StbHF41BJjcQo5oOYMGETQafIzrJY+jX2V65S9WvMzJCiDcjCBNCSRoIQKSGNlWkGY9vQxvgN9Lsn09/uXjkR88fJ3lfcS5LZkgWtZNEP6XhnTaggPrNEHFAxRmC8CLMvjOym0NoDS7zzxvbD0djTw5nrdoFAPjV7rnY02cyIegoQ5mRUKFhPPlMTjAiodSEyJARem4yUCkywog048/hj0fnAHPVb4YgCIIgQqGN27ZOHBgi2WefOHBgbQce6y1Em803iTYD7MYMDvt1+jS18YRmalNBpjYiREJLGvBpaPPSTNVlaPMfafatU14rbnf1qUzFsu9o8ftbZlqWkIbe0Nbm/RZ+obQBwgYqzhD9nlj7zoikbV9g0+Ayz6v3zcVFr+WQPgBsy7Xi1j1HKs50lKEcg9AIsnpGhpXQ0GA6aU5ig5CRdJEhYi0yeNp9vs5XbjJQ/j3nJTKCRpqJz/liOlBNDjBank8QBNF/SUzfmTSExAH2aJM4UDlJmOvN4W+bZgMALp+xHEAObqPNFGQk2zamNgkqUxuvmfwQpqmNCjSEX6JMGgiD6jG02UWaNdb1Fbf//aX54lHhOa+XxO9pi6QBhoFeon4zRJxQcYYgQqbsS76NO2DiAvPd4FIdbXZg1QBc8VIWAPDcvEHoyjUg1AxlGZnCY0hCIyhxx5sRhEiUhRm/BF41o3OAeYmMDPcafluKV6SZCteRZqIIIQcYQRAEUeVY9p1JTuKAeCyvpR5aOB4Hc80Y35LB/HEmtnQ/0WaKyVBbUxuj3efrkCxTmylUoCF4TP89uNLksaya4Wn3+TpPQ5sKXY9OfaTZV09eWDzy7adOLmzx81a2kWZA+TxZ4SGtuXwG9ZshEgYVZwhCIIgLzDgqRmxwKYqNwA0uxf2lfZ8bsRgt3cCascCbp2YMLlbEp9Dw81E87Xanu149o4ScYETIRP1vofZEht/c5LAjzfhHcoARBEEQySesCaakJg7sWdmFp3uPBgBcdaxO58mKMiZo+nUyMh5vwQwsHokDUa6eidrURrqJsCFMQ5ttf04RZ4Y2EWtDW5AenXJuPPv54vahHtV3sEmkGdvW6CfZfBr1myESDBVnCMIEQxcYw9MFxhO4waVqIrDyxja56QCuTucn5G49qw6rBmUxbtIezfv7naTUCA02gZrxeIuAQiMogXOUHUBCg3Di/nIcZ2ZLaCKDkTG9ElkUo21usp9IM/F5bTjACIIgCKKIowmpaBIHvCPNGLmeLP62OW9IueKIFTCPNmPHVeMFh/06ZbRrjnlQjaY2gHQTEX3SgMncQGyGNt33hHJVnm2PTv45/11X+b1Xl8oWt3+zaI7H++sizQArU5th2oBRrKYFlDZA2ELFGYJAxH1n2rht6waXKswaXH511CtoSOXw5P4J6Bxdh2wqhZMue6tw1EW0GQmNMkhoEAGIujBjSmJEhu/cZN4BJu5XESTSrDYdYLQ8nyAIgnDdd0aK88QBtk98FA1v+ef3LpyIrlwDpg/YjlkjdUKFoZq45J3oItH064x99YwOioUmApDEpAFbAhvavMhI9gXq0akrQJfPKX322DeLR77+2GmFLV3agGylYwMqv8sFQ5sMA70kg9IGiCih4gxBRIC7Bpe8C8yuweXxrZtx1qAN6M2l8OMdx6F1/dj84emmE3cmGcoiyRIatpPGVkKDCjSEQ6LOSwaqVGQwfOcmA+aRZqYRJYB9U0tygBEEQRA1guu+MzzOEgcA02iz7St68XzfbADA1ceJ1yqurhUnLr3QRK9mJNs2/TrbDT5eQSymNgeQZuqfkKGtgFePTp6M1wf46dGpLtb894VPlT66s0VyhtibE4XnKhOyQj+lIe9Fxt87KG2ASCBUnCEICaH1nVE1uJThsMFlCjl8ffSrAIDbM7PQ3jMUbzw0FwCwemAfRo3fq/kMWeSPCV4ToohcaIiYCI2w482oQEOERn8QGRnuXH67AlVusk2kmXieDK+mlgxVpBlADjCCIAiiP+I2cYCZ2mTI+nPqo83u25LXTe+fxVb9iGMBnXtcHG8Y9utk1OLqGTK1EY4ItTCjoDYNbeH16EyhNAd221szNWcC8iKNGGlmUFhPexxvq9xFaQNEnFBxhiBMcdF3RkUgFxjg5QJ735CVmN2yAwf6GvHzXe8BAKx8cwwmd+aQTaWw4HI/0WaqBnA+os1MiGj1TNTxZjaQ2OgfxOH+ql2RYYIY1ygWZFxGmqkcYCxyhRxgBEEQRLIJKw6a4SZxgMEnDvAHVdpJHrPz14WT0Zurw1GD38W04bs1F6absGSJAzI0/ToZmcKj6IJP0OqZpMSbkWbqH4SeNGDYn1NG/zO08ZTPJX3iPW8Vj/zdQ+yXqos0A3cOf9wiaQDQpw1wmNxzmKHNBEobIPxAxRmCKJDovjMBGlw2p3rxpZGvAAB+tXse9vSVXjdoQz7arG7GBt2PwOGVoSwiZJbKos0yio+KafWMjKQ4wQASG7VOkt1foYoMGaLIYMhEhhLT3GT2nKGKNDMp1NhEmgHGQiPtcbytchc5wAiCIIiwCS1xgMdJ4gAgTxxQPQfY/frd5Tm8UIg2u/a4lYVjMfXrZOja31Szqc0DKtAQgMP/vw6SBqrS0FaGqaEN8BNp9quLHy9u7+qQ6R6xCGOyeoafFys8pCVvLRbyPQxtDDFtgEFpA0RYUHGGICIikAtMiXeDyw8PW4JxjQewuWcQbt1zTOFY/ub25sNzAABrBvVixLj9ms/xm6GsICPZJ0abyUiY0KBGl4RL4nB/mWIrvlUiowxeZLQbvKnuuyHj9WJdbjJgHmmmQx2JonboxucAIwiCIIiwCZQ4EErfGaB84s+s70yuO4t7th4NALhq9grbD4Szfp0Zj4+RmdockgRTmw1UoKlNbP6/9itDm4jO0Cbt0akr0gSPNPvL214iVizKqNIGNAxE5T3BIg6a0gaIuKDiDEEo0LnAmNBw2ndGJjSsXGCVk39D6zrwieH5XjP/tfMEdOXKJw+XLToMbZ059KVSWHDFksJelQtMhM9QFoWGJkNZJjS8cCA0dDnKfoWGEnKCEZY4K8xEFGfmV2SUOcBksAKs1998pvAYKDcZCC/SzKSpZQgOMA1B+s3Q8nyCIAjCCp8TVuEmDqic2KrHcu58fSp6c3WYGyjaTIdi3JKR7FP165TRzm33Q1Mb6abaIpLCTLUa2ky+DzJeVxS0R2f5HNIN894uHvmCdaQZ/6iIf1ahM7S1VZ5OaQNE3FBxhiBssBQa1pV3hlGDS9nNqXw56A3DX8Pg+m680zUC9++bIX23IZtG51850yvazLHQ4MkUHkMUGiKh5yj7dIJRgab/EfqyfAV+M5O9CFVkiDnrgEVuMoMvKjPE1YEQjsn2m+CgqaWNA4z6zRAEQRARkMi+M7L7ZQXiBJ/ozJY5tFm0WTaEaDODfp0MU1MbwzbyyJBITW1UoCE4nBVmdJChTYFXRKNcJ/3+0keL29sPypzHYjFctrJRPEdiaJNhsGJGd68RDW0mkKGN8AsVZwgiAqxcYMYNLgF1g8tGpOsP4vphCwEA/71zAXKKm9zSR/MiY/XgHqTHHCocizBDOUKhoVs9Eym0VJ8okDT3VyJEhilakSFD9p2kKjb7jTQzcYAxHEaatVWeTg4wgiAIIipC6ztjkzhghSpuVI2/aDOVOcRHv05GpvCoMrV59euU9a6A3eqZyExtHlCBpv/gtDDjIGlAhpehTfw7SqahTUTVo9PEsFv6/r5nhU4vyr6DVQVzhX5KQ546w98rLPvNiFC/GSJMqDhDEByxuMCcNrgsCY2PD38ZA+u68XbnGDxxQF2UWPLqeEzqyqIvlcKpV7xp+mEcATOUGZnCowOhkfjVMwA5wYiqcH+FIjJ42iX72N+6SmRkvN/WTW6yq0iz5DjAGDYOMIIgCIIIQqC+MyYYJw7I+nTaRZvdtbAUbTZ1mEm0GWA+nrDs12mCh/HF0zgTAi7jzQDSTf2BSAozlvgxtJkSqqGtDFUxRtajU5UyoI40++BRy4tHPvfA2YUtlaGN10uy72FDU1va7DQeZmijtAEiTqg4QxAaIuk7wxPYBQYAjRhRvx8fSL8MAPjZzlOhd4SlMKwQbdY8a73He/uJNtMslc0YvoVIhKtnAucokxOMkJBEkWGCE5Ehc2z6/ZvOeJ3gJzeZx2WkWfwOMHZ/YlC/GYIgCCI0oug7Y5U4AMjvu+bRZhuXZfFi35EAgGuPYzdIm8QBcRzis19nBnKqYPWMEp+mNoB0Uy3j9P+Xj6QBv4Y2r6QBJ6tm/BjaKuoxppFmuudynfSHKx4ubm85MEhyhizSjMF/94abNhAYShsgHEDFGYKwxWffmSgbXH5i+DNorevBmx2H49mD/MBA7lRY8hiLNuvGsDEHCntthYZIMoWGiJfQkGElNHSQ0OiXRFaYSaLI4GmX7BNFhkim8Bg4NxnQ57/bEHJTy7Tl5YAcYARBEET41FLigE202d1b5wEArp6zXH9yEdm4giUOyLAwtYmJAzICrJ5JlKmNdFO/w/b/UxRJAzJs+3PqcL5qhpHRHeTTT0RDGz8HZNOjs3R/eGDlZM1ni9+9jZAXyoGSfhLSBtKSt3Xcb4YZ2nSQoY0IAhVnCCJifDW4BIwbXI5uyODa9IsAgJ/tPBMAG9TIREdelCx5eSLGd+XQm0rh9CuWeF8fALmDwitDWUJGsq8KhEYU8WYACY1aI4mFGRmhiQzbVTOh5SYnKdKMoSnWkAOMIAiCSDiR9J3hcZQ4UCKMaDNxYlOHR79Ohmm/Ti9Tm4LYTW0+UwdsId2UbJwXZmKMM+uPhrarZ68sbn+2GGnG8BNp5lFAH4jKe4BF2gAztGnvRaB+M0R4UHGGIASCuMBEoWGMV4NLYxcY8MnhT6C5rhcLD7XhxUNHGL8u/W5+9q/JONoMSKTQiGj1jIokFGhIbCSTSHrMANaiNnEiwwutyBBR5SYD4UaaAeaRZswBViAiBxhBEARBRIVt35nwEwd4Q4TMSOEq2kyEN4mE3K+TR2aEUfTrFInc1KbDoWYCqECTVCItzITcn1NE9/cUn6FNhX9D2+1XPlDc3riP6Ry+4GISaQaEHWkmvdeYQGkDhGOoOEMQHtj0nRFhlXdptIymP0ARywaXQ+sO4IqhLwAA/mfX+civmjFzgb316FwAwMrBPRg5lrkoTKLN2PGECA2OIELDeY4yEFmBBiCxkTSc///w4f4KS2To8CUyTHKT+e0yTHOTdQ4wm2KMrqmluG3Z1DJkBxj1myEIgiBCJ+AElq/EAdn9U4qq94yeXHcW92w9GgBw1ewVHmfH3K8zIatnVMRtagNIMyWNpBZmTPEytPFEamgrQxY9765HZ4qLNHtoVZvmAnWRZkDlvJUk0kxGQEObFZQ2QDiCijME4YeENri8Ov08Wut6sLxzPF4+NB02LrC3Xz4cEzvz0WanXfmm+ocow6Y3Q3UJDRmhOsEMILFRndS6yIht1YwSk9xkwN+EiQ6ZA0wWaUYOMIIgCKI2CLvvjBSvxAErZHGk/KOcvy6aht5cHY4esglThu0Rjpr065QRoF9nzKtnyNRGuCAJmklFIlfNmBjaGBWmWBeGNjk3zHu7uP2J+84VjuoizVT7xNWOHGl4p8x49SorwAxtfvrNEERQqDhDEDHgq1LvcdNpRA8+OCyvhG7acw70vWbkDN44FgBQN3ODx5myDOXqFxpJd4IBJDaqCT8Rc/1KZMjwLTJsc5NVEybitm3RRvZ9a9nUUgY5wAiCIIgqwEXfmWgSB3QrXvn7tnyF7MZlWbyYLUSbHct6K9iMGWSJAyJkatNCBZqawrlm8knNrZrJ6E5SpZ2YGtrk33m/v/TR4vbm/YMLW7pIM9X+UkKMJ2nhuUo7adIGvKC0ASJMqDhDEBIi6zvD3xwCusDOGvQaRjfswfbeNB7eNx/lgsIs2mzRA4Vos0F9GDspU9hrmqHMjlWv0DAhbicYQGKjGvDzOw+rkWWQOLNQRQYrrJqKDEZGd9A0N5khPncVaRagqWUazhxgDHKAEQRBELFRWJlp20g53MQBoHLiz8TQlr+f57qyuGfrPADA1XNcR5tZ9uvMCPt1pjYZ1WpqM8CPZiLdFD2hFGYcJg2Y6Cavf//JMrTJcNOjsy6VLR654+0ZitcAlSsWZWkDIpyhTQalDRBVDBVnCMKA0PrOyPDZ4PLyoU8DAO7ZewZ6lO4v/pHfzp+/avFYTOkAsqkUFlzuFW3mN0PZg/4gNLzwOQmvg4RGdMRSmIkgzkwkcpHBR5ppRYZJbrJq9YxfQmhqmRae+3CAqfrNMMgBRhAEQSSdMBIHypFNDLL9av76+tRitNnUYbsLe036dbLEAUf9OhmyRuAifOJAu/fpUZrakpI6AJBuiopQUgaARCQNJNvQtg/e8zL+enR+7rjSHNLnHxR/2SbxZeJ5LNJM0FCtkBuYKW2AqFKoOEMQfgmr7wyPYYPL0Q07sGDgIgDA3XtPR/5upZoc1NPafhgAIDtddtMyFRo8/ESpuL9ys4L+KDQA50v1AXKDRUEohRkvQnB/yQhFZNiSUR2wiTTz2u8q0syyqaUMcoARBEEQCad6EgfEaDMRVcKAOtrs+b7ZAIAPnrgSZjju1+nVi48ZXrwmdSM2tSU9dQCgAk3YhKaZfBZmourPKZJMQ5v/Hp0/u+Cp0mUeGlDY0he6y89hj4aRZrL5Mf6ewN8rNPGY1G+GiBsqzhBETEgr9j6jzc4d9BzqU1ks6piJ9T2HcUd0DS7lN8lX78tHm60ekMWEGbsKe71uxKqcUlnVpTaERrUWaAASG2GRNJEhw1QMRyYy2N+w14o4r+8EAN65yby44PczXEeaqfYpmloyB5jo9iUHGEEQBFFFJDdxQER2P9YlDjAK0WbdWfx18zEAgGtnL/O4UNkKXt25lqa2TOFR7NcpIwGmNhVhpg5QgSY5xKKZNLg2tOmoLkObXY/Ohrq+4pHfvjFb8b6iLjKNNPNA23NMDfWbIZICFWcIQoFLF5g2WsZBg8tTBy4EADy5/xToJwa9o83WLRuF6YdSyKVSOOFS02gzwFtoAHELDX7S2FZoBIk30xJxljKDxIY7/K5ISor7KxKRwaMTGb5yk2UiI2husg0qh21ITS3JAUYQBEFUGwlKHKjEf+LAna9ORleuAbMGbsHsUdsLe02izdjxCPp1kqmtCCUPxE9shZmQkgaq29AmQ2VoE8+p5BsLXi9uf+WRM4SjfiLNgJJ+4tIGTCLNKG2AqDKoOEMQhgRxgTFYZd5lg8umVBeOH5C/ET578NjCCbJJQHOh0bjmcABA9+TN3N6gQkO8qUcvNHR4CQ0ZToQG4GSpPrnB4sHv7zDJ7q9QREa75GRbkZHRnRQ0N1nERaQZ/0gOMIIgCILwi8vEATWNUCcNyO/jW1f04enefOrA9SfbRJuZjDM0E6j8bnGclIDVM7HEmwGhFWgA0k1Bic3MBgQ2Q6qwNW86XTXjxNDm1aNT9qjmB2e9UNze19Vc2PITaQbk57I85q+8eouNq9xFaQNEUqHiDEEEIWDl3EWDy6lNa9BS14U9vUOwunuScLJs2ajssZwX756HulwOawfkMHmOSdMXwD5D2YOQhIZu9Yxu0hkI0QnmheEkPbnBoiPI7y0OkeE3zsyL+FfNiM/95iaHEWnm1dSSHGAEQRBEbZD4xIFWfoPvO2MbbSaYLvpy+MuGeQCAq2e+DUB3fbIxibhPZ2qT6KeM5uNkRLR6RoazeDMq0FQdoWomL3wkDbjqzxnbqpmM7iRXhrZyw25zfel77CcvH6N4vWmkmWHSAKBPG+ChtAEi4VBxhiCSQIAGl5Ob1gEA1nZPATAAcpeBbKm+WmhsWjsMMw7mvx6Ovdg72qwp1Ys5LVtwxdDX8OWRD+E/xt2C30/4H9w56Ue4p+37+OPEf8V/Hfb/8Onhf8bs5reE10uizTIeHyniaPWMiInQcIYDoQGQ2IiCIL+rsAszruPMQl81Y0pGdSBobrJ4ju64DJUbTOUA84AcYARBEEQNkLjEAU/8R5vd/cokHMo1Y3LrThx7OHOhqBIHVOhW0zgwtckmdQ16WdiY2iKLNwNiL9CQbjIjdDMbEEnSgIzqWzUjPndraPvuGS8Vt//hiVOEo34jzQCpoU1ELMgbGtoobYBIElScIQgNflxgTGioXGDGDS5FFC6wMQ35Wc7NveKsnRht1ggboZFaNREAcHASPzLI34yH1B3AWYPW4ZujnsZdk27D69P/A3dMugX/OvYBfHLE8zh/yFKcMGANjmzZhBnNm3FM62qcO3ghvjjqDvyl7Rv448SvYHrTMnhOrAYVGu3yt61loQGQ2AiLoL+fuAozfuPMvHCyakb8WxZFBp+brBQZKpcpIO+FJXN9BYEcYARBEAThSQISB7yxTxzYvaYbj3XnneLXn+QVbSZOeOrGIWzylMejX6cJMoNMO7ft09QWBOepA4YEWZlBmklP6JoJiCzOLBGrZkzJqA6EYWgr8c1TXitud/TaxDrrIs0U+imNyu92lV7iCGxoo7QBIkSoOEMQFpi4wEwJ5ALjGlx255oAAA3Km6dXtJmc5/96NBpyOWxoAebN34DTB27EN0a9ijsn3YUXp92C/z78IXx4+BIc0bITDakcdvUOwAsH2/DHPcfih9svxFc2fwCf2vhJfHzj5/F/3v00frj9Ojyy/wR0ZxtwTOty/HnSl3HKgFcKnxaS0ODRDHyizlFOcoEGILEhI6jASFJhxpTqXDWjy01WiY2oIs0AZaSZCDnACIIgCEJOgMSBPLJoMx6Z0UITbZYF7liX7ztz1fSlSBWjzYL26xTx6NepijSqgtUzvkiIZiLdVE4kZjYgUYY2L83E/93o4gKLBDW0lSEztMkwKRbLXgMMaCwd+/6zJyhebxJpBjgztPH3BpM4TBsobYAIASrOEERQwnCBWQiN3X3DAACTm9YX9qhuaKpoM3GysQEDUj2Ytmc/vvBIDv/3pl784eAT+MX4x/GR4W/jyJZdqEsBa7vS+HNmNr6y+XycueajOHXNZ/DJTdfgB9vPxy17jsfD+2fj+UPT8dKhqXjiwBzcsuc0fHnzZ3HO2v/GcwfnobWuC/922L9iXINiJjBTeFQNNByunhFxmaNczQUaEhsJEhgOqX6RwRNObrIZtpFmmpWLaZADjCAIgqhqYk0cGCM8F00OaXjM+YmJA4BN4sB9Lx2OfbkBGNe0F6dMNr03m67eFSdUDU1turahAVbPuDC11UrqAIN0U56gv4OkFWZk2CYNaAnL0FbxvSAWdmWRZgz+uVmk2b+d+2xx+/vPnOhxkTpDm0xbcYY2GbJ0GQPYvYXSBoikQMUZgkgKPhtcvpg9Gb25ehzRshLHtrJCB78UVIw2q2RQXRdOGrARXxr5Gm6b+De8PP2P+N/xj+HkN7KYtgWoywHruwfjjswMfG3z6Tht9Qfxvvar8L1tp+Ph/VOwrXcwADa40AuNnX3D8IVN38KSjpkYWr8fHx/+p8IRD6GRKTzaCg2eGISGb0hsxIqLn9tpYSZBcWZaXIsMhtPcZB6bYoyMCCPNyAFGEARBVBGRJg7wyBIHjPFKHJAbM/Zv6MaDnfMBANefaBptBpTGLDL9JItsNVw9I+Jw9YyIH1ObirgLNC50U38kUbrJYWHGRZxZ9RjagMrvJt05peOfO67Uo7gnW1/YUqUN8Ki+V2XzWIWnacnb8N/5hmkDplDaABEVVJwhCA8id4G1cdsGDS73Zofh3n2XAgB+evjXcMmQR9CY6hZOLN3wmlPA3JYt+GB6IW4c+zDua7sVr07/LX474T58asRizGvdjoZUDhu7B+G+rsn4xUUpfObz9fjWhOPw3W0L8OD+KdjZN0ByMSrHhUgHegD8584PAAAuHfoIGlMypYBwVs9wk8eJFhomRFSgAfpXkcbFz5nEwowpsYgMEVU0R5F9MM9AFGNE/BZkgkaaARWRZjLIAUYQBEHUKjEnDpQQo81EVIkD4r7SeXesyk/kXj75bTTU9RX2mkabsX18tJmIYtzj0tTWzm0nzNQWRYEGIGObDa6KMknrMQNE0J8zcYY20cgmJgzI9dPIAYeK219+RPVHqip6A5WRZgamNq+Cu9iKGZQ2QCQfkw5NBEFw3LdKPYBIPWM34T4BG+2bajMGojiB+YNt/4jpTStxVOtS/HDcjfjOmGYs65yM3X2D0Z2rQ2vdIQytO4jxjTsxplHuuHq3ZzBeO3QYXj00Fq8eGofNvfkb4zUT/4zdQ3pw9NlvYfGzkwpn9yL/9dGD/A2VPWePDLFA0wF2w33l0Bxk+gYjXb8f05rasbzLIMMng7zI2gH1BOY2VEYbGLIRE7RuitWYWpz4BPIDNDYByliCo4qToTyLh8+STo7mTtc4Ms6C90DgPGgLhQz2bzaow4MNwGvRUe9KSMVdmFGRaJEhOsB4jESGiM6JClQKDvYaU/xEmknEBjnACIIgCMKe8TnvRtqjUFmcaIWixsFPEO4X9pmNDx5+eTR2HjUEIxv34exp6/HIyikGr2JjEJ27nJ3DjyM6UDGuyEBTjOLYicoi1haUJjTbIR1vbF87QTnBuQbTysyHomYCzHWTSjN54qWbDDUTkB/Lk2ZSE7lmAgJHQCeqPydPO7cdmqHNFHtD288vfKK4/dNXjrH4rEaoCzYGpIXnqqkkTdqAytCmhdIGiJCglTME4YKYXWAduQH40IY/4j93fBnbe0eita4L8weswLmDX8NFQ17BWYPewvwBa4uFmV29A/DMgWn4n50L8JlNl+OU1Z/AuWuvxz9sPRv37JuBzb2Di++9/838IGPrYbuQQtbwJ1K7K0qksKIzX+yZ1tzOfhLuhyo8Zgw/MoLVMzJid4IBkbrBGLXiCmM/h6vVMlEVZnQEiTNzIjJ4dCJDhS+R4S432RyHTS3JAUYQBEHUAIlNHBgIs8IFAL/RZh1benD/oWMBAB8yijYT41d5BztbPSOrJHlMtorjKFagko2/HEZCJyJ1IEEraBi1opkAtz+L88JMyHFmXv05RYwNbTKcG9rYd4nOtObP0HbV7FIVM5tj08smaQOq1TN8pBmXNiCTVKJZ18PQxvDqZcbuSbbxmwQRBCrOEEQQDN03QKXQkOKn70waQCvQgyb8evencMaaR3BZ+2/x5Xf/Ad/b9kncuP3D+OetH8VXNn8C16z/Ihas/jZOXfMP+Oy71+J/dp2CZw/OwG5pTFmep/5yFFqzWWxrrMOCi1wKjR7szQ4CAAyo60CkQkNDHEIDqM4CDVC9gsP1dUcpMIBo3F8izlfNqHKTeZTJZaoJCx6b3GQv/ESaWTa1TKP8+z1mBxhBEARBBCX0vjMmpP29zE+02e0r5gAALpn4NprrRfe5aqWu6VhEHPt49OvUEYKpzQQbU1vcsdCAmz40DJeGsKipZt0UVdKAsaFNpZn8GNoYGdUBE0Obqkent6Ft0tC9xe3r7z5fcQ26SDP+uUGkWRqVhjaD8BVmaPN9D5EY2ihtgHANFWcIwgA/yxRFF5gIq9h7Nrj0coFVkMLKrql45MCp+HPmAty65yL8de/peHj/fLzVOQ17+tIAUlDfGMuFRseBZkzdnf+gSacu4873Ehre9OYaCp94iNvbP4WGJyEUaMIo0iRZcIR1jUkvzEQqMniciwxZEddvbrJIWJFmgHVTSx7HDjAbaHk+QRAEEQpRJw4oEfvOyO7XYvSOnidfHo7NueEYUt+JC2et9ThbttJXNkaROd4VprZM4bHKTG1JTh0A3GomoDrMbWHoJmv9GWJhJqr+nNbYGNr4vrxaQ5sXfDHGXBP9/tKSw+uPS2x+n40o/251GGkmSRsgiGqAijME4QOXLjBGeA0uZb0ORPeXXnRseyV/s20ftR+NzV7uLtFxoTq/A0Pr88JiX2EFTTULjUicYFUgNoBkFWrCvJaoBQbgxv0lw5nIaJfs8ysylJhWbXW5yVFFmvloahmiA4z6zRAEQRDVgO/EAbHvpCJxwBt/0WbdO3tw7/7jAAAfOF6WOCCa2mTjETFxQKS6Vs+IY0ybfoehpw7ElDzASJJmAsLXTVZEHAENuOnPKRLPqhlVj06doU2GPtLszMmbAABdvfXIgf1spmkDUJznM9JMBvWbIaoIKs4QhCtc5ff7iTYDDIRGK9TZnnqeuXsmhvdmsbe+DmdeuVRzpkpoyKLNgJH1+aWwe/qGImlCw2syOhYnGFA1BRoGP8iPSnhE8XmhCAwPXLm/qkZkKHOT2XPVChj/uclybCPN+NcYkhaeh+0Ao34zBEEQREi46DsjYpw4wGOUOOCFfbTZn98+EgBw4eHLMLCxu7DXa6xh0q+TTa4yqsPUJiMxqQNArMkDPHFrprA+09fvLMFJA4lcNSNlH+wMbXaRZkeP2V7cvuKOixXva1PkboV2fiot2cd/x3ukDbB7B/WbIZIKFWcIIigGef3JaXAJlPdCUD2W3xizvQ0Yvy3/IcPn8/ZqMdpMRO3ESCGLiU1bAQAbu4d5XnUUQkPnBPMrNFT4doIBoRVowizSMEQREEQIyN4rbDETmsAAnC/Lr02RISI6wHjsRIYZJpFmjppayqB+MwRBEESVUTuJA2a88MoQrMuOwYC6Hlx+tH4i0K5fJ49i0rVGTW2hpg4AiTK2MVzrnKg1E+BTM1VR0oCI+HeQbEMbUGloM+/RecdV9xe3H1w1WX2NFYhF7gYYLmf0nveSGNqk9wobqN8MERFUnCEIQ4L0nfHCfYNLmdAQMYk2K01ErnoqP2hZle7EkOGHJOfyqDKUS0JjQuNGDKjrQme2ERt6xhbOlfWUQGRCQySueDMgngINEF2RhkclPrz+ixpfvxcHhRkdNoUZGYkTGQylyJBNUIj4y002x1GkWRr6SLMYHWC0PJ8gCIIIlTgTB9LwmThgFm3Wu7cXf80cDwD48HG8FgzerzOPQaEmU3gUjS9VamrTEWeBJmrNBFSHborLzAZEmzRg8u/cmMgNbWJRmH8UtxnsNTnMGJG/mLV7hgLWkWaq+SeJoU0G/52uioM2uTcQRIKg4gxB+CR2F5h1tBm72fmLNnvt8TYc1p1FV10KZ1z9JndEJTT0guPIlg0AgNXdE5BFPapBaMgIywkGOCzQ+CzSEHl8i68IRIYNQUWG7d+GL5GRgQJRZLjPTZaTvEgzcoARBEEQtUzoiQNW8EUZ9px/5LdL5/3hjdkAgLNGvYPRAw94fIZJv0425hGPKaLNRDKaY2RqK1ElxrYkk1TNFHXSQPUa2rzTBs6b2l7cvvKO9yne21GkWSu8U2I8DG0Mdg9RpQ1ojdXUb4YIGSrOEIRLwnSBiQ0uRdI2H2ByoxRvkHUYsTF/EY2z11l8llxoHNeanw1cdGhqxbFqEhoyXDnBjCCxERqBijI1KDLKiExkGB0QEIs1UUWaAU4cYCrIAUYQBEFUCUH6zngRTeJAsGiztxY2YVHfVNSncvjg8e9ozjTt1ykiTrxK+nVmPC6STG1yfPaJ7M+aCUi2mS3spAGRWje03Xvd34rbb2z1mqTiCRhpxqNaMSPB9p5B/WaIOKDiDEG4IIq+MzxWLjBVvI4sjkcfbfbG/fMAAKsG9WLs5IzuQyEXE6XBwHEDVgIAXus4QjinNoRGZE4wIBKx0Z8ER6Cf1+Z3TCIjj1ZkiFGHzDUqFmr4+EQvTM5RoVohwztrdXGSUDvALCLNGGE4wAiCIAjCNdWXOMBOUBkw9NFm2Y4+3LH1OADA9fPe5o7479epPsfD1Kbq18mTQFNb7AUaMrYZEYluiijKDPCXNFD9hjbA1NCWQg5N9VkAwJPr+J87aKQZ4GloSwvPKW2AqCGoOEMQFtgsW2RCw9YF5r7BJYNvTA1hvxkr3hiLqYeAbCqFBe9/gzvilaFcLiLGNuzBtOatyOZSWNgxgzuHH0hUl9CI1QkGhF6gAWpfcAT++RJamKldkcE7wHhsc5NVmIiMRviONLOOVikRpgOMlucTBEEQkVBViQN8McYs2uxPr01HT64exwzZiCNGeg2IvPt1ymOKkm9q8xtvpiOSAg1AxjYNgYsyIffl9MJv0oAX8RraxOe6IrD4nWPGJ+e/Vdy+4Z7zFWf5iTSTFGRkNRpHaQPM0EYQSYGKMwQRgFAr56E1uATMb5iVhZuWNeMBAF1T3zX4HLnQOH3gUgDA4s42ZPoaUA1CI/FOMCB0Nxij1gSHk58n5sKMCj9xZiKhiAyRjGynjcgA/OYmmxFSpFlaeE4OMIIgCKLWSWzigBhtJmIXbbbx7Rye6c2P+W9YwE8E+uvXWU71mNpkBE0d8CIJBRqg9jQTkCwzGxB90kCyDW37JAdkkWYML0NbpfH2f9/3eHF7077BhS0Tc5rPSLM09JFmHmkDzNAmpg0wqN8MkRSoOEMQrrFwgYk3Bylt3LbTBpeAn2iz5/96NOpzObS35jDjPbKRhbfQOGNQ/ub39IE5ktebLsNFpELDhNidYEBkYgOofsHhrChj4/wKUJjR4dL9FYnIyBQeD2rOASCfbHCRm+wHk0gzwDrSTCyyy9y+mmK9GGlGEARBEEnAz8RV5IkDrRUbwkGTaDM5ud4c/rz+PQCAD8xaghRMfx+6sY3X6hnv3RUkxNSmwo+pzQhbzeSgSFOtusnZ9Se4MOMiaSBUQ5v1qhkd/gxtrQ2l76Rfvj5XckaQSDMP0sLzkAxturQBMrQRYULFGYJwhUGDS9EFxmCV/OgaXPqPNnu3fThmHsiLkXnve5M74uX2yg8C0vUZnDww32/myQPsps6EBi9CxF4T3G4T/AoNbvLZldCwdYJFXqBxVKSpBsHh9FojFBhAcPeXjEhFhopM4bHsb1tcPWfyh2+TmxxWpJnjppZtlbtEB5gI9ZshCIIgkkgYfWeK+E0cMEYXbSaOGUpFm7++OAn7c62Y2LwbC9o2eXyGvl+nfOxiMCmbKTyGbGoLq2cnkIDUAUZAzQRUj2YCHGumBBRmVLiIM6vAtaHNE3HehK2uc2do+8HZLxS3v/boaYr3sIk0A8rnp7i0gRAjzYxwFbtJEIZQcYYgLPHTd8aW8BtcAkGizXrfbgMA7Bq/HUDW43PKxcR5g5ehMZXFss7DsLZ7BBInNARc5CirCG2pPhCL2ACSWahxfk0RCwzAjfsrcGZyUJGhyk1WohMZDD6Xnd/nEpul+jzJiDSjfjMEQRBEYklk4oC7aLNMezce6DwWAPCRBSu4I3b9OiuRRb3Ga2oT8duzM5bUgQiTBxhJ1ExASLrJhhALM0GSBqwMbSZpHFVoaPvyiYuK2wd7mgpb6jmjErIVMzLDsEBass8i0owhRpoxozQztBFEEqDiDEEEROoCi6rBpUxopG0/xD7a7Inbj0FrNottTSkcf95a7hxvofG+wUsAAPfvO0byGdUnNGSE7QQDQizQOBIcQPngPmrhEdrnRiwwgPDcX74zk1WEJjJkyJxf9rnJ5iS7qaUR1G+GIAiCSApRJg54RZsVYaYKcZ9JtJnCyJEDbnvnaADA+9veQlO9SdIAgzeg8M532djIY+yUKTxGbGqLIt4McFCgASJPHuBJimZy+tl+zGwRF2ZMCdSfs53bjtzQxhPM0DZu0IHi9hcfPsPgFbpIMxPjWwFf81t5TO8R1G+GSAJUnCGImHDS4JInrTrgPtps755WTNszAAAw7YylBq/IDwIOb8xg/oCNyOaAB/bNQS0LjbCdYEBIBRrAqdjgEQf/LgSA7D1DETV+RFjIhZnY4szaue1IRIasaMvjLzdZjWmkGf8YTVNLBjnACIIgiGogEYkDPLJoM8NE0sqijMzUVhlt9sgLI7E5Nxzphg5cPJs3tTG8+3WqESdhZb36DAnB1GaCC1MbEEOBBnBepGGEpW8i0UxAKJoJcNubE4ixP2dkhjY+0ozH3tD2u0tLVYr/eXWe4iIdRZrJ4L+7qd8MUYNYlCwJgvDkESgHI/N2L8fi4bMwF29pB5YTsBEbMQGjp2ysnCAdh9KNfSRKN/FRqCxItMJghQn7CujlnvcIj73I3zjLb+rbXp0JXPAm1o/ei8bmHvR0qYo7vcX3ed/gZQCAVw5Nxo4+0ZnGwz6L3aD3odLJZshOVDrmtqE0CboFpZt6O0oToptSVo71NZimL7BpWIKjlJOqi4fP0ro5cqcbiFk24LVZ0cX+HYfcmyJpS/kr8Cu4YirMxJaZ7EdkMJTfU7J4Q3e5yWbohkkOm1oaQg4wgiAIopq5b1WIY7/xucpCwRiUxii8dgLyZglPs4gIGxeYjyu6d/bg7v3H4/NDHsaHT1iOvy6ZUTjCNJIK8TMauH1sjMRPZnZwz7lttplBfvxxEPmfnT3fgcqilUw/qTSTwPa1E8omRzdiQtn4RaaZVmNqRWzdSsysWDUF6HWTDiPNBOTH8LYpGOchkn5+NambHBVmwoqAdt6fk6eKDG3nT1sPAMjmgL4c8/iHFGnWCm+tNEayj/rNEFUMrZwhCB9E4QIr4ucmk7Z9gWypvj7a7Om/Hol0bxZ76+tw9lX8AFmMNmPk8L4h+eLMffvmSI6z1TO6fjMdlZuZwqPf1TOGeK2ekRGFEwwI0Q0GhOYIqwr8CoyEF2YSIzIyFWfCqN9UBf5yk83xE2lm2dQyAf1mCIIgCCIWDCbCkp044BFtBuAPb84BAJw3dgWGtXiNbcSxDEM3holo9YwqdcBy9YxpvFkssdAAaSZb/P7sIRdmgmIdw1e1hjY97xlXuuiL/3SZwSscRprxiAVjRpv6LQKlDURQcCUIBhVnCMIBofadYbRx29YNLlVCg6cRNg7wbG8Dxm8dBgBIz1/pef7clo2Y2rwLndkGPHbgiMJer8FCsoRG2PFmiS3QAP1LcCRYYNgUZmRUh8hQ5SbLRIa/3GQ1IUea6YjYAUbL8wmCIIhYseg7o4KtyPCMNuP1kqzfm1G0mSyyxyza7LXXBmJZdiKa63pxdZlu8u7XqR7fsLFRj7BPsl2lpjY/ONVMVKTRE+RnTZhu8hNnVjuGNn2k2d3X3FvcfnDVZMV7hxRplhaeGxrabNMGbAxtlDZAhAEVZwgiBsQGl6ILzFmDS0+hwRpcmtxIK4XG8kfnAQDeSXdj1Dh+QrVSaFybzv/MD+0/AgezXl89CRIaCXGCAQkQG0BtC44ECAyXBBYZthniTkWGbBWdCvvcZDNCiDQL4ABj9wXqN0MQBEFUE4lOHPDVbFrUTt707e/FHTuPAwB8eP7bBq+Q9YjgzSmyMY3puMkDMrWVQ5qpkqCaqQoKM9ZJA1VvaJPrpBRymDh0PwBgybaRANjPGWGkmaygLuLK0CaBDG1EFFBxhiBcY+ACMyVwg0srTKPNSrz57ERM6gR6Uimcdu0i5XlD6zpxweD8Xe32zDGFvfxAgR88+HGBGOBXaAhE4QRLvNgAakdwnIfECAwgYSKDp53bjkVkiPDfHZA8itt+cRBpJpIWnjuONKN+MwRBEESSSUziQFr14nCizf74+ixkcymcPGwd2tIZyRmimcR0HCPrMyFJHCBTWxm50yNIHgBqRzMBwX8Wi99j1SYNmJB4Q1uJDxxV0hVX/+V9ivdtELZDjjTjv9PbvN9KNLRpoX4zRExQcYYgfBK7C8w62gxwHW0GAAPWHgYA6JmhnkC8bOgKNNf1YVnnKCzpNLE+ALUmNGycYF4kpkADVK/gcHHdDgUGEG1hxpPYRYbsuMwtKotE9JupHFKkWUwOMN19hxxgBEEQRLUQWuKASMjRZmveTOHFviMBANefzE8SehVhRPe7l6nNa0xlQAJNbWEUaICIkgeA4IawuHBx3ZZmtigLM87jzNq57ao1tJWe/+GKh4vb7+warrnICCLNDBHTBkTYPcWPoY0gwoKKMwThCL8usGgbXKr2+482e/4v81Gfy2FtK3Dkce9y78Fu6t24Jp1fvn975iiUlsLWttAI0wkGhFCgcVWkSbLgcHmNCS3MmBK7yJDCiwxZcZZHVqSRTW5USaQZOcAIgiCI/khVJw7YjQNy3Vn8aWM+QeD6OUsAeJkwdP1mVOfzJMPU5ireTEdkBRoguGYCqks3BSUBmkmFn6SBCqrG0GZPc33p9TctPpI7EmKkmQj/Xe04bYBB/WaIJEDFGYKIiUgaXKZh6AJj2EebbVmXxsz9+XPmvu+NiuMnDNiMtqa9ONDXiAf2zUBVCY12bttSaMhw6QQDHC/XB9yIDSBZgsP1tVgWsqIuzDjPTDbBj8go0xI6kaHKTeYxz022I+KmlhrIAUYQBEFUM9WXOBBOtNlfXpiCjlwTprdux3ETtnJHKvt1lqPST8zUpouFjtfUZoKL1IHICzS1ppvOg3sjW4SaSYfLpIHqM7QB5frJvEfnD85+obj9dw+dqXhvx5FmaZgb2hjUb4aoAag4QxBh4NAFViTB0WaH3pwCANhy+C6kUn3ckV5cm87/vPfum4lDOfErJ+FCQySkeDMgvAINEKPYACoH+mELj7A+y8fvJYnuL2vaue2qFRkmBIk0c9TUkhxgBEEQRD8hsYkDIUeb7VjTh0c63wMA+Nipyww+SxzLiIkDsvN5jVS9pjYVYcVCA5aaCXCrmYDa0ExA5JoJiClpgP933q54E14zxWJo478n7Hp0fvWkhcXt/d3NyvPsDG2As0izNvUhdm8Q0wa0BmlKGyBihIozBBGASF1gjDaDc9KqA+FEmz3x53kY1JfFroY6nHp5qfg0uuEgzh7UDgC4Yy+/FDaBQiPCeDOXS/WBEMWGa8HBkBVsTIWB7rVhiRgfAiPMwgyJDFuRUUWRZgxygBEEQRBEkcgSB6ywjzYDgJtXHA0AuGbKkrLYoBKig11cVSOj+kxtQeLNEmNqA5KrmVy83pYYzGxAOEkDnuj+fev+LoCIDW2AH0NbW3pvcfuzD5zNHdEZ2iAck0WaKb4zHUWamfYgo7QBImlQcYYgHOLaBaa9uSQo2qzjYDOm7hycv6yTVxT3XzP0HTSkcnjt0Fis7GIXKZtMVb4zIhMapsTkBItFbADhiQ0VcRRfVIQkMAD3hZnQM5MjFRkmqIq4fpCJjJiaWrZV7hIjzcgBRhAEQdQUiU0cCCfa7MHnRmBzbjiG1R/CZUev5I6otJFoSpH16xSxNLVBeB6BqU1Gkgo0iSrSyPDSTFHqJp8/e5ILM877c8ZmaBPnXswMbX9+/wPF7V8tVP1NyTSRuG0ZaWZLQEMbpQ0QSYGKMwQRA0xoeLnAGH6jbdTRZoDraLN1T+dXxqwZfgiDhnegMdWHq9L5ycM/7pkN9cSpbADhSGhEsHrGBBdCA4i5QBN1kSZOQhQYgNu8ZCDGzOTQRAZbKScr1IgiQ/Vdocpst0EXaQYkraklOcAIgiCIJFN9iQM6/Eebde9J4fY9JwIAPn6iSbQZYD6eEZ30hqa2g+pDSTC1xZE6AJBu8iSAZgqzMKPC9N+R8/6cPOzvSTYv4YkrQ5sXOZwwPt8Ta/vBVmQroul5gkSaSTBNG2jTXJIfJIY2ShsgooSKMwQRFpLJsEgbXKZVb6K6GXpFm/Guh0qh8eJDM3B4dxaddSmcce0bOH/wOoxs6MSWngF44kCb5PNUk6oyfAoNHbZCo53b9iE0klagIbEhIWSBASTY/eVCZARG9rcsOsC8CFKMAcwjzZLb1JIcYARBEEQ1UDuJA/bRZr9/LW9qO2vkShw+mJ9wFSPMxHGNztSmSxvQmNpEMoXHiExtJvFmKsLWTIBPzQTUtm4K8LOFrZkAuwjoSPpz6rRSpvAYyNAmQ5xr4Z/LIs1K575vxtrS9m2Xcee6jjRrLXsokhaeexjaGIHSBggiZqg4QxABcT25xW4iWtok+8Tl+SKBo828qMOI9aPzr569Dh9M513ct2eOQB+yhXMiFBqZwqOr1TMiITvBSGxESEQCIynurwrEwkw7tx3aqhmG+DesExlAeZHGPDdZTbIjzQJBDjCCIAiiBog/ccB9tNnSRa14pXcm6lM5fPTUtw0uLkRTW6bw6HL1TDu37aGZZCTJ1AYEMLYBtaWbAv4scRZmQjG0tauvschOybavVTM8svkRnaFN9t2h/j6577q/Fbdf26yqhphEmsnOU5AWnouFcw2UNkDUAlScIQjH2LrAmNBgFX2G6ALzbHDJY9zgUteYTYw28xYar9w9D6lcDsj0YW7rTnRn6/CXvTrnio3QYOczHK+eMREd7f4+hqFy5lS92KhGweHg2l3GmEXp/jKOM1PhRGSIf78mIkPMXOcfxW3ZcxvEeBLx+zDcSDMGOcAIgiCIfkGiEwdU+I82y3VncevG4wAAH5nzFgCZ2c/L1CaSoNUzIlWeOsDwrZmA6tVMgBPNVHOFGRGVoU1HpvCYIENbc33pu+Wu5fzvSDYH5GVok80xab5TKW2A6MdQcYYgwiTxDS7FfbpoM8BLaKxZNg4zDtbjgoX5lTIP7p+MPX0tsguAXGiozhMnb0NYPcMTsdDwS2LEBlA9hRoH1+hSYAAkMkqYrHiRneMnS1mGSkSI34HhN7UkBxhBEARRi4SVOMBMDFLaJPsSEG325+cn4WCuGVNbd+CUtvXcEZW5RBzviIkDXuczHKyeCcHUFnfqQCTGNqD6NJMD3WRKVIUZGab//oq0G5xThYa2fzrtleL2p+8/R3I9skgzkx6dfJ9ji7QBWTpMm/ylgNrQpoXSBogEQMUZgoiYWBtchh5tBjQtnICTlueF120ZfrJYlnEqPpdFm4lYCA0dbGCkEhqmOcoxOsEA9wWawEUaIHmiw+H1uBQYQLSFGU/auW1VUTIUkSEeU4kMXlyoXKMql6kMVW4yf1y3LN8g0izippbkACMIgiCqCVeJAwxfiQM8oUWbVY41dq0F7us8HgDwiVNXGFwcYDa+UU3gak7nicnUJiPK1AEgYmMbUPOaKar4Z5uUAcBn0kAQQ5usmJkpPEZqaPOGL87sPDRAcZbDHp2tMI80k3yHs+96bc8xlO4dvg3SBBEyVJwhCAeIk1yJbXCpxF202ahHWtCQBVYeBow8b5fuQzlCFBqZwqM48OEJkqNsQJhOMMBtgQZwJDYYZyF64RHCZ9oWrqIszJhiJTJ4QhcZYja6Ct05QSLMRMSiDP892ADrSLO08NyrP1gBr0gzLeQAIwiCIKqNuBMH0qo3cRFtxh8vcfOyowEA75+0FAMau7kjoqnNpl+neJ4scUCzekZGwkxtOqIu0DjTTaSZlPgtzISWNNDObfvpz2mMa0ObLNKs9J0xaeje4vZXHz2NOzdoj04DkhBpRmkDRIxQcYYgEoCTBpe+o81ahX3+o80akMUVrWsBAA8dW4dxp4kusBiEho6IcpRluHSCAW6X6wOOxQaPKAL8igHV+zgWM35+D1EXZmpTZPCI3wHmucnmmEaaGZIWnvtwgHlBDjCCIAii1gk1ccDLLBFBtNnjzw/FuuwYDKrvwtXHmKyeMe3X6Xj1DE8Mpjab1AEg2gINkHDN5Pq9NEStmXREkjRgQiyGNh3q74/b3v9gcfunrxxj8F62kWaatAHALBK6zeAcx1DaABEFVJwhiLAJ0wXWZnBuWnjuW2h4C45zBrdjdEMHMvWNePmIFFYO68CwURmDzwlZaLBLiFBoBHGChVWgAfyJjVAEh4iu2BKBmBDxW5QJ0/0lo3ZFhira0C43WY3rSDMJ5AAjCIIgiAoSkzjAY5Q4YBptpnKRV/br7M0Af959IgDgYye8rbi4sExtEjLqQ3Ga2lQkrUCTSM2UYN3khdf/p9iTBmI3tImI+snG0JbDyRPyP9CuQy3ozdYrrivCSDNeO2liKb3SBrRGaEobIBICFWcIwhE2FXVnLjAZzqPNAP2y1dKxD6bzxaY/bpuJcb1AV10KZ33gTcV71vbqGRlhL9UHwhEbQISCIwH4+TlNf6dRL8sHql1kAGYiIwgNwrafSLPWsociaYOPbzM4xzHkACMIgiCqFevEAZtoM0CTOKBCtmrG3Nx202uzkM2lcOrwtZic5gdXKrOJzfhHZ2rrUB8CwjW1OYw38zsRb6qZElukSQBxmdmABCQNmBC6oU2MNBOP8Y/idvnzS2euKW5feNvl3DkxR5rJKHynm6YNMLTGaDK0ETFDxRmCCImgDS5VLjC3DS5dRJvlmdWcwfwB29GTS+GOzGwMWXs4AKD7iA0oHyyELTQUhzKFx6BCI0COsgqXS/WB8MQGULuCg/1cYRVmluAop4UZU3zHmalwLjJEZLnJ/DH+kd9WFXtlyERDRJFmHg4wsaklOcAIgiCIfkUYiQOMNoNz0sLzCKLNVi0GnuubAwD4+Kle0WbiWEhlYOuFXi/FvHrGgKSkDgBkbBMJUzMB8RRmrKkxQ9s9195b3H71XZPJpYgizfhCepvBZRkQqlGaIHxCxRmCiJlYGlyGEG32wWH5pfiP7m/Djr4BeOZP89CQy2FdKzBvgZerQSY0RAIKDR02QkOkXf/WJk4wFWEXaAB/YgOoHcER5OcI2/kFhCwy2jXHIhMZ4go4Ga5WyIiI32e840u3LD85kWbUb4YgCIKodUwm0kRTmxZ+3tFX4oDbaLNcdz3+sH4+AODDs99CCt3ce6m0key5KiZadNwbrJ7JFB5NTW0qg1tAU1vSCjT9vUgTVDeZEFdhxnl/zsQZ2hiVumpAY2nfH5ccIbuYAroCtI9IMxEfkWYMlaFNi0F8JqUNEFFBxRmCiAKJCyyyBpdW0WaAn2izdH0vLhqcvxH+cU9+4LVtwwgckWkCAMy68C3IBww6oSFGm4lELDRiijeLqkDT34o0Qa87bIEBRCAyeGIXGR2QiwxeXPDPXcWbyQovsm2DSDORtMHHtxmcYwD1myEIgiCqiaB9Z8TEARFt4oAM42gzhSmjOGkpW3XrvZrmjufHY29uACY078HZ002jerxWC0doauOxMbU5TB0AghVoojK2VZtuCnrdNjoziG6SYRof7jzOjIf9vWRMXxCWoU0dafYf5z1d3P7cg2dz55j06JQ9GpKG4rvWGzFtQIVpDCaD0gaIuKDiDEE4xFVl3WmDSx7PaLNWbp8q2qxSaFw5dDma6/rwdudILO4cXdy/69X8AGr92AyaWk0nUBMqNHS0C88jcoIB7go0gH+xAVSH4HBxjVEJjEjdX7ZUgchQY5ubLB7XEFFTy6CQA4wgCIKoVkJLHOBJC8+tEgcAP9Fm+9q7cc+hEwAAnzhFjDbzinBV9esUSaCpzYDA8VMFTCb2ozC2AcnXTa6uz0YzhaGbZEQSZ6aLSk+goe3T80sFjH1dzYoLV/XodBxpFkfaABnaiARAxRmCCBG/LjAvfDW4TAvnOcpQrkcW16WXAgD+uGc2gNIk8JN3HIGRvVnsq6/D2dctFV7pQmgkaPVMu+R9OMJyggHuCzRBxAZQPqCPW3S4vA6b30tUhRkZnv/W2oXniRcZgLuVMipUTlfHkWYaTN291G+GIAiCqGmqInHAbbQZ0Iib3syPHS8Z/zaGNu9XfTCHKl1AROxREYOpLQHxZoDbAg0QzNjGqFXN5MrMBiTI0Gbbn7MKDG3zx20tbl9z50XcOTY9Oh1EmvHItFOb99uaQP1miKRCxRmCSAChNLjUTQimdW/GhIaMSof5GYM2YlzjAezubcGD+6eAvyH39bXi8HdHAAAGHLMGZpOqNkJD9dxSaGQKj7zQ0E1GM0KINwPCK9BELTYYURVrxM9x9Vm2AiMq5xdg6P4yjTNTEZrIMEWWm+xV3DUhhkgzWVNLcoARBEEQ/ZCoEgek5ofAiQMygkWbPfviQLyTHY/Wuh584ARV7xxbU5vsPEaEpjaRduF5lRdoXOmmKA1ucWsmoMoKMyJ++nMm0ND28IfuKm7f8fYMyRm6SDNZ+oAhaeG5WBhnSL6rxUgzVdpAUEMbpQ0QUULFGYKIioAuMOsGlzz8zc4z2ozf5x1t9sH0mwCAO/ceie5c5U35jXuOBgCsHNSDCbN2e/wAMQkNHtlAKmKhoaNaxQaPrIhiIg78vs4vSREYQK2KDObk1P3dA+5WyiQk0swx5AAjCIIgaoGwEgeKRJY4AOTHD15jjHKyB/tw29bjAQAfn78U8n6dIqZjJIerZzKFR5emNgnVVKABwtFNXtonabrJhigLMzJMUy2KBO3P6Un0hrbGuj6MHNAJAHhx4zjwCSjlqCLNgMrvOotIM13aQFSRZgSREKg4QxCOMaqwa4QGQ3SBMYwaXJoKDSsqo82mNe3EiQM3oS+Xwu2Z2dJXLXtjLGYeTCGXSuHE9y+GXGiIg44YhQaPi3gzx04woDbEho6oRIQKPz9vUIGhw3dhRqRdeJ4okaGKLhRFRhjxZqrcZAeRZjLIAUYQBEEQvgk9cSBQtJk4dlAZPuTRZje/MAO9uTrMH7IBs8bsUn1wAdkYqXZMbTb4LdC4Th4A3KYPmFBtusn0d+qyMOMraaDd+yVKqsTQ9rWTXy9uXyuNNLM1tPmINEt7vyTUSDNKGyASAhVnCCJkTPP+mdDw7QJrs3sZ0tC4wHihIVK68V6Tzl/rEwemYEvvYKiERvat/MXtmrQNqVTW48ISKjR0ROwEA9wUaJJepIkavz+fi8JMKO6voHFmPKGJDBmqRrb8cd1zE9R9tJxGmpEDjCAIgiDs8Zk4IEabaQkl2gxQR5t5s355Fo9351MHPnPGCpCpzTx1wE+BBqgNY1schGlmC7swY500ENTQJptvKCMeQ9v/PfuF4vbGfbrvNIZpj04P0sJzVdqA5jvay9AWFDK0EVFDxRmCiJIoKvP8TUzlAjOONmP7KycyG1CHCwbnJwjvVKyaYTx229EY1JfF9sYUzni/eMOsAaEh0i48D8EJBgQv0AD2YgOoPcERpCgTR2HGufsrVpEhy02WHRNfa7tqRuYAkx3nH0OMNGsze2svyAFGEARBVDNBEweYqU01IVeROJCoaDPJOCMH/H7FewAAH5y2GE31pgYUlZ7qP6Y2IPwCjd8iDemmcM1sKnwlDYi4MLQxMoVHJ6tmgCCGtsnpTHH764+dqnkP0x6dQBSRZtr0GI6gaQMEETVUnCGImNG5wHw1uJThrN9BacLypIHrMLyhA7t6W/HSoYnQOSYO7GvB5B15N8bYk5fDfGI14UIjZicYEHy5PuBPbADVLziCXLvp7yuKwkwg95cKvjATmshQEcWqGX5bFmkmEl6kGSNsBxhBEARBJB3bxAHftEn2RR5tJj/+t6fHYEtuOEY0HMQVx6wTzlX3jyg/RzS88FSfqS3svp1AuMY2oLrNbUE0X9hmNiBgBHTVGNrYPltDm5zbr3yguP1fL7+HO1JbkWaUNkBUC1ScIYgQcOUCM8bLBcaTFrZ9RptdNHgZAODh/bPQh+ayY3nKb9grHp4LAFg+rAvjJu5VfKZfoeFg9YwY0cTjNfhixOAEA+IXG0D1CI6gBSUbgRFLYcYL8d+oSmTIyBQeQ10106s4P0h/GZlYSE6kWVQOMFqeTxAEQVQVUScOqIg42qwr04s/7joJAPCZk5bCbAwkM63JjqvGZQ5NbarxZABTmwqXsdCAnWYKWqRJum5ycZ1hm9kA8whoKa7jzBJpaKucX6lLZXHc4fkfaMv+gejJ1mveh2Hao9ODtPDcR6SZHyhtgEg6VJwhiAhIpAvMd7RZA1pS3TinEGl2/74jjS7tjacnYnIH0JtK4bQPLII8Q1nEVGiI+xgWQkOGSmjE5ARLutgAkik4XF2TK4EBhFiYce3+Yn8DGcnr2T6nIoMh5ibz234LNSFFmukIIdKMHGAEQRBEfyW0xIHYo83KJzh/89KRyOZSOH34KkwdvsfjM3RNwlVaKkRTG08QU5tB6oCKKDQTEMzYBiRPN7nUTHEWZnwlDYgEiTPLFB49+3PGY2j7yLy3i9sX3nY5d0S26s/W0BYw0kwD+w6nfjNELULFGYKIGp8NLhmi0JAicxo4izYDzhi0GgPqerCpeyje7DwMRhnKABqWtwEAdk3ZglQq6/EpYQgNj0lhmdAImqPcLjyvggINEFxsAOUD/ChFh+vPdS0wQukxA7iJM+OR/dv3FONBRUaP4nwZUUSayUSGQBrlIkOMQmEYOMCMGhgXIAcYQRAEUQskInHAhLTqgGm0mS4+lacR77xZh6d782PPz5y5AnJTm6xfp9fYyNHqGRmuTW0SotRMURnbGEnQTUGx1UyxFWZktGuO+YkzMyYMQxtDrqd+e8ljxe3FW0dLzjA1tOnOFTCNNNOkDZjC7glSQxv1myESChVnCCLhiC4wEeMGlzxp4blRtFnppPMHvwkAeGD/XABsQtg7Q/mxP8zDgGwWW5vqcNrl78CN0LBZpi+gExqyfaZCw9IJpqJWxAZDFB1BhYDr9xOx/fmjEhiAA/eXK5ERyaqZIPFmXrnJUOxnDjAZBrZZr0izAqIDjCE6wKwnngTIAUYQBEFUC5EnDvDmCV+JAypY5I95pFmRvhx+t+pYAMANMxejsa7P4wUmpjZxDBXC6pmQe3YC0RVoAHtjW9i6yfX7udZNpgRJGQAiSBpQxZmpMDG0OV81ozO0qXt0jh10oLj9oxeOlbyWYWJoCynSjNFm9/ZWkKGNSBiWf00EQZhycS6H+1KlG/99q4CLpwsnPQngLPnr5+ItfwO8NlQOMEaiNOk6CuUDiDSEyVd+4lEcRTSgAX04eUBeNT1x4Ajkb8xmE6YHdrdg2vYhWDL2AMaesgK4y2tA2IvyrynZ54hfYx0o/Qz89j7kBRS/T8JBVAqwHVC74RnboJ6EbYfn4GL72gnSuIWNmCBdJbUG05Qu+9WYqizmMVZipvHSX/bvMOgEsYqkLOVn2P7dhSUwVDh1f/F4ZSYnVGT4Q4wwcxBpljZ4TVvh0adrlxxgBEEQRL/iEQDnle9KPQPkTte/bBrWYDWmYipWYw2mYQI22jWQ57VTGuVaqRVmXpMyGlEauzQgP9ZpKOxrFJ7nz//r02PxH7PSGN2YwaXz2nHnoqkojVF6uffhxy3sfXQwUxsbyPDvodJR3NMMKsc8sn07UZpw5bdFtqC8QNaOct20KVUxbpLppiRoJiCvI/qLZgKSo5tCKcyIJH7VjLeh7bYrHixuf/fpk7gjQQ1tDiPNJGkDQSLNbFJqCCIuaOUMQSQMZ/0EfDe41HN06zoMqu/C7t6BeLtzInfELEN5xUNH5x+HdWLspAzMV894EfPqGRFHTjAdUbrBAPeOsKTh5+cLKjB0RC4yZGQ8jlegExkm6ESGLbJIER5RfPiMNOOJqKmlFHKAEQRBEP0UNhnuNYnuK3GAJ606oIo2EycxTaPNgM4dvbhtd37y9NMnLTW4OAY/hgq6esYwEponaLyZjBhTBwC75AGgpClIN5Uw/R26TBqowPDfURGb/pw8CTW0pZDDmZM3AQD2dTWhs9dmVZ/DHp1pxXkOI82kaOIxGZQ2QMQFFWcIIg4s+87E2+CyXGgcPyB/DS8fmoYc6mCbobzomQmY2gH0pVI47QNveJzPE4HQyBQeXQmNkJfqA9GLDaD2ijR+fx4XhZnI3V+xiowOeIsMGWKhxnbVjMoBpspNtow0i6GpJTnACIIgiFpCnBCTrgA1mFizpk2yz2m0GaCfxNSvcvn1y7MBAGeNWIlJ6b3CUdX4KEGmNh028WYSVKY2vwWaMIxtQG3pJr9FJ1PNFGoEtIx24XmQOLNM4VE2jyAlekPbDfPeLm6fdfOVmjNte3QyNOkkaeF5yJFmpoZnShsgkgIVZwgibsJocNkmOaYSGoDEvaBeinpUSzsAYHHHZLtr42h4uw0AsHvyFiCVFY7GKDRkZLht2xxlExJQoAGCiY1qFRxBijIunF/OCzNeVIXICNJfRoeXyBDP8YAcYARBEAQRO6EmDvgytenQxajKEgcaseyNejzbMwd1qRw+c5bYr1NHFZjaRBymDvgp0ADhGduA6i7SBLn2qM1sQELizNi+BBnafn/po8XthVvGckdcRZoBVpFmMjRpA6KhzQSpoY3SBogEQsUZgggRIxeYhCRHm81u2QAAeKtzIiqjgEyizRrx6K3zMKgvi61NdTjtioQIDdEJ5jUh7UpoAJEUaMIUG0D1CI4gBSWb34/fHjNWeegi7cLzxIoMHtU+cdu2UGMrMvhjqtxkgbTwnBxgBEEQBOGesBIHnEabsclIr2gzSJ5r6MvhdyvnAwA+csQbaKjrU5zoZ3VxTKa2CFIHgHALNIA/YxtQPea2oNfpwswGxFSYCZo04En0hrYJQ0qf+YPnjjf8nJB7dPJ6SWNoU6XEmKQNEES1QMUZgogLS6HBiD7arCQ0WlJ1GNWQv7Gv7ebf1C5Def/eFkzbnp/sHHPKCs2ZEQoNGWKUExBOvJkFfgs0QPhiA0im4HBxTTZFGdeFmdoTGXxhVdzHUIkM20gzFbpIMxXumloyVJFmOsgBRhAEQfRbXCYOMNok+6yjzbyQTWaa9eu848nDsCM3BGMb9+Kyeesh79cp0oPYTG0ZVKIzsoWUOgBEU6CpJd3kSjO5MLMB3v+fRCJLGuDJaPY5WTXDCGZo+8tV9xe3v//MidwRmx6dsudsH5szkpAWnodkaLNNGxANbZQ2QMQJFWcIokowdgQ4jzYrMaZhFwDgYLYZ+7ODESRDefkDcwEAK9OdGDd5LxIpNHhk+3R4DezahecOhAaQHLEBlA/woxQdLj83KoERqDDjRaJEhohOZKjOscUkNznaSDNtMR2l73vfE00EQRAEUWUkMnEg1Ggz9lx/fsfuPty262QAwOcWLDH4jJhNbTxhxJsBkRVoal03uf5Mm9+Ba91krJnahed+kgZM+nNWwP6ebFfNBDe01aeyOGH8VgDA1gMD0NUn0z0mPTpF7SQryIQTaWaDs3sCQUQIFWcIImSMKvBhuMAYDqPNBtR1AQAO9LUA4Ac79hnKC5+fiOmHgL5UCqdc94bBp9eY0JCRsAIN4EZsMEQBEFQIuH4/HtufO9bCTLvwPNEiQyyoMmTCQvybN/0O8JubDFRVpBk5wAiCIIj+QlyJAyrSsp1e0WY2RZlyfvHCHGRzKZw5fCWmjdwDualNHEuJYy1xf4JNbT5TB4yawHOYrMyw1U2uiEI3ucLWzBZJYUZGu/DcNGmAR9afU4Z2yiFaQ9un5peKuuf/4QrNdUXYo9Mr0qwA+65WpQ3oDMyUNkBUExZ/WQRBuOC+VcDF073Pm7d7ORYPn6U9ZypWYw2mYQI2eveqGIPSgGMkSgOONMoHFq2QDCbyd9Y+NAEAGlJZlL4++IFBI2xc7qmlbcDx7dg9eTNSdVnksj0o3fx7C5/B7wP3/uLXF/+5/PkdKI0M+O19yIsofh/3NINK4cXv24HShOxOlAYY/PY2lA82tqC8WNaOyknbTamKDOztaydIReRGTCgOVETWYJq2WR4bFJvEKAElsRFGpmtSlvID9qLKRLDFWpixjTPLSC8nTyCRIUOMLfTbX8YUMT4k3kgzgiAIgiDcMhdvYQmOwgy8YzamG5/Lj6/aUDmm4vXSKJTGSwMhMW9JBZQEXucAJa3TUNjXKDzP886SOjx5/lyc0/wm/s85y/HFP5/s8TlMRzFUOomdy+shXnsJOsmLg6gcH5loJqBSN4m0o1w3STSTCi/NBOibjK/G1ERoJoB0kzFeEdAiuqQBr/6coqFNmTQQj6Ht5xeVHF1vbhvNHQnSoxOoLERLSAvPTQ1thn/bDNtIM4JIGrRyhiDiJGCDSxFtg0sZltFm23uHAwBGNOzHgFRnYa//DOVHbj0Gg/uy2NZYh3OuVi0/zWFMwwGcPWgtPjbsNfzz6Kfwn4fdi9+N/zN+P+EW/Hr8zbhx7O347IgnMb91LRrAGmWarJ7hB0gSMoVH474bEvw4wSJaQQPYucEAtytpkoSfnyvSwoyMduG5rjBjQobbdi4yZDGE/DERsVBju2pG9Tw5kWbkACMIgiCIcoImDgTGK9pMxLNuoTrBpl9nI5AFfrX8OADA9dMWoaVBFnUEmK+eYa+JaPVMwlMHALcxZ0DtaibAX8KAyWqZSCOgXScNGBNk1Yy9oW3a8D3F7W8/dZLpRRbQGdpk323VE2lGaQNE0qCVMwRRBTAXmDVt0LvAeKQusHIyfUOwpWcExjXuwokD38GTB2ZzR5mw6OGe6wYMjTi4H2jbmsZbh+/DsJOWA3+ejYF1BzGnZS+OatmJuS07MLd1J0Y3HNJfWJEnsKUnjd/uPhN/ypyB0i1WtXpGdlxxCnOCZWC3egawd4IpCGMFDWC/igYI3xUWFX5Ek6kwc1qYCer+ilVk6Ahr1YxtpBmfm5yMSDNygBEEQRD9nTATB0ZP2aged/GJAzxplIoO/HYRcdzQg8pVNf7GOvc8NQob547EhPqduO74Nfj9i0d4vEK3egaoHB95rZ5hqQMKMsj/TvjVM2yfKRGkDgAIrJv8aCaAdJMKr8JYbEkDKjKafbGtmpF/r/z16vuK2//vheO4IzpDW8g9Oh1FmukgQxtRbdDKGYKIAKMGl5LJtmQ1uMxPVj6w71QAwGdHPITGVA5BMpQbkEXmzmk4d1EWpz/XhYdn3o1Xpt2B3094BF8ZtRDnDN6A0Q2H0JtLYUXnCNy/bzp+uWs+vr/tdHxj84X4yuZL8PdbLsJ/7DgTD+ybgz29AzCuMYN/GnM3fjP+fzAgdQC+V88EzVEOwQkG+F9BE8YqGqDknKo2Z5jfa05EYUaGqfuLxyvOjO0LZdVMD3dMxFWhRhQZ7FHXI0sGOcAIgiAIInYcJw4UYRP7bZJjvF7izRjScYBp/JcYD2SSONCInn19uHn7AgDA509YLHnfIKtn+ONeE8eK1TM8QXp2hpw6ALhJHgiim6oJv1rP9HeUqMKMSCJWzTBMDW2lv+Wm+l7MHZP/IdbuGYquPllxxVWPTglp4bmY2qLCMAWm2oudBMFDK2cIokqZhjVYjalqFxjLUOaRucD4DGVA4gLje7QAf8hchKvSj2F2y3r84vCf45+3XoctvYPhlaHcgDq0Ne3EjObdmNOyHUe17MCRLbvQeqiPE1sHgRTwbs8gvNUxEks6R2FJ5ygs7xyNjhwvZHjxUppsbUoBVw59A18Z9SBOGrgSPzv81/jkpi8iC8B69YwMmxxlEUdOMMDfChogvFU0jKSvpgkihlwUZQBHhZl24blfkSEjkatmbCPNVCJDFBZ2zXjJAUYQBEEQycd34gBjHMyMVAxtuxmxDyY/1mHjFrZKxduU8qvnjsA3rrwf84dswHvGb8OiTWNgN54JafWMy56dJrQj0AoaIHjyAGC/ioaR9NU0QQtIYZrZgACFGRE/SQMZyfuwfaGsmvFvaPvKSYuK2xf/6VLpOXlUhjbxmG71jGWkGf/3zuZE2jSXqCFo2gAZ2ogkQMUZgoibRwCcV74r9QyQO11+unGDS0YbnEabbe8dga9s/hL++/B/w8kDV+DRKd/Fy4emY0nHBGzrHYSD2Xo0pnowqO4QRjXsx7jGDKY378SUpgwaU9mK99vX14gtw5qxdM5BbB4L3Puz9+HdA0Mhd5SpyA9IunMNuC1zPN7smIibJv4KJw1cievST+GPmbNhJjQU8WYZuBEaVVKgAdwUaRhxiQ5XzrSqLszELjJ0q2ZkuI4341Gtlokg0szSAaYVGQRBEARRw1ycy+G+VGkMJI02exLAWeW7TKLNrOBNbbx2SsMi2kwcF/HjD9lYJ29qKxVr2PM8G5fn8GDnsbis9RV84axl+NgtKkcI0zss2oy9lwzeBMPOCcnUpoPXTV6aCYikQAMgVGMbkJxCjQvdFLZmAgL05gT8JQ3wZLhtsT+nElm/pvBXzQDAjWc/X9xetoMXKqaGtvgizRiioU2H37QBgkgCVJwhiCQhERqMUFxgKqEBKFxg+TvtS4fm4+r1N+IfR/8GJw5cjgUDV2LBwJWel3CgrxGruodheedwLOkchbc6R6K9ewBS9X0441N3YHtjHY69/h28+4vjFe9gJjTe7pqAH2+/EN8Zezc+PeIR3JE5vTCE8RIaGjLwl6MckhMMCFagAcwGOX4dYTxRFWtcxwTYxBXEUpjxwqThaobbdi4ydIgixNWqGdlzVdPdKo00IwcYQRAEQRTRmdoYvhIHGLxe4hMHpKY27TIaAb5fJ7+KRnN+rgf/++Z7cNmJr+CaSW/iS80nY18X/15inxkeXj/JroOdY7p6xrGpTcSkQGOI3wINEI2xjSHTMtWgm6LQTIDjCGjXcWZKQ5tqn82qGRn6Qs0xY0s/4OceUEwwKVFF1jsytOkozHmo/l7Z35fu74LSBohqhIozBBERRi4wCZE0uLSMNgOANd3j8bFNf4/JTetw4oBlmNW8AcPqD2BgXQe6cyl05Bqwo3cQtvYOwOquNFZ3pbG5dxAqnR9Arq8Ho9eNw/YZ25CduxbA8SgXA/ZC4869x+OzI5/A6IZ9OG3Qm3jiwHwEEho8tvFmITnBAP8NL4FoxQaPiRgQB1xR5jO7FBhAiIUZv80sIxMZJqtmZMWXIKtmTHOTZcc0hZoQI81cQA4wgiAIol8gSRyQwUxtvhMHnEWbyVacAOXjH7NIM8ajzwzFO8eNx8z6TfjYaSvxk8dU5j0/q2fEGLaQTG0JSh0A3BVogP6lm2z77kRWmGkXnkeeNMBwtWqG7wtlZmh79qN3FLd/tXAud8TG0BZSj05d2oAllDZA1ApUnCGIJOAz2oy5wJQwF1gbnEab5cnfoNd1T8a67rEoH1zwE7GyRpOVPP3H+Rj33QewrjWF485Zh9cen6w400xo9KEej+4/Ch8a9gJOGLCqUJwB5ELDUHBkEJ0TrB01LzZ0xNEss2oEBuDG/ZWRnMv2ecaZhbFqRnyd6aoZEdvcZL6RJe8AE74T0sLHOIo0IwcYQRAEQfhAkzjgjMDRZqpZW9lKmfJ+nfJos0ZkO3vwu40n4kdtd+Izx7yBnzw2B0FNbc5Xz/AEiTeTEWOBBjA319SybnKtmYAICzMiJnFmPEbzJCaGNvEcm1UzDPk5Q5q7MKgpf+zRNZPQl6uTnBXE0KbB1NDGiDhtQDS0UdoAkRRkf6UEQSSUwM4A2c2Pv0mmhWMV9Qo2YSkrZLCJTvEGLk6KNgrP82xdPwSzduXfd+r5Swp7+QGHbqKWd5KUBjaLOiYBAI5qWY+S0AB3rohs4rnyKYDygVlGcVm6WClxMt3UnadYtq3L39UNdoH8gNlk0MxYjanWg/IkY/vzmPy+QluSD/h3f/FkuG0jkSEjjFUzfpDlJvMEzE3miSDSjBxgBEEQRH9HnDAzXSGq6znAJsnZ5DqbiC9O1uvMFLxe4s0Z0nGBxQqTijGJ2UTobx6fhgO5Fsxs3Yr3HrnJ4BUmYy7ZWE021jOMbct47OMNQ7oJcp0BidEu2ReCZgLMCg08taSb/PwssRZmZOiSBiI1tAHmq2bE1+sNbT+/8Ini9ofuukDz/jJkhjb23GekmQxZ2oAQaSamDZgUOqWGNoKoAqg4QxARYlSZl1T4TZqbKYWGDNnNUMwBTYsnhCk0CqtwHs0vuV05/BBGjd9v8Dl6obGhO6+kxjXuFo74EBrscMbjkmSRUQwvZ4446d6uOC9hYqMaBYffa49UYABu3V9ecWaZwmOkq2b4/eL7BcGrUMMQRYaEEJpamkAOMIIgCIKQIFk5qpuQs+7Z0WZ3ehGlVBInM5mZTaeV9MaT3e9mccfeEwEAXzxDZ2oTx1ViPBJvpOGJyNSmKtD4MbW1S/b51EyujW1AdRdp/GomEzNbVSUNGPfnZMjmGfysmjE3tKWQwwfnrig+33FoAHc0ph6dpmkDhrDvdGNDG6UNEFUAFWcIIkaC9AlgNyVPBwFzgbVJjqmWjBsvPZcJDfG4OS89OhlTOnLorkvhjOsXFvb6Fxo7evMDhZH1+wHkoBYa/D6F0ODJFB5thYaIbICYgAKNrdgAqqdQ4/caTX8voRZmZLiOM/PERGTYrJoxy01WYyIyVJMgstUzjiLNGIqmlqIDzDrSjCAIgiAILc4SB3jzhSpxgN8uIhlPVCDGBqlW9lYWcn724jwAwPkjl6Ft+F6PzwEiXT0jM7XZrtSOsUADhGNsA6pPM/nVTV54/X5DL8yYJA3wWPXn3CceQPBVM/x+2fvl+dDckrnrzJuvVLx3bUWaSZEY2ggiyVBxhiCSgqELLNnRZkBpIlQmNHTRZo0AUmhYmu81s2vKFqTq+hSfwaMWGt25/PvXpXKoQ7dwVJcF69MJpsLGCSajXbE/JDcY4L9IAyRLdPDX4vd6TAVG6IWZKOLMEigyzDARGexYBJFmbf4+ghxgBEEQBJEnMYkDjNCizQC1a13N4oVNeLp7DupSOXzx3GWFvVGsnvFpalPt05naqqRAUwuaCQh+PS7MbEDEhRkR33FmIq5XzZgZ2m65/OHi9tPtut+z3x6dyYs0M/nOByhtgEg2VJwhiCQSRaU/9GgzwI/QeOSWeRjal8X2xhTO+YALoZGnDuzmayI0DMh47LOJN/MrNIBQxQbgzxHG46I4EtfnxSIwgHBFhoxM4TGhIiMYvMgQC8YykSHgMNLMBtNIM4IgCILoL3glDphEmwVKHDBBKpVk4w1xdS+PWb9OZIFfvH0CAOAj0xdiQKNoRJPhYvWMuC/EeDMTHBRogiYPAMGKNED0msnlZ9r87F5mNueFGS+cxZnJDG1sXw/K/yDCMbTNHFGKcv/eMycA4H9nsqhE2bZtj854Is1kSO8BZGgjqgQqzhBExMTa4LJN8mKTaLNAGcoieqFxcH8L2t4dBgAYeqKZC0IlNAbW5UVKV7YBvcihsj8GP2Di9wOhC40qK9AELdIwRBHgVwy4eh8ZrgQGEFNhpoZEhp4wcpPDizTz09TSC3KAEQRBEP0Swwm3SBMH0hBQOMul+OvXedeTo7A+Oxrp+kO4YcHqwrGoV88YkLE7PXDqgArV+BtukgcAd7rJpdYJSzfZaiZfKQOA9v9bGTL9HEucma4/J49bQ9uDH7y7uH3j88crXqNC9/3jsEcnI2CkWeDvdoJIGFScIYgkEXaDS0YgoQH4z1A2Exov334s6nI5rBiUxZwT3y0csxcag+vzo6n92WbhPH4gI8tR9sAmRzkKJxgQqEATR5FGRCUawizCiLgUGEBIhRkvbAozPAkVGWaElJvMY9yHC+QAIwiCIIgo8RltxjCKNvNKHLAZJ5SZ2kTsxiq9B7L4zbsLAABfOG4hABOThp/VMw5MbZnCo4mpTSQCUxvgztgGhKebbDVTGLrJ9mfzrZkAt705XcSZGaOaYwjH0DaoqRtThuV7T72xZRQ6e3W9rHSGNjF1IECPThmatAGVoU2H37QBMrQRSYOKMwRRhThzCphEm1UQJNrMjDVvj8KsvU0AgLmXLTJ8VaXQGNOwHwCwo3cId45qYtlCaPBkDPcxbJ1gjgo0YYiNsAo1UROGwAitMBPU/cWTsTw/YpGhx0tkyB4BfW6yQCvUhWpZpBk5wAiCIAjCKX4TB3SIiQMV+Ekc4PGMNuMJEm2W3/e/T87AoVwzjhywBWfOeBeVuFg9w5CNBT3MbVWUOgC4NbYB4ZrbosaPZoqsMBNm0gBPQg1t3z/zheL2+++4WPGaiA1tppFm4/WFEvadbW1o43Bx7yCIMKHiDEEkFUsXmHG0mQzf0WZhZCjn921+ci4A4J3RBzBi7H7JZ3sLjXEN+dzVLT1DJOfFJDREIirQAO7FBlC9gsNPgSlWgQG4cX/J4swSKjLMcJmbrFgRqBIZOjwizXSQA4wgCIIgDKiKxIEwo83y7GjP4c79JwIAvnQmM3p4GV/8rp5hyN7fp6lNh5fxKKYCDWBnbAOqVzMB/q7d5PeTuMIMj0zX++7PyQjL0JbDl058o/hsXSbNHfNraGPPHfboZERlaKO0AaKKoOIMQcSASxeYtdBoKzw6iTYzwZ/QeO7eqWjrzKGrLoVzPvp6Ya+d0GhryhdnNvWkudc7FBpe8Wb8/iBOMCByseFXcCRZdAS5xlgFBhCh+0t1QpJWzZhg/51ThskCwQBNLckBRhAEQRDREFnigHG0mczUJh7n0ffrBICfvXwMAOCiMUsxYei+iuMlwl4947hnJ2C/StyhZgrT2JZkzQT4v05TM1uoukmH7t9QxnBfEVl/TkZ0hrZLZ5b6WF55x/sUrzE1tAFOenTK0ESaMZihzQS/hjaCSCJUnCGIpGHoAotXaKhmLm0ylL2ERh0a35oCANg2ZQvqG2UOL/3qmZnN2wEA73SNkpzDvx7wJTR4MupDUmIu0IThCGMkSXQEvRZT0RVaVjJgV5gxIaM7mAyRoUf33aHKTeafyxxgkmJzWrEdZ6QZOcAIgiAIopKwEgfaJC+WmTPSwvMKqcQmM1Uaio1f/PXrfP3VZjzfcyTqUzl88b3LCsdk/TrhsU92jo2pTYNfU5tIGJopBmMbkCzNBAQ3sjnRTElKGuAxThrwMrSp/u78G9ruufbe4vZdy6crz/MmvkgzsecX+25m39VBEA1tlDZAJBEqzhBEkvHZ4NJptBmP7wxlXUM6GfnzH7rpGAzrzWJnYx3O+/CSwjHT1TM5HNGSL86s6OIrUKarZzyXFKhP8+MEkyETGzLaFfs1QgMIV2ww+EF+FMLD1efZCAynhRmvZfki4r+hqhAZDJerZnRCwiTSDHAWaVYgykgzgiAIgqhVbBMHYok240mLO8Lv14m+HH6x/AQAwEdnLkRLg834K6LVMzwZzeUxokwdAAIXaIBguilqzeTqM21+Zl+aCYgvaYDtN5wWKJ0ozjOIWsmdoW3qsD3F7f965RjkwP8evSLNGoVHhmWPTp40t+0w0sw6bYAMbUSVUXPFmXfffRc/+clP8N73vhcTJ05EU1MTxo4di/e///145ZVX4r48gigSRsXe2FnQVniMLEMZqHS0e9N5sAkTNuRnRBvme2X4lA9gxjbsQbq+E725FNZ0pyEf/IgTy7KRl0+h4cIJBphlKQOhFmiA4EUahigC/AoCF+8hw5nAAMIpzAR1f7H9sYoMP71lGLa5yeJ3jSgyJASNNHPQ1NILcoARBOEC0k1E1WM4AZesaLNw+nX+5fEx2JgdheH1B/HRU9lAIejqmRhNbSIJKNBEqZvC1kyudJMJvs1sQDiFGRMyuoMm8wPivnAMbQ998O7i9j8+sUBxlirSTHaejx6daeEUmaHNINLMBGff5QSRMGquOPOzn/0MX/7yl7F27Vq8973vxVe/+lWccsop+Nvf/oaTTz4Zt99+e9yXSBBSYnGBMULNUJbd4M2FxrN/OA6NuRzWDsjh2Pe2F455C42ZzfkR2LruEejOidcgGxzphIY4uNLkKGcklyNi4wST4adAkyCxoUInHsJ2k9kWZZwuyQfM8pJdNbOUErXIEM+1xU9usuy7KECkGSOESDNygBEEEQWkm4iaIqzEgTbJiwJFm3nhr19nz/4sfrnpNADAl45/DSnojBumq2dU5wC+TW1B4s1iLtAA5sY2IDzdlATNFLqZDQivMOO0P6d4YnSGtqHNnZg+IgMAWLx1FA72NHFHgxraLFHNF1lEmolpAzrjMaUNELWG7+LMBRdcgLvvvht9fX0urycwxx9/PJ5++mmsXr0av/nNb3DjjTfizjvvxFNPPYX6+np89rOfRVdXV9yXSRB6HLrAoo02U8GizeyFxqa1aRy5K//eU85b7HF2aYAzpyU/SlvRNYo7ZiM0ZBPLmgabSXCCAVUvNqLG9udwLjAA9+6vjOYYkBCR4TLSjEf8jjEVGj4izTQOMFFkEATRf0mqZgJINxHVg220mQ5marPuZRBptBlg368T+OUj07Av14oZLdtw4dwNhb2uVs/wmJjaIkgdkGFToGlXvIeBZupvusmPZgqUMtAu2W+imUT8FGbYfuk/Xd2/b/5vQdzvYtVM+Xv87IKnituX/OlSxev9GtoMenS2IhJDm3WkGQelDRDVgu/izCOPPIIrr7wS48ePx7e+9S2sXu2dpR4FV1xxBU4//fSK/aeeeirOPPNM7NmzB2+9RUvhiCoiYINLJW2Fx0ijzQA/QmPlA/MAAO8M78DhMzKFvXqhMb/1XQDAokOHK65DHCT1Qt9jI2ShkaACjR+xUU2Cw881BxIYQLiFmaoTGeK5pui+K8TYRJlbzEdusgwfkWbkACOI/ktSNRNAuomoESSmNtmEXXVFm8kmTHXkj+/enMMfdp8CAPjaGW94vMbl6hlL042xUaiAbeqAC93kkTwA2BnbgOos0vi5ZiPN5CJlIEgEtIqM7qCf/pw87lbNNNT14fqjS9ph4z7bORmegIY2niREmlHaAFGF+C7OrF69Gt/4xjdQV1eHH/3oR5g5cybOPvts/PnPf0Z3d7fLa3RGY2P+C6ehwaQxMEGETyIaXDJiy1DWkX/Na0+1YebBFPpSKZz6wdc8XtODxlQf5rVuzb+2YyzMhAZDNjFtiCrerMoKNIB9kQZIdqHG77UZ/x6qoTAjJQ6R4XLVjElusu4cj9xkfpscYARB+KAaNRNAuonoXyQv2gzwGzX0X8/MRV8uhTOGrcTcw9nA0fXqGQemNpt4M5GoCjSAc2MbkGzNBISsm8LSTIBdBHSG2w4UZ8bQzSeojGv+DW1fOH5xcfu0318tnOtlaJM98scte3TaRJoxPCLNdJChjahFfBdnpkyZghtvvBEbNmzA3XffjQsvvBDPPvssPvjBD+Kwww7DV77yFSxbtszltQZiw4YNePzxxzFu3DgcddRRyvO6urqwb9++sv8IIhbCjDbT4SvazDZDWbaMVk/HKzMAABvG78DAISxiQy40jmrZjpa6PuzsbcXa7mGKd+QnlSvfI7DQMMXLwRNmgcawSOOHJIiOoNdgXJSJqjAjYuv+SrjI0BNjbjI5wAiCCEC1aSbATDeRZiISic/EAWOcR5u579e58q16PNBxLADg6+cvUVwcw+/qGYYu3kwk4tQBIPHGNqBcr8Slm4JeQ+xmNsB/BHSohjbV31dQQ1sO/3leyc313IbxyqsuoYs08zK0OezR2aa7xhLsuzmIAZkMbUQ14bs4w6ivr8ell16K++67Dxs2bMD3vvc9pNNp/Nd//ReOOuoonHLKKbj55pvR2dnp4np90dPTg+uvvx5dXV340Y9+hPr6euW5N954I4YOHVr8b8KEZLoZiH6G62gz0QUWONpMBS80dHgLjYf/cBTG9GSxr74OF3z8de27HduaL0It7DgMQApyocEjDpIizFEW8ZqQB+yFRrvmvUIUGwxxwB+G8HD5Gc4ERrvimJ+8ZCAE91eSRIYpLnOTBcRdacnHU1NLgiB8Ug2aCTDXTaSZiKip2sSBiPt1AsB/vnocAODqw9/A2MEHCnuDrJ6xMbWJ+wx7dmYU5/D74y7QxKCbwiAWzRRnYcY2aaDin7NXwVE1jwCUdJJMK/kztJ07ZX1x+7MPnC2ca2NoA+wMbW57dNogM7RJv+PJ0EZUKYGLMzzjxo3D3//93+PGG2/EuHHjkMvl8OKLL+JjH/sYxo8fj3/7t39DNpt1+ZGeZLNZfOQjH8Gzzz6LT37yk7j++uu153/rW9/C3r17i/9t3EjNfInqIf5oM90SWL8Zynmy2QYMW5Ef/O2dtR51dbLJ3vy+YwfkR3KvHxoHM6HB4zhHOcNtR+EES7DY4JEVU4L854LQBQaQQPeX+DwekaFHJzJkmOYme0Sa8UQUaeYFOcAIojZIomYC7HQTaSYiUYSZOKDoJwfAO9osLR6Mpl/n088Pxms909CU6sMXz1Mb+vKYrJ7hcWxqk817y4xGjDgLNICRZgLc6SbXmsmFbrL62bx+X+2K/WEUZkzQJg149edk2CQM2BvaHr3+ruL2rxeqUoFMDG1ikSbkHp2MAJFmBFGrOCvOrFy5Et/4xjcwfvx4XHvttdi9ezeuv/56PP744/jRj36EQYMG4Zvf/Cb+/u//3tVHepLNZvGxj30Mt912Gz70oQ/hl7/8pedrmpubMWTIkLL/CCJMXLjAnDW4ZPiONjNF14NGLTQe/t1xGNyXxZamOpz3oaXSd25O9WJ+od/MK4cO544EERriPkfxZhlu25XQACITG64KNXFi/bMEERiJcX/xyNyM0YsMM2QFFzECRFa44UWGASoHGPteJAcYQRABSKJmAux1E2kmIrG4ThxgtBUebaLNjLDp1+k1likc78vhZ8sXAAA+dcRraG1gYzAXq2dU5/s0tfFkuG2d4cg2FlpFkOSBiIs0ScBaM4WdMmAbAZ3htmX/vnwnDfDniqvLZP05/Rvapg/fU9z+9aI56MvxU7p+esSpohR5HPXobDO7IkobIPojgYoznZ2duPXWW3H66adj1qxZ+PGPf4zhw4fj3//93/Huu+/i5ptvxllnnYWvfe1reOedd7BgwQLccsstrq5dSzabxUc/+lHcfPPNuO6663DTTTehrs7pQiGCCB8Hk3LGDS6dR5vZCA0v8q85sKcFbZvyF9V04goArLBVGtwc17oJrXV92NIzEKu7hxf2qoSGbAWAWLgRhYbDHGWRpBRoDMUGUL2Cw/q6TX4v7Yr9qv8Xsbm/bOLMwhcZ/pBNTshyk8XvF/adNKRyN0/a8DLaCo+um1oSBFFTJFkzAaSbiP5J8qPNAO9eEGr+9Og4bMiOwvD6A/jY6R7uP6erZ8R9lqa2DLcdJBZa1rfTZgUN4MzYBlSvuc3XdfvVTECwwozzpAEv+PkB2WoaL+wNbfd/4O7i9tcfPU1xlszsWn09OnXf0brYSoDSBojqw/eo+wtf+AIOO+wwfOQjH8Err7yCa665Bk899RSWLVuGL33pSxg2rLwJd3NzM8477zzs3GlqY/APExi33HILrrnmGtx6663aPjMEURUYusDYTcy6wSUjUqHBPzfjmVtOQGMuhzUDgJMvqPwZTx20CQDw3MEJqBQSotAQESehvZwxgK8cZRsnWFgFmnbFMYbPIk2SBYdvcWEiMNoVx8IszGS4bd9xZuJ+r5gKL1yvmrHJTbYVGdFFmlk3tSQHGEHUDEnWTADpJqK6iCpxIL5oM69+neLYR92vs/dgFr/YcCoA4O+OfQ0pianNbvVMEFNbDKkDqn1hFGgsNBNQHeY2X9cYxMwGuC3MOEka8NOfk992a2gb3tqBGSPyF71k20js7WoxeI+Qe3TKcBxpZpwSQ2kDRBXjuzjz85//HCNGjMAPf/hDbNq0CbfddhtOP/107WvOOOMMfPvb3/b7kUawJfm33HILrrrqKvzhD38ggUEkHpcNLp3hJTQqCDtDOb9v85qhmLVjEABg/Hlvcsfzg5zTBvLFGR7ZJDE/WBL3MxzmKGe47bgLNIC32ACsxQaQrEJNoGsx+dnbNcfiLsxYZybLXpiUVTMmucni+TKR4UHSIs04yAFGENVJUjUTQLqJqDE8JuaSG21m0q/Tn6ntl49Mw75cK2a0bMNF87z6QnmtnhGxMbVZkuG2g6QOqPDTu7Pd4z0DFGmSoJsCXYsLM1sUhRkVvgozuvkCE71jb2j776tfLG5fcfvFwlEbQxvgrEcnv02RZgThGz+hhACAxx57DGeffbbVaxYsWIAFCxb4/Ugjvve97+Hmm2/GoEGDMGPGDPzrv/5rxTmXXXYZ5s2bF+p1EIQzHgFwnvdpc/EWlkDeEG4qVmMNpmECNmIjJmD0lI35Qdf4XH4g1Yb8YGkcKgdGIyEfBKXhMdgRBYZs1MMGIeyryGyCd8ldxwKfeQbLh3Zj1vwtWL4wf9ef2LgPk5r2oydXh5cOHsa9pzjw6IX86499foPkvA7kfybZ+/Gw8/S7KsjAPEoJyP8/EQdA2yCfPN4CtdO/vfDYpvksNtjWuQYViIP7ogMxJJwIm6BFGSBYYcaLDLft1P0VvcjQoxMZMmQTGCKWkWa6iZa2wmOYTS3JAUYQVU9SNRNAuomoYZ4EcJb3aTPwDlZipv/PGYP82I7XS6NQmhweiNJYrRXC+IyNR0yLGY0oH1M1ID/WYvvZ8xKZrcCtuxbg8yMfx9dPX4j7F08sHOH1jEwXyfaJ47lGbj/b5kUP28/2scd9yP/sHpopg9KY7CBKBhp+P5D/XfOGGlEjyTQT4F83tSmOMapEN0WmmYDwzGyyfboI6FDizBiyxA53hrbm1hyuaysZU9fsGSa+oICJoU0s0oiGNgNijDTzwsvoTBBJxHdxxlZkREV7ezsA4MCBA/jBD34gPaetrY1EBlGdSITGvN3LsXj4LOnp07AGqzHV/nOY0OAJRWjIijLeQmPpy+Nw9XUNWDa0F/Oufh3LF+adI6cOXA8AWHhoNA7lZAsDmdDgizD8Z/ODEZXQgLBPFBoeZCAXGuIxL6Gh2qcTGkBsYoOhEwKmAiQ0Z5kLgQGEKzIyHp8NeBRmVKhEhriiLCmrZsSeMrbDGcNIMxk+I81kkAOMIGqfpGomgHQTUZ1cnMvhvlRpzHbfKuDi6erzU88AOWGxms7UxlCa2nTwRZo0SuM2fhuAREChfLwjHhO1krhPpFTI+cmTR+NTVz2F09KrcHzbNrzarps1FU1oonFNdk1+TG1VXqABqkI39WvNBIQcZyZ7A1EnqbA3tH3r46X/3+f+8SrhqK2hDdAb2gL06PQRaabDOG1AY2ijtAGiGvBdnEkqN910E2666aa4L4MgYiWwC4whWzWThqXQYPuBysELj1gokVESGlsenwu8fxHeGXUAY9v2Ymv7UJw2cDMA4LmD44XXea12Ec9l1wPohYZjJ5iI6wIN4FZsAIEEh0hsy/nDFhhAgtxfJpnJ4t9wOCJDj83wxFFucsiRZtTUkiCIpEG6iag5DBMHdLDEASVtUCcOyOBNbVJkOklEVjARj8tNbauX1+OugyfgmkEv4h/OW4TL/vcC7n1MVs/4NbXx4ke2T4Of1IEoCzSAmbENcFKkkRGLbrKJbWv3OF61hRne0Mb2ybSSO0Nbqg74zog7i88fXy3OdTBk8x0yQ1vEPTrb9J/A0gasDW0EUUP47jlDEIRbwm5wyW56ygaXbYUTQ8lQlmGToVw5UfvcfdMwtQPorkvhjBteRUuqF8cNyI/mSsUZXZNLMUdZHDCJz21zlCPqP6Patw3++tAAZr1oGD7ylRMBu+4onF9VIzL4feGKDHNk/adizE1mtBUeqaklQRAEQSQTyUpU2QQfM0/oJga1sIl9ftwgi/kBJAUH236dDJOVw6Vx0b+9cDwA4OIxSzFz9G6P9zYd2/Erqhmy8Z6Pnp08GW5bLHJloEfW09G2d6dX/852j2tgVKtmAuw1U7vmuO53GoZmUhGoNZJXf04ZKgObt6HtUx8q/cCfePh8APz/C9V3gEwXmRraPEhKpBmlDRA1BBVnCKIacNDg0hrZTZW/+fI3ZanQMLixFxEnYBuE/bJz69D7an510IaJO3DqyE1oqevD5p6BWN2dlrzOVmjwz9mgSbbyRyU0FKgGgmEUaIBgBZp2j3N4bIsdceDnGtvh1vkFxFSY8cKPyBDPd7VqRkWIuclesO9DijQjCIIgiKrDa4WqFxWmNh18kSat2AZQqZXE1b785KmNqa2Sha824ZHOY1CXyuGbF7zJHRGNaSI6U5t4Hj8ONDW1aQo0JqY28ZhsxbmNsU1FGMa2JGsmwL9u0hFFykBGeG6VNBCkPyeE5+LfhCniuTn8cvItxWe/f/UIxetcGtqqLNKMg9IGiGqFijMEUW0YTt4FaaJWhuwmm1ZsA5AXZbyEhniuGQ/fchTGdWexr74OHzgmLzSeO3gYvB1cNkJDti1b1sw/+nCCeR0Lq0DjukgDJEt0+L2WdpgJjDgLM1b4ERm9kudRrJoxyU2WiYwAuclekWYBcfZ9TBAEQRCEPnHAcOWpTeJABW2FR51pQ2Vqk+LH1MYQTW2SAk4W+LeFJwEAPjBpEQ4bst/jM4KY2hgmpjaeAAUanjgLNO0e54gkSTMB4esmFX4LMyIZ4XmgpAFxv6o/J78t/j0w/K+auezSA8Xt77x8GrLSvroiJoY2oNzQJiMZkWaeUNoAUQNQcYYgEkTVRZsZCQ0TsSHLPvUWGn19DRi6bBKQy2HanrzIePrA4ZrPMREasoFVDPFm4jEgnAINYJaX3Q57wQGUD/KjEB5BP68dZj+nrcAA3C/L548FzkxmiIUUWWQF/9owV83IhIX4XWHbOi/8SDMd5AAjCIIgiHgJNXHAxNTRKj6RRZupTG08dr0jnnxmCF7pmYmmVB++fsFS7ojL1TOybVNTmyU2qQOAu2joMIo0QPSaycVntiO4mS1IYcZPb07jpAHdPp2hDXC3aga4e96vi9s/fHyecNTE0Cai+j4J2KOTEUKkGaUNEP0BKs4QRLUQwBEQWrQZj3GGsig0ROyExgO/PhZHbstixD6gpy6FVw6NLRzxKzR4XOQoK0higSbMIg1DFAF+BIGL9xBph3lRJq7CTKhxZl4ig+2H4piKMFbNMCLOTTaMNHPS1JIcYARBEAQRHIeJA+FEm4l4RbKqJlb544Bs/JTryeLfl54CAPjEjFeRbun0uJYoTW0+4s2A8Ao0QHzGNoYrvRO3blIRRDMB/pIGPOVRkP6cvKFNljRgb2g79fTS3+ivVsxHd59JfxnR0NYgPDfBokcnwyLSzAQXhjaCqCaoOEMQ1Yhlg8vABI42002U8pFE9kKj42AzTnl1MABg1aQcuuA10PQSGuKgCpLnXisQDOPNwi7Q2LrBADOxAQQXHDJ04iEsJ1k7zH8OL4Ghcn5FXpjhSZbIMEf2XSDmJquiEUUMcpNlRB1pRg4wgiAIgjAm6sSBCtoKj36izQL362TIekqIlMZJf310FFb0jcegui78n3OXceeEaWoLMd5MRDwWdoHGpkjTbnCuKaaayaVuakcyzGyAXWEmAw+8dLtpnJkOE7Na5TmPnV5aNfO1u08UjsZgaJPhI9KMfaey79hAaAxtlDZAVBNUnCGIKsdWaDBCiTaLSWiMXJzffnF6Hc6/YQl3jh+hISJOVsuEhg7D/jMiXr1FTAo0uv0uxAYQjuCIgnbYFWVcCQzZfueFmaBxZuGJDD02bjBV8cZnbjK/neBIM4IgCIIgDPBYgRp5tFlasQ3AX79OSV+ZCirHVdnOLP5z1RkAgP8z9xW0NJiY1ryO60xtDuPNXJnaAPsCjWvdVE20w61mCrMwI8Ifs04a8CocqoqTuv6c5oa2Oe/pQ3Mqf/zBd2dif3ez4kwTQxs7L6ChzatHp89IM+u0ATK0ETUIFWcIImG4aHApg930rB0KNtFmZegmSFVCg8ds2e3Qui4c1bAbALBoWgp1J74DIKs4WzYQkjm8VEJDtu0lNCA5V7ErIxzTLdUHwi3QAHZiA0i+4GiHW4EBuC/MiHgV6Sowze9WxZlBeO5OZJgjWymnExkR5CY7jDTzhBxgBEEQBOEOh9FmRWyjzYwx7dcJePfrlJ9788PjsTE7EqMa9uMTp6/kzglqags53kzcneG2wy7QAGa6yZR2JFs3tcP++vxqJsBdYSa0pAEvQ5tO/+h6M6n33XPRHcXtj/7pDOGoraEthh6dDItIM2epLwRRhVBxhiCqlbiizXxlKPNLZnWoml0C8gnbRpwycDPqUzmszQ7GwcE5rG8BzrqG/z2YTharzrMRGo6dYEA0BZqwijTsv7jwex2mRZkwCjMZblsnNK2bWXrFmYHbDkdk6AlTZHis4EtL9oUYaUYOMIIgCIKIh6DRZhUTjG2FR2bi8DK1eSYOyJCZ2njs+nV27c3ifzaeAQD42vEvoz7lytQme624zfBpahPJcNsuCzRRGduAZGgmwP91BDWzxVaYsY0z64D3v3kTQ5vqeYm2GTlMrcv/UpfuG4vtBwcqzjQ1tLFzdZFmHkVhL0Mbw0GkmYu0ATK0EdUGFWcIoh/jK9qMEVhoAC6ExumD3gUAPLFnImZsTuc/8bSlmlfohIaYo8wPukxylHX4dIKJyI7JCjSu3WCAP8EBRFesCfo5pj+fH+dX6IUZHj/ur3BFhjnq/lLyc33mJqcV+xMUaUYQBEEQhJ6wEwd8E2q/TkCvnwCVqY3x8wenYVduMCY17cK1J67lzkmaqc0jdUDEVYFGtz8MYxujXfJfGLj4nKjMbID/wowSnYHSJs6M788pQ9WfU3ZOiZuvfry4felNFwpHbQxt7HzT4m2AHp1RRZrxBPiOJ4ikQcUZgqgCXDW4DBxt5jxDWXbMXGjUI4tTB24GADx98HA8f8uJaMzlsHpgDie/j/8ZdQLCC1c5ypCcq9md4bZNGhu6coOFWaRhtBv+5/d8W2yKMmE4v4CAecn8gaBxZuGIDH/wf+/8ahlTkeEzN5kRR6QZBznACIIgCMIRYSQO+I02U5ngiwTp12nG/u1Z/GrL6QCAb57yCgDVz+LC1GbTs9Nx6oDsuIsCDRCNZgLsdJDNuX4JWpQBwivMiAROGrDpz6kztImoj486HDitPj+Ps6+3GWv3pD3ey6xHr9rQJpLwSDNKGyBqFCrOEEQC0U7C+Whw6Ty/M/QMZYZeaMxr3Y6h9d3I9DXhzY6RWL9yOI7cnlc8485bLJztJ0dZtzIgBCeYSIbbdlmg0e0HoivSeNGOcFfauCjKAMEEBuBwWb5snyzOjBGNyPBG/Ds3FRnsOS8yZBiKDEbUkWY85AAjCIIgiFCJPdqMp1V8IiYO8KY2Wb9O+cqY8ufycdVPHpqFA7kWzBnwLi6bv5474trUJtvWTYIzAqQOuC7QuDK2ha2bwsLm+r00U5iFGf64s6QBv/05VYY2/ZzEf39kYXH79N9epbp4yFfKqAxtOlObgx6djIREmhFENULFGYKoZhw4B+KLNgsuNE4buAkA8NzBw5AtfJ0t/NMJqM/l8M7gXhxz5gbNxTNMJp1NhEYITjAgvgINYCY2gGgEhytsr9XrdxBVYUYJ/w/GRmR4OR1tRAZPkFUzMuEgy02WfU/wx/wUghFapJkn5AAjCIIgCN9oo80CEH20GY/NWEYswpiZXLZvSuG3O/KrZ7575ksoXz0T1NTmIt6MJ8ICTZjGNqB6NBNgr5n8mtkiK8yYJg2I53tFPbshPTKHqxtKlYjFW0cLZwQxtHkRoEdnwiLNKG2AqEaoOEMQNYJfF5gxzqLNVEtoRVTRZiXOGJSfGH3mwKTivhWLx2D27hYAwPRLXhde4UdoiK+XCY0QnWAiYRRoXIkNIJmFGj/XZCIwoizMWLu/REzjzPzgctUMQ+YGUxVkVEQfaSaDHGAEQRAEERNhJw6IpjYZXtFmxv06VaY2HpXGko+ffvTQXBzMNePogRtx8dEuTW08OlOb17jVIHXAdYEGCGZsq3Zzmx8jm2szW6iFGdk+mzgzcOe6XTXz408uK26ffeuVkutl+DG0RdSjk0GRZgRhBRVnCCKhBGlwqXMauOqHYJ+hrHNjiEJDdqycsQ0HML05g75cCi8cPLzs2LK/HgsAWJ7uwlEnbdZdVAGV0DDNUWYkxAkGuHWDAfZFGqB8cB+l8AjyuUEFhovCjIgv95dNnJl4zJ3IMEcWvaFygOpEhki0kWbkACMIgiCIhBJgYs8z2ozhJ9qsDJWpTTeJatKvk6e0b0t7Cr/fdRoA4DvnvIzgq2d4eOOPqmen+Fk6U5shGeG5SYHGpbENCGZui7pY4/ezTTVTmGY28bhRYSZonJlpf047hqaz+HhTSQw8uXaCcEZYhjbHPTrb5J8SVqQZQdQKVJwhiGrHUGjoHAnG0WaBMpRFdPFDukJN/vjxA/JFl6WdI7E321x27uIXJ2B2phG5VApHXvWK8B42QkN2LIx4swgLNEC0YoNHHPy7EB+u3jMMgQFEJDLEfckQGd7YiAx2vteqO4MYEFkRWVaUEUWGRaQZOcAIgiAIIrmYJg7EH21mi8lYqZz/98g8dOSaMH/Qelx4lElcq0wLicY10dQm27YZt1qmDmTkp2mP2xRovI4x/Oommb4JqptcvacLzZSYwozXa4L057QztP3fT60sbl/xl0sApBTXVvuGNk84Q5uYNkCGNqJaoeIMQdQQoUebMSLLUGZUCo0TCsWZVw7Jc4dW3nU8AGDZsE7MOeFdg88wERr8uS7jzVTHEX2Bhh0zLdL4LdQwVELB9L8gmP4MXr+PWAsztu4vr8xkdyLDHH3D2vJz+OdMZMgQVuqpTtM5wAwhBxhBEARBxIvrxIHQos1C7dfptXpG3tNz4+oUbt5T6D1jvXpGh42pzXHqAKDv2ykeZ/hJHgizSCMSl2YCzH+GODSTESZJA/w/Iq8EDa90AnMGD8nic60PFp/fvWyqcEbCDW2MtsJjwEgzadoAGdqIGoeKMwRRRSSmwSXDSYayH6GRw4kD8qPMVw6N546Vzn3j2dLqmdnXvCq8h1+hEUa8mcFSfRcFGluxAY9jPK4ER1TYXK+fIpYqKzkj7IvM/cUIkpnsGi+RwX8f8M9ly/R95Cbz2yOFR5+4ijQjBxhBEARBhISDCb5A0WaMtGIbgP9+nQyb/nx5fvjYPHTmGnHckHacN3uTwStcm9p0+EwdANwVaIDgxjbAnbktKmyuNy4zG2AZAR00aUAkuKHt+58sGb2uvesi2K2aUZ3DozvfgaHNo0en30gzguhPUHGGIBKMdlKOd4FJhIatC8x5tFla3OFOaExs3IdxjQfQna3DGx1q5bPqntLqmdkn6nrP6AZOuhxl8Vx+2+FSfZGM8Fw2YBXPAcIVG0CyBYfttYUpMACfhRkeG/eXrciQkaRVMzoC5iYzKNKMIAiCIGoa22izwIkD1iYQVb9OnamNIVtxDOU561ekcEsm33vmu+81XT3jNYZ0ZWrjcVCgCZo8UOu6KQzNFGthRrbPJGkA3Lb4N2DTn1PPwEFZfHHQ/cXnd7w1QzjD1tDGjolGVz7SzKehjRFRpJlX2oAr4zJBJAEqzhBEPyCWaDMltkKDJ//8hAF5R9fizrHozDVAJTQWPT0Bc9jqmat1vWdM9tvkKNs6wbz2SXZnhOcu3GCuxAZQPrCPQ3T4/fywBQYQoDATxP2lizNzKzK8MREZ4nmiyAgpN5kizQiCIAiiKnEdbWaFSbQZwyjaTNevU4duLKXnh48cg65cI04csg5nH2liapPtk5naXMeb+SjQiARJHgC8NVEQ3RQHYWgmaI6rfrcZ4bmTwozfpAFdcdEGvaHtXz7dXtz+8L0XIBd41Yzp37yFoY0RV6QZj+a7nNIGiGqGijMEUWOYTv45jzYLTWjwjo/8YGNe63YAwOuHDvO8vFV351fPLDdePSPbZ7KE2a8TzNFSfcCuQBOV2GCEXawJ+v5RCAwgQF6yiftLddwrzkyGTny4WjXDkImIRu7RRGT46WWFwM4vhqvvU3KAEQRBEETIGK5ctUocEBETB0Lr16kbJ5msnimx7p06/HHvAgDAd9/r0tTGn+vH1BZCLDRgPib3a2wzPUdE1DSudVPQ9w/6c6s0U0bY57QwIztukzQAyXMHq2YG5/DVQfcUn/9x8UzhDD+GNnbcp6FNRtyRZpQ2QPQDqDhDEAknygaXgaLNGGnFdgU+J1PRiCNb8sWZtztHwUtoLHxmAo4KtHrGJkfZjxOMJ8ICDeBGbPgp1AByYRDkPz+Y/gy2AgPwJwIDu7+8mlny56nEsVdmclB0kRv8cdnfs0xkiGhyk2OKNPNsakkOMIIgCIKIhciizRjO+3UyvPp1qig/9wcPz0dXrgGnDF2NM2e5MrUFjTeD5ByvfZJDGcnxoNHQgJkWCqKZgPg1E+CmGOXKzGZdmPEbZ8YQ+3NCcZ6I3tD2T5/eUNz++P3vRTbnNT1rYmjzioP2mINJK7Yj6NHJQ2kDRH+DijMEUSt4OApCjTZzLjRkGcr5gUZzqgdTm3YDAJZ1qWZYywclK/96HIDC6pmTRKHhdxWAi3gzR0v1M8JzFwUawFxIBBUcUePq5zItzJjkW0fi/rKNM5MR1qoZ/m+2UTguizjkz3WUmxxypBlBEARBENESxYpUT1ObjMj6darOlxds1r5Tjz/vPRkA8N3zXhWO+jW1eZ3LbztOHRAPZRBe8oCNtqgW3WRzvV6aKfTCDI+uMCMiFmbE4qFJVB9ga2hrHQx8c+Bfi89vXnSkcIZLQ5uM8Ht06ggaaSZ+t5Ohjah2qDhDEDVI5NFmDN3Nu0gwoTGzeQ8aUjns7G3Ftl4xDFWxeua5iaXVM1eJq2dUmAqNIPFmsmOAswKNjRvMZZEmiYLDVlz4FRgZYZ+J4IvN/eWF61UzPDqRoTtHhePcZAvIAUYQBEEQyUA7QRdS4oAnptFmSmz7depWz8jGVuX7/vXh+ejO1eO0oStx2hFbhXOjMLU5Th2QHcoIz6M2tvHnJk032V6XKzMbELAwY7qqSvfvyW+cmQyPXjOfXV/c/vSD56LPeNWMH0Mbm2+RzcNwpBXbDEtDm0mkmScUaUb0E6g4QxBVQKKjzQJnKMv2qYXGkS15kbC8czTKl/GKqFfPHL3gXeFc1eDJb46yjROMx7RRoeL0jOQcV2ID8C844hAdfj7f5NwgAkN2XuDCjAq/cWZhrpoxyUQXs9PFiYaIcpPbCo9hRJrxkAOMIAiCIKInjsQBhlXigMrUZhoNbRJzVM7qdxpx+96TAAD/ev5Lhq/yY2pTxZvp8Glqkx3KCM9tNZOrIg1/flyFGj+f79LMBjgszMgioE0MbTItpIszg3Ae/1q9ZhoyPIevt5ZWzfz29dma947A0MYTsaFN9l1LhjaiP0LFGYKoJapKaIiDAjOhMbVpFwBgRddoq8tb+NxEzNnThFwqhSOuetnj7KA5yuJ2REv1ATcFGpdFGvE1YQmPIO9vWpRJTGGGJ2icGYT9stfyz8Xz/KDvE1U6R3ec4TM3meFTbFBTS4IgCIKoDVwkDvjq12kUbcajGu+oTG2y8wD5OKx8cvf7Dx+HrlwDTh26Cu+ds0l4HxtTm6oA4zfeTHYMCL1AI57HiEIzudZNQd/fb1EGUGum2AszELZFQxt/XKeRzLjxMyUd8ZG/nSdZNROxoU311ZKkSDOOKOIqCSJqqDhDEDWK02gz5xnKPKrJVbnQOLxxLwBgU0+aOw8wERpL7zgRdbkc3k734Liz1wmfFzRH2STeTPf+IS7VB/yJjTAEh/haV/8FuQYdXgIjI9lvEifnOy+Z33YRZyYr1IQRZ+YlMmS5yTqRIWKZm0yRZgRBEARRk8SROGCMabSZZ79OZmrTGdtMzDByVr3TiFv2nAoAuPG8FwDoVvPqTG0iKlObbNLbNnXAYYEmLGNbEnRT0M/Xofs9ZCT7ApnZxIMmvTlFTAqHpkkDZoa2kWOz+FzjfcXnt745S3FtQCINbXFFmmm+uyltgKgFqDhDEFWKjWMgcLSZCicZyiL6gcPYhv0AgC3F4ow5S186DHN2DQAATLrsNY+zbSamg8Sb8Th2gmWEfbIBMDtXhZfYAOJfim+DjbjwIzC8nF9AwLxk08JMkuLMeHQig48p9BIZDnKTGRRpRhAEQRD9jyQkDvCmNlmcEADP8Y4RonPe29T2L/fNx8FcM94zaD2uOi5MU5ur1AEg8cY2oHo0E2B+rS7MbOxcHl+FGVutLdNFfuPMvPn3j79d3L7urxcia71qRmZoY8cdGtoYFGlGEJFAxRmCqBKMG1wGiMxJRoaynqH1nQCATFZ2vrfQeO2Wk9GQy2HF4CxOu1ScUA0iNFzEm8mOAb6FBmDnBhPPZZiKDSCZhRrba/IqymQk+00FhtNl+SJiYUa2msu9yPDGr8hg+3TFGo/cZBkjhUdLKNKMIAiCIGqL2KLNGGnFdhlDINdP+n6d5Zivonm3vQ6/3H4WAOD7Zz2PulRWOMPWvGOySkHcJ8OrQAOD/ZJDGck5fo1t1Wxus7kuP2Y2IOLCjJ+kARE3hrbDJ/bhww2PFZ/fvnSm5r1sDG26ORRLQxsjQI9OHRRpRhByqDhDEDWMTGjIJg9DizZLc/vSlaeVMBcaA+u6AAAHs03c+eZCY9WSUZi9NT+Zmz5vMQBToeGVo8y/RnTd8K+PeKk+EL3YYMQpOvx8dmwCQzxBtyxftk/l/uL36f6dJnHVTCMqhYaP3GRZpJkjKNKMIAiCIKqDREabRWBqKyGbwPU2td34tznI5AZiZstWfORU01lRk4KLqJGCpg7wiONoBwUaW2Mb4N/cFrVu8vPZJpopI9mvShkQzw1UmOGxTRoAd1z3bxKK43p++uGFxe0rbr8YOaSEMxJgaEtL9lka2pxEmvFQpBnRD6DiDEH0EyKLNrMWGrIMZTXZwtdWCjn4FRrP/GYBWrJZrB2Qw7kffBt6bHOUZdteTrAQlupHITb8Fmpcio+g72vys2Rg//vj8V2YCer+so0zE/d5ORdNCCIydDjOTY460oyDHGAEQRAE4RbjCbs4o80Yvvp1yva5Xz2za0sK/7XpHADAd055Hk31ogYyMbV5xZvJXuMydUA85nEoA3fGNsBeMwHhaCYX7+tXMwHmfXwCF2ZsI6BN4sy8nov7KucKpkzvxRX1zxef371imuQ9GCaGNnY8RENbnJFmlDZA9DOoOEMQVUSQBpemxCs0ZK6wcqHRmc0PMlpTougyFxqb1qQxc9OI/KtOW4pUnW71jGq/bbyZzgkm+4yABRrZ4YziPJ3YUL2G4Udw8KiKK6b/+cXkujOwFxji+ZEWZkzcXyq83F9RrJphx8UJBT43WRQZ0eYmU1NLgiAIgqhNIo02c9avUxZZpMKfqe3H9xyBbdk0JjbuwufOURv78oQRb+aqQKPBRfKA7HweP+Y2nqCaya9uMr3ujGK/qZkNCKEwo0JnUuPxShowN7T96prnitvvu+0yIPCqGdOVcwENbQwHkWYEQaih4gxB1CLcJGAk0WamGcpKzIXGlt78Gx7euLewx5/QePTnJ2NwXxYbm1O48BNveFyfl9AwiTeDsG27VN9hgSYjOU8lNnSv4QkqOKLARlxkFMecCQzxBFfL8sOIM/OLV8FU1sjSBEe5yQyFyNBBTS0JgiAIorqwWakaerQZI3C0mYj71TMHdufwo3XnAgC+dexzGNDYLZxha2oTz9HFm5mmDsiOAZXja0fR0EGKNEDyNRNgfo0Z+DMA8vjqyynuN4mA9koaMNFJ9oa2OfN6cXZ9ab7hgVWTFa8F3BnaRHwa2hxGmvlKG+AMbeJ3OBnaiFqCijME0c9xEm3GcJahLCM/4FjTla8EHd26GUGExo6tgzF5bf69uo5diYZm9XLlQXXdmNeyDVcOXYEvjHgZ3x7zDP517JP4lzGP4aujnsf7hy7F9KYdAHIVr3XnBBO3AxRoAHuxwV4je51Ikgo1NteSgf7ns4kzcFKY8ev+ChpnJuJ61UyjcEwsospEhoiD3GRdURnRRJoRBEEQBBEO2ok7fgVrkqLN0ty+dOVpeYZA3a/TC3+mtv/52xSsz47G6IZ9+NqFSz0+w2s1gmm8WRipA6r34A7ZGNtcFWmSoJtsryUDf2Y28TVONRPgNmnA1tAmP+/Wi+8vbp/6+6vhf9WMuF+HzOxq8B0RQY9O3XcqAIo0I/olVJwhiConaINLGfFEm+mERmlA8uKh6QCAcwctRz1YHJk/oXH//5yEYb1ZbG1K4bLPvQoAqEMWs5p34fphy/HTwx7HE1Nux6vT/4jbJt2L7419Dp8buQjXpt/GFUNX4Kr0cnx8+CJ8f+yT+NvkW/Dw5F/huvQiNKIP5k6wmAs0Gcl+V0UaIHrR4efzMvAuytgIjNAKM6buL36faWFHtU8tRnK572je18+qGd1rWKSZRmSEnJscRssgfhYAANIPSURBVFNLcoARBEEQRHKIPdqMYWVqU03UqvpR8MfN6N6fww9WvBcA8JW5L2BYqzjYNVlxrSrIyExEtqkDDgs0qsMZxbmuNBMQrW7y+1kZJMDMpjon7KQBk2JNOaec1oV5dWsBAFs6BuP5DeM1Z3utmnFsaEsrthmWPTqDQGkDRH/HtgMvQRAxc3Euh/tSottCwpMAzlIfnou3sARHle2bhjVYjanS80dP2Yjtayfkb8KbUvmbcjvyN+ktyN+0tyF/ExczbdOoHJC1ghtLlT1BaWBROdB/+sCR2N07EOObMrgu/Tr+kDlGOKMBpgOmfZkBGPfWRIwctRHzDrTj/LZDOKphNwbXV37ulp4BWN09DO/2DMHu3lZ05prQkMpiWH0HpjZl8J7WzZjYlME/j3kM7x+6BF/dfAnW94wpXAv7qu1BfvDE75PRUfgdsPP5feL2PpQGXPx+xdtCckoG8gHZQciXOfOvg+K1KrwG/15uHZdCJeNx3EtsiXgKDPEkF8vyZeLVT5yZ/eqYVOpfDM6yWTXD79MVa8LPTdZBTS0JgiAIovaZt3s5Fg+fVbZvBt7BSsyUnj8BG7ERE+w/iGmnUagc56ahGa+K+kncx48ReS3SAHUxRa2jfnfveHx5xgTMatiIf7j0LXz9z8cLZ/C6RaZ/+P38PvEc3fnsOdM8vPZh55hoJvGYBNnhDNSaCVDrpgy3LXu9Cp3uMVnh4Eo3ZQzOsSnKAAELM6YFOkiOmyQN6F6v2id7XQ73nnFz8dkFN18qOSdBhjaGT0NblJFmBFFrUHGGIPoBqWeA3Onq45EJjYHQT3gD8BIanbkB+Pmuc/BPY/6Gr49+HAezTbh735EoFwAiJaExuC6Lo1q24ujWnTiudTuOfnAHWuv68qc1bwMA7O9rxKKOUXi9Ywze6BiFVV0jsT/bzL0XoyQqWlM5XDZ0Gb4w4hUc2bINt068DTdsvA7ruseiNFhrhFx8iEKDx3GBRnVKBv7EBnstQ/YeNoTtEssYnheqwADcLstXxUfInIhQHJe9Xr9qRl2csR1aqPLQgdLqOfFYuLnJFGlGEARBELXHfauAi6cXnjwC4Dyz18lMbYypWI01mKZ/gzbITW0y0jAYrw5B+ViS1womME3CNBJ7LjsnT19XFt9ZfA7uOPb3+PyMF/Ef6TnYkhmg+QwvMxp7b1HDqUxtJgUaRsgFGiC4bpK93oYoEgkyBudEambjj5munPKTNGBiaPNeLXbVlQcxLHUAAPDStvF4c9tozdkxGtoYNj06DQgz0ozSBohag4ozBFEDlAkND2QuMBlGQsOGNMoHoxnZSWZC47bMyZjbsh6XDF2MH4y7H9ekF+L+fUdiaecobO9tQVeuAc2pLoxoOISJjbsxoWkf2hozmNOyE1Oa91Z86sGGeiydksWq8cBLj5+MZ1ZPQrYs9ZEfHMmdYB25FP6UORqP7Z+GX42/F0e07MBPD7sL16y/AYdyAyvO9+cEA0Iv0AD+xQb/Hqr3iYOMxbmhCwwgnMKMGJ2nEgymmcl6gq2aaRAeTRHz1IV/wGnJSxzlJkcRaUYQBEEQhFv8JA54mdoYVokDKpiZTZY8wJvalIkDvCaQHZfhf/XMXx4ehdfnTsWxTWvwL1cswqd+d4rkvWVmG9WqGtVzoNzUpir0iLrJVjOJxxQfAckpmcJjWvKaqM1tLskYnmermQCHhRke2b9j10kDXqtqKvfX1edwx+xfFZ9ff/u5ktfFbGhjpCX7vHp0UqQZQTiFijMEUYVohQbvAos72ixEofGtrddiXfdIfGrE05jbugVzW7eof1CBDd2DsKRzJBZ2jMbrh0ZjbfdgXPr5O7F6QA5HTV6C7DcmC69QCY1KdvYNxCc2XYo7J/0ZU5t341MjXsJPdjKFxwshWyeYnwINEIvYEN+HIXu/MBA/1wuv1Vyq94usMKPCxP3lTmQw3KyaaeQe+Ux0MTdZ9u9XkZvMk9YcC5CbTA4wgiAIgug/RBZtxrCKNot+9Qz6gH984Uw8cuYafGT8K/iPw4/CineHCq9xHW+mSx3gCbFAozslA/W404+5DZr3CwPxs73wo5tM/0kaF2Zk0X0ukgZk+IkzA/7uk6VJkNvfmYU1e4Yp3h/QG9q8YsxkeBjaGPy/y4A9Op1HmnGQoY3oD9R5n0IQRC3g5UiQ3TQZumgfI2Tu9bTqZHHiVZaZCuRQh//dfSbOW/tF/HD7uXj+YBu29AxGV7YeANCZrce2noF4/dA43L13Bv5rx3x8ZtM5WLD6Opy/7ip8Y8spuD0zA2u608ihHtvvPQ4AsHTUIcxbIPt5VYOwyv27+wbg+9vOBgDcMOw1DK8/qHmdauCoW7ItYuswMjwlo3nNQe4/UzKS/4IS5D29rl/3fpEWZly4v3j8iQxG9KtmHOUmW4oN20gzHnKAEQRBEEQyKZvoe0R5mhXaFbasn11b4TkzhzCziGx8ki488uOasqGQZlxUhJ/YFU0w7LgK9bFHnx2ORw4ejcZUH/7ziuc178Hw6msojmH557LxqUoz8duySXtx26dmkp2WgZluMiWj+C8IQd/Tr24y1kxRFGa8kgZM4sy8GTgkh/8cc2vx+ZfvP1VyVpBVMwENbWnFtoiHoU1HYEOb5ruZDG1ELUIrZwiiH+I02qwNwTKUlatnZM8B8WtrZ99w3LJnAW7ZcwLYgCqFQ8ghBf3AvpwXH56C685bjLeGdWPqtS9j8Qterje9E+ypgxOxpGMs5rZuxZVDl+BXu0+C3gkmQ7eCRrRvOXCDQXJapvCY1rzWZjWNSMbzDPeYiKOMYr8v5xfgvjDj5f4SCZaZzHC7aoa9RiUyRAxzk9l2xJFm1NSSIAiCIJKHq2gz34kDpsj6dTLS0IyZZfoJMBu0iqtlDFbPAPjqg6fhrCuX4vzhb+O9R8/Do28eLrzGVbyZbf8Z1WfpVtAAVtHQutMyhce04nW8BqkG3RS7ZhKPBy3MgNvnN2nAe17h+19YX9z+t5ePw5YDg6Tn5aleQ5tJj04vyNBGEHlo5QxBVCmiY0DpAvOI1JE5GGROB0Yx8sfLMaG7uRsNRmUDC4PBMoAcmgpbsoZ6UO5bdPMCNORyeGdwH864WjbRaj6JDaRwR2Y2AOD8wey9bJ1ggHnTQ/G5DzeY7rQMvAWBn9U0UWF6bRmoXV++l+TbFGZk7+PH/RV1nBmPqcjwigpkucn83/0QWIkMhpibTJFmBEEQBEEYIDNfyFbQMjNH4MQBRtrkJN6MJes3AehXzzBkE8DqSeG3l7XidztPAwD85/lPoT6V9bhOr8lt3QqGHsk+2XOv1AHVGByoHKsHWEUDVLduCqqZAIdmNtvCjIj4b8Vl0oCcw9py+HLjXcXn33vqBMlZCTG0MUI0tAWNNCOI/gIVZwiiH5GIaLM0ty9deVoeV0JDhnow9M7iMZi9Of/ZrWe/iVRdn+Qs83izJw5MQV8uhSNadmBswz7J68TJdUie87go0EQgNoDygX0cosPm8zNwKDBsl+QDbtxfEJ57Ob10++ToCzMJFRkBc5NlUKQZQRAEQVQ3ka9g9RNtxjCKNlOZ2rzGZzL9ZGZq++e75yOTG4gjWzbjM2evkLxGHGN6GYNk8WbidlQFGtlxBbrTMrDXTVFjq9ky0Gsm5ykD/HOvwoxXBLS4T8Qmzky+/2cfeq24/dkHzsaB7ibpeXlcG9p4hlSeIpKW7DM0tOlwGWlGaQNEf4GKMwTRTzF1LDjPUGb4EhpsX3hC47Gfn4bBfVlsbAYu+dxCj88BdAWavdlWLO/KzxLPa31XOMdP/xlxmy/Q6PKUA4gNryJNxuytKgb+LsWH3/fOQH/9zgSGeNxVYUa2EsskD9nfqhkzIhQZMtLep1SIDA3OIs0IgiAIgogN7QrVEBIHmMlDtyJXCtNOMid7WvUiWb9OGW5Xz+zYUocfrTsfAPDdE5/BkOYuyVmqSXDdfq/+M17GopgKNC6KNEAyNRPguCjjqjDDYxIBHW6c2fEn9eCK+lIfpl8vPEpylktDm4iYNKIwt/HzMD4NbXFFmlHaAFGrUHGGIGqIyKPNTHEmNHTNLt0Ije3vDsbElYcBADLzVmPIiEOSs8xXGyzpyFeqZrdslrzez1J9cTtkscFOdSU4eGQCwc9/tnhdr5XAYC/gcVmYETFxf7kTGQz3q2bY8wAiwzY3eYzkHCCaSDONA4xEBkEQBEEkDz+JA6amDmPSkn1aU5tsBY1XuoB/U9t/3D0Nq/vGYWT9fnzn8jc0n8EwHYeqnruIhRb3m4zjHRdpMmZvVyRuzaS73sCaSRf/7FWYCTMCGh77RHK445zbi8/ed9tl6MuJ06281nFhaGN/84aGtrRi2wEUaUYQwaDiDEFUMX4m9WKNNmOkJfuUQkP1PDyhcc9PT8bY7ix2NdThfX/3rOL9zeLN1nQPBwBMbtylOS/OAk0IRZqM+VtGRgYhFWX8ZCUD5nnJXoW8MEVGHnerZvjzVAUdVW8pA5HBkBWBLaBIM4IgCIIgeLwmFHWmtiKmiQO6fp1p1Zubmtr446YGNrUJp/tADt9afAEA4PMzXsSU0aLWAPzFm8leb5M6ANgXaBwlD8jeTkYGydRNGZhdV6xmNn7bbwS0F3ZJAzdcuweT6rYDANbtT+OBVZMNPkOFl6FNhmb+xK+hLaoenTwUaUb0U6g4QxD9mEijzXRCQ4nQ/FsrNNytnunqaEL98/llyMsn7sH0o7d5XShUk+EbuwcDAA5r3K84z2upftgFGtk5HtgKjozd2zvB5vOtxQV7EY/O+SU+t81Ltl2Wr8I+zgxwuWrGpkgj+1v3EBlpydsEyE2mSDOCIAiCqF1cJg7I8B1txmBmk7TkmJWpDShfpczrpEbhOCTH9PvufGgknumcjeZUL/79/S9JXgO4iTcD7FIHALsCjex5yMY2RgbVpZmcmtmA8AozMq0tvrdN0oCcluZe3DTzpuLzK/90EYCUcJbfVTOqeQzVSrlkGNq8KDO0eXznMihtgKhlqDhDEDWGUmh4EHq0GUMmNNKVp9kLDRn+hcYjt8zGEfvr0ZNKYd7Hnpe8BjAZrGWyzQCAofWdmvNNHD78Pj8DWBM3WAhFGkZG8p8r/L6376KMbWSci8KMn+gH2ft4vVc5+lUzNiKDP082MaATGQrSkn0Bc5NlUFNLgiAIgqh+okocMDJ3tBUe/fbrVCIbR+lW0MiwM7WhD/jqY2ehL5fCZaPfxOmztxp8hsm4VFWg4bfjKNDIzvHApqiRkfznCr/v7UszsRfyeK1Icl2Ykb1XkKQBuW76/ufWF7fvXjUDi7aospRNUM1VODC0MdKaj1cZ2jRQpBlBBIeKMwRR5RgLDW6yMPRoM1Fo6OCFhlI7RCE0xIFQHd65bQEacjksH9qDM69UDSb08WZ7+1oAAEPru7j9ppPnuol524Gs7LkDscFe4mfQnnH0XxTXWnyxiJfAcFmY8ev+kp3nTS73HY9VM6boBIXqfP7vW3CA2YoMC6KONCMHGEEQBEFUB0GizYomEM2K3TIC9+uU9aEQV8+Ihhn/praFi1rxhz2nAAB+ctHTSEH2c9rGm8me61IHghZovJIHHBjb/L4s4+g/P9fpWzPZrkIKUpgRXxdd0sCUti78f/buPE6Ous7/+Lt7ztydkJADkkxIuG8PDpH7iKLcKCiu4uqqq7u6iu4KP491xWt3ddVVcb1PROSQWxC5BEFB7ksg3AkJ5Jrck8lM//6Y6aSmpo7vt+pb1dXdr+fjkcfMVFdV9wyTUO/+fL6f+vjEK7Z+/eGrDw/Yq84NbTUmI83CMNIMyBTFGaDFZTLazM//P/1CBo3RHrpzlvZ8aZIkqfPY+9XWaTuftl9rBoZWzowt96tDA1u3J1+q793mqkDjKGx4D014eCZSv6awgGHa+SW5KcwEncuk+yvZOLNoaUOG9/NayPDz3+Q2gIO5yUEYaQYAQPPLa7RZapXhj6FNbVHXTDmsnpF07mWv1ZrqGO035nn9/dFhhSrT8WZhb76HPe6iQON/zKTA4D/GQstkJin9lAHv51HTJbzHu5w0EO6777hz6+efvuV1enHNBKPjgjluaKupBHzOSDOgUCjOAE0o69FmqWco11QCttU1aIwu2vzuG4epMjCoxV0lnfzhu0LOG37xtnawc+vn49s2+B5NulTfu820QBNVOJDCw0aKxLBR9Qkezp7XJmDkUZhJ2/1V9FUzQcXX7Ocm1/4dCwoZjDQDAKB5FGK0WZr7ddZUwh7wXzdFjUhyu3rmpcXt+uozx0qSPn/wTRrbZXLdaXKdGjd1IK7gk6ZAE/S148Y27+H1KNY4e96wzJRVYSbo9yDud8PtOLOFx63TsW33bv36v+94dcBedWxoC7pHZxDHDW2MNAPsUZwBmkBeo82czFAOChq1i4VK2EnzChqjrVg2XjMe7JEkLd79eW0/N6iIIYWtVBgM7dIJuzg07QTzbjMp0ATtZxI2/Mel4A8eaU8bdD4nLzWqKGMbMPo9210XZhTydfKQURNdmMk6ZPgLrzEhoxLxUv0hw4BpyPBipBkAAM3PyWgzW0H366yJbGoLur6y6cq3b2r7yiW76LmBaZrZtlqfPi2smSXqmtT2/jNJtoUVaGwmD0iZNLb5T9PwuSlqP//XaQszitnmdtJAua2qXx/0i61fn3bxm7Vpi9n7C8Hq0NBW73t0AtiK4gyAVJ0M1jOUaxooaFz+rQPUs6mqdW1lHf2hWwyfZ+hCruyZuTxYLfn2SbtU37stSYEm6Gsp07ARdtokfzJ7MX6mHXI2AUOyK8x4xRX4/Pv5RRdmqtXPRj5uro4hIyxs9Ax/tP33KgAdYAAANL5CjTabHr9LZAe8pOCmNn9+yqaprW9jWefdc5wk6V8W3KoFs9aG7Gm6mtt0FFXUtXNcgcZm8oBNYxu5aeS+UV+7KMwENTnW9nM/aeDjH3hZk0pDUzEeWrG9Lnts54C9GqihzUImI82YNoAWR3EGwAi2o82cKXDQGBho18orXytJemj7DTr4jWEriEZ3gnn/ka16to88xvSCM2i7d1tUgSauG8w2bGRylV8nUd+PyUoi284vyb4wY7ssP6txZq5Dhl/EzS2DVvFXIl6qAecjzbwIGQAAFE6hR5vVJL5fZ9xNwrNvarvwd7N068bd1V3q1/+99WZ5E1C4qOtW06akpAUa/+euGttq+zZTZpLcN7O5Ksz496ntZ5OPzCYNTJ0xqK9M/eXWr8+6eKEkfxOmjRwb2moyukdnFiPNmDaAVkBxBmhSJl1gdZ+hbBw0/PIPGn+6doH2fmWsJGnySXerrT1s9cHIi7pSybNyZsRFW94FGv/+cRfLNWFFmrBzNJK4ooxtwJDSFWb853DU/fXZQ6VPvT7kuUJ86jDps4eb7RvINGSErYaLkWZucoTUI824qSUAAE2tLqPNairDH73XQaGXTd4HgprapCya2qSSPnjlMeqrduioSY/rrEOfDdnPdLxZ1DG1bSarz10XaEwb28L2byRRzXmmY7FtVylJ9vk4yQho335bc1N4Zvrqex7Z+vn37t1HD73vtIDcVNCGtpT36AzCSDMgPYozQJNI+2afk9FmSVWGP4YGjZDlurkGDekP/3ukJg4M6vnukk796J+MjunQ5q2fb6kO+h41Xarv/dq0QGOyXF8yCxuSWeBohNAR91rTBIwkhRkvm3nJ3n3CzjdsoCp9/nDpU6/T3I61+uWc3+tfp90nfyfj1lUznzpM+vyRQ8eNkGXI8P9dznBu8nAROc2/W4w0AwCgeYQ2tcVwMtqsZ/ijzf06ayphD3ivo8Ka2kwzkH1T26OPj9HXXjxGkvTVw2/QpDF9Ied2cf8Z/z4uCzRZNbY1QmaSkmemNM1sLgszCSYNbM1NhynI0e8Zp3e2/37r1+fu+/GQ3GSqsRraov4ti2toG4GRZsAIFGcAjOJ0tFkTBY2Xn5+gGQ/MkyQ9u/uLmrPrypBzb7vom9A2VJzZUi1pU7VN4W+imwQImwKNlD5s2AYO73FFCh0mr8lVwPA/HleYsZmXrJCvY0LG+bdLn75V+vyROvd1z2r/MSt09pS/6dFdf711z1GFmU/fLJ1/m5JxFTLqNzfZ5UgzAADQoPKYOGAq8f06Q3cclk1T23/8ei8tGpih6W29+s8z747Y03T8roupA2HbTCcPBD3eTJlJStfI5rqZrXaMd1vUf1svy0kDNeffPpSDPn/kqAJNW4f0kx0u2Pr1hwbfp5WfPykgNzV/Q1vkv20e3oa2uH9DgzBtAK2C4gzQxAox2ixOgwWNy791oBZskDaWyzroA7dE7Dl0kTehPFScWTfYqW2zaJPMvrUdfVWTJmzUtiUJHN5j8w4dG2X+3GkChmS+JD+og8tmXrL/2Kj9/Nul2V+5Sse9+6t6+g2vHbH93p1/s+2LyMKMy5Dh5zBk5DA32ctkpBkdYAAAFEvWEweimtpSTxyoieqIlzS6qc1/DZZtU9um9SX98+1vliS9d/adOniXVwyfy+T+M1H7uirQxE0esC3SmBZqipqb4lYExW2zbWbzHmPbqJhg0oB3//NvCyzQnPeJFdqxtFyStLg6Rd/99y8bFmbC0NAGYAjFGaCJ1CNo1LRK0KgOlvX8RYeorVrVo5M368gzH404f7/GDxdn1g5Evaa4pfphb9onLdDYhg3/ObxMAof3vK7DR5JzRr1m0+8/rvPLZEm+d1uSZfkK2M+777bH/neH2/X1O3+gs3/ykxGPdJcH9ORxf1KpFkCMCzNhTEJGuxKHjJpKzFPEsJ2b7MVNLQEAQI2TNyJ7hj+muV+nVVObf1s2TW3X3TJFF/cepHKpqu+efKPaSv4RzzWm4828j6WZOuDfFjQaWhqdmdI0tklmmcl7DtcFG9vzxuU8l81s0sj/blkWZgwKNr4Czc4HdOg/On669eGT3nepBj97m+GkgQI0tNWYNrRFSNLQNgIjzYBRKM4ACJTrDOUGCxr33jZbey2eJEkqH/2AxkzYHLrvhLahC76hlTOmnWBhBZqgi9AkBRr/12Fhw7QjrMa0UBN0zqR/TCUJF2HbXXR+ebeFFWbCni9u+2hRxcH+557XT5/5keUoM5chwysmZCSZmxzBaDSj6AADAKDZ5TlxYGuTiO1os5rK8EerpjYp76Y2SfqX37xOq6vjtM+YF/SxE6IaXGwKNP7taQs0/u22kweSFmnyzExJclPca4nbFtXM5v/adPxz1L5pCzO+/O0p0Dzxxq9s3XzBq96tv363t7Ea2oLeb4ni+B6djDQDolGcAZpcXUeb2aoMf6xb0Ah6ozn4PFd/7UhN6x/Uso6STv74LaFnn1jeIElaM9g5vCVNgca7T1yBJmy5vk3YCNrHex7XhRrXbMYImGwPChguCzN+7kLG91fuEXD+bV57zz36nx99KeCRrEKGlz90WIaMuLnJPcMf8xppRgcYAACF53LiQDHv1+m/tor6usZ9U9tLi9v1mceOlyR9Zt8/aMft1hkfG870OttlgSZJY1vU9pqiZaa4RjaTCQNSfDObd8qAi8KMIr4O2h43Mm/Y+bfpJ586fcSmc+8+urgNbTU5NLQx0gxwh+IM0GQKP0M596BhsnrGZNbrSL0rx2jMH4fe8H5szgrt//rg731q+9CF5/It3tflepZy0GP+bVmEjbjHavwX/FmED5vniAsXcUUZKfpnmLQwk133123rZynOwgkv6D+m/yV2v2Qhwy9oJGGKkGGBkWYAACBP9blfZ41pU5tJHkrW1Paty+bp7s0LNL68Sd858/aI8ycZbxb0mHdb1Jv8aSYP1I6zLdKYFmqyLNjYPIdtDkzbzObdblOYCWt4M580EGTWrC16V/vvt379thO/p97z7w3YM+rvQdDXGTW0VQJOVeeGthFoaAMCUZwBECpV54M/aMR1ZGQWNGyZB41rf7av9ljdoS2lknY86w61tY/uuJnavkmS9MqWsRHP6WKpftz+/u0uw4b3MdMl80HhI80fE0mKTHE/A3/AiFu9lLYwo5DtI1/na8Ys0akTn9bE4XsejSv36+m+CSHn2eb0ytMaU6qd12XIqD0etSw/6GuZhYwEc5MZaQYAALyymjgQxPn9OismO7tYPWPf1FYdkP7x2oXaUi3rhKkP6uQDo773NAWasDfmw97sD3rMf96ga3/T6QMucpPrzGR7z5uox73imtn8X2dRmPFLOM7MY/E/fHPr5w9O2V0X7f9W6VOHhe4/kv/vSpJVMyka2ipmrzIMI82A/FCcAVpAXWYop1WIoBFXoCnpjm8drfEDg3pmjHT6OX8ctce0tuGVMwOdvkfCLwKTL9WP29+/3TRshBUoXBVqsmIyZzmqKOM6YLhYlh/f/bVb10r9ZPZNOn/mX3TT/Ct1zrT7dcv8K7RT19rYYyWpszQQ8WiSkOHnDxmB1deRm9OGjIi5yYw0AwCgNRVutFnP8EfT+3VWAs5hdL/ONKtnvMyb2v76wDhdsOwoSdI3j7lBY7uirmnrWaAxmTwQ9HVQdqjtF5WJipibbB43aegznTKQtjDj4D4zHh/+uyUjvj79cwdJn7lF+vyRvgJN1Ioyl6tmor5Wrg1tjDQD3KI4AzSheo42i2QaNILUJWjEW/xURTPuny9JemrXpdp1/5EXcbWxZq9sGaP0S/XjCjQ2y/Vtwob//P59TQNH1sHD9Hmi9nEVMLzbo4pnSZflh//ubN++TuXS0Odjy1v0nimPa1w5qhi4zfn7v1W9g11yGzJMVs2EdIBVIk4bFzIimIYML0aaAQAAE7ndr7MmdvSr9zrL4E3eTJrapPMu2k8vDmyn2e0r9IUz74t70RGyLNBEbU/a2Obd1yY3ZcUmn6XJTNLozORiyoB/37BzRG33PzZSuatN39jpoq1f/8etB+rJlZOH7jXz6Zs9BRrT9xnafR+L3dAWhJFmQHYozgCIFNcB4XSGck1l+GPioJHn6hnpt//7Wu26rqy+ckl7vPePUmlw62PTRt1zJssCjXdbXDeY/ziTsBF0Dv/+JkEiKBDYdGvZHBd0jiBpA0btHN7HvNuCAoj7woy0RY/3VTQ4/Ffvd2tmy8bmU/aUPnWE4d6uV80EzE32qgx/jCrievUMf7QMGV50gAEA0FpcTBwwlep+nUHXQ5Xhj8ZNbd6vs29qW9db1r/86XhJ0ofm/lH777YqYu+4N9JdFmhsJg9IwfnAtEgTtn/Ufq5zk83z+kVlwaBz1NgUT+IKM0nuzWln9r/sP+LrL91+wLYvRhRoXu/ZK6yhLazBzS+jhrYaf0NbBFf36HT1byXQ7CjOAC2icDOUw4JGkMrwx4KunpHKevB7R6h7cFBPjqvq1A/ftfWRqW1D95xZvqXbs389CzRR2yWzsFE7R1zgSNrx5Wq1jck5bIoyUQHDZEm+Ah4L2+Z/jWHPO9rLW8bqxnU7SpLeMDH+7+LbnjtGe/ztDB321Em65MGp0ucP9wQNk5CR8aqZ2CKtGdORi4w0AwCgdRR+tFmcyvDHBlk9c+lNM3X16n3VURrQj07+vdrbN0e8ZvOVDsFMCzRBj/kfD5o84KpIU7TcFMRmikKSZjYXhZko5r9Lk889SK8vD/1dHqyW9PofnaFNW3y/z+ffJn36Vk9uMnnvwNWqGcuGtrD3WQp6j06mDaAVUZwBEMj5aLOe4Y9xQcM/Q7lBgsaTD26vnR7dQZL04r7Pau7uy9WhAU1u75MkvTIQsgw5UJoCTdLl+knCRu08YYHDe1w9luiH6Vfw6zYtyiTt/HI5LznItsd/vHK3mH23OXTcS5JKWj4wRvrCnZ6gYXKzS9uQEfZ1TMioyXFushcjzQAAgA3r+3WmHW1WUwnaWJzVM5L0j786TKsHx2m/rmf1+bc9ErO3y6a2oO1hBRqTyQNSfHHCe27b3JQVm3xmW5RJ2swmRf+3qW2P+m/hYJzZ/ztUMwbX6csd35ckfesv++qOF3YI3vn82z256eDhjXVcNZO2oc3xPTpHCGloAzCE4gzQpFy+GehktFlalaCNxQoal3ztEM3fIG0ol/WaD96s7dqHbsDeXy2pd6DLt3fapfpB20yX60eFjbRFmqjA4T+H7VL7NEv0o15f2PH+bbYBIOi/R9LCjF/0788Dm6KWoo303imPaWLZ07VYCxpttUsElyHD/3czJGTU1EJGJeapwqSYm8xIMwAAWpPr0WaZ36+zEnBsYLNLMZraXnx5jD5y19B4s4/P+4MO2vOVwP22yaNAk6axTQrPTGlyUz0yU9zry6KZzft51oWZaLO3bNS/tl+kHUvL9ezais77w+tD9hz+PT//dunTf/TkpjAFbWiLkLahzeTfSKYNAEMozgCtqp6jzZo0aAwOtmvRjw5R12BVf5swoLeedY8k6ZUtY1VVKeAI1wWaoO1BF71h5wp7HTZFmtr5bIo1fkmDRNTriHqeuO1BASOu2GUSMPzbo1bG2HV/lTU4aluYzvKgDhq7TCN+x8//s/S5O2KOzCBk1FQinjYsZEQwHWkWh5FmAAA0n6xHmxlLer/OmsimFtOmttq2bJvafvb7ebp81avUXhrUz0+6Tt2dmyyfz3WBxrstKjPZNrZFbfeesyiZKaogY5IH435e3sf8++RRmAnPW1PeMFevLT2us9tvkCS99/Kjtb6/M2BP39+N8+8czk0N2NDWM/yxIPfoZNoAWhXFGaCF2L5p6Gy0WSGChsnqmaDHwi6mgrc/fM9szX9sliRp84yXJUnLtkS1teRVoEkbNqToi3LTMWJJgocJ/3PEFWRMg1OSi38XhZnk3V9jSlt0405XhT5+zpKD9dbnjtUnlhyk3oFOrRto19ObJ4fsnXPI8D7s/3sfdONbL4cjzbwhg5FmAAAgCevRZlHimtqCJGpq8+ehbJraJOkffnGIlg1WtKBjqb521n2KblSKuwbOokAT1IwV91xSfNYwbXDLKzeFsc1MJiuNwprZXBdmgkSMMxvXrqn7TtJXOr4nSfrBvXvpD8/MjTmfFJ+ROtSMDW2MNAPcojgDNLGsRps1ZtDwCwoaUnDQiAofo/3mqwdr1/VlTVg39PXLo0aa+cWvhAh/PI+w4SpweJ/D9Z8oca/LNGDE/cxMO7/825MVZoZWx4z+XZnctkkzOsL/G0wo9+vhTdvpmrU9ev1TJ+uQp07RU5srnj0KHDJsZDk3GQAANDXb0WZ1u19nTWX4Y6KmtqCvXa6eCd6+YuUYvf+PJ0qS/nHOH3Xc/ktlV6Bx3dRmO3kgSWOb/7FGyk22mUkBj/k/jyvoJC3M2OXr2f+0nz7afol2Ki/V4vUT9PEbwu69afp3wnS/xm5o82KkGZAcxRmgldVjhnLP8MekQSOQi9UzQZ1gQeI7waqDnXrggiM0fe3QaKlJ+65WfMHFz2ZJdpICjf/zqIvq2uNJA0eS5fVpmRZk0gYMf+eXdx+XhZltprev10VzrtPDu16o3827Wu+f8ogqbX2SpFnt63TDTlcHHidJ/770Nbp8zbytXw+orH55i4emv/tRgv6O+c8bEzJqKsMf6zQ32YuRZgAAwEQh7tdZY9TU1uH7Os3qGbumtitumaofLx+6r8ePj79Ok8b0qT4FmqAGq7Dr/KRFmrSFmiyYZDaTzCTF/5xqH6N+1v7t/m1Bz+VnV5iZdup87VNapH9ou0aS9I9XHqXevu7IY4Y0V0NbEEaaAfmgOAO0mDSjzZzMUE6rMvwx9eqZuIsjyaQQE+aph7fXrOfGS5KenrNB+x62WPZL9V0XaEy6wdIUaWxvTOmK6bmjHrcNWmGdX1kVZrY99pGpD2ifMSskSXM61+kj0x7S7+ZdrXdO/pvePeVvKgfd3mjYxb3z1V9tC99hKxchY4xGhwrDkBFZjI3QM/wxxb83LkeaAQCAxpDVG4N1uV+n8XVU0tUz7prapHZ9+Bev0bMD22tW+ypd8M4/R5yvJm2Bpj/gGO9xcUUF0+cMeg6bvFKvzOTf1yusKFOvZrag5zPXPqlTU3Yeq//p+I7aSlVd+NCuuuqJ+WF7ez532dCWYtVMTWX4Y4qGNkaaAfVDcQZocnkGDevRZnUJGmEhwv0c5b6nh573lQkl7XDWnzRm3Ga5naUsmRdo4rrBXBRpJLsgEXYTS9s/ps/ht0XhAcO/n/cxm84v/2NB57R5TNq7e7kk6RNLDtK/vXSQHt9U0cS2fn1y+/t01uTo6uu+3St8W7IKGf5tXoYhoyYuZESo90gzOsAAAGhcdRttlvZ+nTWV4Y8jrrPCLrpsVs8ESd7Utq63U+/+w4kaqJb0thn36IzDXlT81IE0BZra41E5wHVjW5rclEdmins9pkWZPJvZ/I/HZeXRdvjAPvp/7b/U/PJLWrJ+vP75uqNC9oz7va/TqpkMGtoYaQbkj+IM0OrqOdosqcrwx9igEbRM38t7kWTCbo7y9PZNkqT+sVU931XSKf/v5uFHXC7V9+8TVqAJe8x/0Rx3MZw0cOS1TN/kOaPChW3AkMIDRu0xm067+P/eAxpaGrNqoEtXrenR6c8dp88te3XAc2+zcXBotczbKt4r4SxDhsEqmTC1kFEx3D+juclejDQDAAA2kow2c3K/zorBi5MUPnEgzeqZ5E1tt9w5Ud9cfLQk6TuHX6uZ0zYp3VhoyTxXmVybp21sC9rPq4i5KaqRzfT7D2pmk8L/2+RXmNnu+B4dWb5Pf9d+oyTpXZct1MqNJpkly1UzXsVuaGOkGeAOxRmgBWU52iwyaESNGnIeNPJcPROkqu3bN0iSNjy4iyTpwVlrdezbH4k5TsqvQBO22iPqYjtof/++pqEjTfiwPZdpuFDAflEBI6rza0vIdv/zmzy2bZ+VW7okSdOGi3+DKuv9Ux4N2HfIhxcfone9MNQF9oYJL2hsKei8eayaCQn+tU8rEaeeGvGYn6O5yYw0AwCgtdRrtJnT+3XWRDa7RN2v07vde11nMhLay66pTWrXJ3++lx7ZPFtTyuv047/7o6Sq7KcOuCzQJGlsC8oQNs1tNfXKTVGvLWpCQNDjcUWuoMeCzut/LOhxv+jHy2PbNXfvDv1nx/9Jkv7nrlfpxqfnhuxt2tDmb2Lzfh5036aaqDHQERq0oQ3AaBRngBZQ7xnKgXqGP2YWNLyyWj0T3gk2odyvseUBSdLlv91b+y4eKhqtP/oR7bhgldwv1ffvU7vwjRvblTRsePdJUqjxymJZ/hYlCxe2AcO/3fuY+8KMJD21eZIkabeuVVsf6SwNBuw/5MSJz2pp/1g9v3m8OsuDOmDsywr/3U0aMryiVs0EdIB5+f9+T/M9Xu+5yRboAAMAoMlkNNoskMn982ya2qzu15l3U9uQzZvL+rtLF6qv2q6FEx7SP55cu5arV4HG/5hJY5t/X+8+aTOTlH1u8osrMHn38x/j3y/qv1Wawoxt4Uaa+uYefaXje5pWWqOHXp6qc298fcieYY1nYZMG/PtGFXaiVs3UWDa02ajTPTqZNgCMRnEGQL6jzdIEjSBOg4a/myUsaMQXcqYPr5rpHehUX7Vdl37xCM3bVNWatrJe+9GbVSoPKvsCjXefuG6wqLARV6QJ2s//muKKJWmZnN8mXJgGDBeFmaDX6bft+Ic2TZEk7T1m5dZtJ609LvRsx0xYrBt3ukpzOtdJkuZ3rQvZMypkBBVpvOKW5qdYNWOiZ/hj1nOT6QADAKClpHkj0fsGZib364zjZPWMVzZNbd597ntioj7/6LGSpP/c52rtsqBv+LEiFGj8j0dlptq+SZvb8spMcbkp7Niw/Wya2WpTBlwWZuJNPGCG3rvgHh3Tdp/6Btp01qVvVN9AmtVgUnRWslk1s/VVRr+UsIa2sJFmlg1t3KMTyBfFGaBF1W20WRDToFEZ/phZ0EgiOGjURpot2zJWUrs2re/U0z88TGMGB/XkuEGdfu7tw3vWs0DjP960QFHb1zZw+F+jyz9hTAKQf9+wr6MChvdxkxVHQcfFnXfIfRuHrr737l6hceV+SYM67pN/0Ns/0TZiv8OfOkmnPbtQ923cTp3lbStrHhku7gwxDRl+GYSMmsrwxxzmJnulmZtMBxgAAM0hjzcKrZvaehw9cUGb2kZq15cu3VV3rpuv8aVN+unpv1dbZ+2xLAo0QdfhcStfbBrbws7n3TcqN9UjMyVZCZRHM1uSHDxS58xxetXh0qfafyFJ+uSNr9dDL4d1grq6P2fQ8c3R0OYV2tAGIBbFGaBFFHqGclqFChpD26e3Dy0hX7al9nzteuTu6Zp9z06SpL8tWKaD3/TM8GN5F2hMw4b/saCvw87p3de0YONC3POZdHzJ97VJwDDp/PKf17+PWcB4sX+8nt08QR2lqg4au0ynfuRuPTRls9Qm/cPsA3RFb49Of/Y4vTIwRo/1TdZZzx+rtz53rP775X313heO0l0bglqn6hAyairDH8eNfihSwrnJXnSAAQAAKw5Gm3lZN7XVmN6vswGa2vwGq2W94xdHad1gtw7qekJfedfDhudPUqDx7+PPRVHPkTQzmeSmrCXJTN7jwr5O0szmPy7J40H7jFTqbNPME+foG90XaExps36/aI6+cderQvaOuo9m3L5+zdvQZvRvneffTBragGAUZwAMKcoM5UIGDfs5yrWVMy/3jx2x/bLvvFr7rujSYKmk7pPv1vQdayOmsirQpAkbLgOH/7g0ASTsHGHnMQ0X/m1h3XT+cwc9ZluYsXP7+hmSpFN2flLP7Pe8JGm3x2fqjht30rlLD9KjfVNG7P/wpu30o1V7608bZnm2ppkN7iBkBBVVK8Mfw0JGFMO5yYw0AwAANrIcbWasZ/ij7f06/QrY1DZau55eNkEfvPV4SdI5s27QW45eNvyYbWYKOsa2sS3oOWwb22xzk23eCeMyM8UVZUya2aTw4xTwuAweD9pntO0WztHHp16tfcrPaMXGbp19xUJVVYo9Lvr+nN59wu7P6VXMhrY4NLQB2aI4A7SwQsxQ7hn+WOigESZ8jvJ27ZskSSsGun37lHT1F47WzM1VvdJR1mH/doPK5YHhx23frDcpNoTtF/Z8URfaQY+bBA6T7ysqgNgGkqjnjgoXpp1w3udQyGN+bkPGH9cP/YXZr22Z+iXt2duuX3/ldaH7jxY2qkwKnplckJCR0dxkRpoBAICaQo42C2Lb1Bao/k1tYX5+W48ueP5QSdIPDrlMu/ckbWoLOsY2M6Ut0oRt8x5b79wU9jxR2+KKMop4PPvMVDNu7+30xr1e1AfbrpQkvf+qY7Rk7YSQvW3GmZn8Phe/oY2RZkB9UZwB4JSToFGTa9AIK8gk6wSb1DZ04bh6oMv3SLt6V47Vsp++Xl2DVT02YVBnfsp7NRO36sTP5KI2aD+bsBG0X9TFuknoMA0fcUzOGfa6TIouNgHD+5j/eYLOH/R41Hm2uWfDVG1pk6aslfZZNqgbv3isqpH/S7edmRwng5DhZxIueoY/ZhAykqIDDACAJmc5cSCz+3Xaqgx/LFhTW7ChfT78s1fpzo0LNLG0UZe9/WqNH1e7Rs6zQBPV2OY/T9rM5D1H1pkpTW4ynTBQv2a2mo6p3drr2LH6Rse3VS5V9X9/3VuXPrZLyN5Jxpk1d0ObFyPNgGxQnAFaSOybhvWeoVy7eDCYjzpCZfhjqqDh5b3ZuY1t+1fKfZKCijNDHrxjpmb9ZYEk6aF5K3TMmY96Hs2zQGMSNmwDh3d7XMdWWFAw/RMm6vnDvieTokxUwIhbyZT0v9NoJ//rn/XwnKFl+PMum6sVy6Ku1BskZNS+9t+T0z832XHI8GKkGQAACOLqjcVE9+usNaH0RBxYiNUz7sebbRlo0+k/XqilAxXt1rFEP3vPHZJqmTbLAk1WjW1RhZqo3JRHZjIpIplmwTTNbEH7JMtM5e427XDqXH133Lc0ubRO9yyZro9cd2TscUNMxpnFyaGhzUTP8EfDhjavrEaaAdiG4gzQ4vKaoRz5P/0egyfLdfWMV7I5ypPahoozvQOdAa9r6Nirvru/9lk6XtVSSeuPfVgL9l7h2SdJgSarsFHbJ0ng8D9musTehsn504QLyX3ACNonbL+RDj3pKT2621I9sNNQcaZnycbYY7YxDRlRhcmChYwEGGkGAACi1Gu0mbG6NrV5rwPTN7XFWfJKl864+mT1V9t0yuS/6twzvRdgrgo0eTW2effJOzclzUySXVHGppktbL+ox4P2CVCSpp64kz437RLtV16klRu7dfrFb1bfQNxEAZPtJg1t7b79M2po8xdmEza0eTm5R2cMpg0A21CcAeCccdAIGm3mDxomo42kDFfPBImfo7y1ODMYvHKmtt+l/3GUdtoorW0ra7cP3ayxEzd79rEt0AQdY7OfiyKNd5+oQBEUDpL+MTm/l21RxnXACNonbL+R5u2xQoMn3quBUknrpwzdz+jVY5arszQQckTSkCHP5wUKGUEYaQYAAPJWr9FmPYavTxrd1BapHk1t5uPNbrt/ij754PGSpP/Y5Vodd/BKzz4uCjRR+7lsbLMp1ATt4zozRb2GoNfsupkto8KMpMrhO+rM+Q/pne2/12BVOuuyN+q53kkhe5tMGggb6ef/PGwfqcgNbc7v0clIM8AYxRmgxbgebdaaQSPIyKAxrjx00bhuMPqqqm9Dpx76xlGaNDCoZ7ulk/79ekmDnj1MlqH7Je0Gq+2b5Lmjls1n2QEW9hx+ScKFbcCo7Ru3X7KQMW5Cn/b88M3qbStr3qaqfvLN4/Tylm51lwe0/5jlAUfYhIyo7q+gY+oUMmrF257RD7mcm8xIMwAA4JfFaDMn9+v0N7WFNbf5m2Iyb2pzN95Mkr722wW6+JXXqL00qF8cc5nmzu7z7JNlgSZoX9vGNttCTV6ZyaYgk6SZTTL7OZtOgTAzbo8peu1B/fpSxw8lSeffdqB+99Q8w6PjJg1ErfoKW0lGQxuAcBRnABRjhnKQzINGh8KDRtg9NryP+Y/Ztm/H8EqGzYNtIefYduwzj09R25X7q71a1UNT+nTmuX/07VePbjCTsOHdNyxwmASVLDrBwl6DTbhQyH5pOr+G9nn1mGX62szbdOnca/SjHW/Qh7Z7SHt0rVS4QZ34uRv0bLdUGRjUI988WuvWdOvO9TMkSQePXRpxrJRuZrLjkFFjGjKi1HFuMh1gAAA0t6QrYkc0eVhIfL9OZ9I2tXlXXJuyGYdW0tk/eJ0e7ttR08prdMlZ16lrjGlTm+S+QFPb1zY3+c9Zz8wU9RriCkQmzWwupwyE7TtS5/SxmvOG7XVBx9c1ttSnGxbN1eduPTjiCJfjzLwK0NAWoAj36GTaADASxRkAmfB2XBQzaPi3hQWQ2td2QaOjNBQU+mP/mR1eqn/FztrpgTmSpEd2eVkLz3rYt18eYSNNkca7f1ToMDlPElHnN+k+M1lVYxowavsG73PqxKf08zm/1xsmPq/du1fpoHEv60NTH9ElPTfo57Nv1JHjFmvbjU6HvO3c2/XQlD61V6vSFa/W049OkSTduWGognnQ2GW+5ytwyKiEvLQwOcxNdoGQAQBAiwmZOOAVNnEgVVNbT8SB/iaXsPt11r522tQWxO3qmY2b23XKz96o1YNj9Zqup/Xd99zj2y9pZrKZPGAyfcC2uc17XJ6ZKW1usslMtf39+8btE7XvSOWx7Zp26k76r7E/0PzyS3q+d4LefunxGqyGZfIk48yiFKyhrWf4Y8qGNu7RCWSL4gzQgvIebWasJ+Ix50Gjxn+BFBc0zOYoby3ObL0QjL+ou+zrB2rvpeNVLZW06uhHtd+hi317uCzQpO0IS1qo8T6Pyz+2z+9/Hf7jovbxP0/cvtv2mdm+Xp+afrck6be9PfrAi4fps0tfo+vXzlZ/taxXj12ub+/4R/109k1bV9Icd9ZjemSXoeLLvAdm6/Yr52893z0bhv4i7N69Wh2q3XcmaciIGtmXc8iI0mOwjxhpBgAA3MvijcZUTST++3WmFnct576pzezaddt+Ty2ZoHdef5IGqyWdvd3t+sfTXvTtlyQzhR2XtLEtbJ+g89tMA8gyM9nkpqjXG8RkyoAC9ok7r0e5pGknz9f7J9+kN7X9RZsHynrLb96sFRtNl6GYjjMLa2gLum9nDQ1tAMJRnAEgqYAzlHMJGlHhw39xFXX/Dclf1KkVZ7aM6NKJDhqS9JvPHqNd1pW1oVzWdu+8QzPmrPPt66obLGzfuP1tA4f3ONMLf1tJzm9SlJFcB4wzK0+ouzyguzdM03lLD9Rt62fpN70L9NElh+joRSfo+yt216bBNr127Cu6pOcGfXffm1Q94EENlkraa9k4Xf71A0ecb8mWcVoz0KGO0qDmda1VdLiNCxneY+ocMjKam8xIMwAAYMPFaDNvE0gcJ/frzHT1jHebm6a20dvjXfWXmfriY0dJkv5nz8t18GvX+/bIo0ATtb/tBIKgY4uUmUyKMqbNbLX9TfYL23e0KUfP1nFzn9a57RdKkj56/RH6y+KokRw2v3NhhZmg/aLGAJq8D+HRhA1tTBsARqM4AyAXxQwaNXFBQ76vg8KFV0l9g22SpO5SyC6jDJ1rS1+7bvvCcZqxeVDLOso68JPXa+yEzb598wwbNkUam6X3JjORTf+YCgsX+QWMQ8YtkST9evUCSSN/OZYPjNH/LN9Xxz9zvK7snStJOmzTy/rK96t6381VXfnZo0YdI5X09Oah3+k5HWtDXq9kN84saL+okBG03aBDrTL8cVrUTsNquapn9EOZzU22QMgAAKBFGYw283I22ixIXZra/NeJYU1tbsebSdJnfrOPfrdyT3WV+vXrN1yqGXP9+5quWjE5zmQiQNxzNEpm8r7WoO8jar+g1x63f9zPNt74fafqta/epG93fEPtpUH99P499J279404wrahLeo8Qfep9Te0BZ2LhjYAQyjOAC0qy9FmzR00/IKDxvrBoW1jy/0yCyDbLH9pvJZ+/1CNHxjUorFVnfD561QuD/j2cl2gSVOkSVuoyVLQ63ERrvzHBO070pjSFu3StVqS9NeN4RWJpVvG6Uv9++qbbyvrbztI3f3SMXcN6PKZN+jo8S/Kfz+a9YNDv19jyv6/Q0lnJtuGDMtVM+NG7yJpdPHVQcjwsp6bzEgzAAAQIOvRZtZNbVGLA0yb2gKZNLV5RRVkTNiNN6uqpDN/cISe2TxNs8vLddlZ12rc1Dbfvqar/P1sJw9EHVM7rlEyU9pGttqxQceY7Be1/2hj5k/Sbgsn6ked/6kJpU26+Zkd9Q9XHavRTW01JoWZoP2Dxpn59/P//bCd3hGgMvyxzg1tALJDcQbAVvWYoZx50IhcPWOyAsB/4RU3R3lo3/WDnUMvpxx0wRnfCfbw3TPVecWr1FGt6pFKv952/u8lDfr2T9MN5rJI4z02LHTkFTzCnjOuayxpwDAPGXM7V6m9VNWKLV1atmVs6DO1d27Rws9er9t7yvrq20r6atueWtY/RrM71+t/d7hdF8+9QQvHP6+u0tBzjy8PfRyoegNImpnJ8u2XQciwUbC5yXSAAQDQWqxWyDq+X2fipjaTsUdeoy7f8mtqSzPerHdjl0762fFaO9itgzse16Xv+b06x7sq0IQd28yZSUqfmVyMMTP7OXTtMF5zTtpBP+76b80qrdRjr0zRqRefqP5B/++AiaiGyqgVXv4M1a749xkar6GNkWZAdijOALDmcobyCFkEjVHCbmwZ9Hlc91bQHOUhtZUzE9s2e/aJOtfo57ztygXa4Y5dJEkPzFqnt/3bHQH7R1+4Tiz36ejxT+tfp/1F39nh97ps7m91+/wLdcf8C3XNvF/r2ztcr7+f8oB26FjjOSpu2bvpRXnYMnYXAcTkXCbL+F0HjPCQsV3bJknSK1uiiheDOvPzN+jRiQPqGqyq79cH6YeP7q03PXO8vrtiD20cbNNe3av0Pzv8SX9acLl+23Od9h2zQoNV6f7A1Tim3V9etiGjxiJk1L72z02O+vvdE/GYh5O5yRYIGQAAtLgUK2zj7tdpLKqprSbz1TOmTW02BZr4zPTQ4u100sUna1O1XQu779fP33+7yp3+t7rSFmhcF2m8x5tmpiS5yTYzZVGUsS3MmOmYOkYzTp+nb4/9tvYqP6uX14/Rmy48Was3dUccZZLJw34X/RkqroAT17xGQxuAIRRngBaWdLRZmLDRZkH/s4/r2Ngq86CRZvWM18hOsGVbhl7EjPaNvn1MbNvv6h/sq90emyVJemj3ZTr5ffcG7D/yInZq2wb93eTHdNGc6/SnBb/R/+5wq86e8oiOGP+idutepSntfZrc3qd5nWt05Pjn9fFpf9Hvd/q1vj7r95rb0Tt8FpPZxCYhweZi3+ZP3POYBKU8AkbtGGm79qHizIqBrtA9z/zsLbp/+gaVqlVtd/PuuvvGoeHZG6od+ubyfXT00yfqO8v31NL+MRpTHtAuXb0arErfXL6flmwZP3yWJOPM0oSMDFbN1IqxQX//85ibzEgzAAAQIes3HlPdr9Pf1Gba3JZ69Yz3c5OCjKn4As3Nf5ult111krZUy3rr+Dt1wfvvlUYtnrDJLkFsRx/7n9e0UOMyN5k8VxTTn1nYsWHPbbp/sLaJndr+rQv0ufEX6ui2+7Sxv10n/OpkPbO6EnGUq3Fm/nNkuGqm9nUdGtoA5CfJ/xkBNLGrnpRO2Dm788/XU1qkBSO2bb/TC3r56dkjd+yR9GzISaZKWu75GGeMpI2RG3zbvJ+3a9uFYu2fzP6Qr7d5vn+ypOc1r3O1hi7Kwsabxd+c8ZKvHKwzv3CjHt6hV88evEhH9Y7VTb/ebcQ+ZW3WYeNe1tsm/02vG7tUbaVtb2Av6puov2yYocf7Juul/olatmWsqippStsm7dq9UoePe1EHjV2i4yY8q8PHvaCvvHKQLlq9u4bm9NZeX9Sb+94L6bj/rcR/vyOfz3T/OKYX+0nmHpsdM2Z4/FhtVZXfWz9xhx6et1KStOv9s3XJz/catc/qgS59a8Xe+taKvbRzZ69mdWzSos2T9GL/hOE9ko4z82qMkFGPucl0gAEA0JpOqFZ1VSnsHhY+N0k6aujT0q1S9fChz/db+Zjun7K7pKGmtge1t6ShJpIntGvo6WbrBb0gX1basSq96Hs9MyW9FPPapkl6RUPXY6s1dH223vN1oImS1mhkRvLmmLDtta+l4LzQr5FZy3tsu8yu37ft99v75uofxr1ZPz76Sr1vyk1a8Z4xOu/7u/pv1zi8v0lmCbpm9+dCk2OCjg87h/98JoqWm9Jlpjjlse2afsYuen/lD3pX++81WJXOuuyN+sviqK7OqJ+13f1hg+/P6c9POa2acdDQ5sVIM6A+WDkDIJFcZyjXLjai5qnmsnrG+3n0HOXHNg29oL26Xw7Yxyu+E0wq6defOlp7ruzSllJJvQsf0iFvGloVMKncp3dPflzXzbtG39nxFh067iW1laq6f+NUnb/stTr8qVN1wrMn6vMvH6Df9O6s2zdM15ObJ+ipzZP1l40z9fNVe+q9Ly7Uyc+erDvWz1JXeUCfmX6HPjntLo1MMiadVZK7ecmmz+fqtSRZxm93TMfwz3NLdfT/ek/54F/16J5DaXqvJ6fpkm8cFPN6S3py81Tdun7HmMKMArZF/e4WKGQESRAyvKxDBgAAgAlHK21T3a8ziNPVM0HbTFfPxN+zc4jteLORfnL7Ap1z1xskSefOvEb/+o7nQvZMM+Ys6nibDNNsuSlutU7Uc5spdZa1/Vt21punPqT/1/5LSdInfn+YLn/ctrvUdJxZkvtzBp2vuA1tjDQD6o/iDNDiXI8287KZoVzcoDFGwbNmpahRUPdtHKoo7TtmmcaVNyv+jfMg2/arVsu67Nw3ao81beorlzTt8Hv07VfdpJvmX6lPbH+/ZneuV+9Ap364cjctfPpNevvzb9CFq3fVKwNhN54feYH81ObJ+ocXj9N/v/waSdI7pzysj0/7S8hxSQKHi+Bh+3wmks5Wtg8Y7aVBSdKARnY5Hv+uh/X0a5+WJO394iRd9IVDI8699WwG+0jxIbggISNIrSjbM/ohm7nJ1iPNvOgAAwAAIfIcbea0qa2mMvwx16a28Ht2xjMp7khfu34Pnf/QkZKkr+x0qd57athyoqwauLzHJc1NWUqS0UzGrUUdF/VaDLWVNO2UBTpg5kv6ese3VS5V9Z2799XX7nx1zIEm48ziCjP+4/2/iy3Y0AbAGYozAEapxwzlEYKCRk2uQcPfCRYULPxfD12YPdc/Wc9srqijNKhjxz8dcIxpJ9i2/fr72rTiP16jL/xyQF/8YVVHrn9ZY8oDenxTRZ9a+loduehEffWV/fRC/wQlCxsl/WjV3vrU0kMkSX8/5UGdPPGJmGNtu7RsZyKbHmsbYtIGrajXGKxWlCl5ViQddcbjeumIxzRQKmmvV8boV58+Wsn+15x2nFmHRgeQHENGrUhjGTK8GGkGAAAKw9PUZjJxINemtmlhOw7LtKlNvseD7kUT1dRmVqD59GX76YKnD5YkXbDXr3X6cWGzsF2NPk5yL8soaXNPHpnJ+zxhxyY5LkBJmvrmedpn3hr9sOO/NKa0Wdc8MU8fvu7IoQdD2d5nJkxQU2aLN7Qx0gxwhuIMADspgkaQ2A6OnuGPQXNUMw0aURdaZkHjt71DM6XPnnK/2jQY8+LCCzRdpS16y6SndFXPdfrf7e7Qzs9XNViS/rxrSf/1tpI+s/2+uqx3vjZV/a8jWdi4rHcXfWv5fpKk/zf9Ds1qX2VwvIsZxy6LL0FMb7AZdXyU6Ne4eXDorqRdwytojn7r41q98CH1l0rac1WHLj5voRQw8my0pDezDDtXWKFRqnvIiJHp3GQAAAAPq4kDjqRqPjFpaqvJrakt7XizKCPHQn/o5wfp10tfpfbSoH5+0K+08Mi1Ice5GINscp6kzW1Rz5VFbjJ5ndk0s4WZcuwc7bv7Bl3Ycb6mltfqr0u215mXvkkDkdkpyaSBpPfnrAlraAv6u5OASUNbDBragOKhOAMgt6Dh7cwIuiiI6+zYKs3qmcDroaiLpeRzlC9evb96B7q0S9dKvWfKfcNbzTvBdu5cpU9Ou0c3z79Sn5txj+Z3rdG6gXb9dOUuemvv0brk+Dbd3dOmye//o/Y+OGq5vn3Y+O6KffXXDdM1rrxFX5x5u0rqNziPy8DhStpw4T1H0uOHbB4OD12lAR391se16g0PaXO5pD1Xd+iyT75Jg/0mAcK0MBM2RzlsmX5QyPB2QSrg8xxCRs/oTbnNTaYDDAAAxEjyRmRu9+sM4m9qq4Tsl1lTm5/L8Wa18w2pqqR3fP9QXbdyT3WX+nXxob/Q649yUTQwadhqtNxk+nqybWYbpSRNecNc7f+qzfpV5/maVl6je1/aXsf+/DSt29wZcWDYODL/46bjzMLuz1kzRvF/H3yyamhLMNKMhjagvijOAAiUZ8dDJkHDWNAKAJOg4f08eI5y7+AY/dfLr5ckfXjqn/WeKfeqrEFFFWh6Ojbo7yc/oovmXKcr5l2jd055XJW2zXpx8zh9+eX9deTTJ+krr7xKj740TXeff5zmbqpqVXtZE9/7R+1/2IsR36fdPVgGVdZ5S1+vDYPtOmDsUp04sfbmt2lHVr/qEzpsntfFCALzgNFXHVo5M2PGWq0eLszs0duhy/7tTervS1KYCZNmZrJ/v4xCRhRHISOMN2SYogMMAABYC5k4ECaT0Wa15heT+3QmXj2jgG02TW0ux5uNtGWwTadecJRuXz1fE0sbdfnrf6pXH90WdYTRed1kDf+58sxNts+Zb26SJJVLmnriTtp/v76hwkxpqDBzzM9O06pNts1hWY0z8za0eZ+HhrYaGtqAeBRnANgrwgzlNEEj8roobrWA3Rzly9bsp5+s3F/lknTOtLt0Zc9F+uB2d+vo8c/ooLGLddT4Z/X3Ux7Sf868VdfNu0TX7nSZPr79fdpnzAr1V0u6Ye1sfeDFI/WGZ07Sz1btqfWD2y4MX35xov78+eM1b6PU21ZW97v+pFcfE/WGtc0Fcb9e6J+o76zYT5L0z1PvU2fJf7zNsnl/AEgbQNKcy7S4ZHIec7XijCatU1+5pN17O3T5v6YpzJh2fwUdl2RmssOQ4R9JGKRn9Kakc5PDulLpAAMAADbqPdqsGKtnXDe1+bm9/8ymLR06/oI36P41O2hqaY2uOORH2vO4sSHHSskyThTbcWNFy0wuijJ2uanUXtb2py3QPrtv1IUdX9C00hrd99I0w8JMkkkDpuPMwp4n7B5MIRqgoS0JGtqAZCjOAJDUYkFjlIj7ZwQGDbs5yv/5ypH67NKj1DvQpZ26Vuufpt6t/93hRv1o9rX61g6/18en3aM3T3xaczvXqr9a1u3rZ+nzy16rIxedpn9ZcrhuW7+DBrf+cz3yYnP54nG6/bPHa/4GaW1bWe1vv1MHH/9MxPdqFwx+uWqBlvaP1ayO9TqjEnb/oLTzjaMKN67Cie2KH5Nz2dl/4dB/l/YBaffVHbri3453XJgJ43JmckYho1ZsNQwZXpnMTaYDDAAAGDJ9Q9LF/Tq96rd6psZtU1t4sSV9gWbt5i4d+38n6Il107RDaYWuOPCH2vNNQdnPy66xzc0KFJPnaJTMVDufnVJnm7Z/687aa/46/arjfG1f7h0qzPz89JSFmbDtJuPMwr4OK0TWNG5DGyPNgPxQnAEQyuUM5cIFjcSrZ8LmKIePN5NK+k3vXjr26bP16aVH6ao1O+uBjdP1ZN8UPbhxmq5ds5O+/spr9IEXj9Hrnnqb3vfiQv1q9V5aOdBt8I1Jq14eq5s/9Sbtsr6k9eWytpx+t447M250k9mFcl+1Xd9esbck6R+mPKiO0oDBeV3ciDIt25tiZhcwJOktH7tTL+49dF+gCX1lXf5vb9LmTbbztWuSdn95H4ubmRy0vab+IcP53OQIdIABAIDEUryRGXe/zhHqunqmJq6pLel4s6R8TW0bxuro75+sFzZM0vzyS/r9q7+pw07rkEpR57DNNbZFmqLlpjjZFqHKY9o1/W27aM85a0YVZlZutC3MhD1uM2kg7Gsa2iTR0AY4QnEGQDKWM5S9ChE0RvEHjbixTvadYOsGu3Rp7176t5eO1dueP00nPXumznz+Lfr4S0freyv3123r52n9YKfvXEFGX0SuWTlGN5z7Ju26tk0by2W9vPBhnfLBv4YcX2N20fzb3vl6qX+sprZv0vETnlSypfBZB4+kz5N1l9ug3v7pW/TYPou1uX0o+W1cNkZbjFbMSOm7v/z7BIXgoM/rEDKCWISMVHOT6QADAACG6jFxwCt2TFFPxGMmTW01leGPzpravJ/nMd5s9L4vrpmgw75/uh7rnaaZpZW6Zs9v6IyzNkrlyAqN7HOA7cqVPIs1SZ7LdhyavbYJHZpx1q7aY+bqrYWZ+5emKcy4mDSQpKEt6n0FCwkb2rxsRpqFNbSZoqENSI7iDICtmj5oWK2eqbGdo1wTtDTfpBMs6Zv20ro13bri4ydojxXd2lIq6ckDntHbzrtN0mDMuaIvzAdU1oWrd5UkvXPy45KqcrNU3qZ4k/S4IEnmLdtra9+it3/5ej04f7kkafbSSZKkrlLcf48a0+4vk32iZiZ7w0gdQ0ZU0dUj87nJdIABAABLWY02S3S/ziBBzTBhTW3jRu86UvZNbeH7B233P+Y3ct9nV0/S6777Ft2ydJ7Glzbplzt9V594z2KV2k0KNFnnDv9z2eafpMeFcX1fz2DtlS7NOGs37TZ1pS5s9xRmfua6MONinFm74kf6hWzLqqHNI+jfhLCGNpPpJow0A7JHcQZApKYKGrFM3oA2DRr+z4Mu+IL29zMPGpK0pa9dF3/8eO39/FAh4KFdXtZZX75e7R3pRnZdsnqBNg62affuVXr1mJc9j6S9SWXQa8hixU2S15n8uSvTNurU/7lKD85Yr1K1ql0emqVrv3eAJKnNqDiTd8iQChMyekZvynRucgQ6wAAAQCLeZg9Hb2g6u19n4ZvaXN9/ZrTVm7p13PdP1M+e2k9tpar+c9ZF+vYHH1Tb2DaDo5NmBJeZyfs6XK+4SdLIlvy5O6aO0YyzdtMBled0UfvnNb2tVw8snapjfna6VuRSmPEfazvOLKqhLSHbhraAv/s0tAGNg+IMgOSaMmhEvRltM0e5Pp1gkqRqWb/6zLHa7eGZKlWremDGep3yP1dp2sx1EeepCb647h3s0tVr5kmSTp0UVlhzXahJK/j1TG7bpF27VmnnzlWaWO4LOC5dwFiw1wq97vyr9diEAXUPDmrWzXvosq++TgPDA63bS3EXqq4LM80ZMpzNTfb+O0bIAAAABoKuCRrmfp1BcmtqM7lnp/9rm6kD5mOh+wfb9K5fHqnP3nOkJOkfJ/9eV334Jk2a3Tlq32BpizRFyUxSuhU+yY2ZP0kzztpVZ0z6k37Veb6mta3RvS9tr6N/9haDwkwQk6Ken4txZjS0AUiO4gyAEbIYbeb9n72N+gSNGn/QsJ2jXJNHJ1hwYLnkvw/RjJv3UPdgVY+NH9SrPnet9j5wacR5vEaHjct750uSFk54XmNLcRfu/co/eIQ/55S2Tfrw1Pv1u3m/1R0LLtHlPdfoinnX6K6df6Mre67SJ6b9Vbt3vaK0AeP1xz+jWf/yBz3fVdKULYPSLw7S9T/bU5K0pTr0v9z2yDFzaQszUcfWIWREcRQyvExChilCBgAAcKZo9+s0uZ9FJk1tCvg8rKkt7GvzFTLR55Skkv7jmv31jutPUl+1XW/s+qtuftdvNO/AbovzpmnsqkdmcvG8KVfqlKTKYTtoxunz9anxv9Z/d/yfukpbdOmjC3TYj99qWJiJy8th+ch20kDQ51G/+y3Q0AbAKYozAGKlHW3mFTbarJhBI+jCynaOctDn3q9ddIKFn+f3P9tT/T96nab1D2pxZ1nj3nebFr790Yjz+G278L5/01Q9u3mCxpa36NgJz1ucQxodANIEELtznTJxka7f6bf6wHYPa07nOg1WpeVburVyS5ckaUFXr9495TFd2nO9frjjzTporGkBy2tQb/3YnVr3lnu0sr2sOX1VPf/fx+iem+Zs3WNLdWjlTFvoypm4wozpdtuZyTVRBciEIaNWHI36+2gYMrzC5iabYG4yAABwoUj366SpLewxv+DM9Mu75uuYX5yulVvGaf+2Rbpt4fd08CldBveh8XIxWiws5yTJTS7PVZP+eyyPbdf0M3bRjgdP0o86/kvva79GkvS5Ww/SW35zgtb3m6xcsrk3Z9pxZnH35/RuC1lJVoBVM04xbQBwiuIMAHsZzFD2Kk7QqAkKFf7HJLvxZmFfx3X82IeNB27fQY9+/o2av76kdW1lLTn2Eb3jczeqvdPmwnqLpAFd0buTJOnEic9YHBslKjSkCxMdGtD5M+7UF2beqXHlLXpo4xR9bMmhOvCpt+qwRafr9YtO0cFPnqKPLXmdrlszW1uqJR08bpl+NPsW/e+sP2rHDpMxcNL4SX16x39fq0f3WawtpZL2WNWpP37yzXr28Skj9hsY/l9u8Fgzk8KMbfeX6TizsLBckzBkBKkVaQxDhhdzkwEAQNHlcb/O5mtqSzrezG2B5vanZ+nAC96qReu3046l5frd3t/Qae+pqm2i6ZizGpf3f/HKJjPFc3dPm64dx2vm2Xtot571uqzjMzqi7QFt6G/XW37zZv37La9TVSbFMFeTBpKOM4srPqaQU0MbI82AYqE4A2AUV29C2gaNWA0RNJJ2gtks1bcPG8uen6DrzjlRe700QdVSSffPXa2T/ucKzd1ldcxzjXT12tmSpNeOXaZK2yarY/NU1qC+PPNPOnXSIg1US/rGK/vqzOffqN+tnav1g2XVwkXvYJd+t3aOznnpEL3h6Tfrl6t21pZqSUdPWKyreq7Vuyc/rnLEGLI9D1iqI79ype6fuknlalW7PTpDF3/0zVq7avQFemdpQJK0uer/X69tYca2+6tOIcPBqhnXc5NNETIAAIBzGTe1BeoJ2JZrU5vJPTvl2demqc1//vQFmqdWTtaB336r/rRstiaWNujC6d/QR963XN09SVaRZ1WkyYvb1z/htdM1/W276vBJf9PlHZ/Wzm1L9ELveL3+R2fokkd3MTxL0sKMf/+048wcNrQFrZopWkObxT06AdijOAPASOSblSlmKHt5OzsaN2go4HOTTjA/k/uJhAnef/OmDl107kLN/uMuGjM4qL+Nq2rOv96g48583PjMi/vH69FNk9Vequqo8c+pmKGjqs9O/4veOPE59VfL+tDiI/R/K3dXVQOKeq1LtozTF15+tU559g360/rp6ioP6hPb36+fzL45YBXNoE7/8J815gO36ZnukiYNDGri5fvrkv98vcL+19o9XJzZNNjm2Wry3zZt91fQ/nEhI+Ymr3VcNZNmbjIjzQAAgEtZ36/T23wSZ0RTS12b2qKmDbgYb5btWOgVG8foqO+fol89voc6SgP66rgf6D/ffp8mHbR9xLmiuFt5kj33r7XU2aapJ8/XlKN21NkdN+inHV/R5PIG3fnCTL32+2fpvqVRId7LZgR02LE2kwZM789Z/1Uz9WpoC8K0AcAexRkAyeTYMdG4QcN0vJnrpfpB+29z3Q/3Ue//Hq7Zm6RV7WUtecPDeseXr9WkKRtjzjnkhrU7SpKOG/+iZ2txQsfpk57SWypPaaBa0jlLDtZt600v+Ics2jxJ733xCH1m6Wu1frBdrxn7in7b8zudNmmRpKqm77hOZ379Sj3+qhe0qVzWLuvKevZLx+r2KxdEnndMebg4U60VZ8L+G7nu/ko6MzmFpKtmPJibDAAAGpHL+3V6hY02K15TW03UKgPb8WY2UwfSF2j6Btr19l8v1Of/eJAk6Z87rtBvjvmN9n7Ldip3twUeY6Y4mWmb7F5Tx7Qxmvmu3VXZdYK+2PYDfa7jp2ovDeqn9++hI3/6Fi1bPy7+JJKymzTg/Tzod69dZiPOi7dqJo+GNqYNAG5QnAEQKOjNyLxnKDdO0Eg73kwRX2dToHnivum6+ZwTtNeSoQvJ+2ds0IFfvkqHn/h0zDmlG4ZHmx08bpkmlDcH7LFF9QkeW7RL13Kdt/09kqT/Wb6Pblw3O+G5Srqkd75OfvYNumfDNI0tb9HnZ9ytiw66Tq/6t2v0cGWL2qtV7f7YDF3+zydr8VOTY884rX2o+LV8yxilL8zYdH8FfW4yziyDVTNxmJsMAACaWcYrd2ObW3JvaosbnRs33izJ1AHbx8Ouy0v6zE2v09m/XajNg206tu2vumP383XuP72kCXvFX/vHq19myvx5S9L4/adpxt/tpnnbrdXPy+fr7R03a7AqnXPDYTr7ioXqGzCdDuGiMOP/OqiZ0r9/1KSBuHvRanRmitIkDW0AkqE4A8CNlg0aLsabZbVUP/pcm9Z36aLzjtPkK/bVdv2DWtJZ1qpT/qq/++L1mjx1Q+hxz/ZP1JN9k9RRGtQR4xfHPL80OgCkDQLB52vXoL484y51lwd027qZ+vHK3VI8x5DF/eN19gtH6gcDu2qgLO2zao3O+2lVhy0aVOknB+s3X3m9qqPuIRNsZvt6SdJL/WEdYra/A3HdX0HnMp2Z7N/H92Ul4pCoVTOGHWDev++Zz00GAABIwdVoszRNbYFiml7yaWqrSdrUpoCvk0wd8D/uF37t/dMH9tRr/u/t+vOyHTShtFHnj/uFbj3lpzry7LHqmOpo1bmkvDJT1rp2GK8Z79xd04/bQR/svkY3dHxcB3U+od5NnXrzhSfra3e+RlLJ8GyuCjNx9zMyGWfm3xZ1j9oANqtm4hS0oY1pA0AyFGcAJNesQSO2QBN1kM14s6SdYO4KNJL0x8t31n3nvll7vDJGg6WS7pu1Vvt++Wqd9N77JQ0GHvP7wNFmtsICQ9yfYO+a8jft1r1aq7Z06tylB6pqfNEfrlwe0Kn//Ff9+dwndd4727RkirTdWumDvx7U4Y+sUHvIzyfIgq41kqQX+8cHPBr139BmpVXacWYpQ0aQoJARMzc5SJqQEYmRZgAAwLEkEwfS8DazBDa19Xg+j+rEL0RTm3fElMupA/7H/cIz00MvT9Pr/u8t+uC1R6u3v0v7l5/SDXO+qG+87wHNOmZ7lTqzfGvNbWbKStv4Dm335nma8Y7ddPDMxbq64zx9suMijS336+ZndtRrv3+WrntqJ4szpinM+I+JKsSEjTOriWpoM5w0EIWGNqDlUZwBEMrVaLM0UgeNqNUzleGPJm8wjxAWNPLqBDN53C+6QLNm+Vhd/Ik3aeo1e2v65qqWt5e16PVP6W3/+1vtc9CyUfvfuG6oOPO6cUvVXar/zOQdO9bpg9s9LEn6z1f216qB7tTnPPT4p3XCty/X469+Tr1tZW2ZXNJ57QfrN6t3UrkkvW+7x/SLOTdqbscao/O9duzLkqR7N/pvJuqi+yvNODP/YylCRkFXzTDSDAAAFEaCNzzDmtpimTa1BV27OWlqS3LPTv82/+dxBZtsCjSD1bIuuHtf7f7Nd+nSJ3ZRe2lQ/9hxte46+Kt6+z8OauxuLkadNaC2kiYeOEOz/mEv7bBnl77Y/n1d1vXv2q3tRb2yfozeefkbdNTP3qInV9r8fNIWZlxMGjBpaDOUdNVMDg1txhhpBmSG4gwAd/LsrHCxeiZKojnK/n2z7AQzedwvfq7vLb/ZVbd/9CTtvmiq2qtVPTRhUKX336a/+8q16tl59db9Hu+raHH/WI0pD+jgsaOLN3n7yNQHNaY8oLvWb68r1vSkOtfOe63QO756pVa99V4tGiONHRzUbo/M1PX/fLL+8se5+uyyA/SRxYeod6BT+4xZqd/2/E7vn/KIOjQQes6ejvWa17lWA9WS/jqiOOO6MOP//Qo63mRmsoEcQ0YYb8gwRgcYAABwLM1oM5P7dYYJu19nbFNbFKdNbTVJm9pMpw7YjIUOOlfc/tu8tG68Tv/Vm3XCr07SC+snanb5Ff2i8jX96vRrtefbZ6h9clfk8c2ke6dJmvX3e2ryETvo1O479YeOc/T29pslST+8b0/t9u2z9fMH95DdGDPXhZmo+8zEjTNzeH/OIAVoaGOkGVB/FGcApJPDaLPYi40es9eQTdCoieqmyaoTzORxv/gCTd/6Tv3m80do7dcP1269HRoolXTf9A2adu4NetfnbtS0WRsklXTzuh0kSUca3XcmO7t0rdabJj4vaWjVjPnF/0g77b5K7/jydRp3zk26f7vNkqQ9Xx6r5//jDbrkvw7RYP+2n93v183WKc8u1O3rZ6irPKiPTHtIV8z7nd484VmVR406a9NHpj0gSbp1/SytHewc3p50WX7Y10G/N2HjzOTbVoyQ4eX9ex/WAeaVaG4yI80AAEBGIlfk1rupLWrigFfdm9r826L28YrLTFHHhu0/0tVPzNfu33invnbXqzVQLenEtjt1x4J/18fe97ImHzpDpY7mfbutvdKlaact0PS37Kxdtluln5e/oK93fkdT29bq0Vem6NAfv1XvvXKhVm60aQIL+5mb/LcMeyzuPjP+z23vz2mhMvyxDg1tidDQBuSmef9vAcAJ69Fmnv+Juxpt5jWiCywuaJgs0w9SGf6Yao5yTZJOMNsCjV/6Ao0kLXpge13ykZM0/jevVs+GkjaWy/rr3NXa6fyr9Y7/+L0emVSRJB0xfnFAQSI/H95u6E3569bM1uN99uMEdt1nud75xd+p8okbdf+M9dpSKmmXdW1q/8mB+vW/Hq+lzwbff2XplnF634uH6xNLDtLyLV3q6Vyr/5x1l3437xqdM+1+HTpuiQ4Z+7L+a+btWjjheQ1US/r28n2Hj86y+yvoPCa/qwlUhj86ChmBXZ4uRYQMRpoBAIBMFbWpLaiJppBNbWEZKelY6KB94vYfaX1/p865/nAd8P23695lMzSptEFf6f6hfnfED3XkP22nyUfPbqqVNO2TOlU5YkfNes+eqiwYpw+XL9HvOv5Vh3Y+qo39bTrvD4dov+/+nW5/fkfbM4dst2lMdDVpwOb+nI3R0MY9OoFiM3uHDgAc20cP6UHtLWnoYuEJ7Spp6CLiKc2XNHRxsUgLJA1ddLyg2eEn7JH0rMETT5W03LetImm1hi6e1sedYIykjRq6IAu710htn9rHdm27KWPt8w5J/b5tflH7xB3vfTxI2HOOdtc186Rr5uqosx5Vx2GPaXFXWffP6dXYf/2rNn1Tmqo+7dO9UvdvirqDaDb26V6hoyYs1kC1pG+t2NviyEEdfsKz2uGoh/RIZbPuLZUklTR/fVmvXLefLrva9GaVJV2ztkc3r9tBZ01+Un8/5XHt2Lle75nyuN4z5fGtew1US/rcsgP0WN8UJSvMmHwdFFyjQmhzhgxjzE0GAAAOnVCt6qpSshXcpVul6uFuX8/2O72gl58ezk87VqUXfa9tpqSXfAcFZaVpkl4JeZKKhnJULfZs5c9M/nxUyyreA/05qXYd26+R2SYqB0VlpqDHg/bxqr2G6Nx070vTdcD/nal/PvA+ff6oO3VAx9903YRP6fsHvUk/fc1CPfdMWWvvfUUbF62WGu197LaSxu5c0fh9p2lMz0S1aUBHle/Vv1V/oQWdQ/fU/N1Tc/Wha4/W06sqCZ4gj8JM1DGO7s8ZpTL8scANbSMaa2loA3LVtCtn7r77bh1//PGqVCoaN26cDjroIF188cX1fllAQ8pihnIaVqtnapytnvELW5HguhPMf1zcY/7Hg9jU58u66Zd76foPnKYZN+ypeRtK2tBR1j0LhkLeB079g97yT3ersl2fxTnT+8jUByVJV6zp0TObg1e4eM3YcZ3e9tG7dNIFl2rFaffqwcn9GiiVtPPaNlWu2FdXfOgU/cm4MLPNhmqHvr9yDx216ER9dPGhunpNj/7WV9GTfZN01Zoevf35hbqkd2clL8y46P5KEDKifv8rwx/ThAwPVyEjydzkIHSAAWgV5CYgO65Hm5msnonVE7At6eqZSGnGm2U5dSCb3DRQLevrd71ae3zrnbrybzupszSgD7VfqTu7/km/2uX/9K63rNC89++miQfNUHlM8fuk26d0q3Lkjtrxg/to2knztfO8TTqn7de6ve2f9P3Or2lB18t6ae04nXHJm/TGX55ax8JM1HbTSQMO7s/ZrA1tADJX/P8jJHDzzTdr4cKF6u7u1plnnqkJEybo0ksv1RlnnKEXXnhB55xzTr1fItDwrnpSOmHnkAdvknSU3flMVs/E6tHo1TPTJfnvV+989UxYJ5hXmk6wqG4uFytoFHBciGpZN164u3Thbnr1wme1ZOf7pEc3a/4zVX3ruOe0+/7PaP6qsVpx3zzddtUCrV+T3TL+A8Ys08Hjlqm/WtZ3lu8Zut/EyX06/KQnVNl3kR6d1K+HykOrZDoHq9p5+Tg9fPn+uvzOmMqBoY3Vbl2/bq6uXzc34FGXhZmgz026vyxDRk1l+KPrkBEzN9kkZBijAwwARiE3ATm7XtLC+N32W/mY7p+yu6SREwdMeCcOjFg9EyTT1TNx/KtnbKcO+LcpYB//cynk8aB9/MwmD7ywZqJOuugknbzbU/r4wX/VIXOW6Ki2+3VU2/1atf14/fboQ/SbQw/VXx6dpLX3vazNS2LDZ25K7SWN3XWyxu87Td2zJ6hT/Tq2/FedoT/osM5Htu63fEO3fnTfXvriHw9Qb193gmeKeivStjATN97OdNKAg3FmQSrDHwu2aoZ7dALFUapWm+tv0pYtW7TbbrvpxRdf1F133aX99ttPktTb26sDDjhAzz77rJ544gnNnRv0xtloa9as0aRJk9Tb26uJE+O7soFmFbREf1Rxxhs0PMUZ/xL9WtCQNCJo1IozkrYWZySNKM54R5uNCBq1ZfrPep6oFjS8xZnlvo/StqCx2rNtvW/bqKBR27DG93XtY3/I9i2ex7f49g3aVjvG/7WXP0QEhYaooBF1XLTx5c26Y8Hl6ihV9bm/L+uR6dsWZI4bHNSuvV3asGgHPXDzPD35yBRJyUY9jFbVL+fcqP3HrNAvV+2sL7z86hGP7fmq5drr0KfVttMSPTGhX5vK217XrD5p3KKZuv3CffTKixMcvR7JvPPLv2+SwkxQ91dUyBijVCGjMvzRP2PcGzKm+j5K24oz3pDR4/l8OGh4Q4bLucmhy/N9ISOoOEPQQKvjGrj5ucxN/L4A28TmppDMJI3MTZlmJmlbbvIWZ2q5yZuVap97izOrhz/GZibvRpPcFJSZal/3R2wLOsa7TQGPBT0etI+fXWbaZbuVOnu/R/TOfR7TDhPXbd3+6OBcXTJwmC5e+iotfrqqviXrtHnJeg2sM8lt7pQ6y+rcfqzG7jZZ4/bcTm3d7VpQelFnlG/WKaU/amr7ttd8w6K5+sG9e+mKv83X5oGkvd55FWbiJg14HzedhmE5Brri2aeWm7xZqfa5t6Gtlpt6PNtCijNBuSmqoS0sN0WONPPkJjITMJqLa+CmWzlz0003adGiRXr3u9+9NWBI0qRJk3Teeefp7LPP1k9/+lN95jOfqd+LBBpQ0WYoh+pR9OqZWidY1OqZIKnnKBehE0xy0Q3mtW6wU3/ZsL0OGbdM476yl8Yc2K35hz6pJduv0cr2su6d3C+95ll1veZZHdc/oJm9Y9X30lQtfWJ7PXHf9lry4nglKdgcNu4l7T9mhTYOtumq7rk65rQntf1OL6s8c4VemrBRyzratK23q6yp/VVNWzpZj16/u264fVai5wxnEzD8+9usponax3acmYXK8Meom79msGomLdORZoQMAK2K3AQUgOHEgTT368x09Uxt6kBFCVfP+PkzU9TUgbCMZHrfzqDHg/bxs5s88MSKKTrvD4fqUzcdouPmP6d37/eITtptkfZoe06fKf9cn9zxQt0061X6zcBhumVwP/Wt2aK+JevVt2S9Ni9Zp76lG6QBN9em5THt6pw+dsSf9sldKpVKGqNNelPbHTqj+ge9tnNbg9TiNeP0o/v30o/u20vPrp6U8hVkUZgJ+jpqBJ5/m+PCTJAMxkCHNbSFiWpoC8U9OoFcNF1x5pZbbpEkHXfccaMeW7hwqEXl1ltTDKEHsJWL0Wa5Bw2vWugwCRqBktzossYbKmohQgHb0hRopKzDhiTdtG5HHTJumY4ct0Q/uu4Y3X3dTlJpUPse8aIWHPaUNs9Ypee6B7Wko01LpvZJUxdLey/WlNOkPfsHVNncrq6+DrVt6FZ1fbc2ru/WwOZ2VfvLGtzSpsFyVZ3d/ers3qy27i0qj92gs69bLa2Qbj9wQIPH3KKlkpZufUVtaqtWNWdjWV2Lp+qxW3bWLXfMUDa3Wcu6MGPb/VXTXCEj9dzkBDPdAaDZkZuAOokYbeZtavOONnNmx+q21TM9Sj8SOkgmTW01tX1txkJnUaAJOy7cYLWs3z01T797ap4md2/U2/b+m979qkf0mhnL9Ia2u/WGtrv1cnWSLp/yel056RAt2m2mNmm2qgOD2rxs49DKmqXrNdg3IA1K1WpVGqyqOiipWlV1sCpVNbxt6PP2StfIQszEzq2vZ7w2aG7pJfWUlunAgYd0cuedmti2SZK0ZbCkq5/YST+4d2/97qkeDVTT5qi4tx7TFGZMRkDH3WfG/5hhZgrSpA1tQWhoA9xouuLMk08OtcDuvPPod4xnzJih8ePHb90nSF9fn/r6tt3Qes2aNaH7AvBplKARFDBqKhodNGrbMu0Eq13YB22rsS3QBO0Ttp+fedi4ed0sfXr6X7X/mOWa0rZJKwe6pWpZD9w8Rw/cPEeSNLayUfsf+YKm771Ymtqr5WP7tbK9pJc62vRSR1Uat1maslnbxh2EO+ixQc1YUdWGTulXr2uTJE3tr2q7DZ0qL5+kxQ/uqL/+YbYeWJfd/W7sA4b/mLSFGf9xNjezbM6Q4WpuMgC0ijS5icwEhAuaOFCP+3XWbfVMoKQFGml0PjItxuRVoFHAsdFWbRqj79y9n75z937aa/tX9O79HtE79nlM24/r1fvbr9H726+RJC0fnKAlmqolc6dq8ZypWlLdTi9Wp2pJdejzFZqoqIkAE7RBc0tLNa/0rOaWlmleeanmDLykuaVl2r5j7aj9F62cpB/ct7d+ev8eemndeKvvKVyaZragfdKMgPafI+WkgRZqaOMenUB2mq4409vbK2loOX6QiRMnbt0nyJe+9CV97nOfy+S1AY2uaYKGl8nqmUhpO8GCur7ClurLt58UHDT8x2QbNpZuGadHNk3Wnt2rdMS4JbpszU6j9tmweozuuHwX6fJdtm4bU9moXV+zTDN7VmvMlPVqm7BRA2P7pI5+DZarGihVtaVcVXtVahsoqzxQVtvmNv3dTUNB7dZp09R36c565t6penjV2JjvxSXXhZmgY9N0f6UcYxY2M9nLUcjwsg0ZxixDBh1gAFpFmtxEZgJSimhq8/I2tXknDqSSpKnNdPVMbVvqpraasKY2k6kDQY1uWYyGDjvWzMMvT9M5NxyhT954qI7f+Rm9e79HdNS85zWhq19Ty2s1VWu1j54JPHbjYIeWDEzR4oHt9OLgdlpRnaiZbavU0/ay5rUt03bl0QUYtW37dNm6sXpyZUWPLZ+iXz20m255draqzsY/p81MQfuYFmbCtmUwacCLhjYACTRdcSatc889Vx/72Me2fr1mzRrNnh3xxi+ARAoTNLxqBZqKEgaNJJ1gNf5QkWapvv+YqH0kF2Hj5nU7aM/uVTpy/OLA4kyQjavH6P4be3S/0d5DTpr4jKbN/LNWD3TqM7ccpvWDUYUO17IIGN594gJFku4vByGjphKwLWXI8HaA2WJuMgDUD5kJyE6S+3WaNLXFapimtqipA1FjoaVkBZqg/YIkW0VT0z/Ypiv+tkBX/G2BpKoq3X2aM2mt5k5aozmT1ng+X6u5lTWaNWG9xpT7Nb+8TPM7/LPotlm6bqyeXDFZT62s6MmVleGPk7Vo5SSt3ZzVpIG8CzP+fVxPGohQGf5IQxuAhJquOFPr/Arr8lqzZo0mT54cenxXV5e6urIchQM0OcPRZqZyDRpeFQ0VY4KCRugcZVtRnWBpl+pLeYaNm9btoH+a+rBeN26puktbtKnq/n8vHRrQh7Z7WJL0g5W7N1lhJujrnLu/vCrDHwsUMkylmZsMAK0kTW4iMwHRspo4EHa/ThMjJg40XFOb6dSBuLHQQYIykwLOZZKZascnK9BsU9LqTd1avalbDy4LuviWOtu2aIcJ6zS3MlSwmTNpraaPW6/FaycMFWBWVLRoVUXrNncGHp8NkwzoojATdE5/PgralnTSQIs1tAHIVRZ3R66r2szkoPnIS5cu1bp16wLnKgMwE9QhETl/1PBG3N6LBttOjxEXK96lvz0BO3svgoIukoKufaOWJ28V9sZ33BvlYReMHSHbZLDNK+rCNmq/MO2B53i8r6LF/WM1pjygg8eGd26lcWrlae3YuV6vbOnWhavy+nc8+PsdyUVhJm5ZftTvg6NCWNTM5KBtpiHDK0XIyGNuMh1gAFoJuQmos4gVvd5mk1zfUI1rrgm6/qsMfwxq6hn1JrbJm+L+3BQk6I342vag/OPPTEmaqGr7meQmkwyRzuaBdj2zuqJbnp2jnz2wp86/7SD983VH68u3H6BLHt1FDyzbPsfCjGlmCvpvk6QwE5ajbEZA09AWimkDQK6arjhz+OFDbfk33HDDqMeuv/76EfsAyIjh/8zrFjSC3kCOChpBjJpt0hZogrZlXaCxKdJ4lXTzuh0kSUeNf9HwHOa6S1v0gSmPSpK+u2LPTFbmjJQ0YNSO9e8XtY9JYcZ/XIYho8ZVyOiJf6o0mJsMAMmQm4BsWTd9pGxq875B633j1vuGbmM0tYXtH9XU5v3o/7y2b1wmMi3QhO0bJPsiTX2Zfn82xa+wfaKysvfzqAY228wUcYpKwGO1bd6/GzmvmqGhDWg8TVecOfroo7XTTjvpwgsv1P333791e29vr774xS+qs7NT73znO+v3AoEWZzpyKNOg4ZVp0DDZN+pN8rALTFcFGrdh4w/rdpQkHTPhRXWUBgzPYeadk/+m6R0btbh/rC7pNbunTTJpAkbt+Kj9/Oc3nZfssvsrQFYhwyskZNjOTXYRMoIQMgC0GnITkL/IiQMRGmb1TNC2yJ4gl1MHFLGttr0eBZraOZqpSGOTmUyb2ZIWZrz/XV1PGjAYZxb1PsHUkM9rLFbNpMU9OoHiarriTHt7u37wgx9ocHBQhx12mN73vvfpnHPO0b777qsnnnhCX/ziF9XT01Pvlwk0NJejzbwXCcZvuiblavVMbVtg0DAdb+Zn2gnmskATtl9tX7sizd0bZmpp/xhNauvXEeOWGB4bb3LbJr13ytDvyddf2Uf91TZn597GRcCwWZJfezxqm4vuryA5hoyeiJfhQNI3KpK+MQIAzYTcBBSA4WizKIVdPRPJ5FrV9dSB2naTa3TTxjabzBR1nkZh8/qTNrP59zH57xy0raDjzHJqaIvCPTqBYmm64owkHXnkkbr99tt1yCGH6Ne//rUuuOACTZ8+XRdddJHOOeecer88oDUUPWh4pQkauYw3M3mTPosCTdj+wQZV1lVrhla1nDTxWePj4vzjdo9ofNsWPbJpsq5dO9fZebeFC9cBI2hfV4UZ2+4vQoYkOsAAIAS5CchWHvfrzERdm9qi+HOTl0lTW227ydjnrBrbaudppCKNbWZK2swmJS/CBG3LYNJAkErAtqC/I14ZN7R5/42IbGhjpBlQd01ZnJGkAw44QNddd516e3u1YcMG/fnPf9YZZ5xR75cFNI08ZihnwnXQCBR1ceeyQBMWKkwKNLZjzszCxlVr5kmSDh2/RFPb+iPOaWZOx1qdURl60/6/X9lPVZVSnW+IbRiyDRhJCzP+z113fwUwGWfmVaCQEYmQAQDGyE1AnRk2tZmuGE7V1BYns6Y20/FmfkmmDtS2uy7QhO0fxbZZLE9JXluaZjb/fjaFmTQjoIPQ0AYgP01bnAGQv6xnKNd19Uxl+OO4gG2pO8GSFmj8nwftW9vmKmxEB46nNlf0wMap6ihV9ZZK7Rcieej4xLT71VGq6rZ1M/XnDXFXs1GShgvXASOqMOMNia4KM14W48y86hgyonj/3SBkAACAVuNtWjF9YzZWj+fztE1tVvfsDNrmeuqA6wKNq1U0/nPWs1CT9DWkbWaTRuemuG0m0yVMGOT2FmxoC0JDG5ANijMAspNwtFkmQcPLJmjk0gkWxWbZtun2oPP790veEfaLVbtKks6oPKl2DYacO/6i/6jxL+roCYvVXy3pv17ZL3Lf6OdwFS685w86Jmq/sNDgcl5yEItxZnHb6hAyvH//M19VBwAAkIGs7tdpKvfVM5WAx2vbjJraghShQOO+sS1cmiyT5/O4amZzVZjJaZxZgza02eAenUB+KM4ASKzhgkZPzAmnhnxeUzHctpWLG12GMQkaUdv953IXNm5YO0evbOnW9u0btXDCcxHnqD3H6D9jS1X9v+3vlST9ZOVuWrR5kvGxyZkUZVx1fgVtjyq8mXxfjR8yTDE3GQAANJUM79dprcfzeWGa2oJO4LpAE3Qt7pe0sS1pkcb/HEmyT5pjw8QVZZI0s3n3yaMwE8RynFll9G6B7yN4/+7UsaGNaQNAMVGcAZCtIgUNr7ig4ZUqaJiMN8uyEyxqe9jr8O9rEja2HduvNl04vHrmg1MfVNuo1TPxPrH9vZrZsUEvbh6nC1bsK7dhwi9JUaZ2XNy+toWZuODmYJyZV2X4Y1DI8P7e5xQymJsMAAAg46Y205XFjdnUZrNaPIvcZJMBvOcI46JIE/a8WRRggiSdMGBamInLUmkLM14Ox5nFNbTFoaENaGkUZwDky3CuafMEDa96dYJFbfez7YDyHzt0/M9X7aaVW7o0r3OtTpr4dMxxIx09/gWdUXlSg1XpM8sO0qZqVkv5TQJS2oBhsj2L7i+HM5O96hQyvP8eJA0ZAAAA9WI9ccCnNZvagh4zaUzKIzdF5QHTzJRFoSYLJq83TSGrdu6sMpNXxpMGaGgDkADFGQCppA0aXlFvuuYaNIIuqrwXXZXhj0GrDSKDRtA2k24eFwWasOX6NmEj6piRx2+ojtX3V+4jSfrwtAc0qdwXc8yQ2R1r9fkZd0qSfrxqD921YWbMEbZMw1DaTrm8CjNeFt1fXs0aMnzoAAMAAIWW8A3SJPfly72pbVzANqvxZl5xUwe8itzYVju+qEWatJkpyykDYZkpiKNJA3Hb4v4+ePV4PqehDWh5FGcAZC/haLO6BQ0v74VV7p1g3s+TFGhchI3kgeNXq3fTor5J2r59o/7f9L9G7itJk8p9+r8db1KlbbMe2jhF31y+b+wxZmy609IGDMl9yHDU/eVVMdxmEzK8DEKGqST/DkiiAwwAABRe1vfr9Da/JBmLtFXapjbFbNvKe73rYupA0LY0BZosGtu856hnoaZD5q8j6vvKu5nNy/GkAS/T+3N6hTW0pZCkoc0GDW1A/ijOAMhfowYNr4rhtkBBF4ImF5A2BRpXYSNoX7PjNlfbdd7SQzVQLenNE5/WaZOeUdgF/3ZtG/W9Hf+gns61WtI/Th9afIT6q20RzxvFJljEfx92AaO2JD9oqb73GNvCTNDrsez+Chpn5jJk9MQc5+P9e2u6Go65yQAAoFEZXXs4aGozfcO2+E1tQduCVpGbFGjk22ZToMmnsW30ebIu2GSRmepRmEkzaSBiFyl6YoZXhqtmTEU1tDHSDCg2ijMAUjMabVbkoBHXxWITNCrDH2ODRpKl+pJ5gSbosbjt3seCXpN9keahTdP07RX7SZL+ffqfdGblMUm1/w5D59yjq1cXzb1ee49ZqdUDXfrAi8dq+cDYiOfyvp404aVd+QaMNIWZjLu/KgGPOw4Zprx/zwkZAACgpbVCU1vqqQNBj0WNvbIp0NhOHqjtG1ekMS3U+M9pm3tcZCYpXWaKa1rzPubdHvbfyX9+ya54NzHgsQCVmG1B7wt4/y44WjWTZ0MbgPqgOAOgPooUNLzCgobpHGWvxJ1gpo+lLdC0y64brHYekyLNtnN8d8W+unj1LmorVfWZ6Xfp57Ov099NfkRnVh7TN2b9Qb+ee7V26FinZzdP1Nuee5Oe2jxZ0UEibfdYXCCy7XozDRjebXGFGS+TsXeOxpllGDLC5iYTMgAAQKtIe79O0/vtJRl3VLemtli2UwdcFmiiHjPJFFGSFGmCniPrzJSmec+/v/e8YY95v476b5dkBHTQY9GbEt2fM0yP5/OUq2ayaGhj2gBQHxRnADjR9EHDy3sxVgl4PGjbCLadYFE3u4y6kM06bJgGjg79+7LX6b9ffo36Btv06rHLdO72f9Fnpt+lYyc8r7ZSVVev2Ulve+5Neq5/Usw5kzLpUksSMDpCzhv0s7cpzCRdlu9onJnjkGHK9O+36b8XEiEDAAA0mIQrgKPeoA1raksyPmmruja1eUWtLE9aoImaPGCTm2wyU9pCjQuuMpNNM5vrwoyDSQOVgEODtnmxagZAQhRnAOSnGYOGl/F4My+TTjDvNpsCTdxjCngsiyLN0H4/WrW/Fj5zhr7+ymt0w9q5+sPaOfruin104jMn619fOly9g90G5zHVLvOg4zJgBD1m8t8jqjAjy8eiNyUeZ5bzqpmov/eRGGkGAACaScKJA7k0tXmvD4PYNLUZjzczeUPeZCy0d1vQqo6oXBSVqYKYrmaxyTAu5JGZ4qYMeB9LUpgJOqejSQNpG9rC7s9JQxsAD4ozAOqnGYNGok4w0wKN/7E0BZokYSN94Hh5yzh9b+X++pclx+mflyzUN5cfqKc2T4s5dxR/oDANMiavN0nASFsosw2bUfcu8qjEbAsaOxGmDiGDDjAAANAs8rpfZxRnTW1eYQ0+zsebeZk2M9kUaIL2V8hjto1tYceFSZp1TM9lwrSoFHSc93njHktamHE4acDLRUObJRragNZEcQaAM1mONmuYoOFV8XzufKm+aYHG5CLXKyps+I8PYxM4/Oe2/WPLpmMt6Niwx+M6v6IKM2mX5YcwHWfmlWPIiMLcZAAAAA8HTW2mY5FSNbXFrSSoeD7PtKnNtEATVhgI2l++x9I0ttWOK2pusl3pE3SsAh5TwGOmmdW0MJNi0kAl4BDvtrj8HzZpoMfzOQ1tAHwozgDIl79DoxGDRlDoqHg+t+oES7pU3/t51MWrd5vJcn3bx6N0KF3ocMHmNUQVZUwDhm3Asy3MZND9VaeQ4f17S8gAAACtrima2rzCmtoqAft6t+VaoAnaFtZYlbSxzdWq/qzZvAaTokzaKQNB25L8N08wacBmnFlcQdIQDW1A66I4A6BQGiJoeIUFDeNOMK+4pfouCzTex9OGDdsl8VkGjyTPYVqUMe0KS1qY8Z/bz1H3l1czhQwfQgYAACgqo2sSi7FDkc0sHoVoavNK1NQWt81Fgcb7uW1jW5oijfccWRdrXOcm7z4mj9sWZryiCjNehpMGvGx+J8MmDdDQBsACxRkATqWdoezXEEEjbuVBxfN5qk6wsGOSFGiiHvcyCRv+/Uz5A4FpOEh6nF/SoowUHjBMVynVsfurotFMxpllFDKimP79l8TcZAAA0Nws3mD1NrtEvaHbHE1tUSvQwx63LdD4H49qbPMf59/HZWaKyj9Jjwt7zS6a2WxXKXnPZzquLMGkgUrAKb3bwhraHClaQxuAfFGcAVB/CTs5ChM0vCqez512gsV1B8UVaGy6wVwUaWxDh5+rAoyfabhQyD6S3c/Vv827PU1hJoTLcWYpmYYM799jqxta0gEGAAAaWNqmNps3YL0K09RWCXhS7zbjpjbv9rACTdDxcQWaqNXwYefPOzN5nyevzOR9TkXsl+TnGvazznAEdMXzedJxZkVeNeNn0NDGtAEgXxRnADhnFDQiJB1tVsigYdUJFncxmaZA492eNGzU9jEJHN59XYSONGzChSL2jVtxFPe497w2/11rUnZ/eRUkZNhgbjIAAGh5Du7X6Ve3pjavsGYiqwJN0ONRb+BL8as1wq7zkzS25VGoScPkdZgWZUyb2RTyeJrCjGK2eVQ8n1s1VrrjetXMKDS0AYVHcQZAffg7NooYNHIfb+aVdYHGu902bNjs598/j+Bh+lxJizJJilxBAcP7edx/TwfdX97tGUuyasaPuckAAADRmrapzbs9kM1YaNMCjXffsCyVprGttp9pc1vRMlNcUUYy+/nEFW68zWxJCjMJJg14VTyfN1tDmw8NbUAxUJwBUEiFCBppVDyfO+kEM73ZpffzuAte//aofWyKNCbL6P1hIEkASXKOsHDhImB49wnrvktSmFHMthAm48wKFDKYmwwAAFpNnvfrLGxTWyXm3LFTB8K2uSjQhDVlpWls8+5r2+CWV2byvkb/eeL2Nf25eY+Vb3vcf0eFPO5wnJmXbWYylHtDG/foBAqJ4gyATDR00OjxfF6YTjDv9ribXfo/DysY2IQN/3P5JQ0cYecx+WMq6LUkXbJvGsy8z1OTtDAT1A0YsmsleJet0hRmDGUeMvyYmwwAAFqJ7w3Xhm9q87JqavOKmzogJSvQRG1P09iWNjdlkZmCXotJI1van5V/u0lhJu7enA4mDcRNxYjS4/m8oA1tAIqD4gyA4ihK0MiTUSeYywKNf3/TC2iTjjDvvmGBw7ZYk0TY89ks21fAvkkDhpS+MBMiz3FmPZ7P6xkyGGkGAACaHE1tAdsTTx0IK9AE5aokmSkqO9g0t3mPySMz+Z8vaSObf3/bVUbe7RkVZsJkPWkgQr0b2hhpBhQHxRkAuUozcijzoFGvOcrWBZqgx10WaIIupmv7mRZpvPtHhY60ASTuPKZdY3FdX1E/E+9+QdtdFGYcd38RMgAAAAoh0bWKRbNKIZra0uSmWEkLNN7t3mKNSRHBnwmiChs2zW1RDW6uM1OS3BRXxHHdzObfPyg4GxRmbCYNuERDGwADFGcAZMYoaPi7wOoZNEzVrUATtM1lgcakIyzoHLaFmrhl8TZ/kjxP1GtPErqiOuaCAp+DwoyJiudzk8KMqRxDxiiEDAAAgFH8b8wWrqktjYrn88RTB/yf593YFvR4mLg8k0VmStLIlnUzm3//hPfmrFdDW0QRs4gNbQDqh+IMgEKzmZuaW9BIouL5PPHIKZdBI6wAERUggi7IkwQO//G24SDpcWGvM+x7ido/aD9/wEg6gi5F95dVl6FPAUOGzd9/QgYAAGhUie7XSVObT94FGpvskKS5zX98vTKTTcHJtOnNVWHG8Qho28KMoUZoaGPaAFA/FGcAZMp10PC/eZtr0Egz3sxELkHDddgI2z/NsnvvOZOGCT+TcKGAfUx/LlkEjDp1fxmqe8gwmL1OyAAAAK2i4ZraGrZAY9PYZpo9kuamPDKT93n8+4V9bZozM85NWWvAhjYAxUNxBkDhFSZoRCl8J1hWYcN0NU3aYo2tqOc1KSbZ/CwasDATpU4hw4+QAQAAWl1sU1sEm7FHdWlqSyv3Ao1NY5tJkcZ1oSapuMyUpgBlMmVAKnRuMhlnFsH7dybqPYi8Gtq4RydQPBRnANRFEYKG9+IoMmj0eD5PMke5rgWarMNG2Lawji1/0cRV+Ig7p+kYNtuAUafCTFoOQ0YUm5DB3GQAANDKEr1B6nsjNqq5xX/t1fBNbcaynDyQpkjjPSaqUJNlZrJ9PWFfx33vUZmpQQozUXo8n0c0tHnR0AaghuIMgMwVNWhEMryoMhpvZqri+dxZgUYyu8iV0oWNsG0my+uDQoLtH7+o501alMmj88tQxfN5VuPMejyf13vVDHOTAQAArDVEU1vmUwck95MHwpq3JPPChm1uqldmSluUMW1mK0BhJorJ/Tl9slg1Q0Mb0HwozgBoSIUJGn5pOsH8UhdoknaDJSnSmAQO7/Fp5yHbntP0daYNGGkLMwUcZ+bjImT4WYUMH0IGAABoFq7v1+lnMx4pt6a2KBXP57kXaKL2k+wyQ9A+3m155KakmUkh24K+v7DHbZrZMijMmErS0OZV71UzNLQBDYniDIBc5B00kq6eGVGgibq4ymO8WSTTC9UkF76S2cW2beAIW44fFBRs/gSJet6oTjX/fv59auJ+fnUszERJMs4sg5AR9cZAbMgwGIFIyAAAAK3Kfy0V1QTjv0YrXFNblIrnc+MCTdhjNqs40jS2effJOzcFMclMQd9P0mY2/9dhP2f/1ylHQLscZ9agDW0AioviDICGYRM0/EyDxiimc5STjDereD637gTzP5ikQON/zDZsBO3j3S8oVJgsr7dlcs648QFB+4btE/fzzbkw45ckZHj1hD9EyAAAAMhfEVfPOG9qy23qQNrR0EFfm+aJeuamemSmJFMG/F8bFGb8Kp7Psxpn5uf5OxBVvHTV0DZKTENb0LQBGtqAYqA4A6CuYoOGhUyCRhTXnWC5Fmhsw4ZpkSYqcIQFiqCgYPPH5Jw2r9W/X41NwFDAvjUpCzN+3sfCCjN+TRwyAAAAGpmLN0zzWD0zStqmNr8kUwecjjiT4icPRDW21Y43KXR4940qwrjOTHHPa/M9eaVpZktQmEk7aSDJOLMez+eG9+f0S9PQFnVPXgCNheIMgNwkChq+N2cbJmhECQsaft7HEhdoXIUNyezCvLZfkqXzSZmED5tutaD94gpYcT/LmgSFGb8k3V8tFDLoAAMAAIiWS1Nbj+fzJE1tfhXP56nHQkvJ7t0ZdJ64xrbaOUwbxrz7FzE3Be1b46qZTXJamPFLMmkgwTgzPxraAAShOAOgqTRM0DBdqu99LFGBxv9YVFEh6HGbjrCowGFarHHZBWYSeIL29zIJGEmW5AftG7K54vk8bfeXX8Yhw88qZPgRMgAAQItwcb9Om6aXTJra/NKON/NLPXVASjd5wKSxzba5LSwzeY/NMzfF5Twvm2Y2/+OOCzN+3scynDTgR0MbABMUZwDUXVMHjbSdYH7OCjQ2YSPoiZMsxzcp1qQRd37bcJEmYPi/NizM+FU8nyfp/vLLOWT4//55xYYMi9npNYQMAACAYP5rr9yb2vzSjjfz8z6WSYFGsm9sk+zvNePPNK5zk8n5bSYkJGlmy7gwk9ekAZ+6rJoB0PAozgDIVRZvnhY6aPil7QSTHBVogr5OEzaSzE0OCgZp/gSxnaEcVpSx7fxKWJgxXZbvfaygIcOPkAEAAJBOUzW1+bloavM+lqpAEzd5wKSxzUVzm//YrDJT1OsIO9Z1M5uUqDDjl+WkAb8irJpJMNKMhjagWCjOACiEhgwaPZ7PTYOGX1QnmLMCjW03WJKwIeV/r5kkz2MbLmx/PlEBI+jxkIcq4btFFmYKFDKiVs34uQgZAAAAzYSmNs/n/mveiufzTAo0/sdNm7VcNrflmZlMG9kks+8z6JxxUwYSFmYqns+j8nOSSQN+PeEPsWoGQFIUZwDkrqGDRhbjzTIv0Ejx3WBpwoZN4PCe33YWctJjo15PWKEprigTtE9GhZkk3V9+UePMejyf+36/s1o14//76gIdYAAAoBXYNrXZSNPUNuK60aapLWq8WZKpA1JGBZqgr00yQm1bVHObaYObbeHG1b1npHQTBjKYMiAlGwHtFzVpwPu72uN7zPM77s9MrlbN+NHQBjQnijMAiquIQcOvJ+JJTTvB/CoWj9UlbCQNHKbzkpPeyDLseYPYfB+mhSqvqGJYxKEV32NJu78yGGfGqhkAAIB8uWg+8V9rZdXUNoqL8WZ+pk1tfs4zU5K8ELVdcp+ZTHJTXFYLa8YzmTCggH0cNbNJyUdA5zxpIM2qGRragNZAcQZAYdi+6Zpl0HA2RzlpJ1jF83nUxaaUU9iwDRwmocPVDS5Nzxn12tIEjIRL8v0PV3yPuej+8sshZPgRMgAAABpHIZra6jJ1IO6aPk1jm+kEghrb+8WYyDszFaAw45fzODM/Vs0ACEJxBkBd1OPNVJug4WcVNFx0gkk5F2jiLp6l4O6rqFARFzpqsrqpZdxrCHssh4AR97Cr7q86hAybVTMAAACwl/Z+nYVvavOzmTqQuEATtIOrxrawfb3bY1+c6pOZol572PcZd6yjwoyf/7GkkwZYNQMgRxRnABRKwwaNnogT2wSNpLOUpZTdYEEnSBI2TAo1JsEjKZPniQoXOQSMoIcrns+z6v7yhowe32N1WjVDBxgAAEA0mtoCvo5qavPzP+akQGPa2JamuS3L3GSTmUy/N9NmtoTjn6XkI6D9bCYN9Hg+T3F/zjSrZgA0N4ozAOqmiEEj6o3nURdfWYw3k+xmKfsfzyxshBVpkhRq/PvYhBAXx/nZBKe4gBG0T8zDFc/ncf99Tbu//Cy6v9KEjHqsmqEDDAAAtKKWaGqry9SBoB2SNrZJ6TKTfz+bwk1Wmcm0KJNhM5tkNwI6g0kDfpF/B3xsV83Q0AY0N4ozAAqnnkHDL9Uc5aSdYH4V39eZF2jCTmIbOLzH2HZ9JQkgUccHsS3KJClsBeziVfF8blOY8YsKGVHdXz7+329CBgAAQP21bFObX9zUgYrnc+eZSUre2CaZZ6Y0mSftOYIkGdPm5biZTUo+AtovxTizRls1Q0MbUFwUZwC0nDRBY5R6dIJJOXaDJSnSmIYO24JNFNPzmqz4CdruZxkwgnapeD63Lcy46v6KCso+hAwAAICCa8WmtszGQtd2yKqxLSwzeY91nZtszpv03jlepj+riIcrvq/TjIBOOs7Mh4Y2AC5RnAFQV8ZvrjoOGjbjl2KDRj06waScw4ZN4JDMQof3HGn/RDEtGplsT9D5FbRLJWJf/2NZdX/5EDIAAAAaS9bXSIVtast16kDQTi4a2yTzzOQ9Tx6ZyUVRxvGUAcntCGhHDW1x9+ekoQ1AHIozAAop7zdj44KGzRvVmXaCVXxf5xY2avvZBg7JLnS4EBcspOiQ4jhg2IQM/2NxhZk03V8NHjIAAABaSeI3WJulqa2uUwfCdkqSmUwKNfXITWFsM2CC1TJBu1R8X2c1AlpyNs4sTtYNbQAaE8UZAHXXCEHDLzZo9EQcbNsJVtew4TJwSKNDR9rwYXuuqNeXQ8Co+LYlvZFl0NdR3V9+TRAy6AADAAAwuF+nY7k2teU9dSDx5AGbzGT6ZFlnJtOxamGP+SVsZlPALhXf12lGQPvFFWZ6Io71KVpDG9MGgMZEcQZAYRU9aFi9oW3bCVa3Ao1pcSJqf+9jRk+s8MAQ98dE3GupQ8CQ7OYl+9l2f/WEnypunBkhAwAAoBhcNacUuqmtnlMHpBwb27yPNXJmklI1s2VdmHE4acAm//szEw1tAMJQnAFQCEUJGqneeLYNGvUo0DgPGyYntpl17ILJ80U9HhWsDJ/eqxKwTz27v6LGScTIPGQYImQAAABsE9vUlvCaK0zqpjab69E8MpN/n9SNbVFFGpN8UpTMpJjHHTazSdkXZhxOGohraMsbDW1A46I4A6DQ8g4afqmDRo/vhFGdYElUfF877QZLWqSJewJ/EEgaQJKcxzZcxB0TsKtXJWCfLLu//Hp8Xxc9ZPj+fhMyAAAAsuG6qS3VdWOP7+u8m9qC9jGOJ2lyk8l565mZkuQmw5fkV/F9nXVhxmLSQJy4SQM0tAGIQnEGQNPJOmhYzVGOkzZoSBkWaGo72oYN73E2ASIsOKQNJknDRe1Yi6fwqvi+Hqf8Q4YF23FmqUNGQoQMAADQyoyvhXJuavNz3tRWrwKNVW4K4qK5LeyYPDOTlFkzW8W3rd6FGcuGttzR0AY0FYozAAqjUYKGX92DhpThcn3vzlFFGtNCTZGW6KftZvPs6lfxfW0S/uocMnIXMzeZkAEAAGDOxbVT1k1tqcabmciiQCM5aGyTGj8zOSjKZDFlwLWUkwaK0tAGoHFQnAFQeE0RNHp8L6geBZqgfayv/dMEDv850oaPJN1hccUky9fTJIUZQgYAAEDjcNXUVrhrtB7f17aZKYhJgcbp5IHaznGZKS43pV0JE3euODlnJilZYaaO48xykbChjWkDQOOgOAOgUJo2aJioV4FGSnCd7yJwBJ0viyX6Nqt7LF+uV0XZFGaaUUzICEPIAAAACEdTW8DXUnyBRjKfPOCsSCM1Xm5KUCByVZgJQkMbgCZAcQZAQ2iJoBEk7wKN0yKNNPIC3yZ0pGXbkWYhacAI2s+kMNOCIYORZgAAANESN60UoKmtYQo0Yfs5z0xS/TOTw0a2sEMqSr5yqeCTBjJBQxvQEijOAGgeOQSN3DvBgiQt0Jh0g0kJV8vbLo93FT6CzmczIsBC2CGVgG1JCjNBCBlbETIAAADijWpyMbzW8sq6qc2JIhRoEhdpbAo1WeamKClGqbnMTFL6wkwOaGgDkBTFGQCFE/YmbBGChhM9vq+TzFJOUqCRMg4b3gOTLp+3/ZPl6/Id6ldR8pEIWcxLzgAhAwAAoAll0NTWEFMHpHQFmqB9k94GpvC5KQHXzWySfWEmSJ0nDSRCQxvQMijOAGguDoJG5uPNghShQBO0b4prczc3r0zDwXOHHVoJ2JZ0Sb7kpvurx/d1I4QMAAAAJJb4fp0GbJvanHAxdcBlgSbzxjb/CeqdmTIoylQCtptOGUhSmHE8acAFGtoARKE4A6CQ6hk0kkjdCRYk7wJN2L6So5yQZbEmzY0vI07nV5Hbzi+pLt1fLjgJGb6/v2Ehgw4wAAAAc0Zv3BaxqS1Ij+/rLAs0lYD9TK/zJUcxx3GuiT2/g1MFqQRsMy14mWSmILaFGQM0tAHIGsUZAA2lqEHDSBZL9YO22RRocltFEyQoeCT94/glBamEbM+6MJOy+ysIIQMAAKA5NH1TWxCTAo2fSYFGSj95QGr+3BSXmSoB201/hqaFGRcjoHO4PycNbQDiUJwBUFhFDxqF6QQL2hYWNCoB25OsoqnHivssJQ0Y9S7MGKhHyDDC3GQAAIDMNFJTW92mDkjZNbZJzZeb4r6fSsh2l5kpaFsGI6BN2N6f0wgNbUDLoTgDoOFkFTTi5ij7L75MZNYJZlqgySNsNHLgSFKUkex+VlkWZnp8XxckZIz6+2YQMpibDAAAkEyjN7UZ6fF9HXRtnHWBRrJvbJOaOzNJ+TWzBW0Lysp1GGdmwsWqmTA0tAGNi+IMgObgIGiYsJ2jHMikE8xVgUZyFzbCjpEaK3CYFJUqIdttAobkrjATpMf3dQbjzEwQMgAAAIqpyE1ticab9fi+zqNAUwnYnjYzNVpuClORmyJW3oUZy3FmJvJaNUNDG9B8KM4AKLSwN2frFTRMOBlvFiTPAk2zBQ7T11WRu4DhsjCToPuLkAEAAIBRCtLUFsTJ1IEgaTKTZLcqpLZ/0DFeRcxMknkjWyXkMduJDEkLM0Ec3GcmCA1tALJEcQZA82j0oNHj+9qkE0zKphtMCg8bccfV1LtQY/P8FWUfMLIszDRhyAAAAIC5Rmhqa5ipA3k1tkkjM0vRc1NF0d+P7ci3NIUZByOggyS5PycNbQDSoDgDoPCaJWgUtkAjJQsbtePCjvXyh44sgoft+StKVpSpHRvEZcCQEnV/BWnEkEEHGAAAQEYyamrLZLxZkJ6AbS4LNFI2jW1hx3rlnZnSNrJJyZrZ8i7MBEgyaSDu/pxBaGgDEIXiDIDmUuCgESjpUv2susEqIa/BtEgTdnyQoGCQ5o8pk9cZ970GyaMw0+P7OuD3h5ABAADQmurd1GYis6kDQbIo0EjJM1Pt2LDjg9Q7M0W9VlfNbFK63OTXE7Atp0kDNLQBsEVxBkBDKGLQcDLeTEq2VF/KP2xIdoEj6jx5q8i8KGM7lsB155fkrPsrCCEDAAAA9WxqczLeTHI/dcB28kAl5DEyUzDT8c8K2eZoBHQQF5MGgtDQBiAOxRkAzcfgYsZ/UZQkaARJNN4sSE/AtrwLNJWQxySzwOE9T9S5slCRfbhwFTCkfAozBt1fhAwAAIDW4rLJJaumtiBG482ynjoQtj2sGUtq7Mxk+9wm0xSC2DSzhW13OALaxaSBIP6/C0b3t6WhDWh5FGcANIxUFx8GFz3+i6esgkYg06X6WXWDZV2k8Z7L+8eVJOc17WgLU6DCTBBCBgAAAGqSTBwIklVTW5C6TB2I2p42M9kWauLOayvpudPch9QmM4Vtz3AEdJCs7s9JQxuAIBRnADS8wKCRYPVMEP9FVlDQMBlvlrgTrCfgReUdNqT0q0+izuvij420s6CjOr9czkqWjAszJt1fJuPMCBkAAAAtJGFmyrOpLZepA7aTB5IWaaTGy0xJizJSfQszAZJOGnBxf85ANLQBEMUZAK2kETvBpPwLNGnChpS8UJMl09dUUX0ChuRsWX6QrMaZBSJkAAAA1JXL+3VK9W1qC+Q6M0l2kwek+Ma2SsTjUvNnJtspA000aSAIDW0AwlCcAdBQUgWNAPWcoywZdoLZsC3QZBU2auoVOsYp2fiAMLYBI2y7TWGmJ+L1DDPp/gpCyAAAAICk4Gu1Vmxqk9w2tknJmtvqmZviVJQuM9msUHJcmEkyaSAIDW0AXKM4A6B5FTBoZN4JJtkt11fEdtOwUYnYx8sfOlwFjzTnrch9wFDI9rSFmYTdX0ULGQAAAMheIzW1+a9XjZvailCgkaIzk2Q/ViyrYk3S3FRR/OuP+hnY/kwzLswEMRlnFoSGNgBpUZwB0HAaOWgEcR40JPdhw3XgqAkKCLZ/bFWUrigj2a9GyqAw06ghgw4wAACAgsm5qc3p1IG8CzRJG9skN/d/qUduiuKymU3KpTCTdNIADW0AskBxBkBza9WgIbkNG1K2gSNrFZm/rriijIuAIREyAAAAkKmiNLVlOnUgTE/AtrACjavJA5JZZpKKmZkk89yUtJkt6rE0hZkAppkp6aQBGtoAuEBxBkBDasagUYgCTdxjkn3gqBjsmwXb588rYEi5FGYIGQAAADCWYVNbkMynDkjFb2yTRmaWisH+rtk+v0lmSjJlwLQwE8ZgBHQQl5MGaGgDkATFGQDNz2HQcDneLJMCje1y/aRhQzIPHNLoi/6K4XGmkp5/mrIJGFLhCjOEDAAAgNZT5Ka2zKcOhCliY1tNRdkWa5Ke3+T7yLqZTWqoSQM0tAEwQXEGQMPKI2i4HG9m0gkWyrRAI9kt15fiL6JdB46aisM/tkzDRdIglkFhxpRpYYaQAQAAgBFybmoLkktTm5SsQOOisc02N1Uc/7Fh+przaGaTGmrSQCAa2gAEoDgDoDUkDBpBko43C2IcNML0hGx3GTZMHpeSB4682Ly+uO81KmCELclPMS9ZSt79FYSQAQAA0DqaoamtEAUaKX1mkoqdmSS7zJRHM5tUmEkDQWhoA5AGxRkADa0oQSNILkFDyi9smDxeU5RCje3rSFOkchEwJEIGAAAA6qcOTW1Jr1eljAs0WTa2SSOzSiPlJpPvryCFmSBpJg0kvj8nDW0AQlCcAdA6ChI0GqJAYxI2khRqsg4dSZ8rbcBqoMIMIQMAAAA1WTe1ZT3eLJSLAo2UvrHNNDNJ+RZqkuQm06KM7ZQBKXVhJkzSSQNBEt+fMwgNbQCGUZwB0LTqETSSjjeT6lSgSRM2avvYBA5pdBBIEkBcnMO0yBRXlClAYcaUaWGGkAEAANBcrK69DJvajFZVB8ilqU2yL9Bk0dhmuo9fWN6xyTwuziG5+R6jfoYOCjN5NLSZoKENgC2KMwAaXpGCRhDToBEm0wKNlD5s2OwXJSo8uFx94+p7chUwpFTL8iW348yCEDIAAACaV9KmtiBpmtrqXqCR0jW2mRYw8spNadi81jTNbA1SmKGhDUBWKM4AaGrNEDSknAo0Los0aQOHa7bhIm6/jAszYQgZAAAASKroTW02nBRoekJOnrSxTTLPQUXMTbavKatmtp6A7SkLM2mYZiYa2gAkQXEGQFPIK2jUs0ATKEmBJm3YSBI46hE6XIcLyW3nl9QQ3V+EDAAAgObXaE1tYawKNFJ9G9v8+zdCZvIeEyWHZrYsJg2kGVNuhIY2AD4UZwA0vcCLnaACTYGk7gSTki3Xl+LDhpQsPGRZrPGf23W4kJIFDCmzwkyhWIQMAAAA1E8rNLXlUqCR8slMLnNT2nOnbWaTMi/M0NAGoNFQnAHQNFJ3lhQoaEh1LtBI8WFDShcYgsJB0j9pnj9O0oAhZVqYadSQQQcYAABAMRW9uaZuBRpXjW31zk1pnzuKSWYqcGHGuaDMxKoZAAEozgBoCcarZxx3txSmQNMT8gJdhA2pmDOTg9i8zrQBoyfksTovyw8qzAQxLsywagYAAKChNOrqmTCZF2gkN41tUmNkJsnudcZ97xk3s0nmhZkwzhvaAMAQxRkATSWLDpM0QcNGprOUpfRhwzZwFCV02L4ek+81ScCQcg8ZpoJ+n7NABxgAAECxuW5qq+fUASmDAo2LxjapeJlJcp+b4n5ePSHbHWSmMHWbNEBDG4AQFGcAtIw0QSOvTrBMl+pL6cKGZBc4pPqFjiTPa1qUKWBhhpABAAAAG3k1taVVtwJNT8SLyjIz5Zmbkj6vi2a2npDHHBVm0kwaqDca2oDWQnEGQNMp4sVMYZbqS9FBQ4oPG5J94JCyu7ll2vOafi9xP5eeiMcK0P1Vb0X8ewkAAIDRitTUFibzAo2Uf2NbTRa5ycU5G6CZTUp/nxka2gDkieIMgJZSr6ARJvNOsDTdYFkVabzqdWNLya4ok7TzSyJkAAAAoJDq3TyTxdSBKHVvbKtnbkrK5rUXoJktr/vMZKXefycB5I/iDICmlOdFTdpOsEwLNFL6sJFX4MiD7etMGzDqUJghZAAAACCtPJrawmQ1FlrKsEBjkpmkxshMkn1myrEwEybtCGgbNLQBcIXiDICW4zpo2KjLUn0pXdiQzMOGVLxCTZLXk1HAkPLv/rJByAAAAGgtRWxqC+NiLLSUokDTE/OkSYo0RclMUrJGNpPM1BPxeILCTNoR0GFSTRpwgIY2oDVRnAHQtLK6uEk73szFUv1MCjQ94Q9LsgsbNfUKHUmfN8OAIbkrzNh0fzkfZ2aJkAEAANCY8mpqy3vqgJRwNLRk3tiWtLktz9yU9HlNv7+eiMcSTBmQ8p00YIWGNgAJUZwB0JLSBA0bLjrBcivQSNmEjRr/xb+r4OHivC4ChlS3wkyYTMaZETIAAACaSr2b2sJkOXVAyrixTSpWbso7M/VEPJ4gM0n5TxqgoQ1AHijOAGhqTi5yUs5RTtsJJmVQoHEZNpIEjpqgkGD7Jw2XAaOOhZm03V+EDAAAAATJoqktq/FmuRZoJLPMJKXPTFJjZCYpk2Y2qRiTBqzQ0AbAEMUZAC3LOGhYyKoTLEqiAo1kFjZ6DF+Ei8CRF5uiUo9SBQzJvjBjK5PurzCEDAAAgKaUZ1NbmHo0tUkOCjQ90bts5aK5LS+2r7VHhS7MhMmkoc3yPQUa2oDWRnEGQNOzvtjJaY6y5Ob+M1J0gSZV2JDMw4ZU3MCR5HX1GOwTEzCSFGay6v6yQsgAAADAsLRNbVmNN8u1QOOySCMVMzNJyTJTT8w+CacMSNlOGgiTuqEtAA1tAMJQnAHQ0tJeJKUdbxbGZYFGqkPYkOpfqEn6/D2qS8CQsu3+ImQAAAAgSlZNbWFsxpvVvUAjmTe29cTvttVM1Tc3JX3+HqVuZpPqW5hJPc6MhjYADlCcAdASihA0wrgKGlIBw0aN/6LfdfBwcf4e1S1gSPl3f4UiZAAAAMAni6a2MLYFmjBJCzSpG9skd7nJtTwzU8IpA1J9CjNh0t6fk4Y2AFEozgBoeVYXSzmON5PqVKDJMmx4BRVUkv5Jo0dOAoZU/8KMTfcXIQMAAAB+RWlqczF1QEpWoJEcNbZJ6XOTy8yUJjf1yPz7SJGZJPvCjK3MJg1Y3p+ThjYAEsUZAC3ESdAI4WK8WZ4FGudho8ds98LokfOAUcTCTBhCBgAAAGxl1dSW5dQBKeMCTV5FmnrpkV1mqkNhphCTBgAgIYozAKD0QSOMi04wyX2BRnIcNqTGCBw9snuNGQYMyc09ZiS77q9QhAwAAAAo/6a2MK6a2qToAk2qyQNSsiJNj9nuddEj50UZKX7KQB6FGRraABQNxRkALSXqIijw4slB0HDRCRYl0wKNlLxIU/tTbz3KpCgjZVOYieKi+4uQAQAAgKSK0tTmskAjOZg8INllJqnxM5NknJlcTxmQ3BVm8r4/JwB4UZwBgCRCgkbROsEkswJNJkWamh7lGzp6lC5cOAgYUv1DRpDQ309CBgAAADyK3tQm1Sc3xXKRmXrsD8/9+erczJZ1YSbL+3PS0AbAi+IMgJbjLGg0WCeYk7AhJQ8c0ugQUPtT73NZfE8mAaOoISMQIQMAAAAuBOWmDJvaojRsY1tNT8CfIpyrwM1strKaNAAANijOAICJDDvB8irQSA7DhpQ+cHj1JPjjgmVRJk3AkArY/RWGkAEAANDyrJvawtShqU3KrkAj5dTY5teT8I8LDpvZpGwKMy4mDYSioQ1ARijOAGhJWQcNqzfCI9SjQCNZhA3JbeDIi+VrThswJLeFGVtW48xCEDIAAAAQq0BNbVL2kwesizSNlJssX7NpM1sRCjP1aGgjMwEIQnEGAEw5WE1gGzSi5FGgSVSkKWrgSPD6TH8GWRRmomQ6zoxVMwAAABhWr6a2IhRopAwa26Ri56aEr62emSnzwoyjhjYACEJxBkDLaoWgkbYbTEpQpJGKEzhShAvTgNGwISNISGGGVTMAAAAw5uCenUnUu0BjnZmkYuSmFK+h3s1smY4yk5w1tJGZAIShOAMANiyDRr0LNJKbsCE5ChxZh46Uz2XzPZoEjMIWZhhnBgAAAEPOmtpCuMpMUrYFmkyLNFJ+ucnB8xS5mS1KvRraACAMxRkALS1R0Mh47FO9CjSZF2lq/GHANhikPT6AbVEmTcCQ6tz9FYZxZgAAAHClTk1tUnYFGimHxjavNLkng8wkuW1mk7IpzNDQBqCRUJwB0PKKNt4sSpYFGsk8bEiOAodfVIjIoIPM9ntIGzCkAnR/ETIAAABgqZGa2qTsRkNLOTe2hckxM0nZNLM1ZGGGhjYAjlGcAYAkGrATTDIv0NS9SJOxJEWZehZmCBkAAACotyI2tSVdXV6vxrZGyk1JXnMRM1PWaGgDkAbFGQBQ9nOUo9SjQOM6bEjFDxxJX59pwChKYcZagpnJhAwAAAAYybipTUo2FlqqT2ObVPzmtqSZKevCTJSo7ExDG4AiozgDAEk5ChpRsirQSNmFDak4hZo0r8NVwJDy7f5yMc4MAAAACJLHeLO8xkJLbgo0UroiTb1zU5rXYfN9py3MJLk3Z9aFGRraAKRVqlb51yLKmjVrNGnSJPX29mrixIn1fjkAMnZVqRT62Ak7hzywMGT7UcGbq4cHb79/yu6hz/2g9g597AntGvqYJD2l+ZGPL9KCyMe9XtBs433DvPx0+nOEcRFqbEJVVgFDImQAqC+ugWGD3xegtdQzM0nhuSnLzCSZ5yYXmUnKLje5KgS5KspI2eSmqGKei9wUN2GD3AQ0PxfXwBRnYhA0gNaSR9CQilegkfIv0gQxCSBZdZW5LMpIBSrMSIQMANa4BoYNfl+A1tOKTW1S42QmqbVzU9aFGYmGNgAUZ3JB0ABaT1E7waRihQ0pu8CRJ9vxA0UszEiEDABucQ0MG/y+AK0nKjNJIbmpSZrapNbLTVlkJqlADW0OM5NEbgJahYtrYO45AwAWrOcoJ7j/TNIbvaedpyyZ3djeK+k9aYrA9rWb/myatTADAAAA1CR689kyM0WJujZOet9Oaeha3tX9O70aNTcled2mmakwhRnHKMwAsNFUxZn+/n5deumlete73qXdd99d48eP14QJE3TggQfqggsu0MDAQL1fIoAGkNfFVJKLw7gbxLso0EjJw0bRA0fS15lH51dShAwAgA0yE4A81LupLU2BRsqmsU1qjCJNmszkqpkti4a2UEwaAFBHTTXW7PHHH98aMI4++mjtuuuu6u3t1VVXXaUlS5bozW9+s6688kqVYpbferFEH2hNiZbpS4VYqi/FL9eXsluy71WE5ftpwo+roowUHwLreZ8ZiZABYCSugZsXmQmAS05HQkuFGgstkZlM2BSo0jazFX3SALkJaC3cc8Zn8eLFuuKKK/Sud71L48aN27p9/fr1OuKII3TPPffo4osv1lve8hbjcxI0gNaVKGhIhbj/jOQ2bEjpAkdNHsHDRSeay4AhETIANB6ugZsXmQmAa3ncs1OqX25qxswk5ZubCpmZJBraAKTCPWd8dthhB33wgx8cETIkady4cfrYxz4mSbr11pzmvwBoeFEXV4nuzZHjUn3J3XL9miTL9v28S+RdLOl3fT7b77GQISPBzO4ohAwAaC5kJgB5cjXeTEo+ztfFaGib8dBFy0xZnNPm+yxkZpK4PyeAQmiv9wvIS0dHhySpvb1lvmUAGbvqyZBOsOsVvVQ/QOnW8E6w/VY+FtoJto8eiuwE20V/i+0Eq10sm3aE1S7CXXSFSW46ttKyDVCm4awuISMMIQMAEIPMBCCJE6rV2LHQVm5S5AqaIFGZSXKXm1o5M0nZ5KY0mSkxy8JMHBraACTVVCtnovzoRz+SJB133HGR+/X19WnNmjUj/gBoXYkvshx3gqVdQeN6FY3kpius3pJ8D6YBoxG6v+IQMgCgtZCZAGTBevVMhKSZSXI3eSDv6QNFkGTCQB6FmbwmDdDQBiArLVGc+d73vqfrrrtORx11lI4//vjIfb/0pS9p0qRJW//Mnl3/G7MBKC6X483ipCnQSNmEDakxA0fSooyLgCHlXJiJQMgAANSQmQCk0QhNbZKbAo2UvLGtkXJT0tecx5QBqTiTBmhoA5BGqVot3r8i55xzjvr6+oz3/8hHPqKddw6+y9zVV1+tU089VbNmzdKdd96pmTNnRp6rr69vxHOvWbNGs2fP5uaWQIuLW6ZfhBtdSvE3u5Tib3hZY3PjSz9Xy/ddShOEXAUMKXlhRkrY/UXIAJAAN3gvPjITgKJxnpmkuuUm08wkNVduaoTMJOU7aSAqN5GZgNbmIjMVsjgzfvx4rV+/3nj/m2++WUccccSo7ddee61OPfVUTZs2Tbfeeqt22mkn69dCMAVQExU2QoOG1NAFGild2JDqGzjSdqbZdMQRMgA0E66Bi4/MBKCImqlAI7VGbmr5zCTR0AYgkaYtzrhwzTXX6LTTTtPUqVN1yy23aMGCZP+TI2gAqGm2oCHlGzZqsgwdrsYEuA4YEiEDQGPhGrg1kJkAZKEoTW1S4xVovLLKTS5Hq1GYAdDKKM6EqIWMKVOm6NZbbw1dvm+CoAHAy3nQkDIr0Ejuw4bkNnD4mQSQLOc0286OdhEwJEIGgOLhGrj5kZkAZKVITW1S/gUaKdvMJMXnpqzvbVO0ZjYp30kDErkJAMWZQNddd51OOeUUTZ48Wbfccot23dXuf6B+BA0AXomDhpSoE0wqZoFGyj5w5Mm2KCM1ZmFGImQAMMM1cHMjMwHIWjMWaKTiFWnylkUzm1SnwoxEQxuAVCjO+Dz++OPab7/91NfXpzPPPDMwZPT09Ojss882PidBA4BfowUNKbuwITV24MiqKCNlXJiRCBkAMsU1cPMiMwHIQyZNbVJDFmik1stMUuMWZiTuzwnADMUZn1tuuUVHHnlk5D6HH364brnlFuNzEjQABCnSeDOp/gUaqbECR70DhkT3F4Bi4xq4eZGZAOSlaFMHJHcFGqn5c1MrZiaJ3ATAHMWZHBA0AARpxE4wKZ+wIRUzcCQNF5J5wJAIGQCaA9fAsMHvC4AweU8dkBqjsa2maLmpaTKTREMbgMy5uAYuO35NANAS4i7KIi/qIt5Aj7qAjLv4jLt4lcwugmt20d+sLrC9FmjR1j/1lPZ12PwM9tFDFGYAAAAAFxJmJinf3JQmM0nFyE0uXkMzFGYAoB5YORODLjAAYeqxekZy0wkm5beKJkgWHWKuA43rgCFlWJiR6P4C4BTXwLDB7wuAKI2cm2wyk+Q2N2W1qsZlbrItTBW5MENuAmCLsWY5IGgAiNLIQUOqb9gIExVC8ugoq0fAkAgZAIqFa2DY4PcFQJS4zCQlHG8mpRoLLTVuboor3GSdm7LITBKFGQCNheJMDggaAOIUtUAjNW7YqId6BQyJkAGgeLgGhg1+XwDEySwzSYUt0EjNl5uSjHArRGFGIjcBcI57zgBAA8ji/jOSwcWpzAoDkt29aKT0s5WLJMn30gyFGQAAACAvmd2zM0Y9M5PUPLkpaWaiMAMA0SjOAEBKmV6sFbhAI227SG+0wJHmdRemMJMSIQMAAABF0ihNba1UpMk6M0nmP/vEKMwAKDCKMwDgQKadYA0QNqTGCBxpXqPLzi+J7i8AAAC0ltTXoAUp0EjJGtukxmhuS9vIZpOZmDQAoNVxz5kYzE8GYCrVjS6lVLOUJbfzlGuSzFX2q+ecZRehxzZ4Fb0wI1GcARCPa2DY4PcFgI163X9GMstMkl1uavTMJOWfm0wLYVkVZiQa2gCk5+IauN3xawKAlnVCtWpUoAl1veLDRoTSrfFho3YRbBo29tFDqcOG90I/j9DhsgutcAFDSjVzWyJkAAAAoNiuejKiQBOXmW5SZIHGJDNJQ9f2NplJSlek8WeYrHNTvTKT5KiZTUo1AprCDICiYOVMDLrAANiqZyeYlE03mOSmIyxMkvCR5SiALAKG5CBk0P0FICdcA8MGvy8AbDXC1AHJPjNJxcpNRcpMUo6FGSYNAMgBK2cAoAFl2QkmZdMNJrnpCAtTlJnLWQUMqf7dXwAAAEBRZD51wGAFjeR+8oDkZvpAmCLkpsJnJiYNAGgg5Xq/AABoNiYXc5FvpMddTBq8gW90USvzmzB62dzksVEk+Z5sfnZ0fwEAAAB2Yq9xc85NNmr5oplyU9LvpyiZSWLSAIDioTgDABlIfVGXY9CQ7MOG1BxFmqwDhkT3FwAAABAkdVObiToWaGoaPTelyUyNVJgBgHqgOAMAdVKkTjApfdholMCR9vUWrTBDyAAAAECjynzqgFT3yQM1jZSZpHSvN/fMFINJAwCKinvOAEBGTOYoR95/RnJ2DxrJ/D40UrIbX0oj5w9neSNMWy5CkPOAIeVSmCFkAAAAoKnFZSbJ6b07Jfv7d3oVNTNJ6XOTbeHKWWGGSQMAGlSpWuVfoChr1qzRpEmT1Nvbq4kTJ9b75QBoQCY3uows0EjxYSMmaNSYho2apIHDrx6hw1VXWiYBQ6IwA6DQuAaGDX5fAKSVS2aSjHJTK2UmqeC5yUFhhtwEICsuroFZOQMAzcCgE0yy6waT0nWEefkv+LMIHlmMCKhbYSYGo8wAAADQTHKZOiA5nzwgpZ8+UBOUZ1znpqbLTBRmADQ4Vs7EoAsMgAtF6gST7LvBJHcdYaa8QSTv2cxJ5kgTMgA0E66BYYPfFwAumGQmidzk5S/e5JmbyEwAWp2La2CKMzEIGgBcKVqBRip+2MhbpgFDImQAaBhcA8MGvy8AXHGSmaSWKtDkLUlmkopXmJHITQDScXENXHb8mgAAKcReRJrc6NBiZJZVYWHYfisfS3xBXlRJvycKMwAAAIA7JteyRm+8O8xNaTJTM+WmNJkpz8KMKXITgCKgOAMAOTG9+GuEAo3UHEWaNN9DEQszAAAAQKNzVqAxYVGgadXc1EiZSaKhDUBjaa/3CwCAVmJyo0sjpje7lIyW69ve9NLLe6HeCMv30wYj5wFDYlk+AAAAYOmqJ2NGnJlkJmnomt1wxFnp1mSZSdqWQxohM0kFzE0UZgA0Ie45E4P5yQCykOssZSnz+9AEKVLocNWpVo/CjETIAJA/roFhg98XAFkwbWpzct9OySozSW5yU5Eyk+QmN1mvMMqxMCORmwC44+IamOJMDIIGgCw4CxpSocNGTT1Ch8vRAZkEDInCDIDC4hoYNvh9AZCVVijQ1JCZQlCYAVBQLq6BGWsGAHVgOt4sdqm+lMlyfSndqDM//0V/FsEjiznOieZK1yFkAAAAAM0o17HQktVoaCnbzCS5z01Z3fuGwgwAJMPKmRh0gQHIUiOsoJHcdoTZuH/K7nW5eWbSm30SMgA0C66BYYPfFwBZqktmkhomN9UKOA2Tm8hMAJqEi2vgsuPXBACwYHqRaHTRaXDxKmnoYtj0gnhY6dYUBYsUGipgEDIAAAAA5+qSmSTrzCTVJzftt/Kx3HNT4u/TYWYCgGZAcQYAGkSrho28ZB4wJAozAAAAQAKNVKCRmjc3pcpMjgsz5CYAzYDiDADUmc3FImHDrdr3kfkYM4l7zAAAAAApOC/QZDh5oKbZclMijjOTRGEGQPOgOAMABeD8otG2QJMybDRa4Ej9mm1/Zo4LM4QMAAAAtCKnBRopl8Y2qTEzk+QoN5miMAOgBVGcAYCCqGvQkFKFDan4hRpnr8/250RhBgAAAMhdZgUaB0WaomYmydFrzKCZTaIwA6D5UJwBgAIpRIEmZZFGKk7ocPo6kgQMxzeyJGQAAACg1TkfCy3l3tgmFSczOX8tGTSzSYyABtCcStUq7/REWbNmjSZNmqTe3l5NnDix3i8HQIu4qlQy2u+EnS1OutDyRRxlub+h6uHZnFfKMNhkFDAkur8AFBPXwLDB7wuAejDNTJJFbrLNTFImuSnLzCRllJuSFKwyKMyQmwDkxcU1cLvj1wQAyNFVT1oEjetlFzZqF9eOw0ZQEEgSPnLpMMswYEh0fwEAAABJnVCtGhdojHOTbWaSMslNYVnHNje1WmaiMAOg0bByJgZdYADqJZNOMClZN5iU2UqaQko6poCQAaBJcA0MG/y+AKinQuWmVspMEoUZAC3NxTUw95wBgILKZJaylPw+KI7uR1Noab5HQgYAAACQu0LlplbITFLy75PMBAAjUJwBgALLNGhQpNkmbVGGkAEAAAA0BBrbUqCZDQCcojgDAAWXWYFGSh42pOYIHGm/B8uAQcgAAAAA3LO9dqaxzcJNKmwzGwA0OoozANAAbAs0uYUNqfECR9pwIWUeMCjMAAAAAHYyLdBIbjJTI+amNCx/ZuQmAK2G4gwANIhChw2p+IHD1WsjYAAAAACFVPjMJBU7N7l6bQkaAMlNAFpRe71fAADA3AnVqq4qlYz3v+pJ6YSdLZ6gdgG90Opljea9mD8q5blcvAYXEgQxAgYAAACQr4bJTBK5aRi5CUCrojgDAE3OOmxIQxfULsKGFHyxn0XwyLLzjIABAAAANIzMCzSS2yKNNDrPZFWsySo35ZCZJHITgOZCcQYAGoxt0JAKEja8ogJBVAjJe+l/wrEFFGYAAACA+kpSoJHq3Njm1SiZSaIwAwAJUZwBgAaUW4FGyrZIE6QIs5dTzJKmMAMAAAAUQ1M0tgUpQmaScmtmk8hNAJpTud4vAACQTJKL06ueTHYhLMnNzS+LLsGNK2uS/GwJGAAAAEC2kuamRFLkiYaSMjfZIjcBaFYUZwCggSW9SCVs+KT8vggYAAAAQHHlWqCRyE0BkjYKkpsANDOKMwDQ4HIv0EjbLsobPXA4+B4IGAAAAEDx5T55QGqOzCTVpZlNIjcBaH7ccwYAmkCSWcpSipteeuU9X9kFBwGJgAEAAAA0ljS5icyUDLkJAMJRnAGAJlG7eK1L2JBGXrwXMXQ47FgjYAAAAACNqRCNbVIxM5NU96KMRG4C0DoozgBAk6lr2KgpSuhwPEKAgAEAAAA0vqSZSXLU2CY1bWaSyE0AYIriDAA0oUKEjRr/xX6WwSPDec4EDAAAAKB5pM1MksPclGdmCno+R1Ldn0fkJgCth+IMADSpQoUNr7AgYBNAcrypJgEDAAAAaE5pMpOUQWNbTYNlJolmNgBIguIMADSxwoaNIDmHhzhpizISIQMAAAAoOheZScopNxUsM0k0swFAGhRnAKDJNVTYKAgKMwAAAEDrqF27k5vMkZkAID2KMwDQAtIWaKTWCBsEDAAAAKB1kZviuchMErkJACSKMwDQMlx0g0nNGTYIGAAAAAAkNwUaqflyk6vMJJGbAKCG4gwAtBjXYUNqzMDhMlxIBAwAAACgWbjKTFLjF2koygBAdijOAEALchk2pMYKHK6LMhIhAwAAAGg2riYP1JCZyEwA4EdxBgBalOuwIRV3NU0W4UIiYAAAAADNLqvGNqk1MpNEbgKAMBRnAKDFuQ4bNfUOHVmGC4mAAQAAALSKLBrbpNGZJe/cRGYCgPqiOAMAyCxs1ARd9LsMHlmHCi8CBgAAANCasmpsq8m6WENuAoBioTgDANgq67DhlWcwcIWAAQAAALS2rBvbvMhMANDcKM4AAEbIM2w0CgIGAAAAAK88G9saAZkJAOxRnAEABKJIQ8AAAAAAEI7MNITcBADJlOv9AgAAxdaqF9qt+n0DAAAAsHNCtdqS+aFVv28AcIWVMwCAWK3UEUa4AAAAAJBEq+QmMhMAuEFxBgBgrFnDBuECAAAAgCvkJgCACYozAABr3ovyRg4chAsAAAAAWWmWIg25CQCyQXEGAJBKIwYOwgUAAACAvDRicxuZCQCyR3EGAOBE0QMH4QIAAABAvRW5uY3MBAD5ojgDAHCuKIUawgUAAACAIvJnlXrlJjITANQPxRkAQKaCLvazCB6ECgAAAACNKq9iDbkJAIqD4gwAIHcEAgAAAAAIR2YCgOZXrvcLAAAAAAAAAAAAaCUUZwAAAAAAAAAAAHJEcQYAAAAAAAAAACBHFGcAAAAAAAAAAAByRHEGAAAAAAAAAAAgRxRnAAAAAAAAAAAAckRxBgAAAAAAAAAAIEcUZwAAAAAAAAAAAHJEcQYAAAAAAAAAACBHFGcAAAAAAAAAAAByRHEGAAAAAAAAAAAgRxRnAAAAAAAAAAAAckRxBgAAAAAAAAAAIEcUZwAAAAAAAAAAAHJEcQYAAAAAAAAAACBHFGcAAAAAAAAAAAByRHEGAAAAAAAAAAAgRxRnAAAAAAAAAAAAckRxBgAAAAAAAAAAIEcUZwAAAAAAAAAAAHJEcQYAAAAAAAAAACBHFGcAAAAAAAAAAABy1F7vF1B01WpVkrRmzZo6vxIAAAAgH7Vr39q1MBCFzAQAAIBW4yIzUZyJsWLFCknS7Nmz6/xKAAAAgHytWLFCkyZNqvfLQMGRmQAAANCq0mQmijMxpkyZIkl6/vnnCaYZWLNmjWbPnq0XXnhBEydOrPfLaTr8fLPFzzdb/Hyzxc83W/x8s8fPOFu9vb2aM2fO1mthIAqZKXv8m5ctfr7Z4uebLX6+2eLnmy1+vtni55stF5mJ4kyMcnnotjyTJk3ilzhDEydO5OebIX6+2eLnmy1+vtni55stfr7Z42ecrdq1MBCFzJQf/s3LFj/fbPHzzRY/32zx880WP99s8fPNVprMRNoCAAAAAAAAAADIEcUZAAAAAAAAAACAHFGcidHV1aXPfvaz6urqqvdLaUr8fLPFzzdb/Hyzxc83W/x8s8XPN3v8jLPFzxc2+H3JHj/jbPHzzRY/32zx880WP99s8fPNFj/fbLn4+Zaq1WrV4WsCAAAAAAAAAABABFbOAAAAAAAAAAAA5IjiDAAAAAAAAAAAQI4ozgAAAAAAAAAAAOSI4gwAAAAAAAAAAECOKM6k9PTTT2v8+PEqlUr6wAc+UO+X0/B++ctf6pRTTtH8+fM1YcIEjR8/Xnvuuac++tGPavHixfV+eQ2tv79fl156qd71rndp99131/jx4zVhwgQdeOCBuuCCCzQwMFDvl9jw7r//fp133nlauHChpk2bplKppCOOOKLeL6vh3H333Tr++ONVqVQ0btw4HXTQQbr44ovr/bKawi9+8Qu9//3v12te8xp1dXWpVCrpJz/5Sb1fVtNYvHixvv71r+u4447TnDlz1NnZqRkzZui0007Tn//853q/vIa3adMmfexjH9Nhhx2mWbNmqbu7WzNmzNAhhxyiH//4x+rv76/3S2w6X/nKV1QqlVQqlXTXXXfV++WggZGZ3CIzZYfMlD0ykzvkpuyQm7JDZsoWmak+0uSm9oxeU0sYHBzU2WefXe+X0VQuuugiPfnkkzrooIM0c+ZMVatV3X///frGN76hn/zkJ7r99tu155571vtlNqRFixbp9NNP1/jx43X00UfrxBNPVG9vr6666ip98IMf1LXXXqsrr7xSpVKp3i+1Yf32t7/Vl770JXV2dmqXXXbR8uXL6/2SGs7NN9+shQsXqru7W2eeeaYmTJigSy+9VGeccYZeeOEFnXPOOfV+iQ3tU5/6lJ77/+3de0xX9R/H8ZeE4AZhFiyZbuAF0krDy7wsFbHy1k1T8zIRNJsz25xWTs2iXEQsl39UK5p5KeeFZbFS0CwlplBrQg5tbNjUVl7+IMy+pF9IPr8/miwH/ZbwPefzPYfn47/v5xzna2dnX85r73PO9+xZxcfHKzExUWfPnrUdyVfefvtt5efnq1+/fpo4caISEhJUW1uroqIiFRUVaceOHZo9e7btmJ4VCAT03nvvacSIEXr44YeVkJCg+vp6lZSUaNGiRdq1a5dKSkoUEcG9R6Fw4sQJ5eTkKCYmRg0NDbbjwMPoTKFHZ3IOncl5dKbQoDc5i97kHDqTs+hM7utwbzJotw0bNpjIyEizceNGI8ksWbLEdiTPu3LlSpvrmzZtMpLMzJkzXU7kH7/88ot59913TSAQuGE9EAiY4cOHG0mmsLDQUjp/OHHihDl27JhpbGw058+fN5JMenq67Vie0dTUZPr162eio6NNVVVVy/qlS5dMamqqiYqKMmfOnLEX0AcOHjzYcgzz8vKMJLNlyxa7oXxkz549prS0tNV6WVmZ6dq1q+nRo4e5evWqhWT+cO3aNRMMBlutNzU1mfHjxxtJZu/evRaS+U9jY6MZOnSoGTlypJk/f76RZCoqKmzHgkfRmUKPzuQcOpPz6EwdR29yHr3JOXQmZ9GZ3BWK3sSYrJ1qamq0bt06rVmzRmlpabbj+Ea3bt3aXJ81a5Yk6dSpU27G8ZVevXrpmWeeUUxMzA3rMTExWrlypSTpm2++sRHNN+655x4NHTpUXbt2tR3Fkw4dOqSffvpJ8+bNu+F7tXv37lq7dq0aGxu1bds2ewF94MEHH1RSUpLtGL71xBNPKD09vdX62LFjlZGRofr6elVXV1tI5g8RERGKiopqtR4ZGanp06dL4johVHJzc3Xy5Elt3rxZt9xyi+048DA6kzPoTM6hMzmPztRx9Cbn0ZucQ2dyFp3JXaHoTQxn2uHatWvKyspSSkqK1q1bZztOp7Bv3z5J0r333ms5iT9dvzCOjORNh7CntLRUkjRx4sRW2yZNmiSJMgzv4nvWOc3Nzdq/f78krhNCobKyUrm5ucrJydHdd99tOw48jM7kPjqTs/hbjnBBb4Jf8T3rHDpT6IWqN3G2t0NeXp4qKyv17bfftjmNRMcVFhbqxx9/1J9//qmTJ0/qwIED6tOnj9avX287mi9t3rxZUtsXd4BbamtrJUkpKSmttvXs2VOxsbEt+wBe8vPPP+urr75SYmKiBg0aZDuO5zU2Nur111+XMUZ1dXX6+uuvVVNTo4ULF+qBBx6wHc/TgsGgFixYoLS0NK1atcp2HHgcncl5dCZ30ZkQLuhN8CM6U2jRmZwVyt7EcOYmHT9+XOvXr9cLL7ygYcOG2Y7jW4WFhdqzZ0/L5+HDh2vXrl3q06ePxVT+9MEHH6ikpEQTJkzQ1KlTbcdBJ/b7779L+vtx/LbExcW17AN4RVNTkzIzMxUMBpWfn88rokKgsbFRr776asvnLl266Pnnn1deXp7FVP7w8ssvq7a2VseOHeNcRYfQmdxBZ3IPnQnhhN4Ev6EzhR6dyVmh7E2dcjjz3HPPKRgM/uf9ly9frpSUFDU2NiorK0v9+/dXTk6Ogwm9rb3H958++eQTSdKlS5dUVVWlF198UcOGDdOnn36qCRMmhDSv14Ti+F63d+9ePfvss0pKStL27dtDFdHTQnl8AXRuzc3Nys7OVllZmZ5++mllZmbajuQLsbGxMsaoublZ586d0xdffKG1a9eqoqJCxcXFiouLsx3RkyoqKrRhwwa98sorvOoAkuhMTqMzOYvO5Cw6E4BQoTM5g87knFD3pk45nCkoKFBDQ8N/3n/mzJlKSUlRXl6eqqurVV5erujoaAcTelt7j29bbrvtNmVkZGj//v266667tGDBAp0+fbpT/3hgqI5vcXGxZs6cqTvvvFOHDh1SYmJiKGN6VijPX9yc63d+/dtdXpcvX1aPHj3cjAS0W3NzsxYtWqQdO3Zo/vz5ev/9921H8p2IiAj17t1bS5cuVXx8vJ588knl5uYqPz/fdjTP+euvv5SVlaXBgwdr9erVtuMgTNCZnEVnchadyVl0JrvoTfALOpPz6Eyh5URv6pTDmUAg0K5/V1VVpebmZo0aNarN7QUFBSooKNDjjz+uoqKiDiT0tvYe3/8nLi5Oo0aNUlFRkU6dOqWBAweG/P/wilAc33379mnGjBmKj4/X4cOH1bdv3xAk8wcnzl/8N9cLW21tbatXoFy4cEGBQEAjRoywEQ24Kc3NzVq4cKE++ugjzZ07V1u3blVERITtWL52/f3/138gFzcnEAi0vJv+334bZPTo0ZKkzz77TNOmTXMrGiyiMzmLzuQsOpOz6Ex20ZvgB3Qm99GZOs6J3tQphzPt9dBDDyk+Pr7V+vnz51VcXKwBAwbo/vvv15AhQyyk879z585JUqe+AywUrpeM22+/XYcPH1b//v1tRwIkSenp6crLy9OXX36pOXPm3LDtwIEDLfsA4eyfJWP27Nn6+OOPeWeyC7hG6Jjo6Gg99dRTbW4rKytTbW2tHnvsMSUkJCg5OdndcPAcOpNdfB+GBp0J4YzeBK+jM9nBNULHOdKbDDrs8OHDRpJZsmSJ7SiedvnyZVNTU9Pmtg8//NBIMikpKS6n8pfi4mITHR1tevbs+a/HGqFx/vx5I8mkp6fbjuIZTU1Npm/fviY6OtpUVVW1rF+6dMmkpqaaqKgoc/r0aWv5/CYvL89IMlu2bLEdxTeuXbtmsrKyjCQza9Ys09TUZDuSr5w8edI0NDS0Wm9oaDCTJ082kkxubq6FZP52/ZyuqKiwHQUeR2cKDTqT8+hM7qEztQ+9yV30ptCiMzmLzmRPe3sTT84gbNTV1WngwIEaPny4BgwYoF69eqm+vl7ff/+9KisrFRcXp23bttmO6Vk1NTWaPn26gsGgxo8fr507d7baJzk5WdnZ2e6H84mamhq98cYbkqQrV660rP3zmG7dutVCMm+IjIzUpk2bNGnSJI0bN05z5szRrbfeqj179ujs2bPasGEDd2x30KZNm3TkyBFJUnV1dcva9ceax4wZo8WLF9uK53nr16/Xtm3bFBsbq9TUVL322mut9pk2bZrS0tLcD+cDhYWFeuuttzRmzBglJycrLi5Ov/76q0pKSlRXV6exY8dqxYoVtmMCgKPoTM6iMzmPztRx9Cbn0ZucQ2dyFp3JexjOIGwkJCTopZdeUmlpqQ4ePKi6ujpFRUUpOTlZK1as0MqVK9W7d2/bMT3rwoULCgaDkqRdu3a1uU96ejpFowMuXLjQqgxfvHjxhjWKxv+XkZGhI0eOKCcnR7t371ZTU5MGDRqk/Px8zZ4923Y8zzty5Eirc/To0aM6evRoy2dKRvudOXNG0t/voc3NzW1zn+TkZIpGOz3yyCM6d+6cysvLVVFRoUAgoO7du2vw4MGaM2eOFi1apMhILm0B+BudyVl0JufRmUKD3uQsepNz6EzOojN5TxdjjLEdAgAAAAAAAAAAoLOIsB0AAAAAAAAAAACgM2E4AwAAAAAAAAAA4CKGMwAAAAAAAAAAAC5iOAMAAAAAAAAAAOAihjMAAAAAAAAAAAAuYjgDAAAAAAAAAADgIoYzAAAAAAAAAAAALmI4AwAAAAAAAAAA4CKGMwAAAAAAAAAAAC5iOAMAAAAAAAAAAOAihjMAAAAAAAAAAAAuYjgDAAAAAAAAAADgIoYzAAAAAAAAAAAALmI4AwAIC8YYTZ06VV26dNHu3btbbZsyZUqb2wAAAACgs6A3AYB/dDHGGNshAACQpIsXL2rw4MEKBoM6fvy4kpKSJEkbN27UypUrlZ2drS1btlhOCQAAAAD20JsAwB8YzgAAwsr+/fs1depUjR49WmVlZaqurtbIkSOVlJSkyspKxcbG2o4IAAAAAFbRmwDA+3itGQAgrEyePFnLly9XeXm5Vq9erblz58oYo507d1IwAAAAAED0JgDwA56cAQCEnWAwqFGjRumHH36QJOXn52vVqlV2QwEAAABAGKE3AYC38eQMACDsREdHa8qUKZKkbt26afHixZYTAQAAAEB4oTcBgLcxnAEAhJ3vvvtOb775pu644w5dvXpVS5cutR0JAAAAAMIKvQkAvI3hDAAgrPzxxx+aN2+eIiMjVVpaqhkzZqiwsFCbN2+2HQ0AAAAAwgK9CQC8j9+cAQCElczMTG3fvl3vvPOOli1bpvr6et1333367bffVFlZqdTUVNsRAQAAAMAqehMAeB/DGQBA2Ni+fbsyMzP16KOP6vPPP29ZLysrU0ZGhoYMGaKKigp17drVYkoAAAAAsIfeBAD+wGvNAABh4fTp01q2bJkSExNbPYo/btw4rVmzRseOHdPatWstJQQAAAAAu+hNAOAfPDkDAAAAAAAAAADgIp6cAQAAAAAAAAAAcBHDGQAAAAAAAAAAABcxnAEAAAAAAAAAAHARwxkAAAAAAAAAAAAXMZwBAAAAAAAAAABwEcMZAAAAAAAAAAAAFzGcAQAAAAAAAAAAcBHDGQAAAAAAAAAAABcxnAEAAAAAAAAAAHARwxkAAAAAAAAAAAAXMZwBAAAAAAAAAABwEcMZAAAAAAAAAAAAF/0Pd7FYUZ+2Qy0AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# Create a 3D plot\n", "fig = plt.figure()\n", "ax = fig.subplots(1,2)\n", "ax[0].set_xlabel('x')\n", @@ -186,7 +201,7 @@ "\n", "ax[1].set_xlabel('x')\n", "ax[1].set_ylabel('y')\n", - "ax[1].set_title('Rosenbrock Function - Adam Trajectories')\n", + "ax[1].set_title('Rosenbrock Function - Ademamix Trajectories')\n", "# Show the plot\n", "ax[1].plot([1], [1], 'x', mew=1, markersize=10, color='cyan')\n", "ax[1].contourf(X, Y, Z, np.logspace(-1, 3, 100), cmap='jet')\n", @@ -199,24 +214,69 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "fbdc5a55-e55a-47d1-9980-2b602de6ee3b", + "metadata": {}, + "source": [ + "## Print out final values" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "2cf96de0-cb01-4338-87b4-dd80f0498ebd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "AdeMAMix Values:\n", + "Final value with b3 = 0.999: ((1.0000168085098267, 0.9999828934669495))\n", + "Final value with b3 = 0.9999: ((1.0000070333480835, 0.9999932050704956))\n" + ] + } + ], "source": [ - "print(\n", - " all_ademamix_params_array[0,-1,0],all_ademamix_params_array[0,-1,1],\n", - " all_ademamix_params_array[1,-1,0],all_ademamix_params_array[1,-1,1]\n", - ")" + "print(\"Adam Values:\")\n", + "[0.9,0.99,0.999,0.9999]\n", + "print(f\"Final value with b1 = 0.9: ({float(all_ademamix_params_array[0,-1,0]),float(all_ademamix_params_array[0,-1,1])})\")\n", + "print(f\"Final value with b1 = 0.99: ({float(all_ademamix_params_array[1,-1,0]),float(all_ademamix_params_array[1,-1,1])})\")\n", + "print(f\"Final value with b1 = 0.999: ({float(all_ademamix_params_array[0,-1,0]),float(all_ademamix_params_array[0,-1,1])})\")\n", + "print(f\"Final value with b1 = 0.9999: ({float(all_ademamix_params_array[1,-1,0]),float(all_ademamix_params_array[1,-1,1])})\")\n", + "\n", + "print(\"AdeMAMix Values:\")\n", + "print(f\"Final value with b3 = 0.999: ({float(all_ademamix_params_array[0,-1,0]),float(all_ademamix_params_array[0,-1,1])})\")\n", + "print(f\"Final value with b3 = 0.9999: ({float(all_ademamix_params_array[1,-1,0]),float(all_ademamix_params_array[1,-1,1])})\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "66647d3e-81e2-4987-b5ef-81e08ac048dc", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(2, 100001, 2)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_b1_params_array.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da0bc169-0f2e-43d5-a1a4-a9e407102cbf", + "metadata": {}, "outputs": [], "source": [] } diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 138094a8d..7b31ffcb1 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -1,10 +1,18 @@ +"""AdeMAMix. + +Implementation of +"THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER" +(https://arxiv.org/pdf/2409.03137) by Matteo Pagliardini, +Pierre Ablin and David Grangier. +""" import optax.tree_utils as otu -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import NamedTuple, Optional import chex import jax.numpy as jnp import jax.tree_util as jtu from optax._src import base from optax._src import combine +from optax._src import numerics from optax._src import transform @@ -26,38 +34,37 @@ def scale_by_ademamix( b3_scheduler: Optional[base.ScalarOrSchedule] = None, alpha_scheduler: Optional[base.ScalarOrSchedule] = None, eps: float = 1e-8, - weight_decay: float = 0.0, ) -> base.GradientTransformation: """Rescale updates according to the Ademamix algorithm. References: - [Pagliardini et al, 2024](https://arxiv.org/pdf/2409.03137) + [Pagliardini et al, 2024](https://arxiv.org/pdf/2409.03137) Args: - b1: Exponential decay rate to track the first moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW - b2: Exponential decay rate to track the second moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW - b3: Exponential decay rate to track the first moment of past gradients - for the second EMA. - alpha: the coefficient that "blends" the two EMAs. paper states values in - :math:`[4,10]` work well in practice. - b3_scheduler: The schedule for the b3 parameter - alpha_scheduler: The schedule for the alpha parameter - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - weight_decay: Strength of the weight decay regularization. + b1: Exponential decay rate to track the first moment of past gradients for + the first Exponential Moving Average (EMA) - same as AdamW + b2: Exponential decay rate to track the second moment of past gradients for + the first Exponential Moving Average (EMA) - same as AdamW + b3: Exponential decay rate to track the first moment of past gradients + for the second EMA. + alpha: the coefficient that "blends" the two EMAs. paper states values in + :math:`[4,10]` work well in practice. + b3_scheduler: The schedule for the b3 parameter + alpha_scheduler: The schedule for the alpha parameter + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. Returns: - A `GradientTransformation` object. + A `GradientTransformation` object. Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, - the method is best suited to settings where the number of iterations is - important. The paper reports on this effect in App. C.1.5, showing how - smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations - scenarios. Moreover, retaining gradient information over many thousands - steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is non-stationary. + the method is best suited to settings where the number of iterations is + important. The paper reports on this effect in App. C.1.5, showing how + smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations + scenarios. Moreover, retaining gradient information over many thousands + steps can pose a problem in domains requiring fast adaptation to a sudden + distribution shift, or general cases in which the distribution is + non-stationary. """ def init_fn(params): @@ -116,54 +123,55 @@ def ademamix( Description Examples: - >>> import optax - >>> import jax - >>> import jax.numpy as jnp - >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> solver = optax.ademamix(learning_rate=0.003) - >>> params = jnp.array([1., 2., 3.]) - >>> print('Objective function: ', f(params)) - Objective function: 14.0 - >>> opt_state = solver.init(params) - >>> for _ in range(5): - ... grad = jax.grad(f)(params) - ... updates, opt_state = solver.update(grad, opt_state, params) - ... params = optax.apply_updates(params, updates) - ... print('Objective function: {:.2E}'.format(f(params))) - Objective function: 1.40E+01 - Objective function: 1.39E+01 - Objective function: 1.39E+01 - Objective function: 1.39E+01 - Objective function: 1.38E+01 + > import optax + > import jax + > import jax.numpy as jnp + > def f(x): return jnp.sum(x ** 2) # simple quadratic function + > solver = optax.ademamix(learning_rate=0.003) + > params = jnp.array([1., 2., 3.]) + > print('Objective function: ', f(params)) + Objective function: 14.0 + > opt_state = solver.init(params) + > for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 1.40E+01 + Objective function: 1.39E+01 + Objective function: 1.39E+01 + Objective function: 1.39E+01 + Objective function: 1.38E+01 References: - Pagliardini et al, 2024: https://arxiv.org/pdf/2409.03137 + Pagliardini et al, 2024: https://arxiv.org/pdf/2409.03137 Args: - b1: Exponential decay rate to track the first moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW - b2: Exponential decay rate to track the second moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW - b3: Exponential decay rate to track the first moment of past gradients - for the second EMA. - alpha: the coefficient that "blends" the two EMAs. paper states values in - :math:`[4,10]` work well in practice. - b3_scheduler: The schedule for the b3 parameter - alpha_scheduler: The schedule for the alpha parameter - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - weight_decay: Strength of the weight decay regularization. + b1: Exponential decay rate to track the first moment of past gradients for + the first Exponential Moving Average (EMA) - same as AdamW + b2: Exponential decay rate to track the second moment of past gradients for + the first Exponential Moving Average (EMA) - same as AdamW + b3: Exponential decay rate to track the first moment of past gradients + for the second EMA. + alpha: the coefficient that "blends" the two EMAs. paper states values in + :math:`[4,10]` work well in practice. + b3_scheduler: The schedule for the b3 parameter + alpha_scheduler: The schedule for the alpha parameter + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + weight_decay: Strength of the weight decay regularization. Returns: - A `GradientTransformation` object. + A `GradientTransformation` object. Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, - the method is best suited to settings where the number of iterations is - important. The paper reports on this effect in App. C.1.5, showing how - smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations - scenarios. Moreover, retaining gradient information over many thousands - steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is non-stationary. + the method is best suited to settings where the number of iterations is + important. The paper reports on this effect in App. C.1.5, showing how + smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations + scenarios. Moreover, retaining gradient information over many thousands + steps can pose a problem in domains requiring fast adaptation to a sudden + distribution shift, or general cases in which the distribution is + non-stationary. """ return combine.chain( scale_by_ademamix( diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index d440fb786..55824cafb 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -36,82 +36,82 @@ # Testing contributions coded as GradientTransformations _MAIN_OPTIMIZERS_UNDER_TEST = [ - dict(opt_name="acprop", opt_kwargs=dict(learning_rate=1e-3)), - dict(opt_name="ademamix", opt_kwargs=dict(learning_rate=1e-3)), - dict(opt_name="cocob", opt_kwargs={}), - dict(opt_name="cocob", opt_kwargs=dict(weight_decay=1e-2)), - dict(opt_name="dadapt_adamw", opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name="dog", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="dowg", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="momo", opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name="momo_adam", opt_kwargs=dict(learning_rate=1e-1)), - dict(opt_name="prodigy", opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='acprop', opt_kwargs=dict(learning_rate=1e-3)), + dict(opt_name='ademamix', opt_kwargs=dict(learning_rate=1e-3)), + dict(opt_name='cocob', opt_kwargs={}), + dict(opt_name='cocob', opt_kwargs=dict(weight_decay=1e-2)), + dict(opt_name='dadapt_adamw', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='dog', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='dowg', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)), + dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)), dict( - opt_name="schedule_free_sgd", + opt_name='schedule_free_sgd', opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), ), dict( - opt_name="schedule_free_adamw", + opt_name='schedule_free_adamw', opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), ), ] for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST: - optimizer["wrapper_name"] = None - optimizer["wrapper_kwargs"] = None + optimizer['wrapper_name'] = None + optimizer['wrapper_kwargs'] = None # Testing contributions coded as wrappers # (just with sgd as we just want the behavior of the wrapper) _MAIN_OPTIMIZERS_UNDER_TEST += [ dict( - opt_name="sgd", + opt_name='sgd', opt_kwargs=dict(learning_rate=1e-1), - wrapper_name="mechanize", + wrapper_name='mechanize', wrapper_kwargs=dict(weight_decay=0.0), ), dict( - opt_name="sgd", + opt_name='sgd', opt_kwargs=dict(learning_rate=1e-2), - wrapper_name="schedule_free", + wrapper_name='schedule_free', wrapper_kwargs=dict(learning_rate=1e-2), ), dict( - opt_name="sgd", + opt_name='sgd', opt_kwargs=dict(learning_rate=1e-3), - wrapper_name="reduce_on_plateau", + wrapper_name='reduce_on_plateau', wrapper_kwargs={}, ), ] # Adding here instantiations of wrappers with any base optimizer _BASE_OPTIMIZERS = [ - dict(opt_name="sgd", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="sgd", opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name="adam", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="adamw", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="adamax", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="adamaxw", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="amsgrad", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="lamb", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="lion", opt_kwargs=dict(learning_rate=1.0, b1=0.99)), - dict(opt_name="noisy_sgd", opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), - dict(opt_name="novograd", opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), + dict(opt_name='adam', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='lion', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), + dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), + dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1.0)), dict( - opt_name="optimistic_gradient_descent", + opt_name='optimistic_gradient_descent', opt_kwargs=dict(learning_rate=1.0, alpha=0.7, beta=0.1), ), - dict(opt_name="rmsprop", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="rmsprop", opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name="adabelief", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="radam", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="sm3", opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name="yogi", opt_kwargs=dict(learning_rate=1.0, b1=0.99)), + dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), + dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), ] # TODO(harshm): make LARS and Fromage work with mechanic. _OTHER_OPTIMIZERS_UNDER_TEST = [ dict( - opt_name=base_opt["opt_name"], - opt_kwargs=base_opt["opt_kwargs"], - wrapper_name="mechanize", + opt_name=base_opt['opt_name'], + opt_kwargs=base_opt['opt_kwargs'], + wrapper_name='mechanize', wrapper_kwargs=dict(weight_decay=0.0), ) for base_opt in _BASE_OPTIMIZERS @@ -124,223 +124,235 @@ def _get_opt_factory(opt_name): - """Get optimizer factory.""" - if hasattr(contrib, opt_name): - return getattr(contrib, opt_name) - if hasattr(alias, opt_name): - return getattr(alias, opt_name) - raise ValueError(f"Unknown optimizer: {opt_name}") + """Get optimizer factory.""" + if hasattr(contrib, opt_name): + return getattr(contrib, opt_name) + if hasattr(alias, opt_name): + return getattr(alias, opt_name) + raise ValueError(f'Unknown optimizer: {opt_name}') def _wrap_opt(opt, wrapper_name, wrapper_kwargs): - if wrapper_name == "reduce_on_plateau": - return combine.chain(opt, contrib.reduce_on_plateau(**wrapper_kwargs)) - else: - return getattr(contrib, wrapper_name)(opt, **wrapper_kwargs) + if wrapper_name == 'reduce_on_plateau': + return combine.chain(opt, contrib.reduce_on_plateau(**wrapper_kwargs)) + else: + return getattr(contrib, wrapper_name)(opt, **wrapper_kwargs) def _setup_parabola(dtype): - """Quadratic function as an optimization target.""" - initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) - final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) + """Quadratic function as an optimization target.""" + initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) + final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) - @jax.value_and_grad - def get_updates(params): - return jnp.sum(numerics.abs_sq(params - final_params)) + @jax.value_and_grad + def get_updates(params): + return jnp.sum(numerics.abs_sq(params - final_params)) - return initial_params, final_params, get_updates + return initial_params, final_params, get_updates def _setup_rosenbrock(dtype): - """Rosenbrock function as an optimization target.""" - a = 1.0 - b = 100.0 + """Rosenbrock function as an optimization target.""" + a = 1.0 + b = 100.0 - initial_params = jnp.array([0.0, 0.0], dtype=dtype) - final_params = jnp.array([a, a**2], dtype=dtype) + initial_params = jnp.array([0.0, 0.0], dtype=dtype) + final_params = jnp.array([a, a**2], dtype=dtype) - @jax.value_and_grad - def get_updates(params): - return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq( - params[1] - params[0] ** 2 - ) + @jax.value_and_grad + def get_updates(params): + return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq( + params[1] - params[0] ** 2 + ) - return initial_params, final_params, get_updates + return initial_params, final_params, get_updates class ContribTest(chex.TestCase): - @parameterized.product( - _ALL_OPTIMIZERS_UNDER_TEST, - target=(_setup_parabola, _setup_rosenbrock), - dtype=("float32",), + + @parameterized.product( + _ALL_OPTIMIZERS_UNDER_TEST, + target=(_setup_parabola, _setup_rosenbrock), + dtype=('float32',), + ) + def test_optimizers( + self, + opt_name, + opt_kwargs, + wrapper_name, + wrapper_kwargs, + target, + dtype, + ): + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) + initial_params, final_params, get_updates = target(dtype) + + @jax.jit + def step(params, state): + value, updates = get_updates(params) + if ( + opt_name in ['momo', 'momo_adam'] + or wrapper_name == 'reduce_on_plateau' + ): + update_kwargs = {'value': value} + else: + update_kwargs = {} + updates, state = opt.update(updates, state, params, **update_kwargs) + params = update.apply_updates(params, updates) + return params, state + + params = initial_params + state = opt.init(params) + with self.subTest('Test that tree_map_params works'): + # A no-op change, to verify that tree map works. + state = _state_utils.tree_map_params(opt, lambda v: v, state) + + with self.subTest('Test that optimization works'): + + def f(params_state, _): + return step(*params_state), None + + (params, state), _ = jax.lax.scan(f, (params, state), length=30_000) + + if ( + opt_name in ['schedule_free_sgd', 'schedule_free_adamw'] + or wrapper_name == 'schedule_free' + ): + params = contrib.schedule_free_eval_params(state, params) + chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) + + @chex.all_variants + @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST) + def test_optimizers_can_be_wrapped_in_inject_hyperparams( + self, opt_name, opt_kwargs, wrapper_name=None, wrapper_kwargs=None + ): + """Checks that optimizers can be wrapped in inject_hyperparams.""" + # See also https://github.com/deepmind/optax/issues/412. + # When debugging this, make sure that options like weight decay or not + # are checked by asserting wehter such a value is None or not (see e.g. the + # logic in schedule_free_adamw). Some hyperparameters may not be supported + # by inject_hyperparams (e.g. warmup_steps). In that case (if you're sure + # you can ignore such hyperparameter), add the exception below. + if wrapper_name == 'reduce_on_plateau': + # TODO(vroulet): discuss adding support for reduce_on_plateau + # so removing all assertions in its definition + self.skipTest('reduce_on_plateau is not supported by inject_hyperparams.') + if wrapper_name is None: + factory = _get_opt_factory(opt_name) + hparams = opt_kwargs + else: + base_opt = _get_opt_factory(opt_name)(**opt_kwargs) + factory = getattr(contrib, wrapper_name) + factory = functools.partial(factory, base_opt) + hparams = wrapper_kwargs + opt = factory(**hparams) + + # Add here the hyperparameters that cannot be injected with + # inject_hyperparams. + static_args = [] + for uninjectable_hparam in ['warmup_steps', 'num_betas']: + if uninjectable_hparam in inspect.signature(factory).parameters.keys(): + static_args.append(uninjectable_hparam) + static_args = tuple(static_args) + opt_inject = _inject.inject_hyperparams(factory, static_args)(**hparams) + + params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))] + grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))] + + if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': + update_kwargs = {'value': jnp.array(1.0)} + else: + update_kwargs = {} + + state = self.variant(opt.init)(params) + updates, new_state = self.variant(opt.update)( + grads, state, params, **update_kwargs ) - def test_optimizers( - self, - opt_name, - opt_kwargs, - wrapper_name, - wrapper_kwargs, - target, - dtype, - ): - dtype = jnp.dtype(dtype) - opt = _get_opt_factory(opt_name)(**opt_kwargs) - if wrapper_name is not None: - opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - initial_params, final_params, get_updates = target(dtype) - - @jax.jit - def step(params, state): - value, updates = get_updates(params) - if opt_name in ["momo", "momo_adam"] or wrapper_name == "reduce_on_plateau": - update_kwargs = {"value": value} - else: - update_kwargs = {} - updates, state = opt.update(updates, state, params, **update_kwargs) - params = update.apply_updates(params, updates) - return params, state - - params = initial_params - state = opt.init(params) - with self.subTest("Test that tree_map_params works"): - # A no-op change, to verify that tree map works. - state = _state_utils.tree_map_params(opt, lambda v: v, state) - - with self.subTest("Test that optimization works"): - - def f(params_state, _): - return step(*params_state), None - - (params, state), _ = jax.lax.scan(f, (params, state), length=30_000) - - if ( - opt_name in ["schedule_free_sgd", "schedule_free_adamw"] - or wrapper_name == "schedule_free" - ): - params = contrib.schedule_free_eval_params(state, params) - chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) - - @chex.all_variants - @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST) - def test_optimizers_can_be_wrapped_in_inject_hyperparams( - self, opt_name, opt_kwargs, wrapper_name=None, wrapper_kwargs=None - ): - """Checks that optimizers can be wrapped in inject_hyperparams.""" - # See also https://github.com/deepmind/optax/issues/412. - # When debugging this, make sure that options like weight decay or not - # are checked by asserting wehter such a value is None or not (see e.g. the - # logic in schedule_free_adamw). Some hyperparameters may not be supported - # by inject_hyperparams (e.g. warmup_steps). In that case (if you're sure - # you can ignore such hyperparameter), add the exception below. - if wrapper_name == "reduce_on_plateau": - # TODO(vroulet): discuss adding support for reduce_on_plateau - # so removing all assertions in its definition - self.skipTest("reduce_on_plateau is not supported by inject_hyperparams.") - if wrapper_name is None: - factory = _get_opt_factory(opt_name) - hparams = opt_kwargs - else: - base_opt = _get_opt_factory(opt_name)(**opt_kwargs) - factory = getattr(contrib, wrapper_name) - factory = functools.partial(factory, base_opt) - hparams = wrapper_kwargs - opt = factory(**hparams) - - # Add here the hyperparameters that cannot be injected with - # inject_hyperparams. - static_args = [] - for uninjectable_hparam in ["warmup_steps", "num_betas"]: - if uninjectable_hparam in inspect.signature(factory).parameters.keys(): - static_args.append(uninjectable_hparam) - static_args = tuple(static_args) - opt_inject = _inject.inject_hyperparams(factory, static_args)(**hparams) - - params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))] - grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))] - - if opt_name in ["momo", "momo_adam"] or wrapper_name == "reduce_on_plateau": - update_kwargs = {"value": jnp.array(1.0)} - else: - update_kwargs = {} - - state = self.variant(opt.init)(params) - updates, new_state = self.variant(opt.update)( - grads, state, params, **update_kwargs - ) - - state_inject = self.variant(opt_inject.init)(params) - updates_inject, new_state_inject = self.variant(opt_inject.update)( - grads, state_inject, params, **update_kwargs - ) - - with self.subTest("Equality of updates."): - chex.assert_trees_all_close(updates_inject, updates, rtol=1e-5) - with self.subTest("Equality of new optimizer states."): - chex.assert_trees_all_close( - new_state_inject.inner_state, new_state, rtol=1e-5, atol=1e-5 - ) - - # Not testing with `without_device=True` because without_device set the - # variables to the host which appears to convert then the dtype, so we - # lose control of the dtype and the test fails. - @chex.variants(with_jit=True, without_jit=True, with_device=True, with_pmap=True) - @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST, dtype=("bfloat16", "float32")) - def test_preserve_dtype( - self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None - ): - """Test that the optimizers return updates of same dtype as params.""" - # When debugging this test, note that operations like - # x = 0.5**jnp.asarray(1, dtype=jnp.int32) - # (appearing in e.g. optax.tree_utils.tree_bias_correction) - # are promoted (strictly) to float32 when jitted - # see https://github.com/google/jax/issues/23337 - # This may end up letting updates have a dtype different from params. - # The solution is to fix the dtype of the result to the desired dtype - # (just as done in optax.tree_utils.tree_bias_correction). - # Otherwise, just make sure that all variables defined in the optimizer have - # the same dtype as the parameters. - dtype = jnp.dtype(dtype) - opt = _get_opt_factory(opt_name)(**opt_kwargs) - if wrapper_name is not None: - opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - fun = lambda x: jnp.sum(x**2) - - params = jnp.array([1.0, 2.0], dtype=dtype) - value, grads = jax.value_and_grad(fun)(params) - state = self.variant(opt.init)(params) - if opt_name in ["momo", "momo_adam"] or wrapper_name == "reduce_on_plateau": - update_kwargs = {"value": value} - else: - update_kwargs = {} - updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) - self.assertEqual(updates.dtype, params.dtype) - - @chex.variants(with_jit=True, without_jit=True, with_device=True, with_pmap=True) - @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST, dtype=("bfloat16", "float32")) - def test_gradient_accumulation( - self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None - ): - """Test that the optimizers can safely be used with optax.MultiSteps.""" - # Checks for issues like https://github.com/google-deepmind/optax/issues/377 - # Should pass as long as test_preserve_dtype passes. - dtype = jnp.dtype(dtype) - opt = _get_opt_factory(opt_name)(**opt_kwargs) - if wrapper_name is not None: - opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) - opt = _accumulation.MultiSteps(opt, every_k_schedule=4) - - fun = lambda x: jnp.sum(x**2) - - params = jnp.array([1.0, 2.0], dtype=dtype) - value, grads = jax.value_and_grad(fun)(params) - state = self.variant(opt.init)(params) - if opt_name in ["momo", "momo_adam"] or wrapper_name == "reduce_on_plateau": - update_kwargs = {"value": value} - else: - update_kwargs = {} - updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) - chex.assert_trees_all_equal(updates, jnp.zeros_like(grads)) - - -if __name__ == "__main__": - absltest.main() + + state_inject = self.variant(opt_inject.init)(params) + updates_inject, new_state_inject = self.variant(opt_inject.update)( + grads, state_inject, params, **update_kwargs + ) + + with self.subTest('Equality of updates.'): + chex.assert_trees_all_close(updates_inject, updates, rtol=1e-5) + with self.subTest('Equality of new optimizer states.'): + chex.assert_trees_all_close( + new_state_inject.inner_state, new_state, rtol=1e-5, atol=1e-5 + ) + + # Not testing with `without_device=True` because without_device set the + # variables to the host which appears to convert then the dtype, so we + # lose control of the dtype and the test fails. + @chex.variants( + with_jit=True, without_jit=True, with_device=True, with_pmap=True + ) + @parameterized.product( + _MAIN_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32') + ) + def test_preserve_dtype( + self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None + ): + """Test that the optimizers return updates of same dtype as params.""" + # When debugging this test, note that operations like + # x = 0.5**jnp.asarray(1, dtype=jnp.int32) + # (appearing in e.g. optax.tree_utils.tree_bias_correction) + # are promoted (strictly) to float32 when jitted + # see https://github.com/google/jax/issues/23337 + # This may end up letting updates have a dtype different from params. + # The solution is to fix the dtype of the result to the desired dtype + # (just as done in optax.tree_utils.tree_bias_correction). + # Otherwise, just make sure that all variables defined in the optimizer have + # the same dtype as the parameters. + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) + fun = lambda x: jnp.sum(x**2) + + params = jnp.array([1.0, 2.0], dtype=dtype) + value, grads = jax.value_and_grad(fun)(params) + state = self.variant(opt.init)(params) + if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': + update_kwargs = {'value': value} + else: + update_kwargs = {} + updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + self.assertEqual(updates.dtype, params.dtype) + + @chex.variants( + with_jit=True, without_jit=True, with_device=True, with_pmap=True + ) + @parameterized.product( + _MAIN_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32') + ) + def test_gradient_accumulation( + self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None + ): + """Test that the optimizers can safely be used with optax.MultiSteps.""" + # Checks for issues like https://github.com/google-deepmind/optax/issues/377 + # Should pass as long as test_preserve_dtype passes. + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) + opt = _accumulation.MultiSteps(opt, every_k_schedule=4) + + fun = lambda x: jnp.sum(x**2) + + params = jnp.array([1.0, 2.0], dtype=dtype) + value, grads = jax.value_and_grad(fun)(params) + state = self.variant(opt.init)(params) + if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': + update_kwargs = {'value': value} + else: + update_kwargs = {} + updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + chex.assert_trees_all_equal(updates, jnp.zeros_like(grads)) + + +if __name__ == '__main__': + absltest.main() From 7cb270a6fa78219a9ab6cdddacfabe0521a92571 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 14 Oct 2024 11:32:26 -0400 Subject: [PATCH 04/32] ran notebook in order --- examples/contrib/rosenbrock_ademamix.ipynb | 85 +++++++++++++++------- 1 file changed, 57 insertions(+), 28 deletions(-) diff --git a/examples/contrib/rosenbrock_ademamix.ipynb b/examples/contrib/rosenbrock_ademamix.ipynb index f89435321..825e928a0 100644 --- a/examples/contrib/rosenbrock_ademamix.ipynb +++ b/examples/contrib/rosenbrock_ademamix.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "source": [ "# Recreate AdeMAMix Rosenbrock Plot from Paper\n", - "This notebook attempts to recreate the Figures 2(b) and 2(c) from the [AdeMAMix paper](https://arxiv.org/pdf/2409.03137)" + "This notebook attempts to recreate Figure 2 from the [AdeMAMix paper](https://arxiv.org/pdf/2409.03137)" ] }, { @@ -169,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "id": "69d8642f-dfcc-4fac-8f85-3ee1fbfa135f", "metadata": {}, "outputs": [ @@ -214,6 +214,51 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "e244977b-6418-41cb-aec4-0ced34aca838", + "metadata": {}, + "source": [ + "## Plot Figure 2a from Paper" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9db496a1-7b7d-44b3-a5f8-a662ea10bb5a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABqgAAAMyCAYAAAAR60BPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxV1d7H8c8ZmQ6ggAIKIs5DzmbOw600LUutnO5VKystNW9d7XZ9Si3v0/SUpaXeVETSTE29Ug6pKThnOZvmjCgqKCIznPn5Y3MOHA8qqAjq7/167dfeZ62191kbieB891pLZbfb7QghhBBCCCGEEEIIIYQQQghxl6jLuwNCCCGEEEIIIYQQQgghhBDiwSIBlRBCCCGEEEIIIYQQQgghhLirJKASQgghhBBCCCGEEEIIIYQQd5UEVEIIIYQQQgghhBBCCCGEEOKukoBKCCGEEEIIIYQQQgghhBBC3FUSUAkhhBBCCCGEEEIIIYQQQoi7SgIqIYQQQgghhBBCCCGEEEIIcVdpy7sD9zuLxcK+ffsIDg5GrZY8UAghhBBCCCGEEEIIIYR4kNlsNlJSUmjRogVa7YMb0zy4d36X7Nu3jzZt2pR3N4QQQgghhBBCCCGEEEIIUYH89ttvPPzww+XdjXIjAVUZCw4OBpRvtNDQ0HLujRBCCCGEEEIIIYQQQgghytPFixdp06aNMz94UElAVcYc0/qFhoYSFhZWzr0RQgghhBBCCCGEEEIIIURF8KAvC/Rg370QQgghhBBCCCGEEEIIIYS46ySgEkIIIYQQQgghhBBCCCGEEHeVBFRCCCGEEEIIIYQQQgghhBDirpKAqoSmTZtGREQEnp6edOzYkQMHDpR3l4QQQgghhBBCCCGEEEIIIe5JElCVwKJFi/jnP//JlClT2LNnD3Xq1KFHjx5kZmaWd9eEEEIIIYQQQgghhBBCCCHuORJQlcAXX3zByJEjGTp0KI0bN2bu3LlYLBYWLVpU3l0TQgghhBBCCCGEEEIIIYS452hv5+T//ve/zJw5k71795KTk0NoaCht27bl008/JTw8/E718aYWLlzI1q1b2bNnD4cOHcJkMhEdHc0LL7xw3XN+//13Jk2axI4dOzCbzTRp0oS33nqL/v37u7QzmUzs27ePSZMmOcu0Wi1du3Zl586djBw5sqxuSwghhBBCCCGEEEIIUQyz2YzVai3vbgghBAAajQadTlfe3bjn3FJAZbfbGTlyJLNnz6Z27doMHDgQX19fLly4wObNm0lMTLyrAdW7775LYmIiQUFBhIaGkpiYeMP2cXFx9OjRA09PT2ffly9fzoABAzh37hz/+Mc/nG1TU1OxWq0EBwe7XKNq1aqcOnWqTO5HCCGEEEIIIYQQQgjhLjMzk9TUVIxGY3l3RQghXHh4eBAUFISfn195d+WecUsB1fTp05k9ezavv/4606dPR6PRuNRbLJabXmPBggV07tyZiIiIYuutVivTpk1j9OjR6PX6G15r7ty51K1bl4iICD7++GP+9a9/XbetxWLhlVdeQa1Ws2XLFpo3bw7AxIkTadOmDRMmTOC55567br+EEEIIIYQQQgghhBB3X2ZmJufPn8dgMBAUFIROp0OlUpV3t4QQDzi73Y7ZbCYjI4Pz588DSEhVQqUOqPLy8nj//fepVasW06ZNcwunQJkC70aSkpJ45ZVXCA0NJT4+3i0MstlsDBs2jO+++w69Xs/o0aNveL3HHnusxP3ftGkTp06d4sUXX3SGUwD+/v5MmDCBF154gZiYGCZOnAhAUFAQGo2GlJQUl+tcunSJkJCQEr+vEEIIIYQQQgghhBDi1qWmpmIwGAgLC5NgSghRoXh5eeHr60tSUhKpqakSUJWQurQnrF+/nqtXr9KnTx+sVisrVqzg448/5j//+Q8nT54s0TXCwsL4/vvvSUpKolu3bpw9e9ZZVzScGjp0KK+//nppu3hD8fHxAHTv3t2trkePHgBs3rzZWabX62nRogUbN250llksFuLj42nXrt0d7ZsQQgghhBBCCCGEEMKd2WzGaDTi7+8v4ZQQokJSqVT4+/tjNBoxm83l3Z17QqlHUO3ZswdQFv1q2rQpx48fd9ap1WrefPNNPvvss5tep2/fvnz//fcMGjSIrl27Eh8fT1hYGC+88AILFy7kr3/9K9HR0ajVpc7QbujEiRMA1K1b160uJCQEg8HgbOPw5ptvMnz4cFq1akXLli357LPP0Gq1DB48+LrvM2PGDGbMmIHJZLqj/RdCCCGEEEIIIYQQ4kFjtVoB0Ol05dwTIYS4PsfPKKvVKj+vSqDU6c+lS5cAmDp1Kv7+/vz2229kZWWxZcsW6tWrx+eff86sWbNKdK3nnnuOhQsXcvbsWbp168agQYNYsGABAwcOJCYm5o6HUwAZGRmAMqVfcfz8/JxtHAYPHszHH3/MhAkTaNGiBceOHWPdunU3HKY3atQojhw54hyxJYQQQgghhBBCCCGEuD0yekoIUZHJz6jSKfUIKpvNBihT361cuZJq1aoB0KlTJ3744QeaNWvG559/zmuvvVai6w0YMACLxcLf/vY3Tp8+TZ8+fVi4cGGxa1uVp7FjxzJ27Njy7oYQQgghhBBCCCGEEEIIIcQ9r9RDlBwjj1q3bu0MpxweeughatWqxalTp0hPTy/R9ex2O5s2bXK+Pnz4MCkpKaXtVok5+n/tKCmHzMzM646uEkIIIYQQQgghhBBCCCGEELev1AFV/fr1AahUqVKx9Y7yvLy8m17Lbrfz6quvMm/ePAYMGMDChQs5ffo03bp148KFC6XtWok41p66dp0pgOTkZLKzs4tdn0oIIYQQQgghhBBCCCHuB2fOnEGlUvHCCy+Ud1duqmbNmtSsWbO8uyGEKAOlDqi6desGwJ9//ulWZzabOXnyJD4+PlSpUuWG17Hb7YwYMYK5c+fSv39/vvvuO/7617+yYMECTp06Rbdu3bh48WJpu3dTXbp0AWD9+vVudevWrXNpI4QQQgghhBBCCCGEEBXNSy+9hEqlIjAwEKPRWN7dqVByc3P5/PPPGTx4MA0aNECtVqNSqThz5kx5d83N8ePH6d+/P0FBQXh5edGsWTNmzZqF3W4v1XWSkpIYMWIENWrUQK/XU61aNV588UXOnTtXRj0Xd0q3bt1o1KgRM2bMKO+ulItSB1S1a9eme/funDx5krlz57rUffzxx6Snp9O3b1+02usvb2W323nttdeYM2eOM5xyrDk1aNAgl5AqOTm5tF28oUcffZRatWqxaNEi9u/f7yzPyMjgww8/RK/XM3To0Dv6nkIIIYQQQgghhBBCCHEnZGVlsXTpUlQqFWlpaaxcubK8u1ShXLp0iXHjxvH999+Tn59P5cqVy7tLxTpy5Aht2rQhNjaWnj178sYbb2C1Wnn99dd54403SnydU6dO0apVK2bPnk3Dhg0ZO3Ysbdq0ISYmhtatW3Pq1KkyvAtxu+Li4jhy5AijRo0q766Ui+unSDcwc+ZM2rdvzyuvvMLKlStp0KAB+/btY9OmTURERPB///d/Nzz/woUL/Pe//+X555/nu+++cwuzBg0ahN1uZ+jQofzyyy/87W9/u+H15s6dy7Zt2wA4dOiQsyw+Ph6Ajh078vLLLys3rNUyd+5cevToQefOnRk4cCC+vr4sX76cxMREPvvsMxkyKoQQQgghhBBCCCGEqJCWLFlCTk4Ob731Fl9++SVRUVEMGDCgvLtVYQQFBbF+/XpatWpFQEAATzzxhHPmrIrktddeIyMjgzVr1tCzZ08ApkyZwmOPPcbXX3/N4MGDadeu3U2vM3bsWC5dusS0adNcgq0ffviB/v37M2rUKH7++ecyuw8hbkepR1CBMopq9+7dvPDCC+zZs4fp06dz4sQJRo0axW+//UZISMgNz69evTo7d+5k0aJF1x1pNXjwYI4cOXLTcApg27ZtxMTEEBMTw969ewHYvn27s8wRXjl069aNbdu20aFDB5YsWcKsWbMIDg5m8eLF/OMf/yjhV0EIIYQQQgghhBBCCCHurqioKLRaLW+//TbdunVj48aNJCYmFtvWarXyySefUKdOHTw9PalTpw4fffQRNput2PZxcXG89NJL1K9fH4PBgMFgoHXr1syePbvY9iqViq5du3L+/HkGDx5MUFAQvr6+PPnkk5w+fRpQlorp06cPAQEB+Pr68txzz5GSklLq+05PT2fEiBGEhITg6elJixYt+P77793aGQwGHn/8cQICAkr9HnfL8ePH2bJlC926dXOGUwB6vZ4pU6YAMGfOnJteJz8/n3Xr1hEcHMyYMWNc6p5//nmaN2/OunXrnP8WQlQ0tzSCCiA8PJzo6OhbfuNatWrdtE29evVKdK358+czf/78Ur1/mzZtWLt2banOEUIIIYQQQgghhBBCiPJy5MgRfv31V3r16kVwcDBDhw5l48aNREdHM3nyZLf2r776KvPmzSMyMpJRo0aRn5/P1KlT2bFjR7HX/+STTzh58iRt27alb9++pKen8/PPPzNixAiOHTvG559/7nbO1atX6dixIyEhIQwbNozjx4+zatUqjh49SmxsLJ06daJVq1a89NJL7Nmzh+XLl5OWlsamTZtKfN8mk4nHHnuM7OxshgwZQk5ODkuXLmXw4MGkpqa6hTMVnWPmr+7du7vVdezYER8fHzZv3nzT61y5cgWLxUJERAQqlcqtPjIykv379xMXF1eiz+OFuNtuOaASQgghhBBCCCGEEEKI8mS328kzW8u7GzflpdMUGyCUVlRUFABDhgwBoF+/frz++utER0czceJE1OrCCbPi4+OZN28ezZo1Y/v27fj4+AAwYcIEmjdvXuz1Z82aRWRkpEuZxWKhV69eTJs2jbFjx1KjRg2X+oMHD/Lmm28ydepUZ9nrr7/OrFmz6NSpE5MnT2bs2LGA8u/11FNPsWbNGvbu3UvLli1LdN8XL16kbt267NixA71e77yPFi1aMH78ePr160f16tVLdK3SiI+Pd4ZJJVGzZk1eeOGFm7Y7ceIEAHXr1nWr02g0REZGcuTIESwWy3VnIAOoXLkyGo2GxMRE7Ha72/dYQkICoIzYEqIikoBKCCGEEEIIIYQQQghxT8ozW2k0seKtL3StIx/0wFt/ex/Fms1mFixYgJ+fH3369AGU6ez69u3LwoUL+eWXX1xG5Hz77bcATJw40RlOgbL8ytixY3nvvffc3uPacApAq9UycuRINmzYQFxcHMOGDXOpNxgM/Pvf/3YpGzRoELNmzSIwMNBlXSSVSsXAgQNZs2YNBw4cKHFABfDhhx86wymAsLAw532U1dIt8fHxvP/++yVu36VLlxIFVBkZGQD4+/sXW+/n54fNZiMrK4vKlStf9zre3t507tyZuLg4Zs6cyahRo5x1K1asYP/+/YAyPaIQFdEtrUElhBBCCCGEEEIIIYQQ4u6JjY3l8uXLPP/883h6ejrLhw4dChSOrnI4cOAAAJ06dXK7VnFlAFlZWUyaNIlmzZphMBhQqVSoVCqeffZZAC5cuOB2Tt26dfH29nYpCw0NBaBp06Zuo3ocdcVd63q0Wi3t2rW77n3s27evxNcqjcmTJ2O320u8lWa01Z3yxRdfYDAYGD16NE888QRvv/02/fr14/nnn6dp06YALiPrhKhIZASVEEIIIYQQQgghhBDinuSl03Dkgx7l3Y2b8tJpbvsajgDKEUg5PProo1SvXp3Y2FjS0tIICAgAlFE6arWaoKAgt2sFBwe7lZlMJrp27crevXtp0aIFQ4YMITAwEK1Wy5kzZ4iJicFoNLqd5+fn51bmmJbuRnVms/lmt+wUFBRUbMjiuA/HiKR7hWPk1PX6nZmZiUqlwtfX96bXatasGb///juTJk0iLi6OuLg46tSpwzfffEN6ejrjx4+natWqd7T/QtwpElAJIYQQQgghhBBCCCHuSSqV6ranzrsXnDt3jvXr1wPKNHLXs3DhQueUev7+/thsNlJTU6lSpYpLu5SUFLdzY2Nj2bt3L8OHD2fu3LkudYsXLyYmJuZ2b+OWpaamYrPZ3EIqx31cb6q821VWa1A51p5yrEVVlNVqJSEhgcjIyBuuP1VUgwYNWLJkiVu5oy+tW7cu0XWEuNvu/5/eQgghhBBCCCGEEEIIcQ+bP38+NpuNjh07Ur9+fbd6i8VCTEwMUVFRzoCqWbNm7N27l61bt9KvXz+X9lu3bnW7xqlTpwB45pln3OqKa383WSwWdu7cSYcOHVzKHf1q0aJFmbxvWa1B5QgZ169fzzvvvONSt23bNnJycm4YRJZEVlYWP/30E4GBgTz++OO3dS0hyooEVEIIIYQQQgghhBBCCFFB2e12oqOjUalUxMTEUKtWrWLbHT9+nJ07d7J7925at27NkCFDiI6O5oMPPqBHjx74+PgAcP78eaZNm+Z2fkREBKAEJL1793aWb968mTlz5pTBnZXOhAkT2LBhA3q9HoCkpCSmTZuGh4cHAwcOLJP3nDx5MpMnT77j161fvz6dO3cmLi6OtWvX0rNnT0CZZvG9994D4OWXX3Y5JzU1ldTUVIKCglymbczLy0On07mMtjIajQwfPpy0tDSmTZvmsmaZEBWJBFRCCCGEEEIIIYQQQghRQW3atImEhAS6dOly3XAK4MUXX2Tnzp1ERUXRunVrunXrxosvvkh0dDRNmjShb9++GI1GlixZQtu2bVm1apXL+b1796ZmzZp8+umn/PHHHzz00EMcO3aMVatW0bdvX5YtW1bWt3pdoaGh5OTk0LRpU3r37k1OTg5Lly7lypUrTJ8+nerVq7u0HzduHKmpqQAcOnTIWWYwGAAl/OnYsePdvYlrzJw5kw4dOtCnTx8GDBhAaGgoq1ev5vDhw4wePZr27du7tP/66695//33mTRpkktotmfPHvr168fjjz9OeHg4mZmZrF69mrNnz/LKK68wZsyYu3xnQpScBFRCCCGEEEIIIYQQQghRQUVFRQHcdOq4AQMGMHbsWL7//numTp2Kl5cXc+bMoV69esyZM4evv/6asLAw3nrrLfr37+8WUBkMBjZt2sT48ePZsmUL8fHxNG7cmO+++47g4OByDaj0ej0bNmzgnXfeYcGCBaSnp9OgQQO++uorBg0a5NZ+2bJlJCYmupQtX77cedy1a9dyD6gaN27Mrl27ePfdd1m9ejU5OTnUq1ePGTNm8Nprr5X4OjVq1KBr165s3bqVlJQUvL29admyJVOnTuXZZ58twzsQ4vap7Ha7vbw7cT9LSkoiPDycc+fOERYWVt7dEUIIIYQQQgghhBDinpOfn09CQgKRkZEyXZkQosIq6c8qyQ0UMoJKVEgmi42MXBOZ2dlkZ2eSk51JbnYmedmZ5OdlYcrNxpKfjcmYh81sRIcFHRb0KjM+GisGrQ1PnQYPvQ5fTw98vTzw9fbAZleRZUbZTJBrUZFr1ZBrVZFrUZFj1WCyqTHZNZjQKnu7GpNNgwkNVnTYVBpsah02tRa7WocFZY9KjVqtRqVSoVaBVqNCo1ajVauUTaNCo1ahLSjTaArKi7zWqdVKG02R8iKvNWoVuqLX0ajQadToNYXHuoK9VqNCr1E7j3VqNWq1qrz/aYUQQgghhBBCCCGEEEIICajE3We32Tj1v63R2K2osaC221BjRWO3osGK2m5Fh5nKGKmist3x9zcAoXf8qgqTXYMZLRaUvRktFrsGK2qsqLGhxoYKW8FrR5nj2I4Kq921vLBehQU1RuWrhsWuxYoaMxqsaFz2yntq3OqsKg12tRZUWuwaHag02NU61BqtUq7WodJoUWl0ULBXaxxlejQaHSqtFrVWh0ajR6PVotLq0Wm16NQqdFolVNNr1WjVSljmONZrlU2nUeGhVaPXaJyvHXUeGg06raogcFOX0b+SEEIIIYQQQgghhBBCiPImAZW461RqNZGW02hU15ldsphBPma0GFVemDWeWDReWLVe2LXeqHSeoPUg36YlLR+u5EOaEXKtGgA02NAUREIa7GhUtoJxUFa0WNFiQY9FOVZZ0WPBQ23DQ21DX/BaqypobzcroZrdgsZuRoV7/5VzrDe9nwrBkf1ZAfPtXcpqV2FBg+U6YZl7nZpctGQW1JnQYSrYm+1aTAXhnlWtx6rWKSPWVHrsGh1WtR40euwaPXa1HrR6UOtRaZUNrQcqrR611gON1gO1zgO11gO13gOt1gON3gO9VusMxfQaNbqCvV6rxkOrjDrz0Krx0Knx0GqUY62EZkIIIYQQQgghhBBCCHGnSEAlysXhbnNRq5WROGqtFrVGGcWj1erx0OsxeHvj7euHxsMHdD7oNFp011wjITWHH3afY93hZE5dznGp02lUNK7mz0PV/ahTxUDtqgaqV/LCYrWRnGnkfHoeSVdzOX81j/PpeZy/mkdyZj62UqzIpsaGriDc8lBZ8NOr8NXb8dfZMejtGLRg0IFBa8NHB15aFV46Fd5a8NKCj16Fj1aFt06Fjw68dUq9XgXYrWCzuu7t9iJlFmWzFuxt5iKvlWO71YzdquxtVsdrc2G5zQJWx3nmwmvaLKhsFlT2gn3BsdpmQWVXxnRdS6Oyo8GCBxb3L9SdCOiKhml3gLlgCkczWkzoMNp15KPHiLLPuOa10a7DiA6TSo9F7YG1YLOoPbBrPLBpPbBrPLFrPQsCMk/QeaHWeaDSeaPWeaLx8EKj90Kv0xWEX47gS9k7wrHC8muOtRp0GhUqVUVNPIUQQgghhBBCCCGEEKLkJKAS5aJp1+du6bwco4XVhy7yw+5z/H7mqrNco1bRKqIyXetX4ZHIQBpX88NTpyn2GvVCir+22WojOSOfpKt5XMzIIy3HxNVcE2k5ZtJzTaTlmMjIM5NjspBrtJJrspJnVmMEcuzKyC2Mt3RbLvRaNX6eOvy8tPh76QqOdfh5FrwuKFOOtc56fy8dvp5adAWjfFQUZkPFfyVukc123VBMCbqsxb++JgRzCcesRuwWEzaLEatZ2WyOvcWE3aLU2y1GsDr2ZlTWwr3KZkZtNaG2mVDbzKhtZjR2MxqbsnpYUbqCUXHOf7DSZj52lLDsFkafmewa8tGTjwe5dg/y0JNXcJyJB3l4kGf3IBcP8tGT6zxW6ixqT8waL6waL6w6L2wab+xaT2w6b9D5oNZ5oNdqXMItT50GT50aT61GOdZr8HSWF9TpNAX1yrFHkTIJxoQQQgghhBBCCCGEEHeaBFSiwrPb7ew9e5Ulv59j1cGL5JqUsEGtgi71qtC3ZRhd6lXB3+vaMValo9OoCQ/wJjzAu8TnWG128sxWco0Wck1Wsgv2jhArx2hRjh11RgvZBeVZRjMZeWYy8yxk5pvJzDNjs4PJYiM120hq9q2lXd56TZFgS1skzFJCLr8iIde1IZivhxa1+iZBhFoNaj2gv6X+XY8KJUi7o2Gag80GVpPrVhByYclXXpvzlDJLwd6cB5Z8bOY8rKZ8rKY8rKY8bKZ8bOZc7GYj9oI2WPJROTarCbU1H43ViNpmRGszorEXjixTpoHMw4+82xtd5gjIri22q5SQCz15dg9y8CIbT3LshfscPLla5DjbsceLHLunS1kOnqhURcIsrSPAKhp6uYZdHtprg69r6q8Jw7x0Grz0GuexBGJCCCGEEEIIIYQQQtz/JKASFdalzHxW7DvP0t3nOF1kCr/IIB+ebx1GvxZhhPh7lmMPlZFbBg8tBo/b/0/JbreTbbSQmW8hM88RXpldX+crgVbhsZmsfOV1tlEJQXJNyuiuixn5pe6DSgW+HtobjtLy89Ti66mM1vIrGLXl56m0N3hq0dws4CoPajWoPUFX+u8XdcF2W/GnzVoQZBUEX+Y8MOcW7HOUvSm3yHGOs43dlIPNlIvNqOyVOuVclTkXtSUPtSUXtU0ZyqVR2TGQj4H8O7b+WY4j6DJ7kmP2JCfPi2y7J5n4kGX3IhMfMu3eZOFNht3b5XVmwWtzKf53o1Gr8CoItLz0aiXAcr7WFL4uclw04HKcUzT8uvZ8T52mYn6vCiGEEEIIIYQQQgjxgJCASlQouSYLG46ksGLvebaeuOxcE8pbr+HJJqH0fzic1hGV78vRFSqVqiD40VG9klepz7dYbWQbLW6jsq4XbGXmW4qEYGbyzTbsdpRALN8C5N3SfRg8tPh6ap3BlW9BoOXnVSTYKrJ3lDvKvPWa++/fV60BvY+ylVKJR5ZZLQXBVcHmCLNM2cpmdOyzbvI6G0xZ2I3ZqOzKEC0flREfjFS9jX8Wk0pPjspAjsqHLBzhlQ8Zdi/SbV6kWb3IsHuTbjdwFQPpJgPpRgNpGMjBkzuWthWh16qLCbjUroHWdQKu4gIy74JNOdbiJSGYEEIIIYQQQgghhBDXJQGVKBdmq42sfAtZ+WYSr+Ty58VMtp1MZVdCGiaLzdmudURl+rcOp1fT0DsySul+ptWoqeStp5L3rU29Z7RYiwm2LM4Ayxl8Fbx2/PtlFuzzzcq/W7bRQrbRwsWMW7sPjVrlDLh8PdyDLb8bBV4FI7o8tGUyUWDFptGCxg88/e7I5VR2uzLqqyCwujbAwpgF+ZlgzIT8jIKt6HGGUmfMBEBvN6G3p1GZtGLejBv+38im1mHWV8Kk8yNfV4l8rR85Gj+y1X5kq33JVPmSURBspdkNpFl9uGz1IdOiIc9sJd9kJc+sbI7vU1Cm0zRZbGTklXIhsVLw0KoLgiutM8Ty0mncy/QavHXaIgGXo1x73XMc680JIYQQQgghhBBCCHEvkk/8xV1nstio9+7a69bXCPCmb4vq9G1RnZpBpR9xIm6Nh1ZDFV8NVXw9bul8k8VGVkFw5RJgOUKvgtdZBaFXVr6yDldmXmG5xWbHarOTnmsmPdfMrY7i0mvVRcKswuCquMCr6AivCj9V4d2kUoHOS9mocuvXsVkLQqxM1+Cq2FArHfLSIS8N8q5CbhpYjahtZjzyL+ORfxnf0ry33gA+QeAXBD5VwCcIu3cVzF6BmDwCyddXJlenbNkaf3It6oIQy0pe0VCryHGeyabUF21TsM81WckzWcg1W7EXjP40WmwYLTau5t75EEynURUEV67Blpdei7dO416mLwy5vPUaWtaoXO7TpAohhBBCCCGEEEKIB5cEVOKu02vVeGjVGC02vHQaQv09aRjqR4salehavwq1qxjuvyneHgB6rZpAgweBhlsLuOx2O3lmqzPAyrxmhJZLsFWkvGjAlVWwDpfJYiM120hqtvGW78dHrykMtjyLX3vLUX7tCC5fTx0+9+NUhbdCrQGvyspWWna7Ml2hI6zKu+oaXuVdvX6d3Vo4veHVM85LqgB9wWa49v28KoN3YZil7KtAQDAYQsC3YPOpqoxYu2637eSbbeSaLEpoVRBe5Zos5JkcQZbyOtdsdZY5Ay6Xc4qUmazkmq1YC+Y+NVvtmK2OKTlLb9ZfW9KzSegtnSuEEEIIIYQQQgghxO2SgEqUi10THsXHQytTVAknlUpVMLJDS7DfrY3qsNrsZBvdg6vMogGX0XVk17UjvhxTwOWYrOSYrLc1VaHBQ6uM2PIoDLD8vZRRWv5eyqgt52tv13IvnQRcqFSFa3f5h5X8PLtdGaWVkwq5VyDncpEttchxQV1uKththYHXlRM365gSXPleE1wZgsE3FJVvCF6GYLwMwQQavG/rS+B+a3ZMVts1oZZ72OUItQrDriIhV0H7qn63FiYLIYQQQgghhBBCCHEnSEAlysWtrpMkxI1o1Cr8C0IgbmHADhROVVjcyK3rTVV47Zpc1oKpCjMK1vK6lakKdRqVMuWgV8HmWRBm3SjkKtJW+yCHvyoVePorW2Dtm7e3FYRTjuAqN7UwyMpOgawUyE4u2Kcoo7NyLikbh258be9A8KuuBGx+1cG/OviHFx77hoJGV4pbU+Gh1eCh1VDpzmZfQgghhBAKmw2sJmWz2wA7zvmLoeDYfs0eUGuVUeZqLah1yu84D/oDV0IIIYQQ4oYkoBJCiCLKcqpC5bUSWmXmXftaaZORZ8Zqs2O22rmSY+JKjumW+uGj1xQGVjcJtBzllbz0VPLW4anT3NJ73rPUavAJVDYa3LitzaqMyspKVrbs5CLHKZB1sTDIspmVtrlXIPlg8ddTqZWRV44QyyXICoNKNcE7QD7cEUIIIYTCnF9k/c6CzZwDpoLNnAum3IKy3ILXRcoteWA1FwZQjmOLqUgoZb1z/VWpC8MqtUY5dqxzqvMCnXfB5lW41/soew+/woeOPP3BqxJ4Vip8XYqHfIQQQrg7c+YMkZGRDBs2jPnz55d3d26oZs2agNJnIcT9RQIqIYS4g253qkK73U6uyaqEVvlmMnILg6vMghFZxYVcjrock/KBgmOKwgsZ+aXug6dO7QyrKnnrqOytHPt76alcUFbJW08lL2Vf2VuZotBD+wAEW2oNGKoqW2jT67ez2ZQ1sbIuQsZ5yEwq2J9X9hnnIPOCEmJlXVS287uLv5beAJUioHJEwb5mkeMI5UMcIYQQQtw7XKYjTlNGb+deUV7npbkHUPmZhcfWW19jtVzYbUqfy6LfOh9lDdGia4caqhQeO8uClb36AZ5hQAhxX3rppZeIjo4mICCACxcu4OEh07g75ObmMmvWLPbs2cPevXs5fvw4drudhIQEZ9hVURw/fpx3332XTZs2kZOTQ7169Rg5ciQjR44s1dIPSUlJTJkyhbVr15KcnExQUBA9evTggw8+IDw83K29zWZj5syZzJs3j6NHj6LVamnevDnjxo3j6aefvpO3KMQNSUAlhBAViEqlwsdDi4+Hlmp4lfp8i9XmHK11bZiVcc2IrQzHCK88M+kFr602O/lmG8nmfJIzSxduees1VPLS4e9dXJBVeFzZR1/QTgm/7su16NTqgg9LgiCkSfFtbDZlGkGX8CrJNcTKugimbLh0WNmK4x10TWhVEwLrKJuhqoy+EkIIIe4GqxmyLxWZFjhZeZ2TWiSAulI4utpmvo03c0xn7Ace/gVrdnoXjDwqONY5XhfdF5Rr9EU2XfHHWr0y2kmlLvK7hKrguGB/7e8YNivYLMrXwma55tis1FvNyigus2PLLRzp5XjtKDNmuod1eelgylLez5yjbJlJN/+SqXXgFwp+YcpI9WunYK5UQwm7hBDiHpGVlcXSpUtRqVSkpaWxcuVKBgwYUN7dqjAuXbrEuHHjAIiIiKBy5cqkpaWVc6/cHTlyhPbt25OXl0f//v2pVq0aq1ev5vXXX+fIkSN89dVXJbrOqVOnaN++PZcuXaJ79+4MGDCAEydOEBMTw5o1a9ixYwe1axcugWC32+nfvz/Lly+ndu3aDB8+HKPRSGxsLM888wxfffUVo0ePLqvbFtfo1q0bOp2OUaNGMWrUqPLuzl0nAZUQQtxHtBo1AT56AnxKv86b3W4ny2ghI9dMeq6Zq7km0vPMpOeanK8zcpUwy3F8NddERp4Zmx1yTVZyb2HUlp+nlkCDh7PfAd56Agx6AgteV/YpPA708cBLf5+M1FKrwTdY2aq3Kr6NOV8Jqq4mQvoZuHqm4DhR2eenF3zolVr8CCy9QVmHK6B2YWgVWFvZ5EMYIYQQ4uZsVmUq38zzRabyLQihsi4WTPGbrIRO2G96ORc6H2WKYe9A5YET70Blat+i09gVt+kNFXM0kFqjbNoyfoLfaikMrxwj0LIvFa4n6tiyi6wvajND+lllux6vylA5EgJqFdkKXvtUkYd+hBAVypIlS8jJyeGtt97iyy+/JCoqSgKqIoKCgli/fj2tWrUiICCAJ554gnXr1pV3t9y89tprZGRksGbNGnr27AnAlClTeOyxx/j6668ZPHgw7dq1u+l1xo4dy6VLl5g2bRpvvPGGs/yHH36gf//+jBo1ip9//tlZvnz5cpYvX06HDh3YsGEDXl7KA9IffvghrVu3Zty4cTz11FMVbrTZ/SouLo6wsLDy7ka5kYBKCCEEoIze8vNU1qcKDyj5eTabEmy5BFl5RUKu3IKQq6DMcZyRZ8ZuRxnxlW8hITWnRO/nqVMT6FMk0CpmcwZaBg/8PLWlGhZfoeg8IaiushUnL70wrHLs005D2inlAxhTNlw8oGzX8gooDK2C6kCVBspWuaby4ZIQQghxv7PblYAjM0kZxeyclrfgOCNJCaFKuiaTWgs+VZWHTwwhyt6nSpEAKkAZXe1dEErpSj9aXgAarfK19A5QAqSbsZoLQ8aio9Uzi4xgz7kMeVeV7cJe92vofaFKfaha8PtSlYbKsV91Ca6EEOUiKioKrVbL22+/zYEDB9i4cSOJiYlERES4tbVarXz22WfMmTOHpKQkwsLCGD58+HUDrbi4OBYsWMD27ds5f/48AA0aNODVV1/l1VdfdWuvUqno0qUL3333HePHj2f9+vUYjUY6d+7MV199Ra1atfjzzz/517/+xZYtWzCbzfTo0YMZM2YQHBxcqvtOT0/nn//8J7GxsaSnp9OwYUPefvttBg0a5NLOYDDw+OOPl+rad9vx48fZsmUL3bp1c4ZTAHq9nilTptC1a1fmzJlz04AqPz+fdevWERwczJgxY1zqnn/+eZo3b866des4ffo0tWrVAiA2NhaACRMmOMMpUIK9N998k7///e9ER0fz/vvv36nbFeK6JKASQghxW9RqFf5eOvy9dEQElvw8q81Oeq6Jq7kmrmSbSMsxkZZrIi3bxJUc5bVLXY4Jk9VGvtnG+fQ8zqfnleh99Fo1VQweBPl6UMWgJ8jgQRVfD4IMHkWO9QT5euDrcY+FWV6VlC20mXudxagEVldOKoHVlZNw5ZSyZV1Q1rhI+k3ZitJ4KIGYI7CqUl/ZB0TKYuRCCCHuLY4A6uqZIiORzygPcThCKEsJfp9Qa8G3mjJFnCEYfEOK7EMKAynvwIo5sulBp9FBpXBlux5TjvK9kXa6YEso3GecU6YVPL/bfcS6h1/h70qhzaBaCwhuLOGjEKJMHTlyhF9//ZVevXoRHBzM0KFD2bhxI9HR0UyePNmt/auvvsq8efOIjIxk1KhR5OfnM3XqVHbs2FHs9T/55BNOnjxJ27Zt6du3L+np6fz888+MGDGCY8eO8fnnn7udc/XqVTp27EhISAjDhg3j+PHjrFq1iqNHjxIbG0unTp1o1aoVL730Env27GH58uWkpaWxadOmEt+3yWTiscceIzs7myFDhpCTk8PSpUsZPHgwqampbuFMRRcfHw9A9+7d3eo6duyIj48Pmzdvvul1rly5gsViISIiotjPMyIjI9m/fz9xcXHOgCo5OdlZV1x7gE2bNklAJe4KCaiEEEKUC41aRaDBg0CDB3Wq3ry93W4nx2QtCLCMbuGVY7tSJNjKNlowWUoeaHlo1S4BVhVfvTPccpRX9fUg2M8TT10FH2Wk9YAq9ZTtWqYc5UOXKyeVLfUkXP4TLh9XPqhL+UPZilLrlNFWjieHgx+CkIeUda/upVBPCCHE/cWcrwROzvApsfD4amLhekU34lNVWY/Iv3rBGkUFx/7hyggZQ1UZXXy/0/sowVJwY/c6i1H5venyUbh0VPmd6dJR5QEgYyYk/a5s+xYo7VWawsDKuTVV3kMIUTbsdmX9uopO531H/naKiooCYMiQIQD069eP119/nejoaCZOnIi6yMMS8fHxzJs3j2bNmrF9+3Z8fJSfRRMmTKB58+bFXn/WrFluwYXFYqFXr15MmzaNsWPHUqNGDZf6gwcP8uabbzJ16lRn2euvv86sWbPo1KkTkydPZuzYsYDyt/1TTz3FmjVr2Lt3Ly1btizRfV+8eJG6deuyY8cO9Hq98z5atGjB+PHj6devH9WrVy/RtUojPj7eGSaVRM2aNXnhhRdu2u7EiRMA1K3rPmOKRqMhMjKSI0eOYLFY0Gqv/xF+5cqV0Wg0JCYmYrfb3UKqhIQEQBmx5RAUFOSsa9iw4U3bC1GWJKASQghxT1CpVBg8tBg8tNQI9C7ROflmK5ezjKRmG0nNNjmPC8scx0qYZSxFmOXvpSPEz5Oqfh6E+HkS7OdJsJ9Hwd6TEH9PAn30aDUV8ElqvQ+ENFG2omw2yDhb8OHLUbh8rDC4MucUHP/peo6HX8EHOgWBVUgTqNpInhwWQghx51gtSvDkeLDCuZ1Spmi7Gd9QZQpbx1aphhI++VdXAqiyXjNJ3Nu0HlC1obIVza8sJuX78PJR5cGeiwfgwn5lzatLh5XtwCKlrUqj/I4U/gjUeETZ+z+4a00IcceZc+HDauXdi5ubcOG2w2qz2cyCBQvw8/OjT58+gDKdXd++fVm4cCG//PKLy4icb7/9FoCJEyc6wymA6tWrM3bsWN577z239yhuVI1Wq2XkyJFs2LCBuLg4hg0b5lJvMBj497//7VI2aNAgZs2aRWBgoMu6SCqVioEDB7JmzRoOHDhQ4oAKlDWSHOEUQFhYmPM+Fi9ezD/+8Y8SX6uk4uPjSzWSqEuXLiUKqDIyMgDw9/cvtt7Pzw+bzUZWVhaVK19/DWlvb286d+5MXFwcM2fOZNSoUc66FStWsH//fkCZHtGhZ8+eLF68mI8//pi//OUveHp6AsporC+//NKtvRBlSQIqIYQQ9y1PnYbwAG/CA24eaOWZrEpgVSTAcgZZWSZneUpmPkaLjYyCdbSOpVz/yWy1CoIMHoT4e1LV15MQfw+CfQvCLH9Pqvl7ElrJC4NHBfnfsVpd+OFd/ScKy202ZU2Oy8fg0p/KlnJICbKMmXB2p7I5qNTKaCtHaBXaDKq1VNaKEEIIIYpjtyvrBF0bQF05CVcTwGa5/rl6wzUBVESR43B5aEKUDa0eghsp20P9lDK7XVm37ML+wnVAL+5Xyi7uV7bfvlHa+lUvCKzaQs1OSgAmo9KFEDcRGxvL5cuXGT58uDNUABg6dCgLFy4kKirKJaA6cEBZj7hTp05u1yquDCArK4vPPvuMlStXcurUKXJyXNeLvnDhgts5devWxdvb9e/u0NBQAJo2beo2qsdRV9y1rker1Ra7HpPjPvbt21fia5XG5MmTi506sSL54osv6NixI6NHj+ann36iadOmnDx5ktjYWJo2bcrBgwddRtYNHjyY+fPnExcXR5MmTXjiiScwm82sXLnSuS6YWqYtFndJBflETAghhChfXvqShVl2u53MPAspWfmkZOaTnJHPpSwjyRnK65QsIykZ+VzONmK12bmUZeRSlhHIuO41/Ty1VKvkRfVKXoRW8qRaJS+q+Xsp+0pKoKUrz5FYarXytHmlGlC3yEKzFhOkHleeGk4+VLD/Q3lyOPW4sh1eUdi+UoSyNkO1FlC9pRJceRb/tJgQQoj7lM2mrOnjHKVb8PBD6nEwZV//PK0XBNYu2OoUbgG1lLWf5IN9URGoVOBXTdka9Cosz0iCc7vg7C5ln3xIGf13eEXh70o+VSCyc8HWRQlY5ftaiJLReSujkyo6XclmArkRx/R+Q4cOdSl/9NFHqV69OrGxsaSlpREQoDwcmJGRgVqtdk7pVpQjiCjKZDLRtWtX9u7dS4sWLRgyZAiBgYFotVrOnDlDTEwMRqPR7Tw/Pz+3Mse0dDeqM5vNN7tlp6CgoGJDE8d9OEYk3SscI6eu1+/MzExUKhW+vr43vVazZs34/fffmTRpEnFxccTFxVGnTh2++eYb0tPTGT9+PFWrFq6toNVqWbt2LR9//DGLFi1i9uzZ+Pv707dvX8aNG0e9evVc2gtRliSgEkIIIUpBpVLh763D31tHveDr/6Jotdm5km0kJVMZdZWcmc+lzHxSMo0kZyph1oX0PDLzLcqWnMXR5OJHY6lVUNXXk2qVlBFX1St5OUdfVa/kRXiAN/5eurK65evT6gum9XsImg1Uyux2yE5Rgqrkg0podWG/sk5DeqKyHVlZeI3AOsroKkdwFdoM9Lf/h5sQQohyZrMpP/OLBlGXjxZOG1sclQYqRxQJoIqEUb7VlAcmhLgX+ResbfbQs8prUw6c3wvnfoXEHZC4E3Iuwx/LlQ3AvwbU6gJ1u0PtbuBx8w8ohXhgqVQPxDpv586dY/369YAyjdz1LFy40Dmlnr+/PzabjdTUVKpUqeLSLiUlxe3c2NhY9u7dy/Dhw5k7d65L3eLFi4mJibnd27hlqamp2Gw2t5DKcR/XmyrvdpXVGlSOtacca1EVZbVaSUhIIDIy8obrTxXVoEEDlixZ4lbu6Evr1q1dyj08PJg0aRKTJk1yKXfc67XthSgrElAJIYQQZUCjVlHVz5Oqfp404fq/KGflm7mYoYRVF9IL9hl5ztfJGfmYrDaSC0IuzqYXex1/Lx3hAV6EV1ZGgYVX9iIswJsaAd5Ur+SFp+4uLe6uUoFviLLVfaywPC9dmdbmwj7lA5kL+5X1rhzTOB1aWnB+kTUawtsom3+4PEEshBAVWW5awUjaw8qDCSl/KEGU5TprOmr0EFgXqtSHKg2gagMIqq+MhtLqiz9HiPuJ3gciOykbgMUISbshYTMkbIGk35Xfk/YtUDa1DiLaQ70noF4PJbwVQjxw5s+fj81mo2PHjtSvX9+t3mKxEBMTQ1RUlDOgatasGXv37mXr1q3069fPpf3WrVvdrnHq1CkAnnnmGbe64trfTRaLhZ07d9KhQweXcke/WrRoUSbvW1ZrUDlCxvXr1/POO++41G3bto2cnJwbBpElkZWVxU8//URgYCCPP/74zU8AvvvuOwAGDhx4W+8tRElJQCWEEEKUI19PHb6e1x+NZbPZSc0xciE9n4vpeZwvCK4uZijHSVfzSMsxKWtinTfzx/nMYq8T7OfhEl45pjOsEeBNiJ8nanUZB0BelaBWV2VzyElVAivHdn4vZCe7r9HgGwphDxeEVo9AaFNZ0F4IIcqD1aKMiC06rWvKYci6zrRKGg8IqucaRFVpAJUjQSN/igrhpPWAmh2UrdsEMGbD2V/h1EY4vk757y5hs7Kt+xcE1Ib6PaFRHwhrLQ/yCPEAsNvtREdHo1KpiImJoVatWsW2O378ODt37mT37t20bt2aIUOGEB0dzQcffECPHj3w8VFGmp0/f55p06a5nR8REQEoAUnv3r2d5Zs3b2bOnDllcGelM2HCBDZs2IBerzzQkpSUxLRp0/Dw8CizQKWs1qCqX78+nTt3Ji4ujrVr19KzZ09AmWbxvffeA+Dll192OSc1NZXU1FSCgoJcpm3My8tDp9O5jLYyGo0MHz6ctLQ0pk2b5rJmGShTCF47/eKyZcuYN28eDz/8sFugKURZkb8KhBBCiApMrVZR1deTqr6eNA+vVGybHKOFc1dzOZeWx7m0XOdx0tVczqXlkmOyFkw1aGR34lW38z20amoG+hAR6E1kkA81g3yoGehDZJAPVX09yi688glS1rRyrGtltytrNCT9Bud+K1yjIesi/PmjsoHygWe15kpYVbOjsri4rGUlhBB3lilH+Rl88YAyZWvyH8oUfZb84ttXrgnBDymjYIMbQ9VGytqDEkQJUXoeBmUket3H4ImPIPUknFinhFWJO5TAaufXyuYXBo2eLgirHpapMIW4T23atImEhAS6dOly3XAK4MUXX2Tnzp1ERUXRunVrunXrxosvvkh0dDRNmjShb9++GI1GlixZQtu2bVm1apXL+b1796ZmzZp8+umn/PHHHzz00EMcO3aMVatW0bdvX5YtW1bWt3pdoaGh5OTk0LRpU3r37k1OTg5Lly7lypUrTJ8+nerVq7u0HzduHKmpqQAcOnTIWWYwGAAl/OnYsePdvYlrzJw5kw4dOtCnTx8GDBhAaGgoq1ev5vDhw4wePZr27du7tP/66695//33mTRpkktotmfPHvr168fjjz9OeHg4mZmZrF69mrNnz/LKK68wZswYt/d+5JFHCA8Pp2HDhnh6evLbb78RHx9PrVq1+OGHH9Bo7tIsLOKBJ38tCCGEEPc4Hw8tDUL8aBDivvis3W4nLcfEuavFh1dJV/MwWmwcS8niWIr7GlieOiW8qhmoBFeRQd5EFAmvVHfyiV2VCiqFK5tzjYZcZXRV0dAq94qyP7cLdkwHlVr5QDSio/LkcY124B1w5/olhBD3O0cYdWF/4XSsqcfBbnNvq/OB4EYFYdRDENwEqjYET/f/Bwkh7pCgOsrWbhTkZ8KpTfDnT3D8Z8hMgl9nKptvKDR8GpoOgOotZWSVEPeRqKgogJtOHTdgwADGjh3L999/z9SpU/Hy8mLOnDnUq1ePOXPm8PXXXxMWFsZbb71F//793QIqg8HApk2bGD9+PFu2bCE+Pp7GjRvz3XffERwcXK4BlV6vZ8OGDbzzzjssWLCA9PR0GjRowFdffcWgQYPc2i9btozExESXsuXLlzuPu3btWu4BVePGjdm1axfvvvsuq1evJicnh3r16jFjxgxee+21El+nRo0adO3ala1bt5KSkoK3tzctW7Zk6tSpPPvss8WeM2DAAFasWMGvv/6K2WwmMjKSd999l/Hjx7uNrBKiLKnsdru9vDtxP0tKSiI8PJxz584RFhZW3t0RQgghXFisNs6n55GQmsOZ1BzOXMnlzBXl+NzVPKy26/+a4K3XUKuKD7WrGKhTxUDtqgbqVDVQM9AHvbaMnt612yHttBJOJe6AxO3Kaxcq5YPTmh0gomDzCSyb/gghxL3GLYzaD6nHig+jDCHKiNWQpgVh1EPK9HwyQkOIisGcr0wDeCQWjq0FY5GpnoPqQbOBSljlL59FiPtDfn4+CQkJREZGuk1XJoQQFUVJf1ZJbqCQgKqMyTeaEEKIe5XZaiPpap4zsDqTmkPClVzOpOaQdDWX62VXGrWKGgHe1K5ioHZVH+pUUYKr2lUN+Hnq7nxHMy8oYdWZbcp25YR7m6qNC9fAimivTJ0jhBD3O5sNrpyEpN8Ltt1w6fCNw6hqLSC0uXLsG3KXOyyEuGUWI5yKgz+WwZ+rwJJXUKGCyE7QbLAyFaDep1y7KcTtkIBKCHEvkICqdCSgKmPyjSaEEOJ+ZLLYOHc1l1OXsjl5OZtTl3IK9tlkGy3XPa+qr4cy4qqqgbrBBuoH+1I/xJdK3vo717msFGVkVeJ2OLMdLv/pWq/WQXibwsCqWktZI0UIcX/IuwpJewoDqfO7IT/DvZ0jjAptrgRSEkYJcX/Jz1TW7jywGM5sLSz38FdGVbV+Cao2KL/+CXGLJKASQtwLJKAqHQmoyph8owkhhHiQ2O12LmUZOXkpm1OXs132KZnG655X1deD+iG+1A/2pV7Bvm6wAW/9HQiOclIhYQucjofTcZB+1rXeww9qdioMrILqypoNQoiKz2ZTpuY7u1MZGXXut+JHkGq9lBAqrDWEPazs/ard/f4KIcrH1UQ4uBT2L4SrZwrLIzooQVXDp0F7Bx8UEqIMSUAlhLgXSEBVOhJQlTH5RhNCCCEUWflmTl3O4eQlJbA6kZLFsZQskq7mFdtepYIaAd7UC/Z1jrSqH+JLZJAPOs0trn9it8PVhIKwKh5Ob4b8dNc2vtWgzl+gzuNKYOVV6dbeSwgh7iSrWVkv6uzOwi3vqnu7gFoQ1qYwkApuDJoymF5VCHFvsdmUB3V2z4Njawqn+vSpAg+/rGw+QeXbRyFuQgIqIcS9QAKq0pGAqozJN5oQQghxY9lGC8dTsjierARWx1OyOJacRWq2qdj2Oo2KWkEGGob60qiaH41C/WkY6kugwaP0b26zQvJBJaw6FQdnfwVrkZFeKo0yHWCdR5XAKqQpqG8xHBNCiNIw5SjT9CXuULak3UXWlCmg81aCqPBHlDCqemvwCSyf/goh7h0Z52Hvt7A3BrIuKmVaT2g2CNqNUkaTC1EBSUAlhLgXSEBVOhJQlTH5RhNCCCFuTWq20SW4OpacxfGU669xFeznQcNQPxqF+tGomh8NQ/2oGeiDRl2K6frMecoHwSc3wslflOmzivKpWhBWPQa1/wLeAbdxh0IIUYQpRwnJE7Yoa8ZcPAC2a37eeVWGGu2ULaI9hDaT0VFCiFtnNStrVe34Ci7sKyhUQf2e0P4NiGhXrt0T4loSUAkh7gUSUJWOBFRlTL7RhBBCiDvHbrdzPj2PY8lZ/HkxkyMXM/nzYhYJqTnFtvfSaagf4hhppYRWDUJ88fEo4dpWVxPh1EY48QskbAZTdpFKFVRvBXUfVwKrai1Arbn9mxRCPBgsJji/WwmkErYoa0jZzK5t/MKUD4gdgVRQfRnFKYS48+x25QGdnV8r0/851OwEXd+Bmh3Lr29CFCEBlRDiXiABVelIQFXG5BtNCCGEKHvZRgvHkjM5cjGLIxcy+fNiJkeTM8k329zaqlRQM9CnILAqnCYw2M8DleoGo60sJjj3qzKy6sQvcOmwa71XgDKqqu7jyt5Q9Q7fpRDinuacUnSzEkid3QnmXNc2ftUhsgtEdoaaHaBSjfLpqxDiwZV6QhlRtX9RYWge0RG6/lMJrG70u5IQZUwCKiHEvUACqtKRgKqMyTeaEEIIUT6sNjtnruRw5IJjpFUmRy5kcinLWGz7AB+9c3pAx75WkA9azXVGK2ReKAirNihrWBkzXetDm0HtgukAw9vINFxCPIiykuHUJuVnxalNkHfVtd47UAmjIjsrwVRALfnwVwhRMaSfg21fwL4FYC1YFzSiAzw6EWq0Ld++iQeWBFRCiHuBBFSlIwFVGZNvNCGEEKJiSc02OsMqR3B16nIOVpv7r0QeWrUyRWCR4KpBqB+Ga6cItJohaTec3KB8EH3xgGu93hdqdVHCqjqPyqgIIe5XVjOc26X8HDj5CyQfcq338FM+4HWEUlUbyZR9QoiKLSMJtn0Je2MKg6oGT8Gjk6BKvXLtmnjwSEAlhLgXSEBVOhJQlTH5RhNCCCEqvnyzleMpWc7QyjFNYI7JWmz7moHeLiOt3KYIzL4Ep+IKRk1shNwrrhcIqq8EVXUeVT6s1nmV8R0KIcpMRhKcWA8nNyrT95myXOurtSgIpx+D6q1BU8I18IQQoiLJOA+bP1FGVNltoNJAy6HKGlW+IeXdO/GAkIBKCHEvkICqdCSgKmPyjSaEEELcm2w2O2fTcp2BlWOfnJlfbPvrThGoAi7uVz68PvkLJP2mfLDjoPVUFh93fIAdWEem+BKiIrPblVGSx9Yo27WjpLyDCgLox6BWNzBUKZ9+CiFEWbh0FDa+r/z8A9B5Q6e3oP0boPUo376J+54EVEKIe4EEVKUjAVUZk280IYQQ4v5yJdvInxezOHIxwxlclWqKwEo2DOe3KSOrTm6EzPOuJ1WqoaxFE9lZCa78qt2lOxNCXJfFCAlbC0KptZB1obBOpYawh6Hu40ooFdJMpu0TQtz/EnfA+vfg/G7ldUAteOITqNe9fPsl7msSUAkh7gUSUJWOBFRlTL7RhBBCiPvfLU8RGOLLw4ZLNMz5Hd+keFSJOwrXd3AIqA2RnaBmweYbfBfuSAhBfgYc+xmOrVbCZFN2YZ3OB+r8Ber3grrdwSeo/PophBDlxW6HQz/A+nchO0Upq98LenwIAZHl2zdxX5KA6s46c+YMkZGRDBs2jPnz55d3d26oZs2agNJnISo6CahKRyZAF0IIIYS4TZ46DU3DKtE0rJKz7EZTBJ65ksuZK7msOZRc0Loe/l6NaVJlLI95n6SV/RA1s/ZhuHoYVdopSDsFe+YrTYPqKyOrarSD8DbKiCuZElCIOyMvXRkhdWQlnNrkGhgbQqB+T+XD18jOoJMPxoQQDziVCpr2h3pPKOtT7fqPMtL05Ebo+k9l2j+Nrrx7KcR966WXXiI6OpqAgAAuXLiAh4dMs+mQm5vLrFmz2LNnD3v37uX48ePY7XYSEhKcYVdFcfz4cd599102bdpETk4O9erVY+TIkYwcObJwjeMSSEpKYsqUKaxdu5bk5GSCgoLo0aMHH3zwAeHh4W7tbTYbM2fOZN68eRw9ehStVkvz5s0ZN24cTz/9dLHvsWvXLj788EO2b99OVlYWNWrUYNCgQfzrX//Cy0vWVRa3RkZQlTFJQoUQQghRVGmmCPQjhy6eJ+jufYJWtj8IzT+JimvaGYKVoCqsDYQ/AqHN5INzIUoj7yocXVMQSsWBzVxYF1QfGj2thFKhzWXqPiGEuJFLR2HteEjYorwOaQrPzIDQpuXbL3HfkBFUhbKysggNDSU3Nxe73c7ixYsZMGBAqa5xP4+gctwbQEREBFlZWaSlpVW4gOrIkSO0b9+evLw8+vfvT7Vq1Vi9ejWHDx9m9OjRfPXVVyW6zqlTp2jfvj2XLl2ie/fuNG3alBMnTvDjjz9SpUoVduzYQe3atZ3t7XY7zz//PMuXL6d27dr07NkTo9FIbGwsly5d4quvvmL06NEu77FixQoGDBiARqPh2WefJSQkhO3bt7Nr1y46dOjAxo0bJSQtICOoSkcCqjIm32hCCCGEuJl8s5XTl3M4cSmLEynZHE/J4uSlbM5cyaFobuVPNm3Vf/KI+k9aa0/SiAS0uE4jaNfoUYU2g2otoVoLqNYcguqBWnN3b0qIisyYBX+ugj+Wwel4sFkK66o0hMZ9oNEzULVhefVQCCHuTXY7HFgMP78D+emg1kKHv0Pn8fIAjbhtElAVmjt3Lq+88gpvvfUWX375JY8++ijr168v1TXu54AqOzubnTt30qpVKwICAnjiiSdYt25dhQuounTpwpYtW1izZg09e/YEwGQy8dhjj7F161Z27NhBu3btbnqdp556itWrVzNt2jTeeOMNZ/kPP/xA//796dGjBz///LOzfNmyZTz//PN06NCBDRs2OEc/paam0rp1a5KTkzl69Kjza5WXl0dERATp6enOrysoQdeYMWOYMWMGH330Ee+8886d+tLc0ySgKh15BFAIIYQQopx56jQ0qubHM82rM65HfWYPbc2mcV058sETrB3biWkDmzO6Wx3aNq7NiYCu/Ns6lKfzP6BxfhTPGSfykXkQ662tuGz3Q2U1QdLv8Ns3sHIkzGyL7cPq2KK6w9p/Kh8aXToKtuLXxxLivmU1w/H1sGw4/F9d5b+Pk78o4VTVxtDtf2DUbzDqV+j6joRTQghxK1QqaD4IRv+uBP02C2z9DL7pDBf2l3fvhLhvREVFodVqefvtt+nWrRsbN24kMTGx2LZWq5VPPvmEOnXq4OnpSZ06dfjoo4+w2WzFto+Li+Oll16ifv36GAwGDAYDrVu3Zvbs2cW2V6lUdO3alfPnzzN48GCCgoLw9fXlySef5PTp0wD8+eef9OnTh4CAAHx9fXnuuedISUkp9X2np6czYsQIQkJC8PT0pEWLFnz//fdu7QwGA48//jgBAQGlfo+75fjx42zZsoVu3bo5wykAvV7PlClTAJgzZ85Nr5Ofn8+6desIDg5mzJgxLnXPP/88zZs3Z926dc5/C4DY2FgAJkyY4DI1X1BQEG+++SZGo5Ho6Ghn+Y4dO7h8+TJ9+vRxhlOg/Nv/+9//BuA///kPMg5G3ApZg0oIIYQQooLy1GloGOpHw1A/l3KjRRlxdfJSNicvNebkpWziL2WTkJpNiC2ZlqoTNFEn8JA6gYdUCfhY8uDcLmUrYNF4YQx6CI8ardCGt4bqLSGglqxnJe4vdjuc3wsHl8AfyyE3tbAusA406Q8P9YOguuXXRyGEuB8ZqkL/b+HIj7D6H5B6DOY+Bo++B+3GyJSp4o6y2+3kWfLKuxs35aX1KtWaQtdz5MgRfv31V3r16kVwcDBDhw5l48aNREdHM3nyZLf2r776KvPmzSMyMpJRo0aRn5/P1KlT2bFjR7HX/+STTzh58iRt27alb9++pKen8/PPPzNixAiOHTvG559/7nbO1atX6dixIyEhIQwbNozjx4+zatUqjh49SmxsLJ06daJVq1a89NJL7Nmzh+XLl5OWlsamTZtKfN+OkUXZ2dkMGTKEnJwcli5dyuDBg0lNTXULZyq6+Ph4ALp37+5W17FjR3x8fNi8efNNr3PlyhUsFgsRERHFfn9FRkayf/9+4uLiqFWrFgDJycnOuuLaA2zatIn333//pu0rVapE5cqVSUxM5PTp0y5TCQpREhJQCSGEEELcYzy0xQdXFquNs2m5SnB1OZsll7L5KCUD6+UT1LGcLBJancHbmoc25XdI+R1+V87P1fiREdAEe/VW+Ndui09kGzBUKYc7FOI2ZSXD/kWw/zu4crKw3DsImjwHTfsr02BKICuEEGWr0dNQsyP8OAaOroINE5XRq33+A/7Vy7t34j6RZ8njkUWPlHc3bmrX4F1467xv+zpRUVEADBkyBIB+/frx+uuvEx0dzcSJE1EXCYDj4+OZN28ezZo1Y/v27fj4+ADKyJnmzZsXe/1Zs2a5BREWi4VevXoxbdo0xo4dS40aNVzqDx48yJtvvsnUqVOdZa+//jqzZs2iU6dOTJ48mbFjxwJKoPjUU0+xZs0a9u7dS8uWLUt03xcvXqRu3brs2LEDvV7vvI8WLVowfvx4+vXrR/Xqd/7nSnx8vDNMKomaNWvywgsv3LTdiRMnAKhb1/1BKY1GQ2RkJEeOHMFisaDVXv8j/MqVK6PRaEhMTMRut7uFVAkJCYAyYsshKCjIWdewYcNStb9WRkYGV69edZ4jAZUoLQmohBBCCCHuE1qNmlpVDNSqYqDoc3h2e2cuZuRz8lI2hy5lszIlg7yLR/G5coja5hM0U5+isSoRb2sm3pe3w+XtsH86AJc0wVz0a0pOtfZ41OlKZN3GBBhk8VtRAVktcGI97FsAx9eBvWAaS60XNHwKmg6AWl1BoyvXbgohxAPHOwAGLFR+Pq/9JyRsgVnt4ZmvoWHv8u6dEPcUs9nMggUL8PPzo0+fPoAynV3fvn1ZuHAhv/zyi8uInG+//RaAiRMnOsMpgOrVqzN27Fjee+89t/cobpSMVqtl5MiRbNiwgbi4OIYNG+ZSbzAYnFO9OQwaNIhZs2YRGBjosi6SSqVi4MCBrFmzhgMHDpQ4oAL48MMPneEUQFhYmPM+Fi9ezD/+8Y8SX6uk4uPjnSOJSqJLly4lCqgyMjIA8Pf3L7bez88Pm81GVlYWlStXvu51vL296dy5M3FxccycOZNRo0Y561asWMH+/fsBZXpEh549e7J48WI+/vhj/vKXvzjXSbpy5QpffvmlW/sOHTrg5+fHypUr2bdvHy1atHDWTZw40Xlc9BwhSkoCKiGEEEKI+5xKpaJaJS+qVfKicz3HiKjmwECuZBs5eSmbFSlXyUw8gEfKPoIy/qCe5Th1VBeoak2h6tUNcHUDHH6fJHsQu3RNuVy1A/oG3XmodgT1Q3zRaWSqHlFOrpyCfQuVEVPZyYXl4W2h5RBlDRQP3/LrnxBCCGXEasuhENEBlr8MF/bCkr9B+zHw6GTQyMdT4tZ5ab3YNXjXzRuWMy+t180b3URsbCyXL19m+PDhzlABYOjQoSxcuJCoqCiXgOrAgQMAdOrUye1axZUBZGVl8dlnn7Fy5UpOnTpFTk6OS/2FCxfczqlbty7e3q6jw0JDQwFo2rSp26geR11x17oerVZLu3btrnsf+/btK/G1SmPy5MnFTp1YkXzxxRd07NiR0aNH89NPP9G0aVNOnjxJbGwsTZs25eDBgy4j6wYPHsz8+fOJi4ujSZMmPPHEE5jNZlauXElwcDCAS3uDwcDUqVN5+eWXadeuHc899xwhISHs2LGDPXv20KBBA44ePepyjhAlJb8BCCGEEEI8wAINHgQaPHikViC0qwM8C0BWvpk/zieTfmIXmnPbqZL6G7WMfxKmSiXMsgkubMJ8/kN+3dCQj1WtuVTtMRo2aEiH2kE8VN0fjVqmThNlyGqB4z/Db7Mhocjc/N5B0HwQtBgKVeqVX/+EEEIUL7A2DF8PG9+HHV8p2/m98Nw88A0p796Je5RKpbojU+fdCxzT+w0dOtSl/NFHH6V69erExsaSlpZGQEAAoIzSUavVzinainIEEUWZTCa6du3K3r17adGiBUOGDCEwMBCtVsuZM2eIiYnBaDS6nefn5+dW5piW7kZ1ZrP5ZrfsFBQUVGwA4rgPx4ike4Vj5NT1+p2ZmYlKpcLX9+YPWjVr1ozff/+dSZMmERcXR1xcHHXq1OGbb74hPT2d8ePHU7VqVWd7rVbL2rVr+fjjj1m0aBGzZ8/G39+fvn37Mm7cOOrVq+fSHmD48OFUq1aNTz/9lNjYWKxWKw8//DAbN27kk08+4ejRo27nCFESElAJIYQQQgg3vp46mtYOh9rhwHNKoSmHnBPbufLHenzO/EJgXgKdNH/QiT+wXYxh5/lGxKzvxHZ9e1rUDuMvDaryaMOqBMqUgOJOyUmFvd/C7nmQca6gUAV1HlNGS9XrCVr9DS8hhBCinGl00P3fENYGVr4OidvhP53g+WhlvSohRLHOnTvH+vXrAWUauetZuHChc0o9f39/bDYbqampVKniurZsSkqK27mxsbHs3buX4cOHM3fuXJe6xYsXExMTc7u3cctSU1Ox2WxuIZXjPq43Vd7tKqs1qBxrTznWoirKarWSkJBAZGTkDdefKqpBgwYsWbLErdzRl9atW7uUe3h4MGnSJCZNmuRS7rjXa9uDMjVgz5493cqHDBmCWq0u1XSNQjhIQCWEEEIIIUpG74NP4+74NC6YNuTKKWx/rsb4x494Jf9OB81hOmgOk2OPZtWxdkQf6cE7RPBwzQC6Nw7hiYdCqF7p9qc2EQ+g83vhtznwx3KwFjy16xUArYZB65egUo0bny+EEKLiafQ0VG0ES4fApSPw7TPQ6/+Un+tCCDfz58/HZrPRsWNH6tev71ZvsViIiYkhKirKGVA1a9aMvXv3snXrVvr16+fSfuvWrW7XOHXqFADPPPOMW11x7e8mi8XCzp076dChg0u5o19F10W6k8pqDSpHyLh+/Xreeecdl7pt27aRk5NzwyCyJLKysvjpp58IDAzk8ccfL9E53333HQADBw4sUfvt27dz5swZevXqVWYhobi/SUAlhBBCCCFuTWBt1B3fwKvjG5B+Fg4swX7ge3zSTjFAG88AbTw7rY2Yl/gE/05oyZRVR2hbK4B+LcPo1SQUg4f8KipuwGaD42th+3Q492theWhzeGQENO4HOs/rni6EEOIeEFQHXt4IP70Bh36AVW/C5ePKCCtZl0oIJ7vdTnR0NCqVipiYGGrVqlVsu+PHj7Nz5052795N69atGTJkCNHR0XzwwQf06NEDHx8fAM6fP8+0adPczo+IiACUgKR3797O8s2bNzNnzpwyuLPSmTBhAhs2bECvV0bMJyUlMW3aNDw8PEocqJRWWa1BVb9+fTp37kxcXBxr1651jkwymUy89957ALz88ssu56SmppKamkpQUJDLtI15eXnodDqX0VZGo5Hhw4eTlpbGtGnTXNYsA2UKwWunX1y2bBnz5s3j4Ycfdgs0i2t/4cIFXn75ZbRaLVOmTLnFr4R40Mn/7YUQQgghxO2rVAO6jEfVeRyc2wW7voEjsbTTHKGd5ghntRF8kvcMa0634dfTaUyKPUzPh0IY0i6CFjUql3fvRUViMcLBJUowdaVgyhO1Dhr3hTavQlhrUMkaZ0IIcd/Qe0O/OVClAWyaArtmwZWT8FwUeMrT+EIAbNq0iYSEBLp06XLdcArgxRdfZOfOnURFRdG6dWu6devGiy++SHR0NE2aNKFv374YjUaWLFlC27ZtWbVqlcv5vXv3pmbNmnz66af88ccfPPTQQxw7doxVq1bRt29fli1bVta3el2hoaHk5OTQtGlTevfuTU5ODkuXLuXKlStMnz6d6tWru7QfN24cqampABw6dMhZZjAYACX86dixfKcVnTlzJh06dKBPnz4MGDCA0NBQVq9ezeHDhxk9ejTt27d3af/111/z/vvvM2nSJJfQbM+ePfTr14/HH3+c8PBwMjMzWb16NWfPnuWVV15hzJgxbu/9yCOPEB4eTsOGDfH09OS3334jPj6eWrVq8cMPP6DRaFzaT58+nYULF9KxY0eqVq3KuXPniI2NJTc3l6ioKJneT9wyCaiEEEIIIcSdo1JBjbbKlpEEv8+F3+dRw5jIDN10UivV4UtrfxamN2bFvvOs2HeeZmH+vNChJr2ahOKh1dz8PcT9KS9dWVtq138gu2BNBA9/aP0iPDIS/ELLtXtCCCHKkEoFncdBUF1YMQJOboCoHvDXH6BSeHn3TohyFxUVBXDTqeMGDBjA2LFj+f7775k6dSpeXl7MmTOHevXqMWfOHL7++mvCwsJ466236N+/v1tAZTAY2LRpE+PHj2fLli3Ex8fTuHFjvvvuO4KDg8s1oNLr9WzYsIF33nmHBQsWkJ6eToMGDfjqq68YNGiQW/tly5aRmJjoUrZ8+XLncdeuXcs9oGrcuDG7du3i3XffZfXq1eTk5FCvXj1mzJjBa6+9VuLr1KhRg65du7J161ZSUlLw9vamZcuWTJ06lWeffbbYcwYMGMCKFSv49ddfMZvNREZG8u677zJ+/Hi3kVIA7du3Z/Pmzfz0009cvXqVwMBAevXqxT//+c8ym15RPBhUdrvdXt6duJ8lJSURHh7OuXPnCAsLK+/uCCGEEELcfXnp8Oss+HUmGDMByAxtz2zvV5h91AuT1QZAFV8PXu1Ui7+2rYG3Xp6jemDkpinfG7/+B0xZSplfdWj7GrQcBp7ufyALIYS4j13YB98PgqyLyv8P/rYCqjYo716JCiA/P5+EhAQiIyPdpisTQoiKoqQ/qyQ3UKjLuwNCCCGEEOI+51UJuv0Lxh6Ajm+BxgO/izsYd3o4Bx5ex/90CyHYz4PLWUb+d82fdPwkjlnxp8gxWsq756Is5abBpn/Dl01hy/8p4VTVRtDnP/DGfmg/RsIpIYR4EFVrAS//AkH1IfM8zOsB534r714JIYQQogxIQCWEEEIIIe4O7wB4bBKM/g0aPQN2G177o3nl0CC2PZPPp882pUaAN2k5Jj75+SidPo3j251nMBeMsBL3ieKCqeCHoP8CGLkdmg8Crb68eymEEKI8+YfBSz9D2MOQnw4xT8Px9eXdKyGEEELcYRJQCSGEEEKIu6tyTej/LQz7CQLrQnYKuh/+Sv8zE9k0shGfP9+MyCAf0nJMTIw9zBNfbmHjnynIzNT3OHMebJ0K05q7BlMDFsKIrdDoaVDLnydCCCEKeAfA0Fio8zhY8uD7gXD4v+XdKyGEEOK+Nm3aNCIiIvD09KRjx44cOHCgTN9P/gIUQgghhBDlI7IzjNwGHd8ElQYOr0D7TQee9T/G+jc7M+WZxgT46Dl1OYfhMbsZOu83Eq/klHevRWnZrLB3AUxvCRvfB2OGazDVsLcEU0IIIYqn94FB30OT/mC3wrLhcGhZefdKCCGEuC8tWrSIf/7zn0yZMoU9e/ZQp04devToQWZmZpm9p/wlKIQQQgghyo/OEx6bDK9sVNYfyrkMC/uh2zSZIW2qEz++KyO71EavVbP1RCrdv9jCjLiTMu3fvcBuh2M/w6z28ONoyLoA/uHQ9xsJpoQQQpScRgd9/wPN/6qEVCtegYNLy7tXQgghxH3niy++YOTIkQwdOpTGjRszd+5cLBYLixYtKrP3lL8IhRBCCCFE+avWAl7ZBA+/rLzePg3m9cAvP5l3ejZg/d8706FOIEaLjf9bd4ynpm/jUFJG+fZZXN/l47CwH3w/AC4fBc9K0P3fMHo3NBsowZQQQojSUWvg6a+hxRCw2+C/I2D/9+XdKyGEEOKuWrhwISNGjKB169Z4eHigUqmYP3/+Dc/5/fff6dWrF5UqVcLHx4e2bduydKn7gx4mk4l9+/bx2GOPOcu0Wi1du3Zl586dd/pWnOQvQyGEEEIIUTHovODJz5Wp3zz94fwemN0VEndQM8iHhcMfYWr/ZgT46DmWkkXfmduZEXcSq03Wpqow8jNh/bswqx2c2gQaPbR/A8buh/ZjlBFzQgghxK1Qq6H3dGj1ghJSrXwN/lhR3r0SQggh7pp3332X2bNnk5iYSGho6E3bx8XF0aFDB7Zt20b//v0ZOXIkycnJDBgwgM8//9ylbWpqKlarleDgYJfyqlWrkpycfEfvoygJqIQQQgghRMXSsLeyNlVIU8hNhZje8HsUKpWKfi3D+OWtLvR8KASLzc7/rTvGwNk7OZeWW969frDZ7XBgCXzdGnZ8BTYL1HsCXv8Vuk8Br8rl3UMhhBD3A7UanvxCCamww4pX4cQv5d0rIYQQ4q6YO3cuZ86c4fLly4wcOfKGbS0WC6+88gpqtZotW7Ywe/ZsPv/8cw4cOEC9evWYMGECiYmJd6nn1ycBlRBCCCGEqHgq1YCX1sFDzyphx+q3YNVbYLMS4KNn5l9b8n/PNcVHr+H3M1d5cvpWNh1NKe9eP5jSTsO3T8N/X4XsFAioBYN/gMFLILB2efdOCCHE/UathienQuN+YDPDkr9BYtlNPSSEEEJUFI899hgRERElartp0yZOnTrF4MGDad68ubPc39+fCRMmYDKZiImJcZYHBQWh0WhISXH9u/rSpUuEhITckf4XRwIqIYQQQghRMem94dkoeGwyoILdUbBkCJhyUalUPN86nLVjO9M8vBKZ+RZemr+bLzYcxyZT/t0dVgtsnw4z20PCFtB6waMTlVFT9brf1a4YLVauZBtJvJLD0eRMDl/I4I/zGRxMSmf/uXSOXMgkITWHlMx8MvLMmK22u9o/IYQQd5haA32/gTqPgyUPFvWHiwfKu1dCCCFEqWVlZZGZmencjEbjHblufHw8AN27u/9t1qNHDwA2b97sLNPr9bRo0YKNGzc6yywWC/Hx8bRr1+6O9Kk42jK7shBCCCGEELdLpYKOb0JAbVj+MhxbDd8+o4zO8Q6gRqA3S0a05d+r/mTBr4lM23iCA0npTBvYAn8vXXn3/v518SD8OAYu7ldeR3aG3tOU0VN3WI7RQkJqDqcuZ5N4JZfkzHxSMvKVfaaRzDwzplsInPy9dAQZ9AQZPAjy9SDUz5OIIB8iAryJCPSmeiUvtBp5nk8IISosrR76fwsLn4WzO2Dhc/DKRmUUthBCCHGPaNSokcvrSZMmMXny5Nu+7okTJwCoW7euW11ISAgGg8HZxuHNN99k+PDhtGrVipYtW/LZZ5+h1WoZPHjwbffneiSgEkIIIYQQFV+jp8EnFr4fAEm/wbweMGQl+FfHQ6thSp+HaB5eiQn/PUT8scs8N2sH0S8+TFhl7/Lu+f3FZoVtX0D8R8rUi57+0P1/ocXflDDxNqVk5nPgXDoHkzI4eD6DkylZXMjIL/H5PnoNXnoNapWqYAOVSoXRYiPfbCXXZMExwC4jz0xGnplTl3OKvZZWraJOVQONqvnRKNSPRtX8aFzNX4JPIYSoSPTeMHgxRPeClD/gu/4wfJ3y/ychhBDiHnDkyBGqV6/ufO3h4XFHrpuRkQEoU/oVx8/Pz9nGYfDgwVy+fJkJEyaQkpJC69atWbduHX5+fnekT8WRgEoIIYQQQtwbItrBS+uVJ6VTj8P8XjDsJ+eT0s+2CqN+iC/DY37nxKVs+s7cQdSw1jQNq1S+/b5fXD0DK0bAuV+V1w2egic/B99bn4886WouO05eYfupVHadTiM5s/gwKsigp1aQgZpB3oT4exHi50mIvwdVfT2p7KPH4KHF4KFFo75xSGa32zFZbWTnW7iSYyI1y0hqjonLWUbOX83jbFoOiVdyOZuWi9Fi42hyFkeTs1jBeUDJ4OoH+9ImMoCHawbQJjKAYD/PW75/IYQQd4CnvzKyeu5jcPlPWDoU/roMNPJAgRBCiIrP19e3TAOg0ho7dixjx469a+8nAZUQQgghhLh3VG0AL/0MMb3hagJEPwkv/ASVawLwUHV/Vo7qwIvRv3M0OYsB3/zKzL+2pFuDquXb73uZ3Q4Hvoc1b4MpC/S+0OtTaDao1KOmrDY7u8+k8fPhZOKOXuLMlVyXerUK6gX70jTMn6ZhlWgY6kftKj5U8tbfkVtRqVR4aDV4GDQEGjyoF+xbbDubzc7FzHz+vJDJkYuZHLmQyeGLGZxLy3OGVt/uTASgQYgv3RpUpVv9qrSsUUmmBRRCiPLgH6aEVPN6wul4WPV3ePrrOzK6V4j71ZkzZ4iMjGTYsGHMnz+/vLtzQzVr1gSUPgshSsYxcuraUVIOmZmZVK5c+W52qVjy15MQQgghhLi3VAqHF9co61JlnFWm9blyylkd6u/FDyPb0bleFfLMVl75djdrDl0sxw7fw0w5sOJVWPmaEk6Ft4XXtkHzwSX+0M9ms7PjZCr/XHaQNv/7CwNm/0r09jOcuZKLRq2iRY1KjO5Wh0UvP8If7/fg57935tPnmvG3thG0iqh8x8Kp0lCrVVSv5MVjjYJ549G6/GdIK7a+/Rd++59HmTG4JS+0r0mjUD9UKjianMWs+FP0/2YnLads4M0l+9l0NAWTpfTrYgkhhLgNoc3g+WhQqWHfQtgxvbx7JESZeumll1CpVAQGBmI0Gsu7OxVKbm4un3/+OYMHD6ZBgwao1WpUKlWFDLiOHz9O//79CQoKwsvLi2bNmjFr1izsdnuprpOUlMSIESOoUaMGer2eatWq8eKLL3Lu3Lli29tsNr7++mtatmyJt7c3fn5+dO7cmR9//PG677Fr1y6eeeYZgoKC8PDwoG7dukycOJG8vLxi21+9epVx48ZRp04dPDw8qFKlCs899xyHDx8u1b2JW+NYe+radaYAkpOTyc7OLnZ9qrtNRlCVkRkzZjBjxgxMJlN5d0UIIYQQ4v7jV00JqWJ6K9P9ffuMMrLKPwwAX08dUcNa84+lB/jxwAVGL9rLZ883o1/LsHLu+D3k8nFYOgQuHwWVBrr9Czq+BWpNiU5PvJLD8j1JLN97nvPphX+0+nvpeLRhVbo3CqFDnUB8Pe+dKZiq+nryZNNQnmwaCsDVHBNbTlwm7ugl4o9fJj3XzH/3nee/+85TyVtHz4dCeLpZddrWCkAlT/ELIUTZq9cDnvgE1o6HXyZDSBOo/Zfy7pUQd1xWVhZLly5FpVKRlpbGypUrGTBgQHl3q8K4dOkS48aNAyAiIoLKlSuTlpZWzr1yd+TIEdq3b09eXh79+/enWrVqrF69mtdff50jR47w1Vdfleg6p06don379ly6dInu3bszYMAATpw4QUxMDGvWrGHHjh3Url3b2d5ut9O/f3+WL19O7dq1GT58OEajkdjYWJ555hm++uorRo8e7fIeK1asYMCAAWg0Gp599llCQkLYvn07U6ZMYdOmTWzcuNFl7aQrV67Qrl07Tpw4Qbt27XjmmWe4ePEiy5cvZ+3atWzatIlHHnnkznwhRbG6dOnCRx99xPr16xk4cKBL3bp165xtypvKXto4VpRKUlIS4eHhnDt3jrAw+UBECCGEEOKOyr4E0T3hykkIqgcvrgWfIGe11WZnwopDLNl9DpUKPuzbhEFtapRjh+8Rh5bBj2+AOQcMIcoT6RHtb3qa3W5n64lUorYlsPn4ZWe5r6eWp5pW46mmobSJDEB3H06DZ7XZ2Xf2KqsOXmT1oYtczip8krlWkA+D2tTg2VZhBPjc/RFhQgjxQLHb4cfRyigqr8rwarxzKmBxb8vPzychIYHIyEg8PR/sNSDnzp3LK6+8wltvvcWXX37Jo48+yvr160t1jft5ir/s7Gx27txJq1atCAgI4IknnmDdunUkJCQ4r1URdOnShS1btrBmzRp69uwJgMlk4rHHHmPr1q3s2LGDdu3a3fQ6Tz31FKtXr2batGm88cYbzvIffviB/v3706NHD37++Wdn+bJly3j++efp0KEDGzZswMvLC4DU1FRat25NcnIyR48edX6t8vLyiIiIID093fl1BeV3/zFjxjBjxgw++ugj3nnnHed7jB49mhkzZvDWW2/x+eefO8t37txJp06dqF+/PocOHUKtvv/+Lijpz6rbzQ0+/vhj/vWvfxEdHc0LL7zgVm+xWKhfvz7nz5/n119/pXnz5oAy5V+bNm04c+YMx44dK/f/Ju6/7wAhhBBCCPHgMFSFISvBL0wZSbXwWcjPdFZr1Co+6teEF9rXxG6HCf89xH/3JZVffys6mxXW/Q8sH66EUzU7wYgtNw2nTBYbi387S/cvtjB03m9sPn4ZlQo616vC9EEt+P1/HuOjfk3oUCfovgynQPlea10zgMlPN+bXfz3KopcfYUDrcHz0Gk6n5vC/a/6k7YcbGbt4H3+cL34eeCGEEHeASgW9PodqLSHvKiz5G5hyb36eEPeQqKgotFotb7/9Nt26dWPjxo0kJiYW29ZqtfLJJ59Qp04dPD09qVOnDh999BE2W/HTEcfFxfHSSy9Rv359DAYDBoOB1q1bM3v27GLbq1Qqunbtyvnz5xk8eDBBQUH4+vry5JNPcvr0aQD+/PNP+vTpQ0BAAL6+vjz33HOkpKSU+r7T09MZMWIEISEheHp60qJFC77//nu3dgaDgccff5yAgIBSv8fdcvz4cbZs2UK3bt2c4RSAXq9nypQpAMyZM+em18nPz2fdunUEBwczZswYl7rnn3+e5s2bs27dOue/BUBsbCwAEyZMcIZTAEFBQbz55psYjUaio6Od5Tt27ODy5cv06dPHGU6B8m//73//G4D//Oc/LtMSxsbGolaref/991361K5dO3r37s2RI0fYvHnzTe9PuJo7dy4vvPACL7zwAj/88INb2dy5c51ttVotc+fOxWaz0blzZ1599VX+8Y9/0KxZM44fP86HH35Y7uEUyBR/QgghhBDiXlcpHIauhHk94OJ++H4QDFkBWmWKCbVaxaTejbDb7cTsTGTcDwfx1mvp0TikXLtd4RizYPnLcLzg6cqOb0G3/wHN9f9kMFttLN+TxFebTjqn8fPRa+j/cDgvtK9JRKDP3eh5haNRq2hfJ4j2dYJ4r3cjfjpwgUW7znLofAax+y8Qu/8CHesE8WrnWnSqGyTT/wkhxJ2m84QBC+CbLpB8CH4aC/1ml3j9RHFvsdvt2K+zBk5FovLyuiP/zz9y5Ai//vorvXr1Ijg4mKFDh7Jx40aio6OZPHmyW/tXX32VefPmERkZyahRo8jPz2fq1Kns2LGj2Ot/8sknnDx5krZt29K3b1/S09P5+eefGTFiBMeOHXMZDeNw9epVOnbsSEhICMOGDeP48eOsWrWKo0ePEhsbS6dOnWjVqhUvvfQSe/bsYfny5aSlpbFp06YS37djZFF2djZDhgwhJyeHpUuXMnjwYFJTU93CmYouPj4egO7du7vVdezYER8fnxIFOFeuXMFisRAREVHs91dkZCT79+8nLi6OWrVqAcr6Q4664toDbNq0yRku3ah9pUqVqFy5MomJiZw+fdo5lWBycjJBQUEYDIYbvke3bt1ueo/3u27duqHT6Rg1ahSjRo26Ydtt27YRExPjUrZ9+3a2b9/ufP3yyy+7XHvbtm1MmjSJJUuWYDabadKkCZ988kmFmRZUAiohhBBCCHHvC6oLf1uhrEmVuE2Znq7vf5wfRKlUKib1bky20cryvUmMWbSPeS88TMe6QTe58AMi/SwsGgiXDoPWE/rMhIeevW5zm83Oyv3n+eKX45xLUz4QquLrwaudajGgTTh+99C6UmXN4KFlUJsaDGpTg4NJ6czblsBPBy+y7WQq206m0riaH289Xo+/NKgqQZUQQtxJ/mHQPwZinoZDSyGyE7QcWt69EmXAnpfHsZatbt6wnNXfuweVt/dtXycqKgqAIUOGANCvXz9ef/11oqOjmThxosuUafHx8cybN49mzZqxfft2fHyUh4cmTJjgnO7rWrNmzXILIiwWC7169WLatGmMHTuWGjVcp8w+ePAgb775JlOnTnWWvf7668yaNYtOnToxefJkxo4dCyiB4lNPPcWaNWvYu3cvLVu2LNF9X7x4kbp167Jjxw70er3zPlq0aMH48ePp168f1atXL9G1SiM+Pt4ZJpVEzZo1i51u7VonTpwAoG7dum51Go2GyMhIjhw5gsViQau9/kf4lStXRqPRkJiYiN1ud/t9MiEhAVBGbDkEBQU56xo2bFiq9tfKyMjg6tWrznMcAVVQUBCXLl0iOzvbLaQq7j0eZHFxcSWe4m/+/PmlnpKzTZs2rF279hZ6dnfcn/NrCCGEEEKIB0+15vD8fFBp4OBi2PKZS7VareKTZ5vQ86EQTFYbIxbs5s+LmcVe6oFyfi/M+YsSThmC4YU1Nwyn9p29Sr9ZO3hr6QHOpeURZPDgvacasfXtbrzSuZaEUzfQNKwSXw5swebxXXmxQ028dBoOX8hkeMxu+s3awY5TqeXdRSGEuL/U7AiPvqccr3kbLh0t3/4IcZvMZjMLFizAz8+PPn36AMp0dn379uXs2bP88ssvLu2//fZbACZOnOgMpwCqV6/uDIyuVdwoGa1Wy8iRI7FarcTFxbnVGwwG51RvDoMGDQIgMDDQZV0klUrFwIEDAThw4MDNbtnFhx9+6AynAMLCwhg7dixGo5HFixeX6lolFR8fz/vvv1/iraThQUaGMuWzv79/sfV+fn7YbDaysrJueB1vb286d+5MSkoKM2fOdKlbsWIF+/fvB5TpER0cUwp+/PHH5OfnO8uvXLnCl19+6da+Q4cO+Pn5sXLlSvbt2+fyHhMnTnQeX/seNpvNbYq/Xbt2sWrVKrf24sElI6iEEEIIIcT9o86j8ORnsOpNiPs3BNZyCVu0GjXTBrYgfd5v7Dx9heHzf2flqA5U9XtAF9o+HQ+L/wqmbAhpAoMWK0+cF+NKtpH/XfMnK/aeB5Sp/Eb9pQ4vto/ES6+5i52+94VV9mZS78a88Ze6zN56mujtCew7m87gObvoVDeI955qRL1g3/LuphBC3B/aj4XTm+F0HCx7EV7ZBDqvm58n7hkqLy/q791T3t24KZXX7X/fxcbGcvnyZYYPH46nZ+Hvr0OHDmXhwoVERUW5TBnnCIA6derkdq3iygCysrL47LPPWLlyJadOnSInJ8el/sKFC27n1K1bF+9rRoeFhoYC0LRpU7dRPY664q51PVqtlnbt2l33Pq4NTu6UyZMnFzt1YkXyxRdf0LFjR0aPHs1PP/1E06ZNOXnyJLGxsTRt2pSDBw+6jKwbPHgw8+fPJy4ujiZNmvDEE09gNptZuXIlwcHBAC7tDQYDU6dO5eWXX6Zdu3Y899xzhISEsGPHDvbs2UODBg04evSoyzkffPABP//8M5999hk7d+6kbdu2XLx4kWXLltGoUSO3PokHlwRUQgghhBDi/tL6JUg9Cb/OgP++BpVrQvXCaV/0WjX/+Vsr+s3azqnLOQyP2c2SEW3x1j9gvxofiVXWnLKaILILDPwOPNxDEbvdzqqDF5n042HSckwAPNcqjLd71H9wg707pLKPnn8+0YAX29dkRtxJFv12lq0nUuk5bSvD2tVk7GN18feSEWlCCHFb1Gpl/alZHeDSEVj3P/DU1JufJ+4ZKpXqjkyddy9wTO83dKjrdJWPPvoo1atXJzY2lrS0NAICAgBllI5arXZO0VaUI4goymQy0bVrV/bu3UuLFi0YMmQIgYGBaLVazpw5Q0xMDEaj0e08Pz8/tzLHtHQ3qjObzTe7ZaegoKBiAw3HfThGJN0rHCOnrtfvzMxMVCoVvr43f2ipWbNm/P7770yaNIm4uDji4uKoU6cO33zzDenp6YwfP56qVas622u1WtauXcvHH3/MokWLmD17Nv7+/vTt25dx48ZRr149l/YAw4cPp1q1anz66afExsZitVp5+OGH2bhxI5988glHjx51OScsLMzZp7Vr1/Lbb78RHh7OBx98QM2aNRk4cKDbe4gH0wP2V7gQQgghhHggdJ8Caafg+M+wdBi8uhl8Ap3V/t46ol9oQ5+Z2zl0PoN/LD3AzL+2fHDWANoTA6v+DnYbNHwanp0LWg+3ZpezjPzPfw+x/kgKAA1CfPmoXxNa1Kh8lzt8f6vq58n7zzzE8I61+PfqI6w/ksK87QnE7j/POz0b8FyrsAfne1MIIcqCoSr0+wYW9IXdUVD7L9DwqfLulRClcu7cOdavXw9Aly5drttu4cKFzin1/P39sdlspKamUqVKFZd2KSkpbufGxsayd+9ehg8fzty5c13qFi9eTExMzO3exi1LTU3FZrO5hVSO+7jeVHm3q6zWoHKsPeVYi6ooq9VKQkICkZGRN1x/qqgGDRqwZMkSt3JHX1q3bu1S7uHhwaRJk5g0aZJLueNer20PyrR9jukBixoyZAhqtdptPbHq1au7fR8BzhFpxb2HePBIQCWEEEIIIe4/ao3ytPTsrpB2Gla8DH9dppQXqBHozZyhrRg0exdr/0hmztbTvNq5dvn1+W7ZHa2EUwCtXoAnp7p8XRy2n0xl7OL9pGYb0apVjP5LHV7vWge9VqbiKCs1Ar2ZPbQ1W45f5v2fDnPqcg7jlx3kp4MX+ahfE6pXkimphBDiltX+C3QYC9unwU9joUZb8HEfVSJERTV//nxsNhsdO3akfv36bvUWi4WYmBiioqKcAVWzZs3Yu3cvW7dupV+/fi7tt27d6naNU6dOAfDMM8+41RXX/m6yWCzs3LmTDh06uJQ7+tWiRYsyeV/HGlQl1aVLlxIFVI6Qcf369bzzzjsuddu2bSMnJ+eGQWRJZGVl8dNPPxEYGMjjjz9eonO+++47AOc6YTezfft2zpw5Q69evUoUElqtVhYvXoxWq+XZZ6+/7q14cMhfl0IIIYQQ4v7k6Q/9F4DWC05tgs2fuDVpFRHAxN6NAPjk52PsOn3lbvfy7tr7bWE41XYUPPWlWzhlsdqYuv4Yf4vaRWq2kfrBvvw0piN/f6yehFN3Sed6VVg7tjP/fKIBeq2aLccv0+OLLXy3KxG73V7e3RNCiHtXt/+Bqo0hN1VZr1J+pop7hN1uJzo6GpVKRUxMDHPnznXb5s+fT7t27Th48CC7d+8GlJEtoKwHVHQtqfPnzzNt2jS394mIiACUgKSozZs3M2fOnLK6vRKbMGECJpPJ+TopKYlp06bh4eFR4kCltCZPnozdbi/xVtLRVvXr16dz587ExcWxdu1aZ7nJZOK9994D4OWXX3Y5JzU1laNHj5KamupSnpeXh8VicSkzGo0MHz6ctLQ0Jk6c6LJmGShTCF5r2bJlzJs3j4cfftgt0Cyu/YULF3j55ZfRarVMmTLFpc5sNpOXl+dSZrPZGDduHMeOHWPMmDFUq1bN7ZriwSMjqIQQQgghxP0r5CHoPQ3++6oSUIU/AnUedWny10dqsCfxKv/dd57R3+9j9ZiO9+faSnsXwI/K07Q88hr0+F+4Ztq4jFwzoxbtZdtJ5Y/eQW3CmfhUY7z07iOsRNnSa9W81rU23RsH8/ayg+xJvMr//PcPNhxJ4fPnmxFocJ+SUQghxE1oPaDvLJjzF/jzR/hjOTR5rrx7JcRNbdq0iYSEBLp06UKtWrWu2+7FF19k586dREVF0bp1a7p168aLL75IdHQ0TZo0oW/fvhiNRpYsWULbtm1ZtWqVy/m9e/emZs2afPrpp/zxxx889NBDHDt2jFWrVtG3b1+WLVtW1rd6XaGhoeTk5NC0aVN69+5NTk4OS5cu5cqVK0yfPp3q1au7tB83bpwzyDl06JCzzGAwAEr407Fjx7t7E9eYOXMmHTp0oE+fPgwYMIDQ0FBWr17N4cOHGT16NO3bt3dp//XXX/P+++8zadIk5zR5AHv27KFfv348/vjjhIeHk5mZyerVqzl79iyvvPIKY8aMcXvvRx55hPDwcBo2bIinpye//fYb8fHx1KpVix9++AGNxvX3/+nTp7Nw4UI6duxI1apVOXfuHLGxseTm5hIVFeU2vV9KSgqNGzeme/fuREZGYjKZWLduHUePHuXJJ5/ko48+unNfSHFPk0cghRBCCCHE/a3ZAGj9knK88jXIcR0lpVKp+N++D1E/2JfLWUbeXLofm+0+e6L68H/hxzGAHdqMgCc+cgunTl/Opu/M7Ww7mYq3XsO0gc35qF9TCafKWe0qBpaOaMd7TzXCQ6sm/thlek7byo5TqTc/WQghhLvQZtB5vHK8+h+QlVy+/RGiBKKiogBuOnXcgAED8PLy4vvvv3eOXpkzZw4fffQRKpWKr7/+mrVr1/LWW2/x5Zdfup1vMBjYtGkTzz77LL///jtff/01Fy5c4LvvvmPUqFF3+rZKRa/Xs2HDBrp06cKCBQuYN28eYWFhLFq0qNgAZtmyZcTExBATE8OFCxcAWL58ubPs5MmTd/sW3DRu3Jhdu3bx9NNPs3r1aqZNm4ZarWbGjBlMnz69xNepUaMGXbt2ZevWrXzxxRd8//331KlTh2XLljF79uxi1zIdMGAAycnJREdHM336dFJSUnj33XfZt2+fcyRdUe3btyc8PJyffvqJzz77jF9++YVevXrx+++/M2zYMLf2/v7+PPPMM+zdu5evvvqKefPmUblyZebMmcOPP/6Ih4c8bOXQrVs3GjVqxIwZM8q7K+VCZZc5IspUUlIS4eHhnDt3jrCwsPLujhBCCCHEg8mcB990gdRj0OApGLDQLaA5dTmbJ6dvJd9s472nGjG8Y2Q5dfYOS9gKC/uB1aQEdU9Odbv3naeuMGLBbjLzLVTz92TusIdpVM2vnDosrudociajF+3j5KVsVCoY060OYx+rh0bt/qGDEEKIG7CaYe6jcPEA1H8SBi0q7x6JEsjPzychIYHIyEi36cqEEKKiKOnPKskNFDKCSgghhBBC3P90XvDsXFDr4OgqZS2ma9SuYuB/nnSsR3WUY8lZd7uXd17yH7B4sBJONewNvT5zC6d+OZLCsOjfyMy30LJGJWJHd5RwqoJqEOLHj6M7MPDhcOx2mL7pJK98u5vMfHN5d00IIe4tGh30+Y/ye8Gx1fDnqpufI4QQQog7TgIqIYQQQgjxYAhtCo8qCw7z8zuQdtqtyd8eqUG3+lUwWWyMXbwPo8V6lzt5B2UkwcJnwZgJER2g31xQu07XF7v/PCMW7sFksfF4o2AWvdKWKr4y3UZF5q3X8vGzTflyQHM8tGo2Hb1E3xnbSUjNufnJQgghCgU3gg4FazOuGQ/G++DBFCGEEOIeIwGVEEIIIYR4cLQbAzU7gTkXfvo7XDPbtUql4pPnmhLgo+dochb/iXcPse4Jplxl5FR2MlRtBAMXgc51eokfdp/j70v2Y7XZ6duiOjP/2hJPnaw3da/o06I6P4xsR4ifJ6cu5/DM19vYcVLWpRJCiFLpPB4q14SsC7Dpf8u7N0IIIcQDRwIqIYQQQgjx4FCrofc00HpCwmbY/51bk6q+nkzqrUz1NyPuJKcuZ9/tXt4eux1+HKOsq+EdCIOXgFcllyarD17kn8sPYrfD39rW4PPnm6HTyJ8G95qmYZX4cUwHWtaoRGa+hWHRv7Hq4IXy7pYQQtw7dF7K2owAv30DF/aVb3+EEEKIB4z8FSqEEEIIIR4sgbWh2wTleN0EyEpxa/J0s2p0qVcFk9XGhBWHsF8z0qpC2z4N/lgGai30/xYq1XCp3nQ0hbGL92Gzw8CHw5nyzEOo1arrXExUdFV9Pfn+1bb0ahKC2WpnzPf7mL89oby7JYQQ9446j0KT58FuU0ZX22zl3SMhhBDigSEBlRBCCCGEePC0HQWhzSE/A9a+7VatUqn4d5+H8NJp2JWQxg+7k+5+H2/F6c3wy2TluOcnULOjS/W+s1d5beFeLDY7zzSvxv/2bYJKJeHUvc5Dq+GrQS0Z2i4Cux0m/3SEqRuO31vBqhBClKceH4KHH1zcDwcWlXdvhBBCiAeGBFRCCCGEEOLBo9HC01+BSgNHVirBzjXCA7x58/G6AHz881Ey8sx3uZOllH0ZVrwC2KHF36D1cJfqc2m5vPLtbowWG39pUJXPnm+GRkZO3Tc0ahXvP92YfzxeD4DpG09ISCWEECVlqApdCh5Y+eV9yM8s3/4IIYQQDwgJqIQQQgghxIMptCk8XBDi/PwOWC1uTV7sEEntKj6k5ZiYEXfyLnewFGw2WPkaZKdAlQbQ8/+gyMiorHwzL8fsJjXbRKNQP74a1ELWnLoPqVQqxjxal3efbAjAV5tOSkglhBAl1WYEBNSGnEuw9fPy7o0QQgjxQJC/SoUQQgghxIOr67/AKwAuHYHd89yqdRo17z7ZCIDo7QkkXsm52z0smV9nwMkNoPWE56JB7+2sstvtvLX0AMdSsqjq60HUC63x8dCWY2dFWXu5Uy2XkOqLX06Uc4+EEOIeoNX/P3v3HR5VnbZx/DuTmfQGCZBQQ+gBBCkCigiioKBYWUAFBdeyYlmVdVcWddXdRVcEV9G1UJWmr4qoWJCmdEGKdEILECAVEtLLzPvHmQkZUkggyaTcn+ua6wxzzpzzBGPIzD3P84PB/zLub3wPkg+7tx4REakTBgwYQFRUFO+++667S3ELBVQiIiIiUnf51ofrJxn3V/0T0pOKHNK/XQOubRNKbr6dyd/tq+ICyyB+L6x4xbh/02vQKMpl98y1R/hpTxyeHmZm3N+D8CAfNxQpVa1wSPX2imjmrj/q3oJERGqCtjdB5ADIz4FlL7i7GhERqQNWrVrFnj17GD9+vLtLcQsFVCIiIiJSt3V/ABp1hqwU+OWNIrtNJhOThkZhNsEPu0+z9diZqq+xJPl58NVjxhtpbW8yvpZCfos5w2vfG6HaC7dGcUXT4KqvUdzmj9dGFqxJ9Y9vdvPt7yfdXJGISDVnMsHgf4PJDPu+heO/ursiERGRWk0BlYiIiIjUbWYPGOToQNoyE84eL3JIu7AA7urWFIBpPx2oyupKt2E6nNwKXkFwyzSXdadSMnN5YsFW8mx2bu3SmPt6NXdjoeIuj1/fmjF9WmC3wzOf7mD9wUR3lyQiUr01ioKu9xj3l78MWsdPRESk0iigEhERERGJHAAR1xqdSD+/XuwhTw5sg8VsYk10IpuPJldxgcVIjIZV/zbu3zQZAhu77P7nt3s4mZJFixBfJt/ZGVOh8ErqDpPJxEu3dmRI5zBy8m08Ou83DiekubssEZHq7bq/gYcnxKyFQyvcXY1IpTh69Cgmk4kHHnjA3aW4TUREBBEREZVy7v79++v3b5EyUEAlIiIiImIywcCXjPvbFxjhzwWa1fdleI9mAExd5uYuKrsdvpsA+dnQ+obzn/R2WLE3jv/77QQmE7w5vAv+XhY3FSrVgYfZxLQRXeneoh6pWXn88eMtpGblurssEZHqK7gZ9HzIuL/iFbDZ3FuPSDHGjRuHyWQiJCSE7Oxst9XhDGJMJhPffvtticf16tWr4LjVq1eXeNwrr7yCyWTCarVy+vTpEo974IEHCs43ffr0Eo8bMWJEwXFz5swpy5fkFnFxcTz++OP06tWLRo0a4eXlRdOmTRk4cCBffvkl9mrUzXnq1CkefPBBwsPD8fb2pl27dvzrX/8iN7d8v1+eOXOGCRMm0Lp1a7y8vGjQoAF33303u3fvLvE5CxYs4JprrsHf3x8/Pz969uxZ6n/XvXv3cu+99xIWFoaXlxctWrTgqaeeIjm5GnzgUAAFVCIiIiIihmY9od0QsOef70y6wOPXt8bTw8yGw0lsOpxUxQUWsvdrOLwaPLxgyBTX0X4ZuTz/5U4A/ti3JT0i6rupSKlOvCwe/O++boQHeXM4IZ0nF24j31Z93ugQEal2rn0GPP3h1A7Y85W7qxFxce7cOT777DNMJhPJycl89dVX7i4Ji8XCrFmzit23e/dufv31VyyW0j80ZbfbmT17NiaTiby8PObOnXtZ101OTmbJkiUlXnfFihWsWFE9uiSPHz/Oxx9/TFBQEHfccQfPPvssN910E7t27eKuu+7i4YcfdneJAJw+fZpevXoxe/Zsrr76av785z9Tv359Jk2axN13313mIC0pKYlevXrx5ptv0rBhQx5//HFuvPFGvvnmG6666io2bdpU5DnPPvss9957L4cPH+bee+9l7NixJCUlMXbsWCZMmFDk+I0bN9KzZ08WLVrE1VdfzZNPPknbtm15++236dOnD0lJbnw9JwUUUImIiIiIOA34u7HdvRgSDxbZ3STYh7t7GGtRfbTmcFVWdl5OBvww0bh/zVNQv6XL7inL9hN/LpvIBn48O6idGwqU6qphgDcfjemBt9XM6v0JTFm2390liYhUX36hcPUTxv1V/wZbvnvrESnk008/JT09naeffhqz2czMmTPdXRI333wz3377LQkJCUX2zZw5E7PZzODBg0s9x4oVKzh69CgPPfQQgYGBJQZPF15327Zt7Nixo8i+efPmkZ2dzZAhQ4p9bqtWrWjVqtVFr1EVunTpwpkzZ1i2bBnvv/8+//73v5kxYwYHDx6kQ4cOzJgxo9TOoqry17/+lePHj/Pee+/xxRdf8Nprr7F+/XpGjhzJ119/zaJFi8p0npdeeono6GieeeYZ1q9fz5tvvsmCBQtYvXo12dnZjBs3Dluh7tUtW7YwdepUWrduze7du/nggw+YPn06O3fupGfPnrz55pts2LDB5RoPPfQQ6enpLF68mC+//JI33niDn376if/85z8cOHCAv//97xX6dyOXRgGViIiIiIhTWCdoezNgh/X/LfaQP/ZtickEy/fGczD+XNXWB7B2GqSegKBm0Pdpl127YlOYvykGgH/e3glvq0fV1yfVWqcmQfzn7i4A/G/1IVbvj3dzRSIi1Vif8eAdDEnR6qKSamXmzJlYLBaee+45BgwYwIoVK4iJiSn22Pz8fF5//XVat26Nt7c3rVu3ZvLkyS5v/l8oPj6ep59+umD0WmhoKHfddRe7du0q8Tnjxo0jNzeXTz75xOXx3Nxc5s2bx6BBg2jatOlFvy6Ahx9+mOHDh3PgwAHWrFlT6nPuv/9+PDw8ig3pZs+eTYcOHejTp0+xz71wDaqkpCSaNm1KQEAABw+6flittH2lycrK4m9/+xvNmzfH29ubDh068M477xTpNLJarXh4FP3dPSAggJtuugmgXNetDOfOnePTTz8lMjKSRx55pOBxk8nEa6+9BsBHH31UpnMtWbIEs9nMyy+/7PJ4nz59uPXWW9mzZw8///yzy/EATz/9NPXrn58Q4efnVxA0vf/++wWPHzp0iF27dtGzZ0+GDRvmco1nn32WkJAQPvnkE9LT08tUr1QeBVQiIiIiIoU5Q5/tCyH1ZJHdkQ38ubFDIwBmrDlSlZXBudOwwTFjf9A/wdO3YJfNZufFJbuw2eGWK8K5ulVo1dYmNcawLo0Z3bsFAM98toO41Cw3VyQiUk15BUDvx4z7v0zRWlRSLezZs4eNGzcyaNAgGjVqxJgxY7DZbMyePbvY4x9++GH+9re/YbPZGD9+PIMHD2bq1Kk89dRTxR5/6NAhunfvzltvvUWrVq144oknGDJkCD/88AO9e/cudvQaQO/evYmKiipSxzfffENCQgLjxo0r9etKTk5m8eLFREVF0b17d8aMGQNw0e6wJk2aMGjQIBYsWEBOTk7B41u3bmX79u2MHTu21OcXFhISwscff0xGRgb33HOPy5pKDz74ILGxsUyfPp3WrVuX+Zx/+MMfmD9/PnfeeSePPvooaWlpPPnkk8WOpCtOVlYWK1euxGQy0bFjxzJftzJs2LCB7OxsbrzxRkyFRowDtGjRgnbt2rFu3Try8y/ecXr69GlCQ0Px9/cvsq9lS2NCxMqVK12OL7zvco43m800b96cjIwMNm7ceNFapXJptWQRERERkcKa94IW10DMOtjwLgz+V5FDHrkukmV74vhyayzPDGpLwwDvqqntlzcgNwOa9ICo21x2LdkRy9ZjZ/H19ODvQztUTT1SY/19aAe2xJxh76lUnlq0jfl/7I2H2XTxJ4qI1DW9Hob170D8Htj/HXS4xd0VyQXsdjuZudV/BKOP1aPIm/qXwhnYjB49GoA777yTxx57jNmzZ/Piiy9iNp/vR1i9ejWzZs2iS5curFu3Dj8/PwAmTpxI165diz3/mDFjOHXqFD/88IPLSL5JkybRo0cPHnroIX7//fdinztu3DgmTJjA5s2b6dmzZ0G9ISEh3HbbbaWu9zR//nyys7MLvq5rr72WiIgI/u///o+3336bwMDAEp/74IMP8v3337NkyRKGDx9ecF2LxcKYMWNKDO+Kc/311/Pcc8/x2muvMWnSJF5//XXee+89lixZwqhRo7j//vvLfC6AAwcOsGvXLoKCggB4+eWX6dWrF9OmTWPUqFH06NHD5fj4+Hjee+89bDYb8fHxfPfddxw/fpyXXnqpzMHY9u3by7UuWXBwMH/+858velx0dDQAbdq0KXZ/mzZt2L9/PzExMURGRpZ6rtDQUOLj40lLSysSUh05YnwI8MCBAy7HF95X3PEnTpwgIyMDX1/fUo+32WwcO3as4BoDBw4stVapXAqoREREREQu1PdpI6DaMhuuew68g1x2d29Rn27Ng9l67CwLNx3nqRuKf5FWoZIPw29zjPs3/AMKvcGRnZfPm8uMF3DjB7QmPMin8uuRGs3b6sG791zJLe+sZePhZN7/+RDjB5T908AiInWGTz0jpFrzpvFBkfZDXf4NFvfLzM0n6sUf3V3GRe15ZTC+npf3VqxzhF5gYCC33347AP7+/txxxx3MmzeP5cuXM2jQoILjP/74YwBefPHFgnAKjK6jp556ihdeeMHl/Nu2bWP9+vWMGzeuyHpRbdu25aGHHmLq1Kns2rWLTp06Falv9OjRPP/888yaNYuePXty8uRJfvzxRx5//HE8PT1L/dqc61Tdd999gDE27r777uOf//wnixYt4uGHHy7xucOGDSM0NJRZs2YxfPhwsrKyWLhwIUOHDqVRo0alXrc4r7zyCitWrGDKlCk0bdqUv/71r0RERLiMkCurF154oSCcAggKCmLSpEmMHj2auXPnFhtQFR57Z7VaeeONN3j22WfLfM3t27cXGZ1XmhYtWpQpoEpJSQFw+XoKc4aIzuNKc/PNNzN79mxefvll3njjjYLHN23axLfffgvA2bNnXY5/7bXXeOutt7jnnnsIDg4GICMjg8mTJ7vU6OvrS9u2bYmMjGTz5s0sXbqUoUOHFhzz1ltvkZSUVOQa4h4a8SciIiIicqHWN0CDDpCbboz6K8b9V0cAsGjzMfLyq2Dkz6p/gy0PWg2Elte67Fq46RgnzmTSMMCLcdcUHWMhUpzIBv68PMwYFfPf5dHsP+2GNdVERGqC3o+B1RdObYeDJXeAiFS2JUuWkJCQwPDhw/H2Pt/BX9I4vB07dgBGN9KFinvMOe4sLi6Of/zjH0Vu+/btAyjYXqhhw4YMHTqURYsWkZWVxdy5c8nPz7/oeL8tW7awY8cOBgwY4LJOVVnH/FmtVu677z6WLVtGbGwsixcv5syZMxe9bmnnW7hwIb6+vjz55JPk5OQwf/78Uru4SlLa3/22bduK7OvUqRN2u528vDyOHDnCyy+/zN///nfuuusu8vLyynTNBx54ALvdXubb0aNHy/11Xa5XXnmF8PBwpkyZQt++fZkwYQL33nsv/fr1IyoqCsClG7Bfv36MHj2a6OhooqKiePTRR3niiSfo3Lkzp06dKgjNnM8xmUy89957WK1Whg0bxl133cVzzz3H4MGDefbZZ+ncuXORa7jLgAEDiIqK4t1333V3KW6hDioRERERkQuZTHDVH2Hps7D5I7jqYbjgxctNncKo52vlVEoWq/cncENU+T+dWWZJh2Dn58b9gS+67ErPzmP6KmPB5CcHtsHHs+jiyiIlubt7U77fdZqV++L5y+c7+PJPV2PxcP8LdRGRasUvFHqMM9aBXDsV2tzg7oqkEB+rB3teGXzxA93Mx3r5v6M5gxpncOM0cOBAmjRpwpIlS0hOTqZ+/fqA0U1iNpsLxp0VVlxnUXJyMgBLly5l6dKlJdaRnp5e4r5x48bx1Vdf8cUXXzB79my6d+/OFVdccUlfV5s2bejduzcbN25k9+7dpa7BNG7cON566y3mzJnD6tWrCQsLY8iQIaVetzSRkZEFoxG7d+/O1VdffUnnKe7v2flYaZ1GHh4eRERE8Pzzz2OxWHjuuef46KOP+NOf/nRJdVQEZwhUUt2pqakux5WmadOmbN68mZdeeonvv/+eX3/9lWbNmvHKK68QERHByJEjadiwoctz5syZQ48ePZg5cyZz5szBx8eHwYMH85///IeOHTtisVgKvvcBBg8ezJo1a3j11VdZuXIlS5cupVOnTixevJgVK1awc+fOItdwh1WrVrkEs3WNAioRERERkeJcMRKWvwxJB+HwKmjtOpvcy+LB3d2b8tGaIyz49VjlBlTr3gLs0PYmaNzVZdfcDUdJTMuhRYgvI3o2q7wapFYymUxMvrMzN079md9PpPDBL4c16k9EpDh9xsOm940RwLFboUk3d1ckDiaT6bJH59UEx48fZ9myZQBcd911JR43b948nnzyScAICmw2G4mJiTRo0MDluLi4uCLPdXYIvfPOOzz++OOXVOeQIUMIDw/nr3/9K7Gxsbz33nulHp+ZmcnChcbEgvvvv7/ENZ5mzpzJ1KlTSzxP586d6dmzJ++++y5xcXFMmDABi+XSvy+mTp3KunXrCAkJ4ddff+W9997jscceK/d54uLiaN68eZHHoGxBDsCgQYN47rnnWL16dZkCqspag8q59pRzLaoLRUdH4+npWeTrLUmTJk2YMWNGkcf/8Y9/ABQZf2g2m3nyyScLvr+djh49SlpaGt26dcNqtbrs69WrV8HIwMLeeuutYq8hVa/2//QWEREREbkUXv7Q9R7jzahfPyoSUAGMuqo5H605wqr98Zw4k0HTer4VX0fqyfNjBvs+7bIrKzefWWuPAvDk9W2wqvNFLkGjQG9eurUjz/7fDv67PJqbO4UR2cD/4k8UEalLAhtDxzth52ew8T24q+ibqiKVac6cOdhsNvr27Uu7du2K7M/Ly2Pu3LnMnDmz4A38Ll26sHXrVtasWcOdd97pcvyaNWuKnKNXr14AbNiw4ZIDKg8PD8aMGcPrr7+Ot7c3o0aNKvX4zz//nJSUFLp27Ur37t2LPWb+/Pl88sknvPbaa6WuZTVu3LiCAOdSx/uBMXpv4sSJtGvXjlWrVnHNNdcwYcIErrvuulK7uIqzZs0a7r333iKPAVx55ZVlOsfJkycBioQvJamsNah69+6Np6cnP/30E3a7HVOh9fhiYmLYv38/AwYMuKxgMD8/n0WLFmGxWLjrrrvK9Jz58+cDMHLkyDIdHxMTw9q1a4mKiioY9Sfuo1ewIiIiIiIl6flHY3vgB0g5UWR3ZAN/rm4Vgt0OX26NrZwaNrwLtlxofjU07+2y68utsSSmZdM4yJthXRtXzvWlTrizWxP6tW1ATr6Nl77ejd1ud3dJIiLVTx9H98TuxZBSSf/uixTDbrcze/ZsTCYTc+fOZcaMGUVuc+bMoU+fPvz+++9s2bIFgNGjRwPGej+Fx/LFxsby3//+t8h1rrrqKnr16sXChQv59NNPi+y32Wz8/PPPF633mWeeYfHixfz4448EBweXeqxzvN/UqVOL/bpmzJjBHXfcQWJiIl9//XWp57rvvvtYvHgx33//fbEhXlmkp6cXhGoLFy4kPDycBQsWkJuby6hRo8jKyirX+V599VWXkXgpKSn885//xGQyuXSL7dixg9zc3CLPT05OZuLEiQBlHllYWWtQBQYGMnLkSA4fPswHH3xQ8Ljdbuf5558H4KGHHnJ5TkpKCvv27ePUqVMuj+fm5pKZmenymM1mY8KECezfv58nnniCxo1dX984RwgWtmbNGiZPnkyLFi149NFHXfalpaUV+Z02JSWF0aNHk5+fz+TJk8v0dUvlUgeViIiIiEhJQttAi2uMcT6/fwrXPlvkkDuubML6Q0l8tT2WJ65v7fJJwsuWlQq/zTHuX9A9lW+z8+EvhwB48NpIdU/JZTGZTLw8rCODp/3CmuhEftx9mps6hbu7LBGR6qXxldCiL8SshV8/hBvL3qEgcjlWrlzJkSNHuO6664iMjCzxuLFjx7JhwwZmzpxJjx49GDBgAGPHjmX27Nl07tyZO+64g+zsbD799FN69+5d7OizhQsXMmDAAEaOHMlbb71Ft27d8PHx4dixY2zYsIGEhISLhjQNGzbk9ttvv+jXdfDgQX755RciIiLo379/qV/XwoULmTlzJnfffXeJx/n7+5fpuqV56qmn2L9/P1OmTCnocOrduzcvvfQSL7zwAn/5y1945513yny+tm3b0qlTp4JuoC+++IITJ07wzDPPuIyXmzZtGt9++y3XXHMNzZs3x8fHh5iYGJYuXUp6ejrDhw+/aDdaVXjttddYtWoVjz32GMuXL6d169b8/PPPbNy4kVtvvbVIF9PixYsZO3Ys999/P3PmzCl4PC4ujo4dOzJo0CBatmxJTk4OP/74I/v27WPo0KHFhkd33303mZmZXHHFFQQGBrJz506+//576tevz1dffUVAQIDL8V999RUTJ07k+uuvp3HjxsTHx/P111+TkJDAq6++yrBhwyrl70jKR69iRURERERK0/UeY7t9ARTTVXJTpzC8LGYOJ6Sz+2TRT/Vdlh2LICcNQttCmxtddv205zRHkzII8rEyUmtPSQVoGerHI9cZb3q98s0eMnLy3FyRiEg11Ge8sf1tNmSnubcWqTOcXUYPPPBAqceNGDECHx8fFi5cWNCd8tFHHzF58mRMJhPTp0/n+++/55lnnilYg+dCLVu2ZNu2bUyaNIm0tDRmz57NBx98wPbt2+nXr1/BelEVYdasWdjtdu6///5SP+Q1cOBAmjVrxrJlyzh+/HiFXf9CX3zxBTNnzuTGG2/kmWeecdk3ceJE+vXrx/Tp04sN9kry2Wefcc899/Dll1/yv//9Dz8/P95++22mTJnictzo0aMZMmQI+/btY+7cuUybNo2ff/6Zfv36sWjRIj777LOK/SDcJQoPD2fTpk2MHTuWtWvXMm3aNJKSknj11Vf5/PPPy1xjUFAQt912G1u3buWdd95h1qxZ1KtXj48++oivv/4aLy+vIs+5/fbbyc7OZv78+UydOpV9+/bxxBNPsGvXLrp27Vrk+M6dO9OlSxeWLVvGlClTWLJkCb169WLlypVMmjTpcv8qpIKY7JrdUKlOnDhBs2bNOH78OE2bNnV3OSIiIiJSXtnnYEpbyM2AB5dDs55FDhk/fytLd57ij31bMumWqIq5rt0O7/aCxP1w83+g1yMuu++dsZF1B5N4rH8rnrupfcVcU+q8zJx8bpj6M7FnM3ny+tY8M+jSxuOIiNRaNhtM7w7Jh2HIFLjqoYs/RypEVlYWR44coWXLlnh7e7u7HBGRYpX1Z5VyA4M6qERERERESuMVAB0c4x+2zy/2kNsc6z99veMk+bYK+vzX0TVGOGX1gy6uozIOJ6Sx7mASJhPc06t5xVxPBPDx9GDS0A4AfLTmCPGp5VtnQUSk1jOb4aqHjftbZhfbXS0iIiJlo4BKRERERORiujrmve/6EvJyiuzu364hQT5W4s9l8+uR5Iq55q8fGdsuI8A7yGXXwl+PATCgXUOa1vOtmOuJONzUKYwrmweTmZvPf1dEu7scEZHqp8tIsHhD/G44sdnd1YiIiNRYCqhERERERC4m4lrwbwTZKXDk5yK7PS1mbujQCICf9sRd/vXSk2D/d8b9Hg+67MrKzef/fjsBwL3qnpJKYDKZ+JtjbOSizcc5lKA1VkREXPjUg053Gfe3zHJvLSIiIjWYAioRERERkYsxe0CHW437e74q9pBBHR0B1d7TXPYyr7u/BFsehHeBsE4uu37aE8fZjFwaB3nTv13Dy7uOSAl6RYYwsH1D8m12pvy4393liIhUPz3GGdtdX0JGBXVPi4iI1DEKqEREREREyiLqNmO7bynk5xbZfW2bULwsZo4nZ7Lv9LnLu9aOhcb2ipFFdi3ZHgvAHd2a4GE2Xd51RErx3E3tMZng+12n2X+539M1SF6+jTPpORxPzuBoYjrx57LIyMlzd1kiUt006Q5hnSE/+/y/2yIiIlIuFncXICIiIiJSI7S4BnxDISMRjq6BVte77Pb1tHBtmwYs3xvHst1xdAgPvLTrJEZD7G9g8oDOd7vsOpOew+r9CQDc1rXJpZ1fpIzahQVwc6cwvtt5mndXHeTtUVe6u6QKl2+zszM2hQ2HkvgtJpnDCekcS84gz1a0C7K+nyfN6/sS1TiQHi3q0SsyhCbBPm6oWkSqBZPJ6KL69mnYMht6P2Y8JiIiImWmDioRERERkbJwGfO3pNhDCo/5u2Q7Fhnb1jeAv+sIv+92nSLPZqdDeCBtGwVc+jVEymj8gNYAfPv7SQ7XorWo9p1O5dVv99Bn8gpuf3cdr/+wj+V74zmcmF4QTnlZzPh5ehS835ycnsP242dZsOkYz3y2g2teW8mw6Wv53+pDnE7JcuNXIyJu03k4WH0hKRpObHF3NSIiUgMNGDCAqKgo3n33XXeX4hbqoBIRERERKasOt8Bvs+HAMrDbi3xS+vr2RqC0KzaVxLRsQv29ynd+u/38GldX/KHI7iXbTwJwW9fG5S5d5FJ0bBzEwPYNWbEvnv+tPsQbw7u4u6TLsvloMu+uOljQiQgQ6G2hd2QIV7WsT1R4IC0b+BHq74XVw/g8p91uJy07j2PJGRxNzGD78TP8evQMO0+c5fcTKfx+IoU3l+1ncKcwHuzbkm7N67nryxORquYVAB2Gwe+LYMcCaNbT3RWJiEgNs2rVKpo2beruMtxGAZWIiIiISFm16Gt8UvrcSYjbDWGdXHaH+nvRITyQvadSWXcwsfxj+BL2Q9JB8PCENoNcd53LZvNRYxH2W7sooJKq8/j1rVmxL57F22J5dlA7woK83V1SuZ04k8G/v9vLdzuN7kazCQZ3DOOubk3p17YBnpaSh4uYTCYCvK10bBxEx8ZBDL0iHIDEtGyW7Y7jq22x/Ho0maW/n2Lp76e4vn1DJgxqR1TjSxzzKSI1S5eRRkC16wsYPBmsNe9npIiIiLtoxJ+IiIiISFlZvSHiWuN+9LJiD+nXJhSANdGJ5T//vm+MbWR/8HZ9c3vVvnjsdujcJEjr3kiVurJ5Pa5qWZ88m51PNh51dznlYrfbmb8phhum/sx3O09jNsGoq5qx8tn+/O++7twQ1ajUcKo0of5e3NOrOZ892ofvnryWu7s3xcNsYuW+eIa+s4a/L95JSmZuBX9FIlLttOwHgU0gKwUOfO/uakRERGoUBVQiIiIiIuXR5kZje3B5sbv7OgKqtdGJ2O328p17ryOgan9LkV3L9sQBcEOHRuU7p0gFGHdNBAALNh0jKzffvcWUUUpmLg9/8ht/X7yLrFwbV7Wsz9Inr2XynVcQEepXodeKahzIlOFd+OnpftxyRTh2O8zfdIwbpv7Mj7svY006Ean+zB5wxQjj/vaF7q1FRESkhlFAJSIiIiJSHs6A6thG49PSF+gZUR9Pi5nTqVkcjE8r+3nPHoNTO8BkhnZDXHZl5uSz9qCxZs6NUQqopOrdGBVGk2AfzmTk8rVjLbTqLCYpnTvfW8dPe+Lw9DAzaWgHFj3Umw7hlTt2L7KBP9Pv6cbCh3oTGepHwrlsHvnkN174aleNCfZE5BJ0vcfYHlwOafHurUVERKQGUUAlIiIiIlIe9SIgpA3Y8+Hwz0V2e1s9uCqiPgBrD5ZjzN+BH41ts97g38Bl17qDiWTl2mgS7EOH8IBLrVzkknmYTdx/dQsAZq07Uv7uwCq080QKd7y3nkMJ6YQHefPlY1fzx2sjMZtNVVZDn1YhfPfUtTzSLxKATzbGcMd76zmenFFlNYhIFQptA016GL8b7PrS3dWIiIjUGAqoRERERETKK7K/sT26ttjdvSONgGrL0TNlP+ehlca2zQ1Fdq0+YHwae2CHhphMVfcmu0hhI3o0x8fqwb7T59h2/Ky7yynWzhMp3DtjI8npOXRuEsSS8dfQqUmQW2rxtnrw/JAOzBnbkxA/T/aeSuWO99ax9Vg5fi6ISM3R6S5ju+crt5YhUlZHjx7FZDLxwAMPuLsUt4mIiCAiIqJSzt2/f3/93i5SBgqoRERERETKK6KvsS0hoOrh6KDaEpNctk6T/Fw48otxv9X1RXavP5gEQN/WoeWvVaSCBPlaGdI5HIDPNh93czVF7T2Vyr0zNpKalUf3FvVY+HBvGgZ6u7ss+rdryLdP9qVDeCCJaTmM/HAjP+zSulQitU7Ubcb22EZIrf6jUKXmGzduHCaTiZCQELKzs91WhzOIMZlMfPvttyUe16tXr4LjVq9eXeJxr7zyCiaTCavVyunTJf97+cADDxScb/r06SUeN2LEiILj5syZU5YvyS3i4uJ4/PHH6dWrF40aNcLLy4umTZsycOBAvvzyy2rVvX7q1CkefPBBwsPD8fb2pl27dvzrX/8iNze3XOc5c+YMEyZMoHXr1nh5edGgQQPuvvtudu/eXeJzFixYwDXXXIO/vz9+fn707Nmz1P+ue/fu5d577yUsLAwvLy9atGjBU089RXJycrHHZ2Vl8eqrrxIVFYW3tzf16tXj5ptvZt26deX62qTsFFCJiIiIiJRXi2uMbfxuyCj64qZL02CsHibiUrM5cSbz4uc7sRly0sA3BMK6uOw6lZLJ4cR0zCboFRlSEdWLXLI/9GgKwDc7TpKRk+fmas6LT83iwTmbC8KpueOuwt/L4u6yCoQH+fD5o30Y2L4hOXk2xi/Yytc79Aa2SK0S1MQY04sd9nzt7mqkljt37hyfffYZJpOJ5ORkvvrqK3eXhMViYdasWcXu2717N7/++isWS+n/NtvtdmbPno3JZCIvL4+5c+de1nWTk5NZsmRJidddsWIFK1asuOg1qsLx48f5+OOPCQoK4o477uDZZ5/lpptuYteuXdx11108/PDD7i4RgNOnT9OrVy9mz57N1VdfzZ///Gfq16/PpEmTuPvuu8scpCUlJdGrVy/efPNNGjZsyOOPP86NN97IN998w1VXXcWmTZuKPOfZZ5/l3nvv5fDhw9x7772MHTuWpKQkxo4dy4QJE4ocv3HjRnr27MmiRYu4+uqrefLJJ2nbti1vv/02ffr0ISkpyeX4rKwsBg4cyIsvvojVauVPf/oTt99+O+vWreO6665jyZIll/aXJqVSQCUiIiIiUl7+DaBBe+N+TNFP0/l4etCxsTFWbEtM8Z/Oc+Ec7xc5AMyuv6Kvc3RPXdE0mCAf66XXLFIBrmpZn4gQX9Jz8ln6+yl3lwNAZk4+D87dwsmULCIb+DHr/p7VKpxy8vOy8OGYHtzZrQn5Njt/XrSNL3474e6yRKQidbzd2O5e7NYypPb79NNPSU9P5+mnn8ZsNjNz5kx3l8TNN9/Mt99+S0JCQpF9M2fOxGw2M3jw4FLPsWLFCo4ePcpDDz1EYGBgicHThdfdtm0bO3bsKLJv3rx5ZGdnM2TIkGKf26pVK1q1anXRa1SFLl26cObMGZYtW8b777/Pv//9b2bMmMHBgwfp0KEDM2bMKLWzqKr89a9/5fjx47z33nt88cUXvPbaa6xfv56RI0fy9ddfs2jRojKd56WXXiI6OppnnnmG9evX8+abb7JgwQJWr15NdnY248aNw2azFRy/ZcsWpk6dSuvWrdm9ezcffPAB06dPZ+fOnfTs2ZM333yTDRs2uFzjoYceIj09ncWLF/Pll1/yxhtv8NNPP/Gf//yHAwcO8Pe//93l+OnTp7N+/XqGDx/O1q1bmTZtGrNnz2bbtm34+fnx0EMPce7cucv/SxQXCqhERERERC6Fs4vqaPHjHnpG1ANgc1nWoXIGVMWO90sE4JrW6p4S9zOZTAzv0QyA/9tSPcKVl7/Zzc7YFOr5Wpn9QE+CfKtvkOthNjHl7i6M7NkMmx3+8vkOjfsTqU2cY/6Ob4SUWPfWIrXazJkzsVgsPPfccwwYMIAVK1YQExNT7LH5+fm8/vrrtG7dGm9vb1q3bs3kyZNd3vy/UHx8PE8//XTB6LXQ0FDuuusudu3aVeJzxo0bR25uLp988onL47m5ucybN49BgwbRtGnTi35dAA8//DDDhw/nwIEDrFmzptTn3H///Xh4eBQb0s2ePZsOHTrQp0+fYp974RpUSUlJNG3alICAAA4ePOhybGn7SpOVlcXf/vY3mjdvjre3Nx06dOCdd94p0mlktVrx8PAo8vyAgABuuukmgHJdtzKcO3eOTz/9lMjISB555JGCx00mE6+99hoAH330UZnOtWTJEsxmMy+//LLL43369OHWW29lz549/Pzzzy7HAzz99NPUr1+/4HE/P7+CoOn9998vePzQoUPs2rWLnj17MmzYMJdrPPvss4SEhPDJJ5+Qnp5e5Br/+Mc/XP5btGrVinHjxpGQkMDnn39epq9Pyk4BlYiIiIjIpXCuQxVT+jpUv10soMpOg5PbjfuR17nsstvtrD9kdFBd00rrT0n1cHf3pphN8OvRZI4nZ7i1lm92nGTR5uOYTDD9nm60CPFzaz1lYTab+PcdnRnRwwipnly0jU2Hky7+RBGp/gIbQ3PHG+F7NApKKseePXvYuHEjgwYNolGjRowZMwabzcbs2bOLPf7hhx/mb3/7GzabjfHjxzN48GCmTp3KU089Vezxhw4donv37rz11lu0atWKJ554giFDhvDDDz/Qu3fvYkevAfTu3ZuoqKgidXzzzTckJCQwbty4Ur+u5ORkFi9eTFRUFN27d2fMmDEAF+0Oa9KkCYMGDWLBggXk5OQUPL5161a2b9/O2LFjS31+YSEhIXz88cdkZGRwzz33uKyp9OCDDxIbG8v06dNp3bp1mc/5hz/8gfnz53PnnXfy6KOPkpaWxpNPPlnsSLriZGVlsXLlSkwmEx07dizzdSvDhg0byM7O5sYbb8RkMrnsa9GiBe3atWPdunXk5+df9FynT58mNDQUf3//IvtatmwJwMqVK12OL7zvco43m800b96cjIwMNm7ceMnXkIpR/eYeiIiIiIjUBM16Gdu4PZCTAZ6+Lru7NgsGIDr+HJk5+fh4Fv1EJACxW8CeD0HNIcj1U6UnzmRyOjULi9lEtxb1KvorELkkjQK96R0ZwvpDSXz7+yn+1N89o3Fiz2Yy8cudAIzv35prWtecENdsNvGvOzqRnJHDT3vi+OPHW/jyT1fTplGAu0sTkcsVdRsc2wD7lkKfx9xdTd1gt0Ouez8wUSZWX7jgTf1L4QxsRo8eDcCdd97JY489xuzZs3nxxRcxFxoXvXr1ambNmkWXLl1Yt24dfn7GBzkmTpxI165diz3/mDFjOHXqFD/88IPLSL5JkybRo0cPHnroIX7//fdinztu3DgmTJjA5s2b6dmzZ0G9ISEh3HbbbaWu9zR//nyys7MLvq5rr72WiIgI/u///o+3336bwMDAEp/74IMP8v3337NkyRKGDx9ecF2LxcKYMWNKDO+Kc/311/Pcc8/x2muvMWnSJF5//XXee+89lixZwqhRo7j//vvLfC6AAwcOsGvXLoKCjPHfL7/8Mr169WLatGmMGjWKHj16uBwfHx/Pe++9h81mIz4+nu+++47jx4/z0ksvlTkY2759e7nWJQsODubPf/7zRY+Ljo4GoE2bNsXub9OmDfv37ycmJobIyMhSzxUaGkp8fDxpaWlFQqojR44Axt9d4eML7yvu+BMnTpCRkYGvr2+px9tsNo4dO1ZwjYEDBxZc4+DBgxw5coSoqKiL1iQVQwGViIiIiMilCGwM/mGQdhpO7YAWrqNDGgV60yDAi4Rz2ew5lUr3kgKmY45P7TXvXWTX1mNG91XHxoF4W0sIuETc4JYrGjsCqpNuCajsdjsvfLWLc9l5dGsezJ9vKP6NkurM4mHmnVFXMnrmJjYfPcPDn/zGV+Ov0VpzIjVdu5vhh78ZIVVGMvjWv/hz5PLkZsC/G7u7ioubeBI8L6/T1zlCLzAwkNtvvx0Af39/7rjjDubNm8fy5csZNGhQwfEff/wxAC+++GJBOAVG19FTTz3FCy+84HL+bdu2sX79esaNG1dkvai2bdvy0EMPMXXqVHbt2kWnTp2K1Dd69Gief/55Zs2aRc+ePTl58iQ//vgjjz/+OJ6enqV+bc51qu677z7AGBt333338c9//pNFixbx8MMPl/jcYcOGERoayqxZsxg+fDhZWVksXLiQoUOH0qhRo1KvW5xXXnmFFStWMGXKFJo2bcpf//pXIiIiXEbIldULL7xQEE4BBAUFMWnSJEaPHs3cuXOLDagKj72zWq288cYbPPvss2W+5vbt24uMzitNixYtyhRQpaSkALh8PYU5Q0TncaW5+eabmT17Ni+//DJvvPFGweObNm3i22+/BeDs2bMux7/22mu89dZb3HPPPQQHBwOQkZHB5MmTXWr09fWlbdu2REZGsnnzZpYuXcrQoUMLjnnrrbdISkoq9hobN27klVdeYf78+QVj/o4cOVIQchY+XiqGRvyJiIiIiFwKkwmaOl5Qxm4p9pDOTYwXb7tiS3mRdsyxmG8xAdW2Y2cBuLK5uqekermpUxgeZhO7T6ZyJDH94k+oYN/tPM3KffFYPUz85+4rsHjUzJe23lYP3r+vO02CfTiSmM6fF20j32a/+BNFpPqqFwENo4zu6IPL3V2N1DJLliwhISGB4cOH4+3tXfB4SePwduzYARjdSBcq7jHnuLO4uDj+8Y9/FLnt27cPoGB7oYYNGzJ06FAWLVpEVlYWc+fOJT8//6Lj/bZs2cKOHTsYMGCAyzpVZR3zZ7Vaue+++1i2bBmxsbEsXryYM2fOXPS6pZ1v4cKF+Pr68uSTT5KTk8P8+fNL7eIqSWl/99u2bSuyr1OnTtjtdvLy8jhy5Agvv/wyf//737nrrrvIy8sr0zUfeOAB7HZ7mW9Hjx4t99d1uV555RXCw8OZMmUKffv2ZcKECdx7773069evoHupcDdgv379GD16NNHR0URFRfHoo4/yxBNP0LlzZ06dOlUQmjmfYzKZeO+997BarQwbNoy77rqL5557jsGDB/Pss8/SuXPnItd4+umniYqK4tNPP6V79+4888wzjBs3jq5du9KiRYsix0vFUAeViIiIiMilatIN9n0Lsb8Vu7tTkyBW7otnZ0kBVX4eHN9s3G9edPFmZweVxvtJdVPfz5NrWofyy4EEvt1xkicGVl0HU1p2Hv/4ZjcAj/VvTeuGNXssXoi/Fx+M7s5d/1vPqv0JvL0imqdvbOvuskTkcrS7GeL3wP7v4Yo/uLua2s/qa3QnVXdW34sfcxHOoMYZ3DgNHDiQJk2asGTJEpKTk6lf3+jcS0lJwWw2F4w7K6y4zqLk5GQAli5dytKlS0usIz295A+njBs3jq+++oovvviC2bNn0717d6644opL+rratGlD79692bhxI7t37y51DaZx48bx1ltvMWfOHFavXk1YWBhDhgwp9bqliYyMLBiN2L17d66++upLOk9xf8/Ox0rrNPLw8CAiIoLnn38ei8XCc889x0cffcSf/vSnS6qjIjhDoJLqTk1NdTmuNE2bNmXz5s289NJLfP/99/z66680a9aMV155hYiICEaOHEnDhg1dnjNnzhx69OjBzJkzmTNnDj4+PgwePJj//Oc/dOzYEYvFUvC9DzB48GDWrFnDq6++ysqVK1m6dCmdOnVi8eLFrFixgp07d7pcIyAggHXr1vHKK6+wePFipk+fTsOGDXn00Ue55ZZb6NevX5Ga5PIpoBIRERERuVRNuhvbE8UHVBftoIrbCbnp4B0EDdq77MrKzWfPSeNFXrfmwRVSrkhFuqVzOL8cSGDpzlNVGlB99MthEs5lExHiy2MD3LP+VUXr1CSIyXd25pnPdvDOymiubRNKjwiNBROpsdreDGveNDqo8nLAUvpoM7lMJtNlj86rCY4fP86yZcsAuO6660o8bt68eTz55JOAERTYbDYSExNp0KCBy3FxcXFFnuvsEHrnnXd4/PHHL6nOIUOGEB4ezl//+ldiY2N57733Sj0+MzOThQsXAnD//feXuMbTzJkzmTp1aonn6dy5Mz179uTdd98lLi6OCRMmYLFc+lvfU6dOZd26dYSEhPDrr7/y3nvv8dhj5V9XLi4ujubNmxd5DMoW5AAMGjSI5557jtWrV5cpoKqsNaica08516K6UHR0NJ6enkW+3pI0adKEGTNmFHn8H//4B0CR8Ydms5knn3yy4Pvb6ejRo6SlpdGtWzesVtdRyb169SoYGVjYW2+9Vew1goODmTp1apHvtTlz5hR7fEUYMGAAVquV8ePHM378+Ao/f3WngEpERERE5FI1vhIwQcoxSIsHf9dP1HVqYrzIj45PIys3v+g6Us7Oq6Y94YJxETtjU8iz2WkY4EWTYJ/K+gpELtmNUY0wfwn7Tp/jxJkMmta7/E+GX0z8uSw+WnMYgOduao+XpfaszXZnt6asPZjIl1tj+fOn2/nuqWsJ9NZ6VCI1UpPu4NcA0hPg2HqI7O/uiqQWmDNnDjabjb59+9KuXbsi+/Py8pg7dy4zZ84seAO/S5cubN26lTVr1nDnnXe6HL9mzZoi5+jVqxcAGzZsuOSAysPDgzFjxvD666/j7e3NqFGjSj3+888/JyUlha5du9K9e/dij5k/fz6ffPIJr732WqlrWY0bN64gwLnU8X5gjN6bOHEi7dq1Y9WqVVxzzTVMmDCB6667rtQuruKsWbOGe++9t8hjAFdeeWWZznHypNEheGH4UpLKWoOqd+/eeHp68tNPP2G32zGZTAX7YmJi2L9/PwMGDLisYDA/P59FixZhsVi46667yvSc+fPnAzBy5MgyHR8TE8PatWuJiooqGPVX0dcoj1WrVrmMtqxrFFCJiIiIiFwq7yAIaQVJB+H0Tmg90GV3WKA3of6eJKblsO/0Obo2C3Z9/iljXQDCuxY5tbPr6oqmwS4v/kSqi3p+nnRvUY/NR8+wal88o/tEVPo1314RTUZOPl2bBXNzp7BKv15Ve3lYR7YcPcOx5Az+sWQ3U0d0dXdJInIpzGZoOxi2zTPG/Cmgkstkt9uZPXs2JpOJuXPnEhkZWexxBw4cYMOGDWzZsoUePXowevRoZs+ezSuvvMLgwYPx8zM6zWJjY/nvf/9b5PlXXXUVvXr1YuHChQwbNowRI0a47LfZbKxZs6bUDi6AZ555ht69e1O/fn2Cg4NLPdY53m/q1KkMGDCg2GMyMjJYuHAhX3/9NXfffXeJ57rvvvsICwvD29u72BCvLNLT0wtCtYULFxIeHs6CBQu49tprGTVqFL/++qvL+l8X8+qrr3LLLbe4jMf75z//iclkcukW27FjB1FRUUVCqOTkZCZOnAhQ5pGFDzzwAA888ECZayyrwMBARo4cyccff8wHH3zAo48+Chjfn88//zwADz30kMtzUlJSCtaICg8PL3g8NzeXvLw8fHzOfxDPZrMxYcIE9u/fz9NPP03jxo1dzpWamlpkHbA1a9YwefJkWrRoUVCPU1paGn5+fi6vpVJSUhg9ejT5+flMnjy5yNdY3DWmTZvG8uXLueOOO+jZs+dF/56kfBRQiYiIiIhcjkadjIAqbneRgMpkMtG2UQCJaUlExxUTUJ3cbmzDuxQ5rXO8X1Tj8i/GLFJVrm/fiM1Hz7CiCgKq0ylZfLr5OADP39y+Vga3Ad5W3hrZlbv/t54vt8Vya9fGDGintQ5EaqS2NxkB1cHl7q5EaoGVK1dy5MgRrrvuuhLDKYCxY8eyYcMGZs6cSY8ePRgwYABjx45l9uzZdO7cmTvuuIPs7Gw+/fRTevfuXezos4ULFzJgwABGjhzJW2+9Rbdu3fDx8eHYsWNs2LCBhIQEsrKySq23YcOG3H777Rf9ug4ePMgvv/xCREQE/fv3L/XrWrhwITNnziw1oPL39y/TdUvz1FNPsX//fqZMmVLQ4dS7d29eeuklXnjhBf7yl7/wzjvvlPl8bdu2pVOnTgXdQF988QUnTpzgmWeecRkXN23aNL799luuueYamjdvjo+PDzExMSxdupT09HSGDx9+0W60qvDaa6+xatUqHnvsMZYvX07r1q35+eef2bhxI7feemuRDqPFixczduxY7r///oIxeWCMOezYsSODBg2iZcuW5OTk8OOPP7Jv3z6GDh1abHh09913k5mZyRVXXEFgYCA7d+7k+++/p379+nz11VcEBLiuS/rVV18xceJErr/+eho3bkx8fDxff/01CQkJvPrqqwwbNqzINZo0acKAAQNo06YNJpOJ1atX89tvvxWsfSUVz3zxQ0REREREpESNOhnbuN3F7m7T0B8wxvy5yMuG+L3G/WICqr2nHQFVuAIqqb4GdjDCk/WHksjIyavUa81ad4TcfDtXtaxPr8iQSr2WO3VrXo8H+7YEYNLiXaRnV+7fq4hUkpb9wORhfIjl7DF3VyM1nPON8Yt1xYwYMQIfHx8WLlxIZmYmAB999BGTJ0/GZDIxffp0vv/+e5555pmCNXgu1LJlS7Zt28akSZNIS0tj9uzZfPDBB2zfvp1+/foVrBdVEWbNmoXdbuf+++8v9YMnAwcOpFmzZixbtozjx49X2PUv9MUXXzBz5kxuvPFGnnnmGZd9EydOpF+/fkyfPr3YYK8kn332Gffccw9ffvkl//vf//Dz8+Ptt99mypQpLseNHj2aIUOGsG/fPubOncu0adP4+eef6devH4sWLeKzzz6rFh/OCQ8PZ9OmTYwdO5a1a9cybdo0kpKSePXVV/n888/LXGNQUBC33XYbW7du5Z133mHWrFnUq1ePjz76iK+//hovL68iz7n99tvJzs5m/vz5TJ06lX379vHEE0+wa9cuunbtWuT4zp0706VLF5YtW8aUKVNYsmQJvXr1YuXKlUyaNKnYuu677z4OHDjA+++/z/vvv4/NZuONN95g7dq11KtXr1x/V1I2Jrvdbnd3EbXZiRMnaNasGcePH6/TsyRFREREaq1938GiUdCoM/xpbZHd8zbGMOmrXfRv14A5Y686v+PkNviwP/jUg+eOGAt8O+Tm2+j44o/k5Nv45S8DaB5S+Wv7iFwKu93Otf9ZxYkzmXw0pgc3RjWqlOukZORy9WsrSM/JZ/bYnrW+qygjJ48bp/5C7NlM/ti3JZNuiXJ3SSJyKWbcCCd+hVvfhu73X/x4KVVWVhZHjhyhZcuW5RqxJiJSlcr6s0q5gUEdVCIiIiIil6ORY6HkhH2Qn1tkd9tGxqiJ6LgLOqgK1p/q4hJOARxOSCcn30aAl4Wm9XwQqa5MJhMD2xth0ar98ZV2nXmbYkjPyad9WAD92zaotOtUF76eFv55h9GdOWvdEfY5OipFpIZpdb2xPbzKvXWIiIhUUwqoREREREQuR3Bz8AoEWy4kRhfZ7RzxF3s203VUV+GA6gJ7TqUA0D48ALPZ/aM8RErTt40RGG08lFQp58+32Vn4qzEe64/XRlaL8TZVYUC7htzcKQybHf61dC8afiJSA7UaYGwPrwZbvltLERERqY4UUImIiIiIXA6T6XwXVdyuIrvr+XkS6m/MUD9YeB2q+H3GtlHnIs/Ze+ocAB20/pTUAL0i62M2weHEdE6lZFb4+X+JTuDEmUyCfKzcckV4hZ+/Ovvbze3x9DCzJjqR1fsT3F2OiJRXk+7gGQCZZ85/MEVEREQKKKASEREREblcDdob28QDxe52dlFFOwMqux0S9jqe27bI8c4gq41jPKBIdRbobaVz02AA1h+s+C6qBZuM7qk7uzXB2+pR4eevzlqE+PHANREA/HPpHnLzbe4tSETKx8MKLfsZ9w+tdG8tIiIi1ZACKhERERGRyxXS2tgWM+IPoGUDPwBiktKNB9ITjU9TY4KQNkWOP5xgBFStQv0qvFSRynB1qxAA1lfwmL/41CxW7jPWtrq3V/MKPXdN8fj1ranv58mhhHQWb411dzkiUl7OMX9HfnFvHSIiItWQAioRERERkcvlDKiSDhW7u0V9XwBikjKMBxIc4/3qtQBPX5djc/JsHD9jjEmLbOBf8bWKVAJnQLXhUGKFrpW0dOcp8m12rmweTOuGdbOjMNDbyp+uawXAO6ui1UUlUtO0uMbYntgM+bnurUVERKSaUUAlIiIiInK5Qh1dUMmHwFb0zeMWIc6AytFBlbjf8bx2RY49lpxBvs2On6cHjQK9KqVckYrWo0V9rB4mTqZkcSw5o8LO+82OkwAM69K4ws5ZE93buzmh/p4cT87ky60n3F2OiJRHg/bgUw9yM7QOVQWpyA9CiIhUNP2MKh8FVCIiIiIilyu4OZgtxptP504W2d28vmPEn/ON+wRHQNWgaEDlHO/XsoEfJpOpcuoVqWA+nh50bBwEwNZjZyrknMeTM9h67CwmEwztHF4h56ypfD0tPOrsolp5UF1UIjWJ2QzN+xj3Y9a5t5YazsPDWIcwN1edaCJSfTl/Rjl/ZknpFFCJiIiIiFwuDyvUizDuJx0sstvZQXU2I5eUzNzzI/4atC9y7OFEo8sqMlTj/aRm6da8HgBbY85WyPmW7jwFQO+WITQM9K6Qc9Zk9/ZqQai/FyfOZBZ0lolIDdHiamMbs8G9ddRwVqsVLy8vUlJS1KEgItWS3W4nJSUFLy8vrFaru8upESzuLkBEREREpFYIaW2EU0kHIbK/yy4/Lwuh/l4kpmVzLCmDzkmHjR3O0YCFODuoIhv4VXbFIhWqW4tgZq2ruA6qH3adBmDoFXW7e8rJx9ODsddE8MaP+5mx5gh3XNlEXZYiNUVzR0B1bL0xCtisz4tfqtDQUGJjYzlx4gRBQUFYrVb9LBQRt7Pb7eTm5pKSkkJaWhpNmjRxd0k1hgIqEREREZGKENLa2CYdKnZ3ixBfEtOyOZ5whs6pscaD9VoWOS4myRgD2DJUAZXULM4Oqn2nz5GRk4ev56W/3ExKy2bHibMA3NChUUWUVyvc26s501ceZM+pVDYcTuLqVqHuLklEyiL8CrD6QVYKxO+BsE7urqjGCgwMBCAxMZHY2Fg3VyMi4srLy4smTZoU/KySi1NAJSIiIiJSEYJbGNuzx4rd3aK+L7/FnOHsqcOA3Xijyq/om8snzmQC0LSeT2VVKlIpGgf7EBbozenULH4/kULvyJBLPtfPBxKw2yEqPJCwII33cwr29eTu7k35ZGMMM9ccUUAlUlN4WKHZVXB4FcSsV0B1mQIDAwkMDCQ3N5f8/Hx3lyMiAhhrTmmsX/kpoBIRERERqQjBzYxtCQFVeLDxJnuec7xfvQi4YCRNXr6N06lZADSt51spZYpUpm4tgvlu52m2HjtzWQHVyn3xAFzfvmFFlVZrjL0mgnmbYlixL57DCWlENtB6dSI1QvPeRkAVuwV42N3V1ApWq1VvBouI1HAaeisiIiIiUhGCmxvblOPF7g4LMjqiPM7GGA/UiyhyzOnULPJtdjw9zDTw96qMKkUqVZemwQDsjk295HPk5dv45UACAAPaN6iIsmqVyAb+DGhnBHefbi7+542IVENNehjb2N/cW4eIiEg1ooBKRERERKQiBDk6qDLPQPa5IrsbO8aU+aQ73lAuJqCKdYz3Cw/2xmzWgt9S80Q1Nubt7zl16QHVrpOppGblEehtoWuzehVVWq0ysqfx8+bz306Qk2dzczUiUiZNuhnbpIPG7woiIiKigEpEREREpEJ4B4J3sHH/bNGuhnBHB1VwtmNB7+ICqrNaf0pqtqhwI6A6kphOWnbeJZ1j0+EkAK5qGYKHgtpiXd++IQ0DvEhKz2H53jh3lyMiZeFbH+q1NO7HbnVvLSIiItWEAioRERERkYriXIeqmDF/4Y4OqrB8x5vJxQRUJxwdVE2CFVBJzRTi70VYoPG9vu8Su6g2OgKq3pH1K6yu2sbiYWZ4j6YALPy1+HXvRKQaatLd2CqgEhERhwEDBhAVFcW7777r7lLcQgGViIiIiEhFCW5hbM8WfcM42NeKt9VME1OC44HmRY6JLQiofCutRJHK1vEyxvzl2+xsOWqMvurVMqRC66ptRvQwfoasPZhY0H0pItVcU61DJSIirlatWsWePXsYP368u0txC4u7CxARERERqTWCSu6gMplMRASaCUrPMB4IDC9yjPNN5iYa8Sc1WFTjQFbsi2d3bPkDqj0nUzmXnUeAl6VgPSspXvMQX3q1rM+mI8l8u+Mkj1zXyt0lFctut7P12BlW7I1ny9EzHE1KJyEtGxPg52UhIsSPqPBA+rYJpX+7BgR4W91dskjlKeig+g3sdjBpjKmIiNRtCqhERERERCpKQJixPXe62N1R/mmQDnkePli8ir75HpeaBUBjxzhAkZrI2UG1+1RKuZ/769FkAHpE1NP6U2UwrGtjNh1J5utqGFDl2+x8sfUEH/x8iEMJ6UX224FzWXnsjE1hZ2wKn245jo/Vg2FdGvNQv0haN/Sv+qJFKltYZzBbID3e+DBLMd3UIiIidYkCKhERERGRinKRgKqVl9FRkubZkOBiPjUdfy4bgAYBXpVTn0gVaBdmBFQH49Ow2eyYyxE0/X7iLABXNq9XGaXVOkM6hfPSkt3sPpnKwfi0ahPq7Dh+luc+/539cecA8PX0YFBUI65uHUr7sAAaBXpjAlIyczmcmM6Wo8ms2BvP4cR0Pt1ynP/77TjDuzfjuZvaEeKvn4dSi1h9oGEHOL0TTv2ugEpEROo8BVQiIiIiIhXFGVClxRW7u6nlLABnLKEEX7AvOy+flMxcABoGqINKaq5m9Xzw9DCTlWsj9mwmzeqXfU21nSeMrqsrmgZVVnm1Sj0/T65tE8qq/Ql8veMkz9zY1q312O12/vfzId5cdoB8m51gXyvj+7dmVK/m+HsVffuhYaA3bRoFMLhjGBOHdGDz0TN8+Mthlu+N49Mtx1m+N45/3t6JmzsXHYkqUmM16mwEVHG7oMMt7q5GRETErczuLkBEREREpNbwL72DqhHG+LJEc0iRfQmO7ilPDzOBPvocmdRcFg8zEaFGKHUoIa3Mz3N20wBc0TS4MkqrlW7r2gSAb3ecdGsdufk2nvv8d/7zw37ybXaGXhHO6gn9eahfZLHh1IVMJhNXtazPjPt78PmjfWgfFkBSeg5/mr+Vf3y9m9x8WxV8FSJVoFFHY3t6p3vrEBERqQYUUImIiIiIVJSARsY26yzkZhbZHWpPAuC0rej4soRC4/1MWjRdarhWDYxRcwfjyx5Q7Y41uqea1vOhvp9npdRVGw3s0BCrh4nDienlCgQrUl6+jScWbOP/fjuBh9nEP2/vxPRRVxLse2n/HXtE1GfJ49fwWH9jXa05649y/6xfSc3KrciyRdwjrJOxjdvl3jpERESqAQVUIiIiIiIVxTsYPBzrpRQz5i8oNwGA43nBRfZp/SmpTZxrIR1KSC/zc3Y4xvt1UfdUuQR4W+kdaXRlLt9T/HjRymS323n+y538sPs0nh5mPhrTnft6t7jsoN3L4sFzN7Xng9Hd8fP0YP2hJO6bsYmzGTkVVLmImzTqbGzPHIXsc24tRURExN0UUImIiIiIVBST6XwX1bmibxT75RgB1dGcouvrJCigklrE2UF1qBwdVLscHVSdtf5Uud0YZfzcWb636gOqD385XNA59c49V3J9+0YVev7BHcP47NE+1Pfz5PcTKYz8cKNCKqnZ/EIgwLGuWtwe99YiIiLiZgqoREREREQqkvNNp7Si61B5ZRoB1aGsAGw2u8s+ZwdVQwVUUgs4O6gOlmPk3L7TqQBEhQdWSk212cAORij0W8wZktOrLrzZcCiJ13/YB8A/hnVkcMewSrlOx8ZBLHq4Nw0CvNh3+hwPf/wbWbn5lXItkSrRyDnmzw3rUNnyIX4v7P0GtsyGXz+CbfPgwI+QGA35GqUpIiJVR6svi4iIiIhUJP8SOqjsdsyZiQDE2QJIycylXqF1dtRBJbVJZAM/AJLTczibkXPRtYiy8/I5mpQBQNtGAZVeX23TJNiHqPBA9pxKZdW+eO7q3rTSr5malcvTn27HZoc7uzXhvl7NK/V6bRsF8MmDVzH8fxv49Wgyz3y2nemjumE2a80+qYHCOsHBn+B0Fa1DZbfDoRWwfaERROWUMlrQ4g1Ne0KLq6HdEAjvYnSIi4iIVAJ1UImIiIiIVCS/BsY2I9H18Zx0THlZACTbA0lKz3bZnZimgEpqD19PS8H38vHkzIsefyQxnXybnQBvC40C9f/ApbihQ0MAVu2Pr5LrTf5uH6dTs4gI8eVft3e+7DWnyqJ9WCAfjOmOp4eZ73ae5n8/H6r0a4pUioIOqt2Vf62j6+CjATDvLtj1uRFOWf2gSXdoezN0GAatbzTWxrL6QV4WHF0DP78OH14H/+0CP70ESfr/TUREKp46qEREREREKpJviLHNSHJ93BFYZeNJBl4knMuhdcPzu884xnLVv0iniUhN0ayeDwnnsjmWnHHRdaX2nzY+zd+2UUCVBB21Ud82DXh75UE2HErCZrNXamfRxsNJLPz1GACv3XUFPp4elXatC13dKpRXb+/IX7/YyZvL9nNl82CubhVaZdcXqRAN2hnbxANGd1Nl/NzLz4Xl/4AN7wJ2I3y68j7oPBwaXwkexbwlaLcbNcWsh8OrIPonOBsD694ybpH9occ4aH8LmKvu/3sREam91EElIiIiIlKRSgqo0o0/nzMHAiZSMl3XiTmbaaz5EORrrewKRapE8/q+ABxLzrjosdFxxlpVbRv5V2pNtVnXZsH4enqQlJ7DvtOljO+6TDabnX8u3QPAqKua0zsypNKuVZIRPZszvHtTbHZ4cuG2gg5UkRojpDVggqyzkJ54saPLLzsNFo6CDdMBO1w5Gp7aAUP+A816Fh9OgRGUNWgHPcbCHz6GvxyE4XOhzSCj3sOr4bMxML0H/DYH8vT/noiIXB4FVCIiIiIiFcm3vrHNSHZ93NFBlW4JBuBshusi5GczjMCqnjqopJZwBlTHz1w8oDoQd76DSi6Np8VMr5bGz591ByvhDW+Hb34/ya7YVPy9LEwY1LbSrnMxr9zWifZhASSm5fDCV7uw2+1uq0Wk3Kw+EOxYty3xQMWeOy8bFo401riy+MAfPoHbpoN/g/Kfy9MPOt4O9/6fEXD1fQZ86kHyYfjmKXjrClj3NuSkV+zXICIidYYCKhERERGRilRSQOX4hHSmtR4AKZnnAyq73V4QWCmgktqiqTOgKkMH1cEEo4OqTUMFVJfjmtbGqLt1hyonoMrNt/HGj/sB+FP/VoT4u2+9MB9PD978QxcsZhPf7zrN0p2n3FaLyCUJdQS8ifsr7px2O3z1mLGGlGcA3P8NRA2rmHPXawE3vAR/3gWDJ0NgE0g7DT+9AP/tCps+UEeViIiUmwIqEREREZGKdJE1qHI8iwZUadl55NmMT/8Ha8Sf1BLNyxhQ5dvsnEjOBKBFiG+l11WbOQOqTYeTycmzVfj5v9lxkhNnMgn192LcNS0r/Pzl1bFxEI8NaA3AC1/tKljLT6RGKFiHKrrizrl5Buz6HMxWGPGJMc6vonn5Q5/H4MntcNu7ENwC0uPh++fg7W7G6L/83IudRUREBFBAVWb//e9/adGiBd7e3vTt25cdO3a4uyQRERERqY4KB1SFR045OqjyvI39hQMqZ/eUt9WMt1WLjkvt0MwRUJ04k0m+reTxa6dTs8jJt2H1MNE42KeqyquV2jUKoL6fJ5m5+ew6mVKh57bb7Xzw82EAxl4TgY9n9fhZ9fiA1rQPC+BMRi5v/lSBnSgilS20jbGtqBF/Cfth2STj/qBXodWAijlvSSyecOV98PgWuGUaBDSG1BPG6L/pPWDHIrDlV24NIiJS4ymgKoMFCxbw17/+lVdffZXffvuN1q1bM3jwYFJTU91dmoiIiIhUNz6OEX/52ZBbqHPE0VFl9ysaUJ1xrD8V7KPxflJ7hAV6Y/UwkWezczo1q8TjYpKMtUua1vPFw2yqqvJqJbPZRLfmRpfmb0fPVOi5Vx9IYH/cOfw8PbivV4sKPffl8LSY+cewjgAs2HSM3RUczIlUGueIv4QKCKjsdvhuAuRlQesboNejl3/OsrJ4Qo9x8OQ2Y/SfXwM4cxQWPwL/uxr2LAFbxXd0iohI7aCAqgymTZvGo48+ypgxY+jYsSMzZswgLy+PBQsWuLs0EREREaluPP3Aw7EuS+Exf477Jt+SO6g03k9qEw+ziYYB3gCcTik5oDqWZAS5zpGAcnl6RDgCqpiKDajmb4wBYORVzQmqZj+rekeGcMsV4djs8PLXe7DbS+7YE6k2Qh0j/lKOQc7F1+or1d6v4cgvYPGGoVPB5Iaw3+p9fvTfwBfBOxgS9sFnY+DD6+DAMtfOchERESoooHr99dcxmUyYTCY2btxYEacsl3nz5vHII4/Qo0cPvLy8MJlMzJkzp9TnbN68mSFDhhAcHIyfnx+9e/fms88+K3JcTk4O27Zt44Ybbih4zGKx0L9/fzZs2FDRX4qIiIiI1HQmU/HrUGWeBcDqX3IHVT1fdVBJ7RIWZARUcaV1UDnWqNL6UxWjewsjoNoSc6bCgprTKVms3BcPwD29mlfIOSvaxCEd8Laa+fVoMqv2x7u7HJGL8wsB7yDj/pmjl34eWz6seMW4f81TUM/NHY5e/nDts/DUDrjur+DpD6d/hwXDYdZgI0gTERFxuOyAateuXbz00kv4+flVRD2XZNKkSXz44YfExMQQHh5+0eNXrVrFNddcw9q1a/nDH/7Ao48+yunTpxkxYgRvvvmmy7GJiYnk5+fTqFEjl8cbNmzI6dOnK/TrEBEREZFawifY2DpCKQCyjfHQnv7Gm8fqoJK6ICxQHVRVrXOTIKweJhLTsjmenFkh5/xi6wlsdugZUY9WDfwr5JwVrXGwD/dfHQHAm8sOYCtl3TORaqNeS2N7OQHVniWQdNDoWLr6iYqoqmL4BMOAifDU73D1k2DxgeObYO6tMHcYHN/s7gpFRKQauKyAKjc3l/vvv5+uXbtyxx13lOu5n3zyCTExMSXuz8/PZ+rUqeTk5Fz0XDNmzODo0aMkJCTw6KOlz9nNy8vjoYcewmw288svv/Dhhx/y5ptvsmPHDtq2bcvEiRNLrUtERERE5KK8Ao1tdqE1S7OMdVF8AooGVKmO+0E+CqikdmkUWJYOKmMNqhYh7vvQY23ibfWgUxOjK+O3Y8mXfT673c5nW44DMKJn9eyecnq0Xyv8vSzsPpnKj7v1gVKpAepFGNtLDajsdlgz1bjf+0/gFVARVVUsvxAY9Co8tR16PgRmKxz5GWbeAAtGwKnf3V2hiIi40WUFVP/617/YvXs3s2bNwsPDo8zPO3HiBA899BD9+/cvNgyy2Wzcf//9PPvss3z44YcXPd8NN9xAixZla2FeuXIlhw4d4p577qFr164FjwcFBTFx4kRycnKYO3duweOhoaF4eHgQFxfncp74+HjCwsLKdE0RERERqWO8nQHVufOPZRlhlV+gMeIvNTO34BP+aTl5APh7WaquRpEqEBZkrMd2upSAKvaM0eXTrL5PldRUF3RvXnHrUO2MTSEmKQMfqwdDOlfv18D1/DwZ19foSJm2XF1UUgNcbkB1fBPE7QSrL1z1cEVVVTkCwmDoFHhyK1w5GkwecOAH+OBaWDASjm1yd4UiIuIGlxxQbd26lX/961+89NJLREVFleu5TZs2ZeHChZw4cYIBAwZw7Nixgn3OcGr+/PmMGTOGxx577FJLLNbq1asBGDRoUJF9gwcPBuDnn38ueMzT05Mrr7ySFStWFDyWl5fH6tWr6dOnT4XWJiIiIiK1hPMTzI5QCls+5BhhlV+QEVDZ7JDuCKbSsoytnwIqqWUaXWTEX1ZuPmccIy7DgxRQVZQrmgUDsCs2tfQDy+C7nUYn0vXtG+LrWf1/Rj3YtyUBXhYOxKWx+oDWopJq7nIDqq0fG9uOd4Jv/YqoqPIFN4fbpsP4X6HT3YAJDnwPswbBrJvhwI9GZ5iIiNQJlxRQZWdnM2bMGLp27cpzzz13SRe+4447WLhwIcePH6d///4cO3YMm83GAw88wLx587j33nuZPXs2ZvNlL5PlIjo6GoA2bdoU2RcWFoa/v3/BMU5PP/0077//PvPmzWPPnj08/PDDWCwW7rnnnhKv8+677xIVFUX//v0rtH4RERERqQEuHPFXaNSfl38wFrMJgPTsfMfWCKgCvKv/m78i5RF2kRF/zuDKx+pBoL7/K0ynxsbPoH2nU8nLt13yeex2O9/vOgXAzdW8e8opyMfKqF7GKMIPfzns5mpELuJyAqqsFNi92LjfbUxFVVR1QlvD3TPh8S1G/R6ecGw9LPgD/O8a2LEI8rLdXaWIiFSyS0p/XnzxRaKjo5k9e3a5Rvtd6O6772bevHkcO3aMAQMGMGrUKD755BNGjhzJ3LlzKzycAkhJMWb/BwUFFbs/MDCw4Bine+65h9dee42JEydy5ZVXsn//fn788UcCAwNLvM748ePZs2dPQceWiIiIiNQhzhF/zg4qx/pTWHwwWbzw9TR+hy7ooMrWiD+pncKCHB1UqVnYi/lEvHP0X3iQNyaTqUprq80iQvzw8/QgK9fG4cT0Sz7P3lPniEnKwMtiZkC7hhVYYeV64OoILGYTGw8n8/uJs+4uR6RkzoDqbAzYyhkm71sKuRkQ2haaXVXhpVWZ0NYw7B146ne4+knw9If43bD4EZjWEVb+E1Ji3V2liIhUknInQBs2bGDKlClMmjSJTp06XXYBI0aMYO7cuRw+fJjPPvuM22+/nXnz5l1W8FUZnnrqKY4dO0Z2djbr1q2jS5cu7i5JRERERKqrgg4qRzDlDKocwZUziHJ2TjkDKo34k9rGOeIvK9dGamZekf3ODirncVIxzGYTUY4uql2xKRc5umQr9hprMfdr26BG/XxqHOzDsC6NAfhozRE3VyNSiqBmxlpMeVmQFnfx4wvb+42x7Xgn1IaAPzAcBr0KT++C61+AgMaQngC/vAFvdYbP7oej6zT+T0SklilXQJWXl8f999/PFVdcwd/+9rcKKcBut7Ny5cqCP+/evZu4uHL+o1wOzs6pC7uknFJTU0vsrhIRERERKRNvx++TF3ZQOR73LQiojBF/BR1UGnEmtYy31aNgdGVCWtFRTadSzndQScXq2Nj4ebP75KWvQ7UmOhGA/u0aVEhNVWlc35YA/LDrFInFfO+JVAseFghuZtw/U44wNTsNDjrWSu9wa8XX5U4+9aDfBPjz7zB8LrS4Buz5sOcrmDMEpveEtW/BudPurlRERCpAuQKqtLQ0oqOj2b59O56enphMpoLb3LlzAejTpw8mk4mvvvrqouez2+08/PDDzJo1ixEjRjBv3jwOHz7MgAEDOHny5CV9QRfjXHvqwnWmAE6fPk1aWlqx61OJiIiIiJRZSWtQOQIqvws6qJxBlUb8SW0U4ucJQHJ6TpF9zrWpGimgqnAdL7OD6lxWLluPnQGgX5uaF1B1ahJEl2bB5Obb+eK3E+4uR6RkQY6AKqUc36cHf4L8bKjXEhp1rJy63M3DCh1vh7HfwaProPsDYPWFpGhY/hJMjYIFI4xOsryi/76IiEjNUK5XwF5eXjz44IPF7vvll1+Ijo5m2LBhNGjQgIiIiFLPZbfbeeSRR5gxYwZ/+MMfmD9/Ph4eHpjNZkaPHs2AAQNYvXo14eHh5Snxoq677jomT57MsmXLGDlypMu+H3/8seAYEREREZFLVtIaVI7gyu+CNajOZWkNKqm9Qvy9OJqUQVKxHVSZgDqoKkOnJkYgvudUKna7vdxrfG08nEyezU5EiC/N6vtWRomV7p6rmrHj+FkW/nqMh/tFap0zqZ4Cmxjb1HKss3RwubFtP7R2jPe7mLBOcOt/YdA/Yfdi2DYPjm+CAz8YN99Q6HgHdL4bml4FlbCmvYhIZRkwYABWq5Xx48czfvx4d5dT5cr1CtjHx4cZM2YUu++BBx4gOjqa559/nt69e5d6Hrvdzp/+9Cc++ugjl3AKYNSoUQAuIVVYWFh5yizVwIEDiYyMZMGCBTz55JN07doVMEb+/fvf/8bT05MxY8ZU2PVEREREpA4q6KA6Z2yzSuqgyndsFVBJ7VXf0UGVVEwH1elUI7TSGlQVL7KBHx5mE+ey8og/l13uv+M10QmAsf5UTXXLFY159du9HE3KYMPhJK5uFerukkSKCnIGVOWYJHTkF2Mb2b/Cy6nWvAKg2xjjlnAAts+HHQuN9bs2f2TcgppBpzuh090Q1rluBHgiUqOtWrWKpk2bursMt3HLK+CTJ0+yePFihg8fzvz587FYXMsYNWoUdrudMWPGsHz5cu67775SzzdjxgzWrl0LwM6dOwseW716NQB9+/blj3/8IwAWi4UZM2YwePBg+vXrx8iRIwkICOCLL74gJiaGKVOmXLT7S0RERESkVN4ljfi7oIMqO4+8fBuZuRrxJ7VXaSP+nF1Vof5eVVpTXeBl8aBFfV8OJ6YTHZdW7oBq0+FkgBod6vh5Wbita2PmbzrGZ5uP1+ivRWqxwMbGNqWMHVTJR+DsMTBboHmfyqurumvQFm58Ga5/AQ6vhl2fw95vIeU4rPuvcQttC53uMm6hWs5DRKQ6cssr4CZNmrBhwwaaN29eJJxyuueee+jRowdt27a96PnWrl1bsAaW07p161i3bl3Bn50BFRhtc2vXruWll17i008/JTc3l86dO/P6668zYsSIS/yqREREREQcPP2NbXaasc1Jc3m8oIMqJ4/0nPyCp/kpoJJaqH6pAZXxWKi/Z5XWVFe0buhvBFTx5+jbpuzhTEpmLgfijQ7QHhH1Kqu8KnFX96bM33SMZXviyMzJx8fxAQGRaiPQ8an5so74c3ZPNekBXv6VU1NN4mGBNjcYt1syIXoZ7PwcDvwIiQdg9WTj1qgTRN1urGulsEpEpNqosFfAc+bMYc6cOWU+PjIy8qLHlCWcupRrA1x11VV8//335XqOiIiIiEiZePoZ29wMsNshJ8P4s9VYx+X8iL880hzj/TwtZjwtWjNBap+SRvxl5uQXdA86j5GK1aaRP8v2xBEdn1au520/fha7HSJCfGt8d9uVzYJpVt+H48mZLN8bx61dGru7JBFXzg6qso74i3F8GLvltZVTT01m9YGo24xbVirsW2p0Vh1eDXG7jNuqf0LDjkZQFXW70YklIiJuo1fAIiIiIiIVzerjuGOHvCwjqALwdARUns4OqnwyHAGVnz7VL7VUiL+zgyrb5fEkx589Pcwab1lJ2jQMAOBgXPkCqt9izgDQrUXN7p4CMJlMDHOEUku2l2ONH5GqEuhYgyo9HvKySz8WIPY3Y9usV+XVVBt4B0LXUXDfFzAhGoZNh9Y3GKMR43fDqn/Buz3hvT6w+nVI2O/uikVE6iQFVCIiIiIiFc3RKQUY3VM56Y7Hjc4qP6/za1A5O0h8rAqopHYK8TM6cJzj/JycI/9C/D0xaRH7StG6oTH+K9oxrq+stjoCqu61IKACuK2rEQD8fCCelIxcN1cjcgHf+mBxrBF37lTpx2aehaSDxv3G3Sq1rFrFtz50G30+rLrtXWh9oyOs2gOr/w3vXqWwSkTEDRRQiYiIiIhUNLPH+TebctOLdFD5OjuosvPJyrUB4K0OKqmlShrx5wysNN6v8rRq4I/JBGcycklMK0NnBmCz2dl+/CwA3ZrXjoCqbaMA2ocFkJtv58c9p91djogrk+n8mL+Ui6xDdXKbsa0XAX4hlVpWreVbH668D+77HP5yEG57D9oMArPVNax6tzesfg3i97m7YhGRWk0BlYiIiIhIZXB2UeVkFFmDyttq/BqenXd+DR5viwIqqZ2cAdTZjBzsdnvB487ASgFV5fHx9KBxkDFyNCYpvUzPiUnOIC07D2+rmbaNAiqzvCp1U6cwAJbviXNzJSLFcI75u1gHlXO8X5PulVtPXeFTD668F+79P/hLtGtYlbAXVk+G93rBu71g1WSI3+vuikVEah0FVCIiIiIilcEZUOVmGF1UAJ7GiD9vxzi/7FwbWc6AyqpfzaV2CvSxApCbbyc7z1bwuHNNqlB/L7fUVVc0q28EVMeSM8p0/J6TqQC0CwvEw1x7Ri/eGNUIgDXRiQU/d0WqDf+GxjYtvvTjnB1UGu9X8S4Mq27/H7QZ7Air9sHPr8F7vRVWiYhUML0KFhERERGpDJ6FAqoLOqi8LMav4Vl5+YUCKnVQSe3k5+mBM+dIzTy//o9G/FWNFvWNYDwmqYwB1akUAKLCAyutJneICg+kcZA3mbn5rD+U6O5yRFz5NTC26RcJqBIc4+Yadazceuo6n3rQ9R649zNjDODt70Pbm4qGVdOvglX/hrg9UKhDWEREyk4BlYiIiIhIZXAZ8efsoHKO+CvaQeWjgEpqKZPJVNBFlVI4oNKIvyrRPMT4uXOsrAGVo4MqKrz2jPcD4/vwBkcX1U97LhICiFQ1Z0CVllDyMblZkHzYuN+wQ+XXJAafYOg6Cu751DWs8vCExP3w8+vwvz7GulUr/6WwSkSknBRQiYiIiIhUhoIRf+nnR/xZnSP+CndQ2RyPKaCS2ivQ2wioUrPOB1TOsCrY1+qWmuqK5vUdAVVZR/ydcgRUjWtXBxXADR2MgGr53jiX9dBE3M454q+0DqqkaLDbwDsI/BtVTV3i6sKw6o4PoO3NjrDqAPzynwvCqt0Kq0RELsLi7gJERERERGqlghF/medH/Hk6R/wZYVRW7vkRf15ag0pqsSBHB1VqZl7BY85xf87wSipHC0cHVUwZAqrk9BziUo21wdqF1b6AqldkfXysHiScy2Z/3Dna18KvUWoovzKsQRXvGO/XoAOYas/6cDWWdxB0GWncslJg/w+w5ys4uPx8WPXLfyCkDXS8HaJuN0Yz6r+diIgLvQoWEREREakMzg6q7HOQn+147IIOqlybOqikTgj0MT4bWbiDKjUrz7FPAVVlcq5BlXAum8yc/FKPPZSQBkCTYB/8vWrf51m9LB5c1bI+AGujtQ6VVCP+zjWoShnxl7DX2DZsX/n1SPl4B0GXETBqoaOz6kNoN8TorEqKhl/egPevgek9YMWrcHqXOqtERBwUUImIiIiIVAZP401h0gu9CXpBB1V2Xj6ZWoNK6oCCEX+F1qA6l+XsoKp9QUh1EuRrLfg7vtiYv8OOgCqygV+l1+UufVuHArDuoAIqqUacHVTpCSUHF4kHjG0DBVTVmktYdQju/AjaDQUPL0g6CGumXBBW7VRYJSJ1mgIqEREREZHKYPUxtgWfhjaBxRs43y1ldFDlOx7Tr+ZSe51fg6qYEX/qoKp0zR1j/k6cuUhAlWisl9eqgX+l1+Qu1zgCqk1HksnJs7m5GhEHP0cHVX4OZJ0t/pgzR41t/ciqqEgqgncgXPEHGLXA6Ky6cwa0v+WCsKovvNMdVrwCp35XWCUidY5eBYuIiIiIVAbniL+MxPN/dqw7UDiMcr5J721RB5XUXgUj/hzf7zabnXPZRlgVoA6qShcWaATmp1KySj3ucIIRULUMrb0dVO3DAqjv50lGTj47Tpx1dzkiBqs3eAUZ99OKGfNnt8OZGON+vYgqK0sqkHcgXDEcRs4vGlYlH4I1b8IH18I73WD5ywqrRKTO0CsBEREREZHK4AyoMs8aW4tXwS6vQmHUWWdApRF/Uoud76Ayvt/TcvIK3ndz7pPKEx5kdG+evmhAVftH/JnNJq5uFcK3v59ibXQiPSPqV9m107Lz+HbHSVbvT+BIYjrZefnU9/OkS7NgBncMo1fL+pgcH2SQOsi/AWSnQHo8NGjrui/zDGSnGveDm1d9bVKxnGHVFcONtUoP/Ai7F8PB5ZB8GNZONW71IyHqduh4O4RdUfBBJxGR2kQBlYiIiIhIZbB4Gtvsc44/nw+orB4mzCaw2eFsRg4A3p4KqKT2co7xS83Mc2yNoMrTYlY4WwXCHAFVaR1Uefm2gjWqImvxiD+AXpFGQPVbzJkquZ7NZufjDUd5a0U0ZzNyXfYdTcpg67GzzF53lA7hgbwwtANXO8YQSh3j18AY+5ZeTAfVmSPG1j/s/AhhqR28AqDz3cbNGVbt+Qqif3INq+q1NIKqqNshvIvCKhGpNRRQiYiIiIhUBg9HIOX8xHOhgMpkMuFt9SAjJ/98B5VF07el9vJ1BLAZOc6Aytiqe6pqFHRQpWaWeEzs2Uxy8+14WcyEB3pXVWlu0aNFPQC2HTtDXr4Ni0fl/fxNy87jsflb+eWAETq0DPXjziubcEWzYHw9PTh5NpN1BxP59vdT7D2Vyj0zNnFvr+a8cEuUwtu6xsfRzefsvC7Muf6UxvvVbi5hVRoc+OF8WHXmCKydZtzqtYSo26DjHQqrRKTGU0AlIiIiIlIZLI43eJ0dVB5eLrudAZXzjXpPBVRSi/l6Gi8903PygfOj/pxrU0nlKksH1fFkI7xqVt8Xs7l2v9nZtlEAAV4WzmXnsT/uHB0bB1XKdVIychkzaxM7TqTgbTUzcUgH7u3VAo8L/n5v69qE52/uwLTlB/h4QwzzNx1j18lU5o7tSbCvZ6XUJtWQjxGckllMZ58CqrrHy981rIr+EXZ/dT6sWveWcasXcX4MYHhXhVUiUuPoVbCIiIiISGVwjvjLSnX9s4PVw3gDwdlR4lmJn+AXcTdfL9cOqnNZ6qCqSo2DjJFgp1OysDsX/7pA7FljvF+T4No/PszDbKJr82CAShvzl5tv40/zf2PHiRTq+Vr57JE+jOkTUSSccqrn58krt3ViztieBPta2XH8LCM+2EhiWnal1CfVkE+wsS0uoDp73Nhq/am6ycsfOt0FIz6BvxyEu2cbHVQWHyO8XPcWfNgf3u4KP70IJ7dBCT/rRUSqG70KFhERERGpDM6OqXzHm4sW15FZVkcgleHoKLEqoJJazM/RQZWRbXy/n3N0UAV4q4OqKjg7qDJy8kl1hIMXij1rdFc1qVf7AyqA7o4xf5UVUE3+bh/rDyXh5+nBgod6c0XT4DI9r3+7hnz2SB8aBnixP+4cD87dQqbj3wmp5UrroDp32tgGNq66eqR68vKHTnfCHz6G5w4VE1b91wir/tvFCKtityqsEpFqTa+CRUREREQqg8V1pN+FI/4u7JiyeGgki9Re59egynfZ+miNnSrhbfWgnq/RrXa6hDF/sWeMEX91oYMKoEcLY72fLUcrPqDacCiJWeuOADBtRFc6hAeW6/ltGwWw8OHeBZ1Uf/50W4mdb1KLlNZBde6UsQ0Ir7JypAbw9HMNq4bPMcb9WX3hbIwRVn00wAirlr0Asb8prBKRakcBlYiIiIhIZbgwoLpgxN+FgZRG/Elt5gyo0h0j/pwdIc7HpfI1CjS6qE6nlhBQ1aERfwBdmwdjMkHs2UwSzlXcGL2s3Hz++sXvAIy6qjmDOoZd0nlaNfDnozE98PQw8+PuOGavO1phNUo1VdBBdbboPmcHVUCjKitHahhPP+h4B/xhrjEGcPhc17Bq/dvw0fXw3ytg2SSFVSJSbehVsIiIiIhIZfC4MKAqfsRfwZ8t+tVcai8/L8eIv5x87Hb7+Q4qT434qyqh/sbPpKQS1jSKPevooKojI/78vSy0DPUDYPfJlAo778cbjnIsOYOwQG8mDml/WefqGVGfvw/tAMDk7/eyK7bi6pRqyBlQZZ11fTw/D9LjjfvqoJKy8PSDjrc7wqpDRljV8Q5HWHUM1r/jGladUFglIu6jV8EiIiIiIpXhgo4pPFz/fGFAZTFrxJ/UXs5OqXybnew8Gxm5eS6PS+UL8Td+BiWl5RTZl2+zF4z+a1xHOqgAOjUOAmD3ydQKOV9KZi7vrjoEwDOD2hLgbb3sc47p04JBUY3Izbfzl89/JzffdtnnlGqqpDWo0hPAbgOTGfwaVH1dUrN5+hph1fA5Rlj1h4+LhlUzroe3roAf/w4ntiisEpEqpYBKRERERKQyXNAxVbSDynTBn/WrudRevoU6pTJy8snINjqo/BRQVZkQP6ODKjG9aAdVwrlscvPteJhNNArwKrK/turUxFgbqqI6k+auP0pKZi5tG/lzV7emFXJOk8nE5Ds7E+xrZe+pVGatPVIh55VqqKSAyrn+lH8jMOtnplwGT1+Iuu2CsOpOsPpByjHYMB1mDIS3OiusEpEqo1fBIiIiIiKV4YKOqQs7qi4MpDw14k9qMQ+zCW+r8T2enp2nEX9uEBpg/AxKPFe0gyrOsS5VA38vLHUoLHd2UO2qgBF/2Xn5fLwhBoDxA1rjUYFdsSH+Xky82Rj1N235AU46xjFKLeMMqHIzILfQWnEF609d2npmIsUqCKtmG2tW/eET6HSXI6w6XjSsOr5ZYZWIVIq685uniIiIiEhVslzQhXDBmlQa8Sd1jbOLKjM3n0yN+KtyoY4OqqQSOqgAGgbWne4pgI6OgOp4ciYpGbmXda6vt58kMS2b8CBvhnSu+HWChvdoSs+IemTl2pj204EKP79UA54Bxhg/cF2HKs0RUPkroJJK4ukLUcPg7lnw3CEYMa9oWDXzBpjWCX6YCMd/BZvGjYpIxVBAJSIiIiJSGS4MqCwXBlQa8Sd1izOMcu2gUkBVVUpbgyohzQioGvjXrYAqyNdKs/rGmlu7L7OLav6mYwCM6RNRKT/PTSYTzw8xuqi+2HqC/afPVfg1xM3MZvAONu4XHvOXnmRs/UKrvCSpg6w+0OHWC8Kqu8HTH1JPwMZ3YeaN8FYn+OF5hVUictn0KlhEREREpDJ4XCyg0og/qVv8HB1U6dn5BQGVOqiqTqgjfEpKK7mDKrSOBVQAUeHGOlT7LiPwOZyQxvbjZ/Ewm7i7e8WsPVWcbs3rcXOnMGx2eOPH/ZV2HXEjrwBjm512/rHMZGPrW7/q65G6rSCsmmmMARwxHzoPd4RVsbDxPdew6tgmhVUiUm56FSwiIiIiUhk04k/EhbcjjMrKzSdTAVWVc3ZQJablYL9gHRFnQNUgoO4FVG0aGoFAdHzaRY4s2VfbYgG4tk1opf8dThjcDrMJlu+NY++p1Eq9lrhBQUBV6L9thiOg8lFAJW5k9YEOt8BdM+Avh2DkAkdYFXA+rJo1CKZ1hO//Bsc2KqwSkTJRQCUiIiIiUhkuMuLPcuGIP3VQSS3n5fgez86zkZFjrEHlY7W4s6Q6JcSxBlVOvo3UrDyXfYlpdTigauQPwMH4S+ugstvtfLX9JAB3XNmkwuoqSasG/tzsWOPqg58PVfr1pIp5Gt+P5BQKTDMcI/7UQSXVhdUb2g91hFUHHWHVH4yw6txJ2PQ/mDXYEVb9VWGViJRKr4JFRERERCrDRUb8eV7QQWU161dzqd3OB1TnO6j8vNRBVVV8PD3wsRp/3ykZuS776nIHVeuGRiBwIC6tSGdZWeyPO8ex5Ay8LGZujGpU0eUV60/XtQLgm99PcTw5o0quKVWktBF/6qCS6qggrPrIEVYthCtGgFegI6x63zWsitmgsEpEXOhVsIiIiIhIZfCwuv75ImtQWT004k9qNy+LEY5k59nIyNWIP3cI8jF+LqVkXhBQ1eEOqlYN/DGZjL+TxLSccj9/xd54AK5pHYqvZ9V0BHZqEsS1bULJt9mZs/5olVxTqoiXo4Mqu1BHn3PEn29I1dcjUh5Wb2g/BO780AirRi0qGlbNvgmmRcF3z0HMeoVVIqKASkRERESkUphMYPE+/2ePkkf8mUzgoTWopJbztjo6qHLzyXB0UPlU0Rv6YigxoHJ2UPnXvYDK2+pB8/q+AERfwpi/FXvjABjYoWGF1nUx465pCcDnv50gyxH4Si1QMOKv0Peis4NKI/6kJrF4QbubLwirRjrCqlPw6wcw+2aY2gG++4vCKpE6TAGViIiIiEhlKRxKlTLiz2o2YzIpoJLazdlBlZlrIyfP5nhML0mrUnEBVVahwLC+v6db6nK3Ng2d61ClXeRIV0lp2Ww7fhaA69tXbUDVr20DmgT7kJKZy9LfT1XptaUSeQUaW+eIv/w8yEox7mvEn9RUBWHVB46w6lPoMgq8giDtNPz6oWtYdXQd2BS8i9QVejUgIiIiIlJZCo/583B947fwiD+N95O6wMvRQZWWfT4cUUBVtQKLCaic9z3MJgK86mZHW+uGxro/0XHlC6g2HE7Cbof2YQGEB/lURmkl8jCbuKdXcwDmb4qp0mtLJbpwxF/W2fP7fOpVeTkiFc7iBe1ugjveh79Ewz2fFQ2r5gwxwqqlE+DoWoVVIrVc3fztU0RERESkKlhK7qAqPOLPqjfppQ5whlGpmXkFj3nqe79KFddBdTYjt2BfXe3kjAz1A+BoUnq5nvfrEWP0Wu9I96wNNLxHU6b9dICtx86y73Qq7cMC3VKHVKCCEX+OsDQjydh6BYGH3sKTWsbiBW0HG7e8bDi8GnZ/BfuWQlocbP7IuPk3gvZDocMwiOhbdJ1XEanR9K+biIiIiEhlKdw1dcGLaUuhNacsZr1JL7Wfc8Rfatb5cKTwqEupfMUHVDkABPvU3Tf8mocYa1AdS84o1/OcAVWvlu4ZvdYwwJvr2zdk2Z44lmw/SfubFFDVeF5GN1/BiL+C8X5B7qlHpKq4hFU5Rli15yvY960RVm2ZZdx86kG7IUZY1WpAkQ+AiUjNo1cDIiIiIiKVpfCL5gtG/HkUCqU8NeJP6gBnB9W5LKODytOitdeqWrEBleN+kG/dDahaOAKq2DOZ5OXbyvScsxk57DttjGHr6aaACuC2rk0A+Hr7SWw2u9vqkApSEFClOraOUX9eCh+lDrF4QttBcPt7MOEg3PsFdBsDviGQeQa2z4eFI+A/reDzcUbXVXb5RrSKSPWhDioRERERkcri0kF1YUBV6L4CKqkDnGtQnXN0UHmpe6rKBfkYbwGkFl6DyjHiry53UDUK8MbTYiYnz8bJs1kFHVWl2Xz0DACtGvgR6u++T/AP7NAQfy8LsWcz+e3YGXpGuC8skwpw4Yi/goAqwD31iLibxRPa3GDchk6DYxtg7zfG7dxJ2PWFcbN4Q+sboMOt0PYm8Al2d+UiUkZ6RSAiIiIiUllcOqhc3/w1mzTiT+oW54i/wh1UUrWcXVKuHVSOEX++nsU+py4wm000q+cDQExy2dah2n7cCKjcHQh5Wz0Y3DEMgK+2xbq1FqkAF474U0Alcp6HBVpeC0P+A0/vhgeXw9VPQr0IyMsyxgEufgTeaAWf3Am/zYH0RHdXLVJrffnll9x4443Ur18fk8nE0aNHL+k8ekUgIiIiIlJZPEoe8Vd4DSpNOZO6wDniz9m946WAqsoVvwZVrsu+uqpFiB9Q9nWodsYaI9g6NXH/2kC3dW0MwA+7TpOvMX81m6fxfUiOIyh1dlI5O6tExGA2Q7OeMOhVeHI7PLoW+j0HDdqDLQ8OrYBvnoIpbWD2UNj0AaQoxBepSOnp6fTr149XXnnlss6jEX8iIiIiIpXFUnjEn+sIKI9CAZWHEiqpA7yt6qByt1LXoKrjAVXz+sZYv2NJFw+o7HY7u2NTgOoRUPVpFUKAt4Wk9By2Hz9D9xYa81djWY1OPvIyja06qEQuzmSCsM7G7fq/Q8IB2Pu1MQbw1HaIWWvcvn8OmvSAqGHGKMD6ke6uXKRGGz16NAC7du26rPPoFYGIiIiISGVxWYPK9c1fj0Jj/QqHVSK1lbNj6ly2Aip38fcyfg6lO/4bQKE1qHzrdkDVwrHuVEwZAqrTqVkkpefgYTbRPsz9wYHVw8yAdg0BWLYnzs3VyGVxBlS5CqhELlmDttBvAjzyMzz1Owz+NzTrDZggdgv89CK8fSX8ry+sfh3i94Jd3adSM8ybN49HHnmEHj164OXlhclkYs6cOaU+Z/PmzQwZMoTg4GD8/Pzo3bs3n332WdUUXAbqoBIRERERqSwuAZXriD+PQu/Nm9VBJXXAhYGUc00qqTq+nsbfeVrhgCpTARVAs3pGQBV7NvOix+5yjPdr09C/oDPQ3W6MasTXO06yfE8cz9/cwd3lyKWyODuossBmU0AlcrnqtYA+443budPGOlV7voajayFup3Fb/W8IaWN0VUUNg/Cumr8t1dakSZOIiYkhNDSU8PBwYmJiSj1+1apVDB48GG9vb0aOHElAQABffPEFI0aM4Pjx4zz77LNVVHnJ9JE1EREREZHKUrhr6oKAqnAopQ4qqQs8PVxffqqDqur5exmfUc3Os5GXbwPOB1SB3nU7oAoP9gbgVEpZAqrqM97P6bp2DbB6mDiUkM7hhDR3lyOXytlBBUZIpYBKpOIEhEHPP8L9X8NfDsJt70Lbm4zf0ZOiYe1U+LA/vHUF/DARYjYYQbFINTJjxgyOHj1KQkICjz76aKnH5uXl8dBDD2E2m/nll1/48MMPefPNN9mxYwdt27Zl4sSJRQKuv/3tb5hMplJvFU0dVCIiIiIilabQL/AXjPizeJzfp3xK6oLC3/NQNLCSyufrdb7bJz0nnyAfc8G4Pz+vuv32QHiQEQwkpuWQnZdfaoffwXgjAGrXqPqEBoHeVnq1DGHtwURW7osnsoG/u0uSS3FhQJXjCBsVUIlULN/6cOV9xi0rFaKXGetWRf8EKcdg47vGzb8RtL/F6K6K6Fvk93mRqnbDDTeU+diVK1dy6NAhxo4dS9euXQseDwoKYuLEiTzwwAPMnTuXF198sWDfs88+ywMPPFCBFV9c3f4NVERERESkMhX+hFkpHVRmJVRSB1jMF4z4syqgqmpeFg+sHiZy8+1k5OQR5GMlPccIqPzreEBVz9eKl8VMdp6NuJRsmjvWpCrOIUeHUuuG1SsE6tc2lLUHE1l/KIk/Xhvp7nLkUpg9wGwFWy7kZpzvoPKsXt9rIrWKdyB0vtu45WTAoZVGWLX/B0iLgy0zjZtPPWg3BDoMg8j+YPV2d+VSS5w7d47U1NSCP3t5eeHl5XXZ5129ejUAgwYNKrJv8ODBAPz8888ujzdo0IAGDRpc9rXLQ68IREREREQqTckBVeGxfh6acy91gFUdVNWCs1PK2TmVnp3v8nhdZTKZCA+6+Ji/fJudw4npALSqZl1K17QOBWDj4SRy8zWWqsayOsLR3CzIdrxhqQ4qkarh6QsdboE7PzTGAN77BXQbA74hkHkGts+HhSPgjdbw+TjY/RVka6yqXJ6oqCiCgoIKbpMnT66Q80ZHRwPQpk2bIvvCwsLw9/cvOOZSJCcns337dvbv3w/Anj172L59O8nJyeU6T93+DVREREREpDIVDp4u6B6xmNVBJXWLRWtQVQt+nhbOZuSSlp2P3W4v6KDy8yp5pF1dER7kw9GkDE6lZJV4TOyZTHLybHhazDSp51Pice7QISyQ+n6eJKfnsP34WXpG1Hd3SXIprN6QnWJ0UOVkGI95+rm3JpG6yOIJbW4wbkOnwbENsPcb43buJOz6wrhZvKH1DcYYwLY3gU+wuyuXGmbPnj00adKk4M8V0T0FkJJirJkZFFT8mpmBgYEFx1yKr7/+mrFjxxb8eejQoQDMnj27XGMCFVCJiIiIiLhB4RF/6qCSusByQRBb2ho/UnmcQVRGdh4ZOfnY7cbjdX3EH1Cog6rkgMo53i8y1M+lE7Y6MJtNXN0qhG9/P8Xa6EQFVDWVcx2qvCzjBsYb4CLiPh4WaHmtcbvpNYj9zRgDuPdrOHMU9n1r3MwWaHkdRA0z1q7yC3V35VIDBAQEEBgY6O4yyu2BBx6okPWq9JE1ERERERE38HDpoHJjISJVxHpBB9WFI/+kavh6GkFUWnZewZg/kwl8rAoMw4MvPuLPGVBVt/F+Tn0dY/7WHkx0cyVyySyOgCo3A3Id34vW6tWtJ1Knmc3QrCcMehWe3A6ProV+z0GD9mDLg0Mr4JunYEobyw4S2gABAABJREFUmD0UNn0AKbHurlrqIGfnVEldUqmpqSV2V1UlfURKRERERKTSlPwGvEtApQ4qqQMsFwRSF/5ZqoazUyojJ5/0HMf6U54WTPo5RFiQEQKcPFtaB5Wx/lRkg+o5cu3qVkZA9fuJs2Tl5uOt4LHmcYZRuYU6qBRQiVRPJhOEdTZu1/8dEg44Oqu+gVPbIWatcfv+OWjSw+is6nAr1I90d+VSBzjXnoqOjqZ79+4u+06fPk1aWhpXXXWVO0pzoc9qioiIiIhUlvycEncVDqiq25gokcpgLbIOm16OuoOvpxFYFO6g0vpThkYBxpoPCWnZJR5z4oyxJlDz+r5VUlN5NavvQ8MAL3Lz7ew4ftbd5cilsBbTQWVRQCVSIzRoC/0mwCM/w1O/w+B/Q7PegAlit8BPL8LbV8L/+sLq1yF+LwWzdkUq2HXXXQfAsmXLiuz78ccfXY5xJ70iEBERERGpLHklv8npoTWopI65sGNKwax7nO+gyiOtIKDScBWAUEdAlXiu5J/dsWeMwKBpveoZUJlMJrq3qAfAb8fOuLkauSTOgCorBXC8cW3VGlQiNU69FtBnPDz4Izy7D4a+aaxPZfKAuJ2w+t/wXm+Y3hOWvwwntymskgo1cOBAIiMjWbBgAdu3by94PCUlhX//+994enoyZswY9xXooN9CRUREREQqS78JcHA5dCv6i3/hN+c1WkvqgiIj/hRQuYWvl7ODKr+gg8pfARUAoX6OgCotG7vdXuRns81m58RZZ0BVfTtaureox/e7TvPbUQVUNZLFEUZlJhd6rPp+v4lIGQSEQc8/GreMZNj/nTEG8NBKSIqGtVONW1BzYwRgh1uhWS8tVCtFzJgxg7Vr1wKwc+fOgsdWr14NQN++ffnjH/8IgMViYcaMGQwePJh+/foxcuRIAgIC+OKLL4iJiWHKlClERES448twod9CRUREREQqS5Nu8JeD4OlfZJfriL+qLErEPYqM+NMaVG7hbTECquzc/PMdVJ56awAgNMATgOw8G+k5+UWCu8T0bHLybJhNEBZUfTtaCndQFRe0STXn7KDKdASMJg/wsLqvHhGpWL714cr7jFtWKkQvM9ativ4JUo7BxneNm38jaD/UCKsirtXPgVpswIABWK1Wxo8fz/jx40s9du3atcydO9flsXXr1rFu3bqCPzsDKue5165dy0svvcSnn35Kbm4unTt35vXXX2fEiBEV+4VcIv0WKiIiIiJSmbwDi33YrDWopI4pOuJPyaw7eFuNgCorN5+MnHxAa1A5+Xpa8PX0ICMnn8Rz2UUCqhOO8X7hQT5Yq/EnCzo2DsLLYuZsRi6HEtJp3bDohySkGvMwOvnISjW2Vh9QyChSO3kHQue7jVtOhtFRtfdr2P8DpMXBllnGzTsI2t4MHW6BVgPBs3qOmZVLs2rVKpo2bVqmY+fMmcOcOXPKdf6rrrqK77///hIqqxoKqERERERE3KDweDOz3niSOuDCN/Q14s89vK3Gf4esXBuZjoDKRx1UBUL8PclIziQpPZuIUD+Xfc6Aqklw9R635mkx06VpML8eTWbbsTMKqGoaZ5dEVoqxtVTfbj0RqUCevkYA1eEWyMuBI784wqrvID0Bfl9k3Cw+0Hqg0VnVdjD41HN35SKXRb+FioiIiIi4gQl1UEndcmEgpe979yjooMrLJyvPCKi8LNW3G6iqhfp7cTw5k4RzOUX2xZ6p/utPOXVuGsSvR5PZfTKV4e4uRsrH4uygcgRU1ur//SYiFcziCW1uMG62aXB8E+z9FvZ9A2ePwb5vjZvZYoz/63CrMQ4wIMzdlYuUmwIqERERERE3KNw0pQ4qqQsuDKTUQeUeXoVG/GXn2oDzXVUCIX5GOJCYll1k3+kUI6CqzutPOXVqYoyX3X0yxc2VSLmpg0pECjN7QIurjdvgf8Hp342wau83kLAXDq8ybkufhaY9jbCqwy1QP9LdlYuUiQIqERERERE3KBxKKaCSusBkMmH1MJGbbwfUQeUu3o5uqew8W0EHlbdFa1A5NQjwBCAprWgHVYIjtGoY4FWlNV2KTo2DANh9MpV8m13/v9UkHsb3INmF1qASEQHjE27hXYzb9X+HpENGULX3G4jdAid+NW4/vQCNOkF7x8jARp20lp1UWwqoRERERETcoPBrRA81L0gdYfUwk5tvhCLqoHIP72I6qLzUQVWgtA6q+FRHQBVY/TtaIhv44201k5GTz5HEdK1DVZN4aMSfiJRRSCvo+2fjlnoS9i01wqqjayFul3H7+TWoF+EIq4YZXVZm/bsv1YcCKhERERERNyjcNaVPtktdUTiU8lAy6xbnAyob2eqgKiLE39FBlV40oHJ2UDWoAR1UHmYTUeGBbD12lt0nUxRQ1SQa8ScilyKwMVz1kHHLSIYDPxijAA+tgDNHYcN04+bfCNoNMUYBRlxrrHcl4kYKqERERERE3MCsNaikDrIWCqXUQeUezvWmsnLzySpYg0oBlVOwrxEOpGTmujxut9sLOqga+Ff/gAqgU5Mgth47y67YFG7r2qRCz51vs7Ns92mW7Ykj9mwmgd4WekbU545uTWgYoEDlslgc31/5jjGT6qASkfLyrQ9d7zFuOelwcLnRWXXgR0iLg99mGzevIGh3k9Fd1XogePq5u3KpgxRQiYiIiIi4gUlrUEkdZPE4/72ugMo9vBzdUtl5NrJyjQ4qjfg7L9jH+CT5hQFVek4+mY6/r5rQQQXQsXEgAHtPnavQ8x5JTOeJhVvZFZvq8vjyvfG8tTyap25ow8PXRmLW/+OXxsOz9D+LiJSHpx9E3Wbc8nLg6C9GWLXvO0iPh98/NW4WHyOkan+LEVr51HN35XXGgAEDsFqtjB8/nvHjx7u7nCqngEpERERExA1c16DSm3hSN1gKrXlQOKySqlO4gyo7z9FBpRF/BQJ9iu+gSjhndE/5eXrg51Uz3kpp0ygAgIPxaRV2zl2xKdw3cxNnM3IJ8LZwT6/mdGocRMK5bJZsj2XHiRRe+34fW2PO8PaoK/+fvf+Oj+Ow74T/z9Rt6ABBgATBXkSxqFBUL4xly5adnFssuZxb7EQxn5zPJU8cJRc/d77Y8V3sPPnd6ZzIOls+W3LsPI7tyLZKbFHV6hIp9k4CLABRF8D2Kb8/ZmZ3drEAFsDuzpbPW6957e7M7OwsCJGL+eD7/bI6byGmBVSKN+dBRLVHVoF1t1vLO78J9L8MHPmFFViNn7XuH/kFIMrAqpusNoAb3wk0dXt95jVtz5496Onp8fo0PFMdn6qIiIiIiGqMyAoqqkPZwSyrdryQmUGls4Iqj2YnoIpmB1SXJuIAqqd6CkB67tTARBzhWCr93hbqYjiGTz74CsajKWxf0YJv//ur0dmUaef3iRtX4Z9e6ceXf34QTxwaxO6HXse3PnI1VJnfX/PCCioiKgdRAlZeby1v+6/AwH47rPoFcOkgcOopa/nlF4Cea6ywatO7gPa1Xp851Rh+SiAiIiIi8oCYdaHeu/MgKid3tSBb/HnDqZaKp1wt/lhBlebMoJqIa9ANM71+aMqqoKqm+UpNfgVddoC02Coq0zTxp//8Ji5NJrBxaSO+/wc7s8IpwGpd+8GdvXjwk9fAJ4v4zZFL+OqvDi/qdetSbsUUK6iIqNQEAejeBuy6F/jMb4E/eR1463+xgikAOPcK8G9/BfyPq4D/dQOw56vAxTcB05z9uEQF4I/CREREREQeEOCqoOKFeqoTkquEiq0tvZFu8ae5WvyxgirNXWU0Gc9UUY1GkgCA1lB1hQXrl1pVVCcuLW4O1f/32jk8d2IYPlnEtz5yFZr8M38dbljbgf/5oasAAA/+9gwe2XdhUa9dd+ScKj2xur7niKgGtK8Fbvws8KlfA58/ArzzG8Ca26zWf5cOAk9/HfjHm4G/3w48/hfA2RcAQ/f6rKlK8VMoEREREZEH3K3O2OKP6oXICirPOe3WTBOIJq2LSZwTlKFIIoKq9fVwz6Eat1v+tQarq93a+k5rDtWxwYVXUMVTOr7xxDEAwOfeugFrljTM+Zy3bl6Kz9xmtYH6Tz8/gBG7Ao0KwBZ/RFRJmrqBaz4FfPTnwBePA+/+B6vVnxyw5la98D+B774d+MYm4JHPAid+DWhJr8+aqggDKiIiIiIiD7gv1EsMqKhOZLe25Pe9FxRXT1GnQogBVbYWu4pqPDo9oGqptoDKrqA6vogWfz98uQ8DE3F0N/vx8RtWFfy8z711Ay7rbsJ4NIWv/OLQgl+/7kwLqFhBRUQVItgGXPFB4O6HgP/7JPCB7wPb7gJ8zUDkEvDag8AP3gf893XATz4FHPo5kIx4fdZU4RhQERERERF5wH1pnvkU1Qt3taAs8sdRL2QHVBoAwCfzz8KtyQ6osiqoYtZvgzszqqrFBiegGlxYiz/dMPHAs6cBALt3rZtXmKlIIv7mvVshCsDP9l7Aa2fHFnQOdYcBFRFVAzUEbP494L33A396AvjIvwA7Pgk0LAUSYWD/PwM//ijw39YAP/wQMLDf6zOmCsVPoUREREREHnBfqGc+RfXCXTXFCipvKFLm656ZQcUKKrfmPAFV2KmgClRXWLCmwwqoLobjiKfmPx/kmWNDOD8eQ3NAwfuv7pn387evaMHvX70CAPD1x47ANM15H6PusMUfEVUbWQXWvQV4199ZM6s++QRw/f8FtK4CtDhw9JeAwM8alB8DKiIiIiIiD4gsoaI65A6lZInf914QBGHa/C9WUGVzAqrxrAoqp8VfdQVULUEFjX4ZANA3Gp338x966SwA4P1X9yw4yPzs7euhyiJePj2KZ44PL+gYdUXM+TqLsjfnQUS0EKII9F4L3PHXwH/YC9zzHPC2vwY6L/P6zKhC8VMoEREREZEXhLx3iWpados/fud7xd3mL9/jetfot0KoKbsFIgCMRa0Wf82B6qpmEQQBK9uDAICzI/MLqMLRFJ46OgQA+ODOFQs+h2UtAXzk2pUAgH946uSCj1M3cgMpVlARUbUSBKBrK3DD/8VfyKMZ8VMoEREREZEHslr88ec1qhPuTIozqLyj5FSv5T6udw0+q4IlksgEVOkWf1VWQQUAK9tCAICzI/MbVP9vhwehGSY2Lm3Eus7GRZ3Dp25eDVkU8MKpEew/F17UsWoeAyoiorqya9cubN68Gffdd5/Xp+IJ/kRAREREROQBd0DFkRxUL9wt/phPeUfNaekns4IqS8hnBQSRpBVQmaaZbvHXGqy+sKB3gRVUjx24CAB4+5auRZ/DspYA3rWtGwBw/7OnFn28mjYtoGKLPyKiWrZnzx4cOnQIu3fv9vpUPMFPoUREREREHnBXknBoPNULdzArsnTQM9Nb/PHPwi0dUNkVVFMJDbph/T1dnRVUdkA1jxlUsaSenhf1jq2LD6gA4FM3rwFgBV8jU4miHLMm5c6gYgUVERHVMAZUREREREQeEDh5iupQVgUVAyrPTAuoWM6WpSEdUOkAgHG7vZ8qi/Ar0ozPq1ROBVXfPFr8vXp2FEnNQHezHxuXLq69n2PL8mZs62lGSjfx0zfOF+WYNSk3oBKrLxQlIiIqFD+FEhERERF5QHB9Emf9FNWL7AoqD0+kzrkrpkQBEPmHkcWpoJqyK6gm49Ztk786g4JV7dYMqnNjMWi6UdBznj8xAgC4YW0HhCKGyXddswIA8E+v9LN6eCbTWvxV5/cdERFRIRhQERERERF5gDOoqB65gxAWUHnHXUHF+VPTNfisChanxZ8zi8pZX226mvxQJAGaYWJwsrDWer89abX3u3Fde1HP5fe2L0NAkXDi0hTe6B8v6rFrxrSAii3+iIiodvGTKBERERGRB3htnuqRe9RRMasyaH7cAZXKgGqaoJpdQeXcOpVV1UYUBSxt8gMALo7H5tx/Mp7C/vNhAFYFVTE1+hW87fKlAIBfvnmxqMeuGaygIiKiOsJPokREREREHsiqoGKTP6oTnEFVGdwt/mSJfw65nCDKqZyKVHlABQDdzVZAdSEcn3Pf/efCME2gpzWALvt5xfSubcsAWAGVYfDfv2k4g4qIiOoIAyoiIiIiIg+4r82zxR/VC4EzqCpCVos/kZcFcjU4AVVCt2+1rPXVqLs5AAAYCM9dQbX33DgAYPuKlpKcy83rO9DokzEwEcfrfWMleY2qNq2Cqnq/74iIiObCT6JERERERB5g8QjVI8n1jS+w0aVnVDlzKUBhBdU0IXvWVKbFn26vr96goLvFrqAan7uC6s1+q73f9p7mkpyLX5Hw1s1Wm79fsM3fdLkBlVCds8+IiIgKwYCKiIiIiMgD2S3+iOqDu8UfQ1rvuCuoFM6gmsaplEpqBlK64aqgqt6gYJldQXWxgAqqfU4FVU9Lyc7nHVu7AQC/PjwIk2XE2aYFVPx/lIiIahf/lSMiIiIi8kBWQMVrc1QnRPcMKvb48wxnUM3OXSkVSWjpgCqoVnEFlT1L6uIcM6iGJhO4GI5DFIAty0tTQQUAN65rhyqJODcWw8mhqZK9TlXKrZjKnUlFREQ1ZdeuXdi8eTPuu+8+r0/FE9X76YqIiIiIqIqJWTOomFBRfXB/3zMW8Y7srqDiDKppFEmEJArQDRPxlJFu9VfVLf7SFVSzB1THBicBACvbQyV9v0FVxrVr2vDs8WHsOTKEdZ2NJXutqiOKVtWUaViP2eKPiKim7dmzBz09PV6fhmf4SZSIiIiIyAMC+5tRHXLPoBL5/4Bn3H8OrKDKz2/P6Yqn9Jpo8efMoBqeSiCpGTPud9wOqNZ1NpT8nHZt7AQA7Dl6qeSvVXXcbf4YIhMRUQ3jv3JERERERB5j/RTVi6wWf8xFPCO7vvicQZWfX7HCqLimYyqhA6juCqr2kApVEmGawODEzFVUxy5Z7fY2LC1DQLXJCqhePj2arlIjmzugYgUVERHVMH4SJSIiIiLyGK/TU71wV+6witA7YlZAxT+HfNIBVcpwVVBVb0AlCAI6GlQAVhXVTJwKqvVlaLm3uiOEntYANMPEq2dGS/56VcUdSnEGFRER1TAGVEREREREXuP1YaoT7k5VzKe8466gktk+LC+fkmnxF01ZFVRBtXoDKgDoaPQBAIanknm3m6aJY4NWBdX6MlRQAcB1a9oBAC+eYkCVxR1KsYKKiIhqGD+JEhERERERUZlwBlUlcFdQcQZVfn7ZqaDSkbADqoBS3UFBR4MTUOWvoBqJJBGOpSAIwNol5Q2oXjo9UpbXqxpZM6iq+/uOiIhoNgyoiIiIiIiIqEwyE9c4g8o7nEE1N3+6gspA3A6onKqqapVu8TeZP6A6OxIFAHQ3+dMtDkvt2tVtAIA3z4XTrRQJOTOoqvv7joiIaDb8V46IiIiIyGMCe/xRHeL3vXekrBZ//HPIxwloEpqOeMqw1snVXckyVwXVuTEroOppC5btnFa0BbG8JQDdMPHa2bGyvW7FY0BFRER1gv/KERERERERUdnxmqt3JFd7RYkBVV5OQBVP6Uhour2uur9pMwFV/hlU/aNWQLWitXwBFeCeQ8U2f2nu2XBs8UdERDWsuj9dERERERHVAI7ioXphZjr8cQaVhySJs8Dmkt3iz66gqvYZVI1WQDU0QwVV/2gMALCiLVC2cwKAa1a1AgD29o+X9XUrmiDlv09ERFRjGFARERERERFR2bFwxzvuCiqRfxB5Oe384ikdcbuCyidX9yWU9AyqmQKqMW8qqLavaAFgzaEyDHP2neuFwAoqIiKqD9X96YqIiIiIiIiqEmdQecc9d4r5VH5+1QoFJuNauvLPV+UVVEucFn+TswdUve3lDajWdzYgoEiYSmg4NTxV1teuWO6AihVURERUwxhQERERERERUVm4W/yxs5x33FVTEv8g8nIqqMKxVGZdjcygmohrSGpG1jbdMHFhPA4A6Gktb4s/WRKxdXkzAGBvf7isr12x3P9fsoKKiIhqWHV/uiIiIiIiqgG8PEz1iLOPvOOuoBL455CXE0aN2wGVIACqVN2XUJoDSjr3cAdvADAylYBumBAFoLPRX/Zz277CCqj2cQ7VdEJ1f98REdHsdu3ahc2bN+O+++7z+lQ8IXt9AkRERERERFR/mIt4RxJF130PT6SC+e12fuPRJABr/lS1h3miKKDJryAcSyEcS2JJoy+97ZLd9q+jwQfJg76PzhyqfefGy/7aFY8BFRFRTduzZw96enq8Pg3P8F85IiIiIiKPVfk1T6KCmcj0+GMFlXfcoRT/HPLzydYXacKuNPJX+fwpR0tQAQCMR7MrqAYnrPZ+nU2+ac8ph23LWwAARy5OIqUbs+9cD9z9UNnij4iIahgDKiIiIiIiIio7D4o0yOauoBL5B5GXYqd4k3ENQGYmVbVrCeQPqJwKqqUetPcDrLlXDT4ZSd3A6eGIJ+dQWdwD+2rje4+IiCgfBlRERERERB4TOIWK6oS7KKDa26VVM8n1pWc+lZ/iVFDZAZVPqY3LJ81BFUBmtpbD6woqURSwsasRAHBkYNKTc6gorKAiIqI6URufsIiIiIiIiIioIJKrx5/EoDAv1U7xphJ2i78aqaBqTldQJbPWOxVUnR5VUAHIBFQXJzw7h8rBCioiIqoPDKiIiIiIiDzG68NEVE7uUIqVbPk5Lf7iKWsekr9GKqicFn/hnAqqSxN2QOVRBRUAXGYHVEdZQcUKKiIiqhu18QmLiIiIiIiIKp459y5UBrKrr5/EHn95KVL25RJfjVRQtQRnCKgmrRZ/Xs2gAoCNXU0A2OJvGobIRERUwxhQEREREREREdUR9/VuBlT55QZUqlwbl08yLf6yA6qRKavlX0ejdxVUTou/8+OxaQEaERER1aba+IRFRERERERERAURs1r8eXgiFUyVs78wilQbX6iWoAoAGM8JgMbsmVStdoWVF5oDCpY1WxVcxwbrvYqK9aZERFQfGFARERERERFRWZi85loRRNeVAIkJVV65FVSyVBuXTzIVVMn0uoSmI5rUAWQCLK+sW2pVUZ0amvL0PDzHvyyJiKhO1MYnLCIiIiKiKsbLw0RUTu4KKpEBVV6ymH25pFYqqBr9MgBgKq6l1znt/iRRQJO93StrOkIAgFPDEU/Pw3sMqIiIqD4woCIiIiIiIiKqI4I7oOIMqrxyW/zlBlbVqsFnB1SJTEDltPdrDihZ3xteWG0HVKeH6jygYj5FRER1ojY+YRERERERVTNWMBBRGbkzKeZT+eW2+Mt9XK2cgCqSmF5B1eLh/ClHOqCq+woqIiKi+uBt7TYRERERERHVDZNlARXB3daPM6jymx5Q1cbXKeQEVEkdhmFCFIX0PKpWj+dPAZmA6uxIFLphQlpkghpL6ni9bwxJzcDly5vQ2egvxmkSERFRkTCgIiIiIiLyWG1c9iSiapFVQcUSqrxyAyq5RgIqp4IKACJJDY1+BWN2BVVrBVRQLWsJQJVFJDUDF8ZjWNEWXNBxTNPEj17px988diRrxtb7rlqOv3zXZjT5vX+vs6qNbzciIqI51UaNOhEREREREREVJGsGFSuo8lJzA6oamUHlV8R0VVIkoQPIzKBqqYAKKkkUsKrdCqVOLaLN3/948gS+9C/7MR5NoavJj41LG6EbJn786jm8/1u/xfBUolinXCL8/5KIiOpDbXzCIiIiIiKqYrw+THWDHf4qgpgVUHl4IhVMkbO/MLXS4k8QBIRUCQAwZc+hSs+gClRGVdGajgYAwKmhqQU9//GDA/jmvx0DAHz+rRvw/Jd+B49/7hb88z3XY2mTD8cGp/CJ776CeEov2jkXncDLdURE9WLXrl3YvHkz7rvvPq9PxRP8F4+IiIiIiIjKgvlUZXCHUoud8VOrps+gqp3LJ06bv0g6oLJnUIW8r6ACgF67gurcWGzez52Mp/BXPz8AAPjUTavxH96yPv09fs2qNvzTH16P1qCC/efD+OtfHi7eSRcbf3OFiKhu7NmzB4cOHcLu3bu9PhVP1M4nLCIiIiIiIiKak/vaN1v85Td9BlXtXD4J5QRUTiVVo78yxpQvbwkAAM4vIKD6znNnMDiRwKr2IL54x8Zp21d3hPB3d10BAPj+i2fx6pnRRZ1rybCCioiI6gT/xSMiIiIiIqKyYBRSGQS2+JtTbks/pYa+UA12EDVpB1OTces2pFZWQHVuPDqv500lNHzn+dMAgC+8bSP8ipR3v9s2duKuHSsAAP/p5wehG5VY21k7329ERESzYUBFREREROQxgReiqF7wW70iuKum2OIvv9yvSy1VUOW2+HNuGyqlgqp1YRVUj+y7gHAshdUdIdy5tXvWff/sHZvQ5Jdx+OIEHj1wccHnSkRERItTO5+wiIiIiIiIiGhO7uxFYIu/vGQxdwZV7XydnEqp3BZ/TnDlNSegGoum0udYiH9+tR8AcPc1K+YMXttCKj5502oAwP/4zQkYFVlFRUREVPsYUBERERERERHVEVZQzS33y5I7k6qaOTOophI6ACBi34YqJKBq8itosqu5zo8XVkV1ejiC1/vGIYkC3nPV8oKe84kbVqPRJ+Po4CSePja04PMtCQbHRERUJ2rnExYRERERUZXidSiqF2xnWRncf+cwn8pPEISs8E6uoQqqBp81m6lSK6gAYHlrEEDhbf5+fWgQAHDD2nZ0NvoLek5zUMFd11izqB566ewCzpKIiIgWiwEVERERERERlQXD2MrgrqAS+YcyI3dApYi1c/nEr1oBVTylwzTNigyoeuw2f+cKrKB68sglAMDvbOqc1+t86Nre9PMLrdYiIiKi4qmdT1hERERERFWKl4eJqJwYUBVGEmqzgsovWwFVLKUjoRnQ7flLIbuyqhIsb7ECqkIqqCbjKbxyZhQAsGvj/AKqNUsacOO6dhgm8ONX+ud/okRERLQoDKiIiIiIiIioLGrnEn91c7f14wyqmcnuCqoamkHlV5wKKgOTcS29PqRWTgXV0iarTd+lyfic+75wcgSaYWJVexCrOkLzfq3fv9pq8/fImxdgmua8n18a/P+SiIjqQ+18wiqxv//7v8fKlSvh9/tx0003Yd++fV6fEhERERHVCBYwEFE5Ca6/dPj3z8zErICqdr5QAcW6FBTX9PQcqpAqZb1fry1t8gEALk0k5tz3tb4xAMB1a9oX9Fq3b14Knyzi1FAEhy5OLOgYREREtDAMqArw8MMP48/+7M/wla98Ba+99hrWrVuHO+64AxMT/OBCRERERERE1cWdQ7DF38zcFVS19HVKV1Al9cz8KX/lVE8BQGejVUE1ODF3BdUbZ8cBAFetbF3QazX45PTsqkf2XVzQMYiIiGhhGFAV4O/+7u9wzz334KMf/Sguv/xyPPDAA9A0DQ8//LDXp0ZERERERFQ1augaf1XjDKrCuNsf1lIrxHRApWUCqpCvsgKqdAXV5OwVVEnNwL5z4wCAqxcYUAHAu7YtAwA8doABFRERUTnNO6CKx+P4/Oc/j1tuuQXLli2D3+9HV1cXbrzxRnz3u99FKpUqxXnO6gc/+AH+6I/+CDt27IDP54MgCHjwwQdnfc4rr7yCO++8Ey0tLQiFQrjuuuvw4x//eNp+yWQSb7zxBm6//fb0OlmWcdttt+GFF14o9lshIiIiIiIiKimGUoWRar2CKmUgltIBAAF7XaVwKqjCsRTi9jnmc/jiBBKagZaggjULmD/luGVDBxRJwJmRKE4PRxZ8HCIiIpqfeQdUU1NT+Na3vgVBEPDOd74Tn//85/Ge97wH58+fxyc/+Um8613vgmEYpTjXGf3lX/4l7r//fpw9exbd3d1z7r9nzx7ceOONeO655/CBD3wA99xzDwYGBnDXXXfhG9/4Rta+w8PD0HUdS5cuzVrf2dmJgYGBor4PIiIiIqpPQg1d+CSiyuf+K4d//cwsK6CqqQoq61JQLKkjYYc//goLqJoCMnyydZ6zzaE6eMEavbB1efOi/i1t9Cu4ZlUbAGDPkUsLPk7R8H9MIiKqE/MOqNra2hAOh/H000/j29/+Nr761a/iW9/6Fk6cOIHbbrsNTzzxBB599NE5j/P9738fZ8+enXG7ruv45je/iWQyOeexHnjgAZw5cwZDQ0O45557Zt1X0zR8+tOfhiiKeOaZZ3D//ffjG9/4Bvbt24cNGzbg3nvvnfW8iIiIiIiIaGEE8KJrJailaqBSyq6g8vBEiszd4i+eMux1lTUBQhAELG2y51BNzjyH6uiAFVBd1t206Nd05lDtOVoBARUREVGdmPcnEFEUoarqtPWyLOM973kPAODEiROzHuPcuXP49Kc/jdtuuy1vGGQYBj72sY/hC1/4Au6///45z+n222/HypUrCzr/J598EidPnsSHPvQhXHHFFen1zc3NuPfee5FMJvG9730vvb6jowOSJGFwcDDrOJcuXUJXV1dBr0lEREREREQsCqgUoutKAP9IZpY1g6qGvnmdgCqRMtLt8/xyZVVQAUBnoz2HapYKqiMDkwCAjUsbF/16t220AqqXTo3O2laQiIiIiqdovyJjGAYee+wxAMCWLVtm3benpwc//OEPce7cOezatQt9fX1Zx/nYxz6Ghx56CB/96Efxmc98plinCAB46qmnAABve9vbpm274447AABPP/10ep2qqrjyyivxm9/8Jr1O0zQ89dRTuP7664t6bkRERERERLWshq7xVzVWshVGdgVUtdSK1Zk3FUvpmYCqwlr8AUBnkxVQDU7kr6AyTTMTUHUtPqBauySEpU0+JHUDr/eNLfp4RERENDd5oU9MJpP46le/CtM0MTIygt/85jc4cuQIPvGJT+Atb3nLnM9/z3vegx/+8If44Ac/iNtuuw1PPfUUenp68PGPfxw/+MEP8OEPfxjf/e53IYrFLTM/fvw4AGD9+vXTtnV1daGhoSG9j+Nzn/sc/uAP/gBXX301rrrqKvzt3/4tZFnGhz70oRlf57777sN9991XUItCIiIiIiIionLhDKrCuFshSjXU489p5xdP6YhrVos/X4W1+AOA9pAVUI1F819XGZxIIBxLQRIFrOtsWPTrCYKA69a04+d7L+DFU6O4YW3Hoo9JREREs1tUQPWf//N/Tj8WBAFf/OIX8bWvfa3gY7z//e+Hruv48Ic/jF27dmHHjh348Y9/jLvvvhvf+973ih5OAUA4HAZgtfTLp6mpKb2P40Mf+hCGhoZw7733YnBwEDt27MDjjz+OpqaZexzv3r0bu3fvxrlz57BixYrivQEiIiIiIiKiRaidqKW0ZKnGZ1BVeAVVa8gaLzEayR9QnRqaAgD0tgWLdv6ZgGqkKMdbuBr6hiMiIprFggOqhoYGmKYJwzBw4cIFPPLII7j33nvxwgsv4Fe/+tWs4Y3bXXfdBU3T8JGPfASnTp3Cu9/9bvzgBz+AJFXWh6PPfvaz+OxnP+v1aRAREREREVUxXnStPPwzmYl77pRYQwmVE+YYJjAZ16x1FTiDqn2OgOrsaBQAsLI9WLTXvG5NOwBgb9844im9IoM7IiKiWrLoEiVRFNHT04M//uM/xv3334/nn38ef/3Xf13w803TxJNPPpl+fPDgQQwODi72tGbkVE7lVkk5JiYmZqyuIiIiIiIiIqp2bPFXGHdbP7GGvlB+Vzu/8Whq2rpKMVcF1dkRO6BqK15Atao9mJ5Dtbd/vGjHJSIimsmuXbuwefNm3HfffV6fiieK+gnkbW97GwDgqaeeKmh/0zTxh3/4h/jOd76Du+66Cz/4wQ9w6tQp7Nq1CxcuXCjmqaU5s6dy50wBwMDAAKampvLOpyIiIiIiIqLFqaFr/FWOfxCFcAdUUg1986pS5lLQRNwJqCqvUqgtaAVUM82g6huNAAB620NFe01BEHBVbysAMKAiIqKy2LNnDw4dOoTdu3d7fSqeKGpA5YRKiqLMua9pmvijP/ojPPDAA/jABz6Ahx56CB/+8Ifx/e9/HydPnsSuXbtw8eLFYp4eAODWW28FADzxxBPTtj3++ONZ+xARERERERHVmqwKKu9Oo+K5q6ZqKJ+CIAjpkGoyXrkVVG3pCqpU3u2lqKACgO0rWgAA+7wMqGrpG46IiGgW8/4EcujQIUSj0Wnro9EoPv/5zwMA7rzzzlmPYZom/viP/xjf/va30+GUM3Pqgx/8YFZINTAwMN9TnNVb3vIWrFmzBg8//DD27t2bXh8Oh/HVr34Vqqriox/9aFFfk4iIiIiIiIiqizugkmpoBhUAKJL1ftIzqCqxgiqUqaAyTTNrm2ma6LMDqt4izqACgCsqIaAiIiKqE/J8n/DjH/8Y3/zmN3HTTTdh1apVaGpqwvnz5/Hoo49iZGQEN998Mz73uc/NeowLFy7gpz/9KX7/938fDz30EGQ5+zQ++MEPwjRNfPSjH8Wvf/1rfOQjH5n1eA888ACee+45AMD+/fvT65xWgzfddBM+9alPWW9YlvHAAw/gjjvuwC233IK7774bjY2N+MlPfoKzZ8/ib//2b7Fq1ar5flmIiIiIiIhoDrV1ib96uf8cBFZqzEh0/UpvLc2gAgBFFoGknm7x55Mrr4KqNWR159ENExMxDc3BTLee8WgKkwkrXOstcgXV1uXNEAXgQjiOSxNxdDb5i3p8IiIiyph3QPWud70LFy5cwG9/+1u88MILmJqaQnNzM7Zt24a7774bn/zkJ6cFTrmWL1+OF154Ab29vTPu+6EPfQg7duzAhg0b5jyn5557Dt/73vey1j3//PN4/vnn04+dgAqwBo8999xz+PKXv4wf/ehHSKVS2Lp1K77+9a/jrrvumvP1iIiIiIiIaP5q7Bp/1WIoVZjsCioPT6QEMi3+KreCyidLaPDJmEpoGI0mswKqi+E4AKCjQS36uYd8MtZ3NuLo4CT2nQvjrZsZUBEREZXKvAOqHTt2YMeOHYt+4TVr1sy5TyHhFAA8+OCDePDBB+f1+jt37sSjjz46r+cQERERERHRwuV06SKPCDPcp2xC1gyq2vpKKTkBlU+uvIAKsKqophIaRiNJrO4IpdcPTloBVWdjacKjbT3NODo4if3nw3jr5qUleQ0iIiJawAwqIiIiIiIiIqJa5x47JdVYQKXaLf10w0qN/UplXh5qC9pzqCLJrPWDdgXV0iZfSV73su4mAMDRgYmSHJ+IiIgslfkJhIiIiIiojtTYdU8iqnDuv3P498/M3C3+am4GlZT9ftQKnEEFAC1OQBXNCagmEgCApSWaD7WpuxEAcGRgsiTHJyIiIktlfgIhIiIiIiIiopIQ2NivIO4KKrHGrp7kBlJqhQ7ZavRbkymcVoQOp8VfyQKqLquC6uxIFJGENsfeREREtFCV+QmEiIiIiIiIag5HUFUGVlAVRqjpCqrsy0FyxQZUCoDpAdWlidIGVG0hFZ2NVvvAo4OsoiIiIiqVyvwEQkRERERERETkoawZVGJtB1S5Lf8qRVPAqqCaiKey1mda/JVmBhUAbErPofIioKrMPw8iIqJiY0BFREREREREVKfY7m9m7qqpGiugmtbSLzewqhRN6Qqq3ICqtBVUALCpy55DdXGiZK9BRERU7yrzEwgRERERERHVnBq7xl+1ssIW/qHMyP11kmosocqtmKrcgMquoIplWvyZpomRSBIA0NFQugqqjUutgIot/oiIiEqnMj+BEBERERERUc3hDKrKINRY2FIOtTaDSpVzZlBVaAvD9AyqRKaCaiKuQTesv01aQ0rJXnvNkhAA4MxwtGSvQUREVO8YUBERERERERHVqcqMJSqPWKEBzkLlVkzlBlaVotGuoJqMZyqoxuzqqaAqwSdLJXvt1R1WQDUwEUckoc2xNxERES1EZX4CISIiIiIiIqKSqK2opXRMV8lfjeVT02ZQVWoFVVPAqpCaiGUqqMaiVkDVGlRL+totQRVtIes1Tg9HSvpa09RYxR4REdFMGFAREREREXlseUvA61MgojrivvbNdn+FkSo0wFmo3AoquUJnUOWroBqPWmFVKdv7OdbYVVRlD6iIiIjqRGV+AiEiIiIiqgPf/fg1+Mxta/G725Z5fSpEVEcE1lAVJLuCqra+Zoqc/X5yK6oqRXoGVVyDaf+BlKuCCsi0+WNARUREVBqy1ydARERERFSvdm3qxK5NnV6fBhHVsdqKXUqn1gIqWcytoKrM99dkV1AldQMJzYBfkTBqz6BqKUdAtcSjgMqdjhIREdWwyvwVGSIiIiIiIiIqiRrLWkrGRCYkqLEOf9NaFlbqDKqgmvm96qmE1eYv3eIvWL4Wf6eGpkr+WkRERPWIARURERERERFRHXFHEQyrClNrM6jc70eRhIqdRSaJAvyKdekqltQBlLfF34q2IACgfyxW8tciIiKqRwyoiIiIiIiIiIhyuLusVWqAs1DugCq33V+lCdlVVFE7oCpnBVVPqxVQjUaSiCa1kr9eBlv8ERFRfajsTyFERERERERUMzhWpUII7ru1FbxQYSQhu4KqkgVUCQAQsQOi8Vj5ZlA1BxQ0+qyA7MI4q6iIiIiKjQEVERERERERUR1hKFWYWs5Ts1v8VfalIaeCymnxNxW3gqoGnzzjc4ppeWsAAHBugW3+UrqBxw8O4PsvnsW+/vEinhkREVH1K8+/5kRERERERERUEdzd6mqscx0VSK6igCpdQZWwgqkp+7bBX6aAqiWAIwOTOL+ACqozwxH8wfdewcmhSHrd725fhv/+/m3wK9LMT2S5KRER1YnK/hRCRERERERERCXDfGpmtfy1Ed0zqCq8xV/QDqhiKbuCKuFNBdX5eVZQhaMp/PvvvISTQxG0h1TcsmEJJFHAI/su4A+//xp0gyEUERERAyoiIiIiIiKiOlLZcUTlqOXqsmqqoAraLf4iCT3rNlSugKplYS3+vvqrw+gfjWFFWwCP/seb8X8+uRMPf+paBBQJzxwbwj8+c3KWZzO8IiKi+lDZn0KIiIiIiIiIqKgEV/LCy+AzE2s4oXLPoHKHVZXIqaCKJjUYholIsrwVVD2tQQCYV4u/Y4OT+PFr/QCAv/vAFehs9AMArl3Tjv/y7y4HAPz9r4+jfzSa/wBs8UdERHWCARURERERERFRHansOKJy1HA+lRVQVXoQF/LZLf6SOqIpPZ3dVHKLv+88dxqmCbz98i7sWNWWte39V/fghrXtSGgG/ubRI0U9VyIiomrDgIqIiIiIiIiIqI64q6bECq+gCih2i7+kjog9f0oUAL9Snktay5qt6qdLk/GC5kaFYyn8bO95AMAf3Lx62nZBEPBXv7sZAPCrAxdxejhSxLMlIiKqLgyoiIiIiIiIiOpIhRfMVAyhhmvN3KFUhY+gclVQaZhKZNr7CWX6Rm5v8EEUAMMERqYSc+7/m8ODiKcMrOtswI6VrXn32dTVhLds6oRpWtVWRERUv3bt2oXNmzfjvvvu8/pUPFHhH0OIiIiIiIiodnCuSiVwBy8cdVOf5Cpq8RewZ1BFkjqm4uWdPwVY7RDbG3wAgEuTcwdUvz48CAB4x5auWUO0T9xoVVf9fO95xFN6zlb+j0lEVC/27NmDQ4cOYffu3V6fiicYUBERERERERHVk8rOI6gMJDFzOajSA6qQaoVRMVeLv1AZAyoA6Gx0Aqr4rPslNB3PHBsGANx+2dJZ971hbTuWNfsxEdfwxKHB7I1MjomIqE4woCIiIiIiIiIiylXZuc2iuNv6SRU/g8qqoIq6W/z5PQqoJmavoHr1zBimEhqWNPqwdXnzrPuKooD3XtUDAHhk34XinCgREVGVYUBFREREREREZVLZF8LrRYUXzFAZZFdQeXgiBfAp1rkmdQORpF1BpZY7oPIDmLvF36tnxgBY1VFiAV/YO7d2AwCePT6EWDK3zR8REVHtY0BFREREREREVEfcl81NzrqZUYXnNotSTTOofLIdUGkGYkkDAOC3q6rKpbOpsBZ/r/dZAdVVva0FHfey7kb0tAYQTxl4+tiQawv/vyQiovrAgIqIiIiIiIjKhBddK4FQ4YEElZ47lKr0Fn+qHVAlNAPxlFVl5FfKezmrkBZ/hmGmA6qrVxYWUAmCgDsu7wIAPHFwYJFnSUREVH0YUBERERERERHVKaGm64RoJu4KqkoPqHyyVS2VSBmIa05AVd4KqiUFtPg7MTSFybiGgCJhU1djwcd+6+alAICnjg3BMBjiExFRfWFARURERERERFRH2OKP3KFUpVfU+dIVVDriKafFX5krqOwWf0OzBFQHL4QBAFuWN0GWCj+/q3pbEVQljEaSODIwubgTJSIiqjIMqIiIiIiIiIjqSIXnERWj0oObxXAHVFKFv810BZVmIOG0+JPLXEHVYAdUUwmYZv5Q1wmXNnU1zevYqixi5+o2AMDzJ4atlTO8BhERUa1hQEVERERERERlwWuuRJWhqlr82dVSSc1AQnMqqMobULWF1PQ5xOyQLNdRO6DaOI/2fo4b13YAAJ4/aQdUrGwkIqI6wYCKiIiIiIiIqI5w7hSJVdTiT5WcFn8G4k4FVZlb/AVVKX0eo5Fk3n2Opiuo5h9Q3bCuHQDw8ulRJO0QjoiIqB4woCIiIiIiIqKyqPDr4HXD/efAqrb6lFVBVeH/YzoVVNYMKiegKm8FlSAIaA0pAIDxaGra9nAshYvhOABg/dL5B1SXdTWhJaggmtRx+OLE4k6WiIioijCgIiIiIiIiIiLKUdmxzeK4q6YqvsWfPW8qpZuIJq2AylfmgAoAWoNWm7+x6PQKqrMjEQBAR4MPzQFl3scWRQFXrmgBALzRN8bkmIiI6gYDKiIiIiIiIiKiOuIumhIrPqDKXLqaiKemrSsXJ6DK1+KvfzQGAOhtCyz4+Ff2tgIA3ugfX/AxiIiIqg0DKiIiIiIiIioLFgVUhgrv6EZlILq+CSo8n8oKo8IxDUD5W/wBmLXFX99oFADQ2xZc8PGv7G0BALzRNw6Af1kSEVF9YEBFRERERERERFRH3KFUpc+gkiUxfb4TMSsc8ntQQdUyWwXVmBVQrVhEQLWtpwWAFXbpBgMqIiKqDwyoiIiIiIiIiOqIUNPTlYqnwnObRcmqoKr0Eipk5lCFnYDKgwqqNjugGs8zg6p/dPEBVXNAwQq7RWBKNxZ8HCIiomrCgIqIiIiIiIioTtVyCEMzy5pBVQXfAz7Funw1lfCuxV9L0GrxN5qnxV86oGpdeEAFAJd1NQEAkgyoiIioTjCgIiIiIiIiIiLKUQW5zYK5q+ikKkioVCn78pVfKf/lrLZQ/goqwzBxfjwGAOkKqIW6rNsKqFIaAyoiIqoPDKiIiIiIiIiI6girpkh0XQ0Sq+AbQpkWUJW/gqrVbvE3lhNQjUaTSOkmBAHoavIv6jXSARUrqIiIqE4woCIiIiIiIiIiqiPuUKoaKqhkKfsccyuqyqHRLwMApuJa1vqhyQQAoD2kQl7keW1mQEVERHWGARURERERERFRHan8OKIyCFVQWbRQYtYMqsp/n3JOiJYbWJVDgx1QTc4QUHU0+Bb9Gj2tATT4ZMBc9KGIiIiqAgMqIiIiIiIiKguTF10rjsC4qi65w7fqCKiyL1/ltvwrh0a/AmDmgGpJ4+IDKlEUsLGrEYLAvyyJiKg+MKAiIiIiIiIiIqoj7kjKg6xn3nIrpnIrqsrBafGX1A3EU3p6/dCUHVAVoYIKANYuCRXlOERERNWgCj6GEBEREREREVGx1HLrOiqMWG0VVDkpmiKX/3JWgyqn708lMlVUxaygAoA1SxqKchwiIqJqwICKiIiIiIiIiChH5cc2C5cVUHlQjTRfSs45KmL5L2eJomDNh0J2m7+iB1QdrKAiIqL6wYCKiIiIiIiIyqIKCjXqAv8YyP3/olQF/2NOa/EneXPOTpu/yXgqvY4VVERERAvHgIqIiIiIiIjKwjS9PgMiArKrpqqggApKTos/L2ZQAe6AylVBZc+g6ijSDKretiBE8C9LIiKqDwyoiIiIiIiIiIjqiDveqYYWf+5ASpEEz+aoZVr8ZSqoxiJJAEBbSC3Ka6iyCA9GbBEREXmC/+QRERERERER1ZEq6OhWGWr465Q1g6oKviFkVwWV7MH8KUejXwGQqaAyTRPhmBVWtQSVor2Ol++RiIionPgvHhERERERERFRHXEXTVVBAdW0Ciqv5Lb4iyR1aIbVjq85ULyASqqGPxQiIqIiYEBFREREREREZWFyrkpF8Ko9WrURariEqtq+B9wVVLnzqMopt4LKqZ5SJREBRSra67DFHxER1Qv+k0dERERERERUp6osp6AicRfoVENYpbhOWK6ICiormBqPWvOnmgJKUb+OrKAiIqJ6wYCKiIiIiIiIyqKWK1KIqkk1hFJu7lDKy/lMIdUKqKIpHQBKMn8KADzM4IiIiMqKARURERERERERUR1xF+iYZuW33nS3+FM97H8XUK3XjiXtgCpqBVTFnD8FZFdQJTWjqMcmIiKqJAyoiIiIiIiIiOoUCzVmVmVFRvPirqAyKj+fym7x52H7u4BdQeUEVONOBVWRAyrhzv8GALhP+z1cDMeKemwiIqJKInt9AkRERERERFQfTFTBlXCiOpBdQeXdeRTKXUHlvl9uAUUCML3FX3ORW/wJm/8dfq/hYbw5DFw5FsPK9tC8nh9P6RgIx7GsJeBpxRkREdFcGFAREREREREREdUR0VVBVQ3BsbtqSvVwQFNQtQKqWFIDAIyXqMUfALS0LQGGh3BufH4VVD974zz+4qf7EUnq6Gry4yvv3oK3bl5a9PMjIiIqBv4aBRERERERERFRjhru8JfVvrA6KqhcLf68rKByAiq7gmoybgVUjf7iB1TLWwIAgPNjhQdUjx24iP/4o72I2C0IBybiuOcHr2HP0UtFPz8iIqJiYEBFRERERERERFRHsiuoKp8sulr8eTmDymnxZwdAzm2DTyr6a/W02gFVgRVU4VgKf/4v+wEAH7muF4f/y9vxniuXQzdMfPHH+zBhh2lERESVhAEVEREREREREVEOoYZLqLLeWxWUUClZFVTet/iL28FUJKHZ64s/QaO72Q8AGAjHC9r/n1/tx1g0hTVLQvird12OgCrhb963FWuXhDASSeL+p08V/RyJiIgWiwEVEREREREREVEdEVBdFVSiq2pK9DA5TFdQ2S3+nFZ/oRJUUHU0+AAAw1OJOffVDRP/54WzAIBP3bQGqmxd7vPJEv70jk0AgP/93OmCjkVERFRODKiIiIiIiIiIiOpItVWHuUMpycsWf84MqjJUUC1ptAKqocm5Q6U3+sbQNxpFo1/Gu69clrXtjsuXYuvyZsRSOv751XNFP08iIqLFYEBFRERERERERFRHqiyfgiRURgWVE0QlNAO6YaZnUDmt/4rJCahGo0loujHrvr8+fAkAsGtj57SwTBAE/PvrVwIA/umVPhhGNdTMERFRvWBARURERERERERURwRXyFMFI6gqrsUfYLX3iyRLV0HVGlQhCtafz2gkOeu+Tx4ZBAC85bLOvNvfta0bjT4ZZ0eieOXMaNHPlYiI6s/XvvY17NixA42NjVi6dCk+8IEP4MyZM/M+DgMqIiIiIiIiKotquBBed6qtlIaKotr+2N1d/Tzs8Ae/krmMFkvq6VZ/pZhBJYkC2u05VEOzzI7qH43i2OAUJFHAbRvyB1RBVcbbLu8CADx6YKDo50pERPXn6aefxp/8yZ/gpZdewmOPPYbR0VG84x3vgKZp8zoOAyoiIiIiIiIiohz1EqhWwzwq99wpL2dQCYKQrqKKJXVEEnZAVYIKKgDoaJh7DtUzx4cAADtWtqI5qMy43zu2WAHV4wcH2OaPiIgW7bHHHsPHPvYxbN68GVdeeSW+/e1v48iRIzh06NC8jsOAioiIiIiIiIiojlRDKOUmVsgMKgDw2VVUcU1HLFW6GVRAZg7V8NTMLf72nwsDAK5e2TrrsW5a34GQKuFiOI4DF8LFO0kiIiqbH/zgB/ijP/oj7NixAz6fD4Ig4MEHH5z1Oa+88gruvPNOtLS0IBQK4brrrsOPf/zjop9bOGz929LW1jav55XmVzyIiIiIiIiIclTbRXGiWlV1M6jcAZWXPf4AqJIVUI1HU+l1IV+pKqhUALNXUO0/b10Q3Lq8edZj+RUJ169tx68PX8KLp0awraelaOdJRETl8Zd/+Zc4e/YsOjo60N3djbNnz866/549e3DHHXfA7/fj7rvvRmNjI37yk5/grrvuQn9/P77whS8U5bx0XccXv/hF3Hnnnejp6ZnXc1lBRURERERERGVRDRfCiajySK6rVx7nU+kKqrGoVdUkCIBPLs3ltSV2i7+RGWZQJTQdxwYnAQBb5gioAOC6Ne0AgBdPjRbpDImIqJweeOABnDlzBkNDQ7jnnntm3VfTNHz605+GKIp45plncP/99+Mb3/gG9u3bhw0bNuDee++dFnB96UtfgiAIsy65TNPEPffcg76+vjmrufJhBRURERERERFRnRLAsjaqfO4KKsnjUsxMBZUVUIVUOe8Fu2JoCVoVVGOuai23owOTSOkmWoIKeloDcx7PCaheOT0KTTcgS/y9dSKianL77bcXvO+TTz6JkydP4hOf+ASuuOKK9Prm5mbce++9+PjHP47vfe97+Ku/+qv0ti984Qv4+Mc/XvBrmKaJz3zmM/j1r3+NZ555BkuWLCn4uQ4GVEREREREREREdcpE5Zc2ugOqUoVBhfLJ1ryp0YgVGgVKNH8KAFqDCgAgHMs/g+rA+QkAVnu/Qr4ul3U3odEvYzKu4dDFCbb5IyKqAJOTk5iYmEg/9vl88Pl8iz7uU089BQB429veNm3bHXfcAQB4+umns9YvWbKk4JDJNE3s3r0bv/zlL/H0009jxYoVCzpP/qoEERERERERERFVLMnV18/roh/VbucXjtkBlVK6gKrFDqhmqqA6OTQFANiwtLGg40migGtXW8PrXzw1UoQzJCKixdq8eTOam5vTy9e+9rWiHPf48eMAgPXr10/b1tXVhYaGhvQ+C7F792788Ic/xMMPP4xAIICBgQEMDAwgmcz/SxUzYQUVEREREREREVGOyq8rKo5qaPMougIq0fMKKiugmoynsh6XQqbFX/6LfWeGIwCAVR2hgo953Zp2/PrwJbx4ahR/eMvaxZ8kEREtyqFDh7B8+fL042JUTwFAOBwGYLX0y6epqSm9z0J861vfAgDcfPPNWev37NmD2267reDjMKAiIiIiIiKisqiXC/5EVFyufCorrPKCmg6oNACATyllQGW3+Juhgur0iBVQrW4vPKDaaVdQvXZ2DKZpet4ykYio3jU2NqKpqcnr05g30yzOJ3u2+CMiIiIiIiIiqlPVMINKEtwVVB6eCDIzqJwKKrWEPQdb7Qqq8Vhq2oVATTfQPxoFAKzqCBZ8zI1djZBFAeFYChfC8eKdLBERVRSncmqmKqmJiYkZq6vKiQEVERERERERlQV/T5+IFsJdNSVVSIu/CaeCSi7dDKrmgFVBpRsmJhNa1rYL43GkdBOqLGJZc6DgY/pkCes6GwAAhy5MFO9kiYioojizp/LNmRoYGMDU1FTe+VTlxoCKiIiIiIiIiIgqlnvulNct6XJnUKklnEHlVyQEFCsAG49kt/k7Y7f3W9kWnHfbw83dViupwxcZUBER1apbb70VAPDEE09M2/b4449n7eMlBlRERERERERUFpXfSKz+SF73SyMqgLuLntffs9NmUJUwoAIyc6jGY8ms9RfGYwCAntbCq6ccm5dZARUrqIiIatdb3vIWrFmzBg8//DD27t2bXh8Oh/HVr34Vqqriox/9qHcnaJO9PgEiIiIiIiIiKq8P7OjB6eEIrl7Z6vWpVKwizf6mIhAragZVTkCllK7FHwC0BFVcDMcxFs2uoHLmR3W3LCCgsiuoDrGCioioqjzwwAN47rnnAAD79+9Pr3vqqacAADfddBM+9alPAQBkWcYDDzyAO+64A7fccgvuvvtuNDY24ic/+QnOnj2Lv/3bv8WqVau8eBtZGFARERERERER1Zn/9v7tXp8CUcGyAqoKqaCasmdCqVJpK6hanQqqaHYF1UW7gqq7yT/vY15mB1R9o1FMxFNo8iuLPEsiIlqoXbt2QVEU7N69G7t375513+eeew7f+973stY9//zzeP7559OPnYDKOfZzzz2HL3/5y/jRj36EVCqFrVu34utf/zruuuuu4r6RBWJARUREREREREREFcvd1k/0fAZVdsWUTylPi7+xSE5AtYgKqtaQis5GHy5NJnDy0hSu7GUlJRGRV/bs2YOenp6C9n3wwQfx4IMPzuv4O3fuxKOPPrqAMysPzqAiIiIiIiIiIqpT1dDK0J1JSR4HVGrOzKlSV1A1B6yAymkp6LgYtiqoljXPv4IKAFZ1hAAAZ0eiizg7IiKixWFARURERERERESUw0QVJDd1QqqgGVS5AVWpK6gafFbzI6elIACYprmoCioAWN1uBVRnRiKLPEMiIqKFY0BFREREREREREQVK6vFn8cJlZzz+rkt/4qtwWdXULkCqom4hmhSBwB0L7CCamVHEAArqIiIyFsMqIiIiIiIiIiIqGIJQuXMoJoeUJW4gspvV1C5Wvw586hCqgS/srCAbBUrqIiIqAIwoCIiIiIiIiIioorlrqCSvK6gypk5VeqAqtFu8TcZT6XXjUatgKo1pC74uCvbWUFFRETeY0BFREREREREZWFypA9VE36/Vgz3DCqPC6igSNknkDuTqtga/dNnUDkVVG2LCqisCqrRSBLhWGqOvYmIiEqDARUREREREREREVUsdygled7iL/tSmiqVp8XfpKvF36gdULUGFx5QNfhkdDT4AAB9rKIiIiKPMKAiIiIiIiIiIqpT1VAo5s6kPJ9BlVNBldvyr9gafHkqqKKLr6ACgFV2mz/OoSIi8s6uXbuwefNm3HfffV6fiidkr0+AiIiIiIiI6oPXrbmIqDoJyPzlIXo9gyqngiq35V+x5WvxNxqxWvItpoIKsNr8vXp2DGcZUBEReWbPnj3o6enx+jQ8wwoqIiIiIiIiKgvOoCKihciuoPLuPIDpFVRSiU+owacAAKbiGkz7L9HMDCplUcde3eFUULHFHxEReYMBFRERERERERERVaysGVQeJ1S5FVNyqQMqu4JKM0zEUwYAYNRu8de6yBZ/K9tDAMAKKiIi8gwDKiIiIiIiIiIiqljuuVOC1zOoclr8SWJpL60FFSkd0E0mrNZ+6QqqRbb4622zKqj6R2OLOg4REdFCcQYV1Q3TNGFEojAmwjANA4IoQvD5ILW0QJAkr0+PiIiIiIiIKgg7UlYOdyQleR1Q5VZQlXgGlSgKaFBlTCY0TMY1dDYCY3YFVcsiA6ruZj8AYGgqAU03IEv8PXYiIiovBlRUk/SpCMZf/i0GX3gKkeNHIZ09D3VoApJmTNvXFAWguQm+7mUIrF8Pdd06+DduRGD7dkjNzR6cPRERERERERE5KmoGVU7FVKlb/AFA0CdhMqEhltQBAFMJDQDQ6F/cZb32Bh8kUYBumBieSqLLDqyIiIjKhQEV1YzIwHkc/NE/IvXrp9F88hIkw/otq4ac/VISYAiAaAKKDgiGCYyFkRwLI3nocGZHQYBv/XoErr4KoeuvR+iGGyA15B6NiIiIiIgKZbImhajimFXwv6W7rZ9YYRVU5ZiJFVCsri/xlB1QxYsTUEmigKWNPlwIx3ExHJtXQLW3fxxtQRUr2gKet10kIqLqxYCKqloiPIa9//wPmPrVo+g6PIRG1wfrgRbg5EoF0Z5GCMvaIC5bArWtDWawAbqvAZPQMTI1hMmh8whfOAPf0ARWDAMrhkysHjCxbMxE4tgxJI4dw/gP/wmQZQR37EDDrbei8a23Q+3p8ex9ExEREREREdULYcYH5adMq6AqfVs8vx1QxVI6dMNExK6kavAt/rJeV7MfF8JxDE7EC9pfN0zcff8LeOXMGADgw9f24q/fs3XR50FERPWJARVVnWQ8gr0//y7Gfv4TdL05gCYNaLK3ne4Gxtcl0dEdwxVKHLt060MbJgEczTmQpAKNXUDTcpg9mzCyqQNHVBmvG1P4P9ELOHP+JNadS2Fzn4krTppYNqYh+uKLiL74Ii59/evwb92KpjvvRNPb74DS3V3GrwARERERUXUSvL6yTDQPZjWUFhVBNRS/VFKFjicVVKpTQWUgktTS6xsWWUEFIF01dWG8sIDqgWdPpcMpAHjopT6sbA/iD29Zu+hzISKi+sOAiqqClohj78//AZce+QmWvjmMxgTQaG8baAPG1qWwpmcSb/drEEOdQOMKwN8MiAogyoCRAuITQGISiI4AkSFATwLjfcB4HwQAHQBushcAiAkCXm4NYc+qNvzX20XIYzquOmFixwkTm/tMxPfvR3z/flz6+tcR2LgCzTdfgaabtkNqarB6JJhG9mLorsf2fcN+DACCaP1kIAj2fRGA674gAIIEiM4iuxb7sSBNX5e+n/sce33ucyrogz8RERERERGVVjXkcO4MyOufWJWcgCr3cSn45UwFldPeT5VE+Oz1i9HZaAVUQ1OJOfc9MjCBrz16BACwuiOEd23rxv948gS+/thRXL2yDVevbF30+RAR1Ztdu3ZBURTs3r0bu3fv9vp0yo4BFVWs4ZNv4OC//E9MvbIXS45H0RgDVtvbxhqAS+s0rNi2BDfveCvk5VcAnZcDbWsAqYBvaz0FTA4AExeAiXPAxEVg4jwQPmevO4/A5ABujUzh1sgUDAD7fSoe2xjE/7oqhFRCwnVHTFx/2MCmfiB2tB+xo/0Y/M6/orEnhubVMYSWJqyMqdpMC7nmCsPybS/ic9LPmyFgmxbEyXM8L9855L6WDJShTQMRERFRveEMKiJaiEqqvpRyflYsawVVUsdUojjzpxxLGn0AgKHJuQOqH77UB8AK5X71H25GQJVw4tIUHj0wgK/+6jD+v3uur6hqNyKiarBnzx701PEoGQZU5JlkNIzY2ACiYwMYPb0PY6cOYqrvDPTTg2i8mMCSMaAT1gIAE0FgcIOMrus345o7PgZl9c2Ar3G2l5iZpAAtK6wF1+bfJx1inYcYPoftsTFsj4fxxXgYL0+dxS+b+/Hfrx6FOqnjhsMmbnvTwMohARN9QUz0BSE3Smje0oTmbS3wdfgzlVCi5KqIciqlkKm6gmnfd1VhwXRVYelW5ZWhuRY9c9/Up6/LeqxbFWUzMXVA1wF97g+ntU0oQRiWZ7swSzA3U+WbIM7xnIW+Vr73JbKqjoiIiIiIPOX+kcTrAETOCaTKMYMq4JpBNWlXUBWjvR8AdDSoAIDhOSqowrEUfvL6eQDAAx+7Jh2a/effuxxPHrmE186O4bcnR3Djuo6inBcREdUHBlRUdlo8isNXXQ3ZyKwTAbTbi9vFTiC2oQWdN9yIK9/7eagty8p3olkhlms1gOvt5S+1OJ4+9zR+cdkv8Oc7n8WKAQ23vWngpkMmGid1jLwwhpEXxhC44go0v/c9aHrHOyA1LjBUK7ZpIZdmh2AzBF/GLMGXOVMYNp/n5TvGLMc1823P97ycfXLfo2nM8AUyrSBvtjCvXswabMnW/yuSAojK/O5Lquv5s91XrMrIOe+7X0O1tjn32b6SiIiIiOaJ9X6UjyKVv4LK7wqonAqqBl+xAiqrgmqugOrR/RcxldCwYWkDbnaFUJ1Nfrz/6h489FIf/nXvBQZUREQ0LwyoqOxkfxCm6/ObIQCTQWCyWUCyRYXYuwTt267E2tvuxmVrrvLuRAvgl/24Y9UduGPVHRiLj+GxM4/hF1t+ge8P7MPVJ0zc9qaJK0+ZiO3di9jevRj86tfQ+Na3ouW970Hw2msheNlGThQBUQWgencOlcAw8lSd5Qm+ShLeafmDQmefWavh5hvezfEeTX3mr1GtVNXlC65y7+cNuxZzX7Xvy7Pfl32ZQM7ZJvns82G7SSIiIiIqnWpovSmWIQQqVO6plGMGVUC1fiaIJTMzqIoeUE0mZ91vb/84AOD2y5ZO+/N4+5YuPPRSH548egmGYVbUnxcREVU2BlTkic4f/yN8Te0ItnRBCbZAlBY/2NNrrf5WfHDTB/HBTR/EmfAZ/PL0L/GDHY/gHwfO4eYDJna9aaBnJI6JRx7BxCOPQF7WjZZ3vxvN73431N5er0+/fokiANEKBOqZ01JyXiFcynqspwA9aT3WtXnetxfDPoauFed+vso4pyKu2oriBCkTXMnuAEvNWZQ8+82x72zBWN7j5TsWK9SIiIiIqLSEGe57ITd8KcsMKruCKp7SMZWwfqAp1gyqDnsG1UgkAdM0Z2yh+PLpUQDA9hUt07btXN2GoCphaDKBgxcmsLWnuSjnRkREtY8BFXmi6/JbvD6FklrVvAq7r9iNz2z/DPYO7cUjVz2C/+f0Y1hydgK73jRw4yEToQsXMfy/voXh//UtBHfsQPN73oOmt98BMRTy+vSpHglCZo4WfF6fzeIZhiu4SlmB2pz3cwOz+d5PLjygc85BT1jn5GbqgBazlootYhPyBF7ObZ4QbFow5jzXufXZoZj71pe937Tn5HkuK9CIiIiI5iR4HvnMrZJ+F0oSKmQGVZEqqNpDVleVlG4iHEuhJTi9y8r58RhODUcgiQKuX5s7nAHwyRJuWteBJw4N4skjlxhQERFRwRhQEZWQIAi4svNKXNl5Jb6080t45twz+MV1v8BDZ57ClUdSuO1NE9tOm4i++iqir76Ki//1K2h++9vR/J73ILhjh7ctAImqmSgCoh1qVJvccE1LuAKsZJ7FXq8lcvaxA6/cdbMez7VOy7POvWQx7deqwARNVOYIshYYkE0Lw5TCn1tJV1iIiIio7lVFiz/X5yevP0qJQvkrqPyqHVAldUSTVnv2gFqcS3p+RUKjX8ZkXMPwVCJvQPXa2TEAwJZlTWjy5+8+8jubOq2A6uglfPb29UU5NyIiqn0MqIjKRJVU3L7ydty+8naEE2E8fuZxPHLyEfzDyTdw6wETt75pYNlYHOGf/gzhn/4MWncHWt71u1jyu++Gb/36GcvsiajGVEO4ZpqZ6rPFBmPp5ycy+2TdJuywLN9tnv2NnB6ORgpIVlhfRyk3wHIFWbLfvnXdz7vNfiyp9vo822Z8rt+af8Z/V4iIiKhKVNKnltzfIy3LDCpXBVUspWetK4YlDT5MxjUMTSaxrnP69v3nxgHkb+/nuHXjkvS+E/HUjEEWERGRGwMqIg80+5rxgY0fwAc2fgBD0SE8c+4Z/Kx/D0Ze/i1u2BvH9YdNBC8OY+rb38XUt7+LyWUtEN9yI3rf/UF0Xn6116dPRPVOEOwWfQqgVlhbUtMsMOSaIwxLh2qFPnem10pOD82ccG72OdSlJYg54ZXqCrH8BYZieYKvadtyAjT3cSWFIRkREREVxvWRweuPD15UULlnUMXtgMqvFK/jSkeDD6eGIxieyt8R4djgFABgU1fTjMfobg5gdUcIp4cjeProEH53+7KinR8REdUuBlREHlsSXIL3bXgf3rfhfYjfGsdLF1/CE6f2IP7k01j7+iCuOGWi8cI48P1fYuT7v8SBTgUXd66GeusNWHPNW7ChbSMa1Uav3wYRUWUQhExgUikMww6l5gi50gFX3H4ct/Z3P9adx7Nty/NcPZHdmtE0MrPNvJQvFFP8rlDLP//Hhewj8SMwkWcqv5MYUZrJ79eKkRsKeSk3kCrHDCp/OqAyEE8ZWeuKoaPRaus3U0B14pIVUK1f2jDrce7c2oX79pzE4wcHGFAREVFB+NM5UQXxy37cuuJW3LriVuBWYDg2jFdPPoOLj/0rmp7bjw3Houi+lEL3L44BvziG0YYH8eA6ASc3t0K7ejNWdW7EutZ1WNu8FisaV6DZ18zWgEREXhNFQLRDEi8ZhisAywmv0qHWXMFXYn6hWNb+eeaUOa+JcHm/FqK8uBBszpDMBygBO3wLZB6ztSIREdGCVNK/nl5UUKmyFYIlNQOJErT4awtZAdVYdHpr7FhSx/lx65ea1i6ZPaC6fk0H7ttzEvvsloBERERzYUBFVME6Ah14+5b3AlveC3wRCA9fwIlHHkZsz9No2ncabVM6bt9r4va9o0j+6Dkc6n0eL68S8J2VAs4sBUK+RvQ09qCnoSd92xnsRGewEx2BDrQH2iGL/GuAiKguiCIgBqygxCtONVne4CthV3XZ61Ix17oFPNbiQCqeWeeuIDM0IDllLeUkiNmBVW6A5TxW/K5bP6AEXevsZdp6+7Hsz96HgRgREdUA9y9eel3ZlptHyWUIqBTJCqgSupGeQVXMFn/NAWte1ERsekDVNxoFADT5ZbQGZ58rtbWnGQDQPxrDyFQC7Q0V1NWAiKhC7dq1C4qiYPfu3di9e7fXp1N2vDJNVEWaO5bh6k98EfjEF2Ekk4i+/ApGn3wCk0/ugTowhCtOm7jitPVpfcoPHOwN48DKCRxceRi/6QDMnItUAgS0+duwJLgEHYEONPua0aQ2ZRafdduoNsIn+eCTfPDLfutW8sMnW+tEofQtDYiIqAZ4WU1mGK4qsbgrwCrxY3fVmGkAqYi1lKu7ohNYZQVf7pBrnoFX+n6efSQOQyciotKopF+3yK2YEstYQZXSjPQMKl8RK6icgGo8On1I6unhCABgVUdozg4tzQEF6zsbcPzSFF47O4a3Xd5VtHMkIqpVe/bsQU9Pj9en4RkGVERVSlRVNNx0IxpuuhHmf/p/kDxxApHf/haRF19C9JVX0DA1hWuPmbj2mBVYJYMqBlY14lSPjAPdKbzeMYUp1cBIfAQj8ZHFnYsgWgtESKKUvi+KIiRBggABkiClH0uCtY+zThbkzGPBOoYsygjKQYSUEIJyEEHFXuQgmtQmtPnb0OpvRau/Fe3+dgTkANsZEhHRzEQRUIPWUk7pqrH5VILFXbdRO/iKWeuy7ufuYz/PcP32c7qN4ljp36sg5an2misUmynwCtrbgpl9FPvPTw5wjlg148c1ospTBbO2xAqqoPLi507VrqBK6qWZQeUEVOE8FVRnR+yAqj1U0LF2rGplQEVERAXjT3ZENUAQBPjWr4dv/Xq0fexjMDUN8UOHrLDqxRcRfeMNqNEYeg+NoPcQcJv1JEhrV0O/bA0m13RipLcFQ8uCGDcjmEhOpJfJxCQmU5NIaAnE9TgSegIJLQHN1NKvb5gGDNOwH3jxFQB8kg/t/nZ0hbrQ3dCN7pC1dIW60B3qxvKG5QgqZb4oSURE5EXVmK7ZIZYdXjnBVSqWWT8t5MoXeOUeY4bjOUwdSE5aS6lJak6AlRNkKQFACeVZl2c/NZR/m6SyRWIpVMGFcCKqQK6/juvxrxH3DKpYCWZQzRZQnUkHVIX9PH1lbyt++HI/9vaPF+38iIiodjGgIqpBgiwjsG0bAtu2AX/4aSuwOnoUsb17Edu7D7G9e5Hq74d+4hRw4hQaATQCWCVJ8K1bB//mzfBv3gb/5Zvh37QJYnD6B1HN0JDQE4hpsXRANdOim3rWfd3UoRvWOs3U0ts0Q0vvk9STiGkxxLQYoqkoIqkIopp1G06GMRYfw1h8DKPxUSs00xO4ELmAC5ELwKX8X5fOYCdWNa3CyqaV6aW3qRcrGlZAYVsiIiKqFZIMSI2Ar7H0r2WaduVXNKfya6bAK18o5gRhsex90ks0c+tcltST1hIPl+69CWKeQCs4QxiWp9Jr1v2Cmcoxka2SiSpVPQYhlcr9+wKm1yVUHkhXULla/BV3BpUKIH9A5W7xV4hNXdbnj5NDkSKdHRER1TIGVER1QJBlBC6/HIHLLwc+/GEAgDY8bAVW+95E/PBhxA8ehD42hsTRo0gcPYrwT39qP1mAunKlVaG1YYO1rF8PdWUvQkoIIaWwD6mlYpomYloMo/FRDMeGMRAZwMXIRWuZsm4vRC5gMjmJS9FLuBS9hJcHXs46hiRIWNawDKuaVmFV8yqsalqF1c2rsappFToCHWwdSERENBNBsFvzlaFCzB2GuUOrrNucdcnozNuy1tmPkxGrEgywZoYlp6yllOTcECsn5FJDObdBqzosa3sgzzr7ePwcQ0Q1oN7/JkvPoNINJLRyt/iLAgBWFtjizwmyhqcSmIin0OTnL4MSEdHMGFAR1Sm5owONt9+OxttvB2AFPdrgIOKHDiF+8JB9exDapUtInjmD5JkzmPy3f0s/X/D5oK5dA//6DfBtyIRXcmdnWQMdQRDS86l6GmceKDgeH8fZybM4O5FZ+ib6cGbiDGJaDP2T/eif7Mez55/Nel5ICaWDq5VNK7G6aXX6fkAOlPrtERERkSMrDGsr3evoqdmDr2Rk9lAsvW6W/bR45vU0u3IsNlqCNyPMEHYFMy0QswKvPCFX7nr3Os4CI6IyqfdfGsxq8ZcsQYu/YCagMk0z/fWOp3RcDFv/ZhXa4q/Jr6CjwYfhqQTODEewraelaOdJRES1hz9REBEA6wO/0tUFpasLjb/zO+n12vAwEseOIX7sGBLHjiNx/DgSJ07AjMWQOHQYiUOHs44jNjRAXb0avjWroa5eDXWVc7sSos9X7reV1uJvQYu/BduXbM9ab5omhmJDODtxFqfDp3Fm4gzOhM/gzMQZnJ86j0gqgoMjB3Fw5OC0Y3aFuqzwyg6wnPCqK9QFUWC7HiIioqokKYDUDPibS/cahuFqZ5hb6RV1rYtktiUj+de7n+/skw7ATDsoK1GbJUktoLormBVumcPLAdgXOY89kb2vsyghQFZLc85EVPXqr8FfJqBK6AYkrRQt/qyAKqWbiKV0BFXrcuHQZMJ6fUlEW6jwv5fXLAlheCqB0wyoiIhoDgyoiGhWckcH5I4OhG64Ib3ONAykzp1D4tgxJI4fT4dXyTNnYExNIb5/P+L792cfSBCgLF9uhVWrV8G3ejWU3l6oPT1QurshqN5chBAEAZ3BTnQGO3FN1zVZ25J6Ev2T/TgTPoPTE6fTwdWZiTMIJ8IYiAxgIDKAFy++mPU8v+RHb1Nvet7V6ubVWNawDEuDS7E0uJTzroiIiOqdKGbCmFIw9AJCrjmCr7zr7cow02ovlZkFNl74uSX/GMDN1v2Hf3/m/UTF/ho1WCFW+r6r5aHzWA267tsBVzrwcj2fVV80T/U466hS1Xf9FKBI1lcgqRnpr4VPLl4FVUiVIIkCdMNEOJZKB1SjkSQAoC2kzquKbU1HCC+fHuUcKiIimhM/nRPRvAmiCLW3F2pvb7pFIAAYySRSZ88icfo0kqfPIHnqFBJnTiN56jSMyUmkzp1D6tw5RJ7NbqMHUYTctRRqzwooK3qs0KpnBZSe5VCWL4fc0QHBgwHiqqRibctarG1ZO23beHwcZybOTKu66pvsQ1yP49jYMRwbO5b3uO3+dnSFurA0uNS6DS1FR6ADbf42tPpb0eazbv1yGeZ5EBERUe0RJcDXaC3F5p4Flg6x5gi5XKGYcKwHGLOP1b19+r66dTEURsoKvuYTfhVC9tsBlyvQyg241Iacqq6ZQjLXfQ8+qxIVC2O4yueTMmFUKWZQCYKAloCCkUgS4VgK3c1WO3t3QDUfq+05VKeHGVAREdHsGFARUdGIqgrf+vXwrV+ftd40TegjI0iePm2FV6dOW3OtzvUjde48zHgc2oWL0C5cBF5+efqBZRly5xIoS7sgL10KZelSyF1dULqWQl66FHLnUsgd7RD95Qt0WvwtuMJ/Ba7ovCJrvWZouDB1ISu8OjtxFgORAQxGBpE0khiJj2AkPpK3baBbUA5agZW/Dc2+ZjQqjQipIetWCaFBbUCjat9XGuCX/fBJvvSiSip8kg9+yQ9ZlOu+bzsREREVgXsWWHABs8B+tBcYO2/d/6Nnpm/XkpmQKxkBklOuAMy+n29JzbDeOYap28ePW0uxZ37lBlqFBly+Bnudcz+UecxqL6K0ev9Rxmnx51bMFn+A1eZvJJLEeDSVXucEVO0NCw2opop3gkREVJP4iZeISk4QhHSrwOA12W30TNOEPjyMZP85pM6fQ7K/H6l+q9Iqee4ctIEBQNMyAdYsxGAQUns75LY267a9DVJb7m0bpNZWyC0tJWkrKIsyept60dvUi1t6bpn2XscSY+mwaiCauR2NjWI0Poqx+BhGE6PQDA1RLYroVBTnp84v+rwECPBJPiiiAlEUIQkSREGEKIiQBRmiIEISrXWSIKW3S4IESXQ9tu+nF3H6frnPVUUVASUAv+SHX/YjIAcQkLMfB5UgmtQmNKlNCMgBhmlEREQ1as5KDVm1lkBrEV/UtCqz0oHXfMKvqUyFV3of12PnHTmVYpGh4p237M8EWr7GTHjlDrXUkP240XU/ZD1231dDgOzjVX6qWkKdN/nLF1DlW7cYTfYcqnBsekA13wqqNUsaAACnhyIwTZM/3xER0YwYUBGRpwRBgLxkCeQlS4Crrpy23dQ0aMPD0AYGkBoYhHZp0LodGEDq0iC0gUFog4MwUykY0SiMaBSp/v6CXltsbLTCqtbWTHDV1gqptRVSaxukNve2Noih4KI+WAuCgDZ/G9r8bdjcvnnG/UzTxFRqKhNYxUcRToQRSUUwmZpEJBnBVGrKWpKZ24SemLakjwkTcT2OuB6f8XUrhSRIaFQb0aQ2Zd02qo1o8bWg1d9qLb7WrPsMtoiIiCgvQbDCGdm3sKqvmZimPZcrmhNeTc1Q2ZUnHHNuE1NActK677Q5dKq9osPFOV9RzhNeNeSEXnOFYY3Z7RBr/LMXW99RpZBEIT0jyqEUubVoc56AamSBAdWKtgBEAYgkdQxNJdDZyPb1RESUHwMqIqpogixD6eqC0tWFwAz7mKYJY2oK+sgItNFRaCMj0EdGoY26bkfHoI+OQBsZhT4+DhgGjMlJazZWX19h56IoOUGWdV9qa7Wqtlpas0OtlhYI0vz7gguCkA5kVjatnPfzHaZpImkkkdATSOpJxLU4UkYKhmlAN/X0rW7o09YZhgHN1PLuk++xYRrQjOn7a6aGlJ5CTIshpsWskEyzlpges261GCKpCCaSE9AMDbqpYzwxjvHE+Lzer0/yZQVXLb4WtPnb0qGWc7/F14JmXzOafc1QpeJX0REREVGdEAS7dV8QCHUU77ha0g6upuzgKmKFVwl3qOVsy90vZ71TIQYAhgbEw9ZSFEJ228JZK7nssMvXZN+31/sa7X0aWeFFs8r61qjT5FCRMgGVJAoQxeL+/9Lgsy4RRhNaet1oxPqlx7bg/H5u8skSlrUEcG4shjPDUQZURESz2LVrFxRFwe7du7F7926vT6fsGFARUdUTBAFSYyOkxkaoq1bNub+p69AnJqCPjUEfHYU2NmYFWGNj0MdGoTn3R0ehjVvbzHgcZioFbdCq2ErM+SoABAFSUxOkJR1Qli2zl+VQlmfuy0s6IJRoqLYgCOl5VNXANK0qr4nEBCaTk5hMTWIiMYGJZGYJJ8IYi49hLD6G8cR4usrMCeIGIgMYiAwU/JoBOWCFVWpzOrRqUpvS95vVZjT5mhBSQtYihxBUgggp1q0iKiX8ihAREVFdklVAbitetZeh56/UKij8igCJyenhF0xrcfYvBlG2w6smV3jlDrOaXEGXK9jK2td+ruwvSthVL3HZ9WvbvT4FKoAiiYinDPt+8b87g6r1y5WRpJ5eNxqxqqna5jmDCrDmUJ0bi+HU0BR2ri5i9SoRUY3Zs2cPenp6vD4NzzCgIqK6I0gSZLu1H9asKeg5RjQKfWzMCq/G7fBqhlBLHxuDHg4Dpgk9HIYeDiN54mT+c1EUKMuXQ121Curq1fbtKvhWr4bU0VFXLesEQUjPp1oaWlrw80zTREyLYTQ+mg6txhPj6faI6XXxcYwlxhBOhDGRnIBhGunKrvmEWm6qqKbDqqASREAKQJEU+CQfVFGFKlmLM//LJ/nS62RRzprXlb4Vsx/nW5f3Vpxje85xcrfJogxVUiEKpQlMiYiIyCOiBPibraUYTDMzlysxmRN+zVT5Za9LTNj37fWJSasNImBXeI1by2Klw67GPEFXvnArN+iytpuGPvdrVbGX7n0LTlyawg0MqKqCImU+pxe7vR8AhOwKqlhWQGX9amb7PFv8AcC6zgY8e3wYp4YjxTlBIiKqSQyoiIgKIAaDEINBKMuXF7S/qWnQx8ehj40hdekSUhcuIHX+vHVrL9qANTsreeYMkmfOAE89lf2aDQ1QV62Cb/16+DZugH/jRvg2boTcxt8+cxMEIR0Q9TQW9hsnhmlgMmlVaIWTYYQT9uK671RshRNhRLQIoqkoIqkIIqkIUob1m4RJI4lkIomxxFgp32JZqaIVqPlkX7oCL7241vklv/V1l4MIKAEE5WD6cdb9nFtFYtUZEVE9u35NO376xnmvT4MWQxAyM6gaOhd/PEPPDrsSk5kl72N7XXIyO+hyV3MVK+xK/gmA6637f7fFFWDlC7qaMoGYvzlz371enH/771Ja2uTH0qbqaL1WR7+3NyPJ1dJPkYsfUGUqqNwt/pwZVPPvyrG8xWrSfzFc+XOQiYjIOwyoiIhKQJBlyB0dkDs64Fu/Pu8+ZiqF1OAlpPr7kDxzBonTp5E8bYVVqfPnYUxNIX7gAOIHDmQ9T1rSAf8GK6zyb9wA/+WXQ129ekHzruqVKIjpNn4rsGLez0/pKUQ1K7CKpqKIaFZwldASVmilZ2Z/JfVk1iwwZ5t7ZteMt4ZR0D4zbZ9rn3yShnW+k6nJxX6Z85JFecYwK6BYFXR+yQ+/bC0BKQCfbAViATmQXp/ex3WrSioUUYEsynVVfUhEVE3ef3UPAqqEK3tbvD4VqhSiBPibrGWxDCMTas1YtZWniiuZE3TZtx+WfoNfGtfjWuEQEO5f/Pm5q7qyAq0m1+M51qsNQIlahFNlU9wBVQlb/EUTmZ8TRtIB1fwrqJzwc5ABFRERzYIBFRGRRwRFgdqzHGrPcoSuvz5rm5FMItXXh8SpU0gcO47E0aOIHzuKVF8/9KFhRIaGEXn++fT+YjAI/5Yt8G/dgsDWrfBv2Qpl+TJepC8RRVLQLFkBVzVzB1aaoaVDtLgWR0JPZC9a9uO4FkdMiyGqRRFNRbNvnfuu9UnD+uFWM7T0TLFSkgUZsiinA6v0raRAFnJu7XaLkihBERRIopRue+hsc9+695dFGbIgW/ddt7Iop4+jiDnHdPab4ZiznoN9fP6/TUTVShQF/O72ZV6fBtUqUSxq2HVDcgq/HR5Fp7IWSP2OK9zKreJyhV+JSes2PpHZptsTbJ0qr8mLizgxwRVyLSboClVVWZLgmghmwvTwTLwjuUIpuQQhZVC1LhE6FVQp3cBk3Lq/kBZ/3c1WQDUwwYCKiIhmxoCKiKgCiaoK37p18K1bB7ztben1RiSCxIkTiB89isTRY4gfPoz44cMwolFEX34Z0ZdfTu8rtbXBv+VyBLZvR/CqqxDYtg1iKOTF26EK5cyiAgCf5ENIKd33R8pIWYGWHVrFUtPDrUgqgoSeQEyLIa7H00FZTIshrsXT65zH6aBMj0EztKzX00wNmq4hrtfmD8TOfDKf5MvMPbPnmzltGmebgebsM9PMtKz7s8xTU0SFYRkREdUmO+xa1lOEsEtL5A+uZgq08q6fsFoXwrS3L/KXfQQxJ8jKDbScdU2ZtoVO0OVvAvwt1n2pPJeV+HEjO5RSS9DiL+SzK6jsGVRjdvWUKADNgfm36XYqqAYm4jBNk58ZiYgoLwZURERVRAyFENi+HYHt29PrTE1D4uQpxA/sR2z/fsTf3I/4sWPQR0cReeZZRJ551tpRkuDftAmBq65C8OqrELjyKihLizA3gKgAiqhAURU0qUW4yJOHZmhIGan0bUpPQTM169ZZ59ruXqcbuhVoGRp0Q4du6un1uqlDMzRoprVNM7TMOvv+tMdGznNct+7XmPYcU8s6vrMuH+d1k0YSSJXkS1qwaWGXE2aJ1n13BZt7Sa+bodpt2n456xRRST93rtfI2mZXsBEREZWN7LOWUMfCj2GagBbPhFjxcPZ8Lie0SkzOHXaZOmAa1jHi4cW9NyXoCq2aM/ez1uWGWzn78d/lgsglb/FnXSKM2hVUTnu/1qAKUZz/6zkBVVIzMBZNLahNIBER1T4GVEREVU6QZWsW1cYNaHnf+wAARiKBxJEjiL25H7G9exF943VoFy4ifvAg4gcPYuz73wcAKMuXpwOr4DXXQF2zhr/ZRlXJCR9qjWma6TBKN+zgzHXfmXHmnm+WMlLZM9CcmWizzEdL6amsfWY6lrPOzZlb5nVQNh+iIGaFWzOFWbnB1mwBmlNRlr4VVSiSkr3Orl5TRCVrm7NvvufU4vc1EREtgCAASsBaGhbxS2amCaRihVVvpcMue594OHM/FbGOl4pay9TAws9JbZgl3LLuC0oTALs96MgpYHDS3tZcN3O5JLG0Lf5yK6hGFzF/CrCqvNpDKkYiSQyE4wyoiIgoL/7ES0RUg0SfL1Np9e8/AgBIXbyI6OuvI/b6G4i+8ToSR44idf48UufPY+KRRwAAUns7gjuvQejaaxHcuRPq6tUMrIg8JAiCFYxABirkl4sN00iHY4WEW041m1MxllvJ5l7S68zZ9532XFe1nGZOf16+95A0k+nZaJVMFMRpgZcTcLmrydwBmzsAy7uPtIDnzLLdfS5O21AiIqpQggCoQWtp7Fr4cXTNDqtcodW0ICucP9xybrWYdSxnNhfOz3zapgDgIevBc38HvPiUe2tOuJVTsTVtXXPO0gIo/oV/LcpEkTL/xiolaPGXnkGVsD47LTagAoCuZj9GIkkMTsSxedncnRRSuoEfvtyHlG7iEzesWlDlFhERVRcGVEREdULp7kbzO9+J5ne+EwCgT0UQ27fXCqxeew2xN96APjKCyUcfw+SjjwEApI4OhHZeg+DOnQjuvBbq6lUMrIjqnCiI8Ek++CQfGtHo9enMyalCmyncSpmpaSFavn3zhmWuIC2pJ9O3mqGlK9RSesq6tVtPpvdzb9NTWRVxboZpIKEnkNATVVGlJglS3uBrptBLluT8YVmeMM1dWZZVhWaHd87zs/aZ4Tmcn0ZEtEiSDATbrGWhtGQBQVbYXj8B7LWeZoaWAFKHtV1PwprLFbaWBb8fX3ZoFWjJH2Q593Pfh1z66iB3BZVSguAmqGZXUI3HrA8ercFFBFRNfhy8MIGL4cLmwt7x/z6DU0NWdd5XfnEIJ/76HZAl/vILEVEtY0BFRFSnpIYQGm68EQ033ggAMJJJxN98E5GXX0b05VeswGp4GBO/ehQTv3rUes6SDoSu2WkFVtfuhLqKgRURVbZ0FVqVtMozTTNdEZYbfDnBlmZktufOVHPWZ63Lmc+Wuz19PDMFTZ8+sy091y3nmCkjBcM0ss5fN3Xoug7oHn0B58EdiOWGWk7LRXdIllvFNlPwlbUtT4g263Fy2jvy31giqmmyCsgdBc3mEgwT2Psr68Fb/grY+YB1PxXPDrIS4enh1rTAyw7BYuPWrWkAegKIXLKWgjycuftfl1izuHJDrELDLn9zQXO4smdQla6CygmoonYlVdC38DL+pc1WZdrAxNwB1X9//Eg6nHLc+PUn8dK9ty/49YmIqPJVx0/qRERUcqKqIrhjB4I7dgCfsQOrffuyA6uhYUz86leY+JX1w6G8ZIldXbUToWt3Qlm5khfTiIgWQRAEKIIVVFQD3dCzArVpoZa5+DAtd/+Ztrkr1fLtp5vZqZlTCRdDzKOv3txmqhZzh1nuOWY+yQdVVLPWq9Ls2/I+nmEbWzgSUcVR/Nay0LlchmG1F3SCq/i46769xPKsO5tzHGcW1+TFhZ2H2jhnkCXHlgGwPh8oegQY77NncDUWZQZXZgaVFUxF7KAqpC780mFXkxVQXZojoDJNE/ftOTlt/eBEAoMTcSxtqvwWjEREtDAMqIiIKC9RVRG85hoEr7kG2A0YiQRi+/Yh+vIriL78MmJ790IbGsLEL3+JiV/+EgAgL11qB1bXILRzJ5TeXgZWREQ1TBIlSJDgk3xen8qcdEOfNdTSDC2r5aKzONVr7haOWc93qt1yn6vPfIyZ9jFhZp2zsw3TR6l5wqkAyxtiudflbFNEOyArMAgrZJtUQLUBEdWOkv1IIYrWbCp/E4AVhT/vS7/M3P+zMzMHWXMFXim7Yig5aS3h/hlfUk7+OYCtAACl/7fA//tea4Mg5p+tFWjJU9FlPw60Zrbb87ecCqqUbiKpGYjZQZXT+m8hnPlVzjyrmTx+cDB9/7H/eDMafDJu+voeAMBn/+kN/NMfXr/gcyAiosrGgIqIiAoi+nwI7dyJ0M6dAHZbgdXefYi+/HImsBocxMQjj2DikUcAAHJXVzqsCl57LZSeHgZWRETkCUmUIIkS/Kjc38J2V5Olw7A81WFJI2kFXXrmfkJPZGaZ5Tyets1IpGeiOc+fdj/PTDSn6iyqRT36CmVIgjR3sFVoADbTfgWGcNXSQpSomlX0zxCBVmtpXcBz9ZQ9Y2t8epiVE3jJR1qAKetpiiwBsh/Q4laLQuf5CyEHgEALgr42AH8BAIj+9D8icmE7gC4Eh/YCbx7IBFrpcKsZkGav+HYCqrHo7AHVQy9ZJWk3rmvHpq4mAMCnb16Nbz97Gi+eGoVpmpX9PUBERAvGT9JERLQgos+H0LVWaz8AMOLxdGAVefklxPa9CW1gABP/+ggm/tUOrLq7Edp5DYI7r7VmWPX0ePkWiIiIKoosWvPSAgh4fSoArJZL6WArT3g12/2EnsgK0Ao+Ru4+rnXuCjPd1BHTYohp3rdoFAURPsmXabtYpCqx+eznk3zWnDOBs8uIqoqkAKF2a5mD/L1XgMPWjCx509uBD91rzeDKW7U1PkPYZd/Gxqxb0wC0GDAZgzJ5ESpSSEJB5MAvEE2FAHQhePxfgdO/yn9S6daELdZtTojVOtkJoA2j4+PA+dcz4ZavOast4ZGBSQDAPbeuTa/7/Fs34tvPngYAXAzHsaylMv5tJCKi4mJARURERSH6/Qhddy1C112LJfgTGLEYYnv3ZmZYvfkmtIsXEf75vyL8838FACjLlmXPsFq+3ON3QURERA5BENJBiNdM04RmaoWFZAUEaAvdljJSSOgJGKaRPjfDNKywrALmmQkQZgyv8oZe82jVOOM20fUarseyyLCMSsOce5eaJImZ/58UyQ53nBlcjUvnf0DDsNoKxsbsAGscgQenkEwCseu+gOihJmAICHZvAhqm0vsgNg4kJqxjzNGasM3oAfDfMDYeBr79QdcWId2G8LS0GkOTnwYAXHXwb4CzjUCgFQF/C9Y1t+NEWMA3f/4i/vbfrbUCLrWhhD0fiYio3BhQERFRSYiBAELXX4/Q9Va/cCMazQRWL72M2P79SF24gPDPfobwz34GAFCWL7cCq2t3Inj11WwJSERERACssEwRFCiigpAS8vp0rJll86j+WkyAltKtUCxvYKanoJmZIWUmTCT0BBJ6wsOvTka+8GqmSjBnTplP8sEn+7IfuxZVUuGX/DNuc99yVhnVElnKVBzJYhF+RhLFzGwquz2h3/9rhJMJxLd9BNHzR4ChYQRv+APgyi9nP1fXrJAqHW7Zt7GxTIgVG0frZBw4CIyhEXpjD6T4KJCKAjDTlV7fS90CAFgrnEdo7wNZL3Nr6iM4gTtx8sg+4PS77fOWM3O00pVbc9232zAqldvml4jq165du6AoCnbv3o3du3d7fTplx4CKiIjKQgwGEbrhBoRuuAGAFVhF33gD0ZfsGVYHDiB1/jzCP/0pwj/9KQBAWtKB4JVXIXDVlQheeSX8l10GQfX+t7iJiIiovjntGINK0OtTgW7oBVWCOdVfiw3Q0oFZnm2aoWWdW9KwZ5mlvPnayKI8Y4hV0OMZgjLnsTsoy92H4RgVmzuUEkv0S3w+2fq+TWgGoknr/+egmufSoSQDwTZrmUWrbgB/8ShMCAjf84Y1k0pL2iGWFWod/pdhYAC4a3MIWPHnWRVdd42P4H/3AcexAobog2gkAEMDoiPWMl9KMBNWOUuwzfW4Lc+6VkD2zf+1iIgKtGfPHvTU8QgMBlREROQJMRhEw403ouHGGwEARiSC6OtvIPqyHVgdOgR9aBiTTzyBySeeAAAIPh8CW7cicOWV6dBKamnx8F0QEREReUsSJQTEAAKy9/NZDNOYVuHlVH/lVoLlqwxzgi+nCsxZknoScT2etc+0x5r12F1RphkaNENDJBUp+9dCFmSr2kvODrFyQ63cajC/7E+HX05AlvvYue+X/FnP8Uk+dh+oYe4Wf2IxKqjy8MlWlVZC0xFN6gCAUL6AqkCKJKLJL2MirmE0krQCKlkFGjqtBUBf7DcA4rjm1juB3tas56/WDahffhxTmh/HPn0cm9qVrHBrXvdNw6reSkWBifPzfCMhO7RqmSXIyrNOUhb8tSMiqhcMqIiIqCKIoRAabr4JDTffBAAw4nHEDxxA9I03EHv9DcTeeAP6+Diir76K6Kuvpp+nrlmDwPbt8G/dgsDWrfBt3AiRVVZEREREZScKIvyyH35410bL3X4xX6iVLwRLP9byrCvwcUJPZFWQaaYGTdMQ1aJlff/5Aq3Zwq8Zw66cdU6YlnscL2aNmXU6hEoRMy3+XN3+isqnOAGVkQ6oAuriqgHbQiom4hrGoslp2+IpHRfDcQDAyvbp7VsVScQNa9vx1NEh/GzvRXzpHZsANQg0LZvfSTjztqKjdnDlWtLrRqevi4/bwVYECEdmnLU1I7XRDq1a56jScq3zt1gVakREdYJ/4xERUUUS/X4Ed+xAcMcOANZw9OTp04i98Qair7+O2OtvIHn6NJKnTiF56lS6LaCgKPBt2oTA1q3wb92KwLatUFevhiCW6Kc4IiIiIqoYXrZf1A09K7wqpBosvbjCsbgeR0JLpAM19/24Fs/sZ9/XTT19Ds62CUyU5T2Lgpg36Jqpwmu2bbkBWu4+jnotEpOkzBuXSt3iL5Vp8RfyLS6gag2pODMSxcjU9IDq9bNjAIAGn4zWYP5qo9svW4qnjg7h2ODkwk/CPW8Lqwt/nmEAibAdWo3NHGTlrouHAZhWKJacBMJ98ztfX1OeFoR2S8Vgu32/1XW/HVBD9fs/BxFVNQZURERUFQRBgG/NGvjWrEHL+94HANDGxhB7Yy/iB/Yj9uZ+xPfvhx4OI77fuu8QQyH4L78c/q1b4L9sM/ybL4O6ciUEibMBiIiIiKg4JFFCUAyWPRxLGal0m0N3wJUv7MoXcLnXzfV8577DMA3EtBhiWgxIlPqd/g0A4CsvfAX/vxMH5gy4csOudBVYTvjll/0IyIH044AUSK/3ST6IQmX8ops7lCpV1Vq+Fn9BZXGXDtuCVneLfBVUjx8cAABsXd4843ta02FVVp24NLWo81gQUcwERbOP28pm6FZINVOQlVXJNZoJwBJh6/mJCWsZP1v4a0qqK7CyFye8ygq22u2KrjYrsGOoRUQeY0BFRERVS25tRePv7ELj7+wCYFVZpc6dQ+zNNxHffwCx/fsRP3TImm9lz7ZyCH4/fBs3wL/pMvgvuwz+yzbBt2EDxID38xuIiIiIiAqliAoUUUFImd4irRRM00TSyARiTmiVbqM4VyA2S4XYtMf2/ZSRSr++buqYSk1hKlWewCIdZElWkOWX/QA+nt7+p0//aXq7E2w592cKvrL2t2/nCsLcY6ekEs+giqdcAdUiK6haZgmo+sdiAIB3bO2a8fmXL28GAPSNRjE8lUBHg2/GfSuGKGVCovnQNTvYyhNkRUes9dFR+769LjoK6AlATwKTF62l4POU7SqtQoMtuwUhu5MQURExoCIiopohCALUFSugrliB5ne+EwBgahoSJ08hvv9NxA4cQOLIUcSPHoUZiyG+703E972ZOYAoQl21Cv5Nm+DffBl8GzfCt2ED5M5ODpwmIiIiIoLd2cCuWioX3dCx9t7HAACfu/pzePu2xhkrxNzzxgoJy+KavbjuJ41MmOI8J4xw3nN77MxjRXmP7kqwgByYFngdubAFwAoAwOuXXsX/fOPJ6fu6n+MKxXySLx2WSeLMgZPT4m88mgkEg4ucQeW07nMf0zFgz59a0Tpz1WFzQMHm7iYcujiBPUcu4fd3rFjU+VQ0SQZC7dZSKNMEUtFMWJUbXsVG829LRQFDAyJD1lIoQbRCqnxVWblhlnM/0Mq5WkQ0I/7tQERENU2QZfg3boB/4wa0vP/9AABT15E824fEkcOIHz6C+JEjiB85DH1oOD3TauJXv0ofQ2xqgm/9evjWr7NvrUVubfXqbRERERER1Q13qNLsa8aq5t6Svp4zT8wdWsX0WPr+hw+Pp/f90s4vIabF0sFXTItlB145IZh734Se6YuYnh+WzD8/LD7ZACegOjiyHyfeXFgwporqjGHWoZFrASzDz4/9GkAPAOB/H/gHBJTs/XODr3xVYc6fWYsdUIXzBFSXJq33v6Rx9rBz5+o2HLo4gZNDkQW955omCNb8KTUEtMzj/4tUPE945VRo5T62g63EBGAadoXXKDByovDX8zfnD6+cYCu9dGT2mSVMJaLawYCKiIjqjiBJ8K1ZDd+a1Wi68870em1oCPEjRxE/fBjxw4eQOH4cydNnYExMIPbaa4i99lrWceQlSzKB1Qb7du1aiKHytFchIiIiIqo35ehrMPc8sV+m7334sg8v+HUM05g5zLJDMacK7F9eMPHcqPW8rR1bcNXGlmkBmrNvOiiz77uDsKSRRDKZxASmB2HxWA+AZTg+ch5ADyAkcf/+f1zQe1NEBQE5gMTodQB+B0+cfhYDj37D1fowgJGpGwEIeOTsQ3hxVEVADqQXv+RHQLHu+/xWpdXxoTFEUpGsAIwWSPEDyjKgaVnhz9GSOe0G3VVaucGWvS0+bj03HraWsdMFvpiQaT8Y6sgEWKGOTIgVcoVaoQ5AYbt+omrEgIqIiMgmL1mChiVL0HDzTel1RjKJ5OkzSBw/jsSxY9bt8eNInTsHbWgI2tAQIr/9bdZxlJ4eK6xatxbq2rXwrV0H39o1EIPlHZhNRERERESVSxREBJXZgrCMoycO4zmcAgDcsPw6fOG6jQW/jmEa6aqt3DArXfGlxfHQsxqeDwPrm7bj0DjgVyTcvfHuTHWYq4pspgoxR8pIIZVMIaVb7eMm4km8fun1zDmlmmDiJgAGHjr2jxAEY8bz1yYvA/Ax7Dm5H9c9/IcArEqwgBJIzwbLtzhhmPt+UA5mrc/dzwnHGIDlIatA41JrKZSuWSFVvvAqt4IrOgJEhu1Qy3RVah0v7LWUoB1W5VRjhdz3nYCL87SIKgUDKiIiolmIqppuEQi8M73eiESQOHnSDq6s0Cp+/Bj0oWGkzp1D6tw5TO3Zk3UsZflyqOucwGptOsCSGhrK/K6IiIiIiKiaiK6ZuOI85+OKgpgOX2bzxuFDeB6n0e7rATCMZn8Qf3HdXxT8OoZpWDPA7AqumB7DniMj+MpPL2Flwyb8+W3ftNanYnj1pIF/PgF0tCTx3svustbbS/r59hI2NcQAGMkOmKYIQTCsSrBEcsbZYIvlbl+YW9UVlIPZwdgsQVluSOYEZKJQJ8GIJFuhUKij8OfomhVMRYbt4GrYvj/quj+SHWoZKWuuVrjPWgohSDmVWW2Zaqys9a6AS1YX9nUgohkxoCIiIloAMRRCYNs2BLZty1qvjY1ZgdXJE0ieOGmFWCdPQh8eRur8eaTOn0fk6WeyniN3dWUCq3Xr0hVXUnNzOd8SERERERFVKGkRAVWhfLJVNRSOpbIeF0oUxPQ8qha0AACGljQDuAToAbx15W3pfS+dPw7gGH5n3Trce+37Zj2uphu49qu/wUgE+NaNj2HLCn9WkBXVolmhVm7AlRV+2VVjsVTOY3txODPBkJjlxBbBmdc1UxXXTCFXUA5at0ow7+OAHIAsVvnlXkkGGjqtpRCmac3Hio4AETvQcoKrqB1sZYVdI0ByEjB1IHLJWoYKPDdfU552g205rQddYZev0ZoTRkQzqvK/sYiIiCqL3NoK+dqdCF27M2u9NjaG5KlTSJw4icSJE0iePIHEiZPQLl2CNjAAbWAAkeefzz7WkiWZiqt1a+FbawVYcmtrOd8SEREREVHFqNdrvaLrfUslKr5R7ANPxTUAgF9Z/As1+hUAwIR9TMdoNAkAaG/wzXkMWRKxeVkTnj0+jIGwhpvXt6IVxf+ZyDTNrLaHM4ZcOWFYVIvOuC03GHPEdasl4nhivOjvQxXVGQOsoBxMV4HNuj7PY0VSin6uRSEIgL/ZWtrWFPYcLeEKsUZy7g9PD7uiI4BpWEFYYqLwWVqSmqnKCi2ZfhvMeaxyLADVHwZUREREZSC3tkK++moEr746a70+MYHEyZNInjxphVcnTyJx8gS0CxfTM66iL7yY9RyprQ2+deuyZ1ytWwupvR1Cvf7ETkRERERUw9yf80WxNJ/5Zck67lTCCpPmW0GVT6PfuvQ4GU9lrR+LWAFVW7CwlmnLW6z2hOfHY3PsuXCCIBTUCnGhDNNIh1ZxPT6tiiu3EixfUBbVounnOfejWhRRLQrDtOZ4Oe0Pix1+yaJcUBVXQY9d4ZdP8pX/51jZBzQts5ZCGIY9S8tdmTUyc+vByDCgxQA9CUxesJZCKEFXiLUkJ9xaYs3Tcm9jy0GqAQyoiIiIPCQ1NSF45ZUIXnll1np9KoLkKSe0yrQLTJ07B310FNGXX0b05Zezj9XcbLcIXOsKsNZB7lzC4IqIiIiIqIpJYulb/CnTAqrFV1A12RVUCc1AQtPTodeIHVC1hgq7wL7MDqgujscXfU5eEQURQSWIoFL8KhnTNJE0kunAKqbFEE1FM0FWzuPZtuU+1gzr+0EzNEwmJzGZnCzquTsz0goJtoJKMH0bUkKZdc562VofkAOQxMUHrJmTFO22fW1Ax/rCnpOMZoIrJ9SKDNmLs34oc6snrDla433WUghf88zVWblVWsE2oJhfE6IiYUBFRERUgaSG/DOujGgUidOnkTxxIqviKtXXDz0cRuy11xB77bWs54iNjXZ7wOx2gXJ3N4MrIiIiIqIqkNXir0Sf4WXRCqSiSR0A4CtCi78Gf+bS42Rcg6/BukA+5rT4KzCg6m72AwAuhEtXQVXNBEGAT/LBJ/nS87+KJaWnMsFVvkArTyiWt9or53kJ3RrwZZgGIqkIIqlIUc/bL/mzAi0nvHICL3fA5dwPKAGE5NC0/YNyEH7ZD1GYx/8TahBQe4GW3rn3NU0gOZUdWOXeRnNCLVMHEmFrGT1ZwAkJmRlZ0wIsd5WWvc3fXL89VamsGFARERFVETEYRODyyxG4/PKs9UY8juSZM9kVVydOINnXB2NyErG9exHbu3fasdTVq6GuWQN19Sr41qyBunoN1FUrIfrm7gVPRERERETlUY4Wf04FlcNfhBZ/kiigwSdjKqFhMq6hw545NTI1vwqqFW1W1dGRgUkYhlmyrwFNp0gKmqVmNPuai3pc3dCzw6xZKrkiqUh6fSQVybQ2TEUR0SKIpqLpx7ppBazOnK9RjBblfAUI6Youd6XWtGquPIFY7v7O43R7Q0EAfI3WUsgcLaflYMRVlZUVYA1ZM7Sc+7ExAKbdmnAYGDoy92tIqivI6rTuNzgBVqe1vqEzE3JJjBloYfidQ0REVANEvx/+TZvg37Qpa72RTCJ55kzWjKvkyRNInDkLIxpF/OBBxA8ezD6YIEBZvhzq6tXwrVlthVirrRBLXsJ2gURERETkHQH1+Vk0u8VfaV5DkbKrQ4pRQQVYc6isgMqaQ6XpBgYnrFZ9TmXUXK7sbUFQlTA0mcCJoSlsWNpYlHMj70iihAa1AQ1qQ9GO6bQ6jKasIMsJrdIBlh1wuQMt962zzQnFnG2m/Z9zjGIRBXHm1oU57QvztjRUggg2dyLUsTr9WJXyhL66BsRGZ2gxOGTPzXJtS0xY87MmzltLIQJtmcDKWdyB1srrgUBr0b52VDsYUBEREdUwUVXh37AB/g0bstabqRSSfX1Inj5ttQw8dRrJU6eQOH0axsQEUufOIXXuHCLPPpt9vIYGqGvWwLd6lR1aWSGWsnIlRJUDWomIiIiISiGrxV+JEio5N6AqQgUVAIR81uXHSMKqbBmaSsAwAVkU0hVVc/HJEla0BnF0cBKDE3EGVJSXu9Vhq784YYhpmlktC90VW3OGXc5z3GGZXR0GWO0Np1JTmEpNFeVcAUAW5XSg5VRwhWTXfXt9KBRCsHktQsq27G1yCCFBRjAZRSgxBTk2BkxdcgVYQ/bjYSByyQq3TMMKwGKjM1dnffpJYPnVRXufVDsYUBEREdUhQVHgW2vNonL/aGeaJvTRUSu4OnXKCq7sECt17hyMqSnE33wT8TffzD6gKELp7oa6shfKypVQe1dCXdkLtbcXyooVbBlIRERERLQIorvFX4k6GuS2+PPJxamgCqpW0BVNagCAC+NW9dTSJv+8wrYljT4cHZzE0GSiKOdFVAhBENJVSwgU55i6oSOux7MCrnQbw9zqr5m25VSHOTO9NEPDRHICE8mJopyrT/Klq7dCSgihYAjB5uUIKRus9VIAIYgImQZCuoZgKolQKoZQMoJQfArBWBih6Bhagx1QinJGVGsYUBEREVGaIAiQ29sht7cjuGNH1jYjmUTq7Fkk7NAqefpU+r4xNYXU+fNInT8P/PaF3INC7uqC2tsLdaUVXCm9vVaI1bsCYqBIn/KJiIiIiGqUO5QqWQWVmB1I+ZXiVFA5AVUkaVVQDYTn197P0d5gdWwYjSSLcl5EXpFECSHRqmQqlpSRSodYTpAV0SLpECySiqQXJ/TK2qZl75cyrJacCT2BhJ5Y3CwvH/CwNoGtRXqvVFsYUBEREVFBRFWFb/16+Navz1pvmia0oSGk+vuRPHPWah3Ydxaps31Inj0LIxKBdvEitIsXEX3ppWnHlTs7rUqrlZnQSlnRC7V3BaSmpnK9PSIiIiKqBvU5giqrxV+pZlDJORVUuRVVCxVSrcuPMbuCajxmBUytofm1CG8NWvuPMKAimkYRFSiqgia1OD9Dp/RUwSHXbEFXNBVFRIsUNYyrNbt27YKiKNi9ezd2797t9emUHQMqIiIiWhRBEKB0dkLp7ETw6uye0qZpQh8bQ/LsWaT6+pC0Q6tknx1eTUxAu3QJ2qVLwKuvTju21NJiVVutWGFVXtnBldrbC6mjA0KJ2psQERERUYUyvT4Bb4hi6Vv8qTkzqHJnUi1UMGcGVSRhBVUNvvldlmy3A63RKQZURKWmSApapBa0oGXRxzLNOv2Lu0B79uxBT0+P16fhGQZUREREVDKCIEBua4Pc1gZceeW07fr4uB1WWYFVqr8Pyb5+JPv7oQ8PQx8fhz4+Pn3mFQAhGITa0wOld4Wr8soKr5TubggyP+YQERERUW0oS4u/nIopuUivE1SyZ1BN2UFVyDe/FoJtTou/KAMqomrCXyyl2fDKDREREXlGamlBoKUFgW3bpm3TpyJInetHsq/Pah94tg/J/j6k+vqRungRZjSKxLFjSBw7Nv3Asgxl+TKoK3qtwMquulJ7e6H09ED0z6/fPRERERFViDq9zukOqEpVQZU7g6pYQVjQ5wRUVjAVtSuoQvOsoHJmVp0ejhTlvIiIyHsMqIiIiKgiSQ0hSJs2wb9p07RtZjKJ5PnzVnDV12/NvLIrr1L9/TCTSaTO9iF1tg/5fnyVly6FumKFNffKbhuorOiFurKXc6+IiIiIqOK4u+2JJaqgyp05pRSrxZ+aHVBF7EqqBnV+lyW39bQAAE5cmkI8pcOvzK8Ci4iIKg8DKiIiIqo6gqrCt3o1fKtXT9tmGga0wUEk+/ozLQP7+qwZWP39MCYnoQ0OQhsczD/3qrnZmnvlVF45AVZvL+QlS9iegIiIiIjKzv0ZVCpVBZVUmgqqgB0kxVNWQJVp8Tf/GVSSKEA3TIxHU+hqZkBFRFTtGFARERFRTRFEEUp3N5TubuDanVnbTNOEPj5uhVV9/VbLwLNWcJXs74M+NAw9HIa+fz/i+/dPP3YgYM+96rVbBtqVV6tWQVnWDUEszm+ZEhERERG5Zc+gKs1r5M6cKtYMKqfSKWYHVBG7xV/DPAMqQRDQElAwEkliLJpEVzPbdhMRVTsGVERERFQ3BEGA3NoKubUVge3bp203IhEkz52zK66c+VdWmJW6cAFmLIbE8eNIHD8+/djBIHxr18K3bh1865zbdZC7GVwRERERFUu91rK7Q6lSVfSrcvZn1mIHVE4F1WQ8BQBo9M//smRL0AqoxqOpopwbERF5iwEVERERkU0MheDfuBH+jRunbTOTSaQuXLCqrVwBlrOY0SjieSqvxGAQ6tpMYOVbtxa+9eut4IrtAomIiIioAGI5WvzlBFJSkUq1AukKKgMAEI5Z4VJTQJn3sZznTMQZUBER1QIGVEREREQFEFQV6qpVUFetmrbNTKWQ7O9H4vgJJE4cR+LECSRPnETizBkYMwVXzc3wb9oE/6aN8G26zLpduxaCqpbpHRERERFRtciaQVWkyqZcSk4gpZSogmoiZrX4a15AQOU8ZyLGgIqIqBYwoCIiIiJaJEFR4FuzBr41a4A73pZeb6ZSSPb1IXHipCu4OoHE6TMwwmFEX3oJ0ZdeyhxIUeBbuxb+jRvhu2wT/Js2wbdxI+TWVg/eFRERERFVCnfVVKmK8GUpp4KqSAFVQLWCr3RAZVc/NfkXUEFlPyc8S0AVT+n49P95Fc8eH562beeqNnznE9fMe/4VERGVBv82JiIiIioRwQ6cfGvXZgVXRjKJ5IkTiB85iviRw0gcPoL40aMwJiaQOHIEiSNHgJ//PL2/3N2NwJbL4d+yFf4tlyNw+eWQWlo8eEdERERE5AV3ViSWrMVfzgwqqUgVVHKmgiqlG4gmraCqKTD/y5LOcybiWt7tKd3Apv/02IzPf/nMKLZ8+XGc/OqdJatEIyKiwjGgIiIiIiozUVXh37wZ/s2bAbwHAGCaJrQLFxA/cgRxO6SKHzmKVH8/tIsXMXnxIib/7dfpYygrViCwdQv8l2+Bf8sW+C/fDKmhwaN3RERERFQe9TrDUyhDBZWa0+IvN7BaKL/qzKDSs1rzNS6ggspp8ReOJvNuPz44VdBxTg9HsK6Tn52JiLzGgIqIiIioAgiCAGX5cijLl6PxLW9Jr9cnJxE/fBjxAwcRP7AfsQMHkerrQ6q/H6n+fkz86tH0vurq1fBv3YLAFju0uuwyiIGAF2+HiIiIiIrIXewjoEQVVDkVU3KxZlClK6iMdOVTo09eUAXT8pYgAOD0SDTv9q8/diR9/8/evgkJzarWUiQRumHim/92DADwyL4L+NxbN8z79YmIqLgYUBERERFVMKmxEaGdOxHauTO9Th8fR/zQIcQOHET8wAHEDuyHduEikqdPI3n6NCb+9RH7yRL8GzfCv30bAtu3I7B9O9RVq+r2N4+JiIiIqlU5KqimBVRScSqoAnYFVTypp2dHNQXmXz0FABuWWlVPp4byV0o9fWwIALBleRP++La107Y/8OwpTMQ1/P1vjjOgIiKqAAyoiIiIiKqM1NKC0A03IHTDDel12sgI4gcPInbggFVttX8/tKEhxA8dQvzQIYz/8J+s5zY3uwKrKxDYthVSU5NXb4WIiIiIClCOXy9ScmdQFauCSrGOG9cyLf4a/Qu7JNkaUgEgq1WgI6Ub6fv/9x2b8j7/i3dsxF/9/OCCXpuIiIqPARURERFRDZDb29Fwyy1ouOUWAPZMq4EBxPbtQ2zvPsTefBPxAwegh8OIPPMsIs88m36uumZNusIqcMV2+NatgyDzYyIRERFVnnqtAxeyWvyVhigKEAXAMK3HC2nBl09AsSqoUrqJMXt2VNMC5k8BmWBrKqHBMEyIrnMcnIgDsGZp3bSuI+/z79zajb/6+UEIApDQdPjs9oNEROQNXnkgIiIiqkGCIEDp7obS3Y2mt78dAGAmk4gfPWaFVvaS6utD8tQpJE+dQvinP7WeGwwisGULAq7WgPKSJV6+HSIiIiIAgOn1CXhELEdCBautX1Iz7PvFqqDKhEDjUavyyWn7N19OsGWYQDSlo8GXubR5ejgCAOhu8WcFV27tIRVBVUI0qeOZY8N46+alCzoPIiIqDgZURERERHVCUFUEtm5BYOsW4CMfBgBoo6P4/7d35/FR1ff+x9+zz2RjCxAgEAiEhE32oCIgWqHazbqBvXXrrdVbvG1/9fba9udV23ur9V67/bxee1vqUrdqq9W2YhUVUFxAVtkSQgKBEAh79tnP749JJjPZyDKTyfJ6Ph48Mvmec77zPePJZDzvfL7f+k8/Vf3OnXLv3Kn6nZ8qWFurus2bVbd5c/hY2+jRcs2aKecFodDKOXWqzA5Hok4FAABgYInKp+KXUNkjAypzbNagclib+mmcms/WxfWtHFazrGaT/EFD1W5fVEB10+9Cn11LT9e1ebzJZNKQJLvqvPV69uNSAioASDACKgAAgAHMOnSoUi+9VKmXXipJMgIBeUtKGiqsPlX9jh3yHDggX3m5fOXlqlrzRuhAm03OvLyoqQFtmZlRC3gDAADE2kD9pBF53vH8uBVZeBSrCiqTySSXzaJ6X0BV7lBAZbd2rW+TyaRUp1Vn63yqcfulQS33mTg8ud0+br4oSw+9UaDTtZ4ujQEAEDsEVB30q1/9Sj//+c9VUVGhefPm6bHHHtPMmTMTPSwAAICYMlkscuTkyJGTo8HXXSdJCtTUyL17d3gtq/qdOxU4fVruXbvk3rVLZ599VpJkGTIkHFa5Zs6Uc8YMWVJSEnk6AAAA/UJP/RFQ5NR41hitQSVJSfZQQFXZzQoqSUp12nS2zqcqtz+qfWSaQxVVHv1ixax2j1+UM1wPvVGg3UerujwGAEBsEFB1wPPPP6977rlHv/nNbzR37lz913/9l5YvX679+/crLS0t0cMDAACIK0tKipIvvFDJF14oSTIMQ76jRxsCq9BaVu69+xQ4e1Y169erZv360IEmkxyTJso5c6ZcF1wg18xZckyaKJOFxagBAAA6w9TG41izRARhlhgGVI3rUFXVh0IlezcCqsZp/Wo80QFVTUNg1bhOVVvGDHaFH+8pr9S00a2UYQEAegQBVQf84he/0J133qmbb75ZkrR69WplZGTo+eef15133png0QEAAPQsk8kke2am7JmZGvT5z0mSgh6PPPv2hSqsdoRCK9/Ro/IUHZCn6IAq//SyJMmclBRax+qCC0KVVhdcIGt6eiJPBwAAoNeLLKCKZzVVZN/dqXJqzmUPBVThCiprdyqoQrczqxumC5SkQNBQrTcQtb0tg5KaAqyXtx4loAKABOr0b4OjR4/ql7/8pZYtW6Zx48bJbrcrIyND1157rTZt2hSPMZ7Xs88+qzvuuEPz5s2Tw+GQyWTSU0891e4xn3zyia666ioNHjxYycnJuvDCC/XSSy+12M/r9Wr79u36zGc+E26zWq269NJL9dFHH8X6VAAAAPoks8Mh16xZGnrzzRrz859p0jtvK2fj+8r8n8c07BvfUNKCBTIlJSlYV6e6jz/W6d/8RmXfXKWiSxbpwOWf0dHv3q0zv/+96rZvV7C+PtGnAwAAeqmButylKaJuKp6vQWQmFcsKqsaKqVpv9yuomgKqpgqqyLAq9TwVVJJ0/dxMSdLJGtahAoBE6nQF1aOPPqqHH35YEydO1LJlyzR8+HAVFRXp1Vdf1auvvqrnn39eK1asiMdY23TvvfeqtLRU6enpGjVqlEpLS9vdf926dVq+fLmcTqdWrlyp1NRUvfzyy1qxYoWOHDmiu+++O7zvqVOnFAgENHLkyKg+RowYoeLi4ricDwAAQH9gTU9X6mWXKfWyyyRJRiAgz4EDqt+5M/zPW1wi39Gj8h09qqo1a0IHWixyTJwo57Rpck6bJtf0aXLk5cnsdCbwbAAAABKnp4I5syk+a1DZLKG+ahum5Wv8visaA6iaiIDq4Kna8GN7B6qzFk8erj9uLVNFpbvL4wAAdF+nA6r8/HytX79eS5YsiWp///33dfnll+uf/umfdPXVV8vhcLTbzzPPPKPFixcrKyur1e2BQEC/+tWvdNddd8lut7fb1+rVq5WTk6OsrCz99Kc/1Q9+8IM29/X7/br99ttlNpv13nvvadasWZKk++67T/n5+frhD3+o6667rs1xAQAAoGtMFoucubly5uZqyA03SJIC1dVy79oVCqx27FT97t0KnD4tz/798uzfr8o//zl0cGRoNX2aXNMIrQAAwMDRU2tQRQZU5himYtaGiqn6hmn4ujN9YGtT/P34b3s71cfItNBnyM2HzsgwjLhOmwgAaFunA6prrrmm1fZFixZp6dKleuutt7Rr1y7NmzevzT7Kysp0++23a9SoUVq/fn2LMCgYDOqWW27Rc889J7vdrrvuuqvdMUVOv3c+7777roqLi3XbbbeFwylJGjRokH74wx/q1ltv1dNPP6377rtPkpSeni6LxaKKioqofk6cOKGMjIwOPy8AAABasqSmKvnii5V88cWSJMMw5K+okHvPHtXv3i33nj1y79nbdmg1aVIotJoyRc68XDlyc2VJS0vgGQEAAMRB1BpU8XsasznycQwDqoa+6nyxC6iqIiqoGqcM7Oi0hHmjUsOPD5yoUc7I1Hb2BgDES6cDqvbYbKESW6u1/W4zMzP1wgsv6IYbbtDSpUu1fv16jRs3TlJ0OHXzzTfrm9/8ZiyHqPXr10uSli1b1mLb8uXLJUkbNmwIt9ntds2ePVvvvPOOPv/5z0sKVWGtX79e//Ef/xHTsQEAAAx0JpNJtowM2TIylHr55ZLOE1oVFspTWKjKiD5so0fLkZfXEFiFvtrGjpXJHLuFvgEAAHqSqYdqqCKrpiwxTMIaA6m6hgqqjkzD15b0lNCsTRVVTdPzjRnskiTd89ncDvWRFrFOVfHJWgIqAEiQmAVUhw8f1ttvv61Ro0ZpxowZ593/y1/+sl544QXdeOONuvTSS7V+/XplZmbq1ltv1bPPPqt/+Id/0JNPPilzjG8kFBUVSZJycnJabMvIyFBKSkp4n0b/5//8H/3jP/6j5s6dqzlz5uiRRx6R1WrVV77ylTaf57HHHtNjjz0mr9cb0/EDAAAMNO2GVrt3q37PHnn2FchdWCj/sWPylZfLV16umnffDfdhTkqSY/JkOfJCUww6cvPkyJkkSyo3IwAA6EsG6kxspp6qoIqa4i92/Vob1pzy+oOSmiqeumJocmgpkMr6pin+ahrWtkqyd/xW57KpI/XW3gqdrPF0eSwAgO6JSUDl8/l00003yePx6OGHH5bFYunQcdddd50CgYD+4R/+QUuXLtW8efP00ksvaeXKlXr66adjHk5JUmVl6O9rBw0a1Or2tLS08D6NvvKVr+jkyZP64Q9/qIqKCs2bN09vvvmm0tqZPmbVqlVatWqVysrKNHbs2NidAAAAAKJDq4jpngPnzslduF+ewlBg5SkolKeoSMG6OtXv2KH6HTui+rFmZMgxaVLoX07oq33iRFlSUnr4jAAAANoWy/Wg2n+epsexXJfJ2uwen83S9b6dttB9x/qG6QKlpsqsFEfHb3UOTw1VYh08WdvlsQAAuqfbAVUwGNStt96q9957T7fffrtuuummTh2/YsUK+f1+ffWrX1VJSYmuvvpqPfvssx0OuXrKt7/9bX37299O9DAAAADQDsvgwUpekK/kBfnhNsPvl/fQIbkLCkPBVUGhPAUF8p88Kf/x4/IfP67ajRuj+rGOGtUUXDWGVxMnypyc3NOnBAAAEF1BFdfniU8FVfNAytaNKf6S7A0BlbcpoGqqoOr4/cQRqU5J0t93H9N9X5ja5fEAALquWwFVMBjU1772NT3//PP66le/ql//+ted7sMwDL0bMf3Knj17VFFRodGjR3dnaG1qrJxqXiXVqKqqSkOGDInLcwMAAKDnmazWcNCkz38u3B6orJSnuFieogPyHDggz4EieQ4cUODkKfmPHZP/2DHVvv9+VF+20aNlnzhRjuwJsk/Ilj17ghzZ2bIMGxbTvzIGAACIFLUCVRw/c0T2HMuqLauleQVV9wOqOm9kBVUooEruRAXV2KGhdavKK93n2RMA4mfp0qWy2WzhGdkGmi4HVMFgULfddpt+//vf68Ybb9RTTz3V6Sn5DMPQN77xDT3xxBNasWKFvvCFL+iWW27R0qVLtW7duriEVI1rTxUVFWnu3LlR244fP66amhrl5+e3digAAAD6EcugQUqaM0dJc+ZEtQfOnWsIrIobvob+BU6dCq9v1Ty4MqelyTFhguzZTaGVfUK27GMzZbLZBAAAYsMU1/qh3qunKqgixTKgsjUrx3J0o4KqcYo/d8QUf7We0OPOVFBdmjsi/NjtC4T7BYCetG7dOmVmZiZ6GAnTpYAqMpxasWKFnnnmmU5PyWcYhu644w6tXr1aN9xwg5577jlZLBaZzWbddNNNWrp0qdavX69Ro0Z1ZYhtWrJkiR566CG99dZbWrlyZdS2N998M7wPAAAABibL4MFKmjdPSfPmRbX7z56V98ABeUoOyltSIs/BEnlLDsp39KiCVVWq37lT9Tt3Rndmtco+blwotJqQLfuECaHqq+xsWdpZzxQAALTOkJHoISRIzwdzsVwa3tp8ir9uVVCFbmdGTvFX21BB1Zk1qIYk2WQ1m+QPGjpT69Xowa529/cFgvrvdw/oV+8UaeX8sXrwyzNkjuU8iAAwAHU6oGqc1u/3v/+9rr/++i6tF2UYhv7pn/5Jv/3tb6PCKUm68cYbJSkqpMrIyOjsMNt0+eWXKzs7W88//7y+9a1vadasWZJCU/49+OCDstvtuvnmm2P2fAAAAOgfrEOGyDp/vpLmz49qD7rd8pYelvdgiTwlodAqFGAdlFFfL29JibwlJarRO1HHWdLTw1VXjaGVfUK2bKNHyRTLO0IAAKDPi6qg6qFMpLdO8edqqHSqj6igqmuooOrMFH8mk0lDku06We3pUED1u40H9at3iiRJf/jkiJbmjdDyabG7ZwkAA1GnA6of//jHevrpp5WSkqLJkyfrP/7jP1rsc/XVV4eDn9aUl5frz3/+s66//no999xzslqjh3HjjTfKMAzdfPPNevvtt/XVr3613TGtXr1aGxsWtt61a1e4bf369ZKkSy65RF//+tclSVarVatXr9by5cu1ePFirVy5UqmpqXr55ZdVWlqqRx55ROPHj+/gqwEAAICBzux0ypk7Wc7cyVHtRjAof0VFU2h1sCRcfeU/cUKBU6dUd+qU6j75JOo4k9PZME3gRDkmZofWvJo4UfZx45guEAAw4A3YKf6iHvfMaxDPKf5slq737WqYxs8fNOT1ByVJ3kDoa7K9c7c6hzUEVHvKKzV9zKA296vx+PXTNwqi2u54ZqsOPnQV65ACQDd0OqA6dOiQJKmmpkY/+clPWt1n/Pjx7QZUY8aM0UcffaRx48a1CKcafeUrX9G8efM0efLkVrdH2rhxo55++umotg8++EAffPBB+PvGgEoKLTy2ceNG3X///XrxxRfl8/k0Y8YMPfzww1qxYsV5nw8AAAA4H5PZLNuoUbKNGiUtXBi1LVBTI+/Bg/IePBgdYB0qleF2y7N3nzx790V3aLXKnpUVWt9qYrYcjcHVhAkyu9r/i18AANC3RYYg8cxD4lWp1byCyh6DCiopVEXlbwinJCnJ0blZnoJGaMrITSVntGL+uDb321p6ttX2Zz4u1c0Xje/UcwIAmnQ6oHrqqaf01FNPdfuJs7Ozz7tPR8IpqWtjys/P1xtvvNGpYwAAAIBYsKSkyDVjhlwzZkS1G36/vEeOhKYILC6Rt/iAPMWhqQONujp5i4vlLS6W1kYcZDLJNnp0Q2g1KVR11VB9ZRnU9l8CAwCAviMRNTqxneIvui+7tesBld1qDq8dVe8N6FSNR5KUnmLv9NSBCyYM0/6KGtV4/O3ud67O22r7fa/tIaACgG7odEAFAAAAID5MVqscEybIMWGCUi+/PNxuGIb8x46FQquS4obQqljeA8UKnDsn39Gj8h09qtr33o/qzzI8PTRVYE6OHJNz5Jw8WY6cHJmTk3v61AAAQDckYhY5izmWU/w1q6DqRkAlhaqoqj1+1fsCKj9XL0nnXUOqNfkThuqZj0t1rs7X7n7f/sOO8ONHrp+pf/njzk4/FwCgJQIqAAAAoJczNVRJ2UaPlhZdErXNf+aMvMXF8hQXN1RdFctTUiL/8eMKnDylupOnVLdpU9QxtsxMOSZPbgqtJk+WPSuLNa4AAL3eQF3uJ3LdqZ56DWKYT8W0gkoKrUNV7fGrzutXRZVbkpSR5ux0P0OT7ZKkyvq2AyqjYRrARtfNzdQnB8/oxS1HwscOcvEZCgC6goAKAAAA6MOsQ4fKOnSokubPj2oP1NSEpgo8UCzP/v3y7N8vd9F+BU6ekq+sTL6yMtW8+254f5PNJnt2dovgypqRweLfAAAkWNTaUD004V8sf/83n3qvO2tQSaGASpLcvoDqvAFJUoqj87c5G4Olc/WtT+EnRYdX2emhKvR7Pz8lHFD99I19euiaCzr93AAAAioAAACgX7KkpMh1wQVyXRB9w8R/9qw8+4vCoZVn/355iooUrKuTp7BQnsLCqP3NaWmhwCpvipxTpsg5JU/2SZNkttt78nQAAECDvvh3I1ZzjCuobKGAqs4bkMcflCQ5Gto6Y3BSKKCqqPK0uc9LDUGUJP18xSxJ0WHYC5uPEFABQBcRUAEAAAADiHXIEFkX5Ct5QX64zQgG5Ssvjwqt3Pv3y3vwkIJVVarfslX1W7ZGdGKVY+JEOfPy5Jw6RY68KXLm5coyaFACzggAgP7P3EOpVLyqs5qvZxWLKf4kqd4bkNsXqqBydKHP9BRH+HHp6VplDWu5TueDawrCj2eNHSypZXVZrcev5C5UcAHAQMc7JwAAADDAmcxm2TMzZc/MVOpll4Xbg15vaJrAwkK59+6Tu6BA7oICBSsrw9VWla+9Ft7fNnq0HFOmhIMrZ16erKNHM0UgAADdFDXFXx/8tdo8YHNYOl/tFCmpMaDyNVVQObtQQeW0WeSyWVTvC2jTwTOtBlRt+fkNM/Xdl3ZKkp784KDuuiyn088PAAMdARUAAACAVpnt9lDYlJenQV/6kqTQQuH+Y8dCYdW+fXLv2yfPvgL5jh6Vr7xcvvJy1bzzTlMfaWmhPqbkhSqtpuTJMXGiTDYWEwcAoKN6ag2qeIVfzSuoHLbuVVDVekJVU+8XnZKzoa+uVFBJ0rAUu8rO1qu+YS2rSB8eOBV+/L3luVHbLshsqhw/VdP2GlYAgLYRUAEAAADoMJPJJNvo0bKNHh1VbRWoqpK7oECeggK594XCK8+BAwpWValu82bVbd7c1IfNJkdOjhxT8uScMjVUbZWbK3Nyx/9qGQCAgSQylOqTFVTNp/izdC+gOl7pliQ5bWa5fY1rUHWtz4UT0/XiliOqdvtabPvF2/vDj5dNHRm1LWOQK/w4zcUf3gBAVxBQAQAAAOg2S1qakvPzlZzftLZV0OuVt7g4FFgVhCqt3AUFClZXy713r9x796pSr4R2NplkHzdOjqlTQqHVlFC1lTU9PUFnBABA7xFdQdX3WCJOwGo2tQisOuuG+WP1/94pkqSmKf6sXZs2MNUZuj1a7fG32HbodF3EftEhVIrDqvHDknTodJ0Onart0nMDwEBHQAUAAAAgLsx2e0PQNEXSlyWFpgj0lZU1TQ+4N/TVf+KEvKWl8paWqvqNv4f7sA4f3hBaTZEzb4qcU6fINnYs61oBAAaUvv5bL7Jgyt7FqfgiNU7n5/UH5fGFpubragVVY/BU7W4ZUAWCRvjx4KSWVVJXTB2p375/UH/ZWa6Hr71ALnv31tYCgIGGgAoAAABAjzGZTLKPHSv72LFKW7Ys3O4/fbphasC98uzbJ/e+AnkPHZL/5En5N5xU7Yb3wvuaU1LkzMuToyH8ck6dwrpWADBA8AcKfXSKv4hBxzKg8viD4QoqRxcrqFIaKqhOVHmi2qvdPp2pbVpbymlr2X9txLpVp2o8Gjs0qUtjAICBioAKAAAAQMJZhw1TyiULlXLJwnBbsLZW7sL9odCqoEDuvfvk2b9fwZoa1W3ZorotW8L7mmw22XMmNVRshda1ckzOlSWFda0AoD8xDOP8O/V7fS+hskRM6dfd9aekiIDKF5S7oYLK2eUKqtDt0Y9LTke1P7fpcPjxf153QavH+hrCMan1CiwAQPsIqAAAAAD0SubkZCXNma2kObPDbYbPJ09JSWh6wH375N67L7yulWdvaMrAtte1ypNzyhTWtQIA9DlRa1DFMZ8KxikAjAyobDEIqBqrsLyB7ldQjUh1SJJqmq1Bda7OF36c5mz9FupNF2Xpj1vLJIUqqAAAnUNABQAAAKDPMNlscubmypmbK119taSGda2OHpV7796mda0KCuSvqGh/Xau8pikCbZmZMpm7f8MMABBfTPEX3/qpeBWoRU7xZ7V0/wysDb+z/UEjIqDq2u/xWWMHhx9X1vs0yBWaMjhyKkKXvfVbqBdkDpbTZpbbF9TNT2zWoZ9+rktjAICBioAKAAAAQJ9mMplkz8yUPTOz1XWtPAUNlVb79rW7rpUjL7eh0qphXavsbJns9kScEgAAzfRMMBevCRQjK6giH3dVY8gVCAbl8Yem+OtqQDU4qel3fVVEQBXZn62dMbt9wTa3AQDaR0AFAAAAoF9qc12r/fujpghsXNeqfstW1W/ZGt63xbpWU/LkyM1jXSsAQELFs4osXlP8RVVQxSCgagy5/AFDnoaAyGHr2hR/kjQs2a7TtV7VeQPhNnMHX+dUh1XVHtafAoCuIKACAAAAMGCYk5OVNHu2kmY3X9fqoNz79oZCq30Fcu/b1/66VlOmyJmXF6q6ysuTdeRIpp0CAPSIuP62iVMJVXQFVfen1G0MuQJBI1xBZe/G2lZJDotO10q13qagye1rCqvmjh/S5rFPfW2+rn38I0nSwVO1mpDOH7IAQEcRUAEAAAAY0ELrWk2WM3dyq+taeQoKwlMERq1r9femda0sgwfLkZcnZ25u6OuUPKYIBIA4GKh/ChD5NxDx/HuIeE3xF1k1FZsKqlAY5QsaqnaHQqVUZ9dvcybZQsfWeZpCqcaA6vZFE+Swtl2ddUHm4PDjf/njTr38Txd3eRwAMNAQUAEAAABAM5HrWilyXaszZ0LTAhYWhNa3KiyQp+SgAufOqe7jj1X38cdNndhscmRny5mXK0felIavebIOafuvsAEASCQjXlP8xXoNqoY+aj1+efyhKf4GJ9m63F+SIxRARVZQ1TcEVK7zTB1oi6jc2lp6tstjAICBiIAKAAAAADrIOnRoy3WtPB55DhwIVVoVFDZ8LQhNEVhYKE9hofTaX5r6GDkyNDVgbl44vLJnjZPJ0vW1MwCgv8sdmarCimpdMik90UNJCFPU4/iVUMWrgsoS4zWorJZQH6drPOE+Uxxdv82ZbA8dWx+xBtX7RackSU47v58BIF4IqAAAAACgG8wOh1zTpsk1bVq4zTAM+cvL5S4slLugQJ59BXIXFsp3+LD8FRXyV1SodsN74f1NLpccOTmhKQKn5IXWt5qcK0sK61gAgCSt+fYiefwBJdm5ldUXlzyMXHYqFhVUjX2crfNJClVPdWctyKSGEGpn2TldPXuMvP6gDp6qlSS5I0IrAEBs8VsdAAAAAGLMZDLJNmaMbGPGKPWyy8LtgZpaefY3hFYFhXIXFshTuF9Gfb3cn34q96efRvVjGzeuYV2rXDnzQsGVdfTobt2EA4C+yGI2EU71YVEVVJZYTPFnjvo+zdX16f0kqcrta+g3NLZz9d7wtmqPv9VjIn33isn6+dr9kiR/ICirxXyeIwAAEgEVAAAAAPQYS0qykubMUdKcOeE2IxCQt/RweF0rd2EovPJXVMh3+LB8hw+reu3a8P7mtDQ5J0+WY0rDula5eXLkTJLZ4UjEKQEAekDkHyb0xb9RsEStQdX98KZ5Fdb51ok6n8vyRujjkjN6t+CE/u/npsrbsK6VJN10YdZ5j589bnD48UtbyvSVBeO6NR4AGCgIqAAAAAAggUwWixzZE+TInqC0K68Mt/vPnm2xrpWnpETBqirVbdmiui1bmjpp7KNxXavcPDmn5MmaPjDXagEAdE28si+zOcZrUDXrw9bNiqXGiqzkhnWsPBEBVfbwlPMe74wIyDYdPE1ABQAdREAFAAAAAL2QdcgQWS+6SMkXXRRuM7xeeUpKota18hQUKHDunDxFB+QpOqCqv/0tvL8lPT1iisBQxZV9wgSZrPyvIAD0JZFxTF+c5jVyir9YrkHVyNbNaQNzM1IlSR5fMOrr8NSOVSdHVnA5rd2r5gKAgYT/KwEAAACAPsJkt4fXotLVoTbDMOSvqIhe12pfgbylpQqcOqXaU6dU+8EHUX04cnJCoVVuXnh9K0taWmJOCgDQKX0vnmo2xV8MArbm61g1X5Oqs5y20PFuf0CS5Gn46rB2rN/G45s/BgC0j4AKAAAAAPowk8kkW0aGbBkZSr300nB7sK5OnqKiqHWtPIWFCtbVyb1nj9x79qgyoh/b6NFy5DVMEdgQgtkyM2WKwVohAIDYiWcBVbyqs8wx7rd5IGXrYJDUFkdD1ZPb1xhQBRvaO9Zv5Lpazm6uhwUAAwkBFQAAAAD0Q+akJLlmzpRr5sxwmxEMynfkSGhdq8KCcHjlLz8mX3m5fOXlqnn33aY+UlLknDZNzunT5Jo2Tc7p02UbO7ZPTi8FAH1ZX3/bjaygisW5tFiDqpvTBoYrqBqn+AsHVB0LmzLSnOHH7xSc0A+umtKt8QDAQEFABQAAAAADhMlslj0rS/asLGn5snB7oLIyvJ6VuyD01VNUpGBNjeo2bVLdpk3hfc2DBsk1baqc06bLOX26nNOmyTZmNKEVAPQQUx+c5M8SUYgUi18Xzdegaj7lX2c1r6DyNgRU9g5WULnsFl2aO1zrC0/qwIka1Xr8SnZw2xUAzod3SgAAAAAY4CyDBik5P1/J+fnhNsPnk6e4WO49e1S/e7fcu/fIU1CgYGWlaj/8SLUfftR0/ODBobBq+jS5poeCK+vIkYRWABAHcZ3iL079Rk7xF4uArcUaVJburkEVCqg8/qAMw+j0GlSSlOa0hR/7AsFujQcABgoCKgAAAABACyabTc6GtagGX3utJMnweuUuKpJ79x65d+8OrWW1f78C586pduNG1W7cGD7ekp4u57SpcjVWWk2fJtuIEYk6HQDo0yJDnb4Y/UdVPMWhgsre7YCq6XiPPyhPw1R/jk6sJxUwjPDjxgosAED7CKgAAAAAAB1istvlmhZaj0orbpAkBT0eefbvl3v37qZKqwMHFDh1SrUb3lPthvfCx1tHjAitaTVjelOl1dChiTodAOib4plQxanvqDWoYtCf1Wxu9n1316BqCqI8vmDEGlQdD758EaGUt49UUFW7fZr947XyB5vCtVdXLdSssYMTNygAAwoBFQAAAACgy8wOh1wzZsg1Y4aGNLQF3W55CgpUH6602i1PcYn8J06o5sQJ1axbFz7eNnp0qMKqMbSaNk2WtLTEnAwA9FJ9fcbUqIAqBicT6zWobBazLGaTAkFDbn+gS1P8RYZSfaWCau3eiqhwSpL+8+8Fev72CxM0IgADDQEVAAAAACCmzE6nXLNmyTVrVrgtWFcnd0GB3Lt2hYMr78GD8pWXy1deruq33grva8/Kig6tpk6VOSkpAWcCAL1PLNZw6mkWU2wrqGzN16Ayd2+KP0kKNAQ1R87URVRQdXyKv2R7021WX8BoZ8/e42dv7W/RVn6uPgEjATBQEVABAAAAAOLOnJSkpDlzlDRnTrgtUF0t9569cu/Zrfpdu+XevVu+sjJ5S0vlLS1V1euvNxxslmNitpzTZ8g5fZpc06fLkZcns8ORoLMBgMSJZzVVvLo2R1VQdb+/5hVUzb/viiS7RXXegGo8fp2r80mS7J2ooPrussl6fdcxSdLmQ2eUm5Ha7THF29FWwqjSM3UJGAmAgYqACgAAAACQEJbUVCVfuEDJFy4It/nPnpV79x65dzdVWvkrKuQpOiBP0QFV/vnPoR2tVjkm58jVGFrNmCHHpEky2WwJOhsA6Bl9r34quoLKiEFxUfOKqVgEVLPGDtaHxad18FStfr2hWJK0t7yyw8dPHJ4Sfvxvr+7WTRdmdXtM8RQMtv4fwjCkk9UeDU/lj0AAxB8BFQAAAACg17AOGaKURZcoZdEl4TZfxQm59zSGVrvl3rVbgbNn5dm7T569+6SXQvuZ7HY5puQ1hFbT5ZoxXfYJE2SydHyKJgBA7EVWUAVjkFA1z6OsMQiofA1rSJ2s9oTbdpZ1PKDqa+p9gTa3lZysIaAC0CMIqAAAAAAAvZpt5AjZRo5Q6mVLJUmGYchfXt5QYdUQWu3eo2B1tdw7P5V756fhY81JSXJOnRpa06ohtLKNGydTPOfIAoA4iuf7V7z6jqxwikUFlclkktVskr+hCigWFVTDkh0NfTe1fW3hhG7321u9tfd4m9sCbVRXAUCsEVABAAAAAPoUk8kk25gxso0Zo7TlyyRJRjAo3+HDodBq1y7V79kt9959CtbVqW7LFtVt2RI+3pyWJtf0aXJOawqtrKNGEVoB6LUiQ52++E4VOcVfLCqopFAo1RhQxaKCatywJEmS1x+U3WqW1x/UTRf17mn6uuP/vLizzW1nG9bgAoC2/OIXv9Dq1atVWloqq9WqOXPm6KGHHtKCBQvOf3AEAioAAAAAQJ9nMptlHz9e9vHjNejzn5MkGYGAvCUlqt+1W+7du1W/Z7c8+woUrKpS7YcfqfbDj8LHW4YOlXPGdLmmTQ99nT5d1uHDE3U6ANCmvpilR445RvmUrGaTGifjM8cgoLJbQutaef3B8BSCNksffLG7YP2/XKqgYejyn2+QYUirnt+mz13wuUQPC0AvlpWVpZ///OeaNGmSPB6PfvnLX2r58uUqLi7WsGHDOtwPARUAAAAAoF8yWSxy5OTIkZMjXfNlSZLh9cpz4EBTaLV7tzxFRQqcOaPaDe+pdsN74eOtI0fKOWO6nHlT5MzLlSMvT7YxY6i0ApBQpjjWUMWrZ3PE+6ah2CRUVotZUmgdpVhUUNmtDQFVIBie4s5qNne7375gfHqypNiFh/2dLxDUD17ZpYsnDtM1czITPZyYeuiNffrfDSVRbdfMGaOf3zArMQNCr3XNNddEff/II4/ot7/9rXbv3q0lS5Z0uB8CKgAAAADAgGGy20NrUk2dKq24QZIUdLvlKSyMCK12yVtcIn9FhWoqKlTz9jvh480pKXLk5cqZmxf6mpcnR06OzE5nok4JAHo9c5wqqBpZYhAkNQZUHn9QvkDs1rZC//PnbUf1p61l+tPWsn4XUDUPpyTplW1HCah6iWeffVbvv/++tm7dql27dsnr9erJJ5/Urbfe2uYxn3zyie6//359+OGH8vl8mjFjhr773e/qhhtuiNm4vF6vfvOb32jIkCGaMWNGp44loAIAAAAADGhmp1OumTPlmjkz3BasrZV7375QhdW+ArkLC+UpLlawpkb1W7aqfsvWiA5C0ws683LlyM0LV1tZR4yg2gpATERVHfXBt5XoCqrYiAyPYlJB1TDFn9sX6HK/V88arVd3lHd7LD1h1tjB2nHknL59eU64bcnk4dqw/6QkyTAMfoe14UydN9FD6HFcD73Dvffeq9LSUqWnp2vUqFEqLS1td/9169Zp+fLlcjqdWrlypVJTU/Xyyy9rxYoVOnLkiO6+++5ujef999/XlVdeqfr6emVkZGjt2rUaOnRop/ogoAIAAAAAoBlzcrKS5s1T0rx54TbD65Xn4EF5CgrkLiiUpzD0NXDmjLwlJfKWlEhr3mjqY9AgOSZNkmPixNDXSRNlnzRJ1uHDuckDoFMiq47iWdQTr7emeK1B1Sgma1A1VFDVeSMCqk6uQfXNpZPCAdWaXcd01YxR3R5XvPgCQUnSrHGDw233f2GqLvvZBklS8ckaTRqRmoih9XoD8Te4P2gMmDXZerPVq1crJydHWVlZ+ulPf6of/OAHbe7r9/t1++23y2w267333tOsWbMkSffdd5/y8/P1wx/+UNddd52ysrLCx3z/+9/Xww8/3O4YjIg38Xnz5mnHjh06ffq0fvvb3+qGG27Qpk2blJ6e3uFzIqACAAAAAKADTHa7nLm5cubmatCXQm2GYch/8qQ8hYVyFxTIU1Aod2GBvAcPKVhZqfqtW1W/dWtUP+a0tKbgKmeS7BMnyjEpR9YRBFcAWheZ6fTF94noMccmobJYYlxB1VpA1cmpAyPH8aO/7um1AdW5Oq/2lFdJkpJslnD7+GHJ4cc/eGWX/njnxT0+tr7gk0Nnwo/P1Xk1OMmewNH0jEDQUMSlIn8g2LAOHHrSZz7zmQ7v++6776q4uFi33XZbOJySpEGDBumHP/yhbr31Vj399NO67777wtvuvvvudqcLbM7lcmnSpEmaNGmSFixYoJycHD355JP63ve+1+E+CKgAAAAAAOgik8kk24gRso0YoZRFi8LtQY9H3oMH5Sk6IE/xAXkOHJD3QLG8hw8rWFWl+m3bVL9tW1Rf5tRU2SdMkH18luzjx8sxfrzs48fLnpUlc3Jy86cGMIBE/sV6XCuoeqA2JHYVVE03x2OxVpQjHFD5u9yvLeKGva0X37z/f+8cCD9OsjfdHo6sRPvk0NkeHVNf8va+E+HHP1+7Xz/+0vQEjqZn+AJBORsSqksefldlZ+v1wBem6taFExI8sr6vurpaVVVV4e8dDoccDke3+12/fr0kadmyZS22LV++XJK0YcOGqPbhw4dr+PDhXX5OwzDk8Xg6dQwBFQAAAAAAMWZ2OOTMy5MzLy+qPejxyHvoUDi48h44IE9jcFVdLfenn8r96act+rOOGBEKqxoCK/uEhseZmTLZ+/9fbgMDXTBqCaq+V0EVKVZrUB09Wx9+HIsKKkerFVT9M6A6VdN0A9llt7SzJ87nZHXnbsb3VcFg0+Oyhp+9B/66l4AqBqZOnRr1/f33368HHnig2/0WFRVJknJyclpsy8jIUEpKSnifrrjnnnv0xS9+UZmZmTpz5oz+53/+R2VlZbr22ms71Q8BFQAAAAAAPcTscISnCYwU9HrlPXhI3tJD8h4qlffQofC/wJkz8p84If+JE6rbvLlZh2bZMjJkGztWtswxso8dK1vmWNnHZso2dqwsQ4b0yenAADTXFOv09R9pI0YlVN5A0x3zWK5BVd8QUJlNne83cs2qWFR1xYsrYq62JAKqbonFtddbtPezGWjY5vYF2twHXbN3716NGTMm/H0sqqckqbKyUlJoSr/WpKWlhffpivLycq1cuVInTpzQ0KFDNX/+fL3//vuaMmVKp/ohoAIAAAAAIMHMdrucuZPlzJ3cYlugslLe0ujQynMoFGQZdXXylZfLV14ubWql36SkUHg1NlP2zIavDSGWbcxomWN0EwRAfEXeN+7zAVWM+hk9yKnySrekGK1BZQkFNbUef0Ofna+Aiqya6s3/mSKDtOYB1R1LsvW/G0p0YfbQnh5WnxSrwLU3CLZzKv6GEqo/bD7cbh9v7jmuFzYflmFIK+eP1ZW9dB223iQ1NVVpaWmJHkanPfPMMzHph4AKAAAAAIBezDJokFwXXCDXBRdEtRuGIf/Jk/KVHZWv7Ii8R47Id6RM3rLQV39FhYJ1dfIUFspTWNhq39YRI2QbM0a2zMxQBVZmZtP3GRkyWbltAPQGkfeNzXFMqIyYxUftPEeMnmL6mEHhgCoW1Ur2ZlP8daVPW0Tw05tjixRH03t78yn+5mUN1f+qRB+XnOnpYfVJ/kBv/i/dOYF2EqrGKf4q6/3N2o1wFdmBE9W645mt4W0b9p/UG99epCmj+l740h80Vk61VSVVVVWlIUOG9OSQWsUnTQAAAAAA+iCTySTbiBGyjRghzZndYnvQ45HvaHnL8KrsqHyHDytYVxeeOrB++/aWT2CxhKYPDIdWDQFWw/fW4cNl6kKFAYDOi6qgStwwYiJWt/MdEdPUdaXaqbnGgMrfcJM+ssqoo3rzulORkhsCqjSnVQ5rdEAVOf3f1tKzmpuV+BvYvYnHHz3Fnb+9sqM+xu1ve/q+Go9PklO7jkaHHd5AUE5z6JrZX1HT4riC41UEVAnSuPZUUVGR5s6dG7Xt+PHjqqmpUX5+fiKGFoWACgAAAACAfsjscMiRPUGO7JYLmBuGocC5c6Gw6miZfGVl8paVNQRaZfIdPSrD65Xv6FH5jh5ttX+T3S7b6NFtBlisfwXETjAioYpnBVVPiNWUaPaIMCgWuZDDGt1JV6YNjMVUgz3B6w+Vw1wzJ7PFtsiKqkOnagmomnn6w0NR30cGen3di5uPtLnt0XcP6FcrZ+vtfRVR7ZEBXWtXv6nPR+p915IlS/TQQw/prbfe0sqVK6O2vfnmm+F9Eo2ACgAAAACAAcZkMsk6ZIisQ4bINWN6i+1GMCj/yVOh8Oro0aYAqywUWPmOHZPh9YbXxGqNOSmpabrA1gKs1NQ4nyXQf/TUGlR9aTmdyOn0LDGooHI2Cxq60mdkKJ8zIqXbY4qX/153QFLr0xhGXl9pLltPDanPKDtbH/X91NH9pzqo7Gxdm9sOnapttT3Yl940BpjLL79c2dnZev755/Wtb31Ls2bNkhSa8u/BBx+U3W7XzTffnNhBioAKAAAAAAA0YzKbZRs5QraRI6Q5c1psN/x++Y5XhKutfEcjAqyyMvlPnAitf1VUJE9RUavPYR40SPZ2Aiyz0xnv0wT6jMi1ofp6ZWKs7mebI8KVWFQuJTVbi8nWhSn+JOmrF47Tsx8f1uEzbd/s7y0Kjle1aJsUEazZrX1jysKe1PyqaG/dpoFg66GzWpo3ItHDGDBWr16tjRs3SpJ27doVblu/fr0k6ZJLLtHXv/51SZLVatXq1au1fPlyLV68WCtXrlRqaqpefvlllZaW6pFHHtH48eMTcRpRCKgAAAAAAECnmKxW2TPHyJ45ptXtQY9HvvLypikEjx6NCrACZ88qWFkpd2Wl3Hv3ttqHJT29KcDKzJRtzOimAGvUKJls/GU/Bo70FEeihxAzRoxWoYrMpFqrBOosl715BVXX+swckiRJ2lNepSNn6jR2aFK3xxZLgagp2VqeY5rTpgnpyTp4qlb+QLAnh9YnNA+IB8xr1EYw/qO/7iGg6qalS5fKZrNp1apVWrVqVbv7bty4UU8//XRU2wcffKAPPvgg/H1jQNXY98aNG3X//ffrxRdflM/n04wZM/Twww9rxYoVsT2RLiKgAgAAAAAAMWV2OOSYMEGOCS3Xv5KkYG2tvEePhgMr39GyqO+DNTUKnDql+lOnVL9zZ8sOLBbZMzNlz86WY2K27NkTG75mM3Ug+qWRaU49ces8Jdv7/q28mFVQmWJbQdV8LaGu9hl5XOHx6nBAFQwaev/AKVVUupt2Nkkt8rrGw5u1m0ytvHad2Dc3I1Uzxw6WLyJQMbdxjkOT7Tp4qla+wMCuDmpNVb0v6vvyyP+efVxb14MktVVQGHmFePwtwzq3L9DNUfV/69atU2Zmy/XgWvPUU0/pqaee6lT/+fn5euONN7owsp7R93+rAQAAAACAPsWcnCzn5MlyTp7cYpthGApWVsrbuN5VY4BVVibf0XL5yspkeDzylpbKW1qqmnXroo63Dh8u+8SJcmRnhwMsx6RJsqSn9/mp0TCwXZY3MtFDiIlYrVkTGVC1d2O9o2yW6OnsrJauTW8XWXllibir/8r2o/qXP7YSuPegt7+7WCPSmqZPbetla5ze0DdQqoM64ZXtR6O+/9PWMj1y/cwEjSa2cke2/QceF00c1mp76emmqSy/8+KOFtu//8ourcwf1+2xof8ioAIAAAAAAL2GyWSSZfBguQYPlmv6tBbbDcOQ/8RJeQ+WyFNcLG9xiTwlJfKWlMh/4oT8J0/Kf/Kk6j7+OOo4y9ChcuROlnNyrhy5uXLkTpZj0iSZHf1n6jSgu3qiXiZWFVSReXM81krqagVVVEAVMcjXdhxtbfceVVRRoyFJ9vD3bVXkNYZ1BFQDyyBX09S5935uiiTpP17fJ6lp6kog1gioAAAAAABAn2EymWQbOUK2kSOUfOGFUdsC1dXylpTIU1zSEGCVyFtcLO/hwwqcOaO6jz5W3UcRwZXZLPuECXJMzpEzN1eOybly5uXKOmoU1VZAnMQqBIusoLJ3sdqpucwhLpWdrZfU9TWoghFrPEX20RveUwwpatq+byzObnU/OwHVgBRoSI8XTBiqry8KXRufHDqjN/dURK1dBsQSARUAAAAAAOgXLKmpcs2cKdfM6OmWgvX18hwolmd/odyFhfIU7penoECByspQgFVcrOo3/t7Uz5Ahck6b1vBvqlzTpxNaYUAwYlXe1O6TxKYbcxwqqCIDJXMXf94j7+NH9tEb3j2ChhEOnUwmaebYwa3u11hBdawfra+E82sMoaKqABseG4YhP4El4oCACgAAAAAA9Gtml0uuGdPlmjE93NY4VaBnf6E8+/eHgquCQnlKShQ4e1a1GzeqduPG8P6EVkBsGDFKqCLDn+brR3VV5I35XUcru9RH5Bpb0RVUXR9XrBiG5G0IGVIdbd8WfntfhSTpl28X6TufablWIPqnxks3KlhteBwIGuFrp7nTNR4NS3Fo4vBkFZ+sjdo2dqgrPoNFv0FABQAAAAAABpzIqQJTFi0Ktwc9nlBgtXu36vfskXvPXnmKiloPrYYODVdsuWbNlHP6DFlSkhNxOkCfEbs1qCIDqtikP5YYpEiRAVXk416QTzVM8RcKGdqrOvMznduA1FhBZW5lHbWAIfn8rV8XR8/Va1iKQyNSnSo+WauHr50hi9msf/njTmWkOeM/cPRpBFQAAAAAAAANzA6HXDNmyDVjhoY0tLUZWp05o5p161Szbl3DwWY5Jk2Sa9ascGhlnzBBJnNsqjuAeOuJWCJ2a1A1PY7VGlRdXXcqUmSRSbCXBT2hadpCY4pV1Rn6j8Y1qCJ/DBofG4ahl7YcafW4xkDzo5LTkqQztb5wAPrJobNxGi36CwIqAAAAAACAdrQZWhUUqH7HDtXv3Kn6HTvlKy+XZ/9+efbv17mXXgodm5oq1wUXhAKr2bPlmj1LlpSUxJ0MkGCxWucqspdYhS1dXXcqUmTVVCCygqo3zPEn6YmNByWxvhRaavzZjKwkbKymCgQN/debha0eF2gWxO48ck5JDktUv73l+u+Nli5dKpvNplWrVmnVqlWJHk6PI6ACAAAAAADoJLPDEZ7er5HvxAnV79wpd0NgVb97t4LV1ar94APVfvBBw4FmOXJzlTR3rpLmzZVrzhzZRoxI0FkAPS9WNUX+iFIlWzvT1XWGNQZTBUZWTUVOldcbbs8bhvTK9qOJHkaftfPIuVbb+0sA8+i7ByRJ7xScCLdtKjkjSXrojYI2KwybZ87+YFD+QNPPZNCQYjQLZ7+0bt06ZWZmJnoYCUNABQAAAAAAEAO2ESNku+IKpV1xhSTJ8PnkKSpqqLDaobpt2+U7ckSeffvk2bdPZ599NnTcuHGhwGruHLnmzpV9/Ph+cbMTfVAPzEgXqzWoajyB8OMkm6WdPTsuFhVUkev3RIZVveFHOhirF3+AagxwmusvAUzZ2foWbUfPNbW1Vf1oGEbUtT52aJKS7E0/k4GgEZPpM9E/EVABAAAAAADEgclmk3PqVDmnTtWQG2+UJPkqTqh+21bVbdmquq1b5SkslO/wYVUePqzKP/9ZkmQZNkxJc+Y0VFjNlXNKnkxWbuGgf4hVUBN5Q9wco5vfsbiJfs2cMeGp0JpPfZZoHR3Oz2+Yqe++tLNXhGq9SVuvx0APYAxFVwveuWSibBazHltXLKn3/Rygd+HTDQAAAAAAQA+xjRwh25VXKu3KKyVJgepq1W/frrqt21S3dYvcn+5S4PRpVa9dq+q1ayVJ5pSUUIVVfr6SFiwIBVaW2FSMAH2VOTaz+kWJRcgwapBLs8cN1vbD55rdmE98gBHsYFCQNSxZkjRuaFI8h9NvDJTKNKfNojpvoEW7YUS/BikOa9TPUmCAvD7oGgIqAAAAAACABLGkpipl8WKlLF4sSQp6vXLv3t1QYbVF9du2K1hdrZoNG1SzYYOkhsBq3rxQYJWfT2CFmOlbt5FjH/hYYlQyZLeE0rPIG/O9oRqpo0FBY7jgD/StKyJRBkqF0DVzxujZjw+3aD9R7Y6qoLKYTdEBFdcR2kFABQAAAAAA0EuY7fbQ9H5z5ki6XUYgIHdBgeo2f6K6zZtVt2VLKLBav14169eHjklNjQis5suZR2CF/i8egU/kTfV7Pzel2/1EBhe9IJ+KChHa0xjUDZTKoI5q67/hQKkQsjUEr6uWTlSa06aH3iiQJH37Dzu08/5l4f0sZlNU2DtQXh90DQEVAAAAAABAL2WyWOSaNk2uadM07LZbmwKrTZujA6t161Szbp2k6MAqeUG+HLm5BFZAB0SuZTVr7OAu99NaQNUbdHSKv8bpE3vb+BOtrVC0o69rX3foVK0kyWI2KzcjNWpb5LViMZlkNptkMoWm//MHgz06TvQtBFQAAAAAAAB9RFRg9bXbQoHVvoJQWLVpk+q2bm0ZWKWlNQRW85W8YEEosIrHAj7o84wBXulgjQioulOh1VpA1RuqkXyBjgUF1ob3BwKqjhkoL9O6wpOSpN1HKzU3a0jUtsZrxWRqCnqtZpN8AUPkU2gPARUAAAAAAEAfZbJY5Jo+Ta7pDYGV398UWDVWWFVVqebdd1Xz7ruSJPOgQUqaP0/J+QuUtGCBHDmTCKzQ58Qj7zFHpVJdT6gapzfz+EN35k9UufX2vhPdGVpM/Mfr+5Q7MlWFFdX61cpZbe5naayg6gWhWm/y5p6K8ON/+/xU/fvf9kqSPjhwSl+YOTpRw4qZcUOTdPhMnRblpIfbbrowS898XBq136dl51ocu+toqC3ykvE1rD21/fBZXTljVMzHi/6BTx8AAAAAAAD9hMlqlWvGdA37x69p7P/+WpM3fazxf/yjRnzvX5S8eJHMSUkKVlaq5u13VPHggzr4pS+paOElKvvWt3XmuefkKSoa8FU0GLgsEXdKu1NBVV7pliTd++puSdLafRXt7d6jGkOn4SmONvexUEF1XpfljQg//ucXtidwJLEzbXSaJGnZ1JHhtosnDmuxn8VsahHffv/lXW32+69/+jQm4+uvli5dqqlTp+qxxx5L9FASggoqAAAAAACAfqoxsAqFVv8YqrDas0e1mxqmBNy2TYGzZ1X91luqfustSZJl2LDQdID5+UpasED2CRNk6s7d+n7mh1fl6cE1Bfq/V01J9FBibqDHEdaISsLuXPH7jlVFfd+bMl9/wzR/FnPbZ9hYAUZA1bZ2Xr4+q3Eaysj3+9be+61U3MbUunXrlJmZmehhJAwBFQAAAAAAwABhslrlmjlTrpkzpW/cLsPnU/2u3arbvCk0JeC27QqcPq3qN/6u6jf+LkmyDE9X8vxQWJW8IF+2rKwBHVh9Y/FEXTsnU8PaqUBB32Q2t39jvj9onHbNamknoLIQUJ2PuR9eH43/uSPPrbUg02oxtagwbPfl6H8vFWKIgAoAAAAAAGCAMtlsSpozW0lzZkt33qmg1yv3p5+qdvNm1W3arPrt2xU4eUpVa9aoas0aSZJ15Egl5ecreUEotLJlZvbbm/ltIZzqDWIfnkRmNrG8onvTj4cvXEHVdhUMFVTnZ+6HJVRGuIKqqa2107S0ckG3F9j1xzAPsUNABQAAAAAAAEmS2W5X0rx5Spo3T/rmNxX0eFS/c6fqGqYErN+5U/6KClX99a+q+utfJUnWUaPC0wEm5efLnjkmwWeBrupNU9ElQnQFVez6bbliT+KcqPZIaj1kaNRYNRMY6BdEO/phPhX++TdHBVQtTzS0BlV0e/sBVUyGh36KgAoAAAAAAACtMjscSs7PV3J+vvTPdynodqt+xw7Vbtqkus2fqP7TT+U/dkyVr72mytdekyTZxoxpCKvmK3nBAtlGjUrwWQAdY40MqHpRqBQP7a5B1bDNMKRg0OiX1ULd1V7A11e1tgZVa//t//WzeXLYoivwlk/L0BMfHGy1389fMDqGo0R/Q0AFAAAAAACADjE7nUq+8EIlX3ihJClYV6e67dtDFVabN6t+9275jh5V5SuvqPKVVyRJtnHjwmFVUv4C2UaOSOQpAG2yxKuCqhdmGe2uQRXxOviDhuwEVC30x2lNG2d0jDyz1v7Tzx8/RGlOW1TbiLTQtKfXzG6qoL1+bqb+uLVMGYOcsR4q+hECKgAAAAAAAHSJOSlJKQsXKmXhQklSsLZWddu2qW7zZtVu2iz37t3yHT6sysOHVfmnlyVJ9qysUIXVglBllnX48ESeAvqoeMw+F6+1cnpjlNFeBVVkJVmQaf5a1R8zu8b/1pE/B639TFgt5haVVY3rlUVeV40haJC1zNAOAioAAAAAAADEhDk5WSmLFill0SJJUqCmRvVbt6q2YQ0r97598paWyltaqnMvvSRJsmdnh8OqpPx8WYcNS+QpDGiGBvaNZGt/TB3a0N65Nq+gQkvtBXx9nTli9r7WMtvWrh3DaBlQNYZbXEJoDwEVAAAAAAAA4sKSkqKUJUuUsmSJJClQVaW6LVtVt2mTajdvlqegQN6SEnlLSnTuhT9Ikhw5k5SUv0BJ+flKyp8v65AhiTwFxFhvLsgxD6Ap/jqyBpXUVBmDaP1zir+WFVStrbXVWkD11IeHJEklJ2vDbaWn6yRJv9tYom9/JieWQ0U/QkAFAAAAAACAHmFJS1PqZUuVetlSSVLg3DnVbdmi2s2bVbdpszyFhfIUHZCn6IDOPvecpFBg5Zo7V0lz5ylp3lzZRo1K5CmgH4u8GW+K4cR8sewrVqyRZTLNRL4Obl9Ag1y2NvcdqPpjAVUw2LKt+VR+UmiKv+ZO1XglSZsPnQm3bTxwSpJU5fbHaITojwioAAAAAAAAkBCWwYOV+pnPKPUzn5Ek+c+eVd0nn6hu02bVbd4UDqs8RQd07g8vSpJsY8Yoad7cUGg1b77sE8b3y2qGROiJ6qZYPUU8xmqJUwVVb9ROPhUVSvzy7SI9dM2MHhhR71Z8sibq+/44xd9HJadbtLV3msOS7Tpd643jiDAQEFABAAAAAACgV7AOGaK0ZcuUtmyZJMl/5ozqtm5V/Zatqtu6Ve69e+U7elSVR4+q8rW/SJIsw4Ypac6chtBqnpx5uTJZueWFzotbQNULs4z2Kqgivb2vQg+JgOqeP30a9b25HyeYO46c05dmjZEkna5pO4DiDwNiY+nSpbLZbFq1apVWrVqV6OH0OH5bAwAAAAAAoFeyDh2qtCuuUNoVV0iSAjW1qt+xQ3Vbt6h+y1bV79ypwOnTql67VtVr10qSzMnJcs2eraR5c5U0d66cF1wgs8ORyNNAHxEVUMV0ir/ep6MVQP2wUKhL6n2BRA+hxwQj1h1rr1CRfCo21q1bp8zMzEQPI2EIqAAAAAAAANAnWFKSlXLJQqVcslCSFPR65d69W3VbtoZCq23bFayuVu3GjarduFGSZLLZ5LzgAiXNnRuqspo9W5bU1ESeRq/VE1P89WaRVTHdufn+uQtG6fVPj8VgRPHjslkSPYQ+pfn1YI1I7iYOT+7h0cRX5NtAe5VirYWXkeuVjUxzqKLKE8ORoT8ioAIAAAAAAECfZLbbQ9P7zZkj6XYZgYA8+/c3BFah0Cpw8pTqt25V/datOv0bSWazHHm5Spo7T0lzZss1Z45sI0cm+lTQC1ijKqi67vufzYsKqCJv+M8eN1g3zh+noGFEhR4mmWTICIeEkdsMQx3et7nGfe95eVdUu93asSn+YllJ1p9YLWbddGGWnvm4VAuyhyV6OHHTXgVd5LWRl5GqguPVevDLTdNB/vhL03XHM1s1eWRKPIeIPo6ACgAAAAAAAP2CyWKRc8oUOadM0dCbvirDMOQ7fFh1W7aEQyvf4cPy7N0nz959OvvMM5Ik6+hRSpo9R645s5U0e7YcubkyWagwGWjMMVqDymELhT+N3RkRpWnXzx2rG+aP7XrnXRQZUA1Ntnf4OKb4C2ktqBuZFpo61OhnpYeRp9NeBVXkpsbA02VvCj6dDVV6NkvHwlAMTARUAAAAAAAA6JdMJpPsWVmyZ2Vp8LXXSpJ8FSdUv22r6rZsVf327XIXFMhffkxV5a+r6vXXJUnmpCS5Zs2Ua/YcuWbPlmvWTFlSqALoTYx2V8fpmuh1mbqezDTe1G9cyqcv5xemfrLQ0Mlqj378t736Sv44XTSx8xVPrb0Mja9NMNjd0fUuUT9b7VZQNQk2XOSR14spvC12Y0P/Q0AFAAAAAACAAcM2coRsV16ptCuvlCQFa2tV/+mnqtu2TfXbd6h+xw4Fa2pU++FHqv3wo9BBZrMckyfLNXuWkubMkWv2HNnGjO43N+8RYonRGlSRhxqGEXWDnksmMR746x69/ukx/XVnuQ799HOdPr61/2xNQWTfT2D++YXtrba3X0HVtG330aoW+zc+7m8VZogtAioAAAAAAAAMWObkZCVfdJGSL7pIkkLrWB04oPrt20Oh1bbt8pWVyVNQIE9Bgc698AdJknXECLlmzw6vY+XMy5PJZkvkqaCbLDFagyryxr1h9O0A43MXjEr0EGKi7Gx99zpoJahpvFz6Q4XQX3eWhx9HT/HX9jGtVpVFPG56ffrBC4S4IaACAAAAAAAAGpgsFjlzc+XMzdWQlSslSb4TJ0LVVdu2qW77drn37pX/xAlVv/mmqt98M3Sc0ynXjBlyzZkTqrSaNUuWwYMTeCadN9ArHaICqm6UOkVVUDX8a21bonTkv/ON+eP0wubDSrb3j9vH3X3d26ug6m8/N9HXa9uvXGvVVWZTy5+hfvbyIMb6xzsMAAAAAAAAECe2ESNkW75MacuXSZKCbrfcu3apbtt21W8P/QtUVqruk09U98kn4ePs2dlyzZwZ+jdrphyTJslk5XZcb2WOWQVV9Pd9McDob9Uv7VUCdUTra1CFvvaX16hRdyqoIvfvb9cQ4oPfiAAAAAAAAEAnmJ1OJc2fr6T58yVJRjAo78GDTetYbdsm76FD8paUyFtSoso//1mSZEpKkmv69HBg5Zo5U9b09ESeSp8Vj3vekTfXY7VWlGEYCvayOeA6Uh0Wrg6K92B6SHtrKXVE+2tQdavrXqjphNq7VlrbYqKCCp1EQAUAAAAAAAB0g8lslmPiRDkmTtSQ66+XJPnPnlX9zp2q37lT7p07Vb/zUwVra1W3ebPqNm8OH2vLzIyqsnLm5clktyfkPAb6feR3950IP25varPziTy2+RR/88YP6XK/PakxZ+iL1V+tiVXgKEmLJw+X1H8rhC7LGxl+PGlESpv7tRb6maigQicRUAEAAAAAAAAxZh0yRKmXXqrUSy+VJBmBgLwlJeHQqn7HTnkOHJCvrEy+sjJVvf66JMlkt8s5dWpTldWsWbJmZHRrTSR0zNk6b/hxt17uiGMNo6nCZvLIFE0akdqNjnuOuZ9Vv3T35yfy+Hs/N0VS05SQ/eU1ajQ3qylETU+JDsvv+Wxe0zetTvHXsoKq/1WYIZYIqAAAAAAAAIA4M1kscuTkyJGTo8HXXSdJClRXy71rVziwqt+5U4Fz51S/Y4fqd+yQng4dax0xoqnCavoMOadNkyUlOXEn00/F6j56ZBZiyAhXIU0dlRajZ+g5vbn6ZcuhM7ru1x+Fv//qheP071+a3moYFTl9YyBoyNLJRam2lp4NP7Y2HNv4PK/vOqbHOtVb7xb5yjR/LW2W9tdpa20NKmPA12aiPQRUAAAAAAAAQAJYUlOVfPHFSr74Ykmh6dR8hw83BFY7VL9jp9yFhfKfOKHqtWtVvXZt6ECTSfbsbLmmTwsFVtOnyTllisxOZ7fG04uziB4RuVZUdwpuIg9d/ov3dOh0naTur4PUkxrH+vSHhzRxeIrsVrNOVns0erBT5+p8WjF/bMKr+iLDKUl69uPD+sIFo7Uge1iLfSOnXXx5W5lumDe2y8/b+NoUn6gJt43//utd7q+3ae8/a+R7ROtT/LWsoDpypl4VVW6NTOve+1N/tXTpUtlsNq1atUqrVq1K9HB6HAEVAAAAAAAA0AuYTCbZs7Jkz8rSoC9+UZIUrK+Xe8+epiqr3bvlP3ZM3uJieYuLVfnaX0IHN1RouWZMbwqtJk+WyWZL4Bm1ND9i+rDeZuroQdpZVimp+1PCNWoMpySpyu2PSZ/d1ZF1pRpPv9Yb0N1/3Nli++jBrvBaTL3JmVpvq+1mc9Pj45Xubj1H42tTVe/rVj+9VUfXX2vtR6S1CipJOl3jJaBqw7p165SZmZnoYSQMARUAAAAAAADQS5ldLiXNm6ekefPCbf5Tp1S/e7fcu/eEpgjcvVuB06flKSiQp6BA+uOfJIXWs3Lk5ck1fbqc06fLNWO67NnZMlksPX4e6/7lUq3de1w3XTg+Jv3Fo9jrpguz9MLmw5Jan76so9oKt9y+QDd67VnnmwFvf0V1rwyo2hJZ7dPJ2f3a7CvRFWRx014FVcRPXmtBlinqdW563NkpFTFwEFABAAAAAAAAfYg1PV2pl16q1EsvlRSqiPEfPx4KrXbtlnv3btXv2aNgZaXcn34q96efho81JSXJOXmyHFPy5MybIufUKXLk5MjsdMZ1rZgJ6cn6xuKJces/FuzWpjKbWE3x11f1t/Cltannut5X6OtAz1zOV0FlaqOaCohEQAUAAAAAAAD0YSaTSbZRo2QbNUppV1whqWE9qyNHVL9rV1Ol1d69MurqGta32tHUgdkse/YEBaZ9TTI5JEn+s2dlHdJ7p+OLh8gb6h2d5ux8/XSkvTc631j72nplkafT3bXAGgOuvrSmWGd09LRaC/oiX5PInyEzCRXaQEAFAAAAAAAA9DMmk0n2ceNkHzdOgz73OUmSEQjIe+iQ3PsK5CnYJ/e+Arn37VPgzBl5DxRrfvI2vTHhIo2oO6uiiy6WdeRIOfPy5Jg6Rc7cXDkmTZI9K6vXrWsVK5G30LtXQdW7b8Z3JFs63znEs9ouHiLzEYu57f0605e5m/30Vu39l48MJlvbL6pqKuL1sfTTMA/dR0AFAAAAAAAADAAmi0WOiRPlmDhR+nxDaGUY8p84KU/BPv1wT6GmFe/TvOIPJUn+igrVVFSoZsOGpk5sNjnGZ8mRkyP7pElyTJokx6Qc2ceNlcnac7ca41HBE1390XW9/V58Ryp/+lvBS1uVPd3pq79Ng9iovfOK/LFrLaAzt7EGVX+tNkP3EVABAAAAAAAAA5TJZJJt5AjZRo5QypIluquhPVBTK8/+Qrn37ZN73z55iorkLTqgYF2dPEUH5Ck6EN2P3S77hAmhwCpnkuwTJ8qelSV7VpbMDkfPn1gXDJR76B0Jn/rdFH9Ra1B1s6+Gr/0txGvU0dNqLehra92p/lpthu4joAIAAAAAAAAQxZKSrKQ5c5Q0Z064zTAM+cvL5TlwIPSvqOFrcbGM+np5CgvlKSyM7qhhfSz7+PGyj89q+Br6Zxs9ukerrs4n6oZ7Pw0fQjpSQdW/XoDoKf5Yg6o9HV+DqmVb9GvS9Li7rzn6r97zGwAAAAAAAABAr2UymWQbM0a2MWOUsmRJuN0IBuUrL5enqKghuCqS91CpvAcPKlhdLV95uXzl5ar98MPoDm022TMzZRs3VvYxmeG+bWPGyJY5RpbBg3t0GjVTVD7V9eft7blFR8bXy0+h02I53Vx4Dare/h+6izp67bf2sxnZElVB1U9fK3QfARUAAAAAAACALjOZzbJnZsqemanUpUvD7YZhKHD2rLyHDsl78JC8paWhx4dCjw2PR96DB+U9eFC1rfRrTkqKDq1Gj5J1xEhZR45QoMYb+/Mwtf640/20cYO/t6xZ1LEp/trfqY/N8Bc1xVx3i3kaX5te8p8z5to7r7yM1PDjiycO084j59o8ljWo0BEEVAAAAAAAAABizmQyyTp0qKxDh0ZNFSiFqq78FRWhgKqsTL6j5fIdPRr6V1Ym/8mTDetdFclTVNSi7+o5K6Vx8yRJxZ/7vKwjhss2YqSsI0aE/qUPk2XIEFmGDJV16BBZhgzp1HSC3bmd3tvvxdut518Q6Hzn8NM3CnTd3Eylp/SN9cXe238q/PjfXtujf3ttT5f76u8VVO1ZMnl4+PHXFk7Q4+uLo7a3tTYZU/yhLQRUAAAAAAAAAHqUyWyWbdQo2UaNUnIr24MeT2hqwMbgqqxMvmPH5D9xQv4TJ2SyWsL7eouL5S0ubqWXaOZBg2QdEgqrLEMbgqvBQ2ROTZUlLVU11qbqEO+hg/INHypLaopMLlevqX6Khc9MGXnefep9gfPu883ntumlOy6KxZDirsbjj1lfTWtQxazLPiPy58BuaRl0Gm08HoivFTqGgAoAAAAAAABAr2J2OOSYMEGOCRNa3Z724g5p+1FJ0rgnn5CvokL+EydDAVZFhfxnzyhw5qwCZ84oUFkpGYaClZXyVlZKhw612udJ1yBp+b9Jkg7f+A8656sLbbBaZUlJkTk1VWaXS2aXS6Ykl8yuJJmTkkJtSS6ZXKG2oNMlaWiL/oM1NfIeOiSTK0lml1Nmp1Oy2Xo8/Pre8tzz7uPxBc+7z+aDZ2IxnD7H1M8rqDp8Wq3sF4wooYp8bCahQhsIqAAAAAAAAAD0WckXtV/FYwQCClRWKnDmjPxnzihw9pwCZxsenzunYHWNAtVVqqltqhqypKVKZ91SMCj5/QqcO6fAuXMdGk9AJunq/2rRXr99u4r/Z1V0o8Uis9MZCrecTpldTpmcroY2p8xOV1Oby9WszRlqczbfzxnVp8npjHpKl80idJ05vAZV/wxd2lpDrbnWMidHRGWjI2IqSZv5/NNKYmAioAIAAAAAAADQb5kslvBaWO2tmGStrJceeleSlLN2rdJcVhl1dQpUVytYXR36WlevYH2djPr6hsetf9/qOOw2mVNSFKyvlwINYVggoGBtrVRbq/NPqtcNVz8Sflh81ecaQixXU5CVnCTL4MHhf77Tg+I5mj6taQ2qxI4jXjqau7UW0I0flhR+nDkkSXcsyVaK3SqXnVAUrSOgAgAAAAAAAIBIptANeFNysszJyVJGRueO//7rLZqSF1yo3P/9lgzDkHw+Bd1uBevdMtz1CrrdoZDL7Vawvl5Gw7agu15GfWNbfXSb2x3R1nR8+KvH0+rQfAcPnnf4tdO/KE1afN79Tj7637IOHy7riOENX0fIOmyYTNb+e9u5scKo307x1439modWP7hySrfHg/6t/75TAAAAAAAAAEAXxDN7MJlMkt0ui90uS1pa3J7HCARCQZfbLf3Xx+H2rGd+3zIIq6kOT2MYOFcpmzGyQ89x6rHHWjZaLLKNHCnbmDGhf6NHNzxu+JqRIZPNFqvT7HGmfl9B1bET66f5XI9bunSpbDabVq1apVWrVp3/gH6GgAoAAAAAAADAgNfRtXe63H8P39A3WSxNFWARkubPP++xaX/bK208f6XV4BtukP/kyaZ/p05JgYB85eXylZdLn3zS8iCzWbbRo2WfMEH28eNlnzBejvHjZZ8wQdaRI2Xq5esV9f81qDq6X/88/562bt06ZWZmJnoYCUNABQAAAAAAAAARDCPRI0isjkYPo378o6jvjWBQ/pOn5Dt6NBRSHT0a/bi8XIbHI19ZmXxlZap9//3o53W5ZM/KkiM7W47Jk+XInSxnbq6so0b1mkCoqYKqd4wn1jq+BlV8x4GBgYAKAAAAAAAAQJ8Sj/yIG+5NuvpamMxm2UaOkG3kCGnO7BbbDcNQ4NQpeQ8dkufgQXkPlcp78GDoX1mZjPp6eQoK5CkokNasCR9nTk0NBVaTc+TMzZVjcm5XT63bGoMppviL80AwIBBQAQAAAAAAAECc9bf7+V0JKEwmk6zDh8s6fHiLqQYNn0/esjJ5Dx6Sp/iAPIX75dm/X56SEgWrq1W/davqt25tOuDqR1r0b/j9nR9UJzUGU+b+mlB1EFP8IRYIqAAAAAAAAAAgUhxKtIL9bNrAWE+DaLLZ5JgwQY4JE5R62dKm5/F65Tl4MBRWFRbKvX+/3Pv2tdpH+b/+qw4O8ss5Y7qSZs2Sa84c2WK8vo8pvAZVTLvtcwb6+SM2CKgAAAAAAAAADHhOqyX82GqJ/d33DftPxrzPgcBkt8uZmytnbq70hS80bfj+6y32NQIBuXfvlnv3bp174Q+SJOvw4dLCe2I+rv66BlVHDeyzR6wQUAEAAAAAAADoU4xYl+9IGpRk0799fqrMJinZMbBvm8bh5e0RGffdpzH1R1T/6S7Vb9um+r175T8Zn2BwgM/w1+G1qoD2DOx3WgAAAAAAAABo8I+XTEj0ENANlmHDlDZjutKuvFKSFHS75d61S/ZXKuQ1YhuoDPQKqoEe0CE2CKgAAAAAAAAAAJ22t7xKRsOCXaaGSd8MGa0+bm/b+fqI3BY0jDbDob3lVcoalqQUh1VBQ/IHgkrOmS6vcSJ2J91goFcQDfTzR2wQUAEAAAAAAAAAwjo6w99V/+/9uI6js/573QH997oDPfJcDqu5R54H6M8IqAAAAAAAAAD0KX1xiaSX7rgo0UPosDqvP9FD6LUOrbxRyYsu0WX5Fyd6KECfR0AFAAAAAAAAAHGWP2Fowp770Rtn659f2N7h/UcNcsVxNH3X/938tOrLd6l+xw65Lb+RvvBgeFvxDxfJkpaWwNF13fjvv57oIWCAIqACAAAAAAAAAOA8Rv3oAWUc36Pa9zfK+/EnUdv2X7xQSfPnKXXpZUq5bKnsmZkJGiXQdxBQAQAAAAAAAEA/ZjJ1bn+jL86h2AOsQ4ZoyKLrNeT66zW01i39+ztNG/1+1X30seo++lgVDz4oR06OUi67TKmXLZVzxgyZzKxZBTRHQAUAAAAAAACgTyFAiS+jT67y1bPMdlvU9xPf/Luq161TzbvrVLd1qzxFRfIUFen0//6vLOnpSl16qVIuv1zJF18ss92emEEDvQwBFQAAAAAAAAAA3WDPytKwW2/VsFtvVeDcOdW8/76q331Xte+9r8CpUzr3xz/p3B//JHNKilIuvVSpy65QyqJFMrtY7wsDFwEVAAAAAAAAACCMCrXusQwerEFf+IIGfeELMrxe1X7yiWreeVfVa9fKf/Kkqv72N1X97W8yOZ1KWbRIqcuWKWXppbKkpCR66ECPIqACAAAAAAAAAOC8mhbz6ui6Xia7XSkLFypl4UKNvPf/qn7HTlWvXavqt96S7+jR0OO1a2Wy2ZR08UVKW7ZMKZddJuuQIXE6B6D3IKACAAAAAAAA0KdQ4BNfvL7xYTKblTRntpLmzNaIf/2e3Hv3qvqtUFjlPXhQtRveU+2G9ySLRUn585W2fLlSly2TdejQRA8diAsCKgAAAAAAAADox0zqYLlPI+b4O69Ov6bNjzeZ5Jo2Ta5p0zT8O9+W98ABVa1dq+q31spTUKC6jz5W3Ucf6/i//4eSFyxQ2lVXKvUzn5Fl8ODYnADQCxBQAQAAAAAAAABwHpHT+nV0ir+O9WuSIydHw3NyNPyb35S3tFRVb72l6jf+Lvfevar98EPVfvihjv3ox0q++CKlXXmlUi+/XJbU1NgNAkgAAioAAAAAAAAAAHoJe1aW0m+/Xem33y7voUOq+vvfVbXmDXn27w9PA3jcZlPy4sWhsGrppTInJyd62OiCpUuXymazadWqVVq1alWih9PjCKgAAAAAAAAAAGFM8Hd+MSygapd9/Hil33mn0u+8U57iYlWteUNVb7whb0mJat55RzXvvCOTw6GUJUuUdtWVSlmyRGaXq4dGh+5at26dMjMzEz2MhDEnegAAAAAAAAAA0Bn/tGSiJOm6uQP3xm483TBvbKKH0OuZIub4+/6VeT3ynI6JEzX8n+9S9ut/04TXXtOwO++QLWucDI9H1W+9paPf+T/av/ASHf3u3ap++20FPZ4eGRfQVVRQAQAAAAAAAOhTpo5O094fL5fLZkn0UPqlsUOTEj2EXqmtqqkFE4b27DhMJjlzJ8uZO1nDv/1tuffuVfUbb6jqjb/Ld/SoqtasUdWaNTInJyv1M5cr9corlXLxxTLZ7TEdR3qKXadqvJIkm6WnasrQnxBQAQAAAAAAAOhzkuzc2kTi9JY4xmQyyTVtmlzTpmn43XfLvWtXaBrAv/9d/uPHVfnaX1T52l9kHjRIqVd8RoOuukpJ+fkyWbv/82M29ZZXAX0V7+IAAAAAAAAAAHRCb8xmTCaTXBdcINcFF2jEv35P9Tt2hMOqwKlTqvzTy6r808uyDB2qtM8uV9qVV8o1d26Xn4+1ytBdBFQAAAAAAAAAAJyHqTemUm0wmc1KmjNHSXPmaOQPvq+6T7aoas0aVb/1lgJnzujs8y/o7PMvyDpihHTxvyZ6uBigzIkeAAAAAAAAAAAgfvpQrtJn9KmwymJR8oULNOrHP1LO++9p7G9/q0Ff/rLMqanynzgRte+JRx5R/Z49MozO1Ud1cndAEhVUAAAAAAAAAAAMCCabTSmLLlHKoksU/NEDqt24UXqrPrz99Orf6fTq38melaXUq66UNKnNvgil0F1UUAEAAAAAAAAAMMCY7XalXnZZVFvqsmUyORzylpbq9OO/jtrmPXSoB0eHgYAKKgAAAAAAAAAAzqPvTOrXdZn/71cK1NSqZt06Va1ZE7Wt+LNXyjl1qtKuulKpn71SEiVU6B4qqAAAAAAAAAAAgCTJkpKsQV/4vMY+/j/NNljk3rtXJx75mYo/8xkFKisTM0D0GwRUAAAAAAAAAACgXTnvv6eMBx5Q0oIFkskkw+cPbzMCAZ194QX5T59O4AjR1xBQAQAAAAAAAEA/NhCmpkP8WYcO1ZCVK5T19FOatGG9TMnJTRsNQ8d/9GMVLVqsw1/7R537058UOHcuYWNF30BABQAAAAAAAADAeZhI+sJsI0bI7HQ2NZjNck6fLgWDqv3wQx2799+0f9FiHbnjTrn37k3cQNGrEVABAAAAAAAAQD82I3NQoofQr2UNSz7/Tv3Q95bnhh/fuXSSJvzpj5r45t81/DvfkSM3V/L5VLNhg0w2WwJHid7MmugBAAAAAAAAAEB/8r3lufqvNwsTPYywzCFJevu7izXIZU/0UGLuVytnadPBM3p+0+Eef+6Pf3C56n0BDU3u26/rVy8cp2c/bv/1uyxvRIu2lfnjZLWYZTFLX5w5RpJkz8pS+p13KP3OO+QpLlbtBx/IkZMTl3Gj7yOgAgAAAAAAAIAYSrJbEj2EFiaNSE30EOLiS7PGKCPN2SMBVfMp/jIGOVvfsY8ZnnL+87CYW5/f8Lq5mW0e45g4UY6JE7s8LvR/TPEHAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAMWQYiR4BAPR+BFQAAAAAAAAAAJyHSa2vwzQQDNwzRzwRUAEAAAAAAAAA+iwK1oC+iYAKAAAAAAAAAIDzoYwIiCkCKgAAAAAAAAAAAPQoAioAAAAAAAAAAAD0KAIqAAAAAAAAAAAA9CgCKgAAAAAAAABAn2UYiR5B/2di/S3EAQEVAAAAAAAAAMQQeUn/REYDxBYBFQAAAAAAAAAAAHoUARUAAAAAAAAAAOdhYp47xNjSpUs1depUPfbYY4keSkJYEz0AAAAAAAAAAACAgWbdunXKzMxM9DAShgoqAAAAAAAAAAAA9CgCKgAAAAAAAABAn2XISPQQ+j2TmN4QsUdABQAAAAAAAAAAgB5FQAUAAAAAAAAAMWQYVPQAwPkQUAEAAAAAAAAAcB5McgfEFgEVAAAAAAAAAADnYSKhAmKKgAoAAAAAAAAA0HcxoyLQJxFQAQAAAAAAAACANlE9hnggoAIAAAAAAAAAAECPIqACAAAAAAAAAABAjyKgAgAAAAAAAAAAQI8ioAIAAAAAAAAA4DxMYiEmIJYIqAAAAAAAAAAAfZbRQ89jIp8CYoqACgAAAAAAAAAAtIlwDvFAQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5lTfQA+rtgMChJOnbsWIJHAgAAAAAAAKAnnD19SkFPXVRbWVlZgkbTNc3H31uVlZXpRMW5HhnviYrjKnO64/48Pa3ybNP12vw6bWyvrarsc9dwb9aYFzTmBwOVyTAMI9GD6M8++eQT5efnJ3oYAAAAAAAAAACgF9m8ebPmz5+f6GEkDAFVnPn9fm3fvl0jR46U2cyMio2qq6s1depU7d27V6mpqYkeDtApXL/oy7h+0Zdx/aKv4xpGX8b1i76M6xd9Gdcv+jKu37YFg0FVVFRo9uzZsloH7kR3BFRIiKqqKg0aNEiVlZVKS0tL9HCATuH6RV/G9Yu+jOsXfR3XMPoyrl/0ZVy/6Mu4ftGXcf3ifCjpAQAAAAAAAAAAQI8ioAIAAAAAAAAAAECPIqBCQjgcDt1///1yOByJHgrQaVy/6Mu4ftGXcf2ir+MaRl/G9Yu+jOsXfRnXL/oyrl+cD2tQAQAAAAAAAAAAoEdRQQUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVAhR71ySef6KqrrtLgwYOVnJysCy+8UC+99FKih4V+6ujRo/rlL3+pZcuWady4cbLb7crIyNC1116rTZs2tdj/gQcekMlkavPfoUOHWn2eN998U0uWLFFqaqrS0tK0dOlSvfPOO22Oa//+/brhhhuUnp4ul8ulmTNn6vHHH5dhGLE6dfQT48ePb/N6vPTSS1vs7/F49OMf/1g5OTlyOp0aPXq0vvGNb+jEiRNtPsdzzz2n/Px8JScna8iQIfr85z+vbdu2tbk/7+PoiKeeeqrd91OTyaTLL788vD/vv0iUZ599VnfccYfmzZsnh8Mhk8mkp556qs39q6qq9N3vfldZWVlyOBwaP368vve976mmpqbV/YPBoB599FHNmDFDLpdLw4cP14033qiSkpI2n4PrGh3V0evX5/Pp5Zdf1i233KIpU6YoJSVFqampWrBggR5//HEFAoEWxxw6dKjd9+UHHnig1TEdO3ZM//iP/6hRo0bJ6XQqNzdXP/nJT+Tz+VrdvyufXdA/dOb9t7d+Tujs7wT0H525fs/3mdhkMunIkSPh/Xn/Rbx19l6ZxGdgxJc10QPAwLFu3TotX75cTqdTK1euVGpqql5++WWtWLFCR44c0d13353oIaKfefTRR/Xwww9r4sSJWrZsmYYPH66ioiK9+uqrevXVV/X8889rxYoVLY675ZZbNH78+BbtgwcPbtH27LPP6qabbtLw4cN16623SpJefPFFXXHFFXrppZd03XXXRe2/d+9eXXzxxaqvr9cNN9yg0aNH6/XXX9c3v/lN7d27V48++mgsTh39yKBBg/Sd73ynRXvzazQYDOpLX/qS3nzzTV144YW69tprVVRUpNWrV+udd97Rxx9/rOHDh0cd85Of/ET33nuvsrKydOedd6q6ulp/+MMfdPHFF+udd97RwoULo/bnfRwdNWvWLN1///2tbvvTn/6kPXv2aPny5S228f6LnnbvvfeqtLRU6enpGl65d7QAABE4SURBVDVqlEpLS9vct7a2VkuWLNGOHTu0bNky3Xjjjdq+fbseeeQRbdiwQe+9956cTmfUMXfccYdWr16tadOm6Vvf+pbKy8v10ksv6a233tLHH3+snJycqP25rtEZHb1+i4uLdd111yklJUWXX365vvjFL6qyslJ//etf9c1vflNr1qzRX/7yF5lMphbHzpw5U1dffXWL9tb+UOb48eNasGCBysrK9OUvf1k5OTnasGGD7r33Xm3evFmvvvpq1HN05bML+o/OvP826k2fE7ryOwH9R2eu37Y+Ex84cEDPPfecpk6dqrFjx7bYzvsv4qWz98r4DIy4M4Ae4PP5jIkTJxoOh8PYvn17uP3cuXPG5MmTDbvdbhw6dChxA0S/9PLLLxvr169v0f7ee+8ZNpvNGDJkiOF2u8Pt999/vyHJWLduXYf6P3PmjDF48GAjPT3dOHLkSLj9yJEjRnp6upGenm5UVVVFHbN48WJDkrFmzZpwm8fjMRYtWmRIMj788MNOniX6s6ysLCMrK6tD+z7xxBOGJOPGG280gsFguP3xxx83JBnf+MY3ovbfv3+/YbVajcmTJxvnzp0Lt2/fvt1wOBzGlClTjEAgEG7nfRyx4PF4jGHDhhlWq9U4fvx4uJ33XyTK2rVrw+9dDz30kCHJePLJJ1vd97777jMkGffcc09U+z333GNIMh588MGo9nfffdeQZCxevNjweDzh9jVr1hiSjGXLlkXtz3WNzuro9VtWVmY89thjRk1NTVR7TU2NMW/ePEOS8dJLL0VtO3jwoCHJuOWWWzo8nptvvtmQZDz++OPhtmAwaKxcudKQZDz//PNR+3f2swv6l868//bGzwmd/Z2A/qUz129b7rrrLkOS8bOf/SyqnfdfxFtn75XxGRjxRkCFHvHmm28akozbbrutxbannnrKkGT86Ec/SsDIMFAtW7bMkGR88skn4bbO/o/P//7v/7Z57T7wwAOGJOPpp58OtxUWFhqSjKVLl7bYf/369W3+jGDg6kxAddFFFxmSWoREwWDQyM7ONpKTk426urpw+w9+8IMW12ijW2+91ZBkbNiwIdzG+zhi4cUXXzQkGVdffXVUO++/6A3au8EUDAaN0aNHGykpKa3e5E9JSTGys7Oj2m+88cYW76WNLr30UkOSUVpaGm7jukZ3dPUG6fPPP29IMlatWhXV3tkbpFVVVYbD4TCys7OjbnYahmEcOnSo1Wu1s59d0H/FOqCK9/tpV34noP/qyvtvfX29MWTIEMNutxsnTpyI2sb7LxKp+b0yPgOjJ7AGFXrE+vXrJUnLli1rsa1xip8NGzb05JAwwNlsNkmS1dpyptP33ntPDz/8sP7rv/5Lr776aptz6nb2um5v/0suuUTJycn8HKAFj8ejp556Sg8++KD++7//u9U5od1utzZt2qTc3FxlZWVFbTOZTLriiitUW1urLVu2hNtjef3yPo6OWr16tSTp61//eqvbef9Fb1VUVKTy8nItXLhQycnJUduSk5O1cOFClZSURK0hsX79+vC25mLxPst1jVho7zOxJJWXl+uxxx7Tgw8+qN/97ncqLi5udb+PPvpIHo9HV1xxRYupArOyspSbm6sPPvggvN5VVz67AL3lc0JXficAkV555RWdPXtWX/ziF9ucSo/3XyRC888FfAZGT2ANKvSIoqIiSWoxx6gkZWRkKCUlJbwPEG+HDx/W22+/rVGjRmnGjBkttjefI3rw4MH61a9+pZtvvjmqvb3rurEt8rpub3+LxaIJEyZo79698vv9bd4kwMBz/Phx3XbbbVFt8+fP1wsvvKCJEydKCq0tEQwGW722pOjrcdGiReHHKSkpysjIaHf/RryPo7tKS0v1zjvvKDMzU5/97Gdb3Yf3X/RW7V1Dje1vvvmmioqKNHbsWNXW1urYsWOaPn26LBZLq/tH9nu+5+C6Rrw88cQTklq/ySNJa9eu1dq1a8Pfm0wm/cM//IN+/etfR92o6sjPSGFhoUpLS5Wdnd2lzy5Ab/mc0NnfCUBzv/vd7yS1/UdbEu+/6Hmt3SvjMzB6AhVU6BGVlZWSpEGDBrW6PS0tLbwPEE8+n0833XSTPB6PHn744ahfmDNnztQTTzyhkpIS1dfX6+DBg3r00UdlMpl066236i9/+UtUX+1d12lpaVH7nG//xmOCwaCqq6u7d5LoN2677Ta98847qqioUG1trbZv366bbrpJn3zyiS6//PLwtdKRaytyv8bHnd3/fM/B+zja8+STTyoYDOrWW29t8T8rvP+it+vs+2xX35fbOobrGvHwm9/8Rm+88YYuu+wyXXXVVVHbkpKS9G//9m/aunWrzp07pzNnzujtt99Wfn6+nn322RaBQE/8jGDg6m2fE7h+0R0HDx7UunXrNG7cOF1xxRUttvP+i0Ro614Zn4HRE4gRAQwYjTdG33vvPd1+++266aaborZ/+ctfjvp+/PjxuuuuuzRlyhRdccUVuvfee/XFL36xJ4eMAa75X4nOmjVLv//97yVJzzzzjH7729/qu9/9biKGBnRKMBjUk08+KZPJpK997WsttvP+CwA9629/+5vuuusuZWVl6dlnn22xfcSIEfrxj38c1Xb55Zfroosu0pw5c/TKK69o27ZtmjNnTk8NGQMYnxPQnzzxxBMyDEO33XabzOaWdQO8/6Knne9eGRBvVFChRzSm2m39BUZVVVWbyTcQC8FgUF/72tf0/PPP66tf/ap+/etfd/jYyy+/XBMnTtSuXbtUVVUVbm/vum7cL/K67sjPgclkUmpqaofHhoHpjjvukCR98MEHkjp2bUXu1/i4s/uf7zl4H0db3n77bR0+fFiXXXaZJkyY0OHjeP9Fb9HZ99muvi+3dQzXNWJpzZo1uu666zRy5Ei9++67GjVqVIePTUpKCt+4avwcIvXMzwjQXKI+J3D9oquCwaCeeuopmc3mVv9oqz28/yIeznevjM/A6AkEVOgRrc0Z2uj48eOqqalpcz5ToLuCwaBuu+02Pf3007rxxhvDHwg7Iz09XZJUV1cXbmvvum5tTtz29g8EAjp48KAmTJjAHLk4r8brsba2VpKUnZ0ts9nc5hpQbV2PNTU1On78eIf3j9wWifdxnM/q1asltT/Pflt4/0Vv0N41FNneuF9ycrJGjRqlgwcPhhclb2//8z0H1zVi5fXXX9c111yj9PR0rVu3TtnZ2Z3uo/nnEKljPyN2u13jxo2T1LXPLkBrEvE5obO/E4BGf//731VWVqYrrrgi/H7YGbz/IpY6cq+Mz8DoCQRU6BFLliyRJL311lsttr355ptR+wCx1PgL9/e//71WrFihZ555ptWFGttTW1urPXv2KDk5OfyBUOr8dd3e/hs3blRtbS0/B+iQTZs2SQpNbyJJLpdL+fn54cVvIxmGobVr1yo5OVnz5s0Lt8fy+uV9HO05ffq0XnvtNQ0dOrTFFD3nw/sveoucnByNHj1aH3zwQdRNISl0nX7wwQeaMGGCxo4dG25fsmRJeFtzjdfp4sWLo/aXuK4RP6+//rquvfZaDR06VOvWrdOkSZO61E/zzyGSdOGFF8put2vt2rUyDCNq/9LSUhUWFmrhwoXhm0Vd+ewCNJeozwld+Z0ASNLvfvc7SV37oy2J91/ETkfvlfEZGD3CAHqAz+czsrOzDYfDYWzfvj3cfu7cOWPy5MmG3W43Dh48mLDxoX8KBALGLbfcYkgyrr/+esPn87W5b1VVlVFYWNiiva6uzrjxxhsNScZtt90Wte3MmTPGoEGDjPT0dOPIkSPh9iNHjhjp6elGenq6UVVVFXXM4sWLDUnGmjVrwm0ej8dYtGiRIcn44IMPunq66Gf27dtn1NbWttqekZFhSDI2bNgQbn/iiScMScaNN95oBIPBcPvjjz9uSDK+8Y1vRPVTWFhoWK1WY/Lkyca5c+fC7du3bzccDocxZcoUIxAIhNt5H0dX/eIXvzAkGd/61rda3c77L3qLhx56yJBkPPnkk61uv++++wxJxj333BPVfs899xiSjAcffDCq/d133zUkGYsXLzY8Hk+4fc2aNYYkY9myZVH7c12jO853/a5Zs8ZwOBxGRkaGUVBQcN7+tm3bFvV5otHLL79smM1mY8iQIVGfHwzDMG6++WZDkvH444+H24LBYPi9/Pnnn4/av7OfXdB/tXf99tbPCZ39nYD+63zvv41OnDhh2Gw2Y/jw4VGfC5rj/Rfx1pl7ZYbBZ2DEn8kwmsXrQJysW7dOy5cvl9Pp1MqVK5WamqqXX35ZpaWleuSRR3T33XcneojoZx544AH96Ec/UkpKir797W+3Wt579dVXa9asWTp06JCys7M1f/58TZkyRRkZGaqoqNDbb7+tsrIyzZgxQ+vWrdOwYcOijn/22Wd10003afjw4VqxYoUk6cUXX9SpU6f04osv6vrrr4/af8+ePVq4cKHq6+u1YsUKjRo1Sq+//rr27Nmju+66S48++mj8XhD0KQ888IB+/vOfa/HixcrKylJycrL279+vNWvWyOfz6Qc/+IEefPDB8P7BYFBXXXWV3nzzTV144YVasmSJDhw4oFdeeUXjx4/Xpk2bNHz48Kjn+MlPfqJ7771XWVlZuvbaa1VdXa0//OEP8nq9euedd7Rw4cKo/XkfR1fMmDFDu3fv1qeffqoZM2a02M77LxJp9erV2rhxoyRp165d2rZtmxYuXBiuLLnkkkvCf+VcW1urhQsXaufOnVq2bJnmzJmjbdu26a233tL8+fO1YcMGuVyuqP5vv/12rV69WtOmTdPnPvc5HTt2TC+++KJSUlL00UcfafLkyVH7c12jMzp6/RYUFGjWrFnyeDxauXKlcnNzW/Q1fvx43XrrreHvL730UhUXF+uiiy5SZmamAoGAtm3bpo0bN8rhcOill17SF7/4xag+jh07pgULFqisrEzXXHONJk2apA0bNujjjz/WF77wBb322msymUzh/bvy2QX9R0ev3976OaErvxPQf3Tm80Ojn/3sZ/qXf/kXffe739XPfvazNvvm/Rfx1pl7ZRKfgdEDEp2QYWDZtGmT8dnPftZIS0szXC6XkZ+fb/zhD39I9LDQTzX+RUh7/xr/yqmystJYtWqVMX/+fGP48OGG1Wo1UlNTjfz8fOM///M/jbq6ujaf54033jAWLVpkJCcnGykpKcaSJUuMtWvXtrl/QUGBcd111xlDhw41HA6HMWPGDOOxxx5r9a+kMHCtX7/euOGGG4ycnBwjLS3NsFqtRkZGhvGlL33JePPNN1s9xu12Gw888IAxceJEw263GxkZGcbXv/514/jx420+z7PPPmvMmzfPcLlcxqBBg4yrrrrK2Lp1a5v78z6Ozti0aZMhycjPz29zH95/kUjn+6xwyy23RO1/7tw54zvf+Y4xduxYw2azGePGjTPuvvvuFn/V2SgQCBi/+tWvjGnTphkOh8MYNmyYsWLFCuPAgQNtjonrGh3V0et33bp15/1MvGTJkqi+f/vb3xqf/exnjbFjxxoul8twOBxGdna28fWvf93Yt29fm2MqLy83vva1rxkjR4407Ha7kZOTY/z7v/97m9UCXfnsgv6ho9dvb/6c0NnfCeg/Ovv5wTAMY8qUKYYkY+/eve32zfsv4q0z98oa8RkY8UQFFQAAAAAAAAAAAHqUOdEDAAAAAAAAAAAAwMBCQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHvX/ATVThAbN3UY/AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "N = 20000\n", + "fig, ax = plt.subplots()\n", + "lns = ax.semilogy(range(5000), jnp.linalg.norm(all_b1_params_array[0,:5000,:]-jnp.ones(2,),axis=1),label=f'Adam b1 = 0.9')\n", + "for i,b1 in enumerate([0.99,0.999,0.9999]):\n", + " lns += ax.semilogy(\n", + " range(N), \n", + " jnp.sqrt(jnp.linalg.norm(all_b1_params_array[i+1,:N,:]-jnp.ones(2,),axis=1)),label=f'Adam b1 = {b1}'\n", + " )\n", + "ax1 = ax.twinx()\n", + "for i,b3 in enumerate([0.999,0.9999]):\n", + " lns += ax1.semilogy(\n", + " range(N), \n", + " jnp.sqrt(jnp.linalg.norm(all_ademamix_params_array[i,:N,:]-jnp.ones(2,),axis=1)),label=f'AdeMAMix b3 = {b3}'\n", + " )\n", + "labs = [l.get_label() for l in lns]\n", + "ax.legend(lns, labs, loc=0)\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "id": "fbdc5a55-e55a-47d1-9980-2b602de6ee3b", @@ -224,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 7, "id": "2cf96de0-cb01-4338-87b4-dd80f0498ebd", "metadata": {}, "outputs": [ @@ -232,6 +277,11 @@ "name": "stdout", "output_type": "stream", "text": [ + "Adam Values:\n", + "Final value with b1 = 0.9: ((0.9999862909317017, 0.9999725818634033))\n", + "Final value with b1 = 0.99: ((0.9999871850013733, 0.9999743103981018))\n", + "Final value with b1 = 0.999: ((1.0000061988830566, 1.0000123977661133))\n", + "Final value with b1 = 0.9999: ((0.9527199268341064, 0.9080769419670105))\n", "AdeMAMix Values:\n", "Final value with b3 = 0.999: ((1.0000168085098267, 0.9999828934669495))\n", "Final value with b3 = 0.9999: ((1.0000070333480835, 0.9999932050704956))\n" @@ -241,37 +291,16 @@ "source": [ "print(\"Adam Values:\")\n", "[0.9,0.99,0.999,0.9999]\n", - "print(f\"Final value with b1 = 0.9: ({float(all_ademamix_params_array[0,-1,0]),float(all_ademamix_params_array[0,-1,1])})\")\n", - "print(f\"Final value with b1 = 0.99: ({float(all_ademamix_params_array[1,-1,0]),float(all_ademamix_params_array[1,-1,1])})\")\n", - "print(f\"Final value with b1 = 0.999: ({float(all_ademamix_params_array[0,-1,0]),float(all_ademamix_params_array[0,-1,1])})\")\n", - "print(f\"Final value with b1 = 0.9999: ({float(all_ademamix_params_array[1,-1,0]),float(all_ademamix_params_array[1,-1,1])})\")\n", + "print(f\"Final value with b1 = 0.9: ({float(all_b1_params_array[0,-1,0]),float(all_b1_params_array[0,-1,1])})\")\n", + "print(f\"Final value with b1 = 0.99: ({float(all_b1_params_array[1,-1,0]),float(all_b1_params_array[1,-1,1])})\")\n", + "print(f\"Final value with b1 = 0.999: ({float(all_b1_params_array[2,-1,0]),float(all_b1_params_array[2,-1,1])})\")\n", + "print(f\"Final value with b1 = 0.9999: ({float(all_b1_params_array[3,-1,0]),float(all_b1_params_array[3,-1,1])})\")\n", "\n", "print(\"AdeMAMix Values:\")\n", "print(f\"Final value with b3 = 0.999: ({float(all_ademamix_params_array[0,-1,0]),float(all_ademamix_params_array[0,-1,1])})\")\n", "print(f\"Final value with b3 = 0.9999: ({float(all_ademamix_params_array[1,-1,0]),float(all_ademamix_params_array[1,-1,1])})\")" ] }, - { - "cell_type": "code", - "execution_count": 13, - "id": "66647d3e-81e2-4987-b5ef-81e08ac048dc", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(2, 100001, 2)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "all_b1_params_array.shape" - ] - }, { "cell_type": "code", "execution_count": null, From c252b51f91c2282d53c867ff3939cd34ca9ed8ce Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Fri, 18 Oct 2024 11:45:17 -0400 Subject: [PATCH 05/32] implementing pr feedback --- docs/api/contrib.rst | 12 ++++++------ optax/contrib/_ademamix.py | 13 +++++-------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index 232104c76..94262726d 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -38,6 +38,12 @@ Experimental features and algorithms that don't meet the split_real_and_imaginary SplitRealAndImaginaryState +AdEMAMix +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: ademamix +.. autofunction:: scale_by_ademamix +.. autoclass:: ScaleByAdemamixState + Asynchronous-centering-Prop ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: acprop @@ -83,12 +89,6 @@ Momo .. autofunction:: momo_adam .. autoclass:: MomoAdamState -Multiple EMA AdEMAMix -~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: ademamix -.. autofunction:: scale_by_ademamix -.. autoclass:: ScaleByAdemamixState - Prodigy ~~~~~~~ .. autofunction:: prodigy diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 7b31ffcb1..4ffa002e3 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -1,4 +1,4 @@ -"""AdeMAMix. +"""AdEMAMix. Implementation of "THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER" @@ -19,7 +19,7 @@ class ScaleByAdemamixState(NamedTuple): """State for the Ademamix algorithm.""" - count: chex.Array + count: chex.Array # shape=(), dtype=jnp.int32. count_m2: chex.Array m1: base.Updates m2: base.Updates @@ -29,10 +29,8 @@ class ScaleByAdemamixState(NamedTuple): def scale_by_ademamix( b1: float = 0.9, b2: float = 0.999, - b3: float = 0.9999, - alpha: float = 5.0, - b3_scheduler: Optional[base.ScalarOrSchedule] = None, - alpha_scheduler: Optional[base.ScalarOrSchedule] = None, + b3: base.ScalarOrSchedule = 0.9999, + alpha: base.ScalarOrSchedule = 5.0, eps: float = 1e-8, ) -> base.GradientTransformation: """Rescale updates according to the Ademamix algorithm. @@ -49,8 +47,6 @@ def scale_by_ademamix( for the second EMA. alpha: the coefficient that "blends" the two EMAs. paper states values in :math:`[4,10]` work well in practice. - b3_scheduler: The schedule for the b3 parameter - alpha_scheduler: The schedule for the alpha parameter eps: A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. @@ -93,6 +89,7 @@ def update_fn(updates, state, params=None): count_inc = numerics.safe_int32_increment(state.count) count_m2_inc = numerics.safe_int32_increment(state.count_m2) m1_hat = otu.tree_bias_correction(m1, b1, count_inc) + # NOTE: AdEMAMix does not perform bias correction on b2. nu_hat = otu.tree_bias_correction(nu, b2, count_inc) updates = jtu.tree_map( lambda m1_, m2_, v_: (m1_ + c_alpha * m2_) / (jnp.sqrt(v_) + eps), From f9b6559bdf7b32111f09b54b9ed5a04cfecb0139 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 12:01:53 -0400 Subject: [PATCH 06/32] updated ademamix with author docstrings --- optax/contrib/_ademamix.py | 173 +++++++++++++++++++++++++++---------- 1 file changed, 126 insertions(+), 47 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 4ffa002e3..e290396ff 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -5,8 +5,7 @@ (https://arxiv.org/pdf/2409.03137) by Matteo Pagliardini, Pierre Ablin and David Grangier. """ -import optax.tree_utils as otu -from typing import NamedTuple, Optional + import chex import jax.numpy as jnp import jax.tree_util as jtu @@ -14,16 +13,73 @@ from optax._src import combine from optax._src import numerics from optax._src import transform +import optax.tree_utils as otu +from jax.lax import rsqrt +from typing import NamedTuple, Optional, Tuple + + +def alpha_scheduler(alpha, alpha_start: float = 0, T_alpha: int = 0) -> base.Schedule: + """The alpha scheduler from the paper. + + This is a progressive increase in alpha using a linear scheduler. + + Args: + alpha: The current value of alpha (the coefficient that "blends" the two EMAs) + alpha_start: The starting value of alpha + T_alpha: The warmup time for alpha to reach it's final value. + + Returns: + A `base.Schedule` object. + + """ + + def schedule(step: int) -> float: + is_warmup: float = (step < T_alpha).astype(jnp.float32) + a: float = step / float(T_alpha) + return is_warmup * ((1.0 - a) * alpha_start + a * alpha) + alpha * (1.0 - is_warmup) + + return schedule + + +def b3_scheduler(beta_end: float, beta_start: float = 0, T_b3: int = 0): + """The b3 scheduler from the paper. + + This is a progressive increase in b3 attempting to increase t_half linearly + (Appendix A.1 of the paper derives the scheduler.) + + Args: + beta_end: The current value of b3 (the exponential decay rate to track the + first moment of past gradients for the second EMA) + beta_start: The starting value of b3 + T_b3: The warmup time for b3 to reach it's maximal value. + + Returns: + A `base.Schedule` object. + + """ + + def f(beta: float) -> float: + return jnp.log(0.5) / jnp.log(beta) - 1 + + def f_inv(t: float) -> float: + return rsqrt(t + 1) + + def schedule(step: int) -> float: + is_warmup = (step < T_b3).astype(jnp.float32) + alpha = step / float(T_b3) + return is_warmup * f_inv((1.0 - alpha) * f(beta_start) + alpha * f(beta_end)) + beta_end * (1.0 - is_warmup) + + return schedule class ScaleByAdemamixState(NamedTuple): - """State for the Ademamix algorithm.""" + """State for the Ademamix algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - count_m2: chex.Array - m1: base.Updates - m2: base.Updates - nu: base.Updates + count: chex.Array # shape=(), dtype=jnp.int32. + count_m2: chex.Array # shape=(), dtype=jnp.int32. + m1: base.Updates + m2: base.Updates + nu: base.Updates def scale_by_ademamix( @@ -40,15 +96,15 @@ def scale_by_ademamix( Args: b1: Exponential decay rate to track the first moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW + the first Exponential Moving Average (EMA) - same as AdamW b2: Exponential decay rate to track the second moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW + the first Exponential Moving Average (EMA) - same as AdamW b3: Exponential decay rate to track the first moment of past gradients - for the second EMA. + for the second EMA. alpha: the coefficient that "blends" the two EMAs. paper states values in - :math:`[4,10]` work well in practice. + :math:`[4,10]` work well in practice. eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. + (as in the Adam paper) to avoid dividing by zero when rescaling. Returns: A `GradientTransformation` object. @@ -59,7 +115,7 @@ def scale_by_ademamix( smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations scenarios. Moreover, retaining gradient information over many thousands steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is + distribution shift, or general cases in which the distribution is non-stationary. """ @@ -75,21 +131,24 @@ def init_fn(params): nu=nu, ) - def update_fn(updates, state, params=None): + def update_fn( + updates: jtu.tree_map, state, params=None + ) -> Tuple[jtu.tree_map, ScaleByAdemamixState]: del params - c_b3 = b3_scheduler(state.count_m2) if b3_scheduler is not None else b3 + c_b3 = b3_scheduler(state.count_m2) if callable(b3_scheduler) else b3 c_alpha = ( - alpha_scheduler(state.count_m2) if alpha_scheduler is not None else alpha + alpha_scheduler(state.count_m2) if callable(alpha_scheduler) else alpha ) m1 = otu.tree_update_moment( - updates, state.m1, b1, 1 + updates, state.m1, b1, order=1 ) # m1 = b1 * m1 + (1-b1) * updates - m2 = otu.tree_update_moment(updates, state.m2, c_b3, 1) - nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, b2, 2) + m2 = otu.tree_update_moment(updates, state.m2, c_b3, order=1) + nu = otu.tree_update_moment_per_elem_norm(updates, state.nu, b2, order=2) count_inc = numerics.safe_int32_increment(state.count) count_m2_inc = numerics.safe_int32_increment(state.count_m2) m1_hat = otu.tree_bias_correction(m1, b1, count_inc) - # NOTE: AdEMAMix does not perform bias correction on b2. + # NOTE: AdEMAMix does not perform bias correction on b2 to let the momentum + # buffer fill itself slowly. nu_hat = otu.tree_bias_correction(nu, b2, count_inc) updates = jtu.tree_map( lambda m1_, m2_, v_: (m1_ + c_alpha * m2_) / (jnp.sqrt(v_) + eps), @@ -108,14 +167,12 @@ def ademamix( learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, - b3: float = 0.9999, - alpha: float = 5.0, - b3_scheduler: Optional[base.ScalarOrSchedule] = None, - alpha_scheduler: Optional[base.ScalarOrSchedule] = None, + b3: base.ScalarOrSchedule = 0.9999, + alpha: base.ScalarOrSchedule = 5.0, eps: float = 1e-8, weight_decay: float = 0.0, ) -> base.GradientTransformation: - """The Ademamix optimiser. + """The Ademamix optimizer. Description @@ -123,12 +180,12 @@ def ademamix( > import optax > import jax > import jax.numpy as jnp - > def f(x): return jnp.sum(x ** 2) # simple quadratic function + > def f(x): return jnp.sum(x ** 2) # simple quadratic functio > solver = optax.ademamix(learning_rate=0.003) > params = jnp.array([1., 2., 3.]) > print('Objective function: ', f(params)) Objective function: 14.0 - > opt_state = solver.init(params) + > opt_state = solver.init(params > for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) @@ -145,35 +202,57 @@ def ademamix( Args: b1: Exponential decay rate to track the first moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW + the first Exponential Moving Average (EMA) - same as AdamW b2: Exponential decay rate to track the second moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW + the first Exponential Moving Average (EMA) - same as AdamW b3: Exponential decay rate to track the first moment of past gradients - for the second EMA. + for the second EMA. alpha: the coefficient that "blends" the two EMAs. paper states values in - :math:`[4,10]` work well in practice. - b3_scheduler: The schedule for the b3 parameter - alpha_scheduler: The schedule for the alpha parameter + :math:`[4,10]` work well in practice. eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. + (as in the Adam paper) to avoid dividing by zero when rescaling. weight_decay: Strength of the weight decay regularization. Returns: A `GradientTransformation` object. Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, - the method is best suited to settings where the number of iterations is - important. The paper reports on this effect in App. C.1.5, showing how - smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations - scenarios. Moreover, retaining gradient information over many thousands - steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is - non-stationary. + the method is best suited to settings where the number of iterations is + important. The paper reports on this effect in App. C.1.5, showing how + smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations + scenarios. Moreover, retaining gradient information over many thousands of + steps can pose a problem in domains requiring fast adaptation to a sudden + distribution shift, or general cases in which the distribution is + non-stationary. """ return combine.chain( - scale_by_ademamix( - b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps - ), - transform.add_decayed_weights(weight_decay), - transform.scale_by_learning_rate(learning_rate), + scale_by_ademamix(b1, b2, b3, alpha, eps), + transform.add_decayed_weights(weight_decay), + transform.scale_by_learning_rate(learning_rate), ) + + +if __name__ == "__main__": # dummy test + import jax + import jax.numpy as jnp + def f(x): + return jnp.sum(x**2) # simple quadratic function + + alpha = 8.0 + b1, b2, b3 = 0.9, 0.999, 0.9999 + + f_a = alpha_scheduler(alpha, alpha_start=0, T_alpha=10) + f_b3 = b3_scheduler(b3, beta_start=b1, T_b3=10) + + solver = ademamix(learning_rate=0.01, b1=b1, b2=b2, b3=f_b3, alpha=f_a, weight_decay=0.01) + + params = jnp.array([1.0, 2.0, 3.0]) + print("Objective function: {:.2f}".format(f(params))) + opt_state = solver.init(params) + for itr in range(100): + grad = jax.grad(f)(params) + updates, opt_state = solver.update(grad, opt_state, params) + params = jax.tree_util.tree_map(lambda p, u: p + u, params, updates) + if itr % 5 == 0: + print("Objective function: {:.2f}".format(f(params))) + print(params) From 4b621aaca6561ba0d7315916f1f33677fe0a181e Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 16:26:35 -0400 Subject: [PATCH 07/32] added docstrings and matched adamw api --- optax/contrib/_ademamix.py | 245 +++++++++++++++++++++---------------- 1 file changed, 140 insertions(+), 105 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index e290396ff..fba06150a 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -15,18 +15,23 @@ from optax._src import transform import optax.tree_utils as otu from jax.lax import rsqrt -from typing import NamedTuple, Optional, Tuple +from typing import NamedTuple, Tuple -def alpha_scheduler(alpha, alpha_start: float = 0, T_alpha: int = 0) -> base.Schedule: +def alpha_scheduler( + alpha_value, + alpha_start: float = 0, + warmup_alpha: int = 0 +) -> base.Schedule: """The alpha scheduler from the paper. This is a progressive increase in alpha using a linear scheduler. Args: - alpha: The current value of alpha (the coefficient that "blends" the two EMAs) + alpha_value: The current value of alpha (the coefficient that "blends" + the two EMAs) alpha_start: The starting value of alpha - T_alpha: The warmup time for alpha to reach it's final value. + warmup_alpha: The warmup time for alpha to reach it's final value. Returns: A `base.Schedule` object. @@ -34,60 +39,75 @@ def alpha_scheduler(alpha, alpha_start: float = 0, T_alpha: int = 0) -> base.Sch """ def schedule(step: int) -> float: - is_warmup: float = (step < T_alpha).astype(jnp.float32) - a: float = step / float(T_alpha) - return is_warmup * ((1.0 - a) * alpha_start + a * alpha) + alpha * (1.0 - is_warmup) - + is_warmup: float = jnp.array(step < warmup_alpha).astype(jnp.float32) + a: float = step / float(warmup_alpha) + return ( + is_warmup * ((1.0 - a) * alpha_start + a * alpha_value) + + alpha_value * (1.0 - is_warmup) + ) return schedule -def b3_scheduler(beta_end: float, beta_start: float = 0, T_b3: int = 0): +def b3_scheduler( + beta_end: float, + beta_start: float = 0, + warmup_b3: int = 0 +) -> base.Schedule: """The b3 scheduler from the paper. - This is a progressive increase in b3 attempting to increase t_half linearly + This is a progressive increase in b3 attempting to increase the number + of iterations corresponding to where half of the mass of the second EMA + is concentrated (denoted ``t_half`` in the paper). This scheduler attempts + to increase this value linearly. Note for ``b3 = 0.9999, t_half`` is + approximately ``6930.`` (Appendix A.1 of the paper derives the scheduler.) Args: - beta_end: The current value of b3 (the exponential decay rate to track the - first moment of past gradients for the second EMA) + beta_end: The desired ending value of b3 (the exponential decay rate to + track the first moment of past gradients for the second EMA) beta_start: The starting value of b3 - T_b3: The warmup time for b3 to reach it's maximal value. + warmup_b3: The warmup time for b3 to reach it's maximal value. Returns: A `base.Schedule` object. """ - def f(beta: float) -> float: + def fun(beta: float) -> float: return jnp.log(0.5) / jnp.log(beta) - 1 def f_inv(t: float) -> float: return rsqrt(t + 1) def schedule(step: int) -> float: - is_warmup = (step < T_b3).astype(jnp.float32) - alpha = step / float(T_b3) - return is_warmup * f_inv((1.0 - alpha) * f(beta_start) + alpha * f(beta_end)) + beta_end * (1.0 - is_warmup) - + is_warmup = jnp.array(step < warmup_b3).astype(jnp.float32) + step_over_warmup = step / float(warmup_b3) + return ( + is_warmup * f_inv((1.0 - step_over_warmup) * fun(beta_start) + + step_over_warmup * fun(beta_end)) + + beta_end * (1.0 - is_warmup) + ) return schedule class ScaleByAdemamixState(NamedTuple): - """State for the Ademamix algorithm.""" + """State for the Ademamix algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - count_m2: chex.Array # shape=(), dtype=jnp.int32. - m1: base.Updates - m2: base.Updates - nu: base.Updates + count: chex.Array # shape=(), dtype=jnp.int32. + count_m2: chex.Array # shape=(), dtype=jnp.int32. + m1: base.Updates + m2: base.Updates + nu: base.Updates def scale_by_ademamix( - b1: float = 0.9, - b2: float = 0.999, - b3: base.ScalarOrSchedule = 0.9999, - alpha: base.ScalarOrSchedule = 5.0, + b1: float, + b2: float, + b3: base.ScalarOrSchedule, + alpha: base.ScalarOrSchedule, eps: float = 1e-8, + eps_root: float = 0.0, + weight_decay: float=0.0, ) -> base.GradientTransformation: """Rescale updates according to the Ademamix algorithm. @@ -105,18 +125,14 @@ def scale_by_ademamix( :math:`[4,10]` work well in practice. eps: A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: Term added to the denominator inside of the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + weight_decay: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. Returns: A `GradientTransformation` object. - Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, - the method is best suited to settings where the number of iterations is - important. The paper reports on this effect in App. C.1.5, showing how - smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations - scenarios. Moreover, retaining gradient information over many thousands - steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is - non-stationary. """ def init_fn(params): @@ -135,9 +151,9 @@ def update_fn( updates: jtu.tree_map, state, params=None ) -> Tuple[jtu.tree_map, ScaleByAdemamixState]: del params - c_b3 = b3_scheduler(state.count_m2) if callable(b3_scheduler) else b3 + c_b3 = b3_scheduler(state.count_m2) if callable(b3) else b3 c_alpha = ( - alpha_scheduler(state.count_m2) if callable(alpha_scheduler) else alpha + alpha_scheduler(state.count_m2) if callable(alpha) else alpha ) m1 = otu.tree_update_moment( updates, state.m1, b1, order=1 @@ -170,89 +186,108 @@ def ademamix( b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 5.0, eps: float = 1e-8, + eps_root: float = 0.0, weight_decay: float = 0.0, ) -> base.GradientTransformation: - """The Ademamix optimizer. +r"""AdEMAMix. + + AdEMAMix (Adaptive EMA Mixture) is AdamW with a mixture of two momentum + terms to better take advantage of historical gradients. + + Both SGD with momemtum (SGD+M) and Adam incorporate momentum using + Exponential Moving Averages (EMAs) of past gradients + + Let :math:`\eta` represent the learning rate and :math:`\beta_1, \beta_2`, + :math:`\beta_3, \alpha, \varepsilon, \bar{\varepsilon}`, represent the + arguments ``b1``, ``b2``, ``b3``, ``alpha``, ``eps`` and ``eps_root`` + respectively. Let :math:`\lambda` be the weight decay and :math:`\theta_t` + the parameter vector at time :math:`t`. + + The ``init`` function of this optimizer initializes an internal state + :math:`S_0 := (m1_0, m2_0, v_0) = (0, 0, 0)`, representing initial estimates + for the first and second moments. In practice these values are stored as pytrees + containing all zeros, with the same shape as the model updates. + At step :math:`t`, the ``update`` function of this optimizer takes as + arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t` + and the parameters :math:`\theta_t` and computes updates :math:`\theta_{t+1}` and + new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have, + + .. math:: + + \begin{align*} + m1_t &\leftarrow \beta_1 \cdot m1_{t-1} + (1-\beta_1) \cdot g_t \\ + m2_t &\leftarrow \beta_3 \cdot m2_{t-1} + (1-\beta_3) \cdot g_t \\ + v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ + \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ + \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ + \theta_t &\leftarrow \theta_{t-1} - \eta \cdot \left( (\hat{m1}_t + + \alpha m2_t) / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon\right) + + \lambda \theta_{t-1} \right).\\ + S_t &\leftarrow (m1_t, m2_t, v_t). + \end{align*} - Description + Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, + the method is best suited to settings where the number of iterations is + important. The paper reports on this effect in Appendix C.1.5, showing how + smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations + scenarios. Moreover, retaining gradient information over many thousands of + steps can pose a problem in domains requiring fast adaptation to a sudden + distribution shift, or general cases in which the distribution is + non-stationary. Examples: - > import optax - > import jax - > import jax.numpy as jnp - > def f(x): return jnp.sum(x ** 2) # simple quadratic functio - > solver = optax.ademamix(learning_rate=0.003) - > params = jnp.array([1., 2., 3.]) - > print('Objective function: ', f(params)) - Objective function: 14.0 - > opt_state = solver.init(params - > for _ in range(5): - ... grad = jax.grad(f)(params) - ... updates, opt_state = solver.update(grad, opt_state, params) - ... params = optax.apply_updates(params, updates) - ... print('Objective function: {:.2E}'.format(f(params))) - Objective function: 1.40E+01 - Objective function: 1.39E+01 - Objective function: 1.39E+01 - Objective function: 1.39E+01 - Objective function: 1.38E+01 + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function + >>> solver = optax.adamw(learning_rate=0.003) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 1.40E+01 + Objective function: 1.39E+01 + Objective function: 1.39E+01 + Objective function: 1.39E+01 + Objective function: 1.38E+01 References: - Pagliardini et al, 2024: https://arxiv.org/pdf/2409.03137 + "THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER" + (https://arxiv.org/pdf/2409.03137) by Matteo Pagliardini, + Pierre Ablin and David Grangier. Args: - b1: Exponential decay rate to track the first moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW - b2: Exponential decay rate to track the second moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW - b3: Exponential decay rate to track the first moment of past gradients - for the second EMA. - alpha: the coefficient that "blends" the two EMAs. paper states values in - :math:`[4,10]` work well in practice. + learning_rate: A global scaling factor, either fixed or evolving along + iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + b1: Exponential decay rate to track the fast EMA. + b2: Exponential decay rate to track the second moment of past gradients. + b3: Exponenital decay rate to track the slow EMA. + alpha: Mixing coefficient in the linear combination fo the fast and slow EMAs. eps: A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. - weight_decay: Strength of the weight decay regularization. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. This is needed for + instance when computing (meta-)gradients through Adam. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent + with other frameworks such as PyTorch, but different from + (Loshchilov et al, 2019) where the weight decay is only multiplied with + the "schedule multiplier", but not the base learning rate. Returns: - A `GradientTransformation` object. + The corresponding `GradientTransformation`. - Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, - the method is best suited to settings where the number of iterations is - important. The paper reports on this effect in App. C.1.5, showing how - smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations - scenarios. Moreover, retaining gradient information over many thousands of - steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is - non-stationary. + .. seealso:: + See the related functions :func:`optax.adam`, :func:`optax.nadamw`, as well + as the example :doc:`../_collections/examples/contrib/rosenbrock_ademamix` for a use case. """ return combine.chain( - scale_by_ademamix(b1, b2, b3, alpha, eps), + scale_by_ademamix(b1, b2, b3, alpha, eps, eps_root), transform.add_decayed_weights(weight_decay), transform.scale_by_learning_rate(learning_rate), ) - - -if __name__ == "__main__": # dummy test - import jax - import jax.numpy as jnp - def f(x): - return jnp.sum(x**2) # simple quadratic function - - alpha = 8.0 - b1, b2, b3 = 0.9, 0.999, 0.9999 - - f_a = alpha_scheduler(alpha, alpha_start=0, T_alpha=10) - f_b3 = b3_scheduler(b3, beta_start=b1, T_b3=10) - - solver = ademamix(learning_rate=0.01, b1=b1, b2=b2, b3=f_b3, alpha=f_a, weight_decay=0.01) - - params = jnp.array([1.0, 2.0, 3.0]) - print("Objective function: {:.2f}".format(f(params))) - opt_state = solver.init(params) - for itr in range(100): - grad = jax.grad(f)(params) - updates, opt_state = solver.update(grad, opt_state, params) - params = jax.tree_util.tree_map(lambda p, u: p + u, params, updates) - if itr % 5 == 0: - print("Objective function: {:.2f}".format(f(params))) - print(params) From 1a51332b3a7337a5a647fae5238bee978edd4772 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 16:35:10 -0400 Subject: [PATCH 08/32] removed unneeded alpha scheduler --- optax/contrib/_ademamix.py | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index fba06150a..ce8126437 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -18,36 +18,6 @@ from typing import NamedTuple, Tuple -def alpha_scheduler( - alpha_value, - alpha_start: float = 0, - warmup_alpha: int = 0 -) -> base.Schedule: - """The alpha scheduler from the paper. - - This is a progressive increase in alpha using a linear scheduler. - - Args: - alpha_value: The current value of alpha (the coefficient that "blends" - the two EMAs) - alpha_start: The starting value of alpha - warmup_alpha: The warmup time for alpha to reach it's final value. - - Returns: - A `base.Schedule` object. - - """ - - def schedule(step: int) -> float: - is_warmup: float = jnp.array(step < warmup_alpha).astype(jnp.float32) - a: float = step / float(warmup_alpha) - return ( - is_warmup * ((1.0 - a) * alpha_start + a * alpha_value) + - alpha_value * (1.0 - is_warmup) - ) - return schedule - - def b3_scheduler( beta_end: float, beta_start: float = 0, From 4eb60651f997f90eb06b481dd18661abfc07c360 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 16:50:06 -0400 Subject: [PATCH 09/32] added alpha as a scheduler --- optax/contrib/_ademamix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index ce8126437..8e4ecd0a7 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -123,7 +123,7 @@ def update_fn( del params c_b3 = b3_scheduler(state.count_m2) if callable(b3) else b3 c_alpha = ( - alpha_scheduler(state.count_m2) if callable(alpha) else alpha + alpha(state.count_m2) if callable(alpha) else alpha ) m1 = otu.tree_update_moment( updates, state.m1, b1, order=1 @@ -159,7 +159,7 @@ def ademamix( eps_root: float = 0.0, weight_decay: float = 0.0, ) -> base.GradientTransformation: -r"""AdEMAMix. + r"""AdEMAMix. AdEMAMix (Adaptive EMA Mixture) is AdamW with a mixture of two momentum terms to better take advantage of historical gradients. @@ -210,7 +210,7 @@ def ademamix( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> solver = optax.adamw(learning_rate=0.003) + >>> solver = optax.ademamix(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 From 4eb618a1c6fa22f72de18478ec3e9d7bec87856c Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 17:00:13 -0400 Subject: [PATCH 10/32] removed b3_scheduler --- optax/contrib/_ademamix.py | 43 -------------------------------------- 1 file changed, 43 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 8e4ecd0a7..ab2850b7f 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -17,49 +17,6 @@ from jax.lax import rsqrt from typing import NamedTuple, Tuple - -def b3_scheduler( - beta_end: float, - beta_start: float = 0, - warmup_b3: int = 0 -) -> base.Schedule: - """The b3 scheduler from the paper. - - This is a progressive increase in b3 attempting to increase the number - of iterations corresponding to where half of the mass of the second EMA - is concentrated (denoted ``t_half`` in the paper). This scheduler attempts - to increase this value linearly. Note for ``b3 = 0.9999, t_half`` is - approximately ``6930.`` - (Appendix A.1 of the paper derives the scheduler.) - - Args: - beta_end: The desired ending value of b3 (the exponential decay rate to - track the first moment of past gradients for the second EMA) - beta_start: The starting value of b3 - warmup_b3: The warmup time for b3 to reach it's maximal value. - - Returns: - A `base.Schedule` object. - - """ - - def fun(beta: float) -> float: - return jnp.log(0.5) / jnp.log(beta) - 1 - - def f_inv(t: float) -> float: - return rsqrt(t + 1) - - def schedule(step: int) -> float: - is_warmup = jnp.array(step < warmup_b3).astype(jnp.float32) - step_over_warmup = step / float(warmup_b3) - return ( - is_warmup * f_inv((1.0 - step_over_warmup) * fun(beta_start) - + step_over_warmup * fun(beta_end)) - + beta_end * (1.0 - is_warmup) - ) - return schedule - - class ScaleByAdemamixState(NamedTuple): """State for the Ademamix algorithm.""" From 94420857f60b0685d4d56401de74ae3411cbd128 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 17:03:10 -0400 Subject: [PATCH 11/32] removed b3_scheduler --- optax/contrib/_ademamix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index ab2850b7f..75f783980 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -78,7 +78,7 @@ def update_fn( updates: jtu.tree_map, state, params=None ) -> Tuple[jtu.tree_map, ScaleByAdemamixState]: del params - c_b3 = b3_scheduler(state.count_m2) if callable(b3) else b3 + c_b3 = b3(state.count_m2) if callable(b3) else b3 c_alpha = ( alpha(state.count_m2) if callable(alpha) else alpha ) From 420771f867ceaeb567837dca31e31ec4609c3dc3 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 19:41:28 -0400 Subject: [PATCH 12/32] fixing tests with new docstrings --- examples/contrib/rosenbrock_ademamix.ipynb | 124 +++++++++++++++++++-- optax/contrib/_ademamix.py | 34 +++--- 2 files changed, 133 insertions(+), 25 deletions(-) diff --git a/examples/contrib/rosenbrock_ademamix.ipynb b/examples/contrib/rosenbrock_ademamix.ipynb index 825e928a0..8835b03a0 100644 --- a/examples/contrib/rosenbrock_ademamix.ipynb +++ b/examples/contrib/rosenbrock_ademamix.ipynb @@ -22,7 +22,15 @@ "execution_count": 1, "id": "55182561-ad63-4fb1-ba21-116ca65c21b1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n" + ] + } + ], "source": [ "import matplotlib.pyplot as plt\n", "import optax\n", @@ -36,6 +44,17 @@ "plt.rc('font', size=14)" ] }, + { + "cell_type": "code", + "execution_count": 2, + "id": "298cb49c-5d9f-43ae-befd-066fc7d2773e", + "metadata": {}, + "outputs": [], + "source": [ + "from optax.schedules import linear_schedule\n", + "from optax._src import base" + ] + }, { "cell_type": "markdown", "id": "ec581f6c-c3e5-4924-bf78-17f57c60cbcd", @@ -46,7 +65,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "15cd3560-d41c-4a97-83c5-b28df4d5d077", "metadata": {}, "outputs": [], @@ -64,6 +83,16 @@ "Z = rosenbrock([X, Y])" ] }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a153b4de-331c-4c78-aca6-63864e1551e0", + "metadata": {}, + "outputs": [], + "source": [ + "num_iterations = 7500" + ] + }, { "cell_type": "markdown", "id": "152e443e-5697-4eea-97f5-269cd12a2cfd", @@ -74,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "92b6987c-8ba1-43bc-8083-4c2b6324cb28", "metadata": {}, "outputs": [ @@ -83,9 +112,35 @@ "output_type": "stream", "text": [ "Objective function: 1616.0\n", + "Objective function at iteration 0 = 1599.2254638671875\n", + "Objective function at iteration 1000 = 11.406792640686035\n", + "Objective function at iteration 2000 = 11.379987716674805\n", + "Objective function at iteration 3000 = 11.344354629516602\n", + "Objective function at iteration 4000 = 11.301033020019531\n", + "Objective function at iteration 5000 = 11.250575065612793\n", + "Objective function at iteration 6000 = 11.193216323852539\n", + "Objective function at iteration 7000 = 11.129084587097168\n", "Objective function: 1616.0\n", + "Objective function at iteration 0 = 1599.2254638671875\n", + "Objective function at iteration 1000 = 11.371587753295898\n", + "Objective function at iteration 2000 = 11.315672874450684\n", + "Objective function at iteration 3000 = 11.277731895446777\n", + "Objective function at iteration 4000 = 11.231626510620117\n", + "Objective function at iteration 5000 = 11.177850723266602\n", + "Objective function at iteration 6000 = 11.116741180419922\n", + "Objective function at iteration 7000 = 11.04842472076416\n", "Objective function: 1616.0\n", - "Objective function: 1616.0\n" + "Objective function at iteration 0 = 1599.2254638671875\n", + "Objective function at iteration 1000 = 17.629518508911133\n", + "Objective function at iteration 2000 = 66.02627563476562\n", + "Objective function at iteration 3000 = 36.76356887817383\n", + "Objective function at iteration 4000 = 11.747029304504395\n", + "Objective function at iteration 5000 = 13.123286247253418\n", + "Objective function at iteration 6000 = 11.120031356811523\n", + "Objective function at iteration 7000 = 10.75074577331543\n", + "Objective function: 1616.0\n", + "Objective function at iteration 0 = 1599.2281494140625\n", + "Objective function at iteration 1000 = 44.154151916503906\n" ] } ], @@ -101,13 +156,13 @@ " print(\"Objective function: \", rosenbrock(params))\n", " all_params=[params]\n", " opt_state = solver.init(params)\n", - " for i in range(100000):\n", + " for i in range(num_iterations):\n", " grad = jax.grad(rosenbrock)(params)\n", " updates, opt_state = solver.update(grad, opt_state, params)\n", " params = optax.apply_updates(params, updates)\n", " all_params.append(params)\n", - " # if i%1000 == 0:\n", - " # print(f\"Objective function at iteration {i} = {rosenbrock(params)}\")\n", + " if i%1000 == 0:\n", + " print(f\"Objective function at iteration {i} = {rosenbrock(params)}\")\n", " all_b1_params.append(all_params)\n", "all_b1_params_array = jnp.array(all_b1_params)" ] @@ -120,9 +175,60 @@ "## Generate AdeMAMix Trajectories" ] }, + { + "cell_type": "markdown", + "id": "57e329a7-5737-4a74-bda3-936290b004f9", + "metadata": {}, + "source": [ + "### Create `alpha` scheduler" + ] + }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, + "id": "01fe6b99-cb4e-4203-8490-75be300448ee", + "metadata": {}, + "outputs": [], + "source": [ + "alpha = 0.8\n", + "alpha = linear_schedule(0, alpha, num_iterations)" + ] + }, + { + "cell_type": "markdown", + "id": "62b1a5c0-e588-4ffd-a7e2-b046a4daec7c", + "metadata": {}, + "source": [ + "### Create `b3` scheduler" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e652d62f-4135-478b-8995-a34d7729c30a", + "metadata": {}, + "outputs": [], + "source": [ + "def b3_scheduler(beta_end: float, beta_start: float = 0, warmup: int = 0):\n", + " def f(beta):\n", + " return jnp.log(0.5) / jnp.log(beta) - 1\n", + "\n", + " def f_inv(t):\n", + " return jnp.power(0.5, 1 / (t + 1))\n", + "\n", + " def schedule(step):\n", + " is_warmup = jnp.array(step < warmup).astype(jnp.float32)\n", + " alpha = step / float(warmup)\n", + " return is_warmup * f_inv(\n", + " (1.0 - alpha) * f(beta_start) + alpha * f(beta_end)\n", + " ) + beta_end * (1.0 - is_warmup)\n", + "\n", + " return schedule" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "id": "11a4561a-1d92-44ce-bab0-5af22f028167", "metadata": {}, "outputs": [ @@ -138,11 +244,13 @@ "source": [ "all_ademamix_params = []\n", "for b3 in [0.999,0.9999]:\n", + " b3 = b3_scheduler(b3, 0, num_iterations)\n", " solver = optax.contrib.ademamix(\n", " learning_rate=0.003,\n", " b1=.99,\n", " b2=0.999,\n", " b3=b3,\n", + " alpha=alpha,\n", " )\n", " params = jnp.array([-3.,5.])\n", " print(\"Objective function: \", rosenbrock(params))\n", diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 75f783980..7b19fc532 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -14,7 +14,6 @@ from optax._src import numerics from optax._src import transform import optax.tree_utils as otu -from jax.lax import rsqrt from typing import NamedTuple, Tuple class ScaleByAdemamixState(NamedTuple): @@ -34,7 +33,6 @@ def scale_by_ademamix( alpha: base.ScalarOrSchedule, eps: float = 1e-8, eps_root: float = 0.0, - weight_decay: float=0.0, ) -> base.GradientTransformation: """Rescale updates according to the Ademamix algorithm. @@ -54,8 +52,6 @@ def scale_by_ademamix( (as in the Adam paper) to avoid dividing by zero when rescaling. eps_root: Term added to the denominator inside of the square-root to improve numerical stability when backpropagating gradients through the rescaling. - weight_decay: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. Returns: A `GradientTransformation` object. @@ -94,7 +90,8 @@ def update_fn( # buffer fill itself slowly. nu_hat = otu.tree_bias_correction(nu, b2, count_inc) updates = jtu.tree_map( - lambda m1_, m2_, v_: (m1_ + c_alpha * m2_) / (jnp.sqrt(v_) + eps), + lambda m1_, m2_, v_: ((m1_ + c_alpha * m2_) / (jnp.sqrt(v_+eps_root) + + eps)), m1_hat, m2, nu_hat, @@ -131,13 +128,14 @@ def ademamix( the parameter vector at time :math:`t`. The ``init`` function of this optimizer initializes an internal state - :math:`S_0 := (m1_0, m2_0, v_0) = (0, 0, 0)`, representing initial estimates - for the first and second moments. In practice these values are stored as pytrees - containing all zeros, with the same shape as the model updates. - At step :math:`t`, the ``update`` function of this optimizer takes as - arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t` - and the parameters :math:`\theta_t` and computes updates :math:`\theta_{t+1}` and - new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have, + :math:`S_0 := (m1_0, m2_0, v_0) = (0, 0, 0)`, representing initial + estimates for the first and second moments. In practice these values are + stored as pytrees containing all zeros, with the same shape as the model + updates. At step :math:`t`, the ``update`` function of this optimizer takes + as arguments the incoming gradients :math:`g_t`, the optimizer state + :math:`S_t` and the parameters :math:`\theta_t` and computes updates + :math:`\theta_{t+1}` and new state :math:`S_{t+1}`. Thus, for + :math:`t > 0`, we have, .. math:: @@ -147,9 +145,9 @@ def ademamix( v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ - \theta_t &\leftarrow \theta_{t-1} - \eta \cdot \left( (\hat{m1}_t + - \alpha m2_t) / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon\right) - + \lambda \theta_{t-1} \right).\\ + \theta_t &\leftarrow \theta_{t-1} - \eta \cdot \left( + (\hat{m1}_t + \alpha m2_t) / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + + \varepsilon\right) + \lambda \theta_{t-1} \right).\\ S_t &\leftarrow (m1_t, m2_t, v_t). \end{align*} @@ -194,7 +192,8 @@ def ademamix( b1: Exponential decay rate to track the fast EMA. b2: Exponential decay rate to track the second moment of past gradients. b3: Exponenital decay rate to track the slow EMA. - alpha: Mixing coefficient in the linear combination fo the fast and slow EMAs. + alpha: Mixing coefficient in the linear combination fo the fast and + slow EMAs. eps: A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. eps_root: A small constant applied to denominator inside the square root (as @@ -211,7 +210,8 @@ def ademamix( .. seealso:: See the related functions :func:`optax.adam`, :func:`optax.nadamw`, as well - as the example :doc:`../_collections/examples/contrib/rosenbrock_ademamix` for a use case. + as the example :doc:`../_collections/examples/contrib/rosenbrock_ademamix` + for a use case. """ return combine.chain( scale_by_ademamix(b1, b2, b3, alpha, eps, eps_root), From 96903115014187dec4b51bdc178337674ce3c609 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 19:53:07 -0400 Subject: [PATCH 13/32] fixed docstring --- optax/contrib/_ademamix.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 7b19fc532..f5d5bd97a 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -164,8 +164,8 @@ def ademamix( >>> import optax >>> import jax >>> import jax.numpy as jnp - >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> solver = optax.ademamix(learning_rate=0.003) + >>> def f(x): return jnp.sum(jnp.square(x)) # simple quadratic function + >>> solver = optax.contrib.ademamix(learning_rate=0.01) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 @@ -175,11 +175,11 @@ def ademamix( ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) - Objective function: 1.40E+01 - Objective function: 1.39E+01 - Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01 + Objective function: 1.36E+01 + Objective function: 1.35E+01 + Objective function: 1.34E+01 References: "THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER" From f892e2933cb34ea1da7634013413e176c8a8cf4e Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Mon, 21 Oct 2024 20:17:32 -0400 Subject: [PATCH 14/32] updated notebook --- examples/contrib/rosenbrock_ademamix.ipynb | 316 +++++++++++++++++---- 1 file changed, 265 insertions(+), 51 deletions(-) diff --git a/examples/contrib/rosenbrock_ademamix.ipynb b/examples/contrib/rosenbrock_ademamix.ipynb index 8835b03a0..4ccac1917 100644 --- a/examples/contrib/rosenbrock_ademamix.ipynb +++ b/examples/contrib/rosenbrock_ademamix.ipynb @@ -85,12 +85,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 21, "id": "a153b4de-331c-4c78-aca6-63864e1551e0", "metadata": {}, "outputs": [], "source": [ - "num_iterations = 7500" + "num_iterations = 100000" ] }, { @@ -103,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "id": "92b6987c-8ba1-43bc-8083-4c2b6324cb28", "metadata": {}, "outputs": [ @@ -112,35 +112,49 @@ "output_type": "stream", "text": [ "Objective function: 1616.0\n", - "Objective function at iteration 0 = 1599.2254638671875\n", - "Objective function at iteration 1000 = 11.406792640686035\n", - "Objective function at iteration 2000 = 11.379987716674805\n", - "Objective function at iteration 3000 = 11.344354629516602\n", - "Objective function at iteration 4000 = 11.301033020019531\n", - "Objective function at iteration 5000 = 11.250575065612793\n", - "Objective function at iteration 6000 = 11.193216323852539\n", - "Objective function at iteration 7000 = 11.129084587097168\n", + "Objective function for b1=0.9 at iteration 0 = 1599.2254638671875\n", + "Objective function for b1=0.9 at iteration 10000 = 10.89592456817627\n", + "Objective function for b1=0.9 at iteration 20000 = 9.620516777038574\n", + "Objective function for b1=0.9 at iteration 30000 = 7.285767555236816\n", + "Objective function for b1=0.9 at iteration 40000 = 3.306288242340088\n", + "Objective function for b1=0.9 at iteration 50000 = 0.26169437170028687\n", + "Objective function for b1=0.9 at iteration 60000 = 0.009876935742795467\n", + "Objective function for b1=0.9 at iteration 70000 = 9.95625596260652e-05\n", + "Objective function for b1=0.9 at iteration 80000 = 6.432726706862013e-08\n", + "Objective function for b1=0.9 at iteration 90000 = 5.157154703283595e-10\n", "Objective function: 1616.0\n", - "Objective function at iteration 0 = 1599.2254638671875\n", - "Objective function at iteration 1000 = 11.371587753295898\n", - "Objective function at iteration 2000 = 11.315672874450684\n", - "Objective function at iteration 3000 = 11.277731895446777\n", - "Objective function at iteration 4000 = 11.231626510620117\n", - "Objective function at iteration 5000 = 11.177850723266602\n", - "Objective function at iteration 6000 = 11.116741180419922\n", - "Objective function at iteration 7000 = 11.04842472076416\n", + "Objective function for b1=0.99 at iteration 0 = 1599.2254638671875\n", + "Objective function for b1=0.99 at iteration 10000 = 10.799932479858398\n", + "Objective function for b1=0.99 at iteration 20000 = 9.439836502075195\n", + "Objective function for b1=0.99 at iteration 30000 = 6.946890830993652\n", + "Objective function for b1=0.99 at iteration 40000 = 2.7601280212402344\n", + "Objective function for b1=0.99 at iteration 50000 = 0.17759834229946136\n", + "Objective function for b1=0.99 at iteration 60000 = 0.005802110303193331\n", + "Objective function for b1=0.99 at iteration 70000 = 4.045083551318385e-05\n", + "Objective function for b1=0.99 at iteration 80000 = 1.2925656989182244e-08\n", + "Objective function for b1=0.99 at iteration 90000 = 6.390479256879189e-10\n", "Objective function: 1616.0\n", - "Objective function at iteration 0 = 1599.2254638671875\n", - "Objective function at iteration 1000 = 17.629518508911133\n", - "Objective function at iteration 2000 = 66.02627563476562\n", - "Objective function at iteration 3000 = 36.76356887817383\n", - "Objective function at iteration 4000 = 11.747029304504395\n", - "Objective function at iteration 5000 = 13.123286247253418\n", - "Objective function at iteration 6000 = 11.120031356811523\n", - "Objective function at iteration 7000 = 10.75074577331543\n", + "Objective function for b1=0.999 at iteration 0 = 1599.2254638671875\n", + "Objective function for b1=0.999 at iteration 10000 = 10.194862365722656\n", + "Objective function for b1=0.999 at iteration 20000 = 9.375121116638184\n", + "Objective function for b1=0.999 at iteration 30000 = 7.936856746673584\n", + "Objective function for b1=0.999 at iteration 40000 = 5.422780513763428\n", + "Objective function for b1=0.999 at iteration 50000 = 1.4608842134475708\n", + "Objective function for b1=0.999 at iteration 60000 = 0.057731419801712036\n", + "Objective function for b1=0.999 at iteration 70000 = 0.0010820545721799135\n", + "Objective function for b1=0.999 at iteration 80000 = 6.941367587387504e-07\n", + "Objective function for b1=0.999 at iteration 90000 = 3.984723662142642e-11\n", "Objective function: 1616.0\n", - "Objective function at iteration 0 = 1599.2281494140625\n", - "Objective function at iteration 1000 = 44.154151916503906\n" + "Objective function for b1=0.9999 at iteration 0 = 1599.2281494140625\n", + "Objective function for b1=0.9999 at iteration 10000 = 29.86247444152832\n", + "Objective function for b1=0.9999 at iteration 20000 = 9.297667503356934\n", + "Objective function for b1=0.9999 at iteration 30000 = 7.363901138305664\n", + "Objective function for b1=0.9999 at iteration 40000 = 3.581587553024292\n", + "Objective function for b1=0.9999 at iteration 50000 = 0.872508704662323\n", + "Objective function for b1=0.9999 at iteration 60000 = 1.0354793071746826\n", + "Objective function for b1=0.9999 at iteration 70000 = 0.3354209363460541\n", + "Objective function for b1=0.9999 at iteration 80000 = 0.09372159093618393\n", + "Objective function for b1=0.9999 at iteration 90000 = 0.09824670851230621\n" ] } ], @@ -161,8 +175,8 @@ " updates, opt_state = solver.update(grad, opt_state, params)\n", " params = optax.apply_updates(params, updates)\n", " all_params.append(params)\n", - " if i%1000 == 0:\n", - " print(f\"Objective function at iteration {i} = {rosenbrock(params)}\")\n", + " if i%10000 == 0:\n", + " print(f\"Objective function for b1={b1} at iteration {i} = {rosenbrock(params)}\")\n", " all_b1_params.append(all_params)\n", "all_b1_params_array = jnp.array(all_b1_params)" ] @@ -185,7 +199,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "01fe6b99-cb4e-4203-8490-75be300448ee", "metadata": {}, "outputs": [], @@ -204,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "e652d62f-4135-478b-8995-a34d7729c30a", "metadata": {}, "outputs": [], @@ -228,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 25, "id": "11a4561a-1d92-44ce-bab0-5af22f028167", "metadata": {}, "outputs": [ @@ -237,7 +251,207 @@ "output_type": "stream", "text": [ "Objective function: 1616.0\n", - "Objective function: 1616.0\n" + "Objective function for b3=0.0 at iteration 0 = 1599.227294921875\n", + "Objective function for b3=0.9047933220863342 at iteration 1000 = 11.408196449279785\n", + "Objective function for b3=0.9512062072753906 at iteration 2000 = 11.29694652557373\n", + "Objective function for b3=0.9672003984451294 at iteration 3000 = 11.221567153930664\n", + "Objective function for b3=0.9752980470657349 at iteration 4000 = 11.092667579650879\n", + "Objective function for b3=0.9801891446113586 at iteration 5000 = 10.87605094909668\n", + "Objective function for b3=0.98346346616745 at iteration 6000 = 10.514174461364746\n", + "Objective function for b3=0.9858089685440063 at iteration 7000 = 9.911301612854004\n", + "Objective function for b3=0.9875717759132385 at iteration 8000 = 8.91100025177002\n", + "Objective function for b3=0.9889450073242188 at iteration 9000 = 7.268399715423584\n", + "Objective function for b3=0.9900450110435486 at iteration 10000 = 4.631922721862793\n", + "Objective function for b3=0.9909458756446838 at iteration 11000 = 1.0591912269592285\n", + "Objective function for b3=0.9916972517967224 at iteration 12000 = 0.0565548874437809\n", + "Objective function for b3=0.9923334717750549 at iteration 13000 = 0.0014951552730053663\n", + "Objective function for b3=0.9928791522979736 at iteration 14000 = 2.036277919614804e-06\n", + "Objective function for b3=0.9933522939682007 at iteration 15000 = 4.5986325858393684e-11\n", + "Objective function for b3=0.9937664866447449 at iteration 16000 = 1.4210854715202004e-12\n", + "Objective function for b3=0.9941320419311523 at iteration 17000 = 1.4210854715202004e-12\n", + "Objective function for b3=0.9944571256637573 at iteration 18000 = 1.4210854715202004e-12\n", + "Objective function for b3=0.9947481155395508 at iteration 19000 = 0.0\n", + "Objective function for b3=0.9950100779533386 at iteration 20000 = 0.0\n", + "Objective function for b3=0.9952471256256104 at iteration 21000 = 0.0\n", + "Objective function for b3=0.9954626560211182 at iteration 22000 = 0.0\n", + "Objective function for b3=0.9956595301628113 at iteration 23000 = 0.0\n", + "Objective function for b3=0.9958399534225464 at iteration 24000 = 0.0\n", + "Objective function for b3=0.9960060715675354 at iteration 25000 = 0.0\n", + "Objective function for b3=0.9961593747138977 at iteration 26000 = 0.0\n", + "Objective function for b3=0.9963013529777527 at iteration 27000 = 0.0\n", + "Objective function for b3=0.9964331984519958 at iteration 28000 = 0.0\n", + "Objective function for b3=0.9965559840202332 at iteration 29000 = 0.0\n", + "Objective function for b3=0.9966706037521362 at iteration 30000 = 0.0\n", + "Objective function for b3=0.9967778325080872 at iteration 31000 = 0.0\n", + "Objective function for b3=0.9968783855438232 at iteration 32000 = 0.0\n", + "Objective function for b3=0.9969727993011475 at iteration 33000 = 0.0\n", + "Objective function for b3=0.9970617294311523 at iteration 34000 = 0.0\n", + "Objective function for b3=0.9971455335617065 at iteration 35000 = 0.0\n", + "Objective function for b3=0.997224748134613 at iteration 36000 = 0.0\n", + "Objective function for b3=0.9972996115684509 at iteration 37000 = 0.0\n", + "Objective function for b3=0.9973706007003784 at iteration 38000 = 0.0\n", + "Objective function for b3=0.9974379539489746 at iteration 39000 = 0.0\n", + "Objective function for b3=0.9975019097328186 at iteration 40000 = 0.0\n", + "Objective function for b3=0.9975627660751343 at iteration 41000 = 0.0\n", + "Objective function for b3=0.997620701789856 at iteration 42000 = 0.0\n", + "Objective function for b3=0.9976760149002075 at iteration 43000 = 0.0\n", + "Objective function for b3=0.9977287650108337 at iteration 44000 = 0.0\n", + "Objective function for b3=0.9977791905403137 at iteration 45000 = 0.0\n", + "Objective function for b3=0.997827410697937 at iteration 46000 = 0.0\n", + "Objective function for b3=0.9978735446929932 at iteration 47000 = 0.0\n", + "Objective function for b3=0.9979178309440613 at iteration 48000 = 0.0\n", + "Objective function for b3=0.9979602694511414 at iteration 49000 = 0.0\n", + "Objective function for b3=0.9980010390281677 at iteration 50000 = 0.0\n", + "Objective function for b3=0.9980401992797852 at iteration 51000 = 0.0\n", + "Objective function for b3=0.9980778098106384 at iteration 52000 = 0.0\n", + "Objective function for b3=0.9981140494346619 at iteration 53000 = 0.0\n", + "Objective function for b3=0.9981489777565002 at iteration 54000 = 0.0\n", + "Objective function for b3=0.9981825947761536 at iteration 55000 = 0.0\n", + "Objective function for b3=0.9982150197029114 at iteration 56000 = 0.0\n", + "Objective function for b3=0.9982463121414185 at iteration 57000 = 0.0\n", + "Objective function for b3=0.9982765316963196 at iteration 58000 = 0.0\n", + "Objective function for b3=0.9983056783676147 at iteration 59000 = 0.0\n", + "Objective function for b3=0.9983339309692383 at iteration 60000 = 0.0\n", + "Objective function for b3=0.9983612298965454 at iteration 61000 = 0.0\n", + "Objective function for b3=0.9983876347541809 at iteration 62000 = 0.0\n", + "Objective function for b3=0.9984132051467896 at iteration 63000 = 0.0\n", + "Objective function for b3=0.9984379410743713 at iteration 64000 = 0.0\n", + "Objective function for b3=0.9984619617462158 at iteration 65000 = 0.0\n", + "Objective function for b3=0.998485267162323 at iteration 66000 = 0.0\n", + "Objective function for b3=0.9985078573226929 at iteration 67000 = 0.0\n", + "Objective function for b3=0.9985297918319702 at iteration 68000 = 0.0\n", + "Objective function for b3=0.998551070690155 at iteration 69000 = 0.0\n", + "Objective function for b3=0.9985717535018921 at iteration 70000 = 0.0\n", + "Objective function for b3=0.9985918402671814 at iteration 71000 = 0.0\n", + "Objective function for b3=0.9986113905906677 at iteration 72000 = 0.0\n", + "Objective function for b3=0.9986304044723511 at iteration 73000 = 0.0\n", + "Objective function for b3=0.9986488819122314 at iteration 74000 = 0.0\n", + "Objective function for b3=0.9986668825149536 at iteration 75000 = 0.0\n", + "Objective function for b3=0.9986844062805176 at iteration 76000 = 0.0\n", + "Objective function for b3=0.9987015128135681 at iteration 77000 = 0.0\n", + "Objective function for b3=0.9987181425094604 at iteration 78000 = 0.0\n", + "Objective function for b3=0.9987343549728394 at iteration 79000 = 0.0\n", + "Objective function for b3=0.9987501502037048 at iteration 80000 = 0.0\n", + "Objective function for b3=0.9987655878067017 at iteration 81000 = 0.0\n", + "Objective function for b3=0.9987806081771851 at iteration 82000 = 0.0\n", + "Objective function for b3=0.9987953305244446 at iteration 83000 = 0.0\n", + "Objective function for b3=0.9988096356391907 at iteration 84000 = 0.0\n", + "Objective function for b3=0.9988236427307129 at iteration 85000 = 0.0\n", + "Objective function for b3=0.9988372921943665 at iteration 86000 = 0.0\n", + "Objective function for b3=0.9988507032394409 at iteration 87000 = 0.0\n", + "Objective function for b3=0.9988637566566467 at iteration 88000 = 0.0\n", + "Objective function for b3=0.9988765120506287 at iteration 89000 = 0.0\n", + "Objective function for b3=0.9988889694213867 at iteration 90000 = 0.0\n", + "Objective function for b3=0.9989011883735657 at iteration 91000 = 0.0\n", + "Objective function for b3=0.9989131093025208 at iteration 92000 = 0.0\n", + "Objective function for b3=0.9989247918128967 at iteration 93000 = 0.0\n", + "Objective function for b3=0.9989362359046936 at iteration 94000 = 0.0\n", + "Objective function for b3=0.9989473819732666 at iteration 95000 = 0.0\n", + "Objective function for b3=0.9989583492279053 at iteration 96000 = 0.0\n", + "Objective function for b3=0.9989690780639648 at iteration 97000 = 0.0\n", + "Objective function for b3=0.9989796280860901 at iteration 98000 = 0.0\n", + "Objective function for b3=0.9989899396896362 at iteration 99000 = 0.0\n", + "Objective function: 1616.0\n", + "Objective function for b3=0.0 at iteration 0 = 1599.227294921875\n", + "Objective function for b3=0.9900476932525635 at iteration 1000 = 11.411341667175293\n", + "Objective function for b3=0.9950113892555237 at iteration 2000 = 11.296905517578125\n", + "Objective function for b3=0.9966714978218079 at iteration 3000 = 11.221532821655273\n", + "Objective function for b3=0.9975025653839111 at iteration 4000 = 11.092639923095703\n", + "Objective function for b3=0.9980015754699707 at iteration 5000 = 10.876032829284668\n", + "Objective function for b3=0.9983343482017517 at iteration 6000 = 10.514177322387695\n", + "Objective function for b3=0.9985721111297607 at iteration 7000 = 9.911361694335938\n", + "Objective function for b3=0.9987505078315735 at iteration 8000 = 8.911260604858398\n", + "Objective function for b3=0.9988892674446106 at iteration 9000 = 7.269254684448242\n", + "Objective function for b3=0.9990003108978271 at iteration 10000 = 4.634438514709473\n", + "Objective function for b3=0.9990911483764648 at iteration 11000 = 1.0626317262649536\n", + "Objective function for b3=0.9991668462753296 at iteration 12000 = 0.0512956939637661\n", + "Objective function for b3=0.9992309212684631 at iteration 13000 = 0.00029214631649665534\n", + "Objective function for b3=0.9992858171463013 at iteration 14000 = 0.0002549797063693404\n", + "Objective function for b3=0.999333381652832 at iteration 15000 = 6.654856406385079e-05\n", + "Objective function for b3=0.99937504529953 at iteration 16000 = 1.4632985767093487e-05\n", + "Objective function for b3=0.9994118213653564 at iteration 17000 = 3.8770863284298684e-06\n", + "Objective function for b3=0.9994444847106934 at iteration 18000 = 1.1391578027541982e-06\n", + "Objective function for b3=0.9994736909866333 at iteration 19000 = 3.6230431987860356e-07\n", + "Objective function for b3=0.999500036239624 at iteration 20000 = 1.222501424535949e-07\n", + "Objective function for b3=0.9995238184928894 at iteration 21000 = 4.350654592144565e-08\n", + "Objective function for b3=0.9995454549789429 at iteration 22000 = 1.6131096458593674e-08\n", + "Objective function for b3=0.9995652437210083 at iteration 23000 = 6.2786540411252645e-09\n", + "Objective function for b3=0.9995833039283752 at iteration 24000 = 2.7355170573173382e-08\n", + "Objective function for b3=0.9995999932289124 at iteration 25000 = 1.852578179750708e-08\n", + "Objective function for b3=0.9996153712272644 at iteration 26000 = 1.0457262078489293e-07\n", + "Objective function for b3=0.9996296167373657 at iteration 27000 = 4.015987542516086e-08\n", + "Objective function for b3=0.9996428489685059 at iteration 28000 = 1.5677557030358003e-09\n", + "Objective function for b3=0.9996551275253296 at iteration 29000 = 7.655046374566155e-09\n", + "Objective function for b3=0.9996666312217712 at iteration 30000 = 5.6290506478262614e-08\n", + "Objective function for b3=0.9996774196624756 at iteration 31000 = 2.7529409862836474e-09\n", + "Objective function for b3=0.9996874928474426 at iteration 32000 = 7.927974365884438e-08\n", + "Objective function for b3=0.9996969699859619 at iteration 33000 = 1.0755715607047023e-07\n", + "Objective function for b3=0.9997058510780334 at iteration 34000 = 9.467848371969012e-08\n", + "Objective function for b3=0.9997142553329468 at iteration 35000 = 7.07339040673105e-08\n", + "Objective function for b3=0.9997221827507019 at iteration 36000 = 2.70986788564187e-08\n", + "Objective function for b3=0.9997296929359436 at iteration 37000 = 3.4848568475354114e-08\n", + "Objective function for b3=0.9997368454933167 at iteration 38000 = 4.460630975700042e-09\n", + "Objective function for b3=0.9997435808181763 at iteration 39000 = 1.674882810220879e-08\n", + "Objective function for b3=0.9997499585151672 at iteration 40000 = 5.5214698591044e-08\n", + "Objective function for b3=0.9997560977935791 at iteration 41000 = 2.1872224920116423e-08\n", + "Objective function for b3=0.9997618794441223 at iteration 42000 = 1.8436061566262651e-09\n", + "Objective function for b3=0.9997674226760864 at iteration 43000 = 5.354401366730599e-08\n", + "Objective function for b3=0.9997727274894714 at iteration 44000 = 8.93862761586206e-12\n", + "Objective function for b3=0.9997777342796326 at iteration 45000 = 7.074592645039957e-08\n", + "Objective function for b3=0.9997825622558594 at iteration 46000 = 1.0837551656095457e-07\n", + "Objective function for b3=0.9997872114181519 at iteration 47000 = 6.424252774195338e-08\n", + "Objective function for b3=0.9997916221618652 at iteration 48000 = 2.5165093120449455e-08\n", + "Objective function for b3=0.9997959136962891 at iteration 49000 = 4.583888824072346e-08\n", + "Objective function for b3=0.9997999668121338 at iteration 50000 = 4.14928891245836e-09\n", + "Objective function for b3=0.999803900718689 at iteration 51000 = 1.0028271901774133e-07\n", + "Objective function for b3=0.9998076558113098 at iteration 52000 = 4.8693053145143494e-08\n", + "Objective function for b3=0.9998112916946411 at iteration 53000 = 7.591589223920892e-08\n", + "Objective function for b3=0.9998148083686829 at iteration 54000 = 2.0316193172220665e-08\n", + "Objective function for b3=0.9998181462287903 at iteration 55000 = 1.0680969353416003e-07\n", + "Objective function for b3=0.9998214244842529 at iteration 56000 = 2.4416024757556443e-08\n", + "Objective function for b3=0.9998245239257812 at iteration 57000 = 4.2578651715530214e-08\n", + "Objective function for b3=0.9998275637626648 at iteration 58000 = 1.0644058079378738e-08\n", + "Objective function for b3=0.9998304843902588 at iteration 59000 = 6.184544076859311e-08\n", + "Objective function for b3=0.9998332858085632 at iteration 60000 = 5.135669667311049e-08\n", + "Objective function for b3=0.9998360276222229 at iteration 61000 = 2.991244230088341e-08\n", + "Objective function for b3=0.9998387098312378 at iteration 62000 = 7.362999099314038e-08\n", + "Objective function for b3=0.9998412728309631 at iteration 63000 = 7.823453529454127e-08\n", + "Objective function for b3=0.9998437166213989 at iteration 64000 = 2.4600041115263593e-08\n", + "Objective function for b3=0.9998461604118347 at iteration 65000 = 9.106253173740697e-08\n", + "Objective function for b3=0.999848484992981 at iteration 66000 = 9.104553555516759e-09\n", + "Objective function for b3=0.9998507499694824 at iteration 67000 = 4.70359395876585e-09\n", + "Objective function for b3=0.9998528957366943 at iteration 68000 = 7.890244546615577e-08\n", + "Objective function for b3=0.9998550415039062 at iteration 69000 = 9.323736094302149e-08\n", + "Objective function for b3=0.9998571276664734 at iteration 70000 = 1.139520122706017e-08\n", + "Objective function for b3=0.9998591542243958 at iteration 71000 = 1.6440836247966217e-08\n", + "Objective function for b3=0.9998610615730286 at iteration 72000 = 7.427800596815359e-08\n", + "Objective function for b3=0.9998629689216614 at iteration 73000 = 9.9156537203271e-08\n", + "Objective function for b3=0.9998648166656494 at iteration 74000 = 6.971504262764938e-09\n", + "Objective function for b3=0.9998666644096375 at iteration 75000 = 8.126028205879265e-08\n", + "Objective function for b3=0.9998683929443359 at iteration 76000 = 1.2932197535064915e-07\n", + "Objective function for b3=0.9998701214790344 at iteration 77000 = 1.196456196339568e-07\n", + "Objective function for b3=0.9998717904090881 at iteration 78000 = 6.57613128396406e-08\n", + "Objective function for b3=0.9998733997344971 at iteration 79000 = 9.470069528560998e-08\n", + "Objective function for b3=0.999875009059906 at iteration 80000 = 1.3248055097392353e-08\n", + "Objective function for b3=0.9998764991760254 at iteration 81000 = 4.3568547880568076e-08\n", + "Objective function for b3=0.9998780488967896 at iteration 82000 = 1.0066288780308241e-07\n", + "Objective function for b3=0.9998794794082642 at iteration 83000 = 1.3248055097392353e-08\n", + "Objective function for b3=0.9998809099197388 at iteration 84000 = 4.3568547880568076e-08\n", + "Objective function for b3=0.9998823404312134 at iteration 85000 = 1.0448232501403254e-07\n", + "Objective function for b3=0.9998837113380432 at iteration 86000 = 6.760933501936961e-08\n", + "Objective function for b3=0.9998850226402283 at iteration 87000 = 8.262901474154205e-08\n", + "Objective function for b3=0.9998863339424133 at iteration 88000 = 3.462270470322437e-08\n", + "Objective function for b3=0.9998876452445984 at iteration 89000 = 8.820779839879833e-08\n", + "Objective function for b3=0.9998888969421387 at iteration 90000 = 1.028013230097713e-10\n", + "Objective function for b3=0.9998900890350342 at iteration 91000 = 1.3802278431285231e-08\n", + "Objective function for b3=0.9998912811279297 at iteration 92000 = 8.891692004908691e-08\n", + "Objective function for b3=0.9998924732208252 at iteration 93000 = 6.484791015282099e-08\n", + "Objective function for b3=0.9998936057090759 at iteration 94000 = 3.8966874171819654e-08\n", + "Objective function for b3=0.9998947381973267 at iteration 95000 = 1.4512920643028338e-08\n", + "Objective function for b3=0.9998958110809326 at iteration 96000 = 8.160476738794387e-08\n", + "Objective function for b3=0.9998968839645386 at iteration 97000 = 8.40046254779736e-08\n", + "Objective function for b3=0.9998979568481445 at iteration 98000 = 4.609432835422922e-08\n", + "Objective function for b3=0.9998989701271057 at iteration 99000 = 8.32615398849157e-09\n" ] } ], @@ -256,13 +470,13 @@ " print(\"Objective function: \", rosenbrock(params))\n", " all_params=[params]\n", " opt_state = solver.init(params)\n", - " for i in range(100000):\n", + " for i in range(num_iterations):\n", " grad = jax.grad(rosenbrock)(params)\n", " updates, opt_state = solver.update(grad, opt_state, params)\n", " params = optax.apply_updates(params, updates)\n", " all_params.append(params)\n", - " # if i%1000 == 0:\n", - " # print(f\"Objective function at iteration {i} = {rosenbrock(params)}\")\n", + " if i%1000 == 0:\n", + " print(f\"Objective function for b3={b3(i)} at iteration {i} = {rosenbrock(params)}\")\n", " all_ademamix_params.append(all_params)\n", "all_ademamix_params_array = jnp.array(all_ademamix_params)" ] @@ -277,13 +491,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 26, "id": "69d8642f-dfcc-4fac-8f85-3ee1fbfa135f", "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABmcAAANlCAYAAACJ1C0sAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3gUVRfG301200khgUAgJIHQlE6oAQIiQWpo0qUkKEpXUVGRrqBSFSxACGAQEOSjiaJUkSZVUekk9BpIAVI22fv9MTuzM7Mz21KB83uefXZ35k7Z2d2588459z0axhgDQRAEQRAEQRAEQRAEQRAEQRAEUSQ4FfcOEARBEARBEARBEARBEARBEARBPEtQcIYgCIIgCIIgCIIgCIIgCIIgCKIIoeAMQRAEQRAEQRAEQRAEQRAEQRBEEULBGYIgCIIgCIIgCIIgCIIgCIIgiCKEgjMEQRAEQRAEQRAEQRAEQRAEQRBFCAVnCIIgCIIgCIIgCIIgCIIgCIIgihAKzhAEQRAEQRAEQRAEQRAEQRAEQRQhFJwhCIIgCIIgCIIgCIIgCIIgCIIoQig4QxAEQRAEQRAEQRAEQRAEQRAEUYRQcIYgCIHQ0FBoNBrs2bOnuHcl3wwZMgQajQZTpkwp7l15amjdujU0Gg2WL19e3LvyxKPRaKDRaJCcnFzcu/LUkJycLBzXZx06FgRBEMSzBGkYwhKkYQoO0jD5Y/ny5dBoNGjdunVx70qRM2XKFGg0GgwZMqS4d6XYoWNByKHgDPHMwV+ciR/Ozs7w9fVF48aNMXXqVDx48KC4d5MogfBiydrj5MmTxb2rdjNlyhRMmTIFqampxb0rJZLBgwcL329iYmJx706xwwsLRx5POydPnsSUKVPoBgBBEARRoJCGIRyFNMyzC2mYZxOl/sKWx7MQLFi+fDmmTJnyRJ7viKcXbXHvAEEUF8HBwahUqRIAQK/X48qVKzhy5AiOHDmCxYsXY+/evQgPDy/mvSRKImXLlkXVqlVV53t5eRXh3hQMU6dOBcCJN19fX8U2lSpVQvXq1eHj41OEe1b8ZGRkYP369cL7ZcuWYeDAgcW4R8VPYGAgIiMjzabfuXMH58+fBwDF+YWJTqdD9erVi3SbSpw8eRJTp05FVFRUsQmcknIsCIIgiIKHNAzhKKRhSMM86xqmuPHx8UH16tWFc3hhUbt2beTm5ppNP3XqFNLT01XPBdWqVSu0fQoICED16tVRvnz5QtuGLSxfvhx79+5FaGgo6tWrVyz7UFKOBVFyoOAM8cwSGxtrNlz8t99+Q58+fXDjxg0MHz4cO3fuLJ6dI0o0HTp0eCaz4leuXFncu1AsrFmzBo8fP4avry9SU1OxZ88eJCUlISwsrLh3rdjo0KEDOnToYDZ9+fLlGDp0KADgjz/+KNJ9qlChAs6cOVOk2yyp0LEgCIJ4eiENQzgKaZhnC9IwJY/u3buje/fuhb6dL7/8UnF669atsXfv3mI5F4waNQqjRo0q0m2WVOhYEHLI1owgRLRr1w4zZswAAOzevRt37twp5j0iCKK4WbZsGQDuIqphw4ZgjCEhIaGY94ogCIIgCIKDNAxBEHJIwxAEQTwZUHCGIGQ0b94cAMAYQ1JSktl8xhhWr16Ndu3awd/fHy4uLqhYsSL69++P48ePq653165d6N69O4KCgqDT6eDj44MqVaqge/fuwoWTnMzMTMyfPx+RkZHw8/ODq6srwsLCMHz4cMV9A6QFD+/fv49x48YhNDQUrq6uqFChAl599VXcunXL6nH4559/0Lt3b5QrVw5ubm6oUaMGpk+fjqysLMX24uKAf/75J3r16oVy5crB2dlZkt2n1+vx9ddfo0WLFvDz84ObmxsqV66M1157DRcuXLC4TxcuXMDo0aNRs2ZNeHl5oVSpUqhRowbi4uLw+++/W/1MPOnp6XjxxReh0WjQsGFD3L592+Zl7cFa8UlLRbvz+z0yxrBx40bExMQgKCgIrq6uCAwMRNOmTTF9+nTcvHkTgKkYHU9YWJjEd1b83Vn7PKdPn0ZsbKywn35+foiKisLSpUuRl5enuIz4d3Pq1Cn06dMHgYGBcHV1RfXq1TFt2jTk5OSofs7C5vTp0zh06BAAYNCgQRg8eDAAYMWKFTAYDBaX/fHHH9GiRQt4eXnB19cXrVq1wqZNmywuc+/ePSxZsgTdunVDtWrV4OnpCU9PT9SqVQvvvvuu6s0WcXFJxhgWLlyIevXqwdPTE+XKlcMrr7yCq1evCu137tyJ9u3bw9/fH56enmjZsiX27t1rz6Gxmz179kCj0SA0NBQAsHr1akRFRaF06dKSIr4PHz5EYmIi+vXrh5o1a8LHxwfu7u6oWrUqRowYoXrus/R/4jl+/DgGDx6M0NBQuLm5Cd/L8uXLLX6fqampmDFjBpo0aSKct8LCwhATEyPJxgwNDRVGDu3du9fMx1lePDW//xm1c21hHYvjx49jwIABCAkJgaurK7y8vBAaGoqXXnoJc+bMAWNMdXsEQRBE4UIahoM0TP4hDUMaxl4Nw2Pv9aVcH3z//fdo1qwZvL29ERAQgG7duuG///4T2h87dgw9evRAYGAg3N3d0bBhQ2zYsEFxXxzVFIW5T2LNJubdd98VtqlUP+nWrVsIDAyERqPBZ599prju/MLXppoyZQrS0tLw3nvvoXr16nB3dxeOBQCcO3cOs2bNQps2bRASEiJ8z82bN8cXX3yh+rvn/7Nqts95eXlISEhA27ZtERAQABcXF1SoUAEDBgzAX3/9ZXHfT5w4gdjYWFSpUgXu7u7w9fVF7dq1MWbMGKG2DP+98pp36NChknOG/DtxpM8Uf78GgwFfffUVGjduDB8fH4kWLIxjYTAYEB8fL+hrnU6HgIAAPPfcc4iNjcXu3bstHkOimGEE8YwRFRXFALDJkycrzt+/fz8DwACwU6dOSebp9XrWq1cvYX7FihVZREQE8/HxYQCYs7Mz+/bbb83WuWTJEmEZX19fVrduXVanTh3m5+fHALAKFSqYLXP58mX23HPPMQDMycmJVapUidWtW5d5eHgwAKxUqVJs9+7dqp9v+vTpLDg4mGm1Wla7dm1WrVo15uTkxACwypUrs7S0NLNlQ0JCGAA2a9Ys5uHhwVxdXVmDBg1YeHi4sP/NmjVjDx8+NFuWnz979mym1WqZl5cXa9iwIatWrRqbMmUKY4yx9PR01rJlS6Ft5cqVWcOGDYXP5O7uzjZv3qz4vSQkJDAXFxcGQPhMdevWZd7e3gwAi4qKkrQfPHiw4vd848YNVq9ePQaAtWvXjmVkZChuTwl+nYMHD7apPf9dJCQkKM5PSkoSjoXaso58j48fP2YxMTHCuv39/VlERASrUqUK0+l0kn2Kj49nkZGRQtuIiAgWGRkpPOLj4236PGvXrhW+H09PT9awYUMWFhYmrLddu3bs8ePHZsvx87/99lvm5uYm/G6CgoKEeT179rTpeBcG48ePZwBY8+bNGWOM3bt3T/ic27dvV13uo48+Eva/TJkyLCIigvn7+zMAbP78+cK8pKQkyXJffvklA8BcXFxYpUqVWEREBKtatarwvQUFBbFLly6ZbS8hIUH4H/Tr148BYOHh4ax27drCsqGhoezevXts0aJFTKPRsMDAQNagQQPm5eUlbPOPP/7I1/Hi90PpN717924GgIWEhLBx48YxACwwMJA1atSIBQUFCeezLVu2CP/zChUqsIYNG7IaNWowd3d3BoD5+Piww4cPm63f0v+JMcY+++wzptFohPNn3bp1WcWKFYVlunXrxnJzc82WO3r0KCtfvrzQLjw8nEVERLCyZcuaba9Xr16satWqDADz9vaW/JciIyPZzZs3hbb5/c9YOtcWxrH4+eefhd+Sl5cXq1WrFqtfvz4rU6aMsJxer1fcHkEQBJF/SMOQhiENQxrGVopawzDm2PWlWB+8//77DACrVKkSq1evHnNzc2MAWOnSpdm5c+fYxo0bmaurK/Pz82MNGzYUzkMajYatXbvWbH8c1RSFuU9izSYmJyeHNW7cWPF3k5eXx1588UUGgEVHRzODwaD6/VmD/y8onQv488SoUaNYeHg402g0rGbNmqxBgwasZs2aQruePXsKeiA8PJw1atRIOAfzny07O9ts/ZMnT1bd9v379yXn2KCgIFa/fn1WqlQpBoDpdDq2evVqxc80Y8YM4Xfn5ubG6tWrx55//nnm6ekp2d7x48dZZGSkcO6tWrWq5JwxatQoYZ2O9pn899uqVSvWo0cPBoAFBwezRo0asYCAAOF/UxjHYuDAgZJlIiIiWLVq1YTjMGDAAMXjR5QMKDhDPHNYEzb8hYy3t7fZhdiUKVMYAObh4cE2bNggTM/KymJvvfWWcKI+dOiQMC83N1e4oPniiy/Mbl6dPn2aLViwQDItOzub1a1blwFgMTExLDk5WbKt9957jwFgAQEBLCUlRfHz6XQ6Fh0dzW7cuCHMO378OAsMDGQA2KRJk8w+O9+p6nQ61qlTJ8m69+3bxwICAhgANmLECLNl+Y7A2dmZvf322ywzM1OYxx/HIUOGCBd6+/btE+anpaUJN5S9vLzMLvZ27NghXMyPHDmS3b9/XzL/0KFDbNGiRZJpSsLmzJkzLDQ0VOiccnJyzD6HJYpD2DjyPfL76evry9atW8fy8vKEeY8fP2YrV66UHH/GTN+f0oW2tc9z+vRp4UJ12LBhEuH722+/CRcxI0eONFsnv12dTsfeffddye9m1apVwoXWrl27VPersNDr9cJxFl98de/enQFgffr0UVzut99+Ez7Xp59+Khx/vV7PJk+eLIhLpeN9+PBhtm3bNpaVlSWZfvfuXfbqq68yAKx9+/Zm2+QvBHU6HStXrhw7cOCAMO/ixYvC775r167M3d2dLVmyRLiwf/jwIYuOjmYAWGRkpEPHSr4fSr9pXug4OzszV1dXlpiYKOyDwWAQPvOZM2fY+vXrzW46pKenC4KxZs2aZsLE0v9pzZo1wn9ixYoVkv/En3/+Kdy8mTZtmmS5W7duCb+B1q1bs3PnzknmJycns48++kjxGMhFl5iC+M9YOtcWxrHg+6R3333XrG+8fPmy5LdOEARBFDykYUjDkIYhDWMLxaFhHL2+5PUBHxgVBznv3LnDGjRowACwtm3bMl9fXzZt2jThXKTX64XfTHBwsNl1qKOaojD3yZJOuHjxohA4+Oqrr4TpH3/8MQO4pLZbt26ZLWcPtgRnnJ2dWd26dSW6R9ynbNy4kR0+fNjsuJ0+fZo1bdqUAWAzZ840W7+lgESHDh0YANaiRQtJckFeXh6bN28ec3JyYm5ubuzs2bOS5fjj6eTkxKZNm8YePXokzDMYDOy3335jK1euVDwGauc3xhzrM8X74+zszPz8/CSBUL1eL/xOCvpYnDhxQuj/5ckPBoOB7d27VzFYSJQcKDhDPHMoCRu9Xs8uXLjAPvzwQ+bs7MwALttHzMOHD4XO8rPPPlNcNx/h7ty5szDt5s2bwoWKrfBZahEREaoX3126dBEunJQ+X0BAAHvw4IHZcnPmzGEAWP369c3m8cKmdOnSitlYq1atEi5C5RcG/IVadHS04v4mJSUJ4uSHH34wm6/X64UsJXHWAmNMuAAaNGiQ4rqVkAubgwcPCgJz/PjxDmWc8Ou09Jg3b57QviCEjb3f419//SWsc+fOnTZ/tvwIm9jYWAaA1apVS/G4Ll26VPjdiAWaeLtt27ZV3Cb/O3/zzTdt/iwFxcaNGxnAZeCIv4NNmzYxAMzV1dVMZDPG2AsvvMAALjtMiXbt2tl0vJWoUKEC02g0khEYjEmDIkr/r0WLFgnzlW5M/PPPP8J8pd+brdgSnAHAPv74Y4e3wWdJyi+G1f5Per1eOLeJL67FHD16lGk0Gubr6yvJ9OIvvqtXr66YNamELcGZgvjPqJ1rGSucY+Hq6soAsNTUVNXtEgRBEIUHaRjSMKRhzCENY05Ra5j8XF+K9YHS+emnn34S5nfs2NFs/r1794Rr1JMnT6odEkXUNEVh7pM1nbB69Wrhuzt16hQ7cOAA02q1TKPRWBzxZCu2BGdcXFzs1qg858+fZwBYjRo1zOapBST4oGClSpVUdejo0aMZAPbGG28I07KzswV3A6VgrxrWzm+O9pmMSbXwqlWrVPehoI8F/7tR++8SJR+qOUM8s0ydOlXwl9TpdAgPD8fHH38MPz8/fPbZZ5g4caKk/b59+5Ceng43Nze88cYbiuscP348AGDHjh2C12bZsmXh7u6OtLQ0bNu2zaZ9W7t2LQAgLi4OOp1OsU3Pnj0BcLUjlOjfvz98fX3Npjdr1gwALHojx8XFwcvLy2x6nz59UK5cOej1evz666+qyyrxyy+/wGAwoFKlSsK+i9FqtRg3bhwA4KeffhKmJycnC56eH374oeo+W2Lr1q1o27Yt7t+/j3nz5uHzzz+3WIvBGmXLlkVkZKTio0KFCg6vVwl7v0fe37ZZs2Z44YUXCnRf1OB/1+PGjVM8roMGDULZsmWh1+vx22+/Ka5j5MiRitNt+b0WFnzBzJiYGMl30KFDB5QpUwbZ2dn4/vvvJcs8evRI8LEdM2aM4nr537kaWVlZ+P777zF8+HC89NJLaNmyJVq0aIEWLVogIyMDjDHBO1eOn58fXn75ZbPpDRs2FF6/9tprZvOff/55uLm5AQAuXrxocf8KgmHDhlmcn5eXh02bNmH06NHo1KkTWrVqJRyD8+fPA+C8hW3h8OHDuHz5MsqVK4fu3bsrtmnYsCFCQkKQmpqKY8eOCdN//PFHAMBbb70Fd3d3m7ZnCwXxn1E711oiP8ciJCQEAMx+8wRBEETRQhqGNIyjkIaRQhrGRH40TH6uL8UoaRRrGsbf3x9hYWEAlI91fjVFYeyTJfr27Yu4uDhkZWWhd+/e6NevH3Jzc/HOO+8gOjrarnU5Stu2bSU1ZpS4c+cOvvjiCwwcOBDt2rUT9CpfQ+Xs2bPIzMy0aXt8v9GvXz/Fcwag3G8cOHAAN2/ehKurK95++22btmULjvaZYkqVKoXevXvbvW1HjwWv0w4dOoRLly7ZvV2i+NEW9w4QRHERHByMSpUqAeCKK164cAGZmZnw9fVFmzZtzNqfPXsWAFfwWemiHwBq164NgLu5mpycjGrVqsHJyQnjx4/H9OnT0alTJ9SuXRtt27ZFs2bN0KpVK5QrV85sPXyRr6+++gqJiYmK2+ILxYmLfIupVq2a4vTAwEAAQEZGhuJ8AKhVq5bidGdnZ9SoUQO3bt3C6dOnFds8//zzitP54/fcc8/ByUk5Lswfv6SkJOTk5MDFxQWnTp0CwF3kqH0mS2zevBkzZsyAs7MzVq9ejT59+ti9DjkdOnRQLShZ0Nj7PfLHiy8KW9ikpaUJRT3Vfjc6nQ41atTAnTt3cObMGcU2+fm9FgZ37twRBDZfQJNHp9NhwIABmD9/PpYtWyYRZRcuXBAKh6r9F9SmA1zxzk6dOqkWqORJSUlRnF6lShXF6WXLlhVeh4eHq7a5cuUKHj58aHHb+SUgIECyP3Ju3ryJTp06WQ2+qB0DOfz5NDMzEy1atLC6vqtXr6JZs2bIyMjA5cuXARTs/6mg/jOWfkdqOHosAOC9995DXFwcRowYgTlz5qBdu3Zo1qwZoqKiBEFAEARBFD6kYUjDOAppGBOkYQpOw+Tn+pInICAAPj4+ZsvYqmHOnDljpmHyqykKY59s4YsvvsCBAweEc1Xjxo0xY8YMu9fjKNY0xvr16zF06FCLn40xhvv379sU9OV/Pxs2bMAff/yh2CYrKwuAtN/gzxm1atWCt7e31e3YiqN9ppjq1atDq7X/drujx6Jp06aIiorC3r17Ua1aNbRs2RKtWrVCs2bN0KJFC9XPQZQcKDhDPLPExsZiypQpwvvU1FS8/fbbWLZsGaKjo3Hy5ElB+ACmCyslIcJTvnx5s/YAl+EWEhKCL7/8En/99RdOnTqF+fPnQ6PRoG3btpg9ezbq1q0rtH/w4AEAU4djicePHytO9/T0VJyuJirE8BeTluapXWiqbdeR4+fv74/09HQAUM0csMalS5eQl5cHX19f1KxZ06F1FCf2fo/5PV72Iv4d2PLd2vu74T8nY8zmfRo9erTiRfiHH36IDh062LSOlStXIjc3F+XKlVPMUho8eDDmz5+P48eP4++//0adOnUAmD6fk5MTypQpo7hutf+XwWBAz549kZSUhPr162Pq1Klo2LAhAgIC4OLiAgBo1aoV9u3bB71er7gOteMozga01saeY+0IatvnGTp0KE6cOIHKlSvj448/RvPmzREYGAhXV1cAXBbjd999p3oM5PDn07S0NOzfv99qe/6cyv+XgIL9PxX2f8YSjh4LgOsz/fz8MHv2bBw6dAjffPMNvvnmGwBAkyZNMGvWLLRu3drufSIIgiDsgzSMOqRhSg6kYZ4dDZOf60uewtAw+dUUxaWrPDw8EBkZKQRnhg4dqjoSsTCwpDGSk5MxcOBAZGdno3fv3hgzZgxq1KgBHx8faLVaGAwGODs7A4DdWu38+fPCaCY1xKNxCuuckZ8+k8cRnQY4fiw0Gg22bt2KWbNmYcWKFdizZw/27NkDAHB3d0e/fv3w6aefIiAgwKH9IgofsjUjCCO+vr5YsmQJmjdvjgcPHmDEiBGS+aVKlQIAIcNGiZs3b5q1B7iTZVxcHE6ePIk7d+7gf//7H8aNG4dy5cphx44deOGFF3D9+nWhPR/Z3rVrFxhXG0r1kZycXBAfX8Lt27etzhN/Pltw9PjxWRB8lp29jBkzBoMHD0ZKSgpeeOEFm62QCgJrF2WPHj0q8G3m93jZi/h3YMt3a+/vxhFOnTqF/fv3mz0s/a7l8HYAt27dglarFexD+Ef9+vWFtvHx8cJr/vMZDAbcvXtXcd1q+/Hnn3/i9OnTcHd3x6+//oouXbogKChICMwAto8WeVK5desWtm/fDoDLGO3bty8qVaokiCjA/mPAn09btWpl9XzKGBOG44szsAry/1Sc/xlHjwVP9+7dsX//fty/fx/btm3DhAkTUKVKFRw+fBjt27cXsr0IgiCIooM0jAnSMAUDaRgTpGFMqO1Hfq8vC4PC0BRFxaZNm7B06VIhwPf+++/jypUrxbxXHGvWrEF2djYaN26M1atXIzIyEv7+/sIoEUeOKf/7WbZsmU2/H57COmfkp8/ML44eC37ZGTNm4OrVqzh//jyWL1+OgQMHQqPRYNmyZYiJiRFGxxElDwrOEIQIJycnzJ8/HwDnGcxHmwGgRo0aALhsAbUhnHyWmJubm6pPZ5kyZdCtWzfMmzcPZ8+eRVhYGO7fv481a9YIbfhhkn///Xc+P5Fj/Pvvv4rT8/LyhGGe9mZw8cfvv//+g8FgUGzDH7/KlSsLN6T5bJ6UlBScO3fOrm0C3He6bNkyDBs2DCkpKWjbti2OHDli93ocgc+YULuQdeTzWIM/XgcOHCjwdSvh4+MjZJX8888/im1yc3MFK4CiyPzbs2dPvoTAoUOH8N9//wHgMsTUHn5+fgCAVatWCV6z4eHhQrYQvw45av8v3sqsZs2ailktDx48KJTfTEmCPwalS5dWHFKfm5uLo0eP2rVO/nz677//qp57lChVqpRwHrfn/2TNC744/zOOHgs5Pj4+6NChA2bOnIkzZ86gadOmyMnJwdKlSwtqVwmCIAg7IA3DQRqmYCANw0EaRora/6ugri8LksLQFEXBtWvXEBsbCwCYN28eunfvjtTUVPTv379E3Fjnj2uLFi0UR8EdOnTI7nU62m/w54x//vnHLvtAa1qtoPpMRyioPjQ8PByDBw/Gd999h0OHDkGj0eDAgQOqdWuJ4oeCMwQho1GjRujcuTMAYPLkycL0Fi1awNvbG1lZWfj6668Vl50zZw4AoF27dpJsdzVKlSoldCrirDO+eNiiRYtUh/wXJkuXLlXMiPrhhx9w8+ZN6HQ6tGvXzq51vvTSS3BycsKVK1eEItticnNzsWDBAgBAp06dhOkhISGIiIgAAMycOdOubfI4OTlh8eLFGDFiBB48eIB27do5dOFgL1WrVgUAHDx4UHG+2u8oP/Ts2RMajQYHDx6UCHNreHh4AFC3mLAE/33Nnz9fMcPuu+++w507dxz63RQHy5YtA8Bd8N26dUv1ce7cOeh0OqSkpGDTpk0AODHbqlUrAMCXX36puH7+dy6H/w5u376teBznzZuH3NzcfH++kgx/DNLT0xV/iytXrsSdO3fsWmeLFi0QFBSElJQUSYagLfTq1QsAd+x5f19r2PJfKq7/TH6OhRparRZNmjQBIO3HCIIgiKKFNAxpmIKCNAwHaRgpahqmMK4v80thaIrCJi8vDwMGDMD9+/fRpUsXjBkzBkuXLkVwcDD2798vsbMsLvjjKh45wsMYw+zZs+1eJ99vrFy50q5RYs2bN0dQUBCys7Mxd+5cm5ezds4ojD7TVhw9FpaoXbu2UDuJtFrJhYIzBKEAL2h+//137Nq1CwB3wfLWW28BAKZMmYKNGzcK7XNycvDuu+/i999/h7OzMz788ENh3n///Ye4uDj88ccfZpkkv/32G3bu3AmAE1Q8r776KmrXro3z588jOjpaMXL+77//4qOPPsKWLVsK5kOLyMjIQP/+/QXPS4DLYho3bhwAIC4uzqIHpxIhISEYNGgQAGDUqFGSAmcZGRkYOnQoLl26BC8vL+E483z22WdwcnLC8uXLMXbsWLOhq3/++Se++uori9vXaDRYtGgRxo0bh7S0NERHR9vkiZsfunbtCgDYsmWLJKswKysLH3zwgV3Cw1Zq1aolZFf17NkT//vf/yRiIysrC4mJiWYF5vhihjt27LB7m+PHj4ebmxv++ecfDB8+XCKKd+3ahbfffhsAMHz4cLt/N0XN48ePsXbtWgCcv68lAgIC0KVLFwAmMQQAEyZMAMAV8pszZ47wv8/Ly8P06dOxe/duxfU1a9YMOp0O169fx6RJk4TsKIPBgEWLFuGTTz6Bm5tb/j5gCef5559HQEAAcnNzMWrUKElAZP369Rg9erTdx8DFxQWff/45AM7Le/78+RKPXgB4+PAhfvzxRwwbNkwy/Z133kFgYCDOnDmDTp064cKFC5L5ly9fltwAA0z/pX///Vd1OHxx/WccPRbp6el4+eWXsX37diHDkufYsWPCf0bcjxEEQRBFD2kY0jAFAWkY0jD2aJj8XGsXFoWhKQqb6dOn4/fff0dQUJBgT1e6dGmsWrUKzs7O+OSTTwrlv2cPUVFRAIB169bhp59+EqZnZGRg2LBh+PPPP+1eZ+fOnREdHY379++jTZs2Zv9xgKvB9dlnn0lG6et0OiHwPW3aNHzyySeS3x1jDDt37kRiYqJkXfw5Y/fu3YojvRztMwsCR49FYmIiJk2aJIz049Hr9fj888+RmpoKZ2dnia0hUcJgBPGMERUVxQCwyZMnW2zXqVMnBoC1bNlSmKbX61nPnj0ZAAaABQcHs0aNGjEfHx8GgDk5ObFvv/1Wsp4TJ04I7T08PFidOnVYo0aNWFBQkDA9JiaG5eXlSZa7cuUKq1+/vmRbTZo0YfXq1WO+vr7C9ISEBMXPJ5/Ok5SUJCwrJyQkhAFgs2bNYh4eHszNzY01bNiQVa1aVVimSZMmLD093WxZfn5SUpLqMU1PT2ctWrQQ2oaHh7OIiAjm4eHBADB3d3e2efNmxWWXLVvGdDodA8B0Oh2rU6cOq1u3rnDso6KiJO0HDx6s+j2/++67DADz8vJie/bsUd1fOfw6Bw8ebPMyPXr0ED5vhQoVWEREBCtVqhRzc3NjixcvVv0u8vM9Pn78mHXt2lWY7+/vzxo1asTCw8OFYyhf7+zZs4X2NWrUYK1atWJRUVGSdpb2ae3atczFxUU4rhEREaxy5crCOtu1a8ceP35stpy1301CQoLi91tYrFixggFgLi4u7O7du1bbb926VfjvX7t2TZj+/vvvC5+tbNmyrFGjRiwgIIABYPPnz1f93B999JEwr0yZMiwiIoKVKVOGAWDDhg1T/Q6sHSdLvxce/v+/e/duq59bDX4/lLaze/duBoCFhIRYXEd8fLywDh8fH9awYUNWoUIFBoC1b9+eDRw4UPG/be0zzp8/n2m1WgaAubm5sbp167ImTZqwKlWqMCcnJ9V9O3LkCCtXrpyw7qpVq7KIiAgWGBiouD2DwcBq164tnPMjIiJYVFQUi4qKYjdv3hTaFdZ/pjCOxYMHD4T1ubi4sOeee441btyYhYaGSvqGR48eqe4TQRAEkT9Iw5CGIQ1DGkaN4tYwjDl2rW2LPrB2rNW+X0c1RWHuk9rvYu/evczZ2Zk5OTmxXbt2ma1v8uTJwn/x3r17qvtlDX6/lM4Fls49PHl5eax169bC5w8LC2MNGzZkHh4ezMnJia1cuVL12PCfQWnbDx48YC+++KLZb69BgwaCFlbbt+nTpzONRiOci+vXr89q1arFPD09Fbd3+PBh4fdYoUIFFhkZyaKiotjYsWOFNo70mYzZ/r8v6GMxb948yXmzQYMGrH79+pI+d/bs2Rb3iSheKDhDPHPYKmyOHDkinMh27NghTDcYDGzVqlXshRdeYH5+fkyn07GgoCDWr18/dvToUbP1PHr0iMXHx7P+/fuzGjVqMD8/P6bVallAQAB78cUX2cqVK81EDU92djaLj49n0dHRrEyZMkyr1TJPT09Wo0YNFhsbyzZt2sQyMzMVP19+hM3u3bvZqVOnWK9evVjZsmWZi4sLq1atGpsyZYrixSljtgkbxhjLyclhixYtYs2bN2fe3t7MxcWFhYSEsGHDhrFz585ZXPbMmTNs+PDhrEqVKszNzY15e3uzmjVrsmHDhrF9+/ZJ2lq7uJg4caIgNsXfryUcETbZ2dls+vTprFq1aszFxYUFBASwHj16sL/++svid5Gf75Ex7nf6ww8/sA4dOrCyZcsynU7HAgMDWdOmTdmMGTMkN4kZ4y60Zs+ezerWrSsITfnxs7ZP//77LxsyZAirVKkSc3FxYT4+Pqxly5ZsyZIlLDc3V3GZkiZs+IvNnj172tQ+NzdXuEkxY8YMybwffviBNW/enHl4eDBvb2/WsmVLtnHjRsaY5c+9dOlSVq9ePebq6sq8vb1Z06ZN2dKlSxlj9l/o8zxJwRnGGPvf//7HmjVrxtzd3ZmnpyerW7cu+/zzz5ler1f9b9vyGU+fPs1GjhzJatasyTw9PZlWq2WBgYGsdevW7NNPP1U9B6WkpLDJkyez+vXrMy8vL+bm5sbCwsJYt27dWGJioln7q1evssGDB7Pg4GDhZoLS910Y/5nCOBa5ubls1apVLC4ujtWqVYv5+/szZ2dn5ufnx1q2bMm+/PJLlp2drbotgiAIIv+QhiENwxhpGNIwypQEDcOY/dfahRkIYcwxTVHUwZmUlBRWsWJFBoB98MEHiuvLzc1lLVu2ZABY165dVffLGvkNzjDGBVEnTJjAwsLCmE6nY2XKlGGdOnUSAsaOBGcY4/7L69atYzExMax8+fJMp9Mxd3d3VqVKFdavXz+2evVqlpaWprjsn3/+yQYOHCj8h/38/Fjt2rXZ2LFj2V9//WXWfvPmzax169bM19dXCNTI/6v29pmMFUxwxpFjceXKFTZ79mzWqVMnFhYWxjw9PZmLiwsLDg5mffr0Yb///rvF/SGKHw1jCuaaBEEQBEEQTxCXLl1ClSpV4Ozs/NTX5iEIgiAIgiAIgnhSmDRpEqZPn464uDiJJRdBEFRzhiAIgiCIpwDeX7506dLFvCcEQRAEQRAEQRAED2k1glCHgjMEQRAEQTzxrFu3DgDQoEGDYt4TgiAIgiAIgiAIAgDS0tLw66+/AiCtRhBKaIt7BwiCIAiCIBxl9OjR+Omnn5CUlAQAGDlyZDHvEUEQBEEQBEEQxLPNrVu3EBMTg7NnzyItLQ2VKlVC165di3u3CKLEQSNnCIIgCIJ4Yjl16hRu3LiB+vXrY/Xq1ejSpUtx7xJBEARBEARBEMQzTVZWFv788084OTmhR48e2LVrFzw8PIp7twiixKFhjLHi3gmCIAiCIAiCIAiCIAiCIAiCIIhnBRo5QxAEQRAEQRAEQRAEQRAEQRAEUYRQzRkrGAwG3LhxA6VKlYJGoynu3SEIgiAIgiCIQocxhoyMDAQFBcHJifK5CMuQZiIIgiAIgiCeNQpCM1Fwxgo3btxAcHBwce8GQRAEQRAEQRQ5V69eRcWKFYt7N4gSDmkmgiAIgiAI4lklP5qJgjNWKFWqFADuIHt7exfz3hQfP/v4CK87VDa+eNH43ApgkdzLv0tXAwD8i+dxAVUBAJcQhiRwC11DBdxLqgjc0ABXjMvfAnAHQIrokQbgEYDb/FYzAWQAeABAb3zOBZAF4LHxmW+Xa3zNPwOmn7oWgM74rAXgbnxoAZQC4Gac72d8784t5gbAB4Cn8dnf+IDxuSyAcsb3lYzPQQwBYdcAABVxHQAQhksAgMpIAgCE4zwA4Hn8CwCoc/8cAECzX7Trv4te7+Cefr4EMzqkpZlPJAiCIEok4n4VEPWtgKR/5ZH3swDX1wKQ9LcAkITKuIYKAMD1uYDlfheQ9r1pMHWruAOub80C1w/zrzNh6nP5Z/6hNy7L98N8Hyzuf/k+2A2Ah/G5lOjZHVznClMfHART/+tvmo1y4PpeY78r7nPF/a24rxX6Wb6PVehfn/V+NT09HcHBwcK1MEFYgjQThzXNBEjP5+LzuPwcbnb+tqSZFM/bGTCdu/lztfh8rYe5bhLfHuDP0/w528P47GacXkr22t10vgbMdROvnXjdJNJMAMzO34B1zQTA/HwOCOd0wFw3PevndoIgiCcJmzQTYNbHAtL7k4CyZgJg3ucC0n4XsKHvteWepVw3AdK+GDDdswS4TpV/z/e3vH7yg9m9S7X7luJ7lkbNBJj6XXGfK+5vze5P/g7STAoUhGai4IwV+GH53t7ez7TQ8BC99nY2vnARzTwBsCjAyzjTHTq4whUAoIM7nOEJAHBCKaCUN+CpAWoCSAZ3AnE1rk8HwBlcNSQncOedTMB0ctICSDc2yjA20hkXlt8MgmiaeHkYV6yTPTsB8ILp5hEAGL9zjWifnAGkGpsFGPfdzTitPGD8qEApBidv7s/Jf36dMdjDHxt34355gTtu3vyudwA0e42v+eMM47YB9KkKbDkPCc/y75MgCOJJYotGI+lXAVHf2l40UdSIGU/xfD/7N2rDHcA5VIcrgAuoAh2Aiwg3dqPGi8NS3sA1Ddc3uQO4CfN+9y6kfa/gSJQJ7mJfC4ABMBgffCONsQ0gvdHnLJoGmAdn+GdeYHgYnz3B9bvuAAIhJEhojLvhDFOXz/e95QGEGldv7HfFfa4rXFENZwHoJH2thj+2fB/rbHbIqV81QhZVhC2QZuLowxi2GI/F3ktAl6oAdoM7txtPMOLzOa8FXOEq6ARneMIJpVC2bhruXArmdBN/L4Y/d5cHdz7k5ZDkb6oDl+HG6yYncOdqjXEef2NIC+nNIX5ZHvn52gPSc7crTMF047k7y7gZX5h0k/i8DZjuL90Fd/4uxd0kcvIuhRuogWBctaqZLnnXRL37pwGIzucvAdhlfM13Q4B5f/sM/z4JgiCeNFQ1EyC9VybrY0+Wrgkv4yxxXwtA0t9eRTBXjL2UcUFPDXefUtzv3oPy/UqzvjcT3D1FPYAc43ve5orXTe4waSa9cYW2aCb+viX/4O9dloLQB8vvW8p1Ey+3SjGUrXwVAKebdHBHOC4CcIU7dKiDUwCchfuTGg9w/asLSDNZID+aiQykiQKFv0gGYLwZwlEFF6QNKzI71+yuMl0eX9QpTBOTK3uvlz1nyuZnmp5SwUXHUxVWe1thGsAJKjs4WbqmXe0JgiCIJ58uVS3PZ1GOrVe1D1Lqs1Jh6uPkXaEwQR6I0cvey/tYS+hEz3xWtgLu4G7yWUN2XWF23WFEfJ1CEARRpOwyn8TdAOHgboxYIRBcchiPL7gbLb7yhu6yZ0AaeFFDrzJdaVQkoNBhqHPP9qYXEQ6ASzwAuEQEgEtMsAlRooO8j91CAWeCIIgnApvP1y+oz1LrN/h+xmas9mHpotfivtFejaQGH6zhgzRQeBZRBtz1QgC4aweiREPBGcJu5CM27IWL0DoCH42Vj4KRv1ZCLjTEAkM+T02UqCA/SSebN7kKy0Eai0LDQkcjhoQGQRDEU4TCud/WAL5in5Mse2/zTTL5jTdrfaY1SzNxMEZ+o9CCwChMtnNP4uubLszeJBKCIIiCIxgKeqm8lYUsnjp1stdKCW62wJ/bxX2DlQDNXZj3OTdFr6/lT8M4msBAEARBPFlIgu3tzedb6g/4ID8f9FdE3B/dVJh/F9Jk7UyzF5BqIXFQpiACNGL4fls2akUxYQOK1xCK1xoyBEcfolCh4AxhE4o3Kbart3c4CywAXITXFyZ7MAmOZoGpoXaytDELzE6hoZYFpoRix2IhC4wgCIIo+RREIP3pyQADTBlgPI4HZuTJH+LrD/F1iYBCFjtBEERho3SjQ+w4oEioI1uS6yaloLg9CW7yYLy4X0iHmX4qDscBG5PaCIIgCILHLrcBQL1/M3MbAKy7DVjTVDa6DQDW3QZCzSeJ3QbE1yLkNlC0UHCGyD923NywOwvMF8WcBZau0NbIXRtXTRAEQRAqlLgMMAG1DDAx+Q3QiPtseWcve+8Lm4bn25IBRhAEURJQu/FRsHbQSo4D8nlqKNXxtPTaCG8HrURhOw6IUehTechxgCAI4inhiXcb4Mmn24Ct+W12X1OYQ24DBQ8FZ4gCw6HhbqGObCm/WWBiSzOgUOrOJJteFmUWGAkNgiAIAshHBphqvRkepXoz4mdbULshqDA83xcqI2nBJXeEWt6SUjY6Dc8nCKIwEd+oEG5gWHAcsISqHbTcccAMteK81upzAsp20PJ5DtadEVPYjgMiyHGAIAjiycLivS0LwXcxfDBfntCm6DZgKaHtWXcb4HHwWoawjrUrM4Kwm3r3TwsBhmo4K5wIq+CC9CRYkdnpMewO00lPB+6kp4W5gNApTFNC7ocvRg+7Tnb3IC3OqcBVBCMYV3ER4ZKMuHOojmo4i79R2/KJkKc9hJNil6r5rwFEEATBo9frkZeXV9y78dSyu0YNaENCJNOygkRvfIzPTUyTWA73fNq3CpDFvXZBaQCAF7wAAKXhAgB4CA1uoRwqIhfuecbGGg1wHYALAA9w9+seAPAzbsDJOD3TOC+L33I2gIcADACY8b0BpoQIfro7uD7TA1y/6mJ8djWuh8/4cjYu52x872Zs42pcBqL1ZwPQcE28jZsoBaC0cb+9jZtz4ZqhHAPyAP+sXARCg0q4DMAFXvASjhVDedRMvYgshEDDd/v88TZ+B9oc03HPysrC04azszN0uvxYwRIEkW92wSzpqg5OCTeQwnHR8mhI3nFAHmjng9iSOIm7cQL/rIP0JpGSjrJErmgdvBbj9RLfiajAOw5Y0UsEQTx7kP4glJBrpjahIpniI27IPbFGAIzX8iyL6ywtaSYAYMY+0T0vi9NM3EKczgC4bo3/ad43TteA0yO8fEEWuL40z7QDMMCkeVyM792MbbTGZ3mCm1b0LNZNvF7SghNu7uA0WJ5xuxpuH3jdVBYmzeQHwB+mz6UBkMfgn5WLcrgFQIPSRs3EHy9mvNDIynkMAJxuOgzumJNmKjQoOEM4xJbz+ctAKlv5qnpWr/ii/ZF8pjeUs7P4gIwlgaGHcrYunwWshXkWmA0BGnlg5iZMwumaRnXY4AVUsVqPh0VRhi9BEIVPeno67t27h+zs7OLelaeast98I3nvoQWS+Dfi7sZN9DqFe3J+wPVfOdAhHEAudAgFoIcWkcb3AJBr7DgNuclcl+YFoAq46/dQ47PB+IDoPRM9CzOY6FHKuDKm8IB4QdFrcQKGRvSs9HAybjMVXCKGxjSZf3Y2PvOvec1j1AROSbnQwgNaYyKIDqHQoil3/KBHssF4jcAHxDqA6+pbA49zOS0DAB4hIUhKEr6ZpwpXV1cEBATA29vCTVSCIEoEwbgqtVwJhaINmAQ+DiOZwCedZRin8QEWXjfxGkgtwY3XSfxrsZ7KFL3nk+hEHRqv5Xxlq7wNkz1lMoRRkHcuBauPFlLgZOma5vZwL8Am2+0tGg3ZsRBEMUL6g7CEXDclKbkhK2imHCcdnB9wr8ON/VOo8Vlv7MvEusmQqwWQzMmcXHC6KRSmoAyvm1Q1E4NUXLnAFJCxppnE723RTU6i14/BJbQZ+3ZnWNdMAJBl0kwAoEW4UVUCWjSFs/E6INmg5zRWEDjNBACtuaeyotumpJkKBgrOEDbThTHzoYXboTqk0K4sMMB0gS4eMlgkWWA8allggGqQ5i44OwErFKXQIAiCsIf09HRcv34dXl5eCAgIgE6ng4YsEguF9EfSjANvF9EbL9FrY9/HRHZeWVqucZZRheQYR6ZkG9O2cozPeqPYyMvRAXrj95gDU7eYB1OyVi5MOoIXHUz+RqxE5OJDPB0wFxs8coEhjrjIoy8aCDf6NJCKC16A8SJDB0776BicXbgPqIMeLsiBqzFzzQXZcDNGb9xyuWmaRzDdMHzIPaWLMsC8w8JUPseTC2MMer0eaWlpuH79OgBQgIYgihHNXnMrLrHjQP4ROw6IkeskSwEZXmOJAzNqr2VaiZdQvgqrdsBxgE9qI8cBgng6IP1BWCL9n3/gHyDtKATdpKCZAJNukmsmQFk3STQTwOkmXg/wA2H4AI1YN0k0E4wT+JlircQUnu3VTGLtxD8AkxgyCiW5ZhIbF8g0EwA4u+gFzQQArsiBC7ggqRuypJoJeKZ0U3FpJgrOEAWDwhB9NcyywABupIlSoWIeh7LAIJqenyww/r27+aAdX9l7lSwwHl5oqOGI0JBDWWAEQdjDvXv34OXlhYoVK5IoKmTkg77dxIeb745KmSYxo6bI1LoJJmF5cAYAGGSZX07GFhrokJvtwo1+12i4hCq+OxPfY5PrBUkGGK86+AUMCg9+QUAqNMT9j0bhWS4wxGldTsaddTI1F4sKPutLXGrOBYALg8bVCS7QA3CFMxi0xn3TIReuxvW5GfM4NHoAfJKmcdfE342bmzgN7+nB3d0dpUqVwrVr13Dv3j0KzhBECUBsBy0m/3bQgMlxQMkO2hFLM3lgpugcB2yBHAcI4smB9AdhCblm8nUVvRHfxTYGbJhYP2m5635eM2XDzTjQ3hXOxvdcahifDOYC5Gi4dTGYbg3yP8tccBqElz+8nBG6K72oAR+1UdJMSsEZ+agZawlt4mQ2PvriZB7HEY+W4XWTUTMBgMbVCU5wgrNx+1oYoDP26a5wkmom7iAKu5WabXJ0842IwNNIcWgmJ+tNCMJ2lC6IlYrxSgi1ZwtqF/3yOKO1YpdKhS0tvbYBq0XCTFCBS4IgSgJ6vR7Z2dnw8fEhYVTIpB49qj7TR32WmMdWbnzlKFp3yrC5JqVB9FrpZpmauJAj/10pZYSp4GxxB01lbVRnm3sgazIUGj5DaDQa+Pj4IDs7G3q9I6OMCYKwl/yO1lAdfR9gfJQBlzDmqdyMQ9x/5MdHXe6RL85cU7KeFnHX8mweVetrFZSCW7YmDVosOE0QRKFA+oMoKeRmu1hvBJhG0ojtzQAo6yU12zJHkY+iAUy38u2/pa91zZG8d4PJVtDDUj+eZvemnmiKWjNRcIYoFMwsuYxUwQXphHxkR5mEhTgIY20wmPxPpRSkkWeB2YA8MCMeBWR3lhtBEETRwBffpCLhRY+vlcCCOANMTrZxmH6WMERfYaRHjqjvsdT1KSIuyioeHSMfGWMJpaCMWkBGQWDYeoXqIt0fpYAMALjnPn3FKh2F/79T8V2CKDwUR7Hzo94VbIrFI+fFNSkVR9yXh2mkvhxfyHLZ5EF9ney1UoKbEpYS2MR6KR1m+ukRuHJicm6LXiebzzZzWpDB22dbRWTBTUltBFH8kP4gLGExoU2MglbK1OZj9Lu49JHNyWw8am4C1pLYHIHXUbIsNn7EDN+18w8FXOxJRH+GE9uKUjNRcIZwmOLLAhMPKSuqLDAlz2YjlAVGEMQTDmWtPYPwGWBm2JIBVhDIgzSOX5LKM8BsxpgBlioSY0/r8Hwx9H8niCeYUEcWcpc9i70hIZqmhoOOA5lQDsoAxeY4QBBEyYGuRwhbkCS02ek2oJbQVnhuA4D6yBlrrgOF5Dagg1W3ASWedbcBoGjPURScIezCYhaYAmpZYIoUahaYGrZmgYne80KDssAIgiAIGyjKDDDJ8HxbMsDMhueLZ4hhsH/0jBoaldeA2aWpPAPMBqwOzyexQRBECcCaHXT+HQeUktiUTqS2jJYByHGAIAiCKGlYchuwCWtuAxYHmSi5DQBPmtuA+DW5DRQPFJwhCg6FIfpqmA3RD3Vkg/ZkgYmnK4mJQq47oyA0KAuMIAji2eTJygBTqzdjqailGmo33SwMz5ej1NXDzuH5BEEQJQg1O+iCQa2IrT3JbHKewLoz7c1n85DjAEEQRMmg2BLaJDMUpjlcb6YgKDi3gfwidhsgChYKzhAFjrUsMDOKPAtMThFkgTkIWZsRBEE8exR6BphFCrLejHw4vvi1BaHBeyZbwsrwfKXaMzQ8nyCI4qJQ7KADoWwH7StvKE9os8fSTIwDdWcKyHGAT2rj4ZPabHYcEEGOAwRBEE8G+UloU0MxoS1fQQdrbgP5CdjY4TYAWHUbkFtBW3Ub4Ekzn/QsWEEXJRScIQoNtSwwsyH6DlEcWWAqdWdSYZ4FdluhHWzPAiNrM4IgiCeP5ORkaDQaDBkyRJhWIjPA9EDXdqHo2j6Ue19kGWBKSQP5uBS1MDxfDA3PJwiiOLBoB63gOKBmB23mOABwdtCWMLs3JQ/QAJxuUkpmU5pWvI4DPLzjgCXIcYAgiGcJJf3xrMAntA2NnQJPTW1cTr4uzFNzG7AJUZe3eNEUNKqjwbEjexQaWrMys7XejBwH3AbkiW0F4TZAiW1FBgVniHxRKFlgQAnNAhO9F2eBWSPZfJI8C4yszQiCIIqf2NhYaDQa+Pv7Izu7cMZtl7gMsDylieYZYFlZj5CYuAgTJw5Hr16RaNy4PBo1Ko8bN1T6cVWURs4o+CfbiDwDDAAunbuEuN6jUT2gEfzdI1CvQX98/c16MPGNUmMGmHh4vjwD7Nq1axg+fDgqVaoEFxcXBAUFYejQobh61d7PTBAEUYCEFsRK5DpJLcGtGOrOOAg5DhAE8SRSWPqjdevW0Gg0Fh979uwR2k+ZMkWYPn78eNX1vvfee0K7KVOmqLb7/fffhXbr1q1Tbbd8+XKh3eD33lNt9/V366EJaQRNSCMMGam+XbsopHoz33zzCUaM6IlOneqjRYtQvPjicxg0qD2+/34xsrIeK6yreNwGsrOzMW/aPDSu2hal3RqiStALeO31j3H30n0rK5ViMBiwcOFCNGjQAB4eHvD29karVq2wefNm1WUOHz6MmJgYBAQEwNXVFVWrVsWkSZOQmal8HfHgwQOMHz8e4eHhcHV1RZkyZdCrVy/8+++/du1rSYWCM4TdWMwCU0AtC0yR8uCG6KvhUBZYMdadKeYClyQ0CIIgbCMjIwM//PADNBoN7t+/j40bNxbLfliyNCuoDDDLKNebuX//HhYsmIzt2zcgJycL3t6+sjb5zQBTwBnS4fk25Fxc+e8UOjfujF827UB0h0i8MaY/8gx5GDn6U4yZMNvmDLCLFy+iYcOGWLx4MWrWrImxY8eicePGWLFiBSIiInDxopXrGYIgCBuwZgdt5jhQ4HbQJaDuDDkOEATxjFIU+uPtt9/G5MmTFR+hoaFm7bVaLRITE5Gbay4ecnNzsXLlSmi11vuO+Ph4AIBGo8GyZcusttc6O+OXffuQkppqPtMHiF+7GVqtciTio0/fxvHTmxBUoazV7QB21psxiJ5tdBtYty4ejx8/RNOmUejb91VER3fjAiHzpiA2tiuysvh+0h59VLBuAwaDAbExsZgzeQ78A/wwctxANG9aC/HLNqFZ+1jcvffAplUzxtC7d2+MHj0a6enpiIuLQ9++fXH27FnExMRg4cKFZsts2LABLVq0wPbt29G+fXuMGjUK/v7+mD59Otq1a2cWpExJSUGTJk0wZ84clC1bFqNGjUK7du2wZcsWNG7cGIcPH3b82JQQKDhDFCwKQ/TVMBuiH1oQO6CUBWYLhZQFpkBhFrgkoUEQBOEYa9euxaNHj/Dmm2/CyclJEBT5IT+WZg5T4Blg3BB9X18/LFy4Djt2nMHmzUdRs2ZdKzuilAEmfi8XHc7Sl2pXqFaG57//xvtIT0vHio3fIP67mZjx6Vs49mciWraoj4VLfsDBY39b2W+OsWPH4s6dO1iwYAG2b9+Ozz//HBs3bsTatWtx584djBw50qb1EARB2IKaHbQYVccBQOo4YIbYDlopWGMPhVB3Rkyy+SRyHCAI4mmlMPSHnPHjx2PKlCmKD6XgTIcOHXD79m1s3brVbN62bdtw69YtdOzY0eI209PTsX79etSpUwcvvvgifv31V6sjz19s3hw5ej1+2LYNgNRt4O/T53Hs1Gl0bBOpuGz58mVQvUZl6HVcfydPYOPf2+Q2YHNCG6BWb2bbtlNYvvwXfPTRfIwa9SHeffcTrF27Bx069MT58/9h8+a1Nqy7cN0GNqxYg73b96JHvy7YdmAdps96E+t/+AyLvnwPl5KvY+LHX0vap6oM6vrxxx/x448/IjIyEqdOncKXX36JxYsX499//0VISAjGjx+P5ORkoX1mZiZef/11aDQa7N+/H6tWrcKcOXNw8OBBjBw5Evv378e8efMk25g8eTLOnz+Pt956CwcOHMCcOXPw/fffY8+ePcjOzkZsbCwMBvl38WRBwRmiUHh6s8AcrDuTbL4ILzTUcKTAJUEQBOEY8fHx0Gq1ePfdd9GmTRvs3LkTly9fVmybl5eHTz/9FOHh4XBzc0N4eDhmzpypelG47+hRjJo2DU169YRXk1bwatIKEQMGYfH3GxTbO+kaIerFN3Dj+m0M6f8uqgc0QmipuhjUaRAuX+L26cLpcxjdrS+ala6EyFIBeLvHAKTcVkg/VhIYeRCSuzLSUvHJJ8PRvn05REaWwoABTbB9Oy8YTH2zh4cXmjRpDR8fP8V9NmEpA0ytkKWT+SRLqAzPTzp3EYd/P4wWbZrixQ6mO3MuLjpMf284AGDJ6o1WV5+VlYXt27cjMDAQo0ePlsx7+eWXUa9ePWzfvh2XLl2yYWcJgiA47LWDtuo4AHCOAwGyaRbtoHl0ste2aKeSUXfGbsjajCCIEkph6g9H6dGjB3x9fRVHuyxbtgx+fn7o3r27xXWsXr0ajx8/xqBBgzBo0CAYDAYsX77c4jKN69RBtdBQrFIICi37YTOcnZ0xuGcnYZrYbeC1IR/CU1MbV5KvcfMYwysdX0FFTUX8tPZ/knUxxjCma0c0cnPCrz8aNY8lB08RmzbFo2/f2oiM9EDHjqGYO/ddPHqUAXm9GVdXN9E008iatm07AwCuXUuyeCyUKVi3gXVLVgIAJs4cD0+NqT7n8Nd6oHJoBaxa/wsys7IEK2gxYivoTZs2AQA++OADuLubrjUCAgLw5ptvIjs7GwkJCcL0AwcO4O7du+jWrRsaNmxo+nQaDWbMmAEA+OabbyR21Js2bYKTkxOmTp0q2Y9mzZqhS5cu+O+//7B3r8JN6CcICs4QhYotWWAWKXFZYKL3lurO2DCChoeywAiCKIkwxvA4J7dEP5iSzaYD/Pfffzh06BCio6MRGBgoiAjxhaSY1157DRMmTIDBYMDIkSPRvn17zJ07F2PHjlVsv2DFChw4cQKNaj2HUf1exsDOHXDvQSqGvz8Tb083ZQaJRcaDB+l4scVgJCXdRJ/B3dGsdTPs2rYL/dv1x7l/TqN38054/PARYmIH47mIBtj1vx8xcVB/bmEbM8D0+hyMHPkijh/fi44dB6JLlyG4ffsaJk4cirVrv+H3CubiIj9Y8E22B4Xh+Yf37AcAtI5uIUx3z+XERoum9eDp6Y69h49bXXVKSgpyc3MREhICjcLNurCwMADA7t27Hdt3giCeGSzaQSs4DojtoMWYOQ4AnB20JazaQStVC7ZkB81TdHVnCsRxQAw5DhBEiYb0R8HpD0dxc3NDv3798PPPP+O2KPHr9u3b+Omnn9CvXz+4uVm2Vo6Pj4ezszMGDBiAHj16wMvLCwkJCVaPXf8uXfDv+fM4edp0HzFHr8eqjb+gfaumCAorY7ZMptZ8XzQaDeYmzIV/2TKYOHw8bly+IsxbNf9LHPz1F3QeMATRPfuo74ys61u1ai5mzx6D556LQN++oxEQUA6rVy/E6NFdkJvLN7asl/bv3wkAqFKlhnhvFZ7V6s0ABeE2kJ2VhROHT6BK9SoIDqkgaaPRaNCudRM8epSJo39bv59769YtACZ9JIaftmvXLpva+/r6ws/PD5cvX5Ykwd26dQsBAQHw8vKyaRtPIo4OLyCIAqNs5avmF9583RnxRbun8Vlyve8um6CDSRzwZyO9bL6SwNCK2qq9tlyQWZWbMImnaxoHRglxQsMs0PUCbLKR26LRKAtDgiAIC2Tq8/DcJAsFxUoA/01rDw+X/F/K8BYCr7zyCgAuY2zEiBFISEjApEmT4ORkuurds2cPli1bhrp162L//v3w9OQ6pw8++AD16tVTXP+cCRMQUqGCZHh+rmcuOg4ZhwUJazF2TD9UqlhOssw/f5/DqDdfweS5kwFww/E/GPEBVn69En1bdsGIKR/glbEjkAMdGGMY3bEn9v+yDWdOHkeNmg1sygC7d+8mgoOrIj7+AHQ6LQA9hg59GwMHNscXX0xEmzadUbZsOdlSSqLD1j6GExfHju3DsWP7IRUdzqYm/MMJgBYoXzEUXfoOMVubfHh+8nnuIr5y1VCzts7OzgirFIT/ziYhNyUXWq1WdXi+n58fnJ2dcfnyZTDGzAI0SUlcttu5c+ds+MwEQRCFQCgUR+Yr4w7T6H8dOO3EnfOlKOkkS+TCpL349fJ6KRPSRDoRqcZn8T222zDpv2SY2V1fRbBigOocqqMazuJv1DYLbLEoZTcHgiBKLqQ/Ck5/8MyePVvxprabmxsmTJiguExcXBy+/vprrFy5Eu+88w4AYOXKlcjNzUVcXJzFa+BTp07hyJEjaN++PcqVKyd8tpUrV2LXrl1o27at6rJ9O3bEjK++wrqfNqN1PS7YvunwXty7n4q4Pl0tfk45ZQLL4LMVX2JYx354p38slv6+CxdO/YUvJkxEpfCqeOfzL6UJbVbcBg4d3I4VK46gatVaAPLAWB4++mgwtm//AWvWfI2BA0eZLb5y5UJkZWUiIyMNf/11BKdP/4WmTaPQqVMvKI+EkQZjMjJSsXr1N5C6DThxbeSayfjo9+o4lPL35ZqruA1cvpgMg8GAKlUrme9BBlC1Mnd/9nzSVbSsXl95JUYCArjhu0lJSahZU5ogoaSZxO3lpKWl4cGDB8IyVapUEZa5c+cOHj58aPZbflp0GQVnCIfowpgw9HvLeWPm0XZIspHE1MEpwaYrHBeFUSLBuGpu71UeiplTAvJ4DNxhCp5kwJQFJj67amF+ttVDGk5WCtLIs8BsCNDcg7m9gIw7l4It+0YbURIairSHkIXXpar9tgkEQRDPKnq9Ht999x28vb3RrVs3AICXlxe6d++OxMRE7NixA9HR0UL7lSu5IeCTJk0ShBEAVKhQAWPHjsVHH30EQFpvJqSCNCMJ4Iptvj6gB37bdxi79x3F4H6dJfO9vDwweYbUUiumXwxWfr0Svv5+GDjmDWG6RqNB9Mt9sP+XbTj3z19ccEbxw5pPGjnyE+h0LsLMwMCK6Nv3DXzzzXT8+ut6DBw4EkqFLu1DKj6OHfsDS5bMsnnpBk2j0GXgEKvtHqdxF/PePqXgoZC57e3hCYPBgIzHj+HnLb1pKB6e7+HhgVatWmH37t346quvJPVlNmzYgJMnTwIAUpUKlhIEQdiJZq/56PhqOCuMpq+CC8JIewBcopfdtl/eUB7Rwgdk7AnKiBPY5GTCpK/4gJC7dNOecJiLCEcVXMAFVLHN9s1OKKmNIIiioLD0hxJz5sxRnO7j46ManGnYsCHq1KmDhIQEITiTkJCAunXrokGDBhZvhPNBp0GDBgnTBg0ahJUrVyI+Pt5icCYwIADtIiOx+udfMWf8OLi5umLZ2s0o4++HLi+2wrGLxoRlC3ey+foyWXBF1EttMXDsCHw3fxG+mDAR+7ZuA2MMM1Z+Dw8vL7vqzXTsOAhVq9YBX2dGo2EYOXIqduz4EVu3fi/STKZEtpUrv0Ja2n1hHR069MSECTOh1Vpz++GiLhkZ6XZpJgDo3GeIKTgDKLoNPEzj+mdvH5NtA+82AADepbjfWFrGQ6vb69ChA9asWYNZs2bhhRdeEEZVpaSkYP78+QCkmikyMhLe3t7YuHEjTpw4gfr1TcGfSZMmCa/Fy3To0AEJCQmYOnUqPv/8c2H64cOHhdpIT7ouo+AMUfDsgs3evmaE4gnIAgMUgzSpxmcHssDkQoPPAlOCssAIgigK3HXO+G+aSsS9hOCuc7beyAqbNm3C3bt3ERcXJxmiP2jQICQmJiI+Pl4ijv766y8AQMuWLc3WpTQNADIePcLS1YnYuHsvLl69hkeZ0htkN27Ji5YBVaqGAB6lJdPKli8LAKhR53loNBqhqGVutgv8y3FDNO9dvWFaQC0DzGhN7eysRe3azSAtZslQr15zAMDZszYkBygiH5Yvfq3Ba699gNde+xDcaBktJGPydcaXcu9kfp4sA8zF3voGNjJv3jy0aNECo0aNwpYtW1CnTh1cuHABmzZtQp06dfD3339LMhoJgiDyS737p63acik6DgDKiWGp8gl8hhv/rJTMpnZO5ZPalJLZdMiX44A4sa0oHAcoqY0gSiykPwpOf/DcvHlTGMFiD7GxsRg3bhwOHjwIADh9+jQWLFhgcZns7GwkJiaiVKlSkro0bdq0QXBwMP73v//hwYMH8PMz1bB8LBtBMbBrV/z8++/43849aNWwPn7ddxhjh/aFrrQWjsTk35o1FUf27MPK2ZyV9OiPZ6FmLVNSlq31ZurXFx9nTjuVL18JgYEVcenSaej1OUYnAoAP0OzY8S8A4N692zh69A98+eXHGDKkM778chUCA82T9+QEBYXgyJEMmMSRE4QECPEkJd0kQu42YC9qbgMA0L9/fyxfvhy7d+9G7dq18dJLL0Gv12Pjxo0IDORuhoo1k5eXF+bOnYthw4ahWbNm6NWrF8qVK4cDBw7g2LFjqFGjBs6cOSNZZtq0afjll18we/ZsHDx4EE2bNsXNmzexfv16PPfcc0+FLnuy954o0SgFEMQBhyq4IJ3pwMW36nB5mwtbiim+ujMOQwUuCYIoJDQaDTxctCX6oVQPxF6UsrsAoG3btqhQoQI2bdqE+/dNGU9paWlwcnIShmSL4S9AxeTo9ejy+uuY9u1SODs54ZXOHfHhqFhMHvcqBvfiilpmZ+sl9WYAwNvblBXHZ4DlaT0AAO7ePmbb0Wq5Ps/keWzEQgaYr2+A7EKWExr+/lyWwcOHabCv3oza96E0XeES1NarUoUMMAAoZcz+Sk/LkG7d+Db94SNoNBqU8vCwuom6deviyJEj6N27N44fP44FCxbg7Nmz+PbbbwX7ibJly9q4wwRBEBz2BgKsjg4pD1MimBxf2FB3BuB0kziLV6nujJxCqjujgK11Z3iXBoIgnlxIfxSM/igIBg4cCBcXFyxbtgzLli2Di4sLBgwYYHGZjRs3IiUlBb169ZIUh3dycsKAAQOQlZWF77//3uI6er8QiUD/0li2cTOWb98Kg8GA2D5dFNsq1ZuR4+LqipYd2gEAXN3cEDN0mNVllPIUSpcOBKeV8kRTDShdugwYY3j0KANqeikgoCxeeqkHPvtsKZKSzmH+/GkqG7ampWSBQaXAjA23Qf19uOOWnpah6DaQnsHd4PTRmtvhid0GAE6D/vzzz5gyZQqcnJywePFibNiwATExMVi/fj0Ac80UFxeHbdu2oVmzZti0aRO++uor6HQ67Ny5E+Hh4WbLVKxYEUeOHEFcXBySkpLwxRdf4NChQ5g2bRo++OADxW08adDIGaLQKd4sMKW6M0oUcN0ZubVZYWWBiaEsMIIgCLu4evUqfv31VwBAVFSUarvExESMGTMGAGcBYDAYcO/ePZQpIy1KyRfNzLlnutu0be9e/HXmDOK6x2Dp1IncRGNsZc2vv2LF+p8k67BFZCiSK7qYtzEDLDX1HgwGA7j4jGn0TErKHQCAl5c8ASK/9Wa4x7Fjf+DYsT9E01TqzRgf5UNC0aXfEKtbCKvKFYS8dD4ZQJRkeH5eXh6Srt5AWHCQEMiyRo0aNbB27Vqz6UOGcPsSIRMnBEEQSojtoAV4O2gFxwGxHbQYMzvoUNjhOKCGfLQMOQ6QtRlBEIVJYemPgsbf3x8xMTHCtXC3bt3g7+9vcRk+6JSQkICEhATVNmLLYDlarRaDunTCnJWr8O+lS2hc73nUqh6u2h4Aco337HIUiqz8ffgIEj5fAF9/f6SmpGDWmBGYuWytbfVmDBCkzv378uPMjNPvQKPRwNPTPJAh5/nn68Hb2xfHjx80TrHsNpCRkYbVq7+GVBwZ2znBTDPBCeg3fBxKBfiabVvsNlCpciU4OTnh0vnLivt5/ixXgqFqJduSI1xdXTF58mRMnjxZMn3Pnj0AlDVThw4d0KFDB7Ppr7zyCpycnNCggdSmu0KFCli6dKlZ+ylTpqhu40mCgjNEsSCuO6NKILiLdDm+xmfFhCy+7gygXndGLDoKoe7MXUiFhgp83Rm1Apc8VOCSIAiicFi+fDkMBgNatGiB6tWrm83Pzc3FihUrEB8fL4ijunXr4vjx49i3bx969Oghab9v3z6zdSRduwYAiGnTipsgGvSy79CJfH+G3GwX2xoq3GvLy8vFqVP7UbduU9FUA06ePAAAqF69Ngq63gwAHDu2D0uWzLR5DQ2aRpkFZ5SG5zeLagYA2PPrH3h/wmDJvD8OncSjx5mIasJd6Fsanm+JjIwMbNmyBf7+/mjXrp1jKyEIgihyrNlBW7I0U8LWujP8eyt1Z2yo2WkPVq3NCIIgiomi0B8FRWxsLNatWye8tsTly5exc+dOBAYGonPnzoptdu3ahRMnTgi1RsQ1OiXb7d4Vny//Djfv3MPkca/avd9ZxiDN/Qw93u0fB2etFkv2/IqvJ32CHet/QLM27dG1t+zzWHAbAIATJ/ahU6eBxndcUtvNm1dw+/Z1VK5cAzqdDpJojgKPHz/Cw4fpCAhQG+0k1U0ZGWl2aSYA6NxvCBecMY9TCbi7u6NB4zo4eugkrly+gUohQdzWMwDGGH7bdxieHu6IeN5ykr01Vq1aBQDo27evTe3379+P5ORkdOzYET4+5k4RcvLy8rBmzRpotVr07NkzX/ta3FBwhnAYcRbYlvPcaA0hC0wBm7PAAG6UyU179sYdUkFRjFlgvrJpVrLAeOzJAnMEygIjCIIwwRhDQkICNBoNVqxYgcqVKyu2O3fuHA4ePIijR48iIiICr7zyChISEjBt2jS0b99eKMp5/fp1RR/m4PLcsMk/TvyFLq1bCdP3HjqGJSs3cvuicPGcZ/T44i3N5ORYtJuB5Qww0Wj8RYsmYtGibUaPZIbbt69jzZqv4eLiiujoHgorsYblDDBTzZmPYPJNtlJvhi9pIEOcAeaGbFSpXgXNWjXCH7sPYfvP+9CtXSMAQE6OHh9N/xYAMKxLjGQdKamp0FesiNx79yRWEZmZmdDpdJJRNtnZ2YiLi8P9+/exYMECiUc4QRBEftDs5ZKvxFTDWZwDd+OuCi7gIkSZwxUZNxpfiTJQcBoAODtosY7hX4t1ki1BmgKuO1MIjgNWIccBgiCKgaLSHwVFdHQ0Nm7cCABWk5ISEhJgMBgwfPhwTJ06VbHN4sWLMXz4cMTHx2PhwoVm832NmqhGWCh+XrEAWdk5eLFFY6CUWVOb3Aamj3gLVy8lYcLC+Qiv9Tw+/Gox/j3yJ2a/OwZ1G7ZASKVq0gUsBGi2bVuJPn1GompVLmDBGMOiRVOQl5eHzp37C+2uXUtCqVLe8PHxla46Nwdz506GwWBA8+aW6hOYNFNQUCiOHHkEu+rNiO/yq1hBA8Arr/XF0UMnMen9+VizYipgvLf77fINuHTlOl7r3x3uIq2jz81F0rVr8PfzQ5Uq0kT79PR0eHtLHRfWr1+PZcuWoVGjRmYBRaX2N27cwLBhw6DVajF9+nTJPL1ej9zcXIlVnsFgwPjx43H27Fm8+eabCAoKwpMMBWeIwkFhiL7NhMKOIfqFmQUmH1VTgFlgVOCSIAiiWNm1axeSkpIQFRWlKowAYOjQoTh48CDi4+MRERGBNm3aYOjQoUhISEDt2rXRvXt3ZGdnY+3atWjatCm2bt0qWf6lli0RGhSEzxJW4p8LF1Hr+So4e+kytu78A907tcb6zTtt2t8sS+lPaljJAAsIKI+srEfo168BWrbsiMzMR9ix40ekpd3H+PGfomzZ8hBngM2fPwWpqZz/9cWLZwAACxZMg7u7JwANunXrh3r1msq2YqnejJP5JEvYcAg++2oqOkf2Qd9uY9Hn5XYo7x+An379A/+euYRRg3ujeb26kvZLfvgBny5ZgsmTJwvD4gHg2LFj6NGjB9q1a4fg4GCkp6fjp59+wpUrV/Dqq69i9OjRNuwwQRCEfThsBy1Owr0HabKYxG3AXTZBnIyWC8u6qeQ6DvBJbeQ4QBBESaao9IeY2bNnw8tL2XLrpZdeQtOm8mt3E05OToiJiVGdz2MwGISgE2//q0SfPn0wbtw4rFq1CrNnz7a4zpdaN7e6XQB4rNLPbExchy2Ja9C6Swf0Gfk6crNd4O3ngmlLEvFGpzb4aHh/LNt8EFqNLANMD2lZGSNNm0YjNjYS0dEvw9c3AEeO7Mbp08dRu3Yj9OnzKni9dOLEIcya9S7q1m2MChUqwcfHDykpd/Dnn/tw585NhIVVxYgRExT22IFaRvLAjAJKbgMA0HdwD2xc+xPWrf4ZVy9dQ6tW9XHxzDVs2LobYcFBmDH+DQAmt4Gbd+6gycsvIyQkBMnJyZJ1NWnSBMHBwahZsybc3Nzw559/Ys+ePahcuTLWrVsHZ2dprZwvvvgCiYmJaNGiBcqWLYurV69i06ZNePz4MeLj480szW7fvo3nn38e0dHRCAsLQ05ODrZv344zZ86gU6dOmDnTvtFFJREKzhCFSoFmgaliSxaYJXihUYR1Z4zwQsMaSkKDIAiCcBzeE9mSgAA4ETF27FisXr0ac+fOhbu7O5YsWYJq1aphyZIlWLhwISpWrIi33noLvXv3NhNHFf08sGvpV3hn7hf4/cQJ7Dl6DM9Xq4xV305DYFl/SXDG4XozOaJ+Uy4mLARodDoXLFy4HQsXvodt277Hw4dpCAmphnfe+Rzt2/eAuA4NwLBr11bcvHlNso5du7YJrxs2bKYQnAGkI2d4bInGKGAhA8wVWajxfDXsObwK0yZ+iZ9+3o9HjzJRrUolLJr+Lt54pZcpn8MKlSpVQuvWrbFv3z7cvn0bHh4eaNCgAebOnfvED5snCKJ4ERwHbMSqHbQ1xwF5PEZwHODtoJUcBwBzO2g1+HY8Re84YDM2WpuR4wBBEIVBUekPMXPmzFGd5+vrazE4Yys7duzAlStXEBUVhbCwMNV2Pj4+6NGjB1atWoXvZs/Gyy+9ZNd2mIf1NllwxdWky5gycgLKlC+HScuWSOY3iGyFIWPfx7J5H2PRzA8w9oPPlVfEuw0Yu4IBA95Eq1adsXr1F7h27SK8vf3Qt+8beP31D6DTuYDXTXXrNkb37gNx4sRhnDv3Dx4+TIeHhxdCQ8PRt28cXn55CNzcPGCL24BpGiDU6ORfKkkpG90GAC7w9uOm+ZgzKx5rvtuM+QtWo7SfN+IGdsWMsW+gjNZP+bgo0KdPH2zYsAGHDh2CXq9HWFgYJk6ciHfeecdshAwANG/eHHv37sWWLVvw4MED+Pv7o2PHjnjvvfdQv359s/Y+Pj6IiYnB/v37sXXrVuh0OtSqVQtLlixBbGwsnJwc1JUlCA1jdNVhifT0dPj4+CAtLU3xR/WsIy5uKYgM3tbMOHKGD86cLF1TsDU7h+qCyLiIcMHW7M6lYC44kwxOZPA1Z+4ZH3fBXcg/gkiEZIK723IfnHi4D04kZBrfi5/5u1RykcGfwfizmRaciNCJnkuJXpcGFxQyCg13cOLCEyaRUQZccIYP0ATCFJwJhTByhg/O8FlgVXABAAShwVub8cEZ8cgZIQtMLDK2m14qjZwhoUEQBE9WVhaSkpIQFhZGFk0FhNw72Vc82kNsnSsans+Mr/ngjDgDjLc140fO8O9zoDPVm+GDM9nguje+q8uFKRGaD9rwQsMgfyNWIAbRg4me5fVnxP2JXEDwosJJ5WHn8HwduJEzLkzIAHOBXgjOuCFbeO2BTLjncq81fBk6/jmNexLXnPF9wgtI2out/3u6BibsgX4v1rGmm8SaCeCSs/iEtguoIiS0XUWwSTMBUt2kpJlSIQrO3Da+4fVTpugh1kvizkQpcMMnrsn1kjtMOorXSuLXMNdNvGaC8VmumQCgIrNbMwEm3SQZOWOjbiLNRBCFA+kPAjDXTICKbrKgmQCTbrKkmQBjnU6xZgJMXZ1YO/G6SaKZIJqQC04Dqekm8QMKz5Y0k1w7ORunaSEIJVs1EyDoJj4444osITjjiix4GC8OJLqJNJNAUWqmJz+8RDyxWM1yKg/pEH05ZglY7rJn+cAwnWiaFa9+ANIzNCBNOeMFjQxeABUiilYHYgs5Uc0fe7LyCIIgiPyhVtTSEkzBQ5nHmsiQNbaOmciQ7InsuSBQyggTT5cOcbfqm6yCm00f3kSqfc0JgiCKBLuttxyqxaJ208AWbSRGPCxTD1MARy+bbiP3ZO8VRgKZWbmpoFTjlCAIgijZqCa0FQBCQpsYeRdlxQ6aQy6ilPphpjDd1v666NwGzLacYTaJKEIoOEPkC3FGkZBttF25LQBVay4z7+BQe/ZCHKWRZ3PJX9uCJVGRqfw+E+pBGbHYEAsNu+3bCIIgiCcJW0bN8BSIpZkYmwQGj5LQkGd9WUItA0wtKAOYXYLK4jQSrAzPF+Nh1k/DLANMzLOWAUYQxJMFPxoEMI0UsUoAuNEovjCvhwnAPKFNfJIVJ7MpYelull5hmsI5Wc5d2fvbotfJ5s15xwV+JBHvxsCPMlJCbrNtC+KRTgRBEETBUVAJbfJRM3KsJrTJ9ZJKvRlThpv4vbBnUE9yy0/Sm7wPKpzb9/yoGaJ4oeAMUXjY4OerSLFkgSmJCbUgTT6ywBSwlgXGCw2lLDASGgRBEE8XakUt5RROBpha/2trkMYWFC49bbkadbXeRA5lgBEE8SQitjBWQ7FmZSCk9S55fGGD4wBgrpsccRzgIccBgiAIIp9YcBhQQu42YBNKeqnQ3Qbk9+QKx22At4J2FHIbKDooOEMUOkpD9EteFpgcW7LArGSCFVAWmM28oDyZhAZBEEThU1iWZnIKPwNMzScZsveODM/nBUbRDM8nCIIoqSjVhbSE2A7azHEAMNVocRglO2hbsJQVQI4DBEEQhHXssTQrELcBhxLaAKle4p8dcRsQv7fmNiAjH24DYitoRbcBHnIbKHIoOEMUGU92FphS3RnAVEjTSCqUs8DURtA4KDQUs8DEtLc8myAIgihcfNVGejy1GWBy1Po3Kx7KVnIpbMkAo+H5BEGURBQLzPN20AqOA2p20GaE2rMXBWkHXch1ZxQgxwGCIIgnm4K2NLOGotuAEha7KyURVdD1ZuQ4yZ5RaG4DAuQ6UGxQcIYoVsRZYIqUiCwwa3VnLE82o4ALXJLQIAiCKOEURwaYzRRWvRnxa3k2GCBJ+3KGaXg+j8rwfDlusG+8PQ3PJwiiJKPkOCDGzHEg33bQYscBsU56MuvO2AxZmxEEQZRsbExo4xPY5AltdrsNACqJbPIJhVVvpvjcBsgKuvih4AyRb8RZYMIQ/e3KbQH1LDCzIfqh9uyFpSwwWy3Niq/uDI8jBS4lqFibEQRBEIVHicwAy4UNXZWj9WYcGbYPOJQBpoNZBpiLygdTHJ7Piw0ank8QxBOK2A7aJgJgbgftK2+k1t/IgzTi6daw5DjA152RnadLguMAQRAEUWwUuaWZGDU7M3IbIIoYCs4QhYvCEH2bKNAsMDFFWHcmFXZngdmDPdZmlAVGEARRAihRGWD5qTdjK0qjZorm0pMywAiCeJLJlx20GlbtoJUcB+xJcFPCguNAqg2rJscBgiAIQoatCW2KFHu9mXy4DQDkNvCUQsEZokhQGqIvzgIzG6KvRInMAgMkWWC80HiksJoCygJTEhqOQEKDIAiicKAMMKV9k0+z8xLUzuH5BEEQJR3BccACYscBsR20meMAYIcddEHUnVFzHCj8ujM8heU4QEltBEEQhUNBuQ3wqNXk5BPaVN0G5HqpWOvN2OE24KzQlEfh1ia5DTw5UHCGKFKe3CwwG+vOqCEXGg5kgVkSGo5kgREEQRD5p0RYmhVbBhiPtQwweTuVDDC+e1bpoml4PkEQTzpiO2gB3g7aUccBwE47aEDqOCDG1vqccp4exwGCIAiiaFBNaFPQSpYS2uRuA4pYGwnCx2AksZjiqDfD48DtelfrTYiSCQVniGLnycgCs/baAoVY4FJRaFCBS4IgiJKHg5ZmPIqWZmLynQGmhLUMMFs8kvn3ChlggHoGmBZW7xGKh+crZoDxGDPAaHg+QRBPAnY7DqjZQcsdB8yQJ7SJT7q2JrPJKfmOA2RtRhAE8QwhdhsQ6yObk9l4nk63AbKCLhlQcIYoEMRZYMIQ/e3KbQHpEH2LhNq7JyUwC6wEQ0KDIAiiYCloS7OiywArrHozFrDlKlSWAaY2PF8RC2KDhucTBPE0oOo4EKDQ2Bc2Og6o1ey0pKeeHMcBCWRtRhAEUSSUCLcBs0a27kkJcRvgJ5HbwFMHBWeIwsfGIfo2Z4EBpiwwRZ6QLDCx0CimLDCCIAjCcRwRGQVCic4AkyMWGyqjZ6zhor5/lAFGEMTTii120GbY7DhgK7YEZJ5gxwExZG1GEARRZNhjaaaE3G2Af6/oNiBOaFNyG8hT20pB1ZspILcBJQllp9uALZDbQNFDwRmiyLA2RF+MxSwwcSaYJ2zMAgMKLguMf85nFpgCfBZYgRa4JGszgiCIIuHKjRvQ1GmEIROnqLaxlAFmcZQMCjYDrGvXaujatSYKLgNM7b216fajFJAhCIJ4EhEcBywgdhywaAcdas+WLdlB5zeZDSDHAYIgiKIhOTkZGo0GQ4YMKe5dKTSsuQ28O2QUnteUwvXky7atUEkvqbgNLF48HY0aeeDYMfENzSfXbUDRCppPbEszn0VuA0UDBWeIIqd4ssC0suf8ZIHJ29oQpCnOApdWIKFBEATBERsbC41GA39/f2Rn258y5C2OnTiQAcZTEjPAsrIeIzHxa0ycOBK9erVC48YV0ahREG7c4G8Oaiw8xMiG5zvBfHi+wj1Be4bnn7twGb2Hvo+Aei/CvVoL1O3VH1+vXQ+mVIjbAteuXcPw4cNRqVIluLi4ICgoCEOHDsXVqwoJJAAMBgMWLlyIBg0awMPDA97e3mjVqhU2b95s13YJgni66aJ0LuLtoG10HDDDkuOAKmI7aCXHAaAkOw5YszbjIccBgiBKMvnVH2q0bt0aGo3G4mPPnj1C+ylTpgjTx48fr7re9957T2g3ZcoU1Xa///670G75rFmq7ZZv2gJNSCNoQhqh1+vvCdPlCW1Lvl4LT01tlNGEY9SQd61+fgDqbgMOY1tf+803n2HEiD7o1CkCLVpUwYsvPo9Bg17C999/i6ysxypLFa7bgBI593Iw7bMlqBrVA27VIhHUtgNem/ox7t6/b9d6HNFAhw8fRkxMDAICAuDq6oqqVati0qRJyMxUvr/64MEDjB8/HuHh4XB1dUWZMmXQq1cv/Pvvv6rb+P777xEZGQkvLy94enqiUaNGWL58uV2fraig4AxRbBRPFpgjWMsCA1QDNI9gngVWwgpcEgRBEEBGRgZ++OEHaDQa3L9/Hxs3brTYPr+WZrbUm1EkR6WvsCMDTBnLGWD376dgwYLp2L59I3JysuHt7WvrHkNRYPBBGTUUumxxBph4eL44A+y//y6h8YtDsOnnvejQujnG9O+DvLw8jPj4U7w3e7bNe3zx4kU0bNgQixcvRs2aNTF27Fg0btwYK1asQEREBC5evChpzxhD7969MXr0aKSnpyMuLg59+/bF2bNnERMTg4ULF9q8bYIgCGuOA2Z20ErwjgO8HTTvOCDBXeW1DtITcTHXnbEAOQ4QBPGkYq/+cIS3334bkydPVnyEhoaatddqtUhMTERurrm4yM3NxcqVK6HVWg/cx8fHAwA0Gg0SZTfpJZZm7vx2nbFl5z7cS0lVXN+K+A3Cdg1GEcEntI2fORFbTh+Fb4UQbj8LzG3AAHN9ZN1tYN265Xj8+BGaNo1C377DEB0dg+zsbMybNwWxsV2NAZqidRuQOw8YDAbEDHgbk2ctRoCfL8bF9kWzOrWxdMMmRMfG4t6DBzZtwxENtGHDBrRo0QLbt29H+/btMWrUKPj7+2P69Olo166dWZAyJSUFTZo0wZw5c1C2bFmMGjUK7dq1w5YtW9C4cWMcPnzYbBtvv/02BgwYgEuXLmHAgAEYOnQoUlJSMHToUIvBx+LCkVQYglCkC2PCKIwt540XtNtRDP693jAJAHfjax1MZ19eVNgbOs8VrUdnXK9YxKSbtslv3tfKKm/CbFTQnUvByrZuRs6huqodnIQXYMq+aw8hI69LVdssFAiCIJ4l1q5di0ePHuGtt97C/PnzER8fjz59+hToNgq8qGURZoD5+pbGwoWrUaNGbfj4+GH06AE4dGiPSmulUTMO5AO5Wm8iZ8SoWUhLf4hty+ejQ5tIIA2YPup1tB42Ekt++AG92rdHdGys1fWMHTsWd+7cwYIFCzBmzBhh+rp169C7d2+MHDkSv/zyizD9xx9/xI8//ojIyEj89ttvcHfnvtdPPvkEERERGD9+PDp37qwoggmCINSod/+01VHyZStfNR9FEgjpSH0xvDySTBB3KFrZex0suwjw+kgrmqb22oY+7y64gBLPbXCfB+AcB0Ktr0KNk6VrOubiYGSLRqM88okgCMIBikJ/jB8/HuXKlbO5fYcOHbBlyxZs3boV3bp1k8zbtm0bbt26ha5du1ocFZGeno7169ejTp06CAwMxM6dO3Ht1i1UtLAfHVo3x5Yd+5D4wzaMe6O/MD1T64ZTf5/FiWP/4aWubfHL5p1my/qUD4VP+YJ0GzCozLCt3sy2bcfg6son4mmE50mTRuHnn3/E5s1r0bv3MBSn20DCyq3YvusQ+vVsj1VzpkOj0QBpwDc//Ig3ZszCjK+/xvwPPrC6Lns1UGZmJl5//XVoNBrs378fDRs2BMAFeUaPHo1FixZh3rx5mDBhgrCNyZMn4/z583jrrbcwZ84cYfrBgwfRsmVLxMbG4tSpU3By4vTm0aNHMXfuXISHh+Pw4cMoXbo0AODRo0do06YN5syZg549e6JZs2ZWP19RQSNniKLBGCSwOwtMaYh+vrPAxJSQujPJ5pMKtMClFcjajCCIZ534+HhotVq8++67aNOmDXbu3InLl5V9i/Py8jB/xQo06N4d5SIj0aB7d3z9XQIMzDgyRRxHKQXs3ncUsaOmocbzPVHKtxVK+bZCi4g+WLZ4ndBMXNQyXFMG/Vp3x+3rN/B2/zi0CaiAqABvjOvWGdeSLgEAks6dxvih3dD2udKIqlkK743ohZRbanfixBggHkGTkZGKTz4Zi/btayAysgIGDGiD7ds3mC3l4eGJJk1awcfHz8K6lfoS+bT8Dc+3lAF27txl/L7vBNq0jOACM/wqdDp8OHw4AGClDRmJWVlZ2L59OwIDAzF69GjJvJdffhn16tXD9u3bcenSJWH6pk2bAAAffPCBIEoAICAgAG+++Says7ORkJBgddsEQRDWEDsOKGKzHbQlxwH5a1vJR90ZchwgCOIZw1798emnnyI8PBxubm4IDw/HzJkzYTBYGhlvPz169ICvry+WLVtmNm/ZsmXw8/ND9+7dLa5j9erVePz4MQYNGoSeLVrAYDBg9datFpdp3rAOalQNRcL3W80S2lYu2whnZ2f0GdxDcdkPhgzH85pSuJGcDIC70T+2Syc08tbg1/+tlbRleoYxQzugUTUNfv3ZOM/MbUCZTZtWom/f5oiMDELHjrUxd+5EPHr0ULx2ABAFZqS0bdsFAHDtWpLKForObWBp/EYAwMxJI7nAjJE+XXsgtEIFrP/lF2RmWa/zaa8GOnDgAO7evYtu3boJgRmAG2E1Y8YMAMA333wjsaPetGkTnJycMHXqVMm2mzVrhi5duuC///7D3r17Je0B4M033xQCMwDg6emJDz/8UNhGSYKCM0SJRXH0SKD5JAGzRCz5BLnAKEF1ZxyEhAZBEIUGY0DOo5L9KKDs1f/++w+HDh1CdHQ0AgMDMWjQIBgMBtWb6UN69MDUhQthMBgwrFcvvNC0KeZ+9z3Gzpqj2P7TBSvw+8ETaNTwOYwc8TL6DuyMlHupGD18Gia9/YniMmkPUvFKi2hcT7qMzoMHokGr1tj/yzaMimmHCyf/QVzH5nj86CG69I5FzToR2PXLj5j4Vj8b6s2Y0OtzMHJkDI4fP4COHXujS5f+uH37BiZOfANr1y61cMTkN8XE2V5KGWD8s4KykGeAKWBrBtievccAANFtmpjNb1qvHjzd3bH/+HGr60pJSUFubi5CQkIkYoUnLCwMALB7925h2q1btyTzlNrv2uVoMQmCIJ5mbBnRLraDFpM/O2hAue4MoGxpZg8O1J2xxk3zSdbqzpC1GUE8YZD+UNUfr732GiZMmACDwYCRI0eiffv2mDt3LsaOHVsg+8Pj5uaGfv364eeff8bt26abVrdv38ZPP/2Efv36wc3Nsj1zfHw8nJ2dMWDAAHR54QV4eXhg1ZYtYIxJLc3EuAJDB3TB3/+ex7FjphGOOTl6/LDqJ7Rp3xLlgspa3f/cbBdoNBpM+joBpcuUxczxw3EzyRTwWr1sPg7u/QWdewxBdAdro5RM0ZpVqxZi9uwJeO65+ujbdzgCAgKxevVijB7dB7m5ttkZ7N+/AwBQpUoN45TicRvIysrG4T//RfWqIQjxNWZ0pBn3SKNB6yZN8CgzExds+K3bq4Estff19YWfnx8uX74sSYK7desWAgIC4OXlle9tlFRdRrZmRLGgNEQ/HBctjg5BeShelJvjDpPFGG8/Jh6eLx+qbwtKQ/XV6s7IgkJ8gUtf0bR74Eb/AFJrs2samwp6krUZQRCFjv4x8ElQce+FZT64Abh45ns1vCfyK6+8AoDLGBsxYgQSEhIwadIkYYg0AOzZsweJmzejVtWq+CU+Hp7GDKGpw4eiXu8Biuv/es4EhNaqILzP1LohNzcXMR1HYfGCFXht7BCUqVRZssy5v//BoDdH4s25XMAnN9sFs0aOxI/xX+O1Li3x6jtT0C92LJDLZYC9GdcZ+/dsw5n/jqNG9QY21Zu5d+8WgoMrIz7+F+h0WgAMQ4eOwcCBbfHFF9PRpk1HlC3L2w/YIkSVMpnNM8COHduDYyf2cM2dFB4AF7RxBuDMUDG8IroOGSQs7ybxJzBx/gJ3k7Jq5Upm85ydnVEpKAhnk5KQm5tr0Svbz88Pzs7OuHz5MhhjZgGapCQu2+3cuXPCtICAAGFezZo1rbYnCOLZRmwHLcDbQe+CNFhgKxWZ+ogSsU1Yqnym3OdMbOOcC/t0k9zGTBzkyYRdNUDFeklMMsyCUFcRjGBcxUWEowou4AKqqI4wImszgijBkP5Q1R/Lli1D3bp1sX//fnh6ctv/4IMPUK9ePYvbmT17tuJNbTc3N4l1lJi4uDh8/fXXWLlyJd555x0AwMqVK5Gbm4u4uDiL17SnTp3CkSNH0L59e5QrVw6p166hc5s2WPPTT/j9yBHEtGxsauwjXXZQn474cMZXWLZ8Mxo2rIlMrRu2rtuOe/ceYEDcy5K2fL0Z3n1Ajn/ZQEz+cgXG9euIiSP6Y/G633HhzCks/HwCKoVWxTsTvlT9DBxS7XTo0E6sWLETVas+B8AAxj7ERx+9ge3bN2DNmngMHPi62RpWrvwKWVmZyMhIx19/HcHp03+hadPW6NRJKShk7jaQkZGK1Wvmm+I3SrpJpJkA4JXxI+Dv62nRbeDixWswGAyoWlk5waFKMDf9/PnzaNmypWIbHns1kLi9nLS0NDww1ro5d+4cqlSpIixz584dPHz40Oy3bO82+GnXrl3D48eP4eHhYfHzFRUUnCGKlTo4pTj6IxhXBVsvANwFeLI9a1aqOwNIPZMdDdJYqjsD0zS1ujNqQkMEX3eGhAZBEEThotfr8d1338Hb21vwVfby8kL37t2RmJiIHTt2IDo6Wmi/cuVKAMC7w4YJgRlfV8A3sCzGDuiLjxaKhkgbh+SHhVQwC21otVoMfr0/9vy2H3t2H8PLg6XBGQ8vL4yZMQmAqd5M+5f74cf4r+Hj54++Q011UDQaDaI79sX+Pdtw7uxfXHBGFXFhS2DkyEnQ6VzAC5DAwCD07fsqvvnmU/z66yYMHDjcwrqUUPNNNnHs+B4sWTxVdb6chlEt0WtIP8V54uH5aWmcrYCPt1EwGzPAUo2xnFKenjAYDMjIyICfn7o9m4eHB1q1aoXdu3fjq6++wsiRI4V5GzZswMmTJ7n1pqYK0zt06IA1a9Zg1qxZeOGFF4SMwpSUFMyfP9+sPUEQhDU0e81HxFfDWWE0SBVcEOyPVeE1h9gezNf4rFh3xh1ABpTrzgAmDSS3fban7gy/PQtYqjvjIH+jttkIJBalbLtNEARRmDiqPyZNmiQEZgCgQoUKGDt2LD766CPVbYlrdIjx8fFRDc40bNgQderUQUJCghCcSUhIQN26ddGgQQOLwRk+6DRo0CCkHj0KAOjbsSPW/PQTvtu8WRqckVEuMAAd20VizdpfMefzcYCXG1Yu+x8CypRG+y4v4K9j/6guq0Tzti+h72tjsfrb+Vg4awL+2LkVjDHMmLsaHp5eKlZm8noznG7q2LEvqlatJSyg0WgwcuT72LFjE7ZuXSsKzph01sqVXyEt7YHwvkOHXpgw4TNotTrY4jaQkZFql2YCgG7D+sPf13Lw0KSZzIN2AKeZuHZpVrdnrwaKjIyEt7c3Nm7ciBMnTqB+/frCvEmTJgmv5TorISEBU6dOxeeffy5MP3z4MLYa7fLk7WfNmoX58+ejf//+8PX1BQA8fvwYM2fOFB2HNArOEE8n4iywLeeNQ8GLIgvskdLMws4C46epjJhRohALXBIEQRQoOg8uM6wko8v/xdSmTZtw9+5dxMXFSYboDxo0CImJiYiPj5eIo+MHDwIAmokuJHlaNqinuI2MjEf4fG4iNm3ai4uXruHRI6m9y+0bnGVAlmhMekjVKnD2kKaTBQRywyzDn6/DjeYQDeD0L8PNu3fH0ncmVR/OzlrUrt0IpkKWnJioV4+zBTt7VtlGRxlL9Wako2deGz4Fr42Yol7YUgfT8HwXZrQ1s6PPtrMcnBLz5s1DixYtMGrUKGzZsgV16tTBhQsXsGnTJtSpUwd///23JKOxf//+WL58OXbv3o3atWvjpZdegl6vx8aNGxEYyHX04vYEQRC2ouQ4IKds5atSiy9+VL5dNspPpuMAn9SmBjkOEMQTBOkPRf3x119/AYDiKAZrIxtu3ryJcuXKWWyjRGxsLMaNG4eDRu1z+vRpLFiwwOIy2dnZSExMRKlSpdC9e3dk//svt48REagQGIif9uzBg/R0+Hl7q65j6LCu2Pzz7/jfxj1o3KYpdv56EK+NHQydzo6RlzkmXTJqwiwc+2MPEr+dDQAY/d6nqFmrobRrU603Ywq01K/fDGK9BDCULx+MwMAgXLp0Fnp9tjHhzcSOHdznv3fvLo4e/QNffvkxhgzpgC+//AGBgRUgxdxtICgoFEdOMnXNBJh0k6CZAEv9tnuu9Toy9mCvBvLy8sLcuXMxbNgwNGvWDL169UK5cuVw4MABHDt2DDVq1MCZM2cky0ybNg2//PILZs+ejYMHD6Jp06a4efMm1q9fj+eee85Ml7Vq1QqvvPIKvvvuOzz33HPo2rUrdDodtm3bhtzcXPj4+CAtLa1EabOSsyfEM4NSdpL4grkKLlhfSQCko088wV3Eq9ad4Z+V6s6Iz2pibKk7Iy9waYFUy7MFCrrApUpAjDyUCYKwiEbDDdkvyQ+FeiD2Is7uEtO2bVtUqFABmzZtwv3794Xp6Q8fwsnJCf7GDBwxgf7+ZtNycvRo3e11TJ+xFM7OTuj3She8++Gr+GDyG0JRy+xsc5suL+9SZtOcjVZcnl7mgoa36crN0duQAcY18PX1V7goZfD35zrYhw8zzLajjEb2Wv4esFzJUgUX6XgjS8PzAcDXncv+SktXzNhAxqNH0Gg0KFXK/NjKqVu3Lo4cOYLevXvj+PHjWLBgAc6ePYtvv/1WsJ8oW9bke63VavHzzz9jypQpcHJywuLFi7FhwwbExMRg/fr1Zu0JgiDyg3gkff7qzrhDWndGjB03whSxVndGNE2t7sw9hWkq8K4L/GgiS3bZ1oJd1jCzpCMIomAg/aGoP/gbybxdkxj+BnhBM3DgQLi4uGDZsmVYtmwZXFxcMGCAsoUzz8aNG5GSkoJevXpJisM7OTnh5ZdeQlZ2Nr7f9gs30Ud5HZ06RiIwsDQSlm9G4vJNMBgM6B/LWZrlgAt+5BlvZfOWZgbje95tQIyLqyuav9ABAODq6oaYnsNsPAJSQVW6dBnFVqVLlwFjDI8ePVRdU0BAWbz0Ug989lk8kpLOYv78SbDFbcDqHXuFbtpFdM9SbAUtdhvw8eE1k3SfebeBjEePjO1UviQRjmiguLg4bNu2Dc2aNcOmTZvw1VdfQafTYefOnQgPDzdbpmLFijhy5Aji4uKQlJSEL774AocOHcK0adPwwQcfKG5j+fLlWLBgAcqUKYPly5cjMTERjRo1wr59+5CXlwetVovSpUtb/XxFBY2cIUo0JT8LDJBam4lFh3FaqvGtfGShWhaYEbUsMN7azOYsMDGiLDAlyNqMIIhniatXr+LXX38FAERFRam2S0xMxJgxnI2Yt5cXDAYDUlJTEeDnJylqeTsrxfTGeP9/07a9OH7iDGKHxmDp4onI1HIi4jHc8b81W7F2xQaL+yiIDFEGmMQjTZ6MzGNDBlhqagoMhjyY4jPcvJSUuwAAL69SZsuoCwi1oIw5x47vwbFje2yqN+OkzUNQaIiqrRlgygDjfZPPJ18xa5OXl4crN24gLCzMYr0ZMTVq1MDatWvNpg8ZMgQAEBERIZnu6uqKyZMnY/LkyZLpe/bsUWxPEATBIzgOWEDNDrpg4B0H+GfeaQAw3f2xVTtZcxywMehTRI4DZG1GEERR4oj+8PHxgcFgwL1791CmjDRIcPu2XTfHbMbf3x8xMTHCtXC3bt3gr5CIJoYPOiUkJCAhIUG5zYbNGNm3t+o6tFotXhnYCXPnrcI//yUhonFt1KxVzf4PYAw0/HP8ML77+nP4+Pkj7UEKZk1+AzNnm1/fSxGLKE4H3b9/F3K3AX66RqOBp6eyRZiY55+vD29vXxw/fkA2R9ltICMjFatXz7debwYQdNOQca/DWyGJUEzlyhXg5OQk1OvkraB5Ll411vGsaltGtyMaqEOHDujQoYPZ9FdeeQVOTk5o0EBq012hQgUsXbrUrP2UKVMUt+Hk5IQxY8YI/yGe5ORkPHz4EA0aNLBvNFYhQ8EZothQGqIfjotClpNZ3Rm7cAeQrjKvoOvOAKpCQ6kkDWBT3Rkeed0ZS+S37gxBEMSzxPLly2EwGNCiRQtUr17dbH5ubi5WrFiB+Ph4jBkzBqlHj6JW1ar468wZHDxxAl1ekA5N3PfnSbN1XEy+BgCI6dpKCMzwHNx3RHgtL2ppsJRFZbajtjaURmvy8nJx6tSfqFtX6v188uRhAED16rVs3wcA5gEa8+H5cAKOHd2DJd86VnNGnAEmJyqSu4j/9ffDmDBiiGTeoZMn8Sgz06IItoWMjAxs2bIF/v7+aNeunU3LrFq1CgDQt2/ffG2bIIinC7EdtICddtBmdWfU7KCVdIfqoH++7gwgDdIAJosye+vOQNaW344KqTCv26mE0drMVpTqzkggazOCIAoZe/UHwI3oPn78OPbt24cePXpI2u/bt6/Q9jU2Nhbr1q0TXlvi8uXL2LlzJwIDA9G5c2fk3JMOfXRxAnb9eQQnzpzFidNnUb+p+Wdnxryw2CFdMXvOd7h18y7GTx5j1s4iooS2Rw8zMHFkf2idtfjm+z1YsmAKdvz8A5o1b4+uMbEmcwFBIim7DQDAiRMH0akTH1TiAjQ3b17B7ds3ULlydaOlmVJCm+n58ePHePgwHQEBgbDFbSAjI9UuzQQAPYf0Qxlfk+ZUchvwyHVD4wbP49DRU7h87SZCSpkyxRlj2HP4MDw9PfOdWGavBtq/fz+Sk5PRsWNHm0bt5OXlYc2aNdBqtejZs2eh7FNRQcEZotixOQssFFyWlM3wQ/QtZYEVVN0Zflrx1J2xWuBSLDREkNAgCOJZhTGGhIQEaDQarFixApUrV1Zsd+7cORw8eBBHjx5FOIA+HTti1ZYt+GzpUrzQrBl8Xbnz/vXbd7AgYY3Z8iHB3MXuH/v/wovdTN7R+/ceRuKSH+zb6RzRa3lAxmqAxjwDDAAWLZqBRYvWCZlDt2/fwJo1S+Hi4oro6Bj79k9ALDbMx+O/9sYUvDZyis31ZlxU+moP2Z3F6lVD0ap5few+cBQ/b92PDi0jkZoN5Oj1+PjbbwEAw4ZJrQzu3buHe/fuISAgQGIVkZmZCZ1OJxllk52djbi4ONy/fx8LFiyQeIQDQHp6OrxlHtrr16/HsmXL0KhRIzMxTRAEYQ3NXpldMTg76HMwv6ll5jgAmDSGkj2YvDwn3CHVRnwSG/9sj3ayp+4Mv23k23GAT2rjseQ4kN+kNnIcIAjCXhzRHxEREXjllVeQkJCAadOmoX379vA0Fmy/fv261Tow+SE6OhobN24EAKtJSQkJCTAYDBg+fDimTp2K1KNHJfN9XYHF6zdg+LSZiP9pExY2fZeb4Wa+rho1QvG/n79GdlY2mrzYBoApgc0ePh0/AtcvX8I70xcivHotfDhzCf7960/M/mQM6tZtgZBKlkbkSM/v27atQZ8+caha9TluLmNYtGgm8vLy0LmzaSTQtWuXUaqUN3x8pJZZubl6zJ07EQaDAc2bvyiao56QFxQciiN/M7vqzXC6SbmujLjezGuDu+HQ0VN4/9NFWDV9OlfPFEDChg1Ivn4dr732msSaTq/X4+LFi9DpdKhSRWobaq8GUmp/48YNDBs2DFqtFtOnT5fM0+v1yM3NleyPwWDA+PHjcfbsWbz55psICgqyuo19+/Zh5syZCAkJweuvv654jIoLCs4QBY44C0wYom9nFpgZBZoFxt/9kQ/VtzcLTEloqA2VMZIK9QKXYmzMAiNrM4IgCMfYtWsXkpKSEBUVpSqMAGDo0KE4ePAg4uPjMTMuDi0jIjCgSxes2rIFLfv1Rfe2rZGdo8faX39D0/q1sHXnH5KBlJ17tUTox0H4fPZK/P1fEp6rFY4zZ6/i16270bF7O2xZ/4v1nc1RuWhXCsgw2JgBxhAQUA5ZWY/Qr18UWraMRmbmI+zYsRlpaQ8wfvwMlC1bHmJhMn/+NKSmPgAAXLx4BgCwYMFUuBtrvXTrNgj16jUXbUtpFI2NuDjWD3019T1E9hyGbuPeQZ/27eBXOgC//vEHzly6hFGjRqF58+aS9gsXLsTUqVMxefJkYVg8ABw7dgw9evRAu3btEBwcjPT0dPz000+4cuUKXn31VYwePdps202aNEFwcDBq1qwJNzc3/Pnnn9izZw8qV66MdevWwdnZgbo7BEEQUHYcsEp5cIEMq4gdB+R20PYms8mROw5Y0Er2Og4kQzWpzRbHAYIgiKLGEf0RERGBNm3aYOjQoUhISEDt2rXRvXt3ZGdnY+3atWjatCm2bt2quq7Zs2fDy0vZcuull15C06ZNVZd1cnJCTIz1ZC2DwSAEnYYMGWIWmOHp0z4a4z6bi1Ubf8HsD8fCzc1VsV2m1g3RL7UAwFlBK8EHa3KUHGyygW3rEvHzj4lo2a4Leg8eCeQC3p5+mDY7EW+80gYfTeiPZSsOQuukZG1l7g3dtOkLiI3tgOjobvD1LY0jR/bh9Om/ULt2Q/TpYxpVdOLEIcya9T7q1m2EChVC4OPjh5SUu/jzz324c+cGwsKqYcSIibDFbcAiVhy5LLkNAMDgfp2xdv1vWL15O5KSbyAqoj5OJ1/Dlt27ERIUhBkzZkjaX79+HTVr1kRISAiSk5Ml8+zVQF988QUSExPRokULlC1bFlevXsWmTZvw+PFjxMfHm1ma3b59G88//zyio6MRFhaGnJwcbN++HWfOnEGnTp0wc+ZMs8/Xq1cvZGZmok6dOvD29sapU6fw888/o3Tp0ti4caNNNUiLEgrOEMWCtSww8RD9gs8Cyw98kEZsaQZYrDsjfivPAhNjQxYYLzT4LDAlyNqMIAjCOrwnMl8/RI0+ffpg7Nix+D4xEZMGDIC7mxsWfPghqlSqhMRNG7Fw9TpUDCyLt4b1R+/O7bjgjAgvLw/s/PUrvDvhC+zddxL79hxF9eer4utVc+EbWB5b1v+CXJXLMaWilvlDGvDQ6XRYuPBHLFw4Ddu2rcPDh+kICQnHO+98jPbtu5ktvWvXT7h585rZNJ6GDVuiXr1ImGeAyYISvEeyOAPMCuIh+UrD8zVGB57nq1XB4Y0JmDjzG/y0bz8eZWaiSqVK+Pzdd/H2rFnWN2SkUqVKaN26Nfbt24fbt2/Dw8MDDRo0wNy5c1WHzffp0wcbNmzAoUOHoNfrERYWhokTJ+Kdd94xy9wiCILILxbtoEPhgOMAL1qKou4Mvx0rWHIccBC7HAfI2owgiALEXv2xevVqzJ07F+7u7liyZAmqVauGJUuWYOHChahYsSLeeust9O7d22JwZs6cOarzfH19LQZnbGXHjh24cuUKoqKiEBYWhtSUFMl8vkanTykv9GjbBqt++gUbftmN/t1eEtowhTiNWmDGGtcvJ+Gz90YiILA8Jn2+TDKvQb1WGBL3PpYt+RiLvvwAY8d+LporL9pp0k4DBryBVq2isXr1Yly7lgRvb1/07TsMr7/+ntHSjKNu3Ubo3n0ATpw4jHPn/sXDh+nw8PBCaGhV9O07DC+/HAc3Nz5YZtltwCqyY2ar2wDABd42LZmDWV+vwHfrt2Hed6vh5+2NV7p2xYdvvGFW28gS9mqg5s2bY+/evdiyZQsePHgAf39/dOzYEe+99x7q169v1t7HxwcxMTHYv38/tm7dCp1Oh1q1amHJkiWIjY2Fk5P5sevWrRuWL1+OVatWITMzE8HBwRg9ejTef/99BAbm80KiENAwRinylkhPT4ePjw/S0tJIWNuB2D9ZKG7Z3vhsHDnDB2dOlq4p2JqdQ3VBZFxEuCAy7lwK5kaTJBvXcRPcxfk94+MuuFEpj4zPwrnntvFNJriMMP51LoDHotd6mMSH/ITGixF+xA3/7C561hqfvY3P4tfGJ19wwRlfmERGAEyZYIEwBWdCIYycKVuZG5rPD9Hns8D44Ix45AwvNMTBGUFoiG3NZCNn5EKDRs4QxLNBVlYWkpKSEBYWZmbRREhRygDzFV8Qi21xRYk4vHeyuN4MLzT4jC95vRk+A0wIzuRohKKWku4qF9Kui9cTiiNncsEJDIPowRSe+fO//BlQ8k42VaiUV6p0Nk7TGl8bL5p1UA7OWBmezwdk3JAtvPZApjA8nw/OCGUSjIUtU0VJY7759E1+WrD1f0/XwIQ90O/FMeQ1ZyS6SUEzAVxwgU9ou4AqQkLbVQSbEtp43STXTIBJN12HSDPxWuk+uA7lPrh+I9P4XvwMmPSTHLFu4m+quQPwgElDlYJJR5WGqmYCTLqJ10sBMNdMgEQ3qWkmwKSbLGomQFU3KQVnSDcRhH2Q/nj6UbI0E7CimQCTbpJrJsCKbuLdBrJh6qLEeonvuixqJoPotVwriTWU+Nwvfw9IAy8a2YPXSWLdxL829qPiyZasoI1uA2IraFdkCSNnXJElBGckuok0k1WKUjM5EJojiMJFfAEt9gwG4EA9FvEfgxcI4hEv/BnOVuSWZ2rTrHDXhjZKNm4KKHlOKyK2k2uv2gqAuVAkCIIgpPi6qsxQERk8jmaAWYQXGYqoZ4CZhERB3FgSCw4e2WWmA8PzxRlg1obnEwRBPGlYvLGvUC+SR5ycZZN9lzghjMcXCoNX3GXPcp0k11GWEHtvijVSpux1umlaJkwJd5awYNfGJ/fxQSs+8U8Ju63iCIIgCIuoWZrZijihTY7VhDZLWK3PKUYpWU2tv1aarnY/TSlgI0bBAlkcmLEBW90GiJIFBWeIEoF8aLkilmqwlIE0y0qCu8pruaDQqkwXh9zl05Tm8aNzFOBH9ogRW7MpCA0+A46EBkEQRPFgVWT4WJ6thLWilpJRMzzie1tKAkMxA0zeAFAWEUqiwxaxoSQ+5NPsuNxUC3rJUBqeL6CQAUYQBPGkIhnRYQV+xL0Em9w7lDSS+E5QftzQc2XPgIUiocqI9dJt0etkh3ZIcG0QI7HcVklqE0Y2EQRBEDahmtAmQimhzWHE1/9yvWRTLrVcPxX26EiFejOAxHhAQMUOmncbsAQ/aoYomVBwhigUxFlgwvBvfki4hSwwMYWTBcZjKQvMGkoCAzAXGQpZYHKUauYUMPLaPkqQ0CAIgig8HMkAs4kCywBzNCijNlRf3F4hA8waLtLtK2V9KWIhE4yG5xME8aSjVFNSrQalQHnLs6UoOQ4AUp30jDgO2AE5DhAEQdiIiqWZEna7DeTYeC6WmwqoziC3AaLooOAMUWwoZYGJh+iLKZlZYDwOZoHJhUYhZYFJIGszgiCIfONoBpjDlmYlOgNMqa+wkgEm901WQC0DjIbnEwTxrKPmOJB/O2jAPKFNJ3q2JUhDjgMEQRDPKoVhaSav0clj0dLMmtsAYIfbgC1BGnIbIPIHBWeIYkcpC8wqT3oWWKrsvdoIGmMWmFDcUwYvNJSywEhoEARB5B+7LM1sHJJvzdJM4InOAJNh6YqTL2ppJzQ8nyCIpw0zxwFHUbODDoDUDtpXqVFh1J0pRMeBZCu7oAJZmxEEQRQuhWFpZlFHqQUZ8uU2IH+t1FaMmtuA0kPcntwGnmUoOEOUSMRD9EtmFlgubMsCS5euJhUFWuDSVsjajCAIongosKKWBZoBJsaeDDA5ljLAHLjElAk4teH5FjPACIIgnjDEdtBmWLCDFjsOWLWDDoS5FTSPxUGdT4DjQAmAHAcIgnjWKciENofdBsTY7TZggLlesiVAkx/IbYDgoOAMUSQIWWAWUBuiL6HEZIHxyLPBlERGpvosoEgKXEogazOCIAiHKdailmIKrN6MWlsx8swutXoz8vaA3ZeaLvkUPjQ8nyCIpxAlO2g1zOygbXYcUHMYkL8nxwGCIPIPsxSYJp45LFmaqSEktIndBiwltPGJbHniifIJSnbQBfVbJbeBJ4miPEdRcIYoNBSzwPgh+sWWBVZYdWfE5CMLTIl8FLgkoUEQBOE4xZoBpmRpVuAZYPnFmpeybHi+PAPMBsRZXxYzwGh4PkEQzwBKdtAF6zigZAdtbwKbo3VnSqDjAFmbEUSB4+zMXR/q9TYVTiSeUGxJaLMVNbcBm1BLaFO0hBZTVPVmCs9tgMgf/DmKP2cVJhScIYqVos8Ck2NPFphYTNiSBWYhSFNEBS7FkLUZQRBE0WJPUUseuy3NbMoAk1OY9WYUhufzQRk5KkP01YbnEwRBEBw2OQ5YQuw4YIa7wmt5uq44yl5QdWdE00qK44AdkOMAQdiGTqeDq6sr0tLSaPTMU4KjCW1KbgMOW5qJYxJ2J7QBnH4qrHoz4tcF5DZgAxatoMltwCKMMaSlpcHV1RU6nQPDleykoIcMEIRD1Lt/2myURzguCoGHYFwVghIAuCywZHu24A3TVb678bUO3Fmbf7Y3c4M/42uNy2phWXS4m976yprdg/IIoGQ4VGPnb9S2LNpegGn0UntYLDq6RaOx7IVNEATxjFCYRS0LNwNMPHrG0QwwOQWYAWbloytlgNHwfIIgnla2nDcmTG2HVQtiMVVwQTpKpCIzH4HP6w1xkMPX+Cy5h+MOZW2kVZluD7z+4jdqY/93F1xQyRLXNOo22CLOobrErQHgHAeURiYRBFHwBAQE4Pr167h27Rp8fHyg0+mgoQDnE4tSWlWW+FQs1i6ixkx0OZ+t5fSK3phglgNXYcE8o7YwGLUG47VNjgHQG3834pxpPkctD6ZENjU5BINxO3xwRvxsgFQviRe21NfIAzLi5DUn0Xu+uIzWtNNMtphG9Hn4TWoA6BiQDTCmhwF6uCAHeQBykQsXZEMPINv4oTXG46/JgWlEqnFd4u/Ou1YtZGU92xqLMQa9Xo+0tDQ8fPgQFSpUKJLtUnCGKFHUwan8ZTCJL9hT5TP5oAz/OgPmAoMXB7xosEV8iAM8OuM28lHA7DY4uzYx+RAaBEEQhP2UiKKWJSoDTKnejNp7HjuCNBaG54uxmAFGEATxhNKFMfWRF7sAvMA5DshHwlfDWUVr47KVr5rXXgmEdLSJIu4w2YvxuobXS7ZqI4ja86+1sunyafy2FeBvJPmKpokT227CzFHhzqVglK3MJfcF4youIhxVcAEXUEViA6cGi1JxeBAltXWpaltdVYIgzPH25iwU7927h+vXrxfz3hD55fE9aTEwDy2QIp4gth92M3+d42QK0vMJa7nGZ72xr8gVpnPD8Q25xj5EPDhTHJThgzEG0Xtx7AWQTZC/VgvKwMJrJU2kkT3EgRr+vfG1RjbZWTSbdyFwhuA+4KTlPrwWedAa+1IdcoXXvJ5yMeghuEPzz0ZJ9VikMT2SkkBwuLq6okKFCsK5qrCh4AxRZDiaBWZGicwCkwsNcRDIAqnGZ2tZYEZ4ocFji9CwNwuMhAZBEET+caSopYBSvRklLDmXSYbRiAM0BVV7xlJQBlD2MbOCi4P7xAs+Gp5PEMQzgJLjgFXKw2J9FimWHAcAadDFGrmQugs84Y4DViDHAYKwHW9vb3h7e0Ov1yMvz+JFLVGC2V2jhtm0NqGiNy1Er5uYXrJGptenfTnHnLOoJkxLRhgA4IrRQecKQgAAt1AOAJBy1RiVv6UB+PgeX9M5BcAD4wMA7oPTC3x5M2FwSBaAh+Ay4vjnbOP0LOPrXOPrPEiH5eTCvB/Tip75Qpuuxveu4PpUV3Ceom7GZ1cAXtx7N3BdsDu45L/SAPyMD3+Y7htWAFCOwT+Y69jL4RYq4TIqGevOhYILslTHOdRM5e4Tao4Ylz1sfP6De9qdbNr7NmfOgOBqzBSFlZkYCs4QhYotWWBKiLPAxEP0S04WmJLAUMLCKJpHMPd5tpIFxsNngVlCSWhIssDI2owgCMJmbLE0S0q5gcqlYzD4lU74auUs1XZqRS3trjcDSLPBJBPE7wF5hlfXrg0BMGzebGWEkFWURsrIRszwWV+8RlGpNSPH1aScJK+FLWeYTTLDNyLCeiOCIIgnnIK1gwaUHQcAqQ4CTEGWInIckFubFbDjAFmbEUTRo9PpivxGKFFw5F6+bDbNTSxp0sSNTS+ZqI3GzQMAkGO8GcbdC3wIALhvNN66bdQy14wi4o6zmylpm/fmegzTvcEH4O6v8QGbVJhqPwu5AAxcACYDJk2TAxgNwrjXenDBGb4Ugrgkgrw/zAbXtxkgtewUayNem7nAJI58uUnuMNWD8xU1dTbOywHXpzMAzgyZbloE4ypuA/BCDkobj1kO7gv3At1cuO+HtzYTvo8bxk8g+vrc3GxMJCQKnIKvOEQQdqI4ZNxWVIIXyoiHo/EiQK2wpa2IT8biE7Q484sPz4sKXKYqrKqAC1wSBEEQ9hEbGwuNRoPSvr7IzrFQlL6oLc3EqNWbUUSploy1+jLKN7Sysh4jMfEbTJz4Bnr1ikTjxmXRqFEAbty4AotWZpauNOX1pQFoXS0cdyPW6s1cuHwZQ99/H1VefBHu7u6oW7cuvv76a7uLzl67dg3Dhw9HpUqV4OLigqCgIAwdOhRXryonRxgMBixcuBANGjSAh4cHvL290apVK2zevFl1G4cPH0ZMTAwCAgLg6uqKqlWrYtKkScjMJPs2giAsY3G0B49akCIAXJDDF6abQGbdlXyCXCfl92Yqr50yYVNgJ9W+tfMJfXygik/244NY1pBbyAmIHCC6VDWfrZqYSBAE8QygdF60hKWRoGrna7OEbUA6OvSebF4qTBaZEkT36QDRaz4Awwdj8oO875S/d8A2y4YEBIsYE7PJMafkQMEZosSglKUktuoyGykS6shW3FVeqwVpbBUdubJnAPb44t+13sTMyk0G33Ep+U7ba31gb4dKEATxpJORkYEffvgBGo0GD9LS8NOePQ6vK9fJ3NIrX5ZmDteb4bHlAt5yUcv791OwYMFUbN/+P+TkZMPb29esjZSCqTfjphqlUua/i5fQdsgQbNu7Fy82b44xY8YgLy8PI0aMwJgxY2xez8WLF9GwYUMsXrwYNWvWxNixY9G4cWOsWLECERERuHhRaiXKGEPv3r0xevRopKenIy4uDn379sXZs2cRExODhQsXmm1jw4YNaNGiBbZv34727dtj1KhR8Pf3x/Tp09GuXTtkZ5M/G0E8qwg3TPiR7TbaalXBBcsNAqFsCaYKr4XE2sjWZDZx4pq4UrNaJ2ZBOyndVBPffLPZsk0ZqzVPVdweCIIgnmWsBqPF5QxE51Gl4Lel8zAfXJeMCrWGPEADyEbNyPsc/r1SH1UQARrxvUV32bMRX5gSJsqA669t6LPFfb9SDWohId7CtQQ55RQvFJwhShzFnwVmD3qF10rTLJAqe29FaKhlgSmh1MFJOkKx0LBSB4iywAiCeJpZu3YtHj16hDfffBNOTk5IFI12sMXSjFkZQcNTYJZmEiszSzPk9WaURs+oYTrv+/qWxsKFa7Fjxxls3nwcNWvWV2jHPztQb8YGPGxIenhjxiykP3yIxM8/x7fTpuHTTz/F8ePH0bJlSyxcuBAHDx60aVtjx47FnTt3sGDBAmzfvh2ff/45Nm7ciLVr1+LOnTsYOXKkpP2PP/6IH3/8EZGRkTh16hS+/PJLLF68GP/++y9CQkIwfvx4JCcnC+0zMzPx+uuvQ6PRYP/+/Vi1ahXmzJmDgwcPYuTIkdi/fz/mzZtn1/EhCOLJxpYbI0qOA0o3YgBI6lQCsMNxwB1F6zggpmgdBwoiqY0gCIIoRpJFr62WO1BD3ifpZa/F09T6LyXkyd46qAZmLJktBMKsDzfr44knGgrOEEWKWRaYjZTcLDD5NCjMy4RiJlgqzLPAlKL7AFmbEQRBFDLx8fHQarUYHh2Nlg0bYu+RI7hyUyFC7gPk5eXh069XILx1d7iVj0R4w+6Y+WkCDAZptIS3NPtj90G8Hfs2WlVvhTpeIajjFYLeEa3ww+JlivvSyFuD4TGtcefmdUwc0R8v1g1AVO1SGDe0E65duQQASLp0GuPHdUPbVqUR1bIU3nuvF1JS5PurVG8GkmkZGWn45JN30L59XURGhmHAgGhs377RrLWHhyeaNImCj4+fyhFUqTfDT7JWb8ZFuo9KNWYU4csgpAHnki/j92Mn0DIiAu0iI02rdnHB9OnTAQBLliyxusqsrCxs374dgYGBGD16tGTeyy+/jHr16mH79u24dOmSMH3Tpk0AgA8++ADu7iZ1FRAQgDfffBPZ2dlISEgQph84cAB3795Ft27d0LBhQ2G6RqPBjBkzAADffPON3VZsBEE8G9hdFyXUka2o3SmS32xyNEjDP6toJTXscBxQtL4BWZsRBEEUBgVhacYHza2ep604ywDg+otUa43k/Y8jo2Tk9xXl9xctOfLYYYEdKn0rdhcSuw7xUA21JwcKzhCFjsUsMOOwuqcvCwwwP8nLssCsoRT1t6UDAlmbEQRB2MN///2HQ4cOITo6GmX9/dG3UycYDAZ8v2WLYvvXJnyCCbMWwmAwYGRcL7R/oSnmzf8e496ao9h+/qdLcej3Q6jbqC4GjopDzMCX8eBeCqYOH4s5b7+nuExG2gO82q0FblxJQqdeg9GgSWvs37sNo4a2w4Xz/yDuleZ4nPkQXbrGombNCOza9SMmThwgWoOlejPcs16vx8iRvXH8+EF07NgTXbr0we3bNzBx4kisXascODLHQr/EB2WUUOhq1erNKAVqNBnm7fYcPQYAaNOkCQDANyJCmNeiRQt4enpi717rhe5SUlKQm5uLkJAQaBRusIWFhQEAdu/eLUy7deuWZJ5S+127dtnU3tfXF35+frh8+bIkAEQQBGEJi3bQasgdB8yQZ/haqjtjzQ666B0HeBx1HJBA1mYEQRACBWlpZgm5pZnd9WYALik6VTxBNkpTmMZTkPVmxInh8vIJDtSbsQGb3IiIEkV+/JwIosCpd/+0fUGEUDgwqsQdytERHaQnX63xvQ62GfzzbflnGLdjY92au+AEkg3cuRSMspWv4iqCEYyruIhwVMEFXEAVxYi5HBYlCoi9AJP3ZHtYHNW0RaMhL0qCeEZgjCEzt2QXJXfXuivePP8/e+cd50Sd/vF3stm+LAu7LB2WKoiCBVRURFBQUey9nA3LKdZDTz0Vy/3sFeup6Hn2s5yIiKgUxUZHsQCCLL3tQnaXrckmvz+SSSaTmcmk7WaX5/16+ZrJzGQSuGO+eb7P5/v5RMvUqVMBON2/2uKkUaOY9PDDvDV9Og9cOwG1lmXeD0t49b+fMGTffnw3ayq5udl428Dtd13KgUMv0Ls9D77wID169QhYmgFUux38ddwZvPP0s5xz9U106tEjJG/mj19/5vwrb+KmO5/wHXDDQ3dew4dvv8CVF43gir/ew3nn3wCN4G30ctNNJ/Hdd5+xcuUyBgxQTzAZP7PLyrbTvXsvpk79hPT0DMDLpZdex4UXHseUKf9k1KhxFBd38V+ttS5ToxyzpvlZ8sM8liyYF2zcpAFpXuyORv9LD2n+8diBm14lnTjvkjMAyHYbr6j5Y71vMrJPjx5h59LS0ujVqxe//fYbbrcbh8P4J3C7du1IS0tj/fr1eL3esP+PrVu3DoDVq1cHjhUVFQXODRw4MKrrtVRUVLB79+7Ae/r0sabwFgRh72MwKyI3Fbp5wwVeiuOAdhJLt0xK9x90EKyJ1PtmuAjWRw7VMfUzWC+zUzNh5fRvc1XHytB3TiglqpVCq9nHUBRolfH9JFhZEAQhViKOY0aUGhw3cqQxLG2TmTejxcDWDMLzZvTQRDxEdBlSUOb8/PN9MmalFtKcEVKevqwNLGnszkZrIWDaJkfYQzgb34M3G58nirbAsNqQgdACQylAtMeUz9RBsTYrUB1TFxtbiWJ1UDg/s790zgVBiJpady2Hvn1oc38NUxacv4Cc9Jy47uFyuXjjjTfIz8/nxJE+KVdeTg4nHn00/505k69+XMjYww/zXdwW/vPhDADu/vsEcnODz/WuXYu5/rpzuXvyi0DQ0gygR6/QZkE9WTgccPrVV/L9l3NY/PVcTrro4pBrcnLz+OutPnsrpSY47sTz+PDtF2hbUMi5510fiJex2WyMHXs23333GatXL1c1ZyLnzVx77e3+xoyPjh27cO65l/Pii4/yxRfTuPDCv6qutqn+Ux9Dc8y8SbPkh3m8/NS9pteoOXzkIYHmjBkVe/YAkJ+bq3s+Pz8fj8dDVVUV7doZ2bNBTk4ORx11FHPnzuX5558PyZf56KOPWL58OQBOpzNw/IQTTuDdd9/loYceYvTo0WRl+TKFysvLeeqpp8KuP+KII8jPz+fjjz9m2bJlHHhgMMPn7rvvDuyr3yMIwt7F9D/8K9ln4RNQzcFw9UZ/VgVWzvdhTUBpXNx7Y7jKuCMWfPmz8amK9VDXSVabNApGTZpaQmsl5bOzw0/psR3fn0vNJptxTimYitqWtx8YsIMJEbWpEVGbIAhCgERamkVEz1FGO65ZscAEkpc3o6CsnjERbWejv3q1iPCxDeO8GT2hge74pUHGquZHbM2ElMRSM0Hvx3YR+h1m3R/02twZ7b4ZZrkzrSfgUjyUBUFo7UybNo2dO3dy1llnkZUZXNly7oknAjD1f9NCrv/pd5/MaMTwA9Ey4sgDdD9jT9UeHpv8GCcNOZrBeT0ZZGvDIFsbJp1xLgA7t24JXlzv23Tv1Y+sHFXjyQ1Fxb5Ofd/+g8NWcxQW+s6VlSnr+iPnzaSlOdh//6FhZw84wNeUW7XqF90/Tzjqn5M6PmZK3oyfK2++h0VbvCza6WVRpZdFdR4W1XlY5q3jV28Vv3qrWOPdySbvJjZ5NzFt3tvkRJNHkACefPJJ8vLymDhxIscffzy33norp59+OmeddRaDBw8GwG4P/rnPP/98Ro0axfz589l///257rrruPrqqxk0aBD5+flh1+fl5fHEE0/gcrkYPnw4F154IZMmTeLwww/nxRdfZMCAAWHvEQSh9WNlgsTKRIshUQm+zOygtRNN8ebOQEJyZ3SszZTmlFibCYIgJIZkW5op4uy4Lc2chOc8B1DnnSnbeFbJaPNm9JoxJqtmzIhBrC15My0LWTkjNDmprwJTHqKxqMDUlmZgTeqlwoq1mV8FplibKTSVtZkgCHsH2Y5sFpy/oLm/hinZjih/2OqgWJqddsghIcdHDhtG1+Jips39hl0VFbRv2xaAiqo92O12igoLwu7Vtkv4L+eqBjtnHX0WK5auYN8D9+fUi84mr7ADDoeDjaUbmf76m7jq60MszQBy2+SH1QdpXt/PttzccH9ixabL7XahzZbRy5sBKChoj91uCzteWOgbiPbsMRovtatnFOyhu2mhh0Kslk1QZ8zo5c0EUHJnKvzXZucBUFmtX4VVVlZis9lo06ZNxO8wZMgQFi1axOTJk5k7dy5z586lb9++/Otf/8LpdHLLLbdQXFwcuN7hcDBz5kweeugh3n77bV566SXatm3LaaedxqRJk+jfv3/I9QCXX345Xbp04ZFHHmHatGk0NjYybNgwZs+ezcMPP8zKlSvD3iMIgqCgZwdt6jhQQpx20FYcByK5DyTZcaAZEWszQRCE6InZ0swIPUszJ6r+vyKaVr9WE23ejFFxo14xo5c3A1HNFZaEvlRny1mZ/xNSG2nOCE3CeK83Ynfd9nX04WABOmMaAhlKPqFFhpILo86MUXAQW/dcmzujfJYJTmKyNlNyZ8wQazNBEKLFZrPFbRmW6mzcuJEvvvgCgJOuusrwujc/ncn11/hWubRtk4fH46Gs3EmHonZ4VfP8O7aXA+BW/bz6YtoXrFi6grMuv4AHX3mKenx2Vw2k8/m7/2X662+af0m9IchUWK3XmNHH6dyFx+MJW51RXu6TJOflKU2gxOXNACxZ6M+cgYh5Mw7cdC/pxuWXjAt+YhW69OnuV0Vv2EDB0NAVQY2Njaxbt45evXqZ5s2oGTBgAO+9917Y8UsuuQSAoZrPyMzMZPLkyUyePDnk+Lx583SvB58d2gknnBB2/KKLLsJut3PQQQdZ+q6CIOy9xJw7A8ZNjZC5KsUOGhKfO6Ns9RwHVJ+vvCzQXKaul9TWZqXEnTsj1maCIAjWaFJLM4VS1X5EkbYRLs2+dkxzGxyLFfWcYLjYLpA3Y4SJVaeCzPu1TKQ5I6QcrVMFBoaraJz+bRMFXKoLDSvoqcCk0BAEoTXw73//G4/Hw5FHHkmvwsKw83avm9c/mcHU/30SaM4MGdiPpb+sZP4Pyzh9fOiyz+/nLw27R+naUgCOPSV8An7Z/O9CD9QbfFG9GsCfN+PDo36hg37eTGOjmxUrFjNkSOiqoeXLFwKwzz7qyT51g0Y7wWc9bwZgyffzePmJ6DJnLr9kHNluk1U0wBH+RsbcBQvQ3v3bb7+lurqakSNjVYH4qKqqYvr06RQWFjJmzBhL73nrrbcAOPfccy1d/91331FaWsq4ceNo61+xJQiCEA1qxwFdlCaGUWgyEOo4kE5owyQRuTNN7zigiNrMHAciitrUjgOCIAh7Gc1uaaYnNNCyE/0IgRDUY5pbZ9+lORbNOKdglDeTHbpbQOh8oFFcg4Y+rAns6+XNBFDGLL+YQFZ5ph7SnBFSlr1GBdbEAZdqxNpMEIS9Fa/Xy2uvvYbNZuOZv/2Nkm7dQs4X+ONnVq/fwA8/rWDxz78xdPC+XHT6OF57fzr3PfoKx40eTk4b30N87fYKnn/6rZB71JNFt56++y75dgHHjA9WK4u//oaPXn7V96JRZxzTPuKjqgfUeTPaG4W+fu65h3juuXdJT/cVDdu3b+Hdd18hIyOTsWNP0bw3UjGkkzejw5WT7uHK2+8BJeInw4sjs4EM/x8ykzqy/J2qTOos5830Kynh8AMPZP7ixcycOTOwIqWhoYG77roLgAkTJoS8p6ysjLKyMoqKiigqCv54qK2tJT09PWSVTX19PZdffjm7du3i6aefJisrK+RelZWVgXwZhQ8++IBXX32VYcOGcfrpp0e8fsuWLUyYMAGHw8H9999v6c8tCELrJVY7aDVhdtAJcxyAyAI2K2gdB0xCkxWcxOQ4kGzE2kwQBME6cVuameXNgM8K06k9qMqBBhKfNwMJy5vpSNh4po42iISVjDoRXacG0pwRWhQtSwWmbJX7Wyg0wFgFplNoGKnA9BBrM0EQhCBz5sxh3bp1jBw5Mqwxo+bSU8bzw08rmPreJwwdvC+jjhvKpeeP57W3p7P/Uedy6ilHU1/v4r/vf8WwwwYz89PQX8Fjxo+hW0kPXnrkGVb+soq+++3Ln6vWMv/Tzzj6lNOY/dEH1r90YzR/Qq/Ofui2qKgjdXU1nHfeMYwYMYba2hq++mo6FRW7mTTpfoqLQwedp56ajNO5C7Cxdu1vADz99J1kZ+cBNk49dQIHHHCU72IlbyaNcKtlNRlxFgT+vBmnf9XRY3//O8dPmMCpp57KOeecQ+fOnZkxYwa//vorEydO5PDDDw95+7PPPsu9997L5MmTueeeewLHlyxZwumnn86YMWPo3r07lZWVzJgxgw0bNnDFFVdw3XXXhX2VQw89lO7duzNw4ECysrJYuHAh8+bNo3fv3rz//vukpYU2r6ZMmcKbb77JkUceSXFxMRs3bmTatGnU1NQwdepUsTQThL2UpNpBl5BAxwFtkwaC9U+Scmec/m0zOA6ItZkgCHs7emNTs1uaqTGaB6wN29G8dqm28ebNKEWPOm9GITF5M2r0RNnROOUIqYF1g3BBSCABRZHyQ9ZkabjR8rywjnFUCim1UlV5KKq73OqHZjw9TK2HpQl6XX3TJpM19AY6vQHRDL0BN+JyVkEQhBRm6tSpQDA/RI2yagbgnOPHkJ2VyTufzKK2zmer9fLT/+DBu6/FZrPx3PPv8/ms75l481945Km/h90rNy+XN+Z8xHFnnMQvi5byzrMvsXPLVv7vrX9z1lXXhH8x7Qp6PcLmeEI8zrCSNwOQnp7Os8++y0EHHcZnn33I9OnvUVzcmX/+8znOOWeC5mobc+ZMZ8aMd5kx4x127vRJ1ebM+ZgZM95kxow32LjRLw4w+3WpM6Q6MhsC+5nU6e6HYZA7M7BPHxYuXszJJ5/MjBkzePrpp7Hb7Tz33HNMmTLF5IuF0qNHD44++mjmz5/Pk08+yTvvvEPfvn354IMPeOmll7DpjIHnnHMO27Zt47XXXmPKlCls376dO++8k2XLltGzZ8+w6w8//HC6d+/O9OnTeeyxx/jqq68YN24cixYt4uKLL7b8XQVB2HvRm4BRT9REyqUEfI2NDgRtVQrQmTMymkSyWicpg5tb55jeOQiGNtcGX0ZCL3fA77IQsnqIoGWOYqETFQarlwRBEPZqEmRpZoiea472ub/T2meFz89p5+7cRNeoMUM7VurkzUSB0dguYuyWi6ycEZqMlqMCiyZ3Rm1tBuEqMAh9mGtszdSHCjS3jhRwKdZmgiAIMfP222/z9ttv41y82PS6/K551Kz6NuRYWloat914CX+/6xIAah1Be6udXl+Dop7gse69evLcB68FjjX4xwd3fQaL6jzQYAvJm1m0wf9sdxMyhHTpWsKi5V6dvBk4+OCRLFpUg++k2s5MP2/mk08WBY7dcccj3HHHI/4zenkyNv97luHrvNj9x5SlMcoxiytELV6mh82gKaNmn3324f3337d0v3vuuSdkxYxCjx49+O9//xvVdzO6lxGjR49m9GiZ4RMEIX5itoMuIgbHAaVeSpbjQIy5MxatzRTHATPicRyQvE5BEITImI1ZhnkzaswszZxmn1xL5LwZLUYiAjO0eTMGtmYFBIURHTDOmzGZ+xNaPrJyRkhJmkcF5jDYjwZ1UJi6QNGTepnIvyx3+4MDlTJwmdm+xePpGe1yVUEQhFZHm+Cut43xZWrq/MEq6mYN+BozgK8xEwnL811m+TL6TZpIq2uC6DVt1FjLmwmQGfoyQ/WHzFJ3qvxku01W0QiCIAhhqK2OdT3qtZmWhpg5Dmj340EZB2qJyXFAj9LovkEkxwFDEeFxBscFQRBaCS3O0kwZJ0IszSo1F6nzZmJpwEB4zkwUeTNmWgSdMVo9lqvHeCO3ISAovvYLryUXLTWR5oyQ0lhSLOl1kPU6zSGon4Lah2ciLM3cmi1YW4uPubWZ5fDOUGK2NotQaIi1mSAIrY2CzMjXGFFjQe3bYDaJZWZpZpo340Hf1ky7j4XjZgRX0fiwa7aaU+q8GT0s5M3kWB07BUEQWilGdtB6+SemEzRqorKDhtDayaiOUj/w9cY6I7tno0aMom6uDD3k1LlUXS+ZWJsZIdZmgiAIcZIKlmZRoc2b0e5HQk/crR4DE5A3E/VYHUQ3I02DrOhMHaQ5I7Q4Wo4KTG8FjQvTJk215rWRzUCptW9gtdAIGTCl0BAEYS8hkqUZbSPfQ21ppqBdJROR8IUiPrQNGg86lmbaro3Wzkx9zCqRVsmomzSan5JKU0aLko1pgmnGjB4Vvo3T6O9PEAShhRLNhEnUjgMlBjfSOg6EoZ1M0k5GWUUtYtPLnWmdjgMiahMEQfCRUEszNTuxtqoybJWMS2dfmzdjdUWNdkyMM2+mJPSlekzXizHQ+00gpD7SnBGaDSMVmB5NqwJTvCEVrKrAjCzNlHNqdAIunRG+qoWASyk0BEEQkkQclmaG6FmaaYeLaARcga6NlZUzepg1YWyEr5xRY/EnpWYYdWQ26F6m16gJ5M2Y5M4UDB1q7XsIgiDsjRg5Dui5DugKe5WHuFltFC0ty3FArM0EQdjbSBlLMz30BM1a0XMAbd4MRG66RFWM+dHOKRrYmkF43oweFvJmYs1JE1IDac4ITYoVFZiy/K75VGAK8VqaQeTcGRPUKrAEFBp6JMLaTBAEoSWit2omWZZmykoaxdIskDejJmZLMwXtKhntyplk5M0o2/jyZtRI3owgCEJs6E3MqB0HdLHkOJCNNceBWGsniNpxwEmzOQ6EII4DgiAIybc0UyhV7avFy3rPfycW82bAN95oV8pYIVLeTDqmeTMFOrcsIuLYHHFs1yJ5MymPNGeE1kFSVWBNnDvThAGXasTaTBAEQUVTWZpZxWN2QnsyUtMlyXkzEFW/JiMmRZogCMLeQ5jjgAlGjgNhdtBxOQ6oaQbHgUiI44AgCELSiHbVjJqYLM0i5IYBUVhcqhv/ejWIlQZNtHkzamLLm9GNdEB/zA/kzZi4E0neTGohzRkh5YlpeV5CVWDRohcopqcC01QWcQZcar04lYFNT4UghYYgCEIUJNPSTL1QJJKtseW8GWVfb7VMJKzkzRig5M0ovy6VukSvNsmI/L1yol1xKgiCsLfgn3CxEvirS0msH6xVAMeaO6OQmo4DUVmbqUVt4jggCEIrI+Jcj4XnniXHFquon+vaFTPOSG9Wjy1GeTPaYxB/3owyrxhf3kwkJG+m5SLNGSE1iFIFpl7Gl1gVmHo/VhWY9pjeOQWVCkwPo+6/SaGhqAysItZmgiDsbaScpZkRlheUeNHPm8HkmBYreTPaa6P4Gakzb6fOm1FnzOjlzQRQ8mYqfBtnuAuaIAhCqyAaVWvUdtB6FBFuB53quTMJdhyIydpMEARB8BGFpVnUeTOlBse1DRplXAizNFOPKcp+M+fNQDBvxgLqsVw9xgstn1bdnPnf//7HmDFjKCwsJCsri169enHeeeexcaOFH6dCkxDmedjsKjCFRKnAtFu9ADIDnJrXRj7KUSDWZoIgCBZoakuzuPJm1JZmRnkz6q322khomzLavBmw/HMyjvk7W1XkawqGDo39AwRhL0fqptaBJccBrR10R4xDiANEa2lmleZ1HBBrM0EQBGP0nleJtjRTmuJRWZrpPd8to52PizdvxuprFUreTK7qmFEsg16Eg4aY3IaElKJVNme8Xi9XXXUVp59+OuvWrePcc8/lxhtvZMSIEXz//fesX7++ub/iXs3epQIzIkIH3izgUj0Qlfq3Bh6cYm0mCIIQJ8mwNFMwWvWhrQ0UB7MwSzOjxowZsebNqFH/fIwiXAbQ/tWo82aydP5Cst0mq2gEQYgbqZtaP4lzHNCzg1ZomY4DZoi1mSAIQgRSxdJsJ9ZWUYY1/yNlnumhzZvRjnlRrJrRQyeiQT12q8d0o4w5IJg343cpChPGCylFrMsCUpopU6bw0ksvcc011zBlyhTS0kInDtzuaDuiQnMzmBWRmwndvOFNiiIirDjJxvdDX/26Ft8DVfn/ifJwjWZJo4vgPy+X/x7qYxBeXGQHDxVE8VEqdvzZneLeG9lId2uNKz/L2w+M7E95HJbs5wRBEFKZlLA0a9BpZmt/mkS1il5raZaMvBmttRmEaXyUvJk0zAXUkjcjCCmD1E0th+l/+AVSs/D9Lp+D4Ur3/qyK3jLGEsqY5/LvV+F72Mdi/aLg9t/DTbAGSydYGFkQzDkJrZ/KCFcgl2LJZWENfRJqFzO+X/ik2HSbTcKYBUFoXTSHpZkWrcg5gHYFpl6eTCLzZtShm3rNmghEHdUQxIoLkYw/qUerWzlTW1vLvffeS+/evXn66afDCgwAh6NV9qT2OkxVYKDbcdbHLJQrGSqw1Au4VCPWZoIg7NWoLc0MVs2kjqUZRLYrizdvxuhazU9Io1+U6trEANOMGUEQkobUTamPlQkUq3bQYcKtEoMLLTkOKKhro2T9fyVBjgMKGmszhWRZmwmCILREmsPSLIxIlmZ6QmwnOnkzatR5M2bjSyLzZsB83lGHktCXkfJmIoqthZSm1TVnvvjiC3bv3s2pp55KY2MjH330EQ899BAvvvgia9asiXwDockJKImUVRlzjK6MsGxPTdSd5mzClxwmK3cGmiPgcldpO3p+thGbK2iDI9ZmgiDsTSR61YwecVuaafGYndCzNVPvJyJvRvsfJCJvxpHZoHuZaaNGyZ2p8G2cVv8OBUHQReqm1kUkO+gQ9DzsjfzuA6gnmYxWs1ht0uhlzaj3tcplRfWssjbTy50xsjYzQck30EOszQRBEAxIgqVZWN6MFaJ67ivzcNpxJ9l5MxqVQwFBAUQHjMdeyZvZa2h1UqglS5YAkJaWxuDBg1m9enXgnN1u56abbuKxxx4zfH99fT319cFqv7JS22UVmgrb15GXQ4Kvgxzyo7oE/YZFB51jYT0SZQm9sjxfsSOLFz1rM+W+OnI0J+HWZurl+dsJXxm0yQbdvGHWZq6VmZz30n/Z9/OV2D1eVl7Ul8p/5IR9pFibCYIgRIfW0mxj6UaG9xrO6Refw/3/fgWIwdJMrzbQzZtRo7UzMyJ47uSTDwXgk08WmFyvR+LyZqxiq4p8TcHQobHdXBD2YuKpm6RmSl0s2UFr6YjFgOV8ggWU1g5abQsNoVZlRvWUYmkGwTpJXS/VYtkGxomxtdlW4rKJiRexNhMEoVWTTEszBaO8GbAgaq4l3OnGpbMfqVGTgLwZoyGtiIjuP2r3IKF10epWzuzYsQOAJ554grZt27Jw4UKqqqr45ptv6N+/P48//jgvvPCC4fsffPBB2rZtG/ive/coOraCZaL5Idq0KjA1RpZmZqgf+FpLM1Tn1OiowLTEEHC577LfuPnqKfzfqfey32e/Y/f4/m4GvLGG3YvbG78R69ZmsnpGEITWxGX334et5zAKhxxLfUZwZUeTWJpZOa6LNm9G2U9U3oxvxUxdXS1vvvkMd955CWeeeQCHHJLNsGFpbNlS6rvMat6MigzVHzRLZylRtjs2u7PVq1dz9tlnU1RURHZ2NkOGDOGFF17AG+Uk2KZNm7jqqqvo0aMHGRkZdOnShUsvvZSNG/Uz3TweD88++ywHHXQQOTk55Ofnc9RRR/HJJ58YfsaCBQs45ZRTKCoqIjMzk379+nH33XdTW6u/wnb37t1MmjSJvn37kpmZSYcOHTjzzDP59ddfo/qzCQLEVzdJzZT6qB0HTO2gY3Ic0O5rH/oWcmJCSIDjgBVK/Vs9qxyC1jqK1Y6aqBtesnpGEIQWSlNamhmiPKdLVcfMLM0U55mQIaNSc0BtaWaGXmSBGcnLm9GNcCCCq5DiRuQXVmvFAULq0eqaMx6Pb5IiIyODjz/+mGHDhpGXl8eIESN4//33sdvtPP7444bvv/3226moqAj8Z1SAC01PTMv1Ysqd0T40Y2nSKOjlz0BUuTNOzWs9X83S4O7wZd/z5F9u4fnzbuLAeT/jsdtYMW4gL3x8KWuO7AVA3w+DbxAPZUEQ9gYiWZpVVVfz3xlfYbPZ2OWs4OMZ85L3ZfTsuPTyZgwtzfCf1LMv0+5HakhYy5vZtWsnTz/9D2bNep+Ghjry89sFT5rlzWjJiNwgyYk2m03Fb7/9xiGHHMK0adM44YQTuP7662lsbOSaa67h+uuvt3yftWvXcvDBB/PSSy8xcOBAbrjhBg455BBef/11hg4dytq1oSIRr9fL2WefzXXXXUdlZSWXX3455557LqtWreKUU07h2WefDfuMjz76iCOPPJJZs2Zx3HHHMXHiRAoLC7n//vsZM2ZMyKoEgPLycg499FAef/xxiouLmThxImPGjGH69OkccsghLFgQ7SooYW8nnrpJaqbmIRo7aEuUGBzX5s6Eoa6XtJNS8aIncHOhWz/pWZtB5NwZP4p1TlKszQRBEForCbQ0U5riibc0044Z2tcuYrM0M0I7/iUub8YqVjLoZMVmatLqmjNt2/qSfIcOHUqXLl1Czu2333707t2btWvX4nQ6dd+fmZlJfn5+yH9CatG8KjA10SjCjFRgiQu4zKmr5oH7b2fqPyZw4MKfcaXZmHHqSG797J888sTfqC7Mpedi39/RhrFddT/OcAAVD2VBEFo57339JdU1tdx0+XnY7Xamvmm82gHCLc0A6vFZlzUa2X0lzNJML2sm0g9tq9cphGbNFBQU8uyz0/jqq0188slqBg60aCOmzcUkNG9GnTFjmjdjkb/+9a9UVFTw8ccf88Ybb/Dwww+zdOlSRowYwbPPPssPP/xg6T433HADO3bs4Omnn2bWrFk8+uijfPzxx7z33nvs2LGDa6+9NuT6Dz/8kA8//JAjjjiCFStW8Mwzz/DSSy/x66+/0rNnTyZNmkRpaWng+traWq6++mpsNhvfffcdb731Fo8//jg//PAD1157Ld999x1PPvlkyGdMnjyZP/74g5tvvpnvv/+exx9/nLfffpt58+ZRX1/PZZddFphsFwQrxFM3Sc3UdFiZSFEmZCI5DqgDhXUxchzIDtvB2GM/ntyZKB0H9FBP1qlrJxPHgaZAHAcEQWgJRL1qJuUtzcA45ywWzPJltNZmEHPejAr12K3nIhQxnkBIeVpdc2affXz/2AsKCnTPK8eN7CKE5qHZVGBhc2t6KrBoV8uosVJ0xBlwuRUGblvCR/cfz6k/TMNjg5kH27j+ajvPXD6U7SW+5UNHvvwj6XVudg5pz+ajO0ccGK3k/YAUGoIgtA6mvvcJDkcat179F0aNOJjZ3yxi/catupZmjY2NTHn4XwzrO5puWftyRN8jePbBZ8MmxpW8mR9nfcd9V17GGQftw1Gd8ziqZx5/OXYoH735UvgXccGwQTauuuxodmzfzJ13nM+xxxYxcmQbbrzxJDZt+hOAdetWMmnS2RxzTA9GjuzK3/9+MeXl24nW0qyqqoIHHriV444bzBFHlHDBBccya9ZHYdfl5LTh0ENH07atuS1mognkzSjbCt/GqbP6aPXq1XzzzTeMGjWKE044IXA8IyOD+++/H4CXX3454mfW1dUxa9YsOnbsyHXXXRdy7qyzzuKAAw5g1qxZ/Pnnn4Hj06ZNA+COO+4gOzv4W6KoqIibbrqJ+vp6XnvttcDx77//np07d3Lqqady8MEHB/+8Nhv//Oc/AXjxxRdDrNimTZuG3W7n3nvvDflOw4cPZ/z48fz22298/bUFyZwg+JG6qfViyXFAawcdl+NAumarYKWG0hOxJchxQI9SwOMBg6aXWJsJgiAkjpgszRRKLX6IobWlMtemkKy8GYg7b0aLXmSDhpjchYSUpNU1Z0aNGgXA77+Hdw5dLhdr1qwhNzeXDh300uGFpiQlVGABzFRg6uPR5s5oj0XrXenHqXmtqAW8Hi5fcyvvvn4hPXaWUdYG7rkgjdfGplGeb8NVvC8A7bbtYui7ywCYecNY0DRPpNAQBKE1E8nS7Le1f/LjshWMHXEYHTsU8pdzTsTj8fDaW9N17zfxynu5/7ZH8Xq8XHztJRx93NG89MRL3H/DP3Svf/3xR1j27Xz2PWgYZ105kRPOuhBneRkP3noVT97/N90hoapyN1dceiRbtqzjxBMv5qCDjua772YyceI41qz5lcsvH01NTTXjx1/EwIEHMGfOJ9x55xVR/b24XC6uvfYcli79gXHjzmT8+HPZvn0Ld975V95772VCV89A8Gejzs9HJW/GiMzQl4nMmykY6lvFM2/ePADGjh0bds2RRx5Jbm6upeZFeXk5brebnj17YtMRG/Tq5bMInTt3buDYtm3bQs7pXT9nzhxL1xcUFNCuXTvWr18f0gDatm0bRUVF5OXlWfoMQYiE1E17F82bO5NkxwEn1hwHvB7yy9bS/YtZtPviN8CatZkeibI2E1GbIAgtCvVckMHzLm5LM71cMLO8GdDkzSjiZzXJzJtJV+3HmDejI5BQj9XqMdw0b0ZokSTCFDal6NOnD2PHjuWLL77glVdeYcKECYFzDz30EE6nkwsvvBCHo9X90QWFbt7Qh3lHTL2Gg+QTfGBn+/fT8T2UlS0EH7LK8UjLIt34/qmp75Ou+qwIxUo1vpU+CmVAERS4N/Po/As5Yq1vgmdBv1zuP/9inHnt6bLrIeqzetCY35WN2LjkxTdwNDRSOqw7fw7vSV/+1PskwDeQKs0w70iVb+Vo4l/RJAhCi8Hr9eJNcbW0LTtbd/I8WqbO8K16uOj0cQCcftIorrnlYV57ezp33T8Buz3YjPhm3iL+8+r/GDRkIDO+ew9Hrm8lyRV3/I3xB4zSvf9tU56na9fevhf+PoS71s2NF47jvVef5ry/3ECnjj1C3vPH6p85/4KbuOnmJwIuZg899Fc+/PBFrrzyGK644g7OO++vgAev18NNN53Nd999ycqVPzFgwP7or6AJfV1Wtp3u3Xsxdep00tMzABuXXnojF154LFOm3MuoUeMpLu7mv1rboFG9VDdlFB2DSd7M4nlfs8zfSEnDjcM/vjpwk67se3xbWz2U9OjMJePH69wwlD/+8C3D7dcvfElnWloavXr14rfffsPtdpv+DmzXrh1paWmsX78er9cb9v+xdevWAb6VOgpFRUWBcwMHDozqei0VFRXs3r078J4+ffoE3rNjxw727NkT1qDR+wxBiITUTa2f/qyK3UZGIReDFSnZBOsgh8F+NLgIDh4ugnWW+v9/2t8l2b5D2nmvnfgcE1Q4GvdQtPonMhudAKQ1WP+Oq9knIZNh4/tJMLMgCKlL1JZmKprc0sw0b0aLMnaoRdOJzJsxIwrb16iFEkEC83bKnJ3flUjGnJZBq/yl/fzzz3P44YdzxRVX8PHHHzNgwACWLVvGnDlz6NmzJ48++mhzf0UhRgazImyVRx/WBDruxb03hgaIdSZKf2GlKaPsVxF7gWGGtvhQPk+DE90QzsMaP+DhH+6hQ1UjrjR4bvQoXjp4CnRJo8cOX+BxVfujwGaj369/MPKDbwGYc8NRIatmklloTLfZJGxMEFow3tpaVh10cOQLm5F9li7BlpNjeo3eqhk1LpebN/43k/w2uZw6diS0gTxyOO3Eo3nzvzP5avZCxo45LGBp9vZ/fFk0k+6eSG5uTmDNR6eunbnkhit48q6HgKClmbs+g669ekFD6Oc6HA5Ov+BqFnzzJYt/nMtJp1wccj4nJ4+//vWfIXkzxx13Nh9++CJt27bn3HOvQTlps8HYsafz3Xdfsnr1r/7mjBrjZ/G1197mb8z46NixC+eeewUvvvgQX3zxPy688HrV1TaDfQMMfmUunvcN/7r3wcjv9zPyiIMsNWcqKnyeZ0qOhpb8/Hw8Hg9VVVW0a9fO8D45OTkcddRRzJ07l+effz4kX+ajjz5i+fLlACE5HCeccALvvvsuDz30EKNHjyYry/f/l/Lycp566qmw64844gjy8/P5+OOPWbZsGQceeGDg3N133x3Y137Ga6+9xr333hvyW3bBggV8+umnYdcLghWkbmp5TP/DP1k2C5+CeQ4w2jcxY8WWuDsbQ1eJlBBuHaN2HHASrEfCNBvpmoNq4ZqVGkppwihiNuWYQ7WNwXGgQPW6zEub4nUUVK7Crsps29M98kqZNfQJ8/f/mf2js5I5jqB1tyAIQksmSgeVpFuaOTGxNAP9VZeRss3MMMqbUVbPKK8NbM0gNG9Gj5LQl4nMm5H5udSl1dmagU8FtnjxYi655BKWLFnClClT+OOPP7j22mtZuHAhnTp1au6vKMRJ3A2FAoIPxTD0cme0+5EwChxT70cIuNQRrNt3uLgp7UpeXnoXHaoa2dzOwSWnPMBL3Z4Hu4N2e/5Hfs08vDYHzo6nkL2nhsk3P4DD3ciiMQexYaivCBEPZUEQ9mbUlmbT5n7NzvLdnDXuWLKygif+cs6JALz62rSQ9674yTf+HDZiKPVkhZwbOuIw3c+rrqriX/83mfMPH8JRPfMY1sHGsC42/n7VGQDs3L7Fd6FqiOjeox9Z2aGNp6Ii3++Xvn33w2ZTflz7toWFxQCUlW0z+6OHkJbmYP/9h4YdP+AA359j1Spl8knbiDHzL9OgqWEcmQ1cfc9d/OqtYo13J2u8O9nk3cRO7xp2etdQ7V1BtXcFHtciPK5FeHctYt7b//K92SRvJtE8+eST5OXlMXHiRI4//nhuvfVWTj/9dM466ywGDx4MELKi6vzzz2fUqFHMnz+f/fffn+uuu46rr76aQYMGBYLS1dfn5eXxxBNP4HK5GD58OBdeeCGTJk3i8MMP58UXX2TAgAFh77nvvvvo3Lkzjz32GEceeSSTJk3iggsu4KijjmLfffcNu14QrCB1U8sgmgmVSHbQIeh52kcMJ84mcu6M+uEfTQ2ltTRTttrcAAM0k3SOymo6Fv5Ae8fv2PHg8Y9froxc6goLA24LWmszRfgXiViszSSvUxCEVkHKWppByLxaALfOvnoljd55hUh5M2pM8mYKNJcaxS9I3sxeR6tcOQPQvXv3kOBVIfVpUhWYegmkerFMACMVmFJoWF1Jo7Y0g1AVGOivw9enY/0qnmx3CQf84QTgqz5duKP/G+xp2wWATO+fdC7zKZG3F1xHXW5/br73Orqt38L2zh2Yev9fqCYvxKvSiFiszWT1jCC0LmzZ2eyzdElzfw1TbNnWnp9mTJ3ua7785YxxIcePGTmMrl2LmfbJN+zaVUF2sa8R46yoxm63U1jUPrAYps4fqFLUsTjs/q6GBq4eO4qVy5ayz/4HMu7si2jbtpC0NAdbNpQy44PXcTWoug2Nvk1ubvgS+LQ0h/9cG80Zb8B2yO12YbZSRk1BQXvsdm2jxRZo9OzZU4mlvJlIZEa+RA9bVXTXKytmlBU0WiorK7HZbLRpo/37C2fIkCEsWrSIyZMnM3fuXObOnUvfvn3517/+hdPp5JZbbqG4OPi/t8PhYObMmTz00EO8/fbbvPTSS7Rt25bTTjuNSZMm0b9//5DrAS6//HK6dOnCI488wrRp02hsbGTYsGHMnj2bhx9+mJUrV4a8p1u3boHvNHPmTBYuXEj37t257777KCkp4dxzzw37DEGwgtRNrRM9x4GIWLaDBmuOA+qVNFbsoPWw4DigfI0C5YCXNoWlFHReid3uweNNY3fjvrTJXE9GYyVV7XqEZXBGIlGOA4IgCKlIyluaqTG1NNNr5KvzZszGoWjHKPWKGfX0evx5M2qszOEJLZtW25wRWgbjvV7LSqEDdv0e1oXvy1r9ZZHa3BnwNWb0uuwBlIemldwZiGmZfVjujPbzdHDCsd1e5n7XU7Td7qEuHR7rfxpv5/6fr6jYDraO9XTjFuzUsSfrMMraXsIps6cxZvocGtPs3P/47VQXhAcIgxQagiDoY7PZIlqGpTp6lmbqVTMbt23ji/kLABh59lWG93nt3a+45voLAMhv2waPx0N52S7adOgact3W7b6mgMffwHDXZ/D19PdZuWwpp/zlcu58/BXfhX5x1hefvMuMD15XWZdp8Kh3Ggm/0Guwb+U1OJ278Hg8/tUWNpRGTHn5DgDy8pQGkUHejHJIyZ0xE5FlBD9/+bw5LJo3PyxvBiAdd0jeDEBJcWcuOUvf1qxgaHDlj5I1o2TPqGlsbGTdunX06tXLcn7GgAEDeO+998KOX3LJJQAMHRq66igzM5PJkyczefLkkOPz/Pk62uvBZ1V2wgknhB2/6KKLsNvtHHTQQSHHu3btyiuvvBJ2/T333GP4GYIg7N2oc2cSawcNobkzahKZOwP6jgPq76D65JwaCg/6iax2u3xX1hdS7hxMWvsGMhor8WKnum03n4ivJPI3EmszQRAEmt/STDuX57RyIxfheTPKfjLyZrQ1Rux5M8W9N+peZmnuTsabFoc0Z4QWR8tXgVkPuEyvr+W2Iy7jvB3LAfizMJObs55gde1ony2bn05tniCb1bhpz6YOD9BrWyl3PXc/AK9ddxG/HDQoxKsSpNAQBEH497RP8Xg8HDnsAPbp3TNE5ORNh0a3m9ffmMHrUz8KNGcGDRnIz0t/Zf78nxh3emhzZsn878M+Y9OfvufsUSee4jugGi6WLZwffBH1HJba1sxspYzxucZGNytWLGbIkENUR20sX+5rWO2zz+Bov5QPB6a/MBfNm8/z0WTOHHaQYXMm5LqRPtneF198wW233RZy7ttvv6W6ujpwTaxUVVUxffp0CgsLGTNmjKX3vPXWWwCce+65lq7/7rvvKC0tZdy4cYb5OWoaGxt59913cTgcnHHGGZY+QxCEVoLGcSBmStCfDNPzxDd0HFDqpWTkzigfbKZA9pLXbwPthv6OPb0RT2Mau7cNYI+tJ2CjwO5r3FdndMbjCGatsckG3bzs+LM7xb19Tgzd2cha+orjgCAIewVRr5pRjTfqVTNNbmlWjUmTRj1YqW3M9PatkIC8mUiURP8WUI07mjFHO84IqYsYUwutAvUP57AOs6YDbY1E5s6ol05qt8aDQUnRUv57yIhAY+bjHn05o3weqytVI2EZtMmfS2Hu2wBs4p/YS/N5/OWbyamv5cchh/L2FeeE3Fc8lAVBEMDr9fLaJ9Ox2Wy8/sRkXnnkTl6ZovrvpTt57dV7GH7Y/vzy82qWLv6VGrI5+yJfk+Wp+56ipromYGm2bfNW3nz6+bDP6dyjJwA/ffttyPElP37Nx++8HHpxo+Vvr9nGznPPPYjL1RB4vX37Ft599yUyMjIZO/YMgqtmlG0aln8+6uTNAFx7zx2W82a8uxYx771/Wfq4ffbZh6OOOoq5c+cyc+bMwPGGhgbuuusuACZMmBDynrKyMlauXElZWWi1V1tbi9sdOkbX19dz+eWXs2vXLu6++26yskIzhyorK8O+0wcffMCrr77KsGHDOP300yNev2XLFiZMmIDD4eD+++8POedyuaitDZ0V9Xg8TJo0iVWrVnHdddfRpUuXsHsKgtD6MJpwCUzQ6KAWZYWItoxyZ0yzZ7STTk2ROwP64c6Q1qaW4lMXUjj8F+zpjdTtbM/WBSPYU14CZTbsNhc5dl++257MHlF8l1BituYByesUBGGvJK7nph56lma16h1t3oza0kyPJsybUXKvO2A8xqrGZPVYrZcdp5cxZ4Q0/1MbWTkjpCYWVGDqJfqGlGCcO6MlabkzCmo1mN4HBwuYYwe/wAO8Qt5uL1WZNu7Pv5RPf7klZLUMZeDoup2u3X0TTmXVf2FP3gju/Oo+BmxaRXmb9tx6yyOUre8QogLTQ6zNBEFobUSyNJuzYBHrNm5h5GEH0btHN1DFkHhV+xdcdjo//LiC16d+xINDh3LkqOGcd+mZvPPaBxy7/7Ece9qJNNTXM+O9aQw+bBhff/o54LM0Axgx5mS69CzhP888wtrff6FP//1Yv2YV387+lKPHnsbsmR+EfkkjizM86NuYaVfOWP/hXVTUkbq6Gs47bzQjRoyltraWr776hIqKXUya9DDFxcrKIF8z5qmnbsfpLAdsrF3rW2X59JOTyM7JAxuces4EDjjsyNAPiTFvJoCSO+OPkXHWG14JwPPPP88RRxzBqaeeyjnnnEPnzp2ZMWMGv/76KxMnTuTwww8Puf7ZZ5/l3nvvZfLkyQFrMIAlS5Zw+umnM2bMGLp3705lZSUzZsxgw4YNXHHFFVx33XVhn33ooYfSvXt3Bg4cSFZWFgsXLmTevHn07t2b999/n7S00HyfKVOm8Oabb3LkkUdSXFzMxo0bmTZtGjU1NUydOjXM0mz79u0MGjSIsWPH0qtXLxoaGpg1axYrV67kxBNP5MEHra9GEgSh5RGvHbRlLDsOKHYtzek44CV3/y20P2YN9kw3Hrcd57IBVG0ogQKb77IOkJuzCbvNQ4OnDfVp7Xz2bVEI+GJ2HDBZPaOHrJ4RBCElaUpLMwW1zaZpPIEZ2kk+F8YrZpKRNxOBIiLmzegRlduN0CKQlTNCStEkKjBTsgn1hdRTgamJpb+ppwIDcIFtD9ePvImnGl4mr97LyvY5nLn9DT5dckvYXWz2errvNwmHo4LamoFsr7qBsT/M4vyl7wBw+6UPUdZez4sgiN4AGbVlnIWBWlbPCIKQakz99BMALjnzJNPrzjjneLKzs3j/nZnU1tYB8NDLj3Pbg7dhs9l449mpfD1zDhffPJHbnno47P05eXk8P30Oo086g9+WLeK/rz3Lzu1buO/ptzjromt9F+k1ZAJDmJI3o4fRBE6kvBnfMzk9PZ1nn32Pgw46nM8++4Dp09+huLgz//znS5xzzpUh14KdOXM+YsaM/zBjxuvs3OlTIM/56kNmfPI6M6a9zsZNke1fMlSFTxbhnZZst+/v2FYVdiqMAp18lUGDBrFgwQJOPvlkZsyYwdNPP43dbue5555jypQpkW/qp0ePHhx99NHMnz+fJ598knfeeYe+ffvywQcf8NJLL2HTGdfOOecctm3bxmuvvcaUKVPYvn07d955J8uWLaNnz55h1x9++OF0796d6dOn89hjj/HVV18xbtw4Fi1axMUXXxx2fdu2bTnllFNYunQpzzzzDK+++irt2rXj5Zdf5pNPPiEzM95OmCAIrQ29CZzkOg6oicVxQL1v7jiQlueiw5lrKBq3Enumm7rNBWx9ewRVv/cCbCqrGy9tctYDUOXp6cvsVCj1b/1WOoq1jmK1E7fjgAlWg7YFQRCSSXNbmgVQLM1KLdzEic/SzBS9cSWZ6OXNWLQ2M8mbsWKvKbR8bF6vSDPMqKyspG3btlRUVJCfH0WYkxAV6gEhMBAok/7+h7/y4F/efmCggaCsnFlDn5CHfIhnZSnBrvt2fF33MnzLIZ0EfSrDlkLuIhggVqXaV5bTuwn3rNSi9qEE38NZ6azn+LfZQDY5ubU8dsgdHL1lNwAz2vfjjoXv4GrMDV8GiZdu42+joOdnNLrasHbNO/Ty1vLWjAvIcdcw9bDLePyyWwINKuXhrm5eKQ95pbmlXjmjLuTUSyVDmmRqFZgmd0avySYqMEFIPerq6gIh6VqLppaM3qoZCF05gzrKo03odeqVM7UO399LjerHdT2+Y3WqZSHKsQb/c19ZOUODjUAPQj2/pO7Tuwj2X5ReTKBhoz6g7CurZTyq/7ya/9DZKmitymz+/+yqrXY/LXSrPaQMbcpwp4jJMoEM3+c7MhsCzZlM6gLNmUx8DZkcasObMyYrZ/SaM0JkrP67l9/AQjTI/1+ahmhrJiCkblKEWWvpG+r1r54U20pozQQGdZPaPka77wZqCBZYyjEwniTT+vcrdVI6oTVUFrmD7LQ7tpy0LA9etx3n/P5ULuoNWTZ/rUSgbsrtt4miHj/h8aSxyX0MXtJ9SmVlMqzEvzWom7Q1EwTrpqhrJgipm6RmEgShubHUnFELci00Z7TzdUDI+AMmeTOl/jcoYxFYGI8UthM6Z2c0l6eMSXrzeXq2ZtrxKZvgGJWjet3ef749Ic0ZI1szZeVMZ0LGIvU4pB6DtGOPMu7YviY41vjHGO34ImNL8kjEb2BZOSO0SJpeBZaI3BntseC5kq5L+e/g6zl6y27cdngkaxy3/PCurzGjxQkdhr1IQc/P8HocbFjxBG2qcnl+1jXkuGv4vmQ4Tx19k+9aAxWYHuKhLAhCayakMWOCXmMmEqaNGS1mlsaGlmbKSa19WbJ/ZGubOGDpp6POMKnkzUCwGaMmJ9xbVBAEQTAhMPGiTPb7J2bMHAfUGNkdB2jW3BmFUMcBe46X0ef/xGHjl5GW5aF+Sw5b/z2MyoV9wOu3MXMG352eW0X7rr8AULmnt68xo6C2zLFAJMcBQ9W45HUKgpCiWHrepJKlmV7eTAC9WiJS3oyCWWNGvZ/kvBkVVvNmrI75QmojzRmhRWMpJ6XE4HgHQh+SuisOtTZm6mIjEZZmbo4e8jnvFk2h9y4Xzmw7V9XexL9/upXQyTAC6oC2g2bQ8VBf6PSWJf/AVX4gU5ZcR+e6razLL+Hmw5+k0W79u8VsbSaFhiAILZW2kS8xQmnEJA3DVTNqzLJnrDZsjJ7JNtV/EPpTMS38ciMs2i3rNWoEQRCE+NELCtab4AHC7aAte+BrbVuUfW0tkm6wb0S4Hc3A/hvYcvMlzO53K2+nP8Dur9ux7c0SXOWN6E3K2dLcFO27FHtaI7VVhVTs6aefW1Dq324KHRejtTZTY9XaTBAEIdVoCkuzMKK1NIuIsjpG2Ve2Rnkz0WKUNxMFeg0avWiGBCCrZlIfac4IKYeRCswqTaMCUxPjw9jWyF+Pmsoz9TPIr4M/2mZx5rqH+KF0BKHWAMFNTs9ldB1/FwA7l17K7rVncO+6uznQuZwKRz7XHPs8lZn+WcdS849PpoeyIAhCc6FnaWa6aqaN/mH1qpkane59XbRJ9+p5Jr3FlJYxsiqzep22KaNtxhhdq/nJqLU000PzV5Rh8Q9rxdJMEARhbySWCZaog4ObPXcmdJDMdtTz2Dnv8Nt5N9ExzQnAnqpsKn9s51sto4uXwsEryMjfg7suk7INB0KZ6trtBm9DZbFjQrIdB0TUJghCShCHQ4rZczLM0kwP9XNa21hXLM10UTfrjew0E9GgUaOMdUpRZDFvJoIQQi9vJuoxXWgRSHNGaDEoy/VangosXPl12n4LeGXMA1y3fRVpXviifSfOXvoEWyrCA4MDdy/aSI8Lb8DucFG5cjTbv7+Ryyqmckr5J7hJ4+YDn2B9217hbxRrM0EQBEt4DRo1CnqrZiJamllpKDRGviSYMaPsK9tE2ZvZMG7U2E1fAqF5M2oy9L9bls5fjJI3YwXJmxEEQYgdUztoIwqw6DignaRSH4sGN93b7uLH6x/lbwNmhJw58IXnMVM35HXbQG7vLXg9Nsp+OQhPo4GgIonWZiGitigdBwRBEJKNpawZNSbPMYWEW5qp0bM0C8uNrtU5qXWvUa+kMfObVjBy01GvngFDW7NIaAQRRmOyJdcgocUizRkhJWjdKjAFN9mOBr6/7l6ubv8Fh6+vo9EGT+UexI0/3EG921ghZc+ppOc11+LI203tpn3ZOO1BRrnncfOuJwB4qOft/NDhCFMVmBlibSYIQmtAb9VMGHFYmsVNpLwZ3SaNh1BbM6MVMYlerp64vJlISN6MIAhC8onZDroInx20LsnLnRnRZx1LJz7B4DYbAscaGtM4+KXHqaxXcjmVgOfgpFxGWyftR/4GgPOXfaivaB86qWdmbabBzNosLlGbIAhCCyahlmYKpTrHtM9rp6WPItzq0mz1vtG5BOTNQHjejB4loS8jugH5CeTNKG5DfvehgBuR0GKQ5ozQqohaBRZV7oxWBRadv+SgDtuZfd0j2L7Novt2G3uy4dqG8by09FyCk1+1hPpjAmkuelzzN7I6r8O1qyPr//MMfV2beKT2Fux4ebfNObxdfH7owKWoDUrNv5NYmwmC0NqxammmXjXTrJZmYXkzarSNGaOt9notetZmytaoGRN73owjsyGwLxkzgiAIicPIDtosIFjtOBAyAaTndW8htNjnOKAmfseBK0cs4asLXqDIURVydtIXF7J0a4nhN7FnuuhwylJsaR5qNnSk8o/expN5eqI2jeNA0tA4DoioTRCElKOpLM026TzrIlmamRLuXBPcb+K8Gb35RaOYBYO8GT2XID03ISMkb6ZlIM0ZocUTlwrMkMTmzpy6zwoeOuUt8j7Pp20NbO/g5dwNf+GbPw/2X6E3Q1cLVNDlovvI2+9HGuuyWf/Cs+TvSeN5+1/JpYYf0w7jgaJ/QKQf7ylkbSaFhiCkHl750RaR5rU0g9DVM2C+UiaapoxR1kyEFTNK3owRFvpXpo0aTd6MkDjk37sgtGyimWiJZgInDMt20OCrleJ3HEizeXjunM/51+j3ybA1srShD3tcvrH1g9+G8szC4/1XKitm1HgpPHE5joJaXBXZlH03BKo1Y5zeqpkUsDYTBEFoSlq8pZkTA0szo7wZvf1EoWflGX/ejB6SN9N6keaMkJKkjgos9tyZrJIc0vLT+PvIufxl0Dz6zM/C4YF1fTyct/xK/tyl+Kypwy9Di4yiE9+h/dEf4/XY2fTCo3g29eLp7OvpatvCem8Pbsp+ErctPbTQSJK1mXgoC0LrIy3NN7vuckWdSp9y6Fmaha2aSWVLM0M8hDZbtPvNkDejbsooorEIc28ZKhGCXt5M4JtU6R93Wml0CZZQ/r0r//4FQWj96E3omDoOGNlBFxB0HNAl9tyZdlm1fHXte1wz4FsAntl1Ag5nDXnpDSzb2o2LP55AuKIh6DiQf8gGcvrtwOu2U/a/g/Hu1IjpUs3aTERtgiC0QJrN0iwqtHkzWpKVNwPhq0lNMMmbUY/RQutHmjNCypCaKjCIWgVmh/bHd6THOR2ZesU0Rtn+YMCvaXiAtQc0csX8S9i2R3lga8PJQBlICo78ik5nvQrA1rdvZc9PI3io8DYOdiyl0tuGazwvUGEr0A9GA8sqMKvWZpaJY/mrIAhNS3p6OpmZmVRUVOx9avrmsDQzO96IjqWZ0WoZr8F+okhQ3kyG9e+V7bZud1YwdKjla4VQvF4vFRUVZGZmkp5uZCskCEJrJaG5M9lhO5rX1nNnBnXeyZIbX+PowtVUezP5y+9X0mNnKYOLt7JtTxtOfvdqalzKOBQ+sGZ2K6NgpE+ot+urQTRs16gxnDp/JiNitDZLRl6nIAhCorHU9E1FSzMnFi3NalX7ytZoxUwi82a0+8SUN6OH6dityZsRWiaxpJoLQsowmBVhP4T7sCbQcCjuvTH0h3Vn9JsWBZrXYdnE6f6DDnwP8HTVFpSHui3dRodTutKtTyN3NNxLl6/rKai2UZ0Je4Y1MPGDi9lSpf0wVPfw/ZNsM2QBXS99BoCdn57Lrq9O59YOj3J87ue4vOlcX/MM67J6h9+ijPBVQaX4HvibbNDNy44/u1PceyMb6W4YNLaafawVbxYY3y88kGy6zSbel4KQIhQVFbF582Y2bdpE27ZtSU9Px9bClJqVv/yie7xO/ZjJI/Q3eTAGBa+qL1DvCDZEXH6FbgOZgTc3+psUDWQA9bhIJ+Bf1uABly30/spv/kaCgl+l76J8VFh/RWnOeAmunNHbonpttLpGi02zbye4Ysau2tr9f2Y7vrGpMfS7Kl+hUXVLL8Hh0QZ4vaRluPDWgQcXGTTQCLj9f5d2GsmijnrA5v/fxqb8vblV9yTkfy7q6iS3Jlq8Xi8ul4uKigr27NlD165dm/srCYKQaGbhm1CbQ1QWWt3ZaGp5TBEWFMz56BRQBGsnhXTVa9/+KQes4Y3xH9LGXsdGTwfOnHcxp2Us4JQjf6fO7eCUdyewqVJRUSj1kuI4kI09103RKVuw2WHPLx3Z85Pmz1KNb3JMQV0vbScma5k19NHNAVBY3n5gQEzoHWnu/iAIgtDcWLU0M1o1kxBLM6s40RluzCzNEolZ3ky26UvAOG9GhXqOzmycMRpXtHNvQstAmjNCq6A/qyIvKy8hfLmk8mDUXX2Sjc+7Uv26Ft+D2K3a+rDnZlB8Zgf277yNq8qfpt93dtIbYUehl3YHVXPl25exvqKd/+pgIya0yeMmu+8aul/zDLY0D7u/OZbtH1zOhQVvc0n7fwNwR/kDLMw41He5k2BjKUmFxs/sH7BCMCw0RhPs2AuC0KLIz/et5CsrK2Pz5s3N/G1io6YsfNYoxwHl6gNqu6wsYLfmtZ8Gu+953KBS+bpV+y7/s1s55vZ7fHncDuXi0K26KaNuyCgNGq/qXKCnoteMMfoPna12H8KbMtqttkGjfm0PXqvt4aSpLknz/6cMbw6wO9z+3UYc/gIpHXdgX7E7y/D4iyel76Js/XVWjaqxlrNuHUJsZGZm0rVr18C/e0EQWj7T/9CfWLN97fu9fsCu38OsZ/qyVn/irJs3VMnckSgsk5WZKJd/Xz3waps0AF7+ceIC7jt4Fnabl+9dAznr3TMZmfMrt50xH4DLpp3Jws3d/Ner6y//uGyrocP4PTjy3DTszGLXFwMIjFfKPF2B6iN3ElQvq2unrfhEfKWYqpjX0jfMaiYuUdtxRFQ7i6hNEIQmJcFOKJYtzdQianVpZ+QUAwTzZrTH1GjzZqxamlkhQXkzelEMBkTjIiRjR8tBmjNC6tPsKjAIPuCVQiO0wHC0z6LjWe0Z224xJ615j4ErfJN1G3o3MuKgnRz16uX8sUtZXu8mtDHjCGwzu2yj543PYM9soHL5QWz+90TG5H3DbcWPA/D4zpuZUXsS1PjfrlaB6aEUGhHQKzTiQgoNQWhR5Ofnk5+fj8vlorHRckJ9yjD3hBPCjo0q0Rw4UrV/aHDXOyy4/3tBcKJqDf0D+6X0AmCDakzZQE8AttEJgPKN/oftNv+k0GaCxYTSJdpNsCm0C99wotQUIYtB6vxvqsPXnKnHJ/11+4+5/ceULfjGkUb07TIhpGNCGr4iIg1fZ8oBZKr2c/37uf7jeb7XWQTj2NoA7QFFc9AOKMQ36aUsyujkpbD7VjqxDYAerPdvN1KCr8GyD6sBKHGuxbbI/74F/q0vdoC5paF/klErVyJET1pamliZCUIrYbzXG3UeScIcByJi5DhA4FiWw8XrF0/n7G5LAXitchR/fXUknW27ePHsTwH4v/kjeeeXAww+w1c7FRxZQ1bPejwNNnZO64bXVY1vbNOZHHNikpWjIQrHATVqUZshJqI2PccBQRCERKE3biTSUjFuSzMjnFYv1K6eMVs5E2lVTbrBtnnyZiKOLUKLRpozQsqSOiow8D1w9R7eDjK7eulwRj5XZL3PsIWLKNnsa8xsHlrPmD7l2G2wslLPYFJp0vjUX+ntyyj520s48mqoWdObjc9P5MDMn3ik8z+x27y8vfs8pu6aEPw6aoxUYAqlJN/aTAoNQWjxpKent7jJW6PJqawM1YvjgArVa1Xfwqu6zpaVE9hv8D9IfUXGHgB2+c21fBNZ3hABwI40//Ibry2oAFMa6ZUExQBlBJs2Tnw9FyeqWqIWX8NFWZJi8x9TmjSKlYviqaxVgCljlVFzRlmGr4xBNv+90/A1d5S/kAyCTZwC36FsQr2TlVVARf4/YzY+/zEvAQVYbZYDm38VTx4N9GUte4AGdgFgY6tPAZYRtDYL/G+1xf8nWR/6J8nKykIQBEGIjZgdB0DfM78WrDsOuOjatoJPJ7zFAXkbcHnTuH3DOTzxZn9sjY385+KPyc9s4NsNPbh77jH+e+k7DmT1rqft4b6Bo/zzDrjLTTLhWri1mYjaBEFoFpra0swobwaCNZMuSm2koN2PZnWMNm9Ge0yNSd6MFUoiX5KouAEhtbGQ7ioITUcsPzr1OsjqTrO6Aw2YryRRJpx0UXfKfa9z+qfR7dw87rI/yTFzFlOy2UZtBtydOZi1FT2x++cMRx9g3glKy62k59/+TXqhk7rNnVj/5PX09JbzXNcHybS7mFN1BA/suJlQWxo/ToObRtV8CqI3kKoHXPVArB6gBUEQUgWrvslGWAr1VWE51FKN0+rdtUvzrSjAjAoQbeNNqwBTiEEBFmFiK1EKMJmcEgRBSDwRV4VY8Mk3Hi+Ck1qH9y1l6cTnOCBvA7u8eZyy8Foe/3c/vG4vEw9ZyIieGwBYWVZImwxFkKAQDHdOy/dSdJLvXNXSLGp+zzP+Ws5I39uPskqo1PwyPYueiM0uMzQWQolUsQuCIChYWm3ZEizNQgRtlUS2MdMej5Uo82YKCM4vdsDCGGphLPYTaPYr4mi/a40Iolsu0pwRWg2WOsolOseK8D0sC1THssN2NK/TaXMw9D3Vxb1Vkzn0y220r4KtbW2cu/0kPvj5UK7+dDxv7hgOwNgBij++euLMt2/LqKHHje+R1XUHrvK2rH/8StrV7eGlbg9QkLaHn2r3YdLWu/D4cw1CqNa81rNps2hHEHHgjBYLhUa0dgyCIAhq4i0yjBRgekStANN79qaMAkyvqIDw5k2UCjCT5flqRAEmCIKQWAITMoqtsH/Cxmy1huGqD633fVSrSoxVxJeNWMSc85+h2FHJ743dOfzj65n5eXHg/D6FwaS4CQctY8b5b6jerbLstEOHUzykZUP9Vju75uQSXE2qTNLVhmxCUE/2mVlc+wUXigDD1C5bhZHAI0TUFoVVtyAIQrKIV9CmpvktzSD40NfWSS7VthnzZoqwnDejN0ZHkzcjtCykOSO0aiypwCKiVYFlUTDKzSHHbOAfpfdx0NcNZLhhecdMTv3pUv7Y0QPlAf3lmt4AjCo2aGHbG+h+zafk9ttEY3UWpY9fQvruHJ7v9jzdMnayvqET12y+hzqvxsKllvBBKlKhUerfRlFoqAdYS0pyKTQEQWhmku2brDSyDZ+dpTrHkqIAi7ZBo4eeAixbs1W9LCCyAqwkzq8kCIIgJJyYJ3QsZFeGk416ZWaarZGnz/0vU0e/RqbNzWe1BzP85atZ9XNOyLtu++pYHv1+OCvLCgHYt8MO3bu3G2Ujs4uNxlrY+XEmNGon+bTh0Cqcqn312JxAxwE1kYQfASyo1UXUJghCrMT7/EgJSzNTlAa9sq9sjeolK3kzDs1rSFbejBpxG9g7keaMkNKklgosG9K8FJ1cxRlDZ3Pz0mfY/yff2Q+KO3HBN9dQVR+6pH7Wki4ADMrYSOfiOs09vXS9ZDb5B/yJp8HB+icvwL2liMe7TGW/rA3scrfhqk23s7sxE+MJOz9Og6/dBIWGqbWZFBqCIKQSBr7JahJqaWaEM5pPMFKAKaSAAkyLarxViyTMvPlleb4gCEJ0NIkdtB6K44DSrDccJrIpyK7i84mPcP0+MwF4dMepjH/qIiq2p6FdwVnVkMmtX47m+pknALC5qo3qrG8iLWeAnfyhvmy08hk2Gis9GNZHWiJO7hGX40DU1mYmojaxNhMEIdmEPWfisDQzE7RFxOi5qxW0OdF53NdqDmozOKNF6yxg5DYAcefNqBC3AUGaM0KLpilVYPYsNx3P2ckNvV/g4q8/pddGG3XpcG/2odw9fwJeMsLes31XDivqfRN3xx28KeRcx7N+pN1Rv+FttLHx+VOoWdOJyR0/YmTe79R60rlm841scHWK/MWSbG0Wl4eyBik0BEFIFIn0TU64pZlCyinA9LbJz5sxI5pxXBRggiAIiSEmO+iIuTPZqMeNfTqXsuiGWzi2/U/UeDO5+JeJ3PriWDwNNkLHnVBLzS5tqgDYEmjO+MY8R3svhcf7hHAVP7ipXduoepfWClRzyqk5lgTHATXJyusUUZsgCE1CAjM6DS3NSnUu1j6PnUZ3VcTL2mNqtG4DsTZsICl5MyVxfB2h1SHNGSHlSD0VWDZp+fX0uGA197f5Byd+9TtFFbAz38bFuy/iveUnq24S/rCeu70/AMf2/tN/xEXRCT/R4cRlAGz+9xiqlpdwfdEczixYRKPXxi1bL+Hnui5EVII5Na+bsNBQIx7KgiCkAsn2TbZsaWY11NKJgaVZshRgDp1jCsnLm0nU8nxBEATBAhrHATPUKxpD7KDjdBw48cBvWDDhKvpmbmWTp4hjvryH/3x4MHiNFMgKDirrMwHokFMTOGpz2OhwajvsmXbqNrhxzlcaM9qxURlDU8dxwDIW8joFQRCiJaUzOhXiErRBeO1kVDepz4nbgJA6SHNGaFUkQwWWXlzBkL/M58HqWzhyXiVZLvitYyan/DqZFZsPNPkg34N61u++DxzVbhXgpd1Rq+l0zgIAtr47HOf8gZxfsISrC+cDcO/205mzZ7DqPuqAS1WhYXEFf6wky0NZr9AQFZggCNHQlL7JZjSdpZlCshVgUeTN6FES+VNleb4gCEJyMJqYUSZymiZI2MvfT3yOaeP/Tlt7DQsa9uGw/zzCjz/0InQSy7hJ8+Mmn3PAkE7b6NLGp4xud0w7Mjqk07jHQ9kn1eBVj39RTLClmuOAiNoEQWgGmiqjMyIWn7fmK2jU44F2P5oGjLrhonYWSGDejLgNCAZIc0ZolSRKBZZVspOTz3+Ge1f/kyHLPAB81qUr58x/FmdNoepK4wJjzrJO1HnT6ZK2i/1O+o0ul3wPwM4ZQyj//ECOy1vFHcVfATClbBQfVBxKTBNuTtW+XsClFBqCILRiEumbrKZlKMD09qMlSgVYRHub+BVggiAIgjlN4jigZwet4ziQmVbL2xMu46GhT5Nm8/JmxbGMfOZZNq8vJHSSy7xJs3VPG37c1BW7DUb3Wk12vxzaHJCH1+ulbHoljdUN/ivV46J21amGlmJtJnmdgiAkkJieFwnM6IzZ0kzrNgAGbgNqtCI27fFYSELejLgNCDpIc0ZIeQIqMM0S/bhVYBFyZ3IHbuT6Uydx/Q9v0Wc91DvgsfwxTJr7DI1etdJYXWyEP7zr6h38WN2XBVmZ2E9bhM3uZdfX/dn+/iEckr2Rhzt/ht0Gb+8+kBfLjyJ0maWF5TFOQif69AYyNaX+rRQagiDsLTSXpZkaS6GWEJsCzErejJkCTCEBCjCtCMKESOO3LM8XBEFIHjE5DoBuY75L2018d9PRnNd1Jm6vnb+v+ysXTbmP+j1Zqqv0gpb19mHB5q4AHNx9O4Un+JZrVi6oom6DMnC6NVsFA8cBLU79w81ubaZBrM0EQUg0sQramtzSzBnNTbRjg4I6tzPWJk3z5M2I28DehTRnhJSkeXNnvLQb+jsPHnIh58xeSAcnlLexcVX17by66Gb/hcpDWDvRpeyHNm7er+zK9R074EnzUrG4J1v+PZx9Mst5pusMMuweZlX144EdxwJKg0I9cEQIuDRDTwUWBVJoCIKQaiTLN1lPAWb2DLRkaRZ1qGWiFWB6aq+WmTcjy/MFQRCST4jjgB46k0yHdv+exZeP4ODclez25nHK98/xyH9uAo8y1aAdUyLnzizfVgzAsIG7SMtOo35bPc75Ts11MU66tQJrMxG1CYKQFBIoaIuIledsNRGaNEpDXtlXtkbCtVjzZvQyPJs+byaAuA20OqQ5I7Q64lKB2Rrps99cprQ5h2Pn7iC7AVZ3zOC0lTNYWHqW5g3KA9i8wMjo7OSbA0upsdsZVlPPjlePpKtjDy91+5Q2aQ0srOnK37eOw4Od0EEizoBLPWszi0ihIQhCSyLlGr4JsTRragWY5M0IgiC0GgwcB/QwnAiyYAd98X6vMu/4k+ns2MUqdzeOeH86n311mv+sejyx5jigsGSLL3dm/5wN4HJTNn0HeNQWn9otWHYcUJNAazM9QUciHAdS7jeOIAgtgqYUtOmRfEszrQU0hI4JicybUfabN2/GaCwXt4GWjzRnhL0C09wZgCKw2+sZ1XsqT2y7loOX+B7qs3t05cz5Cyjb08t/YTbGXfHwAiO9/R56TZoFOQ3sU+fmmR07GDdwMy93m0kHRw2r6ttz3eaTaAhRBCco4FIPRZ1Q6t82o7WZFBqCIERLU/gmJ8XSTIvT7KSepZmiADOzNEvdvBlBEAShaYg0QWPVDjpS7oydRp4YeiP/HnITWTYXX9QdwmH/nsPvvx+kuVM65o4DoFdDrbZ1Z483i3xbLR0X/Yx7l159pGdtpjgO6DRqtIec4ZcAMVubxYXkdQqC0AQkaw4mIZZmapzR3MSq80C06NVHkOp5M+I20DKR5ozQskiECkxLR0inigkFt3D3kifptw4a0uCFbsdy3bdf4fZm6bxJ7devfh1UgaW1qaXklhmkF9ZQt6UtJ67sSnYD3OL9jpKMSja78rhy43iqPGmq+xqpwBIccBkFybI2EwRBSATN4ZsctaVZwhRgsRApbyaBCjCDvBmz8TgwjmuW54sCTBAEwRrJsIMOoST0Zb5tNzMPOJGbOr4GwJO7LuCEtz7DuUsZFLIxHkci587YHDban9Sd5R7fOHxAjZWGkt5YGcFxQE0M1maKYCOS44CRylzyOgVBSBbxCtqskHBLs5jdBhTUqyvVxGoPDeEratSI24CQOKQ5I7QIkqkCy2rYyf11Z3PlnC/puAt258FNXe7gmfXP+B6qhs1wYxWYPdNLz5s+J7NzBQ3luax/bCzfr+zDpu/aUVBdy253JldsPIGdjbmq++mrwGzpDRQc+S3dr32Uzhc9Q1bPFaSytVlIoSHWZoIgJIimKDL0sFxkKCQl1DIWBZjVvBkt8SnArGJ13BYEQRCSh94EkNnKx/7231nUcwRjM36gzpvOZRsf4OaZz+Fp1CqLwVxdbDwetRvVgYyiLBbX+8bf4d02aa7QW2EKlsdHJ6GTfjFam+kRSdQWSRhihDgOCIIQL9EI2lLG0syJzrSXImgzGgsiuQ0YoTeOQVx5M+I2IFhEmjNCytIUKrDC8pU8v/o4TppTSk4D/NnZwdldPmSu8yKdN+ipwMJzZ2xpjXSfOIOc3jtxV2ZR+uiJuHflsv/KjVRvy8KW5uGOilGUugpM/iQuHAWVdDxrGvs8eRfdJvybtsMWUXjMF3S5+Dnjt6WAtZkpYm0mCEIyibPIMGtAJ8TSLKpQS/UxveviUYDFmDdjAVmeLwiC0PKwkjtzfMN0FuQeQ3/7RrZ62nNM6Xu89vO1gJl4IhsrjgOBq/vk0uagdgDM/b4AgCO6b0TfwjNGx4EmQvI6BUFoMcQhaGs6SzNlNaQWIxFbLII2pU5Suw6AuA0IyUaaM0KrJKIKzOvlgKXTeXn26Ry2qBY78N3AIk7r9iOb3ftGuLs2d0b1MLel0fXyWbTZfwOeegfrnzyehm0F/K3DQsZkrsNrg25H7KbkwDr/G8K7/fbMaopP/Zr+Dz9OhxPn4ciroWFn+8BVDdu76H8tp+Z1qlibSaEhCEKSaA7f5KaxNIPkKsDURKEA64C++qsk8qfI8nxBEISmITBRY2AHHfXKRa+XSdse4lP3RRTYqlni2YdhW77i+53H+M4XEGzix+A4oEyE2XPTKDyhEwCVi3bxzZJCPF7oV7iL4tw9mvvpOQ5ox01lbI3DcSAOazM1ktcpCEJTYWkuxaKgzYjUsTTTug0oqJv38eTOqJsykMi8GauI28DegTRnhFaNXufZ5nJz9rf/x+Mf3MqANV7cdnh9xHCu6DUfly1X5y5GaCe90ul49mwKDv8dr9vOhmePp3ZdMRe3W85l7X8G4Ne+XcjrUs/Yvn9q3usGu4d2I5fT75GpFJ86H3umi5o/erD+6csofWQijbWZAFQtP5CwQiNJAZdN4aEshYYgCJFIed9khYRZmiVDAQYJy5spwlQBphZDWMp/0yjABEEQhOQTyXGguPdGMly1/OfHi3i05kHSbF7e8YzliD1z2NzQxzcWGHnqR5U746NoXAfSch007Khj99dlOOuy+WVHMQBH9thg8ieJwdoMLE7++Sn1bxNgbZZoRNQmCEIkYp1ziWRppve8S66lGRhbmrmILGiLJW9Guw8x582oELcBQY00Z4QWQ0JUYBUubv/Xhdzy/lt0LoOKXLj93Ot5uO+r4dcWEJUKrPC4b+lwwo8AbH51LHtW9OGkNmv4e/H3ADy28xD+vftAAEYVrgSCD87cfTfR977/0vXS2aS3raF+ewEbnj2HP//vKqqW96fr5W+Tll1P9ap9qFg43PzPqC00zAIuS/1bC9ZmyfJQNkIKDUEQzGgO32RDrFqaRUQbaqkmXgWY3goZiEsBFiPKeB1Ynq9BlucLgiBERzwTMkYrGzvs2M7894/hotrpNHpt3J5xI+dnv0c9eRE99INYy51pc1Ae2b1z8Lg8lE3fCo1ewME3633j71E912M8GQdRW5s5NccS4DgQa15nCFphiclvG0EQBC3NldEZNTFbmmmf7WpBWzw0Ud5MSeRvIm4DeyfSnBFaHUYqsMYNtdx3wwWcP/NncutgQyc7F932H2aU/DX04qhVYNm0PWwJnc/7DIBt743G+f0gDs8p5f86fwXA67sG8+quwcxZWkydN53OaU4G995NeuEeul/7Jb1u/Yysbrtw78lk61ujWHPHFVQu7gfYKBzzHbkD1uCpz2DzK1eB1+SfrVPzOoHWZnHRHAO+IAitgqYuMqKyNCvVuahFKMDCff51i4sCzPNmNMvzRQEmCILQOhj181d88cI5HOL+jQpvLqf1eJmHut0Ldnv4ysmImOfOpBdl0m6UL2fGOW8XrrKGwDvnb+gBwIiQlTN6lmbqrULTWptFIiZrMw3iOCAIQrQkWtCWGpZmEG5pZuQ8YEbz5s2I24CgIM0ZIaVJJFlcdgABAABJREFUlAqs8YdK/jHxRo76sQK7FxYPbsOp//ieNR2Hhb4pKhWY78GcO+hXuk54B4CyWUdQNnM4gzK3MqXrJ6TbPHxa2Y9Hdh4B2KirdfB9TX/qbdD37IX0e+Aj2g5bh7fRRtkX+7H61gso//IgvI0eAHL6/0Gns2cAsO29U2jYmY+hKiAFrM3EQ1kQhKYg1iLDiIRbmqlxGr0hxRVg2l6N5M0IgiC0PAwcB/RQJoaOnjGH1z78O1285ayhK4eP+ozp/c7yXWTkmV9AbLkzaVA0Ph+bw0bN2lqqllaiHpPm+1fODOm4nfzMOozRW0FjgVS0NpO8TkEQYqCpnwNJtzQzRGm8a4/pXReLPTSEW0BDc+bNiNtA60eaM0KrxuuF/Dd+4+q/P8+g1Y247fC/s/rzlwd+oK6PptMdtQoMsnpupMfE57E7GnH+OIRt755Az/RdvNjtA3LsLr6r7s4/to7FizJQenmzvgOndO3C6gFbsWc2Ur2yE2smn8G2tw/HU5OFMmikty+nx8T/YnN4cC4Ywq45I1SfHEEF1kzWZqZIoSEIQjPTZJZmSVOAWZl0SrICTK9BY6AAEwRBEJqHSBM2enbQXpeHE1/8jMcWPUa2rYHZWUM5/oZ3+a3HEOMbxeA44MM3VrUbmUVGsYPG6kbKPysPu8PWPW1Yu6uANLuXw7trxxftmKheUWpia6bg1LxuRmszEbUJgtAkaAVtCcjoTBq16h1l3susKRPJbSBWJG9GSD7SnBFaJhZUYI2NNnLv/YkrnviErjugKhum3n0sT9//jG9JvpoYVGDp7XfS8+YHScuuZ8+v/dn8yoUU2at5qdubFDpq+KWuEzdsPhUXaQBkdKym581fsujgP9mc7qDY3cjWf41g3UPHU7+pfci9bRkuelz/EY78amrXd2bz1LOARmt/N07N6yRZm0mhIQhCstk7LM0gVAEWSf0bqwIMEqYAsyhmMFueHxi3NcvzRQEmCIIQG7FMzCgTQLYyFyc/9hV/2fY5AK90PpULJ71MVbuC4MUlmjfHlDsTJKuXjfxhWQCUfbYHT41H97pv1qutzfQsPdXjpfq41jJUOYa540AzWJslAhG1CYKgJZo5llgzOpWmtGVBm9ptQP28dUb8iipqMbYxi6U+0iPOvBktJZG/ibgN7L1Ic0ZoUVhVgVVXOvBc9SenvreaNrWwsZONf//nTL4991wgtFOtSwQVmC09k+4THyS9bQW1G3qw4ZnryPG4+Fe3V+mesZsNDe24etOZ1HgzsGe66HjmQvr+3we0GbwZj9vOubvrmb5pC8Pc2YCN0MKhga6XfkF2yXbclTlsmHIu3oYM/zkLAZdWSIC1mRopNARBaCpSyjdZIS5LM/WDW88vP1EKML0CQ6Hp8mb01NqCIAhC8+H+2cbVz/2Xw+p/p8abyT8PmcCLV11JV4eqI9FN0/hJQO5M/lDf65rVLur+bNB/Gw7mb/BN+B3Vc73et9dsQb8ho4y7OrQEazMRtQmCYEKrEbRVY6FJE6l2UraRmjTiNiCkDtKcEVKeaFVg634tpOMFyzni2yrsXlgx2MHr0ydSvr/BZFuJ5nVEFZiXLn95mJzef+CuasOGKTeTVpfOlC5TGZi1mTJ3HldsuohdjTm0PXQN/R58mw4nLcfu8FD1U1fW/ONU+qzpRY7Xy/GDSjX3dlN82iIKhq/E22hj4/On4CovIOEBlwotwNpMEIS9k6YuMvRoHkszaHYFWALzZuJBlucLgiAkl22fFnPeh7Pp5i1jo60DT559ASvGDTZ/k5l3fhS5MzWrfM/4rF4OHG3tBCfGQBmb0osyWJLnE1Qc0n0LJecXk90nx+DeRitoTHBqXjeBtVlMeZ0WEVGbIAgKKZnRaZVa7Y4y96W1MVO2kQRtzZQ3I24DgkWkOSO0Kj6ZMYSCq8rY9w8PbjssOi2H5e+egrtNbuCakA51DCqw9se8Q7sRn+L12Nn4/N9xlXVmcqfXGZ67impPJldtuowdHV30uu19uv91Buntq2nYkc/6p45j/ZNjaNjelllrfQ2O0Z1Cly0WHLmK4lMWA7Dl9TFUr+xJQgIum9hD2XKhYVEFJoWGIAgQXZERLUm3NDOlBSvAVOgpwGR5viAIQvMQmLjRsYN2u+3UPpvHCYuXkWNrYHFWf16aeCY1g9qE3Sei4wD4mvgFeieMc2f2/JxG3XoP9nQb7Y9rAzbI6JhOm6H5dDitI92u60OXy3tRccT+bPYWkmFr5Iiemyg+sxOdL+tG7r5ZYNPWSep6KUrHAafBdQm2NrOMiNoEQbBAoudKkprRGZOlmd6qR7WlWTKJMm9GD3EbECyg/X+aILRIPF4brz81jP2nVZJbDxW58NukDqSd1xkbdvqyNvKKj87o/+guCO7m9F1M5/MeBWDbf6+g+vchXNX+P5zedj5ur51bqs6i8sIf6XvEMmx2L556Bzs/HUbZ5/vjDYwbLmYu6QqHwv6ZG+lUWMu28mxy991E10u+AWDHJwez+xv1QOjC989VCbg06dYrhYb/e7f3lLOv5zf23f0bebVVVKXls7umgLn1oynvEdkseiPdLS25/Jn9rQWXjSbY9RcEQdAhEUVGylmaGYZatnAFmFbk4CcWBZggCIKQfHbV5rP55Z4c7vI1z2d3GczGy7rgcWQErunDGv3xsIRQgYJSSkQUf2UT2g1xAC7KZ0GXy7xk98qg+42dsGeEjv+eBg/1W2r5Or8353cp5+Cti/mmYF8yOmRQNL6Y/EMb2PqfcmhU6qR01VZBb/JOM+5V45tcUyjDWJBQiu/vYZMNunnZ8Wd3wwbWGvqYjodavCP1s1QBnyBFNV6O7yfKaUEQ9IkoaGtOSzMjLFuaKWOJ1tbSipBNQb1KUy1kU9tvputcq2BQO8XhNhCPoE3cBlo+snJGaHFoVWAVs7N5c9oQDvmvrzGzsTPUvw5p54WvuVd3pqPNnXHkb6PHlX/D5nDj/GEc5Z9fxvj8OdzQ4R0ApnQdxI77PqTdiKXY7F4qFuzLH3dcys7ph+B1hfZBt+/MZkW9L9zyhKGbyOy2ix4Tv8Tm8OL8oS87PjrEf2WkgMtwa7M03Bya+SOTs+5htn0031YfyUu1V3Lj7qeYsHUqN216kvt+mcyn35zEiWs/jWhtpke01maJWD0jCIJgSoKLDEUBFvYsLPVvE2ppBimjAIOgAswIEwWYGZEUYDLJJAiCEB+RJmjW7OpJ3ar27O9aT403k1mHHkD7K8uxOfQnhkwFWnHlzvhw7wbnt768GXuGDU+9h9q1teyeV87WN9az8ek/2PHeJmYt7ATA8LTf2PTCn1QscAKQUZyBLU3d0Ilmtakfp+Z1M1qbJQJxHBCEvYcWJWhTMHMbAIMFj2aCNi2x2kPrrV2QvBmhaZCVM0KL5jc6s2pdO4atrQNg5RA44dHV/N5jHwazwvIyTF0VmObHeMexz+Bos4va9QPY/No9DM1ewj87PQnA/w7K5MfjfscOVK/sxbb/Hk/tn4UER5bwQWP29v7s32MDwweu5rux1aTluKhe1YnNU0cCNnyDiFoFptxH75+tl4OzlzCuzWzGtPmSIkd54IzHa6M0rYTf7IMoyy0kP72K/apX0L/2Dx795hY25PdgRecI3tJ+1tI3bOnlavYJFHPq1TPL2w9M6FLM6TabKAIEYS/AqMhoCt9kUzaZFD9xW5olQgGmoPbsV16DZQWYuvYoIGoFmNny/GiQ570gCEJiWZQ1mP02bCHb1sBmWyEN4xooPiR8ZszQcaCb13wsVFA393Un2ZRVNL5VLpULoGHnHjzVLhp21OGzGwgd++au6wnAsC5baEMNrp2+hk791nq8DQ3+eyp1krpesuA4oD7tRN+aZju+ZtRWTDN39BwHol09E4LWcUBWzwiCEC/NldEZtaWZUb6yVtAWyW0gGlpG3ozQ+pDmjNAiGO/1hk3YfTZgENlLGxlQ7suXWTfazan3hU8K9WdV2CRcdzYGBxFtodGRMJuatNwy2g6ZCcCWf99Nv7Y/81zRtaR73PwwwMa7Y93Ube7C9vfHU7V8AFAH1OB7qCuDgtJo8fH5b724sqeNzwZuJSMzjfqtbdkwZSxed5rO34CL4AARvEeevZqT87/kvILP6JO5PnB8t7uAL+vG8KVtLMuyDqQm318lFQJF4PC4eGrDDYzeMZcT1n3GisHmzRmr1maWkUJDEIREYlJkROObnDBLMy1hk1ORLM20WFGAaX/SaZs06uMgeTOCIAh7B267nUWegxieuxqAn+0l9Oy/lvy+1Sy38vzHt0IyZCWpnh204jjg1L5bGW9qNceqUJopdX+CmRhhY2Vb/ihvR7/C3YzosYEFvXsBULeuxn+FImpD9VoZ74xyZyJMsiXI2kyNWtRmhKm1mUVE1CYIrZ+mErSpSaqlmSWMGjXafSukRzif3LwZM8RtYO9DbM2EFocXeGPoQXRc30incnDmwp4LKjj5mPDGjBpLiiUDJVT74R9gd7horG1D37Ev8lKvS2jjaWBVV3jy8GI2/fsq1tz1MFXL98e36gXCA5hD9+f+1IGbO3RgTWYaVGdQ+sQYGquVxoyb8Im64ITdPpl/MLnjC8zrM4E7O75An8z11Hiy+ajiZCZsfIWRa7/hnl338V3jkdSQG1oklYHbns6nXccDcOjWhXFbm1lRoEc78Oshy/QFYe8kmiIjWpJiaeY0+jRFAaalhSrAYsibEQRBEJqOD8vzWXnqIIZ38DVmFtT3Y79BK2ib6ZsJUyaA9FY46q2EBMJXThYRsYHvI5vQMUdvFaeRdtTBnFLfB4/utY6skhwAateZTdRFMWY6CZ0cTJK1mRrL1mYJVrkLgiDo0eyWZk6Me+kBS38FbZ0UjaWlQ7PfPHkziXIbEFoH0pwRWhR1GZnsOmswQ5fWkFMP6ztD90M3MHybb6YsWpWR1dyZzPZ/ApDtqOSfq+ZQXAFbcrO4wXkvv9w+ld1fHwueNIIPaodmq8VL4VmL+S43i0yPhxHLe+La2UbnOrXFjZfhOb8ztdsz/K/k/zinYB459nrW1Hfn/u3XMHLtB9y57R6+rzkCt3rwMFAh2Co9AOxJNwsWCEcKDUEQkkVz+iabEoulmaVQS0i8pZkZic+bsYoyAWi0PF8UYIIgCImlZt8eDL+mHftl+fJlllV349A9S7DbjVdURL3iMercGTAeq8ybNHPWlQBwTL/1pOU48NR7qN9Sp3MfbWan4WyfjwinAyirZfUEGir0LH6aI69TRG2CsPcRUdCWpIxOQxJiaaZ9SKsbMZEEbSmSN6Mi0W4DskqydSDNGaHFsLt7MRmje3Pozz6P4d8GwOitq+iWVhN2bcJUYH62fnMHW7/4G1e83pV+W8HpbcPlKz5iw9Lx4MkjWhVY4fG/UHjMSvDCQzvLOSs//Fd+VskOOl84jy7nzea4ghW83/M5pnZ/meG5q3F77cysPIS/bJjMyaVP8Y5zPNUendk0p+a1auKwc8M2ALalWZtpS0ShYYoUGoIg6JDMIkNN01maKWgVYMoxveuaSAEWQ96MWuQgCjBBEITmZ7zXS93xgznlzGq6pO1mk7uInfPSObDut5juFzKRZLBiMgzF5sVwcaZyQjteGe/PK+0BwJCirbSjkrrSavBAJMcBH4qVqDLRp5dj4Mep2jdbNVPq31pwHFCjFohYFrUJgiCoSLSgzQirgramszTTug0kErXbgFIviduA0DRIc0ZoEXzxwYv06Nqe/uu8uNJgcd9GzvhtJRkeawVCvCqwxroCTl7kYeTODbi86Vy38XnWu3pp3qRXSGitZBzkD/uTzucu9r38cj+OralleO4fZGf7BpqM4gp63PA5A+98n7MLl/HGusU82XEag7K2UuNJ543dIzj+z8n8beulLK7tRdBGTYO23nCq9stg8J6fAFif6wvYjMXaTI2VQiMR1maCILROmqrIUGPJ0kxbZMRtaaZVgKknlaAlKsD0EAWYIAhC0+NyNfDD81dy1qGlZNlcLKoqoc2n2+m5dUtU94nZDroDBt772YSPN1YdB3zsqM7jlx2+YJtD7SupXaed0dOuQNWOrxZoYmszy2iFKCJqEwRBg9FzIEBzZXQqxGxpBvqWZsl0G4DQMSn+vBmrRHIbEFon0pwRUhqv18vb91xM0X1P03E37G4DP+dVUTAjfv+TaFRgQ90Lucn+BAAP1N3OEob6z2i752bd9HSy+26l25VzASj/aiC/vH0wWxvbkWVzMeqQzXQ8czH73/1fLqwv5dnnG7nqcw+dd4PTm8FzZSM59s9beHDHeLa426vua0EFpqldcjzVHOX8BoA5HaOXnUuhIQhCShBDkdEkvskxhVrGUlhEomnyZvSW5wuCIAhNS/mOzax+9BiG73gPgBnr+rHlyRV885N/UFImdvwTPcrEj1nwsHpFZIgddInOxU2QOzN3g++DD7f/qtOcUaO1NotALdasSCNYmynCjqRbmwmCsNdieS7EJKMzKYI2LQmzNNOK2JSttvmuraX0nv3RuA2gOQfx5M0kym1ArKBbJ9KcEVKW6ppq3rnkcA58dyHZDVDazcbmLZtou3BzzPeMRQXWwb2DJxr+hoNGpnlO5j3XuZo35KM/2RWqAssodtLz+pnY0xupXNaDrW8dAtiZXd6fT3NzaDhlHtfnLefFl1xcONdDu2ooawPP9+3IsX9cy3PlI3A2KtZlUQZcqtkJo2rmkuWtZ11WCSvbDLBszyOFhiAIycCoyIhoaZZgIvomK1h5ZjqJM9QyWsWvlbwZEwUYJCRvxmycFQWYIAhCcvhj+Xxczx/FoIafqfZm8f433fD8Zwn2KFYixmVDGfUKy8g2ZnpNmvlVPoHFcH6hsVJv5aneFvQn+ZTjOjhV+1FYm+mhJwCJydosRlGbIAiCGQnP6CyN4j2WBW2gb2lm5jZgFXEbEJofac4IKcmfK5cx54xDOHCBE4BfDm7D6E9+IGvLnvCLk6gCc7R38cSumylqLGOVvT/3eu/B0EbMRAVmy/DS44ZPceTXUbuuiI0vHANeO1k9y3ivsJK1S9vx2CsNnPaDl9x62JDn4LkT7Vx1Wnuen30WNd4M//3iDLh0+janVv0PgFntjwP1pGgTWJuFEWHhjhQagiBYQd34NXrmqJ9RZo3lprM0A3MFGDSZAsxq3oyKSAowBbPxGEQBJgiCEA+LPn6e7v87jU6UsdHWhbLzPuOsOb/GfD+9iSJTxwHtykoFy7kzCpEbNgvTBuHx2ujn2ErnvCrVNXqWZuqtGrXjgA4JsjazLPgwIRGiNnEcEITWQ8yCtubK6FSIZGlmirp+SoSlmR5meTPaFZ8RsOg2IHkzghppzggpx5z3n2PTJefTd52HBgesPPsAznprIZk5bUOuszqZE7MKrCPcvPJxDt6zhCpbHjdmP01dxIeyXiHhoOsls8jqWo5rdy7rnzoRe6abo878imcO+JiHPqhg7DIvGY3wu6eQuwZ045aJXub0y+TPF8biqVP+maon8NQDkMWAS3+hMbDxV46o/R43aXzY4czm91DWYlEZL4WGILQOkl1kNIlvsprWrgArsf4t4lJhC4IgCKa4Gur58bkJDFt+O1k2F8uzD6PtDfPpOeDghNw/FXNn6joXscLry/08tvefEb6cUe1kgtPCNRatzfRQC0OMVOkiahMEIZmkpKVZrXYnkqWZlmiaNNHkzehcm8S8GQVxG9j7kOaMkFK8d8+FtLvvWTo4YVc+VN1zNafd907gfDzL9qJVgR237nMuKX0dgDuKH2S9vST0IRyY3zLLnXHQ7ujFFBz+K95GG5tfPpajD1rAf4a/zosr1nHESi92L9R2tNNzdBmvn5TPqtO24fHa2fjiMdRvUrJl9NRfsVmbTWh4BYCZeePYnNkteCLJ1mbqQkOszQRBSCSJLDIiovesjCnUsjkVYNr9CESZNyPL8wVBEJJL+fZNrH7sGA7b+T4AP3a7nMGTPiO/IMKSxxaRO6Met1Sk2cjsnM23nv0AGNtHr3mkl0GgoLd6VXNajVO1nwRrMzWRhCSGiKhNEPYakiFoS1hGZ0ItzfRWNRpZQGtRnvlW3QbUDgN6bgNqkp83I24Dey/SnBFSgtqaat79y2EMfncJWS5Y191O9/+8yZFn3pDwz7KiAmuXv4t7vp8MwCudL2d27rG+h64h+rkz9sx6Op31FTavl4OWZvJG4cdM2fILB6730GiDmd6unFZ6OrO6D+LHkjTWDfY9tLe9M4w9P/Uw+KzYrc16UspY1xe+P1fBBP1CIwprs2g8lE2J0UNZCg1BaJ1EVH4mqMjQI/mWZs2tAIPmyJsRBEEQEsPqpd/gfuEoBjWsYI83m2WHP8dhE57AnpZm+J6kOw5AgnJntMeCjZvMzlnYHHbmVQ8CfCtnbHjRn4wzW5Fq3XEASLi1meR1CoLQXCTieRK1ZWNclmYQ+oxOlIhNjbYpo3Ub0Fv1qUMT580IrQtpzgjNTunKpcw+7RCGLKwAYMWwfI6Z9iM9olmSn2AV2HXfTKFtQyW/5w/k6SE3RvGgDVWBFR+yhONX1fDkvxq5/as97LOjkfo0+CCzO8evvYC/rR7HqvpCpm0r4rYOhXhtsHtuf8q/3BdoUN3LLOASrFqbXd34Amk2D3PTjuaPjP7BEzFam8WDFBqCIFhusFpUhkZLalmaNZUCzGLeTASizZuR5fmCIAjxs/B/z9Bz2ul0pJwN9q6UnzeTA8demPDPaZ7cGXNLs8xuvmu/X9+Z6oZ0OuVVs39HveWsiXMcMMTI2qw58jpF1CYIey3JFrRZyujUErOlmfagS3PSaPVMtPbPRhjlzURBlG4D8SBuA60Lac4Izcq8959h48UX0Ge9h3oH/HbewZz9xgIyc9pYen8yVGD9N63irOU+i4AH9r2dRrt5oRBOOu3TKplYOJ2PdnzKFbM8dNkNezJsvJ7fi2PXXsHdP5/GZldbIB1HQQ3OE5dQa7czvLaW/qu7U3zaMgY+/zZ97vkIbEqTxkrApTH7Vv/KKRmfAPB85rXhqgWIaG2mDMCRVGBSaAiCECvR+qWnpKWZKWaWZtr9eDBTgEFEBVgRTZI3I8vzBUEQIuNqqGfBs5dxyE93kmlzsTxnOAXXz6fngAMN35OIiZumzZ1R0K+9MrtkAbBnayPzSn0OA2Ms585AtI4DQPTWZibEk9cpojZB2LtJaUFbwizNTFYzhojZ9ARtkWoptX2mnnhNK3hD5xoSljcjbgOCFmnOCM3Goi/eoeDe5ymqgLK2UHXPNZwx+c2kfmZEFVhXD7e99yBpXg+flxzPkoHDwm9SgGHuTJ+M9dzT8QVm9/4b1xR9RgF1bHK34+G6kRz9yyQeXnQ25fXBgsSW4abH9bNJb19DYV0aR9bUseeSeRSf8hNpOQ1kl5SRlttAOFEGXNZ6ubXdwwBM94zn16r9gueisDbTI55CQxCEvZeYi4yWYGkWEmoZjaWZmc1ZNGgLEAhVgEnejCAIQkvip2fP59CyDwH4oceVDP7bDPILCuO7qcZxwCohjgN6JDx3xrd17fKNjW0OaM9XG3xCi2DujNlYGpvjQHNbm0UlatMgojZBaP2ktKBNISGWZlq3ASOirZ/07Mz0BG16r/0Y5c2oELcBwQrSnBGajYNGn8XafTJZ29NOz/+8xYgzr7P0vmSqwI75YTaHrVpAnSOTx8ZNCr9Ax+ol3dbAuDYzeL37zUzvdTVnF3xJpt3Fz7W9uWnzZZyw9g5eX38iNV7FzN9faNi8dLtiLjm9fSOW0w6PFrajNrMxcG/nD31p3JPlf2Ul4FJ9LFhojMqdyyFZi6jzZvJU9Y26f/bmKDTCfiBIoSEIQhTEoyRNDUszBSMFmJmlmYJBcHIsCjAjRAEmCILQrLQdfSPltGX5ES8w/LJHTfNlImIw0aNMDOmtfDScUCrRvE5K7oyPiu8rcO1uwNE2nQWdDgVgRI8NZKbpKab1tkm2Niv1bxNkbRYVSVLLC4LQ/LRIQVtclmZgbGmmbtQY10e92zk5uPM2w/NBlAaNUX5njHkzJRY+2o+4DQggzRmhGUlzOBjzyhcc+7+F9NjnoMTdOA4V2DE/zAbg3VHnsaWga+gFmgduD9t6/lbwKHN7H81jXSYxLGcpbq+dL6sO58IN93LuhruYtWcYjaiLp+BDv/DY32g7bF3gdWNGI1keT+B1/da2bHn9UP8rqwGXEGqZA+k0MKn4UQD+U3kxW71dgpfGYW0WCSk0BEEwwqjICGvAJrjIMFOARXy2JcTSDBIXaqltyOgt049SAVZA1HkzVhAFmCAIQnz0O2AEObf8ygFjzo/5HtFO7CQvd8ZsPDK2k/Y2eCn7dCtej5ctfQewpTaf7HQ3R/bYYPgeH01kbWaBprQ2E1GbIAgKzSJoMyKipZn2IWzmLBBeSznsjay9/lUWX/k27bLU9zJqwATe6d82Td6MuA0IaqQ5IzQr+e2Kycoxk+tGQQJUYJ3KfN3137trlpH7GzM9XaVcVv8Kb9vP5fM2x3N5/qu0d+xmq6sTz5RdzbF/vscNW+5iae3+gPLDV18F1m7Ur4FX3kYb5V8N5IAqXyMnzWVn/ZTReOoyNN/YKODSeHLvkvb/pldGKWXuQl6uvCJ4wqm6qBmtzaTQEAShqYnKNzkuSzP1vtVQy3hJgAIsirwZPTW1Mt4q468RogATBEGwTnautUxONamZOwO+cUg98WWUOxParGnYUkfFj+WAjW/TDwBah7VZMvI6BUFo+SRb0GZEQjI647Y0g3BLM6N6Keg28NehPwWO1ri0DgNqIVu6zjHJmxGaB2nOCC2eRKrAOuzyjRqD1v9CV+cmDt62mJP2TOeGVU8xbcV4Zm48gUkNj3OA7Sc8XhvfuEZw7c7nGPvnl7xQfgM73GrJsdpLObzAqPixP67yXJw/9GXNnadSW1rIj219BZTH4aHP3Z9SdILZEke9IiO0qOiavpm/Fj4PwKM7b6Xam+c7YWS/I4WGIAhJxnKRoUXzbGh232Q1hpZmVkMtteesNmmaWAGmItEKMEEQBKHlkJjcGYicO6M9BxXflrPnZyffeny1xfH7r9fc08zaLAqcFq6JwdpMRG2CIKQKkSzN9J5Xic/o1EM9txWdpdmUE+YF9usbjVZj6h1Pbt6MGZI3I0hzRmiRJEsF9ntvX9PgL7Pf4Mvnx/DGzIt45JtbuWrtv+hXuwYXDr7PHs59mXcx2jOXq2tfYm7taBoDD3etCgz0u+8Odn4ylFV/u4BN/xoNaR66Tfg2cNZrg7RsF3n7bSY0c0BPBaanvq4Fariz+P/IstfzY/WhTK8c7zvt1FxuZm2mHWCl0BAEoamIojGbMN9kIxJqaaYXammuAAvHigJMe22MeTNqtBY2fhKhAJPl+YIgCE2IMuGjsYM2W/GYCrkzkA5eKJ+5nc8W+CTKg/M30+coYzs0H5FU1xqSaG2mJtq8zjBE1CYIrZbmErQlFUuWZmarINE5p8/89V1Nz+u7DWhXdCYub8bMbSAS4jbQ+pHmjCCoePyFidz2twfZVNgVV1o6Gwq6s6DTIfyv62n8vfdDjDjoWyZ0fpV3M85nh5mk2FSdHK4CKzh0ne6V5V/uq3PUyNoslGPzvmdk3o80eNK5b/tkwBZ7oVFqck4HKTQEQdAj5iJDQ9J8k5OmAIsv1NIaehk0Ccib0SzPjztvRhAEQWg2rE7wmE0YNV3ujHnDZe2cWn6u9E2+HTdYb2zSW5mqoDcJmBxrM0UAEmm1bkLyOgVB2PtIsKCtZVmauUP29+0Q/OAbZx2tOmfVbUAhdfNmhNaJNGeE1kWcKjCv3c4nx5zC2Ae+5IDnlnP8NV9w6Qmv848hDzB9wClUOtpG8WXSCT7UzVVgFUv6UL2qI7u/7Utmo29yMN1tp2ZtcYTP0A+4zLXXcEfxCwC8susSSl2drH/tBFibqZFCQxCEqEkl32SFqCzNEhdqGR0JUIAlO29GlucLgiC0GKKaQEpo7oy+pZmWL1b1BODo/N/9cZ96q04jrVTVG6v9OA0/OkgEazM9Euo4EEHUJo4DgtDySFVBW/NbmukRfMY/MfbrwP7SrR3RdxvQZsy0rLwZcRtonUhzRmgVRKsCM8ud8b/Aa1f989BTgRVgQQWmPmasAqsr7cC6B8dR80cx9Wm+h63L4aHkls/9V0QXcHlLh9folF7O+oYuvLTrTJ3vhLEKTKG5PZSl0BCEVkVTFBkJtTRTnn0JszSDWEItjWliBVgEolWAyfJ8QRCEpiFZdtDJzZ1RMGvSOPj89xIARjh+wVGg18QxchxoHmuzpOR1CoKw95KKgraosGJphuaY8fP7uL7aDDKr6LkN6EUW+Elg3oyC7evI1witF2nOCEIkjFRghljJnUlHr0njaBtaBaSnezRXmAVc+gavw3J+4+yC2QDcte0GGryZobeoJbS4UO8nqdBQI4WGIAimRFFkREtUlmZ6NJkCzAwreTNRKsCsEGF5fjyIAkwQBKEZSZncGWuOA+r9bzd0p9aTTifbbgb32h3hM8IFbeGDd4zWZnpCDhURs+2iRPI6BaH10uIEbQqRLM2qiUHQBr5nt7YpE96kSbc3BvZv++rICPdXuw0o9ZLRSk4TkuE2oFhBi9vAXoU0Z4QWS8tQgVkJuAxOtO2cMZj1Tx4b/H5b21n4nGChkWOr5r6OUwF4a/cJLK4dbPXLhhKlh3IkrA74UmgIQuskUUWGlqT7JivEZWkWKdRSKS7MQy3N0VMLJz5vRo9ELM8XBEEQmo7AysUIEz9Wg4oTlzujPmYtd6a+0cGPe3zCizF9//QfVa8+1XMc0IojlFy4OKzNFJJobSZ5nYIghJAqgraEWpppL4hUHwWf99cesjxw9KkfDyJUpGY+loSeT628GXEb2DuQ5ozQekkJFRhYf7inQ6OdxpqMwJFNB5SS2dmpc62+tdlNHT6hW0Y5m12FPLnzXIKFRpQqMIUm8FCWQkMQhAARiowm901uckszt8G+FawowOLLm1GLFwzHTYLjbWB5vijABEEQWiwR7aDVxJ07A1YdBxTm7PQJMo4tNprwMrI2s0AKWZtpEVGbILQ+Wq2gTRc9QZv6oat9dptbmj15XNAXrL7RSLhmlDejvkZDgvJmEoG4DbRepDkjtD6aUgVWgIEKTDsBZqQC08mf2dCe9PWFALjtkNNHmQnUZhCEbg/O/oML2n0DwN3bLqPGm6Xzhf00g7WZ1UIjWqTQEITUJln/FmP1TY6ZZgi1DCcaBZia5ObNKFgdX0UBJgiC0LQky3EgIklzHAhOrn1ZNgCAIwrWkOWwqrIG31hsNIA3jbVZPHmdYYioTRD2HppL0GZEJEGbrtuAEdpVj9rnunmTZl5ptwj31zZlmiZvRtwGBDOkOSO0GqKd7EmICsyUbPQDLvWKjeAkm6c+nZ//eQIZHl8Rta8mMkaPHNseHuj0HgDvO4fzQ82gyG+KRBzWZvEUGmE/JKTQEIRWSXP5Jje/pRmaY7Fmz4C5Aqzp8maiXZ4vCIIgtFzUKydDJqBKNBcm1XEgyO/OYrZ425Od5uLoklL/UTNLUb3sGbW1WWX4hzij+EJROA6osSJqizevU0RtgpC6tAhBWzyWZk7isDTTq5eC4uX9ioOdoJtmHa26RqcOCjnXPHkzCuI2IChIc0Zo0aR+7gxYU4GF7nsdXhrsvoGv3ZifDe4bHKz+XjyL7hm72OxqxyM7TyN8aWjTWpupMRropdAQhNZPVP8Go/BNjvaZkTqWZhCLAiwycSjACmjWvBlZni8IgtCMKBNABnbQlh0HtESVOxO74wCA1+VlXuMQAI7vq1Xr6TsORJX11kTWZrEgojZBaP20eEGbLom1NHtibNDSbPm2YkLHC3WdpBWxaffFbUBoHqQ5I+wV2L6OfE3T5c5EDrj01GbQ82ffYFjaqRLSPP4zWmszODr3d84qWIrHa+P2rWdT7YliQHEa7Cep0Ij0A0BBCg1BaN3E65usJqm+yQoJsTSDaBRgxiRJAZbsvBlBEAQhZUi444DaDjrq3BmwljujPefD67LxtecAAE7oa0U00AzWZn4hSCIcB0TUJgitjxYlaFOwKmiLytIMjFc5Gh2DMX02WLivdv4tCXkz4jYgxIg0Z4TWiUYFphC1CiwpuTNgVmAoLP9kv8B+cbsa3WvapVVyX6eZALy++wgW1/bBsgI74pJSmtTaTAoNQWhdJLLIiMc3WY+IvskKCbM0g2gVYOG0cAWYf1wWBZggCELLJTVyZ0L33VUuvvMMwuVNo3/hLnq326Vz72a2NtMhUXmdEX8jiahNEFosKSloKzX5kLgtzdR1krLV1kvukP3MtODrW74cEeEzlBpKO74kLm9GD3EbECIhzRmhVdFycmcUjJs0Wze3x+HPnRl27hKde3u5p+OXFDlqWF3fgafLtL/MU9PaTAoNQdi7ibfIUDdyjZ4h6meOaaO4ySzNIinAYrU0S5ICTI0owARBEFoFiZzYabrcmciOA3WlNTir0lni7Q/A8X2V76ZdjZpka7M4RG3RYipqM1HRK4ioTRBSh1YtaNNF+0CNZGmmR/DZfcNhSwP7UxYcqLrGzG0Awt0GtHN3JojbgJBgpDkjCCRbBQaxqMA8tRnsX1oIgGugekLMNxCdkr+SMW3W4PLauW3riTR404lpks9psN/MHsphSKEhCC2G5igyEu6brBCzpZn2AjMFmPp8tI2aBCjAtEjejCAIwt5HU+TOFBCD44AanSaNF6p/cQZyZ04Iy53RI5LNqELTiNrUQhIreZ1RI6I2QWhxJFLQZkTSBW2GlmZqtwE1eoI28zmuh4/9NrDf0OhA321A7Tqg3YfIjRwdxG1ASCDSnBFaPCmnAjNstkenAgMoXegbIN0ZddjtDYHjXRxV/KN4HgDPlh3OyvqOhAdOJ9jaLE4PZSuFRtTWZlJoCEKLoimKDDWmvslGKM86dVMmYpERbaildt8q0SrAtPsGJChvRhAEQWg5BCZ+ZpleFiAhuTOmGDkORM6d2fNzFV97fM2Z0b1KQ2xujFFP/hk5DmhwGtwq1fM6dRBRmyA0P00laDOyNNMj4YI2J9bmnSJampnPcc38o8TKh6jQE7SpXovbgNCESHNGaP00tQosgKICM5okUxca+k2aBT92A2Cnw87wEZsAsOHl/zrPJi/NxdLazkzdNUz1Dj3/5ARZmynE6KGsRgoNQWi9NHeRYdQIDsNIAWYFZxTXGirAtMfMsJI3o6cAa9q8GcPl+RYn/gRBEITUJxUdB9xOF0vXtWebtx056W5G9FzvP69elao39kYhlDCzNlNoQXmdgiCkLi1C0KaQEEsz5Xg0lmZuDu26JXD0pllHq67RE7Rp3QYgvFbSruJUoZc3I24DQoKQ5ozQ6kg9FRgY585oi4tQhViDsy19KtPx2mz0GPULAH9pt4JDc7ZQ43Fw+9YxeGj0vyeSZY4JToP9KFVgUmgIgmBEIgMto6X5LM2iD7W0jl5T30gBpjpcQLMpwGR5viAIQvOSrAme5s6dAahcWM7XfmuzcQOMBhy9ib9ItVMLzevUitqicBwQUZsgJJ8WL2hLAUuzZ06YG9hfVd4eY0GbFu140rR5MwoBQZuw1yPNGUHw0zQqsGzVvkV+6wlAbccy+mbs5MaixQA8vOMINrry/BcZFRoWUI+JRgOo1tpMCg1BEHSIq8jQoP23r27cGj0zIvkmBxrIRs+sJrM002LleR2PAswkb6aZFGCCIAhCCmLgOKCHJZvLhOXOGFuaKdStq2b27gEAjBu4LvJ3w0VLsDYzI15RW3MKZwRB0KdVCdqcRGlppqCtjfSbNMO6+oq3Oneayb1TL28mbGwVt4G9HmnOCK2ClMqdKSBC0105qVWBGVibfe2bYNzq8PBEz7lk2huZt6c771fsa3B/7VJ9PVWCgbWZGjNrMwtIoSEIgoKlf5MWbAv1iMs3udS/bXJLs1hXyDSvAiyRyPJ8QRCEFCLChJAyodR8uTMQKkYIdxwAB59+2xm3184+eTvo2dbpPx7JVrSJrc1iyOsUUZsgtC6aW9CmJiZBm4LR885Q0KZ3TH1cXSfpPZuDlmZtMuoDRyd+FqmQ1NZQqZ03I24Dex/SnBH2DqJQgVnCTAUWgl7uDEQTcLnt904UuuDk76CvvYLd7kzu3nYUoB4otYWGngpb6+epwWmwX6bZRiCaQkONFBqC0LJJZJGhpVl8k5NuaRaDBWWA5lGAmamkJW9GEASh5RHrBFCqOg5s/aWBpY2+3wAnDvlTc1Y74RfnmByrtZkO8Yja4kVEbYKQOjSVoE1vTiYqQVskSzMwELRp3QYU9J7B5s/lfxy1ILD/2vJBqjNmY4a6sa/Nhk7NvBlh70CaM0LrJmVUYOB74EcfcAk2Dvgpj1N+9H2HydsPp6wxh9CAS6NtnNZmcXooRyo0Iine40EKDUFIDWIpMlqOb3IiLc2ibdSkhgJMGUcjIQowQRCE1k3MuTMJdBzAA7OdPneBcQOsTIppayUjazONKsNpcLsUyOsUUZsgpDatRtCmEJegDcIb5XqWZuH8/YjFgX2P146x24BWxKbdt+g0AAnNm7GKuA3sHUhzRmiVpK4KDKINuMy1N3Da3FrsXvh+P5jv6mbhMxJkbaYQo4eyGik0BKH1kkpFhpqE+CarcUZzk0RYmimIAkwQBEFIDFYneqwEFSckdyZA/I4DAJ+X+mqOkcVrSLc36ny4nuNAkqzN9IQfkNS8zlgQUZsgND8pK2jTopfRqSYuSzMF7TPZrdkPjmNvrzAT/GrHCG1N1Xx5M+I2IKiR5owgRCD+3Bm9TrxRdz68wLi9+AeK3fXsbAv/GpvGyOONOk9WFAYxWpspxOGhrEYKDUHYO0hmkaFG/UxJiKWZGqfmdVQKMCOrSasTQamjAEskogATBEFIQZSJoTmhh5WJJLOVkqnjOOBg8ar27PS2JS+tgSN7rPMfj+Q4oFyTYGszhSRbm0UtZBFRmyA0Cy1S0Fbq35pldMZlaQbhz2QjmzMfx/UpDezf+uVRJl8MgjWUdhxR10sqYYDkzQjNgDRnhFZDtCows9yZuFVgYUSjAguqoI/JW8fpbVfj8cL/jsqmNtNG16Hqp7WeHY5WcdCE1mY6SKEhCK2b5vq3FMkWUffZkxRLM2Vf2/w2mwBSiMXSrHkUYDHlzQiCIAgpT7QTQdFONIWQZMcBgLotdXztGQLAuAHa3Bk9UtvaLBZRW0THAR1E1CYIzUdTCdrUpI6lWaTc5HCeHRcsNjZXtSF0bs1sjDByGzBB3AaEJkCaM0Lrx0AFppAQFVgRCVSB+W+ZVsO9Hb8FYOquIfyy2/fDvKx4N8YNF6O8GcXaTEvyrM1aQqEhDRpBSB66//a0DdQIRYa6QdssRYYap9Eb9J6hRsVFoizNrCjAVK+bK29GlucLgiC0WvQmnprbcQDAU9PI7Epf7szx/Y2aM81gbZaAvE4zRNQmCKlNKgjaEpbRqa6XLAna9I7pWZrpPYu1lmbQt30FANv25Oh9IMEmjdptQM95QNwGhNRAmjNCqyVWFVjUHe6oFWAQWQXm5f5Oc2jvqGNlXXueLTuYBZ/si93rZV2mnf0OMltTqkZPCWZibVaLdWszrdK81L+16KFslWQWGoIgxEeLLDKMSLilGUQTahmOkWWZ0XWpmzcjy/MFQRBSi0RO+DSv44D6XPD1l+v70Oi1sV+77XTL3+U/btXaLEqsWptFiZW8ThG1CULLJ9GCtmiJpymsi1PvYDSWZnoE39Mhpyawf+1nFh5sIRhZY/pRBG1GiNuAkCSkOSMIFjBVgemhqJNDVGDaiTJ10RFqaXZ2218YmbeBek8at24dhYs0qnbm0rPGV3wMOUZtyaYNm9Yei8LaTE20HsomqAd8KTQEofWT0kVGqX+bMEsz0G96RxNqaYXWoQATBEEQWgD+CSOtHXTLyJ2B7RvsLPf6ao7j+1oRFug5DkRpbabet+o4EGdepxYRtQlCatKUcw7quRQjQZup24BVQVvSLM3MVzL+c/R3gf2PflcXnVbcBhQM8mYUlPk8vbFKPaaJ24CQIKQ5I7QqUkYFZkg2+qtmQBkweqbv4tbi+QA8UXYYaxo6Bq51rO8EgLvHNown+4yUB1FYmzkNvn4CrM3UWC00okYKDUFIOi2qyLCKUZHhxAC9VYjqgsJaqGV8NK8CTCGgAIuALM8XBEFIYSxOEKV67kz95hq+bvTnzvQ3as7oWZtFgZXJx0jWZjokUsUuojZBSG32DkGb3jErlmb6TZorD1Y3OmyEuw0Y5W/GkDcDMTnliNuAEC3SnBH2LpKpArOcOwNGk2lpeHi486fk2N38UN2dN3cPDnnXL18NAGB1XiOFHXRHPw16ljoWrM3UOHVu28weylJoCELqYqnIiEDCiwytb7KCnm+yEU1qaabQRAowNREUYArKuKmMowGU5fmiABMEQWgxxDohlJDcmTgdB3yENmtcO+qY5RwEwJg+68hIU4slwFzk1jzWZsnI67SEiNoEIam0ekGbGqfewURZmrmxEaxTnl04xOQ9kdwG0hG3ASGVkOaMsHeQbBVYzLkzyoDgGySuLPyewdlbqWjM5I5tx+IldCD/Y2lnihu8uGw2jjjpN9UZK9ZmMSCFhiAIOsRdZERQgKmx8iywYpcYgpkdY5NamhlhpADTNmQSpACzkDejEO04KQowQRCE1k9zOQ4Y7ztYsLyAbd525DkaGNnzT4PPMMqbSaC1WSQBiIjaBGGvpLkFbbo0WUYnxGJpduHgoCjs7rmHR7i/dhWNsm9QKyluA7pZaCQ+byYC4jawdyHNGaFV02QqMD0MVWD6AZcDMndydeE3ANy/fSzb3W0IX7Jvo8OOAgDa7rue6KzNlIGuiazNIg3sUSKFhiCkPskuMowUYHqkjqWZHkbKXTOMGjZNowCLdXm+IAiCkJpEnPhJcceBSOz5pYK5jQcAMH4/q2NYgqzN9ERtUVibqYk3r9MSImoThKSQyoK25rU0065mtGZp9uJJXwX2d9dlE1oHGaFeZamtlfTm50hI3oyC4ZgpbgOCCmnOCIIJManAOhBBBQbaASAdGw90eo90m4dZVQP4rGogRiqwzYt9E5Ib2lVjtzdG/n7NZW2mg/oHQHMVGmYNGkEQIpNqRYYaS5Zm2iIjaZZm6v0YVy+GkVoKsGgRBZggCEILoIU4DkTKn2msdPH5jn0BOKn/GgjY4agFEkaOA5GaNDo/DPQmJ5spr1NEbYLQ/Jj9O2k1gjY1Tr2DibM0A8hJ922XbCnWuVapk/REbNp97cpMA+LImxG3ASEapDkjtDqaXQVmmaDK+crCLxiQtZVd7hzu334SYDyQ//hZf7I9XsoddoaPWWdwVepbm6lpykLDDCk0BMGcZBcZZiTVN1khoZZm1hRg4agVYEaBxwlUgKmJMW8msDxfFGCCIAgtltR0HNCiVUmn65wL8sWKbtR70+mVu4uBRdsMPtwob8bM2kyD02BfQStqUyj1bxNgbSarZwShlRGFoC1aEiZoc2peJ9HSbGBReWD/2s+imfjRrqqxvgITkLwZoUmQ5oyw95CSKjAYkLmDKwt9X+6fO05jV2Oe/4z+pJy7IZ1eFZkAlAxfTfiS0CRZm6n3I3kol/q3KVpoyOoZQWgmIhQZ8fgmmxYZRug9y5wG+yEYrDpMiqVZkhVgkjcjCIIgJICmyZ0B/caMvuNA2aoGfvD4V88MTAFrMwWT7LtY8zq1RPxNJatnBCFpNLWgTT1nkjqCNqPjepZmavSfwU8fPzewv2BzF9UZvWaLUjtpxwvteJId3BQgbgNCsyHNGaHVk5oqMN8BB24e6PSK385sMJ9XDfZfpx5EtAoxB7WrfINoRSera+X1Brgorc3M0Fqb6ZBKhYYZUmgIgj6pUmSosWKPCMTnm2zZ0ix6BVjsNK0CTPJmBEEQWidWJ4C0jgNmJC93Rj2pZm5ppuDe3cAXVb7fE+P3+1N1RjtWa481g7VZAkRtZlhS3YuoTRBSgyQK2nSxmhcclaBNa2mmfmaaCYq1r337Y/ps0FxjlMepRc9twGh1JknJmxG3ASES0pwRhAgkM3fm4vYzGZC1kd3uXO7ffi4+O7PIk2wLPt0Xm9dLaZaNXgONfvFrC404Jgn1rM3iKDQSSTIKDWnQCEISSFCR0WS+yWEoViZKYZEMSzMFUYAJgiAIzYgyYTRH/7Qy4dTUjgOhY565pZnCpyt9vxuGd1hPu6wqzVm9CcImsjazkNepxoqoTes4IKI2QWh6mlPQZoT6+WHa9C31b5OW0QnWBW3B87npDYH962aOMrm31m1Az3nAgtMAxDFOiduAED3SnBFaJbGqwJoyd6Zr+jauKfwvAI/sPINdjW2wGnC5Y1MBfWp9/3wPPv43YrM20yMKazOFGAoN9Q8CI+V7cxYagiCEkopFhhpLvskKVn2TnTrnAyTL0qyJFWBqEqUAEwRBEFosqew4EERbH5k7DvyxOpOVnu44bB6O6/sn1ojD2qwJ8jpjRURtgtACiEPQZmRppkdLsjS7a+SPgf0XFg2J8AW1NZTRvgXEbUBoIqQ5I+xdRFCBKSRcBVaAptDwcmfx82TbG1hQsy/TKtXdf2sBl1kbiwGw9d5s8ctpBzo9FZjOJXokqdCI9APCCCk0BCHFSUKREbVvsontYmpbmiVRAZaMvBlZni8IgrDXkkzHAf0JtsgTbfWbapjtPgCAkwepJ86SZG2mRyTVeal/2xx5nSJqE4SE0aoEbQoxZXQm1tLs70csDhxp9NoJbcQbrZzUjg/qeknV+Fe7DRTo3EbcBoQmQJozwl5Bk6jAjAIudSxljstbyMi8RTR4HNy77Qp8dmbqQSVywOUvc3yTmWtzXbRpb1QYJMnaTKEJPZSl0BCE5qFFFxmRMHqGObUH4rU0iwYzSzO9a1JTASbL8wVBEFoGzZ47ExVmuTP6E3Rel5dZ23x103G915Bma9BcYWWiMMnWZjokK69TRG2CkMKkkqAtUkanmiRamkFwDHlt2SCD+2ob9tpxQes2YEJT5c0IggppzgiCBWLqhBtYyGTbarit+DEAXtl1NqWunqqz1ifZflnQlU4NHhrsNkaMbznWZlJoCMJeRhMVGbqU+rdWfJOdmteGRUYslmbhCrBwtJZmemjVYS1DASYIgiC0YJojdyY7bEfvJFYdB75bVcxubx7tM2oZ3n2ThS/notmszUTUJggtklQXtOkSj6CtiSzNzhkUHGP+/tUIo2+IseWlepVNCuXNiNuAoEKaM0KrJVVVYBe3/zcd03eyqaETL+06x380m+BAYU0FBnY67GwHQMF+6yJ9dT9xWpslwEM5HqTQEISmpamLjHhInm+yGdFYmsWygsbIxgxakgJMlucLgiC0QKKcOIo7d6YAA1GBIj7QrprRa8zoOw5U/VHDPI8vp2D8ACPxQYKszZwGl2mtgaLI61QjojZBSD2irpn0SLKgzbTJW+rfRitoc+qcT5ilmY+Xxn8V2N9Zk6M6Y+Y2oB0fDGolsxJK3AaEJkSaM8LeR4TcmaSpwAqgMLeMy9tPBeDJshto8GZoLtKGWJoHXG5e4htsN7arxm43UmMn0NpMDz0PUjWl/m0EFZgUGoKQWiSjyNCi/TesbsCa/dtXSJhvshqn9oDa0syomZ0oS7NIiAJMEARBSC7xThTFnDtjCSOXASMxmw/Xznq+qPT9xjhp37VE5zigvLZobabGaX46EjELT0iQql4QhOTQmgRtUVuaaY8ZCdp8x/MzfVaUy7YqA4XWbSDSuKCtlfJDjxUgbgNCsyPNGUGIQCJzZ65t+yy59hp+rt2fmVVjCVeBQTQBl9/P6E9eo4fdDjvDj4u20DBq0mizFfw4Dfa1xOihbJXmWD0jDRpBsICFIsNSA1WHpPkmKzg1ry0XGZAYSzMFUYAJgiAIrYPk5s5oHQcweB1k5m8luL129m27g5ICp4XPSYK1WYLyOo3sXeMStenVTCJqE4SINMeqmYQK2iJh9bll2LDWWppZf7YO6hD88KtnHGtypdZtQM95QG/uTYO4DQjNhDRnhL2GJlGBaVE93LvbNnBm3gcAPLrzFkA9GGrVz9aszdz1DkoqfEs7S4avtviltANhDT3T1zGuzTwmFr7Kk13+jze6T+TDnucxveREXu12KZPz72FE1jfY8DSLh3Jzr54RhL2Nplg1oyXaxmtcRYZCyliaiQJMEARBaD4iTgj5J5S0dtDJdBwIzZ3J17koOseBbb+7Wez11RQn9m9CazMn4cSZ16lGLWDR0hyrZ6RBIwgRSOKqmZgEbaX+bUIszSCypZnea33h8HPjZgf2F25Wd0mMBG0YXGM925kiYnYcELcBIVakOSMIKpKZO3Nhxps4bI3Mrz6SJbXDDO4efcBl1W89ASjvuNvkWyv4Couu6ds5r+AbnuwylW/63MbM3rfxWJcpXFP0Pse1+ZaDc35hYNYq+mT+yWG5P3JOwXv8q/gqPs49hRK7Tr5NDB7Kej8MjKzNtMjqGUFIMWJYNROvb7Ippf6tUZGhLjacBvtA9JZm6ByPldRSgCmEKcAiIAowQRCEFkwyc2ciOA6EE5/jQP3mWr6sHQzAyYP/JGnWZkarb1uCqC0Bq2cEYW8iIfmccdhAWyGpgraYLM3MnARCn8cjSzZrzmsFbVrUbgMK2nEjO7gpMPkqakqCu+I2ICQDac4IrZpoVWAKiVaB5VHF6RkfAvB69SX+o4oKTDtYqAeZyAGX30/flzSvl02ZNvY9aIvqmuDAZqOBA7M3cUuHOUwveYYvez/OXR0/5rg2yylyVFLvcbC8th/vO4/l4R1XcMPmf3DFxke4dOOL3Lb1Qd7cfSFVnjz6pa3h9Zy/0JVNcXsoq5HVM4KQOjRFkaElHmVnQnyT1cRlaRY51NIYM0szvWuiVIDFiDIOKuNiGMryfFGACYIgtBqa23HAnNgcBwBmrPb9kBnZeR256Q0WPisGazM1eqtyI1kElVq7tVVRW9SIqE0QmpVobKATJmjTEnNGp/aY2tIMQp+p5lnIBVl1gf0Jn4wxuEpZKakVsem5DRiswhS3ASEFkOaMsHfSxCqwM3I+JNdWw5rGPnxfd7jOJ2QT66Tb7u159K1OA2D/Mb+jDHA2vByQtZW/d/iW2b1f5a0eb3Fp+0X0ySzH7bWxqKYXT+88gYs23MCha57g/A3/YPL2K3h99wl8uedIvqs5hAU1h/BJ5ak8sOMfHP/nLFY27EMHexnXND4f/AJWPZRL/dsIKo3mLjTER1nYm0nW/8cTvWom4b7JupZmZl2aWCzNjIqPSAowSIgCTPJmBEEQhCYmebkzEIvjwE8rslnvKSbT7ubY3jpuAECzW5vpkMi8ThG1CUJiaFWCtoRldKpXF6rRPkMjW5r93+jvAvuvLtvP5ItpMXIbiEAHwseiRObNREDcBvZupDkj7FUkRQXWLcJDtAMc75oJwFsNFxCaNaMlNhWYbV1XABq6b2Fw1g5u7fADX/V+h7d7/o+L2/9Mp/Rqqhoz+KRiEDdvOZkj1tzKxRsn8K9do1hS25cGr1kzyDfq7m5szz277gXgxPRPaUNlqyw0BEHQp8UUGaX+bUJ8k5vD0kwUYIIgCELTk1K5M8o4FpY7E5/jQN36Wma7DgDg5APW0dKtzYxEbVYCwk0RUZsgmLJXCdriyuiEWC3Nrhn2U+CMNyyvWYtSO2kb9VoRtAWSlTcjbgOCCdKcEYQYMOyU66jA8jxV7Of5BYBv3P7RNqTQ0JJNdAGXaeyZ2ZmLZjdy4+tu3u05g0va/0Ln9Gr2NKbzSUV/rtk0jiPXTuC2bSfxedVAqjzqAUpRKmiriPBC4+eKIWxzdyTD5qJEvQZWCg1BaPGkWpFhRGr4JkNiLc0i0bwKMAVRgAmCIOyFNEfujCmxOw7g8fLZxgEAnNhrNTasjFNJtDaLM6/TKs0lapO6Sdgb2fsEbYm3NEuzeQL7T/xwkH9P6zZg9OzXE7RBSINfcRswzDhTURLcFbcBIVlIc0Zo9TSZCqxE89o/ITasdhFpeCj19mRrThedgkMZJKwHXNrwMiRrG5M6LODL3u/yROP3jF/opUMl1KfZ+bSyNxM3j+HItRdy27ZjmVddgsvrIHxw1KJWgemztdE3s9e5XmdJjJE3qRQagtCiaY4io0l9k8Noakuz1FaARcybEQRBEFodqZ07A7E6Dsz5qRNV3mw6ZlZxSNctBlfFYm2mrLZV4TTYj4YIojY1qSZqE4TWSNR2ZhZJeUGbmiRbmv1VtWpm8jy9WAAFPbcB7b7eXJuKAsLHHnEbEJoYac4Iey+JVoGpUT3MCxt9s4F/Onpb+BTjgEsHjQzP2chdxXOZ2+d13un5EZe1/5mu6Xuo8Tj4pUcWj55h5/FLM7l160jm7CmhwWtcmPgwVywECY6+afhUDG7luzlN3mbmWaqDUaGhRgoNQUg8LaXIUJNQ32Sn3gWxWprFSgtWgMnyfEEQBMGAuHNnEuo44KBiTR3zGgcDcOrgNSTO2kyD+pB69Uy0eZ0RsCpgEVGbIMRPTP9/bkmCNgXLGZ2RsGppFnrNMyfMDRzZ05ChOm8maNM7H4WgTdwGhGZEmjPCXkdT58405PsGk0zqo/gUX/GRbavn2LyV/8/eeYfLcZTp/p2Tj+IoS7bCUbYsWZaRs5wDTjjiCBiT810yCxuAheUa7rK7LOwCS3QAg42xjXPO2ZYsy7IkKx4FK4dRPHnm/jFTMzU1VdVV3dVh5ny/5/EzPd09M32Odabrre+t98MPx96F56b9N3474Q5cN+wtjG44hAN9jXhg31T83bvnYMHq6/Bv2ePx2ow6rBzRhZYB3dx7ic5u9ihzLPAqQp6hPCy1GwCwOze8/JJNo83aC48erg1VtJkICQ2CCEY1iYzqiDQzLXqrEB1gqkiz5DnAaHk+QRBEdZPIvjNl2CcOiGQ7+/DgjvwY5NJZpjeugNFmjIxkX4j9OsnURhDRUROGNt+RZrJ5JFVqiy5toJyHV08qbKlWxfDI0gYsobQBIiaoOEMQPjHqOzMS2Ng8EQBwXN9rmMLcyGlIG1w2ogfzWt7Bp4ffi9+N/xlenPbP+Onht+KSoUswtL4TO3sH4I7MUfj0pkuwYM3H8PUt5+HxA5PRlWvB609OxojeLA7V1eGUS1ZA3f/ANNqskpZUB8Y2bgUAbM2O9ZehLCERQkMy2ew1wKICDVHrRCUybPGVm8yLjYxiG0B0kWYMLweYeE48DjCCIAiiHxNl35lRMOg7A+gSB8ofK7n/rcnoydVj1uBtmDpst+KsEKLNGBH16xTx6llBpjaCUEOGtgJKqeQVaaYyCJdz9uT1xe0vPKj6ZbGVkaKJzTJtIC15a/6e1FbapLQBIkyoOEP0C2LrOwPgjUHH4IXWk9GEHtxZ/358d9C38aGmW/C+Affh2vRt+PTwX+D/jv027pr0cbw+44O4bdK38cVRd+HEgSvRXNeLjd3DcdPuBfjQhg/jjDWfx3e3vRfPHWxDT66+7HNyuTocvnMwAGD0PJObhFkzthIdmN2yFI2pXmzvHYWte8aWDmUMXi4jSUJDATnBiFomLpEhwhdcIxMZPBUiI4pIMxsHGCM+B1jF8nwPBxgtzycIgqh+Iu874+uexd8bZStoKu+xW1f04pXsEQCAy2avQuTRZowI+nXqTG1GxhkytRGEJzVraNPiFTvgL9LsN5c8VtyzZs8w04uBr7QBwLvXmSQlh9IGCNdQcYbo30TRdyaVwj+P+le83jIfLejC1XV/wT+03Ij/N/Ib+PaY7+OLo36Ky4bejyNa1qAx1Ys9vYPxyP4T8C/brsWFa/8J5637R/y/HRdiUccUZFEHnQts++L8pObmkfuAQm+YPLJoM0AdbSYKjXy02TGtiwEAizvmASgMsG0ylKtQaHhBQoOoVcIUGV4FVWvaPY47y00Gwos0s3GAlVZcFp+mQQ4wgiAIIpEY952RkYY0caAcXh95rTxtQO+ebjyy/2gAwKVHmd73AkSbZRTbIrpeeWRqI4jY6HeGNkZGsV1EZWjzH2nWls7PPa3PDC7sEXuH8TDtJCvMO0ob4LDpN0MQNlBxhuiXhN53Rvgy3zp2HD487lZ8uvV/8evsJ/BIz3l4ufNEPLr/XPwlcyV+tvMz+PymH+CcNTdhwZqb8OXNX8HtmfPR3nM48kUQswaXz987C63ZLHY21OHYs9fDf7SZnDMG5RuzvXrohMqDGckLxGgzGSQ0CCIWki4yePjvhPBzk1XoIs1ETCPNvFA5wDzwcoBJIAcYQRAEYYPzxAFxciyNvOlASSvUE3FqMxvjbytmAABOHtWOEa2HFGc5ijaTYduvMyTI1EYQerz+3RrPFUj+jhJraHPWo1MXaVbJlGGZ4vYn7ztXcoYY/SweA+SrZQRDmwzxHtSmOE9A1W+G0gYIG6g4QxABMMrpZ1/yqRSeazgN/5n7Kr7c8RN8rPv3+NKen+I7276PX+z6Ip46eAE297YBGCC8gXmDy85DTZi6txkAMP2UFYY/BWDi9h5VvxbvaX0DAPD4gXPyOzOSEyMUGrR6hiD84UtkGP5tuBIZfGFWSqIizWyLMHxxXZxAisABRv1mCIIgCAWmcdBe+EocYFiZDGR9Z8TnlcWad95qwLLsJNSncrho5mokJtrMsl+nialN1ExxmtpINxG1hPLfu485BS+qx9AGmOmkyv0/Pf+p4vZjaydp3l+WNiBuA6XVlgImaQMclDZAhA0VZwhCIIy+M2Wk4eECAyon4MwbXHaunAgA2DdOvMPKnN6qpmyVQuPcwc8CABZ3HI3tvWMSITRESGgQRHCsVoo5XDVjgrPc5ERGmvGPDMcOMEOo3wxBEARRgc+JJdvEAf99Z8RUAZ3JoQHdmzvwSMc8AMAV89Yafo5NtJkweZlRbNvgN5rIgyhNbQRRTYSZNFC7hja+0G0aaZbff9GMddwx/rp1kWay441wkjZA/WaIiKDiDNFvUE4WMaGRSBeYfYPLF/82G/W5HNa3ANPnboeraLPLhz4MAHh4//nlBzKKbRuqRGhQvBlRCyRVZMSfm1zqr1W5P4pIM9H1JXODAYEcYG2lTXKAEQRBECa4mkiyShwQSUPRd0acfDNPHACA+9bkTSLnjF+FlgaVHrKNNpOMJ1SmtpD7depWz3hBpjaiv+MszswB1WtoA0wjzQY1dRe3/+6hMwpbKn3EI37vGxRleKjfDJEAqDhDED5x5gKrEBoidg0ut28eimkH6wEAx5y/VPGhMtSu71nNqzG7ZSW6s424d997vd8qRqER5uoZL0hoEEmnZkWGCl+5yaIDjMd1pJnsmO57nhxgBEEQRHJITuKALNrMO3HgpTeGYnNuOAbWd+OsKe1wE20WAMf9OnWImolMbQThgH5jaFMhFq7NI83+5YwXi9u/eP1ozWfITGx82gDPEFSkDcjuKdRvhoiZflOc+dGPfoRUKoVUKoWXX3457sshEkDsLjCVy7noAvPf4LJu3WEAgO5Jm4Uj/qLNrkrfDwB4/MApyPQNK53qFW3GiFBoiEQtNKhAQ9QcSRMZMkxykwPBf0+6jDTzcoAx4nOAEQTRvyDNRADV2HcG8JM40LmhA491HwMAuOKYmKLNAvTrNDG18YS2eoZ6dhI1RpiGNluzaCSGNobTHp12kWZfOWlR8Uhvtp47z2slpCxtQLaykiMt2Uf9ZogY6RfFmaVLl+I73/kOBg70tN0QBAB/LrCyL+027oDsSz5teiX+Gly+/tAcAMCaAVmMmZhBkGizIXUHcMmQ/C/kL3svLux1mKHcXrmrmoUGQSSVahYZZbBCbrvh+UxkZLh9GfEkFkEiKg8+NxkIN9IsmQ4w6jdDEP0H0kyEJ3H3nXGcOIBsDvetz5tJLpq8Eimo7mEhRpsxfPTr5HHRr9O3qU0BmdqIasS3ZjI0tIkkwtAWc6RZfSpb3P7Jy8cUtsRIMwjPG6HvLebBKOgNbZQ2QERIzRdnenp6cMMNN2DevHm4/PLL474cImbCdIEZ48sFZtfgct2yUZjSAeRSKZx0afkSSzmqVTTAVeknMKCuE+90teGVQ0fAd4YyCQ2CiI1qFxnagq1fkaFckR9mpBmDHGAEQSQL0kyEjlpOHHj8jbHYn2vF2OZ9OP7wzYg02ixjfmoZAfp1kqmNIOIjFkObLRnFNgB1j04gSKTZZ497s7j97adOVrxO1VNM1EuWqMwAHJQ2QIRNzRdnfvCDH+Dtt9/G7373O9TX13u/gOifxOkCS0PR4FLErsHlwHdHAwAapm8Sjqic35UOh0bsx4fSDwIAbtp9MQDDG3zG7LQK4hYaFG9GECUMRXZtiAwdLiPNyAFGEEQyIc1E+EWVOGBMG7cdQ+LAvjWdeLp3LgDg8mNMjQqm0WaSVbkZyWmO+nWSqY0gglGThrb2wqOXoS1Qj05ZpJnaACwWv392wVPFI/u7m2UfWEDUUryhjb8HCPNqurQBkbbSpjKGE5Q2QLilposzixYtwg9+8AN85zvfwZFHHhn35RAJJBEuMCWiCwywaXD59hOzAQBrBvdgyIiD8BNtduGQ1zCmMYPtvcPw4L5ThNdbCA0VSRMaBFGDOI0zC1Fk8BiLDC985SaD27aJNPNDI+R9Z/hIM3KAEQQRLqSZCBlBEwfYxJVV3xmRCBIHcj05PLAlr5sumakTh7poM1FniZOV3G5GCP06TSFTG0FU4vvfYa0a2nyhmnMC1PqpdK+5950phS3xe5xH1bOTne+hndKovLco0gZkUNoAEQY1W5zp6urChz/8YcybNw/f+MY3rF63b9++sv+I/kmofWdGwdIFZt/gcslLh2N8dxY9qRROuextg88pdzbUoxOfGfEQAODWPRegp/h5DjOUZZDQIAhnOHV/SXApMvjCqxTZd0Moucm6XjK6pfp++83wyHrPMNw4wHR4OsAIgqg5SDMR1iQycUCcjLNLHLhvcRt6cvWYNWgrpg7bDftoM6DcSe5BRrFtQnvlLjK1EUT4JNHQxuPc0JYRd5hGmulSBSr3X35E6d7w+QfPlrxGVYwRv+ctDW2UNkAkiJotznz729/GqlWr8Pvf/95qaf6NN96IoUOHFv+bMMGuMTmRfGLvO+PpaFbdVCwbXCKFEVtHAAAGzW4X3ss72uziIQsxqWkHdvcOwp/2nAajDOWMYtsE8RKRMKFBBRqiv1NtIoORUWwDiCfSzMQBJm67c4AF6jfD7o+KiTlank8Q1QdpJsKUWBMH0jAwI4j3S/PEga3Le/FK9ggAwOXzQo42c9mvk0xtBOGEsOPMwjK08XMkRVwa2nxFmomPekPbLZc/XLr0fYNlH1hAZmLj0wZ4JIY2GZQ2QCSEmizOvPTSS/jxj3+Mf/qnf8KcOXOsXvutb30Le/fuLf63cSP9IfYbonKByUhDuJ/wLjD/DS7XvZgXGevSHWhqkd0Y5ZONDejEZ0Y8BgD43e5zcSjXovgEg4aXCRMa4uSxsdAgiCojaSJDh6fIkOEsN5mPNONFhp9IM5tCjc4Bxki2A4wgiOqGNBPhCj99Z6wSB6zQJQ6wOFF+H5Dt7MNDO/OGksvnBIk2001aasiYnVakvXKX1NAiQKY2gqgkjn93tWdoA7wNbXIGNeWPvbJpbGGPl6ENwnF+nyYSeiDkRRr+ntNW2gyz3wxBiNRccaa3txc33HAD5s6di29+85vWr29ubsaQIUPK/iNqm+S7wAC/DS5fengKRvVk0VFXh1MvX2bwOfkb4yVD38DEpl3Y2TsIf8qIM7BMaFhGm5nQXrnLtdCQYTzJTEKDqBKSKDKqMzcZsIs0s6W6HWC0PJ8gqhfSTIQJYfadMcIocUD2b0+VOCAjf+5dS/JGkROHtWP0wAMwizazuf877NfpgSpxQCS01TMBIN1EJI2qWjUjQzS0MXwb2iCcJEaaqag0sS2Y8G5x++P3vlfyGpmhrRHy6EqzGEv7gn8JV/1mKG2AEKm54syBAwewatUqLF68GE1NTUilUsX/br75ZgDASSedhFQqhXvuuSfeiyUST+h9Z6zw0eASdTh8Z16wjDpGvJHIneCN6MJnhj8NAPjt7rPQkatH4AzlhAkNIycYCQ2iSjH5txb3qhm+oKoSGVoHmNd3RuDcZCDaSDMIxwM6wHjaSpuBHGAEQdQUpJmIQCQqcQBwkTiw6q16LM5OQV0qh8uPsllZajIeiKZfZxSrZ8I2tRFElITdn1OG61UzZZgY2mRRZlbwqwN57CPNbr6sFGn29g5dNV5WpOENbZQ2QFQ33qOUKqO5uRkf//jHpceeffZZrFq1CpdccglGjRqFtra2aC+OSAwX53L6G/GTMJqcn4Y19jdMIH8zkN0UW1G4x/E3F/Gm14j8zYw96tm8cDrwvjewccR+pOr7kOsDyv/0ewvP84+XDX0T45v2YEfvINyeOdnj3ZnQaC1dKrv0g6hcFbQTpZ99JPJCQ9EXAZtS0hujitWYqnQyLMFRnvEKudMlk5BnodIReB6kgvTi6d4u8vtSKXJJELETh8gIbdVMaLnJfiLNTInAAab6XpVg3W9GAX23EUR1QZqJ8MN9q9xMrk/ARm/Xt6iX0oVHzxX6RUGF/D2X3bvZPbUHJe1Tom9/Lx7IHIN5w9fi/fNW439fnqd4f/Z69sjvB+SDDcXEYQaln4vfNqEd5UZAgY2YUFwpuwbTlJOLKzFTazxcPHxW0bChRKaZNHjpJtJMRBQ4N0/GFAPtxNCW4fZlxJNc9OiUM3X4XgDA+ozYa0angWTJAwy2mlJIG5Cl1VDaAJEgaq4409rait/85jfSYx/5yEewatUqfOtb38KJJ54Y8ZURVcEj8DVBORWrK1dyjM+VJhXHoXRjlBVm2M0iI3t3Npjnb4q8IOAfK3nu3pk47oKF2Ftfh5MuWI0X71dPkjamevHpEc8DAH6963R05uok78s/N3AoZBBYaGxfO6Ho+I5UaKgIUKAhiDCpBpEReNWMKRnFtjGqJfp6B5g5XpFmCXWA+XRLEwSRLEgzEWExb/dyLB4+C3PxFpbgKMzAO0qTxugpG0v3/DaU7vVjUL5aZBS8V+KXwQo0Pcjfa/n7OP+cbeeLNX99ewb+8VTgjNGrMLS5E3u7WlDSWcwY1yA8svcxgXeyKdiB/M+rMrVtQeUKoyoztZlABRoiTJwnDRhSvYY22x6dIpU6adbIXcXtG+45v7DFT1HL4vxVJrZGeH63piX7wug3QxA+qLlYM4LwQ6h9Z9o8XpQ2fXdxss6swWVPdyMm7xkAAJh4klisKHeEXzX0TRzWuA/begbjjr3HCefqmlpaZCirlunLXB2WfSZC6T1D8WZEldBvRIaYmyx+p6gmbzrEJ6pIMzFDPmikGcNr4oYcYARBEEQVUpiQ95qYct53ppXfGAL5xJxsVaqat95swTvZ8WhMZXHxkS6jzeLp1xlpJLQKijcjqhDrf5/92tAG+Ik0+/XFjxWPPLNetZJS9t0t7rMwtEXRb4YVqanfDGEBFWcIwoPAfWd4AvedAfw0uMwsmQIA2DY6AyCLyhtlLwbWdeGzI14AAPxy9ynozgFmET4kNHhMBnJUoCHiwLowUwsiw/M7SOYA4xEjzWwLMoCZA0wVaUYOMIIgCCJelBNJUfSdkRVmtP3WWN8Z/v4prk4VH0v07urGgwfeAwC48j264oyq+TUbK4jjCw0ZybbKaFJjpjYv3USaiQiDQP+uAvTnrBpDWxkuenTKWTBxMwCgs7fe4xq80gZEWtWHAOo3QySSflWcuemmm5DL5Wh5PmGGYW6utoKua3ApExtpKFxgJg0u1SmFz95zJFqzWexoTGHe6Ruk53x8+CKMaOjAuu7h+GtmNneEuchlQkMjNjKSba8BQD8RGgThmrjizEIVGTJMc5N5MuIOr9xkQB1pJuIn0kzl4uW/wxPqAFNADjCCqC1IMxFREG7iAOCdOCAey9+H73onP3Y597CVGNDYjdI9Xrzn84+mPelYNFABL1ObLoKI0V65q1pMbSZQgYZwSVKSBryIzNDGcNqjU6RSJx02eH9x+/q7LyhsVabAlOALNLK0ATZnNqT8ZWnIC/qUNkAkjH5VnCEInkhcYCrEm0Ea8oiaCkQXmExgVBZp9u9rxbR9zQCAGWesKOwtCY3RDQdww7A3AAD/seMU9KIedkIDUE54mgiNbZJzJLgQGuLkMQkNotqpdpHBUyYyIslNFokq0ky3mkYszEgizUSo3wxBEAQRESYTTdEnDqhMDTaJA3lefWMINmRHYUBdD86bvs7oNXl04wMPUxtPxvDjvAwrAXFuaqPUAaIKiCrOLDGGtgy3LyOeZPid5SPS7KfnP1U8cucy2S9d7C/D0BXYPRDvIZQ2QCQIKs4QRIHI+s4EjjZrhVpoqF1gnSsmAgD2HlZpLf/CiNfQWteLhYfG4YkD4qRpwAxlRkbzFjxsINFueL4E3eoZE0hoENVCoMKMiohFBl9oldJu9DYlMoptKaqVgD2ILtIM0DvAJJADjCAIgkgSrvvOWCcOAPrEAe9IM0b3lk482DkfAHDVsabRZjITB+80F3HQr1MGZ2wxMbWJmsnZ6hlKHSASiPM4M0Nq09AGBIk0e39ZTy/d/xevSDOLtAFAb2iTwAxtlDZAhA0VZwjCAKd9Z3icNLg04/m/zUFDLod3m4Hp87cW909r2o3Lh+ZX0/x4xymoXLoP7nmADGVGBEJDxM/qmajizahAQ4SJ9t9gTCLDiFBzk70wXSUTJNJMto89ipNKGsgBRhAEQURA5H1neAIlDnhFm8ljdO5ZnR/nXDRhBRrr+uAdbcYf8yKafp0u8L16RgWlDhAxEUrSQC0Z2jwxjTQTvwcrdVK6pbO4/YUHzyxs6SLNeGSGNqBi3kyVNgCo5+C4e1BgQxulDRCWUHGGIHRY9p2xanApI8QGlzu2DMH0A/lma8e8d2lhby++OupV1KdyeGT/VLzZOaq4v/QYUoayCe0+Xxc1GqFBBRoiLEJxf0UgMuLPTWbfVTKRIYs04/G7gobHtqmlBwYrZnjIAUYQBEEkhfATBwC7aLP8uc8uGoHtuTSGNHTirCntFp/Fxgkyk5vDfp3MqOLRr1M6toLd6hkytRHVTlRJAzKqxtCWEXcE6dGpjzT7v2c/Xzzyy9ePlryvGGnGnqsizTT6STu/VqDN43gAKG2AMIWKM0S/xrULjGFUaZe5wIzx1+Ayu3o8AKBzYv6OfcKAzTh90Eb05FL4yY4TNJ8XIEM5I9muRaERwAlGELb0e5EhYpSbbIrrSDNVU0sIxw0dYDK3MH8/8eg34xtygBEEQRAFIus7w2OVOCBO1un7c4p0bOzAw13vAeA32gwoTxyQfop8t0m/ThntmmMwcN1HBZnaiAgJ/O/FQuPXrqFNxLRHp5zPHrukuN2X001Jy/rOiBHQhujSBjhk9yJmaKO0ASJMqDhDEFHRxm07c4HJCjJqXrpvLlK5HNa2AlNmbsfXRr0OALg9cyTW96S5M2UOCPbcMkNZRy0JDQ0kNAiXBC7M1ILI8JWbLOmPVYYqM151LjTHRWTiwqcDDPDlAKN+MwRBEESouOg74yxxAFAnDrBjmv4zWeCe9XMAAJdMXoG6VBbm0WZeqQMh9evk0fWgKFDNpjbSTYRLwoozi9XQpsKXoU3VoxOwjTRrbSid852nTypsmUSayUxsPIKhTYau34zE0CZNxdFBaQNEAKg4QxAcukmnSPvOpKFxgancA94NLjevHYYph/J/9h875mXMbtmFA32N+MWuY6AWGrbRZhyyaLOM4VvVmNCgAg3hgtDcX9UsMhgZ0w9RRZoxVEUZ15Fm/D5GNA4wBjnACIIgCFPC6jsTT+KA7Hn5StbHF41BJjcQo5oOYMGETQafIzrJY+jX2V65S9WvMzJCiDcjCBNCSRoIQKSGNlWkGY9vQxvgN9Lsn09/uXjkR88fJ3lfcS5LZkgWtZNEP6XhnTaggPrNEHFAxRmC8CLMvjOym0NoDS7zzxvbD0djTw5nrdoFAPjV7rnY02cyIegoQ5mRUKFhPPlMTjAiodSEyJARem4yUCkywog048/hj0fnAHPVb4YgCIIgQqGN27ZOHBgi2WefOHBgbQce6y1Em803iTYD7MYMDvt1+jS18YRmalNBpjYiREJLGvBpaPPSTNVlaPMfafatU14rbnf1qUzFsu9o8ftbZlqWkIbe0Nbm/RZ+obQBwgYqzhD9nlj7zoikbV9g0+Ayz6v3zcVFr+WQPgBsy7Xi1j1HKs50lKEcg9AIsnpGhpXQ0GA6aU5ig5CRdJEhYi0yeNp9vs5XbjJQ/j3nJTKCRpqJz/liOlBNDjBank8QBNF/SUzfmTSExAH2aJM4UDlJmOvN4W+bZgMALp+xHEAObqPNFGQk2zamNgkqUxuvmfwQpqmNCjSEX6JMGgiD6jG02UWaNdb1Fbf//aX54lHhOa+XxO9pi6QBhoFeon4zRJxQcYYgQqbsS76NO2DiAvPd4FIdbXZg1QBc8VIWAPDcvEHoyjUg1AxlGZnCY0hCIyhxx5sRhEiUhRm/BF41o3OAeYmMDPcafluKV6SZCteRZqIIIQcYQRAEUeVY9p1JTuKAeCyvpR5aOB4Hc80Y35LB/HEmtnQ/0WaKyVBbUxuj3efrkCxTmylUoCF4TP89uNLksaya4Wn3+TpPQ5sKXY9OfaTZV09eWDzy7adOLmzx81a2kWZA+TxZ4SGtuXwG9ZshEgYVZwhCIIgLzDgqRmxwKYqNwA0uxf2lfZ8bsRgt3cCascCbp2YMLlbEp9Dw81E87Xanu149o4ScYETIRP1vofZEht/c5LAjzfhHcoARBEEQySesCaakJg7sWdmFp3uPBgBcdaxO58mKMiZo+nUyMh5vwQwsHokDUa6eidrURrqJsCFMQ5ttf04RZ4Y2EWtDW5AenXJuPPv54vahHtV3sEmkGdvW6CfZfBr1myESDBVnCMIEQxcYw9MFxhO4waVqIrDyxja56QCuTucn5G49qw6rBmUxbtIezfv7naTUCA02gZrxeIuAQiMogXOUHUBCg3Di/nIcZ2ZLaCKDkTG9ElkUo21usp9IM/F5bTjACIIgCKKIowmpaBIHvCPNGLmeLP62OW9IueKIFTCPNmPHVeMFh/06ZbRrjnlQjaY2gHQTEX3SgMncQGyGNt33hHJVnm2PTv45/11X+b1Xl8oWt3+zaI7H++sizQArU5th2oBRrKYFlDZA2ELFGYJAxH1n2rht6waXKswaXH511CtoSOXw5P4J6Bxdh2wqhZMue6tw1EW0GQmNMkhoEAGIujBjSmJEhu/cZN4BJu5XESTSrDYdYLQ8nyAIgnDdd0aK88QBtk98FA1v+ef3LpyIrlwDpg/YjlkjdUKFoZq45J3oItH064x99YwOioUmApDEpAFbAhvavMhI9gXq0akrQJfPKX322DeLR77+2GmFLV3agGylYwMqv8sFQ5sMA70kg9IGiCih4gxBRIC7Bpe8C8yuweXxrZtx1qAN6M2l8OMdx6F1/dj84emmE3cmGcoiyRIatpPGVkKDCjSEQ6LOSwaqVGQwfOcmA+aRZqYRJYB9U0tygBEEQRA1guu+MzzOEgcA02iz7St68XzfbADA1ceJ1yqurhUnLr3QRK9mJNs2/TrbDT5eQSymNgeQZuqfkKGtgFePTp6M1wf46dGpLtb894VPlT66s0VyhtibE4XnKhOyQj+lIe9Fxt87KG2ASCBUnCEICaH1nVE1uJThsMFlCjl8ffSrAIDbM7PQ3jMUbzw0FwCwemAfRo3fq/kMWeSPCV4ToohcaIiYCI2w482oQEOERn8QGRnuXH67AlVusk2kmXieDK+mlgxVpBlADjCCIAiiP+I2cYCZ2mTI+nPqo83u25LXTe+fxVb9iGMBnXtcHG8Y9utk1OLqGTK1EY4ItTCjoDYNbeH16EyhNAd221szNWcC8iKNGGlmUFhPexxvq9xFaQNEnFBxhiBMcdF3RkUgFxjg5QJ735CVmN2yAwf6GvHzXe8BAKx8cwwmd+aQTaWw4HI/0WaqBnA+os1MiGj1TNTxZjaQ2OgfxOH+ql2RYYIY1ygWZFxGmqkcYCxyhRxgBEEQRLIJKw6a4SZxgMEnDvAHVdpJHrPz14WT0Zurw1GD38W04bs1F6absGSJAzI0/ToZmcKj6IJP0OqZpMSbkWbqH4SeNGDYn1NG/zO08ZTPJX3iPW8Vj/zdQ+yXqos0A3cOf9wiaQDQpw1wmNxzmKHNBEobIPxAxRmCKJDovjMBGlw2p3rxpZGvAAB+tXse9vSVXjdoQz7arG7GBt2PwOGVoSwiZJbKos0yio+KafWMjKQ4wQASG7VOkt1foYoMGaLIYMhEhhLT3GT2nKGKNDMp1NhEmgHGQiPtcbytchc5wAiCIIiwCS1xgMdJ4gAgTxxQPQfY/frd5Tm8UIg2u/a4lYVjMfXrZOja31Szqc0DKtAQgMP/vw6SBqrS0FaGqaEN8BNp9quLHy9u7+qQ6R6xCGOyeoafFys8pCVvLRbyPQxtDDFtgEFpA0RYUHGGICIikAtMiXeDyw8PW4JxjQewuWcQbt1zTOFY/ub25sNzAABrBvVixLj9ms/xm6GsICPZJ0abyUiY0KBGl4RL4nB/mWIrvlUiowxeZLQbvKnuuyHj9WJdbjJgHmmmQx2JonboxucAIwiCIIiwCZQ4EErfGaB84s+s70yuO4t7th4NALhq9grbD4Szfp0Zj4+RmdockgRTmw1UoKlNbP6/9itDm4jO0Cbt0akr0gSPNPvL214iVizKqNIGNAxE5T3BIg6a0gaIuKDiDEEo0LnAmNBw2ndGJjSsXGCVk39D6zrwieH5XjP/tfMEdOXKJw+XLToMbZ059KVSWHDFksJelQtMhM9QFoWGJkNZJjS8cCA0dDnKfoWGEnKCEZY4K8xEFGfmV2SUOcBksAKs1998pvAYKDcZCC/SzKSpZQgOMA1B+s3Q8nyCIAjCCp8TVuEmDqic2KrHcu58fSp6c3WYGyjaTIdi3JKR7FP165TRzm33Q1Mb6abaIpLCTLUa2ky+DzJeVxS0R2f5HNIN894uHvmCdaQZ/6iIf1ahM7S1VZ5OaQNE3FBxhiBssBQa1pV3hlGDS9nNqXw56A3DX8Pg+m680zUC9++bIX23IZtG51850yvazLHQ4MkUHkMUGiKh5yj7dIJRgab/EfqyfAV+M5O9CFVkiDnrgEVuMoMvKjPE1YEQjsn2m+CgqaWNA4z6zRAEQRARkMi+M7L7ZQXiBJ/ozJY5tFm0WTaEaDODfp0MU1MbwzbyyJBITW1UoCE4nBVmdJChTYFXRKNcJ/3+0keL29sPypzHYjFctrJRPEdiaJNhsGJGd68RDW0mkKGN8AsVZwgiAqxcYMYNLgF1g8tGpOsP4vphCwEA/71zAXKKm9zSR/MiY/XgHqTHHCocizBDOUKhoVs9Eym0VJ8okDT3VyJEhilakSFD9p2kKjb7jTQzcYAxHEaatVWeTg4wgiAIIipC6ztjkzhghSpuVI2/aDOVOcRHv05GpvCoMrV59euU9a6A3eqZyExtHlCBpv/gtDDjIGlAhpehTfw7SqahTUTVo9PEsFv6/r5nhU4vyr6DVQVzhX5KQ546w98rLPvNiFC/GSJMqDhDEByxuMCcNrgsCY2PD38ZA+u68XbnGDxxQF2UWPLqeEzqyqIvlcKpV7xp+mEcATOUGZnCowOhkfjVMwA5wYiqcH+FIjJ42iX72N+6SmRkvN/WTW6yq0iz5DjAGDYOMIIgCIIIQqC+MyYYJw7I+nTaRZvdtbAUbTZ1mEm0GWA+nrDs12mCh/HF0zgTAi7jzQDSTf2BSAozlvgxtJkSqqGtDFUxRtajU5UyoI40++BRy4tHPvfA2YUtlaGN10uy72FDU1va7DQeZmijtAEiTqg4QxAaIuk7wxPYBQYAjRhRvx8fSL8MAPjZzlOhd4SlMKwQbdY8a73He/uJNtMslc0YvoVIhKtnAucokxOMkJBEkWGCE5Ehc2z6/ZvOeJ3gJzeZx2WkWfwOMHZ/YlC/GYIgCCI0oug7Y5U4AMjvu+bRZhuXZfFi35EAgGuPYzdIm8QBcRzis19nBnKqYPWMEp+mNoB0Uy3j9P+Xj6QBv4Y2r6QBJ6tm/BjaKuoxppFmuudynfSHKx4ubm85MEhyhizSjMF/94abNhAYShsgHEDFGYKwxWffmSgbXH5i+DNorevBmx2H49mD/MBA7lRY8hiLNuvGsDEHCntthYZIMoWGiJfQkGElNHSQ0OiXRFaYSaLI4GmX7BNFhkim8Bg4NxnQ57/bEHJTy7Tl5YAcYARBEET41FLigE202d1b5wEArp6zXH9yEdm4giUOyLAwtYmJAzICrJ5JlKmNdFO/w/b/UxRJAzJs+3PqcL5qhpHRHeTTT0RDGz8HZNOjs3R/eGDlZM1ni9+9jZAXyoGSfhLSBtKSt3Xcb4YZ2nSQoY0IAhVnCCJifDW4BIwbXI5uyODa9IsAgJ/tPBMAG9TIREdelCx5eSLGd+XQm0rh9CuWeF8fALmDwitDWUJGsq8KhEYU8WYACY1aI4mFGRmhiQzbVTOh5SYnKdKMoSnWkAOMIAiCSDiR9J3hcZQ4UCKMaDNxYlOHR79Ohmm/Ti9Tm4LYTW0+UwdsId2UbJwXZmKMM+uPhrarZ68sbn+2GGnG8BNp5lFAH4jKe4BF2gAztGnvRaB+M0R4UHGGIASCuMBEoWGMV4NLYxcY8MnhT6C5rhcLD7XhxUNHGL8u/W5+9q/JONoMSKTQiGj1jIokFGhIbCSTSHrMANaiNnEiwwutyBBR5SYD4UaaAeaRZswBViAiBxhBEARBRIVt35nwEwd4Q4TMSOEq2kyEN4mE3K+TR2aEUfTrFInc1KbDoWYCqECTVCItzITcn1NE9/cUn6FNhX9D2+1XPlDc3riP6Ry+4GISaQaEHWkmvdeYQGkDhGOoOEMQHtj0nRFhlXdptIymP0ARywaXQ+sO4IqhLwAA/mfX+civmjFzgb316FwAwMrBPRg5lrkoTKLN2PGECA2OIELDeY4yEFmBBiCxkTSc///w4f4KS2To8CUyTHKT+e0yTHOTdQ4wm2KMrqmluG3Z1DJkBxj1myEIgiBCJ+AElq/EAdn9U4qq94yeXHcW92w9GgBw1ewVHmfH3K8zIatnVMRtagNIMyWNpBZmTPEytPFEamgrQxY9765HZ4qLNHtoVZvmAnWRZkDlvJUk0kxGQEObFZQ2QDiCijME4YeENri8Ov08Wut6sLxzPF4+NB02LrC3Xz4cEzvz0WanXfmm+ocow6Y3Q3UJDRmhOsEMILFRndS6yIht1YwSk9xkwN+EiQ6ZA0wWaUYOMIIgCKI2CLvvjBSvxAErZHGk/KOcvy6aht5cHY4esglThu0Rjpr065QRoF9nzKtnyNRGuCAJmklFIlfNmBjaGBWmWBeGNjk3zHu7uP2J+84VjuoizVT7xNWOHGl4p8x49SorwAxtfvrNEERQqDhDEDHgq1LvcdNpRA8+OCyvhG7acw70vWbkDN44FgBQN3ODx5myDOXqFxpJd4IBJDaqCT8Rc/1KZMjwLTJsc5NVEybitm3RRvZ9a9nUUgY5wAiCIIgqwEXfmWgSB3QrXvn7tnyF7MZlWbyYLUSbHct6K9iMGWSJAyJkatNCBZqawrlm8knNrZrJ6E5SpZ2YGtrk33m/v/TR4vbm/YMLW7pIM9X+UkKMJ2nhuUo7adIGvKC0ASJMqDhDEBIi6zvD3xwCusDOGvQaRjfswfbeNB7eNx/lgsIs2mzRA4Vos0F9GDspU9hrmqHMjlWv0DAhbicYQGKjGvDzOw+rkWWQOLNQRQYrrJqKDEZGd9A0N5khPncVaRagqWUazhxgDHKAEQRBELFRWJlp20g53MQBoHLiz8TQlr+f57qyuGfrPADA1XNcR5tZ9uvMCPt1pjYZ1WpqM8CPZiLdFD2hFGYcJg2Y6Cavf//JMrTJcNOjsy6VLR654+0ZitcAlSsWZWkDIpyhTQalDRBVDBVnCMKA0PrOyPDZ4PLyoU8DAO7ZewZ6lO4v/pHfzp+/avFYTOkAsqkUFlzuFW3mN0PZg/4gNLzwOQmvg4RGdMRSmIkgzkwkcpHBR5ppRYZJbrJq9YxfQmhqmRae+3CAqfrNMMgBRhAEQSSdMBIHypFNDLL9av76+tRitNnUYbsLe036dbLEAUf9OhmyRuAifOJAu/fpUZrakpI6AJBuiopQUgaARCQNJNvQtg/e8zL+enR+7rjSHNLnHxR/2SbxZeJ5LNJM0FCtkBuYKW2AqFKoOEMQfgmr7wyPYYPL0Q07sGDgIgDA3XtPR/5upZoc1NPafhgAIDtddtMyFRo8/ESpuL9ys4L+KDQA50v1AXKDRUEohRkvQnB/yQhFZNiSUR2wiTTz2u8q0syyqaUMcoARBEEQCad6EgfEaDMRVcKAOtrs+b7ZAIAPnrgSZjju1+nVi48ZXrwmdSM2tSU9dQCgAk3YhKaZfBZmourPKZJMQ5v/Hp0/u+Cp0mUeGlDY0he6y89hj4aRZrL5Mf6ewN8rNPGY1G+GiBsqzhBETEgr9j6jzc4d9BzqU1ks6piJ9T2HcUd0DS7lN8lX78tHm60ekMWEGbsKe71uxKqcUlnVpTaERrUWaAASG2GRNJEhw1QMRyYy2N+w14o4r+8EAN65yby44PczXEeaqfYpmloyB5jo9iUHGEEQBFFFJDdxQER2P9YlDjAK0WbdWfx18zEAgGtnL/O4UNkKXt25lqa2TOFR7NcpIwGmNhVhpg5QgSY5xKKZNLg2tOmoLkObXY/Ohrq+4pHfvjFb8b6iLjKNNPNA23NMDfWbIZICFWcIQoFLF5g2WsZBg8tTBy4EADy5/xToJwa9o83WLRuF6YdSyKVSOOFS02gzwFtoAHELDX7S2FZoBIk30xJxljKDxIY7/K5ISor7KxKRwaMTGb5yk2UiI2husg0qh21ITS3JAUYQBEFUGwlKHKjEf+LAna9ORleuAbMGbsHsUdsLe02izdjxCPp1kqmtCCUPxE9shZmQkgaq29AmQ2VoE8+p5BsLXi9uf+WRM4SjfiLNgJJ+4tIGTCLNKG2AqDKoOEMQhgRxgTFYZd5lg8umVBeOH5C/ET578NjCCbJJQHOh0bjmcABA9+TN3N6gQkO8qUcvNHR4CQ0ZToQG4GSpPrnB4sHv7zDJ7q9QREa75GRbkZHRnRQ0N1nERaQZ/0gOMIIgCILwi8vEATWNUCcNyO/jW1f04enefOrA9SfbRJuZjDM0E6j8bnGclIDVM7HEmwGhFWgA0k1Bic3MBgQ2Q6qwNW86XTXjxNDm1aNT9qjmB2e9UNze19Vc2PITaQbk57I85q+8eouNq9xFaQNEUqHiDEEEIWDl3EWDy6lNa9BS14U9vUOwunuScLJs2ajssZwX756HulwOawfkMHmOSdMXwD5D2YOQhIZu9Yxu0hkI0QnmheEkPbnBoiPI7y0OkeE3zsyL+FfNiM/95iaHEWnm1dSSHGAEQRBEbZD4xIFWfoPvO2MbbSaYLvpy+MuGeQCAq2e+DUB3fbIxibhPZ2qT6KeM5uNkRLR6RoazeDMq0FQdoWomL3wkDbjqzxnbqpmM7iRXhrZyw25zfel77CcvH6N4vWmkmWHSAKBPG+ChtAEi4VBxhiCSQIAGl5Ob1gEA1nZPATAAcpeBbKm+WmhsWjsMMw7mvx6Ovdg72qwp1Ys5LVtwxdDX8OWRD+E/xt2C30/4H9w56Ue4p+37+OPEf8V/Hfb/8Onhf8bs5reE10uizTIeHyniaPWMiInQcIYDoQGQ2IiCIL+rsAszruPMQl81Y0pGdSBobrJ4ju64DJUbTOUA84AcYARBEEQNkLjEAU/8R5vd/cokHMo1Y3LrThx7OHOhqBIHVOhW0zgwtckmdQ16WdiY2iKLNwNiL9CQbjIjdDMbEEnSgIzqWzUjPndraPvuGS8Vt//hiVOEo34jzQCpoU1ELMgbGtoobYBIElScIQgNflxgTGioXGDGDS5FFC6wMQ35Wc7NveKsnRht1ggboZFaNREAcHASPzLI34yH1B3AWYPW4ZujnsZdk27D69P/A3dMugX/OvYBfHLE8zh/yFKcMGANjmzZhBnNm3FM62qcO3ghvjjqDvyl7Rv448SvYHrTMnhOrAYVGu3yt61loQGQ2AiLoL+fuAozfuPMvHCyakb8WxZFBp+brBQZKpcpIO+FJXN9BYEcYARBEAThSQISB7yxTxzYvaYbj3XnneLXn+QVbSZOeOrGIWzylMejX6cJMoNMO7ft09QWBOepA4YEWZlBmklP6JoJiCzOLBGrZkzJqA6EYWgr8c1TXitud/TaxDrrIs0U+imNyu92lV7iCGxoo7QBIkSoOEMQFpi4wEwJ5ALjGlx255oAAA3Km6dXtJmc5/96NBpyOWxoAebN34DTB27EN0a9ijsn3YUXp92C/z78IXx4+BIc0bITDakcdvUOwAsH2/DHPcfih9svxFc2fwCf2vhJfHzj5/F/3v00frj9Ojyy/wR0ZxtwTOty/HnSl3HKgFcKnxaS0ODRDHyizlFOcoEGILEhI6jASFJhxpTqXDWjy01WiY2oIs0AZaSZCDnACIIgCEJOgMSBPLJoMx6Z0UITbZYF7liX7ztz1fSlSBWjzYL26xTx6NepijSqgtUzvkiIZiLdVE4kZjYgUYY2L83E/93o4gKLBDW0lSEztMkwKRbLXgMMaCwd+/6zJyhebxJpBjgztPH3BpM4TBsobYAIASrOEERQwnCBWQiN3X3DAACTm9YX9qhuaKpoM3GysQEDUj2Ytmc/vvBIDv/3pl784eAT+MX4x/GR4W/jyJZdqEsBa7vS+HNmNr6y+XycueajOHXNZ/DJTdfgB9vPxy17jsfD+2fj+UPT8dKhqXjiwBzcsuc0fHnzZ3HO2v/GcwfnobWuC/922L9iXINiJjBTeFQNNByunhFxmaNczQUaEhsJEhgOqX6RwRNObrIZtpFmmpWLaZADjCAIgqhqYk0cGCM8F00OaXjM+YmJA4BN4sB9Lx2OfbkBGNe0F6dMNr03m67eFSdUDU1turahAVbPuDC11UrqAIN0U56gv4OkFWZk2CYNaAnL0FbxvSAWdmWRZgz+uVmk2b+d+2xx+/vPnOhxkTpDm0xbcYY2GbJ0GQPYvYXSBoikQMUZgkgKPhtcvpg9Gb25ehzRshLHtrJCB78UVIw2q2RQXRdOGrARXxr5Gm6b+De8PP2P+N/xj+HkN7KYtgWoywHruwfjjswMfG3z6Tht9Qfxvvar8L1tp+Ph/VOwrXcwADa40AuNnX3D8IVN38KSjpkYWr8fHx/+p8IRD6GRKTzaCg2eGISGb0hsxIqLn9tpYSZBcWZaXIsMhtPcZB6bYoyMCCPNyAFGEARBVBGRJg7wyBIHjPFKHJAbM/Zv6MaDnfMBANefaBptBpTGLDL9JItsNVw9I+Jw9YyIH1ObirgLNC50U38kUbrJYWHGRZxZ9RjagMrvJt05peOfO67Uo7gnW1/YUqUN8Ki+V2XzWIWnacnb8N/5hmkDplDaABEVVJwhCA8id4G1cdsGDS73Zofh3n2XAgB+evjXcMmQR9CY6hZOLN3wmlPA3JYt+GB6IW4c+zDua7sVr07/LX474T58asRizGvdjoZUDhu7B+G+rsn4xUUpfObz9fjWhOPw3W0L8OD+KdjZN0ByMSrHhUgHegD8584PAAAuHfoIGlMypYBwVs9wk8eJFhomRFSgAfpXkcbFz5nEwowpsYgMEVU0R5F9MM9AFGNE/BZkgkaaARWRZjLIAUYQBEHUKjEnDpQQo81EVIkD4r7SeXesyk/kXj75bTTU9RX2mkabsX18tJmIYtzj0tTWzm0nzNQWRYEGIGObDa6KMknrMQNE0J8zcYY20cgmJgzI9dPIAYeK219+RPVHqip6A5WRZgamNq+Cu9iKGZQ2QCQfkw5NBEFw3LdKPYBIPWM34T4BG+2bajMGojiB+YNt/4jpTStxVOtS/HDcjfjOmGYs65yM3X2D0Z2rQ2vdIQytO4jxjTsxplHuuHq3ZzBeO3QYXj00Fq8eGofNvfkb4zUT/4zdQ3pw9NlvYfGzkwpn9yL/9dGD/A2VPWePDLFA0wF2w33l0Bxk+gYjXb8f05rasbzLIMMng7zI2gH1BOY2VEYbGLIRE7RuitWYWpz4BPIDNDYByliCo4qToTyLh8+STo7mTtc4Ms6C90DgPGgLhQz2bzaow4MNwGvRUe9KSMVdmFGRaJEhOsB4jESGiM6JClQKDvYaU/xEmknEBjnACIIgCMKe8TnvRtqjUFmcaIWixsFPEO4X9pmNDx5+eTR2HjUEIxv34exp6/HIyikGr2JjEJ27nJ3DjyM6UDGuyEBTjOLYicoi1haUJjTbIR1vbF87QTnBuQbTysyHomYCzHWTSjN54qWbDDUTkB/Lk2ZSE7lmAgJHQCeqPydPO7cdmqHNFHtD288vfKK4/dNXjrH4rEaoCzYGpIXnqqkkTdqAytCmhdIGiJCglTME4YKYXWAduQH40IY/4j93fBnbe0eita4L8weswLmDX8NFQ17BWYPewvwBa4uFmV29A/DMgWn4n50L8JlNl+OU1Z/AuWuvxz9sPRv37JuBzb2Di++9/838IGPrYbuQQtbwJ1K7K0qksKIzX+yZ1tzOfhLuhyo8Zgw/MoLVMzJid4IBkbrBGLXiCmM/h6vVMlEVZnQEiTNzIjJ4dCJDhS+R4S432RyHTS3JAUYQBEHUAIlNHBgIs8IFAL/RZh1benD/oWMBAB8yijYT41d5BztbPSOrJHlMtorjKFagko2/HEZCJyJ1IEEraBi1opkAtz+L88JMyHFmXv05RYwNbTKcG9rYd4nOtObP0HbV7FIVM5tj08smaQOq1TN8pBmXNiCTVKJZ18PQxvDqZcbuSbbxmwQRBCrOEEQQDN03QKXQkOKn70waQCvQgyb8evencMaaR3BZ+2/x5Xf/Ad/b9kncuP3D+OetH8VXNn8C16z/Ihas/jZOXfMP+Oy71+J/dp2CZw/OwG5pTFmep/5yFFqzWWxrrMOCi1wKjR7szQ4CAAyo60CkQkNDHEIDqM4CDVC9gsP1dUcpMIBo3F8izlfNqHKTeZTJZaoJCx6b3GQv/ESaWTa1TKP8+z1mBxhBEARBBCX0vjMmpP29zE+02e0r5gAALpn4NprrRfe5aqWu6VhEHPt49OvUEYKpzQQbU1vcsdCAmz40DJeGsKipZt0UVdKAsaFNpZn8GNoYGdUBE0Obqkent6Ft0tC9xe3r7z5fcQ26SDP+uUGkWRqVhjaD8BVmaPN9D5EY2ihtgHANFWcIwgA/yxRFF5gIq9h7Nrj0coFVkMLKrql45MCp+HPmAty65yL8de/peHj/fLzVOQ17+tIAUlDfGMuFRseBZkzdnf+gSacu4873Ehre9OYaCp94iNvbP4WGJyEUaMIo0iRZcIR1jUkvzEQqMniciwxZEddvbrJIWJFmgHVTSx7HDjAbaHk+QRAEEQpRJw4oEfvOyO7XYvSOnidfHo7NueEYUt+JC2et9ThbttJXNkaROd4VprZM4bHKTG1JTh0A3GomoDrMbWHoJmv9GWJhJqr+nNbYGNr4vrxaQ5sXfDHGXBP9/tKSw+uPS2x+n40o/251GGkmSRsgiGqAijME4QOXLjBGeA0uZb0ORPeXXnRseyV/s20ftR+NzV7uLtFxoTq/A0Pr88JiX2EFTTULjUicYFUgNoBkFWrCvJaoBQbgxv0lw5nIaJfs8ysylJhWbXW5yVFFmvloahmiA4z6zRAEQRDVgO/EAbHvpCJxwBt/0WbdO3tw7/7jAAAfOF6WOCCa2mTjETFxQKS6Vs+IY0ybfoehpw7ElDzASJJmAsLXTVZEHAENuOnPKRLPqhlVj06doU2GPtLszMmbAABdvfXIgf1spmkDUJznM9JMBvWbIaoIKs4QhCtc5ff7iTYDDIRGK9TZnnqeuXsmhvdmsbe+DmdeuVRzpkpoyKLNgJH1+aWwe/qGImlCw2syOhYnGFA1BRoGP8iPSnhE8XmhCAwPXLm/qkZkKHOT2XPVChj/uclybCPN+NcYkhaeh+0Ao34zBEEQREi46DsjYpw4wGOUOOCFfbTZn98+EgBw4eHLMLCxu7DXa6xh0q+TTa4yqsPUJiMxqQNArMkDPHFrprA+09fvLMFJA4lcNSNlH+wMbXaRZkeP2V7cvuKOixXva1PkboV2fiot2cd/x3ukDbB7B/WbIZIKFWcIIigGef3JaXAJlPdCUD2W3xizvQ0Yvy3/IcPn8/ZqMdpMRO3ESCGLiU1bAQAbu4d5XnUUQkPnBPMrNFT4doIBoRVowizSMEQREEQIyN4rbDETmsAAnC/Lr02RISI6wHjsRIYZJpFmjppayqB+MwRBEESVUTuJA2a88MoQrMuOwYC6Hlx+tH4i0K5fJ49i0rVGTW2hpg4AiTK2MVzrnKg1E+BTM1VR0oCI+HeQbEMbUGloM+/RecdV9xe3H1w1WX2NFYhF7gYYLmf0nveSGNqk9wobqN8MERFUnCEIQ4L0nfHCfYNLmdAQMYk2K01ErnoqP2hZle7EkOGHJOfyqDKUS0JjQuNGDKjrQme2ERt6xhbOlfWUQGRCQySueDMgngINEF2RhkclPrz+ixpfvxcHhRkdNoUZGYkTGQylyJBNUIj4y002x1GkWRr6SLMYHWC0PJ8gCIIIlTgTB9LwmThgFm3Wu7cXf80cDwD48HG8FgzerzOPQaEmU3gUjS9VamrTEWeBJmrNBFSHborLzAZEmzRg8u/cmMgNbWJRmH8UtxnsNTnMGJG/mLV7hgLWkWaq+SeJoU0G/52uioM2uTcQRIKg4gxB+CR2F5h1tBm72fmLNnvt8TYc1p1FV10KZ1z9JndEJTT0guPIlg0AgNXdE5BFPapBaMgIywkGOCzQ+CzSEHl8i68IRIYNQUWG7d+GL5GRgQJRZLjPTZaTvEgzcoARBEEQtUzoiQNW8EUZ9px/5LdL5/3hjdkAgLNGvYPRAw94fIZJv0425hGPKaLNRDKaY2RqK1ElxrYkk1TNFHXSQPUa2rzTBs6b2l7cvvKO9yne21GkWSu8U2I8DG0Mdg9RpQ1ojdXUb4YIGSrOEIRLwnSBiQ0uRdI2H2ByoxRvkHUYsTF/EY2z11l8llxoHNeanw1cdGhqxbFqEhoyXDnBjCCxERqBijI1KDLKiExkGB0QEIs1UUWaAU4cYCrIAUYQBEFUCUH6zngRTeJAsGiztxY2YVHfVNSncvjg8e9ozjTt1ykiTrxK+nVmPC6STG1yfPaJ7M+aCUi2mS3spAGRWje03Xvd34rbb2z1mqTiCRhpxqNaMSPB9p5B/WaIOKDiDEG4IIq+MzxWLjBVvI4sjkcfbfbG/fMAAKsG9WLs5IzuQyEXE6XBwHEDVgIAXus4QjinNoRGZE4wIBKx0Z8ER6Cf1+Z3TCIjj1ZkiFGHzDUqFmr4+EQvTM5RoVohwztrdXGSUDvALCLNGGE4wAiCIAjCNdWXOMBOUBkw9NFm2Y4+3LH1OADA9fPe5o7479epPsfD1Kbq18mTQFNb7AUaMrYZEYluiijKDPCXNFD9hjbA1NCWQg5N9VkAwJPr+J87aKQZ4GloSwvPKW2AqCGoOEMQFtgsW2RCw9YF5r7BJYNvTA1hvxkr3hiLqYeAbCqFBe9/gzvilaFcLiLGNuzBtOatyOZSWNgxgzuHH0hUl9CI1QkGhF6gAWpfcAT++RJamKldkcE7wHhsc5NVmIiMRviONLOOVikRpgOMlucTBEEQkVBViQN8McYs2uxPr01HT64exwzZiCNGeg2IvPt1ymOKkm9q8xtvpiOSAg1AxjYNgYsyIffl9MJv0oAX8RraxOe6IrD4nWPGJ+e/Vdy+4Z7zFWf5iTSTFGRkNRpHaQPM0EYQSYGKMwQRgFAr56E1uATMb5iVhZuWNeMBAF1T3zX4HLnQOH3gUgDA4s42ZPoaUA1CI/FOMCB0Nxij1gSHk58n5sKMCj9xZiKhiAyRjGynjcgA/OYmmxFSpFlaeE4OMIIgCKLWSWzigBhtJmIXbbbx7Rye6c2P+W9YwE8E+uvXWU71mNpkBE0d8CIJBRqg9jQTkCwzGxB90kCyDW37JAdkkWYML0NbpfH2f9/3eHF7077BhS0Tc5rPSLM09JFmHmkDzNAmpg0wqN8MkRSoOEMQrrFwgYk3Bylt3LbTBpeAn2iz5/96NOpzObS35jDjPbKRhbfQOGNQ/ub39IE5ktebLsNFpELDhNidYEBkYgOofsHhrChj4/wKUJjR4dL9FYnIyBQeD2rOASCfbHCRm+wHk0gzwDrSTCyyy9y+mmK9GGlGEARBEEnAz8RV5IkDrRUbwkGTaDM5ud4c/rz+PQCAD8xaghRMfx+6sY3X6hnv3RUkxNSmwo+pzQhbzeSgSFOtusnZ9Se4MOMiaSBUQ5v1qhkd/gxtrQ2l76Rfvj5XckaQSDMP0sLzkAxturQBMrQRYULFGYJwhUGDS9EFxmCV/OgaXPqPNnu3fThmHsiLkXnve5M74uX2yg8C0vUZnDww32/myQPsps6EBi9CxF4T3G4T/AoNbvLZldCwdYJFXqBxVKSpBsHh9FojFBhAcPeXjEhFhopM4bHsb1tcPWfyh2+TmxxWpJnjppZtlbtEB5gI9ZshCIIgkkgYfWeK+E0cMEYXbSaOGUpFm7++OAn7c62Y2LwbC9o2eXyGvl+nfOxiMCmbKTyGbGoLq2cnkIDUAUZAzQRUj2YCHGumBBRmVLiIM6vAtaHNE3HehK2uc2do+8HZLxS3v/boaYr3sIk0A8rnp7i0gRAjzYxwFbtJEIZQcYYgLPHTd8aW8BtcAkGizXrfbgMA7Bq/HUDW43PKxcR5g5ehMZXFss7DsLZ7BBInNARc5CirCG2pPhCL2ACSWahxfk0RCwzAjfsrcGZyUJGhyk1WohMZDD6Xnd/nEpul+jzJiDSjfjMEQRBEYklk4oC7aLNMezce6DwWAPCRBSu4I3b9OiuRRb3Ga2oT8duzM5bUgQiTBxhJ1ExASLrJhhALM0GSBqwMbSZpHFVoaPvyiYuK2wd7mgpb6jmjErIVMzLDsEBass8i0owhRpoxozQztBFEEqDiDEEEROoCi6rBpUxopG0/xD7a7Inbj0FrNottTSkcf95a7hxvofG+wUsAAPfvO0byGdUnNGSE7QQDQizQOBIcQPngPmrhEdrnRiwwgPDcX74zk1WEJjJkyJxf9rnJ5iS7qaUR1G+GIAiCSApRJg54RZsVYaYKcZ9JtJnCyJEDbnvnaADA+9veQlO9SdIAgzeg8M532djIY+yUKTxGbGqLIt4McFCgASJPHuBJimZy+tl+zGwRF2ZMCdSfs53bjtzQxhPM0DZu0IHi9hcfPsPgFbpIMxPjWwFf81t5TO8R1G+GSAJUnCGImHDS4JInrTrgPtps755WTNszAAAw7YylBq/IDwIOb8xg/oCNyOaAB/bNQS0LjbCdYEBIBRrAqdjgEQf/LgSA7D1DETV+RFjIhZnY4szaue1IRIasaMvjLzdZjWmkGf8YTVNLBjnACIIgiGogEYkDPLJoM8NE0sqijMzUVhlt9sgLI7E5Nxzphg5cPJs3tTG8+3WqESdhZb36DAnB1GaCC1MbEEOBBnBepGGEpW8i0UxAKJoJcNubE4ixP2dkhjY+0ozH3tD2u0tLVYr/eXWe4iIdRZrJ4L+7qd8MUYNYlCwJgvDkESgHI/N2L8fi4bMwF29pB5YTsBEbMQGjp2ysnCAdh9KNfSRKN/FRqCxItMJghQn7CujlnvcIj73I3zjLb+rbXp0JXPAm1o/ei8bmHvR0qYo7vcX3ed/gZQCAVw5Nxo4+0ZnGwz6L3aD3odLJZshOVDrmtqE0CboFpZt6O0oToptSVo71NZimL7BpWIKjlJOqi4fP0ro5cqcbiFk24LVZ0cX+HYfcmyJpS/kr8Cu4YirMxJaZ7EdkMJTfU7J4Q3e5yWbohkkOm1oaQg4wgiAIopq5b1WIY7/xucpCwRiUxii8dgLyZglPs4gIGxeYjyu6d/bg7v3H4/NDHsaHT1iOvy6ZUTjCNJIK8TMauH1sjMRPZnZwz7lttplBfvxxEPmfnT3fgcqilUw/qTSTwPa1E8omRzdiQtn4RaaZVmNqRWzdSsysWDUF6HWTDiPNBOTH8LYpGOchkn5+NambHBVmwoqAdt6fk6eKDG3nT1sPAMjmgL4c8/iHFGnWCm+tNEayj/rNEFUMrZwhCB9E4QIr4ucmk7Z9gWypvj7a7Om/Hol0bxZ76+tw9lX8AFmMNmPk8L4h+eLMffvmSI6z1TO6fjMdlZuZwqPf1TOGeK2ekRGFEwwI0Q0GhOYIqwr8CoyEF2YSIzIyFWfCqN9UBf5yk83xE2lm2dQyAf1mCIIgCCIWDCbCkp044BFtBuAPb84BAJw3dgWGtXiNbcSxDEM3holo9YwqdcBy9YxpvFkssdAAaSZb/P7sIRdmgmIdw1e1hjY97xlXuuiL/3SZwSscRprxiAVjRpv6LQKlDURQcCUIBhVnCMIBofadYbRx29YNLlVCg6cRNg7wbG8Dxm8dBgBIz1/pef7clo2Y2rwLndkGPHbgiMJer8FCsoRG2PFmiS3QAP1LcCRYYNgUZmRUh8hQ5SbLRIa/3GQ1IUea6YjYAUbL8wmCIIhYseg7o4KtyPCMNuP1kqzfm1G0mSyyxyza7LXXBmJZdiKa63pxdZlu8u7XqR7fsLFRj7BPsl2lpjY/ONVMVKTRE+RnTZhu8hNnVjuGNn2k2d3X3FvcfnDVZMV7hxRplhaeGxrabNMGbAxtlDZAhAEVZwgiBsQGl6ILzFmDS0+hwRpcmtxIK4XG8kfnAQDeSXdj1Dh+QrVSaFybzv/MD+0/AgezXl89CRIaCXGCAQkQG0BtC44ECAyXBBYZthniTkWGbBWdCvvcZDNCiDQL4ABj9wXqN0MQBEFUE4lOHPDVbFrUTt707e/FHTuPAwB8eP7bBq+Q9YjgzSmyMY3puMkDMrWVQ5qpkqCaqQoKM9ZJA1VvaJPrpBRymDh0PwBgybaRANjPGWGkmaygLuLK0CaBDG1EFFBxhiBcY+ACMyVwg0srTKPNSrz57ERM6gR6Uimcdu0i5XlD6zpxweD8Xe32zDGFvfxAgR88+HGBGOBXaAhE4QRLvNgAakdwnIfECAwgYSKDp53bjkVkiPDfHZA8itt+cRBpJpIWnjuONKN+MwRBEESSSUziQFr14nCizf74+ixkcymcPGwd2tIZyRmimcR0HCPrMyFJHCBTWxm50yNIHgBqRzMBwX8Wi99j1SYNmJB4Q1uJDxxV0hVX/+V9ivdtELZDjjTjv9PbvN9KNLRpoX4zRExQcYYgfBK7C8w62gxwHW0GAAPWHgYA6JmhnkC8bOgKNNf1YVnnKCzpNLE+ALUmNGycYF4kpkADVK/gcHHdDgUGEG1hxpPYRYbsuMwtKotE9JupHFKkWUwOMN19hxxgBEEQRLUQWuKASMjRZmveTOHFviMBANefzE8SehVhRPe7l6nNa0xlQAJNbWEUaICIkgeA4IawuHBx3ZZmtigLM87jzNq57ao1tJWe/+GKh4vb7+warrnICCLNDBHTBkTYPcWPoY0gwoKKMwThCL8usGgbXKr2+482e/4v81Gfy2FtK3Dkce9y78Fu6t24Jp1fvn975iiUlsLWttAI0wkGhFCgcVWkSbLgcHmNCS3MmBK7yJDCiwxZcZZHVqSRTW5USaQZOcAIgiCI/khVJw7YjQNy3Vn8aWM+QeD6OUsAeJkwdP1mVOfzJMPU5ireTEdkBRoguGYCqks3BSUBmkmFn6SBCqrG0GZPc33p9TctPpI7EmKkmQj/Xe04bYBB/WaIJEDFGYKIiUgaXKZh6AJj2EebbVmXxsz9+XPmvu+NiuMnDNiMtqa9ONDXiAf2zUBVCY12bttSaMhw6QQDHC/XB9yIDSBZgsP1tVgWsqIuzDjPTDbBj8go0xI6kaHKTeYxz022I+KmlhrIAUYQBEFUM9WXOBBOtNlfXpiCjlwTprdux3ETtnJHKvt1lqPST8zUpouFjtfUZoKL1IHICzS1ppvOg3sjW4SaSYfLpIHqM7QB5frJvEfnD85+obj9dw+dqXhvx5FmaZgb2hjUb4aoAag4QxBh4NAFViTB0WaH3pwCANhy+C6kUn3ckV5cm87/vPfum4lDOfErJ+FCQySkeDMgvAINEKPYACoH+mELj7A+y8fvJYnuL2vaue2qFRkmBIk0c9TUkhxgBEEQRD8hsYkDIUeb7VjTh0c63wMA+Nipyww+SxzLiIkDsvN5jVS9pjYVYcVCA5aaCXCrmYDa0ExA5JoJiClpgP933q54E14zxWJo478n7Hp0fvWkhcXt/d3NyvPsDG2As0izNvUhdm8Q0wa0BmlKGyBihIozBBGASF1gjDaDc9KqA+FEmz3x53kY1JfFroY6nHp5qfg0uuEgzh7UDgC4Yy+/FDaBQiPCeDOXS/WBEMWGa8HBkBVsTIWB7rVhiRgfAiPMwgyJDFuRUUWRZgxygBEEQRBEkcgSB6ywjzYDgJtXHA0AuGbKkrLYoBKig11cVSOj+kxtQeLNEmNqA5KrmVy83pYYzGxAOEkDnuj+fev+LoCIDW2AH0NbW3pvcfuzD5zNHdEZ2iAck0WaKb4zHUWamfYgo7QBImlQcYYgHOLaBaa9uSQo2qzjYDOm7hycv6yTVxT3XzP0HTSkcnjt0Fis7GIXKZtMVb4zIhMapsTkBItFbADhiQ0VcRRfVIQkMAD3hZnQM5MjFRkmqIq4fpCJjJiaWrZV7hIjzcgBRhAEQdQUiU0cCCfa7MHnRmBzbjiG1R/CZUev5I6otJFoSpH16xSxNLVBeB6BqU1Gkgo0iSrSyPDSTFHqJp8/e5ILM877c8ZmaBPnXswMbX9+/wPF7V8tVP1NyTSRuG0ZaWZLQEMbpQ0QSYGKMwQRA0xoeLnAGH6jbdTRZoDraLN1T+dXxqwZfgiDhnegMdWHq9L5ycM/7pkN9cSpbADhSGhEsHrGBBdCA4i5QBN1kSZOQhQYgNu8ZCDGzOTQRAZbKScr1IgiQ/Vdocpst0EXaQYkraklOcAIgiCIJFN9iQM6/Eebde9J4fY9JwIAPn6iSbQZYD6eEZ30hqa2g+pDSTC1xZE6AJBu8iSAZgqzMKPC9N+R8/6cPOzvSTYv4YkrQ5sXOZwwPt8Ta/vBVmQroul5gkSaSTBNG2jTXJIfJIY2ShsgooSKMwQRFpLJsEgbXKZVb6K6GXpFm/Guh0qh8eJDM3B4dxaddSmcce0bOH/wOoxs6MSWngF44kCb5PNUk6oyfAoNHbZCo53b9iE0klagIbEhIWSBASTY/eVCZARG9rcsOsC8CFKMAcwjzZLb1JIcYARBEEQ1UDuJA/bRZr9/LW9qO2vkShw+mJ9wFSPMxHGNztSmSxvQmNpEMoXHiExtJvFmKsLWTIBPzQTUtm4K8LOFrZkAuwjoSPpz6rRSpvAYyNAmQ5xr4Z/LIs1K575vxtrS9m2Xcee6jjRrLXsokhaeexjaGIHSBggiZqg4QxABcT25xW4iWtok+8Tl+SKBo828qMOI9aPzr569Dh9M513ct2eOQB+yhXMiFBqZwqOr1TMiITvBSGxESEQCIynurwrEwkw7tx3aqhmG+DesExlAeZHGPDdZTbIjzQJBDjCCIAiiBog/ccB9tNnSRa14pXcm6lM5fPTUtw0uLkRTW6bw6HL1TDu37aGZZCTJ1AYEMLYBtaWbAv4scRZmQjG0tauvschOybavVTM8svkRnaFN9t2h/j6577q/Fbdf26yqhphEmsnOU5AWnouFcw2UNkDUAlScIQjH2LrAmNBgFX2G6ALzbHDJY9zgUteYTYw28xYar9w9D6lcDsj0YW7rTnRn6/CXvTrnio3QYOczHK+eMREd7f4+hqFy5lS92KhGweHg2l3GmEXp/jKOM1PhRGSIf78mIkPMXOcfxW3ZcxvEeBLx+zDcSDMGOcAIgiCIfkGiEwdU+I82y3VncevG4wAAH5nzFgCZ2c/L1CaSoNUzIlWeOsDwrZmA6tVMgBPNVHOFGRGVoU1HpvCYIENbc33pu+Wu5fzvSDYH5GVok80xab5TKW2A6MdQcYYgwiTxDS7FfbpoM8BLaKxZNg4zDtbjgoX5lTIP7p+MPX0tsguAXGiozhMnb0NYPcMTsdDwS2LEBlA9hRoH1+hSYAAkMkqYrHiRneMnS1mGSkSI34HhN7UkBxhBEARRi4SVOMBMDFLaJPsSEG325+cn4WCuGVNbd+CUtvXcEZW5RBzviIkDXuczHKyeCcHUFnfqQCTGNqD6NJMD3WRKVIUZGab//oq0G5xThYa2fzrtleL2p+8/R3I9skgzkx6dfJ9ji7QBWTpMm/ylgNrQpoXSBogEQMUZgoiYWBtchh5tBjQtnICTlueF120ZfrJYlnEqPpdFm4lYCA0dbGCkEhqmOcoxOsEA9wWawEUaIHmiw+H1uBQYQLSFGU/auW1VUTIUkSEeU4kMXlyoXKMql6kMVW4yf1y3LN8g0izippbkACMIgiCqCVeJAwxfiQM8oUWbVY41dq0F7us8HgDwiVNXGFwcYDa+UU3gak7nicnUJiPK1AEgYmMbUPOaKar4Z5uUAcBn0kAQQ5usmJkpPEZqaPOGL87sPDRAcZbDHp2tMI80k3yHs+96bc8xlO4dvg3SBBEyVJwhCAeIk1yJbXCpxF202ahHWtCQBVYeBow8b5fuQzlCFBqZwqM48OEJkqNsQJhOMMBtgQZwJDYYZyF64RHCZ9oWrqIszJhiJTJ4QhcZYja6Ct05QSLMRMSiDP892ADrSLO08NyrP1gBr0gzLeQAIwiCIKqNuBMH0qo3cRFtxh8vcfOyowEA75+0FAMau7kjoqnNpl+neJ4scUCzekZGwkxtOqIu0DjTTaSZlPgtzISWNNDObfvpz2mMa0ObLNKs9J0xaeje4vZXHz2NOzdoj04DkhBpRmkDRIxQcYYgEoCTBpe+o81ahX3+o80akMUVrWsBAA8dW4dxp4kusBiEho6IcpRluHSCAW6X6wOOxQaPKAL8igHV+zgWM35+D1EXZmpTZPCI3wHmucnmmEaaGZIWnvtwgHlBDjCCIAii1gk1ccDLLBFBtNnjzw/FuuwYDKrvwtXHmKyeMe3X6Xj1DE8Mpjab1AEg2gINkHDN5Pq9NEStmXREkjRgQiyGNh3q74/b3v9gcfunrxxj8F62kWaatAHALBK6zeAcx1DaABEFVJwhiLAJ0wXWZnBuWnjuW2h4C45zBrdjdEMHMvWNePmIFFYO68CwURmDzwlZaLBLiFBoBHGChVWgAfyJjVAEh4iu2BKBmBDxW5QJ0/0lo3ZFhira0C43WY3rSDMJ5AAjCIIgiAoSkzjAY5Q4YBptpnKRV/br7M0Af959IgDgYye8rbi4sExtEjLqQ3Ga2lQkrUCTSM2UYN3khdf/p9iTBmI3tImI+snG0JbDyRPyP9CuQy3ozdYrrivCSDNeO2liKb3SBrRGaEobIBICFWcIwhE2FXVnLjAZzqPNAP2y1dKxD6bzxaY/bpuJcb1AV10KZ33gTcV71vbqGRlhL9UHwhEbQISCIwH4+TlNf6dRL8sHql1kAGYiIwgNwrafSLPWsociaYOPbzM4xzHkACMIgiCqFevEAZtoM0CTOKBCtmrG3Nx202uzkM2lcOrwtZic5gdXKrOJzfhHZ2rrUB8CwjW1OYw38zsRb6qZElukSQBxmdmABCQNmBC6oU2MNBOP8Y/idvnzS2euKW5feNvl3DkxR5rJKHynm6YNMLTGaDK0ETFDxRmCCImgDS5VLjC3DS5dRJvlmdWcwfwB29GTS+GOzGwMWXs4AKD7iA0oHyyELTQUhzKFx6BCI0COsgqXS/WB8MQGULuCg/1cYRVmluAop4UZU3zHmalwLjJEZLnJ/DH+kd9WFXtlyERDRJFmHg4wsaklOcAIgiCIfkUYiQOMNoNz0sLzCKLNVi0GnuubAwD4+Kle0WbiWEhlYOuFXi/FvHrGgKSkDgBkbBMJUzMB8RRmrKkxQ9s9195b3H71XZPJpYgizfhCepvBZRkQqlGaIHxCxRmCiJlYGlyGEG32wWH5pfiP7m/Djr4BeOZP89CQy2FdKzBvgZerQSY0RAIKDR02QkOkXf/WJk4wFWEXaAB/YgOoHcER5OcI2/kFhCwy2jXHIhMZ4go4Ga5WyIiI32e840u3LD85kWbUb4YgCIKodUwm0kRTmxZ+3tFX4oDbaLNcdz3+sH4+AODDs99CCt3ce6m0key5KiZadNwbrJ7JFB5NTW0qg1tAU1vSCjT9vUgTVDeZEFdhxnl/zsQZ2hiVumpAY2nfH5ccIbuYAroCtI9IMxEfkWYMlaFNi0F8JqUNEFFBxRmCiAKJCyyyBpdW0WaAn2izdH0vLhqcvxH+cU9+4LVtwwgckWkCAMy68C3IBww6oSFGm4lELDRiijeLqkDT34o0Qa87bIEBRCAyeGIXGR2QiwxeXPDPXcWbyQovsm2DSDORtMHHtxmcYwD1myEIgiCqiaB9Z8TEARFt4oAM42gzhSmjOGkpW3XrvZrmjufHY29uACY078HZ002jerxWC0doauOxMbU5TB0AghVoojK2VZtuCnrdNjoziG6SYRof7jzOjIf9vWRMXxCWoU0dafYf5z1d3P7cg2dz55j06JQ9GpKG4rvWGzFtQIVpDCaD0gaIuKDiDEE4xFVl3WmDSx7PaLNWbp8q2qxSaFw5dDma6/rwdudILO4cXdy/69X8AGr92AyaWk0nUBMqNHS0C88jcoIB7go0gH+xAVSH4HBxjVEJjEjdX7ZUgchQY5ubLB7XEFFTy6CQA4wgCIKoVkJLHOBJC8+tEgcAP9Fm+9q7cc+hEwAAnzhFjDbzinBV9esUSaCpzYDA8VMFTCb2ozC2AcnXTa6uz0YzhaGbZEQSZ6aLSk+goe3T80sFjH1dzYoLV/XodBxpFkfaABnaiARAxRmCCBG/LjAvfDW4TAvnOcpQrkcW16WXAgD+uGc2gNIk8JN3HIGRvVnsq6/D2dctFV7pQmgkaPVMu+R9OMJyggHuCzRBxAZQPqCPW3S4vA6b30tUhRkZnv/W2oXniRcZgLuVMipUTlfHkWYaTN291G+GIAiCqGmqInHAbbQZ0Iib3syPHS8Z/zaGNu9XfTCHKl1AROxREYOpLQHxZoDbAg0QzNjGqFXN5MrMBiTI0Gbbn7MKDG3zx20tbl9z50XcOTY9Oh1EmvHItFOb99uaQP1miKRCxRmCSAChNLjUTQimdW/GhIaMSof5GYM2YlzjAezubcGD+6eAvyH39bXi8HdHAAAGHLMGZpOqNkJD9dxSaGQKj7zQ0E1GM0KINwPCK9BELTYYURVrxM9x9Vm2AiMq5xdg6P4yjTNTEZrIMEWWm+xV3DUhhkgzWVNLcoARBEEQ/ZCoEgek5ofAiQMygkWbPfviQLyTHY/Wuh584ARV7xxbU5vsPEaEpjaRduF5lRdoXOmmKA1ucWsmoMoKMyJ++nMm0ND28IfuKm7f8fYMyRm6SDNZ+oAhaeG5WBhnSL6rxUgzVdpAUEMbpQ0QUULFGYKIioAuMOsGlzz8zc4z2ozf5x1t9sH0mwCAO/ceie5c5U35jXuOBgCsHNSDCbN2e/wAMQkNHtlAKmKhoaNaxQaPrIhiIg78vs4vSREYQK2KDObk1P3dA+5WyiQk0swx5AAjCIIgaoGwEgeKRJY4AOTHD15jjHKyB/tw29bjAQAfn78U8n6dIqZjJIerZzKFR5emNgnVVKABwtFNXtonabrJhigLMzJMUy2KBO3P6Un0hrbGuj6MHNAJAHhx4zjwCSjlqCLNgMrvOotIM13aQFSRZgSREKg4QxCOMaqwa4QGQ3SBMYwaXJoKDSsqo82mNe3EiQM3oS+Xwu2Z2dJXLXtjLGYeTCGXSuHE9y+GXGiIg44YhQaPi3gzx04woDbEho6oRIQKPz9vUIGhw3dhRqRdeJ4okaGKLhRFRhjxZqrcZAeRZjLIAUYQBEEQvgk9cSBQtJk4dlAZPuTRZje/MAO9uTrMH7IBs8bsUn1wAdkYqXZMbTb4LdC4Th4A3KYPmFBtusn0d+qyMOMraaDd+yVKqsTQ9rWTXy9uXyuNNLM1tPmINEt7vyTUSDNKGyASAhVnCCJkTPP+mdDw7QJrs3sZ0tC4wHihIVK68V6Tzl/rEwemYEvvYKiERvat/MXtmrQNqVTW48ISKjR0ROwEA9wUaJJepIkavz+fi8JMKO6voHFmPKGJDBmqRrb8cd1zE9R9tJxGmpEDjCAIgiDs8Zk4IEabaQkl2gxQR5t5s355Fo9351MHPnPGCpCpzTx1wE+BBqgNY1schGlmC7swY500ENTQJptvKCMeQ9v/PfuF4vbGfbrvNIZpj04P0sJzVdqA5jvay9AWFDK0EVFDxRmCiJIoKvP8TUzlAjOONmP7KycyG1CHCwbnJwjvVKyaYTx229EY1JfF9sYUzni/eMOsAaEh0i48D8EJBgQv0AD2YgOoPcERpCgTR2HGufsrVpEhy02WHRNfa7tqRuYAkx3nH0OMNGsze2svyAFGEARBVDNBEweYqU01IVeROJCoaDPJOCMH/H7FewAAH5y2GE31pgYUlZ7qP6Y2IPwCjd8iDemmcM1sKnwlDYi4MLQxMoVHJ6tmgCCGtsnpTHH764+dqnkP0x6dQBSRZtr0GI6gaQMEETVUnCGImNG5wHw1uJThrN9BacLypIHrMLyhA7t6W/HSoYnQOSYO7GvB5B15N8bYk5fDfGI14UIjZicYEHy5PuBPbADVLziCXLvp7yuKwkwg95cKvjATmshQEcWqGX5bFmkmEl6kGSNsBxhBEARBJB3bxAHftEn2RR5tJj/+t6fHYEtuOEY0HMQVx6wTzlX3jyg/RzS88FSfqS3svp1AuMY2oLrNbUE0X9hmNiBgBHTVGNrYPltDm5zbr3yguP1fL7+HO1JbkWaUNkBUC1ScIYgQcOUCM8bLBcaTFrZ9RptdNHgZAODh/bPQh+ayY3nKb9grHp4LAFg+rAvjJu5VfKZfoeFg9YwY0cTjNfhixOAEA+IXG0D1CI6gBSUbgRFLYcYL8d+oSmTIyBQeQ10106s4P0h/GZlYSE6kWVQOMFqeTxAEQVQVUScOqIg42qwr04s/7joJAPCZk5bCbAwkM63JjqvGZQ5NbarxZABTmwqXsdCAnWYKWqRJum5ycZ1hm9kA8whoKa7jzBJpaKucX6lLZXHc4fkfaMv+gejJ1mveh2Hao9ODtPDcR6SZHyhtgEg6VJwhiAhIpAvMd7RZA1pS3TinEGl2/74jjS7tjacnYnIH0JtK4bQPLII8Q1nEVGiI+xgWQkOGSmjE5ARLutgAkik4XF2TK4EBhFiYce3+Yn8DGcnr2T6nIoMh5ibz234LNSFFmukIIdKMHGAEQRBEfyW0xIHYo83KJzh/89KRyOZSOH34KkwdvsfjM3RNwlVaKkRTG08QU5tB6oCKKDQTEMzYBiRPN7nUTHEWZnwlDYgEiTPLFB49+3PGY2j7yLy3i9sX3nY5d0S26s/W0BYw0kwD+w6nfjNELULFGYKIGp8NLhmi0JAicxo4izYDzhi0GgPqerCpeyje7DwMRhnKABqWtwEAdk3ZglQq6/EpYQgNj0lhmdAImqPcLjyvggINEFxsAOUD/ChFh+vPdS0wQukxA7iJM+OR/dv3FONBRUaP4nwZUUSayUSGQBrlIkOMQmEYOMCMGhgXIAcYQRAEUQskInHAhLTqgGm0mS4+lacR77xZh6d782PPz5y5AnJTm6xfp9fYyNHqGRmuTW0SotRMURnbGEnQTUGx1UyxFWZktGuO+YkzMyYMQxtDrqd+e8ljxe3FW0dLzjA1tOnOFTCNNNOkDZjC7glSQxv1myESChVnCCLhiC4wEeMGlzxp4blRtFnppPMHvwkAeGD/XABsQtg7Q/mxP8zDgGwWW5vqcNrl78CN0LBZpi+gExqyfaZCw9IJpqJWxAZDFB1BhYDr9xOx/fmjEhiAA/eXK5ERyaqZIPFmXrnJUOxnDjAZBrZZr0izAqIDjCE6wKwnngTIAUYQBEFUC5EnDvDmCV+JAypY5I95pFmRvhx+t+pYAMANMxejsa7P4wUmpjZxDBXC6pmQe3YC0RVoAHtjW9i6yfX7udZNpgRJGQAiSBpQxZmpMDG0OV81ozO0qXt0jh10oLj9oxeOlbyWYWJoCynSjNFm9/ZWkKGNSBiWf00EQZhycS6H+1KlG/99q4CLpwsnPQngLPnr5+ItfwO8NlQOMEaiNOk6CuUDiDSEyVd+4lEcRTSgAX04eUBeNT1x4Ajkb8xmE6YHdrdg2vYhWDL2AMaesgK4y2tA2IvyrynZ54hfYx0o/Qz89j7kBRS/T8JBVAqwHVC74RnboJ6EbYfn4GL72gnSuIWNmCBdJbUG05Qu+9WYqizmMVZipvHSX/bvMOgEsYqkLOVn2P7dhSUwVDh1f/F4ZSYnVGT4Q4wwcxBpljZ4TVvh0adrlxxgBEEQRL/iEQDnle9KPQPkTte/bBrWYDWmYipWYw2mYQI22jWQ57VTGuVaqRVmXpMyGlEauzQgP9ZpKOxrFJ7nz//r02PxH7PSGN2YwaXz2nHnoqkojVF6uffhxy3sfXQwUxsbyPDvodJR3NMMKsc8sn07UZpw5bdFtqC8QNaOct20KVUxbpLppiRoJiCvI/qLZgKSo5tCKcyIJH7VjLeh7bYrHixuf/fpk7gjQQ1tDiPNJGkDQSLNbFJqCCIuaOUMQSQMZ/0EfDe41HN06zoMqu/C7t6BeLtzInfELEN5xUNH5x+HdWLspAzMV894EfPqGRFHTjAdUbrBAPeOsKTh5+cLKjB0RC4yZGQ8jlegExkm6ESGLbJIER5RfPiMNOOJqKmlFHKAEQRBEP0UNhnuNYnuK3GAJ606oIo2EycxTaPNgM4dvbhtd37y9NMnLTW4OAY/hgq6esYwEponaLyZjBhTBwC75AGgpClIN5Uw/R26TBqowPDfURGb/pw8CTW0pZDDmZM3AQD2dTWhs9dmVZ/DHp1pxXkOI82kaOIxGZQ2QMQFFWcIIg4s+87E2+CyXGgcPyB/DS8fmoYc6mCbobzomQmY2gH0pVI47QNveJzPE4HQyBQeXQmNkJfqA9GLDaD2ijR+fx4XhZnI3V+xiowOeIsMGWKhxnbVjMoBpspNtow0i6GpJTnACIIgiFpCnBCTrgA1mFizpk2yz2m0GaCfxNSvcvn1y7MBAGeNWIlJ6b3CUdX4KEGmNh028WYSVKY2vwWaMIxtQG3pJr9FJ1PNFGoEtIx24XmQOLNM4VE2jyAlekPbDfPeLm6fdfOVmjNte3QyNOkkaeF5yJFmpoZnShsgkgIVZwgibsJocNkmOaYSGoDEvaBeinpUSzsAYHHHZLtr42h4uw0AsHvyFiCVFY7GKDRkZLht2xxlExJQoAGCiY1qFRxBijIunF/OCzNeVIXICNJfRoeXyBDP8YAcYARBEAQRO6EmDvgytenQxajKEgcaseyNejzbMwd1qRw+c5bYr1NHFZjaRBymDvgp0ADhGduA6i7SBLn2qM1sQELizNi+BBnafn/po8XthVvGckdcRZoBVpFmMjRpA6KhzQSpoY3SBogEQsUZgggRIxeYhCRHm81u2QAAeKtzIiqjgEyizRrx6K3zMKgvi61NdTjtioQIDdEJ5jUh7UpoAJEUaMIUG0D1CI4gBSWb34/fHjNWeegi7cLzxIoMHtU+cdu2UGMrMvhjqtxkgbTwnBxgBEEQBOGesBIHnEabsclIr2gzSJ5r6MvhdyvnAwA+csQbaKjrU5zoZ3VxTKa2CFIHgHALNIA/YxtQPea2oNfpwswGxFSYCZo04En0hrYJQ0qf+YPnjjf8nJB7dPJ6SWNoU6XEmKQNEES1QMUZgogLS6HBiD7arCQ0WlJ1GNWQv7Gv7ebf1C5Def/eFkzbnp/sHHPKCs2ZEQoNGWKUExBOvJkFfgs0QPhiA0im4HBxTTZFGdeFmdoTGXxhVdzHUIkM20gzFbpIMxXumloyVJFmOsgBRhAEQfRbXCYOMNok+6yjzbyQTWaa9eu848nDsCM3BGMb9+Kyeesh79cp0oPYTG0ZVKIzsoWUOgBEU6CpJd3kSjO5MLMB3v+fRCJLGuDJaPY5WTXDCGZo+8tV9xe3v//MidwRmx6dsudsH5szkpAWnodkaLNNGxANbZQ2QMQJFWcIokowdgQ4jzYrMaZhFwDgYLYZ+7ODESRDefkDcwEAK9OdGDd5LxIpNHhk+3R4DezahecOhAaQHLEBlA/woxQdLj83KoERqDDjRaJEhohOZKjOscUkNznaSDNtMR2l73vfE00EQRAEUWUkMnEg1Ggz9lx/fsfuPty262QAwOcWLDH4jJhNbTxhxJsBkRVoal03uf5Mm9+Ba91krJnahed+kgZM+nNWwP6ebFfNBDe01aeyOGH8VgDA1gMD0NUn0z0mPTpF7SQryIQTaWaDs3sCQUQIFWcIImSMKvBhuMAYDqPNBtR1AQAO9LUA4Ac79hnKC5+fiOmHgL5UCqdc94bBp9eY0JCRsAIN4EZsMEQBEFQIuH4/HtufO9bCTLvwPNEiQyyoMmTCQvybN/0O8JubDFRVpBk5wAiCIIj+QlyJAyrSsp1e0WY2RZlyfvHCHGRzKZw5fCWmjdwDualNHEuJYy1xf4JNbT5TB4yawHOYrMyw1U2uiEI3ucLWzBZJYUZGu/DcNGmAR9afU4Z2yiFaQ9un5peKuuf/4QrNdUXYo9Mr0qwA+65WpQ3oDMyUNkBUExZ/WQRBuOC+VcDF073Pm7d7ORYPn6U9ZypWYw2mYQI2eveqGIPSgGMkSgOONMoHFq2QDCbyd9Y+NAEAGlJZlL4++IFBI2xc7qmlbcDx7dg9eTNSdVnksj0o3fx7C5/B7wP3/uLXF/+5/PkdKI0M+O19yIsofh/3NINK4cXv24HShOxOlAYY/PY2lA82tqC8WNaOyknbTamKDOztaydIReRGTCgOVETWYJq2WR4bFJvEKAElsRFGpmtSlvID9qLKRLDFWpixjTPLSC8nTyCRIUOMLfTbX8YUMT4k3kgzgiAIgiDcMhdvYQmOwgy8YzamG5/Lj6/aUDmm4vXSKJTGSwMhMW9JBZQEXucAJa3TUNjXKDzP886SOjx5/lyc0/wm/s85y/HFP5/s8TlMRzFUOomdy+shXnsJOsmLg6gcH5loJqBSN4m0o1w3STSTCi/NBOibjK/G1ERoJoB0kzFeEdAiuqQBr/6coqFNmTQQj6Ht5xeVHF1vbhvNHQnSoxOoLERLSAvPTQ1thn/bDNtIM4JIGrRyhiDiJGCDSxFtg0sZltFm23uHAwBGNOzHgFRnYa//DOVHbj0Gg/uy2NZYh3OuVi0/zWFMwwGcPWgtPjbsNfzz6Kfwn4fdi9+N/zN+P+EW/Hr8zbhx7O347IgnMb91LRrAGmWarJ7hB0gSMoVH474bEvw4wSJaQQPYucEAtytpkoSfnyvSwoyMduG5rjBjQobbdi4yZDGE/DERsVBju2pG9Tw5kWbkACMIgiCIcoImDgTGK9pMxLNuoTrBpl9nI5AFfrX8OADA9dMWoaVBFnUEmK+eYa+JaPVMwlMHALcxZ0DtaibAX8KAyWqZSCOgXScNGBNk1Yy9oW3a8D3F7W8/dZLpRRbQGdpk323VE2lGaQNE0qCVMwRRBTAXmDVt0LvAeKQusHIyfUOwpWcExjXuwokD38GTB2ZzR5mw6OGe6wYMjTi4H2jbmsZbh+/DsJOWA3+ejYF1BzGnZS+OatmJuS07MLd1J0Y3HNJfWJEnsKUnjd/uPhN/ypyB0i1WtXpGdlxxCnOCZWC3egawd4IpCGMFDWC/igYI3xUWFX5Ek6kwc1qYCer+ilVk6Ahr1YxtpBmfm5yMSDNygBEEQRD9nTATB0ZP2aged/GJAzxplIoO/HYRcdzQg8pVNf7GOvc8NQob547EhPqduO74Nfj9i0d4vEK3egaoHB95rZ5hqQMKMsj/TvjVM2yfKRGkDgAIrJv8aCaAdJMKr8JYbEkDKjKafbGtmpF/r/z16vuK2//vheO4IzpDW8g9Oh1FmukgQxtRbdDKGYKIAKMGl5LJtmQ1uMxPVj6w71QAwGdHPITGVA5BMpQbkEXmzmk4d1EWpz/XhYdn3o1Xpt2B3094BF8ZtRDnDN6A0Q2H0JtLYUXnCNy/bzp+uWs+vr/tdHxj84X4yuZL8PdbLsJ/7DgTD+ybgz29AzCuMYN/GnM3fjP+fzAgdQC+V88EzVEOwQkG+F9BE8YqGqDknKo2Z5jfa05EYUaGqfuLxyvOjO0LZdVMD3dMxFWhRhQZ7FHXI0sGOcAIgiAIInYcJw4UYRP7bZJjvF7izRjScYBp/JcYD2SSONCInn19uHn7AgDA509YLHnfIKtn+ONeE8eK1TM8QXp2hpw6ALhJHgiim6oJv1rP9HeUqMKMSCJWzTBMDW2lv+Wm+l7MHZP/IdbuGYquPllxxVWPTglp4bmY2qLCMAWm2oudBMFDK2cIokqZhjVYjalqFxjLUOaRucD4DGVA4gLje7QAf8hchKvSj2F2y3r84vCf45+3XoctvYPhlaHcgDq0Ne3EjObdmNOyHUe17MCRLbvQeqiPE1sHgRTwbs8gvNUxEks6R2FJ5ygs7xyNjhwvZHjxUppsbUoBVw59A18Z9SBOGrgSPzv81/jkpi8iC8B69YwMmxxlEUdOMMDfChogvFU0jKSvpgkihlwUZQBHhZl24blfkSEjkatmbCPNVCJDFBZ2zXjJAUYQBEEQycd34gBjHMyMVAxtuxmxDyY/1mHjFrZKxduU8qvnjsA3rrwf84dswHvGb8OiTWNgN54JafWMy56dJrQj0AoaIHjyAGC/ioaR9NU0QQtIYZrZgACFGRE/SQMZyfuwfaGsmvFvaPvKSYuK2xf/6VLpOXlUhjbxmG71jGWkGf/3zuZE2jSXqCFo2gAZ2ogkQMUZgoibRwCcV74r9QyQO11+unGDS0YbnEabbe8dga9s/hL++/B/w8kDV+DRKd/Fy4emY0nHBGzrHYSD2Xo0pnowqO4QRjXsx7jGDKY378SUpgwaU9mK99vX14gtw5qxdM5BbB4L3Puz9+HdA0Mhd5SpyA9IunMNuC1zPN7smIibJv4KJw1cievST+GPmbNhJjQU8WYZuBEaVVKgAdwUaRhxiQ5XzrSqLszELjJ0q2ZkuI4341Gtlokg0szSAaYVGQRBEARRw1ycy+G+VGkMJI02exLAWeW7TKLNrOBNbbx2SsMi2kwcF/HjD9lYJ29qKxVr2PM8G5fn8GDnsbis9RV84axl+NgtKkcI0zss2oy9lwzeBMPOCcnUpoPXTV6aCYikQAMgVGMbkJxCjQvdFLZmAgL05gT8JQ3wZLhtsT+nElm/pvBXzQDAjWc/X9xetoMXKqaGtvgizRiioU2H37QBgkgCVJwhiCQhERqMUFxgKqEBKFxg+TvtS4fm4+r1N+IfR/8GJw5cjgUDV2LBwJWel3CgrxGruodheedwLOkchbc6R6K9ewBS9X0441N3YHtjHY69/h28+4vjFe9gJjTe7pqAH2+/EN8Zezc+PeIR3JE5vTCE8RIaGjLwl6MckhMMCFagAcwGOX4dYTxRFWtcxwTYxBXEUpjxwqThaobbdi4ydIgixNWqGdlzVdPdKo00IwcYQRAEQRTRmdoYvhIHGLxe4hMHpKY27TIaAb5fJ7+KRnN+rgf/++Z7cNmJr+CaSW/iS80nY18X/15inxkeXj/JroOdY7p6xrGpTcSkQGOI3wINEI2xjSHTMtWgm6LQTIDjCGjXcWZKQ5tqn82qGRn6Qs0xY0s/4OceUEwwKVFF1jsytOkozHmo/l7Z35fu74LSBohqhIozBBERRi4wCZE0uLSMNgOANd3j8bFNf4/JTetw4oBlmNW8AcPqD2BgXQe6cyl05Bqwo3cQtvYOwOquNFZ3pbG5dxAqnR9Arq8Ho9eNw/YZ25CduxbA8SgXA/ZC4869x+OzI5/A6IZ9OG3Qm3jiwHwEEho8tvFmITnBAP8NL4FoxQaPiRgQB1xR5jO7FBhAiIUZv80sIxMZJqtmZMWXIKtmTHOTZcc0hZoQI81cQA4wgiAIol8gSRyQwUxtvhMHnEWbyVacAOXjH7NIM8ajzwzFO8eNx8z6TfjYaSvxk8dU5j0/q2fEGLaQTG0JSh0A3BVogP6lm2z77kRWmGkXnkeeNMBwtWqG7wtlZmh79qN3FLd/tXAud8TG0BZSj05d2oAllDZA1ApUnCGIJOAz2oy5wJQwF1gbnEab5cnfoNd1T8a67rEoH1zwE7GyRpOVPP3H+Rj33QewrjWF485Zh9cen6w400xo9KEej+4/Ch8a9gJOGLCqUJwB5ELDUHBkEJ0TrB01LzZ0xNEss2oEBuDG/ZWRnMv2ecaZhbFqRnyd6aoZEdvcZL6RJe8AE74T0sLHOIo0IwcYQRAEQfhAkzjgjMDRZqpZW9lKmfJ+nfJos0ZkO3vwu40n4kdtd+Izx7yBnzw2B0FNbc5Xz/AEiTeTEWOBBjA319SybnKtmYAICzMiJnFmPEbzJCaGNvEcm1UzDPk5Q5q7MKgpf+zRNZPQl6uTnBXE0KbB1NDGiDhtQDS0UdoAkRRkf6UEQSSUwM4A2c2Pv0mmhWMV9Qo2YSkrZLCJTvEGLk6KNgrP82xdPwSzduXfd+r5Swp7+QGHbqKWd5KUBjaLOiYBAI5qWY+S0AB3rohs4rnyKYDygVlGcVm6WClxMt3UnadYtq3L39UNdoH8gNlk0MxYjanWg/IkY/vzmPy+QluSD/h3f/FkuG0jkSEjjFUzfpDlJvMEzE3miSDSjBxgBEEQRH9HnDAzXSGq6znAJsnZ5DqbiC9O1uvMFLxe4s0Z0nGBxQqTijGJ2UTobx6fhgO5Fsxs3Yr3HrnJ4BUmYy7ZWE021jOMbct47OMNQ7oJcp0BidEu2ReCZgLMCg08taSb/PwssRZmZOiSBiI1tAHmq2bE1+sNbT+/8Ini9ofuukDz/jJkhjb23GekmQxZ2oAQaSamDZgUOqWGNoKoAqg4QxARYlSZl1T4TZqbKYWGDNnNUMwBTYsnhCk0CqtwHs0vuV05/BBGjd9v8Dl6obGhO6+kxjXuFo74EBrscMbjkmSRUQwvZ4446d6uOC9hYqMaBYffa49UYABu3V9ecWaZwmOkq2b4/eL7BcGrUMMQRYaEEJpamkAOMIIgCIKQIFk5qpuQs+7Z0WZ3ehGlVBInM5mZTaeV9MaT3e9mccfeEwEAXzxDZ2oTx1ViPBJvpOGJyNSmKtD4MbW1S/b51EyujW1AdRdp/GomEzNbVSUNGPfnZMjmGfysmjE3tKWQwwfnrig+33FoAHc0ph6dpmkDhrDvdGNDG6UNEFUAFWcIIkaC9AlgNyVPBwFzgbVJjqmWjBsvPZcJDfG4OS89OhlTOnLorkvhjOsXFvb6Fxo7evMDhZH1+wHkoBYa/D6F0ODJFB5thYaIbICYgAKNrdgAqqdQ4/caTX8voRZmZLiOM/PERGTYrJoxy01WYyIyVJMgstUzjiLNGIqmlqIDzDrSjCAIgiAILc4SB3jzhSpxgN8uIhlPVCDGBqlW9lYWcn724jwAwPkjl6Ft+F6PzwEiXT0jM7XZrtSOsUADhGNsA6pPM/nVTV54/X5DL8yYJA3wWPXn3CceQPBVM/x+2fvl+dDckrnrzJuvVLx3bUWaSZEY2ggiyVBxhiCSgqELLNnRZkBpIlQmNHTRZo0AUmhYmu81s2vKFqTq+hSfwaMWGt25/PvXpXKoQ7dwVJcF69MJpsLGCSajXbE/JDcY4L9IAyRLdPDX4vd6TAVG6IWZKOLMEigyzDARGexYBJFmbf4+ghxgBEEQBJEnMYkDjNCizQC1a13N4oVNeLp7DupSOXzx3GWFvVGsnvFpalPt05naqqRAUwuaCQh+PS7MbEDEhRkR33FmIq5XzZgZ2m65/OHi9tPtut+z3x6dyYs0M/nOByhtgEg2VJwhiCQSRaU/9GgzwI/QeOSWeRjal8X2xhTO+YALoZGnDuzmayI0DMh47LOJN/MrNIBQxQbgzxHG46I4EtfnxSIwgHBFhoxM4TGhIiMYvMgQC8YykSHgMNLMBtNIM4IgCILoL3glDphEmwVKHDBBKpVk4w1xdS+PWb9OZIFfvH0CAOAj0xdiQKNoRJPhYvWMuC/EeDMTHBRogiYPAMGKNED0msnlZ9r87F5mNueFGS+cxZnJDG1sXw/K/yDCMbTNHFGKcv/eMycA4H9nsqhE2bZtj854Is1kSO8BZGgjqgQqzhBExMTa4LJN8mKTaLNAGcoieqFxcH8L2t4dBgAYeqKZC0IlNAbW5UVKV7YBvcihsj8GP2Di9wOhC40qK9AELdIwRBHgVwy4eh8ZrgQGEFNhpoZEhp4wcpPDizTz09TSC3KAEQRBEP0Swwm3SBMH0hBQOMul+OvXedeTo7A+Oxrp+kO4YcHqwrGoV88YkLE7PXDqgArV+BtukgcAd7rJpdYJSzfZaiZfKQOA9v9bGTL9HEucma4/J49bQ9uDH7y7uH3j88crXqNC9/3jsEcnI2CkWeDvdoJIGFScIYgkEXaDS0YgoQH4z1A2Exov334s6nI5rBiUxZwT3y0csxcag+vzo6n92WbhPH4gI8tR9sAmRzkKJxgQqEATR5FGRCUawizCiLgUGEBIhRkvbAozPAkVGWaElJvMY9yHC+QAIwiCIIgo8RltxjCKNvNKHLAZJ5SZ2kTsxiq9B7L4zbsLAABfOG4hABOThp/VMw5MbZnCo4mpTSQCUxvgztgGhKebbDVTGLrJ9mfzrZkAt705XcSZGaOaYwjH0DaoqRtThuV7T72xZRQ6e3W9rHSGNjF1IECPThmatAGVoU2H37QBMrQRSYOKMwRRhThzCphEm1UQJNrMjDVvj8KsvU0AgLmXLTJ8VaXQGNOwHwCwo3cId45qYtlCaPBkDPcxbJ1gjgo0YYiNsAo1UROGwAitMBPU/cWTsTw/YpGhx0tkyB4BfW6yQCvUhWpZpBk5wAiCIAjCKX4TB3SIiQMV+Ekc4PGMNuMJEm2W3/e/T87AoVwzjhywBWfOeBeVuFg9w5CNBT3MbVWUOgC4NbYB4ZrbosaPZoqsMBNm0gBPQg1t3z/zheL2+++4WPGaiA1tppFm4/WFEvadbW1o43Bx7yCIMKHiDEEkFUsXmHG0mQzf0WZhZCjn921+ci4A4J3RBzBi7H7JZ3sLjXEN+dzVLT1DJOfFJDREIirQAO7FBlC9gsNPgSlWgQG4cX/J4swSKjLMcJmbrFgRqBIZOjwizXSQA4wgCIIgDKiKxIEwo83y7GjP4c79JwIAvnQmM3p4GV/8rp5hyN7fp6lNh5fxKKYCDWBnbAOqVzMB/q7d5PeTuMIMj0zX++7PyQjL0JbDl058o/hsXSbNHfNraGPPHfboZERlaKO0AaKKoOIMQcSASxeYtdBoKzw6iTYzwZ/QeO7eqWjrzKGrLoVzPvp6Ya+d0GhryhdnNvWkudc7FBpe8Wb8/iBOMCByseFXcCRZdAS5xlgFBhCh+0t1QpJWzZhg/51ThskCwQBNLckBRhAEQRDREFnigHG0mczUJh7n0ffrBICfvXwMAOCiMUsxYei+iuMlwl4947hnJ2C/StyhZgrT2JZkzQT4v05TM1uoukmH7t9QxnBfEVl/TkZ0hrZLZ5b6WF55x/sUrzE1tAFOenTK0ESaMZihzQS/hjaCSCJUnCGIpGHoAotXaKhmLm0ylL2ERh0a35oCANg2ZQvqG2UOL/3qmZnN2wEA73SNkpzDvx7wJTR4MupDUmIu0IThCGMkSXQEvRZT0RVaVjJgV5gxIaM7mAyRoUf33aHKTeafyxxgkmJzWrEdZ6QZOcAIgiAIopKwEgfaJC+WmTPSwvMKqcQmM1Uaio1f/PXrfP3VZjzfcyTqUzl88b3LCsdk/TrhsU92jo2pTYNfU5tIGJopBmMbkCzNBAQ3sjnRTElKGuAxThrwMrSp/u78G9ruufbe4vZdy6crz/MmvkgzsecX+25m39VBEA1tlDZAJBEqzhBEkvHZ4NJptBmP7wxlXUM6GfnzH7rpGAzrzWJnYx3O+/CSwjHT1TM5HNGSL86s6OIrUKarZzyXFKhP8+MEkyETGzLaFfs1QgMIV2ww+EF+FMLD1efZCAynhRmvZfki4r+hqhAZDJerZnRCwiTSDHAWaVYgykgzgiAIgqhVbBMHYok240mLO8Lv14m+HH6x/AQAwEdnLkRLg834K6LVMzwZzeUxokwdAAIXaIBguilqzeTqM21+Zl+aCYgvaYDtN5wWKJ0ozjOIWsmdoW3qsD3F7f965RjkwP8evSLNGoVHhmWPTp40t+0w0sw6bYAMbUSVUXPFmXfffRc/+clP8N73vhcTJ05EU1MTxo4di/e///145ZVX4r48gigSRsXe2FnQVniMLEMZqHS0e9N5sAkTNuRnRBvme2X4lA9gxjbsQbq+E725FNZ0pyEf/IgTy7KRl0+h4cIJBphlKQOhFmiA4EUahigC/AoCF+8hw5nAAMIpzAR1f7H9sYoMP71lGLa5yeJ3jSgyJASNNHPQ1NILcoARBOEC0k1E1WM4AZesaLNw+nX+5fEx2JgdheH1B/HRU9lAIejqmRhNbSIJKNBEqZvC1kyudJMJvs1sQDiFGRMyuoMm8wPivnAMbQ998O7i9j8+sUBxlirSTHaejx6daeEUmaHNINLMBGff5QSRMGquOPOzn/0MX/7yl7F27Vq8973vxVe/+lWccsop+Nvf/oaTTz4Zt99+e9yXSBBSYnGBMULNUJbd4M2FxrN/OA6NuRzWDsjh2Pe2F455C42ZzfkR2LruEejOidcgGxzphIY4uNLkKGcklyNi4wST4adAkyCxoUInHsJ2k9kWZZwuyQfM8pJdNbOUErXIEM+1xU9usuy7KECkGSOESDNygBEEEQWkm4iaIqzEgTbJiwJFm3nhr19nz/4sfrnpNADAl45/DSnojBumq2dU5wC+TW1B4s1iLtAA5sY2IDzdlATNFLqZDQivMOO0P6d4YnSGtqHNnZg+IgMAWLx1FA72NHFHgxraLFHNF1lEmolpAzrjMaUNELWG7+LMBRdcgLvvvht9fX0urycwxx9/PJ5++mmsXr0av/nNb3DjjTfizjvvxFNPPYX6+np89rOfRVdXV9yXSRB6HLrAoo02U8GizeyFxqa1aRy5K//eU85b7HF2aYAzpyU/SlvRNYo7ZiM0ZBPLmgabSXCCAVUvNqLG9udwLjAA9+6vjOYYkBCR4TLSjEf8jjEVGj4izTQOMFFkEATRf0mqZgJINxHVg220mQ5marPuZRBptBlg368T+OUj07Av14oZLdtw4dwNhb2uVs/wmJjaIkgdkGFToGlXvIeBZupvusmPZgqUMtAu2W+imUT8FGbYfuk/Xd2/b/5vQdzvYtVM+Xv87IKnituX/OlSxev9GtoMenS2IhJDm3WkGQelDRDVgu/izCOPPIIrr7wS48ePx7e+9S2sXu2dpR4FV1xxBU4//fSK/aeeeirOPPNM7NmzB2+9RUvhiCoiYINLJW2Fx0ijzQA/QmPlA/MAAO8M78DhMzKFvXqhMb/1XQDAokOHK65DHCT1Qt9jI2ShkaACjR+xUU2Cw881BxIYQLiFmaoTGeK5pui+K8TYRJlbzEdusgwfkWbkACOI/ktSNRNAuomoESSmNtmEXXVFm8kmTHXkj+/enMMfdp8CAPjaGW94vMbl6hlL042xUaiAbeqAC93kkTwA2BnbgOos0vi5ZiPN5CJlIEgEtIqM7qCf/pw87lbNNNT14fqjS9ph4z7bORmegIY2niREmlHaAFGF+C7OrF69Gt/4xjdQV1eHH/3oR5g5cybOPvts/PnPf0Z3d7fLa3RGY2P+C6ehwaQxMEGETyIaXDJiy1DWkX/Na0+1YebBFPpSKZz6wdc8XtODxlQf5rVuzb+2YyzMhAZDNjFtiCrerMoKNIB9kQZIdqHG77UZ/x6qoTAjJQ6R4XLVjElusu4cj9xkfpscYARB+KAaNRNAuonoXyQv2gzwGzX0X8/MRV8uhTOGrcTcw9nA0fXqGQemNpt4M5GoCjSAc2MbkGzNBISsm8LSTIBdBHSG2w4UZ8bQzSeojGv+DW1fOH5xcfu0318tnOtlaJM98scte3TaRJoxPCLNdJChjahFfBdnpkyZghtvvBEbNmzA3XffjQsvvBDPPvssPvjBD+Kwww7DV77yFSxbtszltQZiw4YNePzxxzFu3DgcddRRyvO6urqwb9++sv8IIhbCjDbT4SvazDZDWbaMVk/HKzMAABvG78DAISxiQy40jmrZjpa6PuzsbcXa7mGKd+QnlSvfI7DQMMXLwRNmgcawSOOHJIiOoNdgXJSJqjAjYuv+SrjI0BNjbjI5wAiCCEC1aSbATDeRZiISic/EAWOcR5u579e58q16PNBxLADg6+cvUVwcw+/qGYYu3kwk4tQBIPHGNqBcr8Slm4JeQ+xmNsB/BHSohjbV31dQQ1sO/3leyc313IbxyqsuoYs08zK0OezR2aa7xhLsuzmIAZkMbUQ14bs4w6ivr8ell16K++67Dxs2bMD3vvc9pNNp/Nd//ReOOuoonHLKKbj55pvR2dnp4np90dPTg+uvvx5dXV340Y9+hPr6euW5N954I4YOHVr8b8KEZLoZiH6G62gz0QUWONpMBS80dHgLjYf/cBTG9GSxr74OF3z8de27HduaL0It7DgMQApyocEjDpIizFEW8ZqQB+yFRrvmvUIUGwxxwB+G8HD5Gc4ERrvimJ+8ZCAE91eSRIYpLnOTBcRdacnHU1NLgiB8Ug2aCTDXTaSZiKip2sSBiPt1AsB/vnocAODqw9/A2MEHCnuDrJ6xMbWJ+wx7dmYU5/D74y7QxKCbwiAWzRRnYcY2aaDin7NXwVE1jwCUdJJMK/kztJ07ZX1x+7MPnC2ca2NoA+wMbW57dNogM7RJv+PJ0EZUKYGLMzzjxo3D3//93+PGG2/EuHHjkMvl8OKLL+JjH/sYxo8fj3/7t39DNpt1+ZGeZLNZfOQjH8Gzzz6LT37yk7j++uu153/rW9/C3r17i/9t3EjNfInqIf5oM90SWL8Zynmy2QYMW5Ef/O2dtR51dbLJ3vy+YwfkR3KvHxoHM6HB4zhHOcNtR+EES7DY4JEVU4L854LQBQaQQPeX+DwekaFHJzJkmOYme0Sa8UQUaeYFOcAIojZIomYC7HQTaSYiUYSZOKDoJwfAO9osLR6Mpl/n088Pxms909CU6sMXz1Mb+vKYrJ7hcWxqk817y4xGjDgLNICRZgLc6SbXmsmFbrL62bx+X+2K/WEUZkzQJg149edk2CQM2BvaHr3+ruL2rxeqUoFMDG1ikSbkHp2MAJFmBFGrOCvOrFy5Et/4xjcwfvx4XHvttdi9ezeuv/56PP744/jRj36EQYMG4Zvf/Cb+/u//3tVHepLNZvGxj30Mt912Gz70oQ/hl7/8pedrmpubMWTIkLL/CCJMXLjAnDW4ZPiONjNF14NGLTQe/t1xGNyXxZamOpz3oaXSd25O9WJ+od/MK4cO544EERriPkfxZhlu25XQACITG64KNXFi/bMEERiJcX/xyNyM0YsMM2QFFzECRFa44UWGASoHGPteJAcYQRABSKJmAux1E2kmIrG4ThxgtBUebaLNjLDp1+k1likc78vhZ8sXAAA+dcRraG1gYzAXq2dU5/s0tfFkuG2d4cg2FlpFkOSBiIs0ScBaM4WdMmAbAZ3htmX/vnwnDfDniqvLZP05/Rvapg/fU9z+9aI56MvxU7p+esSpohR5HPXobDO7IkobIPojgYoznZ2duPXWW3H66adj1qxZ+PGPf4zhw4fj3//93/Huu+/i5ptvxllnnYWvfe1reOedd7BgwQLccsstrq5dSzabxUc/+lHcfPPNuO6663DTTTehrs7pQiGCCB8Hk3LGDS6dR5vZCA0v8q85sKcFbZvyF9V04goArLBVGtwc17oJrXV92NIzEKu7hxf2qoSGbAWAWLgRhYbDHGWRpBRoDMUGUL2Cw/q6TX4v7Yr9qv8Xsbm/bOLMwhcZ/pBNTshyk8XvF/adNKRyN0/a8DLaCo+um1oSBFFTJFkzAaSbiP5J8qPNAO9eEGr+9Og4bMiOwvD6A/jY6R7uP6erZ8R9lqa2DLcdJBZa1rfTZgUN4MzYBlSvuc3XdfvVTECwwozzpAEv+PkB2WoaL+wNbfd/4O7i9tcfPU1xlszsWn09OnXf0brYSoDSBojqw/eo+wtf+AIOO+wwfOQjH8Err7yCa665Bk899RSWLVuGL33pSxg2rLwJd3NzM8477zzs3GlqY/APExi33HILrrnmGtx6663aPjMEURUYusDYTcy6wSUjUqHBPzfjmVtOQGMuhzUDgJMvqPwZTx20CQDw3MEJqBQSotAQESehvZwxgK8cZRsnWFgFmnbFMYbPIk2SBYdvcWEiMNoVx8IszGS4bd9xZuJ+r5gKL1yvmrHJTbYVGdFFmlk3tSQHGEHUDEnWTADpJqK6iCpxIL5oM69+neLYR92vs/dgFr/YcCoA4O+OfQ0pianNbvVMEFNbDKkDqn1hFGgsNBNQHeY2X9cYxMwGuC3MOEka8NOfk992a2gb3tqBGSPyF71k20js7WoxeI+Qe3TKcBxpZpwSQ2kDRBXjuzjz85//HCNGjMAPf/hDbNq0CbfddhtOP/107WvOOOMMfPvb3/b7kUawJfm33HILrrrqKvzhD38ggUEkHpcNLp3hJTQqCDtDOb9v85qhmLVjEABg/Hlvcsfzg5zTBvLFGR7ZJDE/WBL3MxzmKGe47bgLNIC32ACsxQaQrEJNoGsx+dnbNcfiLsxYZybLXpiUVTMmucni+TKR4UHSIs04yAFGENVJUjUTQLqJqDE8JuaSG21m0q/Tn6ntl49Mw75cK2a0bMNF87z6QnmtnhGxMbVZkuG2g6QOqPDTu7Pd4z0DFGmSoJsCXYsLM1sUhRkVvgozuvkCE71jb2j776tfLG5fcfvFwlEbQxvgrEcnv02RZgThGz+hhACAxx57DGeffbbVaxYsWIAFCxb4/Ugjvve97+Hmm2/GoEGDMGPGDPzrv/5rxTmXXXYZ5s2bF+p1EIQzHgFwnvdpc/EWlkDeEG4qVmMNpmECNmIjJmD0lI35Qdf4XH4g1Yb8YGkcKgdGIyEfBKXhMdgRBYZs1MMGIeyryGyCd8ldxwKfeQbLh3Zj1vwtWL4wf9ef2LgPk5r2oydXh5cOHsa9pzjw6IX86499foPkvA7kfybZ+/Gw8/S7KsjAPEoJyP8/EQdA2yCfPN4CtdO/vfDYpvksNtjWuQYViIP7ogMxJJwIm6BFGSBYYcaLDLft1P0VvcjQoxMZMmQTGCKWkWa6iZa2wmOYTS3JAUYQVU9SNRNAuomoYZ4EcJb3aTPwDlZipv/PGYP82I7XS6NQmhweiNJYrRXC+IyNR0yLGY0oH1M1ID/WYvvZ8xKZrcCtuxbg8yMfx9dPX4j7F08sHOH1jEwXyfaJ47lGbj/b5kUP28/2scd9yP/sHpopg9KY7CBKBhp+P5D/XfOGGlEjyTQT4F83tSmOMapEN0WmmYDwzGyyfboI6FDizBiyxA53hrbm1hyuaysZU9fsGSa+oICJoU0s0oiGNgNijDTzwsvoTBBJxHdxxlZkREV7ezsA4MCBA/jBD34gPaetrY1EBlGdSITGvN3LsXj4LOnp07AGqzHV/nOY0OAJRWjIijLeQmPpy+Nw9XUNWDa0F/Oufh3LF+adI6cOXA8AWHhoNA7lZAsDmdDgizD8Z/ODEZXQgLBPFBoeZCAXGuIxL6Gh2qcTGkBsYoOhEwKmAiQ0Z5kLgQGEKzIyHp8NeBRmVKhEhriiLCmrZsSeMrbDGcNIMxk+I81kkAOMIGqfpGomgHQTUZ1cnMvhvlRpzHbfKuDi6erzU88AOWGxms7UxlCa2nTwRZo0SuM2fhuAREChfLwjHhO1krhPpFTI+cmTR+NTVz2F09KrcHzbNrzarps1FU1oonFNdk1+TG1VXqABqkI39WvNBIQcZyZ7A1EnqbA3tH3r46X/3+f+8SrhqK2hDdAb2gL06PQRaabDOG1AY2ijtAGiGvBdnEkqN910E2666aa4L4MgYiWwC4whWzWThqXQYPuBysELj1gokVESGlsenwu8fxHeGXUAY9v2Ymv7UJw2cDMA4LmD44XXea12Ec9l1wPohYZjJ5iI6wIN4FZsAIEEh0hsy/nDFhhAgtxfJpnJ4t9wOCJDj83wxFFucsiRZtTUkiCIpEG6iag5DBMHdLDEASVtUCcOyOBNbVJkOklEVjARj8tNbauX1+OugyfgmkEv4h/OW4TL/vcC7n1MVs/4NbXx4ke2T4Of1IEoCzSAmbENcFKkkRGLbrKJbWv3OF61hRne0Mb2ybSSO0Nbqg74zog7i88fXy3OdTBk8x0yQ1vEPTrb9J/A0gasDW0EUUP47jlDEIRbwm5wyW56ygaXbYUTQ8lQlmGToVw5UfvcfdMwtQPorkvhjBteRUuqF8cNyI/mSsUZXZNLMUdZHDCJz21zlCPqP6Patw3++tAAZr1oGD7ylRMBu+4onF9VIzL4feGKDHNk/adizE1mtBUeqaklQRAEQSQTyUpU2QQfM0/oJga1sIl9ftwgi/kBJAUH236dDJOVw6Vx0b+9cDwA4OIxSzFz9G6P9zYd2/Erqhmy8Z6Pnp08GW5bLHJloEfW09G2d6dX/852j2tgVKtmAuw1U7vmuO53GoZmUhGoNZJXf04ZKgObt6HtUx8q/cCfePh8APz/C9V3gEwXmRraPEhKpBmlDRA1BBVnCKIacNDg0hrZTZW/+fI3ZanQMLixFxEnYBuE/bJz69D7an510IaJO3DqyE1oqevD5p6BWN2dlrzOVmjwz9mgSbbyRyU0FKgGgmEUaIBgBZp2j3N4bIsdceDnGtvh1vkFxFSY8cKPyBDPd7VqRkWIuclesO9DijQjCIIgiKrDa4WqFxWmNh18kSat2AZQqZXE1b785KmNqa2Sha824ZHOY1CXyuGbF7zJHRGNaSI6U5t4Hj8ONDW1aQo0JqY28ZhsxbmNsU1FGMa2JGsmwL9u0hFFykBGeG6VNBCkPyeE5+LfhCniuTn8cvItxWe/f/UIxetcGtqqLNKMg9IGiGqFijMEUW0YTt4FaaJWhuwmm1ZsA5AXZbyEhniuGQ/fchTGdWexr74OHzgmLzSeO3gYvB1cNkJDti1b1sw/+nCCeR0Lq0DjukgDJEt0+L2WdpgJjDgLM1b4ERm9kudRrJoxyU2WiYwAuclekWYBcfZ9TBAEQRCEPnHAcOWpTeJABW2FR51pQ2Vqk+LH1MYQTW2SAk4W+LeFJwEAPjBpEQ4bst/jM4KY2hgmpjaeAAUanjgLNO0e54gkSTMB4esmFX4LMyIZ4XmgpAFxv6o/J78t/j0w/K+auezSA8Xt77x8GrLSvroiJoY2oNzQJiMZkWaeUNoAUQNQcYYgEkTVRZsZCQ0TsSHLPvUWGn19DRi6bBKQy2HanrzIePrA4ZrPMREasoFVDPFm4jEgnAINYJaX3Q57wQGUD/KjEB5BP68dZj+nrcAA3C/L548FzkxmiIUUWWQF/9owV83IhIX4XWHbOi/8SDMd5AAjCIIgiHgJNXHAxNTRKj6RRZupTG08dr0jnnxmCF7pmYmmVB++fsFS7ojL1TOybVNTmyU2qQOAu2joMIo0QPSaycVntiO4mS1IYcZPb07jpAHdPp2hDXC3aga4e96vi9s/fHyecNTE0Cai+j4J2KOTEUKkGaUNEP0BKs4QRLUQwBEQWrQZj3GGsig0ROyExgO/PhZHbstixD6gpy6FVw6NLRzxKzR4XOQoK0higSbMIg1DFAF+BIGL9xBph3lRJq7CTKhxZl4ig+2H4piKMFbNMCLOTTaMNHPS1JIcYARBEAQRHIeJA+FEm4l4RbKqJlb544Bs/JTryeLfl54CAPjEjFeRbun0uJYoTW0+4s2A8Ao0QHzGNoYrvRO3blIRRDMB/pIGPOVRkP6cvKFNljRgb2g79fTS3+ivVsxHd59JfxnR0NYgPDfBokcnwyLSzAQXhjaCqCaoOEMQ1Yhlg8vABI42002U8pFE9kKj42AzTnl1MABg1aQcuuA10PQSGuKgCpLnXisQDOPNwi7Q2LrBADOxAQQXHDJ04iEsJ1k7zH8OL4Ghcn5FXpjhSZbIMEf2XSDmJquiEUUMcpNlRB1pRg4wgiAIgjAm6sSBCtoKj36izQL362TIekqIlMZJf310FFb0jcegui78n3OXceeEaWoLMd5MRDwWdoHGpkjTbnCuKaaayaVuakcyzGyAXWEmAw+8dLtpnJkOE7Na5TmPnV5aNfO1u08UjsZgaJPhI9KMfaey79hAaAxtlDZAVBNUnCGIKsdWaDBCiTaLSWiMXJzffnF6Hc6/YQl3jh+hISJOVsuEhg7D/jMiXr1FTAo0uv0uxAYQjuCIgnbYFWVcCQzZfueFmaBxZuGJDD02bjBV8cZnbjK/neBIM4IgCIIgDPBYgRp5tFlasQ3AX79OSV+ZCirHVdnOLP5z1RkAgP8z9xW0NJiY1ryO60xtDuPNXJnaAPsCjWvdVE20w61mCrMwI8Ifs04a8CocqoqTuv6c5oa2Oe/pQ3Mqf/zBd2dif3ez4kwTQxs7L6ChzatHp89IM+u0ATK0ETUIFWcIImG4aHApg930rB0KNtFmZegmSFVCg8ds2e3Qui4c1bAbALBoWgp1J74DIKs4WzYQkjm8VEJDtu0lNCA5V7ErIxzTLdUHwi3QAHZiA0i+4GiHW4EBuC/MiHgV6Sowze9WxZlBeO5OZJgjWymnExkR5CY7jDTzhBxgBEEQBOEOh9FmRWyjzYwx7dcJePfrlJ9788PjsTE7EqMa9uMTp6/kzglqags53kzcneG2wy7QAGa6yZR2JFs3tcP++vxqJsBdYSa0pAEvQ5tO/+h6M6n33XPRHcXtj/7pDOGoraEthh6dDItIM2epLwRRhVBxhiCqlbiizXxlKPNLZnWoml0C8gnbRpwycDPqUzmszQ7GwcE5rG8BzrqG/z2YTharzrMRGo6dYEA0BZqwijTsv7jwex2mRZkwCjMZblsnNK2bWXrFmYHbDkdk6AlTZHis4EtL9oUYaUYOMIIgCIKIh6DRZhUTjG2FR2bi8DK1eSYOyJCZ2njs+nV27c3ifzaeAQD42vEvoz7lytQme624zfBpahPJcNsuCzRRGduAZGgmwP91BDWzxVaYsY0z64D3v3kTQ5vqeYm2GTlMrcv/UpfuG4vtBwcqzjQ1tLFzdZFmHkVhL0Mbw0GkmYu0ATK0EdUGFWcIoh/jK9qMEVhoAC6ExumD3gUAPLFnImZsTuc/8bSlmlfohIaYo8wPukxylHX4dIKJyI7JCjSu3WCAP8EBRFesCfo5pj+fH+dX6IUZHj/ur3BFhjnq/lLyc33mJqcV+xMUaUYQBEEQhJ6wEwd8E2q/TkCvnwCVqY3x8wenYVduMCY17cK1J67lzkmaqc0jdUDEVYFGtz8MYxujXfJfGLj4nKjMbID/wowSnYHSJs6M788pQ9WfU3ZOiZuvfry4felNFwpHbQxt7HzT4m2AHp1RRZrxBPiOJ4ikQcUZgqgCXDW4DBxt5jxDWXbMXGjUI4tTB24GADx98HA8f8uJaMzlsHpgDie/j/8ZdQLCC1c5ypCcq9md4bZNGhu6coOFWaRhtBv+5/d8W2yKMmE4v4CAecn8gaBxZuGIDH/wf+/8ahlTkeEzN5kRR6QZBznACIIgCMIRYSQO+I02U5ngiwTp12nG/u1Z/GrL6QCAb57yCgDVz+LC1GbTs9Nx6oDsuIsCDRCNZgLsdJDNuX4JWpQBwivMiAROGrDpz6kztImoj486HDitPj+Ps6+3GWv3pD3ey6xHr9rQJpLwSDNKGyBqFCrOEEQC0U7C+Whw6Ty/M/QMZYZeaMxr3Y6h9d3I9DXhzY6RWL9yOI7cnlc8485bLJztJ0dZtzIgBCeYSIbbdlmg0e0HoivSeNGOcFfauCjKAMEEBuBwWb5snyzOjBGNyPBG/Ds3FRnsOS8yZBiKDEbUkWY85AAjCIIgiFCJPdqMp1V8IiYO8KY2Wb9O+cqY8ufycdVPHpqFA7kWzBnwLi6bv5474trUJtvWTYIzAqQOuC7QuDK2ha2bwsLm+r00U5iFGf64s6QBv/05VYY2/ZzEf39kYXH79N9epbp4yFfKqAxtOlObgx6djIREmhFENULFGYKoZhw4B+KLNgsuNE4buAkA8NzBw5AtfJ0t/NMJqM/l8M7gXhxz5gbNxTNMJp1NhEYITjAgvgINYCY2gGgEhytsr9XrdxBVYUYJ/w/GRmR4OR1tRAZPkFUzMuEgy02WfU/wx/wUghFapJkn5AAjCIIgCN9oo80CEH20GY/NWEYswpiZXLZvSuG3O/KrZ7575ksoXz0T1NTmIt6MJ8ICTZjGNqB6NBNgr5n8mtkiK8yYJg2I53tFPbshPTKHqxtKlYjFW0cLZwQxtHkRoEdnwiLNKG2AqEaoOEMQNYJfF5gxzqLNVEtoRVTRZiXOGJSfGH3mwKTivhWLx2D27hYAwPRLXhde4UdoiK+XCY0QnWAiYRRoXIkNIJmFGj/XZCIwoizMWLu/REzjzPzgctUMQ+YGUxVkVEQfaSaDHGAEQRAEERNhJw6IpjYZXtFmxv06VaY2HpXGko+ffvTQXBzMNePogRtx8dEuTW08OlOb17jVIHXAdYEGCGZsq3Zzmx8jm2szW6iFGdk+mzgzcOe6XTXz408uK26ffeuVkutl+DG0RdSjk0GRZgRhBRVnCCKhBGlwqXMauOqHYJ+hrHNjiEJDdqycsQ0HML05g75cCi8cPLzs2LK/HgsAWJ7uwlEnbdZdVAGV0DDNUWYkxAkGuHWDAfZFGqB8cB+l8AjyuUEFhovCjIgv95dNnJl4zJ3IMEcWvaFygOpEhki0kWbkACMIgiCIhBJgYs8z2ozhJ9qsDJWpTTeJatKvk6e0b0t7Cr/fdRoA4DvnvIzgq2d4eOOPqmen+Fk6U5shGeG5SYHGpbENCGZui7pY4/ezTTVTmGY28bhRYSZonJlpf047hqaz+HhTSQw8uXaCcEZYhjbHPTrb5J8SVqQZQdQKVJwhiGrHUGjoHAnG0WaBMpRFdPFDukJN/vjxA/JFl6WdI7E321x27uIXJ2B2phG5VApHXvWK8B42QkN2LIx4swgLNEC0YoNHHPy7EB+u3jMMgQFEJDLEfckQGd7YiAx2vteqO4MYEFkRWVaUEUWGRaQZOcAIgiAIIrmYJg7EH21mi8lYqZz/98g8dOSaMH/Qelx4lElcq0wLicY10dQm27YZt1qmDmTkp2mP2xRovI4x/Oommb4JqptcvacLzZSYwozXa4L057QztP3fT60sbl/xl0sApBTXVvuGNk84Q5uYNkCGNqJaoeIMQdQQoUebMSLLUGZUCo0TCsWZVw7Jc4dW3nU8AGDZsE7MOeFdg88wERr8uS7jzVTHEX2Bhh0zLdL4LdQwVELB9L8gmP4MXr+PWAsztu4vr8xkdyLDHH3D2vJz+OdMZMgQVuqpTtM5wAwhBxhBEARBxIvrxIHQos1C7dfptXpG3tNz4+oUbt5T6D1jvXpGh42pzXHqAKDv2ykeZ/hJHgizSCMSl2YCzH+GODSTESZJA/w/Iq8EDa90AnMGD8nic60PFp/fvWyqcEbCDW2MtsJjwEgzadoAGdqIGoeKMwRRRSSmwSXDSYayH6GRw4kD8qPMVw6N546Vzn3j2dLqmdnXvCq8h1+hEUa8mcFSfRcFGluxAY9jPK4ER1TYXK+fIpYqKzkj7IvM/cUIkpnsGi+RwX8f8M9ly/R95Cbz2yOFR5+4ijQjBxhBEARBhISDCb5A0WaMtGIbgP9+nQyb/nx5fvjYPHTmGnHckHacN3uTwStcm9p0+EwdANwVaIDgxjbAnbktKmyuNy4zG2AZAR00aUAkuKHt+58sGb2uvesi2K2aUZ3DozvfgaHNo0en30gzguhPUHGGIBKMdlKOd4FJhIatC8x5tFla3OFOaExs3IdxjQfQna3DGx1q5bPqntLqmdkn6nrP6AZOuhxl8Vx+2+FSfZGM8Fw2YBXPAcIVG0CyBYfttYUpMACfhRkeG/eXrciQkaRVMzoC5iYzKNKMIAiCIGoa22izwIkD1iYQVb9OnamNIVtxDOU561ekcEsm33vmu+81XT3jNYZ0ZWrjcVCgCZo8UOu6KQzNFGthRrbPJGkA3Lb4N2DTn1PPwEFZfHHQ/cXnd7w1QzjD1tDGjolGVz7SzKehjRFRpJlX2oAr4zJBJAEqzhBEPyCWaDMltkKDJ//8hAF5R9fizrHozDVAJTQWPT0Bc9jqmat1vWdM9tvkKNs6wbz2SXZnhOcu3GCuxAZQPrCPQ3T4/fywBQYQoDATxP2lizNzKzK8MREZ4nmiyAgpN5kizQiCIAiiKnEdbWaFSbQZwyjaTNevU4duLKXnh48cg65cI04csg5nH2liapPtk5naXMeb+SjQiARJHgC8NVEQ3RQHYWgmaI6rfrcZ4bmTwozfpAFdcdEGvaHtXz7dXtz+8L0XIBd41Yzp37yFoY0RV6QZj+a7nNIGiGqGijMEUWOYTv45jzYLTWjwjo/8YGNe63YAwOuHDvO8vFV351fPLDdePSPbZ7KE2a8TzNFSfcCuQBOV2GCEXawJ+v5RCAwgQF6yiftLddwrzkyGTny4WjXDkImIRu7RRGT46WWFwM4vhqvvU3KAEQRBEETIGK5ctUocEBETB0Lr16kbJ5msnimx7p06/HHvAgDAd9/r0tTGn+vH1BZCLDRgPib3a2wzPUdE1DSudVPQ9w/6c6s0U0bY57QwIztukzQAyXMHq2YG5/DVQfcUn/9x8UzhDD+GNnbcp6FNRtyRZpQ2QPQDqDhDEAknygaXgaLNGGnFdgU+J1PRiCNb8sWZtztHwUtoLHxmAo4KtHrGJkfZjxOMJ8ICDeBGbPgp1AByYRDkPz+Y/gy2AgPwJwIDu7+8mlny56nEsVdmclB0kRv8cdnfs0xkiGhyk2OKNPNsakkOMIIgCIKIhciizRjO+3UyvPp1qig/9wcPz0dXrgGnDF2NM2e5MrUFjTeD5ByvfZJDGcnxoNHQgJkWCqKZgPg1E+CmGOXKzGZdmPEbZ8YQ+3NCcZ6I3tD2T5/eUNz++P3vRTbnNT1rYmjzioP2mINJK7Yj6NHJQ2kDRH+DijMEUSt4OApCjTZzLjRkGcr5gUZzqgdTm3YDAJZ1qWZYywclK/96HIDC6pmTRKHhdxWAi3gzR0v1M8JzFwUawFxIBBUcUePq5zItzJjkW0fi/rKNM5MR1qoZ/m+2UTguizjkz3WUmxxypBlBEARBENESxYpUT1ObjMj6darOlxds1r5Tjz/vPRkA8N3zXhWO+jW1eZ3LbztOHRAPZRBe8oCNtqgW3WRzvV6aKfTCDI+uMCMiFmbE4qFJVB9ga2hrHQx8c+Bfi89vXnSkcIZLQ5uM8Ht06ggaaSZ+t5Ohjah2qDhDEDVI5NFmDN3Nu0gwoTGzeQ8aUjns7G3Ftl4xDFWxeua5iaXVM1eJq2dUmAqNIPFmsmOAswKNjRvMZZEmiYLDVlz4FRgZYZ+J4IvN/eWF61UzPDqRoTtHhePcZAvIAUYQBEEQyUA7QRdS4oAnptFmSmz7depWz8jGVuX7/vXh+ejO1eO0oStx2hFbhXOjMLU5Th2QHcoIz6M2tvHnJk032V6XKzMbELAwY7qqSvfvyW+cmQyPXjOfXV/c/vSD56LPeNWMH0Mbm2+RzcNwpBXbDEtDm0mkmScUaUb0E6g4QxBVQKKjzQJnKMv2qYXGkS15kbC8czTKl/GKqFfPHL3gXeFc1eDJb46yjROMx7RRoeL0jOQcV2ID8C844hAdfj7f5NwgAkN2XuDCjAq/cWZhrpoxyUQXs9PFiYaIcpPbCo9hRJrxkAOMIAiCIKInjsQBhlXigMrUZhoNbRJzVM7qdxpx+96TAAD/ev5Lhq/yY2pTxZvp8Glqkx3KCM9tNZOrIg1/flyFGj+f79LMBjgszMgioE0MbTItpIszg3Ae/1q9ZhoyPIevt5ZWzfz29dma947A0MYTsaFN9l1LhjaiP0LFGYKoJapKaIiDAjOhMbVpFwBgRddoq8tb+NxEzNnThFwqhSOuetnj7KA5yuJ2REv1ATcFGpdFGvE1YQmPIO9vWpRJTGGGJ2icGYT9stfyz8Xz/KDvE1U6R3ec4TM3meFTbFBTS4IgCIKoDVwkDvjq12kUbcajGu+oTG2y8wD5OKx8cvf7Dx+HrlwDTh26Cu+ds0l4HxtTm6oA4zfeTHYMCL1AI57HiEIzudZNQd/fb1EGUGum2AszELZFQxt/XKeRzLjxMyUd8ZG/nSdZNROxoU311ZKkSDOOKOIqCSJqqDhDEDWK02gz5xnKPKrJVbnQOLxxLwBgU0+aOw8wERpL7zgRdbkc3k734Liz1wmfFzRH2STeTPf+IS7VB/yJjTAEh/haV/8FuQYdXgIjI9lvEifnOy+Z33YRZyYr1IQRZ+YlMmS5yTqRIWKZm0yRZgRBEARRk8SROGCMabSZZ79OZmrTGdtMzDByVr3TiFv2nAoAuPG8FwDoVvPqTG0iKlObbNLbNnXAYYEmLGNbEnRT0M/Xofs9ZCT7ApnZxIMmvTlFTAqHpkkDZoa2kWOz+FzjfcXnt745S3FtQCINbXFFmmm+uyltgKgFqDhDEFWKjWMgcLSZCicZyiL6gcPYhv0AgC3F4ow5S186DHN2DQAATLrsNY+zbSamg8Sb8Th2gmWEfbIBMDtXhZfYAOJfim+DjbjwIzC8nF9AwLxk08JMkuLMeHQig48p9BIZDnKTGRRpRhAEQRD9jyQkDvCmNlmcEADP8Y4RonPe29T2L/fNx8FcM94zaD2uOi5MU5ur1AEg8cY2oHo0E2B+rS7MbOxcHl+FGVutLdNFfuPMvPn3j79d3L7urxcia71qRmZoY8cdGtoYFGlGEJFAxRmCqBKMG1wGiMxJRoaynqH1nQCATFZ2vrfQeO2Wk9GQy2HF4CxOu1ScUA0iNFzEm8mOAb6FBmDnBhPPZZiKDSCZhRrba/IqymQk+00FhtNl+SJiYUa2msu9yPDGr8hg+3TFGo/cZBkjhUdLKNKMIAiCIGqL2KLNGGnFdhlDINdP+n6d5Zivonm3vQ6/3H4WAOD7Zz2PulRWOMPWvGOySkHcJ8OrQAOD/ZJDGck5fo1t1Wxus7kuP2Y2IOLCjJ+kARE3hrbDJ/bhww2PFZ/fvnSm5r1sDG26ORRLQxsjQI9OHRRpRhByqDhDEDWMTGjIJg9DizZLc/vSlaeVMBcaA+u6AAAHs03c+eZCY9WSUZi9NT+Zmz5vMQBToeGVo8y/RnTd8K+PeKk+EL3YYMQpOvx8dmwCQzxBtyxftk/l/uL36f6dJnHVTCMqhYaP3GRZpJkjKNKMIAiCIKqDREabRWBqKyGbwPU2td34tznI5AZiZstWfORU01lRk4KLqJGCpg7wiONoBwUaW2Mb4N/cFrVu8vPZJpopI9mvShkQzw1UmOGxTRoAd1z3bxKK43p++uGFxe0rbr8YOaSEMxJgaEtL9lka2pxEmvFQpBnRD6DiDEH0EyKLNrMWGrIMZTXZwtdWCjn4FRrP/GYBWrJZrB2Qw7kffBt6bHOUZdteTrAQlupHITb8Fmpcio+g72vys2Rg//vj8V2YCer+so0zE/d5ORdNCCIydDjOTY460oyDHGAEQRAE4RbjCbs4o80Yvvp1yva5Xz2za0sK/7XpHADAd055Hk31ogYyMbV5xZvJXuMydUA85nEoA3fGNsBeMwHhaCYX7+tXMwHmfXwCF2ZsI6BN4sy8nov7KucKpkzvxRX1zxef371imuQ9GCaGNnY8RENbnJFmlDZA9DOoOEMQVUSQBpemxCs0ZK6wcqHRmc0PMlpTougyFxqb1qQxc9OI/KtOW4pUnW71jGq/bbyZzgkm+4yABRrZ4YziPJ3YUL2G4Udw8KiKK6b/+cXkujOwFxji+ZEWZkzcXyq83F9RrJphx8UJBT43WRQZ0eYmU1NLgiAIgqhNIo02c9avUxZZpMKfqe3H9xyBbdk0JjbuwufOURv78oQRb+aqQKPBRfKA7HweP+Y2nqCaya9uMr3ujGK/qZkNCKEwo0JnUuPxShowN7T96prnitvvu+0yIPCqGdOVcwENbQwHkWYEQaih4gxB1CLcJGAk0WamGcpKzIXGlt78Gx7euLewx5/QePTnJ2NwXxYbm1O48BNveFyfl9AwiTeDsG27VN9hgSYjOU8lNnSv4QkqOKLARlxkFMecCQzxBFfL8sOIM/OLV8FU1sjSBEe5yQyFyNBBTS0JgiAIorqwWakaerQZI3C0mYj71TMHdufwo3XnAgC+dexzGNDYLZxha2oTz9HFm5mmDsiOAZXja0fR0EGKNEDyNRNgfo0Z+DMA8vjqyynuN4mA9koaMNFJ9oa2OfN6cXZ9ab7hgVWTFa8F3BnaRHwa2hxGmvlKG+AMbeJ3OBnaiFqCijME0c9xEm3GcJahLCM/4FjTla8EHd26GUGExo6tgzF5bf69uo5diYZm9XLlQXXdmNeyDVcOXYEvjHgZ3x7zDP517JP4lzGP4aujnsf7hy7F9KYdAHIVr3XnBBO3AxRoAHuxwV4je51Ikgo1NteSgf7ns4kzcFKY8ev+ChpnJuJ61UyjcEwsospEhoiD3GRdURnRRJoRBEEQBBEO2ok7fgVrkqLN0ty+dOVpeYZA3a/TC3+mtv/52xSsz47G6IZ9+NqFSz0+w2s1gmm8WRipA6r34A7ZGNtcFWmSoJtsryUDf2Y28TVONRPgNmnA1tAmP+/Wi+8vbp/6+6vhf9WMuF+HzOxq8B0RQY9O3XcqAIo0I/olVJwhiConaINLGfFEm+mERmlA8uKh6QCAcwctRz1YHJk/oXH//5yEYb1ZbG1K4bLPvQoAqEMWs5p34fphy/HTwx7HE1Nux6vT/4jbJt2L7419Dp8buQjXpt/GFUNX4Kr0cnx8+CJ8f+yT+NvkW/Dw5F/huvQiNKIP5k6wmAs0Gcl+V0UaIHrR4efzMvAuytgIjNAKM6buL36faWFHtU8tRnK572je18+qGd1rWKSZRmSEnJscRssgfhYAANIPSURBVFNLcoARBEEQRHKIPdqMYWVqU03UqvpR8MfN6N6fww9WvBcA8JW5L2BYqzjYNVlxrSrIyExEtqkDDgs0qsMZxbmuNBMQrW7y+1kZJMDMpjon7KQBk2JNOaec1oV5dWsBAFs6BuP5DeM1Z3utmnFsaEsrthmWPTqDQGkDRH/HtgMvQRAxc3Euh/tSottCwpMAzlIfnou3sARHle2bhjVYjanS80dP2Yjtayfkb8KbUvmbcjvyN+ktyN+0tyF/ExczbdOoHJC1ghtLlT1BaWBROdB/+sCR2N07EOObMrgu/Tr+kDlGOKMBpgOmfZkBGPfWRIwctRHzDrTj/LZDOKphNwbXV37ulp4BWN09DO/2DMHu3lZ05prQkMpiWH0HpjZl8J7WzZjYlME/j3kM7x+6BF/dfAnW94wpXAv7qu1BfvDE75PRUfgdsPP5feL2PpQGXPx+xdtCckoG8gHZQciXOfOvg+K1KrwG/15uHZdCJeNx3EtsiXgKDPEkF8vyZeLVT5yZ/eqYVOpfDM6yWTXD79MVa8LPTdZBTS0JgiAIovaZt3s5Fg+fVbZvBt7BSsyUnj8BG7ERE+w/iGmnUagc56ahGa+K+kncx48ReS3SAHUxRa2jfnfveHx5xgTMatiIf7j0LXz9z8cLZ/C6RaZ/+P38PvEc3fnsOdM8vPZh55hoJvGYBNnhDNSaCVDrpgy3LXu9Cp3uMVnh4Eo3ZQzOsSnKAAELM6YFOkiOmyQN6F6v2id7XQ73nnFz8dkFN18qOSdBhjaGT0NblJFmBFFrUHGGIPoBqWeA3Onq45EJjYHQT3gD8BIanbkB+Pmuc/BPY/6Gr49+HAezTbh735EoFwAiJaExuC6Lo1q24ujWnTiudTuOfnAHWuv68qc1bwMA7O9rxKKOUXi9Ywze6BiFVV0jsT/bzL0XoyQqWlM5XDZ0Gb4w4hUc2bINt068DTdsvA7ruseiNFhrhFx8iEKDx3GBRnVKBv7EBnstQ/YeNoTtEssYnheqwADcLstXxUfInIhQHJe9Xr9qRl2csR1aqPLQgdLqOfFYuLnJFGlGEARBELXHfauAi6cXnjwC4Dyz18lMbYypWI01mKZ/gzbITW0y0jAYrw5B+ViS1womME3CNBJ7LjsnT19XFt9ZfA7uOPb3+PyMF/Ef6TnYkhmg+QwvMxp7b1HDqUxtJgUaRsgFGiC4bpK93oYoEgkyBudEambjj5munPKTNGBiaPNeLXbVlQcxLHUAAPDStvF4c9tozdkxGtoYNj06DQgz0ozSBohag4ozBFEDlAkND2QuMBlGQsOGNMoHoxnZSWZC47bMyZjbsh6XDF2MH4y7H9ekF+L+fUdiaecobO9tQVeuAc2pLoxoOISJjbsxoWkf2hozmNOyE1Oa91Z86sGGeiydksWq8cBLj5+MZ1ZPQrYs9ZEfHMmdYB25FP6UORqP7Z+GX42/F0e07MBPD7sL16y/AYdyAyvO9+cEA0Iv0AD+xQb/Hqr3iYOMxbmhCwwgnMKMGJ2nEgymmcl6gq2aaRAeTRHz1IV/wGnJSxzlJkcRaUYQBEEQhFv8JA54mdoYVokDKpiZTZY8wJvalIkDvCaQHZfhf/XMXx4ehdfnTsWxTWvwL1cswqd+d4rkvWVmG9WqGtVzoNzUpir0iLrJVjOJxxQfAckpmcJjWvKaqM1tLskYnmermQCHhRke2b9j10kDXqtqKvfX1edwx+xfFZ9ff/u5ktfFbGhjpCX7vHp0UqQZQTiFijMEUYVohQbvAos72ixEofGtrddiXfdIfGrE05jbugVzW7eof1CBDd2DsKRzJBZ2jMbrh0ZjbfdgXPr5O7F6QA5HTV6C7DcmC69QCY1KdvYNxCc2XYo7J/0ZU5t341MjXsJPdjKFxwshWyeYnwINEIvYEN+HIXu/MBA/1wuv1Vyq94usMKPCxP3lTmQw3KyaaeQe+Ux0MTdZ9u9XkZvMk9YcC5CbTA4wgiAIgug/RBZtxrCKNot+9Qz6gH984Uw8cuYafGT8K/iPw4/CineHCq9xHW+mSx3gCbFAozslA/W404+5DZr3CwPxs73wo5tM/0kaF2Zk0X0ukgZk+IkzA/7uk6VJkNvfmYU1e4Yp3h/QG9q8YsxkeBjaGPy/y4A9Op1HmnGQoY3oD9R5n0IQRC3g5UiQ3TQZumgfI2Tu9bTqZHHiVZaZCuRQh//dfSbOW/tF/HD7uXj+YBu29AxGV7YeANCZrce2noF4/dA43L13Bv5rx3x8ZtM5WLD6Opy/7ip8Y8spuD0zA2u608ihHtvvPQ4AsHTUIcxbIPt5VYOwyv27+wbg+9vOBgDcMOw1DK8/qHmdauCoW7ItYuswMjwlo3nNQe4/UzKS/4IS5D29rl/3fpEWZly4v3j8iQxG9KtmHOUmW4oN20gzHnKAEQRBEEQyKZvoe0R5mhXaFbasn11b4TkzhzCziGx8ki488uOasqGQZlxUhJ/YFU0w7LgK9bFHnx2ORw4ejcZUH/7ziuc178Hw6msojmH557LxqUoz8duySXtx26dmkp2WgZluMiWj+C8IQd/Tr24y1kxRFGa8kgZM4sy8GTgkh/8cc2vx+ZfvP1VyVpBVMwENbWnFtoiHoU1HYEOb5ruZDG1ELUIrZwiiH+I02qwNwTKUlatnZM8B8WtrZ99w3LJnAW7ZcwLYgCqFQ8ghBf3AvpwXH56C685bjLeGdWPqtS9j8Qterje9E+ypgxOxpGMs5rZuxZVDl+BXu0+C3gkmQ7eCRrRvOXCDQXJapvCY1rzWZjWNSMbzDPeYiKOMYr8v5xfgvjDj5f4SCZaZzHC7aoa9RiUyRAxzk9l2xJFm1NSSIAiCIJKHq2gz34kDpsj6dTLS0IyZZfoJMBu0iqtlDFbPAPjqg6fhrCuX4vzhb+O9R8/Do28eLrzGVbyZbf8Z1WfpVtAAVtHQutMyhce04nW8BqkG3RS7ZhKPBy3MgNvnN2nAe17h+19YX9z+t5ePw5YDg6Tn5aleQ5tJj04vyNBGEHlo5QxBVCmiY0DpAvOI1JE5GGROB0Yx8sfLMaG7uRsNRmUDC4PBMoAcmgpbsoZ6UO5bdPMCNORyeGdwH864WjbRaj6JDaRwR2Y2AOD8wey9bJ1ggHnTQ/G5DzeY7rQMvAWBn9U0UWF6bRmoXV++l+TbFGZk7+PH/RV1nBmPqcjwigpkucn83/0QWIkMhpibTJFmBEEQBEEYIDNfyFbQMjNH4MQBRtrkJN6MJes3AehXzzBkE8DqSeG3l7XidztPAwD85/lPoT6V9bhOr8lt3QqGHsk+2XOv1AHVGByoHKsHWEUDVLduCqqZAIdmNtvCjIj4b8Vl0oCcw9py+HLjXcXn33vqBMlZCTG0MUI0tAWNNCOI/gIVZwiiH5GIaLM0ty9deVoeV0JDhnow9M7iMZi9Of/ZrWe/iVRdn+Qs83izJw5MQV8uhSNadmBswz7J68TJdUie87go0EQgNoDygX0cosPm8zNwKDBsl+QDbtxfEJ57Ob10++ToCzMJFRkBc5NlUKQZQRAEQVQ3ka9g9RNtxjCKNlOZ2rzGZzL9ZGZq++e75yOTG4gjWzbjM2evkLxGHGN6GYNk8WbidlQFGtlxBbrTMrDXTVFjq9ky0Gsm5ykD/HOvwoxXBLS4T8Qmzky+/2cfeq24/dkHzsaB7ibpeXlcG9p4hlSeIpKW7DM0tOlwGWlGaQNEf4GKMwTRTzF1LDjPUGb4EhpsX3hC47Gfn4bBfVlsbAYu+dxCj88BdAWavdlWLO/KzxLPa31XOMdP/xlxmy/Q6PKUA4gNryJNxuytKgb+LsWH3/fOQH/9zgSGeNxVYUa2EsskD9nfqhkzIhQZMtLep1SIDA3OIs0IgiAIgogN7QrVEBIHmMlDtyJXCtNOMid7WvUiWb9OGW5Xz+zYUocfrTsfAPDdE5/BkOYuyVmqSXDdfq/+M17GopgKNC6KNEAyNRPguCjjqjDDYxIBHW6c2fEn9eCK+lIfpl8vPEpylktDm4iYNKIwt/HzMD4NbXFFmlHaAFGrUHGGIGqIyKPNTHEmNHTNLt0Ije3vDsbElYcBADLzVmPIiEOSs8xXGyzpyFeqZrdslrzez1J9cTtkscFOdSU4eGQCwc9/tnhdr5XAYC/gcVmYETFxf7kTGQz3q2bY8wAiwzY3eYzkHCCaSDONA4xEBkEQBEEkDz+JA6amDmPSkn1aU5tsBY1XuoB/U9t/3D0Nq/vGYWT9fnzn8jc0n8EwHYeqnruIhRb3m4zjHRdpMmZvVyRuzaS73sCaSRf/7FWYCTMCGh77RHK445zbi8/ed9tl6MuJ06281nFhaGN/84aGtrRi2wEUaUYQwaDiDEFUMX4m9WKNNmOkJfuUQkP1PDyhcc9PT8bY7ix2NdThfX/3rOL9zeLN1nQPBwBMbtylOS/OAk0IRZqM+VtGRgYhFWX8ZCUD5nnJXoW8MEVGHnerZvjzVAUdVW8pA5HBkBWBLaBIM4IgCIIgeLwmFHWmtiKmiQO6fp1p1Zubmtr446YGNrUJp/tADt9afAEA4PMzXsSU0aLWAPzFm8leb5M6ANgXaBwlD8jeTkYGydRNGZhdV6xmNn7bbwS0F3ZJAzdcuweT6rYDANbtT+OBVZMNPkOFl6FNhmb+xK+hLaoenTwUaUb0U6g4QxD9mEijzXRCQ4nQ/FsrNNytnunqaEL98/llyMsn7sH0o7d5XShUk+EbuwcDAA5r3K84z2upftgFGtk5HtgKjozd2zvB5vOtxQV7EY/O+SU+t81Ltl2Wr8I+zgxwuWrGpkgj+1v3EBlpydsEyE2mSDOCIAiCqF1cJg7I8B1txmBmk7TkmJWpDShfpczrpEbhOCTH9PvufGgknumcjeZUL/79/S9JXgO4iTcD7FIHALsCjex5yMY2RgbVpZmcmtmA8AozMq0tvrdN0oCcluZe3DTzpuLzK/90EYCUcJbfVTOqeQzVSrlkGNq8KDO0eXznMihtgKhlqDhDEDWGUmh4EHq0GUMmNNKVp9kLDRn+hcYjt8zGEfvr0ZNKYd7Hnpe8BjAZrGWyzQCAofWdmvNNHD78Pj8DWBM3WAhFGkZG8p8r/L6376KMbWSci8KMn+gH2ft4vVc5+lUzNiKDP082MaATGQrSkn0Bc5NlUFNLgiAIgqh+okocMDJ3tBUe/fbrVCIbR+lW0MiwM7WhD/jqY2ehL5fCZaPfxOmztxp8hsm4VFWg4bfjKNDIzvHApqiRkfznCr/v7UszsRfyeK1Icl2Ykb1XkKQBuW76/ufWF7fvXjUDi7aospRNUM1VODC0MdKaj1cZ2jRQpBlBBIeKMwRR5RgLDW6yMPRoM1Fo6OCFhlI7RCE0xIFQHd65bQEacjksH9qDM69UDSb08WZ7+1oAAEPru7j9ppPnuol524Gs7LkDscFe4mfQnnH0XxTXWnyxiJfAcFmY8ev+kp3nTS73HY9VM6boBIXqfP7vW3CA2YoMC6KONCMHGEEQBEFUB0GizYomEM2K3TIC9+uU9aEQV8+Ihhn/praFi1rxhz2nAAB+ctHTSEH2c9rGm8me61IHghZovJIHHBjb/L4s4+g/P9fpWzPZrkIKUpgRXxdd0sCUti78f/buPE6Ous7/+Lt7ztydkJADkkxIuG8PDpH7iKLcKCiu4uqqq7u6iu4KP491xWt3ddVVcb1PROSQWxC5BEFB7ksg3AkJ5Jrck8lM//6Y6aSmpo7vt+pb1dXdr+fjkcfMVFdV9wyTUO/+fL6f+vjEK7Z+/eGrDw/Yq84NbTUmI83CMNIMyBTFGaDFZTLazM//P/1CBo3RHrpzlvZ8aZIkqfPY+9XWaTuftl9rBoZWzowt96tDA1u3J1+q793mqkDjKGx4D014eCZSv6awgGHa+SW5KcwEncuk+yvZOLNoaUOG9/NayPDz3+Q2gIO5yUEYaQYAQPPLa7RZapXhj6FNbVHXTDmsnpF07mWv1ZrqGO035nn9/dFhhSrT8WZhb76HPe6iQON/zKTA4D/GQstkJin9lAHv51HTJbzHu5w0EO6777hz6+efvuV1enHNBKPjgjluaKupBHzOSDOgUCjOAE0o69FmqWco11QCttU1aIwu2vzuG4epMjCoxV0lnfzhu0LOG37xtnawc+vn49s2+B5NulTfu820QBNVOJDCw0aKxLBR9Qkezp7XJmDkUZhJ2/1V9FUzQcXX7Ocm1/4dCwoZjDQDAKB5FGK0WZr7ddZUwh7wXzdFjUhyu3rmpcXt+uozx0qSPn/wTRrbZXLdaXKdGjd1IK7gk6ZAE/S148Y27+H1KNY4e96wzJRVYSbo9yDud8PtOLOFx63TsW33bv36v+94dcBedWxoC7pHZxDHDW2MNAPsUZwBmkBeo82czFAOChq1i4VK2EnzChqjrVg2XjMe7JEkLd79eW0/N6iIIYWtVBgM7dIJuzg07QTzbjMp0ATtZxI2/Mel4A8eaU8bdD4nLzWqKGMbMPo9210XZhTydfKQURNdmMk6ZPgLrzEhoxLxUv0hw4BpyPBipBkAAM3PyWgzW0H366yJbGoLur6y6cq3b2r7yiW76LmBaZrZtlqfPi2smSXqmtT2/jNJtoUVaGwmD0iZNLb5T9PwuSlqP//XaQszitnmdtJAua2qXx/0i61fn3bxm7Vpi9n7C8Hq0NBW73t0AtiK4gyAVJ0M1jOUaxooaFz+rQPUs6mqdW1lHf2hWwyfZ+hCruyZuTxYLfn2SbtU37stSYEm6Gsp07ARdtokfzJ7MX6mHXI2AUOyK8x4xRX4/Pv5RRdmqtXPRj5uro4hIyxs9Ax/tP33KgAdYAAANL5CjTabHr9LZAe8pOCmNn9+yqaprW9jWefdc5wk6V8W3KoFs9aG7Gm6mtt0FFXUtXNcgcZm8oBNYxu5aeS+UV+7KMwENTnW9nM/aeDjH3hZk0pDUzEeWrG9Lnts54C9GqihzUImI82YNoAWR3EGwAi2o82cKXDQGBho18orXytJemj7DTr4jWEriEZ3gnn/ka16to88xvSCM2i7d1tUgSauG8w2bGRylV8nUd+PyUoi284vyb4wY7ssP6txZq5Dhl/EzS2DVvFXIl6qAecjzbwIGQAAFE6hR5vVJL5fZ9xNwrNvarvwd7N068bd1V3q1/+99WZ5E1C4qOtW06akpAUa/+euGttq+zZTZpLcN7O5Ksz496ntZ5OPzCYNTJ0xqK9M/eXWr8+6eKEkfxOmjRwb2moyukdnFiPNmDaAVkBxBmhSJl1gdZ+hbBw0/PIPGn+6doH2fmWsJGnySXerrT1s9cHIi7pSybNyZsRFW94FGv/+cRfLNWFFmrBzNJK4ooxtwJDSFWb853DU/fXZQ6VPvT7kuUJ86jDps4eb7RvINGSErYaLkWZucoTUI824qSUAAE2tLqPNairDH73XQaGXTd4HgprapCya2qSSPnjlMeqrduioSY/rrEOfDdnPdLxZ1DG1bSarz10XaEwb28L2byRRzXmmY7FtVylJ9vk4yQho335bc1N4Zvrqex7Z+vn37t1HD73vtIDcVNCGtpT36AzCSDMgPYozQJNI+2afk9FmSVWGP4YGjZDlurkGDekP/3ukJg4M6vnukk796J+MjunQ5q2fb6kO+h41Xarv/dq0QGOyXF8yCxuSWeBohNAR91rTBIwkhRkvm3nJ3n3CzjdsoCp9/nDpU6/T3I61+uWc3+tfp90nfyfj1lUznzpM+vyRQ8eNkGXI8P9dznBu8nAROc2/W4w0AwCgeYQ2tcVwMtqsZ/ijzf06ayphD3ivo8Ka2kwzkH1T26OPj9HXXjxGkvTVw2/QpDF9Ied2cf8Z/z4uCzRZNbY1QmaSkmemNM1sLgszCSYNbM1NhynI0e8Zp3e2/37r1+fu+/GQ3GSqsRraov4ti2toG4GRZsAIFGcAjOJ0tFkTBY2Xn5+gGQ/MkyQ9u/uLmrPrypBzb7vom9A2VJzZUi1pU7VN4W+imwQImwKNlD5s2AYO73FFCh0mr8lVwPA/HleYsZmXrJCvY0LG+bdLn75V+vyROvd1z2r/MSt09pS/6dFdf711z1GFmU/fLJ1/m5JxFTLqNzfZ5UgzAADQoPKYOGAq8f06Q3cclk1T23/8ei8tGpih6W29+s8z747Y03T8roupA2HbTCcPBD3eTJlJStfI5rqZrXaMd1vUf1svy0kDNeffPpSDPn/kqAJNW4f0kx0u2Pr1hwbfp5WfPykgNzV/Q1vkv20e3oa2uH9DgzBtAK2C4gzQxAox2ixOgwWNy791oBZskDaWyzroA7dE7Dl0kTehPFScWTfYqW2zaJPMvrUdfVWTJmzUtiUJHN5j8w4dG2X+3GkChmS+JD+og8tmXrL/2Kj9/Nul2V+5Sse9+6t6+g2vHbH93p1/s+2LyMKMy5Dh5zBk5DA32ctkpBkdYAAAFEvWEweimtpSTxyoieqIlzS6qc1/DZZtU9um9SX98+1vliS9d/adOniXVwyfy+T+M1H7uirQxE0esC3SmBZqipqb4lYExW2zbWbzHmPbqJhg0oB3//NvCyzQnPeJFdqxtFyStLg6Rd/99y8bFmbC0NAGYAjFGaCJ1CNo1LRK0KgOlvX8RYeorVrVo5M368gzH404f7/GDxdn1g5Evaa4pfphb9onLdDYhg3/ObxMAof3vK7DR5JzRr1m0+8/rvPLZEm+d1uSZfkK2M+777bH/neH2/X1O3+gs3/ykxGPdJcH9ORxf1KpFkCMCzNhTEJGuxKHjJpKzFPEsJ2b7MVNLQEAQI2TNyJ7hj+muV+nVVObf1s2TW3X3TJFF/cepHKpqu+efKPaSv4RzzWm4828j6WZOuDfFjQaWhqdmdI0tklmmcl7DtcFG9vzxuU8l81s0sj/blkWZgwKNr4Czc4HdOg/On669eGT3nepBj97m+GkgQI0tNWYNrRFSNLQNgIjzYBRKM4ACJTrDOUGCxr33jZbey2eJEkqH/2AxkzYHLrvhLahC76hlTOmnWBhBZqgi9AkBRr/12Fhw7QjrMa0UBN0zqR/TCUJF2HbXXR+ebeFFWbCni9u+2hRxcH+557XT5/5keUoM5chwysmZCSZmxzBaDSj6AADAKDZ5TlxYGuTiO1os5rK8EerpjYp76Y2SfqX37xOq6vjtM+YF/SxE6IaXGwKNP7taQs0/u22kweSFmnyzExJclPca4nbFtXM5v/adPxz1L5pCzO+/O0p0Dzxxq9s3XzBq96tv363t7Ea2oLeb4ni+B6djDQDolGcAZpcXUeb2aoMf6xb0Ah6ozn4PFd/7UhN6x/Uso6STv74LaFnn1jeIElaM9g5vCVNgca7T1yBJmy5vk3YCNrHex7XhRrXbMYImGwPChguCzN+7kLG91fuEXD+bV57zz36nx99KeCRrEKGlz90WIaMuLnJPcMf8xppRgcYAACF53LiQDHv1+m/tor6usZ9U9tLi9v1mceOlyR9Zt8/aMft1hkfG870OttlgSZJY1vU9pqiZaa4RjaTCQNSfDObd8qAi8KMIr4O2h43Mm/Y+bfpJ586fcSmc+8+urgNbTU5NLQx0gxwh+IM0GQKP0M596BhsnrGZNbrSL0rx2jMH4fe8H5szgrt//rg731q+9CF5/It3tflepZy0GP+bVmEjbjHavwX/FmED5vniAsXcUUZKfpnmLQwk133123rZynOwgkv6D+m/yV2v2Qhwy9oJGGKkGGBkWYAACBP9blfZ41pU5tJHkrW1Paty+bp7s0LNL68Sd858/aI8ycZbxb0mHdb1Jv8aSYP1I6zLdKYFmqyLNjYPIdtDkzbzObdblOYCWt4M580EGTWrC16V/vvt379thO/p97z7w3YM+rvQdDXGTW0VQJOVeeGthFoaAMCUZwBECpV54M/aMR1ZGQWNGyZB41rf7av9ljdoS2lknY86w61tY/uuJnavkmS9MqWsRHP6WKpftz+/u0uw4b3MdMl80HhI80fE0mKTHE/A3/AiFu9lLYwo5DtI1/na8Ys0akTn9bE4XsejSv36+m+CSHn2eb0ytMaU6qd12XIqD0etSw/6GuZhYwEc5MZaQYAALyymjgQxPn9OismO7tYPWPf1FYdkP7x2oXaUi3rhKkP6uQDo773NAWasDfmw97sD3rMf96ga3/T6QMucpPrzGR7z5uox73imtn8X2dRmPFLOM7MY/E/fHPr5w9O2V0X7f9W6VOHhe4/kv/vSpJVMyka2ipmrzIMI82A/FCcAVpAXWYop1WIoBFXoCnpjm8drfEDg3pmjHT6OX8ctce0tuGVMwOdvkfCLwKTL9WP29+/3TRshBUoXBVqsmIyZzmqKOM6YLhYlh/f/bVb10r9ZPZNOn/mX3TT/Ct1zrT7dcv8K7RT19rYYyWpszQQ8WiSkOHnDxmB1deRm9OGjIi5yYw0AwCgNRVutFnP8EfT+3VWAs5hdL/ONKtnvMyb2v76wDhdsOwoSdI3j7lBY7uirmnrWaAxmTwQ9HVQdqjtF5WJipibbB43aegznTKQtjDj4D4zHh/+uyUjvj79cwdJn7lF+vyRvgJN1Ioyl6tmor5Wrg1tjDQD3KI4AzSheo42i2QaNILUJWjEW/xURTPuny9JemrXpdp1/5EXcbWxZq9sGaP0S/XjCjQ2y/Vtwob//P59TQNH1sHD9Hmi9nEVMLzbo4pnSZflh//ubN++TuXS0Odjy1v0nimPa1w5qhi4zfn7v1W9g11yGzJMVs2EdIBVIk4bFzIimIYML0aaAQAAE7ndr7MmdvSr9zrL4E3eTJrapPMu2k8vDmyn2e0r9IUz74t70RGyLNBEbU/a2Obd1yY3ZcUmn6XJTNLozORiyoB/37BzRG33PzZSuatN39jpoq1f/8etB+rJlZOH7jXz6Zs9BRrT9xnafR+L3dAWhJFmQHYozgCIFNcB4XSGck1l+GPioJHn6hnpt//7Wu26rqy+ckl7vPePUmlw62PTRt1zJssCjXdbXDeY/ziTsBF0Dv/+JkEiKBDYdGvZHBd0jiBpA0btHN7HvNuCAoj7woy0RY/3VTQ4/Ffvd2tmy8bmU/aUPnWE4d6uV80EzE32qgx/jCrievUMf7QMGV50gAEA0FpcTBwwlep+nUHXQ5Xhj8ZNbd6vs29qW9db1r/86XhJ0ofm/lH777YqYu+4N9JdFmhsJg9IwfnAtEgTtn/Ufq5zk83z+kVlwaBz1NgUT+IKM0nuzWln9r/sP+LrL91+wLYvRhRoXu/ZK6yhLazBzS+jhrYaf0NbBFf36HT1byXQ7CjOAC2icDOUw4JGkMrwx4KunpHKevB7R6h7cFBPjqvq1A/ftfWRqW1D95xZvqXbs389CzRR2yWzsFE7R1zgSNrx5Wq1jck5bIoyUQHDZEm+Ah4L2+Z/jWHPO9rLW8bqxnU7SpLeMDH+7+LbnjtGe/ztDB321Em65MGp0ucP9wQNk5CR8aqZ2CKtGdORi4w0AwCgdRR+tFmcyvDHBlk9c+lNM3X16n3VURrQj07+vdrbN0e8ZvOVDsFMCzRBj/kfD5o84KpIU7TcFMRmikKSZjYXhZko5r9Lk889SK8vD/1dHqyW9PofnaFNW3y/z+ffJn36Vk9uMnnvwNWqGcuGtrD3WQp6j06mDaAVUZwBEMj5aLOe4Y9xQcM/Q7lBgsaTD26vnR7dQZL04r7Pau7uy9WhAU1u75MkvTIQsgw5UJoCTdLl+knCRu08YYHDe1w9luiH6Vfw6zYtyiTt/HI5LznItsd/vHK3mH23OXTcS5JKWj4wRvrCnZ6gYXKzS9uQEfZ1TMioyXFushcjzQAAgA3r+3WmHW1WUwnaWJzVM5L0j786TKsHx2m/rmf1+bc9ErO3y6a2oO1hBRqTyQNSfHHCe27b3JQVm3xmW5RJ2swmRf+3qW2P+m/hYJzZ/ztUMwbX6csd35ckfesv++qOF3YI3vn82z256eDhjXVcNZO2oc3xPTpHCGloAzCE4gzQpFy+GehktFlalaCNxQoal3ztEM3fIG0ol/WaD96s7dqHbsDeXy2pd6DLt3fapfpB20yX60eFjbRFmqjA4T+H7VL7NEv0o15f2PH+bbYBIOi/R9LCjF/0788Dm6KWoo303imPaWLZ07VYCxpttUsElyHD/3czJGTU1EJGJeapwqSYm8xIMwAAWpPr0WaZ36+zEnBsYLNLMZraXnx5jD5y19B4s4/P+4MO2vOVwP22yaNAk6axTQrPTGlyUz0yU9zry6KZzft51oWZaLO3bNS/tl+kHUvL9ezais77w+tD9hz+PT//dunTf/TkpjAFbWiLkLahzeTfSKYNAEMozgCtqp6jzZo0aAwOtmvRjw5R12BVf5swoLeedY8k6ZUtY1VVKeAI1wWaoO1BF71h5wp7HTZFmtr5bIo1fkmDRNTriHqeuO1BASOu2GUSMPzbo1bG2HV/lTU4aluYzvKgDhq7TCN+x8//s/S5O2KOzCBk1FQinjYsZEQwHWkWh5FmAAA0n6xHmxlLer/OmsimFtOmttq2bJvafvb7ebp81avUXhrUz0+6Tt2dmyyfz3WBxrstKjPZNrZFbfeesyiZKaogY5IH435e3sf8++RRmAnPW1PeMFevLT2us9tvkCS99/Kjtb6/M2BP39+N8+8czk0N2NDWM/yxIPfoZNoAWhXFGaCF2L5p6Gy0WSGChsnqmaDHwi6mgrc/fM9szX9sliRp84yXJUnLtkS1teRVoEkbNqToi3LTMWJJgocJ/3PEFWRMg1OSi38XhZnk3V9jSlt0405XhT5+zpKD9dbnjtUnlhyk3oFOrRto19ObJ4fsnXPI8D7s/3sfdONbL4cjzbwhg5FmAAAgCevRZlHimtqCJGpq8+ehbJraJOkffnGIlg1WtKBjqb521n2KblSKuwbOokAT1IwV91xSfNYwbXDLKzeFsc1MJiuNwprZXBdmgkSMMxvXrqn7TtJXOr4nSfrBvXvpD8/MjTmfFJ+ROtSMDW2MNAPcojgDNLGsRps1ZtDwCwoaUnDQiAofo/3mqwdr1/VlTVg39PXLo0aa+cWvhAh/PI+w4SpweJ/D9Z8oca/LNGDE/cxMO7/825MVZoZWx4z+XZnctkkzOsL/G0wo9+vhTdvpmrU9ev1TJ+uQp07RU5srnj0KHDJsZDk3GQAANDXb0WZ1u19nTWX4Y6KmtqCvXa6eCd6+YuUYvf+PJ0qS/nHOH3Xc/ktlV6Bx3dRmO3kgSWOb/7FGyk22mUkBj/k/jyvoJC3M2OXr2f+0nz7afol2Ki/V4vUT9PEbwu69afp3wnS/xm5o82KkGZAcxRmgldVjhnLP8MekQSOQi9UzQZ1gQeI7waqDnXrggiM0fe3QaKlJ+65WfMHFz2ZJdpICjf/zqIvq2uNJA0eS5fVpmRZk0gYMf+eXdx+XhZltprev10VzrtPDu16o3827Wu+f8ogqbX2SpFnt63TDTlcHHidJ/770Nbp8zbytXw+orH55i4emv/tRgv6O+c8bEzJqKsMf6zQ32YuRZgAAwEQh7tdZY9TU1uH7Os3qGbumtitumaofLx+6r8ePj79Ok8b0qT4FmqAGq7Dr/KRFmrSFmiyYZDaTzCTF/5xqH6N+1v7t/m1Bz+VnV5iZdup87VNapH9ou0aS9I9XHqXevu7IY4Y0V0NbEEaaAfmgOAO0mDSjzZzMUE6rMvwx9eqZuIsjyaQQE+aph7fXrOfGS5KenrNB+x62WPZL9V0XaEy6wdIUaWxvTOmK6bmjHrcNWmGdX1kVZrY99pGpD2ifMSskSXM61+kj0x7S7+ZdrXdO/pvePeVvKgfd3mjYxb3z1V9tC99hKxchY4xGhwrDkBFZjI3QM/wxxb83LkeaAQCAxpDVG4N1uV+n8XVU0tUz7prapHZ9+Bev0bMD22tW+ypd8M4/R5yvJm2Bpj/gGO9xcUUF0+cMeg6bvFKvzOTf1yusKFOvZrag5zPXPqlTU3Yeq//p+I7aSlVd+NCuuuqJ+WF7ez532dCWYtVMTWX4Y4qGNkaaAfVDcQZocnkGDevRZnUJGmEhwv0c5b6nh573lQkl7XDWnzRm3Ga5naUsmRdo4rrBXBRpJLsgEXYTS9s/ps/ht0XhAcO/n/cxm84v/2NB57R5TNq7e7kk6RNLDtK/vXSQHt9U0cS2fn1y+/t01uTo6uu+3St8W7IKGf5tXoYhoyYuZESo90gzOsAAAGhcdRttlvZ+nTWV4Y8jrrPCLrpsVs8ESd7Utq63U+/+w4kaqJb0thn36IzDXlT81IE0BZra41E5wHVjW5rclEdmins9pkWZPJvZ/I/HZeXRdvjAPvp/7b/U/PJLWrJ+vP75uqNC9oz7va/TqpkMGtoYaQbkj+IM0OrqOdosqcrwx9igEbRM38t7kWTCbo7y9PZNkqT+sVU931XSKf/v5uFHXC7V9+8TVqAJe8x/0Rx3MZw0cOS1TN/kOaPChW3AkMIDRu0xm067+P/eAxpaGrNqoEtXrenR6c8dp88te3XAc2+zcXBotczbKt4r4SxDhsEqmTC1kFEx3D+juclejDQDAAA2kow2c3K/zorBi5MUPnEgzeqZ5E1tt9w5Ud9cfLQk6TuHX6uZ0zYp3VhoyTxXmVybp21sC9rPq4i5KaqRzfT7D2pmk8L/2+RXmNnu+B4dWb5Pf9d+oyTpXZct1MqNJpkly1UzXsVuaGOkGeAOxRmgBWU52iwyaESNGnIeNPJcPROkqu3bN0iSNjy4iyTpwVlrdezbH4k5TsqvQBO22iPqYjtof/++pqEjTfiwPZdpuFDAflEBI6rza0vIdv/zmzy2bZ+VW7okSdOGi3+DKuv9Ux4N2HfIhxcfone9MNQF9oYJL2hsKei8eayaCQn+tU8rEaeeGvGYn6O5yYw0AwCgtdRrtJnT+3XWRDa7RN2v07vde11nMhLay66pTWrXJ3++lx7ZPFtTyuv047/7o6Sq7KcOuCzQJGlsC8oQNs1tNfXKTVGvLWpCQNDjcUWuoMeCzut/LOhxv+jHy2PbNXfvDv1nx/9Jkv7nrlfpxqfnhuxt2tDmb2Lzfh5036aaqDHQERq0oQ3AaBRngBZQ7xnKgXqGP2YWNLyyWj0T3gk2odyvseUBSdLlv91b+y4eKhqtP/oR7bhgldwv1ffvU7vwjRvblTRsePdJUqjxymJZ/hYlCxe2AcO/3fuY+8KMJD21eZIkabeuVVsf6SwNBuw/5MSJz2pp/1g9v3m8OsuDOmDsywr/3U0aMryiVs0EdIB5+f9+T/M9Xu+5yRboAAMAoMlkNNoskMn982ya2qzu15l3U9uQzZvL+rtLF6qv2q6FEx7SP55cu5arV4HG/5hJY5t/X+8+aTOTlH1u8osrMHn38x/j3y/qv1Wawoxt4Uaa+uYefaXje5pWWqOHXp6qc298fcieYY1nYZMG/PtGFXaiVs3UWDa02ajTPTqZNgCMRnEGQL6jzdIEjSBOg4a/myUsaMQXcqYPr5rpHehUX7Vdl37xCM3bVNWatrJe+9GbVSoPKvsCjXefuG6wqLARV6QJ2s//muKKJWmZnN8mXJgGDBeFmaDX6bft+Ic2TZEk7T1m5dZtJ609LvRsx0xYrBt3ukpzOtdJkuZ3rQvZMypkBBVpvOKW5qdYNWOiZ/hj1nOT6QADAKClpHkj0fsGZib364zjZPWMVzZNbd597ntioj7/6LGSpP/c52rtsqBv+LEiFGj8j0dlptq+SZvb8spMcbkp7Niw/Wya2WpTBlwWZuJNPGCG3rvgHh3Tdp/6Btp01qVvVN9AmtVgUnRWslk1s/VVRr+UsIa2sJFmlg1t3KMTyBfFGaBF1W20WRDToFEZ/phZ0EgiOGjURpot2zJWUrs2re/U0z88TGMGB/XkuEGdfu7tw3vWs0DjP960QFHb1zZw+F+jyz9hTAKQf9+wr6MChvdxkxVHQcfFnXfIfRuHrr737l6hceV+SYM67pN/0Ns/0TZiv8OfOkmnPbtQ923cTp3lbStrHhku7gwxDRl+GYSMmsrwxxzmJnulmZtMBxgAAM0hjzcKrZvaehw9cUGb2kZq15cu3VV3rpuv8aVN+unpv1dbZ+2xLAo0QdfhcStfbBrbws7n3TcqN9UjMyVZCZRHM1uSHDxS58xxetXh0qfafyFJ+uSNr9dDL4d1grq6P2fQ8c3R0OYV2tAGIBbFGaBFFHqGclqFChpD26e3Dy0hX7al9nzteuTu6Zp9z06SpL8tWKaD3/TM8GN5F2hMw4b/saCvw87p3de0YONC3POZdHzJ97VJwDDp/PKf17+PWcB4sX+8nt08QR2lqg4au0ynfuRuPTRls9Qm/cPsA3RFb49Of/Y4vTIwRo/1TdZZzx+rtz53rP775X313heO0l0bglqn6hAyairDH8eNfihSwrnJXnSAAQAAKw5Gm3lZN7XVmN6vswGa2vwGq2W94xdHad1gtw7qekJfedfDhudPUqDx7+PPRVHPkTQzmeSmrCXJTN7jwr5O0szmPy7J40H7jFTqbNPME+foG90XaExps36/aI6+cderQvaOuo9m3L5+zdvQZvRvneffTBragGAUZwAMKcoM5UIGDfs5yrWVMy/3jx2x/bLvvFr7rujSYKmk7pPv1vQdayOmsirQpAkbLgOH/7g0ASTsHGHnMQ0X/m1h3XT+cwc9ZluYsXP7+hmSpFN2flLP7Pe8JGm3x2fqjht30rlLD9KjfVNG7P/wpu30o1V7608bZnm2ppkN7iBkBBVVK8Mfw0JGFMO5yYw0AwAANrIcbWasZ/ij7f06/QrY1DZau55eNkEfvPV4SdI5s27QW45eNvyYbWYKOsa2sS3oOWwb22xzk23eCeMyM8UVZUya2aTw4xTwuAweD9pntO0WztHHp16tfcrPaMXGbp19xUJVVYo9Lvr+nN59wu7P6VXMhrY4NLQB2aI4A7SwQsxQ7hn+WOigESZ8jvJ27ZskSSsGun37lHT1F47WzM1VvdJR1mH/doPK5YHhx23frDcpNoTtF/Z8URfaQY+bBA6T7ysqgNgGkqjnjgoXpp1w3udQyGN+bkPGH9cP/YXZr22Z+iXt2duuX3/ldaH7jxY2qkwKnplckJCR0dxkRpoBAICaQo42C2Lb1Bao/k1tYX5+W48ueP5QSdIPDrlMu/ckbWoLOsY2M6Ut0oRt8x5b79wU9jxR2+KKMop4PPvMVDNu7+30xr1e1AfbrpQkvf+qY7Rk7YSQvW3GmZn8Phe/oY2RZkB9UZwB4JSToFGTa9AIK8gk6wSb1DZ04bh6oMv3SLt6V47Vsp++Xl2DVT02YVBnfsp7NRO36sTP5KI2aD+bsBG0X9TFuknoMA0fcUzOGfa6TIouNgHD+5j/eYLOH/R41Hm2uWfDVG1pk6aslfZZNqgbv3isqpH/S7edmRwng5DhZxIueoY/ZhAykqIDDACAJmc5cSCz+3Xaqgx/LFhTW7ChfT78s1fpzo0LNLG0UZe9/WqNH1e7Rs6zQBPV2OY/T9rM5D1H1pkpTW4ynTBQv2a2mo6p3drr2LH6Rse3VS5V9X9/3VuXPrZLyN5Jxpk1d0ObFyPNgGxQnAFaSOybhvWeoVy7eDCYjzpCZfhjqqDh5b3ZuY1t+1fKfZKCijNDHrxjpmb9ZYEk6aF5K3TMmY96Hs2zQGMSNmwDh3d7XMdWWFAw/RMm6vnDvieTokxUwIhbyZT0v9NoJ//rn/XwnKFl+PMum6sVy6Ku1BskZNS+9t+T0z832XHI8GKkGQAACOLqjcVE9+usNaH0RBxYiNUz7sebbRlo0+k/XqilAxXt1rFEP3vPHZJqmTbLAk1WjW1RhZqo3JRHZjIpIplmwTTNbEH7JMtM5e427XDqXH133Lc0ubRO9yyZro9cd2TscUNMxpnFyaGhzUTP8EfDhjavrEaaAdiG4gzQ4vKaoRz5P/0egyfLdfWMV7I5ypPahoozvQOdAa9r6Nirvru/9lk6XtVSSeuPfVgL9l7h2SdJgSarsFHbJ0ng8D9musTehsn504QLyX3ACNonbL+RDj3pKT2621I9sNNQcaZnycbYY7YxDRlRhcmChYwEGGkGAACi1Gu0mbG6NrV5rwPTN7XFWfJKl864+mT1V9t0yuS/6twzvRdgrgo0eTW2effJOzclzUySXVHGppktbL+ox4P2CVCSpp64kz437RLtV16klRu7dfrFb1bfQNxEAZPtJg1t7b79M2po8xdmEza0eTm5R2cMpg0A21CcAeCccdAIGm3mDxomo42kDFfPBImfo7y1ODMYvHKmtt+l/3GUdtoorW0ra7cP3ayxEzd79rEt0AQdY7OfiyKNd5+oQBEUDpL+MTm/l21RxnXACNonbL+R5u2xQoMn3quBUknrpwzdz+jVY5arszQQckTSkCHP5wUKGUEYaQYAAPJWr9FmPYavTxrd1BapHk1t5uPNbrt/ij754PGSpP/Y5Vodd/BKzz4uCjRR+7lsbLMp1ATt4zozRb2GoNfsupkto8KMpMrhO+rM+Q/pne2/12BVOuuyN+q53kkhe5tMGggb6ef/PGwfqcgNbc7v0clIM8AYxRmgxbgebdaaQSPIyKAxrjx00bhuMPqqqm9Dpx76xlGaNDCoZ7ulk/79ekmDnj1MlqH7Je0Gq+2b5Lmjls1n2QEW9hx+ScKFbcCo7Ru3X7KQMW5Cn/b88M3qbStr3qaqfvLN4/Tylm51lwe0/5jlAUfYhIyo7q+gY+oUMmrF257RD7mcm8xIMwAA4JfFaDMn9+v0N7WFNbf5m2Iyb2pzN95Mkr722wW6+JXXqL00qF8cc5nmzu7z7JNlgSZoX9vGNttCTV6ZyaYgk6SZTTL7OZtOgTAzbo8peu1B/fpSxw8lSeffdqB+99Q8w6PjJg1ErfoKW0lGQxuAcBRnABRjhnKQzINGh8KDRtg9NryP+Y/Ztm/H8EqGzYNtIefYduwzj09R25X7q71a1UNT+nTmuX/07VePbjCTsOHdNyxwmASVLDrBwl6DTbhQyH5pOr+G9nn1mGX62szbdOnca/SjHW/Qh7Z7SHt0rVS4QZ34uRv0bLdUGRjUI988WuvWdOvO9TMkSQePXRpxrJRuZrLjkFFjGjKi1HFuMh1gAAA0t6QrYkc0eVhIfL9OZ9I2tXlXXJuyGYdW0tk/eJ0e7ttR08prdMlZ16lrjGlTm+S+QFPb1zY3+c9Zz8wU9RriCkQmzWwupwyE7TtS5/SxmvOG7XVBx9c1ttSnGxbN1eduPTjiCJfjzLwK0NAWoAj36GTaADASxRkAmfB2XBQzaPi3hQWQ2td2QaOjNBQU+mP/mR1eqn/FztrpgTmSpEd2eVkLz3rYt18eYSNNkca7f1ToMDlPElHnN+k+M1lVYxowavsG73PqxKf08zm/1xsmPq/du1fpoHEv60NTH9ElPTfo57Nv1JHjFmvbjU6HvO3c2/XQlD61V6vSFa/W049OkSTduWGognnQ2GW+5ytwyKiEvLQwOcxNdoGQAQBAiwmZOOAVNnEgVVNbT8SB/iaXsPt11r522tQWxO3qmY2b23XKz96o1YNj9Zqup/Xd99zj2y9pZrKZPGAyfcC2uc17XJ6ZKW1usslMtf39+8btE7XvSOWx7Zp26k76r7E/0PzyS3q+d4LefunxGqyGZfIk48yiFKyhrWf4Y8qGNu7RCWSL4gzQgvIebWasJ+Ix50Gjxn+BFBc0zOYoby3ObL0QjL+ou+zrB2rvpeNVLZW06uhHtd+hi317uCzQpO0IS1qo8T6Pyz+2z+9/Hf7jovbxP0/cvtv2mdm+Xp+afrck6be9PfrAi4fps0tfo+vXzlZ/taxXj12ub+/4R/109k1bV9Icd9ZjemSXoeLLvAdm6/Yr52893z0bhv4i7N69Wh2q3XcmaciIGtmXc8iI0mOwjxhpBgAA3MvijcZUTST++3WmFnct576pzezaddt+Ty2ZoHdef5IGqyWdvd3t+sfTXvTtlyQzhR2XtLEtbJ+g89tMA8gyM9nkpqjXG8RkyoAC9ok7r0e5pGknz9f7J9+kN7X9RZsHynrLb96sFRtNl6GYjjMLa2gLum9nDQ1tAMJRnAEgqYAzlHMJGlHhw39xFXX/Dclf1KkVZ7aM6NKJDhqS9JvPHqNd1pW1oVzWdu+8QzPmrPPt66obLGzfuP1tA4f3ONMLf1tJzm9SlJFcB4wzK0+ouzyguzdM03lLD9Rt62fpN70L9NElh+joRSfo+yt216bBNr127Cu6pOcGfXffm1Q94EENlkraa9k4Xf71A0ecb8mWcVoz0KGO0qDmda1VdLiNCxneY+ocMjKam8xIMwAAYMPFaDNvE0gcJ/frzHT1jHebm6a20dvjXfWXmfriY0dJkv5nz8t18GvX+/bIo0ATtb/tBIKgY4uUmUyKMqbNbLX9TfYL23e0KUfP1nFzn9a57RdKkj56/RH6y+KokRw2v3NhhZmg/aLGAJq8D+HRhA1tTBsARqM4AyAXxQwaNXFBQ76vg8KFV0l9g22SpO5SyC6jDJ1rS1+7bvvCcZqxeVDLOso68JPXa+yEzb598wwbNkUam6X3JjORTf+YCgsX+QWMQ8YtkST9evUCSSN/OZYPjNH/LN9Xxz9zvK7snStJOmzTy/rK96t6381VXfnZo0YdI5X09Oah3+k5HWtDXq9kN84saL+okBG03aBDrTL8cVrUTsNquapn9EOZzU22QMgAAKBFGYw283I22ixIXZra/NeJYU1tbsebSdJnfrOPfrdyT3WV+vXrN1yqGXP9+5quWjE5zmQiQNxzNEpm8r7WoO8jar+g1x63f9zPNt74fafqta/epG93fEPtpUH99P499J279404wrahLeo8Qfep9Te0BZ2LhjYAQyjOAC0qy9FmzR00/IKDxvrBoW1jy/0yCyDbLH9pvJZ+/1CNHxjUorFVnfD561QuD/j2cl2gSVOkSVuoyVLQ63ERrvzHBO070pjSFu3StVqS9NeN4RWJpVvG6Uv9++qbbyvrbztI3f3SMXcN6PKZN+jo8S/Kfz+a9YNDv19jyv6/Q0lnJtuGDMtVM+NG7yJpdPHVQcjwsp6bzEgzAAAQIOvRZtZNbVGLA0yb2gKZNLV5RRVkTNiNN6uqpDN/cISe2TxNs8vLddlZ12rc1Dbfvqar/P1sJw9EHVM7rlEyU9pGttqxQceY7Be1/2hj5k/Sbgsn6ked/6kJpU26+Zkd9Q9XHavRTW01JoWZoP2Dxpn59/P//bCd3hGgMvyxzg1tALJDcQbAVvWYoZx50IhcPWOyAsB/4RU3R3lo3/WDnUMvpxx0wRnfCfbw3TPVecWr1FGt6pFKv952/u8lDfr2T9MN5rJI4z02LHTkFTzCnjOuayxpwDAPGXM7V6m9VNWKLV1atmVs6DO1d27Rws9er9t7yvrq20r6atueWtY/RrM71+t/d7hdF8+9QQvHP6+u0tBzjy8PfRyoegNImpnJ8u2XQciwUbC5yXSAAQDQWqxWyDq+X2fipjaTsUdeoy7f8mtqSzPerHdjl0762fFaO9itgzse16Xv+b06x7sq0IQd28yZSUqfmVyMMTP7OXTtMF5zTtpBP+76b80qrdRjr0zRqRefqP5B/++AiaiGyqgVXv4M1a749xkar6GNkWZAdijOALDmcobyCFkEjVHCbmwZ9Hlc91bQHOUhtZUzE9s2e/aJOtfo57ztygXa4Y5dJEkPzFqnt/3bHQH7R1+4Tiz36ejxT+tfp/1F39nh97ps7m91+/wLdcf8C3XNvF/r2ztcr7+f8oB26FjjOSpu2bvpRXnYMnYXAcTkXCbL+F0HjPCQsV3bJknSK1uiiheDOvPzN+jRiQPqGqyq79cH6YeP7q03PXO8vrtiD20cbNNe3av0Pzv8SX9acLl+23Od9h2zQoNV6f7A1Tim3V9etiGjxiJk1L72z02O+vvdE/GYh5O5yRYIGQAAtLgUK2zj7tdpLKqprSbz1TOmTW02BZr4zPTQ4u100sUna1O1XQu779fP33+7yp3+t7rSFmhcF2m8x5tmpiS5yTYzZVGUsS3MmOmYOkYzTp+nb4/9tvYqP6uX14/Rmy48Was3dUccZZLJw34X/RkqroAT17xGQxuAIRRngBaWdLRZmLDRZkH/s4/r2Ngq86CRZvWM18hOsGVbhl7EjPaNvn1MbNvv6h/sq90emyVJemj3ZTr5ffcG7D/yInZq2wb93eTHdNGc6/SnBb/R/+5wq86e8oiOGP+idutepSntfZrc3qd5nWt05Pjn9fFpf9Hvd/q1vj7r95rb0Tt8FpPZxCYhweZi3+ZP3POYBKU8AkbtGGm79qHizIqBrtA9z/zsLbp/+gaVqlVtd/PuuvvGoeHZG6od+ubyfXT00yfqO8v31NL+MRpTHtAuXb0arErfXL6flmwZP3yWJOPM0oSMDFbN1IqxQX//85ibzEgzAAAQIes3HlPdr9Pf1Gba3JZ69Yz3c5OCjKn4As3Nf5ult111krZUy3rr+Dt1wfvvlUYtnrDJLkFsRx/7n9e0UOMyN5k8VxTTn1nYsWHPbbp/sLaJndr+rQv0ufEX6ui2+7Sxv10n/OpkPbO6EnGUq3Fm/nNkuGqm9nUdGtoA5CfJ/xkBNLGrnpRO2Dm788/XU1qkBSO2bb/TC3r56dkjd+yR9GzISaZKWu75GGeMpI2RG3zbvJ+3a9uFYu2fzP6Qr7d5vn+ypOc1r3O1hi7Kwsabxd+c8ZKvHKwzv3CjHt6hV88evEhH9Y7VTb/ebcQ+ZW3WYeNe1tsm/02vG7tUbaVtb2Av6puov2yYocf7Juul/olatmWsqippStsm7dq9UoePe1EHjV2i4yY8q8PHvaCvvHKQLlq9u4bm9NZeX9Sb+94L6bj/rcR/vyOfz3T/OKYX+0nmHpsdM2Z4/FhtVZXfWz9xhx6et1KStOv9s3XJz/catc/qgS59a8Xe+taKvbRzZ69mdWzSos2T9GL/hOE9ko4z82qMkFGPucl0gAEA0JpOqFZ1VSnsHhY+N0k6aujT0q1S9fChz/db+Zjun7K7pKGmtge1t6ShJpIntGvo6WbrBb0gX1basSq96Hs9MyW9FPPapkl6RUPXY6s1dH223vN1oImS1mhkRvLmmLDtta+l4LzQr5FZy3tsu8yu37ft99v75uofxr1ZPz76Sr1vyk1a8Z4xOu/7u/pv1zi8v0lmCbpm9+dCk2OCjg87h/98JoqWm9Jlpjjlse2afsYuen/lD3pX++81WJXOuuyN+sviqK7OqJ+13f1hg+/P6c9POa2acdDQ5sVIM6A+WDkDIJFcZyjXLjai5qnmsnrG+3n0HOXHNg29oL26Xw7Yxyu+E0wq6defOlp7ruzSllJJvQsf0iFvGloVMKncp3dPflzXzbtG39nxFh067iW1laq6f+NUnb/stTr8qVN1wrMn6vMvH6Df9O6s2zdM15ObJ+ipzZP1l40z9fNVe+q9Ly7Uyc+erDvWz1JXeUCfmX6HPjntLo1MMiadVZK7ecmmz+fqtSRZxm93TMfwz3NLdfT/ek/54F/16J5DaXqvJ6fpkm8cFPN6S3py81Tdun7HmMKMArZF/e4WKGQESRAyvKxDBgAAgAlHK21T3a8ziNPVM0HbTFfPxN+zc4jteLORfnL7Ap1z1xskSefOvEb/+o7nQvZMM+Ys6nibDNNsuSlutU7Uc5spdZa1/Vt21punPqT/1/5LSdInfn+YLn/ctrvUdJxZkvtzBp2vuA1tjDQD6o/iDNDiXI8287KZoVzcoDFGwbNmpahRUPdtHKoo7TtmmcaVNyv+jfMg2/arVsu67Nw3ao81beorlzTt8Hv07VfdpJvmX6lPbH+/ZneuV+9Ap364cjctfPpNevvzb9CFq3fVKwNhN54feYH81ObJ+ocXj9N/v/waSdI7pzysj0/7S8hxSQKHi+Bh+3wmks5Wtg8Y7aVBSdKARnY5Hv+uh/X0a5+WJO394iRd9IVDI8699WwG+0jxIbggISNIrSjbM/ohm7nJ1iPNvOgAAwAAIfIcbea0qa2mMvwx16a28Ht2xjMp7khfu34Pnf/QkZKkr+x0qd57athyoqwauLzHJc1NWUqS0UzGrUUdF/VaDLWVNO2UBTpg5kv6ese3VS5V9Z2799XX7nx1zIEm48ziCjP+4/2/iy3Y0AbAGYozAEapxwzlEYKCRk2uQcPfCRYULPxfD12YPdc/Wc9srqijNKhjxz8dcIxpJ9i2/fr72rTiP16jL/xyQF/8YVVHrn9ZY8oDenxTRZ9a+loduehEffWV/fRC/wQlCxsl/WjV3vrU0kMkSX8/5UGdPPGJmGNtu7RsZyKbHmsbYtIGrajXGKxWlCl5ViQddcbjeumIxzRQKmmvV8boV58+Wsn+15x2nFmHRgeQHENGrUhjGTK8GGkGAAAKw9PUZjJxINemtmlhOw7LtKlNvseD7kUT1dRmVqD59GX76YKnD5YkXbDXr3X6cWGzsF2NPk5yL8soaXNPHpnJ+zxhxyY5LkBJmvrmedpn3hr9sOO/NKa0Wdc8MU8fvu7IoQdD2d5nJkxQU2aLN7Qx0gxwhuIMADspgkaQ2A6OnuGPQXNUMw0aURdaZkHjt71DM6XPnnK/2jQY8+LCCzRdpS16y6SndFXPdfrf7e7Qzs9XNViS/rxrSf/1tpI+s/2+uqx3vjZV/a8jWdi4rHcXfWv5fpKk/zf9Ds1qX2VwvIsZxy6LL0FMb7AZdXyU6Ne4eXDorqRdwytojn7r41q98CH1l0rac1WHLj5voRQw8my0pDezDDtXWKFRqnvIiJHp3GQAAAAPq4kDjqRqPjFpaqvJrakt7XizKCPHQn/o5wfp10tfpfbSoH5+0K+08Mi1Ice5GINscp6kzW1Rz5VFbjJ5ndk0s4WZcuwc7bv7Bl3Ycb6mltfqr0u215mXvkkDkdkpyaSBpPfnrAlraAv6u5OASUNbDBragOKhOAMgt6Dh7cwIuiiI6+zYKs3qmcDroaiLpeRzlC9evb96B7q0S9dKvWfKfcNbzTvBdu5cpU9Ou0c3z79Sn5txj+Z3rdG6gXb9dOUuemvv0brk+Dbd3dOmye//o/Y+OGq5vn3Y+O6KffXXDdM1rrxFX5x5u0rqNziPy8DhStpw4T1H0uOHbB4OD12lAR391se16g0PaXO5pD1Xd+iyT75Jg/0mAcK0MBM2RzlsmX5QyPB2QSrg8xxCRs/oTbnNTaYDDAAAxEjyRmRu9+sM4m9qq4Tsl1lTm5/L8Wa18w2pqqR3fP9QXbdyT3WX+nXxob/Q649yUTQwadhqtNxk+nqybWYbpSRNecNc7f+qzfpV5/maVl6je1/aXsf+/DSt29wZcWDYODL/46bjzMLuz1kzRvF/H3yyamhLMNKMhjagvijOAAiUZ8dDJkHDWNAKAJOg4f08eI5y7+AY/dfLr5ckfXjqn/WeKfeqrEFFFWh6Ojbo7yc/oovmXKcr5l2jd055XJW2zXpx8zh9+eX9deTTJ+krr7xKj740TXeff5zmbqpqVXtZE9/7R+1/2IsR36fdPVgGVdZ5S1+vDYPtOmDsUp04sfbmt2lHVr/qEzpsntfFCALzgNFXHVo5M2PGWq0eLszs0duhy/7tTervS1KYCZNmZrJ/v4xCRhRHISOMN2SYogMMAABYC5k4ECaT0Wa15heT+3QmXj2jgG02TW0ux5uNtGWwTadecJRuXz1fE0sbdfnrf6pXH90WdYTRed1kDf+58sxNts+Zb26SJJVLmnriTtp/v76hwkxpqDBzzM9O06pNts1hWY0z8za0eZ+HhrYaGtqAeBRnANgrwgzlNEEj8roobrWA3Rzly9bsp5+s3F/lknTOtLt0Zc9F+uB2d+vo8c/ooLGLddT4Z/X3Ux7Sf868VdfNu0TX7nSZPr79fdpnzAr1V0u6Ye1sfeDFI/WGZ07Sz1btqfWD2y4MX35xov78+eM1b6PU21ZW97v+pFcfE/WGtc0Fcb9e6J+o76zYT5L0z1PvU2fJf7zNsnl/AEgbQNKcy7S4ZHIec7XijCatU1+5pN17O3T5v6YpzJh2fwUdl2RmssOQ4R9JGKRn9Kakc5PDulLpAAMAADbqPdqsGKtnXDe1+bm9/8ymLR06/oI36P41O2hqaY2uOORH2vO4sSHHSskyThTbcWNFy0wuijJ2uanUXtb2py3QPrtv1IUdX9C00hrd99I0w8JMkkkDpuPMwp4n7B5MIRqgoS0JGtqAZCjOAJDUYkFjlIj7ZwQGDbs5yv/5ypH67NKj1DvQpZ26Vuufpt6t/93hRv1o9rX61g6/18en3aM3T3xaczvXqr9a1u3rZ+nzy16rIxedpn9ZcrhuW7+DBrf+cz3yYnP54nG6/bPHa/4GaW1bWe1vv1MHH/9MxPdqFwx+uWqBlvaP1ayO9TqjEnb/oLTzjaMKN67Cie2KH5Nz2dl/4dB/l/YBaffVHbri3453XJgJ43JmckYho1ZsNQwZXpnMTaYDDAAAGDJ9Q9LF/Tq96rd6psZtU1t4sSV9gWbt5i4d+38n6Il107RDaYWuOPCH2vNNQdnPy66xzc0KFJPnaJTMVDufnVJnm7Z/687aa/46/arjfG1f7h0qzPz89JSFmbDtJuPMwr4OK0TWNG5DGyPNgPxQnAEQyuUM5cIFjcSrZ8LmKIePN5NK+k3vXjr26bP16aVH6ao1O+uBjdP1ZN8UPbhxmq5ds5O+/spr9IEXj9Hrnnqb3vfiQv1q9V5aOdBt8I1Jq14eq5s/9Sbtsr6k9eWytpx+t447M250k9mFcl+1Xd9esbck6R+mPKiO0oDBeV3ciDIt25tiZhcwJOktH7tTL+49dF+gCX1lXf5vb9LmTbbztWuSdn95H4ubmRy0vab+IcP53OQIdIABAIDEUryRGXe/zhHqunqmJq6pLel4s6R8TW0bxuro75+sFzZM0vzyS/r9q7+pw07rkEpR57DNNbZFmqLlpjjZFqHKY9o1/W27aM85a0YVZlZutC3MhD1uM2kg7Gsa2iTR0AY4QnEGQDKWM5S9ChE0RvEHjbixTvadYOsGu3Rp7176t5eO1dueP00nPXumznz+Lfr4S0freyv3123r52n9YKfvXEFGX0SuWTlGN5z7Ju26tk0by2W9vPBhnfLBv4YcX2N20fzb3vl6qX+sprZv0vETnlSypfBZB4+kz5N1l9ug3v7pW/TYPou1uX0o+W1cNkZbjFbMSOm7v/z7BIXgoM/rEDKCWISMVHOT6QADAACG6jFxwCt2TFFPxGMmTW01leGPzpravJ/nMd5s9L4vrpmgw75/uh7rnaaZpZW6Zs9v6IyzNkrlyAqN7HOA7cqVPIs1SZ7LdhyavbYJHZpx1q7aY+bqrYWZ+5emKcy4mDSQpKEt6n0FCwkb2rxsRpqFNbSZoqENSI7iDICtmj5oWK2eqbGdo1wTtDTfpBMs6Zv20ro13bri4ydojxXd2lIq6ckDntHbzrtN0mDMuaIvzAdU1oWrd5UkvXPy45KqcrNU3qZ4k/S4IEnmLdtra9+it3/5ej04f7kkafbSSZKkrlLcf48a0+4vk32iZiZ7w0gdQ0ZU0dUj87nJdIABAABLWY02S3S/ziBBzTBhTW3jRu86UvZNbeH7B233P+Y3ct9nV0/S6777Ft2ydJ7Glzbplzt9V594z2KV2k0KNFnnDv9z2eafpMeFcX1fz2DtlS7NOGs37TZ1pS5s9xRmfua6MONinFm74kf6hWzLqqHNI+jfhLCGNpPpJow0A7JHcQZApKYKGrFM3oA2DRr+z4Mu+IL29zMPGpK0pa9dF3/8eO39/FAh4KFdXtZZX75e7R3pRnZdsnqBNg62affuVXr1mJc9j6S9SWXQa8hixU2S15n8uSvTNurU/7lKD85Yr1K1ql0emqVrv3eAJKnNqDiTd8iQChMyekZvynRucgQ6wAAAQCLeZg9Hb2g6u19n4ZvaXN9/ZrTVm7p13PdP1M+e2k9tpar+c9ZF+vYHH1Tb2DaDo5NmBJeZyfs6XK+4SdLIlvy5O6aO0YyzdtMBled0UfvnNb2tVw8snapjfna6VuRSmPEfazvOLKqhLSHbhraAv/s0tAGNg+IMgOSaMmhEvRltM0e5Pp1gkqRqWb/6zLHa7eGZKlWremDGep3yP1dp2sx1EeepCb647h3s0tVr5kmSTp0UVlhzXahJK/j1TG7bpF27VmnnzlWaWO4LOC5dwFiw1wq97vyr9diEAXUPDmrWzXvosq++TgPDA63bS3EXqq4LM80ZMpzNTfb+O0bIAAAABoKuCRrmfp1BcmtqM7lnp/9rm6kD5mOh+wfb9K5fHqnP3nOkJOkfJ/9eV334Jk2a3Tlq32BpizRFyUxSuhU+yY2ZP0kzztpVZ0z6k37Veb6mta3RvS9tr6N/9haDwkwQk6Ken4txZjS0AUiO4gyAEbIYbeb9n72N+gSNGn/QsJ2jXJNHJ1hwYLnkvw/RjJv3UPdgVY+NH9SrPnet9j5wacR5vEaHjct750uSFk54XmNLcRfu/co/eIQ/55S2Tfrw1Pv1u3m/1R0LLtHlPdfoinnX6K6df6Mre67SJ6b9Vbt3vaK0AeP1xz+jWf/yBz3fVdKULYPSLw7S9T/bU5K0pTr0v9z2yDFzaQszUcfWIWREcRQyvExChilCBgAAcKZo9+s0uZ9FJk1tCvg8rKkt7GvzFTLR55Skkv7jmv31jutPUl+1XW/s+qtuftdvNO/AbovzpmnsqkdmcvG8KVfqlKTKYTtoxunz9anxv9Z/d/yfukpbdOmjC3TYj99qWJiJy8th+ch20kDQ51G/+y3Q0AbAKYozAGKlHW3mFTbarJhBI+jCynaOctDn3q9ddIKFn+f3P9tT/T96nab1D2pxZ1nj3nebFr790Yjz+G278L5/01Q9u3mCxpa36NgJz1ucQxodANIEELtznTJxka7f6bf6wHYPa07nOg1WpeVburVyS5ckaUFXr9495TFd2nO9frjjzTporGkBy2tQb/3YnVr3lnu0sr2sOX1VPf/fx+iem+Zs3WNLdWjlTFvoypm4wozpdtuZyTVRBciEIaNWHI36+2gYMrzC5iabYG4yAABwoUj366SpLewxv+DM9Mu75uuYX5yulVvGaf+2Rbpt4fd08CldBveh8XIxWiws5yTJTS7PVZP+eyyPbdf0M3bRjgdP0o86/kvva79GkvS5Ww/SW35zgtb3m6xcsrk3Z9pxZnH35/RuC1lJVoBVM04xbQBwiuIMAHsZzFD2Kk7QqAkKFf7HJLvxZmFfx3X82IeNB27fQY9+/o2av76kdW1lLTn2Eb3jczeqvdPmwnqLpAFd0buTJOnEic9YHBslKjSkCxMdGtD5M+7UF2beqXHlLXpo4xR9bMmhOvCpt+qwRafr9YtO0cFPnqKPLXmdrlszW1uqJR08bpl+NPsW/e+sP2rHDpMxcNL4SX16x39fq0f3WawtpZL2WNWpP37yzXr28Skj9hsY/l9u8Fgzk8KMbfeX6TizsLBckzBkBKkVaQxDhhdzkwEAQNHlcb/O5mtqSzrezG2B5vanZ+nAC96qReu3046l5frd3t/Qae+pqm2i6ZizGpf3f/HKJjPFc3dPm64dx2vm2Xtot571uqzjMzqi7QFt6G/XW37zZv37La9TVSbFMFeTBpKOM4srPqaQU0MbI82AYqE4A2AUV29C2gaNWA0RNJJ2gtks1bcPG8uen6DrzjlRe700QdVSSffPXa2T/ucKzd1ldcxzjXT12tmSpNeOXaZK2yarY/NU1qC+PPNPOnXSIg1US/rGK/vqzOffqN+tnav1g2XVwkXvYJd+t3aOznnpEL3h6Tfrl6t21pZqSUdPWKyreq7Vuyc/rnLEGLI9D1iqI79ype6fuknlalW7PTpDF3/0zVq7avQFemdpQJK0uer/X69tYca2+6tOIcPBqhnXc5NNETIAAIBzGTe1BeoJ2JZrU5vJPTvl2demqc1//vQFmqdWTtaB336r/rRstiaWNujC6d/QR963XN09SVaRZ1WkyYvb1z/htdM1/W276vBJf9PlHZ/Wzm1L9ELveL3+R2fokkd3MTxL0sKMf/+048wcNrQFrZopWkObxT06AdijOAPASOSblSlmKHt5OzsaN2go4HOTTjA/k/uJhAnef/OmDl107kLN/uMuGjM4qL+Nq2rOv96g48583PjMi/vH69FNk9Vequqo8c+pmKGjqs9O/4veOPE59VfL+tDiI/R/K3dXVQOKeq1LtozTF15+tU559g360/rp6ioP6hPb36+fzL45YBXNoE7/8J815gO36ZnukiYNDGri5fvrkv98vcL+19o9XJzZNNjm2Wry3zZt91fQ/nEhI+Ymr3VcNZNmbjIjzQAAgEtZ36/T23wSZ0RTS12b2qKmDbgYb5btWOgVG8foqO+fol89voc6SgP66rgf6D/ffp8mHbR9xLmiuFt5kj33r7XU2aapJ8/XlKN21NkdN+inHV/R5PIG3fnCTL32+2fpvqVRId7LZgR02LE2kwZM789Z/1Uz9WpoC8K0AcAexRkAyeTYMdG4QcN0vJnrpfpB+29z3Q/3Ue//Hq7Zm6RV7WUtecPDeseXr9WkKRtjzjnkhrU7SpKOG/+iZ2txQsfpk57SWypPaaBa0jlLDtZt600v+Ics2jxJ733xCH1m6Wu1frBdrxn7in7b8zudNmmRpKqm77hOZ379Sj3+qhe0qVzWLuvKevZLx+r2KxdEnndMebg4U60VZ8L+G7nu/ko6MzmFpKtmPJibDAAAGpHL+3V6hY02K15TW03UKgPb8WY2UwfSF2j6Btr19l8v1Of/eJAk6Z87rtBvjvmN9n7Ldip3twUeY6Y4mWmb7F5Tx7Qxmvmu3VXZdYK+2PYDfa7jp2ovDeqn9++hI3/6Fi1bPy7+JJKymzTg/Tzod69dZiPOi7dqJo+GNqYNAG5QnAEQKOjNyLxnKDdO0Eg73kwRX2dToHnivum6+ZwTtNeSoQvJ+2ds0IFfvkqHn/h0zDmlG4ZHmx08bpkmlDcH7LFF9QkeW7RL13Kdt/09kqT/Wb6Pblw3O+G5Srqkd75OfvYNumfDNI0tb9HnZ9ytiw66Tq/6t2v0cGWL2qtV7f7YDF3+zydr8VOTY884rX2o+LV8yxilL8zYdH8FfW4yziyDVTNxmJsMAACaWcYrd2ObW3JvaosbnRs33izJ1AHbx8Ouy0v6zE2v09m/XajNg206tu2vumP383XuP72kCXvFX/vHq19myvx5S9L4/adpxt/tpnnbrdXPy+fr7R03a7AqnXPDYTr7ioXqGzCdDuGiMOP/OqiZ0r9/1KSBuHvRanRmitIkDW0AkqE4A8CNlg0aLsabZbVUP/pcm9Z36aLzjtPkK/bVdv2DWtJZ1qpT/qq/++L1mjx1Q+hxz/ZP1JN9k9RRGtQR4xfHPL80OgCkDQLB52vXoL484y51lwd027qZ+vHK3VI8x5DF/eN19gtH6gcDu2qgLO2zao3O+2lVhy0aVOknB+s3X3m9qqPuIRNsZvt6SdJL/WEdYra/A3HdX0HnMp2Z7N/H92Ul4pCoVTOGHWDev++Zz00GAABIwdVoszRNbYFiml7yaWqrSdrUpoCvk0wd8D/uF37t/dMH9tRr/u/t+vOyHTShtFHnj/uFbj3lpzry7LHqmOpo1bmkvDJT1rp2GK8Z79xd04/bQR/svkY3dHxcB3U+od5NnXrzhSfra3e+RlLJ8GyuCjNx9zMyGWfm3xZ1j9oANqtm4hS0oY1pA0AyFGcAJNesQSO2QBN1kM14s6SdYO4KNJL0x8t31n3nvll7vDJGg6WS7pu1Vvt++Wqd9N77JQ0GHvP7wNFmtsICQ9yfYO+a8jft1r1aq7Z06tylB6pqfNEfrlwe0Kn//Ff9+dwndd4727RkirTdWumDvx7U4Y+sUHvIzyfIgq41kqQX+8cHPBr139BmpVXacWYpQ0aQoJARMzc5SJqQEYmRZgAAwLEkEwfS8DazBDa19Xg+j+rEL0RTm3fElMupA/7H/cIz00MvT9Pr/u8t+uC1R6u3v0v7l5/SDXO+qG+87wHNOmZ7lTqzfGvNbWbKStv4Dm335nma8Y7ddPDMxbq64zx9suMijS336+ZndtRrv3+WrntqJ4szpinM+I+JKsSEjTOriWpoM5w0EIWGNqDlUZwBEMrVaLM0UgeNqNUzleGPJm8wjxAWNPLqBDN53C+6QLNm+Vhd/Ik3aeo1e2v65qqWt5e16PVP6W3/+1vtc9CyUfvfuG6oOPO6cUvVXar/zOQdO9bpg9s9LEn6z1f216qB7tTnPPT4p3XCty/X469+Tr1tZW2ZXNJ57QfrN6t3UrkkvW+7x/SLOTdqbscao/O9duzLkqR7N/pvJuqi+yvNODP/YylCRkFXzTDSDAAAFEaCNzzDmtpimTa1BV27OWlqS3LPTv82/+dxBZtsCjSD1bIuuHtf7f7Nd+nSJ3ZRe2lQ/9hxte46+Kt6+z8OauxuLkadNaC2kiYeOEOz/mEv7bBnl77Y/n1d1vXv2q3tRb2yfozeefkbdNTP3qInV9r8fNIWZlxMGjBpaDOUdNVMDg1txhhpBmSG4gwAd/LsrHCxeiZKojnK/n2z7AQzedwvfq7vLb/ZVbd/9CTtvmiq2qtVPTRhUKX336a/+8q16tl59db9Hu+raHH/WI0pD+jgsaOLN3n7yNQHNaY8oLvWb68r1vSkOtfOe63QO756pVa99V4tGiONHRzUbo/M1PX/fLL+8se5+uyyA/SRxYeod6BT+4xZqd/2/E7vn/KIOjQQes6ejvWa17lWA9WS/jqiOOO6MOP//Qo63mRmsoEcQ0YYb8gwRgcYAABwLM1oM5P7dYYJu19nbFNbFKdNbTVJm9pMpw7YjIUOOlfc/tu8tG68Tv/Vm3XCr07SC+snanb5Ff2i8jX96vRrtefbZ6h9clfk8c2ke6dJmvX3e2ryETvo1O479YeOc/T29pslST+8b0/t9u2z9fMH95DdGDPXhZmo+8zEjTNzeH/OIAVoaGOkGVB/FGcApJPDaLPYi40es9eQTdCoieqmyaoTzORxv/gCTd/6Tv3m80do7dcP1269HRoolXTf9A2adu4NetfnbtS0WRsklXTzuh0kSUca3XcmO7t0rdabJj4vaWjVjPnF/0g77b5K7/jydRp3zk26f7vNkqQ9Xx6r5//jDbrkvw7RYP+2n93v183WKc8u1O3rZ6irPKiPTHtIV8z7nd484VmVR406a9NHpj0gSbp1/SytHewc3p50WX7Y10G/N2HjzOTbVoyQ4eX9ex/WAeaVaG4yI80AAEBGIlfk1rupLWrigFfdm9r826L28YrLTFHHhu0/0tVPzNfu33invnbXqzVQLenEtjt1x4J/18fe97ImHzpDpY7mfbutvdKlaact0PS37Kxdtluln5e/oK93fkdT29bq0Vem6NAfv1XvvXKhVm60aQIL+5mb/LcMeyzuPjP+z23vz2mhMvyxDg1tidDQBuSmef9vAcAJ69Fmnv+Juxpt5jWiCywuaJgs0w9SGf6Yao5yTZJOMNsCjV/6Ao0kLXpge13ykZM0/jevVs+GkjaWy/rr3NXa6fyr9Y7/+L0emVSRJB0xfnFAQSI/H95u6E3569bM1uN99uMEdt1nud75xd+p8okbdf+M9dpSKmmXdW1q/8mB+vW/Hq+lzwbff2XplnF634uH6xNLDtLyLV3q6Vyr/5x1l3437xqdM+1+HTpuiQ4Z+7L+a+btWjjheQ1US/r28n2Hj86y+yvoPCa/qwlUhj86ChmBXZ4uRYQMRpoBAIBMFbWpLaiJppBNbWEZKelY6KB94vYfaX1/p865/nAd8P23695lMzSptEFf6f6hfnfED3XkP22nyUfPbqqVNO2TOlU5YkfNes+eqiwYpw+XL9HvOv5Vh3Y+qo39bTrvD4dov+/+nW5/fkfbM4dst2lMdDVpwOb+nI3R0MY9OoFiM3uHDgAc20cP6UHtLWnoYuEJ7Spp6CLiKc2XNHRxsUgLJA1ddLyg2eEn7JH0rMETT5W03LetImm1hi6e1sedYIykjRq6IAu710htn9rHdm27KWPt8w5J/b5tflH7xB3vfTxI2HOOdtc186Rr5uqosx5Vx2GPaXFXWffP6dXYf/2rNn1Tmqo+7dO9UvdvirqDaDb26V6hoyYs1kC1pG+t2NviyEEdfsKz2uGoh/RIZbPuLZUklTR/fVmvXLefLrva9GaVJV2ztkc3r9tBZ01+Un8/5XHt2Lle75nyuN4z5fGtew1US/rcsgP0WN8UJSvMmHwdFFyjQmhzhgxjzE0GAAAOnVCt6qpSshXcpVul6uFuX8/2O72gl58ezk87VqUXfa9tpqSXfAcFZaVpkl4JeZKKhnJULfZs5c9M/nxUyyreA/05qXYd26+R2SYqB0VlpqDHg/bxqr2G6Nx070vTdcD/nal/PvA+ff6oO3VAx9903YRP6fsHvUk/fc1CPfdMWWvvfUUbF62WGu197LaSxu5c0fh9p2lMz0S1aUBHle/Vv1V/oQWdQ/fU/N1Tc/Wha4/W06sqCZ4gj8JM1DGO7s8ZpTL8scANbSMaa2loA3LVtCtn7r77bh1//PGqVCoaN26cDjroIF188cX1fllAQ8pihnIaVqtnapytnvELW5HguhPMf1zcY/7Hg9jU58u66Zd76foPnKYZN+ypeRtK2tBR1j0LhkLeB079g97yT3ersl2fxTnT+8jUByVJV6zp0TObg1e4eM3YcZ3e9tG7dNIFl2rFaffqwcn9GiiVtPPaNlWu2FdXfOgU/cm4MLPNhmqHvr9yDx216ER9dPGhunpNj/7WV9GTfZN01Zoevf35hbqkd2clL8y46P5KEDKifv8rwx/ThAwPVyEjydzkIHSAAWgV5CYgO65Hm5msnonVE7At6eqZSGnGm2U5dSCb3DRQLevrd71ae3zrnbrybzupszSgD7VfqTu7/km/2uX/9K63rNC89++miQfNUHlM8fuk26d0q3Lkjtrxg/to2knztfO8TTqn7de6ve2f9P3Or2lB18t6ae04nXHJm/TGX55ax8JM1HbTSQMO7s/ZrA1tADJX/P8jJHDzzTdr4cKF6u7u1plnnqkJEybo0ksv1RlnnKEXXnhB55xzTr1fItDwrnpSOmHnkAdvknSU3flMVs/E6tHo1TPTJfnvV+989UxYJ5hXmk6wqG4uFytoFHBciGpZN164u3Thbnr1wme1ZOf7pEc3a/4zVX3ruOe0+/7PaP6qsVpx3zzddtUCrV+T3TL+A8Ys08Hjlqm/WtZ3lu8Zut/EyX06/KQnVNl3kR6d1K+HykOrZDoHq9p5+Tg9fPn+uvzOmMqBoY3Vbl2/bq6uXzc34FGXhZmgz026vyxDRk1l+KPrkBEzN9kkZBijAwwARiE3ATm7XtLC+N32W/mY7p+yu6SREwdMeCcOjFg9EyTT1TNx/KtnbKcO+LcpYB//cynk8aB9/MwmD7ywZqJOuugknbzbU/r4wX/VIXOW6Ki2+3VU2/1atf14/fboQ/SbQw/VXx6dpLX3vazNS2LDZ25K7SWN3XWyxu87Td2zJ6hT/Tq2/FedoT/osM5Htu63fEO3fnTfXvriHw9Qb193gmeKeivStjATN97OdNKAg3FmQSrDHwu2aoZ7dALFUapWm+tv0pYtW7TbbrvpxRdf1F133aX99ttPktTb26sDDjhAzz77rJ544gnNnRv0xtloa9as0aRJk9Tb26uJE+O7soFmFbREf1Rxxhs0PMUZ/xL9WtCQNCJo1IozkrYWZySNKM54R5uNCBq1ZfrPep6oFjS8xZnlvo/StqCx2rNtvW/bqKBR27DG93XtY3/I9i2ex7f49g3aVjvG/7WXP0QEhYaooBF1XLTx5c26Y8Hl6ihV9bm/L+uR6dsWZI4bHNSuvV3asGgHPXDzPD35yBRJyUY9jFbVL+fcqP3HrNAvV+2sL7z86hGP7fmq5drr0KfVttMSPTGhX5vK217XrD5p3KKZuv3CffTKixMcvR7JvPPLv2+SwkxQ91dUyBijVCGjMvzRP2PcGzKm+j5K24oz3pDR4/l8OGh4Q4bLucmhy/N9ISOoOEPQQKvjGrj5ucxN/L4A28TmppDMJI3MTZlmJmlbbvIWZ2q5yZuVap97izOrhz/GZibvRpPcFJSZal/3R2wLOsa7TQGPBT0etI+fXWbaZbuVOnu/R/TOfR7TDhPXbd3+6OBcXTJwmC5e+iotfrqqviXrtHnJeg2sM8lt7pQ6y+rcfqzG7jZZ4/bcTm3d7VpQelFnlG/WKaU/amr7ttd8w6K5+sG9e+mKv83X5oGkvd55FWbiJg14HzedhmE5Brri2aeWm7xZqfa5t6Gtlpt6PNtCijNBuSmqoS0sN0WONPPkJjITMJqLa+CmWzlz0003adGiRXr3u9+9NWBI0qRJk3Teeefp7LPP1k9/+lN95jOfqd+LBBpQ0WYoh+pR9OqZWidY1OqZIKnnKBehE0xy0Q3mtW6wU3/ZsL0OGbdM476yl8Yc2K35hz6pJduv0cr2su6d3C+95ll1veZZHdc/oJm9Y9X30lQtfWJ7PXHf9lry4nglKdgcNu4l7T9mhTYOtumq7rk65rQntf1OL6s8c4VemrBRyzratK23q6yp/VVNWzpZj16/u264fVai5wxnEzD8+9usponax3acmYXK8Meom79msGomLdORZoQMAK2K3AQUgOHEgTT368x09Uxt6kBFCVfP+PkzU9TUgbCMZHrfzqDHg/bxs5s88MSKKTrvD4fqUzcdouPmP6d37/eITtptkfZoe06fKf9cn9zxQt0061X6zcBhumVwP/Wt2aK+JevVt2S9Ni9Zp76lG6QBN9em5THt6pw+dsSf9sldKpVKGqNNelPbHTqj+ge9tnNbg9TiNeP0o/v30o/u20vPrp6U8hVkUZgJ+jpqBJ5/m+PCTJAMxkCHNbSFiWpoC8U9OoFcNF1x5pZbbpEkHXfccaMeW7hwqEXl1ltTDKEHsJWL0Wa5Bw2vWugwCRqBktzossYbKmohQgHb0hRopKzDhiTdtG5HHTJumY4ct0Q/uu4Y3X3dTlJpUPse8aIWHPaUNs9Ypee6B7Wko01LpvZJUxdLey/WlNOkPfsHVNncrq6+DrVt6FZ1fbc2ru/WwOZ2VfvLGtzSpsFyVZ3d/ers3qy27i0qj92gs69bLa2Qbj9wQIPH3KKlkpZufUVtaqtWNWdjWV2Lp+qxW3bWLXfMUDa3Wcu6MGPb/VXTXCEj9dzkBDPdAaDZkZuAOokYbeZtavOONnNmx+q21TM9Sj8SOkgmTW01tX1txkJnUaAJOy7cYLWs3z01T797ap4md2/U2/b+m979qkf0mhnL9Ia2u/WGtrv1cnWSLp/yel056RAt2m2mNmm2qgOD2rxs49DKmqXrNdg3IA1K1WpVGqyqOiipWlV1sCpVNbxt6PP2StfIQszEzq2vZ7w2aG7pJfWUlunAgYd0cuedmti2SZK0ZbCkq5/YST+4d2/97qkeDVTT5qi4tx7TFGZMRkDH3WfG/5hhZgrSpA1tQWhoA9xouuLMk08OtcDuvPPod4xnzJih8ePHb90nSF9fn/r6tt3Qes2aNaH7AvBplKARFDBqKhodNGrbMu0Eq13YB22rsS3QBO0Ttp+fedi4ed0sfXr6X7X/mOWa0rZJKwe6pWpZD9w8Rw/cPEeSNLayUfsf+YKm771Ymtqr5WP7tbK9pJc62vRSR1Uat1maslnbxh2EO+ixQc1YUdWGTulXr2uTJE3tr2q7DZ0qL5+kxQ/uqL/+YbYeWJfd/W7sA4b/mLSFGf9xNjezbM6Q4WpuMgC0ijS5icwEhAuaOFCP+3XWbfVMoKQFGml0PjItxuRVoFHAsdFWbRqj79y9n75z937aa/tX9O79HtE79nlM24/r1fvbr9H726+RJC0fnKAlmqolc6dq8ZypWlLdTi9Wp2pJdejzFZqoqIkAE7RBc0tLNa/0rOaWlmleeanmDLykuaVl2r5j7aj9F62cpB/ct7d+ev8eemndeKvvKVyaZragfdKMgPafI+WkgRZqaOMenUB2mq4409vbK2loOX6QiRMnbt0nyJe+9CV97nOfy+S1AY2uaYKGl8nqmUhpO8GCur7ClurLt58UHDT8x2QbNpZuGadHNk3Wnt2rdMS4JbpszU6j9tmweozuuHwX6fJdtm4bU9moXV+zTDN7VmvMlPVqm7BRA2P7pI5+DZarGihVtaVcVXtVahsoqzxQVtvmNv3dTUNB7dZp09R36c565t6penjV2JjvxSXXhZmgY9N0f6UcYxY2M9nLUcjwsg0ZxixDBh1gAFpFmtxEZgJSimhq8/I2tXknDqSSpKnNdPVMbVvqpraasKY2k6kDQY1uWYyGDjvWzMMvT9M5NxyhT954qI7f+Rm9e79HdNS85zWhq19Ty2s1VWu1j54JPHbjYIeWDEzR4oHt9OLgdlpRnaiZbavU0/ay5rUt03bl0QUYtW37dNm6sXpyZUWPLZ+iXz20m255draqzsY/p81MQfuYFmbCtmUwacCLhjYACTRdcSatc889Vx/72Me2fr1mzRrNnh3xxi+ARAoTNLxqBZqKEgaNJJ1gNf5QkWapvv+YqH0kF2Hj5nU7aM/uVTpy/OLA4kyQjavH6P4be3S/0d5DTpr4jKbN/LNWD3TqM7ccpvWDUYUO17IIGN594gJFku4vByGjphKwLWXI8HaA2WJuMgDUD5kJyE6S+3WaNLXFapimtqipA1FjoaVkBZqg/YIkW0VT0z/Ypiv+tkBX/G2BpKoq3X2aM2mt5k5aozmT1ng+X6u5lTWaNWG9xpT7Nb+8TPM7/LPotlm6bqyeXDFZT62s6MmVleGPk7Vo5SSt3ZzVpIG8CzP+fVxPGohQGf5IQxuAhJquOFPr/Arr8lqzZo0mT54cenxXV5e6urIchQM0OcPRZqZyDRpeFQ0VY4KCRugcZVtRnWBpl+pLeYaNm9btoH+a+rBeN26puktbtKnq/n8vHRrQh7Z7WJL0g5W7N1lhJujrnLu/vCrDHwsUMkylmZsMAK0kTW4iMwHRspo4EHa/ThMjJg40XFOb6dSBuLHQQYIykwLOZZKZascnK9BsU9LqTd1avalbDy4LuviWOtu2aIcJ6zS3MlSwmTNpraaPW6/FaycMFWBWVLRoVUXrNncGHp8NkwzoojATdE5/PgralnTSQIs1tAHIVRZ3R66r2szkoPnIS5cu1bp16wLnKgMwE9QhETl/1PBG3N6LBttOjxEXK96lvz0BO3svgoIukoKufaOWJ28V9sZ33BvlYReMHSHbZLDNK+rCNmq/MO2B53i8r6LF/WM1pjygg8eGd26lcWrlae3YuV6vbOnWhavy+nc8+PsdyUVhJm5ZftTvg6NCWNTM5KBtpiHDK0XIyGNuMh1gAFoJuQmos4gVvd5mk1zfUI1rrgm6/qsMfwxq6hn1JrbJm+L+3BQk6I342vag/OPPTEmaqGr7meQmkwyRzuaBdj2zuqJbnp2jnz2wp86/7SD983VH68u3H6BLHt1FDyzbPsfCjGlmCvpvk6QwE5ajbEZA09AWimkDQK6arjhz+OFDbfk33HDDqMeuv/76EfsAyIjh/8zrFjSC3kCOChpBjJpt0hZogrZlXaCxKdJ4lXTzuh0kSUeNf9HwHOa6S1v0gSmPSpK+u2LPTFbmjJQ0YNSO9e8XtY9JYcZ/XIYho8ZVyOiJf6o0mJsMAMmQm4BsWTd9pGxq875B633j1vuGbmM0tYXtH9XU5v3o/7y2b1wmMi3QhO0bJPsiTX2Zfn82xa+wfaKysvfzqAY228wUcYpKwGO1bd6/GzmvmqGhDWg8TVecOfroo7XTTjvpwgsv1P333791e29vr774xS+qs7NT73znO+v3AoEWZzpyKNOg4ZVp0DDZN+pN8rALTFcFGrdh4w/rdpQkHTPhRXWUBgzPYeadk/+m6R0btbh/rC7pNbunTTJpAkbt+Kj9/Oc3nZfssvsrQFYhwyskZNjOTXYRMoIQMgC0GnITkL/IiQMRGmb1TNC2yJ4gl1MHFLGttr0eBZraOZqpSGOTmUyb2ZIWZrz/XV1PGjAYZxb1PsHUkM9rLFbNpMU9OoHiarriTHt7u37wgx9ocHBQhx12mN73vvfpnHPO0b777qsnnnhCX/ziF9XT01Pvlwk0NJejzbwXCcZvuiblavVMbVtg0DAdb+Zn2gnmskATtl9tX7sizd0bZmpp/xhNauvXEeOWGB4bb3LbJr13ytDvyddf2Uf91TZn597GRcCwWZJfezxqm4vuryA5hoyeiJfhQNI3KpK+MQIAzYTcBBSA4WizKIVdPRPJ5FrV9dSB2naTa3TTxjabzBR1nkZh8/qTNrP59zH57xy0raDjzHJqaIvCPTqBYmm64owkHXnkkbr99tt1yCGH6Ne//rUuuOACTZ8+XRdddJHOOeecer88oDUUPWh4pQkauYw3M3mTPosCTdj+wQZV1lVrhla1nDTxWePj4vzjdo9ofNsWPbJpsq5dO9fZebeFC9cBI2hfV4UZ2+4vQoYkOsAAIAS5CchWHvfrzERdm9qi+HOTl0lTW227ydjnrBrbaudppCKNbWZK2swmJS/CBG3LYNJAkErAtqC/I14ZN7R5/42IbGhjpBlQd01ZnJGkAw44QNddd516e3u1YcMG/fnPf9YZZ5xR75cFNI08ZihnwnXQCBR1ceeyQBMWKkwKNLZjzszCxlVr5kmSDh2/RFPb+iPOaWZOx1qdURl60/6/X9lPVZVSnW+IbRiyDRhJCzP+z113fwUwGWfmVaCQEYmQAQDGyE1AnRk2tZmuGE7V1BYns6Y20/FmfkmmDtS2uy7QhO0fxbZZLE9JXluaZjb/fjaFmTQjoIPQ0AYgP01bnAGQv6xnKNd19Uxl+OO4gG2pO8GSFmj8nwftW9vmKmxEB46nNlf0wMap6ihV9ZZK7Rcieej4xLT71VGq6rZ1M/XnDXFXs1GShgvXASOqMOMNia4KM14W48y86hgyonj/3SBkAACAVuNtWjF9YzZWj+fztE1tVvfsDNrmeuqA6wKNq1U0/nPWs1CT9DWkbWaTRuemuG0m0yVMGOT2FmxoC0JDG5ANijMAspNwtFkmQcPLJmjk0gkWxWbZtun2oPP790veEfaLVbtKks6oPKl2DYacO/6i/6jxL+roCYvVXy3pv17ZL3Lf6OdwFS685w86Jmq/sNDgcl5yEItxZnHb6hAyvH//M19VBwAAkIGs7tdpKvfVM5WAx2vbjJraghShQOO+sS1cmiyT5/O4amZzVZjJaZxZgza02eAenUB+KM4ASKzhgkZPzAmnhnxeUzHctpWLG12GMQkaUdv953IXNm5YO0evbOnW9u0btXDCcxHnqD3H6D9jS1X9v+3vlST9ZOVuWrR5kvGxyZkUZVx1fgVtjyq8mXxfjR8yTDE3GQAANJUM79dprcfzeWGa2oJO4LpAE3Qt7pe0sS1pkcb/HEmyT5pjw8QVZZI0s3n3yaMwE8RynFll9G6B7yN4/+7UsaGNaQNAMVGcAZCtIgUNr7ig4ZUqaJiMN8uyEyxqe9jr8O9rEja2HduvNl04vHrmg1MfVNuo1TPxPrH9vZrZsUEvbh6nC1bsK7dhwi9JUaZ2XNy+toWZuODmYJyZV2X4Y1DI8P7e5xQymJsMAAAg46Y205XFjdnUZrNaPIvcZJMBvOcI46JIE/a8WRRggiSdMGBamInLUmkLM14Ox5nFNbTFoaENaGkUZwDky3CuafMEDa96dYJFbfez7YDyHzt0/M9X7aaVW7o0r3OtTpr4dMxxIx09/gWdUXlSg1XpM8sO0qZqVkv5TQJS2oBhsj2L7i+HM5O96hQyvP8eJA0ZAAAA9WI9ccCnNZvagh4zaUzKIzdF5QHTzJRFoSYLJq83TSGrdu6sMpNXxpMGaGgDkADFGQCppA0aXlFvuuYaNIIuqrwXXZXhj0GrDSKDRtA2k24eFwWasOX6NmEj6piRx2+ojtX3V+4jSfrwtAc0qdwXc8yQ2R1r9fkZd0qSfrxqD921YWbMEbZMw1DaTrm8CjNeFt1fXs0aMnzoAAMAAIWW8A3SJPfly72pbVzANqvxZl5xUwe8itzYVju+qEWatJkpyykDYZkpiKNJA3Hb4v4+ePV4PqehDWh5FGcAZC/haLO6BQ0v74VV7p1g3s+TFGhchI3kgeNXq3fTor5J2r59o/7f9L9G7itJk8p9+r8db1KlbbMe2jhF31y+b+wxZmy609IGDMl9yHDU/eVVMdxmEzK8DEKGqST/DkiiAwwAABRe1vfr9Da/JBmLtFXapjbFbNvKe73rYupA0LY0BZosGtu856hnoaZD5q8j6vvKu5nNy/GkAS/T+3N6hTW0pZCkoc0GDW1A/ijOAMhfowYNr4rhtkBBF4ImF5A2BRpXYSNoX7PjNlfbdd7SQzVQLenNE5/WaZOeUdgF/3ZtG/W9Hf+gns61WtI/Th9afIT6q20RzxvFJljEfx92AaO2JD9oqb73GNvCTNDrsez+Chpn5jJk9MQc5+P9e2u6Go65yQAAoFEZXXs4aGozfcO2+E1tQduCVpGbFGjk22ZToMmnsW30ebIu2GSRmepRmEkzaSBiFyl6YoZXhqtmTEU1tDHSDCg2ijMAUjMabVbkoBHXxWITNCrDH2ODRpKl+pJ5gSbosbjt3seCXpN9keahTdP07RX7SZL+ffqfdGblMUm1/w5D59yjq1cXzb1ee49ZqdUDXfrAi8dq+cDYiOfyvp404aVd+QaMNIWZjLu/KgGPOw4Zprx/zwkZAACgpbVCU1vqqQNBj0WNvbIp0NhOHqjtG1ekMS3U+M9pm3tcZCYpXWaKa1rzPubdHvbfyX9+ya54NzHgsQCVmG1B7wt4/y44WjWTZ0MbgPqgOAOgPooUNLzCgobpHGWvxJ1gpo+lLdC0y64brHYekyLNtnN8d8W+unj1LmorVfWZ6Xfp57Ov099NfkRnVh7TN2b9Qb+ee7V26FinZzdP1Nuee5Oe2jxZ0UEibfdYXCCy7XozDRjebXGFGS+TsXeOxpllGDLC5iYTMgAAQKtIe79O0/vtJRl3VLemtli2UwdcFmiiHjPJFFGSFGmCniPrzJSmec+/v/e8YY95v476b5dkBHTQY9GbEt2fM0yP5/OUq2ayaGhj2gBQHxRnADjR9EHDy3sxVgl4PGjbCLadYFE3u4y6kM06bJgGjg79+7LX6b9ffo36Btv06rHLdO72f9Fnpt+lYyc8r7ZSVVev2Ulve+5Neq5/Usw5kzLpUksSMDpCzhv0s7cpzCRdlu9onJnjkGHK9O+36b8XEiEDAAA0mIQrgKPeoA1raksyPmmruja1eUWtLE9aoImaPGCTm2wyU9pCjQuuMpNNM5vrwoyDSQOVgEODtnmxagZAQhRnAOSnGYOGl/F4My+TTjDvNpsCTdxjCngsiyLN0H4/WrW/Fj5zhr7+ymt0w9q5+sPaOfruin104jMn619fOly9g90G5zHVLvOg4zJgBD1m8t8jqjAjy8eiNyUeZ5bzqpmov/eRGGkGAACaScKJA7k0tXmvD4PYNLUZjzczeUPeZCy0d1vQqo6oXBSVqYKYrmaxyTAu5JGZ4qYMeB9LUpgJOqejSQNpG9rC7s9JQxsAD4ozAOqnGYNGok4w0wKN/7E0BZokYSN94Hh5yzh9b+X++pclx+mflyzUN5cfqKc2T4s5dxR/oDANMiavN0nASFsosw2bUfcu8qjEbAsaOxGmDiGDDjAAANAs8rpfZxRnTW1eYQ0+zsebeZk2M9kUaIL2V8hjto1tYceFSZp1TM9lwrSoFHSc93njHktamHE4acDLRUObJRragNZEcQaAM1mONmuYoOFV8XzufKm+aYHG5CLXKyps+I8PYxM4/Oe2/WPLpmMt6Niwx+M6v6IKM2mX5YcwHWfmlWPIiMLcZAAAAA8HTW2mY5FSNbXFrSSoeD7PtKnNtEATVhgI2l++x9I0ttWOK2pusl3pE3SsAh5TwGOmmdW0MJNi0kAl4BDvtrj8HzZpoMfzOQ1tAHwozgDIl79DoxGDRlDoqHg+t+oES7pU3/t51MWrd5vJcn3bx6N0KF3ocMHmNUQVZUwDhm3Asy3MZND9VaeQ4f17S8gAAACtrima2rzCmtoqAft6t+VaoAnaFtZYlbSxzdWq/qzZvAaTokzaKQNB25L8N08wacBmnFlcQdIQDW1A66I4A6BQGiJoeIUFDeNOMK+4pfouCzTex9OGDdsl8VkGjyTPYVqUMe0KS1qY8Z/bz1H3l1czhQwfQgYAACgqo2sSi7FDkc0sHoVoavNK1NQWt81Fgcb7uW1jW5oijfccWRdrXOcm7z4mj9sWZryiCjNehpMGvGx+J8MmDdDQBsACxRkATqWdoezXEEEjbuVBxfN5qk6wsGOSFGiiHvcyCRv+/Uz5A4FpOEh6nF/SoowUHjBMVynVsfurotFMxpllFDKimP79l8TcZAAA0Nws3mD1NrtEvaHbHE1tUSvQwx63LdD4H49qbPMf59/HZWaKyj9Jjwt7zS6a2WxXKXnPZzquLMGkgUrAKb3bwhraHClaQxuAfFGcAVB/CTs5ChM0vCqez512gsV1B8UVaGy6wVwUaWxDh5+rAoyfabhQyD6S3c/Vv827PU1hJoTLcWYpmYYM799jqxta0gEGAAAaWNqmNps3YL0K09RWCXhS7zbjpjbv9rACTdDxcQWaqNXwYefPOzN5nyevzOR9TkXsl+TnGvazznAEdMXzedJxZkVeNeNn0NDGtAEgXxRnADhnFDQiJB1tVsigYdUJFncxmaZA492eNGzU9jEJHN59XYSONGzChSL2jVtxFPe497w2/11rUnZ/eRUkZNhgbjIAAGh5Du7X6Ve3pjavsGYiqwJN0ONRb+BL8as1wq7zkzS25VGoScPkdZgWZUyb2RTyeJrCjGK2eVQ8n1s1VrrjetXMKDS0AYVHcQZAffg7NooYNHIfb+aVdYHGu902bNjs598/j+Bh+lxJizJJilxBAcP7edx/TwfdX97tGUuyasaPuckAAADRmrapzbs9kM1YaNMCjXffsCyVprGttp9pc1vRMlNcUUYy+/nEFW68zWxJCjMJJg14VTyfN1tDmw8NbUAxUJwBUEiFCBppVDyfO+kEM73ZpffzuAte//aofWyKNCbL6P1hIEkASXKOsHDhImB49wnrvktSmFHMthAm48wKFDKYmwwAAFpNnvfrLGxTWyXm3LFTB8K2uSjQhDVlpWls8+5r2+CWV2byvkb/eeL2Nf25eY+Vb3vcf0eFPO5wnJmXbWYylHtDG/foBAqJ4gyATDR00OjxfF6YTjDv9ribXfo/DysY2IQN/3P5JQ0cYecx+WMq6LUkXbJvGsy8z1OTtDAT1A0YsmsleJet0hRmDGUeMvyYmwwAAFqJ7w3Xhm9q87JqavOKmzogJSvQRG1P09iWNjdlkZmCXotJI1van5V/u0lhJu7enA4mDcRNxYjS4/m8oA1tAIqD4gyA4ihK0MiTUSeYywKNf3/TC2iTjjDvvmGBw7ZYk0TY89ks21fAvkkDhpS+MBMiz3FmPZ7P6xkyGGkGAACaHE1tAdsTTx0IK9AE5aokmSkqO9g0t3mPySMz+Z8vaSObf3/bVUbe7RkVZsJkPWkgQr0b2hhpBhQHxRkAuUozcijzoFGvOcrWBZqgx10WaIIupmv7mRZpvPtHhY60ASTuPKZdY3FdX1E/E+9+QdtdFGYcd38RMgAAAAoh0bWKRbNKIZra0uSmWEkLNN7t3mKNSRHBnwmiChs2zW1RDW6uM1OS3BRXxHHdzObfPyg4GxRmbCYNuERDGwADFGcAZMYoaPi7wOoZNEzVrUATtM1lgcakIyzoHLaFmrhl8TZ/kjxP1GtPErqiOuaCAp+DwoyJiudzk8KMqRxDxiiEDAAAgFH8b8wWrqktjYrn88RTB/yf593YFvR4mLg8k0VmStLIlnUzm3//hPfmrFdDW0QRs4gNbQDqh+IMgEKzmZuaW9BIouL5PPHIKZdBI6wAERUggi7IkwQO//G24SDpcWGvM+x7ido/aD9/wEg6gi5F95dVl6FPAUOGzd9/QgYAAGhUie7XSVObT94FGpvskKS5zX98vTKTTcHJtOnNVWHG8Qho28KMoUZoaGPaAFA/FGcAZMp10PC/eZtr0Egz3sxELkHDddgI2z/NsnvvOZOGCT+TcKGAfUx/LlkEjDp1fxmqe8gwmL1OyAAAAK2i4ZraGrZAY9PYZpo9kuamPDKT93n8+4V9bZozM85NWWvAhjYAxUNxBkDhFSZoRCl8J1hWYcN0NU3aYo2tqOc1KSbZ/CwasDATpU4hw4+QAQAAWl1sU1sEm7FHdWlqSyv3Ao1NY5tJkcZ1oSapuMyUpgBlMmVAKnRuMhlnFsH7dybqPYi8Gtq4RydQPBRnANRFEYKG9+IoMmj0eD5PMke5rgWarMNG2Lawji1/0cRV+Ig7p+kYNtuAUafCTFoOQ0YUm5DB3GQAANDKEr1B6nsjNqq5xX/t1fBNbcaynDyQpkjjPSaqUJNlZrJ9PWFfx33vUZmpQQozUXo8n0c0tHnR0AaghuIMgMwVNWhEMryoMhpvZqri+dxZgUYyu8iV0oWNsG0my+uDQoLtH7+o501alMmj88tQxfN5VuPMejyf13vVDHOTAQAArDVEU1vmUwck95MHwpq3JPPChm1uqldmSluUMW1mK0BhJorJ/Tl9slg1Q0Mb0HwozgBoSIUJGn5pOsH8UhdoknaDJSnSmAQO7/Fp5yHbntP0daYNGGkLMwUcZ+bjImT4WYUMH0IGAABoFq7v1+lnMx4pt6a2KBXP57kXaKL2k+wyQ9A+3m155KakmUkh24K+v7DHbZrZMijMmErS0OZV71UzNLQBDYniDIBc5B00kq6eGVGgibq4ymO8WSTTC9UkF76S2cW2beAIW44fFBRs/gSJet6oTjX/fv59auJ+fnUszERJMs4sg5AR9cZAbMgwGIFIyAAAAK3Kfy0V1QTjv0YrXFNblIrnc+MCTdhjNqs40jS2effJOzcFMclMQd9P0mY2/9dhP2f/1ylHQLscZ9agDW0AioviDICGYRM0/EyDxiimc5STjDereD637gTzP5ikQON/zDZsBO3j3S8oVJgsr7dlcs648QFB+4btE/fzzbkw45ckZHj1hD9EyAAAAMhfEVfPOG9qy23qQNrR0EFfm+aJeuamemSmJFMG/F8bFGb8Kp7Psxpn5uf5OxBVvHTV0DZKTENb0LQBGtqAYqA4A6CuYoOGhUyCRhTXnWC5Fmhsw4ZpkSYqcIQFiqCgYPPH5Jw2r9W/X41NwFDAvjUpCzN+3sfCCjN+TRwyAAAAGpmLN0zzWD0zStqmNr8kUwecjjiT4icPRDW21Y43KXR4940qwrjOTHHPa/M9eaVpZktQmEk7aSDJOLMez+eG9+f0S9PQFnVPXgCNheIMgNwkChq+N2cbJmhECQsaft7HEhdoXIUNyezCvLZfkqXzSZmED5tutaD94gpYcT/LmgSFGb8k3V8tFDLoAAMAAIiWS1Nbj+fzJE1tfhXP56nHQkvJ7t0ZdJ64xrbaOUwbxrz7FzE3Be1b46qZTXJamPFLMmkgwTgzPxraAAShOAOgqTRM0DBdqu99LFGBxv9YVFEh6HGbjrCowGFarHHZBWYSeIL29zIJGEmW5AftG7K54vk8bfeXX8Yhw88qZPgRMgAAQItwcb9Om6aXTJra/NKON/NLPXVASjd5wKSxzba5LSwzeY/NMzfF5Twvm2Y2/+OOCzN+3scynDTgR0MbABMUZwDUXVMHjbSdYH7OCjQ2YSPoiZMsxzcp1qQRd37bcJEmYPi/NizM+FU8nyfp/vLLOWT4//55xYYMi9npNYQMAACAYP5rr9yb2vzSjjfz8z6WSYFGsm9sk+zvNePPNK5zk8n5bSYkJGlmy7gwk9ekAZ+6rJoB0PAozgDIVRZvnhY6aPil7QSTHBVogr5OEzaSzE0OCgZp/gSxnaEcVpSx7fxKWJgxXZbvfaygIcOPkAEAAJBOUzW1+bloavM+lqpAEzd5wKSxzUVzm//YrDJT1OsIO9Z1M5uUqDDjl+WkAb8irJpJMNKMhjagWCjOACiEhgwaPZ7PTYOGX1QnmLMCjW03WJKwIeV/r5kkz2MbLmx/PlEBI+jxkIcq4btFFmYKFDKiVs34uQgZAAAAzYSmNs/n/mveiufzTAo0/sdNm7VcNrflmZlMG9kks+8z6JxxUwYSFmYqns+j8nOSSQN+PeEPsWoGQFIUZwDkrqGDRhbjzTIv0Ejx3WBpwoZN4PCe33YWctJjo15PWKEprigTtE9GhZkk3V9+UePMejyf+36/s1o14//76gIdYAAAoBXYNrXZSNPUNuK60aapLWq8WZKpA1JGBZqgr00yQm1bVHObaYObbeHG1b1npHQTBjKYMiAlGwHtFzVpwPu72uN7zPM77s9MrlbN+NHQBjQnijMAiquIQcOvJ+JJTTvB/CoWj9UlbCQNHKbzkpPeyDLseYPYfB+mhSqvqGJYxKEV32NJu78yGGfGqhkAAIB8uWg+8V9rZdXUNoqL8WZ+pk1tfs4zU5K8ELVdcp+ZTHJTXFYLa8YzmTCggH0cNbNJyUdA5zxpIM2qGRragNZAcQZAYdi+6Zpl0HA2RzlpJ1jF83nUxaaUU9iwDRwmocPVDS5Nzxn12tIEjIRL8v0PV3yPuej+8sshZPgRMgAAABpHIZra6jJ1IO6aPk1jm+kEghrb+8WYyDszFaAw45fzODM/Vs0ACEJxBkBd1OPNVJug4WcVNFx0gkk5F2jiLp6l4O6rqFARFzpqsrqpZdxrCHssh4AR97Cr7q86hAybVTMAAACwl/Z+nYVvavOzmTqQuEATtIOrxrawfb3bY1+c6pOZol572PcZd6yjwoyf/7GkkwZYNQMgRxRnABRKwwaNnogT2wSNpLOUpZTdYEEnSBI2TAo1JsEjKZPniQoXOQSMoIcrns+z6v7yhowe32N1WjVDBxgAAEA0mtoCvo5qavPzP+akQGPa2JamuS3L3GSTmUy/N9NmtoTjn6XkI6D9bCYN9Hg+T3F/zjSrZgA0N4ozAOqmiEEj6o3nURdfWYw3k+xmKfsfzyxshBVpkhRq/PvYhBAXx/nZBKe4gBG0T8zDFc/ncf99Tbu//Cy6v9KEjHqsmqEDDAAAtKKWaGqry9SBoB2SNrZJ6TKTfz+bwk1Wmcm0KJNhM5tkNwI6g0kDfpF/B3xsV83Q0AY0N4ozAAqnnkHDL9Uc5aSdYH4V39eZF2jCTmIbOLzH2HZ9JQkgUccHsS3KJClsBeziVfF8blOY8YsKGVHdXz7+329CBgAAQP21bFObX9zUgYrnc+eZSUre2CaZZ6Y0mSftOYIkGdPm5biZTUo+AtovxTizRls1Q0MbUFwUZwC0nDRBY5R6dIJJOXaDJSnSmIYO24JNFNPzmqz4CdruZxkwgnapeD63Lcy46v6KCso+hAwAAICCa8WmtszGQtd2yKqxLSwzeY91nZtszpv03jlepj+riIcrvq/TjIBOOs7Mh4Y2AC5RnAFQV8ZvrjoOGjbjl2KDRj06waScw4ZN4JDMQof3HGn/RDEtGplsT9D5FbRLJWJf/2NZdX/5EDIAAAAaS9bXSIVtast16kDQTi4a2yTzzOQ9Tx6ZyUVRxvGUAcntCGhHDW1x9+ekoQ1AHIozAAop7zdj44KGzRvVmXaCVXxf5xY2avvZBg7JLnS4EBcspOiQ4jhg2IQM/2NxhZk03V8NHjIAAABaSeI3WJulqa2uUwfCdkqSmUwKNfXITWFsM2CC1TJBu1R8X2c1AlpyNs4sTtYNbQAaE8UZAHXXCEHDLzZo9EQcbNsJVtew4TJwSKNDR9rwYXuuqNeXQ8Co+LYlvZFl0NdR3V9+TRAy6AADAAAwuF+nY7k2teU9dSDx5AGbzGT6ZFlnJtOxamGP+SVsZlPALhXf12lGQPvFFWZ6Io71KVpDG9MGgMZEcQZAYRU9aFi9oW3bCVa3Ao1pcSJqf+9jRk+s8MAQ98dE3GupQ8CQ7OYl+9l2f/WEnypunBkhAwAAoBhcNacUuqmtnlMHpBwb27yPNXJmklI1s2VdmHE4acAm//szEw1tAMJQnAFQCEUJGqneeLYNGvUo0DgPGyYntpl17ILJ80U9HhWsDJ/eqxKwTz27v6LGScTIPGQYImQAAABsE9vUlvCaK0zqpjab69E8MpN/n9SNbVFFGpN8UpTMpJjHHTazSdkXZhxOGohraMsbDW1A46I4A6DQ8g4afqmDRo/vhFGdYElUfF877QZLWqSJewJ/EEgaQJKcxzZcxB0TsKtXJWCfLLu//Hp8Xxc9ZPj+fhMyAAAAsuG6qS3VdWOP7+u8m9qC9jGOJ2lyk8l565mZkuQmw5fkV/F9nXVhxmLSQJy4SQM0tAGIQnEGQNPJOmhYzVGOkzZoSBkWaGo72oYN73E2ASIsOKQNJknDRe1Yi6fwqvi+Hqf8Q4YF23FmqUNGQoQMAADQyoyvhXJuavNz3tRWrwKNVW4K4qK5LeyYPDOTlFkzW8W3rd6FGcuGttzR0AY0FYozAAqjUYKGX92DhpThcn3vzlFFGtNCTZGW6KftZvPs6lfxfW0S/uocMnIXMzeZkAEAAGDOxbVT1k1tqcabmciiQCM5aGyTGj8zOSjKZDFlwLWUkwaK0tAGoHFQnAFQeE0RNHp8L6geBZqgfayv/dMEDv850oaPJN1hccUky9fTJIUZQgYAAEDjcNXUVrhrtB7f17aZKYhJgcbp5IHaznGZKS43pV0JE3euODlnJilZYaaO48xykbChjWkDQOOgOAOgUJo2aJioV4FGSnCd7yJwBJ0viyX6Nqt7LF+uV0XZFGaaUUzICEPIAAAACEdTW8DXUnyBRjKfPOCsSCM1Xm5KUCByVZgJQkMbgCZAcQZAQ2iJoBEk7wKN0yKNNPIC3yZ0pGXbkWYhacAI2s+kMNOCIYORZgAAANESN60UoKmtYQo0Yfs5z0xS/TOTw0a2sEMqSr5yqeCTBjJBQxvQEijOAGgeOQSN3DvBgiQt0Jh0g0kJV8vbLo93FT6CzmczIsBC2CGVgG1JCjNBCBlbETIAAADijWpyMbzW8sq6qc2JIhRoEhdpbAo1WeamKClGqbnMTFL6wkwOaGgDkBTFGQCFE/YmbBGChhM9vq+TzFJOUqCRMg4b3gOTLp+3/ZPl6/Id6ldR8pEIWcxLzgAhAwAAoAll0NTWEFMHpHQFmqB9k94GpvC5KQHXzWySfWEmSJ0nDSRCQxvQMijOAGguDoJG5uPNghShQBO0b4prczc3r0zDwXOHHVoJ2JZ0Sb7kpvurx/d1I4QMAAAAJJb4fp0GbJvanHAxdcBlgSbzxjb/CeqdmTIoylQCtptOGUhSmHE8acAFGtoARKE4A6CQ6hk0kkjdCRYk7wJN2L6So5yQZbEmzY0vI07nV5Hbzi+pLt1fLjgJGb6/v2Ehgw4wAAAAc0Zv3BaxqS1Ij+/rLAs0lYD9TK/zJUcxx3GuiT2/g1MFqQRsMy14mWSmILaFGQM0tAHIGsUZAA2lqEHDSBZL9YO22RRocltFEyQoeCT94/glBamEbM+6MJOy+ysIIQMAAKA5NH1TWxCTAo2fSYFGSj95QGr+3BSXmSoB201/hqaFGRcjoHO4PycNbQDiUJwBUFhFDxqF6QQL2hYWNCoB25OsoqnHivssJQ0Y9S7MGKhHyDDC3GQAAIDMNFJTW92mDkjZNbZJzZeb4r6fSsh2l5kpaFsGI6BN2N6f0wgNbUDLoTgDoOFkFTTi5ij7L75MZNYJZlqgySNsNHLgSFKUkex+VlkWZnp8XxckZIz6+2YQMpibDAAAkEyjN7UZ6fF9HXRtnHWBRrJvbJOaOzNJ+TWzBW0Lysp1GGdmwsWqmTA0tAGNi+IMgObgIGiYsJ2jHMikE8xVgUZyFzbCjpEaK3CYFJUqIdttAobkrjATpMf3dQbjzEwQMgAAAIqpyE1ticab9fi+zqNAUwnYnjYzNVpuClORmyJW3oUZy3FmJvJaNUNDG9B8KM4AKLSwN2frFTRMOBlvFiTPAk2zBQ7T11WRu4DhsjCToPuLkAEAAIBRCtLUFsTJ1IEgaTKTZLcqpLZ/0DFeRcxMknkjWyXkMduJDEkLM0Ec3GcmCA1tALJEcQZA82j0oNHj+9qkE0zKphtMCg8bccfV1LtQY/P8FWUfMLIszDRhyAAAAIC5Rmhqa5ipA3k1tkkjM0vRc1NF0d+P7ci3NIUZByOggyS5PycNbQDSoDgDoPCaJWgUtkAjJQsbtePCjvXyh44sgoft+StKVpSpHRvEZcCQEnV/BWnEkEEHGAAAQEYyamrLZLxZkJ6AbS4LNFI2jW1hx3rlnZnSNrJJyZrZ8i7MBEgyaSDu/pxBaGgDEIXiDIDmUuCgESjpUv2susEqIa/BtEgTdnyQoGCQ5o8pk9cZ970GyaMw0+P7OuD3h5ABAADQmurd1GYis6kDQbIo0EjJM1Pt2LDjg9Q7M0W9VlfNbFK63OTXE7Atp0kDNLQBsEVxBkBDKGLQcDLeTEq2VF/KP2xIdoEj6jx5q8i8KGM7lsB155fkrPsrCCEDAAAA9WxqczLeTHI/dcB28kAl5DEyUzDT8c8K2eZoBHQQF5MGgtDQBiAOxRkAzcfgYsZ/UZQkaARJNN4sSE/AtrwLNJWQxySzwOE9T9S5slCRfbhwFTCkfAozBt1fhAwAAIDW4rLJJaumtiBG482ynjoQtj2sGUtq7Mxk+9wm0xSC2DSzhW13OALaxaSBIP6/C0b3t6WhDWh5FGcANIxUFx8GFz3+i6esgkYg06X6WXWDZV2k8Z7L+8eVJOc17WgLU6DCTBBCBgAAAGqSTBwIklVTW5C6TB2I2p42M9kWauLOayvpudPch9QmM4Vtz3AEdJCs7s9JQxuAIBRnADS8wKCRYPVMEP9FVlDQMBlvlrgTrCfgReUdNqT0q0+izuvij420s6CjOr9czkqWjAszJt1fJuPMCBkAAAAtJGFmyrOpLZepA7aTB5IWaaTGy0xJizJSfQszAZJOGnBxf85ANLQBEMUZAK2kETvBpPwLNGnChpS8UJMl09dUUX0ChuRsWX6QrMaZBSJkAAAA1JXL+3VK9W1qC+Q6M0l2kwek+Ma2SsTjUvNnJtspA000aSAIDW0AwlCcAdBQUgWNAPWcoywZdoLZsC3QZBU2auoVOsYp2fiAMLYBI2y7TWGmJ+L1DDPp/gpCyAAAAICk4Gu1Vmxqk9w2tknJmtvqmZviVJQuM9msUHJcmEkyaSAIDW0AXKM4A6B5FTBoZN4JJtkt11fEdtOwUYnYx8sfOlwFjzTnrch9wFDI9rSFmYTdX0ULGQAAAMheIzW1+a9XjZvailCgkaIzk2Q/ViyrYk3S3FRR/OuP+hnY/kwzLswEMRlnFoSGNgBpUZwB0HAaOWgEcR40JPdhw3XgqAkKCLZ/bFWUrigj2a9GyqAw06ghgw4wAACAgsm5qc3p1IG8CzRJG9skN/d/qUduiuKymU3KpTCTdNIADW0AskBxBkBza9WgIbkNG1K2gSNrFZm/rriijIuAIREyAAAAkKmiNLVlOnUgTE/AtrACjavJA5JZZpKKmZkk89yUtJkt6rE0hZkAppkp6aQBGtoAuEBxBkBDasagUYgCTdxjkn3gqBjsmwXb588rYEi5FGYIGQAAADCWYVNbkMynDkjFb2yTRmaWisH+rtk+v0lmSjJlwLQwE8ZgBHQQl5MGaGgDkATFGQDNz2HQcDneLJMCje1y/aRhQzIPHNLoi/6K4XGmkp5/mrIJGFLhCjOEDAAAgNZT5Ka2zKcOhCliY1tNRdkWa5Ke3+T7yLqZTWqoSQM0tAEwQXEGQMPKI2i4HG9m0gkWyrRAI9kt15fiL6JdB46aisM/tkzDRdIglkFhxpRpYYaQAQAAgBFybmoLkktTm5SsQOOisc02N1Uc/7Fh+przaGaTGmrSQCAa2gAEoDgDoDUkDBpBko43C2IcNML0hGx3GTZMHpeSB4682Ly+uO81KmCELclPMS9ZSt79FYSQAQAA0DqaoamtEAUaKX1mkoqdmSS7zJRHM5tUmEkDQWhoA5AGxRkADa0oQSNILkFDyi9smDxeU5RCje3rSFOkchEwJEIGAAAA6qcOTW1Jr1eljAs0WTa2SSOzSiPlJpPvryCFmSBpJg0kvj8nDW0AQlCcAdA6ChI0GqJAYxI2khRqsg4dSZ8rbcBqoMIMIQMAAAA1WTe1ZT3eLJSLAo2UvrHNNDNJ+RZqkuQm06KM7ZQBKXVhJkzSSQNBEt+fMwgNbQCGUZwB0LTqETSSjjeT6lSgSRM2avvYBA5pdBBIEkBcnMO0yBRXlClAYcaUaWGGkAEAANBcrK69DJvajFZVB8ilqU2yL9Bk0dhmuo9fWN6xyTwuziG5+R6jfoYOCjN5NLSZoKENgC2KMwAaXpGCRhDToBEm0wKNlD5s2OwXJSo8uFx94+p7chUwpFTL8iW348yCEDIAAACaV9KmtiBpmtrqXqCR0jW2mRYw8spNadi81jTNbA1SmKGhDUBWKM4AaGrNEDSknAo0Los0aQOHa7bhIm6/jAszYQgZAAAASKroTW02nBRoekJOnrSxTTLPQUXMTbavKatmtp6A7SkLM2mYZiYa2gAkQXEGQFPIK2jUs0ATKEmBJm3YSBI46hE6XIcLyW3nl9QQ3V+EDAAAgObXaE1tYawKNFJ9G9v8+zdCZvIeEyWHZrYsJg2kGVNuhIY2AD4UZwA0vcCLnaACTYGk7gSTki3Xl+LDhpQsPGRZrPGf23W4kJIFDCmzwkyhWIQMAAAA1E8rNLXlUqCR8slMLnNT2nOnbWaTMi/M0NAGoNFQnAHQNFJ3lhQoaEh1LtBI8WFDShcYgsJB0j9pnj9O0oAhZVqYadSQQQcYAABAMRW9uaZuBRpXjW31zk1pnzuKSWYqcGHGuaDMxKoZAAEozgBoCcarZxx3txSmQNMT8gJdhA2pmDOTg9i8zrQBoyfksTovyw8qzAQxLsywagYAAKChNOrqmTCZF2gkN41tUmNkJsnudcZ97xk3s0nmhZkwzhvaAMAQxRkATSWLDpM0QcNGprOUpfRhwzZwFCV02L4ek+81ScCQcg8ZpoJ+n7NABxgAAECxuW5qq+fUASmDAo2LxjapeJlJcp+b4n5ePSHbHWSmMHWbNEBDG4AQFGcAtIw0QSOvTrBMl+pL6cKGZBc4pPqFjiTPa1qUKWBhhpABAAAAG3k1taVVtwJNT8SLyjIz5Zmbkj6vi2a2npDHHBVm0kwaqDca2oDWQnEGQNMp4sVMYZbqS9FBQ4oPG5J94JCyu7ll2vOafi9xP5eeiMcK0P1Vb0X8ewkAAIDRitTUFibzAo2Uf2NbTRa5ycU5G6CZTUp/nxka2gDkieIMgJZSr6ARJvNOsDTdYFkVabzqdWNLya4ok7TzSyJkAAAAoJDq3TyTxdSBKHVvbKtnbkrK5rUXoJktr/vMZKXefycB5I/iDICmlOdFTdpOsEwLNFL6sJFX4MiD7etMGzDqUJghZAAAACCtPJrawmQ1FlrKsEBjkpmkxshMkn1myrEwEybtCGgbNLQBcIXiDICW4zpo2KjLUn0pXdiQzMOGVLxCTZLXk1HAkPLv/rJByAAAAGgtRWxqC+NiLLSUokDTE/OkSYo0RclMUrJGNpPM1BPxeILCTNoR0GFSTRpwgIY2oDVRnAHQtLK6uEk73szFUv1MCjQ94Q9LsgsbNfUKHUmfN8OAIbkrzNh0fzkfZ2aJkAEAANCY8mpqy3vqgJRwNLRk3tiWtLktz9yU9HlNv7+eiMcSTBmQ8p00YIWGNgAJUZwB0JLSBA0bLjrBcivQSNmEjRr/xb+r4OHivC4ChlS3wkyYTMaZETIAAACaSr2b2sJkOXVAyrixTSpWbso7M/VEPJ4gM0n5TxqgoQ1AHijOAGhqTi5yUs5RTtsJJmVQoHEZNpIEjpqgkGD7Jw2XAaOOhZm03V+EDAAAAATJoqktq/FmuRZoJLPMJKXPTFJjZCYpk2Y2qRiTBqzQ0AbAEMUZAC3LOGhYyKoTLEqiAo1kFjZ6DF+Ei8CRF5uiUo9SBQzJvjBjK5PurzCEDAAAgKaUZ1NbmHo0tUkOCjQ90bts5aK5LS+2r7VHhS7MhMmkoc3yPQUa2oDWRnEGQNOzvtjJaY6y5Ob+M1J0gSZV2JDMw4ZU3MCR5HX1GOwTEzCSFGay6v6yQsgAAADAsLRNbVmNN8u1QOOySCMVMzNJyTJTT8w+CacMSNlOGgiTuqEtAA1tAMJQnAHQ0tJeJKUdbxbGZYFGqkPYkOpfqEn6/D2qS8CQsu3+ImQAAAAgSlZNbWFsxpvVvUAjmTe29cTvttVM1Tc3JX3+HqVuZpPqW5hJPc6MhjYADlCcAdASihA0wrgKGlIBw0aN/6LfdfBwcf4e1S1gSPl3f4UiZAAAAMAni6a2MLYFmjBJCzSpG9skd7nJtTwzU8IpA1J9CjNh0t6fk4Y2AFEozgBoeVYXSzmON5PqVKDJMmx4BRVUkv5Jo0dOAoZU/8KMTfcXIQMAAAB+RWlqczF1QEpWoJEcNbZJ6XOTy8yUJjf1yPz7SJGZJPvCjK3MJg1Y3p+ThjYAEsUZAC3ESdAI4WK8WZ4FGudho8ds98LokfOAUcTCTBhCBgAAAGxl1dSW5dQBKeMCTV5FmnrpkV1mqkNhphCTBgAgIYozAKD0QSOMi04wyX2BRnIcNqTGCBw9snuNGQYMyc09ZiS77q9QhAwAAAAo/6a2MK6a2qToAk2qyQNSsiJNj9nuddEj50UZKX7KQB6FGRraABQNxRkALSXqIijw4slB0HDRCRYl0wKNlLxIU/tTbz3KpCgjZVOYieKi+4uQAQAAgKSK0tTmskAjOZg8INllJqnxM5NknJlcTxmQ3BVm8r4/JwB4UZwBgCRCgkbROsEkswJNJkWamh7lGzp6lC5cOAgYUv1DRpDQ309CBgAAADyK3tQm1Sc3xXKRmXrsD8/9+erczJZ1YSbL+3PS0AbAi+IMgJbjLGg0WCeYk7AhJQ8c0ugQUPtT73NZfE8mAaOoISMQIQMAAAAuBOWmDJvaojRsY1tNT8CfIpyrwM1strKaNAAANijOAICJDDvB8irQSA7DhpQ+cHj1JPjjgmVRJk3AkArY/RWGkAEAANDyrJvawtShqU3KrkAj5dTY5teT8I8LDpvZpGwKMy4mDYSioQ1ARijOAGhJWQcNqzfCI9SjQCNZhA3JbeDIi+VrThswJLeFGVtW48xCEDIAAAAQq0BNbVL2kwesizSNlJssX7NpM1sRCjP1aGgjMwEIQnEGAEw5WE1gGzSi5FGgSVSkKWrgSPD6TH8GWRRmomQ6zoxVMwAAABhWr6a2IhRopAwa26Ri56aEr62emSnzwoyjhjYACEJxBkDLaoWgkbYbTEpQpJGKEzhShAvTgNGwISNISGGGVTMAAAAw5uCenUnUu0BjnZmkYuSmFK+h3s1smY4yk5w1tJGZAIShOAMANiyDRr0LNJKbsCE5ChxZh46Uz2XzPZoEjMIWZhhnBgAAAEPOmtpCuMpMUrYFmkyLNFJ+ucnB8xS5mS1KvRraACAMxRkALS1R0Mh47FO9CjSZF2lq/GHANhikPT6AbVEmTcCQ6tz9FYZxZgAAAHClTk1tUnYFGimHxjavNLkng8wkuW1mk7IpzNDQBqCRUJwB0PKKNt4sSpYFGsk8bEiOAodfVIjIoIPM9ntIGzCkAnR/ETIAAABgqZGa2qTsRkNLOTe2hckxM0nZNLM1ZGGGhjYAjlGcAYAkGrATTDIv0NS9SJOxJEWZehZmCBkAAACotyI2tSVdXV6vxrZGyk1JXnMRM1PWaGgDkAbFGQBQ9nOUo9SjQOM6bEjFDxxJX59pwChKYcZagpnJhAwAAAAYybipTUo2FlqqT2ObVPzmtqSZKevCTJSo7ExDG4AiozgDAEk5ChpRsirQSNmFDak4hZo0r8NVwJDy7f5yMc4MAAAACJLHeLO8xkJLbgo0UroiTb1zU5rXYfN9py3MJLk3Z9aFGRraAKRVqlb51yLKmjVrNGnSJPX29mrixIn1fjkAMnZVqRT62Ak7hzywMGT7UcGbq4cHb79/yu6hz/2g9g597AntGvqYJD2l+ZGPL9KCyMe9XtBs433DvPx0+nOEcRFqbEJVVgFDImQAqC+ugWGD3xegtdQzM0nhuSnLzCSZ5yYXmUnKLje5KgS5KspI2eSmqGKei9wUN2GD3AQ0PxfXwBRnYhA0gNaSR9CQilegkfIv0gQxCSBZdZW5LMpIBSrMSIQMANa4BoYNfl+A1tOKTW1S42QmqbVzU9aFGYmGNgAUZ3JB0ABaT1E7waRihQ0pu8CRJ9vxA0UszEiEDABucQ0MG/y+AK0nKjNJIbmpSZrapNbLTVlkJqlADW0OM5NEbgJahYtrYO45AwAWrOcoJ7j/TNIbvaedpyyZ3djeK+k9aYrA9rWb/myatTADAAAA1CR689kyM0WJujZOet9Oaeha3tX9O70aNTcled2mmakwhRnHKMwAsNFUxZn+/n5deumlete73qXdd99d48eP14QJE3TggQfqggsu0MDAQL1fIoAGkNfFVJKLw7gbxLso0EjJw0bRA0fS15lH51dShAwAgA0yE4A81LupLU2BRsqmsU1qjCJNmszkqpkti4a2UEwaAFBHTTXW7PHHH98aMI4++mjtuuuu6u3t1VVXXaUlS5bozW9+s6688kqVYpbferFEH2hNiZbpS4VYqi/FL9eXsluy71WE5ftpwo+roowUHwLreZ8ZiZABYCSugZsXmQmAS05HQkuFGgstkZlM2BSo0jazFX3SALkJaC3cc8Zn8eLFuuKKK/Sud71L48aN27p9/fr1OuKII3TPPffo4osv1lve8hbjcxI0gNaVKGhIhbj/jOQ2bEjpAkdNHsHDRSeay4AhETIANB6ugZsXmQmAa3ncs1OqX25qxswk5ZubCpmZJBraAKTCPWd8dthhB33wgx8cETIkady4cfrYxz4mSbr11pzmvwBoeFEXV4nuzZHjUn3J3XL9miTL9v28S+RdLOl3fT7b77GQISPBzO4ohAwAaC5kJgB5cjXeTEo+ztfFaGib8dBFy0xZnNPm+yxkZpK4PyeAQmiv9wvIS0dHhySpvb1lvmUAGbvqyZBOsOsVvVQ/QOnW8E6w/VY+FtoJto8eiuwE20V/i+0Eq10sm3aE1S7CXXSFSW46ttKyDVCm4awuISMMIQMAEIPMBCCJE6rV2LHQVm5S5AqaIFGZSXKXm1o5M0nZ5KY0mSkxy8JMHBraACTVVCtnovzoRz+SJB133HGR+/X19WnNmjUj/gBoXYkvshx3gqVdQeN6FY3kpius3pJ8D6YBoxG6v+IQMgCgtZCZAGTBevVMhKSZSXI3eSDv6QNFkGTCQB6FmbwmDdDQBiArLVGc+d73vqfrrrtORx11lI4//vjIfb/0pS9p0qRJW//Mnl3/G7MBKC6X483ipCnQSNmEDakxA0fSooyLgCHlXJiJQMgAANSQmQCk0QhNbZKbAo2UvLGtkXJT0tecx5QBqTiTBmhoA5BGqVot3r8i55xzjvr6+oz3/8hHPqKddw6+y9zVV1+tU089VbNmzdKdd96pmTNnRp6rr69vxHOvWbNGs2fP5uaWQIuLW6ZfhBtdSvE3u5Tib3hZY3PjSz9Xy/ddShOEXAUMKXlhRkrY/UXIAJAAN3gvPjITgKJxnpmkuuUm08wkNVduaoTMJOU7aSAqN5GZgNbmIjMVsjgzfvx4rV+/3nj/m2++WUccccSo7ddee61OPfVUTZs2Tbfeeqt22mkn69dCMAVQExU2QoOG1NAFGild2JDqGzjSdqbZdMQRMgA0E66Bi4/MBKCImqlAI7VGbmr5zCTR0AYgkaYtzrhwzTXX6LTTTtPUqVN1yy23aMGCZP+TI2gAqGm2oCHlGzZqsgwdrsYEuA4YEiEDQGPhGrg1kJkAZKEoTW1S4xVovLLKTS5Hq1GYAdDKKM6EqIWMKVOm6NZbbw1dvm+CoAHAy3nQkDIr0Ejuw4bkNnD4mQSQLOc0286OdhEwJEIGgOLhGrj5kZkAZKVITW1S/gUaKdvMJMXnpqzvbVO0ZjYp30kDErkJAMWZQNddd51OOeUUTZ48Wbfccot23dXuf6B+BA0AXomDhpSoE0wqZoFGyj5w5Mm2KCM1ZmFGImQAMMM1cHMjMwHIWjMWaKTiFWnylkUzm1SnwoxEQxuAVCjO+Dz++OPab7/91NfXpzPPPDMwZPT09Ojss882PidBA4BfowUNKbuwITV24MiqKCNlXJiRCBkAMsU1cPMiMwHIQyZNbVJDFmik1stMUuMWZiTuzwnADMUZn1tuuUVHHnlk5D6HH364brnlFuNzEjQABCnSeDOp/gUaqbECR70DhkT3F4Bi4xq4eZGZAOSlaFMHJHcFGqn5c1MrZiaJ3ATAHMWZHBA0AARpxE4wKZ+wIRUzcCQNF5J5wJAIGQCaA9fAsMHvC4AweU8dkBqjsa2maLmpaTKTREMbgMy5uAYuO35NANAS4i7KIi/qIt5Aj7qAjLv4jLt4lcwugmt20d+sLrC9FmjR1j/1lPZ12PwM9tFDFGYAAAAAFxJmJinf3JQmM0nFyE0uXkMzFGYAoB5YORODLjAAYeqxekZy0wkm5beKJkgWHWKuA43rgCFlWJiR6P4C4BTXwLDB7wuAKI2cm2wyk+Q2N2W1qsZlbrItTBW5MENuAmCLsWY5IGgAiNLIQUOqb9gIExVC8ugoq0fAkAgZAIqFa2DY4PcFQJS4zCQlHG8mpRoLLTVuboor3GSdm7LITBKFGQCNheJMDggaAOIUtUAjNW7YqId6BQyJkAGgeLgGhg1+XwDEySwzSYUt0EjNl5uSjHArRGFGIjcBcI57zgBAA8ji/jOSwcWpzAoDkt29aKT0s5WLJMn30gyFGQAAACAvmd2zM0Y9M5PUPLkpaWaiMAMA0SjOAEBKmV6sFbhAI227SG+0wJHmdRemMJMSIQMAAABF0ihNba1UpMk6M0nmP/vEKMwAKDCKMwDgQKadYA0QNqTGCBxpXqPLzi+J7i8AAAC0ltTXoAUp0EjJGtukxmhuS9vIZpOZmDQAoNVxz5kYzE8GYCrVjS6lVLOUJbfzlGuSzFX2q+ecZRehxzZ4Fb0wI1GcARCPa2DY4PcFgI163X9GMstMkl1uavTMJOWfm0wLYVkVZiQa2gCk5+IauN3xawKAlnVCtWpUoAl1veLDRoTSrfFho3YRbBo29tFDqcOG90I/j9DhsgutcAFDSjVzWyJkAAAAoNiuejKiQBOXmW5SZIHGJDNJQ9f2NplJSlek8WeYrHNTvTKT5KiZTUo1AprCDICiYOVMDLrAANiqZyeYlE03mOSmIyxMkvCR5SiALAKG5CBk0P0FICdcA8MGvy8AbDXC1AHJPjNJxcpNRcpMUo6FGSYNAMgBK2cAoAFl2QkmZdMNJrnpCAtTlJnLWQUMqf7dXwAAAEBRZD51wGAFjeR+8oDkZvpAmCLkpsJnJiYNAGgg5Xq/AABoNiYXc5FvpMddTBq8gW90USvzmzB62dzksVEk+Z5sfnZ0fwEAAAB2Yq9xc85NNmr5oplyU9LvpyiZSWLSAIDioTgDABlIfVGXY9CQ7MOG1BxFmqwDhkT3FwAAABAkdVObiToWaGoaPTelyUyNVJgBgHqgOAMAdVKkTjApfdholMCR9vUWrTBDyAAAAECjynzqgFT3yQM1jZSZpHSvN/fMFINJAwCKinvOAEBGTOYoR95/RnJ2DxrJ/D40UrIbX0oj5w9neSNMWy5CkPOAIeVSmCFkAAAAoKnFZSbJ6b07Jfv7d3oVNTNJ6XOTbeHKWWGGSQMAGlSpWuVfoChr1qzRpEmT1Nvbq4kTJ9b75QBoQCY3uows0EjxYSMmaNSYho2apIHDrx6hw1VXWiYBQ6IwA6DQuAaGDX5fAKSVS2aSjHJTK2UmqeC5yUFhhtwEICsuroFZOQMAzcCgE0yy6waT0nWEefkv+LMIHlmMCKhbYSYGo8wAAADQTHKZOiA5nzwgpZ8+UBOUZ1znpqbLTBRmADQ4Vs7EoAsMgAtF6gST7LvBJHcdYaa8QSTv2cxJ5kgTMgA0E66BYYPfFwAumGQmidzk5S/e5JmbyEwAWp2La2CKMzEIGgBcKVqBRip+2MhbpgFDImQAaBhcA8MGvy8AXHGSmaSWKtDkLUlmkopXmJHITQDScXENXHb8mgAAKcReRJrc6NBiZJZVYWHYfisfS3xBXlRJvycKMwAAAIA7JteyRm+8O8xNaTJTM+WmNJkpz8KMKXITgCKgOAMAOTG9+GuEAo3UHEWaNN9DEQszAAAAQKNzVqAxYVGgadXc1EiZSaKhDUBjaa/3CwCAVmJyo0sjpje7lIyW69ve9NLLe6HeCMv30wYj5wFDYlk+AAAAYOmqJ2NGnJlkJmnomt1wxFnp1mSZSdqWQxohM0kFzE0UZgA0Ie45E4P5yQCykOssZSnz+9AEKVLocNWpVo/CjETIAJA/roFhg98XAFkwbWpzct9OySozSW5yU5Eyk+QmN1mvMMqxMCORmwC44+IamOJMDIIGgCw4CxpSocNGTT1Ch8vRAZkEDInCDIDC4hoYNvh9AZCVVijQ1JCZQlCYAVBQLq6BGWsGAHVgOt4sdqm+lMlyfSndqDM//0V/FsEjiznOieZK1yFkAAAAAM0o17HQktVoaCnbzCS5z01Z3fuGwgwAJMPKmRh0gQHIUiOsoJHcdoTZuH/K7nW5eWbSm30SMgA0C66BYYPfFwBZqktmkhomN9UKOA2Tm8hMAJqEi2vgsuPXBACwYHqRaHTRaXDxKmnoYtj0gnhY6dYUBYsUGipgEDIAAAAA5+qSmSTrzCTVJzftt/Kx3HNT4u/TYWYCgGZAcQYAGkSrho28ZB4wJAozAAAAQAKNVKCRmjc3pcpMjgsz5CYAzYDiDADUmc3FImHDrdr3kfkYM4l7zAAAAAApOC/QZDh5oKbZclMijjOTRGEGQPOgOAMABeD8otG2QJMybDRa4Ej9mm1/Zo4LM4QMAAAAtCKnBRopl8Y2qTEzk+QoN5miMAOgBVGcAYCCqGvQkFKFDan4hRpnr8/250RhBgAAAMhdZgUaB0WaomYmydFrzKCZTaIwA6D5UJwBgAIpRIEmZZFGKk7ocPo6kgQMxzeyJGQAAACg1TkfCy3l3tgmFSczOX8tGTSzSYyABtCcStUq7/REWbNmjSZNmqTe3l5NnDix3i8HQIu4qlQy2u+EnS1OutDyRRxlub+h6uHZnFfKMNhkFDAkur8AFBPXwLDB7wuAejDNTJJFbrLNTFImuSnLzCRllJuSFKwyKMyQmwDkxcU1cLvj1wQAyNFVT1oEjetlFzZqF9eOw0ZQEEgSPnLpMMswYEh0fwEAAABJnVCtGhdojHOTbWaSMslNYVnHNje1WmaiMAOg0bByJgZdYADqJZNOMClZN5iU2UqaQko6poCQAaBJcA0MG/y+AKinQuWmVspMEoUZAC3NxTUw95wBgILKZJaylPw+KI7uR1Noab5HQgYAAACQu0LlplbITFLy75PMBAAjUJwBgALLNGhQpNkmbVGGkAEAAAA0BBrbUqCZDQCcojgDAAWXWYFGSh42pOYIHGm/B8uAQcgAAAAA3LO9dqaxzcJNKmwzGwA0OoozANAAbAs0uYUNqfECR9pwIWUeMCjMAAAAAHYyLdBIbjJTI+amNCx/ZuQmAK2G4gwANIhChw2p+IHD1WsjYAAAAACFVPjMJBU7N7l6bQkaAMlNAFpRe71fAADA3AnVqq4qlYz3v+pJ6YSdLZ6gdgG90Opljea9mD8q5blcvAYXEgQxAgYAAACQr4bJTBK5aRi5CUCrojgDAE3OOmxIQxfULsKGFHyxn0XwyLLzjIABAAAANIzMCzSS2yKNNDrPZFWsySo35ZCZJHITgOZCcQYAGoxt0JAKEja8ogJBVAjJe+l/wrEFFGYAAACA+kpSoJHq3Njm1SiZSaIwAwAJUZwBgAaUW4FGyrZIE6QIs5dTzJKmMAMAAAAUQ1M0tgUpQmaScmtmk8hNAJpTud4vAACQTJKL06ueTHYhLMnNzS+LLsGNK2uS/GwJGAAAAEC2kuamRFLkiYaSMjfZIjcBaFYUZwCggSW9SCVs+KT8vggYAAAAQHHlWqCRyE0BkjYKkpsANDOKMwDQ4HIv0EjbLsobPXA4+B4IGAAAAEDx5T55QGqOzCTVpZlNIjcBaH7ccwYAmkCSWcpSipteeuU9X9kFBwGJgAEAAAA0ljS5icyUDLkJAMJRnAGAJlG7eK1L2JBGXrwXMXQ47FgjYAAAAACNqRCNbVIxM5NU96KMRG4C0DoozgBAk6lr2KgpSuhwPEKAgAEAAAA0vqSZSXLU2CY1bWaSyE0AYIriDAA0oUKEjRr/xX6WwSPDec4EDAAAAKB5pM1MksPclGdmCno+R1Ldn0fkJgCth+IMADSpQoUNr7AgYBNAcrypJgEDAAAAaE5pMpOUQWNbTYNlJolmNgBIguIMADSxwoaNIDmHhzhpizISIQMAAAAoOheZScopNxUsM0k0swFAGhRnAKDJNVTYKAgKMwAAAEDrqF27k5vMkZkAID2KMwDQAtIWaKTWCBsEDAAAAKB1kZviuchMErkJACSKMwDQMlx0g0nNGTYIGAAAAAAkNwUaqflyk6vMJJGbAKCG4gwAtBjXYUNqzMDhMlxIBAwAAACgWbjKTFLjF2koygBAdijOAEALchk2pMYKHK6LMhIhAwAAAGg2riYP1JCZyEwA4EdxBgBalOuwIRV3NU0W4UIiYAAAAADNLqvGNqk1MpNEbgKAMBRnAKDFuQ4bNfUOHVmGC4mAAQAAALSKLBrbpNGZJe/cRGYCgPqiOAMAyCxs1ARd9LsMHlmHCi8CBgAAANCasmpsq8m6WENuAoBioTgDANgq67DhlWcwcIWAAQAAALS2rBvbvMhMANDcKM4AAEbIM2w0CgIGAAAAAK88G9saAZkJAOxRnAEABKJIQ8AAAAAAEI7MNITcBADJlOv9AgAAxdaqF9qt+n0DAAAAsHNCtdqS+aFVv28AcIWVMwCAWK3UEUa4AAAAAJBEq+QmMhMAuEFxBgBgrFnDBuECAAAAgCvkJgCACYozAABr3ovyRg4chAsAAAAAWWmWIg25CQCyQXEGAJBKIwYOwgUAAACAvDRicxuZCQCyR3EGAOBE0QMH4QIAAABAvRW5uY3MBAD5ojgDAHCuKIUawgUAAACAIvJnlXrlJjITANQPxRkAQKaCLvazCB6ECgAAAACNKq9iDbkJAIqD4gwAIHcEAgAAAAAIR2YCgOZXrvcLAAAAAAAAAAAAaCUUZwAAAAAAAAAAAHJEcQYAAAAAAAAAACBHFGcAAAAAAAAAAAByRHEGAAAAAAAAAAAgRxRnAAAAAAAAAAAAckRxBgAAAAAAAAAAIEcUZwAAAAAAAAAAAHJEcQYAAAAAAAAAACBHFGcAAAAAAAAAAAByRHEGAAAAAAAAAAAgRxRnAAAAAAAAAAAAckRxBgAAAAAAAAAAIEcUZwAAAAAAAAAAAHJEcQYAAAAAAAAAACBHFGcAAAAAAAAAAAByRHEGAAAAAAAAAAAgRxRnAAAAAAAAAAAAckRxBgAAAAAAAAAAIEcUZwAAAAAAAAAAAHJEcQYAAAAAAAAAACBHFGcAAAAAAAAAAABy1F7vF1B01WpVkrRmzZo6vxIAAAAgH7Vr39q1MBCFzAQAAIBW4yIzUZyJsWLFCknS7Nmz6/xKAAAAgHytWLFCkyZNqvfLQMGRmQAAANCq0mQmijMxpkyZIkl6/vnnCaYZWLNmjWbPnq0XXnhBEydOrPfLaTr8fLPFzzdb/Hyzxc83W/x8s8fPOFu9vb2aM2fO1mthIAqZKXv8m5ctfr7Z4uebLX6+2eLnmy1+vtni55stF5mJ4kyMcnnotjyTJk3ilzhDEydO5OebIX6+2eLnmy1+vtni55stfr7Z42ecrdq1MBCFzJQf/s3LFj/fbPHzzRY/32zx880WP99s8fPNVprMRNoCAAAAAAAAAADIEcUZAAAAAAAAAACAHFGcidHV1aXPfvaz6urqqvdLaUr8fLPFzzdb/Hyzxc83W/x8s8XPN3v8jLPFzxc2+H3JHj/jbPHzzRY/32zx880WP99s8fPNFj/fbLn4+Zaq1WrV4WsCAAAAAAAAAABABFbOAAAAAAAAAAAA5IjiDAAAAAAAAAAAQI4ozgAAAAAAAAAAAOSI4gwAAAAAAAAAAECOKM6k9PTTT2v8+PEqlUr6wAc+UO+X0/B++ctf6pRTTtH8+fM1YcIEjR8/Xnvuuac++tGPavHixfV+eQ2tv79fl156qd71rndp99131/jx4zVhwgQdeOCBuuCCCzQwMFDvl9jw7r//fp133nlauHChpk2bplKppCOOOKLeL6vh3H333Tr++ONVqVQ0btw4HXTQQbr44ovr/bKawi9+8Qu9//3v12te8xp1dXWpVCrpJz/5Sb1fVtNYvHixvv71r+u4447TnDlz1NnZqRkzZui0007Tn//853q/vIa3adMmfexjH9Nhhx2mWbNmqbu7WzNmzNAhhxyiH//4x+rv76/3S2w6X/nKV1QqlVQqlXTXXXfV++WggZGZ3CIzZYfMlD0ykzvkpuyQm7JDZsoWmak+0uSm9oxeU0sYHBzU2WefXe+X0VQuuugiPfnkkzrooIM0c+ZMVatV3X///frGN76hn/zkJ7r99tu155571vtlNqRFixbp9NNP1/jx43X00UfrxBNPVG9vr6666ip98IMf1LXXXqsrr7xSpVKp3i+1Yf32t7/Vl770JXV2dmqXXXbR8uXL6/2SGs7NN9+shQsXqru7W2eeeaYmTJigSy+9VGeccYZeeOEFnXPOOfV+iQ3tU5/6lJ77/+3de0xX9R/H8ZeE4AZhFiyZbuAF0krDy7wsFbHy1k1T8zIRNJsz25xWTs2iXEQsl39UK5p5KeeFZbFS0CwlplBrQg5tbNjUVl7+IMy+pF9IPr8/miwH/ZbwPefzPYfn47/v5xzna2dnX85r73PO9+xZxcfHKzExUWfPnrUdyVfefvtt5efnq1+/fpo4caISEhJUW1uroqIiFRUVaceOHZo9e7btmJ4VCAT03nvvacSIEXr44YeVkJCg+vp6lZSUaNGiRdq1a5dKSkoUEcG9R6Fw4sQJ5eTkKCYmRg0NDbbjwMPoTKFHZ3IOncl5dKbQoDc5i97kHDqTs+hM7utwbzJotw0bNpjIyEizceNGI8ksWbLEdiTPu3LlSpvrmzZtMpLMzJkzXU7kH7/88ot59913TSAQuGE9EAiY4cOHG0mmsLDQUjp/OHHihDl27JhpbGw058+fN5JMenq67Vie0dTUZPr162eio6NNVVVVy/qlS5dMamqqiYqKMmfOnLEX0AcOHjzYcgzz8vKMJLNlyxa7oXxkz549prS0tNV6WVmZ6dq1q+nRo4e5evWqhWT+cO3aNRMMBlutNzU1mfHjxxtJZu/evRaS+U9jY6MZOnSoGTlypJk/f76RZCoqKmzHgkfRmUKPzuQcOpPz6EwdR29yHr3JOXQmZ9GZ3BWK3sSYrJ1qamq0bt06rVmzRmlpabbj+Ea3bt3aXJ81a5Yk6dSpU27G8ZVevXrpmWeeUUxMzA3rMTExWrlypSTpm2++sRHNN+655x4NHTpUXbt2tR3Fkw4dOqSffvpJ8+bNu+F7tXv37lq7dq0aGxu1bds2ewF94MEHH1RSUpLtGL71xBNPKD09vdX62LFjlZGRofr6elVXV1tI5g8RERGKiopqtR4ZGanp06dL4johVHJzc3Xy5Elt3rxZt9xyi+048DA6kzPoTM6hMzmPztRx9Cbn0ZucQ2dyFp3JXaHoTQxn2uHatWvKyspSSkqK1q1bZztOp7Bv3z5J0r333ms5iT9dvzCOjORNh7CntLRUkjRx4sRW2yZNmiSJMgzv4nvWOc3Nzdq/f78krhNCobKyUrm5ucrJydHdd99tOw48jM7kPjqTs/hbjnBBb4Jf8T3rHDpT6IWqN3G2t0NeXp4qKyv17bfftjmNRMcVFhbqxx9/1J9//qmTJ0/qwIED6tOnj9avX287mi9t3rxZUtsXd4BbamtrJUkpKSmttvXs2VOxsbEt+wBe8vPPP+urr75SYmKiBg0aZDuO5zU2Nur111+XMUZ1dXX6+uuvVVNTo4ULF+qBBx6wHc/TgsGgFixYoLS0NK1atcp2HHgcncl5dCZ30ZkQLuhN8CM6U2jRmZwVyt7EcOYmHT9+XOvXr9cLL7ygYcOG2Y7jW4WFhdqzZ0/L5+HDh2vXrl3q06ePxVT+9MEHH6ikpEQTJkzQ1KlTbcdBJ/b7779L+vtx/LbExcW17AN4RVNTkzIzMxUMBpWfn88rokKgsbFRr776asvnLl266Pnnn1deXp7FVP7w8ssvq7a2VseOHeNcRYfQmdxBZ3IPnQnhhN4Ev6EzhR6dyVmh7E2dcjjz3HPPKRgM/uf9ly9frpSUFDU2NiorK0v9+/dXTk6Ogwm9rb3H958++eQTSdKlS5dUVVWlF198UcOGDdOnn36qCRMmhDSv14Ti+F63d+9ePfvss0pKStL27dtDFdHTQnl8AXRuzc3Nys7OVllZmZ5++mllZmbajuQLsbGxMsaoublZ586d0xdffKG1a9eqoqJCxcXFiouLsx3RkyoqKrRhwwa98sorvOoAkuhMTqMzOYvO5Cw6E4BQoTM5g87knFD3pk45nCkoKFBDQ8N/3n/mzJlKSUlRXl6eqqurVV5erujoaAcTelt7j29bbrvtNmVkZGj//v266667tGDBAp0+fbpT/3hgqI5vcXGxZs6cqTvvvFOHDh1SYmJiKGN6VijPX9yc63d+/dtdXpcvX1aPHj3cjAS0W3NzsxYtWqQdO3Zo/vz5ev/9921H8p2IiAj17t1bS5cuVXx8vJ588knl5uYqPz/fdjTP+euvv5SVlaXBgwdr9erVtuMgTNCZnEVnchadyVl0JrvoTfALOpPz6Eyh5URv6pTDmUAg0K5/V1VVpebmZo0aNarN7QUFBSooKNDjjz+uoqKiDiT0tvYe3/8nLi5Oo0aNUlFRkU6dOqWBAweG/P/wilAc33379mnGjBmKj4/X4cOH1bdv3xAk8wcnzl/8N9cLW21tbatXoFy4cEGBQEAjRoywEQ24Kc3NzVq4cKE++ugjzZ07V1u3blVERITtWL52/f3/138gFzcnEAi0vJv+334bZPTo0ZKkzz77TNOmTXMrGiyiMzmLzuQsOpOz6Ex20ZvgB3Qm99GZOs6J3tQphzPt9dBDDyk+Pr7V+vnz51VcXKwBAwbo/vvv15AhQyyk879z585JUqe+AywUrpeM22+/XYcPH1b//v1tRwIkSenp6crLy9OXX36pOXPm3LDtwIEDLfsA4eyfJWP27Nn6+OOPeWeyC7hG6Jjo6Gg99dRTbW4rKytTbW2tHnvsMSUkJCg5OdndcPAcOpNdfB+GBp0J4YzeBK+jM9nBNULHOdKbDDrs8OHDRpJZsmSJ7SiedvnyZVNTU9Pmtg8//NBIMikpKS6n8pfi4mITHR1tevbs+a/HGqFx/vx5I8mkp6fbjuIZTU1Npm/fviY6OtpUVVW1rF+6dMmkpqaaqKgoc/r0aWv5/CYvL89IMlu2bLEdxTeuXbtmsrKyjCQza9Ys09TUZDuSr5w8edI0NDS0Wm9oaDCTJ082kkxubq6FZP52/ZyuqKiwHQUeR2cKDTqT8+hM7qEztQ+9yV30ptCiMzmLzmRPe3sTT84gbNTV1WngwIEaPny4BgwYoF69eqm+vl7ff/+9KisrFRcXp23bttmO6Vk1NTWaPn26gsGgxo8fr507d7baJzk5WdnZ2e6H84mamhq98cYbkqQrV660rP3zmG7dutVCMm+IjIzUpk2bNGnSJI0bN05z5szRrbfeqj179ujs2bPasGEDd2x30KZNm3TkyBFJUnV1dcva9ceax4wZo8WLF9uK53nr16/Xtm3bFBsbq9TUVL322mut9pk2bZrS0tLcD+cDhYWFeuuttzRmzBglJycrLi5Ov/76q0pKSlRXV6exY8dqxYoVtmMCgKPoTM6iMzmPztRx9Cbn0ZucQ2dyFp3JexjOIGwkJCTopZdeUmlpqQ4ePKi6ujpFRUUpOTlZK1as0MqVK9W7d2/bMT3rwoULCgaDkqRdu3a1uU96ejpFowMuXLjQqgxfvHjxhjWKxv+XkZGhI0eOKCcnR7t371ZTU5MGDRqk/Px8zZ4923Y8zzty5Eirc/To0aM6evRoy2dKRvudOXNG0t/voc3NzW1zn+TkZIpGOz3yyCM6d+6cysvLVVFRoUAgoO7du2vw4MGaM2eOFi1apMhILm0B+BudyVl0JufRmUKD3uQsepNz6EzOojN5TxdjjLEdAgAAAAAAAAAAoLOIsB0AAAAAAAAAAACgM2E4AwAAAAAAAAAA4CKGMwAAAAAAAAAAAC5iOAMAAAAAAAAAAOAihjMAAAAAAAAAAAAuYjgDAAAAAAAAAADgIoYzAAAAAAAAAAAALmI4AwAAAAAAAAAA4CKGMwAAAAAAAAAAAC5iOAMAAAAAAAAAAOAihjMAAAAAAAAAAAAuYjgDAAAAAAAAAADgIoYzAAAAAAAAAAAALmI4AwAIC8YYTZ06VV26dNHu3btbbZsyZUqb2wAAAACgs6A3AYB/dDHGGNshAACQpIsXL2rw4MEKBoM6fvy4kpKSJEkbN27UypUrlZ2drS1btlhOCQAAAAD20JsAwB8YzgAAwsr+/fs1depUjR49WmVlZaqurtbIkSOVlJSkyspKxcbG2o4IAAAAAFbRmwDA+3itGQAgrEyePFnLly9XeXm5Vq9erblz58oYo507d1IwAAAAAED0JgDwA56cAQCEnWAwqFGjRumHH36QJOXn52vVqlV2QwEAAABAGKE3AYC38eQMACDsREdHa8qUKZKkbt26afHixZYTAQAAAEB4oTcBgLcxnAEAhJ3vvvtOb775pu644w5dvXpVS5cutR0JAAAAAMIKvQkAvI3hDAAgrPzxxx+aN2+eIiMjVVpaqhkzZqiwsFCbN2+2HQ0AAAAAwgK9CQC8j9+cAQCElczMTG3fvl3vvPOOli1bpvr6et1333367bffVFlZqdTUVNsRAQAAAMAqehMAeB/DGQBA2Ni+fbsyMzP16KOP6vPPP29ZLysrU0ZGhoYMGaKKigp17drVYkoAAAAAsIfeBAD+wGvNAABh4fTp01q2bJkSExNbPYo/btw4rVmzRseOHdPatWstJQQAAAAAu+hNAOAfPDkDAAAAAAAAAADgIp6cAQAAAAAAAAAAcBHDGQAAAAAAAAAAABcxnAEAAAAAAAAAAHARwxkAAAAAAAAAAAAXMZwBAAAAAAAAAABwEcMZAAAAAAAAAAAAFzGcAQAAAAAAAAAAcBHDGQAAAAAAAAAAABcxnAEAAAAAAAAAAHARwxkAAAAAAAAAAAAXMZwBAAAAAAAAAABwEcMZAAAAAAAAAAAAF/0Pd7FYUZ+2Qy0AAAAASUVORK5CYII=", + "image/png": "iVBORw0KGgoAAAANSUhEUgAABmcAAANlCAYAAACJ1C0sAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3gUVRfG301200khgUAgJIHQlE6oAQIiQWpo0qUkKEpXUVGRrqBSFSxACGAQEOSjiaJUkSZVUekk9BpIAVI22fv9MTuzM7Mz21KB83uefXZ35k7Z2d2588459z0axhgDQRAEQRAEQRAEQRAEQRAEQRAEUSQ4FfcOEARBEARBEARBEARBEARBEARBPEtQcIYgCIIgCIIgCIIgCIIgCIIgCKIIoeAMQRAEQRAEQRAEQRAEQRAEQRBEEULBGYIgCIIgCIIgCIIgCIIgCIIgiCKEgjMEQRAEQRAEQRAEQRAEQRAEQRBFCAVnCIIgCIIgCIIgCIIgCIIgCIIgihAKzhAEQRAEQRAEQRAEQRAEQRAEQRQhFJwhCIIgCIIgCIIgCIIgCIIgCIIoQig4QxAEQRAEQRAEQRAEQRAEQRAEUYRQcIYgCIHQ0FBoNBrs2bOnuHcl3wwZMgQajQZTpkwp7l15amjdujU0Gg2WL19e3LvyxKPRaKDRaJCcnFzcu/LUkJycLBzXZx06FgRBEMSzBGkYwhKkYQoO0jD5Y/ny5dBoNGjdunVx70qRM2XKFGg0GgwZMqS4d6XYoWNByKHgDPHMwV+ciR/Ozs7w9fVF48aNMXXqVDx48KC4d5MogfBiydrj5MmTxb2rdjNlyhRMmTIFqampxb0rJZLBgwcL329iYmJx706xwwsLRx5POydPnsSUKVPoBgBBEARRoJCGIRyFNMyzC2mYZxOl/sKWx7MQLFi+fDmmTJnyRJ7viKcXbXHvAEEUF8HBwahUqRIAQK/X48qVKzhy5AiOHDmCxYsXY+/evQgPDy/mvSRKImXLlkXVqlVV53t5eRXh3hQMU6dOBcCJN19fX8U2lSpVQvXq1eHj41OEe1b8ZGRkYP369cL7ZcuWYeDAgcW4R8VPYGAgIiMjzabfuXMH58+fBwDF+YWJTqdD9erVi3SbSpw8eRJTp05FVFRUsQmcknIsCIIgiIKHNAzhKKRhSMM86xqmuPHx8UH16tWFc3hhUbt2beTm5ppNP3XqFNLT01XPBdWqVSu0fQoICED16tVRvnz5QtuGLSxfvhx79+5FaGgo6tWrVyz7UFKOBVFyoOAM8cwSGxtrNlz8t99+Q58+fXDjxg0MHz4cO3fuLJ6dI0o0HTp0eCaz4leuXFncu1AsrFmzBo8fP4avry9SU1OxZ88eJCUlISwsrLh3rdjo0KEDOnToYDZ9+fLlGDp0KADgjz/+KNJ9qlChAs6cOVOk2yyp0LEgCIJ4eiENQzgKaZhnC9IwJY/u3buje/fuhb6dL7/8UnF669atsXfv3mI5F4waNQqjRo0q0m2WVOhYEHLI1owgRLRr1w4zZswAAOzevRt37twp5j0iCKK4WbZsGQDuIqphw4ZgjCEhIaGY94ogCIIgCIKDNAxBEHJIwxAEQTwZUHCGIGQ0b94cAMAYQ1JSktl8xhhWr16Ndu3awd/fHy4uLqhYsSL69++P48ePq653165d6N69O4KCgqDT6eDj44MqVaqge/fuwoWTnMzMTMyfPx+RkZHw8/ODq6srwsLCMHz4cMV9A6QFD+/fv49x48YhNDQUrq6uqFChAl599VXcunXL6nH4559/0Lt3b5QrVw5ubm6oUaMGpk+fjqysLMX24uKAf/75J3r16oVy5crB2dlZkt2n1+vx9ddfo0WLFvDz84ObmxsqV66M1157DRcuXLC4TxcuXMDo0aNRs2ZNeHl5oVSpUqhRowbi4uLw+++/W/1MPOnp6XjxxReh0WjQsGFD3L592+Zl7cFa8UlLRbvz+z0yxrBx40bExMQgKCgIrq6uCAwMRNOmTTF9+nTcvHkTgKkYHU9YWJjEd1b83Vn7PKdPn0ZsbKywn35+foiKisLSpUuRl5enuIz4d3Pq1Cn06dMHgYGBcHV1RfXq1TFt2jTk5OSofs7C5vTp0zh06BAAYNCgQRg8eDAAYMWKFTAYDBaX/fHHH9GiRQt4eXnB19cXrVq1wqZNmywuc+/ePSxZsgTdunVDtWrV4OnpCU9PT9SqVQvvvvuu6s0WcXFJxhgWLlyIevXqwdPTE+XKlcMrr7yCq1evCu137tyJ9u3bw9/fH56enmjZsiX27t1rz6Gxmz179kCj0SA0NBQAsHr1akRFRaF06dKSIr4PHz5EYmIi+vXrh5o1a8LHxwfu7u6oWrUqRowYoXrus/R/4jl+/DgGDx6M0NBQuLm5Cd/L8uXLLX6fqampmDFjBpo0aSKct8LCwhATEyPJxgwNDRVGDu3du9fMx1lePDW//xm1c21hHYvjx49jwIABCAkJgaurK7y8vBAaGoqXXnoJc+bMAWNMdXsEQRBE4UIahoM0TP4hDUMaxl4Nw2Pv9aVcH3z//fdo1qwZvL29ERAQgG7duuG///4T2h87dgw9evRAYGAg3N3d0bBhQ2zYsEFxXxzVFIW5T2LNJubdd98VtqlUP+nWrVsIDAyERqPBZ599prju/MLXppoyZQrS0tLw3nvvoXr16nB3dxeOBQCcO3cOs2bNQps2bRASEiJ8z82bN8cXX3yh+rvn/7Nqts95eXlISEhA27ZtERAQABcXF1SoUAEDBgzAX3/9ZXHfT5w4gdjYWFSpUgXu7u7w9fVF7dq1MWbMGKG2DP+98pp36NChknOG/DtxpM8Uf78GgwFfffUVGjduDB8fH4kWLIxjYTAYEB8fL+hrnU6HgIAAPPfcc4iNjcXu3bstHkOimGEE8YwRFRXFALDJkycrzt+/fz8DwACwU6dOSebp9XrWq1cvYX7FihVZREQE8/HxYQCYs7Mz+/bbb83WuWTJEmEZX19fVrduXVanTh3m5+fHALAKFSqYLXP58mX23HPPMQDMycmJVapUidWtW5d5eHgwAKxUqVJs9+7dqp9v+vTpLDg4mGm1Wla7dm1WrVo15uTkxACwypUrs7S0NLNlQ0JCGAA2a9Ys5uHhwVxdXVmDBg1YeHi4sP/NmjVjDx8+NFuWnz979mym1WqZl5cXa9iwIatWrRqbMmUKY4yx9PR01rJlS6Ft5cqVWcOGDYXP5O7uzjZv3qz4vSQkJDAXFxcGQPhMdevWZd7e3gwAi4qKkrQfPHiw4vd848YNVq9ePQaAtWvXjmVkZChuTwl+nYMHD7apPf9dJCQkKM5PSkoSjoXaso58j48fP2YxMTHCuv39/VlERASrUqUK0+l0kn2Kj49nkZGRQtuIiAgWGRkpPOLj4236PGvXrhW+H09PT9awYUMWFhYmrLddu3bs8ePHZsvx87/99lvm5uYm/G6CgoKEeT179rTpeBcG48ePZwBY8+bNGWOM3bt3T/ic27dvV13uo48+Eva/TJkyLCIigvn7+zMAbP78+cK8pKQkyXJffvklA8BcXFxYpUqVWEREBKtatarwvQUFBbFLly6ZbS8hIUH4H/Tr148BYOHh4ax27drCsqGhoezevXts0aJFTKPRsMDAQNagQQPm5eUlbPOPP/7I1/Hi90PpN717924GgIWEhLBx48YxACwwMJA1atSIBQUFCeezLVu2CP/zChUqsIYNG7IaNWowd3d3BoD5+Piww4cPm63f0v+JMcY+++wzptFohPNn3bp1WcWKFYVlunXrxnJzc82WO3r0KCtfvrzQLjw8nEVERLCyZcuaba9Xr16satWqDADz9vaW/JciIyPZzZs3hbb5/c9YOtcWxrH4+eefhd+Sl5cXq1WrFqtfvz4rU6aMsJxer1fcHkEQBJF/SMOQhiENQxrGVopawzDm2PWlWB+8//77DACrVKkSq1evHnNzc2MAWOnSpdm5c+fYxo0bmaurK/Pz82MNGzYUzkMajYatXbvWbH8c1RSFuU9izSYmJyeHNW7cWPF3k5eXx1588UUGgEVHRzODwaD6/VmD/y8onQv488SoUaNYeHg402g0rGbNmqxBgwasZs2aQruePXsKeiA8PJw1atRIOAfzny07O9ts/ZMnT1bd9v379yXn2KCgIFa/fn1WqlQpBoDpdDq2evVqxc80Y8YM4Xfn5ubG6tWrx55//nnm6ekp2d7x48dZZGSkcO6tWrWq5JwxatQoYZ2O9pn899uqVSvWo0cPBoAFBwezRo0asYCAAOF/UxjHYuDAgZJlIiIiWLVq1YTjMGDAAMXjR5QMKDhDPHNYEzb8hYy3t7fZhdiUKVMYAObh4cE2bNggTM/KymJvvfWWcKI+dOiQMC83N1e4oPniiy/Mbl6dPn2aLViwQDItOzub1a1blwFgMTExLDk5WbKt9957jwFgAQEBLCUlRfHz6XQ6Fh0dzW7cuCHMO378OAsMDGQA2KRJk8w+O9+p6nQ61qlTJ8m69+3bxwICAhgANmLECLNl+Y7A2dmZvf322ywzM1OYxx/HIUOGCBd6+/btE+anpaUJN5S9vLzMLvZ27NghXMyPHDmS3b9/XzL/0KFDbNGiRZJpSsLmzJkzLDQ0VOiccnJyzD6HJYpD2DjyPfL76evry9atW8fy8vKEeY8fP2YrV66UHH/GTN+f0oW2tc9z+vRp4UJ12LBhEuH722+/CRcxI0eONFsnv12dTsfeffddye9m1apVwoXWrl27VPersNDr9cJxFl98de/enQFgffr0UVzut99+Ez7Xp59+Khx/vV7PJk+eLIhLpeN9+PBhtm3bNpaVlSWZfvfuXfbqq68yAKx9+/Zm2+QvBHU6HStXrhw7cOCAMO/ixYvC775r167M3d2dLVmyRLiwf/jwIYuOjmYAWGRkpEPHSr4fSr9pXug4OzszV1dXlpiYKOyDwWAQPvOZM2fY+vXrzW46pKenC4KxZs2aZsLE0v9pzZo1wn9ixYoVkv/En3/+Kdy8mTZtmmS5W7duCb+B1q1bs3PnzknmJycns48++kjxGMhFl5iC+M9YOtcWxrHg+6R3333XrG+8fPmy5LdOEARBFDykYUjDkIYhDWMLxaFhHL2+5PUBHxgVBznv3LnDGjRowACwtm3bMl9fXzZt2jThXKTX64XfTHBwsNl1qKOaojD3yZJOuHjxohA4+Oqrr4TpH3/8MQO4pLZbt26ZLWcPtgRnnJ2dWd26dSW6R9ynbNy4kR0+fNjsuJ0+fZo1bdqUAWAzZ840W7+lgESHDh0YANaiRQtJckFeXh6bN28ec3JyYm5ubuzs2bOS5fjj6eTkxKZNm8YePXokzDMYDOy3335jK1euVDwGauc3xhzrM8X74+zszPz8/CSBUL1eL/xOCvpYnDhxQuj/5ckPBoOB7d27VzFYSJQcKDhDPHMoCRu9Xs8uXLjAPvzwQ+bs7MwALttHzMOHD4XO8rPPPlNcNx/h7ty5szDt5s2bwoWKrfBZahEREaoX3126dBEunJQ+X0BAAHvw4IHZcnPmzGEAWP369c3m8cKmdOnSitlYq1atEi5C5RcG/IVadHS04v4mJSUJ4uSHH34wm6/X64UsJXHWAmNMuAAaNGiQ4rqVkAubgwcPCgJz/PjxDmWc8Ou09Jg3b57QviCEjb3f419//SWsc+fOnTZ/tvwIm9jYWAaA1apVS/G4Ll26VPjdiAWaeLtt27ZV3Cb/O3/zzTdt/iwFxcaNGxnAZeCIv4NNmzYxAMzV1dVMZDPG2AsvvMAALjtMiXbt2tl0vJWoUKEC02g0khEYjEmDIkr/r0WLFgnzlW5M/PPPP8J8pd+brdgSnAHAPv74Y4e3wWdJyi+G1f5Per1eOLeJL67FHD16lGk0Gubr6yvJ9OIvvqtXr66YNamELcGZgvjPqJ1rGSucY+Hq6soAsNTUVNXtEgRBEIUHaRjSMKRhzCENY05Ra5j8XF+K9YHS+emnn34S5nfs2NFs/r1794Rr1JMnT6odEkXUNEVh7pM1nbB69Wrhuzt16hQ7cOAA02q1TKPRWBzxZCu2BGdcXFzs1qg858+fZwBYjRo1zOapBST4oGClSpVUdejo0aMZAPbGG28I07KzswV3A6VgrxrWzm+O9pmMSbXwqlWrVPehoI8F/7tR++8SJR+qOUM8s0ydOlXwl9TpdAgPD8fHH38MPz8/fPbZZ5g4caKk/b59+5Ceng43Nze88cYbiuscP348AGDHjh2C12bZsmXh7u6OtLQ0bNu2zaZ9W7t2LQAgLi4OOp1OsU3Pnj0BcLUjlOjfvz98fX3Npjdr1gwALHojx8XFwcvLy2x6nz59UK5cOej1evz666+qyyrxyy+/wGAwoFKlSsK+i9FqtRg3bhwA4KeffhKmJycnC56eH374oeo+W2Lr1q1o27Yt7t+/j3nz5uHzzz+3WIvBGmXLlkVkZKTio0KFCg6vVwl7v0fe37ZZs2Z44YUXCnRf1OB/1+PGjVM8roMGDULZsmWh1+vx22+/Ka5j5MiRitNt+b0WFnzBzJiYGMl30KFDB5QpUwbZ2dn4/vvvJcs8evRI8LEdM2aM4nr537kaWVlZ+P777zF8+HC89NJLaNmyJVq0aIEWLVogIyMDjDHBO1eOn58fXn75ZbPpDRs2FF6/9tprZvOff/55uLm5AQAuXrxocf8KgmHDhlmcn5eXh02bNmH06NHo1KkTWrVqJRyD8+fPA+C8hW3h8OHDuHz5MsqVK4fu3bsrtmnYsCFCQkKQmpqKY8eOCdN//PFHAMBbb70Fd3d3m7ZnCwXxn1E711oiP8ciJCQEAMx+8wRBEETRQhqGNIyjkIaRQhrGRH40TH6uL8UoaRRrGsbf3x9hYWEAlI91fjVFYeyTJfr27Yu4uDhkZWWhd+/e6NevH3Jzc/HOO+8gOjrarnU5Stu2bSU1ZpS4c+cOvvjiCwwcOBDt2rUT9CpfQ+Xs2bPIzMy0aXt8v9GvXz/Fcwag3G8cOHAAN2/ehKurK95++22btmULjvaZYkqVKoXevXvbvW1HjwWv0w4dOoRLly7ZvV2i+NEW9w4QRHERHByMSpUqAeCKK164cAGZmZnw9fVFmzZtzNqfPXsWAFfwWemiHwBq164NgLu5mpycjGrVqsHJyQnjx4/H9OnT0alTJ9SuXRtt27ZFs2bN0KpVK5QrV85sPXyRr6+++gqJiYmK2+ILxYmLfIupVq2a4vTAwEAAQEZGhuJ8AKhVq5bidGdnZ9SoUQO3bt3C6dOnFds8//zzitP54/fcc8/ByUk5Lswfv6SkJOTk5MDFxQWnTp0CwF3kqH0mS2zevBkzZsyAs7MzVq9ejT59+ti9DjkdOnRQLShZ0Nj7PfLHiy8KW9ikpaUJRT3Vfjc6nQ41atTAnTt3cObMGcU2+fm9FgZ37twRBDZfQJNHp9NhwIABmD9/PpYtWyYRZRcuXBAKh6r9F9SmA1zxzk6dOqkWqORJSUlRnF6lShXF6WXLlhVeh4eHq7a5cuUKHj58aHHb+SUgIECyP3Ju3ryJTp06WQ2+qB0DOfz5NDMzEy1atLC6vqtXr6JZs2bIyMjA5cuXARTs/6mg/jOWfkdqOHosAOC9995DXFwcRowYgTlz5qBdu3Zo1qwZoqKiBEFAEARBFD6kYUjDOAppGBOkYQpOw+Tn+pInICAAPj4+ZsvYqmHOnDljpmHyqykKY59s4YsvvsCBAweEc1Xjxo0xY8YMu9fjKNY0xvr16zF06FCLn40xhvv379sU9OV/Pxs2bMAff/yh2CYrKwuAtN/gzxm1atWCt7e31e3YiqN9ppjq1atDq7X/drujx6Jp06aIiorC3r17Ua1aNbRs2RKtWrVCs2bN0KJFC9XPQZQcKDhDPLPExsZiypQpwvvU1FS8/fbbWLZsGaKjo3Hy5ElB+ACmCyslIcJTvnx5s/YAl+EWEhKCL7/8En/99RdOnTqF+fPnQ6PRoG3btpg9ezbq1q0rtH/w4AEAU4djicePHytO9/T0VJyuJirE8BeTluapXWiqbdeR4+fv74/09HQAUM0csMalS5eQl5cHX19f1KxZ06F1FCf2fo/5PV72Iv4d2PLd2vu74T8nY8zmfRo9erTiRfiHH36IDh062LSOlStXIjc3F+XKlVPMUho8eDDmz5+P48eP4++//0adOnUAmD6fk5MTypQpo7hutf+XwWBAz549kZSUhPr162Pq1Klo2LAhAgIC4OLiAgBo1aoV9u3bB71er7gOteMozga01saeY+0IatvnGTp0KE6cOIHKlSvj448/RvPmzREYGAhXV1cAXBbjd999p3oM5PDn07S0NOzfv99qe/6cyv+XgIL9PxX2f8YSjh4LgOsz/fz8MHv2bBw6dAjffPMNvvnmGwBAkyZNMGvWLLRu3drufSIIgiDsgzSMOqRhSg6kYZ4dDZOf60uewtAw+dUUxaWrPDw8EBkZKQRnhg4dqjoSsTCwpDGSk5MxcOBAZGdno3fv3hgzZgxq1KgBHx8faLVaGAwGODs7A4DdWu38+fPCaCY1xKNxCuuckZ8+k8cRnQY4fiw0Gg22bt2KWbNmYcWKFdizZw/27NkDAHB3d0e/fv3w6aefIiAgwKH9IgofsjUjCCO+vr5YsmQJmjdvjgcPHmDEiBGS+aVKlQIAIcNGiZs3b5q1B7iTZVxcHE6ePIk7d+7gf//7H8aNG4dy5cphx44deOGFF3D9+nWhPR/Z3rVrFxhXG0r1kZycXBAfX8Lt27etzhN/Pltw9PjxWRB8lp29jBkzBoMHD0ZKSgpeeOEFm62QCgJrF2WPHj0q8G3m93jZi/h3YMt3a+/vxhFOnTqF/fv3mz0s/a7l8HYAt27dglarFexD+Ef9+vWFtvHx8cJr/vMZDAbcvXtXcd1q+/Hnn3/i9OnTcHd3x6+//oouXbogKChICMwAto8WeVK5desWtm/fDoDLGO3bty8qVaokiCjA/mPAn09btWpl9XzKGBOG44szsAry/1Sc/xlHjwVP9+7dsX//fty/fx/btm3DhAkTUKVKFRw+fBjt27cXsr0IgiCIooM0jAnSMAUDaRgTpGFMqO1Hfq8vC4PC0BRFxaZNm7B06VIhwPf+++/jypUrxbxXHGvWrEF2djYaN26M1atXIzIyEv7+/sIoEUeOKf/7WbZsmU2/H57COmfkp8/ML44eC37ZGTNm4OrVqzh//jyWL1+OgQMHQqPRYNmyZYiJiRFGxxElDwrOEIQIJycnzJ8/HwDnGcxHmwGgRo0aALhsAbUhnHyWmJubm6pPZ5kyZdCtWzfMmzcPZ8+eRVhYGO7fv481a9YIbfhhkn///Xc+P5Fj/Pvvv4rT8/LyhGGe9mZw8cfvv//+g8FgUGzDH7/KlSsLN6T5bJ6UlBScO3fOrm0C3He6bNkyDBs2DCkpKWjbti2OHDli93ocgc+YULuQdeTzWIM/XgcOHCjwdSvh4+MjZJX8888/im1yc3MFK4CiyPzbs2dPvoTAoUOH8N9//wHgMsTUHn5+fgCAVatWCV6z4eHhQrYQvw45av8v3sqsZs2ailktDx48KJTfTEmCPwalS5dWHFKfm5uLo0eP2rVO/nz677//qp57lChVqpRwHrfn/2TNC744/zOOHgs5Pj4+6NChA2bOnIkzZ86gadOmyMnJwdKlSwtqVwmCIAg7IA3DQRqmYCANw0EaRora/6ugri8LksLQFEXBtWvXEBsbCwCYN28eunfvjtTUVPTv379E3Fjnj2uLFi0UR8EdOnTI7nU62m/w54x//vnHLvtAa1qtoPpMRyioPjQ8PByDBw/Gd999h0OHDkGj0eDAgQOqdWuJ4oeCMwQho1GjRujcuTMAYPLkycL0Fi1awNvbG1lZWfj6668Vl50zZw4AoF27dpJsdzVKlSoldCrirDO+eNiiRYtUh/wXJkuXLlXMiPrhhx9w8+ZN6HQ6tGvXzq51vvTSS3BycsKVK1eEItticnNzsWDBAgBAp06dhOkhISGIiIgAAMycOdOubfI4OTlh8eLFGDFiBB48eIB27do5dOFgL1WrVgUAHDx4UHG+2u8oP/Ts2RMajQYHDx6UCHNreHh4AFC3mLAE/33Nnz9fMcPuu+++w507dxz63RQHy5YtA8Bd8N26dUv1ce7cOeh0OqSkpGDTpk0AODHbqlUrAMCXX36puH7+dy6H/w5u376teBznzZuH3NzcfH++kgx/DNLT0xV/iytXrsSdO3fsWmeLFi0QFBSElJQUSYagLfTq1QsAd+x5f19r2PJfKq7/TH6OhRparRZNmjQBIO3HCIIgiKKFNAxpmIKCNAwHaRgpahqmMK4v80thaIrCJi8vDwMGDMD9+/fRpUsXjBkzBkuXLkVwcDD2798vsbMsLvjjKh45wsMYw+zZs+1eJ99vrFy50q5RYs2bN0dQUBCys7Mxd+5cm5ezds4ojD7TVhw9FpaoXbu2UDuJtFrJhYIzBKEAL2h+//137Nq1CwB3wfLWW28BAKZMmYKNGzcK7XNycvDuu+/i999/h7OzMz788ENh3n///Ye4uDj88ccfZpkkv/32G3bu3AmAE1Q8r776KmrXro3z588jOjpaMXL+77//4qOPPsKWLVsK5kOLyMjIQP/+/QXPS4DLYho3bhwAIC4uzqIHpxIhISEYNGgQAGDUqFGSAmcZGRkYOnQoLl26BC8vL+E483z22WdwcnLC8uXLMXbsWLOhq3/++Se++uori9vXaDRYtGgRxo0bh7S0NERHR9vkiZsfunbtCgDYsmWLJKswKysLH3zwgV3Cw1Zq1aolZFf17NkT//vf/yRiIysrC4mJiWYF5vhihjt27LB7m+PHj4ebmxv++ecfDB8+XCKKd+3ahbfffhsAMHz4cLt/N0XN48ePsXbtWgCcv68lAgIC0KVLFwAmMQQAEyZMAMAV8pszZ47wv8/Ly8P06dOxe/duxfU1a9YMOp0O169fx6RJk4TsKIPBgEWLFuGTTz6Bm5tb/j5gCef5559HQEAAcnNzMWrUKElAZP369Rg9erTdx8DFxQWff/45AM7Le/78+RKPXgB4+PAhfvzxRwwbNkwy/Z133kFgYCDOnDmDTp064cKFC5L5ly9fltwAA0z/pX///Vd1OHxx/WccPRbp6el4+eWXsX37diHDkufYsWPCf0bcjxEEQRBFD2kY0jAFAWkY0jD2aJj8XGsXFoWhKQqb6dOn4/fff0dQUJBgT1e6dGmsWrUKzs7O+OSTTwrlv2cPUVFRAIB169bhp59+EqZnZGRg2LBh+PPPP+1eZ+fOnREdHY379++jTZs2Zv9xgKvB9dlnn0lG6et0OiHwPW3aNHzyySeS3x1jDDt37kRiYqJkXfw5Y/fu3YojvRztMwsCR49FYmIiJk2aJIz049Hr9fj888+RmpoKZ2dnia0hUcJgBPGMERUVxQCwyZMnW2zXqVMnBoC1bNlSmKbX61nPnj0ZAAaABQcHs0aNGjEfHx8GgDk5ObFvv/1Wsp4TJ04I7T08PFidOnVYo0aNWFBQkDA9JiaG5eXlSZa7cuUKq1+/vmRbTZo0YfXq1WO+vr7C9ISEBMXPJ5/Ok5SUJCwrJyQkhAFgs2bNYh4eHszNzY01bNiQVa1aVVimSZMmLD093WxZfn5SUpLqMU1PT2ctWrQQ2oaHh7OIiAjm4eHBADB3d3e2efNmxWWXLVvGdDodA8B0Oh2rU6cOq1u3rnDso6KiJO0HDx6s+j2/++67DADz8vJie/bsUd1fOfw6Bw8ebPMyPXr0ED5vhQoVWEREBCtVqhRzc3NjixcvVv0u8vM9Pn78mHXt2lWY7+/vzxo1asTCw8OFYyhf7+zZs4X2NWrUYK1atWJRUVGSdpb2ae3atczFxUU4rhEREaxy5crCOtu1a8ceP35stpy1301CQoLi91tYrFixggFgLi4u7O7du1bbb926VfjvX7t2TZj+/vvvC5+tbNmyrFGjRiwgIIABYPPnz1f93B999JEwr0yZMiwiIoKVKVOGAWDDhg1T/Q6sHSdLvxce/v+/e/duq59bDX4/lLaze/duBoCFhIRYXEd8fLywDh8fH9awYUNWoUIFBoC1b9+eDRw4UPG/be0zzp8/n2m1WgaAubm5sbp167ImTZqwKlWqMCcnJ9V9O3LkCCtXrpyw7qpVq7KIiAgWGBiouD2DwcBq164tnPMjIiJYVFQUi4qKYjdv3hTaFdZ/pjCOxYMHD4T1ubi4sOeee441btyYhYaGSvqGR48eqe4TQRAEkT9Iw5CGIQ1DGkaN4tYwjDl2rW2LPrB2rNW+X0c1RWHuk9rvYu/evczZ2Zk5OTmxXbt2ma1v8uTJwn/x3r17qvtlDX6/lM4Fls49PHl5eax169bC5w8LC2MNGzZkHh4ezMnJia1cuVL12PCfQWnbDx48YC+++KLZb69BgwaCFlbbt+nTpzONRiOci+vXr89q1arFPD09Fbd3+PBh4fdYoUIFFhkZyaKiotjYsWOFNo70mYzZ/r8v6GMxb948yXmzQYMGrH79+pI+d/bs2Rb3iSheKDhDPHPYKmyOHDkinMh27NghTDcYDGzVqlXshRdeYH5+fkyn07GgoCDWr18/dvToUbP1PHr0iMXHx7P+/fuzGjVqMD8/P6bVallAQAB78cUX2cqVK81EDU92djaLj49n0dHRrEyZMkyr1TJPT09Wo0YNFhsbyzZt2sQyMzMVP19+hM3u3bvZqVOnWK9evVjZsmWZi4sLq1atGpsyZYrixSljtgkbxhjLyclhixYtYs2bN2fe3t7MxcWFhYSEsGHDhrFz585ZXPbMmTNs+PDhrEqVKszNzY15e3uzmjVrsmHDhrF9+/ZJ2lq7uJg4caIgNsXfryUcETbZ2dls+vTprFq1aszFxYUFBASwHj16sL/++svid5Gf75Ex7nf6ww8/sA4dOrCyZcsynU7HAgMDWdOmTdmMGTMkN4kZ4y60Zs+ezerWrSsITfnxs7ZP//77LxsyZAirVKkSc3FxYT4+Pqxly5ZsyZIlLDc3V3GZkiZs+IvNnj172tQ+NzdXuEkxY8YMybwffviBNW/enHl4eDBvb2/WsmVLtnHjRsaY5c+9dOlSVq9ePebq6sq8vb1Z06ZN2dKlSxlj9l/o8zxJwRnGGPvf//7HmjVrxtzd3ZmnpyerW7cu+/zzz5ler1f9b9vyGU+fPs1GjhzJatasyTw9PZlWq2WBgYGsdevW7NNPP1U9B6WkpLDJkyez+vXrMy8vL+bm5sbCwsJYt27dWGJioln7q1evssGDB7Pg4GDhZoLS910Y/5nCOBa5ubls1apVLC4ujtWqVYv5+/szZ2dn5ufnx1q2bMm+/PJLlp2drbotgiAIIv+QhiENwxhpGNIwypQEDcOY/dfahRkIYcwxTVHUwZmUlBRWsWJFBoB98MEHiuvLzc1lLVu2ZABY165dVffLGvkNzjDGBVEnTJjAwsLCmE6nY2XKlGGdOnUSAsaOBGcY4/7L69atYzExMax8+fJMp9Mxd3d3VqVKFdavXz+2evVqlpaWprjsn3/+yQYOHCj8h/38/Fjt2rXZ2LFj2V9//WXWfvPmzax169bM19dXCNTI/6v29pmMFUxwxpFjceXKFTZ79mzWqVMnFhYWxjw9PZmLiwsLDg5mffr0Yb///rvF/SGKHw1jCuaaBEEQBEEQTxCXLl1ClSpV4Ozs/NTX5iEIgiAIgiAIgnhSmDRpEqZPn464uDiJJRdBEFRzhiAIgiCIpwDeX7506dLFvCcEQRAEQRAEQRAED2k1glCHgjMEQRAEQTzxrFu3DgDQoEGDYt4TgiAIgiAIgiAIAgDS0tLw66+/AiCtRhBKaIt7BwiCIAiCIBxl9OjR+Omnn5CUlAQAGDlyZDHvEUEQBEEQBEEQxLPNrVu3EBMTg7NnzyItLQ2VKlVC165di3u3CKLEQSNnCIIgCIJ4Yjl16hRu3LiB+vXrY/Xq1ejSpUtx7xJBEARBEARBEMQzTVZWFv788084OTmhR48e2LVrFzw8PIp7twiixKFhjLHi3gmCIAiCIAiCIAiCIAiCIAiCIIhnBRo5QxAEQRAEQRAEQRAEQRAEQRAEUYRQzRkrGAwG3LhxA6VKlYJGoynu3SEIgiAIgiCIQocxhoyMDAQFBcHJifK5CMuQZiIIgiAIgiCeNQpCM1Fwxgo3btxAcHBwce8GQRAEQRAEQRQ5V69eRcWKFYt7N4gSDmkmgiAIgiAI4lklP5qJgjNWKFWqFADuIHt7exfz3hQfP/v4CK87VDa+eNH43ApgkdzLv0tXAwD8i+dxAVUBAJcQhiRwC11DBdxLqgjc0ABXjMvfAnAHQIrokQbgEYDb/FYzAWQAeABAb3zOBZAF4LHxmW+Xa3zNPwOmn7oWgM74rAXgbnxoAZQC4Gac72d8784t5gbAB4Cn8dnf+IDxuSyAcsb3lYzPQQwBYdcAABVxHQAQhksAgMpIAgCE4zwA4Hn8CwCoc/8cAECzX7Trv4te7+Cefr4EMzqkpZlPJAiCIEok4n4VEPWtgKR/5ZH3swDX1wKQ9LcAkITKuIYKAMD1uYDlfheQ9r1pMHWruAOub80C1w/zrzNh6nP5Z/6hNy7L98N8Hyzuf/k+2A2Ah/G5lOjZHVznClMfHART/+tvmo1y4PpeY78r7nPF/a24rxX6Wb6PVehfn/V+NT09HcHBwcK1MEFYgjQThzXNBEjP5+LzuPwcbnb+tqSZFM/bGTCdu/lztfh8rYe5bhLfHuDP0/w528P47GacXkr22t10vgbMdROvnXjdJNJMAMzO34B1zQTA/HwOCOd0wFw3PevndoIgiCcJmzQTYNbHAtL7k4CyZgJg3ucC0n4XsKHvteWepVw3AdK+GDDdswS4TpV/z/e3vH7yg9m9S7X7luJ7lkbNBJj6XXGfK+5vze5P/g7STAoUhGai4IwV+GH53t7ez7TQ8BC99nY2vnARzTwBsCjAyzjTHTq4whUAoIM7nOEJAHBCKaCUN+CpAWoCSAZ3AnE1rk8HwBlcNSQncOedTMB0ctICSDc2yjA20hkXlt8MgmiaeHkYV6yTPTsB8ILp5hEAGL9zjWifnAGkGpsFGPfdzTitPGD8qEApBidv7s/Jf36dMdjDHxt34355gTtu3vyudwA0e42v+eMM47YB9KkKbDkPCc/y75MgCOJJYotGI+lXAVHf2l40UdSIGU/xfD/7N2rDHcA5VIcrgAuoAh2Aiwg3dqPGi8NS3sA1Ddc3uQO4CfN+9y6kfa/gSJQJ7mJfC4ABMBgffCONsQ0gvdHnLJoGmAdn+GdeYHgYnz3B9bvuAAIhJEhojLvhDFOXz/e95QGEGldv7HfFfa4rXFENZwHoJH2thj+2fB/rbHbIqV81QhZVhC2QZuLowxi2GI/F3ktAl6oAdoM7txtPMOLzOa8FXOEq6ARneMIJpVC2bhruXArmdBN/L4Y/d5cHdz7k5ZDkb6oDl+HG6yYncOdqjXEef2NIC+nNIX5ZHvn52gPSc7crTMF047k7y7gZX5h0k/i8DZjuL90Fd/4uxd0kcvIuhRuogWBctaqZLnnXRL37pwGIzucvAdhlfM13Q4B5f/sM/z4JgiCeNFQ1EyC9VybrY0+Wrgkv4yxxXwtA0t9eRTBXjL2UcUFPDXefUtzv3oPy/UqzvjcT3D1FPYAc43ve5orXTe4waSa9cYW2aCb+viX/4O9dloLQB8vvW8p1Ey+3SjGUrXwVAKebdHBHOC4CcIU7dKiDUwCchfuTGg9w/asLSDNZID+aiQykiQKFv0gGYLwZwlEFF6QNKzI71+yuMl0eX9QpTBOTK3uvlz1nyuZnmp5SwUXHUxVWe1thGsAJKjs4WbqmXe0JgiCIJ58uVS3PZ1GOrVe1D1Lqs1Jh6uPkXaEwQR6I0cvey/tYS+hEz3xWtgLu4G7yWUN2XWF23WFEfJ1CEARRpOwyn8TdAOHgboxYIRBcchiPL7gbLb7yhu6yZ0AaeFFDrzJdaVQkoNBhqHPP9qYXEQ6ASzwAuEQEgEtMsAlRooO8j91CAWeCIIgnApvP1y+oz1LrN/h+xmas9mHpotfivtFejaQGH6zhgzRQeBZRBtz1QgC4aweiREPBGcJu5CM27IWL0DoCH42Vj4KRv1ZCLjTEAkM+T02UqCA/SSebN7kKy0Eai0LDQkcjhoQGQRDEU4TCud/WAL5in5Mse2/zTTL5jTdrfaY1SzNxMEZ+o9CCwChMtnNP4uubLszeJBKCIIiCIxgKeqm8lYUsnjp1stdKCW62wJ/bxX2DlQDNXZj3OTdFr6/lT8M4msBAEARBPFlIgu3tzedb6g/4ID8f9FdE3B/dVJh/F9Jk7UyzF5BqIXFQpiACNGL4fls2akUxYQOK1xCK1xoyBEcfolCh4AxhE4o3Kbart3c4CywAXITXFyZ7MAmOZoGpoXaytDELzE6hoZYFpoRix2IhC4wgCIIo+RREIP3pyQADTBlgPI4HZuTJH+LrD/F1iYBCFjtBEERho3SjQ+w4oEioI1uS6yaloLg9CW7yYLy4X0iHmX4qDscBG5PaCIIgCILHLrcBQL1/M3MbAKy7DVjTVDa6DQDW3QZCzSeJ3QbE1yLkNlC0UHCGyD923NywOwvMF8WcBZau0NbIXRtXTRAEQRAqlLgMMAG1DDAx+Q3QiPtseWcve+8Lm4bn25IBRhAEURJQu/FRsHbQSo4D8nlqKNXxtPTaCG8HrURhOw6IUehTechxgCAI4inhiXcb4Mmn24Ct+W12X1OYQ24DBQ8FZ4gCw6HhbqGObCm/WWBiSzOgUOrOJJteFmUWGAkNgiAIAshHBphqvRkepXoz4mdbULshqDA83xcqI2nBJXeEWt6SUjY6Dc8nCKIwEd+oEG5gWHAcsISqHbTcccAMteK81upzAsp20PJ5DtadEVPYjgMiyHGAIAjiycLivS0LwXcxfDBfntCm6DZgKaHtWXcb4HHwWoawjrUrM4Kwm3r3TwsBhmo4K5wIq+CC9CRYkdnpMewO00lPB+6kp4W5gNApTFNC7ocvRg+7Tnb3IC3OqcBVBCMYV3ER4ZKMuHOojmo4i79R2/KJkKc9hJNil6r5rwFEEATBo9frkZeXV9y78dSyu0YNaENCJNOygkRvfIzPTUyTWA73fNq3CpDFvXZBaQCAF7wAAKXhAgB4CA1uoRwqIhfuecbGGg1wHYALAA9w9+seAPAzbsDJOD3TOC+L33I2gIcADACY8b0BpoQIfro7uD7TA1y/6mJ8djWuh8/4cjYu52x872Zs42pcBqL1ZwPQcE28jZsoBaC0cb+9jZtz4ZqhHAPyAP+sXARCg0q4DMAFXvASjhVDedRMvYgshEDDd/v88TZ+B9oc03HPysrC04azszN0uvxYwRIEkW92wSzpqg5OCTeQwnHR8mhI3nFAHmjng9iSOIm7cQL/rIP0JpGSjrJErmgdvBbj9RLfiajAOw5Y0UsEQTx7kP4glJBrpjahIpniI27IPbFGAIzX8iyL6ywtaSYAYMY+0T0vi9NM3EKczgC4bo3/ad43TteA0yO8fEEWuL40z7QDMMCkeVyM792MbbTGZ3mCm1b0LNZNvF7SghNu7uA0WJ5xuxpuH3jdVBYmzeQHwB+mz6UBkMfgn5WLcrgFQIPSRs3EHy9mvNDIynkMAJxuOgzumJNmKjQoOEM4xJbz+ctAKlv5qnpWr/ii/ZF8pjeUs7P4gIwlgaGHcrYunwWshXkWmA0BGnlg5iZMwumaRnXY4AVUsVqPh0VRhi9BEIVPeno67t27h+zs7OLelaeast98I3nvoQWS+Dfi7sZN9DqFe3J+wPVfOdAhHEAudAgFoIcWkcb3AJBr7DgNuclcl+YFoAq46/dQ47PB+IDoPRM9CzOY6FHKuDKm8IB4QdFrcQKGRvSs9HAybjMVXCKGxjSZf3Y2PvOvec1j1AROSbnQwgNaYyKIDqHQoil3/KBHssF4jcAHxDqA6+pbA49zOS0DAB4hIUhKEr6ZpwpXV1cEBATA29vCTVSCIEoEwbgqtVwJhaINmAQ+DiOZwCedZRin8QEWXjfxGkgtwY3XSfxrsZ7KFL3nk+hEHRqv5Xxlq7wNkz1lMoRRkHcuBauPFlLgZOma5vZwL8Am2+0tGg3ZsRBEMUL6g7CEXDclKbkhK2imHCcdnB9wr8ON/VOo8Vlv7MvEusmQqwWQzMmcXHC6KRSmoAyvm1Q1E4NUXLnAFJCxppnE723RTU6i14/BJbQZ+3ZnWNdMAJBl0kwAoEW4UVUCWjSFs/E6INmg5zRWEDjNBACtuaeyotumpJkKBgrOEDbThTHzoYXboTqk0K4sMMB0gS4eMlgkWWA8allggGqQ5i44OwErFKXQIAiCsIf09HRcv34dXl5eCAgIgE6ng4YsEguF9EfSjANvF9EbL9FrY9/HRHZeWVqucZZRheQYR6ZkG9O2cozPeqPYyMvRAXrj95gDU7eYB1OyVi5MOoIXHUz+RqxE5OJDPB0wFxs8coEhjrjIoy8aCDf6NJCKC16A8SJDB0776BicXbgPqIMeLsiBqzFzzQXZcDNGb9xyuWmaRzDdMHzIPaWLMsC8w8JUPseTC2MMer0eaWlpuH79OgBQgIYgihHNXnMrLrHjQP4ROw6IkeskSwEZXmOJAzNqr2VaiZdQvgqrdsBxgE9qI8cBgng6IP1BWCL9n3/gHyDtKATdpKCZAJNukmsmQFk3STQTwOkmXg/wA2H4AI1YN0k0E4wT+JlircQUnu3VTGLtxD8AkxgyCiW5ZhIbF8g0EwA4u+gFzQQArsiBC7ggqRuypJoJeKZ0U3FpJgrOEAWDwhB9NcyywABupIlSoWIeh7LAIJqenyww/r27+aAdX9l7lSwwHl5oqOGI0JBDWWAEQdjDvXv34OXlhYoVK5IoKmTkg77dxIeb745KmSYxo6bI1LoJJmF5cAYAGGSZX07GFhrokJvtwo1+12i4hCq+OxPfY5PrBUkGGK86+AUMCg9+QUAqNMT9j0bhWS4wxGldTsaddTI1F4sKPutLXGrOBYALg8bVCS7QA3CFMxi0xn3TIReuxvW5GfM4NHoAfJKmcdfE342bmzgN7+nB3d0dpUqVwrVr13Dv3j0KzhBECUBsBy0m/3bQgMlxQMkO2hFLM3lgpugcB2yBHAcI4smB9AdhCblm8nUVvRHfxTYGbJhYP2m5635eM2XDzTjQ3hXOxvdcahifDOYC5Gi4dTGYbg3yP8tccBqElz+8nBG6K72oAR+1UdJMSsEZ+agZawlt4mQ2PvriZB7HEY+W4XWTUTMBgMbVCU5wgrNx+1oYoDP26a5wkmom7iAKu5WabXJ0842IwNNIcWgmJ+tNCMJ2lC6IlYrxSgi1ZwtqF/3yOKO1YpdKhS0tvbYBq0XCTFCBS4IgSgJ6vR7Z2dnw8fEhYVTIpB49qj7TR32WmMdWbnzlKFp3yrC5JqVB9FrpZpmauJAj/10pZYSp4GxxB01lbVRnm3sgazIUGj5DaDQa+Pj4IDs7G3q9I6OMCYKwl/yO1lAdfR9gfJQBlzDmqdyMQ9x/5MdHXe6RL85cU7KeFnHX8mweVetrFZSCW7YmDVosOE0QRKFA+oMoKeRmu1hvBJhG0ojtzQAo6yU12zJHkY+iAUy38u2/pa91zZG8d4PJVtDDUj+eZvemnmiKWjNRcIYoFMwsuYxUwQXphHxkR5mEhTgIY20wmPxPpRSkkWeB2YA8MCMeBWR3lhtBEETRwBffpCLhRY+vlcCCOANMTrZxmH6WMERfYaRHjqjvsdT1KSIuyioeHSMfGWMJpaCMWkBGQWDYeoXqIt0fpYAMALjnPn3FKh2F/79T8V2CKDwUR7Hzo94VbIrFI+fFNSkVR9yXh2mkvhxfyHLZ5EF9ney1UoKbEpYS2MR6KR1m+ukRuHJicm6LXiebzzZzWpDB22dbRWTBTUltBFH8kP4gLGExoU2MglbK1OZj9Lu49JHNyWw8am4C1pLYHIHXUbIsNn7EDN+18w8FXOxJRH+GE9uKUjNRcIZwmOLLAhMPKSuqLDAlz2YjlAVGEMQTDmWtPYPwGWBm2JIBVhDIgzSOX5LKM8BsxpgBlioSY0/r8Hwx9H8niCeYUEcWcpc9i70hIZqmhoOOA5lQDsoAxeY4QBBEyYGuRwhbkCS02ek2oJbQVnhuA4D6yBlrrgOF5Dagg1W3ASWedbcBoGjPURScIezCYhaYAmpZYIoUahaYGrZmgYne80KDssAIgiAIGyjKDDDJ8HxbMsDMhueLZ4hhsH/0jBoaldeA2aWpPAPMBqwOzyexQRBECcCaHXT+HQeUktiUTqS2jJYByHGAIAiCKGlYchuwCWtuAxYHmSi5DQBPmtuA+DW5DRQPFJwhCg6FIfpqmA3RD3Vkg/ZkgYmnK4mJQq47oyA0KAuMIAji2eTJygBTqzdjqailGmo33SwMz5ej1NXDzuH5BEEQJQg1O+iCQa2IrT3JbHKewLoz7c1n85DjAEEQRMmg2BLaJDMUpjlcb6YgKDi3gfwidhsgChYKzhAFjrUsMDOKPAtMThFkgTkIWZsRBEE8exR6BphFCrLejHw4vvi1BaHBeyZbwsrwfKXaMzQ8nyCI4qJQ7KADoWwH7StvKE9os8fSTIwDdWcKyHGAT2rj4ZPabHYcEEGOAwRBEE8G+UloU0MxoS1fQQdrbgP5CdjY4TYAWHUbkFtBW3Ub4Ekzn/QsWEEXJRScIQoNtSwwsyH6DlEcWWAqdWdSYZ4FdluhHWzPAiNrM4IgiCeP5ORkaDQaDBkyRJhWIjPA9EDXdqHo2j6Ue19kGWBKSQP5uBS1MDxfDA3PJwiiOLBoB63gOKBmB23mOABwdtCWMLs3JQ/QAJxuUkpmU5pWvI4DPLzjgCXIcYAgiGcJJf3xrMAntA2NnQJPTW1cTr4uzFNzG7AJUZe3eNEUNKqjwbEjexQaWrMys7XejBwH3AbkiW0F4TZAiW1FBgVniHxRKFlgQAnNAhO9F2eBWSPZfJI8C4yszQiCIIqf2NhYaDQa+Pv7Izu7cMZtl7gMsDylieYZYFlZj5CYuAgTJw5Hr16RaNy4PBo1Ko8bN1T6cVWURs4o+CfbiDwDDAAunbuEuN6jUT2gEfzdI1CvQX98/c16MPGNUmMGmHh4vjwD7Nq1axg+fDgqVaoEFxcXBAUFYejQobh61d7PTBAEUYCEFsRK5DpJLcGtGOrOOAg5DhAE8SRSWPqjdevW0Gg0Fh979uwR2k+ZMkWYPn78eNX1vvfee0K7KVOmqLb7/fffhXbr1q1Tbbd8+XKh3eD33lNt9/V366EJaQRNSCMMGam+XbsopHoz33zzCUaM6IlOneqjRYtQvPjicxg0qD2+/34xsrIeK6yreNwGsrOzMW/aPDSu2hal3RqiStALeO31j3H30n0rK5ViMBiwcOFCNGjQAB4eHvD29karVq2wefNm1WUOHz6MmJgYBAQEwNXVFVWrVsWkSZOQmal8HfHgwQOMHz8e4eHhcHV1RZkyZdCrVy/8+++/du1rSYWCM4TdWMwCU0AtC0yR8uCG6KvhUBZYMdadKeYClyQ0CIIgbCMjIwM//PADNBoN7t+/j40bNxbLfliyNCuoDDDLKNebuX//HhYsmIzt2zcgJycL3t6+sjb5zQBTwBnS4fk25Fxc+e8UOjfujF827UB0h0i8MaY/8gx5GDn6U4yZMNvmDLCLFy+iYcOGWLx4MWrWrImxY8eicePGWLFiBSIiInDxopXrGYIgCBuwZgdt5jhQ4HbQJaDuDDkOEATxjFIU+uPtt9/G5MmTFR+hoaFm7bVaLRITE5Gbay4ecnNzsXLlSmi11vuO+Ph4AIBGo8GyZcusttc6O+OXffuQkppqPtMHiF+7GVqtciTio0/fxvHTmxBUoazV7QB21psxiJ5tdBtYty4ejx8/RNOmUejb91VER3fjAiHzpiA2tiuysvh+0h59VLBuAwaDAbExsZgzeQ78A/wwctxANG9aC/HLNqFZ+1jcvffAplUzxtC7d2+MHj0a6enpiIuLQ9++fXH27FnExMRg4cKFZsts2LABLVq0wPbt29G+fXuMGjUK/v7+mD59Otq1a2cWpExJSUGTJk0wZ84clC1bFqNGjUK7du2wZcsWNG7cGIcPH3b82JQQKDhDFCwKQ/TVMBuiH1oQO6CUBWYLhZQFpkBhFrgkoUEQBOEYa9euxaNHj/Dmm2/CyclJEBT5IT+WZg5T4Blg3BB9X18/LFy4Djt2nMHmzUdRs2ZdKzuilAEmfi8XHc7Sl2pXqFaG57//xvtIT0vHio3fIP67mZjx6Vs49mciWraoj4VLfsDBY39b2W+OsWPH4s6dO1iwYAG2b9+Ozz//HBs3bsTatWtx584djBw50qb1EARB2IKaHbQYVccBQOo4YIbYDlopWGMPhVB3Rkyy+SRyHCAI4mmlMPSHnPHjx2PKlCmKD6XgTIcOHXD79m1s3brVbN62bdtw69YtdOzY0eI209PTsX79etSpUwcvvvgifv31V6sjz19s3hw5ej1+2LYNgNRt4O/T53Hs1Gl0bBOpuGz58mVQvUZl6HVcfydPYOPf2+Q2YHNCG6BWb2bbtlNYvvwXfPTRfIwa9SHeffcTrF27Bx069MT58/9h8+a1Nqy7cN0GNqxYg73b96JHvy7YdmAdps96E+t/+AyLvnwPl5KvY+LHX0vap6oM6vrxxx/x448/IjIyEqdOncKXX36JxYsX499//0VISAjGjx+P5ORkoX1mZiZef/11aDQa7N+/H6tWrcKcOXNw8OBBjBw5Evv378e8efMk25g8eTLOnz+Pt956CwcOHMCcOXPw/fffY8+ePcjOzkZsbCwMBvl38WRBwRmiUHh6s8AcrDuTbL4ILzTUcKTAJUEQBOEY8fHx0Gq1ePfdd9GmTRvs3LkTly9fVmybl5eHTz/9FOHh4XBzc0N4eDhmzpypelG47+hRjJo2DU169YRXk1bwatIKEQMGYfH3GxTbO+kaIerFN3Dj+m0M6f8uqgc0QmipuhjUaRAuX+L26cLpcxjdrS+ala6EyFIBeLvHAKTcVkg/VhIYeRCSuzLSUvHJJ8PRvn05REaWwoABTbB9Oy8YTH2zh4cXmjRpDR8fP8V9NmEpA0ytkKWT+SRLqAzPTzp3EYd/P4wWbZrixQ6mO3MuLjpMf284AGDJ6o1WV5+VlYXt27cjMDAQo0ePlsx7+eWXUa9ePWzfvh2XLl2yYWcJgiA47LWDtuo4AHCOAwGyaRbtoHl0ste2aKeSUXfGbsjajCCIEkph6g9H6dGjB3x9fRVHuyxbtgx+fn7o3r27xXWsXr0ajx8/xqBBgzBo0CAYDAYsX77c4jKN69RBtdBQrFIICi37YTOcnZ0xuGcnYZrYbeC1IR/CU1MbV5KvcfMYwysdX0FFTUX8tPZ/knUxxjCma0c0cnPCrz8aNY8lB08RmzbFo2/f2oiM9EDHjqGYO/ddPHqUAXm9GVdXN9E008iatm07AwCuXUuyeCyUKVi3gXVLVgIAJs4cD0+NqT7n8Nd6oHJoBaxa/wsys7IEK2gxYivoTZs2AQA++OADuLubrjUCAgLw5ptvIjs7GwkJCcL0AwcO4O7du+jWrRsaNmxo+nQaDWbMmAEA+OabbyR21Js2bYKTkxOmTp0q2Y9mzZqhS5cu+O+//7B3r8JN6CcICs4QhYotWWAWKXFZYKL3lurO2DCChoeywAiCKIkwxvA4J7dEP5iSzaYD/Pfffzh06BCio6MRGBgoiAjxhaSY1157DRMmTIDBYMDIkSPRvn17zJ07F2PHjlVsv2DFChw4cQKNaj2HUf1exsDOHXDvQSqGvz8Tb083ZQaJRcaDB+l4scVgJCXdRJ/B3dGsdTPs2rYL/dv1x7l/TqN38054/PARYmIH47mIBtj1vx8xcVB/bmEbM8D0+hyMHPkijh/fi44dB6JLlyG4ffsaJk4cirVrv+H3CubiIj9Y8E22B4Xh+Yf37AcAtI5uIUx3z+XERoum9eDp6Y69h49bXXVKSgpyc3MREhICjcLNurCwMADA7t27Hdt3giCeGSzaQSs4DojtoMWYOQ4AnB20JazaQStVC7ZkB81TdHVnCsRxQAw5DhBEiYb0R8HpD0dxc3NDv3798PPPP+O2KPHr9u3b+Omnn9CvXz+4uVm2Vo6Pj4ezszMGDBiAHj16wMvLCwkJCVaPXf8uXfDv+fM4edp0HzFHr8eqjb+gfaumCAorY7ZMptZ8XzQaDeYmzIV/2TKYOHw8bly+IsxbNf9LHPz1F3QeMATRPfuo74ys61u1ai5mzx6D556LQN++oxEQUA6rVy/E6NFdkJvLN7asl/bv3wkAqFKlhnhvFZ7V6s0ABeE2kJ2VhROHT6BK9SoIDqkgaaPRaNCudRM8epSJo39bv59769YtACZ9JIaftmvXLpva+/r6ws/PD5cvX5Ykwd26dQsBAQHw8vKyaRtPIo4OLyCIAqNs5avmF9583RnxRbun8Vlyve8um6CDSRzwZyO9bL6SwNCK2qq9tlyQWZWbMImnaxoHRglxQsMs0PUCbLKR26LRKAtDgiAIC2Tq8/DcJAsFxUoA/01rDw+X/F/K8BYCr7zyCgAuY2zEiBFISEjApEmT4ORkuurds2cPli1bhrp162L//v3w9OQ6pw8++AD16tVTXP+cCRMQUqGCZHh+rmcuOg4ZhwUJazF2TD9UqlhOssw/f5/DqDdfweS5kwFww/E/GPEBVn69En1bdsGIKR/glbEjkAMdGGMY3bEn9v+yDWdOHkeNmg1sygC7d+8mgoOrIj7+AHQ6LQA9hg59GwMHNscXX0xEmzadUbZsOdlSSqLD1j6GExfHju3DsWP7IRUdzqYm/MMJgBYoXzEUXfoOMVubfHh+8nnuIr5y1VCzts7OzgirFIT/ziYhNyUXWq1WdXi+n58fnJ2dcfnyZTDGzAI0SUlcttu5c+ds+MwEQRCFQCgUR+Yr4w7T6H8dOO3EnfOlKOkkS+TCpL349fJ6KRPSRDoRqcZn8T222zDpv2SY2V1fRbBigOocqqMazuJv1DYLbLEoZTcHgiBKLqQ/Ck5/8MyePVvxprabmxsmTJiguExcXBy+/vprrFy5Eu+88w4AYOXKlcjNzUVcXJzFa+BTp07hyJEjaN++PcqVKyd8tpUrV2LXrl1o27at6rJ9O3bEjK++wrqfNqN1PS7YvunwXty7n4q4Pl0tfk45ZQLL4LMVX2JYx354p38slv6+CxdO/YUvJkxEpfCqeOfzL6UJbVbcBg4d3I4VK46gatVaAPLAWB4++mgwtm//AWvWfI2BA0eZLb5y5UJkZWUiIyMNf/11BKdP/4WmTaPQqVMvKI+EkQZjMjJSsXr1N5C6DThxbeSayfjo9+o4lPL35ZqruA1cvpgMg8GAKlUrme9BBlC1Mnd/9nzSVbSsXl95JUYCArjhu0lJSahZU5ogoaSZxO3lpKWl4cGDB8IyVapUEZa5c+cOHj58aPZbflp0GQVnCIfowpgw9HvLeWPm0XZIspHE1MEpwaYrHBeFUSLBuGpu71UeiplTAvJ4DNxhCp5kwJQFJj67amF+ttVDGk5WCtLIs8BsCNDcg7m9gIw7l4It+0YbURIairSHkIXXpar9tgkEQRDPKnq9Ht999x28vb3RrVs3AICXlxe6d++OxMRE7NixA9HR0UL7lSu5IeCTJk0ShBEAVKhQAWPHjsVHH30EQFpvJqSCNCMJ4Iptvj6gB37bdxi79x3F4H6dJfO9vDwweYbUUiumXwxWfr0Svv5+GDjmDWG6RqNB9Mt9sP+XbTj3z19ccEbxw5pPGjnyE+h0LsLMwMCK6Nv3DXzzzXT8+ut6DBw4EkqFLu1DKj6OHfsDS5bMsnnpBk2j0GXgEKvtHqdxF/PePqXgoZC57e3hCYPBgIzHj+HnLb1pKB6e7+HhgVatWmH37t346quvJPVlNmzYgJMnTwIAUpUKlhIEQdiJZq/56PhqOCuMpq+CC8JIewBcopfdtl/eUB7Rwgdk7AnKiBPY5GTCpK/4gJC7dNOecJiLCEcVXMAFVLHN9s1OKKmNIIiioLD0hxJz5sxRnO7j46ManGnYsCHq1KmDhIQEITiTkJCAunXrokGDBhZvhPNBp0GDBgnTBg0ahJUrVyI+Pt5icCYwIADtIiOx+udfMWf8OLi5umLZ2s0o4++HLi+2wrGLxoRlC3ey+foyWXBF1EttMXDsCHw3fxG+mDAR+7ZuA2MMM1Z+Dw8vL7vqzXTsOAhVq9YBX2dGo2EYOXIqduz4EVu3fi/STKZEtpUrv0Ja2n1hHR069MSECTOh1Vpz++GiLhkZ6XZpJgDo3GeIKTgDKLoNPEzj+mdvH5NtA+82AADepbjfWFrGQ6vb69ChA9asWYNZs2bhhRdeEEZVpaSkYP78+QCkmikyMhLe3t7YuHEjTpw4gfr1TcGfSZMmCa/Fy3To0AEJCQmYOnUqPv/8c2H64cOHhdpIT7ouo+AMUfDsgs3evmaE4gnIAgMUgzSpxmcHssDkQoPPAlOCssAIgigK3HXO+G+aSsS9hOCuc7beyAqbNm3C3bt3ERcXJxmiP2jQICQmJiI+Pl4ijv766y8AQMuWLc3WpTQNADIePcLS1YnYuHsvLl69hkeZ0htkN27Ji5YBVaqGAB6lJdPKli8LAKhR53loNBqhqGVutgv8y3FDNO9dvWFaQC0DzGhN7eysRe3azSAtZslQr15zAMDZszYkBygiH5Yvfq3Ba699gNde+xDcaBktJGPydcaXcu9kfp4sA8zF3voGNjJv3jy0aNECo0aNwpYtW1CnTh1cuHABmzZtQp06dfD3339LMhoJgiDyS737p63acik6DgDKiWGp8gl8hhv/rJTMpnZO5ZPalJLZdMiX44A4sa0oHAcoqY0gSiykPwpOf/DcvHlTGMFiD7GxsRg3bhwOHjwIADh9+jQWLFhgcZns7GwkJiaiVKlSkro0bdq0QXBwMP73v//hwYMH8PMz1bB8LBtBMbBrV/z8++/43849aNWwPn7ddxhjh/aFrrQWjsTk35o1FUf27MPK2ZyV9OiPZ6FmLVNSlq31ZurXFx9nTjuVL18JgYEVcenSaej1OUYnAoAP0OzY8S8A4N692zh69A98+eXHGDKkM778chUCA82T9+QEBYXgyJEMmMSRE4QECPEkJd0kQu42YC9qbgMA0L9/fyxfvhy7d+9G7dq18dJLL0Gv12Pjxo0IDORuhoo1k5eXF+bOnYthw4ahWbNm6NWrF8qVK4cDBw7g2LFjqFGjBs6cOSNZZtq0afjll18we/ZsHDx4EE2bNsXNmzexfv16PPfcc0+FLnuy954o0SgFEMQBhyq4IJ3pwMW36nB5mwtbiim+ujMOQwUuCYIoJDQaDTxctCX6oVQPxF6UsrsAoG3btqhQoQI2bdqE+/dNGU9paWlwcnIShmSL4S9AxeTo9ejy+uuY9u1SODs54ZXOHfHhqFhMHvcqBvfiilpmZ+sl9WYAwNvblBXHZ4DlaT0AAO7ePmbb0Wq5Ps/keWzEQgaYr2+A7EKWExr+/lyWwcOHabCv3oza96E0XeES1NarUoUMMAAoZcz+Sk/LkG7d+Db94SNoNBqU8vCwuom6deviyJEj6N27N44fP44FCxbg7Nmz+PbbbwX7ibJly9q4wwRBEBz2BgKsjg4pD1MimBxf2FB3BuB0kziLV6nujJxCqjujgK11Z3iXBoIgnlxIfxSM/igIBg4cCBcXFyxbtgzLli2Di4sLBgwYYHGZjRs3IiUlBb169ZIUh3dycsKAAQOQlZWF77//3uI6er8QiUD/0li2cTOWb98Kg8GA2D5dFNsq1ZuR4+LqipYd2gEAXN3cEDN0mNVllPIUSpcOBKeV8kRTDShdugwYY3j0KANqeikgoCxeeqkHPvtsKZKSzmH+/GkqG7ampWSBQaXAjA23Qf19uOOWnpah6DaQnsHd4PTRmtvhid0GAE6D/vzzz5gyZQqcnJywePFibNiwATExMVi/fj0Ac80UFxeHbdu2oVmzZti0aRO++uor6HQ67Ny5E+Hh4WbLVKxYEUeOHEFcXBySkpLwxRdf4NChQ5g2bRo++OADxW08adDIGaLQKd4sMKW6M0oUcN0ZubVZYWWBiaEsMIIgCLu4evUqfv31VwBAVFSUarvExESMGTMGAGcBYDAYcO/ePZQpIy1KyRfNzLlnutu0be9e/HXmDOK6x2Dp1IncRGNsZc2vv2LF+p8k67BFZCiSK7qYtzEDLDX1HgwGA7j4jGn0TErKHQCAl5c8ASK/9Wa4x7Fjf+DYsT9E01TqzRgf5UNC0aXfEKtbCKvKFYS8dD4ZQJRkeH5eXh6Srt5AWHCQEMiyRo0aNbB27Vqz6UOGcPsSIRMnBEEQSojtoAV4O2gFxwGxHbQYMzvoUNjhOKCGfLQMOQ6QtRlBEIVJYemPgsbf3x8xMTHCtXC3bt3g7+9vcRk+6JSQkICEhATVNmLLYDlarRaDunTCnJWr8O+lS2hc73nUqh6u2h4Aco337HIUiqz8ffgIEj5fAF9/f6SmpGDWmBGYuWytbfVmDBCkzv378uPMjNPvQKPRwNPTPJAh5/nn68Hb2xfHjx80TrHsNpCRkYbVq7+GVBwZ2znBTDPBCeg3fBxKBfiabVvsNlCpciU4OTnh0vnLivt5/ixXgqFqJduSI1xdXTF58mRMnjxZMn3Pnj0AlDVThw4d0KFDB7Ppr7zyCpycnNCggdSmu0KFCli6dKlZ+ylTpqhu40mCgjNEsSCuO6NKILiLdDm+xmfFhCy+7gygXndGLDoKoe7MXUiFhgp83Rm1Apc8VOCSIAiicFi+fDkMBgNatGiB6tWrm83Pzc3FihUrEB8fL4ijunXr4vjx49i3bx969Oghab9v3z6zdSRduwYAiGnTipsgGvSy79CJfH+G3GwX2xoq3GvLy8vFqVP7UbduU9FUA06ePAAAqF69Ngq63gwAHDu2D0uWzLR5DQ2aRpkFZ5SG5zeLagYA2PPrH3h/wmDJvD8OncSjx5mIasJd6Fsanm+JjIwMbNmyBf7+/mjXrp1jKyEIgihyrNlBW7I0U8LWujP8eyt1Z2yo2WkPVq3NCIIgiomi0B8FRWxsLNatWye8tsTly5exc+dOBAYGonPnzoptdu3ahRMnTgi1RsQ1OiXb7d4Vny//Djfv3MPkca/avd9ZxiDN/Qw93u0fB2etFkv2/IqvJ32CHet/QLM27dG1t+zzWHAbAIATJ/ahU6eBxndcUtvNm1dw+/Z1VK5cAzqdDpJojgKPHz/Cw4fpCAhQG+0k1U0ZGWl2aSYA6NxvCBecMY9TCbi7u6NB4zo4eugkrly+gUohQdzWMwDGGH7bdxieHu6IeN5ykr01Vq1aBQDo27evTe3379+P5ORkdOzYET4+5k4RcvLy8rBmzRpotVr07NkzX/ta3FBwhnAYcRbYlvPcaA0hC0wBm7PAAG6UyU179sYdUkFRjFlgvrJpVrLAeOzJAnMEygIjCIIwwRhDQkICNBoNVqxYgcqVKyu2O3fuHA4ePIijR48iIiICr7zyChISEjBt2jS0b99eKMp5/fp1RR/m4PLcsMk/TvyFLq1bCdP3HjqGJSs3cvuicPGcZ/T44i3N5ORYtJuB5Qww0Wj8RYsmYtGibUaPZIbbt69jzZqv4eLiiujoHgorsYblDDBTzZmPYPJNtlJvhi9pIEOcAeaGbFSpXgXNWjXCH7sPYfvP+9CtXSMAQE6OHh9N/xYAMKxLjGQdKamp0FesiNx79yRWEZmZmdDpdJJRNtnZ2YiLi8P9+/exYMECiUc4QRBEftDs5ZKvxFTDWZwDd+OuCi7gIkSZwxUZNxpfiTJQcBoAODtosY7hX4t1ki1BmgKuO1MIjgNWIccBgiCKgaLSHwVFdHQ0Nm7cCABWk5ISEhJgMBgwfPhwTJ06VbHN4sWLMXz4cMTHx2PhwoVm832NmqhGWCh+XrEAWdk5eLFFY6CUWVOb3Aamj3gLVy8lYcLC+Qiv9Tw+/Gox/j3yJ2a/OwZ1G7ZASKVq0gUsBGi2bVuJPn1GompVLmDBGMOiRVOQl5eHzp37C+2uXUtCqVLe8PHxla46Nwdz506GwWBA8+aW6hOYNFNQUCiOHHkEu+rNiO/yq1hBA8Arr/XF0UMnMen9+VizYipgvLf77fINuHTlOl7r3x3uIq2jz81F0rVr8PfzQ5Uq0kT79PR0eHtLHRfWr1+PZcuWoVGjRmYBRaX2N27cwLBhw6DVajF9+nTJPL1ej9zcXIlVnsFgwPjx43H27Fm8+eabCAoKwpMMBWeIwkFhiL7NhMKOIfqFmQUmH1VTgFlgVOCSIAiiWNm1axeSkpIQFRWlKowAYOjQoTh48CDi4+MRERGBNm3aYOjQoUhISEDt2rXRvXt3ZGdnY+3atWjatCm2bt0qWf6lli0RGhSEzxJW4p8LF1Hr+So4e+kytu78A907tcb6zTtt2t8sS+lPaljJAAsIKI+srEfo168BWrbsiMzMR9ix40ekpd3H+PGfomzZ8hBngM2fPwWpqZz/9cWLZwAACxZMg7u7JwANunXrh3r1msq2YqnejJP5JEvYcAg++2oqOkf2Qd9uY9Hn5XYo7x+An379A/+euYRRg3ujeb26kvZLfvgBny5ZgsmTJwvD4gHg2LFj6NGjB9q1a4fg4GCkp6fjp59+wpUrV/Dqq69i9OjRNuwwQRCEfThsBy1Owr0HabKYxG3AXTZBnIyWC8u6qeQ6DvBJbeQ4QBBESaao9IeY2bNnw8tL2XLrpZdeQtOm8mt3E05OToiJiVGdz2MwGISgE2//q0SfPn0wbtw4rFq1CrNnz7a4zpdaN7e6XQB4rNLPbExchy2Ja9C6Swf0Gfk6crNd4O3ngmlLEvFGpzb4aHh/LNt8EFqNLANMD2lZGSNNm0YjNjYS0dEvw9c3AEeO7Mbp08dRu3Yj9OnzKni9dOLEIcya9S7q1m2MChUqwcfHDykpd/Dnn/tw585NhIVVxYgRExT22IFaRvLAjAJKbgMA0HdwD2xc+xPWrf4ZVy9dQ6tW9XHxzDVs2LobYcFBmDH+DQAmt4Gbd+6gycsvIyQkBMnJyZJ1NWnSBMHBwahZsybc3Nzw559/Ys+ePahcuTLWrVsHZ2dprZwvvvgCiYmJaNGiBcqWLYurV69i06ZNePz4MeLj480szW7fvo3nn38e0dHRCAsLQ05ODrZv344zZ86gU6dOmDnTvtFFJREKzhCFSoFmgaliSxaYJXihUYR1Z4zwQsMaSkKDIAiCcBzeE9mSgAA4ETF27FisXr0ac+fOhbu7O5YsWYJq1aphyZIlWLhwISpWrIi33noLvXv3NhNHFf08sGvpV3hn7hf4/cQJ7Dl6DM9Xq4xV305DYFl/SXDG4XozOaJ+Uy4mLARodDoXLFy4HQsXvodt277Hw4dpCAmphnfe+Rzt2/eAuA4NwLBr11bcvHlNso5du7YJrxs2bKYQnAGkI2d4bInGKGAhA8wVWajxfDXsObwK0yZ+iZ9+3o9HjzJRrUolLJr+Lt54pZcpn8MKlSpVQuvWrbFv3z7cvn0bHh4eaNCgAebOnfvED5snCKJ4ERwHbMSqHbQ1xwF5PEZwHODtoJUcBwBzO2g1+HY8Re84YDM2WpuR4wBBEIVBUekPMXPmzFGd5+vrazE4Yys7duzAlStXEBUVhbCwMNV2Pj4+6NGjB1atWoXvZs/Gyy+9ZNd2mIf1NllwxdWky5gycgLKlC+HScuWSOY3iGyFIWPfx7J5H2PRzA8w9oPPlVfEuw0Yu4IBA95Eq1adsXr1F7h27SK8vf3Qt+8beP31D6DTuYDXTXXrNkb37gNx4sRhnDv3Dx4+TIeHhxdCQ8PRt28cXn55CNzcPGCL24BpGiDU6ORfKkkpG90GAC7w9uOm+ZgzKx5rvtuM+QtWo7SfN+IGdsWMsW+gjNZP+bgo0KdPH2zYsAGHDh2CXq9HWFgYJk6ciHfeecdshAwANG/eHHv37sWWLVvw4MED+Pv7o2PHjnjvvfdQv359s/Y+Pj6IiYnB/v37sXXrVuh0OtSqVQtLlixBbGwsnJwc1JUlCA1jdNVhifT0dPj4+CAtLU3xR/WsIy5uKYgM3tbMOHKGD86cLF1TsDU7h+qCyLiIcMHW7M6lYC44kwxOZPA1Z+4ZH3fBXcg/gkiEZIK723IfnHi4D04kZBrfi5/5u1RykcGfwfizmRaciNCJnkuJXpcGFxQyCg13cOLCEyaRUQZccIYP0ATCFJwJhTByhg/O8FlgVXABAAShwVub8cEZ8cgZIQtMLDK2m14qjZwhoUEQBE9WVhaSkpIQFhZGFk0FhNw72Vc82kNsnSsans+Mr/ngjDgDjLc140fO8O9zoDPVm+GDM9nguje+q8uFKRGaD9rwQsMgfyNWIAbRg4me5fVnxP2JXEDwosJJ5WHn8HwduJEzLkzIAHOBXgjOuCFbeO2BTLjncq81fBk6/jmNexLXnPF9wgtI2out/3u6BibsgX4v1rGmm8SaCeCSs/iEtguoIiS0XUWwSTMBUt2kpJlSIQrO3Da+4fVTpugh1kvizkQpcMMnrsn1kjtMOorXSuLXMNdNvGaC8VmumQCgIrNbMwEm3SQZOWOjbiLNRBCFA+kPAjDXTICKbrKgmQCTbrKkmQBjnU6xZgJMXZ1YO/G6SaKZIJqQC04Dqekm8QMKz5Y0k1w7ORunaSEIJVs1EyDoJj4444osITjjiix4GC8OJLqJNJNAUWqmJz+8RDyxWM1yKg/pEH05ZglY7rJn+cAwnWiaFa9+ANIzNCBNOeMFjQxeABUiilYHYgs5Uc0fe7LyCIIgiPyhVtTSEkzBQ5nHmsiQNbaOmciQ7InsuSBQyggTT5cOcbfqm6yCm00f3kSqfc0JgiCKBLuttxyqxaJ208AWbSRGPCxTD1MARy+bbiP3ZO8VRgKZWbmpoFTjlCAIgijZqCa0FQBCQpsYeRdlxQ6aQy6ilPphpjDd1v666NwGzLacYTaJKEIoOEPkC3FGkZBttF25LQBVay4z7+BQe/ZCHKWRZ3PJX9uCJVGRqfw+E+pBGbHYEAsNu+3bCIIgiCcJW0bN8BSIpZkYmwQGj5LQkGd9WUItA0wtKAOYXYLK4jQSrAzPF+Nh1k/DLANMzLOWAUYQxJMFPxoEMI0UsUoAuNEovjCvhwnAPKFNfJIVJ7MpYelull5hmsI5Wc5d2fvbotfJ5s15xwV+JBHvxsCPMlJCbrNtC+KRTgRBEETBUVAJbfJRM3KsJrTJ9ZJKvRlThpv4vbBnUE9yy0/Sm7wPKpzb9/yoGaJ4oeAMUXjY4OerSLFkgSmJCbUgTT6ywBSwlgXGCw2lLDASGgRBEE8XakUt5RROBpha/2trkMYWFC49bbkadbXeRA5lgBEE8SQitjBWQ7FmZSCk9S55fGGD4wBgrpsccRzgIccBgiAIIp9YcBhQQu42YBNKeqnQ3Qbk9+QKx22At4J2FHIbKDooOEMUOkpD9EteFpgcW7LArGSCFVAWmM28oDyZhAZBEEThU1iWZnIKPwNMzScZsveODM/nBUbRDM8nCIIoqSjVhbSE2A7azHEAMNVocRglO2hbsJQVQI4DBEEQhHXssTQrELcBhxLaAKle4p8dcRsQv7fmNiAjH24DYitoRbcBHnIbKHIoOEMUGU92FphS3RnAVEjTSCqUs8DURtA4KDQUs8DEtLc8myAIgihcfNVGejy1GWBy1Po3Kx7KVnIpbMkAo+H5BEGURBQLzPN20AqOA2p20GaE2rMXBWkHXch1ZxQgxwGCIIgnm4K2NLOGotuAEha7KyURVdD1ZuQ4yZ5RaG4DAuQ6UGxQcIYoVsRZYIqUiCwwa3VnLE82o4ALXJLQIAiCKOEURwaYzRRWvRnxa3k2GCBJ+3KGaXg+j8rwfDlusG+8PQ3PJwiiJKPkOCDGzHEg33bQYscBsU56MuvO2AxZmxEEQZRsbExo4xPY5AltdrsNACqJbPIJhVVvpvjcBsgKuvih4AyRb8RZYMIQ/e3KbQH1LDCzIfqh9uyFpSwwWy3Niq/uDI8jBS4lqFibEQRBEIVHicwAy4UNXZWj9WYcGbYPOJQBpoNZBpiLygdTHJ7Piw0ank8QxBOK2A7aJgJgbgftK2+k1t/IgzTi6daw5DjA152RnadLguMAQRAEUWwUuaWZGDU7M3IbIIoYCs4QhYvCEH2bKNAsMDFFWHcmFXZngdmDPdZmlAVGEARRAihRGWD5qTdjK0qjZorm0pMywAiCeJLJlx20GlbtoJUcB+xJcFPCguNAqg2rJscBgiAIQoatCW2KFHu9mXy4DQDkNvCUQsEZokhQGqIvzgIzG6KvRInMAgMkWWC80HiksJoCygJTEhqOQEKDIAiicKAMMKV9k0+z8xLUzuH5BEEQJR3BccACYscBsR20meMAYIcddEHUnVFzHCj8ujM8heU4QEltBEEQhUNBuQ3wqNXk5BPaVN0G5HqpWOvN2OE24KzQlEfh1ia5DTw5UHCGKFKe3CwwG+vOqCEXGg5kgVkSGo5kgREEQRD5p0RYmhVbBhiPtQwweTuVDDC+e1bpoml4PkEQTzpiO2gB3g7aUccBwE47aEDqOCDG1vqccp4exwGCIAiiaFBNaFPQSpYS2uRuA4pYGwnCx2AksZjiqDfD48DtelfrTYiSCQVniGLnycgCs/baAoVY4FJRaFCBS4IgiJKHg5ZmPIqWZmLynQGmhLUMMFs8kvn3ChlggHoGmBZW7xGKh+crZoDxGDPAaHg+QRBPAnY7DqjZQcsdB8yQJ7SJT7q2JrPJKfmOA2RtRhAE8QwhdhsQ6yObk9l4nk63AbKCLhlQcIYoEMRZYMIQ/e3KbQHpEH2LhNq7JyUwC6wEQ0KDIAiiYCloS7OiywArrHozFrDlKlSWAaY2PF8RC2KDhucTBPE0oOo4EKDQ2Bc2Og6o1ey0pKeeHMcBCWRtRhAEUSSUCLcBs0a27kkJcRvgJ5HbwFMHBWeIwsfGIfo2Z4EBpiwwRZ6QLDCx0CimLDCCIAjCcRwRGQVCic4AkyMWGyqjZ6zhor5/lAFGEMTTii120GbY7DhgK7YEZJ5gxwExZG1GEARRZNhjaaaE3G2Af6/oNiBOaFNyG8hT20pB1ZspILcBJQllp9uALZDbQNFDwRmiyLA2RF+MxSwwcSaYJ2zMAgMKLguMf85nFpgCfBZYgRa4JGszgiCIIuHKjRvQ1GmEIROnqLaxlAFmcZQMCjYDrGvXaujatSYKLgNM7b216fajFJAhCIJ4EhEcBywgdhywaAcdas+WLdlB5zeZDSDHAYIgiKIhOTkZGo0GQ4YMKe5dKTSsuQ28O2QUnteUwvXky7atUEkvqbgNLF48HY0aeeDYMfENzSfXbUDRCppPbEszn0VuA0UDBWeIIqd4ssC0suf8ZIHJ29oQpCnOApdWIKFBEATBERsbC41GA39/f2Rn258y5C2OnTiQAcZTEjPAsrIeIzHxa0ycOBK9erVC48YV0ahREG7c4G8Oaiw8xMiG5zvBfHi+wj1Be4bnn7twGb2Hvo+Aei/CvVoL1O3VH1+vXQ+mVIjbAteuXcPw4cNRqVIluLi4ICgoCEOHDsXVqwoJJAAMBgMWLlyIBg0awMPDA97e3mjVqhU2b95s13YJgni66aJ0LuLtoG10HDDDkuOAKmI7aCXHAaAkOw5YszbjIccBgiBKMvnVH2q0bt0aGo3G4mPPnj1C+ylTpgjTx48fr7re9957T2g3ZcoU1Xa///670G75rFmq7ZZv2gJNSCNoQhqh1+vvCdPlCW1Lvl4LT01tlNGEY9SQd61+fgDqbgMOY1tf+803n2HEiD7o1CkCLVpUwYsvPo9Bg17C999/i6ysxypLFa7bgBI593Iw7bMlqBrVA27VIhHUtgNem/ox7t6/b9d6HNFAhw8fRkxMDAICAuDq6oqqVati0qRJyMxUvr/64MEDjB8/HuHh4XB1dUWZMmXQq1cv/Pvvv6rb+P777xEZGQkvLy94enqiUaNGWL58uV2fraig4AxRbBRPFpgjWMsCA1QDNI9gngVWwgpcEgRBEEBGRgZ++OEHaDQa3L9/Hxs3brTYPr+WZrbUm1EkR6WvsCMDTBnLGWD376dgwYLp2L59I3JysuHt7WvrHkNRYPBBGTUUumxxBph4eL44A+y//y6h8YtDsOnnvejQujnG9O+DvLw8jPj4U7w3e7bNe3zx4kU0bNgQixcvRs2aNTF27Fg0btwYK1asQEREBC5evChpzxhD7969MXr0aKSnpyMuLg59+/bF2bNnERMTg4ULF9q8bYIgCGuOA2Z20ErwjgO8HTTvOCDBXeW1DtITcTHXnbEAOQ4QBPGkYq/+cIS3334bkydPVnyEhoaatddqtUhMTERurrm4yM3NxcqVK6HVWg/cx8fHAwA0Gg0SZTfpJZZm7vx2nbFl5z7cS0lVXN+K+A3Cdg1GEcEntI2fORFbTh+Fb4UQbj8LzG3AAHN9ZN1tYN265Xj8+BGaNo1C377DEB0dg+zsbMybNwWxsV2NAZqidRuQOw8YDAbEDHgbk2ctRoCfL8bF9kWzOrWxdMMmRMfG4t6DBzZtwxENtGHDBrRo0QLbt29H+/btMWrUKPj7+2P69Olo166dWZAyJSUFTZo0wZw5c1C2bFmMGjUK7dq1w5YtW9C4cWMcPnzYbBtvv/02BgwYgEuXLmHAgAEYOnQoUlJSMHToUIvBx+LCkVQYglCkC2PCKIwt540XtNtRDP693jAJAHfjax1MZ19eVNgbOs8VrUdnXK9YxKSbtslv3tfKKm/CbFTQnUvByrZuRs6huqodnIQXYMq+aw8hI69LVdssFAiCIJ4l1q5di0ePHuGtt97C/PnzER8fjz59+hToNgq8qGURZoD5+pbGwoWrUaNGbfj4+GH06AE4dGiPSmulUTMO5AO5Wm8iZ8SoWUhLf4hty+ejQ5tIIA2YPup1tB42Ekt++AG92rdHdGys1fWMHTsWd+7cwYIFCzBmzBhh+rp169C7d2+MHDkSv/zyizD9xx9/xI8//ojIyEj89ttvcHfnvtdPPvkEERERGD9+PDp37qwoggmCINSod/+01VHyZStfNR9FEgjpSH0xvDySTBB3KFrZex0suwjw+kgrmqb22oY+7y64gBLPbXCfB+AcB0Ktr0KNk6VrOubiYGSLRqM88okgCMIBikJ/jB8/HuXKlbO5fYcOHbBlyxZs3boV3bp1k8zbtm0bbt26ha5du1ocFZGeno7169ejTp06CAwMxM6dO3Ht1i1UtLAfHVo3x5Yd+5D4wzaMe6O/MD1T64ZTf5/FiWP/4aWubfHL5p1my/qUD4VP+YJ0GzCozLCt3sy2bcfg6son4mmE50mTRuHnn3/E5s1r0bv3MBSn20DCyq3YvusQ+vVsj1VzpkOj0QBpwDc//Ig3ZszCjK+/xvwPPrC6Lns1UGZmJl5//XVoNBrs378fDRs2BMAFeUaPHo1FixZh3rx5mDBhgrCNyZMn4/z583jrrbcwZ84cYfrBgwfRsmVLxMbG4tSpU3By4vTm0aNHMXfuXISHh+Pw4cMoXbo0AODRo0do06YN5syZg549e6JZs2ZWP19RQSNniKLBGCSwOwtMaYh+vrPAxJSQujPJ5pMKtMClFcjajCCIZ534+HhotVq8++67aNOmDXbu3InLl5V9i/Py8jB/xQo06N4d5SIj0aB7d3z9XQIMzDgyRRxHKQXs3ncUsaOmocbzPVHKtxVK+bZCi4g+WLZ4ndBMXNQyXFMG/Vp3x+3rN/B2/zi0CaiAqABvjOvWGdeSLgEAks6dxvih3dD2udKIqlkK743ohZRbanfixBggHkGTkZGKTz4Zi/btayAysgIGDGiD7ds3mC3l4eGJJk1awcfHz8K6lfoS+bT8Dc+3lAF27txl/L7vBNq0jOACM/wqdDp8OHw4AGClDRmJWVlZ2L59OwIDAzF69GjJvJdffhn16tXD9u3bcenSJWH6pk2bAAAffPCBIEoAICAgAG+++Says7ORkJBgddsEQRDWEDsOKGKzHbQlxwH5a1vJR90ZchwgCOIZw1798emnnyI8PBxubm4IDw/HzJkzYTBYGhlvPz169ICvry+WLVtmNm/ZsmXw8/ND9+7dLa5j9erVePz4MQYNGoSeLVrAYDBg9datFpdp3rAOalQNRcL3W80S2lYu2whnZ2f0GdxDcdkPhgzH85pSuJGcDIC70T+2Syc08tbg1/+tlbRleoYxQzugUTUNfv3ZOM/MbUCZTZtWom/f5oiMDELHjrUxd+5EPHr0ULx2ABAFZqS0bdsFAHDtWpLKForObWBp/EYAwMxJI7nAjJE+XXsgtEIFrP/lF2RmWa/zaa8GOnDgAO7evYtu3boJgRmAG2E1Y8YMAMA333wjsaPetGkTnJycMHXqVMm2mzVrhi5duuC///7D3r17Je0B4M033xQCMwDg6emJDz/8UNhGSYKCM0SJRXH0SKD5JAGzRCz5BLnAKEF1ZxyEhAZBEIUGY0DOo5L9KKDs1f/++w+HDh1CdHQ0AgMDMWjQIBgMBtWb6UN69MDUhQthMBgwrFcvvNC0KeZ+9z3Gzpqj2P7TBSvw+8ETaNTwOYwc8TL6DuyMlHupGD18Gia9/YniMmkPUvFKi2hcT7qMzoMHokGr1tj/yzaMimmHCyf/QVzH5nj86CG69I5FzToR2PXLj5j4Vj8b6s2Y0OtzMHJkDI4fP4COHXujS5f+uH37BiZOfANr1y61cMTkN8XE2V5KGWD8s4KykGeAKWBrBtievccAANFtmpjNb1qvHjzd3bH/+HGr60pJSUFubi5CQkIkYoUnLCwMALB7925h2q1btyTzlNrv2uVoMQmCIJ5mbBnRLraDFpM/O2hAue4MoGxpZg8O1J2xxk3zSdbqzpC1GUE8YZD+UNUfr732GiZMmACDwYCRI0eiffv2mDt3LsaOHVsg+8Pj5uaGfv364eeff8bt26abVrdv38ZPP/2Efv36wc3Nsj1zfHw8nJ2dMWDAAHR54QV4eXhg1ZYtYIxJLc3EuAJDB3TB3/+ex7FjphGOOTl6/LDqJ7Rp3xLlgspa3f/cbBdoNBpM+joBpcuUxczxw3EzyRTwWr1sPg7u/QWdewxBdAdro5RM0ZpVqxZi9uwJeO65+ujbdzgCAgKxevVijB7dB7m5ttkZ7N+/AwBQpUoN45TicRvIysrG4T//RfWqIQjxNWZ0pBn3SKNB6yZN8CgzExds+K3bq4Estff19YWfnx8uX74sSYK7desWAgIC4OXlle9tlFRdRrZmRLGgNEQ/HBctjg5BeShelJvjDpPFGG8/Jh6eLx+qbwtKQ/XV6s7IgkJ8gUtf0bR74Eb/AFJrs2samwp6krUZQRCFjv4x8ElQce+FZT64Abh45ns1vCfyK6+8AoDLGBsxYgQSEhIwadIkYYg0AOzZsweJmzejVtWq+CU+Hp7GDKGpw4eiXu8Biuv/es4EhNaqILzP1LohNzcXMR1HYfGCFXht7BCUqVRZssy5v//BoDdH4s25XMAnN9sFs0aOxI/xX+O1Li3x6jtT0C92LJDLZYC9GdcZ+/dsw5n/jqNG9QY21Zu5d+8WgoMrIz7+F+h0WgAMQ4eOwcCBbfHFF9PRpk1HlC3L2w/YIkSVMpnNM8COHduDYyf2cM2dFB4AF7RxBuDMUDG8IroOGSQs7ybxJzBx/gJ3k7Jq5Upm85ydnVEpKAhnk5KQm5tr0Svbz88Pzs7OuHz5MhhjZgGapCQu2+3cuXPCtICAAGFezZo1rbYnCOLZRmwHLcDbQe+CNFhgKxWZ+ogSsU1Yqnym3OdMbOOcC/t0k9zGTBzkyYRdNUDFeklMMsyCUFcRjGBcxUWEowou4AKqqI4wImszgijBkP5Q1R/Lli1D3bp1sX//fnh6ctv/4IMPUK9ePYvbmT17tuJNbTc3N4l1lJi4uDh8/fXXWLlyJd555x0AwMqVK5Gbm4u4uDiL17SnTp3CkSNH0L59e5QrVw6p166hc5s2WPPTT/j9yBHEtGxsauwjXXZQn474cMZXWLZ8Mxo2rIlMrRu2rtuOe/ceYEDcy5K2fL0Z3n1Ajn/ZQEz+cgXG9euIiSP6Y/G633HhzCks/HwCKoVWxTsTvlT9DBxS7XTo0E6sWLETVas+B8AAxj7ERx+9ge3bN2DNmngMHPi62RpWrvwKWVmZyMhIx19/HcHp03+hadPW6NRJKShk7jaQkZGK1Wvmm+I3SrpJpJkA4JXxI+Dv62nRbeDixWswGAyoWlk5waFKMDf9/PnzaNmypWIbHns1kLi9nLS0NDww1ro5d+4cqlSpIixz584dPHz40Oy3bO82+GnXrl3D48eP4eHhYfHzFRUUnCGKlTo4pTj6IxhXBVsvANwFeLI9a1aqOwNIPZMdDdJYqjsD0zS1ujNqQkMEX3eGhAZBEEThotfr8d1338Hb21vwVfby8kL37t2RmJiIHTt2IDo6Wmi/cuVKAMC7w4YJgRlfV8A3sCzGDuiLjxaKhkgbh+SHhVQwC21otVoMfr0/9vy2H3t2H8PLg6XBGQ8vL4yZMQmAqd5M+5f74cf4r+Hj54++Q011UDQaDaI79sX+Pdtw7uxfXHBGFXFhS2DkyEnQ6VzAC5DAwCD07fsqvvnmU/z66yYMHDjcwrqUUPNNNnHs+B4sWTxVdb6chlEt0WtIP8V54uH5aWmcrYCPt1EwGzPAUo2xnFKenjAYDMjIyICfn7o9m4eHB1q1aoXdu3fjq6++wsiRI4V5GzZswMmTJ7n1pqYK0zt06IA1a9Zg1qxZeOGFF4SMwpSUFMyfP9+sPUEQhDU0e81HxFfDWWE0SBVcEOyPVeE1h9gezNf4rFh3xh1ABpTrzgAmDSS3fban7gy/PQtYqjvjIH+jttkIJBalbLtNEARRmDiqPyZNmiQEZgCgQoUKGDt2LD766CPVbYlrdIjx8fFRDc40bNgQderUQUJCghCcSUhIQN26ddGgQQOLwRk+6DRo0CCkHj0KAOjbsSPW/PQTvtu8WRqckVEuMAAd20VizdpfMefzcYCXG1Yu+x8CypRG+y4v4K9j/6guq0Tzti+h72tjsfrb+Vg4awL+2LkVjDHMmLsaHp5eKlZm8noznG7q2LEvqlatJSyg0WgwcuT72LFjE7ZuXSsKzph01sqVXyEt7YHwvkOHXpgw4TNotTrY4jaQkZFql2YCgG7D+sPf13Lw0KSZzIN2AKeZuHZpVrdnrwaKjIyEt7c3Nm7ciBMnTqB+/frCvEmTJgmv5TorISEBU6dOxeeffy5MP3z4MLYa7fLk7WfNmoX58+ejf//+8PX1BQA8fvwYM2fOFB2HNArOEE8n4iywLeeNQ8GLIgvskdLMws4C46epjJhRohALXBIEQRQoOg8uM6wko8v/xdSmTZtw9+5dxMXFSYboDxo0CImJiYiPj5eIo+MHDwIAmokuJHlaNqinuI2MjEf4fG4iNm3ai4uXruHRI6m9y+0bnGVAlmhMekjVKnD2kKaTBQRywyzDn6/DjeYQDeD0L8PNu3fH0ncmVR/OzlrUrt0IpkKWnJioV4+zBTt7VtlGRxlL9Wako2deGz4Fr42Yol7YUgfT8HwXZrQ1s6PPtrMcnBLz5s1DixYtMGrUKGzZsgV16tTBhQsXsGnTJtSpUwd///23JKOxf//+WL58OXbv3o3atWvjpZdegl6vx8aNGxEYyHX04vYEQRC2ouQ4IKds5atSiy9+VL5dNspPpuMAn9SmBjkOEMQTBOkPRf3x119/AYDiKAZrIxtu3ryJcuXKWWyjRGxsLMaNG4eDRu1z+vRpLFiwwOIy2dnZSExMRKlSpdC9e3dk//svt48REagQGIif9uzBg/R0+Hl7q65j6LCu2Pzz7/jfxj1o3KYpdv56EK+NHQydzo6RlzkmXTJqwiwc+2MPEr+dDQAY/d6nqFmrobRrU603Ywq01K/fDGK9BDCULx+MwMAgXLp0Fnp9tjHhzcSOHdznv3fvLo4e/QNffvkxhgzpgC+//AGBgRUgxdxtICgoFEdOMnXNBJh0k6CZAEv9tnuu9Toy9mCvBvLy8sLcuXMxbNgwNGvWDL169UK5cuVw4MABHDt2DDVq1MCZM2cky0ybNg2//PILZs+ejYMHD6Jp06a4efMm1q9fj+eee85Ml7Vq1QqvvPIKvvvuOzz33HPo2rUrdDodtm3bhtzcXPj4+CAtLa1EabOSsyfEM4NSdpL4grkKLlhfSQCko088wV3Eq9ad4Z+V6s6Iz2pibKk7Iy9waYFUy7MFCrrApUpAjDyUCYKwiEbDDdkvyQ+FeiD2Is7uEtO2bVtUqFABmzZtwv3794Xp6Q8fwsnJCf7GDBwxgf7+ZtNycvRo3e11TJ+xFM7OTuj3She8++Gr+GDyG0JRy+xsc5suL+9SZtOcjVZcnl7mgoa36crN0duQAcY18PX1V7goZfD35zrYhw8zzLajjEb2Wv4esFzJUgUX6XgjS8PzAcDXncv+SktXzNhAxqNH0Gg0KFXK/NjKqVu3Lo4cOYLevXvj+PHjWLBgAc6ePYtvv/1WsJ8oW9bke63VavHzzz9jypQpcHJywuLFi7FhwwbExMRg/fr1Zu0JgiDyg3gkff7qzrhDWndGjB03whSxVndGNE2t7sw9hWkq8K4L/GgiS3bZ1oJd1jCzpCMIomAg/aGoP/gbybxdkxj+BnhBM3DgQLi4uGDZsmVYtmwZXFxcMGCAsoUzz8aNG5GSkoJevXpJisM7OTnh5ZdeQlZ2Nr7f9gs30Ud5HZ06RiIwsDQSlm9G4vJNMBgM6B/LWZrlgAt+5BlvZfOWZgbje95tQIyLqyuav9ABAODq6oaYnsNsPAJSQVW6dBnFVqVLlwFjDI8ePVRdU0BAWbz0Ug989lk8kpLOYv78SbDFbcDqHXuFbtpFdM9SbAUtdhvw8eE1k3SfebeBjEePjO1UviQRjmiguLg4bNu2Dc2aNcOmTZvw1VdfQafTYefOnQgPDzdbpmLFijhy5Aji4uKQlJSEL774AocOHcK0adPwwQcfKG5j+fLlWLBgAcqUKYPly5cjMTERjRo1wr59+5CXlwetVovSpUtb/XxFBY2cIUo0JT8LDJBam4lFh3FaqvGtfGShWhaYEbUsMN7azOYsMDGiLDAlyNqMIIhniatXr+LXX38FAERFRam2S0xMxJgxnI2Yt5cXDAYDUlJTEeDnJylqeTsrxfTGeP9/07a9OH7iDGKHxmDp4onI1HIi4jHc8b81W7F2xQaL+yiIDFEGmMQjTZ6MzGNDBlhqagoMhjyY4jPcvJSUuwAAL69SZsuoCwi1oIw5x47vwbFje2yqN+OkzUNQaIiqrRlgygDjfZPPJ18xa5OXl4crN24gLCzMYr0ZMTVq1MDatWvNpg8ZMgQAEBERIZnu6uqKyZMnY/LkyZLpe/bsUWxPEATBIzgOWEDNDrpg4B0H+GfeaQAw3f2xVTtZcxywMehTRI4DZG1GEERR4oj+8PHxgcFgwL1791CmjDRIcPu2XTfHbMbf3x8xMTHCtXC3bt3gr5CIJoYPOiUkJCAhIUG5zYbNGNm3t+o6tFotXhnYCXPnrcI//yUhonFt1KxVzf4PYAw0/HP8ML77+nP4+Pkj7UEKZk1+AzNnm1/fSxGLKE4H3b9/F3K3AX66RqOBp6eyRZiY55+vD29vXxw/fkA2R9ltICMjFatXz7debwYQdNOQca/DWyGJUEzlyhXg5OQk1OvkraB5Ll411vGsaltGtyMaqEOHDujQoYPZ9FdeeQVOTk5o0EBq012hQgUsXbrUrP2UKVMUt+Hk5IQxY8YI/yGe5ORkPHz4EA0aNLBvNFYhQ8EZothQGqIfjotClpNZ3Rm7cAeQrjKvoOvOAKpCQ6kkDWBT3Rkeed0ZS+S37gxBEMSzxPLly2EwGNCiRQtUr17dbH5ubi5WrFiB+Ph4jBkzBqlHj6JW1ar468wZHDxxAl1ekA5N3PfnSbN1XEy+BgCI6dpKCMzwHNx3RHgtL2ppsJRFZbajtjaURmvy8nJx6tSfqFtX6v188uRhAED16rVs3wcA5gEa8+H5cAKOHd2DJd86VnNGnAEmJyqSu4j/9ffDmDBiiGTeoZMn8Sgz06IItoWMjAxs2bIF/v7+aNeunU3LrFq1CgDQt2/ffG2bIIinC7EdtICddtBmdWfU7KCVdIfqoH++7gwgDdIAJosye+vOQNaW344KqTCv26mE0drMVpTqzkggazOCIAoZe/UHwI3oPn78OPbt24cePXpI2u/bt6/Q9jU2Nhbr1q0TXlvi8uXL2LlzJwIDA9G5c2fk3JMOfXRxAnb9eQQnzpzFidNnUb+p+Wdnxryw2CFdMXvOd7h18y7GTx5j1s4iooS2Rw8zMHFkf2idtfjm+z1YsmAKdvz8A5o1b4+uMbEmcwFBIim7DQDAiRMH0akTH1TiAjQ3b17B7ds3ULlydaOlmVJCm+n58ePHePgwHQEBgbDFbSAjI9UuzQQAPYf0Qxlfk+ZUchvwyHVD4wbP49DRU7h87SZCSpkyxRlj2HP4MDw9PfOdWGavBtq/fz+Sk5PRsWNHm0bt5OXlYc2aNdBqtejZs2eh7FNRQcEZotixOQssFFyWlM3wQ/QtZYEVVN0Zflrx1J2xWuBSLDREkNAgCOJZhTGGhIQEaDQarFixApUrV1Zsd+7cORw8eBBHjx5FOIA+HTti1ZYt+GzpUrzQrBl8Xbnz/vXbd7AgYY3Z8iHB3MXuH/v/wovdTN7R+/ceRuKSH+zb6RzRa3lAxmqAxjwDDAAWLZqBRYvWCZlDt2/fwJo1S+Hi4oro6Bj79k9ALDbMx+O/9sYUvDZyis31ZlxU+moP2Z3F6lVD0ap5few+cBQ/b92PDi0jkZoN5Oj1+PjbbwEAw4ZJrQzu3buHe/fuISAgQGIVkZmZCZ1OJxllk52djbi4ONy/fx8LFiyQeIQDQHp6OrxlHtrr16/HsmXL0KhRIzMxTRAEYQ3NXpldMTg76HMwv6ll5jgAmDSGkj2YvDwn3CHVRnwSG/9sj3ayp+4Mv23k23GAT2rjseQ4kN+kNnIcIAjCXhzRHxEREXjllVeQkJCAadOmoX379vA0Fmy/fv261Tow+SE6OhobN24EAKtJSQkJCTAYDBg+fDimTp2K1KNHJfN9XYHF6zdg+LSZiP9pExY2fZeb4Wa+rho1QvG/n79GdlY2mrzYBoApgc0ePh0/AtcvX8I70xcivHotfDhzCf7960/M/mQM6tZtgZBKlkbkSM/v27atQZ8+caha9TluLmNYtGgm8vLy0LmzaSTQtWuXUaqUN3x8pJZZubl6zJ07EQaDAc2bvyiao56QFxQciiN/M7vqzXC6SbmujLjezGuDu+HQ0VN4/9NFWDV9OlfPFEDChg1Ivn4dr732msSaTq/X4+LFi9DpdKhSRWobaq8GUmp/48YNDBs2DFqtFtOnT5fM0+v1yM3NleyPwWDA+PHjcfbsWbz55psICgqyuo19+/Zh5syZCAkJweuvv654jIoLCs4QBY44C0wYom9nFpgZBZoFxt/9kQ/VtzcLTEloqA2VMZIK9QKXYmzMAiNrM4IgCMfYtWsXkpKSEBUVpSqMAGDo0KE4ePAg4uPjMTMuDi0jIjCgSxes2rIFLfv1Rfe2rZGdo8faX39D0/q1sHXnH5KBlJ17tUTox0H4fPZK/P1fEp6rFY4zZ6/i16270bF7O2xZ/4v1nc1RuWhXCsgw2JgBxhAQUA5ZWY/Qr18UWraMRmbmI+zYsRlpaQ8wfvwMlC1bHmJhMn/+NKSmPgAAXLx4BgCwYMFUuBtrvXTrNgj16jUXbUtpFI2NuDjWD3019T1E9hyGbuPeQZ/27eBXOgC//vEHzly6hFGjRqF58+aS9gsXLsTUqVMxefJkYVg8ABw7dgw9evRAu3btEBwcjPT0dPz000+4cuUKXn31VYwePdps202aNEFwcDBq1qwJNzc3/Pnnn9izZw8qV66MdevWwdnZgbo7BEEQUHYcsEp5cIEMq4gdB+R20PYms8mROw5Y0Er2Og4kQzWpzRbHAYIgiKLGEf0RERGBNm3aYOjQoUhISEDt2rXRvXt3ZGdnY+3atWjatCm2bt2quq7Zs2fDy0vZcuull15C06ZNVZd1cnJCTIz1ZC2DwSAEnYYMGWIWmOHp0z4a4z6bi1Ubf8HsD8fCzc1VsV2m1g3RL7UAwFlBK8EHa3KUHGyygW3rEvHzj4lo2a4Leg8eCeQC3p5+mDY7EW+80gYfTeiPZSsOQuukZG1l7g3dtOkLiI3tgOjobvD1LY0jR/bh9Om/ULt2Q/TpYxpVdOLEIcya9T7q1m2EChVC4OPjh5SUu/jzz324c+cGwsKqYcSIibDFbcAiVhy5LLkNAMDgfp2xdv1vWL15O5KSbyAqoj5OJ1/Dlt27ERIUhBkzZkjaX79+HTVr1kRISAiSk5Ml8+zVQF988QUSExPRokULlC1bFlevXsWmTZvw+PFjxMfHm1ma3b59G88//zyio6MRFhaGnJwcbN++HWfOnEGnTp0wc+ZMs8/Xq1cvZGZmok6dOvD29sapU6fw888/o3Tp0ti4caNNNUiLEgrOEMWCtSww8RD9gs8Cyw98kEZsaQZYrDsjfivPAhNjQxYYLzT4LDAlyNqMIAjCOrwnMl8/RI0+ffpg7Nix+D4xEZMGDIC7mxsWfPghqlSqhMRNG7Fw9TpUDCyLt4b1R+/O7bjgjAgvLw/s/PUrvDvhC+zddxL79hxF9eer4utVc+EbWB5b1v+CXJXLMaWilvlDGvDQ6XRYuPBHLFw4Ddu2rcPDh+kICQnHO+98jPbtu5ktvWvXT7h585rZNJ6GDVuiXr1ImGeAyYISvEeyOAPMCuIh+UrD8zVGB57nq1XB4Y0JmDjzG/y0bz8eZWaiSqVK+Pzdd/H2rFnWN2SkUqVKaN26Nfbt24fbt2/Dw8MDDRo0wNy5c1WHzffp0wcbNmzAoUOHoNfrERYWhokTJ+Kdd94xy9wiCILILxbtoEPhgOMAL1qKou4Mvx0rWHIccBC7HAfI2owgiALEXv2xevVqzJ07F+7u7liyZAmqVauGJUuWYOHChahYsSLeeust9O7d22JwZs6cOarzfH19LQZnbGXHjh24cuUKoqKiEBYWhtSUFMl8vkanTykv9GjbBqt++gUbftmN/t1eEtowhTiNWmDGGtcvJ+Gz90YiILA8Jn2+TDKvQb1WGBL3PpYt+RiLvvwAY8d+LporL9pp0k4DBryBVq2isXr1Yly7lgRvb1/07TsMr7/+ntHSjKNu3Ubo3n0ATpw4jHPn/sXDh+nw8PBCaGhV9O07DC+/HAc3Nz5YZtltwCqyY2ar2wDABd42LZmDWV+vwHfrt2Hed6vh5+2NV7p2xYdvvGFW28gS9mqg5s2bY+/evdiyZQsePHgAf39/dOzYEe+99x7q169v1t7HxwcxMTHYv38/tm7dCp1Oh1q1amHJkiWIjY2Fk5P5sevWrRuWL1+OVatWITMzE8HBwRg9ejTef/99BAbm80KiENAwRinylkhPT4ePjw/S0tJIWNuB2D9ZKG7Z3vhsHDnDB2dOlq4p2JqdQ3VBZFxEuCAy7lwK5kaTJBvXcRPcxfk94+MuuFEpj4zPwrnntvFNJriMMP51LoDHotd6mMSH/ITGixF+xA3/7C561hqfvY3P4tfGJ19wwRlfmERGAEyZYIEwBWdCIYycKVuZG5rPD9Hns8D44Ix45AwvNMTBGUFoiG3NZCNn5EKDRs4QxLNBVlYWkpKSEBYWZmbRREhRygDzFV8Qi21xRYk4vHeyuN4MLzT4jC95vRk+A0wIzuRohKKWku4qF9Kui9cTiiNncsEJDIPowRSe+fO//BlQ8k42VaiUV6p0Nk7TGl8bL5p1UA7OWBmezwdk3JAtvPZApjA8nw/OCGUSjIUtU0VJY7759E1+WrD1f0/XwIQ90O/FMeQ1ZyS6SUEzAVxwgU9ou4AqQkLbVQSbEtp43STXTIBJN12HSDPxWuk+uA7lPrh+I9P4XvwMmPSTHLFu4m+quQPwgElDlYJJR5WGqmYCTLqJ10sBMNdMgEQ3qWkmwKSbLGomQFU3KQVnSDcRhH2Q/nj6UbI0E7CimQCTbpJrJsCKbuLdBrJh6qLEeonvuixqJoPotVwriTWU+Nwvfw9IAy8a2YPXSWLdxL829qPiyZasoI1uA2IraFdkCSNnXJElBGckuok0k1WKUjM5EJojiMJFfAEt9gwG4EA9FvEfgxcI4hEv/BnOVuSWZ2rTrHDXhjZKNm4KKHlOKyK2k2uv2gqAuVAkCIIgpPi6qsxQERk8jmaAWYQXGYqoZ4CZhERB3FgSCw4e2WWmA8PzxRlg1obnEwRBPGlYvLGvUC+SR5ycZZN9lzghjMcXCoNX3GXPcp0k11GWEHtvijVSpux1umlaJkwJd5awYNfGJ/fxQSs+8U8Ju63iCIIgCIuoWZrZijihTY7VhDZLWK3PKUYpWU2tv1aarnY/TSlgI0bBAlkcmLEBW90GiJIFBWeIEoF8aLkilmqwlIE0y0qCu8pruaDQqkwXh9zl05Tm8aNzFOBH9ogRW7MpCA0+A46EBkEQRPFgVWT4WJ6thLWilpJRMzzie1tKAkMxA0zeAFAWEUqiwxaxoSQ+5NPsuNxUC3rJUBqeL6CQAUYQBPGkIhnRYQV+xL0Em9w7lDSS+E5QftzQc2XPgIUiocqI9dJt0etkh3ZIcG0QI7HcVklqE0Y2EQRBEDahmtAmQimhzWHE1/9yvWRTLrVcPxX26EiFejOAxHhAQMUOmncbsAQ/aoYomVBwhigUxFlgwvBvfki4hSwwMYWTBcZjKQvMGkoCAzAXGQpZYHKUauYUMPLaPkqQ0CAIgig8HMkAs4kCywBzNCijNlRf3F4hA8waLtLtK2V9KWIhE4yG5xME8aSjVFNSrQalQHnLs6UoOQ4AUp30jDgO2AE5DhAEQdiIiqWZEna7DeTYeC6WmwqoziC3AaLooOAMUWwoZYGJh+iLKZlZYDwOZoHJhUYhZYFJIGszgiCIfONoBpjDlmYlOgNMqa+wkgEm901WQC0DjIbnEwTxrKPmOJB/O2jAPKFNJ3q2JUhDjgMEQRDPKoVhaSav0clj0dLMmtsAYIfbgC1BGnIbIPIHBWeIYkcpC8wqT3oWWKrsvdoIGmMWmFDcUwYvNJSywEhoEARB5B+7LM1sHJJvzdJM4InOAJNh6YqTL2ppJzQ8nyCIpw0zxwFHUbODDoDUDtpXqVFh1J0pRMeBZCu7oAJZmxEEQRQuhWFpZlFHqQUZ8uU2IH+t1FaMmtuA0kPcntwGnmUoOEOUSMRD9EtmFlgubMsCS5euJhUFWuDSVsjajCAIongosKKWBZoBJsaeDDA5ljLAHLjElAk4teH5FjPACIIgnjDEdtBmWLCDFjsOWLWDDoS5FTSPxUGdT4DjQAmAHAcIgnjWKciENofdBsTY7TZggLlesiVAkx/IbYDgoOAMUSQIWWAWUBuiL6HEZIHxyLPBlERGpvosoEgKXEogazOCIAiHKdailmIKrN6MWlsx8swutXoz8vaA3ZeaLvkUPjQ8nyCIpxAlO2g1zOygbXYcUHMYkL8nxwGCIPIPsxSYJp45LFmaqSEktIndBiwltPGJbHniifIJSnbQBfVbJbeBJ4miPEdRcIYoNBSzwPgh+sWWBVZYdWfE5CMLTIl8FLgkoUEQBOE4xZoBpmRpVuAZYPnFmpeybHi+PAPMBsRZXxYzwGh4PkEQzwBKdtAF6zigZAdtbwKbo3VnSqDjAFmbEUSB4+zMXR/q9TYVTiSeUGxJaLMVNbcBm1BLaFO0hBZTVPVmCs9tgMgf/DmKP2cVJhScIYqVos8Ck2NPFphYTNiSBWYhSFNEBS7FkLUZQRBE0WJPUUseuy3NbMoAk1OY9WYUhufzQRk5KkP01YbnEwRBEBw2OQ5YQuw4YIa7wmt5uq44yl5QdWdE00qK44AdkOMAQdiGTqeDq6sr0tLSaPTMU4KjCW1KbgMOW5qJYxJ2J7QBnH4qrHoz4tcF5DZgAxatoMltwCKMMaSlpcHV1RU6nQPDleykoIcMEIRD1Lt/2myURzguCoGHYFwVghIAuCywZHu24A3TVb678bUO3Fmbf7Y3c4M/42uNy2phWXS4m976yprdg/IIoGQ4VGPnb9S2LNpegGn0UntYLDq6RaOx7IVNEATxjFCYRS0LNwNMPHrG0QwwOQWYAWbloytlgNHwfIIgnla2nDcmTG2HVQtiMVVwQTpKpCIzH4HP6w1xkMPX+Cy5h+MOZW2kVZluD7z+4jdqY/93F1xQyRLXNOo22CLOobrErQHgHAeURiYRBFHwBAQE4Pr167h27Rp8fHyg0+mgoQDnE4tSWlWW+FQs1i6ixkx0OZ+t5fSK3phglgNXYcE8o7YwGLUG47VNjgHQG3834pxpPkctD6ZENjU5BINxO3xwRvxsgFQviRe21NfIAzLi5DUn0Xu+uIzWtNNMtphG9Hn4TWoA6BiQDTCmhwF6uCAHeQBykQsXZEMPINv4oTXG46/JgWlEqnFd4u/Ou1YtZGU92xqLMQa9Xo+0tDQ8fPgQFSpUKJLtUnCGKFHUwan8ZTCJL9hT5TP5oAz/OgPmAoMXB7xosEV8iAM8OuM28lHA7DY4uzYx+RAaBEEQhP2UiKKWJSoDTKnejNp7HjuCNBaG54uxmAFGEATxhNKFMfWRF7sAvMA5DshHwlfDWUVr47KVr5rXXgmEdLSJIu4w2YvxuobXS7ZqI4ja86+1sunyafy2FeBvJPmKpokT227CzFHhzqVglK3MJfcF4youIhxVcAEXUEViA6cGi1JxeBAltXWpaltdVYIgzPH25iwU7927h+vXrxfz3hD55fE9aTEwDy2QIp4gth92M3+d42QK0vMJa7nGZ72xr8gVpnPD8Q25xj5EPDhTHJThgzEG0Xtx7AWQTZC/VgvKwMJrJU2kkT3EgRr+vfG1RjbZWTSbdyFwhuA+4KTlPrwWedAa+1IdcoXXvJ5yMeghuEPzz0ZJ9VikMT2SkkBwuLq6okKFCsK5qrCh4AxRZDiaBWZGicwCkwsNcRDIAqnGZ2tZYEZ4ocFji9CwNwuMhAZBEET+caSopYBSvRklLDmXSYbRiAM0BVV7xlJQBlD2MbOCi4P7xAs+Gp5PEMQzgJLjgFXKw2J9FimWHAcAadDFGrmQugs84Y4DViDHAYKwHW9vb3h7e0Ov1yMvz+JFLVGC2V2jhtm0NqGiNy1Er5uYXrJGptenfTnHnLOoJkxLRhgA4IrRQecKQgAAt1AOAJBy1RiVv6UB+PgeX9M5BcAD4wMA7oPTC3x5M2FwSBaAh+Ay4vjnbOP0LOPrXOPrPEiH5eTCvB/Tip75Qpuuxveu4PpUV3Ceom7GZ1cAXtx7N3BdsDu45L/SAPyMD3+Y7htWAFCOwT+Y69jL4RYq4TIqGevOhYILslTHOdRM5e4Tao4Ylz1sfP6De9qdbNr7NmfOgOBqzBSFlZkYCs4QhYotWWBKiLPAxEP0S04WmJLAUMLCKJpHMPd5tpIFxsNngVlCSWhIssDI2owgCMJmbLE0S0q5gcqlYzD4lU74auUs1XZqRS3trjcDSLPBJBPE7wF5hlfXrg0BMGzebGWEkFWURsrIRszwWV+8RlGpNSPH1aScJK+FLWeYTTLDNyLCeiOCIIgnnIK1gwaUHQcAqQ4CTEGWInIckFubFbDjAFmbEUTRo9PpivxGKFFw5F6+bDbNTSxp0sSNTS+ZqI3GzQMAkGO8GcbdC3wIALhvNN66bdQy14wi4o6zmylpm/fmegzTvcEH4O6v8QGbVJhqPwu5AAxcACYDJk2TAxgNwrjXenDBGb4Ugrgkgrw/zAbXtxkgtewUayNem7nAJI58uUnuMNWD8xU1dTbOywHXpzMAzgyZbloE4ypuA/BCDkobj1kO7gv3At1cuO+HtzYTvo8bxk8g+vrc3GxMJCQKnIKvOEQQdqI4ZNxWVIIXyoiHo/EiQK2wpa2IT8biE7Q484sPz4sKXKYqrKqAC1wSBEEQ9hEbGwuNRoPSvr7IzrFQlL6oLc3EqNWbUUSploy1+jLKN7Sysh4jMfEbTJz4Bnr1ikTjxmXRqFEAbty4AotWZpauNOX1pQFoXS0cdyPW6s1cuHwZQ99/H1VefBHu7u6oW7cuvv76a7uLzl67dg3Dhw9HpUqV4OLigqCgIAwdOhRXryonRxgMBixcuBANGjSAh4cHvL290apVK2zevFl1G4cPH0ZMTAwCAgLg6uqKqlWrYtKkScjMJPs2giAsY3G0B49akCIAXJDDF6abQGbdlXyCXCfl92Yqr50yYVNgJ9W+tfMJfXygik/244NY1pBbyAmIHCC6VDWfrZqYSBAE8QygdF60hKWRoGrna7OEbUA6OvSebF4qTBaZEkT36QDRaz4Awwdj8oO875S/d8A2y4YEBIsYE7PJMafkQMEZosSglKUktuoyGykS6shW3FVeqwVpbBUdubJnAPb44t+13sTMyk0G33Ep+U7ba31gb4dKEATxpJORkYEffvgBGo0GD9LS8NOePQ6vK9fJ3NIrX5ZmDteb4bHlAt5yUcv791OwYMFUbN/+P+TkZMPb29esjZSCqTfjphqlUua/i5fQdsgQbNu7Fy82b44xY8YgLy8PI0aMwJgxY2xez8WLF9GwYUMsXrwYNWvWxNixY9G4cWOsWLECERERuHhRaiXKGEPv3r0xevRopKenIy4uDn379sXZs2cRExODhQsXmm1jw4YNaNGiBbZv34727dtj1KhR8Pf3x/Tp09GuXTtkZ5M/G0E8qwg3TPiR7TbaalXBBcsNAqFsCaYKr4XE2sjWZDZx4pq4UrNaJ2ZBOyndVBPffLPZsk0ZqzVPVdweCIIgnmWsBqPF5QxE51Gl4Lel8zAfXJeMCrWGPEADyEbNyPsc/r1SH1UQARrxvUV32bMRX5gSJsqA669t6LPFfb9SDWohId7CtQQ55RQvFJwhShzFnwVmD3qF10rTLJAqe29FaKhlgSmh1MFJOkKx0LBSB4iywAiCeJpZu3YtHj16hDfffBNOTk5IFI12sMXSjFkZQcNTYJZmEiszSzPk9WaURs+oYTrv+/qWxsKFa7Fjxxls3nwcNWvWV2jHPztQb8YGPGxIenhjxiykP3yIxM8/x7fTpuHTTz/F8ePH0bJlSyxcuBAHDx60aVtjx47FnTt3sGDBAmzfvh2ff/45Nm7ciLVr1+LOnTsYOXKkpP2PP/6IH3/8EZGRkTh16hS+/PJLLF68GP/++y9CQkIwfvx4JCcnC+0zMzPx+uuvQ6PRYP/+/Vi1ahXmzJmDgwcPYuTIkdi/fz/mzZtn1/EhCOLJxpYbI0qOA0o3YgBI6lQCsMNxwB1F6zggpmgdBwoiqY0gCIIoRpJFr62WO1BD3ifpZa/F09T6LyXkyd46qAZmLJktBMKsDzfr44knGgrOEEWKWRaYjZTcLDD5NCjMy4RiJlgqzLPAlKL7AFmbEQRBFDLx8fHQarUYHh2Nlg0bYu+RI7hyUyFC7gPk5eXh069XILx1d7iVj0R4w+6Y+WkCDAZptIS3NPtj90G8Hfs2WlVvhTpeIajjFYLeEa3ww+JlivvSyFuD4TGtcefmdUwc0R8v1g1AVO1SGDe0E65duQQASLp0GuPHdUPbVqUR1bIU3nuvF1JS5PurVG8GkmkZGWn45JN30L59XURGhmHAgGhs377RrLWHhyeaNImCj4+fyhFUqTfDT7JWb8ZFuo9KNWYU4csgpAHnki/j92Mn0DIiAu0iI02rdnHB9OnTAQBLliyxusqsrCxs374dgYGBGD16tGTeyy+/jHr16mH79u24dOmSMH3Tpk0AgA8++ADu7iZ1FRAQgDfffBPZ2dlISEgQph84cAB3795Ft27d0LBhQ2G6RqPBjBkzAADffPON3VZsBEE8G9hdFyXUka2o3SmS32xyNEjDP6toJTXscBxQtL4BWZsRBEEUBgVhacYHza2ep604ywDg+otUa43k/Y8jo2Tk9xXl9xctOfLYYYEdKn0rdhcSuw7xUA21JwcKzhCFjsUsMOOwuqcvCwwwP8nLssCsoRT1t6UDAlmbEQRB2MN///2HQ4cOITo6GmX9/dG3UycYDAZ8v2WLYvvXJnyCCbMWwmAwYGRcL7R/oSnmzf8e496ao9h+/qdLcej3Q6jbqC4GjopDzMCX8eBeCqYOH4s5b7+nuExG2gO82q0FblxJQqdeg9GgSWvs37sNo4a2w4Xz/yDuleZ4nPkQXbrGombNCOza9SMmThwgWoOlejPcs16vx8iRvXH8+EF07NgTXbr0we3bNzBx4kisXascODLHQr/EB2WUUOhq1erNKAVqNBnm7fYcPQYAaNOkCQDANyJCmNeiRQt4enpi717rhe5SUlKQm5uLkJAQaBRusIWFhQEAdu/eLUy7deuWZJ5S+127dtnU3tfXF35+frh8+bIkAEQQBGEJi3bQasgdB8yQZ/haqjtjzQ666B0HeBx1HJBA1mYEQRACBWlpZgm5pZnd9WYALik6VTxBNkpTmMZTkPVmxInh8vIJDtSbsQGb3IiIEkV+/JwIosCpd/+0fUGEUDgwqsQdytERHaQnX63xvQ62GfzzbflnGLdjY92au+AEkg3cuRSMspWv4iqCEYyruIhwVMEFXEAVxYi5HBYlCoi9AJP3ZHtYHNW0RaMhL0qCeEZgjCEzt2QXJXfXuivePP8/e+cdHkW5vuF7N72QBAihQygiiAIioCgWVFBQ7AV7L8de0J96VCzn2Cv2guXY6xEREQtwxEbHDggSeocNIT3Z/f2xO8ns7NTtSd77us41szOzswtH5tv3e5/veZwyefJkAE4KrLY4duRIJjz4IG9Nncp9V16MWssy+8eFvPL+pwzcaw++nzGZnJwsfK3g1jsuYN8hZ+ndnvufu59uPbo1WJoBlNel8o+xJ/POk09z+uXX06Fbt6C8mb9+/4UzL72e629/zH+gDh64/Qo+evs5Lj3nYC75x12ccea1UA++eh/XX38s33//OUuXLqZvX/UEk/Eze9u2zXTt2oPJkz8lLS0d8HHBBVdz9tlHMWnSvxg5cixFRZ0CV2uty9Qox+xpfhb+OJuFc2c3Nm5SgBQf7tT6wEsvKYHxOJU6ehR34IzzTwYgq854Rc1fq/2Tkb26dQs5l5KSQo8ePfjjjz+oq6sjNdX4J3Dr1q1JSUlh9erV+Hy+kP/GVq1aBcDy5csbjhUWFjac69evn6PrtZSWlrJz586G9/TqZU/hLQhCy2MAv1o3Fbr4QgVeiuOAdhJLt0xKCxxMpbEmUu+bUUtjfZSqOqZ+ButldmomrDyBbY7q2Db0nRNKcLRSaDl7GooC7TJuDwlWFgRBCBfLccyIEoPjRo40hqVtLPNmtBjYmkFo3owemogHS5chBWXOLzDfJ2NWciHNGSHp6c3KhiWNXVlrLwRM2+QIeQhn4X/wZuH3RNEWGHYbMhBcYCgFiPaY8pk6KNZmBapj6mJjIw5WB4XyC/tI51wQBMdU1lWy/9v7J/prmDL3zLlkp2VHdI/a2lreeOMN8vLyOOZQv5QrNzubYw47jPenT+frn+Yx+sAD/Bfnw38+mgbAnf93MTk5jc/1zp2LuObq8dw58Xmg0dIMoFuP4GZBNZmkpsJJl1/KD1/NZMH/ZnHsOecFXZOdk8s/bvbbWyk1wVHHnMFHbz9HfkFbxp9xTUO8jMvlYvTo0/j++89ZvnyJqjljnTdz5ZW3Bhozftq378T48Rfx/PMP8+WXUzj77H+ornap/qc+huaYeZNm4Y+zeemJu02vUXPgocMamjNmlO7eDUBeTo7u+by8PLxeL2VlZbRubWTPBtnZ2RxyyCHMmjWLZ599Nihf5uOPP2bJkiUAeDyehuNjxozh3Xff5YEHHuDwww8nM9OfKbR9+3aeeOKJkOsPOugg8vLy+OSTT1i8eDH77tuY4XPnnXc27KvfIwhCy2LqX4GV7DPwC6hmYrh6ow/LGlbO92JFg9K4qOfaUJVxe2z48mfhVxXroa6T7DZpFIyaNJUE10rKZ2eFntJjM/4/l5p1LuOcUjAVtS1p06/BDiZI1KZGRG2CIAgNRNPSzBI9RxntuGbHAhOIXd6MgrJ6xkS0nYX+6tVCQsc2jPNm9IQGuuOXBhmrEo/YmglJia1mgt6P7UL0O8y6P+i1uTPafTPMcmeaT8CleCgLgtDcmTJlClu3buXUU08lM6NxZcv4Y44BYPJ/pwRd//OffpnRwcP3RcvBIwbpfsbust08MvERjh14GANyu9Pf1Yr+rlZMOHk8AFs3bmi8uNq/6dpjDzKzVY2nOigs8nfqe/cZELKao21b/7lt25R1/dZ5Mykpqeyzz5CQs4MG+Ztyy5b9pvvnCUX9c1LHx0zJmwlw6Q13MX+Dj/lbfczf5WN+lZf5VV4W+6r43VfG774yVvi2ss63jnW+dUyZ/TbZTvIIosDjjz9Obm4uV111FUcffTQ333wzJ510EqeeeioDBgwAwO1u/HOfeeaZjBw5kjlz5rDPPvtw9dVXc/nll9O/f3/y8vJCrs/NzeWxxx6jtraW4cOHc/bZZzNhwgQOPPBAnn/+efr27RvyHkEQmj92JkjsTLQY4kjwZWYHrZ1oijR3BqKSO6NjbaY0p8TaTBAEITrE2tJMEWdHbGnmITTnuQF13pmyjWSVjDZvRq8ZY7JqxowwxNqSN9O0kJUzQtxJfhWY8hANRwWmtjQDe1IvFXaszQIqMMXaTCFe1maCILQMslKzmHvm3ER/DVOyUh3+sNVBsTQ7cdiwoOOHDh1K56Iipsz6lh2lpbTJzwegtGw3brebwrYFIffK7xT6y7msxs2ph53Kr4t+Za999+GEc04jt207UlNTWVuylqmvv0ltdXWQpRlATqu8kPogxef/2ZaTE+pPrNh01dXVos2W0cubASgoaIPb7Qo53ratfyDavdtovNSunlFwB++mBB8Kslo2QZ0xo5c304CSO1MauDYrF4Bd5fpV2K5du3C5XLRq1cryOwwcOJD58+czceJEZs2axaxZs+jduzcvvPACHo+Hm266iaKioobrU1NTmT59Og888ABvv/02L774Ivn5+Zx44olMmDCBPn36BF0PcNFFF9GpUyceeughpkyZQn19PUOHDuWbb77hwQcfZOnSpSHvEQRBUNCzgzZ1HCgmQjtoO44DVu4DMXYcSCBibSYIguCcsC3NjNCzNPOg6v8romn1azVO82aMihv1ihm9vBlwNFdYHPxSnS1nZ/5PSG6kOSPEhXE+n2V33fU/5+FgDXTENAQymDyCiwwlF0adGaOQSnjdc23ujPJZJngIy9pMyZ0xQ6zNBEFwisvlitgyLNlZu3YtX375JQDHXnaZ4XVvfjada67wr3LJb5WL1+tl23YP7Qpb41PN82/ZvB2AOtXPqy+nfMmvi37l1IvO4v6Xn6Aav91VDWl88e77TH39TfMvqTcEmQqr9Roz+ng8O/B6vSGrM7Zv90uSc3OVJlD08mYAFs4LZM6AZd5MKnV0Le7CReePbfzEMnTp1TWgil6zhoIhwSuC6uvrWbVqFT169DDNm1HTt29f3nvvvZDj559/PgBDNJ+RkZHBxIkTmThxYtDx2bNn614Pfju0MWPGhBw/55xzcLvdDB482NZ3FQSh5RJ27gwYNzWC5qoUO2iIfu6MstVzHFB9vvKyQHOZul5SW5uVEHHujFibCYIg2COulmYKJap9S5G2EbWafe2YVmdwLFzUc4KhYruGvBkjTKw6FWTer2kizRkh6WieKjAwXEXjCWzjFHCpLjTsoKcCk0JDEITmwGuvvYbX62XEiBH0aNs25LzbV8frn05j8n8/bWjODOy3B4t+W8qcHxdz0rjgZZ8/zFkUco+SlSUAHHl86AT84jnfBx+oNviiejVAIG/Gj1f9Qgf9vJn6+jp+/XUBAwcGrxpasmQeAHvuqZ7sUzdotBN89vNmABb+MJuXHnOWOXPR+WPJqjNZRQMcFGhkzJo7F+3dv/vuO8rLyzn00HBVIH7KysqYOnUqbdu2ZdSoUbbe89ZbbwEwfvx4W9d///33lJSUMHbsWPIDK7YEQRCcoHYc0EVpYhiFJgPBjgNpBDdMopE7E3/HAUXUZuY4YClqUzsOCIIgtDASbmmmJzTQshX9CIEg1GNanc5+reaYk3FOwShvJit4t4Dg+UCjuAYNvVjRsK+XN9OAMmYFxASyyjP5kOaMkLS0GBVYnAMu1Yi1mSAILRWfz8err76Ky+XiqRtvpLhLl6DzBYH4meWr1/Djz7+y4Jc/GDJgL845aSyvfjCVex5+maMOH052K/9DfOXmUp598q2ge1STSZfu/vsu/G4uR4xrrFYW/O9bPn7pFf+Lep1xTPuId1QPqPNmtDcKfv3MMw/wzDPvkpbmLxo2b97Au+++THp6BqNHH695r1UxpJM3o8OlE+7i0lvvAiXiJ91HakYN6YE/ZAZVZAY6VRlU2c6b2aO4mAP33Zc5CxYwffr0hhUpNTU13HHHHQBcfPHFQe/Ztm0b27Zto7CwkMLCxh8PlZWVpKWlBa2yqa6u5qKLLmLHjh08+eSTZGZmBt1r165dDfkyCh9++CGvvPIKQ4cO5aSTTrK8fsOGDVx88cWkpqZy77332vpzC4LQfAnXDlpNiB101BwHwFrAZget44BJaLKCh7AcB2KNWJsJgiDYJ2JLM7O8GfBbYXq0B1U50ED082Ygankz7QkZz9TRBlbYyagT0XVyIM0ZoUnRtFRgyla5v41CA4xVYDqFhpEKTA+xNhMEQWhk5syZrFq1ikMPPTSkMaPmguPH8ePPvzL5vU8ZMmAvRh41hAvOHMerb09ln0PGc8Lxh1FdXcv7H3zN0AMGMP2z4F/Bo8aNoktxN1586CmW/raM3nvvxd/LVjLns8857PgT+ebjD+1/6Xonf0Kfzn7wtrCwPVVVFZxxxhEcfPAoKisr+PrrqZSW7mTChHspKgoedJ54YiIezw7AxcqVfwDw5JO3k5WVC7g44YSLGTToEP/FSt5MCqFWy2rSIywIAnkznsCqo0f+7/84+uKLOeGEEzj99NPp2LEj06ZN4/fff+eqq67iwAMPDHr7008/zd13383EiRO56667Go4vXLiQk046iVGjRtG1a1d27drFtGnTWLNmDZdccglXX311yFfZf//96dq1K/369SMzM5N58+Yxe/ZsevbsyQcffEBKSnDzatKkSbz55puMGDGCoqIi1q5dy5QpU6ioqGDy5MliaSYILZSY2kEXE0XHAW2TBhrrnxjlzngC2wQ4Doi1mSAILR29sSnhlmZqjOYBK0N2NK9rVdtI82aUokedN6MQnbwZNXqibCdOOUJyYN8gXBCiSIOiSPkha7I03Gh5XkjH2JFCSq1UVR6K6i63+qEZSQ9T62Fpgl5X37TJZA+9gU5vQDRDb8C1XM4qCIKQxEyePBlozA9Ro6yaATj96FFkZWbwzqczqKzy22q99OQ/uf/OK3G5XDzz7Ad8MeMHrrrhXB564v9C7pWTm8MbMz/mqJOP5bf5i3jn6RfZumEj/37rNU697IrQL6ZdQa9HyBxPkMcZdvJmANLS0nj66XcZPPgAPv/8I6ZOfY+ioo7861/PcPrpF2uudjFz5lSmTXuXadPeYetWv1Rt5sxPmDbtTaZNe4O1awPiALNflzpDampGTcN+BlW6+yEY5M7069WLeQsWcNxxxzFt2jSefPJJ3G43zzzzDJMmTTL5YsF069aNww47jDlz5vD444/zzjvv0Lt3bz788ENefPFFXDpj4Omnn86mTZt49dVXmTRpEps3b+b2229n8eLFdO/ePeT6Aw88kK5duzJ16lQeeeQRvv76a8aOHcv8+fM577zzbH9XQRBaLnoTMOqJGqtcSsDf2GhHo61KATpzRkaTSHbrJGVwq9M5pncOGkObKxtfWqGXOxBwWQhaPUSjZY5ioeMIg9VLgiAILZooWZoZoueao33ub7X3WaHzc9q5uzqcNWrM0I6VOnkzDjAa20WM3XSRlTNC3Gg6KjAnuTNqazMIVYFB8MNcY2umPlSgubVVwKVYmwmCIITN22+/zdtvv41nwQLT6/I651Kx7LugYykpKdxy3fn83x3nA1CZ2mhvtdXnb1BU03isa4/uPPPhqw3HagLjQ111OvOrvFDjCsqbmb8m8GyvI2gI6dS5mPlLfDp5M7Dffocyf34F/pNqOzP9vJlPP53fcOy22x7ittseCpzRy5NxBd6zGH/nxR04piyNUY7ZXCFq8zI9XAZNGTV77rknH3zwga373XXXXUErZhS6devG+++/7+i7Gd3LiMMPP5zDD5cZPkEQIidsO+hCwnAcUOqlWDkOhJk7Y9PaTHEcMCMSxwHJ6xQEQbDGbMwyzJtRY2Zp5jH75Eqs82a0GIkIzNDmzRjYmhXQKIxoh3HejMncn9D0kZUzQlKSGBVYqsG+E9RBYeoCRU/qZSL/st3tbxyolIHLzPYtEk9Pp8tVBUEQmh2tGnd9rYwvU1MVCFZRN2vA35gB/I0ZK2zPd5nly+g3aaxW1zSi17RRYy9vpoGM4Jfpqj9kprpTFSCrzmQVjSAIghCC2upY16Nem2lpiJnjgHY/EpRxoJKwHAf0KHH2DawcBwxFhEcZHBcEQWgmNDlLM2WcCLI026W5SJ03E04DBkJzZhzkzZhpEXTGaPVYrh7jjdyGgEbxdUB4LbloyYk0Z4SkxpZiSa+DrNdpDkL9FNQ+PKNhaVan2YK9tfiYW5vZDu8MJmxrM4tCQ6zNBEFobhRkWF9jRIUNtW+N2SSWmaWZad6MF31bM+0+No6b0biKxo9bs9WcUufN6GEjbybb7tgpCILQTDGyg9bLPzGdoFHjyA4agmsnozpK/cDXG+uM7J6NGjGKunlX8CGPzqXqesnE2swIsTYTBEGIkGSwNHOENm9Gu2+FnrhbPQZGIW/G8VjdiG5GmgZZ0Zk8SHNGaHI0HRWY3gqaWkybNOWa10Y2AyX2voHdQiNowJRCQxCEFoKVpRn51vdQW5opaFfJWBK6UMSPtkHjRcfSTNu10dqZqY/ZxWqVjLpJo/kpqTRltCjZmCaYZszoUerfeIz+/gRBEJooTiZMHDsOFBvcSOs4EIJ2Mkk7GWUXtYhNL3emeToOiKhNEATBT1QtzdRsxd6qypBVMrU6+9q8GbsrarRjYoR5M8XBL9Vjul6Mgd5vAiH5keaMkDCMVGB6xFcFpnhDKthVgRlZminn1OgEXHosvqqNgEspNARBEGJEBJZmhuhZmmmHCycCroaujZ2VM3qYNWFchK6cUWPzJ6VmGE3NqNG9TK9R05A3Y5I7UzBkiL3vIQiC0BIxchzQcx3QFfYqD3Gz2sgpTctxQKzNBEFoaSSNpZkeeoJmrei5AW3eDFg3XRwVYwG0c4oGtmYQmjejh428mXBz0oTkQJozQlyxowJTlt8lTgWmEKmlGVjnzpigVoFFodDQIxrWZoIgCE0RvVUzsbI0U1bSKJZmDXkzasK2NFPQrpLRrpyJRd6Mso0sb0aN5M0IgiCEh97EjNpxQBdbjgNZ2HMcCLd2AseOAx4S5jgQhDgOCIIgxN7STKFEta8WL+s9/z3YzJsB/3ijXSljB6u8mTRM82YKdG5ZiOXYbDm2a5G8maRHmjNC8yCmKrA4587EMeBSjVibCYIgqIiXpZldvGYntCetmi4xzpsBR/2a9LAUaYIgCC2HEMcBE4wcB0LsoCNyHFCTAMcBK8RxQBAEIWY4XTWjJixLM4vcMMCBxaW68a9Xg9hp0DjNm1ETXt6MbqQD+mN+Q96MiTuR5M0kF9KcEZKesJbnRVUF5hS9QDE9FZimsogw4FLrxakMbHoqBCk0BEEQHBBLSzP1QhErW2PbeTPKvt5qGSvs5M0YoOTNKL8ulbpErzZJt/5e2U5XnAqCILQUAhMudgJ/dSkO94O1CuBwc2cUktNxwJG1mVrUJo4DgiA0Myznemw892w5tthF/VzXrpjxWL1ZPbYY5c1oj0HkeTPKvGJkeTNWSN5M00WaM0Jy4FAFpl7GF10VmHo/XBWY9pjeOQWVCkwPo+6/SaGhqAzsItZmgiC0NJLO0swI2wtKfOjnzWByTIudvBnttQ5+RurM26nzZtQZM3p5Mw0oeTOl/o0n1AVNEAShWeBE1erYDlqPQkLtoJM9dybKjgNhWZsJgiAIfhxYmjnOmykxOK5t0CjjQoilmXpMUfYTnDcDjXkzNlCP5eoxXmj6NOvmzH//+19GjRpF27ZtyczMpEePHpxxxhmsXWvjx6kQF0I8DxOuAlOIlgpMu9ULIDPAo3lt5KPsALE2EwRBsEG8Lc0iyptRW5oZ5c2ot9prrdA2ZbR5M2D752QE83euMutrCoYMCf8DBKGFI3VT88CW44DWDro9xiHEDTi1NLNLYh0HxNpMEATBGL3nVbQtzZSmuCNLM73nu22083GR5s3Yfa1CyZvJUR0zimXQi3DQEJbbkJBUNMvmjM/n47LLLuOkk05i1apVjB8/nuuuu46DDz6YH374gdWrVyf6K7ZoWpYKzAiLDrxZwKV6ICoJbA08OMXaTBAEIUJiYWmmYLTqQ1sbKA5mIZZmRo0ZM8LNm1Gj/vnoIFwG0P7VqPNmMnX+QrLqTFbRCIIQMVI3NX+i5zigZwet0DQdB8wQazNBEAQLksXSbCv2VlGGNP+tMs/00ObNaMc8B6tm9NCJaFCP3eox3ShjDmjMmwm4FIUI44WkItxlAUnNpEmTePHFF7niiiuYNGkSKSnBEwd1dU47okKiGcCv1s2ELr7QJkUhFitOsvD/0Fe/rsT/QFX+O1Eerk6WNNbS+M+rNnAP9TEILS6yGg8VOPgoFVv+7kpRz7Wspau9xlWAJW36WftTHoUt+zlBEIRkJikszWp0mtnanyaOVtFrLc1ikTejtTaDEI2PkjeTgrmAWvJmBCFpkLqp6TD1r4BAagb+3+UzMVzp3odlzi1jbKGMebWB/TL8D/twrF8U6gL3qKOxBkujsTCyIZjzEFw/bSNUgVyCLZeFFfSKql3MuD1CJ8WmulwSxiwIQvMiEZZmWrQi5wa0KzD18mSimTejDt3Ua9ZY4DiqoRE7LkQy/iQfzW7lTGVlJXfffTc9e/bkySefDCkwAFJTm2VPqsVhqgID3Y6zPmahXLFQgSVfwKUasTYTBKFFo7Y0M1g1kzyWZmBtVxZp3ozRtZqfkEa/KNW1iQGmGTOCIMQMqZuSHzsTKHbtoEOEW8UGF9pyHFBQ10ax+m8lSo4DChprM4VYWZsJgiA0RRJhaRaClaWZnhDbg07ejBp13ozZ+BLNvBkwn3fUoTj4pVXejKXYWkhqml1z5ssvv2Tnzp2ccMIJ1NfX8/HHH/PAAw/w/PPPs2LFCusbCHGnQUmkrMqYaXSlxbI9NY47zVmELjmMVe4MJCLgckdJa7p/vhZXbaMNjlibCYLQkoj2qhk9IrY00+I1O6Fna6bej0bejPZ/EI28mdSMGt3LTBs1Su5MqX/jsft3KAiCLlI3NS+s7KCD0POwN/K7b0A9yWS0msVuk0Yva0a9r1UuK6pnlbWZXu6MkbWZCUq+gR5ibSYIgmBADCzNQvJm7ODoua/Mw2nHnVjnzWhUDgU0CiDaYTz2St5Mi6HZSaEWLlwIQEpKCgMGDGD58uUN59xuN9dffz2PPPKI4furq6uprm6s9nft0nZZhXjh+p/1ckjwd5CDflQXo9+waKdzLKRHoiyhV5bnK3ZkkaJnbabcV0eO5iHU2ky9PH8zoSuD1rmgiy/E2qx2aQZnvPg+e32xFLfXx9JzerPrn9khHynWZoIgCM7QWpqtLVnL8B7DOem807n3tZeBMCzN9GoD3bwZNVo7MyMazx133P4AfPrpXJPr9Yhe3oxdXGXW1xQMGRLezQWhBRNJ3SQ1U/Jiyw5aS3tsBizn0VhAae2g1bbQEGxVZlRPKZZm0FgnqeulSmzbwHgwtjbbSEQ2MZEi1maCIDRrYmlppmCUNwM2RM2VhDrd1OrsWzVqopA3YzSkFWLp/qN2DxKaF81u5cyWLVsAeOyxx8jPz2fevHmUlZXx7bff0qdPHx599FGee+45w/fff//95OfnN/yva1cHHVvBNk5+iMZXBabGyNLMDPUDX2tphuqcGh0VmJYwAi73WvwHN1w+iX+fcDd7f/4nbq//76bvGyvYuaCN8Ruxb20mq2cEQWhOXHjvPbi6D6XtwCOpTm9c2REXSzM7x3XR5s0o+9HKm/GvmKmqquTNN5/i9tvP55RTBjFsWBZDh6awYUOJ/zK7eTMq0lV/0EydpURZdeHZnS1fvpzTTjuNwsJCsrKyGDhwIM899xw+h5Ng69at47LLLqNbt26kp6fTqVMnLrjgAtau1c9083q9PP300wwePJjs7Gzy8vI45JBD+PTTTw0/Y+7cuRx//PEUFhaSkZHBHnvswZ133kllpf4K2507dzJhwgR69+5NRkYG7dq145RTTuH333939GcTBIisbpKaKflROw6Y2kGH5Tig3dc+9G3kxAQRBccBO5QEtnpWOTRa6yhWO2ocN7xk9YwgCE2UeFqaGaI8p0tUx8wszRTnmaAhY5fmgNrSzAy9yAIzYpc3oxvhgIWrkOJGFBBWa8UBQvLR7JozXq9/kiI9PZ1PPvmEoUOHkpuby8EHH8wHH3yA2+3m0UcfNXz/rbfeSmlpacP/jApwIf6EtVwvrNwZ7UMznCaNgl7+DDjKnfFoXuv5apY07g5f/AOPn3sTz55xPfvO/gWv28WvY/vx3CcXsGJEDwB6f9T4BvFQFgShJWBlaVZWXs77077G5XKxw1PKJ9Nmx+7L6Nlx6eXNGFqaETipZ1+m3bdqSNjLm9mxYytPPvlPZsz4gJqaKvLyWjeeNMub0ZJu3SDJdprNpuKPP/5g2LBhTJkyhTFjxnDNNddQX1/PFVdcwTXXXGP7PitXrmS//fbjxRdfpF+/flx77bUMGzaM119/nSFDhrByZbBIxOfzcdppp3H11Veza9cuLrroIsaPH8+yZcs4/vjjefrpp0M+4+OPP2bEiBHMmDGDo446iquuuoq2bdty7733MmrUqKBVCQDbt29n//3359FHH6WoqIirrrqKUaNGMXXqVIYNG8bcuU5XQQktnUjqJqmZEoMTO2hbFBsc1+bOhKCul7STUpGiJ3CrRbd+0rM2A+vcmQCKdU5MrM0EQRCaK1G0NFOa4tG3NNOOGdrXtYRnaWaEdvyLXt6MXexk0MmKzeSk2TVn8vP9Sb5DhgyhU6dOQef23ntvevbsycqVK/F4PLrvz8jIIC8vL+h/QnKRWBWYGieKMCMVWPQCLrOryrnv3luZ/M+L2XfeL9SmuJh2wqHc/Pm/eOixGylvm0P3Bf6/ozWjO+t+nOEAKh7KgiA0c97731eUV1Ry/UVn4Ha7mfym8WoHCLU0A6jGb11Wb2T3FTVLM72sGasf2navUwjOmikoaMvTT0/h66/X8emny+nXz6aNmDYXk+C8GXXGjGnejE3+8Y9/UFpayieffMIbb7zBgw8+yKJFizj44IN5+umn+fHHH23d59prr2XLli08+eSTzJgxg4cffphPPvmE9957jy1btnDllVcGXf/RRx/x0UcfcdBBB/Hrr7/y1FNP8eKLL/L777/TvXt3JkyYQElJScP1lZWVXH755bhcLr7//nveeustHn30UX788UeuvPJKvv/+ex5//PGgz5g4cSJ//fUXN9xwAz/88AOPPvoob7/9NrNnz6a6upoLL7ywYbJdEOwQSd0kNVP8sDORokzIWDkOqAOFdTFyHMgK2cHYYz+S3BmHjgN6qCfr1LWTieNAPBDHAUEQmgKOV80kvaUZGOechYNZvozW2gzCzptRoR679VyELOMJhKSn2TVn9tzT/4+9oKBA97xy3MguQkgMCVOBhcyt6anAnK6WUWOn6Igw4HIj9Nu0kI/vPZoTfpyC1wXT93NxzeVunrpoCJuL/cuHRrz0E2lVdWwd2Ib1h3W0HBjt5P2AFBqCIDQPJr/3KampKdx8+bmMPHg/vvl2PqvXbtS1NKuvr2fSgy8wtPfhdMnci4N6H8TT9z8dMjGu5M38NON77rn0Qk4evCeHdMzlkO65nHvkED5+88XQL1ILQ/u7uOzCw9iyeT2333YmRx5ZyKGHtuK6645l3bq/AVi1aikTJpzGEUd049BDO/N//3ce27dvxqmlWVlZKffddzNHHTWAgw4q5qyzjmTGjI9DrsvObsX++x9Ofr65LWa0acibUbal/o1HZ/XR8uXL+fbbbxk5ciRjxoxpOJ6ens69994LwEsvvWT5mVVVVcyYMYP27dtz9dVXB5079dRTGTRoEDNmzODvv/9uOD5lyhQAbrvtNrKyGn9LFBYWcv3111NdXc2rr77acPyHH35g69atnHDCCey3336Nf16Xi3/9618APP/880FWbFOmTMHtdnP33XcHfafhw4czbtw4/vjjD/73PxuSOUEIIHVT88WW44DWDjoix4E0zVbBTg2lJ2KLkuOAHiXmp8XaTBAEIXqEZWmmUGLzQwytLZW5NoVY5c1AxHkzWvQiGzSE5S4kJCXNrjkzcuRIAP78M7RzWFtby4oVK8jJyaFdO710eCGeJIUKrAEzFZj6uNPcGe0xp96VATya14pawOflohU38+7rZ9Nt6za2tYK7zkrh1dEpbM9zUVu0FwCtN+1gyLuLAZh+7WjQNE+k0BAEoTljZWn2x8q/+Wnxr4w++ADat2vLuacfg9fr5dW3pure76pL7+beWx7G5/Vx3pXnc9hRh/HiYy9y77X/1L3+9UcfYvF3c9hr8FBOvfQqxpx6Np7t27j/5st4/N4bdYeEsl07ueSCEWzYsIpjjjmPwYMP4/vvp3PVVWNZseJ3LrrocCoqyhk37hz69RvEzJmfcvvtlzj6e6mtreXKK09n0aIfGTv2FMaNG8/mzRu4/fZ/8N57LxG8egYafzbq/HxU8maMyAh+Gc28mYIh/lU8s2fPBmD06NEh14wYMYKcnBxbzYvt27dTV1dH9+7dcemIDXr08FuEzpo1q+HYpk2bgs7pXT9z5kxb1xcUFNC6dWtWr14d1ADatGkThYWF5Obm2voMQbBC6qaWRWJzZ2LsOODBkeMA0JBnYMfaTI9oWZuJqE0QhCaFei7I4HkXsaWZXi6YWd4MaPJmFPGzmljmzaSp9sPMm9ERSKjHavUYbpo3IzRJomEKm1T06tWL0aNH8+WXX/Lyyy9z8cUXN5x74IEH8Hg8nH322aSmNrs/uqDQxRf8MG+PqddwI3k0PrCzAvtp+B/KyhYaH7LKcatlkXX4/6mp75Om+iyLYqUc/0ofhW1AIRTUrefhOWdz0Er/BM/cPXK498zz8OS2odOOB6jO7EZ9XmfW4uL8598gtaaekqFd+Xt4d3rzt94nAf6BVGmG+Q5V+VYeTuQrmgRBaDL4fD58Sa6WdmVl6U6eO2XyNP+qh3NOGgvASceO5IqbHuTVt6dyx70X43Y3NiO+nT2f/7zyX/oP7Me0798jNce/kuSS225k3KCRuve/ZdKzdO7c0/8i0Ieoq6zjurPH8t4rT3LGudfSoX23oPf8tfwXzjzreq6/4bEGF7MHHvgHH330PJdeegSXXHIbZ5zxD8CLz+fl+utP4/vvv2Lp0p/p23cf9FfQBL/etm0zXbv2YPLkqaSlpQMuLrjgOs4++0gmTbqbkSPHUVTUJXC1tkGjeqluyig6BpO8mQWz/8fiQCMlhTpSA+NrKnWkKfte/9ZVDcXdOnL+uHE6Nwzmr7/8y3D32CN0SWdKSgo9evTgjz/+oK6uzvR3YOvWrUlJSWH16tX4fL6Q/8ZWrVoF+FfqKBQWFjac69evn6PrtZSWlrJz586G9/Tq1avhPVu2bGH37t0hDRq9zxAEK6Ruav70YVn4NjIKORisSMmisQ5KNdh3Qi2Ng0ctjXWW+r8/7e+SLP8h7bzXVvyOCVo2EkYjym/FE43JsHF7SDCzIAjJi2NLMxVxtzQzzZvRoowdatF0NPNmzHBg+xrG+KTQMG+nzNkFXIlkzGkaNMtf2s8++ywHHnggl1xyCZ988gl9+/Zl8eLFzJw5k+7du/Pwww8n+isKYTKAX0NWefRiRUPHvajn2uAAsY449BdWmjLKfhnhFxhmaIsP5fM0eNAN4Tyg/kMe/PEu2pXVU5sCzxw+khf3mwSdUui2xR94XNbmEHC52OP3vzj0w+8AmHntIUGrZmJZaEx1uSRsTBCaML7KSpYN3s/6wgSy56KFuLKzTa/RWzWjpra2jjf+O528VjmcMPpQaAW5ZHPiMYfx5vvT+fqbeYwedUCDpdnb//Fn0Uy48ypycrIb1nx06NyR86+9hMfveABotDSrq06nc48eUBP8uampqZx01uXM/fYrFvw0i2OPPy/ofHZ2Lv/4x7+C8maOOuo0PvroefLz2zB+/BUoJ10uGD36JL7//iuWL/890JxRY/wsvvLKWwKNGT/t23di/PhLeP75B/jyy/9y9tnXqK52GewbYPArc8Hsb3nh7vut3x/g0IMG22rOlJb6Pc+UHA0teXl5eL1eysrKaN26teF9srOzOeSQQ5g1axbPPvtsUL7Mxx9/zJIlSwCCcjjGjBnDu+++ywMPPMDhhx9OZqb/v5ft27fzxBNPhFx/0EEHkZeXxyeffMLixYvZd999G87deeedDfvaz3j11Ve5++67g37Lzp07l88++yzkekGwg9RNTY+pfwUmy2bgVzDPBA73T8zYsSXuytrgVSLFhFrHqB0HPDTWIyGajTTNQbVwzU4NpTRhFDGbcixVtQ3DcaBA9TogbAuiBFuByyvoFeLv/wv7OLOSOYpG625BEISmjEMHlZhbmnkwsTQD/VWXVtlmZhjlzSirZ5TXBrZmEJw3o0dx8Mto5s3I/Fzy0uxszcCvAluwYAHnn38+CxcuZNKkSfz1119ceeWVzJs3jw4dOiT6KwoREnFDoYDGh2IIerkz2n0rjALH1PsWAZc6gnX3llquT7mUlxbdQbuyeta3TuX84+/jxS7PgjuV1rv/S17FbHyuVDztjydrdwUTb7iP1Lp65o8azJoh/kJMPJQFQWjJqC3Npsz6H1u37+TUsUeSmdl44tzTjwHglVenBL3315/9488BBw+hmsygc0MOPkD388rLynjh3xM588CBHNI9l6HtXAzt5OL/LjsZgK2bN/gvVA0RXbvtQWZWcOOpsND/+6V3771xuZQf1/5t27ZFAGzbtsnsjx5ESkoq++wzJOT4oEH+P8eyZcrkk7YRY+ZfpkFTw6Rm1HD5XXfwu6+MFb6trPBtZZ1vHVt9K9jqW0G571fKfb/irZ2Pt3Y+vh3zmf32C/43m+TNRJvHH3+c3NxcrrrqKo4++mhuvvlmTjrpJE499VQGDBgAELSi6swzz2TkyJHMmTOHffbZh6uvvprLL7+c/v37NwSlq6/Pzc3lscceo7a2luHDh3P22WczYcIEDjzwQJ5//nn69u0b8p577rmHjh078sgjjzBixAgmTJjAWWedxSGHHMJee+0Vcr0g2EHqpqaBkwkVKzvoIPQ87S3DibOwzp1RP/yd1FBaSzNlq80NMCBK1maK8M+KcKzNJK9TEIRmQdJamkHQvFoDdTr76pU0eucVrPJm1JjkzRRoLjWKX5C8mRZHs1w5A9C1a9eg4FUh+YmrCky9BFK9WKYBIxWYUmjYXUmjtjSDYBUY6K/D16d99TIeb30+g/7yAPB1r07c1ucNdud3AiDD9zcdt/mVyJsLrqYqpw833H01XVZvYHPHdky+91zKyQ3yqjQiHGszWT0jCM0LV1YWey5amOivYYory97z04zJU/3Nl3NPHht0/IhDh9K5cxFTPv2WHTtKySryN2I8peW43W7aFrZpWAxTFQhUKWxfFHL/2poaLh89kqWLF7HnPvsy9rRzyM9vS0pKKhvWlDDtw9eprVF1G+r9m5yc0CXwKSmpgXOtNGd8DbZDdXW1mK2UUVNQ0Aa3W9tocTU0enbv3oWtvBkrMqwv0cNV5ux6ZcWMsoJGy65du3C5XLRqpf37C2XgwIHMnz+fiRMnMmvWLGbNmkXv3r154YUX8Hg83HTTTRQVNf7/nZqayvTp03nggQd4++23efHFF8nPz+fEE09kwoQJ9OnTJ+h6gIsuuohOnTrx0EMPMWXKFOrr6xk6dCjffPMNDz74IEuXLg16T5cuXRq+0/Tp05k3bx5du3blnnvuobi4mPHjx4d8hiDYQeqm5ome44Altu2gwZ7jgHoljR07aD1sOA4oX6MgjNvbJFqOA4IgCMlI0luaqTG1NNNr5KvzZszGIadjlHrFjHp6PfK8GTV25vCEpk2zbc4ITYNxPp9tpdCgHX+GdOF7s1J/WaQ2dwb8jRm9LnsDykPTTu4MhLXMPiR3Rvt5OnjgyC4vcW/tE+Rv9lKVBo/0OZG3c/7t97LZDK721XThJtxUsTvzALbln8/x30xh1NSZ1Ke4uffRWykvCA0QBik0BEHQx+VyWVqGJTt6lmbqVTNrN23iyzlzATj0tMsM7/Pqu19zxTVnAZCX3wqv18v2bTto1a5z0HUbN/ubAt5AA6OuOp3/Tf2ApYsXcfy5F3H7oy/7LwyIs7789F2mffi6yrpMg1e9U0/ohT6DfTuvwePZgdfrDay2cKE0YrZv3wJAbq7SIDLIm1EOKbkzZiKy9MbPXzJ7JvNnzwnJmwFIoy4obwaguKgj55+qb2tWMKRx5Y+SNaNkz6ipr69n1apV9OjRw3Z+Rt++fXnvvfdCjp9//vkADBkSvOooIyODiRMnMnHixKDjswP5OtrrwW9VNmbMmJDj55xzDm63m8GDBwcd79y5My+//HLI9XfddZfhZwiC0LJR585E1w4agnNn1EQzdwb0HQfU30EHde6M2tpMyZ0pQazNBEEQ7JJoSzPtXJ7Hzo1qCc2bUfZjkTejrTHCz5sp6rlW9zJbc3cy3jQ5pDkjNDmavgrMfsBlWnUltxx0IWdsWQLA320zuCHzMZZXHu63ZQvQodVjZLGcOtqwrt199NhUwh3P3AvAq1efw2+D+wd5VYIUGoIgCK9N+Qyv18uIoYPYs2f3IJGTLw3q6+p4/Y1pvD7544bmTP+B/fhl0e/MmfMzY08Kbs4snPNDyGes+9v/nD3kmOP9B1TDxeJ5cxpfOJ7DUtuama2UMT5XX1/Hr78uYODAYaqjLpYs8Tes9txzgNMv5ScV01+Y82fP4VknmTMHDDZszgRdd6hftvfll19yyy23BJ377rvvKC8vb7gmXMrKypg6dSpt27Zl1KhRtt7z1ltvATB+/Hhb13///feUlJQwduxYw/wcNfX19bz77rukpqZy8skn2/oMQRCaCRrHgbApRn8yTM8T39BxQKmXYpE7o3ywzRWzHuyvolnngi4+tvzdlaKefieGrqxlJb3FcUAQhBaB41UzqvFGvWom7pZm5Zg0adSDldrGTG/fDlHIm7Gi2PlbQDXuaMYc7TgjJC9iTC00C9Q/nEM6zJoOtD2imTujXjqp3RoPBsWFi3h/2MENjZlPuvXm5O2zWb5LNRJug1Z5s2ib8zYA6/gX7pI8Hn3pBrKrK/lp4P68fcnpQfcVD2VBEATw+Xy8+ulUXC4Xrz82kZcfup2XJ6n+9+LtvPrKXQw/YB9++2U5ixb8TgVZnHaOv8nyxD1PUFFe0WBptmn9Rt588tmQz+nYrTsAP3/3XdDxhT/9j0/eeSn44nrb316zDZ9nnrmf2tqahtebN2/g3XdfJD09g9GjT6Zx1YyyTcH2z0edvBmAK++6zXbejG/HfGa/94Ktj9tzzz055JBDmDVrFtOnT284XlNTwx133AHAxRdfHPSebdu2sXTpUrZtC672KisrqasLHqOrq6u56KKL2LFjB3feeSeZmcGZQ7t27Qr5Th9++CGvvPIKQ4cO5aSTTrK8fsOGDVx88cWkpqZy7733Bp2rra2lsjJ4VtTr9TJhwgSWLVvG1VdfTadOnULuKQhC88NowqVhgkYHtSgrSLRllDtjmj2jnXSKR+4M6Ic7q/DgPHfGAWFb84DkdQqC0CKJ6Lmph56lWaV6R5s3o7Y00yOOeTNK7nU7jMdY1ZisHqv1suP0MuaMkOZ/ciMrZ4TkxIYKTL1E35BijHNntMQsd0ZBrQbT++DGAubIAc9xHy+Tu9NHWYaLe/Mu4LPfbgpaLcM2SO28mc5d/RNO28rPZXfuwdz+9T30XbeM7a3acPNND7FtdbsgFZgeYm0mCEJzw8rSbObc+axau4FDDxhMz25dQBVD4lPtn3XhSfz406+8Pvlj7h8yhBEjh3PGBafwzqsfcuQ+R3LkicdQU13NtPemMOCAofzvsy8Av6UZwMGjjqNT92L+89RDrPzzN3r12ZvVK5bx3TefcdjoE/lm+ofBX9LI4gwv+jZm2pUz9n94Fxa2p6qqgjPOOJyDDx5NZWUlX3/9KaWlO5gw4UGKipSVQf5mzBNP3IrHsx1wsXKlf5Xlk49PICs7F1xwwukXM+iAEcEfEmbeTANK7kwgRsZTbXglAM8++ywHHXQQJ5xwAqeffjodO3Zk2rRp/P7771x11VUceOCBQdc//fTT3H333UycOLHBGgxg4cKFnHTSSYwaNYquXbuya9cupk2bxpo1a7jkkku4+uqrQz57//33p2vXrvTr14/MzEzmzZvH7Nmz6dmzJx988AEpKcH5PpMmTeLNN99kxIgRFBUVsXbtWqZMmUJFRQWTJ08OsTTbvHkz/fv3Z/To0fTo0YOamhpmzJjB0qVLOeaYY7j/fvurkQRBaHpEagdtG9uOA4pdS3I4DugurLGyNrNJ2I4DJqtn9JDVM4IgJCXxtDRTUNtsmsYTmKEdK2oxXjETi7wZCwqxzJvRw5HbjdAkkJUzQlIRFxWYKVkE+0LqqcDUhNPf1FOBAdSCazfXHHo9T9S8RG61j6Vtsjll8xt8tvCmkLu43NV03XsCqamlVFb0Y3PZtYz+cQZnLnoHgFsveIBtbfS8CBrRGyAdW8bZGKhl9YwgCMnG5M8+BeD8U441ve7k048mKyuTD96ZTmVlFQAPvPQot9x/Cy6Xizeensz/ps/kvBuu4pYnHgx5f3ZuLs9Oncnhx57MH4vn8/6rT7N18wbuefItTj3nSv9Feg2ZhiFMyZvRw2gCxypvxv9MTktL4+mn32Pw4AP5/PMPmTr1HYqKOvKvf73I6adfGnQtuJk582OmTfsP06a9ztatGwCY+fVHTPv0daZNeZ2166ztX9JVhU8moZ2WrDr/37GrLORUCAU6+Sr9+/dn7ty5HHfccUybNo0nn3wSt9vNM888w6RJk6xvGqBbt24cdthhzJkzh8cff5x33nmH3r178+GHH/Liiy/i0hnXTj/9dDZt2sSrr77KpEmT2Lx5M7fffjuLFy+me/fuIdcfeOCBdO3alalTp/LII4/w9ddfM3bsWObPn895550Xcn1+fj7HH388ixYt4qmnnuKVV16hdevWvPTSS3z66adkZETaCRMEobmhN4ETW8cBNeE4Dqj37TsO6OKxcU1JYBuw0lGsdRSrnYgdB0ywG7QtCIIQSxJtadaAYmlWYuMmHkJXSYagN67EEr28GZvWZiZ5M3bsNYWmj8vnE2mGGbt27SI/P5/S0lLy8hyEOQmOUA8IDQOBMukfePgrD/4lbfo1NBCUlTMr6BX0kA/yrCyhseu+GX/XfRt+JZWHRp/KkKWQO2gMECtT7SvL6esI9azUovahBP/DWemsZwe2WUAW2TmVPDLsNg7bsBOAaW324LZ571BbnxO6DBIfXcbdQkH3z6mvbcXKFe/Qw1fJW9POIruugskHXMijF97U0KBSHu7q5pXykFeaW+qVM+pCTr1UMqhJplaBaXJn9JpsogIThOSjqqqqISRda9HUlNFbNQPBK2dQR3m0Cr5OvXKmMtX/91Kh+nFdjf9YlWpZiHKsJvDcV1bOUOOioQehnl9S9+lraey/KL2YhoaN+oCyr6yW8ar+59P8D52tgtaqzBX4n1u11e6nBG+1h5ShTRnuFDFZBpDu//zUjJqG5kwGVQ3NmQz8DZlsKkObMyYrZ/SaM4I1dv/dy29gwQny30t8cFozAUF1kyLMWknvYK9/9aTYRoJrJjCom9T2Mdr9OqCCxgJLOQbGk2Ra/36lTkojuIZSxHRZmn0aayYIto+BYKu29jROhhUHtgZ1k7Zmgsa6yXHNBEF1k9RMgiAkGlvNGbUg10ZzRjtfBwSNP2CSN1MSeIMyFoGN8UhhM8FzdkZzecqYpDefp2drph2fsmgcj7JVr9sEzrdBd2zS2popK2c6EjQWqcch9RikHXuUccf1PxrHmsAYox1fZGyJHdH4DSwrZ4QmSfxVYNHIndEeazxX3HkR7w+4hsM27KTODQ9ljuWmH9/1N2a0eKDd0Ocp6P45Pm8qa359jFZlOTw74wqy6yr4oXg4Txx2vf9aAxWYHuKhLAhCcyaoMWOCXmPGCtPGjBYzS2NDSzPlpNa+LNY/srVNHLD101FnmFTyZqCxGaMmO9RbVBAEQTChYeJFmewPTMyYOQ6oMbI7biChuTMKJo4DeuNGJaErZvTyCSDYMscGVo4DhqpxyesUBCFJsfW8SSZLM6PnOaA7JljmzSiYNWbU+zHOm1FhN2/G7pgvJDfSnBGaNLZyUooNjrcj+CGpu+JQa2OmLjaiUWDUcdjAL3i3cBI9d9TiyXJzWeX1vPbzzQRPhtGgDsjvP432+/tDpzcs/Ce12/dl0sKr6Vi1kVV5xdxw4OPUu+1/t7CtzaTQEAShqZJvfYkRSiMmZhiumlFjlj1jt2Fj9Ex2qf4HwT8VU0IvN8Km3bJeo0YQBEGIHL2gYL0JHiDUDtq2B77WtkXZ19YiaQb7RhjZnGknzvTCn1V4NK/1cgtKAtt1weOiU2szNXatzQRBEJKNeFiaheDU0swSdQNfPYYY5c04xShvxgF6DRq9aIYoIKtmkh9pzghJh5EKzC7xUYGpCfNh7KrnH4dM5qnqaeRVwV/5mZyy6gF+LDmYkEIjsMnuvpjO4+4AYOuiC9i58mTuXnUn+3qWUJqaxxVHPsuujMCsY4n5x8fSQ1kQBCFR6Fmama6aaaV/WL1qpkKne1/lNOlePc+kt5jSNkZWZXav0zZltM0Yo2s1Pxm1lmZ6aP6K0m3+Ye1YmgmCILREwplgcRwcnPDcGTPHATBsxGjR5hGoGzObMaTBYseEWDsOiKhNEISkIAKHFLPnZIilmR7q57S2sa5YmumiHiOM7DSj0aBRo4x1SlFkM2/GQgihlzfjeEwXmgTSnBGaDMpyvaanAgtVfp2491xeHnUfV29eRooPvmzTgdMWPcaG0tDA4Ia7F66l29nX4k6tZdfSw9n8w3VcWDqZ47d/Sh0p3LDvY6zO7xH6RrE2EwRBsIXPoFGjoLdqxtLSzE5Dod76ksaMGWVf2UbL3syFcaPGbfoSCM6bUZOu/90ydf5ilLwZO0jejCAIQviY2kEbUYBNxwHtJJX6mBPqKCyo4KOrJ3PTyV8TPJmm5AYY4LH5ETG0NgsStTl0HBAEQYg1trJm1Jg8xxSibmmmRs/SLCQ3ulLnpNa9xmg1plHDxshNR716BgxtzazQCCKMxmRbrkFCk0WaM0JS0LxVYAp1ZKXW8MPVd3N5my85cHUV9S54Imcw1/14G9V1xgopd/Yuul9xJam5O6lctxdrp9zPyLrZ3LDjMQAe6H4rP7Y7yFQFZoZYmwmC0BzQWzUTQgSWZhFjlTej26TxEmxrZrQiJtrL1aOXN2OF5M0IgiDEnrDtoAvx20HrErvcmSOGrGHrtbdyUpu53NP//cBRbUNGCXgOdRwIQj2pZ2ZtpsHM2iwiUZsgCEITJqqWZgolOse0z2uPrY8idCAwW71vdC4KeTMQmjejR3HwS0s3oAANeTOK21DAfajBjUhoMkhzRmhWOFaBOcqd0arAnPlL9m+3mW+ufgjXd5l03exidxZcWTOOFxeNp3HyS1GBqQaTlFq6XXEjmR1XUbujPav/8xS9a9fxUOVNuPHxbqvTebvozOCBS1EblJh/J7E2EwShuWPX0ky9aiahlmYheTNqtI0Zo632ei161mbK1qgZE37eTGpGTcO+ZMwIgiBEDyM7aLOAYLXjQNAEkJ7XvY3QYr/jgJrIHAfcbi+PnDebr495suHoY38fo7rGoR2Nx+C4nqhN4zgQMzSOAyJqEwQh6YiXpdk6nWedlaWZKXqZZcp+nPNm9OYXjWIWDPJm9FyC9NyEjJC8maaBNGeEJk9EKjBDops7c8Kev/LA8W+R+0Ue+RWwuZ2P8WvO5du/9wtcoTdDVwmU0umce8jd+yfqq7JY/dzT5O1O4Vn3P8ihgp9SDuC+wn+C1Y/3JLI2k0JDEJIPn/xosySxlmYQvHoGzFfKOGnKGGXNWKyYUfJmjLDRvzJt1GjyZoToIf/eBaFp42SixckETgi27aDBXytF7jiQl1vD9Gs/4MbiLxqO3b70FP755lhCc2dsrLo0y51RSAJrM0EQhHjS5C3NPBhYmhnlzejtRws9K8/I82b0kLyZ5os0Z4SkJHlUYOHnzmQWZ5OSl8L/HTqLc/vPptecTFK9sKqXlzOWXMrfOxSfNXX4ZXCRUXjMO7Q57BN8XjfrnnsY77oePJl1DZ1dG1jt68b1WY9T50qzHXBphngoC0LLIyXFP7teW+s4lT7p0LM0C1k1k8yWZoZ4CW62aPcTkDejbsooojGLubd01aSaXt5Mwzcp0z/usdPoEmyh/HtX/v0LgtD80ZvQMXUcMLKDLqDRcUCX8HNnenUpY+41/2F0XuN3faRkDP9+72BCMwMUNI4DCpWErphJNmszEbUJgtAESZilmSO0eTNaYpU3A6GrSU0wyZtRj9FC80eaM0LSkJwqMHCsAnNDm6Pb0+309ky+ZAojXX/R9/cUvMDKQfVcMud8Nu1WHth6hYZ/ICkY8TUdTn0FgI1v38zunw/mgba3sF/qInb5WnGF9zlKXQX6wWhgWwVm19rMNhEsfxUEIb6kpaWRkZFBaWlpy1PTJ8LSzOx4PTqWZkarZXwG+9EiSnkz6fa/V1adfbuzgiFDbF8rBOPz+SgtLSUjI4O0NCNbIUEQmitRzZ3JCtnRvLafOzNy8GbmXvASfdPWNxx7fcvB3PT6EaqrajVbNXph0Co8+od1CdPaLBZ5nYIgCNHGVtM3GS3NPNi0NKtU7StboxUz0cyb0e4TVt6MHqZjtyZvRmiahJNqLghJwwB+Dfkh3IsVDQ2Hop5rg39Yd0S/aVGgeR3yuz4tcDAV/wM8TbUF5aHuSnPR7vjOdOlVz201d9Ppf9UUlLsoz4DdQ2u46sPz2FCm/TBU9/D/k2w1cC6dL3gKgK2fjWfH1ydxc7uHOTrnC2p9aVxT8RSrMnuG3mIboauCSvA/8Ne5oIuPLX93pajnWtbS1TBobDl72ivebDBuj9BAsqkul3hfCkKSUFhYyPr161m3bh35+fmkpaXhamJKzV2//aZ7vEr9mMkl+Dd5YwwKPlVfoDq1sSFSG/AcqyGj4c31gSZFDelANbWk0eBfVuOFWlfw/ZXf/PU0WpgpfRflo0L6K0pzxkfjyhm9LarXRqtrtLg0+24aV8y4VVt34M/sxj821Qd/V+Ur1Ktu6aNxeHQBPh8p6bX4qsBLLenUUA/UBf4u3dSTSRXVgCvw/41L+XurU92ToP+7qKqS3Bqn+Hw+amtrKS0tZffu3XTu3DnRX0kQhGgzA/+E2kwcWWh1Za2p5TGF2FAw56HfGFFqJ4U01Wv//qXHLOWp/T4k3dXo8/nB9mFc+MIo/IOB8p5U1VZxHLBhHVOOf3JMQV0vbSYsa5kV9NLNAVBY0qZfg5jQd6i5+4MgCEKisWtpZrRqJiqWZnbxoDPcmFmaRROzvJks05eAcd6MCvUcndk4YzSuaOfehKaBNGeEZkEfllkvKy8mdLmk8mDUXX2ShV+BpX5dif9BXKfa+nHnpFN0Sjv26biJy7Y/yR7fu0mrhy1tfbQeXM6lb1/I6tLWgasbGzHBTZ46snqvoOsVT+FK8bLz2yPZ/OFFnF3wNue3eQ2A27bfx7z0/f2Xe2hsLMWo0PiFfRqsEAwLjcNp7NgLgtCkyMvzr+Tbtm0b69evt7g6OanYFjprlJ0K29UH1HZZmcBOzesANW7/87hGpfKtU+3XBp7dyrG6gMeXty5VuTh4q27KqBsySoPGpzrX0FPRa8YY/Q+drXYfQpsy2q22QaN+7W68VtvDSVFdkhL4nzK8pYI7tS6wW09qoEBKo65hX7E7S/cGiiel76JsA3VWhaqxlr1qFUJ4ZGRk0Llz54Z/94IgNH2m/qU/seb6n//3+qAdf4ZYz/Rmpf7EWRdfsJK5PQ4sk5WZqNrAvnrg1TZpwO328vh533FNt2+Cjk8rHcQZzx2D16us2KyjsSGjrsOgcTJOZxZMOVWgOraVRvWyunbaiF/EV4KpinklvUOsZiIStR2FpdpZRG2CIMSVKDuh2LY0U4uo1aWdkVMM0LhqUntMjTZvxq6lmR2ilDejF8VggBMXIRk7mg7SnBGSn4SrwCD4h38Z2gIjtU0m7U9tw+jWCzh2xXv0+9U/WbemZz0HD97KIa9cxF87lLADpcAArQoso9Mmul/3FO6MGnYtGcz6165iVO633FL0KACPbr2BaZXHQkXg7WoVmB5KoWGBXqEREVJoCEKTIi8vj7y8PGpra6mvt51QnzTMGjMm5NjIYs2BEar9/Rt3fUMb9/8saJyoWkGfhv0SegCwRjWmrKE7AJvoAMD2tYGH7abApNZ6GosJpUu0k8am0A78w4lSUwQtBqkKvKkKf3OmGr/0ty5wrC5wTNmCfxypx9iXX9UxIQV/EZGCvzOVCmSo9nMC+zmB47n+15k0xrG1AtoAiuagNdAW/6SXsiijg4+2XTfSgU0AdGN1YLuWYvwNlj1ZDkCxZyWu+YH3zQ1sv/NvZpUE/0lGLl2K4JyUlBSxMhOEZsI4n89xHknUHAcsMXIcoOFYq9wa3rtkCmPyfg5590nPHkt9vVUmllroBqGTcTqTYx5MsnI0OHAcUKMWtRliImrTcxwQBEGIFnrjRjQtFSO2NDPCY/dC7eoZs5UzVqtq0gy2icmbsRxbhCaNNGeEpCV5VGDgf+DqPbxTyejso93JeVyS+QFD582neL2/mFg/pJpRvbbjdsHSXXoGk8EqsLQ22yi+8UVScyuoWNGTtc9exb4ZP/NQx3/hdvl4e+cZTN5xcePXUWOkAlMoIfbWZlJoCEKTJy0trclN3hpNTmWmq14cBZSqXqv6Fj7Vda7M7Ib9msCD1F9k7AZgR8Bcyz+R5QsSAGxJCSy/8bkaFWBKI30XjWKAbTQ2bTz4ey4eVLVEJf6Gi7IkxRU4pjRpFCsXxVNZqwBTxiqj5oyyDF8Zg1yBe6fgb+4ofyHpNDZxCvyHsgj2TlZWARUG/oxZ+P3HfDQowCozU3EFVvHkUkNvVrIbqGEHAC42+hVg6Y3WZg3/X20I/ElWB/9JMjMzEQRBEMIjbMcB0PfMrwT7jgO19Oi8i8/O/Q97pa8LudXiqh7U1GgbM8aOA7Zp4tZmImoTBCEhxNvSzChvBhprJl2U2khBu+9kdYw2b0Z7TI1J3owdiq0viVbcgJDc2Eh3FYT4Ec6PTr0OsrrTrO5AA+YrSZQJJ13UnXL/6+w+KXQZn8sd7sc5YuYCite7qEyHOzMGsLK0O+7AnOHhg8w7QSk5u+h+42uktfVQtb4Dqx+/hu6+7TzT+X4y3LXMLDuI+7bcQLAtTQCPwU0dNZ8a0RtI1QOueiBWD9CCIAjJgl3fZCNshfqqsB1qqcZj9+5aNbAdBZhRAaJtvGkVYAphKMAsJraipQCTySlBEIToY7kqxIZPvvF40TipdcigNcy98Gn2Sl/HZm8+F/10Fl5f49hZVq803/XGs1rNFkJzBnTwWH3vAMoqoRLzy/QseiybXWZoLISiqWIXBEFQsLXasilYmgUJ2nZhbWOmPR4uDvNmCmicX2yHjTHUxlgcoKHZr4ijA641IohuukhzRmg22OooF+scK8T/sCxQHcsK2dG8TqPVftD7hFruLpvI/l9tok0ZbMx3MX7zsXz4y/5c/tk43twyHIDRfRV/fHWh4d93pVfQ7br3yOy8hdrt+ax+9FJaV+3mxS73UZCym58r92TCxjvworO8v1zzWs+mzaYdgeXA6RQbhYZTOwZBEAQ1kRYZRgowPRwrwPSevUmjANMrKiC0eeNQAWayPF+NKMAEQRCiS8OEjGIrHJiwMVutYbjqQ+t972hVibGK+KIxi/jyuCdp597FbzXd2P/ly3llRj/aPXkzg9+5GoBh2SvIytQb37SWndprlNWkyiRdZdAmCPVkn5nFdUBwoQgwTO2yVRgJPIJEbQ6sugVBEGJFpII2NYm3NIPGh77eGKFsE5g3U4jtvBm9MdpJ3ozQtJDmjNCssaUCs0SrAsukYGQdw45Ywz9L7mHw/2pIr4Ml7TM44ecL+GtLN5QH9FcregIwssighe2uoesVn5GzxzrqyzMpefR80nZm82yXZ+mSvpXVNR24Yv1dVPk0Fi6VhA5SVoVGSWDroNBQD7C2lORSaAiCkGBi7ZusNLINn50lOsdiogBz2qDRQ08BlqXZql4WYK0AK47wKwmCIAhRJ+wJHRvZlaFkoV6Z6aKeR8//Ly8Pe5kMVx3TSgdzwJOXsHpjKwB2lGayeHlrNtUXkOmqZeS+myzur7eCRo02HFqFR7WvHpuj6Digxkr40YANtbqI2gRBCJdInx9JYWlmitKgV/aVrVG9ZCdvJlXzGmKVN6NG3AZaJtKcEZKa5FKBZUGKj8Ljyjh5yDfcsOgp9glkWH5Y1IGzvr2CsurcoHfNWNgJgP7pa+lYVKW5p4/O539D3qC/8daksvrxs6jbUMijnSazd+YadtS14rJ1t7KzPgPjCbsAHoOvHYdCw9TaTAoNQRCSCQPfZDVRtTQzwuPkE4wUYApJoADTohpv1SIJM29+WZ4vCILgjLjYQeuhOA4ozXrDYSKLVrkVTLvucW7o/ikAj5Ycw7gnzqG8Qjvx5WLWzr4AHNVPCRrTWz2q3So5bDawnNwjIscBx9ZmJqI2sTYTBCHWhDxnIrA0MxO0WWL03NUK2jzoPO611pbaDE6naJ0FjNwGIOK8GRXiNiBIc0Zo0sRTBebOrKP96Vu5tudznPe/z+ix1kVVGtydtT93zrkYH+kh79m8I5tfq/0Td0ftFxx62f7Un2h9yB/46l2sffZ4KlZ0YGL7jzk0908qvWlcsf461tR2sP5iMbY2i8hDWYMUGoIgRIto+iZH3dJMIekUYHrb2OfNmOFkHBcFmCAIQnQIyw7aMncmC/W40af7auZdcxNj8hdQ5UvjonmXMuH1Y/DhJnjc8e9/GXAcOLz9cs19jSzN1GOh1gpUc8qjORYDxwE1scrrFFGbIAhxIYoZnYaWZiU6F2ufxx6juyriZe0xNVq3gXAbNhCTvJniCL6O0OyQ5oyQdCSfCiyLlLxqup21nHtb/ZNjvv6TwlLYmufivJ3n8N6S41Q3CX1Yz9rcB4Aje/4dOFJL4ZifaXfMYgDWvzaKsiXFXFM4k1MK5lPvc3HTxvP5paoTjgMu41hoqBEPZUEQkoFY+ybbtjSzG2rpwcDSLFYKsFSdYwqxy5uJ1vJ8QRAEwQYaxwEz1Csag+ygI3QcOPqA75h77mX0TVvHpvrWHPHJnbwyfTjGCmSYvrALAHunr6F9oc3VMCFjozKGJo/jgG1s5HUKgiA4JakzOhUiErRBaO1kVDepz4nbgJA8SHNGaFbEQgWWVlTKwHPncH/5TYyYvYvMWvijfQbH/z6RX9fva/JB/gf1jD/9Hziy9TLAR+tDltPh9LkAbHx3OJ45/TizYCGXt50DwN2bT2Lm7gGq+6gDLlWFht2aJUxi5aGsV2iICkwQBCfE0zfZjPhZminEWgHmIG9Gj2LrT5Xl+YIgCLHBaGJGmciJT5Cwl5tOfoapo2+mwF3OoqreDH3pEX74ZQ+CJ7FCmzSbt2XxW003AMZoHAf86NmcgaMJtmRzHBBRmyAICSBeGZ2W2Hzemq+gMRobnOZzqhsuameBKObNiNuAYIA0Z4RmSbRUYJnFWznuzKe4e/m/GLjYC8DnnTpz+pyn8VS0VV1prAKbubgDVb40OqXsYO9j/6DT+T8AsHXaQLZ/sS9H5S7jtqKvAZi0bSQflu5PWBNuHtW+XsClFBqCIDRjoumbrKZpKMD09p3iUAFmaW8TuQJMEARBMCcujgN6dtA6jgPp6RX85x8X89DeT5Lq8vLutpEc+PhzrNtcSPAkl3GT5ptNfseB0b0bHQf8qMc3be4MhK461dBUrM0kr1MQhCgS1vMiihmdYVuaad0GwMBtQI1WxKY9Hg4xyJsRtwFBB2nOCElPgwpMs0Q/YhWYRe5MTr+1XHPCBK758S16rYbqVHgkbxQTZj1FvU+tNFYXG6EP76rqVH4q783czAzcJ87H5fax43992PzBMIZlreXBjp/jdsHbO/fl+e2H4Djg0kPwRJ/eQKamJLCVQkMQhJZCoizN1NgKtYTwFGB28mbMFGAKUVCAaUUQJliN37I8XxAEIXaE5TgAuo35Dm3WMue6Izin6HPqfS5u/fNyznjm31TXZKqu0gtaDt5vdBxYCuiNJ2a5M2DoOKDFo3844dZmGsTaTBCEaBOuoC3ulmYeJzdRnvV6Y4KyDbdJk5i8GXEbaFlIc0ZIShKbO+Oj9ZA/uX/Y2Zz+zTzaeWB7KxeXld/KK/NvCFyoPIS1E13KfnDj5oNdnbmmfTu8KT5KF3Rnw2vD2TNjO091nka628uMsj24b8uRgNKgcBBwaYaeCswBUmgIgpBsxMo3WU8BZvYMtGVp5jjUMtoKMD21V9PMm5Hl+YIgCLEnyHFAD51JpmE9v2PhBYcyLOsPSn3ZnPDN8zzw/g00TjVoxxQTx4FF7anypdEhxcPAPh6LbxvmpFszsDYTUZsgCDEhioI2S+w8Z8uxaNIoDXllX9kaCdfCzZvRy/CMf95MA+I20OyQ5ozQ7IhIBeaqp9fes5jU6nSOnLWFrBpY3j6dE5dOY17JqZo3KA9g4wIDIL2jh2/3LaHC7WZoRTVbXhlB59TdvNjlM1ql1DCvojP/t3EsXtzoL9lXcKgC07M2s4kUGoIgNCWSruEbFUuzeCvAJG9GEASh2WDgOKCH4USQDTvoc/Z9mdlHnEinlO2sqO3M8De/4LPvjw+cVY8n9hwHqmtS+bHcP6iPGbgmcFQ9maa1NNNam1ng0byOorWZnqAjGo4DSfcbRxCEJkE8BW16xN7STGsBDWHnkVnmzSj7ic2bMRrLxW2g6SPNGaFFYJo7A1AIbnc1I3tO5rFNV7LfQv9D/ZtunTllzly27e4RuDAL4654aIGR1mY3PSbMgOwa9qyq46ktWxjbbz0vdZlOu9QKllW34er1x1ITpAiOUsClHoo6oSSwTaC1mRQagiA4JR6+yTGxNNPiMTupZ2mmKMDMLM2SN29GEARBiA9WEzR27aCtcmdc1PPo/tfwn71uJMtVw1flwxjy0hz+/Hug5k5pmDsOgLaG+mqtv0g4omvoyks/ZtZmiuOATqNGe8hjcPswrc0iQvI6BUGIA7Gag4mKpZkaj5Ob2HUecIpefQTJnjcjbgNNE2nOCE2LaKjAtLSHNMq4uOAm7lz4OHusgpoUeK7LkVz93dfU+TJ13qT261e/blSBpbSqpPimaaS1raBqQz7HLO1MVg3c5Pue4vRdrK/N5dK14yjzpqjuG6eASwfEytpMEAQhGiTCN9mxpVnUFGDhYJU3E0UFmEHejNl43DCOa5bniwJMEATBHrGwgw6iOPhlnnsH0/cdyw2FrwPwxJZzOPqd6ZTubhu4IgvjccQ6d+aLn7sDcFDuctLT643+CBr0xkoLxwE1YVibKYINK8cBI5W55HUKghArIhW02SHqlmZhuw0oqFdXqgnXHhpCV9SoEbcBIXpIc0ZoEsRSBZZZs5V7q07j0plf0X4H7MyF6zvdxlOrn/I/VA2b4cYqMHeGj+7Xf0FGx1Jqtuew+pHR/LC0F+u+b01BeSU76zK4ZO0YttbnqO6nrwJzpdVQMOI7ul75MB3PeYrM7r+SzNZmQYWGWJsJghAl4lFk6GG7yFCISahlOAowu3kzWiJTgNnF7rgtCIIgxA69CSCzlY97un5jfrdDOCrtJ6p8aVy46kGu/+ppvF698cVMXWw8Hi1eVsBmbz5ZrhoOG6RXwOitMAXb46OH4Em/MK3N9LAStVkJQ4wQxwFBECLFiaAtaSzNPOhMeymCNqOxwMptwAjtChmFCPJmxG1AsIk0Z4SkJR4qsLbbl/Ls8qM4dmYJ2TXwd8dUTuv0EbM85+i8QU8FFpo740qpp+tV08juuZW6XZmUPHwMdTty2GfpWso3ZeJK8XJb6UhKagtM/iS1pBbsov2pU9jz8TvocvFr5A+dT9sjvqTTec8Yvy0JrM1MEWszQRBiSYRFhlkDOiqWZo5CLdXH9K6LRAEWZt6MDWR5viAIQtPDTu7MmOop/JQ7mj7utWzytuHwlR/x6h+XW9w5CzuOA8G4mL2jLwBH71USOKZn4Rmm40CckLxOQRCaDBEI2uJnaaashtRiJGILR9Cm1Elq1wEQtwEh1khzRmiWWKrAfD4GLZrKS9+cxAHzK3ED3/cr5MQuP7G+bi+Lu2tzZ1QPc1cKnS+aQat91uCtTmX140dTs6mAG9vNY1TGKnwu6HLQTor3rQq8IbTb784op+iE/9HnwUdpd8xsUnMrqNnapuGqms2d9L+WR/M6WazNpNAQBCFGJMI3OT6WZhBbBZgaBwqwduirv4qtP0WW5wuCIMSHhokaAztoxysXvV5u3vAvptafR4GrnAX1fRmyfjY/7gioHgpobOKH4Tig16T58m//OHxEh+UG99NzHNCOm8rYGoHjQATWZmokr1MQhHhhay7FpqDNiOSxNNO6DSiom/eR5M6omzIQzbwZu4jbQMtAmjNCs0av8+yqreO07/7Nox/eTN8VPurc8PrBw7mkxxxqXTk6dzFCO+mVRvvTvqHgwD/x1blZ8/TRVK4q4rzWS7iwzS8A/N67E7mdqhnd+2/Ne+vA7aX1oUvY46HJFJ0wB3dGLRV/dWP1kxdS8tBV1FdmAFC2ZF9CCo0YBVzGw0NZCg1BEKxIet9khahZmsVCAQZRy5spxFQBphZD2Mp/0yjABEEQhNhj5ThQ1HMt6TUVvPnjWTxY9TApLh9v+8Ywonwm62u7+8cCI099R7kzagK5Mws6A7B3+hqK2lSZXA9hWZuBzcm/ACWBbRSszaKNiNoEQbAi3DkXK0szveddbC3NwNjSrBZrQVs4eTPafQg7b0aFuA0IaqQ5IzQZoqICK63l1hfO5qYP3qLjNijNgVvHX8ODvV8JvbYARyqwtkd9R7sxPwGw/pXR7P61F8e2WsH/Ff0AwCNbh/Hazn0BGNl2KdD44MzZax2973mfzhd8Q1p+BdWbC1jz9On8/e/LKFvSh84XvU1KVjXly/akdN5w8z+jttAwC7gsCWxtWJvFykPZCCk0BEEwIxG+yYbYtTSzRBtqqSZSBZjeChmISAEWJsp43bA8X4MszxcEQXBGJBMyRisb223ayHcfHMFZ1Z9T73Nxa+aNnJX5NtXkWHroNxJe7syGrdn8XtMFt8vH0UPWBY4aTcZpz9mwNvNojkXBcSDcvM4gtMISk982giAIWhKV0emYsC3NtM92taAtEuKUN1Ns/U3EbaBlIs0ZodlhpAKrX1PJPdeexZnTfyGnCtZ0cHPOLf9hWvE/gi92rALLIv+AhXQ843MANr13OJ4f+nNgdgn/7vg1AK/vGMArOwYwc1ERVb40OqZ4GNBzJ2ltd9P1yq/ocfPnZHbZQd3uDDa+NZIVt13CrgV7AC7ajvqenL4r8Fans/7ly8Bn8s/Wo3kdRWuziEjEgC8IQrMg3kWGI0uzEp2LmoQCTM/nX6e4KMA8b0azPF8UYIIgCM2Dwxd/yVfPj2do/Z94fDkc3+MVHuh0J7jdoSsnLQkndwZmbvY3NnQdB0y3CvG1NrMiLGszDeI4IAiCU6ItaEsOSzMItTQzch4wI7F5M+I2IChIc0ZIaqKlAqv/cRf/vOo6DvmpFLcPFgxoxQn//IEV7YcGv8mRCsz/YM7p/zudL34HgG0zDmLb9OH0z9jIpM6fkuby8tmuPXho60GAi6rKVH6o6EO1C3qfNo897vuY/KGr8NW72Pbl3iy/+Sy2fzUYX70XgOw+f9HhtGkAbHrveGq25mGoCkgCazPxUBYEIR6EW2QYEXVLMzUeozckuQJM26uRvBlBEISmh4HjgB7KxNDhU7/i1U9uoSM7+MvVheFHfsG0nif5LzLyzC8girkz/uMzlhYDMLLNMtSOA/roraCxQTJam0lepyAIYRDv50DMLc0MURrv2mN614VjDw2hFtCQyLwZcRto/khzRmjW+HyQ98YfXP5/z9J/eT11bvjvqX04974fqeql6XQ7VoFBZve1dLvqWdyp9Xh+Gsimd8fQPW0Hz3f5kGx3Ld+Xd+WfG0fjQxkofbxZ3Y7jO3died+NuDPqKV/agRUTT2bT2wfirchEGTTS2myn21Xv40r14pk7kB0zD1Z9soUKLEHWZqZIoSEIQoKJm6VZzBRgdiadYqwA02vQGCjABEEQhMRgNWGjZwftq/Ny7Auf89DCx8l01fJ11hDGXPsOSzubjJdhOA74Mcud8TNzUQeqfGl0StnB3r1Lda7QjonqFaUmtmYKHs3rBFqbiahNEIS4oBW0RSGjM2ZUqneUeS+zpoyV20C4SN6MEHukOSM0TWyowOrrXeTc/TOXPPYpnbdAWRZMvvNInrz3Kf+SfDVhqMDS2myl+w33k5JVze7f+7D+5bMpdJfzYpc3aZtawW9VHbh2/QnUkgJAevtyut/wFfP3+5v1aakU1dWz8YWDWfXA0VSvaxN0b1d6Ld2u+ZjUvHIqV3dk/eRTgXp7fzcezesYWZtJoSEIQqxpGZZmEKwAs1L/hqsAg6gpwGyKGcyW5zeM25rl+aIAEwRBCI9wJmaUCSD39hqOe+Rrztn4BQAvdTyBc258ibKCgsaLizVvDit3RkuqZttIZVUqcyv8zY6xg1YHjupZeqrHS/VxrWWocgxzx4EEWJtFAxG1CYKgxckcS7gZnUpT2ragTe02oH7eeiy/oopKjG3MwqmP9Igwb0ZLsfU3EbeBlos0Z4QmhV0VWPmuVLyX/c0J7y2nVSWs7eDitf+cwnfjxwPBnWpdLFRgrrQMutHmXgABAABJREFUul51P2n5pVSu6caap64m21vLC11eoWv6TtbUtObydadQ4UvHnVFL+1Pm0fvfH9JqwHq8dW7G76xm6roNDK3LAlwEFw41dL7gS7KKN1O3K5s1k8bjq0kPnLMRcGmHKFibqZFCQxCEeJFUvskKEVmaqR/cen750VKA6RUYCvHLm9FTawuCIAiJo/43uOzpDzig6k8qfBn8a+jFvHDZpXROVXUkumgaP1HPndHn67X+Qf+IrkbNfr28Gb2GjDLu6tAUrM1E1CYIggnNRtBWjo0mjVXtpGytmjTiNiAkD9KcEZIepyqwVb+3pf1ZSzjouzLcPvh1QCqvT72K7fsYTLYVa15bqsB8dDr3QbJ7/kVdWSvWTLqBlKo0JnWaTL/M9Wyry+WSdeewoz6b/P1XsMf9b9Pu2CW4U72U/dyZFf88gV4repDt83F0/xLNvesoOnE+BcOX4qt3sfbZ46ndXkDUAy4VmoC1mSAILZN4Fxl6JMbSDBKuAIti3kwkyPJ8QRCE2LJpWjvGfzCTLr5trHUV8vgpZ/HrMQPM32TmnR/l3JnpP3cH4KBWy0lLtXIRMFpBY4JH8zoO1mZh5XXaRERtgiAoJGVGp10qtTvK3JfWxkzZWgnaEpQ3I24Dgk2kOSM0Kz6dNpCCy7ax119e6tww/8Rslrx7PHWtchquCepQh6ECa3PEO7Q++DN8Xjdrn/0/ard1ZGKH1xmes4xybwaXrbuQLe1r6XHLB3T9xzTS2pRTsyWP1U8cxerHR1GzOZ8ZK/0NjsM7BC9bLBixjKLjFwCw4fVRlC/tTlQCLuPsoWy70LCpApNCQxAEcFZkOCXmlmamNGEFmAo9BZgszxcEQUgMDRM3OnbQdXVuqp7OZsz8JWS7apif2YcXrzqVin1ahdzH0nEA/E38Ar0TdnJnjPNnFi1tzVZvHjmuag7dd0vgqHqyTVsnqeslh44DHoPromxtZhsRtQmCYINoz5XENKMzLEszvVWPakuzWOIwb0YPcRsQbCDNGaFZ4PW5eO2JYXR+pJqO26A0B+bf0Q7f/b1xud2mnegGbOTOZO+zgI5nPAzApvcvofzPgVzW5kNOyp9Dnc/NTWWnsuvsn+h9z8vk9F2LtzqVzR8N569/jqdsSXfAb082fWFnAPbJWEuHtv7KIGevdXQ+/1sAtny6Hzu/VQ+EDgIuKwka3Np4tzOibg6X7nyBG9Y+wiUbXuSUNe/Tdo29Lk04HsqmSKEhCIIF0Sgyks7SzDDUsokrwLQihwDhKMAEQRCE2OOpbMXSl/szfJu/e/NNxwH8PqEn3rbpDdfoTRgB+o4DtrJntOOM3urN0CaNDxezd/YF4Oi9SgzurbU2C8NxQLvK1axMUr5GDKzNJK9TEIRoYCloS6SlmRG2Lc2UZ7j2mW9HyKagHoPUQja1/aaRFScY1k4RuA1EImgTt4GmjzRnhCaHVgVW+k0Wb04ZyLD3d5FTDWs7QvXrkHJGaLdFXWg4zZ1JzdtEt0tvxJVah+fHsWz/4kLG5c3k2nbvADCpc3+23PMRrQ9ehMvto3TuXvx12wVsnToMX21wsbF5axa/VncDYMyQdWR02UG3q77ClerD82Nvtnw8LHClVcBlaKGRQh37Z/zExMy7+MZ9ON+Vj+DFyku5bucTXLxxMteve5x7fpvIZ98eyzErP7O0NtNDCg1BEJKOKBcZSmM65FlYEthG1dIMkkYBBo2iBCNMFGBmWCnAZHm+IAhCZFhN0Kzc0ZWKZW3Zp2Y1Fb50vhg2iDaXbceVqj8xZOqJH1HujBbjJs2XAceBIzout/EZTlabBvBoXifQ2iwaiOOAILQcmpSgTcHMbQAMtMhmgjYt4dpD663ilLwZIT5Ic0Zo0vxBR2au6sHQuVW4gaUD4ZDXl5O3t6thOaAtijWvdR6m7Uc/RWqrHVSu7sv6V+9iSNYi/tXhcQD+OziDn877E3d6LeVLe7Dynn+w9rmTqN2eh9FS/W829wFgeL/ldL/+a1Kyaylf1oH1kw8FXNgLuFTwsV/WQu4ouptZvQ7j1fYXcHr6e3R0bcTrc/G3uwefpR7La/nn8XHhSSzP2oP82lIe/vYm9tn6i62/IpBCQxCE2GP0bz0evsmmrDN5BkVsaRYNBZiC2rNfeQ22FWDq2qMAxwows+X5ThAFmCAIQnRZkLkPndZU0sm1g3Wutmwam0n7saEzY4YrHw1WSoYQce5M8P4Xi/yOAwPSV1NYUKVzPytrMxPiYG0W1bxOEbUJghApicrodGxpZrTaUStos3IbcELTyJsRmh/GBq+CkESM8/lCJuw+79ufrEX19N3uz5dZdXgdJ9wTOinUh2Uhk3BdWds4iHTxBU+6tSfEpiYlZxv5A6cDsOG1O9kj/xeeKbySNG8dP/Z18e7oOqrWd2LzB+MoW9IXqAIq8D/UlUEhFfWg8cUfPbi0u4vP+20kPSOF6o35rJk0Gl9dis7fQC2NA0TjPXLd5RyX9xVnFHxOr4zVDcd31hXwVdUovnKNZnHmvlTkBSTQbYFCSPXW8sSaazl8yyzGrPqcXweYB3+upWt0u/uH0zjAgL/QUA004/YQBbUgCA4wKTKc+CZHzdJMS8jckJWlmRY7CjDtTzptk0Z9HCRvRhAEoWVQ53azwDuYA3L8K09+cRXTfc+V5O1RzhLsiamKeq4NXknakdBGheI44NG+WxlvKjXHyvCPU7WqbSjrNufwZ01n+qWv5+ih63nzq1401kZ1NNZY6jGvjsbxzqhBYzHJtg3jsa8Ev0hhnQu6+Njyd1dbq0iXs6fl2Og7VDUhFyZTXS4ROQhCMydegjY1MbU0s4VRo0a7bwejlZwKsc2bMUPcBloesnJGaHL4gDeGDKb96no6bAdPDuw+q5TjjjDwRg4QSe5Mm+Ef4k6tpb6yFb1HP8+LPc6nlbeGZZ3h8QOLWPfaZay440HKluyDf9ULWKnAZv3cjhvatWNFRgqUp1Py2Cjqy5XGjHnA5Z4ZfzGx/XPM7nUxt7d/jl4Zq6nwZvFx6XFcvPZlDl35LXftuIfv60dQQU5wkbQN6txpfNZ5HAD7b5wXsbWZHQW604FfD1k9IwgtEydFhlNiYmnmMfo0RQGmpYkqwMLImxEEQRDix0fb81h6Qn8OaOdvzPxY3Yf+e/9GfoZ/JkyZANJb4Rib3Bn1mGMvdwZg5hZ/rXHUHlbjix3HAR08BE8OxsjaTI1txwHJ6xQEIQ4k3NLMg8lix1rNSb1sMWVr1aRJ1ewnJm8mWm4DQvNAmjNCk6IqPYMdpw5gyKIKsqthdUfouv8ahm/yz5Q5VRnZzZ3JaPM3AFmpu/jXspkUlcKGnEyu9dzNb7dOZuf/jgRvCo0P6lTNVouPtqcu4PucTDK8Xg5e0p3ara10rlNb3PgYnv0nk7s8xX+L/83pBbPJdlezoror926+gkNXfsjtm+7ih4qDqFMPHgYqBNcuLwC708yCBUKRQkMQhFiRSN9kU8KxNLMVagnRtzQzI/p5M3ZRJgCNlueLAkwQBCG6VPbvxvArWrN35moqfBksKu/G8N0LSHF7Dd/jeMWj49wZCCd3ZvqfPQA4os1S/FI9M7SZnQ6tzYyIo7VZNPI6RdQmCC0PS0FbjDI6DYmKpZn2Ia1uxFgJ2pIkb0ZFtN0GZJVk80CaM0KTYWfXItIP78n+v9QA8EdfOHzjMrqkVIRcGzUVWICN397Gxi9v5JLXO7PHRvD4WnHRrx+zZtE48ObiVAXW9ujfaHvEUvDBA1u3c2pe6K/8zOItdDx7Np3O+IajCn7lg+7PMLnrSwzPWU6dz830XcM4d81Ejit5gnc84yj36symeTSvVROHHWs2AbApxd5MW9Q9lLVIoSEIgg6xLDLUxM/STEGrAFOO6V0XJwVYGHkzapGDKMAEQRASzzifj+qjB3DcyeV0StnJuvpCtsxOZ3DVb2HdL2giKUG5M18v6EClL52OKTsZ1McTOKeefDPLnYFGK1Flok8vxyCAR7VvtmqmJLC14TigJl55nYIgNF+iLWgzwq6gLX6WZlq3gWiidhtQ6iVxGxDigzRnhCbBlx8+T7fObeizykdtCizoXc/Jfywl3WuvQIhUBVZfVcBx870cunUNtb40rl77LKtre2jepFdUaK1kUskb+jcdxy/wv/xqb46sqGR4zl9kZfkHmvSiUrpd+wX9bv+A09ou5o1VC3i8/RT6Z26kwpvGGzsP5ui/J3LjxgtYUNmDRhs1DWYBl9tgwO6fAVid091/LAxrMzV2Co1oWJsJgtA8iVeRocaWpZm2yIjY0kyrAFNPKkFTVIDpIQowQRCE+FNXW8NPz17KKfuXkOmqZX5ZMa2mbqZ443pH9wnbDrodBt77WYSON3YdB6C6JpXvd/cB4Jh9Swyu0q5A1Y6vNoiztZlttEIUEbUJgqDB6DnQQKIyOhXCtjQDfUuzWLoNQPCYFHnejF2s3AaE5ok0Z4Skxufz8fZd51F4z5O03wk7W8EvuWUUTIvc/8SJCmxI3Tyudz8GwH1Vt7KQIYEz2u65WTc9jazeG+ly6SwAtn/dj9/e3o+N9a3JdNUycth62p+ygH3ufJ+zq0t4+tl6LvvCS8ed4PGl88y2Qzny75u4f8s4NtS1Ud3XhgpMo0bI9pZziOdbAGa2dy47l0JDEISkIIwiIy6+yWGFWoZTWFgRn7wZveX5giAIQnzZsWU9yx4+ggO2vAfAtFV92PDEb3z7c2BQUiZ2AhM9ysSPWfCwekVkkB10sc7FMc6d+XK1vzAY3XW5jc/RWptZUIk9K1ILazNF2BFzazNBEFostudCTDI6YyJo0xI1SzOtiE3Zapvv2lpK79nvxG0AzTmIJG8mWm4DYgXdPJHmjJC0lFeU8875B7Lvu/PIqoGSLi7Wb1hH/jxnyi814ajA2tVt4bGaG0mlnine43ivdrzmDXnoT3YFq8DSizx0v2Y67rR6di3uxsa3hgFuvtneh89ysqk5fjbX5C7h+RdrOXuWl9blsK0VPNu7PUf+dSXPbD8YT71iXeYw4FLNVhhZMYtMXzWrMotZ2qqvbXseKTQEQYgFRkWGpaVZlLH0TVaw88z0EGGopVPFr528GRMFGEQlb8ZsnBUFmCAIQmz4a8kcap49hP41v1Duy+SDOV3x/mcBbq9xvoyWiGwoHa+wNLcx8xPcpPlskX+1//7ZK8jNrgkcNZq0057Tm+RTjuvgUe07sDbTQ08AEpa1WZiiNkEQBDOintFZ4uA9tgVtoG9pZuY2YBdxGxASjzRnhKTk76WLmXnyMPad6wHgt/1acfinP5K5YXfoxTFUgaW2qeWxHTdQWL+NZe4+3O27C0MbMRMVmCvdR7drPyM1r4rKVYWsfe4I8LnJ7L6N99ruYuWi1jzycg0n/ugjpxrW5KbyzDFuLjuxDc9+cyoVvvTA/SIMuPT4NyeU/ReAGW2OAvWkaByszUKwWLgjhYYgCHZQN36NnjnqZ5RZYzl+lmZgrgCDuCnA7ObNqLBSgCmYjccgCjBBEIRImD/lWbr+90Q6sI21rk5sG/85p34TXr4M6E8UmToOaFdWKtjOnVEwb9j8uSqP1XXtyHDVMXqIdkDWszRTb9WoHQd0iJK1mW3BhwnRELWJ44AgNB/CFrQlKqNTwcrSzBR1/RQNSzM9zPJmtCs+LbDpNiB5M4Iaac4IScfMD55h3fln0nuVl5pUWHraIE59ax4Z2flB19mdzAlbBdYeblj6KPvtXkiZK5frsp6kyvKhrFdUpNL5/Blkdt5O7c4cVj9xDO6MOg455WueGvQJD3xYyujFPtLr4U9vW+7o24WbrvIxc48M/n5uNN4q5Z+pegJPPQDZDLgMFBr96n/noMofqCOFj9qdkngPZS02lfFSaAhC8yDWRUZcfJPVNHcFWLH9bxGRClsQBEEwpbammp+euZihi28l01XLz1n7k3fNHLr32y8q90+23Blw8c22vgAc3dfOpJZR7WSCx8Y1Nq3N9FALQ4xU6SJqEwQhliSlpVmldsfK0kyLkyaNk7wZnWtjmDejIG4DLQ9pzghJxXt3nU3re56mnQd25EHZXZdz4j3vNJyPZNmeUxXYUau+4PyS1wG4reh+VruLgx/CDfNbZrkzqbQ+bAEFB/6Or97F+peO5LDBc/nP8Nd5/tdVHLTUh9sHle3ddD98G68fm8eyEzfh9blZ+/wRVK9TsmX01F/hWZtdXPMyANNzx7I+o0vjiRhbm6kLDbE2EwQhmkSzyLBE71kZVqhlIhVg2n0LHObNyPJ8QRCE2LJ98zqWP3IEB2z9AICfulzEPhOmk9/aYsljk8idUY9bwUxf2gOAI9sZjTN6GQQKeqtXNafVeFT7MbA2U2MlJDFERG2C0GKIhaAtahmdUbU001vVaGQBrUV55tt1G1A7DOi5DaiJfd6MuA20XKQ5IyQFlRXlvHvuAQx4dyGZtbCqq5uu/3mTEadcG/XPsqMCa523g7t+mAjAyx0v4pucI/0PXUP0c2fcGdV0OPVrXD4fgxdl8EbbT5i04Tf2Xe2l3gXTfZ05seQkZnTtz0/FKawa4H9ob3pnKLt/7mbwWeFbm3WnhNG1X/r/XAUX6xcaDqzNnHgomxKmh7IUGoLQPLFUfkapyNAj9pZmiVaAQSLyZgRBEITo8Nfib6l77hD61/zKbl8Wiw98hgMufgx3Sorhe2LuOABRyp3RHgtu3MxY2JkaXwo9UrewZ3dlAk9vMs5sRap9xwEg6tZmktcpCEKiiMbzxLFlY0SWZhD8jI6WiE2NtimjdRvQW/WpQ5zzZoTmhTRnhIRTsnQR35w4jIHzSgH4dWgeR0z5iW59HSzJj7IK7OpvJ5Ffs4s/8/rx5MDrHDxog1VgRcMWcvSyCh5/oZ5bv97NnlvqqU6BDzO6cvTKs7hx+ViWVbdlyqZCbmnXFp8Lds7qw/av9gJqVPcyC7gEu9Zml9c/R4rLy6yUw/grvU/jiTCtzSJBCg1BEGw3WG0qQ52SXJZm8VKA2cybscBp3owszxcEQYicef99im6fnER7trPW1YntZ3zOvqPPjvrnJCZ3xszSzE/Z7nTmVfjH7mP3KzG4KnqOA4YYWZslIq9TRG2C0GKJtaDNVkanlrAtzbQHazUnjVbPOLV/NsIob8YBDt0GIkHcBpoX0pwREsrsD55i7Xln0Wu1l+pU+OOM/TjtjblkZLey9f5YqMD6rFvGqUv8FgH37XUr9W7rQiGYNNqk7OKqtlP5eMtnXDLDS6edsDvdxet5PThy5SXc+cuJrK/NB9JILajAc8xCKt1uhldW0md5V4pOXEy/Z9+m110fg0tp0tgJuDRmr/LfOT79UwCezbgyVLUAltZmygBspQKTQkMQhHBx6peelJZmpphZmmn3I8FMAQaWCrBC4pI3I8vzBUEQrKmtqWbu0xcy7OfbyXDVsiR7OPnXfkf3voMN3xONiZv45s4omNdeX671/1AY3d1YGNBI+I4DgHNrMxMiyesUUZsgtGySWtAWNUszk9WMQWI2PUGbVS2lXoWpJ17TCt7QuYao5c2I24CgRZozQsKY/+U7FNz9LIWlsC0fyu66gpMnvhnTz7RUgXX2cst795Pi8/JF8dEs7Dc09CYFGObO9EpfzV3tn+ObnjdyReHnFFDFurrWPFh1KIf9NoEH55/G9urGgsSVXke3a74hrU0FbatSGFFRxe7zZ1N0/M+kZNeQVbyNlJwaQnEYcFnp4+bWDwIw1TuO38v2bjznwNpMj0gKDUEQWi5hFxlNwdIsKNTSiaWZmc2ZE7QFCAQrwCRvRhAEoSnx89Nnsv+2jwD4sdulDLhxGnkFbSO7qcZxwC5BjgN6RD13JrhZM+1nf+7MiFbLSU+vDxw1G0vDcxxItLWZI1GbBhG1CULzJ6kFbQpRsTTTug0Y4bR+0rMz0xO06b0OYJQ3o0LcBgQ7SHNGSBiDDz+VlXtmsLK7m+7/eYuDT7na1vtiqQI74sdvOGDZXKpSM3hk7ITQC3SsXtJcNYxtNY3Xu97A1B6Xc1rBV2S4a/mlsifXr7+QMStv4/XVx1DhU8z8AwWGy0eXS2aR3dM/Ynnc8HDb1lRm1Dfc2/Njb+p3ZwZe2Qm4VB9rLDRG5sxiWOZ8qnwZPFF+ne6fPRGFRsgPBCk0BEFwQCRK0uSwNFMwUoCZWZopGAUnh6EAM0IUYIIgCAkl//Dr2E4+Sw56juEXPmyaL2OJwUSPMjGkt/LRcEKpWPM6JrkzwfuLlrZmU30B2a5qDh+8QXM/I6cBh9Y3HhvXaK3NSgLbKFmbOSJGanlBEBJPkxS0RWRpBsaWZupGTTTcBpQGjVF+Z5h5M8X2v4G4DQggzRkhgaSkpjLq5S858r/z6Lan8ZJ8x0SgAjvix28AeHfkGWwo6Bx8geaB2821mhsLHmZWz8N4pNMEhmYvos7n5quyAzl7zd2MX3MHM3YPpR518dT40G975B/kD13V8Lo+vZ5Mr7fhdfXGfDa8vn/gld2ASwi2zIE0aphQ9DAA/9l1Hht9nRovjcDazAopNARBMMKoyAhpwEa5yDBTgFk+26JiaQbRC7XUNmT0luk7VIAV4Dhvxg6iABMEQYiMPQYdTPZNvzNo1Jlh38PpxE7scmfMxiM7dtIuvtnRF4Cxe5XYuD5O1mY2iKe1mYjaBEFQSIigzQhLSzPtQ9jMWcBJLWXUgFGIb96MuA0IaqQ5IySUvNZFZGabyXUdEAUVWIdtmwD4s6tmGXmgMdO9toQLq1/mbfd4vmh1NBflvUKb1J1srO3AU9su58i/3+PaDXewqHIfQPnhq68Caz3y94ZXvnoX27/ux6AyfyMnpdbN6kmH461K13xjo4BL4wHp/Dav0SO9hG11bXlp1yWNJzyqixJobSaFhiAI8caRb3JElmbqfbuhlpESBQWYg7wZPTW1Mt4q468RogATBEGwT1aOvUxONcmZOwP+cUg98WWUO6PfrPnir54AHNlePbnVdK3NYpHXKQhC0yfWgjYjopLRGbGlGYRamhnVS2ZuA1o3Ae38XKpmX31OheTNCDFEmjNCkyeaKrB2O/yjRv/Vv9HZs479Ni3g2N1TuXbZE0z5dRzT145hQs2jDHL9jNfn4tvag7ly6zOM/vsrntt+LVvq1JJjtZdyaIFR+lMfarfn4PmxNytuP4HKkrb8lO8voLypXnrd+RmFY8yWOOoVGcFFRee09fyj7bMAPLz1Zsp9uf4TRvY7UmgIghBjbBcZWjTPhoT7JqsxtDSzG2qpPZekCjAV0VaACYIgCE2H6OTOgHXujPacn+nzu1Lvc9EvfT3dOmoHYTNrMwd4bFwThrWZiNoEQUgWrCzN9J5X0c/o1EM9txULSzO9xn9s82bMkLwZQZozQpMkViqwP3v6mwbnfvMGXz07ijemn8ND397MZStfYI/KFdSSyg9Zw7kn4w4O987i8soXmVV5OPUND3etCgz0u++pbP10CMtuPIt1LxwOKV66XPxdw1mfC1Kyasndez3BKgA9FZie+roSqOD2on+T6a7mp/L9mbprnP+0R3O5mbWZdoCVQkMQhHjhoDEbNd9kI6JqaaYXahkLBZj22jDzZtRoLWwCREMBJsvzBUEQ4ogy4aOxgzZb8ZhMuTPbPZksqvavnjl2yGobn2GlutYQQ2szNU7zOkMQUZsgNFsSJWiLKbYszcxWQaJzLlz03Aa0Kzqjlzdj5jZghbgNNH+kOSMIKh597ipuufF+1rXtTG1KGmsKujK3wzD+2/lE/q/nAxw8+Dsu7vgK76afyRYzSbGpOjlUBVaw/yrdK7d/tZfOUSNrs2COzP2BQ3N/osabxj2bJwKu8AuNEpNzOkihIQiCHmEXGRpi5pscMwVYPEIt9TJoopA3o1meH3HejCAIgpAw7E7wmE0YxS93xjx/5qv1/h8PR/U0EgjorUxV0JsEjI21mSIAsVqtG5W8TkEQWh5RFrQ1LUuzOoN9LXbdBhSSN29GaJ5Ic0ZoXkSoAvO53Xx6xPGMvu8rBj2zhKOv+JILxrzOPwfex9S+x7MrNd/Bl0mj8aFurgIrXdiL8mXt2fldbzLq/ZODaXVuKlYWWXyGfsBljruC24qeA+DlHedTUtvB/teOgrWZGik0BEFwTDL5Jis4sjSLR6ilHlFQgMU6b0aW5wuCIDQZHE0gRTV3xtjSTGHaL8UAHJq/lBS3N3BUb9Wp1UpVvbE6gEf/cBAW1mZ6RNVxwELUJo4DgtD0SFZBW+ItzfQwW0Gj5zagzZhpWnkz4jbQPJHmjNAscKoCM8udCbzA51b989BTgRVgQwWmPmasAqsqaceq+8dS8VcR1Sn+h21tqpfim74IXOEs4PKmdq/SIW07q2s68eKOU3S+E8YqMIVEeyhLoSEIzYp4FBlRtTRTnn1RszSD6IRaKsRZAWaBUwWYLM8XBEGID7Gyg45t7oyCcZPmp9/ascObS76rgoMHbTK4t5HjQGKszWKS1ykIQsslGQVtjrBjaYbmWKRuA3rouQ3oRRYEiGLejEKD24DQIpHmjCBYYaQCM8RO7kwaek2a1PzgKiAtzau5wizg0j94HZD9B6cVfAPAHZuupcaXEXyLSoKLC/V+jAoNNVJoCIJgioMiwymOLM30iJsCzAw7eTMOFWB2sFieHwmiABMEQUggSZM7Y89xQNn3et3M8vQFYOzeTnJnFPRWzIRpbaYn5FBhmW3nEMnrFITmS5MTtClYWZqVE4agDfzPbm1TJppuA0q9ZLSS04RYuA0oVtDiNtCikOaM0GRpGiowOwGXjRNtW6cNYPXjRzZ+v42tbXxOY6GR7SrnnvaTAXhr5xgWVA6w+2WDceihbIXdAV8KDUFonkSryNASc99khYgszaxCLZXiIpJQSz3Ll+jnzegRjeX5giAIQvxoWLloMfFjN6g4erkz6mP2cme+WOEf40d3Uq/eVK8+1XMc0E7oKblwEVibKcTQ2kzyOgVBCCJZBG1RtTTTXmBVH5m5DahFauZjSfD55MqbEbeBloE0Z4TmS1KowMD+wz0N6t3UV6Q3HFk3qISMjh6da/Wtza5v9yld0rezvrYtj28dT2Oh4VAFphAHD2UpNARBaMCiyIi7b3LcLc3shlrqYUcBFlnejFq8YDhu0jjeNizPFwWYIAhCk8XSDlpNxLkzYNdxAGDagi4ADMxYTYfCMoPPMLI2s0ESWZtpEVGbIDQ/mq2gTRc9QZv6oat9dke6WsYsb0Z9jYYo5c1EA3EbaL5Ic0ZofsRTBVaAgQpMOwFmpALTyZ9Z04a01W0BqHNDdi9lJlCrCgje7pf1F2e1/haAOzddSIUvU+cLB0iAtZndQsMpUmgIQnITq3+L4fomh01ShVraUYCpiW3ejILd8VUUYIIgCPElVo4DlsTMccB/fOPWXH6u7g7AuGHrbHyOeoytw3gAj4+1WSR5nSGIqE0QWg6JErQZYSVo03UbMEK76lFbG0XL0kwhPnkz4jYgmCHNGaHZ4HSyJyoqMFOy0A+41Cs2GifZvNVp/PKvMaR7/UXUXprIGD2yXbu5r8N7AHzgGc6PFf2t32RFBNZmkRQaIT8kpNAQhGZJonyTE29phuZYrBRg8cubcbo8XxAEQWi6qFdOBk1AFWsujKnjQCNfrPfnzhzTW72i08xSVC97Rm1ttiv0QzwOvpADxwE1dkRtkeZ1iqhNEJKXJiFoi8TSzEMElmZ69ZKZpZmCTh0UdC4xeTMK4jYgKEhzRmjSJH/uDNhTgQXv+1J91Lj9A1/rUb8Y3LdxsPq/ohl0Td/B+trWPLT1REKXhsbX2kyN0UAvhYYgNH8c/Rt04Jvs9JmRPJZmkHQKsAISmjcjy/MFQRASiDIBZGAHbdtxQIuj3JnwHQcApv5cDMBhBUtJTa3WnNV3HHCU9RYna7NwEFGbIDR/mrygTZdYW5qlava183Da8wriNiAkBmnOCC2Chk60CfHLnbEOuPRWptP9F/9gWNJhF6R4A2dC1QGH5fzJqQWL8Ppc3LrxNMq9DgYUj8F+jAoNqx8AClJoCELzJlLfZDUx9U1WiIqlGSS1AizWeTOCIAhC0hB1xwG1HbTj3BmwlzujPefnh1/as83binxXBYfta6eISYC1WUAIEg3HARG1CULzo0kJ2hTsCtocWZqB8SpHo2N20c6/xSBvRtwGhDCR5ozQPNGowBQcq8BikjsDZgWGwpJP927YL2pdoXtN65Rd3NNhOgCv7zyIBZW9sK0qsFxSSlytzaTQEITmRTSLjEh8k/Ww9E1WiJqlGbR4BVhgXBYFmCAIQtMlOXJngvd9uPhmh7+OOHbvVQb3TrC1mQ7Ryuu0/I0kojZBaLIkpaCtxORDIrY0U9dJylZbL9UZ7NtBqaG040v08mb0ELcBwQppzgjNiqaTO6Ng3KTZuL4NqYHcmaHjF+rc28dd7b+iMLWC5dXteHKb9pd5clqbSaEhCC2bSIsMdSPX6BmifuaYNorjZmlmpQAL19IsRgowNaIAEwRBaBZEc2Infrkz1o4DANP+8o/1R3VcqjqqXY0aY2uzCERtTjEVtZmo6BVE1CYIyUOzFrTpon2gWlma6WHn2W3mNgChbgPauTsTxG1AiDLSnBEEYq0Cg3BUYN7KdPYpaQtAbT/1hJh/IDo+bymjWq2g1ufmlo3HUONLI6xJPo/BfoI9lEOQQkMQmgyJKDKi7pusELalmfYCMwWY+nwCFGBaJG9GEASh5RGP3JkCwnAcUKPfpJn2UyfqfG76pq2nR+dSG1/OymZUIT6iNrWQxE5ep2NE1CYITY5oCtqMiLmgzdDSTO02oEZP0BYNtwG164B2H6wbOTqI24AQRaQ5IzR5kk4FZthsd64CK5nnHyDr0qtwu2sajndKLeOfRbMBeHrbgSytbk9o4HSUrc0i9FC2U2g4tjaTQkMQmhTxKDLUmPomG6E869RNGcsiw2mopXbfLk4VYNp9A6KUNyMIgiA0HRomfmaYXtZAVHJnTDFyHLDOndmxK4f5Vf5a47iha6w+KIB68s/IcUCDx+BWyZ7XqYOI2gQh8cRL0GZkaaZH1AVtHuzNO1lamkXiNqCHnqBN9VrcBoQ4Is0ZofkTbxVYA4oKzGiSTF1o6Ddp5v7UBYCtqW6GH7wOABc+/t3xG3JTallU2ZHJO4aq3qHnnxwlazOFMD2U1UihIQjNl0QXGUaN4BCMFGB28Di41lABpj1mhp28GT0FWHzzZgyX59uc+BMEQRCSn2R0HACYsaYPAGN6qOXG6lWpemOvg4k+M2szhSaU1ykIQvLSJARtClGxNFOOO7E0M7tGT9CmdRuA0FpJu4pThV7ejLgNCFFCmjNCsyP5VGBgnDujLS6CFWI1nnx67UrD53LRbeRvAJzb+lf2z95AhTeVWzeOwkt94D1WljkmeAz2HarApNAQBMGIaAZaOiVxlmaxDLXUa+obKcBUhwtImAJMlucLgiAkllhN8CRD7szUJd0BOKTVMrIyqwyu0pv4s6qdmmhep1bU5sBxQERtghB7mrygrclYmlldB/HOm1FoELQJLR5pzghCgPiowLJU+zb5w19oVLbfRu/0rVxXuACAB7ccxNra3MBFRoWGDdRjotEAqrU2k0JDEAQdIioyNGj/7asbt0bPDCvf5IYGstEzK26WZlrCDbW0qwAzyZtJkAJMEARBSEIMHAf0sGVzGbXcGWNLM4VFSwvYUN+GLFcNRw4x8mNWU0tTsDYzI1JRWyKFM4Ig6NOsBG0eHFqaKWhro0gszZIvbyZkbBW3gRaPNGeEZkFS5c4UYNF0V05qVWAG1mb/808wbkz18lj3WWS465m9uysflO5lcH/tUn09VYKBtZkaM2szG0ihIQiCgq1/kzZsC/WIyDe5JLCNu6VZ01SARRNZni8IgpBEWEwIKRNKicudgWAxQqjjAKTx5TZ/DXFsv79V77OyFY2ztVkYeZ0iahOE5kWiBW1qwhK0KRg97wwFbXrH1MfVdZLes9mOpZke2hoqufNmxG2g5SHNGaFl4EAFZgszFVgQerkz4CTgctOfHWhbC8d9D73dpeysy+DOTYcA6oFSW2joqbC1fp4aPAb72zRbC5wUGmqk0BCEpk00iwwtCfFNjrmlWdNTgJmppCVvRhAEoekR7gRQsjoOTPvTX2eMLvqT0HpIO+EX4ZgcrrWZDpGI2iJFRG2CkDzES9CmNyfjSNBmZWkGBoI2rduAgt4zONxayWzMUDf2tdnQyZk3I7QMpDkjNG+SRgUG/gd+OAGXLgb9nMvxP/m/w8TNB7KtPptQ1YDeNkJrswg9lK0KDSvFeyRIoSEIyUE4RUbT8U2OpqVZ01SAKeOoFaIAEwRBaN6EnTsTRceBL+Z3pNqXSnHqVvr30rEkC0FbKxlZm2lUGR6D2yVBXqeI2gQhuWk2gjaFiARtENoo17M0s4OR24BWxKbdt+k0AFHNm7GLuA20DKQ5IzRLklcFBk4DLnPcNZw4qxK3D37YG+bUdrHxGVGyNlMI00NZjRQagtB8SaYiQ01UfJPVeJzcJBqWZgqiABMEQRCig92JHjtBxVHJnWkgcseB3RWp/FDeB4Bx+60x+EJ6jgMxsjYzir6JYV5nOIioTRAST9IK2rToZXSqicjSTEH7TK4zOWeGdp5NW1MlLm9G3AYENdKcEQQLIs+d0evEG3XnQ5s0txb9SFFdNVvz4YXRKRx6tFHnyY7CIExrM4UIPJTVSKEhCC2DWBYZatTPlKhYmqnxaF47UoAZWU3anQhKHgVYNBEFmCAIQhKiTAzNDD6sTCSZrZRMHseBVL5Y7f/xMbbbckLH4zhbmynE2NrMsZBFRG2CkBCapKCtJLA1y+iMyNIMQp/JRjZn4aDUUNpxRF0vqYQBkjcjJABpzgjNBqcqMLPcmYhVYCE4UYE1qqCPyF3FSfnL8frgv4dkUZnhovMQ9dNazw5HqziIo7WZDlJoCELzJlH/lqxsEXWfPTGxNFP2tc1vswkghXAszRKjAAsrb0YQBEFIepxOBDmdaAoixo4DAFMWdAdg/+y/yMu1Uwclt7VZOKI2S8cBHUTUJgiJI16CNjXJY2lmlZtsB/XcmtkYYeQ2YIK4DQhxQJozQvPHQAWmEBUVWCFRVIEFbplSwd3tvwNg8o6B/LbT/8N8W9FOjAcpo7wZxdpMS+yszZpCoSENGkGIHbr/9rQNVIsiQ92gTUiRocZj9Aa9Z6hRcREtSzM7CjDV60TlzcjyfEEQhGaL3sRToh0HAJaVtOLvuvaku+oZO2yDwXsTYG0WhbxOM0TUJgjJTTII2qKW0amul2wJ2vSO6Vma6T2Lw7E0U5o0arcBPecBcRsQkgNpzgjNlnBVYI473I4VYGCtAvNxb4eZtEmtYmlVG57eth9zP90Lt8/Hqgw3ew82W1OqRk8JZmJtVol9azOt0rwksLXpoWyXWBYagiBERpMsMoyIuqUZRDfU0uq65M2bkeX5giAIyUU0J3wS6zigPtf4esZmf/1wzJ5/49zazCF2rc0cYievU0RtgtD0ibagzSmRNIV18egddGJppke4lmZajKwxAyiCNiPEbUCIEdKcEQQbmKrA9FDUyUEqMO1EmbroCLY0Oy3/Nw7NXUO1N4WbN46klhTKtubQvcJffAw8Qm3Jpg2b1h5zYG2mxqmHsgnqAV8KDUFo/iR1kVES2EbN0gz0m97RDrVsHgowQRAEoQkQmDDS2kE3jdwZ+Oz3HgAc2fYPwE4jSs9xwKG1mXrfruNAhHmdWkTUJgjJSTznHNRzKUaCNlO3AbuCtphZmsXSbUDBIG9GQZnP0xur1GOauA0IUUKaM0KzImlUYIZkob9qBpQBo3vaDm4umgPAY9sOYEVN+4ZrU1d3AKCu2yaMJ/uMlAcOrM08Bl8/CtZmauwWGo6RQkMQYk6TKjLsYlRkeDBAbxWiuqCIRaillsQqwBQaFGAWyPJ8QRCEJMbmBFGy5858vbA95b4MOqR4GLrXToOr9KzNHGBn8tHK2kyHaKrYRdQmCMlNyxC06R2zY2lmt0mjdRswyt8MI28GwnLKEbcBwSnSnBFaFrFUgdnOnQGjybQUvDzY8TOy3XX8WN6VN3cOCHrXb1/3BWB5bj1t2+mOfhr0LHVsWJup8ejcNsEeylJoCELyYqvIsCDqRYbWN1lBzzfZiLhaminESQGmxkIBpqCMm8o42oCyPF8UYIIgCE2GcCeEopI7E6HjgJ/gZk1NTQrf7vLXTccNKlGd0VqboXkdY8cBE2KR12kLEbUJQkxp9oI2NR69g9GyNHOymsbKbSANcRsQkglpzggtg1irwMLOnVEGBP8gcWnbHxiQtZHS+gxu23QkPoIH8r8WdaSoxkety8VBx/6hOmPH2iwMpNAQBEGHiIsMCwWYGjvPAjt2iUGY2THG1dLMCCMFmLYhEyUFmI28GQWn46QowARBEJo/iXIcMN5P5fO//WqRo7suw3jsNcqbiaK1mZUARERtgtAiSbSgTZe4ZXRCdC3N9NCuolH2DWolxW1ANwuN6OfNWCBuAy0Lac4IzZq4qcD0MFSB6Qdc9s3YyuVtvwXg3s2j2VzXitAl+y7abSkAIH+v1TizNlMGujhZm1kN7A6RQkMQkp9YFxlGCjA9ksfSTI9wFWDq/fgqwMJdni8IgiAkJ5YTP0nuOGDFlPn+3wGDM1bRvrDK5ruiZG2mJ2pzYG2mJtK8TluIqE0QYkIyC9oSa2mmtn5WtpFYmpmNDepVltpaSW9+jqjkzSgYjpniNiCokOaMIJgQlgqsHRYqMNAOAGm4uK/De6S5vMwo68vnZf0wUoGtX+CfkFzTuhy3u976+yXK2kwH9Q+ARBUaZg0aQRCsSbYiQ40tSzNtkREzSzP1fvNUgDlFFGCCIAhNgCbiOGCVP7N2Yza/VHfH7fJx/P7qiTO1QMLIccCqSaPzw0BvcjJBeZ0iahOExGP276TZCNrUePQOxtPSTKmT9ERs2n3tykwDIsibEbcBwQnSnBGaHQlXgdmmUeV8adsv6Zu5kR112dy7+VjAeCD/6fM+ZHl9bE91M3zUKoOrkt/aTE08Cw0zpNAQBHNiXWSYEVPfZIWoWppFQwFmFHgcRQWYmjDzZhqW54sCTBAEocmSnI4DWrQq6TSdc418vt6fOzOu919YW5tpmzJm1mYaPAb7ClpRm0JJYBsFazNZPSMIzQwHgjanRE3Q5tG8ThpLMzXaVTX2V2ACkjcjxAVpzggth6RUgUHfjC1c2tb/5f615UR21OcGzuhPytXVpNGjNAOA4uHLCV0SGiNrM/W+lYdySWCbpIWGrJ4RhARhUWRE4ptsWmQYofcs8xjsB2Gw6jAmlmYxVoBJ3owgCIIQBeKTOwP6jRl9x4FPFvUAYGT+H2Sk23EcgJhamymYZN+Fm9epxfI3layeEYSYEW9Bm3rOJHkEbUbH9SzN1Dh8Bus2W5TaSTteaMeTrMZNAeI2ICQMac4IzZ7kVIH5D6RSx30dXg7YmQ3gi7IBgevUg4hWIZZK5TL/IFrawe5aeb0BzqG1mRlaazMdkqnQMEMKDUHQJ1mKDDV27BGByHyTbVuaNV8FmOTNCIIgNE/sTgBpHQfMiF3ujHpSzdzSTGHu723YVF9Ajqua0cM2qc5ox2rtsQRYm0VB1GaGLdW9iNoEITmIoaBNF7t5wY4EbVpLM/Uz00xQrH1tVEsZ5XEaXaceQ4xWZxKTvBlxGxCskOaMIFgQy9yZ89pMp2/mWnbW5XDv5vH47cysJ9nmfrYXLp+PkkwXPfoZ/eLXFhoRTBLqWZtFUGhEk1gUGtKgEYQYEKUiI26+ySEoViZKYRELSzMFUYAJgiAICUSZMJqpf1qZcIq340DwmGduaebHxfSt/QE4fq+VGE8AJsDazEZepxo7ojat44CI2gQh/iRS0GaE+vlh2vQtCWxjltEJ9gVtTlfQQKjbgJ7zgA2nAYhgnBK3AcE50pwRmiXhqsDimTvTOW0TV7R9H4CHtp7MjvpW2A243LKugF6V/n+++x39B+FZm+nhwNpMIYxCQ/2DwEj5nshCQxCEYJKxyFBjyzdZwa5vskfnfAOxsjSLswJMTbQUYIIgCEKTJZkdBxrR1kfmjgNT//DXGkcX/Q7YFQpEYG0Wh7zOcBFRmyA0ASIQtBlZmunRPCzN9NDWUEb7NhC3ASFOSHNGaFlYqMAUoq4CK0BTaPi4vehZstw1zK3Yiym7Rqouthdwmbm2CABXz/U2v5x2oNNTgelcokeMCg2rHxBGSKEhCElODIoMx77JJraLyW1pFkMFWCzyZmR5viAIQosllo4D+hNs1hNtX8zvQKUvnc4pOxjSz6M6EyNrMz2sVOclgW0i8jpF1CYIUaNZCdoUwsrojJelmbopr0U7PqjrJVXjX+02UKBzG3EbEOKANGeEFkFcVGBGAZc6ljJH5c7j0Nz51HhTuXvTJfjtzNSDinXA5W8z/ZOZK3NqadXGqDCIkbWZQhw9lKXQEITE0KSLDCuMnmEe7YFILc2cYGZppndNcirAZHm+IAhC0yDhuTOOMMud0Z+gq6xKZfYuf910wuBV2LM2056PsbWZDrHK6xRRmyAkMckkaLPK6FSTcEszbcNeOy5o3QZMiFfejCCokOaMINggrE64gYVMlquCW4oeAeDlHadRUttdddb+JNtvczvTocZLjdvFweOajrWZFBqC0MKIU5GhS0lga8c32aN5bVhkhGNpFk6opR5adVjTUIAJgiAITZhE5M5khezoncSu48BnK/sAcExX68aSn1oSZm0mojZBaJIku6BNl0gEbUlnaWZkealeZZNEeTPiNiCokOaM0GxJVhXYeW1eo33aVtbVdODFHacHjmbROFDYU4GBm3ZbWwNQsPcqq68eIEJrsyh4KEeCFBqCEF/iXWREQux8k81wYmkWjgLMyMYMmpICTJbnC4IgNEEcThxFnDtTgIGoQBEfaFfN6DVm9B0H/vtTNwAGZZTQpb1dxwEIy9rMY3CZ1hrIQV6nGhG1CULy4bhm0iPGgjbTJm9JYOtU0ObROR81SzMrzNwGtOODQa1kVkKJ24AQR6Q5I7Q8LHJnYqYCK4C2Odu4qM1kAB7fdi01vnTNRdoQS/OAy/UL/YPt2tbluN1GA1kUrc300PMgVVMS2FqowKTQEITkIhZFhhbtv2F1A9bs375C1HyT1Xi0B9SWZkbN7GhZmlkhCjBBEAQhtkQ6URR27owtjFwGjMRsfjZuzWRhVU8ATth/Dc4cB5TXNq3N1HjMT1sRtvCEKKnqBUGIDc1J0ObY0kx7zEjQZtdtwGpc0NZKecHHChC3ASHhSHNGECyIZu7MlflPk+Ou4JfKfZheNppQFRg4Cbj8YVofcuu97Ex1M/yolUTH2kybrRDAY7CvJUwPZbskYvWMNGgEwQY2igxbDVQdYuabrODRvLZdZEB0LM0URAEmCIIgNA9imzujdRzA4HUjn6/tC8CxvewK8GJgbRalvE4je9eIRG16NZOI2gTBkkSsmomqoM0Ku88tw4a11tIsjGerJVq3AT3nAb25Nw3iNiAkCGnOCC2GuKjAtKge7l1dazgl90MAHt56E6AeDLXqZ3vWZnXVqRSXZgNQPHy5zS+lHQgr6J62irGtZnNV21d4vNO/eaPrVXzU/QymFh/DK10uYGLeXRyc+S0uvAnxUE706hlBaGnEY9WMFqeN14iKDIWksTQTBZggCIKQOCwnhAITSlo76Fg6DgTnzuTpXOTMceCTRcUAHNpqKdmZdh0HIGJrMw+hRJjXqUYtYNGSiNUz0qARBAtiuGomLEFbSWAbFUszsLY003vt1N3FSNBmdI39bGcKCdtxQNwGhHCR5owgqIhl7szZ6W+S6qpnTvkIFlYONbi784DLsj+6A7C9/U6Tb63gLyw6p23mjIJvebzTZL7tdQvTe97CI50mcUXhBxzV6jv2y/6NfpnL6JXxNwfk/MTpBe/xQtFlfJJzPMVunXybMDyU9X4YGFmbaZHVM4KQZISxaiZS32RTSgJboyJDXWx4DPYB55Zm6BwPl+RSgCmEKMAsEAWYIAhCEyaWuTMWjgOhROY4sGhpAWvr25LpqmXM/huJmbWZ0erbpiBqi8LqGUFoSUQlnzMCG2g7xFTQFpalmVkTxqmlmRa124CCdtzIatwUmHwVNcWNu+I2IMQCac4IzRqnKjCFaKvAcinjpPSPAHi9/PzAUUUFph0s1IOMdcDlD1P3IsXnY12Gi70Gb1Bd0ziwuahh36x13NRuJlOLn+Krno9yR/tPOKrVEgpTd1HtTWVJ5R584DmSB7dcwrXr/8klax/igrXPc8vG+3lz59mUeXPZI2UFr2efS2fWReyhrEZWzwhC8hCPIkNLJMrOqPgmq4nI0ixWoZZ61zhUgIWJMg4q42IIyvJ8UYAJgiA0GxLtOGBOeI4D4GL65v4AHLeX3e8Xof2O3qpcK4ugEnu3titqc4yI2gQhoTixgY6aoE1L2Bmd2mNqSzMIfqZGIwtZWSmpFbHpuQ0YrMIUtwEhCZDmjNAyibMK7OTsj8hxVbCivhc/VB2o8wlZhDvptnNzLr3LUwDYZ9SfKAOcCx+DMjfyf+2+45uer/BWt7e4oM18emVsp87nYn5FD57cOoZz1lzL/ise48w1/2Ti5kt4fecYvto9gu8rhjG3Yhif7jqB+7b8k6P/nsHSmj1p597GFfXPNn4Bux7KJYGthUoj0YWG+CgLLZlY/Tce7VUzUfdN1rU0M+vShGNpFq4CDKKiAJO8GUEQBCHOxC53BsJxHPj0d//viqMKf8eFkYgvwdZmOkQzr1NEbYIQHZqVoC1qGZ3q1YVqtM/QSCzNrDByG7CgHaFjUTTzZiwQt4GWjTRnhBZFTFRgXSweou3g6NrpALxVcxbBWTNawlGBgWtVZwBqum5gQOYWbm73I1/3fIe3u/+X89r8Qoe0csrq0/m0tD83bDiOg1bczHlrL+aFHSNZWNmbGp9ZM8g/6u6sb8NdO+4G4Ji0z2jFrmZZaAiCoE+TKTJKAtuo+CYnwtJMFGCCIAhC/Emq3BllHAvJnYnMceCrBe3Z7cukvbuUAwbsoKlbmxmJ2uwEhJsiojZBMKVFCdoiyuiE6FmaKZi5DWgb9VoRtA1ilTcjbgOCCdKcEYQwMOyU66jAcr1l7O39DYBv6wKjbVChoSULZwGXKeye3pFzvqnnutfreLf7NM5v8xsd08rZXZ/Gp6V9uGLdWEasvJhbNh3LF2X9KPOqByhFqaCtIkILjV9KB7Kprj3prlqK1WtgpdAQhCZPshUZRiSHbzJE19LMisQqwBREASYIgtACSUTujCnhOw7U1KQws3QvAE4c9LfNd8XQ2izCvE67JErUJnWT0BJpeYK2eFmaad0GjJ79eoI2CGrwK24DhhlnKoobd8VtQIgV0pwRmj1xU4EVa14HJsSGVs4nBS8lvu5szO6kU3Aog4T9gEsXPgZmbmJCu7l81fNdHqv/gXHzfLTbBdUpbj7b1ZOr1o9ixMqzuWXTkcwuL6bWl0ro4KhFrQLTZ2O9f2avY7XOkhgjb1IpNAShSZOIIiOuvskhxNvSLLkVYJZ5M4IgCEKzI7lzZyBcx4Gpy/0/asZ2/tPk3uFYmymrbVV4DPadYCFqU5NsojZBaI44tjOzSdIL2tQkjaWZntuAdl9vrk1FAaFjj7gNCHFGmjNCyyXaKjA1qod523r/bODfqT1tfIpxwGUq9QzPXssdRbOY1et13un+MRe2+YXOabup8KbyW7dMHj7ZzaMXZHDzxkOZubuYGp9xYeLH7mDYOPqm4AWgTvluHpO3mXmW6mBUaKiRQkMQok9TKTLURNU32aN3QbiWZuHShBVgsjxfEARBMCDi3JmoOg6kMmVuV+p9Lvqnr6W4cznRszbToD6kXj3jNK/TArsCFhG1CULkhPXfc1MStCnYzui0wq6lmd1rzARteucdCNrEbUBIINKcEVoc8c6dqclLByCDagef4i8+slzVHJm7lAc6fMyc3k8zuev7nNH6V4pSK9hdn8a0Xb24Zv2RHLTiDB72DmN+HzfL21aTmV2jupdW2a1s9RQL6ipC30O5tWsHADt8bYK/sl1rs5LA1kK1YWRtpkUKDUGIjKZUZDQNS7NoK8CMLM2STwEmy/MFQRCaNkmZOxOEc8cBLVt3ZDC/yv8748Rhq21+wQitzRQ8OsdimNcpojZBiB/NQtAWtqWZ3jySkWuLmduAEUarYtTouQ04RNwGhAQhzRlBCBNbuTOFsDajGwBD6+fTU1EjF6AbcJlGLYMyl3FZm095pctT/ND7DiZ1foPj8n8hP6WKbXXZvO/Zh8vWHcdBKy/kpo1H8fXuHlT7Mlkwswdt67xUuN2MOG4pxvkHdq3NQsl0VdIhbRMAm7wdwvNQ1iEpCg2dyWarH1jSoBGaO/EqMpwSlm+yutjwGOwD8bM0U7BSgGmvSYwCTBAEQWjBxDN3ph02cmfAzHEgeBvK56v7AnBsz+Um94+BtZlCnPI6tVhlVoioTRCMEUFbAMNSycrSzEggHA7KykitiM2h20CBzq3VY1Jx4664DQixRJozQosgYbkzwOLcffk+60DSqeXDlJO5K/dOzk7/D8dmT2V8wdtc1uY57utwJx93v4gFfc7i7e53cm27jzkgZzkZ7jrW1rThtR0Hcfaaczls5ZXctXk0c8qLqfWlBH2Oz+em87ZWABQNsjNIOA1jq6R/5m+kuerYUteOTTs7NJ7y2Hi7HslUaBggSjChOZOoIkOLuuEatyJDTUiREQ9LMycKMIXEKcBCludbKMBkeb4gCELTJ+65M2GNWeqxUW8FTegY+/GCYgBG5Cwlv1UNcbc2U4hDXqeZqM2WcEZEbYJgSbMVtJliZTsQbUszK8JwGwDrrDMdlxxxGxCijTRnhJZNPHJnXC7uaPcvFmTuRybVnOb+gNsy7+ehwpu5s/29XNtuEifkf0bfzJWkuerYWdeKGWX7c/fm8Yz9+3aOWvVPHto6lkWVPfHixkwFtmWJf1JzQ+EuCGTD+NGzNgNjazNtoeG3Nts3awkASyoHAYEf2E48lJtgoWGFFBpCcyWWRYZVQ9UxJRbno+abDLGzNHOiAGtccdnwsgBRgAmCIAhJie3cGT0K0HUcCEZdH1mtPE3l95X5rKjtQLqrnuOHr7PxJSAiazOPwb4Ws6w8EbUJQsJocYI2BY/BfgNGgrZYWJqZuQ3oNeaj5DagwknejCA4QZozQosk5rkzmof5pg4dObfjG1yW9QIveS9mRu1R/FR1AF+WjeIDzyk8te1yrlz3b45c+RoHrXyN6zfcwHueoymp7Yy/CWIv4PK7T/uR5fWyLdXNkCNWE761mT6H5c4CYF7F/qEnPTpv0Fqb6SGFhiAkhGQvMtSonwmx9002wszSTItdSzMrjBRgFlgpwHQQBZggCILghKg7Dmgnxwrwiw4MycJ4Is5YzKYwdcPeAJzQNw7WZno4zeuMESJqEwRzrP67tT1XoPPvKGkFbVHL6DSzNAsHrfWz9hzor5bRCNr00I5Bxfa+kVHejLgNCE6Q5owgRIAtn37lIe9yMSf1EB733cj1lU9wYc2rXLdzEhM338tz269lVvkYNtQVA9maG9gPuKyqSKdXaQYAe4xYavNPAXbU3u1S/mZw1mIAvt59pP+gR+fCOBYasnpGEMIjrCLD5r+NaBUZ6sasLkllaea0CaNurmsnkOKgAJO8GUEQBMEAu3bQVoTlOKDgSGSglzujfR3arPlokV8QcmTB76Sn15M01mYO8zrtiNq0NVMiRW1SNwnNCcP/3sOYU7Ci6QjawF6dFImITc9tQLsPjastNdhxG1AhbgNCrJHmjCBoiEXuTBAFWKjAIHQCzn7AZdXybgDs6qgdYfWU3kahbKGFxqhW3wKwpHIgW+raJ0WhoUUKDUGIHEcrxaK4asYOUfNNTkpLM/VWIcoKMJtI3owgCIIQQpgTS04dB8LPndG6CpiJHFL54Ze2bKovoJWriqOHbbL5OU6szTSTlx6DfSeEa01kQTxFbYLQlIil00DzFbSpG912Lc2saigzSzO982lExW1A8maEOCHNGaHFYDhZpBQaSakCcx5w+cOU/qT4fKzOhD0GbCFa1mYn5n8BwBdlRwef8BjsO6GJFBpibyY0B5K1yEi8b3Jjvlbo8XhYmmlVX3pqMIhIAVbcuCsKMEEQBMEO0ZpIcuQ4oKUAg9wZ7eSbfccBHy4+3+r/XXLi3mZ/SKfWZjq/J4xEbTHO6zRbPWOFiNqElk7U7MyiQNMVtEFklmZG9ZEa7XPfRlNGjeTNCEmANGcEIUyipgILKTS0OAu43LIhn97lKQDse/RvBh+qh7Hqu1/GCvpnLqfGm8anu0Zb3yqBhUYsV89YIYWGkOw02yLDiLB8k7UKMDXRtjTTO2f2nBcFmCAIgpA8JI/jgJ61mbXjwMe/+sUgY9v9htvtJTrWZhEQ5bxOM7Q1k4jaBCEKtBhBmxHaxnW8LM3UbgNq8ghxG9AbUyRvRkgwLaY58+CDD+JyuXC5XPz000+J/jpCEpBwFZiRyrlBBRZ+wKV7VScAarpv0JwJz9rs1ILPAPh69wg89a0bL7WyNlOIY6GhJd6FhjRohGZHshUZetjxTY4I9XMympZmVgowhcQpwARBaFlIzSRAU8ydgXAcB76c14FSXzZF7lIOHrjd5udE2dosgrxOO6I2NTFbPSOZnUIzI5aCNqdi0bgI2hSimtEZS0szvfPqZ7zeykoVBTrHJG9GSCAtojnz22+/MXHiRHJyLGU3ggCEpwILemgXq07oPeQL7H6T8AIuF0zfG4CV2V7ad/MQibVZnns3x+X5/0I+KB0XOBpFD+WS0ENNudAQhGSlKRcZQSiN3BKb1ytFhkd1zKO9SLEg0VYeat9kiK2lWXIqwCRvRhBaDlIzCZYkOncmyo4DtXVuvtrRH4CT9zVbRRpDazOFMPI61UQjrzNsUZsBImoTmiJh10w2BW1akkLQ1iQszbTn0jDPFrOgHeaCNnEbEOJIs2/O1NbWct555zFo0CBOPPHERH8dIcHEUgVmm7BUYM4CLlf90Y6eleBzuRh+fPASS32MVtHAqQXfkO2uYll1MXMr+hK2h7IUGoKQMJp6kWHasA23yDBckR9LSzMFUYAJgpBcSM0kmNGcHQf++2cfAI7t+BvgI67WZh77lwYRQV6niNoEIXEkRNDmFI/BPmCc0QmxtzQzyhTT1ksOMRIDqBC3ASHWNPvmzL///W9+//13XnnlFVJSUhL9dYRkJZEqsAIMAi612A+4BMhZXwRA6h7rNGeMlN+hCoc0yji74HMAXtsxDrA5wHvsXRZCogsNsTcThEZsFtnNo8gwI5qWZqIAEwQhOZGaSQgXI8cB2xSr9hPgODD1p85U+dLokbqFgX2MJh212LU201mV69G5LEp5nSJqE4TIaJaCtpLA1krQFlFGp56lmbEAODK3AfW+WtCmHgM082pmbgNaiht3DW04EbcBIbo06+bMokWL+Pe//83EiRPZa6+9Ev11hCQkKVRghmhVYOAk4PL3b/xL9Fe2qiWvbTnhWJuNzZtP+zQPW+pa8/muEZr3Oyg0jEi2QkMQmiFRtTOLYZGhxnaRYUVYvsmo9p1YmoVDGvq5M2pLM1GACYIQW6RmEvSI1HFAmbhylDujJQ6OA2XlaXxb1heAU4aYrSI1szbT1lnayUrVYYUY5HXaRURtghBK2P8dNldBW1gYzTlBeJZmegI1o8xO5XqL2qmA0LHFwG1AD3EbEGJBs23OVFdXc+7/s/fecXJd9fn/M9u0qzqqlmRJXktylS3LuOBeAdt0DMYY00uAhG8SIKEkJCTEicMvJJAQSAIkMQY72BQDpti4927LclfzWtXqo77aMvP7Y+Zoz5w5/Z5z753Zz/v18mvm3nvmzl15d+Y89/Oc5/OBD2DJkiX4/Oc/7/S6Xbt21f1HjE6i9p2ZDkcXmHuDy2UPHYo5A2UMFgo46+3PWbxPvbOhHf345NTfAQB+uOMSDB58v4AZyjJIaBBEMIK6vySEFBl84VWK7LMhSm6yrpeMbqm+rwOMR9Z7hhHGAabD6AAjCKLlIM1EOJPLxAHxZpxb4sAvV1aLM2+Z93xtj2u0GVDvJDdQUjy3oa9xF5naCCI+eTS08QQ3tJXEHbaRZrpUAR+dpCrGiJ/zjoY2ShsgckTLFmf++q//GitWrMD//u//Oi3Nv/rqqzFp0qSD/82d69aYnMg/mfedMTqaVV8qbg0ugQKmvjoVADB+UZ9wLnO02VsmPoHDurZg+9B4/N+Oc2CVoVxSPLdBvETkTGhQgYYY7TSbyGCUFM8BZBNpZuMAE5+Hc4Al6jfDvh8VN+ZoeT5BNB+kmQhbMk0cKMLCjCB+X9onDvzswXkYrhRwwphXcNisfebrA+AdbRayXyeZ2ggiCLHjzGIZ2vh7JAcJaWjzijQTH0NGmokmNj5tgEdiaJNBaQNETmjJ4sxDDz2Er3/96/jyl7+M4447zum1X/rSl7Bz586D/61dS3+Io4a0XGAyihC+T3gXmH+Dy5cfrLrAXi7uR1e37ItRfrOxA/345NTbAAD/s/312FfpVryDRcPLnAkN8eaxtdAgiCYjbyJDh1FkyAiWm8xHmvEiwyfSzKVQo3OAMfLtACMIorkhzUSEwqfvjFPigBO6xAEWJ8rvAzZt68aj/dW5yGWnv6w5tynaTHfTUkPJbthB+hp3SQ0tAmRqI4hGsvi9az1DG2A2tNliMrRBOM7v00RCj4O8SMN/5/SOPI3Zb4YgRFquODM0NIQPfvCDWLx4Mb74xS86v37MmDGYOHFi3X9Ea5N/Fxjg2+DyoVvmY/pgGfvb2nD2O55vON5I9UvzrZOewryubdg6NB7/VxLvwDKh4RhtZkNf467QQkOG9U1mEhpEk5BHkdGcucmAW6SZK83tAKPl+QTRvJBmImyI2XfGCqvEAdnvnipxQEZ17K/6qvOYt85/sbbfJtrM5fs/YL9OA6rEAZFoq2cSQLqJyBtNtWpGhmhoY3gb2iAMEiPNVISKNOuEPLrSLsbSveA/Qqh+M5Q2QIi0XHFmz549WLFiBZYuXYquri4UCoWD//3gBz8AAJx++ukoFAr4xS9+ke3FErknet8ZJ9wbXFbQhkO3VgXL9BPFLxK5E7wTB/DJKXcDAP57+wXYX2lH4gzlnAkNKycYCQ2iSbH5Xct61QxfUFWJDK0DzPSZkTg3GUg30gzC8YQOMJ7ekaeJHGAEQbQUpJmIROQqcQAIkTjws4d6AQCnj12OqcUBiwth2MwH0unXmcbqmdimNoJIk9j9OWWEXjVTh42hTRZl5gS/OpAnrUgzfh+g1UoqKG2AyBnmWUqTMWbMGHz0ox+VHrv33nuxYsUKvPWtb8X06dPR29ub7sURueEtlYr+i/hOWN2cX4hV7l+YQPXLQPal2IPadxz/5SJ+6XWi+mXGHvVseOII4M1PYe3U3Si0D6MyDNT/6Q/VtquPb5/0NOZ07cCWofG4oXSG4exMaPSMXCq79L1oXBW0FSM/+zRUhYaiLwLWFaRfjCpWYoHSybAMxxvjFSrnSm5CXoBGR+BFkArStxxhdpHfXCiQS4LInCxERrRVM9Fyk30izWxJwQGm+lyV4NxvRgF9thFEc0GaifDh5hVhbq7PxVqz61vUS8Xao3GF/kFBhep3LvvuZt+pgxjRPiOsWDsBzw3MwaKudXjH6Wvw/d+pjGDs9eyR3w/IJxuKG4cljPxc/HMb+lBvBBRYi7kHV8quwkLlzcXlOEprPFw65ZiDhg0lMs2kwaSbSDMRaRDcPJlRDHQQQ1uJ21cSB4Xo0emKTgPJkgcYbDWlkDYgS6uhtAEiR7Rccaanpwff//73pcc+9KEPYcWKFfjSl76E0047LeUrI5qCW+F1g3IBVjau5JhTGbmpOAsjX4yywgz7sijJzs4m8/yXIi8I+MdG7vvVUTjlkiews70Np1+yEg/+Wn2TtLMwhE9MvR8A8L1t56K/0iY5L79t4VAoIbHQ2Lx67kHHd6pCQ0WCAg1BxKQZREbiVTO2lBTPrVEt0U+rqWVOHWCebmmCIPIFaSYiFku2v4ClU47BYjyDZTgeR+IlpUljxvy1I9/5vRj5rj8E9atFpsO8Er8OVqAZRPW7lv8e57fZ82qx5tfrFmHR/HV4+1Ev1YozzMTGzlNvahv57rc0U9Q52RRsQfXnVZnaNqJxhVGTmdpsoAINEZPgSQOWNK+hzbVHp4hr2oD4nG2rTGydMH62FiX7YvSbIQgPWi7WjCB8iNp3ptfwoqLt2cWbdXYNLgcHOnH4jrEAgHmni8WKekf4ZZOexuzOXdg0OAE37jxFGKtraumQoaxapi9zdTj2mYjSe4bizYgmYdSIDDE3WfxMUd282S9uqCLNxAz5pJFmDNONG3KAEQRBEE1I7Ya86cZU8L4zPfyTiZDfmJOtSlXz08erWuKCic9jbLfL97xpXpBNv85UI6FVULwZ0YQ4/36OakMbEC/STPbZLe5zMLSl0W+GFamp3wzhABVnCMJA4r4zPIn7zgA+DS5Ly+YDADbNKAEoo/ELcQjj2g7gU1MfAAD85/azMFAB7JagktDgsZnIUYGGyALnwkwriAzjZ5DMAcYjRpr5NrKUPWfbukgzcoARBEEQ2aK8kZRG3xlZYUbbb431neG/P8XVqeLjCI+/UMTa4anoKQzgTadtUL0J1M2v2VxBnF9oKEmeq4wmLWZqM+km0kxEDBL9XiXoz9k0hrY6QvToTIopbUCkR30IoH4zRC4ZVcWZa665BpVKhZbnE3ZY5uZqK+i6BpcysVGEwgVm0+BSnVJ47y+ORU+5jC2dBSw5d410zEenPImpHfvx8sAU/Ky0iDvCXOQyoaERGyXJc9MEYJQIDYIITVZxZlFFhgzb3GSekrjDlJsMqCPNRJI4wFRNLYHcOsAUkAOMIFoL0kxEGsRNHADMiQPisQ4ABfz61epc5u3HsiWi7Dte/M7nH21vQLJooBomU5sugojR17irWUxtNlCBhghJXpIGTKRmaGME7dEp4hppJjMg8wUaWdoAu2c2sf5lRcgL+pQ2QOSMUVWcIQieVFxgKsQvgyLkETUNiC4wmcBoLNLs3tWDhbvGAACOPO/F2t4RoTGjYw8+OPkpAMC/bDkLQ2iHm9AAlDc8bYTGJskYCSGEhnjzmIQG0ew0u8jgqRMZqeQmi6QVaaZbTSMWZiSRZiLUb4YgCIJICZsbTeknDqhMDS6JA1V+/nR1jnLJ1GfQ0TFs9ZoquvmBwdTGU7J8O5NhJSHBTW2UOkA0AWnFmeXG0Fbi9pXEQZafWVEizcT+Mgxdgd2A+B1CaQNEjqDiDEHUSK3vTOJosx6ohYbKBQb0vzgPALBzdqO1/NNTH0NP2xCe2DcLd+wRb5omzFBmlDSn4GETiT7L8RJ0q2dsIKFBNAuJCjMqUhYZfKFVSp/VaUYoKZ5LUa0EHER6kWaA3gEmgRxgBEEQRJ4I3XfGOXEA0CcOmCPNGHc+MR1byxMwuW0vXn/KZs1Fqm5EyhIHRAL065TBGVtsTG2iZgq2eoZSB4gcEjzOzJLWNLQB+Yg0c0gbAPSGNgnM0EZpA0RsqDhDEBYE7TvDE6TBpR33//I4dFQqWD8GOOKkVw/uX9i1He+YVF1N8/UtZ6Fx6T647QQZyowUhIaIz+qZtOLNqEBDxET7O5iRyLAiam6yCdtVMqGbWrJH8aaSBnKAEQRBECmQet8ZnkSJA6Zos8YYnXK5Db/eshgAcNlipvVM0Wb8MRPp9OsMgffqGRWUOkBkRJSkgVYytBmxjTQTPwdDRJrxyAxtQMN9M1XaAKC+B8d9ByU2tFHaAOEIFWcIQodj3xmnBpcyIja43LJxIo7Y0w4AOPENz9b2DuFz0x9Fe6GCW3cvwNP90w/uH3mMlKFsQ5/n69JGIzSoQEPEIor7KwWRkX1uMvuskokMWaQZj+8KGh7XppYGLFbM8JADjCAIgsgL8RMHALdos+rYny47EgDw5hnPoK2t7PBebJ4gM7kF7NfJjCqGfp3SuRXcVs+QqY1odtJKGpDRNIa2krgjSY/O0JFmbFsVaabRT9r7azV6LS/LA0obIGyh4gwxqgntAmNYVdplLjBrfBpcAuWVcwAA/fOq39ivHbsB545fi8FKAd/c8lrN+yXIUC5Jnrei0EjgBCMIV0a9yBCxyk22JXSkmaqpJYTjlg4wmVuY/z4x9JvxhhxgBEEQRI3U+s7wOCUOiDfr9P05RW599BCUyuMwvW0XLjhJtzRX12OBTxyQkaBfp4w+zTFYuO7TgkxtRIok/n1x0Pita2gTse3RmRRZ3xkxAtoSXdoAh+y7iBnaKG2AiAkVZwgiLXq558FcYLKCjJqHbl6MQqWC1T3A/KM248+mPw4AuKF0LF4ZLHIjZQ4Itu2YoayjlYSGBhIaREgSF2ZaQWR45SZL+mPVocqMV42F5riITFx4OsAALwcY9ZshCIIgohKi70ywxAFAnTjAjqkTB4aG2vHbbdX5zWVLlrO9tUdTtJnppmSkfp08uh4UNZrZ1Ea6iQhJrDizTA1tKrwMbaoenUB6kWYyExuPYGiToes3IzG0SVNxdFDaAJEAKs4QBIfuplOqfWeK0LjAVO4Bc4PLDasnY/6+6p/9R058GIu6t2HPcCf+Y9uJUAsN12gzDlm0WcnyVC0mNKhAQ4QgmvurmUUGo2T7JqpIM4aqKBM60ozfx0jHAcYgBxhBEARhS6y+M9kkDsi261ey/vSZqm546yHLUIDNTTXRSZ5Bv86+xl2qfp2pESHejCBsiJI0kIBUDW2qSDMeb0MbEDfSTNwWzWyidpLopyLMaQMKqN8MkQVUnCEIEzH7zsi+HKI1uKxud/Ydis7BCi5YsQ0A8N3ti7Fj2OaGYKAMZUZOhYb1zWdyghE5pSVEhozouclAo8iIEWnGj+GPp+cAC9VvhiAIgiCi0Ms9d04cmCjZ55448JtHZmFXpQcz20s450TdMn9TvzoVAft1epraeKKZ2lSQqY2ISLSkAU9Dm0kzNZehLXakmewzWvz8lpmWJRShN7T1Ol2YE5Q2QLhAxRli1JNp3xmRousLXBpcVnn05sV402MVFPcAmyo9+OGOYxUjA2UoZyA0kqyekeEkNDTY3jQnsUHIyLvIEHEWGTx9nq/zyk0G6j/nTCIjVFNLBl9MB5rJAUbL8wmCIEYvuek7U4SQOMAeXRIHGpMHBgbacev24wAAl72G/bAho80UlCTPXUxtElSmNl4z+RDT1EYFGsKXNJMGYtA8hrYQkWYM2We1uFKGPTokDTAs9BL1myGyhIozBBGZug/5Xu6AjQvMu8GlOtpsz4qxuPShMgDgviXjcaDSgagZyjJKtcdIQiMpWcebEYRImoUZXxKvmtE5wEwio8S9hn8uxRRppiJ0pJkoQsgBRhAEQTQ5jn1n8pM4IB6rRZs9dzQA4G2zlgHRos0UN0NdTW2MPs/XIV+mNluoQEPw2P4+hNLkmaya4enzfJ3R0KZC16PTJ9KsQ/GcbesizYD6+2S1h6Lm7RjUb4bIGVScIQiBJC4w66gYscGlKDYSN7gU94/s+8OpS9E9AKyaCTx9dsniYkU8hYbPW/H0uQ0PvXpGCTnBiMik/bvQeiLDNzc5dqQZ/0gOMIIgCCL/xLrBlNfEgZsfmo29lTGY074Npx+/TTNSVpSxQdOvk1EynIIZWAyJA2munknb1Ea6iXAhpqHNtT+nSDBDm4izoS1Jj84k2ESaseca/SS7n0b9ZogcQ8UZgrDB0gXGMLrAeBI3uFTdCGz8Yju8aw/eXazekPvhBW1YMb6MWYft0Jzf9yalRmiwG6glwykSCo2kJM5RDgAJDSKI+ytwnJkr0UQGo2R7JbIoRtfc5NhNLYFmcYARBEEQxEEC3ZBKJ3HAHGnG2N/fgdtK1Wizd5/MvkdtEgfYcdV8IWC/Thl9mmMGmtHUBpBuItJPGrC5N5CZoU33OaFclefao5Pf5j/rQqYNyD6nHUxtlmkDVrGaDlDaAOEKFWcIAin3nenlnjs3uFRh1+Dyc9MfQUehgjt3z0X/jDaUCwWc/vZnakdDRJuR0KiDhAaRgLQLM7bkRmR45ybzDjBxv4okIqM1HWC0PJ8gCIII3XdGSvDEAbZPfBQNb9Xtnz1fnee8bfYzcI82Y9ti4oBIOv06M189o4NioYkE5DFpwJXEhjYTJcm+RD06dQVol0gzmVaSrXTsQONnuWBok2Ghl2RQ2gCRJlScIYgUCNfgkneBuTW4PLVnAy4YvwZDlQK+vuUU9Lwys3r4CNsbdzYZyiL5EhquN42dhAYVaIiApJ2XDDSpyGB45yYD9pFmLsv0XZtakgOMIAiCaBFC953hCZY4ANhGm/3iwbnor3Ti8I7NOPmYknBUXF0r3rg0oYleLUmeu/Tr7LN4ewWZmNoCQJppdEKGthqmHp08JdMb+PTo9I2EBhp7c6K2rTIhK/RTEfJeZPx3B6UNEDmEijMEISFa3xlVg0sZARtcFlDBn894FABwQ+kY9A1OwlO/WwwAWDluGNPn7NS8hyzyxwbTDVGkLjREbIRG7HgzKtAQ0RgNIqPEjeWfN6DKTXaJNBPHyTA1tWSoIs0AcoARBEEQo5GwiQPM1CZD1p9THW22Z18H7ti5CADw7lPEaDMothnsZiU/n7Ds18loxdUzZGojAhG1MKOgNQ1taffoVCEr0oiRZhaF9aLheG/jLkobILKEijMEYUuIvjMqErnAAJML7M0Tl2NR9xbsGe7Ed7a9BgCw/OlDcHh/BeVCAWe+wyfaTNUAziPazIaUVs+kHW/mAomN0UEW7q/WFRk2iHGNYkEmZKSZygHGIlfIAUYQBEHkm1hx0IwwiQMMPnGAP6jSTvKYnZ+9dDQA4B1zlkEfbaa7YckSB2Ro+nUySrVH0QWfo9UzeYk3I800OoieNGDZn1PG6DO08SSNNOPH8McdkgYAfdoAh813DjO02UBpA4QPVJwhiBq57juToMHlmMIQ/nTaIwCA725fgh3DI68bv6YabdZ25Brdj8BhylAWETJLZdFmJcVbZbR6RkZenGAAiY1WJ8/ur6giQ4YoMhgykaHENjeZbTNUkWY2hRqXSDPAWmgUDcd7G3eRA4wgCIKITbTEAZ4giQOAPHFAtQ2w7+ufPzAXByodWNj5Kk44ghVSMurXydD16WxmU5sBKtAQQMD/vwGSBprS0FaHraENSCfSzGb1DH9frPZQlJxaLOQbDG0MMW2AQWkDRCyoOEMQKZHIBabE3ODyA5OXYVbnHmwYHI8f7jixdqz65fb0LccBAFaNH8LUWbs17+OboaygJNknRpvJyJnQoEaXREiycH/Z4iq+VSKjDl5k9FmcVPfZUDK9WJebDNhHmulQR6KoHbrZOcAIgiAIIjaJEgei9J0B6m/82fWd2bm7C/fsrt6AffdrHVejAgjWr7NkeBuZqS0geTC1uUAFmtbE5f/rqDK0iegMbdIenboiTRaRZqq0AQ3j0Pid4BAHTWkDRFZQcYYgFOhcYExoBO07IxMaTi6wxpt/k9r242NTqr1m/nXra3GgUn/z8PknZ6O3v4LhQgFnXrqstlflAhPhM5RFoaHJUJYJDRMBhIYuR9lXaCghJxjhSLDCTEpxZr4io84BJoMVYE1/86XaY6LcZCBepJlNU8sIDjANSfrN0PJ8giAIwgnPG1ZxEwdUTmzVYz0/W16dH7197rPS4yP43rhUzFtKkn2qfp0y+rjno9DURrqptUilMNOshjabz4OS6YqS9ugMGWnGPyrin1XoDG29jcMpbYDIGirOEIQLjkLDufLOsGpwKftyql8O+sEpj2FC+wBeOjAVv951pPRsE9fNqL7yKFO0WWChwVOqPUYUGiLRc5Q9nWBUoBl9RF+Wr8A3M9lEVJEh5qwDDrnJDL6ozBBXB0I4JttvQ4Cmli4OMOo3QxAEQaRALvvOyL4vGxBv8InObJlDu/q9/bMHDsNgpR3Hdq3DMYeHijaz6NfJsDW1MVwjjyxJ1dRGBRqCI1hhRgcZ2hSYIhqTGNrEbfHzWBwjMbTJsFgxo/uuEQ1tNpChjfCFijMEkQJOLjDrBpeAusFlJ4rte/H+yU8AAP5965moKL7knv39IgDAygmDKB6yr3YsxQzlFIWGbvVMqtBSfaJG3txfuRAZtmhFhgzZZ5Kq2OwbaWbjAGMEjDTrbRxODjCCIAgiLaL1nXFJHHBCFTeqZlupC/fvrc6BrjjdsleO0hzi0a+TUao9qkxtpn6dst4VcFs9k5qpzQAVaEYPQQszAZIGZJgMbeLfUT4NbSKqHp0hI81kn8GqgrlCPxUhT53hvysc+82IUL8ZIiZUnCEIjkxcYEEbXI4IjY9OeRjj2gbwXP8huGOPuiix7NE5OOxAGcOFAs6+9GnbN+NImKHMKNUeAwiN3K+eAcgJRjSF+yuKyODpk+xjf+sqkVEynzZMbnKMppbZOsAYLg4wgiAIgkhCor4zNlgnDsj6dLpGmx0LALj0sGXS4yPIYoBMOPbrtMFgfDEaZyIQMt4MIN00GkilMOOIj6HNlqiGtjpUxRhZj05VykCoSDNeL8k+hy1NbUW7YTzM0EZpA0SWUHGGIDSk0neGJ7ELDAA6MbV9N95bfBgA8K2tZ0PvCCtgci3abMwxrxjO7eOU0CyVLVmeQiTF1TOJc5TJCUZIyKPIsCGIyJA5Nn3/pkumAT65yTwhI82yd4Cx7ycG9ZshCIIgopFG3xmnxAFA/r1rH212w329GKy0Y1HXOixasLN2zCVxQJyHePbrLEFOE6yeUeJpagNIN7UyQf9/eSQN+BraTEkDQVbN+BjaGuoxtpFmuu2QkWYM/rM3btpAYihtgAgAFWcIwhXPvjNpNrj82JR70NM2iKf3H4p79/ITA7lTYdltLNpsAJMP2VPb6yo0RPIpNERMQkOGk9DQQUJjVJJaYSaPIoOnT7JPFBkipdpj4txkQJ//7kLkppZFx8sBOcAIgiCI+LRS4oBttNnW0hjcu+doAMB7T7f9jpXNK1jigAwHU5uYOCAjweqZXJnaSDeNOlz/P6WRNCDDtT+njuCrZhgl3UE+/UQ0tPH3gGL06BQ/ezshL5QDI/pJSBsoSk4buN8MM7TpIEMbkQQqzhBEyng1uASsG1zO6CjhPcUHAQDf2no+ADapkYmOqihZ9vA8zDlQwVChgHMvNS3TZ8gcFKYMZQklyb4mEBppxJsBJDRajTwWZmREExmuq2ai5SbnKdKMoSnWkAOMIAiCyDmp9J3hCZQ4MIJdtNmNL1VNbZfOM0WEym5s6jD062TY9us0mdoUZG5q80wdcIV0U74JXpjJMM6MDG0iPpFmhgL6ODR+BzikDTBDm/a7CNRvhogHFWcIQiCJC0wUGtaYGlxau8CAj0+5A2PahvDEvl48uO9o69cV11fv/nVZR5sBuRQaKa2eUZGHAg2JjXySSo8ZwFnU5k5kmNCKDBFVbjIQN9IMsI80Yw6wGik5wAiCIAgiLVz7zsRPHOANETIjhTra7Cf3zcNApR1Hd67HCUeaos1EeJNI5H6dPDIjjKJfp0jqpjYdATUTQAWavJJqYSZyf04R3d9TdoY2FaENbR2K5+I2//kbN9JM+l1jA6UNEIGh4gxBGHDpOyPCKu/SaBlNf4CDODa4nNS2B5dOegAA8O1tF6O6asbOBfbM7xcDAJZPGMS0mcxFYRNtxo7nRGhwJBEawXOUgdQKNACJjbwR/P+Hh/srlsjQ4SUybHKT+ed12OYm6xxgPiJD5gATnzs2tYzsAKN+MwRBEER0Et7A8kockH1/SlH1ntGzY9cY3LX7WADAFa81fVFm3K8zJ6tnVGRtagNIM+WNvBZmbDEZ2nhSNbTVIYuez6JHpy7SDGi8byWJNJOR0NDmBKUNEIGg4gxB+JDTBpfvLt6PnrZBvNA/Bw/vOwIuLrDnHj4U8/qr0WbnvOtp9Q9Rh8tS1uYSGjKiOsEsILHRnLS6yMhs1YwSm9xkwO+GiQ6ZA0wWaUYOMIIgCKI1iN13RoopccAJWRwp/yjnxheq0WbvnPs0APHfwKZfp4wE/TozXj1DpjYiBHnQTCpyuWrGxtDGaDDFhjC0+aCLNFPtE1c7chRhTpkx9SqrwQxtPv1mCCIpVJwhiAzwqtQbvnQ6MYgrJ1eV0DU7Xgd9rxk5E9bOBAC0HbXGMFKWodz8QiPvTjCAxEYz4RMxN6pEhgxvkeGam6y6YSI+T+oA4/dZNrWUQQ4wgiAIogkI0XcmncQB3YpX/ntbvkL2p/fPQ3+lEws7X8Vrji7V9rrMGWSJAyJkatNCBZqWIrhm8qTlVs2UdINUaSe2hrZQkWaq/SMJMUaKwrZKO2nSBkxQ2gAREyrOEISE1PrO8F8OCV1gF4x/DDM6dmDzUBG37DoJ9YLCLtrsyd/Uos3GD2PmYaXaXtsMZXaseYWGDVk7wQASG82Az795rEaWSeLMoooMVli1FRmMku6gbW4yQ9wOFWmWoKllEcEcYAxygBEEQRCZUVuZ6dpIOW7iANB448/G0Fb9Pt+1pxN37qpGm703eLSZY7/OkrBfZ2qT0aymNgt8NBPppvSJUpgJmDRgo5tMv//5MrTJSLtHp7hiUZY2IMIZ2mRQ2gDRxFBxhiAsiNZ3RoZng8t3TLobAPCLnedhUOn+4h/559XxK5bOxPz9QLlQwJnvMEWb+WYoGxgNQsOE5014HSQ00iOTwkwKcWYiqYsMPtJMKzJscpNVq2d8idDUsihsezjAVP1mGOQAIwiCIPJOjMSBemQ3Btl+NTc+Xy3OXDqHjzaz6dfJEgcC9etkyBqBi/CJA33m4Wma2vKSOgCQbkqLKCkDQC6SBvJtaNsF832ZkD06GTbxZeI4FmkmaKgeyA3MlDZANClUnCEIX2L1neGxbHA5o2MLzhz3JADgpp3novptpbo5qKenbzYAoHyE7EvLVmjw8DdKxf2NTxsYjUIDCL5UHyA3WBpEKcyYiOD+khFFZLhSUh1wiTQz7c+oqaUMcoARBEEQOad5EgfEaDMRVcKAfKXsz2rRZod3bMapi3ao3lQgcL9OUy8+Zngx3dRN2dSW99QBgAo0sYmmmTwLM2n15xTJp6EtdI9OfaG7fgx7tIw0k90f478T+O8KTTwm9ZshsoaKMwSREdKKvWe02evH34f2QhlP7j8KrwzO5o7oGlzKvyQfvbkabbZybBlzj9xW22v6IlbllMqqLq0hNJq1QAOQ2IhF3kSGDFsxnJrIYH/DphVxps8EAObcZF5c8PsZoSPNVPsUTS2ZA0x0+5IDjCAIgmgi8ps4ICL7PtYlDjCqc4A9+zpw287jAABXnLrccKGyFby6sY6mtlLtUezXKSMHpjYVMVMHqECTHzLRTBpCG9p0NJehLUaPTlEX2UaaGdD2HFND/WaIvEDFGYJQENIFpo2WCdDg8uxxTwAA7tx9FvQ3Bs3RZi8/Px1H7CugUijgtW+zjTYDzEIDyFpo8DeNXYVGkngzLSlnKTNIbITDd0VSXtxfqYgMHp3I8MpNlomMpLnJLqgctpGaWpIDjCAIgmg2cpQ40Ih/4sANz9WizWa7Rpux4yn06yRT20EoeSB7MivMREoaaG5DmwyVoU0c44pPpBkwop+4tAGbSDNKGyCaDCrOEIQlSVxgDFaZD9ngsqtwAKeOfRwAcO/ek2sDZDcB7YVG56pDAQADh2/g9iYVGuKXevpCQ4dJaMgIIjSAIEv1yQ2WDb7/hnl2f0URGX2Swa4io6QblDQ3WSREpBn/SA4wgiAIgvAlZOKAmk6okwbk3+O/eOBQ7Kt0YV7HVpxx/DbpmEZs44E0N1D53eI8KQerZzKJNwOiFWgA0k1JyczMBiQ2Q6pwNW8GXTUTxNBm6tEpe3TFJ9IMqN7LMty/MvUWm9W4i9IGiLxCxRmCSELCynmIBpcLulahu+0AdgxNxMqBw4TBsmWjssd6HrxpCdoqFaweW8Hhx9k0fQHcM5QNRBIautUzupvOQEQnmAnLm/TkBkuPJP9uWYgM3zgzE9mvmhG3fXOTY0SamZpakgOMIAiCaA1ynzjQwz/h+864RpvVmy727u/ErTuqWuDKU019c2RzEnGfztQm0U8lw1uKpLR6RkaweDMq0DQdUTWTCY+kgVD9OTNbNVPSDQplaIsZaWaZNADo0wZ4KG2AyDlUnCGIPJCgweXhXS8DAFYPzAcwFnKXgWypvlporFs9GUfurX48nPwWc7RZV2EIx3VvxKWTHsNnpv0O/zLrWvzv3G/jp4d9Db/o/TtcN+8q/Ovs/w+fmPJjLBrzjPB6SbRZyfCWIoFWz4jYCI1gBBAaAImNNEjybxW7MBM6ziz6qhlbSqoDSXOTxTG64zJUbjCVA8wAOcAIgiCIFiB3iQNG/KPN/u+ZWrTZrKVoayvX9qoSB1ToVtMEMLXJbupa9LJwMbWlFm8GZF6gId1kR3QzG5BK0oCM5ls1I26nYWhj+EaaAVJDm4hYkLc0tFHaAJEnqDhDEBp8XGBMaKhcYNYNLkUULrBDOqp3OTcMiXftxGizTrgIjcKKeQCAvYfxM4Pql/HEtj24YPzL+OL0u/Hzw67H40f8C2487FpcNfM3+PjU+3HxxGfx2rGrcGz3Ohw5ZgNO7FmJ1094An8y/Ub8pPfzuG7eZ3FE1/Mw3lhNKjT65KdtZaEBkNiIRdJ/n6wKM75xZiaCrJoR/5ZFkcHnJitFhsplCsh7YclcX0kgBxhBEARBGMlB4oAZ98SBXz54KErlsZjZXsKFJ202nF+84ambh7CbpzyGfp02yAwyfdxzT1NbEoKnDliSZGUGaSY90TUTkFqcWS5WzdhSUh2IYWgzkTTSTKGfimj8bFfpJY7EhjZKGyAiQsUZgnDAxgVmSyIXGNfgcqDSBQDoUH55mqLN5Nz/sxPQUalgTTew5KQ1OHfcWnx++qP46WE/x4MLr8W/H/o7fGDKMhzdvRUdhQq2DY3FA3t7cd2Ok/GPm9+Iz254L/5g7cfx0bV/hP+3/hP4x81X4Nbdr8VAuQMn9ryAHx/2GZw19pHau0USGjyaiU/aOcp5LtAAJDZkJBUYeSrM2NKcq2Z0uckqsZFWpBmgjDQTIQcYQRAEQchJkDhQRRZtxiMzWqgTBwYG2nHz1iUAgCtP4k0QSft1ihj6daoijZpg9YwXOdFMpJvqScXMBuTK0GbSTPzfjS4u8CBJDW11yAxtMmyKxbLX2GITaQYEM7Tx3w02cZguUNoAEQEqzhBEUmK4wByExvbhyQCAw7teqe1RfaGpos3Em40dGFsYxMIdu/HpWyv4h2uG8KO9d+A/5tyOD015Dsd2b0NbAVh9oIgflxbhsxsuxvmrPoyzV30SH193Of5+88W4dsepuGX3Ity/7wg8tG8B7thzHK7dcQ4+s+FTeN3qf8d9e5egp+0A/mn2VZjVobgTWKo9qiYaAVfPiITMUW7mAg2JjRwJjIA0v8jgyTo3WbffoallEeQAIwiCIJqaTBMHDhG2RZNDEYZ7fmLiAOCSOHD9k0cDAN42/Wl0dpQNoxm27nPxhqqlqU3XNjTB6pkQprZWSR1gkG6qkvTfIG+FGRmuSQNaYhnaGj4XxMKuLNKMwW+HiDQT0RnaZNqKM7TJkKXLWMC+WyhtgMgLVJwhiLzg2eDywfIZGKq04+ju5Ti5hxU6+KWgYrRZI+PbDuD0sWvxp9Mew/XzfomHj7gO/zXnNpzxVBkLNwJtFeCVgQm4sXQk/mzDuThn5ZV4c99l+Oqmc3HL7vnYNDQBAJtc6IXG1uHJ+PS6L2HZ/qMwqX03Pjrl/2pHDEKjVHt0FRo8GQgNb0hsZEqInztoYSZHcWZaQosMRtDcZJ6kIiPFSDNygBEEQRBNRKqJAzyyxAFrTIkDcmPG7x87BJuGJ6HYthdvPWO94T1kN0Jl+kkW2Wq5ekYk4OoZER9Tm4qsCzQhdNNoJFe6KWBhJkScWfMY2oDGzybdGJOGUqUNyMaIn6uy+1i1zaLkNPxnvmXagC2UNkCkBRVnCMJA6i6wXu65RYPLneXJ+NWutwEA/u3QP8NbJ96KzsKAMHDkC29MAVjcvRFXFp/A1TNvwc29P8SjR/w3/nvuzfiDqUuxpGczOgoVrB0Yj5sPHI7/eFMBn/yjdnxp7in4m01n4re752Pr8FjJxagcFyL7MQjgG1vfCwB426Rb0VmQKQXEWT3D3TzOtdCwIaUCDTC6ijQhfs48FmZsyURkiKiiOQ6yC/YZiGKMiG9BJmmkGdAQaSaDHGAEQRBEq5Jx4sAIYrSZiCpxQNxXHVcut+GmV08EAFxxwvPcONtoM7aPjzYTUcx7Qpra+rjnOTO1pVGgAcjY5kKookzeeswAKfTnzJ2hTTSyiQkDSQxtqqI30BhpZmFqMxXcxVbMoLQBIv/YdGgiCILj5hXqCUThHrcb7nOx1r2pNmMcDt7A/PtNf4kjupbj+J5n8Y+zrsZXDhmD5/sPx/bhCRiotKGnbR8mte3FnM6tOKRT7rhaPzgBj+2bjUf3zcSj+2Zhw1D1i/HyeT/G9omDOOHCZ7D03sNqo4dQ/fgYRPULlW2zR4ZYoNkP9oX7yL7jUBqegGL7bizs6sMLBywyfEqoiqwtUN/A3ITGaANL1mKu1k2xEgsO3vgEqhM0dgOUsQzHH7wZyrN0yjHSm6OVczWOjAtgnghcBG2hkMF+Z5M6PNgEvBUd9aGEVNaFGRW5FhmiA4zHSmSI6JyoQKPgYK+xxSfSTCI2yAFGEARBEO7MqZgbaU9HY3GiB4oaB3+DcLewz25+8KNHj8In33E3Lpn8DMaNvQB799nEorE5iG4sG8PPI/ajYV5RgqYYxbEVjUWsjRi5odkH6Xxj8+q5yhucq7CwznwoaibAXjepNJMRk26y1ExAdS5PmklN6poJSBwBnav+nDx93PNohjZbQhnabOiEumBjQVHYVt1K0qQNqAxtWihtgIgErZwhiBBk7ALbXxmL9625Dt/Y8hlsHpqGnrYDOGnsi3j9hMfwpomP4ILxz+CksasPFma2DY3FPXsW4ttbz8Qn170DZ638GF6/+v34i1cvxC92HYkNQxMOnnv309VJxquzt6EAlwxl05d5AS/2V4s9C8f0sZ+E+6FqjyXLt0xh9YyMzJ1gQKpuMEaruMLYzxFqtUxahRkdSeLMgogMHp3IUOElMrLITQ7Y1JIcYARBEEQLkNvEgXGwK1wA8I02e2DZVPQNTcfYwgG86yzTd7QsfpV3sLPVM7JKkuFmqziPYgUq2fwrYCR0LlIHcrSChtEqmgkI+7MEL8xEjjMz9ecUsTa0yQhuaGOfJTrTWkhDm03agGr1DB9pxqUNyCSVaNY1GNoYpl5m7DvJNX6TIJJAxRmCSIKl+wZoFBpSfPrOFAH0AIPowve2/wHOW3Ur3t733/jM+r/AVzd9HFdv/gD+6tUP47MbPobLX/kTnLnyr3H2qr/Ap9a/B9/edhbu3Xsktktjyqrc9ZPj0VMuY1NnG85803LDxbkIjUHsLI8HAIxt249UhYaGLIQG0JwFGqB5BUfo605TYADpuL9Egq+aUeUm8yiTy1Q3LHhccpNN+ESaOTa1LKL+8z1jBxhBEARBJCV63xkbin4vc402Awr4ybpatNmiZ7lxqmgzCPtNiHMfQ79OHRFMbTa4mNqyjoUGwvShYYQ0hKVNM+umtJIGrA1tKs3kY2hjlFQHbAxtqh6doQxtukgzftsi0qyIRkObRfgKM7R5f4dIDG2UNkCEhoozBGGBzzJF0QUmwir2xgaXJhdYAwUsP7AAt+45Gz8uXYIf7ngTfrbzXNyy+yQ8078QO4aLAApQfzHWC439e8ZgwfbqGx12tk2Gsj1DlY7aO+7j9o5OoWEkQoEmRpEmz4Ij1jXmvTCTqsjgCS4yZEVc39xkkViRZoBzU0uewA4wF2h5PkEQBBGFtBMHlIh9Z2Tf12L0jp4fPXAkAOD8Cc9jalHsASoiW+krm6PIHO8KU1up9thkprY8pw4AYTUT0Bzmthi6yVl/RizMpNWf0xkXQxvfl1draDPBF2NiRpkxOlH/2Row0kySNkAQzQAVZwjCg5AuMEa8BpeyXgei+0svOjY9Up289E3fjc4xJneX6LhQjd+PSe1VYbGrtoKmmYVGKk6wJhAbQL4KNTGvJW2BAYRxf8kIJjL6JPt8RYYS26qtLjc5rUgzj6aWER1g1G+GIAiCaAa8EwfEvpOKxAEzftFmy1ZOwnMDc9BVGMYV57wsGSGa2mTzETFxQKS5Vs+Ic0yXfofRUwcySh5g5EkzAfF1kxMpR0ADYfpzimSzakbVo1NnaJMRK9KMR9znGWkmg/rNEE0EFWcIIhSh8vt9os0AC6HRA3W2p557bjoKU4bK2NnehvPf9axmpEpoyKLNgGntOwEAO4YnIW9Cw3QzOhMnGNA0BRoGP8lPS3ik8X5RBIaBUO6vphEZytxktq1aAZNFbrJ4POcOMOo3QxAEQUQiRN8ZEevEAR6rxAETrtFmwI0vLwEAvOeoZdw401zDpl8nu7nKaA5Tm4zcpA4AmSYP8GStmWK9p9e/WY6TBnK5akbKLrgZ2tKMNJMVuXugvT9VlOzjP+MNaQPsu4P6zRB5hYozBJEUi7z+/DS4BOp7Iage678Yy0MdmLOp+iZTTuLt1WK0mYjaiVFAGfO6XgUArB2YbLzqNISGzgnmKzRUeDvBgGgFmphFGoYoApIIAdm5YouZaAIDCL4svzVFhojoAOOJJTJ0+2UFm+ZzgBEEQRBEKFonccCOH95XnUed3rMCc2fuM4x26dfJo7jp2qKmtqipA0CujG2M0Donbc0EeGqmJkoaEBH/DvJtaAMaDW1Je3TaIBa5O2C5nNF830tiaJN+V7hA/WaIlKDiDEFYkqTvjInwDS5lQkPEJtps5Ebkiruqk5YVxX5MnGIjNBi8k31EaMztXIuxbQfQX+7EmsGZtbGynhJITWiIZBVvBmRToAHSK9LwqMSH6b+08fp3CVCY0eFSmJGRO5HBUIoM2Q0Kkdi5yYEizYrQR5pl6ACj5fkEQRBEVLJMHCjCM3HALtrs5fXj8Mj+hWgrVHDl2XwsW/J+nVUsCjWl2qNofGlSU5uOLAs0aWsmoDl0U1ZmNiDdpAGb33NrUje0iUVh/lF8zgidNmARaSaD/0xXxUHbfDcQRI6g4gxBeJK5C8w52ox92flFmz12ey9mD5RxoK2A8979NHdEJTT0X97Hdq8BAKwcmIsy2tEMQkNGLCcYELBA41mkIap4i68URIYLSUWG69+Gl8goQYEoMvKWm5xepBk5wAiCIIhWJnrigBN8UYZt84/885FxN6yozvsvX8BHm6mw6dfJ5jziMUW0mUhJc4xMbSM0ibEtz+RVM6WdNNC8hracRZr1wJwSYzC0Mdh3iCptQGuspn4zRGSoOEMQIYnpAhMbXIoUXd7A5otS/IJsw9S11YvoXCRrcKlCLjRO6aneDXxy34KGY80kNGSEcoJZQWIjGomKMi0oMupITWRYHRAQizVpRZoBQRxgKsgBRhAEQTQJSfrOmEgncSBZtNn1987HUKUNS8b04dj5Ol1j269TRLzxKunXWTJcJJna5Hj2iRzNmgnIt5ktdtKAyOg0tNmQMNKMR7ViRoLrdwb1myGygIozBBGCNPrO8Di5wFTxOrI4Hn202VO/XgIAWDF+CDMPL+neFHIxMTIZOGXscgDAY/uPFsa0htBIzQkGpCI2RpPgSPTzuvwbk8ioohUZYtQhc42KhRo+PtFEktxk1QoZ3lmri5OE2gHmEGnGiOEAIwiCIIjQNF/iABugMmDoo802bevG3XuOBQB86KyXuCP+/TrVYwymNlW/Tp4cmtoyL9CQsc2KVHRTSlFmgF/SQPMb2oDkhrakkWaA0dBWFLYpbYBoIag4QxAOuCxbZELD1QUWvsElg29MDWG/HS8+NRML9gHlQgFnvvMp7ogpQ7leRMzs2IGFY15FuVLAE/uP5MbwE4nmEhqZOsGA6AUaoPUFR+KfL6eFmdYVGbwDjCfN3OROeEeaOUerjBDTAUbL8wmCIIhUaKrEAb4YYxdt9qNna9Fmc59EAabvVnO/TnlMUf5Nbb7xZjpSKdAAZGzTkLgoE7kvpwnfpAET2RraxG1dEVj8zAmJT6SZpCAjq9EEShtghjaCyAtUnCGIBEStnEdrcAnYf2E2Fm66V80BABxYsN7ifeRC49xxzwIAlvb3ojTcgWYQGrl3ggHR3WCMVhMcQX6ejAszKnzizESiiAyRkmyni8gA4ucmm/Z7RJoVhW1ygBEEQRCtTm4TB8RoMxG3aLOf3DsPuyvdmNexFeefxE+A/Pp11tM8pjYZSVMHTOShQAO0nmYC8mVmA9JPGsi3oW2X5IAs0oxhMrS59ujU4RlpVoQ+0syQNsAMbWLaAIP6zRB5gYozBBEaBxeY+OUgpZd7HrTBJeATbXb/z05Ae6WCvp4KjnyNbGZhFhrnja9++d295zjJ622X4SJVoWFD5k4wIDWxATS/4AhWlHFxfiUozOgI6f5KRWSUao97NWMAyG82hMhN9sEm0gxwjjQTi+wyt6+mWC9GmhEEQRBEHvC5cZV64kBPwxPhoE20mZx9/R24eesSAMAHTnlOO7Ye3dzGtHrGvLuBnJjaVPiY2qxw1UwBijTNqpuCXX+OCzMhkgaiGtqcV83oiGVoSxJpZqAobEcytOnSBsjQRsSEijMEEQqLBpeiC4zBKvnpNbj0jzZb3zcFR+2pipElb36aO2L6Eq9OAortJZwxrtpv5s49i2vHmNDgRYjYa4LbbYOv0OBuPocSGq5OsNQLNIGKNM0gOIJea4oCA0ju/pKRqshQUao91v1ti6vnbP7wXXKTY0WaBW5q2du4S3SAiVC/GYIgCCKPxOg7cxDfxAFrdNFm4pxhpGhz7ROLAABvn7EUY7qGDe+h79cpn7tY3JQt1R4jm9pi9ewEcpA6wEiomYDm0UxAYM2Ug8KMihBxZg2ENrQZEe+bsNV1aRvaXCLNgPr7U1zaQMRIMytCxW4ShCVUnCEIR3z6zrgSv8ElkCTabOi5XgDAtjmbAZQN71MvJi6a8Dw6C2U83z8bqwemIndCQyBEjrKKaEv1gUzEBpDPQk3wa0pZYABh3F+JM5OTigxVbrISnchg8Lns/L6QuCzV58lHpBn1myEIgiBySy4TB8JFm/3+kUOwYXgKJhX24Z1n89/jbv06G5FFvWZrahPx7dmZSepAiskDjDxqJiCSbnIhYmEmSdKAk6HNJo2j5Q1tPLIVMzLDsEBRss8h0owhRpoxozQztBFEHqDiDEEkROoCS6vBpUxoFF3fxD3a7I4bTkRPuYxNXQWcetFqboxZaLx5wjIAwK93nSh5j+YTGjJiO8GAiAWaQIIDqJ/cpy08or1vygIDiOf+8s5MVhFNZMiQOb9C5SbLyHdTSyuo3wxBEASRF9JMHDBFmx2EmSrEfTbRZnIjRwUF3Li+qnnet9jmRqC8X2e98102NzLMnUq1x5RNbWnEmwEBCjRA6skDPHnRTEHf28fMlnJhxpZE/Tn7uOepG9p40jC08egizWyMbzW87m9Vsf2OoH4zRB6g4gxBZESQBpc8RdWB8NFmO3f0YOGOsQCAhec9a/GK6iTg0M4SThq7FuUK8Jtdx6GVhUZsJxgQqUADBBUbPOLkP4QAkJ0ziqjxEWGRCzOZxZn1cc9TERmyoi1P6Nxk20gz/jGdppYMcoARBEEQzUAuEgd4ZNFmlomkjUUZmamtMdrsmgeqc7nXTXgO06cckJzX3K9TjXgTVtarz5IIpjYbQpjagAwKNEDwIg0jlr5JRTMBUTQTELY3J5Bhf87UDG18pBlPLENboEgzGfxnN/WbIVoQh5IlQRBGboVyMrJk+wtYOuUYLMYz2onlXKzFWszFjPlrG2+QzsLIF/s0jHyJT0djQaIHFitM2EfAELc9KDwOofrFWf+lvunRo4BLnsYrM3aic8wgBg+oijtDB8/z5gnPAwAe2Xc4tgyLzjQe9l7sC3oXGp1slmxFo2NuE0Zugm7EyJd6H0ZuiK4rODnWV2GhvsCmYRmOV95UXTrlGK2bo3KuhZhlE16XFV3s9zhyb4q8LeVvwFdwZVSYySwz2UdkMJSfU7J4wyxyk1UEbGppCTnACIIgiGbm5hUR535zKo2FgkMwMkfhtRNQNUsYzSIibF5gP694evkkPHNgHo4fswZXnrsa37yJze+YRlIhvkcHt4/Nkfibmfu5be45e1pCdf6xF9WfnW1vQWPRSqafVJpJYPPquXU3R9dibt38RaaZVmJBQ2zdchzVsGoK0OsmHVaaCajO4V1TMC5CKv38WlI3BSrMxIqADt6fk6dlDW08HpFmPTBrpUMk+6jfDNHE0MoZgvAgDRfYQXy+ZIquL5At1ddHm939s2NRHCpjZ3sbLryMnyCL0WaMCt48sVqcuXnXcZLjbPWMrt/M/sanpdqj7+oZS0yrZ2Sk4QQDIrrBgGiOsKbAV2DkvDCTG5FRahgJq35TDYTMTZbhE2nm2NQyB/1mCIIgCCITLG6E5TtxQB9tBgDXr6pGm1155FKLixPnMgzdHCal1TOq1AHH1TO28WaZxEIDpJlc8f3ZIxdmkuIcw9e0hrbQBIw04xELxoxe9SkSpQ2kUHAlCAYVZwgiAFH7zjB6uefODS5VQoOnEy4O8PJQB+a8OhkAUDxpuXH84u61WDBmG/rLHbhtz9G1vabJQr6ERux4s9wWaIDRJThyLDBcCjMymkNkqHKTZSIjdG5y5EgzHSk7wGh5PkEQBJEpDn1nVLAVGcZoM14vyfq9WUWbySJ77KLNfnD3QgxXCji5ezWO7t3NjTX361TPb9jcaFDYJ3nepKY2H4JqJirS6Enys+ZMN/nEmbWOoS3HkWZFYdvS0OaaNuBiaKO0ASIGVJwhiAwQG1yKLrBgDS6NQoM1uLT5Im0UGi/8fgkA4KXiAKbP4m+oNgqN9xSrP/Pvdh+NvWXTR0+OhEZOnGBADsQG0NqCIwcCIySJRYZrhnhQkSFbRaciZm6yCs9IswQOMPa9QP1mCIIgiGYi14kDXs2mRe1kZuOWbty991gAwIfOftHiFbIeEbw5RTansZ03GSBTWz2kmRpJqpmaoDDjnDTQ9Ia2HEeayQrqIqEMbRLI0EakARVnCCI0Fi4wWxI3uHTCNtpshKfvnYfD+oHBQgHnvOdJ5bhJbf24ZEL1W+2G0om1vfxEgZ88+LhALPAVGgJpOMFyLzaA1hEcFyE3AgPImcjg6eOeZyIyRPjPDkgexee+BIg0EykK24EjzajfDEEQBJFncpM4UFS9OE602Y+eXQwAuHzeUwBk37uimcR2HiPrMyFJHCBTWx2Vc1NIHgBaRzMByX8Wh3/Hpk0asCH3hjZXOoTnkSPN+M/0XvOpREObFuo3Q2QEFWcIwpPMXWDO0WZA6GgzABi7ejYAYPBI9Q3Et096EWPahvF8/3Qs67exPgCtJjRcnGAmclOgAZpXcIS47oACA0i3MGMkc5EhOy5zi8pERFJhETjSLCMHmO57hxxgBEEQRLMQLXFAJHK02Y33zMOeSjd6O7bgvNfwAsVUhBHd7yZTm2lOZUEOTW0xCjRASskDQHJDWFaEuG5HM1uahZngcWZ93POmNbTF7NHJxnoY2iwR0wZE2HeKj6GNIGJBxRmCCISvCyzdBpeq/f7RZvf/5CS0VypY3QMce8p67hzsS30AlxefAwDcUDoeALvp29pCI6YTDIhQoAlVpMmz4Ah5jTktzNiSuciQwosMWXGWR1akkQmJJok0IwcYQRAEMRpp6sQBt3nAvv4O3LztBADAB0553uIVun4zqvE8+TC1hYo305FagQZIrpmA5tJNScmBZlLhkzTQQNMY2kISMdJMhP+sDpw2wKB+M0QeoOIMQWREKg0ui7B0gTHco802vlzEUburYxa/+amG468duwG9XTuxZ7gTv9l1JJpKaPRxzx2FhoyQTjAg8HJ9IIzYAPIlOEJfi2MhK+3CTPDMZBt8REadltCJDFVuMk+I3GQZKTe11EAOMIIgCKKZab7EgTjRZtc+sQgA8I4ZS9HVNcwdaezXWY9KPzFTmy4WOltTmw0hUgdSL9C0mm66COGNbClqJh0hkwaaz9AG1OunmD06A0SaFWFvaGNQvxmiBaDiDEHEIKAL7CA5jjbb9/R8AMDGQ7ehUKgXGu8pVn/eX+06Cvsq4kdOzoWGSKR4MyBegQbIUGwAjRP92MIj1nt5/Lvk0f3lTB/3vGlFhg1JIs0CNbUkBxhBEAQxSsht4kDkaLPfP3oINg5PRrFtLy49a43Fe4lzGTFxQDae10jNa2pTESsWGnDUTEBYzQS0hmYCUtdMQEZJA/zveZ/iJLxmysTQxn9OZN2jkxEo0qxXfYh9N4hpA1qDNKUNEBlCxRmCSECqLjBGr8WYoupAnGizO368BOOHy9jW0Yaz3zFSfJrRsRcXju8DANy481ju3DkUGinGm4Vcqg9EFBuhBQdDVrCxFQa618YSMR4CI2ZhhkSGq8hookgzBjnACIIgCOIgqSUOOOEebVYut+HG9ScCAD54gupnEh3s4qoaGc1naksSb5YbUxuQX80U4vWuZGBmA+IkDRjR/X7r/i6AlA1tQHqGNgjHZJFmis/MQJFmtj3IKG2AyBtUnCGIgIR2gWm/XHIUbbZ/7xgs2DqhellnvHhw/+WTXkJHoYLH9s3E8gPsImU3U5VnRmpCw5aMnGCZiA0gnthQkUXxRUUkgQGEL8xEz0xOVWTYoCri+iATGRk1text3CVGmpEDjCAIgmgpcps4ECfa7Hv3Vudyr5vwHGZN380dUWkj0ZQi69cp4mhqg7CdgqlNRp4KNLkq0sgwaaY0dZPnz57nwkzw/pyZGdrEey8xDW0BIs1cSWhoo7QBIi9QcYYgMoAJDZMLjOEbbaOONgNCR5u9fHd1ZcyqKfswfsp+dBaGcVmxevPwuh2LoL5xKptABBIaKayesSGE0AAyLtCkXaTJkogCAwiblwxkmJkcTWSwlXKyQo0oMlSfFb65yTy6SDMgb00tyQFGEARB5JnmSxzQ4R9t9tyqKXisfwE6CmV89IJVlu9nO58RnfSWpra96kN5MLVlkToAkG4ykkAzxSzMqLD9PQren5OH/T3J7ksYCWVoC0mSSDMJtmkDvbbXZ4nE0EZpA0SaUHGGIGIhuRmWaoPLouokqi9DU7QZ73poFBoP/u5IHDpQRn9bAee95ylcPOFlTOvox8bBsbhjT6/k/VQ3VWV4Cg0drkKjj3vuITTyVqAhsSEhssAAcuz+CiEyEiP7WxYdYCaS5ibbRprlt6klOcAIgiCIZqB1Egcco04BXPvCCQCA9y94EgD/XSxGmInzGp2pTZc2oDG1iZRqjymZ2mzizVTE1kyAp2YCWls3JfjZYmsmwC0COpX+nDqtVKo9JjK0yRDvtfDbskgzGw0VOtKsp+7hIEVh22BoYyRKGyCIjKHiDEEkJPTNLfYloqVXsk9cni+SONrMRBumvjKj+upFL+PKYtXFfUPpaAyjXBuTotAo1R5DrZ4RiewEI7GRIikJjLy4vxoQCzN93PNoq2YY4t+wTmQA9UWaWLnJ+Yk0SwQ5wAiCIIgWIPvEgfDRZtfeuRD7K104snMjzlqySTluhIimtlLtMeTqmT7uuUEzyciTqQ1IYGwDWks3JfxZsizMRDG09amv8SBbJc+9Vs3wyO6P6Axtss+O0IY2k5YyUBS2xcK5BkobIFoBKs4QRGBcXWBMaLCKPkN0gRkbXPJYN7jUNWYTo83MQuORm5agUKkApWEs7tmKgXIbfrJT51xxERpsPCPw6hkb0dHn9zYMlTOn6cVGMwqOANceMsYsTfeXdZyZiiAiQ/z7tREZYuY6/yg+l227IMaTiJ+HcSPNGOQAIwiCIEYFuU4cUOEfbbZrTyd+tW0JAOCjp6luFJpMbSI5Wj0j0uSpAwxvzQQ0r2YCgmimlivMiKgMbTpKtcemMrTxmAxtsntMms9UShsgRjFUnCGImOS+waW4TxdtBpiExqrnZ+HIve245InqSpnf7j4cO4a7ZRcAudBQjRNv3kZYPcOTstDwJTdiA2ieQk2AawwpMAASGSPYCATZmFBZyioRIX4Gxm9qSQ4wgiAIohWJlTjATAxSeiX7chBt9j+PVft1XjrjKYzt5idKKnOJON8REwdM4xkBVs9EMLVlnTqQirENaD7NFEA32ZJWYUaG7e/fQfosxrSkoU0WaWbTo5Pvc+yQNiBLh+lVX53K0KaF0gaIHEDFGYJImUwbXEaPNgO6npiL01+oCq/rS/zNYlnGqbgtizYTcRAaOtjESCU0bHOUM3SCAeELNImLNED+REfA6wkpMIB0CzNG+rjnqqJkFJEhHlOJDF5cqFyjIXKT+eO6ZfkWkWYpN7UkBxhBEATRTIRKHGB4JQ7wRIs2a5xr3PbobKwZmoaJhf24/Lx1FhcH2M1vVDdwNcN5MjK1yUgzdQBI2dgGtLxmSiv+2SVlAPBMGkhiaJMVM0u1x1QNbaEI2KOzB/aRZpLPcPZZr+05hpHvDm+DNEFEhoozBBEA8SZXbhtcKgkXbTb91m50lIHls4FpF23TvSlHRKFRqj2KEx+eJDnKFsR0ggFhCzRAILHBuADpC48I7+lauEqzMGOLk8jgiS4yxGx0FboxSXOTecSiDP852AHnSLOisG3qD1bDFGmmhRxgBEEQRLORdeJAUXWSENFm/PEqFRRw3SuvAQB8+LinhXGiqc2lX6c4TpY4oFk9IyNnpjYdaRdogukm0kxKfAsz0ZIG+rjnPv05rQltaJNFmvka2lx6dFqQh0gzShsgMoSKMwSRA4I0uPSONusR9vlHm3WgjEt7VgMAfndyG2ad86LwnhkIDR0p5SjLCOkEA8Iu1wcCiw0eUQT4igHVeQKLGZ9/h7QLM60pMnjEz4AQuckitpFmlhSFbQ8HmAlygBEEQRCtTtTEAZNZIoVos+/ddSTKlQLOHvcS5s/ZafEK236dgVfP8GRganNJHQDSLdAAOddMoc+lIW3NpCOVpAEbMjG06YhlaLOJNNOkDQB2kdC9zheZGEobINKAijMEEZuYLrBei7FFYdtbaJgFx+sm9GFGx36U2jvx8NEFLJ+8H5OnlyzeJ7LQYJeQotBI4gSLVaAB/MRGFMEhoiu2pCAmRHyLMjHdXzJaV2Soog1D5SaHjjSTQA4wgiAIgmggN4kDPFaJA7bRZioXeWO/zpfXT8L9+6pzv4+dt1xxcbFMbRJK6kNZmtpU5K1Ak0vNlGPdZML0/ynzpIHMDW0ion6KZWhTETjSjNdOmlhKU9qA1ghNaQNETqDiDEEEwqWiHswFJiN4tBmgX7Y6cuzKYrXYdN2mozBrCDjQVsAF7xWX6TNae/WMjNhL9YE4YgNIUXDkAJ+f0/bfNO1l+UCziwzATmQkoUN47hNp1lP3cJCixdv3WowJDDnACIIgiGbFOXHAJdoM0CQOqJCtmrE3t13z7GIAwPt6n0ABA9wRldnEZf6jM7XtVx8C4praAsab+d6It9VMuS3S5ICszGxADpIGbIhuaBMjzcRj/KP4XLYtI+NIMxm1z3TbtAGG1hhNhjYiY6g4QxCRSNrgUuUCC9vgMkS0WZVjxpRw0tjNGKwUcGNpESauPhQAMHD0GtRPFmILDcWhUu0xqdBIkKOsIuRSfSCe2ABaV3CwnytWYWYZjg9amLHFO85MRXCRISLLTeaP8Y/8c5/cZJ6UIs0MDjCxqSU5wAiCIIhRRYzEAUavxZiisJ1CtNkNd8/BrkoP5rZvwxtea3K+iHMhlYFtCHq9lPHqGQvykjoAkLFNJKZmArIpzDgzKg1tPClFmvGF9F6f62wkqlGaIDyh4gxBZEwmDS4jRJtdOfk5AMDvd/diy/BY3PN/S9BRqeDlHmDJmSZXg0xoiCQUGjpchIZIn/7UNk4wFbELNICf2ABaR3Ak+TliO7+AyCKjT3MsNZEhroCTEUNQAI2fZ7zjS7csPz+RZtRvhiAIgmh1bG6kiaY2Lbw5witxIGy02b7+Hvxs02sAAB855TnIjSniXEi2rYqJFh33FqtnSrVHW1ObyuCW0NSWtwLNaC/SJNVNNmRVmAnenzN3hjZGjB6dgFekmYhHpBlDZWjTYhGfSWkDRFpQcYYg0kDiAkutwaVTtBngE21WbB/CmyZUvwiv21GdeG1aMxVHl7oAAMe88Rm4Cw0x2kwkZaGRUbxZWgWa0VakSXrdsQUGkILI4MlcZOyHXGTw4oLfjpWbrFuyb4g0EylavH2vxRgLqN8MQRAE0Uwk7TsjJg6IaBMHZFhHmylMGQdvWspW3ZpX0/z3Q9V55VunLMWkCQOG0QzTauEUTW08Lqa2gKkDQLICTVrGtmbTTUmv20VnJtFNMmzjw4PHmfGwv5eS7QtiGdpi9eiUPVpShOKz1oyYNqDCNgaTQWkDRFZQcYYgAhKqsh60wSWPMdqsh9unijZrFBrvmvQCxrQN47n+aVjaP+Pg/m2PVidQr8wsoavH9gZqToWGjj5hOyUnGBCuQAP4iw2gOQRHiGtMS2Ck6v5ypeVEhi43WTyuIaWmlkkhBxhBEATRrERLHOApCttOiQOAT7TZA8smY/ngLHQXBvGB89cIR00Rrqp+nSI5NLVZkDh+qobNjf00jG1A/nVTqOtz0UwxdJOMVOLMdFHpLWFoixhplkXaABnaiBxAxRmCiIivC8yEV4PLojAuUIZyO8q4ovgsAOC6HYsAjNwEvvPGozFtqIxd7W248IpnhVeGEBo5Wj3TJzkPRywnGBC+QJNEbAD1E/qsRUfI63D5d0mrMCPD+LvWJ2znXmQA6eQmMyJGmmmwdfdSvxmCIAiipWmKxIGw0WZAF364qhpt9oFjlsJurqNKFxARe1RkYGrLQbwZELZAAyQztjFaVTOFMrMBOTK0ufbnbAlDG4+qR2eASDMemXbqNZ/WBuo3Q+QVKs4QRA6I0uBSd0OwqDsZExoyGh3m541fi1mde7B9qBu/3T0f/Bfy8HAPDl0/FQAw9sRVCC80VNuOQqNUe+SFhu5mNCNCvBkQr0CTtthgpFWsEd8n1Hu5Coy0nF+ApfvLNs5MRTSRYYssN9lU3LUhg0gzWVNLcoARBEEQo5C0Egek5ofEiQMyEkab3bkAQ5U2nNy9GsctVGkZV1ObbBwjRVObSJ+w3eQFmlC6KU2DW9aaCWiywoyIT3/OpjS06SLNZOkDlhSFbbEwzpB8VouRZqq0gaSGNkobINKEijMEkRYJXWDODS55+C87Y7QZv88cbXZl8WkAwE93HouBSuOX8lO/OAEAsHz8IOYes93wA2QkNHhkE6mUhYaOZhUbPLIiio048H2dL3kRGECrigzm5NT93QPhVsrkJNIsMOQAIwiCIFqBWIkDB0ktcQCozh9Mc4x6Nm7pxu27FwEA/uCclyDv1yliO0cKuHqmVHsMaWqT0EwFGiCObjJpn7zpJhfSLMzIsE21OEjS/pxGmsXQJt4bEj/rHCLNdGkDaUWaEUROoOIMQQTGqsKuERoM0QXGsGpwaSs0nGiMNlvYtRWnjVuH4UoBN5QWSV/1/FMzcdTeAiqFAk5751LIhYY46chQaPCEiDcL7AQDWkNs6EhLRKjw+XmTCgwd3oUZkT5hO1ciQxVdKIqMGPFmqtzkAJFmMsgBRhAEQRDeRE8cSBRtJs4dVIYPWbRZJ/536WIAwBWHPo7OjrLqjWvI5kitY2pzwbdAEzp5AAibPmBDs+km23/TkIUZr6SBPvNLlIxaQ5tHpFnR4nJ6LcZYIDW0UdoAkROoOEMQkbHN+2dCw9sF1uv2MhShcYHxQkNk5Iv38mL1Wu/YMx8bhyZAJTTKz1Qvbtthm1AoNKnQ0JGyEwwIU6DJe5EmbXx/vhCFmSjur6RxZjzRRIYMVSNb/rhu2wZ1H62gkWbkACMIgiAIdzwTB8RoMy1Ros0AdbSZmZ/fNxObypMwrW03Ljt3PcjUZp864FOgAVrD2JYFMc1ssQszzkkDSQ1tsvsNdTSToY3h2GeGpyhsq9IGNJ/RJkNbUsjQRqQNFWcIIk3SqMzzX2IqF5h1tBnb33gjswNtuGRC9QbhTxWrZhi3XX8Cxg+XsbmzgPPeKX5htoDQEOkTtiM4wYDkBRrAXWwArSc4khRlsijMBHd/ZSoyZLnJsmPia11FhswBJjvOP0aMNOu1O7UJcoARBEEQzUzSxAFmalPdkGtIHMhVtFnjPGNoqB3XrTkZAPCxJU87vJdKT40eUxsQv0DjW6Qh3RTXzKbCK2lAJIShjVGqPQZZNQOkZ2jjUa2mAdKINNOmx3AkTRsgiLSh4gxBZIzOBebV4FJGsH4HIzcsTx/3MqZ07Me2oR48tG8edI6JPbu6cfiWqsNs5hkvwP7Gas6FRsZOMCD5cn3AT2wAzS84kly77b9XGoWZRO4vFXxhJprIUJGmyFBFmonEizRjxHaAEQRBEETecU0c8KZXsi/1aDP58f+882gAwLnjXsCCOaJIsekfITO88DSfqS12304grrENaG5zWxLNF9vMBiSMgG4aQxvb52poc6W1Is0obYBoFqg4QxARCOUCs8bkAuMpCs89o83eNOF5AMAtu4/BMMbUHatS/4X94i3VDOUXJh/ArHk7Fe/pKzQCrJ4RI5p4TJMvRgZOMCB7sQE0j+BIWlByERiZFGZMiL+jKpEho1R7jLpqZkgxPoTI4MlPpFlaDjBank8QBEE0FWknDqhIOdpsxdpxuGfvMWgrVPCJC1+C3RxIZlqTHVfNywKa2lTzyQSmNhUhY6EBN82UtEiTd90U4jpjm9kA+whoKaHjzHJpaNPdX3HBtkengaKw7RFp5gOlDRB5h4ozBJECuXSBeUebdaC7MIDX1SLNfr3rWKtLe+rueTh8PzBUKOCc9z4JeYayiK3QEPcxHISGDJXQyMgJlnexAeRTcIS6plACA4hYmAnt/mJ/AyXJ69m+oCKDIeYm8899CzWRIs10RIg0IwcYQRAEMVqJljiQebRZ/Q3O7z99AgDg/XMfR3ubbb9Ohs74wohoauNJYmqzSB1QkYZmApIZ24D86aaQminLwoxX0oBIkjizUu3R2J8z74Y2nx6dCSPNNLDPcOo3Q7QiVJwhiLTxbHDJEIWGFJnTIFi0GXDe+JUY2zaIdQOT8HT/bNhkKANAxwu9AIBt8zeiUMhCaBhuCsuERtIc5T5huwkKNEBysQHUT/DTFB2h3ze0wIjSYwYIE2fGI/vdN4rxpCJjUDFeRhqRZjKRIVBEvcgQo1AYFg4wqwbGNcgBRhAEQbQCuUgcsKGoOmAbbaaLT+XpxI33HIpt5QmY2V7CO87eCLmpTdav0zQ3CrR6RkZoU5uENDVTWsY2Rh50U1JcNVNmhRkZfZpjPnFm1sQwtDFiG9p0YwVsI800aQO2sO8EqaGN+s0QOYWKMwSRc0QXmIh1g0ueorBtFW02MujiCdUGlb/ZvRgAuyFszlC+7UdLMLZcxqtdbTjnHeIyfV+h4bJMX0AnNGT7bIWGoxNMRauIDYYoOpIKgdDnE3H9+dMSGEAA91cokZHKqpkkbjBTbjIU+5kDTIaFbdYUaVZDdIAxRAeY840nAXKAEQRBEM1C6okDvHnCK3FABYv8sY80YwwMtOP69ScDAD72mmUWr7AxtYlzqAirZyL37ATSK9AA7sa22Lop9PlC6yZbkqQMACkkDajizFTYGNqCr5rRGdrS6tEZKdKM0et2eifI0EbkDMe/JoIgbHlLpYKbCyNf/DevAN5yhDDoTgAXyF+/GM/4TfB60TjBmIaRm67TUT+BKEK4+crfeBRnER3owDDOGFtVTXfsORrVL2a7G6Z7tndj4eaJWDZzD2ae9SLwc9OEcAj1H1Oy9xE/xvZj5Gfgn+9CVUDx+yTsRaMA2wK1G56xCeqbsH0wTi42r54rjVtYi7nSVVKrsFDpsl+JBcpiHmM5jrJe+st+D5PeIFaRl6X8DNe/u1gCQ0VQ9xePKTO56UUGjxhhFiDSrGjxmt7ao6drlxxgBEEQxKjiVgAX1e8q3ANUztW/bCFWYSUWYAFWYhUWYi7WujWQ57VTEfVaqQd2XpM6OjEyd+lAda7TUdvXKWxXx//n3Ufj/73/LrxuwrOYN+s8rNnIzgOMaKRB1M9b2Hl0MFMbm8jw51DpKG6zhMY5j2zfVozccOWfi2xEfYGsD/W6aV2hYd4k00150ExAVUeMFs0E5Ec3RSnMiOR+1UyWhraAkWaStIEkkWYuKTUEkRW0coYgckawfgLeDS71nNDzMsa3H8D2oXF4rn8ed8QuQ/nF31UzlF+c3I+Zh5Vgv3rGRMarZ0QCOcF0pOkGA8I7wvKGz8+XVGDoSF1kyCgZjjegExk26ESGK7JIER5RfHhGmvGk1NRSCjnACIIgiFEKuxluuonulTjAU1QdUEWbiTcxbaPNgOdXT8BD+45Ae6GCT1y43OLiGPwcKunqGctIaJ6k8WYyMkwdANySB4ARTUG6aQTbf8OQSQMNWP4eHcSlPydPSxnaGAF7dBYV4wJGmknRxGMyKG2AyAoqzhBEFjj2ncm2wWW90Dh1bPUaHt63EBW0wTVD+cl75mLBfmC4UMA5733KMJ4nBaFRqj2GEhqRl+oD6YsNoPWKNL4/T4jCTOrur0xFxn6YRYYMsVDjKjJUDjBVbrJjpFkGTS3JAUYQBEG0EuINMekKUIsba870SvYFjTYD9Dcx9atcvvfMiQCAD/U+hvY2sV+nan6UI1ObDpd4MwkqU5tvgSaGsQ1oLd3kW3Sy1UxRI6Bl9AnbSeLMSrVH2X0EKXkytIm49uhkaNJJisJ25EgzW8MzpQ0QeYGKMwSRNTEaXPZKjqmEBiBxL6iXoh7f3QcAWLr/cLdr4+h4rhcAsP3wjUAhR0JDRol77pqjbEMOCjRAMrHRrIIjSVEmhPMreGHGRFOIjCTL8XWYRIY4xgA5wAiCIAgic6ImDniZ2nToYlRliQOduP7OudheHo/Z7dvx9rNehf3cqAlMbSIBUwd8CjRAPGMb0NxFmiTXnraZDchJnBnb1xSGtlCRZoBTpJkMTdqAaGizQWpoo7QBIodQcYYgImLlApOQ52izRd1rAADP9M9DYxSQTbRZJ37/wyUYP1zGq11tOOfSl5ALoSE6wUw3pEMJDSCVAk1MsQE0j+BIUlBy+ffx7THjlIcu0ids51Zk8Kj2ic9j5ybzx1S5yQJFYZscYARBEAQRnliJA0GjzdjNSFO0GSTbag4MtOO69acAAD5x8lLNSJ/VxRmZ2lJIHQDiFmgAP2Mb0DzmtqTXGcLMBmRUmEmaNGAkz4Y2nsg9Onm9pDG0qVJibNIGCKJZoOIMQWSFo9BgpB9tNiI0ugttmN5RnUysHuBP6pahvHtnNxZurt7sPOSsFzUjUxQaMsQoJyBOvJkDvgUaIL7YAPIpOEJck0tRJnRhpvVEBl9YFfcxVCIjVm4yUC86ZIRraslQRZrpIAcYQRAEMWoJmTjA6JXsc442MyG7mWnXr/M7d1abwF84/jksnLsX8n6dIoPIzNRWQiM6I1uk1AEgnQJNK+mmUJophJkNMP9/EkktaYCnpNkXZNUMI7ahTbVtU1zmC9ISisJ2JEOba9qAaGijtAEiS6g4QxBNgrUjIHi02QiHdGwDAOwtj8Hu8gQkyVB+4TeLAQDLi/2YdfhO5FJo8Mj26TBN7PqE7QBCA8iP2ADqJ/hpio6Q75uWwEhUmDGRK5EhohMZqjGu2OQmpxtppi2mY+Tz3vtGE0EQBEE0GblMHIgabca29eNf7BuPu/ceg7ZCBZ+60EYPZmxq44kRbwakVqBpdd0U+j1d/g1C6yZrzdQnbPskDdj052yA/T25rppJy9Bm06NT1E6ygkycSDMXgn0nEESKUHGGICJjVYGP4QJjBIw2G9t2AACwZ7gbAD/Zcc9QfuL+eThiHzBcKOCsK56yePcWExoyclagAcKIDYYoAJIKgdDn43H9uTMtzPQJ27kWGWJBlSETFuLffOzcZKCpIs3IAUYQBEGMFrJKHFBRlO00RZu5FGXq+a8nTwQAvH/uo+jsKENuahPnUuJcS9yfY1ObZ+qAVRN4DpuVGa66KRRp6KZQuJrZUinMyOgTtm2TBnhk/TllaG85NKuhTRxjIEmkWQ32Wa1KG9AZmCltgGgmHP6yCIIIwc0rgLccYR63ZPsLWDrlGO2YBViJVViIuVhr7lVxCEYmHNMwMuEoon5i0QPJZKL6zTqMLgBAR6GMkY8PfmLQCZdJQeHZXuDUPmw/fAMKbWVUyoMY+fIfqr0Hvw/c+cWPL/59+fH7MTIz4J/vQlVE8fu4zRIahRe/bwtGbshuxcgEg3++CfWTjY2oL5b1ofGm7bpCQwb25tVzpSJyLeYenKiIrMJCbbM8Nim2iVECRsRGjEzXvCzlB9xFlY1gy7Qw4xpnVpJeTpVEIkOGGFvouxzfFjE+JNtIM4IgCIIgwrIYz2AZjseReMluTjenUp1f9aJxTsXrpekYmS+Ng8S8JRVQEnidA4xonY7avk5hu8pP752Fb5xRxMz2Ei4/fx1+dNs8w/swHcVQ6SQ2ltdDvPYSdJKJvWicH9loJqBRN4n0oV43STSTCpNmAvRNxldiQS40E0C6yRpTBLSILmnA1J9TNLQpkwaa3dAmM7VFijSz/NtmuEaaEUTeoJUzBJElCRtcimgbXMpwjDbbPDQFADC1YzfGFvpre/0zlG/94YmYMFzGps42vO7dquWnFRzSsQcXjl+Nj0x+DH814y58Y/av8D9zfoz/nXstvjfnB7h65g341NQ7cVLPanRguPY6m9Uz/ARJQqn2aN13Q4KPEyylFTSAmxsMCLuSJk/4/FypFmZk9AnbusKMDSXueXCRIYsh5I+JiIUaV5Gh2s5PpBk5wAiCIAiinqSJA4kxRZuJGOsWqgEu/To7MTTUjmvXnAIA+MQJLHEgyeoZ9pqUVs/kPHUACBtzBrSuZgL8EgZsVsukGgEdOmnAmiSrZtIytDF0hjbZZ1vzRJpR2gCRN2jlDEE0AcwF5kwv9C4wHqkLrJ7S8ERsHJyKWZ3bcNq4l3DnnkXcUSYsBrlt3YShE3t3A72vFvHMobsw+fQXgB8vwri2vTiueyeO796Kxd1bsLhnK2Z07NNf2EHuwMbBIv57+/n4v9J5GPmKVa2ekR1XDGFOsBLcVs8A7k4wBTFW0ADuq2iA+K6wtPARTbbCLGhhJqn7K1ORoSOWyHCNNONzk/MRaUYOMIIgCGK0EzNxYMb8tep5F584wFPESNGBf34Qcd4wiMZVNX5zne/cdjQ+9/Hbcda4l7BowW48t2qC4RW61TNA4/zItHqGpQ4oKKH6b8KvnmH7bEkhdQBAYt3ko5kA0k0qTIWxzJIGVJQ0+zJbNeP6uaIztEXu0Rko0kwHGdqIZoNWzhBEClg1uJTcbMtXg8vqzcrf7DobAPCpqb9DZ6GCJBnKHSij9NOFeP2TZZx73wHcctRNeGThjfjfubfis9OfwOsmrMGMjn0YqhTwYv9U/HrXEfjPbSfh7zadi89veCM+u+Gt+MLGN+FftpyP3+w6DjuGxmJWZwlfPuQmfH/OtzG2sAfeq2eS5ihHcIIB/itoYqyiAUacU83mDPO95lwUZmTYur94THFmbF+UVTOD3DGRUIUaUWSwR12PLBnkACMIgiCIzAmcOHAQdmO/V3KM10u8GUM6D7CN/xLjgWwSBzrxysYe/H531az3R+c/LzlvktUz/HHTjWPF6hmeJD07I6cOAGGSB5LopmbCV+vZ/hvlqjAjkotVMwxbQ5tLn14gXI9OCUVhW0xtUWGZAtPsxU6C4KGVMwTRpCzEKqzEArULjGUo88hcYHyGMiBxgfE9WoAfld6Ey4q3YVH3K/iPQ7+Dv3r1CmwcmgBThnIH2tDbtRVHjtmO47o34/juLTi2ext69g1zYmsvUADWD47HM/unYVn/dCzrn44X+mdgf4UXMrx4GbnZ2lUA3jXpKXx2+m9x+rjl+Nah38PH1/0JygCcV8/IcMlRFgnkBAP8VtAA8VbRMPK+miaJGApRlAECFWb6hG1fkSEjl6tmQuUmi8LCrRkvOcAIgiAIIv94Jw4wZsHOSMXQtpsR+2Dycx02b2GrVMymlP987ARccuEyXDHrMfxZz8nYt198DxORVs+E7NlpQx8SraABkicPAO6raBh5X02TtIAU08wGJCjMiPgkDZQk52H7oqyaydLQJh7TrZ5xjDTj/97ZPZFezSVqSJo2QIY2Ig9QcYYgsuZWABfV7yrcA1TOlQ+3bnDJ6EXQaLPNQ1Px2Q1/in8/9J9wxrgX8fv5f4OH9x2BZfvnYtPQeOwtt6OzMIjxbfswvWM3ZnWWcMSYrZjfVUJnodxwvl3Dndg4eQyePW4vNswEfvWtN2P9nkmQO8pUVCckA5UOXF86FU/vn4dr5n0Xp49bjiuKd+G60oWwExqKeLMSwgiNJinQAGGKNIysREcoZ1pTF2YyFxm6VTMyYmYoq1bLpBBp5ugA04oMgiAIgmhh3lKp4ObCyBxIGm12J4AL6nfZRJs5wZvaeO1UhEO0mTgv4ucfsrlO1dQ2Uqxh21Vuvn8m+s6djt6OLfjw61/Gt3+lmnsyvcOizdi5ZPAmGDYmkqlNB6+bTJoJSKVAAyCqsQ3IT6EmhG6KrZmABL05Ab+kAZ4S91zsz6lE1q8p61Uztoa27CLNGKKhTYdv2gBB5AEqzhBEnpAIDUYUF5hKaAAKF1j1m/ahfSfh3a9cjb+c8X2cNu4FnDluOc4ct9x4CXuGO7FiYDJe6J+CZf3T8Uz/NPQNjEWhfRjn/cGN2NzZhpPf/xLW/8epijPYCY3nDszF1ze/EV+ZeRM+MfVW3Fg6tzaFMQkNDSX45ShHcoIByQo0gN0kx9cRxpNWsSZ0TIBLXEEmhRkTNg1XS9zz4CJDhyhCQokM2baq6W6TRpqRA4wgCIIgDqIztTG8EgcYvF7iEwekpjbtMhoBvl8nv4pGPb6CQXx3xWvxD8f8Gp9a9Bi+/asFqDeeiX1meHj9JLsONsZ29UxgU5uITYHGEt8CDZCOsY0h0zLNoJvS0ExA4Ajo0HFmSkObap/LqhkZaRjagCiGNh21ex6qv1f296X7u6C0AaIZoeIMQaSElQtMQioNLh2jzQBg1cAcfGTdF3B418s4bezzOGbMGkxu34NxbfsxUClgf6UDW4bG49WhsVh5oIiVB4rYMDQejc4PoDI8iBkvz8LmIzehvHg1gFORVGj8dOep+NS0OzCjYxfOGf807thzEhIJDR7XeLNITjDAv+ElkK7Y4LERA+KEK8185pACA4hYmPFtZpmayLBZNSMrviQRGba5ybJjmkJNxEizEJADjCAIghgVSBIHZDBTm3fiQLBoM9mKE6B+/mMXacb47u8X4q+O7sKirrU4/zXbcNeTquqGz+oZMYYtkqktR6kDQLgCDTC6dJNr353UCjN9wnbqSQOMUKtm+L5QaRraIvXo1KUNOEJpA0SrQMUZgsgDntFmzAWmhLnAehE02qxK9Qv65YHD8fLATNRPLvgbsbJGk43cfd1JmPU3v8HLPQWc8rqX8djthytG2gmNYbTj97uPx/smP4DXjl1RK84AcqFhKThKSM8J1oeWFxs6smiW2TQCAwjj/ipJxrJ9xjizGKtmxNe5NrRkuOYm840seQeY8JlQFN4mUKQZOcAIgiAIwgNN4kAwEkebqe7aylbK1PfrlEebdWJbCfjplpPw/hkP4dNnPY27nrwQuVs9w5Mk3kxGhgUawN5c08q6KbRmAlIszIjYxJnxWN0nsTG0iWNcVs0wsjK0abA1tDFSThsQDW2UNkDkhbasL4AgCHsSOwNkX378l2RRONZQr2A3LGWFDHajU/wCF2+KdgrbVV59ZSKO2VY974KLl9X28hMO3Y1a3kkyMrF5cv9hAIDju1/BiNAAN1ZEduO5cRNA/cSspLgsXayUeDPd1p2nWLaty9/VTXaB6oTZZtLMWIkFzpPyPOP689j8e0Vbkg/4u794StxzK5EhI8aqGR9kuck8CXOTeVKINCMHGEEQBDHaEW+Y2a4Q1fUcYDfJ2c11diP+4M16nZmC10u8OUM6L3BYYdIwJ7G7Efqte6pR128pLsWhM/otXmEz55LN1WRzPcvYtpJhH28Y0t0g1xmQGH2SfRE0E2BXaOBpJd3k87NkWpiRoUsaSNXQBtivmhFfH8rQJkNmaGPbnpFmMmRpA0KkmZg2YFPolBraCKIJoOIMQaSIVWVeUuG3aW6mFBoyZF+GYg5oURwQU2jUVuH8fjEAYPmUfZg+Z7fF++gnImsGqkpqVud24YiH0GCHS4ZLkkVGMUzOHPGme59iXM7ERjMKDt9rT1VgAGHdX6Y4s1LtMdVVM/x+8XxJMBVqGKLIkBChqaUN5AAjCIIgCAmSlaO6G3LOPTt63YYfRCmVxJuZzMym00p648ljzxfx6P4F6CwM44/e8GJtr8zUJs6rxHgk3kjDk5KpTVWg8TG19Un2eWqm0MY2oLmLNL6aycbM1lRJA9b9ORmy+ww+q2ZCG9pk2xF7dNqmDVjCPtOtDW2UNkA0AVScIYgMSdIngH0pGR0EzAXWKzmmWjJuvfRcJjTE4/Y89PvDMX9/BQNtBZz3/idqe/2Fxpah6kRhWvtuABWohQa/TyE0eEq1R1ehISKbIOagQOMqNoDmKdT4XqPtv0vUwoyM0HFmRmxEhsuqmTRyk1U3QWSrZwJFmjEUTS1FB5hzpBlBEARBEFqCJQ7w5gtV4gD//CCS+UQDYmyQamVvYyHnO0tPBgB8uPdRdHQMG94HSHX1jMzU5rpSO8MCDRDH2AY0n2by1U0mTP++0QszNkkDPE79OXeJB5B81Qy/X3Y+F1or0kyKxNBGEHmGijMEkRcsXWD5jjYDRm6EyoSGLtqsE0ABHc9We81sm78RhbZkQmOgUj1/W6GCNgwIR3VZsJ5OMBUuTjAZfYr9kdxggH+RBsiX6OCvxfd6bAVG9MJMGnFmLS0y2LEUIs16/d6CHGAEQRAEUSU3iQOMaNFmgNq1rub6O+dic3kSZraX8J7z19f2prF6xtPUptqnM7U1SYGmFTQTkPx6QpjZgJQLMyLecWYioVfNJDW0ifj26MxfpJnNZz5AaQNEvqHiDEHkkTQq/dGjzQAfoXHrtUswabiMzZ0FvO69z9f2JhEaVdrAvnxthIYFJcM+l3gzX6EBRBUbgJ8jjCdEcSSr98tEYABxRYaMUu2x6UWGDF5kiAVjmcgQCBhp5oJtpBlBEARBjBZMiQM20WaJEgdskEol2XxDXN3LY9evc3CoDde8cioA4A9PfAJ2hFg9I+6LGG9mQ4ACTdLkASBZkQZIXzOFfE+Xn91kZgtemDERLM5MZmhj+wZR/weRhaFNFpUoe+7aozObSDMZ0u8AMrQRTQIVZwgiZTJtcNkrebFNtFmiDGURvdDYu7sbvesnAwAmnWbnglAJjXFt1dUyB8odGEIFjf0x+AkTvx+ILjSarECTtEjDEEWArxgIdR4ZoQQGkFFhZtSJDNm2b25yvEgzn6aWJsgBRhAEQYxKLG+4pZo4UISAwlkuxa9f57///mgMVdpwes8KvObonbVjaa+esaDkNjxx6oAK1fwbYZIHgHC6KaTWiaWbXDWTV8oAoP3/VodMP2cSZ6brz8mTR0MbELVHJyNhpFniz3aCyBlUnCGIPBG7wSUjkdAA/DOU7YTGwzecjLZKBS+OL+O40/yX6U9or86mdpfHCOP4iYwsR9mAS45yGk4wIFGBJosijYhKNMQswoiEFBhApMKMCZfCDE9Ti4xIuck81n24QA4wgiAIgkgTz2gzhlW0mSlxwGWeUGdqE3Gbq6x9tQe/KZ0AAPjT85dZvspn9UwAU1up9mhjahNJwdQGhDO2AfF0k6tmiqGbXH82b80EhO3NGSLOzBrVPYa8G9rE1IEEPTplaNIGVIY2Hb5pA2RoI/IGFWcIogkJ5hSwiTZrIEm0mR2rnpuOY3Z2AQAWv/1Jy1c1Co1DOnYDALYMTeTGqG4sOwgNnpLlPoarEyxQgSaG2IhVqEmbGAIjWmEmqfuLp+Q4vqlEhuwR0OcmC/RAXaiWRZqRA4wgCIIgguKbOKBDTBxowCdxgMcYbcaTJNqsuu+b978GAHDZ9McxY8oByXuHWD3DkM0FDea2JkodAMIa24C45ra08dFMqRVmYiYN8JChTY9PpNkcfaGEfWY7G9o4Qnx3EERMqDhDEHnF0QVmHW0mwzvaLEaGcnXfhjsXAwBemrEHU2fulry3WWjM6tgOANg4OFEyLiOhIZJSgQYILzaA5hUcPgWmTAUGEMb9JYsza0mR4ZObrFgRqBIZOgyRZjrIAUYQBEEQFjRF4kDMaLMqdz85BUsP9KK7MIhPX7y8ttdkfPFdPcOQnd/T1KbDZDzKqEADuBnbgObVTIDftdv8++SuMMMj0/Xe/TkZzWRoY9sBe3Qy0jK0UdoA0URQcYYgMiCkC8xZaPTWHoNEm9ngJzTu+9UC9PZXcKCtgNd9+PHaXjeh0dtVLc6sGyxyrw8oNEzxZvz+JE4wIHWx4Ss48iw6klxjpgIDSNH9pRqQJ5Fhg/tnTh02CwQTNLUkBxhBEARBpENqiQPW0WYyU5t4nEffrxMo4FtLTwUAfHz+Q+jsKGveO/bqmcA9OwH3VeIBNVNMY1ueNRPgf522ZraoukmH7neoZLnvILL+nIxmNbQBQXp0ytBEmjGYoc0GX0MbQeQRKs4QRN6wdIFlKzRUdy5dMpRNQqMNnc/MBwBsmr8R7Z2ySYh+9cxRYzYDAF46MF0yhn894CU0eErqQ1IyLtDEcIQx8iQ6kl6LreiKlpUMuBVmbCjpDjaDyNB9dqhyk/ltmQNMUmwuKp5nGWlGDjCCIAiCaCRW4kCv5MUyc0ZR2G6QSuxmpkpDsfmLX7/OH94+D5vKkzCzvYQrL2Q3N2X9OmHYJxvjYmrT4GtqE4mhmTIwtgH50kxAciNbEM2Up6QBHuukAZOhTfV3l5WhjSe7SDOx5xf7bGaf1UkQDW2UNkDkESrOEESe8WxwGTTajMc7Q1kUGiaq4393zYmYPFTG1s42XPQB1uTSdvVMBUd3V4szLx7gK1C2q2eMSwrUw3ycYDJkYkNGn2K/RmgAccUGg5/kpyE8Qr2fi8AIWpgxLcsXEX+HmkJkMEKKDJ2QsIk0A4JFmtVIM9KMIAiCIFoV18SBTKLNeIrijvj9OgeH2vD91acDAP74xEcA6G4+ZrR6hqekuTxGmqkDQOICDZBMN6WtmUK9p8vP7KWZgOySBth+y9sCIwPF+wyiVkrb0CbbFqPoE/To5ClyzwNGmjmnDZChjWgyWq44s379enzzm9/EG97wBsybNw9dXV2YOXMm3vnOd+KRRx7J+vII4iAxKvbWzoLe2mNqGcpAo6PdTP/eLsxdU70j2nGSKcOnfgIzs2MHiu39GKoUsGqgCPnkR7yxLJt5eQqNEE4wwC5LGYhaoAGSF2kYogjwFQQhziEjmMAA4hRmkrq/2P5MRYbPUnyGa26y+FkjigwJSSPNAjS1NEEOMIIgQkC6iWh6LG/A5SvaLE6/zm/dciT6K504cUwfzjlxe+1Y0tUzGZraRHJQoElTN8XWTKF0kw3eZjYgTmHGhpLuoM39AXFfloY2VaSZbJxHj86iMERmaLOINLMh2Gc5QeSMlivOfOtb38JnPvMZrF69Gm94wxvwuc99DmeddRZ++ctf4owzzsANN9yQ9SUShJRMXGCMqBnKsi94e6Fx749OQWelgtVjKzj5DX21Y2ahcdSY6gzs5YGpGKiI1yCbHOmEhji50uQolySXI+LiBJPhU6DJkdhQoRMPsd1krkWZoEvyAbu85FDNLKWkLTLEsa745CbLPosSRJoxIkSakQOMIIg0IN1EtBSxEgd6JS9KFG1mwq9f56Zt3fjp1pMAAJ85+ynDaNvVM6oxgLepLUm8WcYFGsDe2AbE00150EzRzWxAvMJM0P6c4sBWMbQ5orpf5BBpJqYN6IzHlDZAtBrexZlLLrkEN910E4aHh0NeT2JOPfVU3H333Vi5ciW+//3v4+qrr8ZPf/pT3HXXXWhvb8enPvUpHDhwIOvLJAg9AV1g6UabqWDRZu5CY93qIo7dVj33/IuWGkaPTHCO667O0l48MJ075iI0ZDeWZX04hJfxpO0EA5pebKSN688RXGAA4d1fJc0xICciI1ZusvgZYys0PCLNNA4wUWQQBDF6yatmAkg3Ec2Da7SZDmZqc+5lkGq0GeDerxP4xp0nAADeXFyKw2bJ5nFJVs/w2JjaUkgdkOFSoOlTnMNCM4023eSjmRKlDPRJ9ttoJhGfwgzbL/3V1f1+838L4v4Qq2bSNLRZ9OjsQSqGNudIMw5KGyCaBe/izK233op3vetdmDNnDr70pS9h5UpzlnoaXHrppTj33HMb9p999tk4//zzsWPHDjzzDC2FI5qIhA0ulfTWHlONNgN8hMby3ywBALw0ZT8OPbJU26ufqJzUsx4A8OS+QxXXIU6ShqDvsRFZaOSoQOMjNppJcPhccyKBAcQtzDSdyBDH2qL7rBBjE2VuMY/cZBkekWbkACOI0UteNRNAuoloESSmNtkNu+aKNpPdMNVRPf7ki5Nw/96j0FEo408vfs7wmpCrZxxNN9ZGoRquqQMhdJMheQBwM7YBzVmk8blmK80UImUgSQS0ipLuoE9/Tp60Vs24ktDQxpOHSDNKGyCaEO/izMqVK/H5z38ebW1t+NrXvoajjjoKF154IX784x9jYGAg5DUGo7Oz+oHT0WHTGJgg4pOLBpeMzDKUdVRf89hdvThqbwHDhQLOvvIxw2sG0VkYxpKeV6uv3T8TdkKDIbsxbYkq3qzJCjSAe5EGyHehxvfarP8dmqEwIyULkRFy1YxNbrJujCE3mX9ODjCCIDxoRs0EkG4iRhf5izYDfKOG/vWxkwEAH5rzCMaNlc29QqyeCWBqc4k3E0mrQAMEN7YB+dZMQGTdFEszAW4R0CXueaI4M4bufoLKuJaVoU32yB937NHpEmnGMESa6SBDG9GKeBdn5s+fj6uvvhpr1qzBTTfdhDe+8Y249957ceWVV2L27Nn47Gc/i+effz7ktSZizZo1uP322zFr1iwcf/zxynEHDhzArl276v4jiEyIGW2mwyvazDVDWbaMVs/+R44EAKyZswXjJrKIDbnQOL57M7rbhrF1qAerByYrzsjfVG48R2KhYYvJwROzQGNZpPEhD6Ij6TVYF2XSKsyIuLq/WkJkyLYj5yaTA4wgiAQ0m2YC7HQTaSYil3gmDlgTPNosfL/On983Cy8PzUCxbS/+4KLViotj+K6eYejizURSTh0Acm9sA+r1Sla6Kek1ZG5mA/wjoKMa2lR/X1kb2sR7MSZDW8Aenb26axyBfTYnMSCToY1oJryLM4z29na87W1vw80334w1a9bgq1/9KorFIv71X/8Vxx9/PM466yz84Ac/QH9/f4jr9WJwcBDvf//7ceDAAXzta19De3u7cuzVV1+NSZMmHfxv7tx8uhmIUUboaDPRBZY42kwFLzR0mIXGLT86HocMlrGrvQ2XfPRx7dlO7qkWoZ7YPxtAAXKhwSNOklLMURYx3ZAH3IVGn+ZcEcUGQ5zwxxAeId8jmMDoUxzzyUsGIri/mlFkhMxNFhB3FSVvT00tCYLwpBk0E2Cvm0gzEWnTtIkDKffrLJfb8K3nTwcA/PGxD6CtrVw7kmT1jIupTdxn2bOzpBjD78+6QJOBbopBJpopy8KMa9JAw6+zqeCouo8AjOgkmVbK2tAGuBnawvbodEFmaJN+xpOhjWhSEhdneGbNmoUvfOELuPrqqzFr1ixUKhU8+OCD+MhHPoI5c+bgn/7pn1Aul80nCki5XMaHPvQh3Hvvvfj4xz+O97///drxX/rSl7Bz586D/61dS818ieYh+2gz3RJY3wzlKuVyBya/WJ387TzmFbS1qZfpnzy2OpN7fN8s2AkNnsA5yiXueRpOsByLDR5ZMSXJfyGILjCAHLq/xO1mExkybHOTDZFmPClFmpkgBxhBtAZ51EyAm24izUTkipiJA4p+cgDM0WZF8WA6/Tq/e8t87CiPQ2/HFlx+3nrD+W1Wz/AENrXJ7nvLjEaMLAs0gJVmAsLpptCaKYRucvrZTP9efYr9MQozNmiTBkz9ORkuCQNZGtrEIk3kHp2MBJFmBNGqBCvOLF++HJ///OcxZ84cvOc978H27dvx/ve/H7fffju+9rWvYfz48fjiF7+IL3zhC6He0ki5XMZHPvIRXH/99Xjf+96H//zP/zS+ZsyYMZg4cWLdfwQRkxAusGANLhne0Wa26HrQqIXGLf9zCiYMl7Gxqw0Xve9Z6ZnHFIZwUq3fzCP7DuWOJBEa4r5A8WYl7nkooQGkJjZCFWqyxPlnSSIwcuP+4pG5GZtBZDDECBBZ4YYXGRaoHGDsc5EcYARBJCCPmglw102kmYjcEjpxgNFbe3SJNrPCpV+naS5TPb53fye+98oZAIDPnfIwAKY3Q6yeUY33NLXxlLjnOsORayy0iiTJAykXafKAs2aKnTLgGgFd4p7Lfr+8kwb4seLqMll/zrQMbTaoohR5AvXo7LW7IkobIEYjiYoz/f39+OEPf4hzzz0XxxxzDL7+9a9jypQp+Od//mesX78eP/jBD3DBBRfgz/7sz/DSSy/hzDPPxLXXXhvq2rWUy2V8+MMfxg9+8ANcccUVuOaaa9DWFnShEEHEJ8BNOesGl8GjzVyEhonqa/bs6EbvuupFdZ32ImRC45SedehpG8bGwXFYOTCltlc1qZGtABALN6LQCJijLJKXAo2l2ACaV3A4X7fNv0ufYr/q/0Vm7i+XOLOsRYYK2c0JWW6y+PnCPpMmNu7mKVpeRm/tMXRTS4IgWoo8ayaAdBMxOsl/tBlg7gWh5hu/PQb9lU6c1P0yznvNdsPokKtnxH2OprYS9zxJLLSsb6fLChogmLENaF5zm9d1+2omIFlhJnjSgAn+/oBsNY2J2Ia25u7RqfuM1sVWApQ2QDQf3rPuT3/605g9ezY+9KEP4ZFHHsHll1+Ou+66C88//zz+9E//FJMn1zfhHjNmDC666CJs3WprY/CHCYxrr70Wl19+OX74wx9q+8wQRFNg6QJjX2LODS4ZqQoNftuOe659LTorFawaC5xxSePPePb4dQCA+/bORaOQEIWGiHgT2uSMAbxylF2cYLEKNH2KYwzPIk2eBYe3uLARGH2KYzELMyXuuXecmbjfFFNhIrTIcMlNdhUZ6UWaOTe1JAcYQbQMedZMAOkmorlIK3Egu2gzU79Oce6j7tf56tZu3LD5FADAn5/D9+v0XT2TxNSWQeqAal+MAo2DZgKaw9zmdY1JzGxA2MJMkKQBn/6c/PO8Gdoi9OiUETjSzDolhtIGiCbGuzjzne98B1OnTsU//uM/Yt26dbj++utx7rnnal9z3nnn4a//+q9939IKtiT/2muvxWWXXYYf/ehHJDCI3BOywWUwTEKjgdgZytV9G1ZNwjFbxgMA5lz0NHe8Osk5ZxxfnOGR3STmJ0vifkbAHOUS9zzrAg1gFhuAs9gA8lWoSXQtNj97n+ZY1oUZ58xk2QvzIjJscpPF8TKRYSBvkWYc5AAjiOYkr5oJIN1EtBiGG3P5jTaz6dfpZ2r7p9tOAABcPHEZjjl8j2G0afWMiIupzZES9zxJ6oAKn96dfYZzJijS5EE3JbqWEGa2NAozKrwKM7r7BTZ6J0tDGxCsRyf/nCLNCMIbn1BCAMBtt92GCy+80Ok1Z555Js4880zft7Tiq1/9Kn7wgx9g/PjxOPLII3HVVVc1jHn729+OJUuWRL0OggjGrQAuMg9bjGewDMdLjy3ASqzCQszFWqzFXMyYv7Y66ZpTqU6kelGdLM1C48RoGuSToCIMkx1RYMhmPWwSwj6K7CYjy35+MvDJe/DCpAEcc9JGvPBE9Vt/XucuHNa1G4OVNjy0dzZ3TnHiMQT5xx97/w7JuP2o/kyy8/GwcfpdDZRgH6UEVP+fiBOgTZDfPN4ItdO/r/bYq3kvNtnWuQYViJP7gw7ESAQRNkmLMkCywoyJEvc8qPurmUSGDNkNDBHHSDPdjZbe2mPMppbkACOIpievmgkg3US0MHcCuMA87Ei8hOU4yv99DkF1bsfrpekYuTk8DiNztR4I8zM2H7EtZnSifk7Vgepci+1n2yM8t2oCbtl1PC6e+Az+/A1P4yP/xT5XeD0j00WyfeJ8rpPbz57zooftZ/vY4y5Uf3aDZiphZE62FyMGGn4/UP235g01okaSaSbAXzf1Ko4xmkQ3paaZgHhmNtk+XQR0lDgzhiyxI6+GNrFIIxraLMgw0syEyehMEHnEuzjjKjLSoq+vDwCwZ88e/P3f/710TG9vL4kMojmRCI0l21/A0inHSIcvxCqsxAL392FCgyeK0JAVZcxC49mHZ+HdV3Tg+UlDWPLux/HCE28BAJw97hUAwBP7ZmBfRbYwkAkNvgjDvzc/GVEJDQj7RKFhoAS50BCPmYSGap9OaACZiQ2GTgjYCpBozrIQAgOIKzJKhvcGDIUZFSqRIa4oy4vIEHvKuE5nLCPNZHhGmskgBxhBtD551UwA6SaiOXlLpYKbCyNztptXAG85Qj2+cA9QERar6UxtDKWpTQdfpCliZN7GPwcgEVCon++Ix0StJO4TGSnkfP2+k3Hxm57BFYc8ii9OPRmbt43RvE40oYnGNdk1+ZjamrxAAzSFbhrVmgmIHGcmO4Gok1RkbWgD9Ia2BD06PSLNdFinDWgMbZQ2QDQD3sWZvHLNNdfgmmuuyfoyCCJTErvAGLJVM0U4Cg22H2icvPCIhRIZI0Jj4+2LgXc+iZem78HM3p14tW8Szhm3AQBw3945wutMq13Esex6AL3QCOwEEwldoAHCig0gkeAQyWw5f2yBAeTI/WWTmSz+DedBZOgIlJscOdKMmloSBJE3SDcRLYdl4oAOljigpBfqxAEZvKlNikwnicgKJuJxuantjsen4ckLD8drul/Gn17yIv7iRydw57FZPeNrauPFj2yfBp/UgTQLNICdsQ0IUqSRkYlucolt6zMcb9rCDG9oY/tkWilrQxtDZmhLuUdnr/4dWNqAs6GNIFoI754zBEGEJXaDS/alp2xw2VsbGCVDWYZLhnLjjdr7bl6IBfuBgbYCzvvgo+guDOGUsdXZ3EhxRtfkUsxRFidM4rZrjnJK/WdU+zbBrw8NYNeLhuGRr5wL2HWn4fxqGpHB78ubyMhJbjKjt/ZITS0JgiAIIp9IVqLKbvAx84TuxqAWdmOfnzfIYn4AScHBtV8nw2blMJsXFfAvT7wWAPAHhz+Isd2mOZnt3I5fUc2QndujZydPiXsuFrlK0CPr6ejau9PUv7PPcA2MZtVMgLtm6tMc1/2bxtBMKhK1RjL155ShMrDFMrTJdJGtoc1AXiLNKG2AaCGoOEMQzUCABpfOyL5U+S9f/ktZKjQsvtgPIt6A7RD2y8a2YejR6uqgNfO24Oxp69DdNowNg+OwcqAoeZ2r0OC32aRJtvJHJTQUqCaCMQo0QLICTZ9hDI9rsSMLfK6xD2GdX0BGhRkTPiJDHB9KZKiImJtsgn0eUqQZQRAEQTQdphWqJhpMbTr4Ik1R8RxAo1YSV/vyN09dTG2N/PjOOXhlaDqmtu3GJy5ZzR0RjWkiOlObOI6fB9qa2jQFGhtTm3hMtuLcxdimIoaxLc+aCfDXTTrSSBkoCdtOSQNJ+nNC2Bb/JmzJg6GtySLNOChtgGhWqDhDEM2G5c27JE3U6pB9yRYVzwHIizImoSGOteOWa4/HrIEydrW34b0nPg0AuG/vbJgdXC5CQ/ZctqyZf/RwgpmOxSrQhC7SAPkSHb7X0gc7gZFlYcYJH5ExJNlOY9WMTW6yTGQkyE02RZolJNjnMUEQBEEQ+sQBy5WnLokDDfTWHnWmDZWpTYqPqY0hmtoaCzjD5TZ88/kzAQCfWXQfOjqGDe+RxNTGsDG18SQo0PBkWaDpM4wRyZNmAuLrJhW+hRmRkrCdKGlA3K/qz8k/F/8eGLFWzaiwMbQB9YY2GfmINDNCaQNEC0DFGYLIEU0XbWYlNGzEhiz71EJoDHdg0vOHAZUKFu7YDQC4e8+hmvexmfjIJlYZxJuJx4A4BRrALi+7D+6CA6if5KchPJK+Xx/sfk5XgQGEX5bPH0ucmcwQCymyyAr+tTFFhkxYiJ8Vrq3z4kea6SAHGEEQBEFkS9TEARtTR4+4IYs2U5naeNx6R/zX7+ZjS3ki5rZvwwdfv4Y7EnL1jOy5ranNEZfUASBcNHSMIg2QvmYK8Z59SG5mS1KY8enNaZ00oNunM7QB6a2asTG0iag+TxL26GREiDSjtAFiNEDFGYJoFhI4AqJFm/FYZyiLQkPETWj85nsn49hNZUzdBQy2FfDIvpm1I75CgydEjrKCPBZoYhZpGKII8BEEIc4h0gf7okxWhZmocWYmkcH2Q3FMRUyRkXJusmWkWZCmluQAIwiCIIjkBEwciBNtJmKKZFXdWOWPA7L50/7+Dnx7xVkAgM+feD8KMP0caZraPOLNgHgFGiA7YxsjlN7JWjepSKKZAL+kAaM8StKfkze0yZIGsjK0dQjbNjj06GQ4RJrZEMLQRhDNBBVnCKIZcWxwmZjE0Wa6G6V8JJGH0Ng7Bmc9OgEAsOKwCg7ANNE0TYDESRUk26YVCJbxZrELNK5uMMBObADJBYcMnXiI5STrg/3PYRIYKudX6oUZnlYQGQwxN1kVjShikZssI+1IM3KAEQRBEIQ1aScONNBbe/SJNkvcr5Mh6ykhMjJP+uZvjsTOylgc2bkRl52/gRsT09QWMd5MRDwWu0DjUqTpsxhri61mCqmb+pAPMxvgVpgpwYBJt9vGmemwMas1gaFNhkekGftMZZ+xidAY2ihtgGgmqDhDEE2Oq9BgRIk2y0hoTFtaff7gEW24+IPLuDE+QkNEvFktExo6LPvPiJh6i9gUaHT7Q4gNII7gSIM+uBVlQgkM2f7ghZmkcWZ5ERkMVW6y7PWeucn88xxHmhEEQRAEYYFhBWrq0WZFxXMAfv06G+OeG2mcV+3c3YXvv3IGAOCLpz4AJF49YzK1BYw3C2VqA9wLNKF1UzPRh7CaKWZhRoQ/5pw0YCocqoqTuv6cWRna2LiEhjZTj07PSDPntAEytBEtCBVnCCJnhGhwKYN96Tk7FFyizerQ3SBVCQ0eu2W3k9oO4PiO7QCAJxcW0HbaSwDKitGyiZDM4aUSGrLnJqEByVjFrpJwTLdUH4hboAHcxAaQf8HRh7ACAwhfmBExFekasM3vVsWZQdjOUmTwgkInMlLITQ4YaWaEHGAEQRAEEY6A0WYHcY02s8a2Xydg7tcpH/tPvz4W+ytdOLG7Dxeftpkbk9TUFjneTNxd4p7HLtAAdrrJlj7kWzf1wf36fDUTEK4wEy1pwGRo0+kfXW8m3T4droa2DHp0MhwizYKlvhBEE0LFGYJoVrKKNvPKUOaXzOpQNbsE5DdsO3HWuA1oL1SwujwBeydU8Eo3cMHl/L+D7c1i1TgXoRHYCQakU6CJVaRh/2WF73XYFmViFGZK3HOd0HRuZmmKMwP3vNVEhmEFX1GyL2KkGTnACIIgCCIbkkabNdxg7K09MhOHydRmTByQITO18bj169y0rRvXbjwNAPClsx7RjHQ1tcleKz5neJraRErc85AFmrSMbUA+NBPgfx1JzWyZFWZc48z2w/w7b2NoU237YGtoY2N1kWaGorDJ0MYIEGkWIm2ADG1Es0HFGYIYxXhFmzESCw0ghNA4d/x6AMAdO+bhyA3F6jue86zmFTqhIeYo85MumxxlHZ5OMBHZMVmBJrQbDPATHEB6xZqk72P78/k4v6IXZnh83F95Fhm6sZ65yUXF/hxFmhEEQRAEoSd24oA3Uft1Anr9BKhMbYx//M1iDFbacc64F3HG4u3cmLyZ2gypAyKhCjS6/TGMbYw+yX8xCPE+aZnZAP/CjBKdgdIlzozvzylD1Z9TNsYWF0MbG29bvE3QozOtSDOeBJ/xBJE3qDhDEE1AqAaXiaPNgmcoy47ZC412lHH2uGozy7v3Hor7rz0NnZUKVo6r4Iw38z+jTkCYCJWjDMlYze4S99ymsWEoN1jMIg2jz/I/3/GuuBRlYji/gIR5yfyBpHFmWYgMFfzfO79axlZkeOYmM7KINOMgBxhBEARBBCJG4oBvtNk4yb46kvTrtKNvw1j8ZOvJAIC/vOAxzcgQpjaXnp2BUwdkx0MUaIB0NBPgpoNcxvqStCgDxCvMiCROGnDpz6kztImEMLQx7Hr0qg1tIjmPNKO0AaJFoeIMQeQQ7U04jwaXwfM7o2coM/RCY0nPZkxqH0BpuAtP75+GV5ZPwbGbq4pn1kVLhdE+Ocq6lQERnGAiJe55yAKNbj+QXpHGRB/irrQJUZQBkgkMIOCyfNk+WZwZIy8iQ/w7txUZbJsXGTIsRQYj7UgzHnKAEQRBEERUMo824+kRN8TEAd7UJuvXKV8ZU78tn1f9w62vQblSwBsnPY0TjuQ1SmhTm+y57iY4I0HqQOgCTShjW2zdFAuX6zdpppiFGf54sKQB3/6cKkOb6Z6ELbKVMipDm87UFqBHJyMnkWYE0YxQcYYgmpkAzoHsos2SC41zxq0DANy3dzbKtY+zJ/7vtWivVPDShCGceP4azcUzbG462wiNCE4wILsCDWAnNoB0BEcoXK/V9G+QVmFGCf8L4yIyTE5HF5HBE0pkMGS5ybLPCf6YTyEY0SLNjJADjCAIgiC80UabJSD9aDMel7mMWISxM7k8t2oCfrvzBADAX7zhCWFcUlNbiHgznhQLNDGNbUDzaCbAXTP5mtlSK8zYJg2I401Rz2mRxNBmIkGPzpxFmlHaANGMUHGGIFoEXxeYNcGizVRLaEVU0WYjnDe+emP0nj2HHdz34tJDsGh7NwDgiLc+LrzCR2iIr5cJjYhOMJEYBZpQYgPIZ6HG55psBEaahRln95eIbZyZDyFXzTBkbjBVQUZF+pFmMsgBRhAEQRAZETtxQDS1yTBFm1n361SZ2nhUGks+f7rq9lMAAJdOeQJHHbZH8VoeW1Mbj87UZpq3WqQOhC7QAMmMbc1ubvMxsoU2s0UtzMj2ucSZgRubh1UzDJOhLaUenQyKNCMIJ6g4QxA5JUmDS53TIFQ/BPcMZZ0bQxQasmP1zOzYgyPGlDBcKeCBvYfWHXv+Z9UM5ReKB3D86Rt0F1VDJTRsc5QZOXGCAWHdYIB7kQaon9ynKTySvG9SgRGiMCPi5f5yiTMTj2UpMnhBoXKA6kSGSLqRZuQAIwiCIIickuDGnjHajOETbVaHytSmu4lq06+TZ2TfI89Nxu93HYeOQhlfeXOI1TPi63VzStl76UxtlpSEbZsCTUhjG5DM3JZ2scb3vW01U0wzm3jcqjCTNM7Mtj9naGIZ2gL36OyVv0usSDOCaBWoOEMQzY6l0NA5EqyjzRJlKIvo4od0hZrq8VPHVosuz/ZPw87ymLqxSx+ci0WlTlQKBRx72SPCOVyEhuxYjHizFAs0QLpig0ec/IcQH6HOGUNgACmJDHFfK4oMNt606s4iBkRWRJYVZUSR4RBpRg4wgiAIgsgvtokD2UebuWIzV6rnb24/DQBw2dTHPFbPqOLNRFOb7LnLvNUxdaAkH6Y97lKgMR1j+Oommb5JqptCnTOEZspNYcb0miT9OcnQpiPRZytnaBPTBsjQRjQrVJwhiBYierQZI7UMZUaj0HhtrTjzyD557tDyn58KAHh+cj+Oe+16i/ewERr82JDxZqrjSL9Aw47ZFml8CzUMlVCw/S8Jtj+D6d8j08KMq/vLlJmcF5GhGsNvM5EhQ1ippxqmc4BZQg4wgiAIgsiW0IkD0aLNovbrNK2ekff0fOiZKbhtt+/qGR0uprbAqQOAvm+neJzhkzwQs0gjkpVmAux/hiw0kxU2SQP8L5EpQcOUThCSnBvaGL21x4SRZtK0ATK0ES0OFWcIoonITYNLRpAMZR+hUcFpY6uzzEf2zeGOjYx96t6R1TOLLn9UOIev0IgRb2axVD9EgcZVbMBwjCeU4EgLl+v1KWKpspJLwr7U3F+MJJnJoTGJDP7zgN+WLdP3yE3mn08THj0JFWlGDjCCIAiCiESAG3yJos0YRcVzAP79Ohku/fmq/M1tMVbP8GNNpjYdnqkDQLgCDZDc2AaEM7elhcv1ZmVmAxwjoJMmDYjk3dBmGh/A0Gbo0ekbaUYQowkqzhBEjtHelONdYBKh4eoCCx5tVhR3hBMa8zp3YVbnHgyU2/DUfrXyWfGLkdUzi07T9Z7RTZx0OcriWP55wKX6IiVhWzZhFccAccUGkG/B4XptMQUG4FmY4XFxf7mKDBl5ExkqEuYmMyjSjCAIgiBaGtdos8SJA84mEFW/Tp2pjSFbcQzlmAe51TN//Sbb1TOmOWQoUxtPgAJN0uSBVtdNMTRTpoUZ2T6bpAFwz8W/AZf+nElxNbSxY6LRlY808zS0MVKKNDOlDYQyLhNEHqDiDEGMAjKJNlPiKjR4qtuvHbsOALC0fyb6Kx1QCY0n756L49jqmXfres/Y7HfJUXZ1gpn2SXaXhO0QbrBQYgOon9hnITp83z+2wAASFGaSuL90cWZ5FBniOFFkRMpNpkgzgiAIgmhKQkebOWETbcawijbT9evUoZtL6WG9Z949zbR6RjYf1JnaQsebeRRoRJIkDwBmTZREN2VBDM0EzXHVv21J2A5SmPFNGtAVF11I29Bm+zfvYGhjZBVpxqP5LKe0AaKZoeIMQbQYtjf/gkebRRMavOOjOtlY0rMZAPD4vtnGy1txU3X1zAvWq2dk+2yWMPs6wQIt1QfcCjRpiQ1G7GJN0vOnITCABHnJNu4v1XFTnJkMnXgIJTIYMhHRyT3aiAyfXlZI7PxihPo8JQcYQRAEQUTGcuWqU+KAiJg4EK1fp26eZLN6ZoQHl03B7bsXWayesdkfMt4sQiw0YD8n9zW22Y4RETVNaN2U9PxJf26VZioJ+4IWZmTHXZIGINnOq6GNHfc0tMnIOtKM0gaIUQAVZwgi56TZ4DJRtBmjqHjegOfNVHTi2O5qcea5/ukwCY0n7pmL4xOtnnHJUfZxgvGkWKABwogNn0INIBcGSf7zwfZncBUYgJ8ITOz+MjWz5MepxLEpMzkpusgN/rjs71kmMkQ0uckZRZoZm1qSA4wgCIIgMiG1aDNG8H6dDFO/ThX1Y//m9tMBVFfPHJlo9QxP0ngzSMaY9kkOlSTHk0ZDA3ZaKIlmArLXTECYYlQoM5tzYcY3zowh9ueEYpxIFoY2Uxy04R5MUfE8hR6dPJQ2QIw2qDhDEK2CwVEQNdosuNCQZShXJxpjCoNY0LUdAPD8AdUd1vpJyfKfnQKgtnrmdHH1jO+kKUS8WaCl+iVhO0SBBrAXEkkFR9qE+rlsCzM2+dapuL9c48xkxBIZ/N9sp3BcFnHIjw2Umxw50owgCIIgiHRJY0Wq0dQmI7V+narx8oLNA3WrZ54Ujvqa2kxj+eeBUwfEQyXESx5w0RbNoptcrtekmaIXZnh0hRkRsTAjFg9tovqAfBvaZMTv0akjaaSZ+NlOhjai2aHiDEG0IKlHmzF0X94HSSY0jhqzAx2FCrYO9WDTkBiGqlg9c9+8kdUzl4mrZ1TYCo0k8WayY0CwAo2LGyxkkSaPgsNVXPgKjJKwz0bwZeb+MhFaZPDoRIZujIrAuckOkAOMIAiCIPKB9gZdpMQBI7bRZkpc+3XqVs/I5lby1TOXT3tUsnomDVNb4NQB2aGSsJ22sY0fmzfd5HpdocxsQMLCjO2qKt3vk2+cmYw8GdrY/RbZfRiOouI5w9HQZhNpZoQizYhRAhVnCKIJyHW0WeIMZdk+tdA4tvtVAMAL/TNQv4xXRL165oQz1wtjVZMn3xxlFycYj22jQsXwkmRMKLEB+AuOLESHz/vbjE0iMGTjEhdmVPjGmcUUGTaZ6GJ2unijIaXc5N7aY4xIMx5ygBEEQRBE+mSROMBwShxQmdpso6FtYo7q4VfPfKWh94wKH1ObKt5Mh6epTXaoJGy7aqZQRRp+fFaFGp/3D2lmAwIWZmQR0DaGNpkW0sWZQRjHvzZkYSYFQxtPyoY22WctGdqI0QgVZwiilWgqoSFOCuyExoKubQCAFw/McLq8J+6bh+N2dKFSKODoyx42jE6aoyw+T2mpPhCmQBOySCO+JpbwSHJ+26JMbgozPEnjzCDsl72W3xbH+aDvEzUyRnec4ZmbzPAUG9TUkiAIgiBagxCJA179Oq2izXhU8x2VqU02DpDPw9SrZxYt2C2cx8XUpirA+MabyY4B0Qs04jhGGpoptG5Ken7fogyg1kyZF2YgPBcNbfxxnUYKRcqGNtVHS54izTjSiKskiLSh4gxBtChBo82CZyjzqG6uyoXGoZ07AQDrBovcOMBGaDx742loq1TwXHEQp1z4svB+SXOUbeLNdOePuFQf8BMbMQSH+NpQ/yW5Bh0mgVGS7LeJk/POS+afh4gzkxVqYsSZmUSGLDdZJzJEHHOTKdKMIAiCIFqSLBIHrLGNNjP262SmNp2xzcYMI+eBZVPwu12L0V6o4B/enMTUJqIytcluerumDgQs0MQytuVBNyV9fx26f4eSZF8iM5t40KY3p4hN4dA2aWCUGNqyijTTfHZT2gDRClBxhiCaFBfHQOJoMxVBMpRF9BOHmR1V99bGg8UZe559aDaO2zYWAHDY2x8zjHaZTCWJN+MJ7AQrCftkE2A2VoVJbADZL8V3wUVc+AgMk/MLSJiXbFuYyVOcGY9OZPAxhSaRESA3mUGRZgRBEAQx+shD4gBvapPFCQEwznesEJ3zZlPbX/zmDJQrBby1+BROO26H8PqQprZQqQNA7o1tQPNoJsD+WkOY2dhYHq/CjKvWluki3zizEPgY2tjxgIY2BkWaEUQqUHGGIJoE6waXCSJz8pGhrGdSez8AoFSWjTcLjceuPQMdlQpenFDGOW8Tb6gmERoh4s1kxwBvoQG4ucHEsQxbsQHks1Djek2mokxJst9WYARdli8iFmZkq7maSWSwfbpijSE3WcY04dERijQjCIIgiNYis2gzRlHxvI6JkOsnfb/OeuxX0SxdPhE/2Vbt2Xn1RfdLRriad2xWKYj7ZJgKNLDYLzlUkozxNbY1s7nN5bp8zGxAyoUZn6QBkbwb2nT3UBwNbYwEPTp1UKQZQcih4gxBtDAyoSG7eRgt2qzI7Ss2DhvBXmiMazsAANhb7uLG2wuNFcumY9Gr1Zu5xYuWAigLI3xzlPnXiK4b/vUpL9UH0hcbjCxFh897ZyYwxAG6ZfmyfSr3F79P93uaV5EhCg2P3GRZpFkgKNKMIAiCIJqDXEabpWBqG0F2A9dsavvyr07BQKUd541/AW84dbPle9kUXESNlDR1gEecRwco0Lga2wB/c1vausnnvW00U0myX5UyII5NVJjhcU0aAHdc9zsJxfGk5MDQVpTsczS0BYk046FIM2IUQMUZghglpBZt5iw0ZBnKasq1j60CKvAVGvd8/0x0l8tYPbaC11/5nPb93HOUZc9NTrAIS/XTEBu+hZqQ4iPpeW1+lhLc//14vAszSd1frnFm4j6Tc9GGJCJDR+Dc5LQjzTjIAUYQBEEQYbG+YZdltBnDq1+nbF/41TMr147DD149AwDwD+ffA0D8d7UxtZnizWSvCZk6IB4zHCohnLENcNdMQBzNFOK8vpoJsO/jk7gw4xoBbRNnZtoW96VhaGPHIxrasow0o7QBYpRBxRmCaCKSNLi0JVuhIXOF1QuN/nJ1ktFTEMWBvdBYt6qIo9ZNrb7qnGdRaNOtnlHtd4030znBZO+RsEAjO1xSjNOJDdVrGD6Cg0dVXLH9zxeb6y7BXWCI41MtzNi4v1SY3F9ZiAxZbrIoMtLNTaamlgRBEATRmqQabRasX6csskiFn6ntKzediH2VLpzU/TIuO2+D4T1ixJuFKtBoCJE8IBvP42Nu40mqmXx1k+11lxT7bc1sQITCjAqdSY3HlDTgZmirVL6iOepqaLNdOZfQ0MYIEGlGEIQaKs4QRCvC3QRMJdrMNkNZib3Q2DhUPeGhnTtre/yExu+/cwYmDJexdkwBb/zYU4brMwkNm3gzCM9dl+oHLNCUJONUYkP3Gp6kgiMNXMRFSXEsmMAQB4Ralh8jzswXU8FU1sjShkC5yQyFyNBBTS0JgiAIorlwWakaPdqMkTjaTCT86pmNW7rxH6+cAwD4uzPuQltiU5s4RhdvZps6IDsGNM6vA0VDJynSAPnXTID9NZbgZwDk8erLKe63iYA2JQ3Y6CR/Q1uh8LfKYyOEMrSJeBraAkaaeaUNcIY28TOcDG1EK0HFGYIY5QSJNmMEy1CWUZ1wrDpQrQSd0LMBSYTGllcn4PDV1XMdOHk5OsaolyuPbxvAku5NeNekF/HpqQ/jrw+5B1fNvBN/e8ht+Nz0+/HOSc/iiK4tGFnqH8MJJj5PUKAB3MUGe43sdSJ5KtS4XEsJ+p/PJc4gSGHG1/2VNM5MJPSqmU7hmFhElYkMkQC5ybqiMtKJNCMIgiAIIg7aG3f8CtY8RZsVuX3FxmFVJkLdr9OEn6ntqpuOR6k8Dkd1bsBHLn7F8B6m1Qi28WYxUgdU5+AOuRjbQhVp8qCbXK+lBD8zm/iaoJoJCJs04GpoU48Ls2pG3K9DZna1+IxIoUen7jMVAEWaEaMSKs4QRJOTtMGljGyizXRCY2RC8uC+IwAArx//AtrBnFt+QuPX3z4dk4fKeLWrgLf/4aMAgDaUccyYbXj/5Bfwb7Nvxx3zb8CjR1yH6w/7Fb468z784bQn8Z7ic7h00ou4rPgCPjrlSfzdzDvxy8OvxS2HfxdXFJ9EJ4Zh7wTLuEBTkuwPVaQB0hcdPu9Xgrko4yIwohVmbN1f/D7bwo5qX2iRoTqua2AJjESaaURG5NzkGE0tyQFGEARBEPkh82gzhpOpTXWjVtWPgj9uR2lXJ76x4lwAwF+95i50dtqunuFRFWRkJiLX1IGABRrV4ZJibCjNBKSrm3zfq4QcmNlUY2InDdgUa+SEWTUT2NBWVDxnOPboTAKlDRCjHdcOvARBZMxbKhXcXCiYB94J4AL14cV4BstwfN2+hViFlVggHT9j/lpsXj23+iW8rlD9Uu5D9Ut6I6pf2ptQ/RIXM22LaJyQ9YCbS9VtYGRi0TjRv3vPsdg+NA5zukq4ovg4flQ6URjRAdsJ067SWMx6Zh6mTV+LJXv6cHHvPhzfsR0T2hvfd+PgWKwcmIz1gxOxfagH/ZUudBTKmNy+Hwu6SnhNzwbM6yrhrw65De+ctAyf2/BWvDJ4SO1a2EftIKqTJ36fjP21fwM2nt8nPt+FkQkXv19xWkiGlCCfkO2FfJkz/zooXqvCNPk3uXVCCpWS4bhJbIkYBYY4KMSyfJl49YkzS0tk6FbN8Pt0xZr4uck6qKklQRAEQbQ+S7a/gKVTjqnbdyRewnIcJR0/F2uxFnPd34hpp+lonOcWoZmvivpJ3MfPEXkt0gF1MUWto77+i6PxqT97EPPat+KP37IC//xz8d+B1y0y/cPv5/eJY3Tj2TbTPLz2YWNsNJN4TILscAlqzQSodVOJey57vQqd7rFZ4RBKN5UsxrgUZYCEhRnbAh0kx22SBnSvV+3TG9rUuilHhjaGp6EtzUgzgmg1qDhDEKOAwj1A5Vz18dSExjjob3gDMAmN/spYfGfb6/DlQ36JP59xO/aWu3DTrmNRLwBERoTGhLYyju9+FSf0bMUpPZtxwm+3oKdtuDpszCYAwO7hTjy5fzoe338Into/HSsOTMPu8hjuXIwRUdFTqODtk57Hp6c+gmO7N+GH867HB9degZcHZmJkstYJufgQhQZP4AKNakgJfmKDvZYhO4cLsV1iJctxUQUGEHZZvio+QuZEhOK47PWxRIaIKg8dGFk9Jx6Lm5tMkWYEQRAE0XrcvAJ4yxG1jVsBXGT3OpmpjbEAK7EKC/Un6IXc1CajCIv56kTUzyV5rWAD0yRMI7Ft2Zgq+/o7cPUz5+Nfl9yELyy6C/91ywLs2aeb85nMaOzcooZTmdpsCjSMyAUaILlukr3ehTQSCUoWY1I1s/HHbFdO+SQN2BjabFaLVcm9oY3h0qPTgpiRZpQ2QLQaVJwhiBagTmgYkLnAZFgJDReKqJ+MlmSD7ITG9aUzsLj7Fbx10lL8/axf4/LiE/j1rmPxbP90bB7qxoFKB8YUDmBqxz7M69yOuV270NtZwnHdWzF/zM6Gd93b0Y5n55exYg7w0O1n4J6Vh6Fcl/rIT47kTrD9lQL+r3QCbtu9EN+d8ysc3b0F/zb757j8lQ9iX2Vcw3g/JxgQvUAD+IsN/hyq82RByWFsdIEBxCnMiNF5KsFgm5msJ5nI6BAebRHz1IVf4KLkJYFyk9OINCMIgiAIIiw+iQMmUxvDKXFABTOzyZIHeFObMnGA1wSy4zL8V8985zfz8afHzcDhHZvx5Xc+iy/+cInk3DKzjWpVjWobqDe1qQo9om5y1UziMcVbQDKkVHssSl6TtrktJCXLca6aCQhYmOGR/R6HThowrappUkMboyjZZ+rRSZFmBBEUKs4QRBOiFRq8CyzraLOIQuNLr74HLw9Mwx9MvRuLezZicc9G9Q8qsGZgPJb1T8MT+2fg8X0zsHpgAt72Rz/FyrEVHH/4MpQ/f7jwCpXQaGTr8Dh8bN3b8NPDfowFY7bjD6Y+hG9uZQqPF0KuTjCfAg2QidgQz8OQnS8G4vuaMK3mUp0vtcKMChv3V15FRif3yGeii7nJst9fRW4yT1FzLEFuMjnACIIgCGL0kFq0GcMp2iz91TNDQ+348sMX4rqz/g9/fPjd+M6so7Bmo27Vf4h4M13qAE/EAo1uSAnqeaePuQ2a88VAfG8TPrrJ9lfSujAji+4LkTQgwz/ODAhhaDPFmMkwGNoY/O9lwh6dwSPNOMjQRowG2sxDCIJoBUyOBNmXJkMX7WOFzL1eVA0Wb7zKMlOBCtrwX9vPx0Wr/wT/uPn1uH9vLzYOTsCBcjsAoL/cjk2D4/D4vlm4aeeR+NctJ+GT616HM1degYtfvgyf33gWbigdiVUDRVTQjs2/OgUA8Oz0fVhypuznVU3CGvdvHx6Lv9t0IQDgg5Mfw5T2vZrXqSaOuiXbIq4OI8shJc1r9nL/2VKS/JeUJOc0Xb/ufKkWZkK4v3iyFhn8ow2BcpMdxYZrpBkPOcAIgiAIIp/U3ei7VTnMCe0KW9bPrre2zcwhzCwim58Ua4/8vKZuKqSZFx2Ev7ErmmDYcRXqY9ffMQcP71+InsIA/r9LH9Gcg2HqayjOYflt2fxUpZn457Kb9uJzT80kG1aCnW6ypaT4LwlJz+mrm6w1UxqFGVPSgE2cWSiSrJpJaGgrKp6LGAxtOhIb2jSfzWRoI1oRWjlDEKOQoNFmvUiWoaxcPSPbBsSPra3DU3DtjjNx7Y7Xgk2oCtiHCgrQT+zrefCW+bjioqV4ZvIAFrznYSx9wOR60zvB7to7D8v2z8TinlfxrknL8N3tp0PvBJOhW0Ej2rcCuMEgGVaqPRY1r3VZTSNSMo4Ij404Kin2ezm/gPCFGZP7SyRMZnLYVTPsNSqRIWKZm8yepxxpRk0tCYIgCCJ/hIo2804csEXWr5NRhGbOLNNPgN2kVVwtY149AxTw2d+dhwcvXYnLpz2Kbx53Ah5+drLwmlDxZq79Z1TvpVtBAzhFQ+uGlWqPRcXreA3SDLopc80kHk9amAG3zzdpwM7QptdMPM1raLPp0WmCDG0EUYVWzhBEkyI6BpQuMEOkjszBIHM6MA5G/pgcE7ovd6vJqGxiYTFZBlBBV+2ZrKEelPue/MGZ6KhU8NKEYZz3btmNVvub2EABN5YWAQAunsDO5eoEA+ybHorbHm4w3bASzILAZzVNWtheWwlq15f3knyXwozsPD7ur7TjzHhsRYYpKpDlJvN/9xPhJDIYYm4yRZoRBEEQBGGBzHwhW0HLzByJEwcYRZtBvBlL1m8C0K+eYchuAKtvCj/0zBTcsPVUAMA33ngXANM8xnRzW7eCYVCyT7ZtSh1QzcGBxrl6glU0QHPrpqSaCQhoZnMtzIiIvyshkwb06DVTTgxtjIiGtqSRZgQxWqDiDEGMInIRbVbk9hUbh1UJJTRkqCdDLy09BIs2VN+758KnUWgbloyyjze7Y898DFcKOLp7C2Z27JK8Try5Dsk2T4gCTQpiA6if2GchOlzev4SAAsN1ST4Qxv0FYdvk9NLtk9OUIiNhbrIMijQjCIIgiOYm9RWsPtFmDKtoM5WpzTQ/k+knO1Pb53/+WuyvdOG0npW44sJ1kteIc0yTMUgWbyY+T6tAIzuuQDesBHfdlDaumq0EvWYKnjLAb5sKM6YIaHGfiEucmd7QZkdoQxvPxMYhIkXJPktDm46QkWaUNkCMFqg4QxCjFFvHQvAMZYaX0GD74gmN275zDiYMl7F2DPDWP3zC8D6ArkCzs9yDFw5U7xIv6VkvjPHpPyM+5ws0ujzlBGLDVKQp2Z2qYeIfUnz4nrsE/fUHExji8VCFGdlKLJs85BYRGTKK5iENIkNDsEgzgiAIgiAyQ7tCNULiADN56FbkSmHaSeZkL6peJOvXKSPs6pk1G3vwby+fBwD4h9NuR1eXydRmu9/Uf8ZkLMqoQBOiSAPkUzMBgYsyoQozPDYR0PHjzIA0DW0iYtKIwtzG34fxNLRlFWlGaQNEq0LFGYJoIVKPNrMlmNDQNbsMIzQ2r5+AectnAwBKS1Zi4tR9klH2qw2W7a9WqhZ1b5C83mepvvg8sthgQ0MJDh6ZQPD5zxXT9ToJDPYCnpCFGREb91cziAy2nUBkuOYmHyIZA6QTaaZxgJHIIAiCIIj84ZM4YGvqsKYo2ac1tclW0JjSBfxNbVf97Di8OlxEb8cW/Pk7XtS8B8N2HqraDhELLe63mccHLtKU7E53kKw1k+56E2smXfyzqTATMwIahn1y9IY2XuuEMLSxv3lLQ1tR8TwAFGlGEMmg4gxBNDE+N/UyjTZjFCX7lEJDtR1PaPzi387AzIEytnW04c1/fK/i/HbxZqsGpgAADu/cphmXZYEmQpGmZH/K1CghUlHGJysZsM9LNhXyshYZPCaRwY9TFXRUvaUsRAZDVgR2gCLNCIIgCILgMd1Q1JnaDmKbOKDr11lUndzW1MYftzWwqU04e/Z14CtPvQ4A8Pmj7sSMqQcko3zizWSvd0kdANwLNIGSB2Snk1FCPnVTCXbXlamZjX/uGwFtwj1pADAZ2mwxGdpkaO6f+Bra0urRyUORZsQohYozBDGKSTXaTCc0lAjNv7VCI9zqmQP7u9B+//EAgBfm7cARJ2wyXShUN8PXDkwAAMzu3K0YZ1qqH7tAIxtjwFVwlNxOHwSX93cWF+xFPDrnl7jtmpfsuixfRQyR4bJqxqVII/tbN4iMouQ0CXKTKdKMIAiCIFqXkIkDMryjzRjMbFKUHHMytQH1q5R5ndQpHIfkmH7f935zOJ45MA8TC/vxD+9URUKHiDcD3FIHALcCjWw7srGNUUJzaaagZjYgXmFGprXFc7skDeiJs2pGdR9DtVIuH4Y2E3WGNsNnLoPSBohWhoozBNFiKIWGgejRZgyZ0Cg2DnMXGjL8hcat1y7C0bvbMVgoYMlH7lec3zxZK5XHAAAmtfdrxts4fPh9PhNYGzdYhCINoyT5LxS+5/YuyrhGxoUozPhEP8jOYzpXPeFEBj9OdmNAJzIUFCX7EuYmy6CmlgRBEATR/KSVOGBl7uitPfr261Qim0fpVtDIcDO1VVDA5+64AADwwZkPYvERu5VjR7CZl6oKNPzzLAo0sjEGXIoaJcl/ofA9t5dmYi/kMa1ICl2YkZ0rSdKAvj9nnFUzbF8AQxujqHl7laFNA0WaEURyqDhDEE2OtdDgbhZGjzYThYYOXmgotUMaQkOcCLXhpevPREelghcmDeL8d6kmE/p4s53D3QCASe0HuP22N891N+ZdJ7Ky7QBig73EZ9JeCvRfGtd68MUiJoERsjDj6/6SjTMTVmTYrrBh4/m/b8EB5ioyHEg70owcYARBEATRHCSJNjtoAtGs2K0jcb9OWR8KcfWMaJjxN7Xd9tgM/HbnCegolPHPb7GJhAbsV3Tbpg4kLdCYkgcCGNt8X1YK9J/PdXprJtdVSEkKM+Lr0k8a0JOxoY1hE2mmgiLNCCIqVJwhiFFOlGgzEfFLP5dCo5FnHpqNRRsnAQC6Xr8U7V2u+bSD2DVcXTkztm0QnRg+uN9/qT6/L1SBJpDY4F/q+fIoJL4mlcCwdX4BYQozsnPZuL/yKDL450xkiIhNbiUEyE2WQZFmBEEQBNH6pBVtlphi7VFpatPNmeKvngGAz/3iTAxU2vG6Cc/h0nM2KEbZxpupbr6rjoco0IjHbAoM4mscGDWaCUieMsA/16VL8K8PmTSgJ7eGNkZR8pwizQgiV1BxhiBakNjRZokzlBlFyb5MhUZj0eaWfz0HxeEy1o8p4O1//LDivOrJ2+5y18Hn49v3CUd9l+rz+2wLNLrCAaAWGwkUw35kIzyCva+LwEijMJPU/ZV3kSErvsbPTWafYzKRQZFmBEEQBNE65CLaLEm/TkZRdUCcN+kiksKunnmxbzz+c905AIBvnH0Lxna7m9rkmFIHTAWfJAUa2XZgYxv/8iyKNcHeV6WZYhVmZL8Hpt+N8HFmZjI0tMl6dMoIbGijSDOCcIeKMwTRAqQVbRYkQ1kmNNhkoag6aVpCo5Ftm8Zj5rJeAMD6Y9ZgxmGyIgagWqlQVrp0VJNDWycYv8+mQCMbZyM2xNclQBQeSU8rO1+QS9UVZVwFxiC3P3RhBort5CJDX5iJLTLEwqtBZBQ1lyqKDAtsRQYPRZoRBEEQROsTJNrMFVm/TobW1CabX7m48t1NbX/54xOxYXgK5nVsxd+8W2Vm0c1JXfvP+OxTFWhckgeAKMY28TRNr5t048TtpIUZGPbFSRpoakNb1j06CYI4CBVnCIJI5GRwzlBmNJHQuOnfT0VvfwV72ttw4R/dbfk+1YlcG0b+XcqVgjAm6VJ9fp9PgUa2DUQVG6rT+vwX7WJEbB1yLgIDcCvM8JgKfOI4EbPICEOGIkMlNnprj66fVxLIAUYQBEEQzU+uos0OMQ/ROuAByE1ton6KY2rbs68Dn3vwYgDAH/fehWPn71aMtF3NbRtFpZs7mwo0LskDLsY20k31Y3XbIQozMpMjG5dV0kATGdociBJpRmkDxCiHipUNprkAALpWSURBVDMEQdThGm0WjBwLjeHhDmz/1SkAgGdm7MPpl6hWEDU6wfgP2Qq3v/41thNO2X5+n65AY3KDuYqNKLP8jND9PDYriVydX4B7YcZ1WX6ziAwRTXNL2Sr+ouZSLQgeacZDIoMgCIIgckeuo80Y3v06TU3C45vafnznobh99yKMKQzhO5feDV4BqdHNW21NSb4FGvF5KGMbG9tKmgkIb2YLVZgRx7BxLvooZJyZCyka2hiRenTGiDSjtAFiNEDFGYJoUWxcYJlnKFsLDZH0hcaDv12I47eMBQBMfttjaO9QTdbqJ3WFArdyBvzKmbQLNOJ402SZoSrSqM7RTJiKMq4CA0hWmBHPEcj99ZWzgS+fpXgvBV8+B/jKuXZjpdiKDNVqOANJcpM1JI40o6aWBEEQBNHSZBJtxijWHvl5kHLaxB+QmdqAGKY2oIA/+tm56K904txxL+D9r1f9zLbxZrrXsH02q89DF2hsjW2q8c2EzpxnG4vtukoJcNfHPhHQwriDuskhzkyqm3JqaEvYo1MGRZoRRHKoOEMQLULSm31Bos18KdYelUJDsVw3VaEB3PGt8zFxuIw13QVc+pkHrV7TiYGDz4cqZeGo7VJ9ftu2QGOzXB+wExuAneBoBtFhutYkAsOnMMPjkpfMj1Gdr8ZwBfi7c4Evn4HDOnfjunm34fPTn4LoZDwoMr58DvB351dfV0dMkSH+LUfMTa4VkZN8blGkGUEQBEG0DkpTm4Eg0Wa9tUeXfp2MouoAP49SmdpsNZC7qW35K+PxL6svAAD8f6+9BRPH2xZbfPrPiGNCFmhiGduaQTMB/popiZktZGHGI2ngoG46B1YodZMtzWVo032WmQxtdVCkGUHUQcUZgiAaCBpt1kJCY/OaCZj59OEAgL5j1mHeUdsV5x6Z9E1orxZnhioF9Ffaob6JbiMgXAo0QHKx4So4+NflSXTYXFMogSEeNxVmXPKSodg2iIyr7gf+6h7g787Hl87ow4k92/ChKS/h+aNuODiyoTDzV3cBV90LP0KJjOxyk0NGmhEEQRAE0aSkkThgi3e/TuXAGnFMbV+98Ti8PDQDM9tL+Mf3PKkZaRu/GyJ1QLXPNnlAdryVNBOQzMgW2szGXsPv0/2/5XFMGmBcdX9VB/3d+coCjVk3tb6hTfvZxsEb2kyfoTIobYAYLVBxhiBamFxEm5loMqFx07+/Fgv3Afvb2nDaJ+/WjKxO8ia0VYsze8pdwMFYM5/sW9foK0YSscH2+QgO/rVpi479sH/vJAIDsF+SL3NwueQli6/VjRP3A3O/djPe8OF/xuqLT6nb/+QRPxnZ0BZmQooMkYAiI4XcZB6bSDNygBEEQRBEvoidOKAztSVOHGDoHPEAGk1t4hwsrqntwEA7/t9dlwAA/mD2vTj5mJ2W72XTf0Y3NlSBxpQ84FqksS3U5FU3mVYEmfa5mtn417gaFT2SBvjxV90rLdCcNmcDKg9MRVvhbxwLMyrI0EYQRBUqzhBEC5GF0GCMFqFRKbdhzY/PRHulgucnD+D89zyvOf8gxteKM7uHdddkWqqvumnvW6BxFRviOXhsBAd/3tDiw+ecumu2/flNzi+bJfn8Pp9l+ZCM48eOHPvWoffjmw99Hx+65pq6I91tw1jxhgdRYALEujCjwkZkdMBbZDCKhrcw4JqbzENNLQmCIAiCYAS5Edlbe0zSr9PJ1Cbui2Nq+82Dh+CXO05Ee6GC77z1dhSgmgO59J+xMS25Fmhk0dBAo2ZKYmwD7DQTf47QBRvX85p0XkgzG1D//y1mYcaiYCMUaI6bsQUPffTHwO//Eou/PNcxaSAHhjaGraFNg4+hrQ6KNCOIBqg4QxCElFQzlJtMaDx571wct34SAKDtwqfRM2FAOXZCe3XCV105Y+sEUxVoZJNQnwKNuK0SG7aOMIZtoUZ2Tt//bPERF6r9IZxf/D5VYUb1fqb9jeiKg4OvrMEPXv4fxyizkCKDxyAyfHKTNVhFM4IcYARBEATR6qSZOHDQJOIabcYo1h6dTG1A2qY2APh/N56FPZVunNK9Cp98y2rN+V0KNOL+pAUacb9r8oBvkSZNzeSjm0zXYtqnM7OJ27bxz7qxSQszgv7mCjTPfOqHB3e/8JX3NZ+hTXa/RUfgHp0UaUYQeqg4QxAtTqbRZq4Ua4+ZCQ3ZjWb5eX79L+dj+mAZmzoLePuf3a08+8S2fQCAXeWu2p4kBRp+jKlAo1qu7yI2ZGP484Qu1ITGJUbAZr9MYIQszIiEExnf236s5PwjnPL44/jG/1wtORJLZPCIosNRZJhyk3trj2lFmpEDjCAIgiByT8jEgXz26xTnVrptRnhT29pXe/D3L7wOAHDVklsxbfIB69eqsZ1nhyzQ+BjbdPsZedNMJiObTcIAYDaz8SkDIQoz0GzL9psi82pcdS+u/6u3HNz8xJu/gYG/fTC/hjZGCoY2ijQjiHBQcYYgWozcZyinLjRsVs/YZL3Ws3N7D3ruq97wfmHeNpx4lvxnn9ZRnXhuHeKvK3SWsuyYuC+G2DAdY4gT/hjiw+U9TOLCVJQB9P+GvoWZeO6ve/fOhomLJqzFVw951DjOT2SIyCIJE4gMByjSjCAIgiCINMmmXyfD1tRmo4f8TG1f//lReGHgUExp24NvXvGw5vw+8WayY/w+3U3+JMkD7HWuRRrbQk3Mgo3Le7jqwKRmNn6/S2FGZXizTxqQMWHsEN7Z9gAAYEdlPL57/AcUhRnd34FsO5KhrSg5VcaGtjrI0EYQUqg4QxCEkkTOB1FomBwZ0YSGK/ZC47fXnoBjS50YKhQw58oH0N7R6LiZ1tEPANgyNFbzniGW6pvGi/tDig3+mO2SeZn4SPKfDT5FJtO/gSgwTKuXkhZmoNhff50n92zApRNXY2Kt59G4tkGsPjBBcZ4R3lVcjZ4CO29IkcGO65bly7ZhJzI8cpMp0owgCIIgCJ5YiQMygvfrLNoMDrF6xt3UNjTcjj+89WKUKwVcOf0RvPnMVzWjkxRoVDfmVTf7ZcfE88rm/rbpAyF0U2jN5NrzRnecx2RmE7djFGZEPOPMOP75vU+gqzCMPhyCuX/0NDCmA/jyOcrx9Yh/Kz6rZhIY2op2V6mCIs0IIj2oOEMQo4BMMpSTkguhYSrQFPDAv1+I8cNlvNwDvOtz9zWMmN5eWzkz3CUcUU8C/Zfqm8aL+23FhqpAEapQEwubnGVdUSa0wAixLN/s/jp6zHZcM/dOXDXrUdy54Ff43PSluHvBLzF/zG7jawGgqzCsOeojMkREkSGtvtbvTioyNLnJFGlGEARBEKOT3EWb9dYebft1FiXnsOrXmWT1DI+9qe3uJ6fiexvPBgD81/m/xiRNz85sCzQ2yQOybZl2YON0miiPusnluI2hzzZlIGlhJkCfGY5zl2zHR2c/CAD4o6FPYO8h3z7Yg6a+QKNbURZy1YxuG6ka2ijSjCDCQsUZgmhBsow202IrNGRkIjTMrF9ZxMylCwAAK496FUeduKHuOIs12zLUg+RL9U0FGpfl+i5iQzy/ONZWcMQWHrbvoxsTSmDw+3XFM99l+erfnRkde9BWqD4f2zaEj055EePadMXAEa468d3YWR6DsCLDZtWMwgFW1JzWJDI02IoMHoo0IwiCIAjChtT6dTKM0a/8PMviJm8UUxvwuetOxstDMzC7fTv+7X2PmC5aQ8wCjW6/r7GNH+uim2Lhos+SaCagUTOFSBkQx6rOodsvHqunq2sY333LbWgrVHDd8Hn47d/vqR646l6hQGN7n6FDeMy3oU0GRZoRRDyoOEMQhBaTAyJohjKjWHv0Fhpprp4BfvGtU3DUnjYcaCvg2I/dBxTKB49Nb+g5E7NAw+8zucHE19mIDdk5xPE2QkImCFzcWi6vk51DRlKBwc7BH+P3yQRI+MIMMIQXDxRRrv3p3bJrLlwYeMci4MvnWY4OvWpGkpvMU6w96oq4PL21R0eRwUMOMIIgCIIYXYRIHLAlUb9O2XyoWHu0NrXx2/FNbXv3deDjt74JAPCBGQ/hzWdu0Iw23UgPWaBxSR4A5PrAtkijGq8bF1o3ubyviE4Lys7BcCmemAozPr053fjq59fjyLb12FyZhD/59+PqD9YVaM7iDqgMbSqDm0gkQxtDNLRpCNWjM9RnJUG0OlScIYhRQu4ylFVCQ0ax9pjT1TNAG5Z99zx0l8tYMa6CS/94pNHltPZqz5mtQ93c+CwLNLr9gJ3YYOcwCQ5fx1eo1TY253ApyugEhs2SfEiOqfaJ16h630Y2D43F7XvmAAAunmj+W7zildfh2Jcuxzkr34afLpsG/N25nNCwERmRV80Yi7R22EYuUqQZQRAEQYwech9tZqJYe2yS1TN3PD4d/7WBxZv9FpMm9Guu2X6lgxzbAo3smHhcljwQqkiTN90kwyVFwcfMFqIwo8P+d2nJX87FZ9puAgB85oGLsK0kxpOjVqC5h9NNNvcOQq2acTS0qe6z5LRHJ6UNEKMRKs4QBCEleLRZb+3RJDTEDOUmERorls3A/OcPBQCsO6EPhx2zFZ0YxuSOAwCALcOKZchSkhRofJfr+4gNdh6V4OBfl8USfRWDkF+3bVHG1/kVMi9Zxsjx/91+tGHsCGeP2wiggK3DPcDfP8QJDZtml64iQ7VtEBmMFHOTeSjSjCAIgiAIF5z7dSaNNmMUZTvzs3oGAD533SlcvNljhtEhTW2y/aoCjU3yAGAuTvDndtVNsXDRZ65FGV8zG6D/f8P26/5fJI8za//LM/DfY/4DXYVh/G7n8bj+jkOVY3HV/ZxuOr22M8NVM0kNbYF7dNahMLQRBFGFijME0aKEvBkYJNosKUXZznwJjZ/+y5lYsA/Y19aGk//wLkztqDZgH6wUsHN4jDA66VJ92T7b5fo6sZG0SKMTHOI5XJfaJ1mir7s+1evFfa4CQPb/w7cwI6L//Xm6X7cUrZ6PTXkBE9u4pqxMaLSzKUJIkSH+bSpEBoOJjKLhrVQkyE2mSDOCIAiCGJ2EjjaL3q+zKHmt1OySD1NbNd7sEgA28WZAOgWaJMY2QK2ZkuimLDST6fpimNn457ELM3q+UvgpXjO8HDvLY/GJ688CUFCMrP2eX3U/8Ff3cbpJRU4NbRqSGtpsPiMpbYAgqlBxhiBGK1lGm7Wo0CiXO7Dqf87EmHIFL00YxruvfBwAsGVoLCrSiV3oAo1sv2zSqzqX6jpcijTsfC7FGhFfIaG7Dt37mPbLBIap2GUjMMT9upUxbu6vNpQb9qnoaivjtLGbUPc7ftUjwN8+YHhlBJHBKGreViUyNNhGmpmgSDOCIAiCaD1iR5tZ49uvk6E1tdia2ti+uKa2Ox4/xCHeTEboAg2/T6eZXI1tuv38OfOimXQFGRs9aPr34o+JY9IozKj11slH78QX2n4CAPizxy7B2s2q6ofwt3HVQzXd1ISGtt7aY056dFLaADFaoeIMQYwiXG8aBos2y4XQsFk9IzummkzJ9z/7+FwseGE2AGBg5mYAwKYhna0lrQJNUrEB6CfltjFiPsLDBvE9TAUZW+HkM/kPUZjxd3/1FIZw+/yblcc/t+F0vPuV1+PPN5yGncNd2DPcgdUDkxWjUxYZ/GHx717W+JYnYKQZLzIo0owgCIIgCB+co810mExtMrxMbaIeimNqA4DPXXci+oamc/FmOqOSaQ4co0AjM2OZ3gswaw1bg1taukmFq2ayWWmkMrOFLszIUB/v6hzGD95xC7oKw/jtzhPw/VvmGc7FMGmkTrSioY0izQgiLFScIYgWJla0WXMKDRGZ0ADkQkMnPhr5yT+fjqP2tmHCnur25oZIMxHzSgj18TTERijBwb9H6P90mK7LVmCY/s1snV/ifr/CTHV1TOPvyuT2fszsVP8/mNA2iGf7p+I3u3tx1sq348yV78DKgSI3Isciw4WYuckEQRAEQbQ0rtFmmfXrZBRrj16mNtl2yNUzqniz7rp4szedsRluBZrQpjbX5AEfY5t4rJl0k6tmguSY+NxU0PEtzLjp63987zM4tms9tpYn4KPXnQ1jnJkR23HNbWjjoUgzgvCHijMEMZrJIkO5t/boKzSkhFg9I3OCyTA7wSrlLjz9H+fhkN3VaKlJJ5RgLriIuCzJ9inQiM91k2p23Fdw+CyvT4ptQSapwBCdX/yYkIWZEQ7p2Isfz/sdnj3qetxy+K/xiSnPodh+AAAwu2MPfj//19LXAcDfvHoybtp1+MHtYbRhEHzx0PZ3X4fsb0w8r0FkMIq1x4xyk3ko0owgCIIgCBty0a+TYWVq6xS2k6yecTO13f74tJF4swt+i0kTZJqDJ1aBRmawUs3zfYs0SQs1MbDRbDaaCTD/O7FH3b+1uF/cJ3svEbfCzLkn7sAfH3YXAODTD7wJr27p0o4fobUMbTIo0owg0oGKMwQxykgSbRYkQzkpxdpj4tUzpskRYFOIUbHy2RmY/cp4AMDqeftwwjnr4b5UP3SBxsYNlqRI49qYMhS259YddxVaKudXrMLMyLE/mfY0FvdsAwDM69qDP5n+DG45/Nf4wOSX8OEpL6FNZfQCcOPOBRistKsHHCSEyOhBo6iwFBnaYqyG3tpjgs+bkJFmBEEQBEE0B7FuDGbSr9N6HuW7eiacqQ3oOBhvdmj7dvz7+x7VnI+RtEAzKHkN/zpTUcH2PWXv4aJXstJM4lgeVVEmKzOb7P3sGTduCP/7xt+gvVDBDVtPwQ13ztaMtr03kOKqGUax9pjA0EaRZgSRHVScIYgWJ02h4RxtlonQUImI8DnKB1ZX33fLhAIOvfJB9IwbQNgsZcC+QGNyg4Uo0gBuQkLVxNL1P9v3EBmCWmCI4/hjLs4v8ZjsnC7HgOO7twIA/nzDafjCxtPwYn8RE9sH8cUZT+HKyfrq6wnd24Q9sUSGuI/HUmQwTCJDQ9aRZuQAIwiCIIjmJbNos6T9OhnF2mPdPEs16XJZPSPD39TG4s3KlQLeN+NhvP/162BOHUhSoGHHdTogtLEtiW5KQzOZrse2KJOmmU08btLKjXzzfU/h8I7N2Dg8GZ+89jTNSNPvfUarZiIY2ijSjCDSh4ozBDHayTLazJdi7dEoNGTL9Hn4SZINbjnKh3T0AwAGx1awZkwB7/jLu2pHQi7VF8eoCjSqY+Kk2TQZ9hUcaS3Tt3lPnbhwFRiAWmCwYy5OO/P/7+FaBvKO4TG4eVcv3vXKG/C3m06SvPcI+8vV1TJXFPmZcEyRYbFKRgUTGUXL8ZFyk3ko0owgCIIgCBd8os2C9OssWlwcAHXiQJLVM/6mttsfn4ZvvHIBAODbp/8KRx62D8lioQF7XWUzN09qbJON48mjbtIZ2Wx/fpmZDVD/v0mvMPOu817Fx2beBwD4+G1vRGl30nsCQPJVMzz5NrRRpBlBhIOKMwQxCokZbaYVGrqooeBCI83VMzIqmNGxDwCwb9mRAIBls3fj9e99zvA6IL0CjWq1h26yLRsvjrUVHUnEh+u5bMUFJON0AkPn/BpS7Bff3+bYyJjtQ2MAANNrxb8y2vCJKc9Lxlb54/Vn4oNrq2L34glrMbYgO28aq2YUwp89LWpOPU1zTCRQbjJFmhEEQRDE6CKraLOg/ToZWrOLrl8nv5+f19lEQvO4mdqADnzhh8fjkf0LMKHQjxuv+B26uobhnjoQskDjY2yTaQgXcxsjK92kuzZdQoDsuKnIJTsmO6947P9v787D7CjL9I/fp7uTTtJZTkICCVuaJSyyC7LIDkIU2REBRwHRUUedcQR1hHFkdAaRGeenjuO4I46IgOxbZJEdQdlli4QdEhJISLqzdjrd5/dH5ySV6lret+qt5Zzz/VxXX+mu7VQ3nVD3eZ56Kmi9X/T6zTdZqZ8eeL0k6QevH6yb/xRVwTBtaPM3sXk/D3puU13UGOgIDdrQBmA4ijNACyh6hnKg7rV/ZhY0vLK6eya8E2xcW7/GtA1Ikq69bhftNneoaLT88Ge0+baL5f5Wff829QvfuLFdScOGd5skhRqvLG7LX6Nk4cI2YPiXe9e5L8xI0gurJ0iSduhcvG7NyMpgwPZDjh3/iub3j9Frq8dqZNug9h7zlsJ/d5OGDK+ou2YCOsC8/H+/p/jWFz032QIdYAAANJmMRpsFMnl+nk1Tm9XzOvNuahsyMNimU359pBYPdmm3zlf1vdMfX7umqAKNf51JY5t/W+82aTOTlH1u8osrMHm38+/j3y7qv1Wawoxd4aatbVC//didmti2XE/2Tdc5v94jYuuwxrOwSQP+baMKO1F3zdRZNrTZKOgZnUwbAIajOAMg39FmaYJGEKdBw9/NEhY04gs5m6y9a6ZnYKT6ah26+luHaKtVNfW2t+k9X7xLlbZBZV+g8W4T1w0WFTbiijRB2/nPKa5YkpbJ8W3ChWnAcFGYCTpPv/X7P7VqkiRpl9HvrFt23NIjQ4/2vnFzdcfWN2rLkcskSdt0LgvZMipkBBVpvOJuzU9x14yJ7rV/Zj03mQ4wAABaSpo3Er1vYGbyvM44Tu6e8cqmqc27zatvjtbf3nOcJOnvNrtXHzpk/tp1ZSjQ+NdHZab6tkmb2/LKTHG5KWzfsO1smtnqUwZcFmbinX/qbB3Q9Vctr3Xq1Cs+oP5+07dFTSYNhE0SML1rpi5hQ1vYSDPLhjae0Qnki+IM0KIKG20WxDRoVNf+mVnQSCI4aNRHmi1YM0ZSh1YtH6mXfnGQRg8Oak7XoD507v1rtyyyQOPf37RAUd/WNnD4z9HlRxiTAOTfNuzrqIDhXW9yx1HQfnHHHfL4yqGr711GLVJXW7+kQR351T/oI19u32C7g184Tie9MlOPr9xII9vW31nzzNrizhDTkOGXQcioq679M4e5yV5p5ibTAQYAQHPI441C66a2bkcvXNKmtg116Op7p+qHbxwsSfrZgdere7P63R9ZFGiCrsPj7nyxaWwLO55326jcVERmSnInUB7NbEly8IYO2mOxztv2NknSFx85WrNfHhOxtavncwbt3xwNbV6hDW0AYlGcAVpEqWcop1WqoDG0fJOOoRCxYE399Tr0zMObaItHtpYk/XXbBdrvgy+vXZd3gcY0bPjXBX0ddkzvtqYFGxfiXs+k40u+r00Chknnl/+4/m3MAsYb/WP1yupxGlGpad8xC3TiFx7WU5NWS+3S326xt67v6daHXjlSbw+M1nN9E/U3rx2hD796hL7z1m765OuH6aEVQa1TBYSMuuraP7uGr4qUcG6yFx1gAADAioPRZl7WTW11ps/rbICmtiBf/NUeemLVdFXbluuKv7lNHR3hI3zXS1Kg8W/jz0VRr5E0M5nkpqwlyUze/cK+TtLM5t8vyfqgbTY0cUK/fv3BG9VRGdRVi/bUz27ZMmLrqOdoxm3r17wNbUb/1nn+zaShDQhGcQbAkLLMUC5l0LCfo1y/c+at/g27ca753z2126JODVYqGnX8w9pk8/qIqawKNGnChsvA4d8vTQAJO0bYcUzDhX9ZWDed/9hB62wLM3buXz5VknTCjDl6effXJEk7zJ6mB+7YWufO31fP9k3aYPunV22kixfvoj+u2NSzNM1scAchI6ioWl37Z1jIiGI4N5mRZgAAwEaWo82Mda/90/Z5nX4lbGobrkP9a9r04cs+oN7aaO09+kX9x0f/snadbWYK2se2sS3oNWwb22xzk23eCeMyM8UVZUya2aTw/RSwXgbrg7bZUEU1/easu7Vl+0K9umayPvHL90qqRO6zXtTzOb3bhD2f06ucDW1xaGgDskVxBmhhpZih3L32z1IHjTDhc5Q36lglSVo0MMq3TUU3XXC4pq2u6e0RbTron25TW9vA2vW2b9abFBvCtgt7vagL7aD1JoHD5PuKCiC2gSTqtaPChWknnPc1FLLOz23IuG/50F+Y3dsXqF/STj0duuKi94ZuP1zYqDIpeGZySUJGRnOTGWkGAADqSjnaLIhtU1ug4pvawsx5fYw++8AxkqQvTr9TR+//1to1RRRo0hZpwpZ59y06N4W9TtSyuKKMItZnn5nqvn7qc/rA+Ke0qjZCH772WPUuT3InTFBhxuT3ufwNbYw0A4pFcQaAU06CRl2uQSOsIJOsE2xC+9CF45KBTt+aDvW8M0YLfnWAOgdrem7coE79mvdqJu6uEz+Ti9qg7WzCRtB2URfrJqHDNHzEMTlm2HmZFF1sAoZ3nf91go4ftD7qOOs9smKy1rRLk5ZKuy4Y1B3fOkK1yP+l285MjpNByPAzCRfda//MIGQkRQcYAABNznLiQGbP67RVXftnyZragg1t85s/bK5fvLm/JOmXh12vzTdZtXZ9ngWaqMY2/3HSZibvMbLOTGlyk+mEgeKa2eo+sM/b+pfthv7SnvPkMfrzs9WIrZOMM2vuhjYvRpoB2aA4A7SQ2DcNi56hXL94MJiPuoHq2j9TBQ0v78PObazfvtrWJymoODPkLw9M06Z/3laS9NRWi/S+U5/1rM2zQGMSNmwDh3d5XMdWWFAw/QgT9fph35NJUSYqYMTdyZT0v9Nwx3/lT3p6y6Fb8be6ZroWLYi6Um+QkFH/eopvO//cZMchw4uRZgAAIIirNxYTPa+z3oTSHbFjKe6ecT/eTJI+d8leemb15prctlS/Pf0PamurP38mywJNVo1tUYWaqNyUR2YyKSKZZsE0zWxB2yTLTNOnrdSvj7xO7ZWaLn17X/3v9d2x+6xnMs4sTg4NbSa61/5p2NDmldVIMwDrUZwBWlxeM5Qj/6ffbfBiud4945VsjvKE9qHiTM/AyIDzGtr3xh/voV3nj1WtUtHyI57Wtrss8myTpECTVdiob5MkcPjXmd5ib8Pk+GnCheQ+YARtE7bdhg487gU9u8N8Pbn1UHGme97K2H3WMw0ZUYXJkoWMBBhpBgAAohQ12sxYoU1t3uvA9E1tcfpWt+vDlx+l5bVOHTDmr7rwo0971roq0OTV2ObdJu/clDQzSXZFGZtmtrDtotYHbTPcyJEDuub032ujtqX6S9+W+ttf7BuzR9ykgaBlUQ1tHb7tM2po8xdmEza0eTl5RmcMpg0A61GcAeCccdAIGm3mDxomo42kDO+eCRI/R3ldcWYw+M6Z+nZXf/Mwbb1SWtreph0+d5fGjF/t2ca2QBO0j812Loo03m2iAkVQOEj6YXJ8L9uijOuAEbRN2HYb2updizR47GMaqFS0fNLQ84z2HL1QIysDIXskDRnyfF6ikBGEkWYAACBvRY026zY8P2l4U1ukIprazMebPfvyWP3Dn4aeP/OVre7Q6Ud6fzYuCjRR27lsbLMp1ARt4zozRZ1D0Dm7bmbLpjAjST866xG9e9TLWjLYpZN+c5RW9dmMgI56Pqf366AiTdg2Upkb2pw/o5ORZoAxijNAi3E92qw1g0aQDYNGV9vQReOyweirqr4VI/XU9w/ThIFBvTJKOu5fb5U06NnC5DZ0v6TdYPVtk7x21G3zWXaAhb2GX5JwYRsw6tvGbZcsZHSN69NO/3CXetrbtNWqmi757yP11ppRGtU2oD1GLwzYwyZkRHV/Be1TUMioF2+7h69yOTeZkWYAAMAvi9FmTp7X6W9qC2tu8zfFZN7U5na82cW3bqkfvH6IJOkn+16nA3Zf7NkmywJN0La2jW22hZq8MpNNQSZJM5tk9nM2nQJh5m+Pfk1nbfKABmsVnXHn8Xrh9TEWe8dNGoi66yvsTjIa2gCEozgDoBwzlINkHjRGKDxohD1jw7vOv8/6bUesvZNh9WB7yDHW7/vy7Elqv2EPddRqempSn0499z7fdkV0g5mEDe+2YYHDJKhk0QkWdg424UIh26Xp/BraZs/RC/T/pt2rq6ffrIs3v02f2+gpvavznZBzkqRBHfuN2/TKKKk6MKhn/vtwLesdpQeXT5Uk7TdmfsS+UrqZyY5DRp1pyIhS4NxkOsAAAGhuSe+I3aDJw0Li53U6k7apzXvHtSm7cWj/eMnumtW7i0ZV+nXNMVdr681XeNbmXaCpb2ubm/zHLDIzRZ1DXIHIpJnN5ZSBsG03tPcuvfr+u2+QJH37hffphgfiZv+5HGfmVYKGtgBleEYn0waADVGcAZAJb8dFOYOGf1lYAKl/bRc0RlSG7n7pj/1nduiC7d7rZ2jrJ7eUJD2z3Vua+TdP+7bLI2ykKdJ4t48KHSbHSSLq+CbdZyZ31ZgGjPq2wducOP4F/XrL2/X+8a9px1GLtW/XW/rc5Gd0Vfdt+vUWd+jQrrmSNrxgPe3c+/XUpD511GrS9XvqpWcnSZIeXDEUNvYds8D3eiUOGdWQUwuTw9xkFwgZAAC0mJCJA15hEwdSNbV1R+zob3IJe15n/WunTW1B3N49MzjYpg//5FA9vXoLTWnr1U0fu1Hjx3qvu5NmJpvJAybTB2yb27z75ZmZ0uYmm8xU396/bdw2UdtuaPONV+m6467W6Mpq3d67k7522c4xeyQZZxalZA1t3Wv/TNnQxjM6gWxRnAFaUN6jzYx1R6xzHjTq/BdIcUHDbI7yuuJMrc23X7hrvrePdpk/VrVKRYsPf1a7HzjXt4XLAk3ajrCkhRrv67j8sH19/3n494vaxv86cduu32Zax3J9bZOHJUnX9XTrM28cpPPn76Vbl26h/lqb9hyzUD/c/D79aos7191Jc+TfPKdnthsqvmz15Ba6/4Zt1h3vkRVDfxF2HLVEI1R/7kzSkBE1si/nkBGl22AbMdIMAAC4l8UbjamaSPzP60wt7lrOfVOb2bXr+u2WrejQBy85WvMHqtpx5Fxd/ak/qL3NdCy0ZJ+DwraNK+jYZiaTaQBZZiab3BR1vkFMpgwoYJu44643ZtSAbjzrZk1rX6y/9m+qk39+mGqqxO63nuk4s7CGtqDndtbR0AYgHMUZAJJKOEM5l6ARFT78F1dRz9+Q/EWdenFmTa0tYFu/9fv+7vz3abtlbVrR1qaNTn9AU7dc5tvWVTdY2LZx29sGDu9+phf+tpIc36QoI7kOGKdWn9eotgE9vGKKzpu/j+5dvql+17Otvjhvfx3+4jH62aIdtWqwXe8Z87au6r5NP97tTtX2/osGKxXtvKBL135vnw2ON29Nl3oHRmhEZVBbdS5VdLiNCxnefQoOGRnNTWakGQAAsOFitJm3CSSOk+d1Znr3jHeZm6a24cvjvfbmaJ1w3YlaWRup9417Rj/8xKO+LfIo0ERtbzuBIGjfMmUmk6KMaTNbfXuT7cK29avp0k/dp907X9U7g2N1zKXHqmdpXIHQ5ncurDATtF3UGECT9yE8mrChjWkDwHAUZwDkopxBoy4uaMj3dVC48Kqob7BdkjTKuFln6Fhr+jp07wVHaurqQS0Y0aZ9vnqrxoxb7ds2z7BhU6SxufXeZCay6YepsHCRX8DYv2ueJOmKJdtKvk6uhQOj9d2Fu+mol4/SDT3TJUkHrXpLF/2spk/dVdMN5x82bB+popdWD/1Obzliacj5SnbjzIK2iwoZQcsNBiVX1/45JWqjterF2u7hqzKbm2yBkAEAQIsyGG3m5Wy0WZBCmtr814lhTW1ux5tJ0kNPV3XWfSdKkj696X06+0R/Y6DpXSt+SSYPmDR0NUpm8p5r0PcRtV3QucdtH/ezjXfh6c/ohImPaXWtXR+edZLmvDYmZg/bhrao4wQ9p9bf0BZ0LBraAAyhOAO0qCxHmzV30PALDhrLB4eWjWnrl1kAWW/hm2M1/2cHauzAoF4cU9Mx/zZLbW0Dvq1cF2jSFGnSFmqyFHQ+LsKVf5+gbTc0urJG23UukSQ9ujK8IjF/TZcu7N9N/31am/66mTSqX3rfQwO6dtptOnzsG/I/j2b54NDv1+g2/9+hpDOTbUOG5V0zXcM3kTS8+OogZHhZz01mpBkAAAiQ9Wgz66a2qOd1mja1BTJpavOKKsiYsBtvJkmX37Wp/vWvH5Ak/cfON+jYA97ybWt6l7+f7eSBqH3q+zVKZkrbyFbfN2gfk+2ith/u00e/oq9udZsk6R8fPVZ/eGSjmD1MCjNB2weNM/Nv5//7YTu9I0B17Z8FN7QByA7FGQDrFDFDOfOgEXn3jMkdAP4Lr7g5ykPbLh8cOXQ6bUEXnPFB4+mHp2nk9e/WiFpNz1T7ddq/3y5p0Ld9mm4wl0Ua775hoSOv4BH2mnFdY0kDhnnImD5ysToqNS1a06kFa8K7uTpGrtHM82/V/d1t+q/TKvqv9p20oH+0thi5XD/Y7H5dOf02zRz7mjorQ689tm3oz4Ga966aNDOT5dsug5Bho2Rzk+kAAwCgtVjdIev4eZ2Jm9pMxh55Dbt8y6+pLc14M0n6xuU76Ddv76P2Sk2XHnaVdt/efzd50gJN2L7NnJmk9JnJxRgzs5/DMQcs0P+8+zpJ0n+9crh+dPNWRvsFi2qojLrDy5+hOhT/PkPjNbQx0gzIDsUZANZczlDeQBZBY5iwB1sGfR7XvRU0R3lI/c6Z8e2rPdtEHWv4a957w7ba7IHtJElPbrpMp/3TAwHbR1+4jm/r0+FjX9JXpvxZ/7vZ7bpm+nW6f5vL9MA2l+nmra7QDze7VWdNelKbjej17BV327vpRXnYbewuAojJsUxu43cdMMJDxkbtqyRJb6+JKl4M6tR/u03Pjh9Q52BNfVfsq188u4s++PJR+vGid2nlYLt2HrVY393sj/rjttfquu5Z2m30Ig3WpCcC78Yx7f7ysg0ZdRYho/61f25y1N/v7oh1Hk7mJlsgZAAA0OJS3GEb97xOY1FNbXWZ3z1j2tRmU6CJz0xSRR//6X56cMUMjaus0vUnX6tNpvT5tk9boHFdpPHub5qZkuQm28yURVHGtjBjZp9devXbw65SR2VQly/cW1/61a4Ge5lk8rDfRX+GiivgxDWv0dAGYAjFGaCFJR1tFiZstFnQ/+zjOjbWyTxopLl7xmvDTrAFa4ZOYmrHSt82JtZvd9PPd9MOz20qSXpqxwU6/lOPBWy/4UXs5PYV+tjE53T5lrP0x21/px9sdo/OnPSMDhn7hnYYtViTOvo0saNPW43s1aFjX9OXpvxZt299hb636e2aPqJn7VFMZhObhASbi32bj7jXMQlKeQSM+j7SRh1DxZlFA52hW556/t16YpMVqtRq2uiuHfXwHUPPnllRG6H/XrirDn/pWP3vwp00v3+0RrcNaLvOHg3WpP9euLvmrRm79ihJxpmlCRkZ3DVTL8YG/f3PY24yI80AAECErN94TPW8Tn9Tm2lzW+q7Z7yfmxRkTMUXaPrXtOnYn83Uy2s21pbtC3XjWbM0elTQWOikkwfq+9vu431d00KNy9xk8lpRTH9mYfuGvbbp9sG2nb5SNxz/O3VV+nTPsh10+k/eq+HP5vRzNc7Mf4wM75qpf11AQxuA/FCcAbCBUgeNOqdBI2iZadAIH2/2Wv9ESdJWI5eEbmM6g/mqi/bTznMnSJJe2e9FHXbK7GHbtGm1Dul6Qz/Z/A+6a5trde7Gj2rX0YvUVpFe7Buv3y7eTufP30efev0IHffycTr25eN15mvv14Vv7a0/Lt9UgzXpyHGv6Lruq3Vq9Vmtf66JzUW7aUeayUfS7dOeY9JAFvW6Q0avHT9Wv6vK78NffkBPb/WOJGn7J7bQbb/eedg2SwY69T+LdtFhLx2r415+v/7ujUP0/peP00/fqW+bdJyZV2OEjCLmJtMBBgBAa8pytJmz53WWpqlNvq+jm9qC97Vvalu4ZKSOvux4LRkco/eMekk3fe52dY70F2ikbO6iidonaP9mzk3Jm9lMTJuySrd+9Bpt3Najp/s217E/OUL9a+Le2oz6fbJ7Pmzw8zn9v+c53TXjoKHNi5FmQDEozgBIJNcZyvWLjah5qrkGjfg5ys+tGjqhnUf5H0ppP95MquiKrx2und7p1JpKRT0zn9L+Hxy6K2BCW58+PnG2Zm11s/5387t1YNebaq/U9MTKyfr3Be/RwS+cqGNeOVb/9tbe+l3PDN2/YhPNWT1OL6yeqD+vnKZfL95Jn3xjpo5/5Xg9sHxTdbYN6OubPKCvTnlIGz543vXFfBzT13N1Lklu47fbZ8Tan+ea2vD/9Z7w2Uf17E5vSpJ2njNFV31/35jzrWjO6sm6Z/nmeqN/3LpXWC+u+ytoXclCRpAEIcPLOmQAAACYcHSnbarndQZpgKa2dOPNNvTsy2N10i0na0VtpA4b+6xu+fssCjRR+9tkmGbLTUkLV+bf/+Tqav3hk9dr644Fen3NRnr/L45R7zKz340NmY4zS/J8zqDjlbehjZFmQPEozgAtzvVoMy+bGcrlDRqjFTxrVooaBfX4yqGK0m6jF6irbbXSdoLVam265twP6F297eprq2jKwY/oh+++U3duc4O+vPET2mLkcvUMjNQv3tlBM1/6oD7y2vt12ZLt9fZA2IPnN7xAfmH1RP3tG0fqO2/tJUk6fdLT+tKUP4fslyRwuAgetq9nIulsZfuA0VEZlCQN+G65P+qMp/XSe16SJO3yxgRdfsGBEcdedzSDbaT4EFySkBGkXpTtHr7KZm6y9UgzLzrAAABAiDwnDjhtaqurrv0z16a28Gd2xjMp7kh3PrqRjpt1qsMCTdKskCY3ZSlJRjMZtxa1X9S5mBk/bo1u/8yN2nHkXM0fqOp9l56suW+bNIeZjDOLK8z49/f/LrZgQxsAZyjOABimiNFmGwgKGnW5Bg1/J1hQsPB/PXRh9mr/RL28uqoRlUEdMfalgH1MO8E8s5T72rXom3vpgt8M6Fu/qOnQ5W9pdNuAZq+q6mvz36NDXzxW//X27nq9f5yShY2KLl68i742f39J0lmT/qLjxz8fs69tl5btTGTTfW1DTNqgFXWOwepFmYrnjqTDTpmtNw95TgOVinZ+e7R++y+HK9n/mtOOMxuh4QEkx5BRL9JYhgwvRpoBAIDSSDHaLPOmtilhG66VaVObfOttx5uZFWjueHiywwKNZPZMmah9XWamtLnJhunzdML2TbLfcF1j1ujWv7tFu3e+qkWD43Tk5Sfr+VfDmhC9bJ8zEyaoKbPFG9oYaQY4Q3EGgJ08Zih7da/9M2iOaqZBI+pCyyxoXNezoyTpzElPqF2DMScXXqDprKzRyRNe0I3ds/SDjR7QjNdqGqxIf9q+ov88raKvb7ybrunZRqtq/vNIFjau6dlO/7Nwd0nSP2/ygDbtWGywf9pb6SW3xZcgpg/YjNo/SvQ5rh5slyR1rr2D5vAPz9aSmU+pv1LRTotH6MrzZkoBI8+GS/owy7BjhRUapcJDRoxM5yYDAAB4WE0ccCRV84lJU1tdbk1tacebRYku0NwcWaBJexeNyXGSNrdFvVYWucnkPLNpZgvS2Tmomz57u/Yd/YJ6BsfoA1d9WE+9MC5+x0STBpI+n7MurKEt6O9OAiYNbTFoaAPKh+IMgNyChrczI+iiIK6zY500d88EXg9FXSwln6N85ZI91DPQqe0639EnJj2+dql5J9iMkYv11SmP6K5tbtA3pj6ibTp7tWygQ796Zzt9uOdwXXVUux7ubtfET9+nXfZ7M+gbU9Kw8eNFu+nRFZuoq22NvjXtflXUb3Acl4HDlbThwnuMpPsPWb228NJZGdDhH56txe9/SqvbKtppyQhd89UParDfJECYFmbC5iiH3aYfFDK8XZAK+DyHkNE9fFFuc5PpAAMAADGSvBGZ2/M6g/ib2qoh22XW1ObncrxZ/Xjr3fHwZB1/y1CB5vC1BZpRnUEFGsnNXTT14zRabjI9n2yb2fxGj1qjmz9/mw7pek7La5065oYP6+HnJhjsGTaOzL/edJxZ2PM5152p4v8++GTV0JZgpBkNbUCxKM4ACJRnx0MmQcNY0B0AJkHD+3nwHOWewdH6z7cOkCT9w+Q/6ROTHlObBhVVoOkesUJnTXxGl285S9dvdbNOnzRb1fbVemN1l7791h469KXjdNHb79azb07Rw/9+pKavqmlxR5vGf/I+7XHQGxHfp90zWAbVpvPmH6AVgx3ae8x8HTu+/ua3aUdWv4oJHTav62IEgXnA6KsN3TkzdepSLVlbmHlXzwhd808fVH9fksJMmDQzk/3bZRQyojgKGWG8IcMUHWAAAMBayMSBMJmMNqs3v5g8pzPx3TMKWGbT1OZyvNlwtz+yYYHmps+7KtC4KNJ4j5VnbrJ9zXxzkySN7Vqj2z5/qw4f+6xW1EbqxFs+rPuenGR1jPWyGmfmbWjzvg4NbXU0tAHxKM4AsFeGGcppgkbkdVHc3QJ2c5Sv6d1dl7yzh9oq0jlTHtIN3Zfrsxs9rMPHvqx9x8zVYWNf0VmTntJ/TLtHs7a6SrdsfY2+tPHj2nX0IvXXKrpt6Rb6zBuH6v0vH6f/W7yTlg+uvzB8643x+tO/HaWtVko97W0adcYftef7ot6wtrkg7tfr/eP1v4t2lyT9/eTHNbLi39/mtnl/AEgbQNIcy7S4ZHIcc/XijCYsU19bRTv2jNC1X0lTmDHt/graL8nMZIchwz+SMEj38EVJ5yaHdaXSAQYAAGwUPdqsHHfPuG5q83P7/BlpeIHmxs/fEVOgyaJIkzY3JeEiM7koytjlponj+3Xn52/SAV1/1dLaKB1980d02yOmXZhJJg2YjjMLe52wZzCFaICGtiRoaAOSoTgDQFKLBY1hIp6fERg07OYo/8fbh+r8+YepZ6BTW3cu0ecnP6wfbHaHLt7iFv3PZrfrS1Me0dHjX9L0kUvVX2vT/cs31b8teI8OffEk/eO8g3Xv8s00uO6f6w0vNhfO7dL95x+lbVZIS9vb1PGRB7XfUS9HfK92weA3i7fV/P4x2nTEcp1SDXt+UNr5xlGFG1fhxPaOH5Nj2dlj5tB/l44BacclI3T9Px3luDATxuXM5IxCRr3YahgyvDKZm0wHGAAAMGT6hqSL53V6FXf3TJ3bprbwYov7As37xj4TU6CRbBvb3NyBYvIajZKZ6sezM2Wj1br7s9frPaNe0pLBLs289iO661HTO2ZMfz9sx5mFfR1WiKxr3IY2RpoB+aE4AyCUyxnKpQsaie+eCZujHD7eTKrodz0764iXztS/zD9MN/bO0JMrN9Gcvkn6y8opuqV3a33v7b30mTfep/e+cJo+9cZM/XbJznpnYJTBNyYtfmuM7vraB7Xd8oqWt7VpzYce1pGnxo1uMrtQ7qt16IeLdpEk/e2kv2hEJSrA1I/r4kGUadk+FDO7gCFJJ5/9oN7YZei5QOP62nTtP31Qq1fZzteuS9r95V0XNzM5aHld8SHD+dzkCHSAAQCAxFK8kRn3vM4NFHr3TF1cU1vS8WZJhRdoVloVaLIs0pQtN8XJtgi12dQ+3fvpa7Rr52taODhOh195mh58qmq4d9zvTJJJA2Ff09AmiYY2wBGKMwCSsZyh7FWKoDGMP2jEjXWy7wRbNtipq3t21j+9eYROe+0kHffKqTr1tZP1pTcP10/f2UP3Lt9KywdHBryu3/CLyN53Ruu2cz+o7Ze2a2Vbm96a+bRO+OyjIfvXmV00X9ezjd7sH6PJHat01Lg5SnYrfNbBI+nrZN3lNqiP/Mvdem7XuVrdUZEkrVwwWmuM7piR0nd/+bcJCsFBnxcQMoJYhIxUc5PpAAMAAIaKmDjgFTumqDtinUlTW1117Z/Omtq8n+cx3ixo26ECzXGeAs3tf/97Ta6ujjiGZJ8DbO9cybNYk+S1bMeh2XvXVst0/yd+px1GzNP8gaoO+c1peuyvQUXAIFlNGkjS0Bb1voKFhA1tXjYjzcIa2kzR0AYkR3EGwDpNHzSs7p6ps52jXBd0a75JJ1jSN+2lZb2jdP2XjtG7Fo3SmkpFc/Z+Waedd6+kwZhjRV+YD6hNly3ZXpJ0+sTZkmpyc6u8TfEm6X5BksxbttfesUYf+fat+ss2CyVJW8yfIEnqrMT996gz7f4y2SZqZrI3jBQYMqKKrh6Zz02mAwwAAFjKarRZoud1BglqhglrausavumGsm9qC98+aLl/nV94gWZ5rVMHdP1Vj3zuCr17h96IY0j55A7/a9nmn6T7hXH9XM9gR+y9SA989Dfq7nhLr66ZrAMvOVXPvDTWcG+bwoyLcWYdih/pF7Isq4Y2j6B/E8Ia2kymmzDSDMgexRkAkZoqaMQyeQPaNGj4Pw+64Ava3s8uaKzp69CVXzpKu7w2VAh4aru39DffvlUdI9KN7LpqybZaOdiuHUct1p6j3/KsSfuQyqBzyOKOmyTnmfy1q1NW6sTv3qi/TF2uSq2m7Z7aVLf8dG9JUrtRcSbvkCGVJmR0D1+U6dzkCHSAAQCARLzNHo7e0HT2vM7SN7W5fv5MsNsfmawDLz9dr6yZoukdb+u+D/+f/uZ9cw32TJoRXGYm73m4vuMmSSNb8tf+9NGv6qb3X6pq23I9smpr7fOTU/TCG2MM93aRmfz72o4zi2poS8i2oS3g7z4NbUDjoDgDILmmDBpRb0bbzFEurhNMtTb99utHaIenp6lSq+nJqct1wndv1JRpyyKOUxd8cd0z2KmbereSJJ04Iayw5rpQk1bw+UxsX6XtOxdrxsjFGt/WF7BfuoCx7c6L9N5/v0nPjRvQqMFBbXrXu3TNf71XAxoaa9ZRibtQdV2Yac6Q4WxusvffMUIGAAAwEHRN0DDP6wySW1ObyTM7/V/bTB0wHwstSY8/P057/fAU3b1sR42prNal+1+h75z5pNraTJqp0hZpypKZpHR3+CRTUU3fOfMv+vGeV2tkZUDXvvNuHfT9o7VgYWfiY5oV9fxcjDOjoQ1AchRnAGwgi9Fm3v/Z2ygmaNT5g4btHOW6PDrBgsPGVd/ZX1PvepdGDdb03NhBvfsbt2iXfeZHHMdreNi4tmcbSdLMca9pTCXuwr1f+QeP8Nec1L5K/zD5Cf1+q+v0wLZX6drum3X9VjfroRm/0w3dN+rLUx7Vjp1vK23H2QFHvaxN//EPeq2zoklrBqVL99Wt/7eTJGlNbeh/uR2RY+bSFmai9i0gZERxFDK8TEKGKUIGAABwpmzP6zR5nkUmTW0K+DysqS3sa7s7ZMKPOWTRkpF633dn6gevHyJJOmf6H/T7f7xDE8fncedIEZnJxeuma2YbPXpA13zhHp0z/Q5J0n++fLhO+sHBWrnKdMS3FJ+Xw45lO2kg6POo3/0WaGgD4BTFGQCx0o428wobbVbOoBF0YWU7Rznoc+/XLjrBwo9z+//tpP6L36sp/YOaO7JNXZ+6VzM/8mzEcfzWX3g/sWqyXlk9TmPa1uiIca9ZHEMaHgDSBBC7Y50w/kXduvV1+sxGT2vLkcs0WJMWrhmld9YMdWVt29mjj096Tld336pfbH6X9h1jWsDyGtSHz35Qy05+RO90tGnLvppe+8779MidW67bYk1t6M6Z9tA7Z+IKM6bLbWcm10UVIBOGjHpxNOrvo2HI8Aqbm2yCuckAAMCFMj2vk6a2sHV+wZlpYLBN/3Dxu3XWQx/SqtoIHTHuaT3891dqlxkmkwfqXIwWC8s5SXKTy2PVpf8euzddqfu+cIOOrz6m1bV2ferPJ+or/7ebamunDJixeTZn2nFmcc/n9C4LuZOsBHfNOMW0AcApijMA7GUwQ9mrPEGjLihU+NdJduPNwr6O6/ixDxtP3r+Znv23D2ib5RUta2/TvCOe0Ue/cYc6RtpcWK+RNKDre7aWJB07/mWLfaNEhYZ0YWKEBvTvUx/UBdMeVFfbGj21cpLOnneg9nnhwzroxQ/pgBdP0H5zTtDZ896rWb1baE2tov26FujiLe7WDza9T5uPMAtjYyf06aPfuUXP7jpXayoVvWvxSN331aP1yuxJG2w3sPZ/ucFjzUwKM7bdX6bjzMLCcl3CkBGkXqQxDBlezE0GAABll8fzOpuvqS3peDO3BRpJ+uWtW+qgqz6muQOTtE3HAj1w2v/pQ4faNm65fP6LVzaZKZ67Z9qcctg8Pf6J/9OenS9ryWCXjr7lI/rZrG7Lo7iaNJB0nFlc8TGFnBraGGkGlAvFGQDDuHoT0jZoxGqIoJG0E8zmVn37sLHgtXGadc6x2vnNcapVKnpi+hId993rNX27JTGvtaGblm4hSXrPmAWqtq+y2jdPbRrUt6f9USdOeFEDtYq+//ZuOvW1D+j3S6dr+WCb6uGiZ7BTv1+6pc55c3+9/6Wj9ZvFM7SmVtHh4+bqxu5b9PGJs9UWMYZsp73n69CLbtATk1eprVbTDs9O1ZVfPFpLFw+/QB9ZGZAkra75/9drW5ix7f4qKGQ4uGvG9dxkU4QMAADgXMZNbYG6A5bl2tRm8sxOeba1aWrzH99NgebhZ6va80en6cEVMzSuskq/O+gy/fvHnlFFthk5qyJNXtyd/6jOAf3s03/W5Qdermrbcj2+qlv7/Oqjuv2RqFtGgiQtzPi3TzvOzGFDW9CPoGwNbRbP6ARgj+IMACORb1ammKHs5e3saNygoYDPTTrB/EyeJxImePvVq0bo8nNnaov7ttPowUH9taumLb9ym448dbbxkef2j9Wzqyaqo1LTYWNfVTlDR03nb/JnfWD8q+qvtelzcw/RT97ZUTUNKOpc563p0gVv7akTXnm//rh8E3W2DerLGz+hS7a4K+AumkF96B/+pNGfuVcvj6powsCgxl+7h676jwMU9r/WUWuLM6sG2z1LTf7bpu3+Cto+LmTEPOS1wLtm0sxNZqQZAABwKevndXqbT+Js0NRSaFNb1LQBF+PNsh0LLUkLFnXqoP93lH427wBJ0j9vfatu/Mc7NW7sQMTxwri78yR77s91lxnL9MjZV+uTU++XJP33a4don/86Xs+/1hWzp5/NCOiwfW0mDZg+n7P4u2aKamgLwrQBwB7FGQDJ5Ngx0bhBw3S8metb9YO2X2/WL3ZVzw8O1harpMUdbZr3/qf10W/fogmTVsYcc8htSzeXJB059g3P0vKEjg9NeEEnV1/QQK2ic+btp3uXR1Xthntx9QR98o1D9PX579HywQ7tNeZtXdf9e5004UVJNW2y+TKd+r0bNPvdr2tVW5u2W9amVy48QvffsG3kcUe3rS3O1OrFmbD/Rq67v5LOTE4h6V0zHsxNBgAAjcjl8zq9wkabla+prS7qLgPb8WY2UwfcFGjWDLTrUz/bW5995AT11Tr0wQlP6okvXK6j9ns74nhxypOZ1svqnGr6zLGv6qHTfqWdRr6hhYPjdOztf6Mv/PLd6l9j+1ZkVpMGvJ8H/e51yGzEefnumsmjoY1pA4AbFGcABAp6MzLvGcqNEzTSjjdTxNfZFGief3wT3XXOMdp53tCF5BNTV2ifb9+og499KeaY0m1rR5vt17VA49pWB2yxRsUEjzXarnOhztv4EUnSdxfuqjuWbZHwWBVd1bONjn/l/XpkxRSNaVujf5v6sC7fd5be/U836+nqGnXUatrxuam69u+P19wXJsYecUrHUPFr4ZrRSl+Ysen+CvrcZJxZBnfNxGFuMgAAaGYZ37kb29ySe1Nb3OjcuPFmSaYO2K6PvpP9RzdvpSOu/6jmD1S1dccC3Xzkr3X9F+7S9GlpRzwXl5myft3JE/t11T/cpx/tcbXGVPp0z7IdtNuPP6Yb/2jXNDfERWHG/3VQM6V/+6hJA3HPotXwzBSlSRraACRDcQaAGy0bNFyMN8vqVv3oY61a3qnLzztSE6/fTRv1D2reyDYtPuFRfexbt2ri5BWh+73SP15z+iZoRGVQh4ydG/P60vAAkDYIBB+vQ4P69tSHNKptQPcum6ZfvrNDitcYMrd/rM58/VD9fGB7DbRJuy7u1Xm/qumgFwdVuWQ//e6iA1Qb9gyZYNM6lkuS3uwPu4Xf9ncgrvsr6FimM5P92/i+rEbsEnXXjGEHmPfve+ZzkwEAAFJwNdosTVNboJiml3ya2uqSNrUp4OskUwf86/2iCzT3PTlJO/z36frx3IM0UKvo2OrjevZvf67zT3tOI0aEP5/STj6ZKUttbYP6h+Nf0vOfv1gnTXxEa2ptOn/2+3Xof31A894eleCIrgozcc8zMhln5l8W9YzaADZ3zcQpaUMb0waAZCjOAEiuWYNGbIEmaieb8WZJO8HcFWgk6b5rZ+jxc4/Wu94ercFKRY9vulS7ffsmHffJJyQFh43bA0eb2QoLDHEfwc6Y9FftMGqJFq8ZqXPn76OaKinObUhb24BO/PtH9adz5+i809s1b5K00VLps1cM6uBnFqkj5OcTZNvOXknSG/1jA9ZG/Te0udMq7TizlCEjSFDIiJmbHCRNyIjESDMAAOBYkokDaXibWQKb2ro9n0d14peiqc07Ysrl1AH/er/ozNTT26G/+/le2vvKT+ihldtqTGW1/nW7WXruK7/VB/d/K3LfdNxmpqwc/O7FeuKfrtH3d7tOE9uW69nVm+nQa0/XN694V8JclqYw498nqhATNs6sLqqhzXDSQBQa2oCWR3EGQChXo83SSB00ou6eqa790+QN5g2EBY28OsFM1vtFh43ehWN05Zc/qMk376JNVte0sKNNLx7wgk77wXXadd8Fw7a/Y9lQcea9XfM1qlL8zOTNRyzTZzd6WpL0H2/vocUDSTqzNnTgUS/pmB9eq9l7vqqe9jatmVjReR376XdLtlZbRfrURs/p0i3v0PQRvUbHe8+YodD22MqNfWtcdH+lGWfmX5ciZJT0rhlGmgEAgNJI8IZnWFNbLNOmtqBrNydNbUme2elf5v88rmCTXYFGkh6bPV77/ccx+uSfTtJbg+O1TccC3fS+S3XDP96prTY3e4ZnM9lskz5d/vn7dPcxv9QuI1/TksEx+tJTx2jXi07W/U9NSnjUtIUZF5MGTBraDCW9ayaHhjZjjDQDMkNxBoA7eXZWuLh7JkqiOcr+bbPsBDNZ7xcfNu7+3fa6/4vHaccXJ6ujVtNT4wZV+fS9+thFt6h7xpJ1283uq2pu/xiNbhvQfmOGF2/y9oXJf9HotgE9tHxjXd/bnepYM3ZepI/+1w1a/OHH9OJoaczgoHZ4Zppu/fvj9ef7puv8BXvrC3P3V8/ASO06+h1d1/17fXrSMxqhgdBjdo9Yrq1GLtVAraJHNyjOuC7M+H+/gvY3mZlsIMeQEcYbMozRAQYAABxLM9rM5HmdYcKe1xnb1BbFaVNbXdKmNtOpAzZjoYOOFbe9X0W/+P10bfe9M/SjNw7UmlqbjpnwhJ4+62L96988p87O8FzQLDpGDOrcU57Xc5/+hU7Z6GEN1ir69YJ9tcP/flz/dc0MDQwmebuxQ+4LM1HPmYkbZ+bw+ZxBStDQxkgzoHgUZwCkk8Nos9iLjW6zc8gmaNRFddNk1Qlmst4vPmz0LR+p3/3bIVr6vYO1Q88IDVQqenyTFZpy7m064xt3aMqmKyRVdNeyzSRJhxo9dyY723Uu0QfHvyZp6K4ZJRxntvWOi/XRb89S1zl36omNVkuSdnprjF775vt11X/ur8H+9T+725dtoRNeman7l09VZ9ugvjDlKV2/1e919LhX1DZs1Fm7vjDlSUnSPcs31dLBkWuXJ70tP+zroN+bsHFm8i0rR8jw8v69D+sA80o0N5mRZgAAICORd+QW3dQWNXHAq/CmNv+yqG284jJT1L5h2w/Xs3SEPvuL92ify8/SQyu21ZhKn87fdpae+dIVOvqAt42O0Wja22v6yBHz9MyXr9C3drhJ4yqr9Piqbh147Zk6/cfv1YJFnQmPHPYzN/lvGbYu7jkz/s9tn89pobr2zwIa2hKhoQ3IDcUZAJGsR5t5/ifuarSZ1wZdYHFBw+Q2/SDVtX+mmqNcl6QTzLZA45e+QCNJLz65sa76wnEa+7s91b2iopVtbXp0+hJt/e836aPfvF3PTKhKkg4ZOzegIJGff9ho6E35Wb1baHbfROv9t991oU7/1u9V/fIdemLqcq2pVLTdsnZ1XLKPrvjKUZr/SvDzV+av6dKn3jhYX563rxau6VT3yKX6j00f0u+3ulnnTHlCB3bN0/5j3tJ/TrtfM8e9poFaRT9cuNvavbPs/go6jsnvagLVtX86ChmBXZ4uRYQMRpoBAIBMlbWpLaiJppRNbWEZKelY6KBt4rYP9tjz47Xffx6jTzx0khYMTNA2HfN14+G/1n1fvllnznxdo0c1/p00kyf26+unzdar5/5Kv3nv5dpuxJtaNDhOn33keO150Qn641P2OWy9pIWZLCYN2DyfszEa2nhGJ1Bu5v+3AQCHdtVT+ot2kTR0sfC8tpc0dBHxgraRNHRx8aK2lTR00fG6tgg/YLekVwxeeLKkhb5lVUlLNHTxtDzuAKMlrdTQBVnYs0bq29T/7ND6hzLWPx8hqd+3zC9qm7j9veuDhL3mcA/dvJV083Qd9jfPasRBz2luZ5ue2LJHY77yqFb9tzRZfdp11Dt6YlXUE0SzseuoRTps3FwN1Cr6n0W7WOw5qIOPeUWbHfaUnqmu1mOViqSKtlneprdn7a5rbtra8DgV3by0W3ct20x/M3GOzpo0W5uPXK5PTJqtT0yavW6rgVpF31iwt57rm6RkhRmTr4OCa1QIbc6QYYy5yQAAwKFjajXdWEl2B3flHql2sNvz2Xjr1/XWS2vz0+Y16Q3fuU2T9KZvp6CsNEVS2E0gVQ3lqHrsWcefmfz5qJ5VvDv6c1L9OrZfG2abqBwUlZmC1gdt41U/B5PcVNHFt07XVfefrm+d8qQ+vfn9OmDMX3XAvn/V9/YZo6sX7KGfPbiDHvrLeCW9078Ie+ywVP946DM6ecojGl0Zmi6wcHCcfvnK3rrw+h21uHdkzBHi5FGYidrH0fM5o1TX/lnihrYNGmtpaANy1bR3zjz88MM66qijVK1W1dXVpX333VdXXnll0acFNKQsZiinYXX3TJ2zu2f8wu5IcN0J5t8vbp1/fRCb+nyb7vzNzrr1Mydp6m07aasVFa0Y0aZHth0KFp858Q86+fMPq7pRn8Ux0/vC5L9Ikq7v7dbLq4PvcPGauvkynfbFh3Tcj67WopMe018m9mugUtGMpe2qXr+brv/cCfqjcWFmvRW1EfrZO+/SYS8eqy/OPVA39Xbrr31VzemboBt7u/WR12bqqp4ZSl6YcdH9lSBkRP3+V9f+mSZkeLgKGUnmJgehAwxAqyA3AdlxPdrM5O6ZWN0By5LePRMpzXizLKcOZJubepeP0Ocv3kvb/fzTuvDFI/T6wEaaUFmhs6Y+oAdP+IWeOe93+qcPz9Emk6Ma6YrV1jaokw+br3u/fLMeO+VnOn3jP2p0ZbWe6ttCn/7zidr822fpK7/erSSFmajlppMGHDyfs1kb2gBkrinvnLnrrrs0c+ZMjRo1SqeeeqrGjRunq6++Wqeccopef/11nXPOOUWfItDwbpwjHTMjZOWdkg6zO57J3TOxujX87plNJPmfV+/87pmwTjCvNJ1gUd1cLu6gUcB+IWptuuOyHaXLdtCeM1/RvBmPS8+u1jYv1/Q/R76qHfd4WdssHqNFj2+le2/cVst7k84cjrf36AXar2uB+mtt+t+FO4VuN35inw4+7nlVd3tRz07o11NtQ3fJjBysacbCLj197R669sGYyoGhlbVRunXZdN26bHrAWpeFmaDPTbq/LENGXXXtn65DRszcZJOQYYwOMAAYhtwE5OxWSTPjN9v9nef0xKQdJW04ccCEd+LABnfPBMn07pk4/rtnbKcO+JcpYBv/aylkfdA2fuaTByTp5Xmjdd6lu+iftbNm7rNQn3jPszp60l/0rhFv6Ns7vqFv7tCu23p30cVP7KQb7ttYAwPF3k3TMWJQ++7cqyN3mqfTux/V9Pah/+Bram26Zcmu+u59u+nuxyfJzV0/UW9F2hZm4sbbmU4acDDOLEh17Z8lu2uGZ3QC5VGp1Zrrb9KaNWu0ww476I033tBDDz2k3XffXZLU09OjvffeW6+88oqef/55TZ8e9MbZcL29vZowYYJ6eno0fnx8VzbQrIJu0R9WnPEGDU9xxn+Lfj1oSNogaNSLM5LWFWckbVCc8Y422yBo1G/Tf8XzQvWg4S3OLPT9Ka0PGks8y5b7lg0LGvUFvb6v63/2hyxf41m/xrdt0LL6Pv6vvfwhIig0mHRmmYeNurFtq/XAttdqRKWmb5zVpmc2WX9DZtfgoLbv6dSKFzfTk3dtpTnPuLqYl6SafrPlHdpj9CL9ZvEMXfDWnhus2+ndC7XzgS+pfet5en5cv1a1rT+vTfukrhen6f7LdtXbb4xzdD6SeeeXf9skhZmg7q+okDFaqUJGde2f/hnj3pAx2fentL444w0Z3Z7P1wYNb8hwOTc59PZ8X8gIKs4QNNDquAZufi5zE78vwHqxuSkkM0kb5qZMM5O0Pjd5izP13OTNSvXPvcWZJWv/jM1M3oUmuSkoM9W/7o9YFrSPd5kC1gWtD9rGzz4z1U0cv1qnH/qaztzhSe0+6tV1yxcMVnXZ3D113eNb6skXx6mnN/s+6nox5rAd5+vgaa9q764XNbayat36xYNd+uVr++gHt22vV960bOyKfuWIdS4LM3GTBrzrTadhWI6Brnq2qecmb1aqf+5taKvnpm7PspDiTFBuimpoC8tNkSPNPLmJzAQM5+IauOnunLnzzjv14osv6uMf//i6gCFJEyZM0HnnnaczzzxTv/rVr/T1r3+9uJMEGlDZZiiH6lb03TP1TrCou2eCpJ6jXIZOMMllN5gkLRscqT+v2Fj7dy1Q10U7a/Q+o7TNgXM0b+NevdPRpscm9kt7vaLOvV7Rkf0DmtYzRn1vTtb85zfW849vrHlvjFWSgs1BXW9qj9GLtHKwXTeOmq73nTRHG2/9ltqmLdKb41ZqwYh2PbNu6zZN7q9pyvyJevbWHXXb/Zsmes1wNgHDv73N3TRR29iOM7NQXftn1MNfM7hrJi3TkWaEDACtitwElIDhxIE0z+vM9O6Z+tSBqhLePePnz0xRUwfCMpLpczuD1gdt42c5ecBjce9Iff/6bfX967fVHjN69MkD5+jUzR7TJm1L9MUt/qAvrv3P9ObARP111VTN7pmi5xZO0jNvVPWXl8br7UXJ38Lr6BjUfjv36rAd39RBm76qfbpeVFdlw5HUiwe79EDvtrrlha11yR1baGWf67cMsyjMBH0dNQLPv8xxYSZIBmOgwxrawkQ1tIXiGZ1ALpquOHP33XdLko488shh62bOHGpRueeeFEPoAazjYrRZ7kHDqx46TIJGoCQPuqzzhop6iFDAsjQFGimPsHHnss21f9cCHdo1TxfPep8enrW1VBnUboe8oW0PekGrpy7Wq6MGNW9Eu+ZN7pMmz5V2matJJ0k79Q+ourpDnX0j1L5ilGrLR2nl8lEaWN2hWn+bBte0a7CtppGj+jVy1Gq1j1qjtjErdOasJdIi6f59BjT4vrs1X9L8dWfUrvZaTVuubFPn3Ml67u4ZuvuBqcrmMWtZF2Zsu7/qmitkpJ6bnGCmOwA0O3ITUJCI0WbepjbvaDNnNq+tv3umW+lHQgfJpKmtrr6tzVjoLAo0YfuZe3zOBH1uzl76x/Y9dNKB83X6bs9qj3Gva2r7Ek1rX6xpXYt1SNdz0qaSdh3a5+3B8Xq+b6r+2ruxnl24kXpWjND4Uf0a17l66GPkanWN6Ne4EX0a29Gnse1DH11tfZrS1qsxIcWYe17fQn94eqqeeH6cak6b2Ori3npMU5gxGQEd95wZ/zrDzBSkSRvagtDQBrjRdMWZOXOGWmBnzBj+jvHUqVM1duzYddsE6evrU1/f+v9h9fb2hm4LwKdRgkZQwKiranjQqC/LtBOsfmEftKzOtkATtE3Ydn7mYeOuZZvqXzZ5VHuMXqhJ7av0zsAoqdamJ+/aUk/etaUkaUx1pfY49HVtsstcaXKPFo7p1zsdFb05ol1vjqhJXaulSau1ftxBuH2fG9TURTWtGCn99r3tkqTJ/TVttGKk2hZO0Ny/bK5H/7CFnlyW3fNu7AOGf5+0hRn/fjYPs2zOkOFqbjIAtIo0uYnMBIQLmjhQxPM6C7t7JlDSAo00PB+ZFmPyKtAoYF9z/QPtuvzuzXT53ZtJkqrjVmuXrZdrp82X6F2bvKMdq29r+9ELtEXHIk1p69WU0b3af/TzwdfbMbzFmDufnqrHMyvGeKVpZgvaJs0IaP8xUk4aaKGGNp7RCWSn6YozPT09koZuxw8yfvz4ddsEufDCC/WNb3wjk3MDGl3TBA0vk7tnIqXtBAvq+gq7VV++7aTgoOHfJ9uwMX9Nl55ZNVE7jVqsQ7rm6ZrerYdts2LJaD1w7XbStdutWza6ulLb77VA07qXaPSk5Woft1IDY/qkEf0abKtpoFLTmraaOmpS+0Cb2gba1L66XR+7cyio3TNlivqunqGXH5uspxePifleXHJdmAnaN033V8oxZmEzk70chQwv25BhzDJk0AEGoFWkyU1kJiCliKY2L29Tm3fiQCpJmtpM756pL0vd1FYX1tRmMnUgqNEti9HQYfsms2TpSN335Ejd9+RESVutW941ql87b7VMO2/ZO1S0mbhQnR0DWtbfqWVrRmjp6k4tXT1Sy1ePUO+qEVpa/1jZoZ4VHXpn6Ug998rYHIoxdWkzU9A2poWZsGUZTBrwoqENQAJNV5xJ69xzz9XZZ5+97uve3l5tsUXEG78AEilN0PCqF2iqShg0knSC1flDRZpb9f37RG0juQgbdy3bTDuNWqxDx84NLM4EWblktJ64o1tPGG095LjxL2vKtD9pycBIff3ug7R8MKrQ4VoWAcO7TVygSNL95SBk1FUDlqUMGd4OMFvMTQaA4pCZgOwkeV6nSVNbrIZpaouaOhA1FlpKVqAJ2i5I+rtooixfNUJ/em6i/vTcREnTM3kNd/IuzPi3cT1pIEJ17Z80tAFIKIsB/IWqd36FdXn19vaGdodJUmdnp8aPH7/BBwALEW9+Jpln6r3I8F58eC9KYsVd7ARdIFXX/hl0kTXsDeykdyqEXShGXUzGLfOKm90btl2Qjoj9pTuXDd2G/96u+RpVySaQjNCAPrfR05Kkn7+zY5MVZoK+zrn7y6u69s8ShQxTaeYmA0ArSZObyExAtKA3LiPHAhk+H8/bnGL7BuwGTTHejvzugI29zTZB13n168FqwLr6ssiI5L8ejbt+Dbr2jWpckmdZXGYIykwm2SpMK/dAR2fGIS4KM0HH9P+OBC1LOmmgxRraAOSq6Yoz9ZnJQfOR58+fr2XLlgXOVQZgpmWChlfU7cnrhL3xbRs0vF8nLdp4uSzQ1PcffozZfVXN7R+j0W0D2m+M/ymibpxYfUmbj1yut9eM0mWL8/p3PKuA4d8m7rb8qN8HRwEwamZy0DLTkOGVImTkMTeZDjAArYTcBBTMsKkt1zdUS9HU5s9NQYLeiK8vD8o//syUpImqvl36xrbmY5qZTAphJoWZsBxlMwKahrZQTBsActV0xZmDDx66//e2224btu7WW2/dYBsAGTH8n3lhQSPoDeSooBHEqNkmbYEmaFnWBZqkHWEV3bX27pnDxr5heAxzoypr9JlJz0qSfrxoJ62qZR12kgaM+r7+7aK2MSnM+PfLMGTUuQoZ3fEvlQZzkwEgGXITkC3rpo+UTW0mEwcao6ktbPu4qQNBb8jLs21cJjIt0IRtG6TZizSm359N8Stsm6is7P08qoHNNjNFHKIasK6+zPt3I+e7ZmhoAxpP0xVnDj/8cG299da67LLL9MQTT6xb3tPTo29961saOXKkTj/99OJOEGhxpiOHMg0aXpkGDZNto94kD7vAdFWgcRs2/rBsc0nS+8a9oRGVAcNjmDl94l+1yYiVmts/Rlf1mD3TJpk0AaO+f9R2/uObzkt22f0VIKuQ4RUSMmznJrsIGUEIGQBaDbkJyF/kxIEIDXP3TNCyTMebmeaj+vIiCjT1YzRTkcYmM5k2syUtzHj/u7qeNGAwzizqfYLJIZ/XWdw1kxbP6ATKq+mKMx0dHfr5z3+uwcFBHXTQQfrUpz6lc845R7vttpuef/55fetb31J3d3fRpwk0NJejzbwXCcZvuibl6u6Z+rLAoGE63szPtBPMZYEmbLv6tnZFmodXTNP8/tGa0N6vQ7rmGe4bb2L7Kn1y0tDvyffe3lX9tXZnx17PRcCwuSW/vj5qmYvuryA5hozuiNNwIOkbFUnfGAGAZkJuAkrAwfM6S3v3TCSTa1XXUwfqy02u0U0b22wyU9RxGoXN+SdtZvNvY/LfOWhZSceZ5dTQFoVndALl0nTFGUk69NBDdf/992v//ffXFVdcoR/96EfaZJNNdPnll+ucc84p+vSA1lD2oOGVJmjkMt7M5E36LAo0YdsHG1SbbuwduqvluPGvGO8X5+82ekZj29fomVUTdcvS6c6Ouz5cuA4YQdu6KszYdn8RMiTRAQYAIchNQLbyeF5nJgptaoviz01eJk1t9eUmY5+zamyrH6eRijS2mSlpM5uUvAgTtCyDSQNBqgHLgv6OeGXc0Ob9NyKyoY2RZkDhmrI4I0l77723Zs2apZ6eHq1YsUJ/+tOfdMoppxR9WkDTyGOGciZcB41AURd3Lgs0YaHCpEBjO+bMLGzc2LuVJOnAsfM0ub0/4phmthyxVKdUh960/87bu6umSqrjDbENQ7YBI2lhxv+56+6vACbjzLxKFDIiETIAwBi5CSiYYVOb6R3DqZra4mTW1GY63swvydSB+nLXBZqw7aPYNovlKcm5pWlm829nU5hJMwI6CA1tAPLTtMUZAPnLeoZyoXfPVNf+2RWwLHUnWNICjf/zoG3ry1yFjejA8cLqqp5cOVkjKjWdXK3/QiQPHV+e8oRGVGq6d9k0/WlF3NVslKThwnXAiCrMeEOiq8KMl8U4M68CQ0YU778bhAwAANBqvE0rpm/Mxur2fJ62qc3qmZ1By1xPHXBdoHF1F43/mEUWapKeQ9pmNml4bopbZjJdwoRBbm/BhrYgNLQB2aA4AyA7CUebZRI0vGyCRi6dYFFsbts2XR50fP92yTvCLl28vSTplOocdWgw5NjxF/2HjX1Dh4+bq/5aRf/59u6R20a/hqtw4T1+0D5R24WFBpfzkoNYjDOLW1ZAyPD+/c/8rjoAAIAMZPW8TlO53z1TDVhfX2bU1BakDAUa941t4dJkmTxfx1Uzm6vCTE7jzBq0oc0Gz+gE8kNxBkBiDRc0umMOODnk87qq4bJ1XDzoMoxJ0Iha7j+Wu7Bx29It9faaUdq4Y6Vmjns14hj11xj+MaZS0z9v/Jgk6ZJ3dtCLqycY75ucSVHGVedX0PKowpvJ99X4IcMUc5MBAEBTyfB5nda6PZ+Xpqkt6ACuCzRB1+J+SRvbkhZp/K+RJPuk2TdMXFEmSTObd5s8CjNBLMeZVYdvFvg+gvfvToENbUwbAMqJ4gyAbJUpaHjFBQ2vVEHDZLxZlp1gUcvDzsO/rUnYWL9vv9p12dq7Zz47+S9qH3b3TLwvb/yYpo1YoTdWd+lHi3aT2zDhl6QoU98vblvbwkxccHMwzsyruvbPoJDh/b3PKWQwNxkAAEDGTW2mdxY3ZlObzd3iWeQmmwzgPUYYF0WasNfNogATJOmEAdPCTFyWSluY8XI4ziyuoS0ODW1AS6M4AyBfhnNNmydoeBXVCRa13M+2A8q/79D+v168g95Z06mtRi7VceNfitlvQ4ePfV2nVOdosCZ9fcG+WlXL6lZ+k4CUNmCYLM+i+8vhzGSvgkKG99+DpCEDAACgKNYTB3xas6ktaJ1JY1IeuSkqD5hmpiwKNVkwOd80haz6sbPKTF4ZTxqgoQ1AAhRnAKSSNmh4Rb3pmmvQCLqo8l50Vdf+GXS3QWTQCFpm0s3jokATdru+TdiI2mfD/VfUxuhn7+wqSfqHKU9qQltfzD5DthixVP829UFJ0i8Xv0sPrZgWs4ct0zCUtlMur8KMl0X3l1ezhgwfOsAAAECpJXyDNMlz+XJvausKWGY13swrbuqAV5kb2+r7l7VIkzYzZTllICwzBXE0aSBuWdzfB69uz+c0tAEtj+IMgOwlHG1WWNDw8l5Y5d4J5v08SYHGRdhIHjh+u2QHvdg3QRt3rNQ/b/Jo5LaSNKGtTz/Z/E5V21frqZWT9N8Ld4vdx4xNd1ragCG5DxmOur+8qobLbEKGl0HIMJXk3wFJdIABAIDSy/p5nd7mlyRjkdZJ29SmmGXreK93XUwdCFqWpkCTRWOb9xhFFmpGyPw8or6vvJvZvBxPGvAyfT6nV1hDWwpJGtps0NAG5I/iDID8NWrQ8KoaLgsUdCFocgFpU6BxFTaCtjXbb3WtQ+fNP1ADtYqOHv+STprwssIu+DdqX6mfbv4HdY9cqnn9Xfrc3EPUX2uPeN0oNsEi/vuwCxj1W/KDbtX37mNbmAk6H8vur6BxZi5DRnfMfj7ev7emd8MxNxkAADQqo2sPB01tpm/Ylr+pLWhZ0F3kJgUa+ZbZFGjyaWwbfpysCzZZZKYiCjNpJg1EbCJFT8zwyvCuGVNRDW2MNAPKjeIMgNSMRpuVOWjEdbHYBI3q2j9jg0aSW/Ul8wJN0Lq45d51QedkX6R5atUU/XDR7pKkf93kjzq1+pyk+n+HoWO+q7NHl0+/VbuMfkdLBjr1mTeO0MKBMRGv5T2fNOGlQ/kGjDSFmYy7v6oB6x2HDFPev+eEDAAA0NJaoakt9dSBoHVRY69sCjS2kwfq28YVaUwLNf5j2uYeF5lJSpeZ4prWvOu8y8P+O/mPL9kV78YHrAtQjVkW9L6A9++Co7tm8mxoA1AMijMAilGmoOEVFjRM5yh7Je4EM12XtkDTIbtusPpxTIo064/x40W76col26m9UtPXN3lIv95ilj428RmdWn1O39/0D7pi+k3abMQyvbJ6vE579YN6YfVERQeJtN1jcYHItuvNNGB4l8UVZrxMxt45GmeWYcgIm5tMyAAAAK0i7fM6TZ+3l2TcUWFNbbFspw64LNBErTPJFFGSFGmCXiPrzJSmec+/vfe4Yeu8X0f9t0syAjpoXfSiRM/nDNPt+TzlXTNZNLQxbQAoBsUZAE40fdDw8l6MVQPWBy3bgG0nWNTDLqMuZLMOG6aBY4T+dcF79Z239lLfYLv2HLNA5278Z319k4d0xLjX1F6p6aberXXaqx/Uq/0TYo6ZlEmXWpKAMSLkuEE/e5vCTNLb8h2NM3McMkyZ/v02/fdCImQAAIAGk/AO4Kg3aMOa2pKMT1qn0KY2r6g7y5MWaKImD9jkJpvMlLZQ44KrzGTTzOa6MONg0kA1YNegZV7cNQMgIYozAPLTjEHDy3i8mZdJJ5h3mU2BJm6dAtZlUaQZ2u7ixXto5sun6Htv76Xblk7XH5ZuqR8v2lXHvny8vvLmweoZHGVwHFMdMg86LgNG0DqT/x5RhRlZrotelHicWc53zUT9vY/ESDMAANBMEk4cyKWpzXt9GMSmqc14vJnJG/ImY6G9y4Lu6ojKRVGZKojp3Sw2GcaFPDJT3JQB77okhZmgYzqaNJC2oS3s+Zw0tAHwoDgDoDjNGDQSdYKZFmj869IUaJKEjfSB4601XfrpO3voH+cdqb+fN1P/vXAfvbB6Ssyxo/gDhWmQMTnfJAEjbaHMNmxGPbvIoxqzLGjsRJgCQgYdYAAAoFnk9bzOKM6a2rzCGnycjzfzMm1msinQBG2vkHW2jW1h+4VJmnVMj2XCtKgUtJ/3dePWJS3MOJw04OWioc0SDW1Aa6I4A8CZLEebNUzQ8Kp6Pnd+q75pgcbkItcrKmz49w9jEzj8x7b9sGXTsRa0b9j6uM6vqMJM2tvyQ5iOM/PKMWREYW4yAACAh4OmNtOxSKma2uLuJKh6Ps+0qc20QBNWGAjaXr51aRrb6vuVNTfZ3ukTtK8C1ilgnWlmNS3MpJg0UA3YxbssLv+HTRro9nxOQxsAH4ozAPLl79BoxKARFDqqns+tOsGS3qrv/Tzq4tW7zOR2fdv1UUYoXehwweYcoooypgHDNuDZFmYy6P4qKGR4/94SMgAAQKtriqY2r7CmtmrAtt5luRZogpaFNVYlbWxzdVd/1mzOwaQok3bKQNCyJP/NE0wasBlnFleQNERDG9C6KM4AKJWGCBpeYUHDuBPMK+5WfZcFGu/6tGHD9pb4LINHktcwLcqYdoUlLcz4j+3nqPvLq5lChg8hAwAAlJXRNYnF2KHIZhaPUjS1eSVqaotb5qJA4/3ctrEtTZHGe4ysizWuc5N3G5P1toUZr6jCjJfhpAEvm9/JsEkDNLQBsEBxBoBTaWco+zVE0Ii786Dq+TxVJ1jYPkkKNFHrvUzChn87U/5AYBoOku7nl7QoI4UHDNO7lArs/qpqOJNxZhmFjCimf/8lMTcZAAA0N4s3WL3NLlFv6DZHU1vUHehh620LNP71UY1t/v3827jMTFH5J+l+YefsopnN9i4l7/FMx5UlmDRQDTikd1lYQ5sjZWtoA5AvijMAipewk6M0QcOr6vncaSdYXHdQXIHGphvMRZHGNnT4uSrA+JmGC4VsI9n9XP3LvMvTFGZCuBxnlpJpyPD+PbZ6oCUdYAAAoIGlbWqzeQPWqzRNbdWAF/UuM25q8y4PK9AE7R9XoIm6Gz7s+HlnJu/r5JWZvK+piO2S/FzDftYZjoCuej5POs6szHfN+Bk0tDFtAMgXxRkAzhkFjQhJR5uVMmhYdYLFXUymKdB4lycNG/VtTAKHd1sXoSMNm3ChiG3j7jiKW+89rs1/17qU3V9eJQkZNpibDAAAWp6D53X6FdbU5hXWTGRVoAlaH/UGvhR/t0bYdX6SxrY8CjVpmJyHaVHGtJlNIevTFGYUs8yj6vncqrHSHdd3zQxDQxtQehRnABTD37FRxqCR+3gzr6wLNN7ltmHDZjv/9nkED9PXSlqUSVLkCgoY3s/j/ns66P7yLs9Ykrtm/JibDAAAEK1pm9q8ywPZjIU2LdB4tw3LUmka2+rbmTa3lS0zxRVlJLOfT1zhxtvMlqQwk2DSgFfV83mzNbT50NAGlAPFGQClVIqgkUbV87mTTjDTh116P4+74PUvj9rGpkhjchu9PwwkCSBJjhEWLlwEDO82Yd13SQozilkWwmScWYlCBnOTAQBAq8nzeZ2lbWqrxhw7dupA2DIXBZqwpqw0jW3ebW0b3PLKTN5z9B8nblvTn5t3X/mWx/13VMh6h+PMvGwzk6HcG9p4RidQShRnAGSioYNGt+fz0nSCeZfHPezS/3lYwcAmbPhfyy9p4Ag7jsmHqaBzSXrLvmkw875OXdLCTFA3YMim1eBN1klTmDGUecjwY24yAABoJb43XBu+qc3LqqnNK27qgJSsQBO1PE1jW9rclEVmCjoXk0a2tD8r/3KTwkzcszkdTBqIm4oRpdvzeUkb2gCUB8UZAOVRlqCRJ6NOMJcFGv/2phfQJh1h3m3DAodtsSaJsNezuW1fAdsmDRhS+sJMiDzHmXV7Pi8yZDDSDAAANDma2gKWJ546EFagCcpVSTJTVHawaW7z7pNHZvK/XtJGNv/2tncZeZdnVJgJk/WkgQhFN7Qx0gwoD4ozAHKVZuRQ5kGjqDnK1gWaoPUuCzRBF9P17UyLNN7to0JH2gASdxzTrrG4rq+on4l3u6DlLgozjru/CBkAAAClkOhaxaJZpRRNbWlyU6ykBRrvcm+xxqSI4M8EUYUNm+a2qAY315kpSW6KK+K4bmbzbx8UnA0KMzaTBlyioQ2AAYozADJjFDT8XWBFBg1ThRVogpa5LNCYdIQFHcO2UBN3W7zNR5LXiTr3JKErqmMuKPA5KMyYqHo+NynMmMoxZAxDyAAAABjG/8Zs6Zra0qh6Pk88dcD/ed6NbUHrw8TlmSwyU5JGtqyb2fzbJ3w2Z1ENbRFFzDI2tAEoDsUZAKVmMzc1t6CRRNXzeeKRUy6DRlgBIipABF2QJwkc/v1tw0HS/cLOM+x7ido+aDt/wEg6gi5F95dVl6FPCUOGzd9/QgYAAGhUiZ7XSVObT94FGpvskKS5zb9/UZnJpuBk2vTmqjDjeAS0bWHGUCM0tDFtACgOxRkAmXIdNPxv3uYaNNKMNzORS9BwHTbCtk9z2733mEnDhJ9JuFDANqY/lywCRkHdX4YKDxkGs9cJGQAAoFU0XFNbwxZobBrbTLNH0tyUR2byvo5/u7CvTXNmxrkpaw3Y0AagfCjOACi90gSNKKXvBMsqbJjeTZO2WGMr6nVNikk2P4sGLMxEKShk+BEyAABAq4ttaotgM/aokKa2tHIv0Ng0tpkUaVwXapKKy0xpClAmUwakUucmk3FmEbx/Z6Leg8iroY1ndALlQ3EGQCHKEDS8F0eRQaPb83mSOcqFFmiyDhthy8I6tvxFE1fhI+6YpmPYbANGQYWZtByGjCg2IYO5yQAAoJUleoPU90ZsVHOL/9qr4ZvajGU5eSBNkca7T1ShJsvMZHs+YV/Hfe9RmalBCjNRuj2fRzS0edHQBqCO4gyAzJU1aEQyvKgyGm9mqur53FmBRjK7yJXShY2wZSa31weFBNsPv6jXTVqUyaPzy1DV83lW48y6PZ8XfdcMc5MBAACsNURTW+ZTByT3kwfCmrck88KGbW4qKjOlLcqYNrOVoDATxeT5nD5Z3DVDQxvQfCjOAGhIpQkafmk6wfxSF2iSdoMlKdKYBA7v/mnnIdse0/Q80waMtIWZEo4z83ERMvysQoYPIQMAADQL18/r9LMZj5RbU1uUqufz3As0UdtJdpkhaBvvsjxyU9LMpJBlQd9f2HqbZrYMCjOmkjS0eRV91wwNbUBDojgDIBd5B42kd89sUKCJurjKY7xZJNML1SQXvpLZxbZt4Ai7HT8oKNh8BIl63ahONf92/m3q4n5+BRZmoiQZZ5ZByIh6YyA2ZBiMQCRkAACAVuW/lopqgvFfo5WuqS1K1fO5cYEmbJ3NXRxpGtu82+Sdm4KYZKag7ydpM5v/67Cfs//rlCOgXY4za9CGNgDlRXEGQMOwCRp+pkFjGNM5yknGm1U9n1t3gvlXJinQ+NfZho2gbbzbBYUKk9vrbZkcM258QNC2YdvE/XxzLsz4JQkZXt3hqwgZAAAA+Svj3TPOm9pymzqQdjR00NemeaLI3FREZkoyZcD/tUFhxq/q+TyrcWZ+nr8DUcVLVw1tw8Q0tAVNG6ChDSgHijMAChUbNCxkEjSiuO4Ey7VAYxs2TIs0UYEjLFAEBQWbD5Nj2pyrf7s6m4ChgG3rUhZm/Lzrwgozfk0cMgAAABqZizdM87h7Zpi0TW1+SaYOOB1xJsVPHohqbKvvb1Lo8G4bVYRxnZniXtfme/JK08yWoDCTdtJAknFm3Z7PDZ/P6ZemoS3qmbwAGgvFGQC5SRQ0fG/ONkzQiBIWNPy86xIXaFyFDcnswry+XZJb55MyCR823WpB28UVsOJ+lnUJCjN+Sbq/Wihk0AEGAAAQLZemtm7P50ma2vyqns9Tj4WWkj27M+g4cY1t9WOYNox5ty9jbgrats5VM5vktDDjl2TSQIJxZn40tAEIQnEGQFNpmKBhequ+d12iAo1/XVRRIWi9TUdYVOAwLda47AIzCTxB23uZBIwkt+QHbRuyuOr5PG33l1/GIcPPKmT4ETIAAECLcPG8Tpuml0ya2vzSjjfzSz11QEo3ecCksc22uS0sM3n3zTM3xeU8L5tmNv96x4UZP++6DCcN+NHQBsAExRkAhWvqoJG2E8zPWYHGJmwEvXCS2/FNijVpxB3fNlykCRj+rw0LM35Vz+dJur/8cg4Z/r9/XrEhw2J2eh0hAwAAIJj/2iv3pja/tOPN/LzrMinQSPaNbZL9s2b8mcZ1bjI5vs2EhCTNbBkXZvKaNOBTyF0zABoexRkAucrizdNSBw2/tJ1gkqMCTdDXacJGkrnJQcEgzUcQ2xnKYUUZ286vhIUZ09vyvetKGjL8CBkAAADpNFVTm5+LpjbvulQFmrjJAyaNbS6a2/z7ZpWZos4jbF/XzWxSosKMX5aTBvzKcNdMgpFmNLQB5UJxBkApNGTQ6PZ8bho0/KI6wZwVaGy7wZKEDSn/Z80keR3bcGH784kKGEHrQ1ZVwzeLLMyUKGRE3TXj5yJkAAAANBOa2jyf+695q57PMynQ+NebNmu5bG7LMzOZNrJJZt9n0DHjpgwkLMxUPZ9H5eckkwb8usNXcdcMgKQozgDIXUMHjSzGm2VeoJHiu8HShA2bwOE9vu0s5KT7Rp1PWKEprigTtE1GhZkk3V9+UePMuj2f+36/s7prxv/31QU6wAAAQCuwbWqzkaapbYPrRpumtqjxZkmmDkgZFWiCvjbJCPVlUc1tpg1utoUbV8+ekdJNGMhgyoCUbAS0X9SkAe/vardvned33J+ZXN0140dDG9CcKM4AKK8yBg2/7ogXNe0E86tarCskbCQNHKbzkpM+yDLsdYPYfB+mhSqvqGJYxK5V37qk3V8ZjDPjrhkAAIB8uWg+8V9rZdXUNoyL8WZ+pk1tfs4zU5K8ELVccp+ZTHJTXFYLa8YzmTCggG0cNbNJyUdA5zxpIM1dMzS0Aa2B4gyA0rB90zXLoOFsjnLSTrCq5/Ooi00pp7BhGzhMQoerB1yaHjPq3NIEjIS35PtXV33rXHR/+eUQMvwIGQAAAI2jFE1thUwdiLumT9PYZjqBoM72eTEm8s5MJSjM+OU8zsyPu2YABKE4A6AQRbyZahM0/KyChotOMCnnAk3cxbMU3H0VFSriQkddVg+1jDuHsHU5BIy41a66vwoIGTZ3zQAAAMBe2ud1lr6pzc9m6kDiAk3QBq4a28K29S6PPTkVk5mizj3s+4zb11Fhxs+/LumkAe6aAZAjijMASqVhg0Z3xIFtgkbSWcpSym6woAMkCRsmhRqT4JGUyetEhYscAkbQ6qrn86y6v7who9u3rqC7ZugAAwAAiEZTW8DXUU1tfv51Tgo0po1taZrbssxNNpnJ9HszbWZLOP5ZSj4C2s9m0kC35/MUz+dMc9cMgOZGcQZAYcoYNKLeeB528ZXFeDPJbpayf31mYSOsSJOkUOPfxiaEuNjPzyY4xQWMoG1iVlc9n8f99zXt/vKz6P5KEzKKuGuGDjAAANCKWqKprZCpA0EbJG1sk9JlJv92NoWbrDKTaVEmw2Y2yW4EdAaTBvwi/w742N41Q0Mb0NwozgAonSKDhl+qOcpJO8H8qr6vMy/QhB3ENnB497Ht+koSQKL2D2JblElS2ArYxKvq+dymMOMXFTKiur98/L/fhAwAAIDitWxTm1/c1IGq53PnmUlK3tgmmWemNJkn7TGCJBnT5uW4mU1KPgLaL8U4s0a7a4aGNqC8KM4AaDlpgsYwRXSCSTl2gyUp0piGDtuCTRTT45rc8RO03M8yYARtUvV8bluYcdX9FRWUfQgZAAAAJdeKTW2ZjYWub5BVY1tYZvLu6zo32Rw36bNzvEx/VhGrq76v04yATjrOzIeGNgAuUZwBUCjjN1cdBw2b8UuxQaOITjAp57BhEzgks9DhPUbajyimRSOT5Qk6v4I2qUZs61+XVfeXDyEDAACgsWR9jVTaprZcpw4EbeSisU0yz0ze4+SRmVwUZRxPGZDcjoB21NAW93xOGtoAxKE4A6CU8n4zNi5o2LxRnWknWNX3dW5ho76dbeCQ7EKHC3HBQooOKY4Dhk3I8K+LK8yk6f5q8JABAADQShK/wdosTW2FTh0I2yhJZjIp1BSRm8LYZsAEd8sEbVL1fZ3VCGjJ2TizOFk3tAFoTBRnABSuEYKGX2zQ6I7Y2bYTrNCw4TJwSMNDR9rwYXusqPPLIWBUfcuSPsgy6Ouo7i+/JggZdIABAAAYPK/TsVyb2vKeOpB48oBNZjJ9sawzk+lYtbB1fgmb2RSwSdX3dZoR0H5xhZnuiH19ytbQxrQBoDFRnAFQWmUPGlZvaNt2ghVWoDEtTkRt711n9MIKDwxxHybizqWAgCHZzUv2s+3+6g4/VNw4M0IGAABAObhqTil1U1uRUwekHBvbvOsaOTNJqZrZsi7MOJw0YJP//ZmJhjYAYSjOACiFsgSNVG882waNIgo0zsOGyYFtZh27YPJ6UeujgpXhy3tVA7YpsvsrapxEjMxDhiFCBgAAwHqxTW0Jr7nCpG5qs7kezSMz+bdJ3dgWVaQxySdlyUyKWe+wmU3KvjDjcNJAXENb3mhoAxoXxRkApZZ30PBLHTS6fQeM6gRLour72mk3WNIiTdwL+INA0gCS5Di24SJun4BNvaoB22TZ/eXX7fu67CHD9/ebkAEAAJAN101tqa4bu31f593UFrSNcTxJk5tMjltkZkqSmwxPya/q+zrrwozFpIE4cZMGaGgDEIXiDICmk3XQsJqjHCdt0JAyLNDUN7QNG979bAJEWHBIG0yShov6vhYv4VX1fd2l/EOGBdtxZqlDRkKEDAAA0MqMr4Vybmrzc97UVlSBxio3BXHR3Ba2T56ZScqsma3qW1Z0YcayoS13NLQBTYXiDIDSaJSg4Vd40JAyvF3fu3FUkca0UFOmW/TTdrN5NvWr+r42CX8Fh4zcxcxNJmQAAACYc3HtlHVTW6rxZiayKNBIDhrbpMbPTA6KMllMGXAt5aSBsjS0AWgcFGcAlF5TBI1u3wkVUaAJ2sb62j9N4PAfI234SNIdFldMsjyfJinMEDIAAAAah6umttJdo3X7vrbNTEFMCjROJw/UN47LTHG5Ke2dMHHHipNzZpKSFWYKHGeWi4QNbUwbABoHxRkApdK0QcNEUQUaKcF1vovAEXS8LG7Rt7m7x/J0varKpjDTjGJCRhhCBgAAQDia2gK+luILNJL55AFnRRqp8XJTggKRq8JMEBraADQBijMAGkJLBI0geRdonBZppA0v8G1CR1q2HWkWkgaMoO1MCjMtGDIYaQYAABAtcdNKCZraGqZAE7ad88wkFZ+ZHDayhe1SVfI7l0o+aSATNLQBLYHiDIDmkUPQyL0TLEjSAo1JN5iU8G5529vjXYWPoOPZjAiwELZLNWBZksJMEELGOoQMAACAeMOaXAyvtbyybmpzogwFmsRFGptCTZa5KUqKUWouM5OUvjCTAxraACRFcQZA6YS9CVuGoOFEt+/rJLOUkxRopIzDhnfHpLfP235keV6+Xf2qSj4SIYt5yRkgZAAAADShDJraGmLqgJSuQBO0bdLHwJQ+NyXguplNsi/MBCl40kAiNLQBLYPiDIDm4iBoZD7eLEgZCjRB26a4Nnfz8Mo0HLx22K7VgGVJb8mX3HR/dfu+boSQAQAAgMQSP6/TgG1TmxMupg64LNBk3tjmP0DRmSmDokw1YLnplIEkhRnHkwZcoKENQBSKMwBKqcigkUTqTrAgeRdowraVHOWELIs1aR58GXE4v6rcdn5JhXR/ueAkZPj+/oaFDDrAAAAAzBm9cVvGprYg3b6vsyzQVAO2M73OlxzFHMe5Jvb4Dg4VpBqwzLTgZZKZgtgWZgzQ0AYgaxRnADSUsgYNI1ncqh+0zKZAk9tdNEGCgkfSD8enFKQasjzrwkzK7q8ghAwAAIDm0PRNbUFMCjR+JgUaKf3kAan5c1NcZqoGLDf9GZoWZlyMgM7h+Zw0tAGIQ3EGQGmVPWiUphMsaFlY0KgGLE9yF00Rd9xnKWnAKLowY6CIkGGEuckAAACZaaSmtsKmDkjZNbZJzZeb4r6fashyl5kpaFkGI6BN2D6f0wgNbUDLoTgDoOFkFTTi5ij7L75MZNYJZlqgySNsNHLgSFKUkex+VlkWZrp9X5ckZAz7+2YQMpibDAAAkEyjN7UZ6fZ9HXRtnHWBRrJvbJOaOzNJ+TWzBS0LysoFjDMz4eKumTA0tAGNi+IMgObgIGiYsJ2jHMikE8xVgUZyFzbC9pEaK3CYFJWqIcttAobkrjATpNv3dQbjzEwQMgAAAMqpzE1ticabdfu+zqNAUw1YnjYzNVpuClOVmyJW3oUZy3FmJvK6a4aGNqD5UJwBUGphb84WFTRMOBlvFiTPAk2zBQ7T86rKXcBwWZhJ0P1FyAAAAMAwJWlqC+Jk6kCQNJlJsrsrpL590D5eZcxMknkjWzVkne1EhqSFmSAOnjMThIY2AFmiOAOgeTR60Oj2fW3SCSZl0w0mhYeNuP3qii7U2Lx+VdkHjCwLM00YMgAAAGCuEZraGmbqQF6NbdKGmaXsuamq6O/HduRbmsKMgxHQQZI8n5OGNgBpUJwBUHrNEjRKW6CRkoWN+n5h+3r5Q0cWwcP2+FUlK8rU9w3iMmBIibq/gjRiyKADDAAAICMZNbVlMt4sSHfAMpcFGimbxrawfb3yzkxpG9mkZM1seRdmAiSZNBD3fM4gNLQBiEJxBkBzKXHQCJT0Vv2susGqIedgWqQJ2z9IUDBI82HK5DzjvtcgeRRmun1fB/z+EDIAAABaU9FNbSYymzoQJIsCjZQ8M9X3Dds/SNGZKepcXTWzSelyk193wLKcJg3Q0AbAFsUZAA2hjEHDyXgzKdmt+lL+YUOyCxxRx8lbVeZFGduxBK47vyRn3V9BCBkAAAAosqnNyXgzyf3UAdvJA9WQdWSmYKbjnxWyzNEI6CAuJg0EoaENQByKMwCaj8HFjP+iKEnQCJJovFmQ7oBleRdoqiHrJLPA4T1O1LGyUJV9uHAVMKR8CjMG3V+EDAAAgNbissklq6a2IEbjzbKeOhC2PKwZS2rszGT72ibTFILYNLOFLXc4AtrFpIEg/r8LRs+3paENaHkUZwA0jFQXHwYXPf6Lp6yCRiDTW/Wz6gbLukjjPZb3w5UkxzXtaAtTosJMEEIGAAAA6pJMHAiSVVNbkEKmDkQtT5uZbAs1cce1lfTYaZ5DapOZwpZnOAI6SFbP56ShDUAQijMAGl5g0Ehw90wQ/0VWUNAwGW+WuBOsO+Ck8g4bUvq7T6KO6+LDRtpZ0FGdXy5nJUvGhRmT7i+TcWaEDAAAgBaSMDPl2dSWy9QB28kDSYs0UuNlpqRFGanYwkyApJMGXDyfMxANbQBEcQZAK2nETjAp/wJNmrAhJS/UZMn0nKoqJmBIzm7LD5LVOLNAhAwAAIBCuXxep1RsU1sg15lJsps8IMU3tlUj1kvNn5lspww00aSBIDS0AQhDcQZAQ0kVNAIUOUdZMuwEs2FboMkqbNQVFTq6lGx8QBjbgBG23KYw0x1xPmuZdH8FIWQAAABAUvC1Wis2tUluG9ukZM1tReamOFWly0w2dyg5LswkmTQQhIY2AK5RnAHQvEoYNDLvBJPsbtdXxHLTsFGN2MbLHzpcBY80x63KfcBQyPK0hZmE3V9lCxkAAADIXiM1tfmvV42b2spQoJGiM5NkP1Ysq2JN0txUVfz5R/0MbH+mGRdmgpiMMwtCQxuAtCjOAGg4jRw0gjgPGpL7sOE6cNQFBQTbD1tVpSvKSPZ3I2VQmGnUkEEHGAAAQMnk3NTmdOpA3gWapI1tkpvnvxSRm6K4bGaTcinMJJ00QEMbgCxQnAHQ3Fo1aEhuw4aUbeDIWlXm5xVXlHERMCRCBgAAADJVlqa2TKcOhOkOWBZWoHE1eUAyy0xSOTOTZJ6bkjazRa1LU5gJYJqZkk4aoKENgAsUZwA0pGYMGqUo0MStk+wDR9Vg2yzYvn5eAUPKpTBDyAAAAICxDJvagmQ+dUAqf2ObtGFmqRps75rt65tkpiRTBkwLM2EMRkAHcTlpgIY2AElQnAHQ/BwGDZfjzTIp0Njerp80bEjmgUMaftFfNdzPVNLjT1E2AUMqXWGGkAEAANB6ytzUlvnUgTBlbGyrqyrbYk3S45t8H1k3s0kNNWmAhjYAJijOAGhYeQQNl+PNTDrBQpkWaCS72/Wl+Ito14Gjrurww5ZpuEgaxDIozJgyLcwQMgAAALCBnJvaguTS1CYlK9C4aGyzzU1Vxx82TM85j2Y2qaEmDQSioQ1AAIozAFpDwqARJOl4syDGQSNMd8hyl2HDZL2UPHDkxeb84r7XqIARdkt+innJUvLuryCEDAAAgNbRDE1tpSjQSOkzk1TuzCTZZaY8mtmk0kwaCEJDG4A0KM4AaGhlCRpBcgkaUn5hw2R9XVkKNbbnkaZI5SJgSIQMAAAAFKeAprak16tSxgWaLBvbpA2zSiPlJpPvrySFmSBpJg0kfj4nDW0AQlCcAdA6ShI0GqJAYxI2khRqsg4dSV8rbcBqoMIMIQMAAAB1WTe1ZT3eLJSLAo2UvrHNNDNJ+RZqkuQm06KM7ZQBKXVhJkzSSQNBEj+fMwgNbQDWojgDoGkVETSSjjeTCirQpAkb9W1sAoc0PAgkCSAujmFaZIorypSgMGPKtDBDyAAAAGguVtdehk1tRndVB8ilqU2yL9Bk0dhmuo1fWN6xyTwujiG5+R6jfoYOCjN5NLSZoKENgC2KMwAaXpmCRhDToBEm0wKNlD5s2GwXJSo8uLz7xtX35CpgSKluy5fcjjMLQsgAAABoXkmb2oKkaWorvEAjpWtsMy1g5JWb0rA51zTNbA1SmKGhDUBWKM4AaGrNEDSknAo0Los0aQOHa7bhIm67jAszYQgZAAAASKrsTW02nBRoukMOnrSxTTLPQWXMTbbnlFUzW3fA8pSFmTRMMxMNbQCSoDgDoCnkFTSKLNAESlKgSRs2kgSOIkKH63Ahue38khqi+4uQAQAA0PwaraktjFWBRiq2sc2/fSNkJu8+UXJoZsti0kCaMeVGaGgD4ENxBkDTC7zYCSrQlEjqTjAp2e36UnzYkJKFhyyLNf5juw4XUrKAIWVWmCkVi5ABAACA4rRCU1suBRopn8zkMjelPXbaZjYp88IMDW0AGg3FGQBNI3VnSYmChlRwgUaKDxtSusAQFA6SfqR5/ThJA4aUaWGmUUMGHWAAAADlVPbmmsIKNK4a24rOTWlfO4pJZipxYca5oMzEXTMAAlCcAdASjO+ecdzdUpoCTXfICboIG1I5ZyYHsTnPtAGjO2RdwbflBxVmghgXZrhrBgAAoKE06t0zYTIv0EhuGtukxshMkt15xn3vGTezSeaFmTDOG9oAwBDFGQBNJYsOkzRBw0ams5Sl9GHDNnCUJXTYno/J95okYEi5hwxTQb/PWaADDAAAoNxcN7UVOXVAyqBA46KxTSpfZpLc56a4n1d3yHIHmSlMYZMGaGgDEILiDICWkSZo5NUJlumt+lK6sCHZBQ6puNCR5HVNizIlLMwQMgAAAGAjr6a2tAor0HRHnFSWmSnP3JT0dV00s3WHrHNUmEkzaaBoNLQBrYXiDICmU8aLmdLcqi9FBw0pPmxI9oFDyu7hlmmPa/q9xP1cuiPWlaD7q2hl/HsJAACA4crU1BYm8wKNlH9jW10WucnFMRugmU1K/5wZGtoA5IniDICWUlTQCJN5J1iabrCsijReRT3YUrIryiTt/JIIGQAAACiloptnspg6EKXwxrYic1NSNudegma2vJ4zk5Wi/04CyB/FGQBNKc+LmrSdYJkWaKT0YSOvwJEH2/NMGzAKKMwQMgAAAJBWHk1tYbIaCy1lWKAxyUxSY2QmyT4z5ViYCZN2BLQNGtoAuEJxBkDLcR00bBRyq76ULmxI5mFDKl+hJsn5ZBQwpPy7v2wQMgAAAFpLGZvawrgYCy2lKNB0x7xokiJNWTKTlKyRzSQzdUesT1CYSTsCOkyqSQMO0NAGtCaKMwCaVlYXN2nHm7m4VT+TAk13+GpJdmGjrqjQkfR1MwwYkrvCjE33l/NxZpYIGQAAAI0pr6a2vKcOSAlHQ0vmjW1Jm9vyzE1JX9f0++uOWJdgyoCU76QBKzS0AUiI4gyAlpQmaNhw0QmWW4FGyiZs1Pkv/l0FDxfHdREwpMIKM2EyGWdGyAAAAGgqRTe1hcly6oCUcWObVK7clHdm6o5YnyAzSflPGqChDUAeKM4AaGpOLnJSzlFO2wkmZVCgcRk2kgSOuqCQYPuRhsuAUWBhJm33FyEDAAAAQbJoastqvFmuBRrJLDNJ6TOT1BiZScqkmU0qx6QBKzS0ATBEcQZAyzIOGhay6gSLkqhAI5mFjW7Dk3AROPJiU1TqVqqAIdkXZmxl0v0VhpABAADQlPJsagtTRFOb5KBA0x29yToumtvyYnuu3Sp1YSZMJg1tlu8p0NAGtDaKMwCanvXFTk5zlCU3z5+Rogs0qcKGZB42pPIGjiTn1W2wTUzASFKYyar7ywohAwAAAGulbWrLarxZrgUal0UaqZyZSUqWmbpjtkk4ZUDKdtJAmNQNbQFoaAMQhuIMgJaW9iIp7XizMC4LNFIBYUMqvlCT9PW7VUjAkLLt/iJkAAAAIEpWTW1hbMabFV6gkcwb27rjN1tnmorNTUlfv1upm9mkYgszqceZ0dAGwAGKMwBaQhmCRhhXQUMqYdio81/0uw4eLo7frcIChpR/91coQgYAAAB8smhqC2NboAmTtECTurFNcpebXMszMyWcMiAVU5gJk/b5nDS0AYhCcQZAy7O6WMpxvJlUUIEmy7DhFVRQSfqRRrecBAyp+MKMTfcXIQMAAAB+ZWlqczF1QEpWoJEcNbZJ6XOTy8yUJjd1y/z7SJGZJPvCjK3MJg1YPp+ThjYAEsUZAC3ESdAI4WK8WZ4FGudho9ts89LolvOAUcbCTBhCBgAAAGxl1dSW5dQBKeMCTV5FmqJ0yy4zFVCYKcWkAQBIiOIMACh90AjjohNMcl+gkRyHDakxAke37M4xw4AhuXnGjGTX/RWKkAEAAADl39QWxlVTmxRdoEk1eUBKVqTpNtu8EN1yXpSR4qcM5FGYoaENQNlQnAHQUqIuggIvnhwEDRedYFEyLdBIyYs09Y+idSuTooyUTWEmiovuL0IGAAAAkipLU5vLAo3kYPKAZJeZpMbPTJJxZnI9ZUByV5jJ+/mcAOBFcQYAkggJGmXrBJPMCjSZFGnqupVv6OhWunDhIGBIxYeMIKG/n4QMAAAAeJS9qU0qJjfFcpGZuu13z/31Cm5my7owk+XzOWloA+BFcQZAy3EWNBqsE8xJ2JCSBw5peAiofxR9LIvvySRglDVkBCJkAAAAwIWg3JRhU1uUhm1sq+sO+CjDsUrczGYrq0kDAGCD4gwAmMiwEyyvAo3kMGxI6QOHV3eCDxcsizJpAoZUwu6vMIQMAACAlmfd1BamgKY2KbsCjZRTY5tfd8IPFxw2s0nZFGZcTBoIRUMbgIxQnAHQkrIOGlZvhEcookAjWYQNyW3gyIvlOacNGJLbwowtq3FmIQgZAAAAiFWipjYp+8kD1kWaRspNluds2sxWhsJMEQ1tZCYAQSjOAIApB3cT2AaNKHkUaBIVacoaOBKcn+nPIIvCTJRMx5lx1wwAAADWKqqprQwFGimDxjap3Lkp4bkVmZkyL8w4amgDgCAUZwC0rFYIGmm7waQERRqpPIEjRbgwDRgNGzKChBRmuGsGAAAAxhw8szOJogs01plJKkduSnEORTezZTrKTHLW0EZmAhCG4gwA2LAMGkUXaCQ3YUNyFDiyDh0pX8vmezQJGKUtzDDODAAAAIacNbWFcJWZpGwLNJkWaaT8cpOD1ylzM1uUohraACAMxRkALS1R0Mh47FNRBZrMizR1/jBgGwzS7h/AtiiTJmBIBXd/hWGcGQAAAFwpqKlNyq5AI+XQ2OaVJvdkkJkkt81sUjaFGRraADQSijMAWl7ZxptFybJAI5mHDclR4PCLChEZdJDZfg9pA4ZUgu4vQgYAAAAsNVJTm5TdaGgp58a2MDlmJimbZraGLMzQ0AbAMYozAJBEA3aCSeYFmsKLNBlLUpQpsjBDyAAAAEDRytjUlvTu8qIa2xopNyU55zJmpqzR0AYgDYozAKDs5yhHKaJA4zpsSOUPHEnPzzRglKUwYy3BzGRCBgAAAIxk3NQmJRsLLRXT2CaVv7ktaWbKujATJSo709AGoMwozgBAUo6CRpSsCjRSdmFDKk+hJs15uAoYUr7dXy7GmQEAAABB8hhvltdYaMlNgUZKV6QpOjelOQ+b7zttYSbJszmzLszQ0AYgrUqtxr8WUXp7ezVhwgT19PRo/PjxRZ8OgIzdWKmErjtmRsiKmSHLDwteXDs4ePkTk3YMfe2/aJfQdc9r+9B1kvSCtolc/6K2jVzv9bq2MN42zFsvpT9GGBehxiZUZRUwJEIGgGJxDQwb/L4AraXIzCSF56YsM5NknptcZCYpu9zkqhDkqigjZZOboop5LnJT3IQNchPQ/FxcA1OciUHQAFpLHkFDKl+BRsq/SBPEJIBk1VXmsigjlagwIxEyAFjjGhg2+H0BWk8rNrVJjZOZpNbOTVkXZiQa2gBQnMkFQQNoPWXtBJPKFTak7AJHnmzHD5SxMCMRMgC4xTUwbPD7ArSeqMwkheSmJmlqk1ovN2WRmaQSNbQ5zEwSuQloFS6ugXnmDABYsJ6jnOD5M0kf9J52nrJk9mB7r6TPpCkD23M3/dk0a2EGAAAAqEv05rNlZooSdW2c9Lmd0tC1vKvnd3o1am5Kct6mmak0hRnHKMwAsNFUxZn+/n5dffXVOuOMM7Tjjjtq7NixGjdunPbZZx/96Ec/0sDAQNGnCKAB5HUxleTiMO4B8S4KNFLysFH2wJH0PPPo/EqKkAEAsEFmApCHopva0hRopGwa26TGKNKkyUyumtmyaGgLxaQBAAVqqrFms2fPXhcwDj/8cG2//fbq6enRjTfeqHnz5unoo4/WDTfcoErM7bde3KIPtKZEt+lLpbhVX4q/XV/K7pZ9rzLcvp8m/LgqykjxIbDI58xIhAwAG+IauHmRmQC45HQktFSqsdASmcmETYEqbTNb2ScNkJuA1sIzZ3zmzp2r66+/XmeccYa6urrWLV++fLkOOeQQPfLII7ryyit18sknGx+ToAG0rkRBQyrF82ckt2FDShc46vIIHi460VwGDImQAaDxcA3cvMhMAFzL45mdUnG5qRkzk5RvbiplZpJoaAOQCs+c8dlss8302c9+doOQIUldXV06++yzJUn33JPT/BcADS/q4irRszlyvFVfcne7fl2S2/b9vLfIu7il3/XxbL/HUoaMBDO7oxAyAKC5kJkA5MnVeDMp+ThfF6OhbcZDly0zZXFMm++zlJlJ4vmcAEqho+gTyMuIESMkSR0dLfMtA8jYjXNCOsFuVfSt+gEq94R3gu3+znOhnWC76qnITrDt9NfYTrD6xbJpR1j9ItxFV5jkpmMrLdsAZRrOCgkZYQgZAIAYZCYASRxTq8WOhbZypyLvoAkSlZkkd7mplTOTlE1uSpOZErMszMShoQ1AUk1150yUiy++WJJ05JFHRm7X19en3t7eDT4AtK7EF1mOO8HS3kHj+i4ayU1XWNGSfA+mAaMRur/iEDIAoLWQmQBkwfrumQhJM5PkbvJA3tMHyiDJhIE8CjN5TRqgoQ1AVlqiOPPTn/5Us2bN0mGHHaajjjoqctsLL7xQEyZMWPexxRbFP5gNQHm5HG8WJ02BRsombEiNGTiSFmVcBAwp58JMBEIGAKCOzAQgjUZoapPcFGik5I1tjZSbkp5zHlMGpPJMGqChDUAalVqtfP+KnHPOOerr6zPe/gtf+IJmzAh+ytxNN92kE088UZtuuqkefPBBTZs2LfJYfX19G7x2b2+vtthiCx5uCbS4uNv0y/CgSyn+YZdS/AMv62wefOnn6vZ9l9IEIVcBQ0pemJESdn8RMgAkwAPey4/MBKBsnGcmqbDcZJqZpObKTY2QmaR8Jw1E5SYyE9DaXGSmUhZnxo4dq+XLlxtvf9ddd+mQQw4ZtvyWW27RiSeeqClTpuiee+7R1ltvbX0uBFMAdVFhIzRoSA1doJHShQ2p2MCRtjPNpiOOkAGgmXANXH5kJgBl1EwFGqk1clPLZyaJhjYAiTRtccaFm2++WSeddJImT56su+++W9tum+x/cgQNAHXNFjSkfMNGXZahw9WYANcBQyJkAGgsXAO3BjITgCyUpalNarwCjVdWucnlaDUKMwBaGcWZEPWQMWnSJN1zzz2ht++bIGgA8HIeNKTMCjSS+7AhuQ0cfiYBJMs5zbazo10EDImQAaB8uAZufmQmAFkpU1OblH+BRso2M0nxuSnrZ9uUrZlNynfSgERuAkBxJtCsWbN0wgknaOLEibr77ru1/fZ2/wP1I2gA8EocNKREnWBSOQs0UvaBI0+2RRmpMQszEiEDgBmugZsbmQlA1pqxQCOVr0iTtyya2aSCCjMSDW0AUqE44zN79mztvvvu6uvr06mnnhoYMrq7u3XmmWcaH5OgAcCv0YKGlF3YkBo7cGRVlJEyLsxIhAwAmeIauHmRmQDkIZOmNqkhCzRS62UmqXELMxLP5wRghuKMz913361DDz00cpuDDz5Yd999t/ExCRoAgpRpvJlUfIFGaqzAUXTAkOj+AlBuXAM3LzITgLyUbeqA5K5AIzV/bmrFzCSRmwCYoziTA4IGgCCN2Akm5RM2pHIGjqThQjIPGBIhA0Bz4BoYNvh9ARAm76kDUmM0ttWVLTc1TWaSaGgDkDkX18Btjs8JAFpC3EVZ5EVdxBvoUReQcRefcRevktlFcN12+qvVBbbXtnpx3UeR0p6Hzc9gVz1FYQYAAABwIWFmkvLNTWkyk1SO3OTiHJqhMAMAReDOmRh0gQEIU8TdM5KbTjApv7togmTRIeY60LgOGFKGhRmJ7i8ATnENDBv8vgCI0si5ySYzSW5zU1Z31bjMTbaFqTIXZshNAGwx1iwHBA0AURo5aEjFho0wUSEkj46yIgKGRMgAUC5cA8MGvy8AosRlJinheDMp1VhoqXFzU1zhJuvclEVmkijMAGgsFGdyQNAAEKesBRqpccNGEYoKGBIhA0D5cA0MG/y+AIiTWWaSSlugkZovNyUZ4VaKwoxEbgLgHM+cAYAGkMXzZySDi1OZFQYku2fRSOlnK5dJku+lGQozAAAAQF4ye2ZnjCIzk9Q8uSlpZqIwAwDRKM4AQEqZXqyVuEAjrb9Ib7TAkea8S1OYSYmQAQAAgDJplKa2VirSZJ2ZJPOffWIUZgCUGMUZAHAg006wBggbUmMEjjTn6LLzS6L7CwAAAK0l9TVoSQo0UrLGNqkxmtvSNrLZZCYmDQBodTxzJgbzkwGYSvWgSynVLGXJ7TzluiRzlf2KnLPsIvTYBq+yF2YkijMA4nENDBv8vgCwUdTzZySzzCTZ5aZGz0xS/rnJtBCWVWFGoqENQHouroE7HJ8TALSsY2o1owJNqFsVHzYiVO6JDxv1i2DTsLGrnkodNrwX+nmEDpddaKULGFKqmdsSIQMAAADlduOciAJNXGa6U5EFGpPMJA1d29tkJildkcafYbLOTUVlJslRM5uUagQ0hRkAZcGdMzHoAgNgq8hOMCmbbjDJTUdYmCThI8tRAFkEDMlByKD7C0BOuAaGDX5fANhqhKkDkn1mksqVm8qUmaQcCzNMGgCQA+6cAYAGlGUnmJRNN5jkpiMsTFlmLmcVMKTiu78AAACAssh86oDBHTSS+8kDkpvpA2HKkJtKn5mYNACggbQVfQIA0GxMLuYi30iPu5g0eAPf6KJW5g9h9LJ5yGOjSPI92fzs6P4CAAAA7MRe4+acm2zU80Uz5aak309ZMpPEpAEA5UNxBgAykPqiLsegIdmHDak5ijRZBwyJ7i8AAAAgSOqmNhMFFmjqGj03pclMjVSYAYAiUJwBgIKUqRNMSh82GiVwpD3fshVmCBkAAABoVJlPHZAKnzxQ10iZSUp3vrlnphhMGgBQVjxzBgAyYjJHOfL5M5KzZ9BI5s+hkZI9+FLacP5wlg/CtOUiBDkPGFIuhRlCBgAAAJpaXGaSnD67U7J/fqdXWTOTlD432RaunBVmmDQAoEFVajX+BYrS29urCRMmqKenR+PHjy/6dAA0IJMHXUYWaKT4sBETNOpMw0Zd0sDhV0TocNWVlknAkCjMACg1roFhg98XAGnlkpkko9zUSplJKnluclCYITcByIqLa2DunAGAZmDQCSbZdYNJ6TrCvPwX/FkEjyxGBBRWmInBKDMAAAA0k1ymDkjOJw9I6acP1AXlGde5qekyE4UZAA2OO2di0AUGwIUydYJJ9t1gkruOMFPeIJL3bOYkc6QJGQCaCdfAsMHvCwAXTDKTRG7y8hdv8sxNZCYArc7FNTDFmRgEDQCulK1AI5U/bOQt04AhETIANAyugWGD3xcArjjJTFJLFWjyliQzSeUrzEjkJgDpuLgGbnN8TgCAFGIvIk0edGgxMsuqsLDW7u88l/iCvKySfk8UZgAAAAB3TK5ljd54d5ib0mSmZspNaTJTnoUZU+QmAGVAcQYAcmJ68dcIBRqpOYo0ab6HMhZmAAAAgEbnrEBjwqJA06q5qZEyk0RDG4DG0lH0CQBAKzF50KUR04ddSka369s+9NLLe6HeCLfvpw1GzgOGxG35AAAAgKUb58SMODPJTNLQNbvhiLPKPckyk7Q+hzRCZpJKmJsozABoQjxzJgbzkwFkIddZylLmz6EJUqbQ4apTrYjCjETIAJA/roFhg98XAFkwbWpz8txOySozSW5yU5kyk+QmN1nfYZRjYUYiNwFwx8U1MMWZGAQNAFlwFjSkUoeNuiJCh8vRAZkEDInCDIDS4hoYNvh9AZCVVijQ1JGZQlCYAVBSLq6BGWsGAAUwHW8We6u+lMnt+lK6UWd+/ov+LIJHFnOcE82VLiBkAAAAAM0o17HQktVoaCnbzCS5z01ZPfuGwgwAJMOdMzHoAgOQpUa4g0Zy2xFm44lJOxby8MykD/skZABoFlwDwwa/LwCyVEhmkhomN9ULOA2Tm8hMAJqEi2vgNsfnBACwYHqRaHTRaXDxKmnoYtj0gnityj0pChYpNFTAIGQAAAAAzhWSmSTrzCQVk5t2f+e53HNT4u/TYWYCgGZAcQYAGkSrho28ZB4wJAozAAAAQAKNVKCRmjc3pcpMjgsz5CYAzYDiDAAUzOZikbDhVv37yHyMmcQzZgAAAIAUnBdoMpw8UNdsuSkRx5lJojADoHlQnAGAEnB+0WhboEkZNhotcKQ+Z9ufmePCDCEDAAAArchpgUbKpbFNaszMJDnKTaYozABoQRRnAKAkCg0aUqqwIZW/UOPs/Gx/ThRmAAAAgNxlVqBxUKQpa2aSHJ1jBs1sEoUZAM2H4gwAlEgpCjQpizRSeUKH0/NIEjAcP8iSkAEAAIBW53wstJR7Y5tUnszk/FwyaGaTGAENoDlVajXe6YnS29urCRMmqKenR+PHjy/6dAC0iBsrFaPtjplhcdCZlidxmOX2hmoHZ3NcKcNgk1HAkOj+AlBOXAPDBr8vAIpgmpkki9xkm5mkTHJTlplJyig3JSlYZVCYITcByIuLa+AOx+cEAMjRjXMsgsatsgsb9Ytrx2EjKAgkCR+5dJhlGDAkur8AAACApI6p1YwLNMa5yTYzSZnkprCsY5ubWi0zUZgB0Gi4cyYGXWAAipJJJ5iUrBtMyuxOmlJKOqaAkAGgSXANDBv8vgAoUqlyUytlJonCDICW5uIamGfOAEBJZTJLWUr+HBRHz6MptTTfIyEDAAAAyF2pclMrZCYp+fdJZgKADVCcAYASyzRoUKRZL21RhpABAAAANAQa21KgmQ0AnKI4AwAll1mBRkoeNqTmCBxpvwfLgEHIAAAAANyzvXamsc3CnSptMxsANDqKMwDQAGwLNLmFDanxAkfacCFlHjAozAAAAAB2Mi3QSG4yUyPmpjQsf2bkJgCthuIMADSIUocNqfyBw9W5ETAAAACAUip9ZpLKnZtcnVuCBkByE4BW1FH0CQAAzB1Tq+nGSsV4+xvnSMfMsHiB+gX0TKvTGs57MX9YymO5OAcXEgQxAgYAAACQr4bJTBK5aS1yE4BWRXEGAJqcddiQhi6oXYQNKfhiP4vgkWXnGQEDAAAAaBiZF2gkt0UaaXieyapYk1VuyiEzSeQmAM2F4gwANBjboCGVJGx4RQWCqBCS963/CccWUJgBAAAAipWkQCMV3Njm1SiZSaIwAwAJUZwBgAaUW4FGyrZIE6QMs5dTzJKmMAMAAACUQ1M0tgUpQ2aScmtmk8hNAJpTW9EnAABIJsnF6Y1zkl0IS3Lz8MuyS/DgyrokP1sCBgAAAJCtpLkpkRR5oqGkzE22yE0AmhXFGQBoYEkvUgkbPim/LwIGAAAAUF65FmgkclOApI2C5CYAzYziDAA0uNwLNNL6i/JGDxwOvgcCBgAAAFB+uU8ekJojM0mFNLNJ5CYAzY9nzgBAE0gyS1lK8dBLr7znK7vgICARMAAAAIDGkiY3kZmSITcBQDiKMwDQJOoXr4WEDWnDi/cyhg6HHWsEDAAAAKAxlaKxTSpnZpIKL8pI5CYArYPiDAA0mULDRl1ZQofjEQIEDAAAAKDxJc1MkqPGNqlpM5NEbgIAUxRnAKAJlSJs1Pkv9rMMHhnOcyZgAAAAAM0jbWaSHOamPDNT0Os5kur5PCI3AWg9FGcAoEmVKmx4hQUBmwCS40M1CRgAAABAc0qTmaQMGtvqGiwzSTSzAUASFGcAoImVNmwEyTk8xElblJEIGQAAAEDZuchMUk65qWSZSaKZDQDSoDgDAE2uocJGSVCYAQAAAFpH/dqd3GSOzAQA6VGcAYAWkLZAI7VG2CBgAAAAAK2L3BTPRWaSyE0AIFGcAYCW4aIbTGrOsEHAAAAAACC5KdBIzZebXGUmidwEAHUUZwCgxbgOG1JjBg6X4UIiYAAAAADNwlVmkhq/SENRBgCyQ3EGAFqQy7AhNVbgcF2UkQgZAAAAQLNxNXmgjsxEZgIAP4ozANCiXIcNqbx302QRLiQCBgAAANDssmpsk1ojM0nkJgAIQ3EGAFqc67BRV3ToyDJcSAQMAAAAoFVk0dgmDc8seecmMhMAFIviDAAgs7BRF3TR7zJ4ZB0qvAgYAAAAQGvKqrGtLutiDbkJAMqF4gwAYJ2sw4ZXnsHAFQIGAAAA0NqybmzzIjMBQHOjOAMA2ECeYaNREDAAAAAAeOXZ2NYIyEwAYI/iDAAgEEUaAgYAAACAcGSmIeQmAEimregTAACUW6teaLfq9w0AAADAzjG1Wkvmh1b9vgHAFe6cAQDEaqWOMMIFAAAAgCRaJTeRmQDADYozAABjzRo2CBcAAAAAXCE3AQBMUJwBAFjzXpQ3cuAgXAAAAADISrMUachNAJANijMAgFQaMXAQLgAAAADkpRGb28hMAJA9ijMAACfKHjgIFwAAAACKVubmNjITAOSL4gwAwLmyFGoIFwAAAADKyJ9VispNZCYAKA7FGQBApoIu9rMIHoQKAAAAAI0qr2INuQkAyoPiDAAgdwQCAAAAAAhHZgKA5tdW9AkAAAAAAAAAAAC0EoozAAAAAAAAAAAAOaI4AwAAAAAAAAAAkCOKMwAAAAAAAAAAADmiOAMAAAAAAAAAAJAjijMAAAAAAAAAAAA5ojgDAAAAAAAAAACQI4ozAAAAAAAAAAAAOaI4AwAAAAAAAAAAkCOKMwAAAAAAAAAAADmiOAMAAAAAAAAAAJAjijMAAAAAAAAAAAA5ojgDAAAAAAAAAACQI4ozAAAAAAAAAAAAOaI4AwAAAAAAAAAAkCOKMwAAAAAAAAAAADmiOAMAAAAAAAAAAJAjijMAAAAAAAAAAAA5ojgDAAAAAAAAAACQI4ozAAAAAAAAAAAAOaI4AwAAAAAAAAAAkCOKMwAAAAAAAAAAADnqKPoEyq5Wq0mSent7Cz4TAAAAIB/1a9/6tTAQhcwEAACAVuMiM1GcibFo0SJJ0hZbbFHwmQAAAAD5WrRokSZMmFD0aaDkyEwAAABoVWkyE8WZGJMmTZIkvfbaawTTDPT29mqLLbbQ66+/rvHjxxd9Ok2Hn2+2+Plmi59vtvj5Zoufb/b4GWerp6dHW2655bprYSAKmSl7/JuXLX6+2eLnmy1+vtni55stfr7Z4uebLReZieJMjLa2ocfyTJgwgV/iDI0fP56fb4b4+WaLn2+2+Plmi59vtvj5Zo+fcbbq18JAFDJTfvg3L1v8fLPFzzdb/Hyzxc83W/x8s8XPN1tpMhNpCwAAAAAAAAAAIEcUZwAAAAAAAAAAAHJEcSZGZ2enzj//fHV2dhZ9Kk2Jn2+2+Plmi59vtvj5Zoufb7b4+WaPn3G2+PnCBr8v2eNnnC1+vtni55stfr7Z4uebLX6+2eLnmy0XP99KrVarOTwnAAAAAAAAAAAARODOGQAAAAAAAAAAgBxRnAEAAAAAAAAAAMgRxRkAAAAAAAAAAIAcUZwBAAAAAAAAAADIEcWZlF566SWNHTtWlUpFn/nMZ4o+nYb3m9/8RieccIK22WYbjRs3TmPHjtVOO+2kL37xi5o7d27Rp9fQ+vv7dfXVV+uMM87QjjvuqLFjx2rcuHHaZ5999KMf/UgDAwNFn2LDe+KJJ3Teeedp5syZmjJliiqVig455JCiT6vhPPzwwzrqqKNUrVbV1dWlfffdV1deeWXRp9UULr30Un3605/WXnvtpc7OTlUqFV1yySVFn1bTmDt3rr73ve/pyCOP1JZbbqmRI0dq6tSpOumkk/SnP/2p6NNreKtWrdLZZ5+tgw46SJtuuqlGjRqlqVOnav/999cvf/lL9ff3F32KTeeiiy5SpVJRpVLRQw89VPTpoIGRmdwiM2WHzJQ9MpM75KbskJuyQ2bKFpmpGGlyU0dG59QSBgcHdeaZZxZ9Gk3l8ssv15w5c7Tvvvtq2rRpqtVqeuKJJ/T9739fl1xyie6//37ttNNORZ9mQ3rxxRf1oQ99SGPHjtXhhx+uY489Vj09Pbrxxhv12c9+VrfccotuuOEGVSqVok+1YV133XW68MILNXLkSG233XZauHBh0afUcO666y7NnDlTo0aN0qmnnqpx48bp6quv1imnnKLXX39d55xzTtGn2NC+9rWv6dVXX9XkyZM1bdo0vfrqq0WfUlP5wQ9+oIsuukjbbLONjjzySE2ZMkVz5szRddddp+uuu06XXXaZTjnllKJPs2EtW7ZMP/rRj7T33nvrgx/8oKZMmaLFixdr1qxZOuuss3T55Zdr1qxZamuj98iFp59+Wueff766urq0fPnyok8HDYzM5B6ZKTtkpuyRmdwgN2WL3JQdMlO2yEz5S52bakjsO9/5Tq2jo6P23e9+tyap9ulPf7roU2p4K1euDFz+85//vCap9qEPfSjnM2oeb7zxRu2HP/xhbdmyZRssX7ZsWW2vvfaqSapdeeWVBZ1dc3j66adrjz76aG316tW1N998syapdvDBBxd9Wg2jv7+/ts0229Q6Oztrjz/++LrlS5YsqW233Xa1kSNH1l555ZXiTrAJ3H777et+hhdeeGFNUu2Xv/xlsSfVRK6++ura3XffPWz5vffeWxsxYkRt4sSJtVWrVhVwZs1hYGCg1tfXN2x5f39/7ZBDDqlJqt10000FnFnzWb16de3d7353bZ999ql99KMfrUmqPfjgg0WfFhoUmck9MlN2yEzZIzOlR27KHrkpO2SmbJGZ8uUiN1EmS2j27Nn62te+pnPPPVe777570afTNEaNGhW4/OSTT5YkvfDCC3meTlPZbLPN9NnPflZdXV0bLO/q6tLZZ58tSbrnnnuKOLWmsdNOO+nd7363RowYUfSpNKQ777xTL774oj7ykY9s8O/qhAkTdN5552n16tX61a9+VdwJNoH3ve99mj59etGn0bROPPFEHXzwwcOWH3jggTr00EO1ePFiPfXUUwWcWXNoa2vTyJEjhy3v6OjQCSecIInrBFcuuOACPfPMM7r44ovV3t5e9OmggZGZskFmyg6ZKXtkpvTITdkjN2WHzJQtMlO+XOQmijMJDAwM6IwzztCMGTP0ta99rejTaQk333yzJGnnnXcu+EyaU/3CuKODSYcozt133y1JOvLII4etmzlzpiTCMBoX/85mZ3BwUL///e8lcZ3gwmOPPaYLLrhA559/vt71rncVfTpoYGSm/JGZssX/y1EW5CY0K/6dzQ6ZyT1XuYnf9gQuvPBCPfbYY3rooYcCq5FI78orr9Szzz6rFStW6JlnntGtt96qrbbaSt/85jeLPrWmdPHFF0sKvrgD8jJnzhxJ0owZM4atmzp1qsaOHbtuG6CRvPbaa7rjjjs0bdo07bLLLkWfTsNbvXq1vvWtb6lWq2nRokX6wx/+oNmzZ+vjH/+4Dj/88KJPr6H19fXp9NNP1+67766vfOUrRZ8OGhyZKXtkpnyRmVAW5CY0IzKTW2SmbLnMTRRnLD355JP65je/qS9/+cvac889iz6dpnXllVfq6quvXvf1Xnvtpcsvv1xbbbVVgWfVnH76059q1qxZOuyww3TUUUcVfTpoYT09PZKGbscPMn78+HXbAI2iv79fH/vYx9TX16eLLrqIEVEOrF69Wt/4xjfWfV2pVPSlL31JF154YYFn1Ry+/vWva86cOXr00Uf5XUUqZKZ8kJnyQ2ZCmZCb0GzITO6RmbLlMje1ZHHmnHPOUV9fn/H2X/jCFzRjxgytXr1aZ5xxhrbddludf/75GZ5hY0v68/W66qqrJElLlizR448/rn/+53/WnnvuqWuuuUaHHXaY0/NtNC5+vnU33XSTPv/5z2v69Om69NJLXZ1iQ3P58wXQ2gYHB3XmmWfq3nvv1d/+7d/qYx/7WNGn1BTGjh2rWq2mwcFBzZs3TzfeeKPOO+88Pfjgg7rllls0fvz4ok+xIT344IP6zne+o3/9139l1AEkkZmyRmbKFpkpW2QmAK6QmbJBZsqO69zUksWZn/zkJ1q+fLnx9h/60Ic0Y8YMXXjhhXrqqaf0xz/+UZ2dnRmeYWNL+vMNUq1Wdeihh+r3v/+9tt9+e51++ul6+eWXW/rhga5+vrfccos+9KEPaZNNNtGdd96padOmuTzNhuXy9xd26p1fYV1evb29mjhxYp6nBCQ2ODios846S5dddpk++tGP6sc//nHRp9R02tratPnmm+vv/u7vNHnyZH34wx/WBRdcoIsuuqjoU2s4a9as0RlnnKFdd91VX/3qV4s+HZQEmSlbZKZskZmyRWYqFrkJzYLMlD0yk1tZ5KaWLM4sW7Ys0X6PP/64BgcHte+++wau/8lPfqKf/OQnOu6443TdddelOMPGlvTnG2X8+PHad999dd111+mFF17Qjjvu6Pw1GoWLn+/NN9+sk046SZMnT9Zdd92lrbfe2sGZNYcsfn9hph7Y5syZM2wEyvz587Vs2TLtvffeRZwaYGVwcFAf//jH9X//93867bTTdMkll6itra3o02pq9fn/9Qfkws6yZcvWzaYPezbIfvvtJ0m69tprdfzxx+d1aigQmSlbZKZskZmyRWYqFrkJzYDMlD8yU3pZ5KaWLM4kdcQRR2jy5MnDlr/55pu65ZZbtMMOO2j//ffXHnvsUcDZNb958+ZJUkt3gLlQDxmTJk3SXXfdpW233bboUwIkSQcffLAuvPBC3XbbbTr11FM3WHfrrbeu2wYoM2/IOOWUU/TrX/+amck54Bohnc7OTn3iE58IXHfvvfdqzpw5OvbYYzVlyhR1d3fne3JoOGSmYvHvoRtkJpQZuQmNjsxUDK4R0sskN9WQ2l133VWTVPv0pz9d9Kk0tN7e3trs2bMD1/3iF7+oSarNmDEj57NqLrfcckuts7OzNnXq1NCfNdx48803a5JqBx98cNGn0jD6+/trW2+9da2zs7P2+OOPr1u+ZMmS2nbbbVcbOXJk7eWXXy7s/JrNhRdeWJNU++Uvf1n0qTSNgYGB2hlnnFGTVDv55JNr/f39RZ9SU3nmmWdqy5cvH7Z8+fLltfe///01SbULLriggDNrbvXf6QcffLDoU0GDIzO5QWbKHpkpP2SmZMhN+SI3uUVmyhaZqThJcxN3zqA0Fi1apB133FF77bWXdthhB2222WZavHixHn74YT322GMaP368fvWrXxV9mg1r9uzZOuGEE9TX16dDDjlEv/3tb4dt093drTPPPDP/k2sSs2fP1re//W1J0sqVK9ct8/5ML7nkkgLOrDF0dHTo5z//uWbOnKmDDjpIp556qsaNG6err75ar776qr7zne/QsZ3Sz3/+c91///2SpKeeemrdsvptzQcccIA++clPFnV6De+b3/ymfvWrX2ns2LHabrvt9O///u/Dtjn++OO1++67539yTeDKK6/U//t//08HHHCAuru7NX78eM2dO1ezZs3SokWLdOCBB+qLX/xi0acJAJkiM2WLzJQ9MlN65KbskZuyQ2bKFpmp8VCcQWlMmTJF//Iv/6K7775bt99+uxYtWqSRI0equ7tbX/ziF3X22Wdr8803L/o0G9b8+fPV19cnSbr88ssDtzn44IMJGinMnz9/WBhesGDBBssIGtEOPfRQ3X///Tr//PN1xRVXqL+/X7vssosuuuginXLKKUWfXsO7//77h/2OPvDAA3rggQfWfU3ISO6VV16RNDSH9oILLgjcpru7m6CR0NFHH6158+bpj3/8ox588EEtW7ZMEyZM0K677qpTTz1VZ511ljo6uLQF0NzITNkiM2WPzOQGuSlb5KbskJmyRWZqPJVarVYr+iQAAAAAAAAAAABaRVvRJwAAAAAAAAAAANBKKM4AAAAAAAAAAADkiOIMAAAAAAAAAABAjijOAAAAAAAAAAAA5IjiDAAAAAAAAAAAQI4ozgAAAAAAAAAAAOSI4gwAAAAAAAAAAECOKM4AAAAAAAAAAADkiOIMAAAAAAAAAABAjijOAAAAAAAAAAAA5IjiDAAAAAAAAAAAQI4ozgAAAAAAAAAAAOSI4gwAAAAAAAAAAECOKM4AAEqhVqvpqKOOUqVS0RVXXDFs3Qc+8IHAdQAAAADQKshNANA8KrVarVb0SQAAIEkLFizQrrvuqr6+Pj355JOaPn26JOm73/2uzj77bJ155pn65S9/WfBZAgAAAEBxyE0A0BwozgAASuX3v/+9jjrqKO23336699579dRTT2mfffbR9OnT9dhjj2ns2LFFnyIAAAAAFIrcBACNj7FmAIBSef/7368vfOEL+uMf/6ivfvWrOu2001Sr1fTb3/6WgAEAAAAAIjcBQDPgzhkAQOn09fVp33331RNPPCFJuuiii/SVr3yl2JMCAAAAgBIhNwFAY+POGQBA6XR2duoDH/iAJGnUqFH65Cc/WfAZAQAAAEC5kJsAoLFRnAEAlM6f/vQn/ed//qc22mgjrVq1Sn/3d39X9CkBAAAAQKmQmwCgsVGcAQCUytKlS/WRj3xEHR0duvvuu3XSSSfpyiuv1MUXX1z0qQEAAABAKZCbAKDx8cwZAECpfOxjH9Oll16q//mf/9HnPvc5LV68WLvttpveeecdPfbYY9puu+2KPkUAAAAAKBS5CQAaH8UZAEBpXHrppfrYxz6mY445RjfccMO65ffee68OPfRQ7bHHHnrwwQc1YsSIAs8SAAAAAIpDbgKA5sBYMwBAKbz88sv63Oc+p2nTpg27Ff+ggw7Sueeeq0cffVTnnXdeQWcIAAAAAMUiNwFA8+DOGQAAAAAAAAAAgBxx5wwAAAAAAAAAAECOKM4AAAAAAAAAAADkiOIMAAAAAAAAAABAjijOAAAAAAAAAAAA5IjiDAAAAAAAAAAAQI4ozgAAAAAAAAAAAOSI4gwAAAAAAAAAAECOKM4AAAAAAAAAAADkiOIMAAAAAAAAAABAjijOAAAAAAAAAAAA5IjiDAAAAAAAAAAAQI4ozgAAAAAAAAAAAOTo/wOdbrBzfTHwmgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] @@ -332,13 +546,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 27, "id": "9db496a1-7b7d-44b3-a5f8-a662ea10bb5a", "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABqgAAAMyCAYAAAAR60BPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxV1d7H8c8ZmQ6ggAIKIs5DzmbOw600LUutnO5VKystNW9d7XZ9Si3v0/SUpaXeVETSTE29Ug6pKThnOZvmjCgqKCIznPn5Y3MOHA8qqAjq7/167dfeZ62191kbieB891pLZbfb7QghhBBCCCGEEEIIIYQQQghxl6jLuwNCCCGEEEIIIYQQQgghhBDiwSIBlRBCCCGEEEIIIYQQQgghhLirJKASQgghhBBCCCGEEEIIIYQQd5UEVEIIIYQQQgghhBBCCCGEEOKukoBKCCGEEEIIIYQQQgghhBBC3FUSUAkhhBBCCCGEEEIIIYQQQoi7SgIqIYQQQgghhBBCCCGEEEIIcVdpy7sD9zuLxcK+ffsIDg5GrZY8UAghhBBCCCGEEEIIIYR4kNlsNlJSUmjRogVa7YMb0zy4d36X7Nu3jzZt2pR3N4QQQgghhBBCCCGEEEIIUYH89ttvPPzww+XdjXIjAVUZCw4OBpRvtNDQ0HLujRBCCCGEEEIIIYQQQgghytPFixdp06aNMz94UElAVcYc0/qFhoYSFhZWzr0RQgghhBBCCCGEEEIIIURF8KAvC/Rg370QQgghhBBCCCGEEEIIIYS46ySgEkIIIYQQQgghhBBCCCGEEHeVBFRCCCGEEEIIIYQQQgghhBDirpKAqoSmTZtGREQEnp6edOzYkQMHDpR3l4QQQgghhBBCCCGEEEIIIe5JElCVwKJFi/jnP//JlClT2LNnD3Xq1KFHjx5kZmaWd9eEEEIIIYQQQgghhBBCCCHuORJQlcAXX3zByJEjGTp0KI0bN2bu3LlYLBYWLVpU3l0TQgghhBBCCCGEEEIIIYS452hv5+T//ve/zJw5k71795KTk0NoaCht27bl008/JTw8/E718aYWLlzI1q1b2bNnD4cOHcJkMhEdHc0LL7xw3XN+//13Jk2axI4dOzCbzTRp0oS33nqL/v37u7QzmUzs27ePSZMmOcu0Wi1du3Zl586djBw5sqxuSwghhBBCCCGEEEIIUQyz2YzVai3vbgghBAAajQadTlfe3bjn3FJAZbfbGTlyJLNnz6Z27doMHDgQX19fLly4wObNm0lMTLyrAdW7775LYmIiQUFBhIaGkpiYeMP2cXFx9OjRA09PT2ffly9fzoABAzh37hz/+Mc/nG1TU1OxWq0EBwe7XKNq1aqcOnWqTO5HCCGEEEIIIYQQQgjhLjMzk9TUVIxGY3l3RQghXHh4eBAUFISfn195d+WecUsB1fTp05k9ezavv/4606dPR6PRuNRbLJabXmPBggV07tyZiIiIYuutVivTpk1j9OjR6PX6G15r7ty51K1bl4iICD7++GP+9a9/XbetxWLhlVdeQa1Ws2XLFpo3bw7AxIkTadOmDRMmTOC55567br+EEEIIIYQQQgghhBB3X2ZmJufPn8dgMBAUFIROp0OlUpV3t4QQDzi73Y7ZbCYjI4Pz588DSEhVQqUOqPLy8nj//fepVasW06ZNcwunQJkC70aSkpJ45ZVXCA0NJT4+3i0MstlsDBs2jO+++w69Xs/o0aNveL3HHnusxP3ftGkTp06d4sUXX3SGUwD+/v5MmDCBF154gZiYGCZOnAhAUFAQGo2GlJQUl+tcunSJkJCQEr+vEEIIIYQQQgghhBDi1qWmpmIwGAgLC5NgSghRoXh5eeHr60tSUhKpqakSUJWQurQnrF+/nqtXr9KnTx+sVisrVqzg448/5j//+Q8nT54s0TXCwsL4/vvvSUpKolu3bpw9e9ZZVzScGjp0KK+//nppu3hD8fHxAHTv3t2trkePHgBs3rzZWabX62nRogUbN250llksFuLj42nXrt0d7ZsQQgghhBBCCCGEEMKd2WzGaDTi7+8v4ZQQokJSqVT4+/tjNBoxm83l3Z17QqlHUO3ZswdQFv1q2rQpx48fd9ap1WrefPNNPvvss5tep2/fvnz//fcMGjSIrl27Eh8fT1hYGC+88AILFy7kr3/9K9HR0ajVpc7QbujEiRMA1K1b160uJCQEg8HgbOPw5ptvMnz4cFq1akXLli357LPP0Gq1DB48+LrvM2PGDGbMmIHJZLqj/RdCCCGEEEIIIYQQ4kFjtVoB0Ol05dwTIYS4PsfPKKvVKj+vSqDU6c+lS5cAmDp1Kv7+/vz2229kZWWxZcsW6tWrx+eff86sWbNKdK3nnnuOhQsXcvbsWbp168agQYNYsGABAwcOJCYm5o6HUwAZGRmAMqVfcfz8/JxtHAYPHszHH3/MhAkTaNGiBceOHWPdunU3HKY3atQojhw54hyxJYQQQgghhBBCCCGEuD0yekoIUZHJz6jSKfUIKpvNBihT361cuZJq1aoB0KlTJ3744QeaNWvG559/zmuvvVai6w0YMACLxcLf/vY3Tp8+TZ8+fVi4cGGxa1uVp7FjxzJ27Njy7oYQQgghhBBCCCGEEEIIIcQ9r9RDlBwjj1q3bu0MpxweeughatWqxalTp0hPTy/R9ex2O5s2bXK+Pnz4MCkpKaXtVok5+n/tKCmHzMzM646uEkIIIYQQQgghhBBCCCGEELev1AFV/fr1AahUqVKx9Y7yvLy8m17Lbrfz6quvMm/ePAYMGMDChQs5ffo03bp148KFC6XtWok41p66dp0pgOTkZLKzs4tdn0oIIYQQQgghhBBCCCHuB2fOnEGlUvHCCy+Ud1duqmbNmtSsWbO8uyGEKAOlDqi6desGwJ9//ulWZzabOXnyJD4+PlSpUuWG17Hb7YwYMYK5c+fSv39/vvvuO/7617+yYMECTp06Rbdu3bh48WJpu3dTXbp0AWD9+vVudevWrXNpI4QQQgghhBBCCCGEEBXNSy+9hEqlIjAwEKPRWN7dqVByc3P5/PPPGTx4MA0aNECtVqNSqThz5kx5d83N8ePH6d+/P0FBQXh5edGsWTNmzZqF3W4v1XWSkpIYMWIENWrUQK/XU61aNV588UXOnTtXRj0Xd0q3bt1o1KgRM2bMKO+ulItSB1S1a9eme/funDx5krlz57rUffzxx6Snp9O3b1+02usvb2W323nttdeYM2eOM5xyrDk1aNAgl5AqOTm5tF28oUcffZRatWqxaNEi9u/f7yzPyMjgww8/RK/XM3To0Dv6nkIIIYQQQgghhBBCCHEnZGVlsXTpUlQqFWlpaaxcubK8u1ShXLp0iXHjxvH999+Tn59P5cqVy7tLxTpy5Aht2rQhNjaWnj178sYbb2C1Wnn99dd54403SnydU6dO0apVK2bPnk3Dhg0ZO3Ysbdq0ISYmhtatW3Pq1KkyvAtxu+Li4jhy5AijRo0q766Ui+unSDcwc+ZM2rdvzyuvvMLKlStp0KAB+/btY9OmTURERPB///d/Nzz/woUL/Pe//+X555/nu+++cwuzBg0ahN1uZ+jQofzyyy/87W9/u+H15s6dy7Zt2wA4dOiQsyw+Ph6Ajh078vLLLys3rNUyd+5cevToQefOnRk4cCC+vr4sX76cxMREPvvsMxkyKoQQQgghhBBCCCGEqJCWLFlCTk4Ob731Fl9++SVRUVEMGDCgvLtVYQQFBbF+/XpatWpFQEAATzzxhHPmrIrktddeIyMjgzVr1tCzZ08ApkyZwmOPPcbXX3/N4MGDadeu3U2vM3bsWC5dusS0adNcgq0ffviB/v37M2rUKH7++ecyuw8hbkepR1CBMopq9+7dvPDCC+zZs4fp06dz4sQJRo0axW+//UZISMgNz69evTo7d+5k0aJF1x1pNXjwYI4cOXLTcApg27ZtxMTEEBMTw969ewHYvn27s8wRXjl069aNbdu20aFDB5YsWcKsWbMIDg5m8eLF/OMf/yjhV0EIIYQQQgghhBBCCCHurqioKLRaLW+//TbdunVj48aNJCYmFtvWarXyySefUKdOHTw9PalTpw4fffQRNput2PZxcXG89NJL1K9fH4PBgMFgoHXr1syePbvY9iqViq5du3L+/HkGDx5MUFAQvr6+PPnkk5w+fRpQlorp06cPAQEB+Pr68txzz5GSklLq+05PT2fEiBGEhITg6elJixYt+P77793aGQwGHn/8cQICAkr9HnfL8ePH2bJlC926dXOGUwB6vZ4pU6YAMGfOnJteJz8/n3Xr1hEcHMyYMWNc6p5//nmaN2/OunXrnP8WQlQ0tzSCCiA8PJzo6OhbfuNatWrdtE29evVKdK358+czf/78Ur1/mzZtWLt2banOEUIIIYQQQgghhBBCiPJy5MgRfv31V3r16kVwcDBDhw5l48aNREdHM3nyZLf2r776KvPmzSMyMpJRo0aRn5/P1KlT2bFjR7HX/+STTzh58iRt27alb9++pKen8/PPPzNixAiOHTvG559/7nbO1atX6dixIyEhIQwbNozjx4+zatUqjh49SmxsLJ06daJVq1a89NJL7Nmzh+XLl5OWlsamTZtKfN8mk4nHHnuM7OxshgwZQk5ODkuXLmXw4MGkpqa6hTMVnWPmr+7du7vVdezYER8fHzZv3nzT61y5cgWLxUJERAQqlcqtPjIykv379xMXF1eiz+OFuNtuOaASQgghhBBCCCGEEEKI8mS328kzW8u7GzflpdMUGyCUVlRUFABDhgwBoF+/frz++utER0czceJE1OrCCbPi4+OZN28ezZo1Y/v27fj4+AAwYcIEmjdvXuz1Z82aRWRkpEuZxWKhV69eTJs2jbFjx1KjRg2X+oMHD/Lmm28ydepUZ9nrr7/OrFmz6NSpE5MnT2bs2LGA8u/11FNPsWbNGvbu3UvLli1LdN8XL16kbt267NixA71e77yPFi1aMH78ePr160f16tVLdK3SiI+Pd4ZJJVGzZk1eeOGFm7Y7ceIEAHXr1nWr02g0REZGcuTIESwWy3VnIAOoXLkyGo2GxMRE7Ha72/dYQkICoIzYEqIikoBKCCGEEEIIIYQQQghxT8ozW2k0seKtL3StIx/0wFt/ex/Fms1mFixYgJ+fH3369AGU6ez69u3LwoUL+eWXX1xG5Hz77bcATJw40RlOgbL8ytixY3nvvffc3uPacApAq9UycuRINmzYQFxcHMOGDXOpNxgM/Pvf/3YpGzRoELNmzSIwMNBlXSSVSsXAgQNZs2YNBw4cKHFABfDhhx86wymAsLAw532U1dIt8fHxvP/++yVu36VLlxIFVBkZGQD4+/sXW+/n54fNZiMrK4vKlStf9zre3t507tyZuLg4Zs6cyahRo5x1K1asYP/+/YAyPaIQFdEtrUElhBBCCCGEEEIIIYQQ4u6JjY3l8uXLPP/883h6ejrLhw4dChSOrnI4cOAAAJ06dXK7VnFlAFlZWUyaNIlmzZphMBhQqVSoVCqeffZZAC5cuOB2Tt26dfH29nYpCw0NBaBp06Zuo3ocdcVd63q0Wi3t2rW77n3s27evxNcqjcmTJ2O320u8lWa01Z3yxRdfYDAYGD16NE888QRvv/02/fr14/nnn6dp06YALiPrhKhIZASVEEIIIYQQQgghhBDinuSl03Dkgx7l3Y2b8tJpbvsajgDKEUg5PProo1SvXp3Y2FjS0tIICAgAlFE6arWaoKAgt2sFBwe7lZlMJrp27crevXtp0aIFQ4YMITAwEK1Wy5kzZ4iJicFoNLqd5+fn51bmmJbuRnVms/lmt+wUFBRUbMjiuA/HiKR7hWPk1PX6nZmZiUqlwtfX96bXatasGb///juTJk0iLi6OuLg46tSpwzfffEN6ejrjx4+natWqd7T/QtwpElAJIYQQQgghhBBCCCHuSSqV6ranzrsXnDt3jvXr1wPKNHLXs3DhQueUev7+/thsNlJTU6lSpYpLu5SUFLdzY2Nj2bt3L8OHD2fu3LkudYsXLyYmJuZ2b+OWpaamYrPZ3EIqx31cb6q821VWa1A51p5yrEVVlNVqJSEhgcjIyBuuP1VUgwYNWLJkiVu5oy+tW7cu0XWEuNvu/5/eQgghhBBCCCGEEEIIcQ+bP38+NpuNjh07Ur9+fbd6i8VCTEwMUVFRzoCqWbNm7N27l61bt9KvXz+X9lu3bnW7xqlTpwB45pln3OqKa383WSwWdu7cSYcOHVzKHf1q0aJFmbxvWa1B5QgZ169fzzvvvONSt23bNnJycm4YRJZEVlYWP/30E4GBgTz++OO3dS0hyooEVEIIIYQQQgghhBBCCFFB2e12oqOjUalUxMTEUKtWrWLbHT9+nJ07d7J7925at27NkCFDiI6O5oMPPqBHjx74+PgAcP78eaZNm+Z2fkREBKAEJL1793aWb968mTlz5pTBnZXOhAkT2LBhA3q9HoCkpCSmTZuGh4cHAwcOLJP3nDx5MpMnT77j161fvz6dO3cmLi6OtWvX0rNnT0CZZvG9994D4OWXX3Y5JzU1ldTUVIKCglymbczLy0On07mMtjIajQwfPpy0tDSmTZvmsmaZEBWJBFRCCCGEEEIIIYQQQghRQW3atImEhAS6dOly3XAK4MUXX2Tnzp1ERUXRunVrunXrxosvvkh0dDRNmjShb9++GI1GlixZQtu2bVm1apXL+b1796ZmzZp8+umn/PHHHzz00EMcO3aMVatW0bdvX5YtW1bWt3pdoaGh5OTk0LRpU3r37k1OTg5Lly7lypUrTJ8+nerVq7u0HzduHKmpqQAcOnTIWWYwGAAl/OnYsePdvYlrzJw5kw4dOtCnTx8GDBhAaGgoq1ev5vDhw4wePZr27du7tP/66695//33mTRpkktotmfPHvr168fjjz9OeHg4mZmZrF69mrNnz/LKK68wZsyYu3xnQpScBFRCCCGEEEIIIYQQQghRQUVFRQHcdOq4AQMGMHbsWL7//numTp2Kl5cXc+bMoV69esyZM4evv/6asLAw3nrrLfr37+8WUBkMBjZt2sT48ePZsmUL8fHxNG7cmO+++47g4OByDaj0ej0bNmzgnXfeYcGCBaSnp9OgQQO++uorBg0a5NZ+2bJlJCYmupQtX77cedy1a9dyD6gaN27Mrl27ePfdd1m9ejU5OTnUq1ePGTNm8Nprr5X4OjVq1KBr165s3bqVlJQUvL29admyJVOnTuXZZ58twzsQ4vap7Ha7vbw7cT9LSkoiPDycc+fOERYWVt7dEUIIIYQQQgghhBDinpOfn09CQgKRkZEyXZkQosIq6c8qyQ0UMoJKVEgmi42MXBOZ2dlkZ2eSk51JbnYmedmZ5OdlYcrNxpKfjcmYh81sRIcFHRb0KjM+GisGrQ1PnQYPvQ5fTw98vTzw9fbAZleRZUbZTJBrUZFr1ZBrVZFrUZFj1WCyqTHZNZjQKnu7GpNNgwkNVnTYVBpsah02tRa7WocFZY9KjVqtRqVSoVaBVqNCo1ajVauUTaNCo1ahLSjTaArKi7zWqdVKG02R8iKvNWoVuqLX0ajQadToNYXHuoK9VqNCr1E7j3VqNWq1qrz/aYUQQgghhBBCCCGEEEIICajE3We32Tj1v63R2K2osaC221BjRWO3osGK2m5Fh5nKGKmist3x9zcAoXf8qgqTXYMZLRaUvRktFrsGK2qsqLGhxoYKW8FrR5nj2I4Kq921vLBehQU1RuWrhsWuxYoaMxqsaFz2yntq3OqsKg12tRZUWuwaHag02NU61BqtUq7WodJoUWl0ULBXaxxlejQaHSqtFrVWh0ajR6PVotLq0Wm16NQqdFolVNNr1WjVSljmONZrlU2nUeGhVaPXaJyvHXUeGg06raogcFOX0b+SEEIIIYQQQgghhBBCiPImAZW461RqNZGW02hU15ldsphBPma0GFVemDWeWDReWLVe2LXeqHSeoPUg36YlLR+u5EOaEXKtGgA02NAUREIa7GhUtoJxUFa0WNFiQY9FOVZZ0WPBQ23DQ21DX/BaqypobzcroZrdgsZuRoV7/5VzrDe9nwrBkf1ZAfPtXcpqV2FBg+U6YZl7nZpctGQW1JnQYSrYm+1aTAXhnlWtx6rWKSPWVHrsGh1WtR40euwaPXa1HrR6UOtRaZUNrQcqrR611gON1gO1zgO11gO13gOt1gON3gO9VusMxfQaNbqCvV6rxkOrjDrz0Krx0Knx0GqUY62EZkIIIYQQQgghhBBCCHGnSEAlysXhbnNRq5WROGqtFrVGGcWj1erx0OsxeHvj7euHxsMHdD7oNFp011wjITWHH3afY93hZE5dznGp02lUNK7mz0PV/ahTxUDtqgaqV/LCYrWRnGnkfHoeSVdzOX81j/PpeZy/mkdyZj62UqzIpsaGriDc8lBZ8NOr8NXb8dfZMejtGLRg0IFBa8NHB15aFV46Fd5a8NKCj16Fj1aFt06Fjw68dUq9XgXYrWCzuu7t9iJlFmWzFuxt5iKvlWO71YzdquxtVsdrc2G5zQJWx3nmwmvaLKhsFlT2gn3BsdpmQWVXxnRdS6Oyo8GCBxb3L9SdCOiKhml3gLlgCkczWkzoMNp15KPHiLLPuOa10a7DiA6TSo9F7YG1YLOoPbBrPLBpPbBrPLFrPQsCMk/QeaHWeaDSeaPWeaLx8EKj90Kv0xWEX47gS9k7wrHC8muOtRp0GhUqVUVNPIUQQgghhBBCCCGEEKLkJKAS5aJp1+du6bwco4XVhy7yw+5z/H7mqrNco1bRKqIyXetX4ZHIQBpX88NTpyn2GvVCir+22WojOSOfpKt5XMzIIy3HxNVcE2k5ZtJzTaTlmMjIM5NjspBrtJJrspJnVmMEcuzKyC2Mt3RbLvRaNX6eOvy8tPh76QqOdfh5FrwuKFOOtc56fy8dvp5adAWjfFQUZkPFfyVukc123VBMCbqsxb++JgRzCcesRuwWEzaLEatZ2WyOvcWE3aLU2y1GsDr2ZlTWwr3KZkZtNaG2mVDbzKhtZjR2MxqbsnpYUbqCUXHOf7DSZj52lLDsFkafmewa8tGTjwe5dg/y0JNXcJyJB3l4kGf3IBcP8tGT6zxW6ixqT8waL6waL6w6L2wab+xaT2w6b9D5oNZ5oNdqXMItT50GT50aT61GOdZr8HSWF9TpNAX1yrFHkTIJxoQQQgghhBBCCCGEEHeaBFSiwrPb7ew9e5Ulv59j1cGL5JqUsEGtgi71qtC3ZRhd6lXB3+vaMValo9OoCQ/wJjzAu8TnWG128sxWco0Wck1Wsgv2jhArx2hRjh11RgvZBeVZRjMZeWYy8yxk5pvJzDNjs4PJYiM120hq9q2lXd56TZFgS1skzFJCLr8iIde1IZivhxa1+iZBhFoNaj2gv6X+XY8KJUi7o2Gag80GVpPrVhByYclXXpvzlDJLwd6cB5Z8bOY8rKZ8rKY8rKY8bKZ8bOZc7GYj9oI2WPJROTarCbU1H43ViNpmRGszorEXjixTpoHMw4+82xtd5gjIri22q5SQCz15dg9y8CIbT3LshfscPLla5DjbsceLHLunS1kOnqhURcIsrSPAKhp6uYZdHtprg69r6q8Jw7x0Grz0GuexBGJCCCGEEEIIIYQQQtz/JKASFdalzHxW7DvP0t3nOF1kCr/IIB+ebx1GvxZhhPh7lmMPlZFbBg8tBo/b/0/JbreTbbSQmW8hM88RXpldX+crgVbhsZmsfOV1tlEJQXJNyuiuixn5pe6DSgW+HtobjtLy89Ti66mM1vIrGLXl56m0N3hq0dws4CoPajWoPUFX+u8XdcF2W/GnzVoQZBUEX+Y8MOcW7HOUvSm3yHGOs43dlIPNlIvNqOyVOuVclTkXtSUPtSUXtU0ZyqVR2TGQj4H8O7b+WY4j6DJ7kmP2JCfPi2y7J5n4kGX3IhMfMu3eZOFNht3b5XVmwWtzKf53o1Gr8CoItLz0aiXAcr7WFL4uclw04HKcUzT8uvZ8T52mYn6vCiGEEEIIIYQQQgjxgJCASlQouSYLG46ksGLvebaeuOxcE8pbr+HJJqH0fzic1hGV78vRFSqVqiD40VG9klepz7dYbWQbLW6jsq4XbGXmW4qEYGbyzTbsdpRALN8C5N3SfRg8tPh6ap3BlW9BoOXnVSTYKrJ3lDvKvPWa++/fV60BvY+ylVKJR5ZZLQXBVcHmCLNM2cpmdOyzbvI6G0xZ2I3ZqOzKEC0flREfjFS9jX8Wk0pPjspAjsqHLBzhlQ8Zdi/SbV6kWb3IsHuTbjdwFQPpJgPpRgNpGMjBkzuWthWh16qLCbjUroHWdQKu4gIy74JNOdbiJSGYEEIIIYQQQgghhBDXJQGVKBdmq42sfAtZ+WYSr+Ty58VMtp1MZVdCGiaLzdmudURl+rcOp1fT0DsySul+ptWoqeStp5L3rU29Z7RYiwm2LM4Ayxl8Fbx2/PtlFuzzzcq/W7bRQrbRwsWMW7sPjVrlDLh8PdyDLb8bBV4FI7o8tGUyUWDFptGCxg88/e7I5VR2uzLqqyCwujbAwpgF+ZlgzIT8jIKt6HGGUmfMBEBvN6G3p1GZtGLejBv+38im1mHWV8Kk8yNfV4l8rR85Gj+y1X5kq33JVPmSURBspdkNpFl9uGz1IdOiIc9sJd9kJc+sbI7vU1Cm0zRZbGTklXIhsVLw0KoLgiutM8Ty0mncy/QavHXaIgGXo1x73XMc680JIYQQQgghhBBCCHEvkk/8xV1nstio9+7a69bXCPCmb4vq9G1RnZpBpR9xIm6Nh1ZDFV8NVXw9bul8k8VGVkFw5RJgOUKvgtdZBaFXVr6yDldmXmG5xWbHarOTnmsmPdfMrY7i0mvVRcKswuCquMCr6AivCj9V4d2kUoHOS9mocuvXsVkLQqxM1+Cq2FArHfLSIS8N8q5CbhpYjahtZjzyL+ORfxnf0ry33gA+QeAXBD5VwCcIu3cVzF6BmDwCyddXJlenbNkaf3It6oIQy0pe0VCryHGeyabUF21TsM81WckzWcg1W7EXjP40WmwYLTau5t75EEynURUEV67Blpdei7dO416mLwy5vPUaWtaoXO7TpAohhBBCCCGEEEKIB5cEVOKu02vVeGjVGC02vHQaQv09aRjqR4salehavwq1qxjuvyneHgB6rZpAgweBhlsLuOx2O3lmqzPAyrxmhJZLsFWkvGjAlVWwDpfJYiM120hqtvGW78dHrykMtjyLX3vLUX7tCC5fTx0+9+NUhbdCrQGvyspWWna7Ml2hI6zKu+oaXuVdvX6d3Vo4veHVM85LqgB9wWa49v28KoN3YZil7KtAQDAYQsC3YPOpqoxYu2637eSbbeSaLEpoVRBe5Zos5JkcQZbyOtdsdZY5Ay6Xc4qUmazkmq1YC+Y+NVvtmK2OKTlLb9ZfW9KzSegtnSuEEEIIIYQQQgghxO2SgEqUi10THsXHQytTVAknlUpVMLJDS7DfrY3qsNrsZBvdg6vMogGX0XVk17UjvhxTwOWYrOSYrLc1VaHBQ6uM2PIoDLD8vZRRWv5eyqgt52tv13IvnQRcqFSFa3f5h5X8PLtdGaWVkwq5VyDncpEttchxQV1uKththYHXlRM365gSXPleE1wZgsE3FJVvCF6GYLwMwQQavG/rS+B+a3ZMVts1oZZ72OUItQrDriIhV0H7qn63FiYLIYQQQgghhBBCCHEnSEAlysWtrpMkxI1o1Cr8C0IgbmHADhROVVjcyK3rTVV47Zpc1oKpCjMK1vK6lakKdRqVMuWgV8HmWRBm3SjkKtJW+yCHvyoVePorW2Dtm7e3FYRTjuAqN7UwyMpOgawUyE4u2Kcoo7NyLikbh258be9A8KuuBGx+1cG/OviHFx77hoJGV4pbU+Gh1eCh1VDpzmZfQgghhBAKmw2sJmWz2wA7zvmLoeDYfs0eUGuVUeZqLah1yu84D/oDV0IIIYQQ4oYkoBJCiCLKcqpC5bUSWmXmXftaaZORZ8Zqs2O22rmSY+JKjumW+uGj1xQGVjcJtBzllbz0VPLW4anT3NJ73rPUavAJVDYa3LitzaqMyspKVrbs5CLHKZB1sTDIspmVtrlXIPlg8ddTqZWRV44QyyXICoNKNcE7QD7cEUIIIYTCnF9k/c6CzZwDpoLNnAum3IKy3ILXRcoteWA1FwZQjmOLqUgoZb1z/VWpC8MqtUY5dqxzqvMCnXfB5lW41/soew+/woeOPP3BqxJ4Vip8XYqHfIQQQrg7c+YMkZGRDBs2jPnz55d3d26oZs2agNJnIcT9RQIqIYS4g253qkK73U6uyaqEVvlmMnILg6vMghFZxYVcjrock/KBgmOKwgsZ+aXug6dO7QyrKnnrqOytHPt76alcUFbJW08lL2Vf2VuZotBD+wAEW2oNGKoqW2jT67ez2ZQ1sbIuQsZ5yEwq2J9X9hnnIPOCEmJlXVS287uLv5beAJUioHJEwb5mkeMI5UMcIYQQQtw7XKYjTlNGb+deUV7npbkHUPmZhcfWW19jtVzYbUqfy6LfOh9lDdGia4caqhQeO8uClb36AZ5hQAhxX3rppZeIjo4mICCACxcu4OEh07g75ObmMmvWLPbs2cPevXs5fvw4drudhIQEZ9hVURw/fpx3332XTZs2kZOTQ7169Rg5ciQjR44s1dIPSUlJTJkyhbVr15KcnExQUBA9evTggw8+IDw83K29zWZj5syZzJs3j6NHj6LVamnevDnjxo3j6aefvpO3KMQNSUAlhBAViEqlwsdDi4+Hlmp4lfp8i9XmHK11bZiVcc2IrQzHCK88M+kFr602O/lmG8nmfJIzSxduees1VPLS4e9dXJBVeFzZR1/QTgm/7su16NTqgg9LgiCkSfFtbDZlGkGX8CrJNcTKugimbLh0WNmK4x10TWhVEwLrKJuhqoy+EkIIIe4GqxmyLxWZFjhZeZ2TWiSAulI4utpmvo03c0xn7Ace/gVrdnoXjDwqONY5XhfdF5Rr9EU2XfHHWr0y2kmlLvK7hKrguGB/7e8YNivYLMrXwma55tis1FvNyigus2PLLRzp5XjtKDNmuod1eelgylLez5yjbJlJN/+SqXXgFwp+YcpI9WunYK5UQwm7hBDiHpGVlcXSpUtRqVSkpaWxcuVKBgwYUN7dqjAuXbrEuHHjAIiIiKBy5cqkpaWVc6/cHTlyhPbt25OXl0f//v2pVq0aq1ev5vXXX+fIkSN89dVXJbrOqVOnaN++PZcuXaJ79+4MGDCAEydOEBMTw5o1a9ixYwe1axcugWC32+nfvz/Lly+ndu3aDB8+HKPRSGxsLM888wxfffUVo0ePLqvbFtfo1q0bOp2OUaNGMWrUqPLuzl0nAZUQQtxHtBo1AT56AnxKv86b3W4ny2ghI9dMeq6Zq7km0vPMpOeanK8zcpUwy3F8NddERp4Zmx1yTVZyb2HUlp+nlkCDh7PfAd56Agx6AgteV/YpPA708cBLf5+M1FKrwTdY2aq3Kr6NOV8Jqq4mQvoZuHqm4DhR2eenF3zolVr8CCy9QVmHK6B2YWgVWFvZ5EMYIYQQ4uZsVmUq38zzRabyLQihsi4WTPGbrIRO2G96ORc6H2WKYe9A5YET70Blat+i09gVt+kNFXM0kFqjbNoyfoLfaikMrxwj0LIvFa4n6tiyi6wvajND+lllux6vylA5EgJqFdkKXvtUkYd+hBAVypIlS8jJyeGtt97iyy+/JCoqSgKqIoKCgli/fj2tWrUiICCAJ554gnXr1pV3t9y89tprZGRksGbNGnr27AnAlClTeOyxx/j6668ZPHgw7dq1u+l1xo4dy6VLl5g2bRpvvPGGs/yHH36gf//+jBo1ip9//tlZvnz5cpYvX06HDh3YsGEDXl7KA9IffvghrVu3Zty4cTz11FMVbrTZ/SouLo6wsLDy7ka5kYBKCCEEoIze8vNU1qcKDyj5eTabEmy5BFl5RUKu3IKQq6DMcZyRZ8ZuRxnxlW8hITWnRO/nqVMT6FMk0CpmcwZaBg/8PLWlGhZfoeg8IaiushUnL70wrHLs005D2inlAxhTNlw8oGzX8gooDK2C6kCVBspWuaby4ZIQQghxv7PblYAjM0kZxeyclrfgOCNJCaFKuiaTWgs+VZWHTwwhyt6nSpEAKkAZXe1dEErpSj9aXgAarfK19A5QAqSbsZoLQ8aio9Uzi4xgz7kMeVeV7cJe92vofaFKfaha8PtSlYbKsV91Ca6EEOUiKioKrVbL22+/zYEDB9i4cSOJiYlERES4tbVarXz22WfMmTOHpKQkwsLCGD58+HUDrbi4OBYsWMD27ds5f/48AA0aNODVV1/l1VdfdWuvUqno0qUL3333HePHj2f9+vUYjUY6d+7MV199Ra1atfjzzz/517/+xZYtWzCbzfTo0YMZM2YQHBxcqvtOT0/nn//8J7GxsaSnp9OwYUPefvttBg0a5NLOYDDw+OOPl+rad9vx48fZsmUL3bp1c4ZTAHq9nilTptC1a1fmzJlz04AqPz+fdevWERwczJgxY1zqnn/+eZo3b866des4ffo0tWrVAiA2NhaACRMmOMMpUIK9N998k7///e9ER0fz/vvv36nbFeK6JKASQghxW9RqFf5eOvy9dEQElvw8q81Oeq6Jq7kmrmSbSMsxkZZrIi3bxJUc5bVLXY4Jk9VGvtnG+fQ8zqfnleh99Fo1VQweBPl6UMWgJ8jgQRVfD4IMHkWO9QT5euDrcY+FWV6VlC20mXudxagEVldOKoHVlZNw5ZSyZV1Q1rhI+k3ZitJ4KIGYI7CqUl/ZB0TKYuRCCCHuLY4A6uqZIiORzygPcThCKEsJfp9Qa8G3mjJFnCEYfEOK7EMKAynvwIo5sulBp9FBpXBlux5TjvK9kXa6YEso3GecU6YVPL/bfcS6h1/h70qhzaBaCwhuLOGjEKJMHTlyhF9//ZVevXoRHBzM0KFD2bhxI9HR0UyePNmt/auvvsq8efOIjIxk1KhR5OfnM3XqVHbs2FHs9T/55BNOnjxJ27Zt6du3L+np6fz888+MGDGCY8eO8fnnn7udc/XqVTp27EhISAjDhg3j+PHjrFq1iqNHjxIbG0unTp1o1aoVL730Env27GH58uWkpaWxadOmEt+3yWTiscceIzs7myFDhpCTk8PSpUsZPHgwqampbuFMRRcfHw9A9+7d3eo6duyIj48Pmzdvvul1rly5gsViISIiotjPMyIjI9m/fz9xcXHOgCo5OdlZV1x7gE2bNklAJe4KCaiEEEKUC41aRaDBg0CDB3Wq3ry93W4nx2QtCLCMbuGVY7tSJNjKNlowWUoeaHlo1S4BVhVfvTPccpRX9fUg2M8TT10FH2Wk9YAq9ZTtWqYc5UOXKyeVLfUkXP4TLh9XPqhL+UPZilLrlNFWjieHgx+CkIeUda/upVBPCCHE/cWcrwROzvApsfD4amLhekU34lNVWY/Iv3rBGkUFx/7hyggZQ1UZXXy/0/sowVJwY/c6i1H5venyUbh0VPmd6dJR5QEgYyYk/a5s+xYo7VWawsDKuTVV3kMIUTbsdmX9uopO531H/naKiooCYMiQIQD069eP119/nejoaCZOnIi6yMMS8fHxzJs3j2bNmrF9+3Z8fJSfRRMmTKB58+bFXn/WrFluwYXFYqFXr15MmzaNsWPHUqNGDZf6gwcP8uabbzJ16lRn2euvv86sWbPo1KkTkydPZuzYsYDyt/1TTz3FmjVr2Lt3Ly1btizRfV+8eJG6deuyY8cO9Hq98z5atGjB+PHj6devH9WrVy/RtUojPj7eGSaVRM2aNXnhhRdu2u7EiRMA1K3rPmOKRqMhMjKSI0eOYLFY0Gqv/xF+5cqV0Wg0JCYmYrfb3UKqhIQEQBmx5RAUFOSsa9iw4U3bC1GWJKASQghxT1CpVBg8tBg8tNQI9C7ROflmK5ezjKRmG0nNNjmPC8scx0qYZSxFmOXvpSPEz5Oqfh6E+HkS7OdJsJ9Hwd6TEH9PAn30aDUV8ElqvQ+ENFG2omw2yDhb8OHLUbh8rDC4MucUHP/peo6HX8EHOgWBVUgTqNpInhwWQghx51gtSvDkeLDCuZ1Spmi7Gd9QZQpbx1aphhI++VdXAqiyXjNJ3Nu0HlC1obIVza8sJuX78PJR5cGeiwfgwn5lzatLh5XtwCKlrUqj/I4U/gjUeETZ+z+4a00IcceZc+HDauXdi5ubcOG2w2qz2cyCBQvw8/OjT58+gDKdXd++fVm4cCG//PKLy4icb7/9FoCJEyc6wymA6tWrM3bsWN577z239yhuVI1Wq2XkyJFs2LCBuLg4hg0b5lJvMBj497//7VI2aNAgZs2aRWBgoMu6SCqVioEDB7JmzRoOHDhQ4oAKlDWSHOEUQFhYmPM+Fi9ezD/+8Y8SX6uk4uPjSzWSqEuXLiUKqDIyMgDw9/cvtt7Pzw+bzUZWVhaVK19/DWlvb286d+5MXFwcM2fOZNSoUc66FStWsH//fkCZHtGhZ8+eLF68mI8//pi//OUveHp6AsporC+//NKtvRBlSQIqIYQQ9y1PnYbwAG/CA24eaOWZrEpgVSTAcgZZWSZneUpmPkaLjYyCdbSOpVz/yWy1CoIMHoT4e1LV15MQfw+CfQvCLH9Pqvl7ElrJC4NHBfnfsVpd+OFd/ScKy202ZU2Oy8fg0p/KlnJICbKMmXB2p7I5qNTKaCtHaBXaDKq1VNaKEEIIIYpjtyvrBF0bQF05CVcTwGa5/rl6wzUBVESR43B5aEKUDa0eghsp20P9lDK7XVm37ML+wnVAL+5Xyi7uV7bfvlHa+lUvCKzaQs1OSgAmo9KFEDcRGxvL5cuXGT58uDNUABg6dCgLFy4kKirKJaA6cEBZj7hTp05u1yquDCArK4vPPvuMlStXcurUKXJyXNeLvnDhgts5devWxdvb9e/u0NBQAJo2beo2qsdRV9y1rker1Ra7HpPjPvbt21fia5XG5MmTi506sSL54osv6NixI6NHj+ann36iadOmnDx5ktjYWJo2bcrBgwddRtYNHjyY+fPnExcXR5MmTXjiiScwm82sXLnSuS6YWqYtFndJBflETAghhChfXvqShVl2u53MPAspWfmkZOaTnJHPpSwjyRnK65QsIykZ+VzONmK12bmUZeRSlhHIuO41/Ty1VKvkRfVKXoRW8qRaJS+q+Xsp+0pKoKUrz5FYarXytHmlGlC3yEKzFhOkHleeGk4+VLD/Q3lyOPW4sh1eUdi+UoSyNkO1FlC9pRJceRb/tJgQQoj7lM2mrOnjHKVb8PBD6nEwZV//PK0XBNYu2OoUbgG1lLWf5IN9URGoVOBXTdka9Cosz0iCc7vg7C5ln3xIGf13eEXh70o+VSCyc8HWRQlY5ftaiJLReSujkyo6XclmArkRx/R+Q4cOdSl/9NFHqV69OrGxsaSlpREQoDwcmJGRgVqtdk7pVpQjiCjKZDLRtWtX9u7dS4sWLRgyZAiBgYFotVrOnDlDTEwMRqPR7Tw/Pz+3Mse0dDeqM5vNN7tlp6CgoGJDE8d9OEYk3SscI6eu1+/MzExUKhW+vr43vVazZs34/fffmTRpEnFxccTFxVGnTh2++eYb0tPTGT9+PFWrFq6toNVqWbt2LR9//DGLFi1i9uzZ+Pv707dvX8aNG0e9evVc2gtRliSgEkIIIUpBpVLh763D31tHveDr/6Jotdm5km0kJVMZdZWcmc+lzHxSMo0kZyph1oX0PDLzLcqWnMXR5OJHY6lVUNXXk2qVlBFX1St5OUdfVa/kRXiAN/5eurK65evT6gum9XsImg1Uyux2yE5Rgqrkg0podWG/sk5DeqKyHVlZeI3AOsroKkdwFdoM9Lf/h5sQQohyZrMpP/OLBlGXjxZOG1sclQYqRxQJoIqEUb7VlAcmhLgX+ResbfbQs8prUw6c3wvnfoXEHZC4E3Iuwx/LlQ3AvwbU6gJ1u0PtbuBx8w8ohXhgqVQPxDpv586dY/369YAyjdz1LFy40Dmlnr+/PzabjdTUVKpUqeLSLiUlxe3c2NhY9u7dy/Dhw5k7d65L3eLFi4mJibnd27hlqamp2Gw2t5DKcR/XmyrvdpXVGlSOtacca1EVZbVaSUhIIDIy8obrTxXVoEEDlixZ4lbu6Evr1q1dyj08PJg0aRKTJk1yKXfc67XthSgrElAJIYQQZUCjVlHVz5Oqfp404fq/KGflm7mYoYRVF9IL9hl5ztfJGfmYrDaSC0IuzqYXex1/Lx3hAV6EV1ZGgYVX9iIswJsaAd5Ur+SFp+4uLe6uUoFviLLVfaywPC9dmdbmwj7lA5kL+5X1rhzTOB1aWnB+kTUawtsom3+4PEEshBAVWW5awUjaw8qDCSl/KEGU5TprOmr0EFgXqtSHKg2gagMIqq+MhtLqiz9HiPuJ3gciOykbgMUISbshYTMkbIGk35Xfk/YtUDa1DiLaQ70noF4PJbwVQjxw5s+fj81mo2PHjtSvX9+t3mKxEBMTQ1RUlDOgatasGXv37mXr1q3069fPpf3WrVvdrnHq1CkAnnnmGbe64trfTRaLhZ07d9KhQweXcke/WrRoUSbvW1ZrUDlCxvXr1/POO++41G3bto2cnJwbBpElkZWVxU8//URgYCCPP/74zU8AvvvuOwAGDhx4W+8tRElJQCWEEEKUI19PHb6e1x+NZbPZSc0xciE9n4vpeZwvCK4uZijHSVfzSMsxKWtinTfzx/nMYq8T7OfhEl45pjOsEeBNiJ8nanUZB0BelaBWV2VzyElVAivHdn4vZCe7r9HgGwphDxeEVo9AaFNZ0F4IIcqD1aKMiC06rWvKYci6zrRKGg8IqucaRFVpAJUjQSN/igrhpPWAmh2UrdsEMGbD2V/h1EY4vk757y5hs7Kt+xcE1Ib6PaFRHwhrLQ/yCPEAsNvtREdHo1KpiImJoVatWsW2O378ODt37mT37t20bt2aIUOGEB0dzQcffECPHj3w8VFGmp0/f55p06a5nR8REQEoAUnv3r2d5Zs3b2bOnDllcGelM2HCBDZs2IBerzzQkpSUxLRp0/Dw8CizQKWs1qCqX78+nTt3Ji4ujrVr19KzZ09AmWbxvffeA+Dll192OSc1NZXU1FSCgoJcpm3My8tDp9O5jLYyGo0MHz6ctLQ0pk2b5rJmGShTCF47/eKyZcuYN28eDz/8sFugKURZkb8KhBBCiApMrVZR1deTqr6eNA+vVGybHKOFc1dzOZeWx7m0XOdx0tVczqXlkmOyFkw1aGR34lW38z20amoG+hAR6E1kkA81g3yoGehDZJAPVX09yi688glS1rRyrGtltytrNCT9Bud+K1yjIesi/PmjsoHygWe15kpYVbOjsri4rGUlhBB3lilH+Rl88YAyZWvyH8oUfZb84ttXrgnBDymjYIMbQ9VGytqDEkQJUXoeBmUket3H4ImPIPUknFinhFWJO5TAaufXyuYXBo2eLgirHpapMIW4T23atImEhAS6dOly3XAK4MUXX2Tnzp1ERUXRunVrunXrxosvvkh0dDRNmjShb9++GI1GlixZQtu2bVm1apXL+b1796ZmzZp8+umn/PHHHzz00EMcO3aMVatW0bdvX5YtW1bWt3pdoaGh5OTk0LRpU3r37k1OTg5Lly7lypUrTJ8+nerVq7u0HzduHKmpqQAcOnTIWWYwGAAl/OnYsePdvYlrzJw5kw4dOtCnTx8GDBhAaGgoq1ev5vDhw4wePZr27du7tP/66695//33mTRpkktotmfPHvr168fjjz9OeHg4mZmZrF69mrNnz/LKK68wZswYt/d+5JFHCA8Pp2HDhnh6evLbb78RHx9PrVq1+OGHH9Bo7tIsLOKBJ38tCCGEEPc4Hw8tDUL8aBDivvis3W4nLcfEuavFh1dJV/MwWmwcS8niWIr7GlieOiW8qhmoBFeRQd5EFAmvVHfyiV2VCiqFK5tzjYZcZXRV0dAq94qyP7cLdkwHlVr5QDSio/LkcY124B1w5/olhBD3O0cYdWF/4XSsqcfBbnNvq/OB4EYFYdRDENwEqjYET/f/Bwkh7pCgOsrWbhTkZ8KpTfDnT3D8Z8hMgl9nKptvKDR8GpoOgOotZWSVEPeRqKgogJtOHTdgwADGjh3L999/z9SpU/Hy8mLOnDnUq1ePOXPm8PXXXxMWFsZbb71F//793QIqg8HApk2bGD9+PFu2bCE+Pp7GjRvz3XffERwcXK4BlV6vZ8OGDbzzzjssWLCA9PR0GjRowFdffcWgQYPc2i9btozExESXsuXLlzuPu3btWu4BVePGjdm1axfvvvsuq1evJicnh3r16jFjxgxee+21El+nRo0adO3ala1bt5KSkoK3tzctW7Zk6tSpPPvss8WeM2DAAFasWMGvv/6K2WwmMjKSd999l/Hjx7uNrBKiLKnsdru9vDtxP0tKSiI8PJxz584RFhZW3t0RQgghXFisNs6n55GQmsOZ1BzOXMnlzBXl+NzVPKy26/+a4K3XUKuKD7WrGKhTxUDtqgbqVDVQM9AHvbaMnt612yHttBJOJe6AxO3Kaxcq5YPTmh0gomDzCSyb/gghxL3GLYzaD6nHig+jDCHKiNWQpgVh1EPK9HwyQkOIisGcr0wDeCQWjq0FY5GpnoPqQbOBSljlL59FiPtDfn4+CQkJREZGuk1XJoQQFUVJf1ZJbqCQgKqMyTeaEEKIe5XZaiPpap4zsDqTmkPClVzOpOaQdDWX62VXGrWKGgHe1K5ioHZVH+pUUYKr2lUN+Hnq7nxHMy8oYdWZbcp25YR7m6qNC9fAimivTJ0jhBD3O5sNrpyEpN8Ltt1w6fCNw6hqLSC0uXLsG3KXOyyEuGUWI5yKgz+WwZ+rwJJXUKGCyE7QbLAyFaDep1y7KcTtkIBKCHEvkICqdCSgKmPyjSaEEOJ+ZLLYOHc1l1OXsjl5OZtTl3IK9tlkGy3XPa+qr4cy4qqqgbrBBuoH+1I/xJdK3vo717msFGVkVeJ2OLMdLv/pWq/WQXibwsCqWktZI0UIcX/IuwpJewoDqfO7IT/DvZ0jjAptrgRSEkYJcX/Jz1TW7jywGM5sLSz38FdGVbV+Cao2KL/+CXGLJKASQtwLJKAqHQmoyph8owkhhHiQ2O12LmUZOXkpm1OXs132KZnG655X1deD+iG+1A/2pV7Bvm6wAW/9HQiOclIhYQucjofTcZB+1rXeww9qdioMrILqypoNQoiKz2ZTpuY7u1MZGXXut+JHkGq9lBAqrDWEPazs/ard/f4KIcrH1UQ4uBT2L4SrZwrLIzooQVXDp0F7Bx8UEqIMSUAlhLgXSEBVOhJQlTH5RhNCCCEUWflmTl3O4eQlJbA6kZLFsZQskq7mFdtepYIaAd7UC/Z1jrSqH+JLZJAPOs0trn9it8PVhIKwKh5Ob4b8dNc2vtWgzl+gzuNKYOVV6dbeSwgh7iSrWVkv6uzOwi3vqnu7gFoQ1qYwkApuDJoymF5VCHFvsdmUB3V2z4Njawqn+vSpAg+/rGw+QeXbRyFuQgIqIcS9QAKq0pGAqozJN5oQQghxY9lGC8dTsjierARWx1OyOJacRWq2qdj2Oo2KWkEGGob60qiaH41C/WkY6kugwaP0b26zQvJBJaw6FQdnfwVrkZFeKo0yHWCdR5XAKqQpqG8xHBNCiNIw5SjT9CXuULak3UXWlCmg81aCqPBHlDCqemvwCSyf/goh7h0Z52Hvt7A3BrIuKmVaT2g2CNqNUkaTC1EBSUAlhLgXSEBVOhJQlTH5RhNCCCFuTWq20SW4OpacxfGU669xFeznQcNQPxqF+tGomh8NQ/2oGeiDRl2K6frMecoHwSc3wslflOmzivKpWhBWPQa1/wLeAbdxh0IIUYQpRwnJE7Yoa8ZcPAC2a37eeVWGGu2ULaI9hDaT0VFCiFtnNStrVe34Ci7sKyhUQf2e0P4NiGhXrt0T4loSUAkh7gUSUJWOBFRlTL7RhBBCiDvHbrdzPj2PY8lZ/HkxkyMXM/nzYhYJqTnFtvfSaagf4hhppYRWDUJ88fEo4dpWVxPh1EY48QskbAZTdpFKFVRvBXUfVwKrai1Arbn9mxRCPBgsJji/WwmkErYoa0jZzK5t/MKUD4gdgVRQfRnFKYS48+x25QGdnV8r0/851OwEXd+Bmh3Lr29CFCEBlRDiXiABVelIQFXG5BtNCCGEKHvZRgvHkjM5cjGLIxcy+fNiJkeTM8k329zaqlRQM9CnILAqnCYw2M8DleoGo60sJjj3qzKy6sQvcOmwa71XgDKqqu7jyt5Q9Q7fpRDinuacUnSzEkid3QnmXNc2ftUhsgtEdoaaHaBSjfLpqxDiwZV6QhlRtX9RYWge0RG6/lMJrG70u5IQZUwCKiHEvUACqtKRgKqMyTeaEEIIUT6sNjtnruRw5IJjpFUmRy5kcinLWGz7AB+9c3pAx75WkA9azXVGK2ReKAirNihrWBkzXetDm0HtgukAw9vINFxCPIiykuHUJuVnxalNkHfVtd47UAmjIjsrwVRALfnwVwhRMaSfg21fwL4FYC1YFzSiAzw6EWq0Ld++iQeWBFRCiHuBBFSlIwFVGZNvNCGEEKJiSc02OsMqR3B16nIOVpv7r0QeWrUyRWCR4KpBqB+Ga6cItJohaTec3KB8EH3xgGu93hdqdVHCqjqPyqgIIe5XVjOc26X8HDj5CyQfcq338FM+4HWEUlUbyZR9QoiKLSMJtn0Je2MKg6oGT8Gjk6BKvXLtmnjwSEAlhLgXSEBVOhJQlTH5RhNCCCEqvnyzleMpWc7QyjFNYI7JWmz7moHeLiOt3KYIzL4Ep+IKRk1shNwrrhcIqq8EVXUeVT6s1nmV8R0KIcpMRhKcWA8nNyrT95myXOurtSgIpx+D6q1BU8I18IQQoiLJOA+bP1FGVNltoNJAy6HKGlW+IeXdO/GAkIBKCHEvkICqdCSgKmPyjSaEEELcm2w2O2fTcp2BlWOfnJlfbPvrThGoAi7uVz68PvkLJP2mfLDjoPVUFh93fIAdWEem+BKiIrPblVGSx9Yo27WjpLyDCgLox6BWNzBUKZ9+CiFEWbh0FDa+r/z8A9B5Q6e3oP0boPUo376J+54EVEKIe4EEVKUjAVUZk280IYQQ4v5yJdvInxezOHIxwxlclWqKwEo2DOe3KSOrTm6EzPOuJ1WqoaxFE9lZCa78qt2lOxNCXJfFCAlbC0KptZB1obBOpYawh6Hu40ooFdJMpu0TQtz/EnfA+vfg/G7ldUAteOITqNe9fPsl7msSUAkh7gUSUJWOBFRlTL7RhBBCiPvfLU8RGOLLw4ZLNMz5Hd+keFSJOwrXd3AIqA2RnaBmweYbfBfuSAhBfgYc+xmOrVbCZFN2YZ3OB+r8Ber3grrdwSeo/PophBDlxW6HQz/A+nchO0Upq98LenwIAZHl2zdxX5KA6s46c+YMkZGRDBs2jPnz55d3d26oZs2agNJnISo6CahKRyZAF0IIIYS4TZ46DU3DKtE0rJKz7EZTBJ65ksuZK7msOZRc0Loe/l6NaVJlLI95n6SV/RA1s/ZhuHoYVdopSDsFe+YrTYPqKyOrarSD8DbKiCuZElCIOyMvXRkhdWQlnNrkGhgbQqB+T+XD18jOoJMPxoQQDziVCpr2h3pPKOtT7fqPMtL05Ebo+k9l2j+Nrrx7KcR966WXXiI6OpqAgAAuXLiAh4dMs+mQm5vLrFmz2LNnD3v37uX48ePY7XYSEhKcYVdFcfz4cd599102bdpETk4O9erVY+TIkYwcObJwjeMSSEpKYsqUKaxdu5bk5GSCgoLo0aMHH3zwAeHh4W7tbTYbM2fOZN68eRw9ehStVkvz5s0ZN24cTz/9dLHvsWvXLj788EO2b99OVlYWNWrUYNCgQfzrX//Cy0vWVRa3RkZQlTFJQoUQQghRVGmmCPQjhy6eJ+jufYJWtj8IzT+JimvaGYKVoCqsDYQ/AqHN5INzIUoj7yocXVMQSsWBzVxYF1QfGj2thFKhzWXqPiGEuJFLR2HteEjYorwOaQrPzIDQpuXbL3HfkBFUhbKysggNDSU3Nxe73c7ixYsZMGBAqa5xP4+gctwbQEREBFlZWaSlpVW4gOrIkSO0b9+evLw8+vfvT7Vq1Vi9ejWHDx9m9OjRfPXVVyW6zqlTp2jfvj2XLl2ie/fuNG3alBMnTvDjjz9SpUoVduzYQe3atZ3t7XY7zz//PMuXL6d27dr07NkTo9FIbGwsly5d4quvvmL06NEu77FixQoGDBiARqPh2WefJSQkhO3bt7Nr1y46dOjAxo0bJSQtICOoSkcCqjIm32hCCCGEuJl8s5XTl3M4cSmLEynZHE/J4uSlbM5cyaFobuVPNm3Vf/KI+k9aa0/SiAS0uE4jaNfoUYU2g2otoVoLqNYcguqBWnN3b0qIisyYBX+ugj+Wwel4sFkK66o0hMZ9oNEzULVhefVQCCHuTXY7HFgMP78D+emg1kKHv0Pn8fIAjbhtElAVmjt3Lq+88gpvvfUWX375JY8++ijr168v1TXu54AqOzubnTt30qpVKwICAnjiiSdYt25dhQuounTpwpYtW1izZg09e/YEwGQy8dhjj7F161Z27NhBu3btbnqdp556itWrVzNt2jTeeOMNZ/kPP/xA//796dGjBz///LOzfNmyZTz//PN06NCBDRs2OEc/paam0rp1a5KTkzl69Kjza5WXl0dERATp6enOrysoQdeYMWOYMWMGH330Ee+8886d+tLc0ySgKh15BFAIIYQQopx56jQ0qubHM82rM65HfWYPbc2mcV058sETrB3biWkDmzO6Wx3aNq7NiYCu/Ns6lKfzP6BxfhTPGSfykXkQ662tuGz3Q2U1QdLv8Ns3sHIkzGyL7cPq2KK6w9p/Kh8aXToKtuLXxxLivmU1w/H1sGw4/F9d5b+Pk78o4VTVxtDtf2DUbzDqV+j6joRTQghxK1QqaD4IRv+uBP02C2z9DL7pDBf2l3fvhLhvREVFodVqefvtt+nWrRsbN24kMTGx2LZWq5VPPvmEOnXq4OnpSZ06dfjoo4+w2WzFto+Li+Oll16ifv36GAwGDAYDrVu3Zvbs2cW2V6lUdO3alfPnzzN48GCCgoLw9fXlySef5PTp0wD8+eef9OnTh4CAAHx9fXnuuedISUkp9X2np6czYsQIQkJC8PT0pEWLFnz//fdu7QwGA48//jgBAQGlfo+75fjx42zZsoVu3bo5wykAvV7PlClTAJgzZ85Nr5Ofn8+6desIDg5mzJgxLnXPP/88zZs3Z926dc5/C4DY2FgAJkyY4DI1X1BQEG+++SZGo5Ho6Ghn+Y4dO7h8+TJ9+vRxhlOg/Nv/+9//BuA///kPMg5G3ApZg0oIIYQQooLy1GloGOpHw1A/l3KjRRlxdfJSNicvNebkpWziL2WTkJpNiC2ZlqoTNFEn8JA6gYdUCfhY8uDcLmUrYNF4YQx6CI8ardCGt4bqLSGglqxnJe4vdjuc3wsHl8AfyyE3tbAusA406Q8P9YOguuXXRyGEuB8ZqkL/b+HIj7D6H5B6DOY+Bo++B+3GyJSp4o6y2+3kWfLKuxs35aX1KtWaQtdz5MgRfv31V3r16kVwcDBDhw5l48aNREdHM3nyZLf2r776KvPmzSMyMpJRo0aRn5/P1KlT2bFjR7HX/+STTzh58iRt27alb9++pKen8/PPPzNixAiOHTvG559/7nbO1atX6dixIyEhIQwbNozjx4+zatUqjh49SmxsLJ06daJVq1a89NJL7Nmzh+XLl5OWlsamTZtKfN+OkUXZ2dkMGTKEnJwcli5dyuDBg0lNTXULZyq6+Ph4ALp37+5W17FjR3x8fNi8efNNr3PlyhUsFgsRERHFfn9FRkayf/9+4uLiqFWrFgDJycnOuuLaA2zatIn333//pu0rVapE5cqVSUxM5PTp0y5TCQpREhJQCSGEEELcYzy0xQdXFquNs2m5SnB1OZsll7L5KCUD6+UT1LGcLBJancHbmoc25XdI+R1+V87P1fiREdAEe/VW+Ndui09kGzBUKYc7FOI2ZSXD/kWw/zu4crKw3DsImjwHTfsr02BKICuEEGWr0dNQsyP8OAaOroINE5XRq33+A/7Vy7t34j6RZ8njkUWPlHc3bmrX4F1467xv+zpRUVEADBkyBIB+/frx+uuvEx0dzcSJE1EXCYDj4+OZN28ezZo1Y/v27fj4+ADKyJnmzZsXe/1Zs2a5BREWi4VevXoxbdo0xo4dS40aNVzqDx48yJtvvsnUqVOdZa+//jqzZs2iU6dOTJ48mbFjxwJKoPjUU0+xZs0a9u7dS8uWLUt03xcvXqRu3brs2LEDvV7vvI8WLVowfvx4+vXrR/Xqd/7nSnx8vDNMKomaNWvywgsv3LTdiRMnAKhb1/1BKY1GQ2RkJEeOHMFisaDVXv8j/MqVK6PRaEhMTMRut7uFVAkJCYAyYsshKCjIWdewYcNStb9WRkYGV69edZ4jAZUoLQmohBBCCCHuE1qNmlpVDNSqYqDoc3h2e2cuZuRz8lI2hy5lszIlg7yLR/G5coja5hM0U5+isSoRb2sm3pe3w+XtsH86AJc0wVz0a0pOtfZ41OlKZN3GBBhk8VtRAVktcGI97FsAx9eBvWAaS60XNHwKmg6AWl1BoyvXbgohxAPHOwAGLFR+Pq/9JyRsgVnt4ZmvoWHv8u6dEPcUs9nMggUL8PPzo0+fPoAynV3fvn1ZuHAhv/zyi8uInG+//RaAiRMnOsMpgOrVqzN27Fjee+89t/cobpSMVqtl5MiRbNiwgbi4OIYNG+ZSbzAYnFO9OQwaNIhZs2YRGBjosi6SSqVi4MCBrFmzhgMHDpQ4oAL48MMPneEUQFhYmPM+Fi9ezD/+8Y8SX6uk4uPjnSOJSqJLly4lCqgyMjIA8Pf3L7bez88Pm81GVlYWlStXvu51vL296dy5M3FxccycOZNRo0Y561asWMH+/fsBZXpEh549e7J48WI+/vhj/vKXvzjXSbpy5QpffvmlW/sOHTrg5+fHypUr2bdvHy1atHDWTZw40Xlc9BwhSkoCKiGEEEKI+5xKpaJaJS+qVfKicz3HiKjmwECuZBs5eSmbFSlXyUw8gEfKPoIy/qCe5Th1VBeoak2h6tUNcHUDHH6fJHsQu3RNuVy1A/oG3XmodgT1Q3zRaWSqHlFOrpyCfQuVEVPZyYXl4W2h5RBlDRQP3/LrnxBCCGXEasuhENEBlr8MF/bCkr9B+zHw6GTQyMdT4tZ5ab3YNXjXzRuWMy+t180b3URsbCyXL19m+PDhzlABYOjQoSxcuJCoqCiXgOrAgQMAdOrUye1axZUBZGVl8dlnn7Fy5UpOnTpFTk6OS/2FCxfczqlbty7e3q6jw0JDQwFo2rSp26geR11x17oerVZLu3btrnsf+/btK/G1SmPy5MnFTp1YkXzxxRd07NiR0aNH89NPP9G0aVNOnjxJbGwsTZs25eDBgy4j6wYPHsz8+fOJi4ujSZMmPPHEE5jNZlauXElwcDCAS3uDwcDUqVN5+eWXadeuHc899xwhISHs2LGDPXv20KBBA44ePepyjhAlJb8BCCGEEEI8wAINHgQaPHikViC0qwM8C0BWvpk/zieTfmIXmnPbqZL6G7WMfxKmSiXMsgkubMJ8/kN+3dCQj1WtuVTtMRo2aEiH2kE8VN0fjVqmThNlyGqB4z/Db7Mhocjc/N5B0HwQtBgKVeqVX/+EEEIUL7A2DF8PG9+HHV8p2/m98Nw88A0p796Je5RKpbojU+fdCxzT+w0dOtSl/NFHH6V69erExsaSlpZGQEAAoIzSUavVzinainIEEUWZTCa6du3K3r17adGiBUOGDCEwMBCtVsuZM2eIiYnBaDS6nefn5+dW5piW7kZ1ZrP5ZrfsFBQUVGwA4rgPx4ike4Vj5NT1+p2ZmYlKpcLX9+YPWjVr1ozff/+dSZMmERcXR1xcHHXq1OGbb74hPT2d8ePHU7VqVWd7rVbL2rVr+fjjj1m0aBGzZ8/G39+fvn37Mm7cOOrVq+fSHmD48OFUq1aNTz/9lNjYWKxWKw8//DAbN27kk08+4ejRo27nCFESElAJIYQQQgg3vp46mtYOh9rhwHNKoSmHnBPbufLHenzO/EJgXgKdNH/QiT+wXYxh5/lGxKzvxHZ9e1rUDuMvDaryaMOqBMqUgOJOyUmFvd/C7nmQca6gUAV1HlNGS9XrCVr9DS8hhBCinGl00P3fENYGVr4OidvhP53g+WhlvSohRLHOnTvH+vXrAWUauetZuHChc0o9f39/bDYbqampVKniurZsSkqK27mxsbHs3buX4cOHM3fuXJe6xYsXExMTc7u3cctSU1Ox2WxuIZXjPq43Vd7tKqs1qBxrTznWoirKarWSkJBAZGTkDdefKqpBgwYsWbLErdzRl9atW7uUe3h4MGnSJCZNmuRS7rjXa9uDMjVgz5493cqHDBmCWq0u1XSNQjhIQCWEEEIIIUpG74NP4+74NC6YNuTKKWx/rsb4x494Jf9OB81hOmgOk2OPZtWxdkQf6cE7RPBwzQC6Nw7hiYdCqF7p9qc2EQ+g83vhtznwx3KwFjy16xUArYZB65egUo0bny+EEKLiafQ0VG0ES4fApSPw7TPQ6/+Un+tCCDfz58/HZrPRsWNH6tev71ZvsViIiYkhKirKGVA1a9aMvXv3snXrVvr16+fSfuvWrW7XOHXqFADPPPOMW11x7e8mi8XCzp076dChg0u5o19F10W6k8pqDSpHyLh+/Xreeecdl7pt27aRk5NzwyCyJLKysvjpp58IDAzk8ccfL9E53333HQADBw4sUfvt27dz5swZevXqVWYhobi/SUAlhBBCCCFuTWBt1B3fwKvjG5B+Fg4swX7ge3zSTjFAG88AbTw7rY2Yl/gE/05oyZRVR2hbK4B+LcPo1SQUg4f8KipuwGaD42th+3Q492theWhzeGQENO4HOs/rni6EEOIeEFQHXt4IP70Bh36AVW/C5ePKCCtZl0oIJ7vdTnR0NCqVipiYGGrVqlVsu+PHj7Nz5052795N69atGTJkCNHR0XzwwQf06NEDHx8fAM6fP8+0adPczo+IiACUgKR3797O8s2bNzNnzpwyuLPSmTBhAhs2bECvV0bMJyUlMW3aNDw8PEocqJRWWa1BVb9+fTp37kxcXBxr1651jkwymUy89957ALz88ssu56SmppKamkpQUJDLtI15eXnodDqX0VZGo5Hhw4eTlpbGtGnTXNYsA2UKwWunX1y2bBnz5s3j4Ycfdgs0i2t/4cIFXn75ZbRaLVOmTLnFr4R40Mn/7YUQQgghxO2rVAO6jEfVeRyc2wW7voEjsbTTHKGd5ghntRF8kvcMa0634dfTaUyKPUzPh0IY0i6CFjUql3fvRUViMcLBJUowdaVgyhO1Dhr3hTavQlhrUMkaZ0IIcd/Qe0O/OVClAWyaArtmwZWT8FwUeMrT+EIAbNq0iYSEBLp06XLdcArgxRdfZOfOnURFRdG6dWu6devGiy++SHR0NE2aNKFv374YjUaWLFlC27ZtWbVqlcv5vXv3pmbNmnz66af88ccfPPTQQxw7doxVq1bRt29fli1bVta3el2hoaHk5OTQtGlTevfuTU5ODkuXLuXKlStMnz6d6tWru7QfN24cqampABw6dMhZZjAYACX86dixfKcVnTlzJh06dKBPnz4MGDCA0NBQVq9ezeHDhxk9ejTt27d3af/111/z/vvvM2nSJJfQbM+ePfTr14/HH3+c8PBwMjMzWb16NWfPnuWVV15hzJgxbu/9yCOPEB4eTsOGDfH09OS3334jPj6eWrVq8cMPP6DRaFzaT58+nYULF9KxY0eqVq3KuXPniI2NJTc3l6ioKJneT9wyCaiEEEIIIcSdo1JBjbbKlpEEv8+F3+dRw5jIDN10UivV4UtrfxamN2bFvvOs2HeeZmH+vNChJr2ahOKh1dz8PcT9KS9dWVtq138gu2BNBA9/aP0iPDIS/ELLtXtCCCHKkEoFncdBUF1YMQJOboCoHvDXH6BSeHn3TohyFxUVBXDTqeMGDBjA2LFj+f7775k6dSpeXl7MmTOHevXqMWfOHL7++mvCwsJ466236N+/v1tAZTAY2LRpE+PHj2fLli3Ex8fTuHFjvvvuO4KDg8s1oNLr9WzYsIF33nmHBQsWkJ6eToMGDfjqq68YNGiQW/tly5aRmJjoUrZ8+XLncdeuXcs9oGrcuDG7du3i3XffZfXq1eTk5FCvXj1mzJjBa6+9VuLr1KhRg65du7J161ZSUlLw9vamZcuWTJ06lWeffbbYcwYMGMCKFSv49ddfMZvNREZG8u677zJ+/Hi3kVIA7du3Z/Pmzfz0009cvXqVwMBAevXqxT//+c8ym15RPBhUdrvdXt6duJ8lJSURHh7OuXPnCAsLK+/uCCGEEELcfXnp8Oss+HUmGDMByAxtz2zvV5h91AuT1QZAFV8PXu1Ui7+2rYG3Xp6jemDkpinfG7/+B0xZSplfdWj7GrQcBp7ufyALIYS4j13YB98PgqyLyv8P/rYCqjYo716JCiA/P5+EhAQiIyPdpisTQoiKoqQ/qyQ3UKjLuwNCCCGEEOI+51UJuv0Lxh6Ajm+BxgO/izsYd3o4Bx5ex/90CyHYz4PLWUb+d82fdPwkjlnxp8gxWsq756Is5abBpn/Dl01hy/8p4VTVRtDnP/DGfmg/RsIpIYR4EFVrAS//AkH1IfM8zOsB534r714JIYQQogxIQCWEEEIIIe4O7wB4bBKM/g0aPQN2G177o3nl0CC2PZPPp882pUaAN2k5Jj75+SidPo3j251nMBeMsBL3ieKCqeCHoP8CGLkdmg8Crb68eymEEKI8+YfBSz9D2MOQnw4xT8Px9eXdKyGEEELcYRJQCSGEEEKIu6tyTej/LQz7CQLrQnYKuh/+Sv8zE9k0shGfP9+MyCAf0nJMTIw9zBNfbmHjnynIzNT3OHMebJ0K05q7BlMDFsKIrdDoaVDLnydCCCEKeAfA0Fio8zhY8uD7gXD4v+XdKyGEEOK+Nm3aNCIiIvD09KRjx44cOHCgTN9P/gIUQgghhBDlI7IzjNwGHd8ElQYOr0D7TQee9T/G+jc7M+WZxgT46Dl1OYfhMbsZOu83Eq/klHevRWnZrLB3AUxvCRvfB2OGazDVsLcEU0IIIYqn94FB30OT/mC3wrLhcGhZefdKCCGEuC8tWrSIf/7zn0yZMoU9e/ZQp04devToQWZmZpm9p/wlKIQQQgghyo/OEx6bDK9sVNYfyrkMC/uh2zSZIW2qEz++KyO71EavVbP1RCrdv9jCjLiTMu3fvcBuh2M/w6z28ONoyLoA/uHQ9xsJpoQQQpScRgd9/wPN/6qEVCtegYNLy7tXQgghxH3niy++YOTIkQwdOpTGjRszd+5cLBYLixYtKrP3lL8IhRBCCCFE+avWAl7ZBA+/rLzePg3m9cAvP5l3ejZg/d8706FOIEaLjf9bd4ynpm/jUFJG+fZZXN/l47CwH3w/AC4fBc9K0P3fMHo3NBsowZQQQojSUWvg6a+hxRCw2+C/I2D/9+XdKyGEEOKuWrhwISNGjKB169Z4eHigUqmYP3/+Dc/5/fff6dWrF5UqVcLHx4e2bduydKn7gx4mk4l9+/bx2GOPOcu0Wi1du3Zl586dd/pWnOQvQyGEEEIIUTHovODJz5Wp3zz94fwemN0VEndQM8iHhcMfYWr/ZgT46DmWkkXfmduZEXcSq03Wpqow8jNh/bswqx2c2gQaPbR/A8buh/ZjlBFzQgghxK1Qq6H3dGj1ghJSrXwN/lhR3r0SQggh7pp3332X2bNnk5iYSGho6E3bx8XF0aFDB7Zt20b//v0ZOXIkycnJDBgwgM8//9ylbWpqKlarleDgYJfyqlWrkpycfEfvoygJqIQQQgghRMXSsLeyNlVIU8hNhZje8HsUKpWKfi3D+OWtLvR8KASLzc7/rTvGwNk7OZeWW969frDZ7XBgCXzdGnZ8BTYL1HsCXv8Vuk8Br8rl3UMhhBD3A7UanvxCCamww4pX4cQv5d0rIYQQ4q6YO3cuZ86c4fLly4wcOfKGbS0WC6+88gpqtZotW7Ywe/ZsPv/8cw4cOEC9evWYMGECiYmJd6nn1ycBlRBCCCGEqHgq1YCX1sFDzyphx+q3YNVbYLMS4KNn5l9b8n/PNcVHr+H3M1d5cvpWNh1NKe9eP5jSTsO3T8N/X4XsFAioBYN/gMFLILB2efdOCCHE/UathienQuN+YDPDkr9BYtlNPSSEEEJUFI899hgRERElartp0yZOnTrF4MGDad68ubPc39+fCRMmYDKZiImJcZYHBQWh0WhISXH9u/rSpUuEhITckf4XRwIqIYQQQghRMem94dkoeGwyoILdUbBkCJhyUalUPN86nLVjO9M8vBKZ+RZemr+bLzYcxyZT/t0dVgtsnw4z20PCFtB6waMTlVFT9brf1a4YLVauZBtJvJLD0eRMDl/I4I/zGRxMSmf/uXSOXMgkITWHlMx8MvLMmK22u9o/IYQQd5haA32/gTqPgyUPFvWHiwfKu1dCCCFEqWVlZZGZmencjEbjHblufHw8AN27u/9t1qNHDwA2b97sLNPr9bRo0YKNGzc6yywWC/Hx8bRr1+6O9Kk42jK7shBCCCGEELdLpYKOb0JAbVj+MhxbDd8+o4zO8Q6gRqA3S0a05d+r/mTBr4lM23iCA0npTBvYAn8vXXn3/v518SD8OAYu7ldeR3aG3tOU0VN3WI7RQkJqDqcuZ5N4JZfkzHxSMvKVfaaRzDwzplsInPy9dAQZ9AQZPAjy9SDUz5OIIB8iAryJCPSmeiUvtBp5nk8IISosrR76fwsLn4WzO2Dhc/DKRmUUthBCCHGPaNSokcvrSZMmMXny5Nu+7okTJwCoW7euW11ISAgGg8HZxuHNN99k+PDhtGrVipYtW/LZZ5+h1WoZPHjwbffneiSgEkIIIYQQFV+jp8EnFr4fAEm/wbweMGQl+FfHQ6thSp+HaB5eiQn/PUT8scs8N2sH0S8+TFhl7/Lu+f3FZoVtX0D8R8rUi57+0P1/ocXflDDxNqVk5nPgXDoHkzI4eD6DkylZXMjIL/H5PnoNXnoNapWqYAOVSoXRYiPfbCXXZMExwC4jz0xGnplTl3OKvZZWraJOVQONqvnRKNSPRtX8aFzNX4JPIYSoSPTeMHgxRPeClD/gu/4wfJ3y/ychhBDiHnDkyBGqV6/ufO3h4XFHrpuRkQEoU/oVx8/Pz9nGYfDgwVy+fJkJEyaQkpJC69atWbduHX5+fnekT8WRgEoIIYQQQtwbItrBS+uVJ6VTj8P8XjDsJ+eT0s+2CqN+iC/DY37nxKVs+s7cQdSw1jQNq1S+/b5fXD0DK0bAuV+V1w2egic/B99bn4886WouO05eYfupVHadTiM5s/gwKsigp1aQgZpB3oT4exHi50mIvwdVfT2p7KPH4KHF4KFFo75xSGa32zFZbWTnW7iSYyI1y0hqjonLWUbOX83jbFoOiVdyOZuWi9Fi42hyFkeTs1jBeUDJ4OoH+9ImMoCHawbQJjKAYD/PW75/IYQQd4CnvzKyeu5jcPlPWDoU/roMNPJAgRBCiIrP19e3TAOg0ho7dixjx469a+8nAZUQQgghhLh3VG0AL/0MMb3hagJEPwkv/ASVawLwUHV/Vo7qwIvRv3M0OYsB3/zKzL+2pFuDquXb73uZ3Q4Hvoc1b4MpC/S+0OtTaDao1KOmrDY7u8+k8fPhZOKOXuLMlVyXerUK6gX70jTMn6ZhlWgY6kftKj5U8tbfkVtRqVR4aDV4GDQEGjyoF+xbbDubzc7FzHz+vJDJkYuZHLmQyeGLGZxLy3OGVt/uTASgQYgv3RpUpVv9qrSsUUmmBRRCiPLgH6aEVPN6wul4WPV3ePrrOzK6V4j71ZkzZ4iMjGTYsGHMnz+/vLtzQzVr1gSUPgshSsYxcuraUVIOmZmZVK5c+W52qVjy15MQQgghhLi3VAqHF9co61JlnFWm9blyylkd6u/FDyPb0bleFfLMVl75djdrDl0sxw7fw0w5sOJVWPmaEk6Ft4XXtkHzwSX+0M9ms7PjZCr/XHaQNv/7CwNm/0r09jOcuZKLRq2iRY1KjO5Wh0UvP8If7/fg57935tPnmvG3thG0iqh8x8Kp0lCrVVSv5MVjjYJ549G6/GdIK7a+/Rd++59HmTG4JS+0r0mjUD9UKjianMWs+FP0/2YnLads4M0l+9l0NAWTpfTrYgkhhLgNoc3g+WhQqWHfQtgxvbx7JESZeumll1CpVAQGBmI0Gsu7OxVKbm4un3/+OYMHD6ZBgwao1WpUKlWFDLiOHz9O//79CQoKwsvLi2bNmjFr1izsdnuprpOUlMSIESOoUaMGer2eatWq8eKLL3Lu3Lli29tsNr7++mtatmyJt7c3fn5+dO7cmR9//PG677Fr1y6eeeYZgoKC8PDwoG7dukycOJG8vLxi21+9epVx48ZRp04dPDw8qFKlCs899xyHDx8u1b2JW+NYe+radaYAkpOTyc7OLnZ9qrtNRlCVkRkzZjBjxgxMJlN5d0UIIYQQ4v7jV00JqWJ6K9P9ffuMMrLKPwwAX08dUcNa84+lB/jxwAVGL9rLZ883o1/LsHLu+D3k8nFYOgQuHwWVBrr9Czq+BWpNiU5PvJLD8j1JLN97nvPphX+0+nvpeLRhVbo3CqFDnUB8Pe+dKZiq+nryZNNQnmwaCsDVHBNbTlwm7ugl4o9fJj3XzH/3nee/+85TyVtHz4dCeLpZddrWCkAlT/ELIUTZq9cDnvgE1o6HXyZDSBOo/Zfy7pUQd1xWVhZLly5FpVKRlpbGypUrGTBgQHl3q8K4dOkS48aNAyAiIoLKlSuTlpZWzr1yd+TIEdq3b09eXh79+/enWrVqrF69mtdff50jR47w1Vdfleg6p06don379ly6dInu3bszYMAATpw4QUxMDGvWrGHHjh3Url3b2d5ut9O/f3+WL19O7dq1GT58OEajkdjYWJ555hm++uorRo8e7fIeK1asYMCAAWg0Gp599llCQkLYvn07U6ZMYdOmTWzcuNFl7aQrV67Qrl07Tpw4Qbt27XjmmWe4ePEiy5cvZ+3atWzatIlHHnnkznwhRbG6dOnCRx99xPr16xk4cKBL3bp165xtypvKXto4VpRKUlIS4eHhnDt3jrAw+UBECCGEEOKOyr4E0T3hykkIqgcvrgWfIGe11WZnwopDLNl9DpUKPuzbhEFtapRjh+8Rh5bBj2+AOQcMIcoT6RHtb3qa3W5n64lUorYlsPn4ZWe5r6eWp5pW46mmobSJDEB3H06DZ7XZ2Xf2KqsOXmT1oYtczip8krlWkA+D2tTg2VZhBPjc/RFhQgjxQLHb4cfRyigqr8rwarxzKmBxb8vPzychIYHIyEg8PR/sNSDnzp3LK6+8wltvvcWXX37Jo48+yvr160t1jft5ir/s7Gx27txJq1atCAgI4IknnmDdunUkJCQ4r1URdOnShS1btrBmzRp69uwJgMlk4rHHHmPr1q3s2LGDdu3a3fQ6Tz31FKtXr2batGm88cYbzvIffviB/v3706NHD37++Wdn+bJly3j++efp0KEDGzZswMvLC4DU1FRat25NcnIyR48edX6t8vLyiIiIID093fl1BeV3/zFjxjBjxgw++ugj3nnnHed7jB49mhkzZvDWW2/x+eefO8t37txJp06dqF+/PocOHUKtvv/+Lijpz6rbzQ0+/vhj/vWvfxEdHc0LL7zgVm+xWKhfvz7nz5/n119/pXnz5oAy5V+bNm04c+YMx44dK/f/Ju6/7wAhhBBCCPHgMFSFISvBL0wZSbXwWcjPdFZr1Co+6teEF9rXxG6HCf89xH/3JZVffys6mxXW/Q8sH66EUzU7wYgtNw2nTBYbi387S/cvtjB03m9sPn4ZlQo616vC9EEt+P1/HuOjfk3oUCfovgynQPlea10zgMlPN+bXfz3KopcfYUDrcHz0Gk6n5vC/a/6k7YcbGbt4H3+cL34eeCGEEHeASgW9PodqLSHvKiz5G5hyb36eEPeQqKgotFotb7/9Nt26dWPjxo0kJiYW29ZqtfLJJ59Qp04dPD09qVOnDh999BE2W/HTEcfFxfHSSy9Rv359DAYDBoOB1q1bM3v27GLbq1Qqunbtyvnz5xk8eDBBQUH4+vry5JNPcvr0aQD+/PNP+vTpQ0BAAL6+vjz33HOkpKSU+r7T09MZMWIEISEheHp60qJFC77//nu3dgaDgccff5yAgIBSv8fdcvz4cbZs2UK3bt2c4RSAXq9nypQpAMyZM+em18nPz2fdunUEBwczZswYl7rnn3+e5s2bs27dOue/BUBsbCwAEyZMcIZTAEFBQbz55psYjUaio6Od5Tt27ODy5cv06dPHGU6B8m//73//G4D//Oc/LtMSxsbGolaref/991361K5dO3r37s2RI0fYvHnzTe9PuJo7dy4vvPACL7zwAj/88INb2dy5c51ttVotc+fOxWaz0blzZ1599VX+8Y9/0KxZM44fP86HH35Y7uEUyBR/QgghhBDiXlcpHIauhHk94OJ++H4QDFkBWmWKCbVaxaTejbDb7cTsTGTcDwfx1mvp0TikXLtd4RizYPnLcLzg6cqOb0G3/wHN9f9kMFttLN+TxFebTjqn8fPRa+j/cDgvtK9JRKDP3eh5haNRq2hfJ4j2dYJ4r3cjfjpwgUW7znLofAax+y8Qu/8CHesE8WrnWnSqGyTT/wkhxJ2m84QBC+CbLpB8CH4aC/1ml3j9RHFvsdvt2K+zBk5FovLyuiP/zz9y5Ai//vorvXr1Ijg4mKFDh7Jx40aio6OZPHmyW/tXX32VefPmERkZyahRo8jPz2fq1Kns2LGj2Ot/8sknnDx5krZt29K3b1/S09P5+eefGTFiBMeOHXMZDeNw9epVOnbsSEhICMOGDeP48eOsWrWKo0ePEhsbS6dOnWjVqhUvvfQSe/bsYfny5aSlpbFp06YS37djZFF2djZDhgwhJyeHpUuXMnjwYFJTU93CmYouPj4egO7du7vVdezYER8fnxIFOFeuXMFisRAREVHs91dkZCT79+8nLi6OWrVqAcr6Q4664toDbNq0yRku3ah9pUqVqFy5MomJiZw+fdo5lWBycjJBQUEYDIYbvke3bt1ueo/3u27duqHT6Rg1ahSjRo26Ydtt27YRExPjUrZ9+3a2b9/ufP3yyy+7XHvbtm1MmjSJJUuWYDabadKkCZ988kmFmRZUAiohhBBCCHHvC6oLf1uhrEmVuE2Znq7vf5wfRKlUKib1bky20cryvUmMWbSPeS88TMe6QTe58AMi/SwsGgiXDoPWE/rMhIeevW5zm83Oyv3n+eKX45xLUz4QquLrwaudajGgTTh+99C6UmXN4KFlUJsaDGpTg4NJ6czblsBPBy+y7WQq206m0riaH289Xo+/NKgqQZUQQtxJ/mHQPwZinoZDSyGyE7QcWt69EmXAnpfHsZatbt6wnNXfuweVt/dtXycqKgqAIUOGANCvXz9ef/11oqOjmThxosuUafHx8cybN49mzZqxfft2fHyUh4cmTJjgnO7rWrNmzXILIiwWC7169WLatGmMHTuWGjVcp8w+ePAgb775JlOnTnWWvf7668yaNYtOnToxefJkxo4dCyiB4lNPPcWaNWvYu3cvLVu2LNF9X7x4kbp167Jjxw70er3zPlq0aMH48ePp168f1atXL9G1SiM+Pt4ZJpVEzZo1i51u7VonTpwAoG7dum51Go2GyMhIjhw5gsViQau9/kf4lStXRqPRkJiYiN1ud/t9MiEhAVBGbDkEBQU56xo2bFiq9tfKyMjg6tWrznMcAVVQUBCXLl0iOzvbLaQq7j0eZHFxcSWe4m/+/PmlnpKzTZs2rF279hZ6dnfcn/NrCCGEEEKIB0+15vD8fFBp4OBi2PKZS7VareKTZ5vQ86EQTFYbIxbs5s+LmcVe6oFyfi/M+YsSThmC4YU1Nwyn9p29Sr9ZO3hr6QHOpeURZPDgvacasfXtbrzSuZaEUzfQNKwSXw5swebxXXmxQ028dBoOX8hkeMxu+s3awY5TqeXdRSGEuL/U7AiPvqccr3kbLh0t3/4IcZvMZjMLFizAz8+PPn36AMp0dn379uXs2bP88ssvLu2//fZbACZOnOgMpwCqV6/uDIyuVdwoGa1Wy8iRI7FarcTFxbnVGwwG51RvDoMGDQIgMDDQZV0klUrFwIEDAThw4MDNbtnFhx9+6AynAMLCwhg7dixGo5HFixeX6lolFR8fz/vvv1/iraThQUaGMuWzv79/sfV+fn7YbDaysrJueB1vb286d+5MSkoKM2fOdKlbsWIF+/fvB5TpER0cUwp+/PHH5OfnO8uvXLnCl19+6da+Q4cO+Pn5sXLlSvbt2+fyHhMnTnQeX/seNpvNbYq/Xbt2sWrVKrf24sElI6iEEEIIIcT9o86j8ORnsOpNiPs3BNZyCVu0GjXTBrYgfd5v7Dx9heHzf2flqA5U9XtAF9o+HQ+L/wqmbAhpAoMWK0+cF+NKtpH/XfMnK/aeB5Sp/Eb9pQ4vto/ES6+5i52+94VV9mZS78a88Ze6zN56mujtCew7m87gObvoVDeI955qRL1g3/LuphBC3B/aj4XTm+F0HCx7EV7ZBDqvm58n7hkqLy/q791T3t24KZXX7X/fxcbGcvnyZYYPH46nZ+Hvr0OHDmXhwoVERUW5TBnnCIA6derkdq3iygCysrL47LPPWLlyJadOnSInJ8el/sKFC27n1K1bF+9rRoeFhoYC0LRpU7dRPY664q51PVqtlnbt2l33Pq4NTu6UyZMnFzt1YkXyxRdf0LFjR0aPHs1PP/1E06ZNOXnyJLGxsTRt2pSDBw+6jKwbPHgw8+fPJy4ujiZNmvDEE09gNptZuXIlwcHBAC7tDQYDU6dO5eWXX6Zdu3Y899xzhISEsGPHDvbs2UODBg04evSoyzkffPABP//8M5999hk7d+6kbdu2XLx4kWXLltGoUSO3PokHlwRUQgghhBDi/tL6JUg9Cb/OgP++BpVrQvXCaV/0WjX/+Vsr+s3azqnLOQyP2c2SEW3x1j9gvxofiVXWnLKaILILDPwOPNxDEbvdzqqDF5n042HSckwAPNcqjLd71H9wg707pLKPnn8+0YAX29dkRtxJFv12lq0nUuk5bSvD2tVk7GN18feSEWlCCHFb1Gpl/alZHeDSEVj3P/DU1JufJ+4ZKpXqjkyddy9wTO83dKjrdJWPPvoo1atXJzY2lrS0NAICAgBllI5arXZO0VaUI4goymQy0bVrV/bu3UuLFi0YMmQIgYGBaLVazpw5Q0xMDEaj0e08Pz8/tzLHtHQ3qjObzTe7ZaegoKBiAw3HfThGJN0rHCOnrtfvzMxMVCoVvr43f2ipWbNm/P7770yaNIm4uDji4uKoU6cO33zzDenp6YwfP56qVas622u1WtauXcvHH3/MokWLmD17Nv7+/vTt25dx48ZRr149l/YAw4cPp1q1anz66afExsZitVp5+OGH2bhxI5988glHjx51OScsLMzZp7Vr1/Lbb78RHh7OBx98QM2aNRk4cKDbe4gH0wP2V7gQQgghhHggdJ8Caafg+M+wdBi8uhl8Ap3V/t46ol9oQ5+Z2zl0PoN/LD3AzL+2fHDWANoTA6v+DnYbNHwanp0LWg+3ZpezjPzPfw+x/kgKAA1CfPmoXxNa1Kh8lzt8f6vq58n7zzzE8I61+PfqI6w/ksK87QnE7j/POz0b8FyrsAfne1MIIcqCoSr0+wYW9IXdUVD7L9DwqfLulRClcu7cOdavXw9Aly5drttu4cKFzin1/P39sdlspKamUqVKFZd2KSkpbufGxsayd+9ehg8fzty5c13qFi9eTExMzO3exi1LTU3FZrO5hVSO+7jeVHm3q6zWoHKsPeVYi6ooq9VKQkICkZGRN1x/qqgGDRqwZMkSt3JHX1q3bu1S7uHhwaRJk5g0aZJLueNer20PyrR9jukBixoyZAhqtdptPbHq1au7fR8BzhFpxb2HePBIQCWEEEIIIe4/ao3ytPTsrpB2Gla8DH9dppQXqBHozZyhrRg0exdr/0hmztbTvNq5dvn1+W7ZHa2EUwCtXoAnp7p8XRy2n0xl7OL9pGYb0apVjP5LHV7vWge9VqbiKCs1Ar2ZPbQ1W45f5v2fDnPqcg7jlx3kp4MX+ahfE6pXkimphBDiltX+C3QYC9unwU9joUZb8HEfVSJERTV//nxsNhsdO3akfv36bvUWi4WYmBiioqKcAVWzZs3Yu3cvW7dupV+/fi7tt27d6naNU6dOAfDMM8+41RXX/m6yWCzs3LmTDh06uJQ7+tWiRYsyeV/HGlQl1aVLlxIFVI6Qcf369bzzzjsuddu2bSMnJ+eGQWRJZGVl8dNPPxEYGMjjjz9eonO+++47AOc6YTezfft2zpw5Q69evUoUElqtVhYvXoxWq+XZZ6+/7q14cMhfl0IIIYQQ4v7k6Q/9F4DWC05tgs2fuDVpFRHAxN6NAPjk52PsOn3lbvfy7tr7bWE41XYUPPWlWzhlsdqYuv4Yf4vaRWq2kfrBvvw0piN/f6yehFN3Sed6VVg7tjP/fKIBeq2aLccv0+OLLXy3KxG73V7e3RNCiHtXt/+Bqo0hN1VZr1J+pop7hN1uJzo6GpVKRUxMDHPnznXb5s+fT7t27Th48CC7d+8GlJEtoKwHVHQtqfPnzzNt2jS394mIiACUgKSozZs3M2fOnLK6vRKbMGECJpPJ+TopKYlp06bh4eFR4kCltCZPnozdbi/xVtLRVvXr16dz587ExcWxdu1aZ7nJZOK9994D4OWXX3Y5JzU1laNHj5KamupSnpeXh8VicSkzGo0MHz6ctLQ0Jk6c6LJmGShTCF5r2bJlzJs3j4cfftgt0Cyu/YULF3j55ZfRarVMmTLFpc5sNpOXl+dSZrPZGDduHMeOHWPMmDFUq1bN7ZriwSMjqIQQQgghxP0r5CHoPQ3++6oSUIU/AnUedWny10dqsCfxKv/dd57R3+9j9ZiO9+faSnsXwI/K07Q88hr0+F+4Ztq4jFwzoxbtZdtJ5Y/eQW3CmfhUY7z07iOsRNnSa9W81rU23RsH8/ayg+xJvMr//PcPNhxJ4fPnmxFocJ+SUQghxE1oPaDvLJjzF/jzR/hjOTR5rrx7JcRNbdq0iYSEBLp06UKtWrWu2+7FF19k586dREVF0bp1a7p168aLL75IdHQ0TZo0oW/fvhiNRpYsWULbtm1ZtWqVy/m9e/emZs2afPrpp/zxxx889NBDHDt2jFWrVtG3b1+WLVtW1rd6XaGhoeTk5NC0aVN69+5NTk4OS5cu5cqVK0yfPp3q1au7tB83bpwzyDl06JCzzGAwAEr407Fjx7t7E9eYOXMmHTp0oE+fPgwYMIDQ0FBWr17N4cOHGT16NO3bt3dp//XXX/P+++8zadIk5zR5AHv27KFfv348/vjjhIeHk5mZyerVqzl79iyvvPIKY8aMcXvvRx55hPDwcBo2bIinpye//fYb8fHx1KpVix9++AGNxvX3/+nTp7Nw4UI6duxI1apVOXfuHLGxseTm5hIVFeU2vV9KSgqNGzeme/fuREZGYjKZWLduHUePHuXJJ5/ko48+unNfSHFPk0cghRBCCCHE/a3ZAGj9knK88jXIcR0lpVKp+N++D1E/2JfLWUbeXLofm+0+e6L68H/hxzGAHdqMgCc+cgunTl/Opu/M7Ww7mYq3XsO0gc35qF9TCafKWe0qBpaOaMd7TzXCQ6sm/thlek7byo5TqTc/WQghhLvQZtB5vHK8+h+QlVy+/RGiBKKiogBuOnXcgAED8PLy4vvvv3eOXpkzZw4fffQRKpWKr7/+mrVr1/LWW2/x5Zdfup1vMBjYtGkTzz77LL///jtff/01Fy5c4LvvvmPUqFF3+rZKRa/Xs2HDBrp06cKCBQuYN28eYWFhLFq0qNgAZtmyZcTExBATE8OFCxcAWL58ubPs5MmTd/sW3DRu3Jhdu3bx9NNPs3r1aqZNm4ZarWbGjBlMnz69xNepUaMGXbt2ZevWrXzxxRd8//331KlTh2XLljF79uxi1zIdMGAAycnJREdHM336dFJSUnj33XfZt2+fcyRdUe3btyc8PJyffvqJzz77jF9++YVevXrx+++/M2zYMLf2/v7+PPPMM+zdu5evvvqKefPmUblyZebMmcOPP/6Ih4c8bOXQrVs3GjVqxIwZM8q7K+VCZZc5IspUUlIS4eHhnDt3jrCwsPLujhBCCCHEg8mcB990gdRj0OApGLDQLaA5dTmbJ6dvJd9s472nGjG8Y2Q5dfYOS9gKC/uB1aQEdU9Odbv3naeuMGLBbjLzLVTz92TusIdpVM2vnDosrudociajF+3j5KVsVCoY060OYx+rh0bt/qGDEEKIG7CaYe6jcPEA1H8SBi0q7x6JEsjPzychIYHIyEi36cqEEKKiKOnPKskNFDKCSgghhBBC3P90XvDsXFDr4OgqZS2ma9SuYuB/nnSsR3WUY8lZd7uXd17yH7B4sBJONewNvT5zC6d+OZLCsOjfyMy30LJGJWJHd5RwqoJqEOLHj6M7MPDhcOx2mL7pJK98u5vMfHN5d00IIe4tGh30+Y/ye8Gx1fDnqpufI4QQQog7TgIqIYQQQgjxYAhtCo8qCw7z8zuQdtqtyd8eqUG3+lUwWWyMXbwPo8V6lzt5B2UkwcJnwZgJER2g31xQu07XF7v/PCMW7sFksfF4o2AWvdKWKr4y3UZF5q3X8vGzTflyQHM8tGo2Hb1E3xnbSUjNufnJQgghCgU3gg4FazOuGQ/G++DBFCGEEOIeIwGVEEIIIYR4cLQbAzU7gTkXfvo7XDPbtUql4pPnmhLgo+dochb/iXcPse4Jplxl5FR2MlRtBAMXgc51eokfdp/j70v2Y7XZ6duiOjP/2hJPnaw3da/o06I6P4xsR4ifJ6cu5/DM19vYcVLWpRJCiFLpPB4q14SsC7Dpf8u7N0IIIcQDRwIqIYQQQgjx4FCrofc00HpCwmbY/51bk6q+nkzqrUz1NyPuJKcuZ9/tXt4eux1+HKOsq+EdCIOXgFcllyarD17kn8sPYrfD39rW4PPnm6HTyJ8G95qmYZX4cUwHWtaoRGa+hWHRv7Hq4IXy7pYQQtw7dF7K2owAv30DF/aVb3+EEEKIB4z8FSqEEEIIIR4sgbWh2wTleN0EyEpxa/J0s2p0qVcFk9XGhBWHsF8z0qpC2z4N/lgGai30/xYq1XCp3nQ0hbGL92Gzw8CHw5nyzEOo1arrXExUdFV9Pfn+1bb0ahKC2WpnzPf7mL89oby7JYQQ9446j0KT58FuU0ZX22zl3SMhhBDigSEBlRBCCCGEePC0HQWhzSE/A9a+7VatUqn4d5+H8NJp2JWQxg+7k+5+H2/F6c3wy2TluOcnULOjS/W+s1d5beFeLDY7zzSvxv/2bYJKJeHUvc5Dq+GrQS0Z2i4Cux0m/3SEqRuO31vBqhBClKceH4KHH1zcDwcWlXdvhBBCiAeGBFRCCCGEEOLBo9HC01+BSgNHVirBzjXCA7x58/G6AHz881Ey8sx3uZOllH0ZVrwC2KHF36D1cJfqc2m5vPLtbowWG39pUJXPnm+GRkZO3Tc0ahXvP92YfzxeD4DpG09ISCWEECVlqApdCh5Y+eV9yM8s3/4IIYQQDwgJqIQQQgghxIMptCk8XBDi/PwOWC1uTV7sEEntKj6k5ZiYEXfyLnewFGw2WPkaZKdAlQbQ8/+gyMiorHwzL8fsJjXbRKNQP74a1ELWnLoPqVQqxjxal3efbAjAV5tOSkglhBAl1WYEBNSGnEuw9fPy7o0QQgjxQJC/SoUQQgghxIOr67/AKwAuHYHd89yqdRo17z7ZCIDo7QkkXsm52z0smV9nwMkNoPWE56JB7+2sstvtvLX0AMdSsqjq60HUC63x8dCWY2dFWXu5Uy2XkOqLX06Uc4+EEOIeoNX/P3v3HR5VnbZx/DuTmfQGCZBQQ+gBBCkCigiioKBYWUAFBdeyYlmVdVcWddXdRVcEV9G1UJWmr4qoWJCmdEGKdEILECAVEtLLzPvHmQkZUkggyaTcn+ua6wxzzpzzBGPIzD3P84PB/zLub3wPkg+7tx4REakTBgwYQFRUFO+++667S3ELBVQiIiIiUnf51ofrJxn3V/0T0pOKHNK/XQOubRNKbr6dyd/tq+ICyyB+L6x4xbh/02vQKMpl98y1R/hpTxyeHmZm3N+D8CAfNxQpVa1wSPX2imjmrj/q3oJERGqCtjdB5ADIz4FlL7i7GhERqQNWrVrFnj17GD9+vLtLcQsFVCIiIiJSt3V/ABp1hqwU+OWNIrtNJhOThkZhNsEPu0+z9diZqq+xJPl58NVjxhtpbW8yvpZCfos5w2vfG6HaC7dGcUXT4KqvUdzmj9dGFqxJ9Y9vdvPt7yfdXJGISDVnMsHgf4PJDPu+heO/ursiERGRWk0BlYiIiIjUbWYPGOToQNoyE84eL3JIu7AA7urWFIBpPx2oyupKt2E6nNwKXkFwyzSXdadSMnN5YsFW8mx2bu3SmPt6NXdjoeIuj1/fmjF9WmC3wzOf7mD9wUR3lyQiUr01ioKu9xj3l78MWsdPRESk0iigEhERERGJHAAR1xqdSD+/XuwhTw5sg8VsYk10IpuPJldxgcVIjIZV/zbu3zQZAhu77P7nt3s4mZJFixBfJt/ZGVOh8ErqDpPJxEu3dmRI5zBy8m08Ou83DiekubssEZHq7bq/gYcnxKyFQyvcXY1IpTh69Cgmk4kHHnjA3aW4TUREBBEREZVy7v79++v3b5EyUEAlIiIiImIywcCXjPvbFxjhzwWa1fdleI9mAExd5uYuKrsdvpsA+dnQ+obzn/R2WLE3jv/77QQmE7w5vAv+XhY3FSrVgYfZxLQRXeneoh6pWXn88eMtpGblurssEZHqK7gZ9HzIuL/iFbDZ3FuPSDHGjRuHyWQiJCSE7Oxst9XhDGJMJhPffvtticf16tWr4LjVq1eXeNwrr7yCyWTCarVy+vTpEo974IEHCs43ffr0Eo8bMWJEwXFz5swpy5fkFnFxcTz++OP06tWLRo0a4eXlRdOmTRk4cCBffvkl9mrUzXnq1CkefPBBwsPD8fb2pl27dvzrX/8iN7d8v1+eOXOGCRMm0Lp1a7y8vGjQoAF33303u3fvLvE5CxYs4JprrsHf3x8/Pz969uxZ6n/XvXv3cu+99xIWFoaXlxctWrTgqaeeIjm5GnzgUAAFVCIiIiIihmY9od0QsOef70y6wOPXt8bTw8yGw0lsOpxUxQUWsvdrOLwaPLxgyBTX0X4ZuTz/5U4A/ti3JT0i6rupSKlOvCwe/O++boQHeXM4IZ0nF24j31Z93ugQEal2rn0GPP3h1A7Y85W7qxFxce7cOT777DNMJhPJycl89dVX7i4Ji8XCrFmzit23e/dufv31VyyW0j80ZbfbmT17NiaTiby8PObOnXtZ101OTmbJkiUlXnfFihWsWFE9uiSPHz/Oxx9/TFBQEHfccQfPPvssN910E7t27eKuu+7i4YcfdneJAJw+fZpevXoxe/Zsrr76av785z9Tv359Jk2axN13313mIC0pKYlevXrx5ptv0rBhQx5//HFuvPFGvvnmG6666io2bdpU5DnPPvss9957L4cPH+bee+9l7NixJCUlMXbsWCZMmFDk+I0bN9KzZ08WLVrE1VdfzZNPPknbtm15++236dOnD0lJbnw9JwUUUImIiIiIOA34u7HdvRgSDxbZ3STYh7t7GGtRfbTmcFVWdl5OBvww0bh/zVNQv6XL7inL9hN/LpvIBn48O6idGwqU6qphgDcfjemBt9XM6v0JTFm2390liYhUX36hcPUTxv1V/wZbvnvrESnk008/JT09naeffhqz2czMmTPdXRI333wz3377LQkJCUX2zZw5E7PZzODBg0s9x4oVKzh69CgPPfQQgYGBJQZPF15327Zt7Nixo8i+efPmkZ2dzZAhQ4p9bqtWrWjVqtVFr1EVunTpwpkzZ1i2bBnvv/8+//73v5kxYwYHDx6kQ4cOzJgxo9TOoqry17/+lePHj/Pee+/xxRdf8Nprr7F+/XpGjhzJ119/zaJFi8p0npdeeono6GieeeYZ1q9fz5tvvsmCBQtYvXo12dnZjBs3Dluh7tUtW7YwdepUWrduze7du/nggw+YPn06O3fupGfPnrz55pts2LDB5RoPPfQQ6enpLF68mC+//JI33niDn376if/85z8cOHCAv//97xX6dyOXRgGViIiIiIhTWCdoezNgh/X/LfaQP/ZtickEy/fGczD+XNXWB7B2GqSegKBm0Pdpl127YlOYvykGgH/e3glvq0fV1yfVWqcmQfzn7i4A/G/1IVbvj3dzRSIi1Vif8eAdDEnR6qKSamXmzJlYLBaee+45BgwYwIoVK4iJiSn22Pz8fF5//XVat26Nt7c3rVu3ZvLkyS5v/l8oPj6ep59+umD0WmhoKHfddRe7du0q8Tnjxo0jNzeXTz75xOXx3Nxc5s2bx6BBg2jatOlFvy6Ahx9+mOHDh3PgwAHWrFlT6nPuv/9+PDw8ig3pZs+eTYcOHejTp0+xz71wDaqkpCSaNm1KQEAABw+6flittH2lycrK4m9/+xvNmzfH29ubDh068M477xTpNLJarXh4FP3dPSAggJtuugmgXNetDOfOnePTTz8lMjKSRx55pOBxk8nEa6+9BsBHH31UpnMtWbIEs9nMyy+/7PJ4nz59uPXWW9mzZw8///yzy/EATz/9NPXrn58Q4efnVxA0vf/++wWPHzp0iF27dtGzZ0+GDRvmco1nn32WkJAQPvnkE9LT08tUr1QeBVQiIiIiIoU5Q5/tCyH1ZJHdkQ38ubFDIwBmrDlSlZXBudOwwTFjf9A/wdO3YJfNZufFJbuw2eGWK8K5ulVo1dYmNcawLo0Z3bsFAM98toO41Cw3VyQiUk15BUDvx4z7v0zRWlRSLezZs4eNGzcyaNAgGjVqxJgxY7DZbMyePbvY4x9++GH+9re/YbPZGD9+PIMHD2bq1Kk89dRTxR5/6NAhunfvzltvvUWrVq144oknGDJkCD/88AO9e/cudvQaQO/evYmKiipSxzfffENCQgLjxo0r9etKTk5m8eLFREVF0b17d8aMGQNw0e6wJk2aMGjQIBYsWEBOTk7B41u3bmX79u2MHTu21OcXFhISwscff0xGRgb33HOPy5pKDz74ILGxsUyfPp3WrVuX+Zx/+MMfmD9/PnfeeSePPvooaWlpPPnkk8WOpCtOVlYWK1euxGQy0bFjxzJftzJs2LCB7OxsbrzxRkyFRowDtGjRgnbt2rFu3Try8y/ecXr69GlCQ0Px9/cvsq9lS2NCxMqVK12OL7zvco43m800b96cjIwMNm7ceNFapXJptWQRERERkcKa94IW10DMOtjwLgz+V5FDHrkukmV74vhyayzPDGpLwwDvqqntlzcgNwOa9ICo21x2LdkRy9ZjZ/H19ODvQztUTT1SY/19aAe2xJxh76lUnlq0jfl/7I2H2XTxJ4qI1DW9Hob170D8Htj/HXS4xd0VyQXsdjuZudV/BKOP1aPIm/qXwhnYjB49GoA777yTxx57jNmzZ/Piiy9iNp/vR1i9ejWzZs2iS5curFu3Dj8/PwAmTpxI165diz3/mDFjOHXqFD/88IPLSL5JkybRo0cPHnroIX7//fdinztu3DgmTJjA5s2b6dmzZ0G9ISEh3HbbbaWu9zR//nyys7MLvq5rr72WiIgI/u///o+3336bwMDAEp/74IMP8v3337NkyRKGDx9ecF2LxcKYMWNKDO+Kc/311/Pcc8/x2muvMWnSJF5//XXee+89lixZwqhRo7j//vvLfC6AAwcOsGvXLoKCggB4+eWX6dWrF9OmTWPUqFH06NHD5fj4+Hjee+89bDYb8fHxfPfddxw/fpyXXnqpzMHY9u3by7UuWXBwMH/+858velx0dDQAbdq0KXZ/mzZt2L9/PzExMURGRpZ6rtDQUOLj40lLSysSUh05YnwI8MCBAy7HF95X3PEnTpwgIyMDX1/fUo+32WwcO3as4BoDBw4stVapXAqoREREREQu1PdpI6DaMhuuew68g1x2d29Rn27Ng9l67CwLNx3nqRuKf5FWoZIPw29zjPs3/AMKvcGRnZfPm8uMF3DjB7QmPMin8uuRGs3b6sG791zJLe+sZePhZN7/+RDjB5T908AiInWGTz0jpFrzpvFBkfZDXf4NFvfLzM0n6sUf3V3GRe15ZTC+npf3VqxzhF5gYCC33347AP7+/txxxx3MmzeP5cuXM2jQoILjP/74YwBefPHFgnAKjK6jp556ihdeeMHl/Nu2bWP9+vWMGzeuyHpRbdu25aGHHmLq1Kns2rWLTp06Falv9OjRPP/888yaNYuePXty8uRJfvzxRx5//HE8PT1L/dqc61Tdd999gDE27r777uOf//wnixYt4uGHHy7xucOGDSM0NJRZs2YxfPhwsrKyWLhwIUOHDqVRo0alXrc4r7zyCitWrGDKlCk0bdqUv/71r0RERLiMkCurF154oSCcAggKCmLSpEmMHj2auXPnFhtQFR57Z7VaeeONN3j22WfLfM3t27cXGZ1XmhYtWpQpoEpJSQFw+XoKc4aIzuNKc/PNNzN79mxefvll3njjjYLHN23axLfffgvA2bNnXY5/7bXXeOutt7jnnnsIDg4GICMjg8mTJ7vU6OvrS9u2bYmMjGTz5s0sXbqUoUOHFhzz1ltvkZSUVOQa4h4a8SciIiIicqHWN0CDDpCbboz6K8b9V0cAsGjzMfLyq2Dkz6p/gy0PWg2Elte67Fq46RgnzmTSMMCLcdcUHWMhUpzIBv68PMwYFfPf5dHsP+2GNdVERGqC3o+B1RdObYeDJXeAiFS2JUuWkJCQwPDhw/H2Pt/BX9I4vB07dgBGN9KFinvMOe4sLi6Of/zjH0Vu+/btAyjYXqhhw4YMHTqURYsWkZWVxdy5c8nPz7/oeL8tW7awY8cOBgwY4LJOVVnH/FmtVu677z6WLVtGbGwsixcv5syZMxe9bmnnW7hwIb6+vjz55JPk5OQwf/78Uru4SlLa3/22bduK7OvUqRN2u528vDyOHDnCyy+/zN///nfuuusu8vLyynTNBx54ALvdXubb0aNHy/11Xa5XXnmF8PBwpkyZQt++fZkwYQL33nsv/fr1IyoqCsClG7Bfv36MHj2a6OhooqKiePTRR3niiSfo3Lkzp06dKgjNnM8xmUy89957WK1Whg0bxl133cVzzz3H4MGDefbZZ+ncuXORa7jLgAEDiIqK4t1333V3KW6hDioRERERkQuZTHDVH2Hps7D5I7jqYbjgxctNncKo52vlVEoWq/cncENU+T+dWWZJh2Dn58b9gS+67ErPzmP6KmPB5CcHtsHHs+jiyiIlubt7U77fdZqV++L5y+c7+PJPV2PxcP8LdRGRasUvFHqMM9aBXDsV2tzg7oqkEB+rB3teGXzxA93Mx3r5v6M5gxpncOM0cOBAmjRpwpIlS0hOTqZ+/fqA0U1iNpsLxp0VVlxnUXJyMgBLly5l6dKlJdaRnp5e4r5x48bx1Vdf8cUXXzB79my6d+/OFVdccUlfV5s2bejduzcbN25k9+7dpa7BNG7cON566y3mzJnD6tWrCQsLY8iQIaVetzSRkZEFoxG7d+/O1VdffUnnKe7v2flYaZ1GHh4eRERE8Pzzz2OxWHjuuef46KOP+NOf/nRJdVQEZwhUUt2pqakux5WmadOmbN68mZdeeonvv/+eX3/9lWbNmvHKK68QERHByJEjadiwoctz5syZQ48ePZg5cyZz5szBx8eHwYMH85///IeOHTtisVgKvvcBBg8ezJo1a3j11VdZuXIlS5cupVOnTixevJgVK1awc+fOItdwh1WrVrkEs3WNAioRERERkeJcMRKWvwxJB+HwKmjtOpvcy+LB3d2b8tGaIyz49VjlBlTr3gLs0PYmaNzVZdfcDUdJTMuhRYgvI3o2q7wapFYymUxMvrMzN079md9PpPDBL4c16k9EpDh9xsOm940RwLFboUk3d1ckDiaT6bJH59UEx48fZ9myZQBcd911JR43b948nnzyScAICmw2G4mJiTRo0MDluLi4uCLPdXYIvfPOOzz++OOXVOeQIUMIDw/nr3/9K7Gxsbz33nulHp+ZmcnChcbEgvvvv7/ENZ5mzpzJ1KlTSzxP586d6dmzJ++++y5xcXFMmDABi+XSvy+mTp3KunXrCAkJ4ddff+W9997jscceK/d54uLiaN68eZHHoGxBDsCgQYN47rnnWL16dZkCqspag8q59pRzLaoLRUdH4+npWeTrLUmTJk2YMWNGkcf/8Y9/ABQZf2g2m3nyyScLvr+djh49SlpaGt26dcNqtbrs69WrV8HIwMLeeuutYq8hVa/2//QWEREREbkUXv7Q9R7jzahfPyoSUAGMuqo5H605wqr98Zw4k0HTer4VX0fqyfNjBvs+7bIrKzefWWuPAvDk9W2wqvNFLkGjQG9eurUjz/7fDv67PJqbO4UR2cD/4k8UEalLAhtDxzth52ew8T24q+ibqiKVac6cOdhsNvr27Uu7du2K7M/Ly2Pu3LnMnDmz4A38Ll26sHXrVtasWcOdd97pcvyaNWuKnKNXr14AbNiw4ZIDKg8PD8aMGcPrr7+Ot7c3o0aNKvX4zz//nJSUFLp27Ur37t2LPWb+/Pl88sknvPbaa6WuZTVu3LiCAOdSx/uBMXpv4sSJtGvXjlWrVnHNNdcwYcIErrvuulK7uIqzZs0a7r333iKPAVx55ZVlOsfJkycBioQvJamsNah69+6Np6cnP/30E3a7HVOh9fhiYmLYv38/AwYMuKxgMD8/n0WLFmGxWLjrrrvK9Jz58+cDMHLkyDIdHxMTw9q1a4mKiioY9Sfuo1ewIiIiIiIl6flHY3vgB0g5UWR3ZAN/rm4Vgt0OX26NrZwaNrwLtlxofjU07+2y68utsSSmZdM4yJthXRtXzvWlTrizWxP6tW1ATr6Nl77ejd1ud3dJIiLVTx9H98TuxZBSSf/uixTDbrcze/ZsTCYTc+fOZcaMGUVuc+bMoU+fPvz+++9s2bIFgNGjRwPGej+Fx/LFxsby3//+t8h1rrrqKnr16sXChQv59NNPi+y32Wz8/PPPF633mWeeYfHixfz4448EBweXeqxzvN/UqVOL/bpmzJjBHXfcQWJiIl9//XWp57rvvvtYvHgx33//fbEhXlmkp6cXhGoLFy4kPDycBQsWkJuby6hRo8jKyirX+V599VWXkXgpKSn885//xGQyuXSL7dixg9zc3CLPT05OZuLEiQBlHllYWWtQBQYGMnLkSA4fPswHH3xQ8Ljdbuf5558H4KGHHnJ5TkpKCvv27ePUqVMuj+fm5pKZmenymM1mY8KECezfv58nnniCxo1dX984RwgWtmbNGiZPnkyLFi149NFHXfalpaUV+Z02JSWF0aNHk5+fz+TJk8v0dUvlUgeViIiIiEhJQttAi2uMcT6/fwrXPlvkkDuubML6Q0l8tT2WJ65v7fJJwsuWlQq/zTHuX9A9lW+z8+EvhwB48NpIdU/JZTGZTLw8rCODp/3CmuhEftx9mps6hbu7LBGR6qXxldCiL8SshV8/hBvL3qEgcjlWrlzJkSNHuO6664iMjCzxuLFjx7JhwwZmzpxJjx49GDBgAGPHjmX27Nl07tyZO+64g+zsbD799FN69+5d7OizhQsXMmDAAEaOHMlbb71Ft27d8PHx4dixY2zYsIGEhISLhjQNGzbk9ttvv+jXdfDgQX755RciIiLo379/qV/XwoULmTlzJnfffXeJx/n7+5fpuqV56qmn2L9/P1OmTCnocOrduzcvvfQSL7zwAn/5y1945513yny+tm3b0qlTp4JuoC+++IITJ07wzDPPuIyXmzZtGt9++y3XXHMNzZs3x8fHh5iYGJYuXUp6ejrDhw+/aDdaVXjttddYtWoVjz32GMuXL6d169b8/PPPbNy4kVtvvbVIF9PixYsZO3Ys999/P3PmzCl4PC4ujo4dOzJo0CBatmxJTk4OP/74I/v27WPo0KHFhkd33303mZmZXHHFFQQGBrJz506+//576tevz1dffUVAQIDL8V999RUTJ07k+uuvp3HjxsTHx/P111+TkJDAq6++yrBhwyrl70jKR69iRURERERK0/UeY7t9ARTTVXJTpzC8LGYOJ6Sz+2TRT/Vdlh2LICcNQttCmxtddv205zRHkzII8rEyUmtPSQVoGerHI9cZb3q98s0eMnLy3FyRiEg11Ge8sf1tNmSnubcWqTOcXUYPPPBAqceNGDECHx8fFi5cWNCd8tFHHzF58mRMJhPTp0/n+++/55lnnilYg+dCLVu2ZNu2bUyaNIm0tDRmz57NBx98wPbt2+nXr1/BelEVYdasWdjtdu6///5SP+Q1cOBAmjVrxrJlyzh+/HiFXf9CX3zxBTNnzuTGG2/kmWeecdk3ceJE+vXrx/Tp04sN9kry2Wefcc899/Dll1/yv//9Dz8/P95++22mTJnictzo0aMZMmQI+/btY+7cuUybNo2ff/6Zfv36sWjRIj777LOK/SDcJQoPD2fTpk2MHTuWtWvXMm3aNJKSknj11Vf5/PPPy1xjUFAQt912G1u3buWdd95h1qxZ1KtXj48++oivv/4aLy+vIs+5/fbbyc7OZv78+UydOpV9+/bxxBNPsGvXLrp27Vrk+M6dO9OlSxeWLVvGlClTWLJkCb169WLlypVMmjTpcv8qpIKY7JrdUKlOnDhBs2bNOH78OE2bNnV3OSIiIiJSXtnnYEpbyM2AB5dDs55FDhk/fytLd57ij31bMumWqIq5rt0O7/aCxP1w83+g1yMuu++dsZF1B5N4rH8rnrupfcVcU+q8zJx8bpj6M7FnM3ny+tY8M+jSxuOIiNRaNhtM7w7Jh2HIFLjqoYs/RypEVlYWR44coWXLlnh7e7u7HBGRYpX1Z5VyA4M6qERERERESuMVAB0c4x+2zy/2kNsc6z99veMk+bYK+vzX0TVGOGX1gy6uozIOJ6Sx7mASJhPc06t5xVxPBPDx9GDS0A4AfLTmCPGp5VtnQUSk1jOb4aqHjftbZhfbXS0iIiJlo4BKRERERORiujrmve/6EvJyiuzu364hQT5W4s9l8+uR5Iq55q8fGdsuI8A7yGXXwl+PATCgXUOa1vOtmOuJONzUKYwrmweTmZvPf1dEu7scEZHqp8tIsHhD/G44sdnd1YiIiNRYCqhERERERC4m4lrwbwTZKXDk5yK7PS1mbujQCICf9sRd/vXSk2D/d8b9Hg+67MrKzef/fjsBwL3qnpJKYDKZ+JtjbOSizcc5lKA1VkREXPjUg053Gfe3zHJvLSIiIjWYAioRERERkYsxe0CHW437e74q9pBBHR0B1d7TXPYyr7u/BFsehHeBsE4uu37aE8fZjFwaB3nTv13Dy7uOSAl6RYYwsH1D8m12pvy4393liIhUPz3GGdtdX0JGBXVPi4iI1DEKqEREREREyiLqNmO7bynk5xbZfW2bULwsZo4nZ7Lv9LnLu9aOhcb2ipFFdi3ZHgvAHd2a4GE2Xd51RErx3E3tMZng+12n2X+539M1SF6+jTPpORxPzuBoYjrx57LIyMlzd1kiUt006Q5hnSE/+/y/2yIiIlIuFncXICIiIiJSI7S4BnxDISMRjq6BVte77Pb1tHBtmwYs3xvHst1xdAgPvLTrJEZD7G9g8oDOd7vsOpOew+r9CQDc1rXJpZ1fpIzahQVwc6cwvtt5mndXHeTtUVe6u6QKl2+zszM2hQ2HkvgtJpnDCekcS84gz1a0C7K+nyfN6/sS1TiQHi3q0SsyhCbBPm6oWkSqBZPJ6KL69mnYMht6P2Y8JiIiImWmDioRERERkbJwGfO3pNhDCo/5u2Q7Fhnb1jeAv+sIv+92nSLPZqdDeCBtGwVc+jVEymj8gNYAfPv7SQ7XorWo9p1O5dVv99Bn8gpuf3cdr/+wj+V74zmcmF4QTnlZzPh5ehS835ycnsP242dZsOkYz3y2g2teW8mw6Wv53+pDnE7JcuNXIyJu03k4WH0hKRpObHF3NSIiUgMNGDCAqKgo3n33XXeX4hbqoBIRERERKasOt8Bvs+HAMrDbi3xS+vr2RqC0KzaVxLRsQv29ynd+u/38GldX/KHI7iXbTwJwW9fG5S5d5FJ0bBzEwPYNWbEvnv+tPsQbw7u4u6TLsvloMu+uOljQiQgQ6G2hd2QIV7WsT1R4IC0b+BHq74XVw/g8p91uJy07j2PJGRxNzGD78TP8evQMO0+c5fcTKfx+IoU3l+1ncKcwHuzbkm7N67nryxORquYVAB2Gwe+LYMcCaNbT3RWJiEgNs2rVKpo2beruMtxGAZWIiIiISFm16Gt8UvrcSYjbDWGdXHaH+nvRITyQvadSWXcwsfxj+BL2Q9JB8PCENoNcd53LZvNRYxH2W7sooJKq8/j1rVmxL57F22J5dlA7woK83V1SuZ04k8G/v9vLdzuN7kazCQZ3DOOubk3p17YBnpaSh4uYTCYCvK10bBxEx8ZBDL0iHIDEtGyW7Y7jq22x/Ho0maW/n2Lp76e4vn1DJgxqR1TjSxzzKSI1S5eRRkC16wsYPBmsNe9npIiIiLtoxJ+IiIiISFlZvSHiWuN+9LJiD+nXJhSANdGJ5T//vm+MbWR/8HZ9c3vVvnjsdujcJEjr3kiVurJ5Pa5qWZ88m51PNh51dznlYrfbmb8phhum/sx3O09jNsGoq5qx8tn+/O++7twQ1ajUcKo0of5e3NOrOZ892ofvnryWu7s3xcNsYuW+eIa+s4a/L95JSmZuBX9FIlLttOwHgU0gKwUOfO/uakRERGoUBVQiIiIiIuXR5kZje3B5sbv7OgKqtdGJ2O328p17ryOgan9LkV3L9sQBcEOHRuU7p0gFGHdNBAALNh0jKzffvcWUUUpmLg9/8ht/X7yLrFwbV7Wsz9Inr2XynVcQEepXodeKahzIlOFd+OnpftxyRTh2O8zfdIwbpv7Mj7svY006Ean+zB5wxQjj/vaF7q1FRESkhlFAJSIiIiJSHs6A6thG49PSF+gZUR9Pi5nTqVkcjE8r+3nPHoNTO8BkhnZDXHZl5uSz9qCxZs6NUQqopOrdGBVGk2AfzmTk8rVjLbTqLCYpnTvfW8dPe+Lw9DAzaWgHFj3Umw7hlTt2L7KBP9Pv6cbCh3oTGepHwrlsHvnkN174aleNCfZE5BJ0vcfYHlwOafHurUVERKQGUUAlIiIiIlIe9SIgpA3Y8+Hwz0V2e1s9uCqiPgBrD5ZjzN+BH41ts97g38Bl17qDiWTl2mgS7EOH8IBLrVzkknmYTdx/dQsAZq07Uv7uwCq080QKd7y3nkMJ6YQHefPlY1fzx2sjMZtNVVZDn1YhfPfUtTzSLxKATzbGcMd76zmenFFlNYhIFQptA016GL8b7PrS3dWIiIjUGAqoRERERETKK7K/sT26ttjdvSONgGrL0TNlP+ehlca2zQ1Fdq0+YHwae2CHhphMVfcmu0hhI3o0x8fqwb7T59h2/Ky7yynWzhMp3DtjI8npOXRuEsSS8dfQqUmQW2rxtnrw/JAOzBnbkxA/T/aeSuWO99ax9Vg5fi6ISM3R6S5ju+crt5YhUlZHjx7FZDLxwAMPuLsUt4mIiCAiIqJSzt2/f3/93i5SBgqoRERERETKK6KvsS0hoOrh6KDaEpNctk6T/Fw48otxv9X1RXavP5gEQN/WoeWvVaSCBPlaGdI5HIDPNh93czVF7T2Vyr0zNpKalUf3FvVY+HBvGgZ6u7ss+rdryLdP9qVDeCCJaTmM/HAjP+zSulQitU7Ubcb22EZIrf6jUKXmGzduHCaTiZCQELKzs91WhzOIMZlMfPvttyUe16tXr4LjVq9eXeJxr7zyCiaTCavVyunTJf97+cADDxScb/r06SUeN2LEiILj5syZU5YvyS3i4uJ4/PHH6dWrF40aNcLLy4umTZsycOBAvvzyy2rVvX7q1CkefPBBwsPD8fb2pl27dvzrX/8iNze3XOc5c+YMEyZMoHXr1nh5edGgQQPuvvtudu/eXeJzFixYwDXXXIO/vz9+fn707Nmz1P+ue/fu5d577yUsLAwvLy9atGjBU089RXJycrHHZ2Vl8eqrrxIVFYW3tzf16tXj5ptvZt26deX62qTsFFCJiIiIiJRXi2uMbfxuyCj64qZL02CsHibiUrM5cSbz4uc7sRly0sA3BMK6uOw6lZLJ4cR0zCboFRlSEdWLXLI/9GgKwDc7TpKRk+fmas6LT83iwTmbC8KpueOuwt/L4u6yCoQH+fD5o30Y2L4hOXk2xi/Yytc79Aa2SK0S1MQY04sd9nzt7mqkljt37hyfffYZJpOJ5ORkvvrqK3eXhMViYdasWcXu2717N7/++isWS+n/NtvtdmbPno3JZCIvL4+5c+de1nWTk5NZsmRJidddsWIFK1asuOg1qsLx48f5+OOPCQoK4o477uDZZ5/lpptuYteuXdx11108/PDD7i4RgNOnT9OrVy9mz57N1VdfzZ///Gfq16/PpEmTuPvuu8scpCUlJdGrVy/efPNNGjZsyOOPP86NN97IN998w1VXXcWmTZuKPOfZZ5/l3nvv5fDhw9x7772MHTuWpKQkxo4dy4QJE4ocv3HjRnr27MmiRYu4+uqrefLJJ2nbti1vv/02ffr0ISkpyeX4rKwsBg4cyIsvvojVauVPf/oTt99+O+vWreO6665jyZIll/aXJqVSQCUiIiIiUl7+DaBBe+N+TNFP0/l4etCxsTFWbEtM8Z/Oc+Ec7xc5AMyuv6Kvc3RPXdE0mCAf66XXLFIBrmpZn4gQX9Jz8ln6+yl3lwNAZk4+D87dwsmULCIb+DHr/p7VKpxy8vOy8OGYHtzZrQn5Njt/XrSNL3474e6yRKQidbzd2O5e7NYypPb79NNPSU9P5+mnn8ZsNjNz5kx3l8TNN9/Mt99+S0JCQpF9M2fOxGw2M3jw4FLPsWLFCo4ePcpDDz1EYGBgicHThdfdtm0bO3bsKLJv3rx5ZGdnM2TIkGKf26pVK1q1anXRa1SFLl26cObMGZYtW8b777/Pv//9b2bMmMHBgwfp0KEDM2bMKLWzqKr89a9/5fjx47z33nt88cUXvPbaa6xfv56RI0fy9ddfs2jRojKd56WXXiI6OppnnnmG9evX8+abb7JgwQJWr15NdnY248aNw2azFRy/ZcsWpk6dSuvWrdm9ezcffPAB06dPZ+fOnfTs2ZM333yTDRs2uFzjoYceIj09ncWLF/Pll1/yxhtv8NNPP/Gf//yHAwcO8Pe//93l+OnTp7N+/XqGDx/O1q1bmTZtGrNnz2bbtm34+fnx0EMPce7cucv/SxQXCqhERERERC6Fs4vqaPHjHnpG1ANgc1nWoXIGVMWO90sE4JrW6p4S9zOZTAzv0QyA/9tSPcKVl7/Zzc7YFOr5Wpn9QE+CfKtvkOthNjHl7i6M7NkMmx3+8vkOjfsTqU2cY/6Ob4SUWPfWIrXazJkzsVgsPPfccwwYMIAVK1YQExNT7LH5+fm8/vrrtG7dGm9vb1q3bs3kyZNd3vy/UHx8PE8//XTB6LXQ0FDuuusudu3aVeJzxo0bR25uLp988onL47m5ucybN49BgwbRtGnTi35dAA8//DDDhw/nwIEDrFmzptTn3H///Xh4eBQb0s2ePZsOHTrQp0+fYp974RpUSUlJNG3alICAAA4ePOhybGn7SpOVlcXf/vY3mjdvjre3Nx06dOCdd94p0mlktVrx8PAo8vyAgABuuukmgHJdtzKcO3eOTz/9lMjISB555JGCx00mE6+99hoAH330UZnOtWTJEsxmMy+//LLL43369OHWW29lz549/Pzzzy7HAzz99NPUr1+/4HE/P7+CoOn9998vePzQoUPs2rWLnj17MmzYMJdrPPvss4SEhPDJJ5+Qnp5e5Br/+Mc/XP5btGrVinHjxpGQkMDnn39epq9Pyk4BlYiIiIjIpXCuQxVT+jpUv10soMpOg5PbjfuR17nsstvtrD9kdFBd00rrT0n1cHf3pphN8OvRZI4nZ7i1lm92nGTR5uOYTDD9nm60CPFzaz1lYTab+PcdnRnRwwipnly0jU2Hky7+RBGp/gIbQ3PHG+F7NApKKseePXvYuHEjgwYNolGjRowZMwabzcbs2bOLPf7hhx/mb3/7GzabjfHjxzN48GCmTp3KU089Vezxhw4donv37rz11lu0atWKJ554giFDhvDDDz/Qu3fvYkevAfTu3ZuoqKgidXzzzTckJCQwbty4Ur+u5ORkFi9eTFRUFN27d2fMmDEAF+0Oa9KkCYMGDWLBggXk5OQUPL5161a2b9/O2LFjS31+YSEhIXz88cdkZGRwzz33uKyp9OCDDxIbG8v06dNp3bp1mc/5hz/8gfnz53PnnXfy6KOPkpaWxpNPPlnsSLriZGVlsXLlSkwmEx07dizzdSvDhg0byM7O5sYbb8RkMrnsa9GiBe3atWPdunXk5+df9FynT58mNDQUf3//IvtatmwJwMqVK12OL7zvco43m800b96cjIwMNm7ceMnXkIpR/eYeiIiIiIjUBM16Gdu4PZCTAZ6+Lru7NgsGIDr+HJk5+fh4Fv1EJACxW8CeD0HNIcj1U6UnzmRyOjULi9lEtxb1KvorELkkjQK96R0ZwvpDSXz7+yn+1N89o3Fiz2Yy8cudAIzv35prWtecENdsNvGvOzqRnJHDT3vi+OPHW/jyT1fTplGAu0sTkcsVdRsc2wD7lkKfx9xdTd1gt0Ouez8wUSZWX7jgTf1L4QxsRo8eDcCdd97JY489xuzZs3nxxRcxFxoXvXr1ambNmkWXLl1Yt24dfn7GBzkmTpxI165diz3/mDFjOHXqFD/88IPLSL5JkybRo0cPHnroIX7//fdinztu3DgmTJjA5s2b6dmzZ0G9ISEh3HbbbaWu9zR//nyys7MLvq5rr72WiIgI/u///o+3336bwMDAEp/74IMP8v3337NkyRKGDx9ecF2LxcKYMWNKDO+Kc/311/Pcc8/x2muvMWnSJF5//XXee+89lixZwqhRo7j//vvLfC6AAwcOsGvXLoKCjPHfL7/8Mr169WLatGmMGjWKHj16uBwfHx/Pe++9h81mIz4+nu+++47jx4/z0ksvlTkY2759e7nWJQsODubPf/7zRY+Ljo4GoE2bNsXub9OmDfv37ycmJobIyMhSzxUaGkp8fDxpaWlFQqojR44Axt9d4eML7yvu+BMnTpCRkYGvr2+px9tsNo4dO1ZwjYEDBxZc4+DBgxw5coSoqKiL1iQVQwGViIiIiMilCGwM/mGQdhpO7YAWrqNDGgV60yDAi4Rz2ew5lUr3kgKmY45P7TXvXWTX1mNG91XHxoF4W0sIuETc4JYrGjsCqpNuCajsdjsvfLWLc9l5dGsezJ9vKP6NkurM4mHmnVFXMnrmJjYfPcPDn/zGV+Ov0VpzIjVdu5vhh78ZIVVGMvjWv/hz5PLkZsC/G7u7ioubeBI8L6/T1zlCLzAwkNtvvx0Af39/7rjjDubNm8fy5csZNGhQwfEff/wxAC+++GJBOAVG19FTTz3FCy+84HL+bdu2sX79esaNG1dkvai2bdvy0EMPMXXqVHbt2kWnTp2K1Dd69Gief/55Zs2aRc+ePTl58iQ//vgjjz/+OJ6enqV+bc51qu677z7AGBt333338c9//pNFixbx8MMPl/jcYcOGERoayqxZsxg+fDhZWVksXLiQoUOH0qhRo1KvW5xXXnmFFStWMGXKFJo2bcpf//pXIiIiXEbIldULL7xQEE4BBAUFMWnSJEaPHs3cuXOLDagKj72zWq288cYbPPvss2W+5vbt24uMzitNixYtyhRQpaSkALh8PYU5Q0TncaW5+eabmT17Ni+//DJvvPFGweObNm3i22+/BeDs2bMux7/22mu89dZb3HPPPQQHBwOQkZHB5MmTXWr09fWlbdu2REZGsnnzZpYuXcrQoUMLjnnrrbdISkoq9hobN27klVdeYf78+QVj/o4cOVIQchY+XiqGRvyJiIiIiFwKkwmaOl5Qxm4p9pDOTYwXb7tiS3mRdsyxmG8xAdW2Y2cBuLK5uqekermpUxgeZhO7T6ZyJDH94k+oYN/tPM3KffFYPUz85+4rsHjUzJe23lYP3r+vO02CfTiSmM6fF20j32a/+BNFpPqqFwENo4zu6IPL3V2N1DJLliwhISGB4cOH4+3tXfB4SePwduzYARjdSBcq7jHnuLO4uDj+8Y9/FLnt27cPoGB7oYYNGzJ06FAWLVpEVlYWc+fOJT8//6Lj/bZs2cKOHTsYMGCAyzpVZR3zZ7Vaue+++1i2bBmxsbEsXryYM2fOXPS6pZ1v4cKF+Pr68uSTT5KTk8P8+fNL7eIqSWl/99u2bSuyr1OnTtjtdvLy8jhy5Agvv/wyf//737nrrrvIy8sr0zUfeOAB7HZ7mW9Hjx4t99d1uV555RXCw8OZMmUKffv2ZcKECdx7773069evoHupcDdgv379GD16NNHR0URFRfHoo4/yxBNP0LlzZ06dOlUQmjmfYzKZeO+997BarQwbNoy77rqL5557jsGDB/Pss8/SuXPnItd4+umniYqK4tNPP6V79+4888wzjBs3jq5du9KiRYsix0vFUAeViIiIiMilatIN9n0Lsb8Vu7tTkyBW7otnZ0kBVX4eHN9s3G9edPFmZweVxvtJdVPfz5NrWofyy4EEvt1xkicGVl0HU1p2Hv/4ZjcAj/VvTeuGNXssXoi/Fx+M7s5d/1vPqv0JvL0imqdvbOvuskTkcrS7GeL3wP7v4Yo/uLua2s/qa3QnVXdW34sfcxHOoMYZ3DgNHDiQJk2asGTJEpKTk6lf3+jcS0lJwWw2F4w7K6y4zqLk5GQAli5dytKlS0usIz295A+njBs3jq+++oovvviC2bNn0717d6644opL+rratGlD79692bhxI7t37y51DaZx48bx1ltvMWfOHFavXk1YWBhDhgwp9bqliYyMLBiN2L17d66++upLOk9xf8/Ox0rrNPLw8CAiIoLnn38ei8XCc889x0cffcSf/vSnS6qjIjhDoJLqTk1NdTmuNE2bNmXz5s289NJLfP/99/z66680a9aMV155hYiICEaOHEnDhg1dnjNnzhx69OjBzJkzmTNnDj4+PgwePJj//Oc/dOzYEYvFUvC9DzB48GDWrFnDq6++ysqVK1m6dCmdOnVi8eLFrFixgp07d7pcIyAggHXr1vHKK6+wePFipk+fTsOGDXn00Ue55ZZb6NevX5Ga5PIpoBIRERERuVRNuhvbE8UHVBftoIrbCbnp4B0EDdq77MrKzWfPSeNFXrfmwRVSrkhFuqVzOL8cSGDpzlNVGlB99MthEs5lExHiy2MD3LP+VUXr1CSIyXd25pnPdvDOymiubRNKjwiNBROpsdreDGveNDqo8nLAUvpoM7lMJtNlj86rCY4fP86yZcsAuO6660o8bt68eTz55JOAERTYbDYSExNp0KCBy3FxcXFFnuvsEHrnnXd4/PHHL6nOIUOGEB4ezl//+ldiY2N57733Sj0+MzOThQsXAnD//feXuMbTzJkzmTp1aonn6dy5Mz179uTdd98lLi6OCRMmYLFc+lvfU6dOZd26dYSEhPDrr7/y3nvv8dhj5V9XLi4ujubNmxd5DMoW5AAMGjSI5557jtWrV5cpoKqsNaica08516K6UHR0NJ6enkW+3pI0adKEGTNmFHn8H//4B0CR8Ydms5knn3yy4Pvb6ejRo6SlpdGtWzesVtdRyb169SoYGVjYW2+9Vew1goODmTp1apHvtTlz5hR7fEUYMGAAVquV8ePHM378+Ao/f3WngEpERERE5FI1vhIwQcoxSIsHf9dP1HVqYrzIj45PIys3v+g6Us7Oq6Y94YJxETtjU8iz2WkY4EWTYJ/K+gpELtmNUY0wfwn7Tp/jxJkMmta7/E+GX0z8uSw+WnMYgOduao+XpfaszXZnt6asPZjIl1tj+fOn2/nuqWsJ9NZ6VCI1UpPu4NcA0hPg2HqI7O/uiqQWmDNnDjabjb59+9KuXbsi+/Py8pg7dy4zZ84seAO/S5cubN26lTVr1nDnnXe6HL9mzZoi5+jVqxcAGzZsuOSAysPDgzFjxvD666/j7e3NqFGjSj3+888/JyUlha5du9K9e/dij5k/fz6ffPIJr732WqlrWY0bN64gwLnU8X5gjN6bOHEi7dq1Y9WqVVxzzTVMmDCB6667rtQuruKsWbOGe++9t8hjAFdeeWWZznHypNEheGH4UpLKWoOqd+/eeHp68tNPP2G32zGZTAX7YmJi2L9/PwMGDLisYDA/P59FixZhsVi46667yvSc+fPnAzBy5MgyHR8TE8PatWuJiooqGPVX0dcoj1WrVrmMtqxrFFCJiIiIiFwq7yAIaQVJB+H0Tmg90GV3WKA3of6eJKblsO/0Obo2C3Z9/iljXQDCuxY5tbPr6oqmwS4v/kSqi3p+nnRvUY/NR8+wal88o/tEVPo1314RTUZOPl2bBXNzp7BKv15Ve3lYR7YcPcOx5Az+sWQ3U0d0dXdJInIpzGZoOxi2zTPG/Cmgkstkt9uZPXs2JpOJuXPnEhkZWexxBw4cYMOGDWzZsoUePXowevRoZs+ezSuvvMLgwYPx8zM6zWJjY/nvf/9b5PlXXXUVvXr1YuHChQwbNowRI0a47LfZbKxZs6bUDi6AZ555ht69e1O/fn2Cg4NLPdY53m/q1KkMGDCg2GMyMjJYuHAhX3/9NXfffXeJ57rvvvsICwvD29u72BCvLNLT0wtCtYULFxIeHs6CBQu49tprGTVqFL/++qvL+l8X8+qrr3LLLbe4jMf75z//iclkcukW27FjB1FRUUVCqOTkZCZOnAhQ5pGFDzzwAA888ECZayyrwMBARo4cyccff8wHH3zAo48+Chjfn88//zwADz30kMtzUlJSCtaICg8PL3g8NzeXvLw8fHzOfxDPZrMxYcIE9u/fz9NPP03jxo1dzpWamlpkHbA1a9YwefJkWrRoUVCPU1paGn5+fi6vpVJSUhg9ejT5+flMnjy5yNdY3DWmTZvG8uXLueOOO+jZs+dF/56kfBRQiYiIiIhcjkadjIAqbneRgMpkMtG2UQCJaUlExxUTUJ3cbmzDuxQ5rXO8X1Tj8i/GLFJVrm/fiM1Hz7CiCgKq0ylZfLr5OADP39y+Vga3Ad5W3hrZlbv/t54vt8Vya9fGDGintQ5EaqS2NxkB1cHl7q5EaoGVK1dy5MgRrrvuuhLDKYCxY8eyYcMGZs6cSY8ePRgwYABjx45l9uzZdO7cmTvuuIPs7Gw+/fRTevfuXezos4ULFzJgwABGjhzJW2+9Rbdu3fDx8eHYsWNs2LCBhIQEsrKySq23YcOG3H777Rf9ug4ePMgvv/xCREQE/fv3L/XrWrhwITNnziw1oPL39y/TdUvz1FNPsX//fqZMmVLQ4dS7d29eeuklXnjhBf7yl7/wzjvvlPl8bdu2pVOnTgXdQF988QUnTpzgmWeecRkXN23aNL799luuueYamjdvjo+PDzExMSxdupT09HSGDx9+0W60qvDaa6+xatUqHnvsMZYvX07r1q35+eef2bhxI7feemuRDqPFixczduxY7r///oIxeWCMOezYsSODBg2iZcuW5OTk8OOPP7Jv3z6GDh1abHh09913k5mZyRVXXEFgYCA7d+7k+++/p379+nz11VcEBLiuS/rVV18xceJErr/+eho3bkx8fDxff/01CQkJvPrqqwwbNqzINZo0acKAAQNo06YNJpOJ1atX89tvvxWsfSUVz3zxQ0REREREpESNOhnbuN3F7m7T0B8wxvy5yMuG+L3G/WICqr2nHQFVuAIqqb4GdjDCk/WHksjIyavUa81ad4TcfDtXtaxPr8iQSr2WO3VrXo8H+7YEYNLiXaRnV+7fq4hUkpb9wORhfIjl7DF3VyM1nPON8Yt1xYwYMQIfHx8WLlxIZmYmAB999BGTJ0/GZDIxffp0vv/+e5555pmCNXgu1LJlS7Zt28akSZNIS0tj9uzZfPDBB2zfvp1+/foVrBdVEWbNmoXdbuf+++8v9YMnAwcOpFmzZixbtozjx49X2PUv9MUXXzBz5kxuvPFGnnnmGZd9EydOpF+/fkyfPr3YYK8kn332Gffccw9ffvkl//vf//Dz8+Ptt99mypQpLseNHj2aIUOGsG/fPubOncu0adP4+eef6devH4sWLeKzzz6rFh/OCQ8PZ9OmTYwdO5a1a9cybdo0kpKSePXVV/n888/LXGNQUBC33XYbW7du5Z133mHWrFnUq1ePjz76iK+//hovL68iz7n99tvJzs5m/vz5TJ06lX379vHEE0+wa9cuunbtWuT4zp0706VLF5YtW8aUKVNYsmQJvXr1YuXKlUyaNKnYuu677z4OHDjA+++/z/vvv4/NZuONN95g7dq11KtXr1x/V1I2Jrvdbnd3EbXZiRMnaNasGcePH6/TsyRFREREaq1938GiUdCoM/xpbZHd8zbGMOmrXfRv14A5Y686v+PkNviwP/jUg+eOGAt8O+Tm2+j44o/k5Nv45S8DaB5S+Wv7iFwKu93Otf9ZxYkzmXw0pgc3RjWqlOukZORy9WsrSM/JZ/bYnrW+qygjJ48bp/5C7NlM/ti3JZNuiXJ3SSJyKWbcCCd+hVvfhu73X/x4KVVWVhZHjhyhZcuW5RqxJiJSlcr6s0q5gUEdVCIiIiIil6ORY6HkhH2Qn1tkd9tGxqiJ6LgLOqgK1p/q4hJOARxOSCcn30aAl4Wm9XwQqa5MJhMD2xth0ar98ZV2nXmbYkjPyad9WAD92zaotOtUF76eFv55h9GdOWvdEfY5OipFpIZpdb2xPbzKvXWIiIhUUwqoREREREQuR3Bz8AoEWy4kRhfZ7RzxF3s203VUV+GA6gJ7TqUA0D48ALPZ/aM8RErTt40RGG08lFQp58+32Vn4qzEe64/XRlaL8TZVYUC7htzcKQybHf61dC8afiJSA7UaYGwPrwZbvltLERERqY4UUImIiIiIXA6T6XwXVdyuIrvr+XkS6m/MUD9YeB2q+H3GtlHnIs/Ze+ocAB20/pTUAL0i62M2weHEdE6lZFb4+X+JTuDEmUyCfKzcckV4hZ+/Ovvbze3x9DCzJjqR1fsT3F2OiJRXk+7gGQCZZ85/MEVEREQKKKASEREREblcDdob28QDxe52dlFFOwMqux0S9jqe27bI8c4gq41jPKBIdRbobaVz02AA1h+s+C6qBZuM7qk7uzXB2+pR4eevzlqE+PHANREA/HPpHnLzbe4tSETKx8MKLfsZ9w+tdG8tIiIi1ZACKhERERGRyxXS2tgWM+IPoGUDPwBiktKNB9ITjU9TY4KQNkWOP5xgBFStQv0qvFSRynB1qxAA1lfwmL/41CxW7jPWtrq3V/MKPXdN8fj1ranv58mhhHQWb411dzkiUl7OMX9HfnFvHSIiItWQAioRERERkcvlDKiSDhW7u0V9XwBikjKMBxIc4/3qtQBPX5djc/JsHD9jjEmLbOBf8bWKVAJnQLXhUGKFrpW0dOcp8m12rmweTOuGdbOjMNDbyp+uawXAO6ui1UUlUtO0uMbYntgM+bnurUVERKSaUUAlIiIiInK5Qh1dUMmHwFb0zeMWIc6AytFBlbjf8bx2RY49lpxBvs2On6cHjQK9KqVckYrWo0V9rB4mTqZkcSw5o8LO+82OkwAM69K4ws5ZE93buzmh/p4cT87ky60n3F2OiJRHg/bgUw9yM7QOVQWpyA9CiIhUNP2MKh8FVCIiIiIilyu4OZgtxptP504W2d28vmPEn/ON+wRHQNWgaEDlHO/XsoEfJpOpcuoVqWA+nh50bBwEwNZjZyrknMeTM9h67CwmEwztHF4h56ypfD0tPOrsolp5UF1UIjWJ2QzN+xj3Y9a5t5YazsPDWIcwN1edaCJSfTl/Rjl/ZknpFFCJiIiIiFwuDyvUizDuJx0sstvZQXU2I5eUzNzzI/4atC9y7OFEo8sqMlTj/aRm6da8HgBbY85WyPmW7jwFQO+WITQM9K6Qc9Zk9/ZqQai/FyfOZBZ0lolIDdHiamMbs8G9ddRwVqsVLy8vUlJS1KEgItWS3W4nJSUFLy8vrFaru8upESzuLkBEREREpFYIaW2EU0kHIbK/yy4/Lwuh/l4kpmVzLCmDzkmHjR3O0YCFODuoIhv4VXbFIhWqW4tgZq2ruA6qH3adBmDoFXW7e8rJx9ODsddE8MaP+5mx5gh3XNlEXZYiNUVzR0B1bL0xCtisz4tfqtDQUGJjYzlx4gRBQUFYrVb9LBQRt7Pb7eTm5pKSkkJaWhpNmjRxd0k1hgIqEREREZGKENLa2CYdKnZ3ixBfEtOyOZ5whs6pscaD9VoWOS4myRgD2DJUAZXULM4Oqn2nz5GRk4ev56W/3ExKy2bHibMA3NChUUWUVyvc26s501ceZM+pVDYcTuLqVqHuLklEyiL8CrD6QVYKxO+BsE7urqjGCgwMBCAxMZHY2Fg3VyMi4srLy4smTZoU/KySi1NAJSIiIiJSEYJbGNuzx4rd3aK+L7/FnOHsqcOA3Xijyq/om8snzmQC0LSeT2VVKlIpGgf7EBbozenULH4/kULvyJBLPtfPBxKw2yEqPJCwII33cwr29eTu7k35ZGMMM9ccUUAlUlN4WKHZVXB4FcSsV0B1mQIDAwkMDCQ3N5f8/Hx3lyMiAhhrTmmsX/kpoBIRERERqQjBzYxtCQFVeLDxJnuec7xfvQi4YCRNXr6N06lZADSt51spZYpUpm4tgvlu52m2HjtzWQHVyn3xAFzfvmFFlVZrjL0mgnmbYlixL57DCWlENtB6dSI1QvPeRkAVuwV42N3V1ApWq1VvBouI1HAaeisiIiIiUhGCmxvblOPF7g4LMjqiPM7GGA/UiyhyzOnULPJtdjw9zDTw96qMKkUqVZemwQDsjk295HPk5dv45UACAAPaN6iIsmqVyAb+DGhnBHefbi7+542IVENNehjb2N/cW4eIiEg1ooBKRERERKQiBDk6qDLPQPa5IrsbO8aU+aQ73lAuJqCKdYz3Cw/2xmzWgt9S80Q1Nubt7zl16QHVrpOppGblEehtoWuzehVVWq0ysqfx8+bz306Qk2dzczUiUiZNuhnbpIPG7woiIiKigEpEREREpEJ4B4J3sHH/bNGuhnBHB1VwtmNB7+ICqrNaf0pqtqhwI6A6kphOWnbeJZ1j0+EkAK5qGYKHgtpiXd++IQ0DvEhKz2H53jh3lyMiZeFbH+q1NO7HbnVvLSIiItWEAioRERERkYriXIeqmDF/4Y4OqrB8x5vJxQRUJxwdVE2CFVBJzRTi70VYoPG9vu8Su6g2OgKq3pH1K6yu2sbiYWZ4j6YALPy1+HXvRKQaatLd2CqgEhERhwEDBhAVFcW7777r7lLcQgGViIiIiEhFCW5hbM8WfcM42NeKt9VME1OC44HmRY6JLQiofCutRJHK1vEyxvzl2+xsOWqMvurVMqRC66ptRvQwfoasPZhY0H0pItVcU61DJSIirlatWsWePXsYP368u0txC4u7CxARERERqTWCSu6gMplMRASaCUrPMB4IDC9yjPNN5iYa8Sc1WFTjQFbsi2d3bPkDqj0nUzmXnUeAl6VgPSspXvMQX3q1rM+mI8l8u+Mkj1zXyt0lFctut7P12BlW7I1ny9EzHE1KJyEtGxPg52UhIsSPqPBA+rYJpX+7BgR4W91dskjlKeig+g3sdjBpjKmIiNRtCqhERERERCpKQJixPXe62N1R/mmQDnkePli8ir75HpeaBUBjxzhAkZrI2UG1+1RKuZ/769FkAHpE1NP6U2UwrGtjNh1J5utqGFDl2+x8sfUEH/x8iEMJ6UX224FzWXnsjE1hZ2wKn245jo/Vg2FdGvNQv0haN/Sv+qJFKltYZzBbID3e+DBLMd3UIiIidYkCKhERERGRinKRgKqVl9FRkubZkOBiPjUdfy4bgAYBXpVTn0gVaBdmBFQH49Ow2eyYyxE0/X7iLABXNq9XGaXVOkM6hfPSkt3sPpnKwfi0ahPq7Dh+luc+/539cecA8PX0YFBUI65uHUr7sAAaBXpjAlIyczmcmM6Wo8ms2BvP4cR0Pt1ynP/77TjDuzfjuZvaEeKvn4dSi1h9oGEHOL0TTv2ugEpEROo8BVQiIiIiIhXFGVClxRW7u6nlLABnLKEEX7AvOy+flMxcABoGqINKaq5m9Xzw9DCTlWsj9mwmzeqXfU21nSeMrqsrmgZVVnm1Sj0/T65tE8qq/Ql8veMkz9zY1q312O12/vfzId5cdoB8m51gXyvj+7dmVK/m+HsVffuhYaA3bRoFMLhjGBOHdGDz0TN8+Mthlu+N49Mtx1m+N45/3t6JmzsXHYkqUmM16mwEVHG7oMMt7q5GRETErczuLkBEREREpNbwL72DqhHG+LJEc0iRfQmO7ilPDzOBPvocmdRcFg8zEaFGKHUoIa3Mz3N20wBc0TS4MkqrlW7r2gSAb3ecdGsdufk2nvv8d/7zw37ybXaGXhHO6gn9eahfZLHh1IVMJhNXtazPjPt78PmjfWgfFkBSeg5/mr+Vf3y9m9x8WxV8FSJVoFFHY3t6p3vrEBERqQYUUImIiIiIVJSARsY26yzkZhbZHWpPAuC0rej4soRC4/1MWjRdarhWDYxRcwfjyx5Q7Y41uqea1vOhvp9npdRVGw3s0BCrh4nDienlCgQrUl6+jScWbOP/fjuBh9nEP2/vxPRRVxLse2n/HXtE1GfJ49fwWH9jXa05649y/6xfSc3KrciyRdwjrJOxjdvl3jpERESqAQVUIiIiIiIVxTsYPBzrpRQz5i8oNwGA43nBRfZp/SmpTZxrIR1KSC/zc3Y4xvt1UfdUuQR4W+kdaXRlLt9T/HjRymS323n+y538sPs0nh5mPhrTnft6t7jsoN3L4sFzN7Xng9Hd8fP0YP2hJO6bsYmzGTkVVLmImzTqbGzPHIXsc24tRURExN0UUImIiIiIVBST6XwX1bmibxT75RgB1dGcouvrJCigklrE2UF1qBwdVLscHVSdtf5Uud0YZfzcWb636gOqD385XNA59c49V3J9+0YVev7BHcP47NE+1Pfz5PcTKYz8cKNCKqnZ/EIgwLGuWtwe99YiIiLiZgqoREREREQqkvNNp7Si61B5ZRoB1aGsAGw2u8s+ZwdVQwVUUgs4O6gOlmPk3L7TqQBEhQdWSk212cAORij0W8wZktOrLrzZcCiJ13/YB8A/hnVkcMewSrlOx8ZBLHq4Nw0CvNh3+hwPf/wbWbn5lXItkSrRyDnmzw3rUNnyIX4v7P0GtsyGXz+CbfPgwI+QGA35GqUpIiJVR6svi4iIiIhUJP8SOqjsdsyZiQDE2QJIycylXqF1dtRBJbVJZAM/AJLTczibkXPRtYiy8/I5mpQBQNtGAZVeX23TJNiHqPBA9pxKZdW+eO7q3rTSr5malcvTn27HZoc7uzXhvl7NK/V6bRsF8MmDVzH8fxv49Wgyz3y2nemjumE2a80+qYHCOsHBn+B0Fa1DZbfDoRWwfaERROWUMlrQ4g1Ne0KLq6HdEAjvYnSIi4iIVAJ1UImIiIiIVCS/BsY2I9H18Zx0THlZACTbA0lKz3bZnZimgEpqD19PS8H38vHkzIsefyQxnXybnQBvC40C9f/ApbihQ0MAVu2Pr5LrTf5uH6dTs4gI8eVft3e+7DWnyqJ9WCAfjOmOp4eZ73ae5n8/H6r0a4pUioIOqt2Vf62j6+CjATDvLtj1uRFOWf2gSXdoezN0GAatbzTWxrL6QV4WHF0DP78OH14H/+0CP70ESfr/TUREKp46qEREREREKpJviLHNSHJ93BFYZeNJBl4knMuhdcPzu884xnLVv0iniUhN0ayeDwnnsjmWnHHRdaX2nzY+zd+2UUCVBB21Ud82DXh75UE2HErCZrNXamfRxsNJLPz1GACv3XUFPp4elXatC13dKpRXb+/IX7/YyZvL9nNl82CubhVaZdcXqRAN2hnbxANGd1Nl/NzLz4Xl/4AN7wJ2I3y68j7oPBwaXwkexbwlaLcbNcWsh8OrIPonOBsD694ybpH9occ4aH8LmKvu/3sREam91EElIiIiIlKRSgqo0o0/nzMHAiZSMl3XiTmbaaz5EORrrewKRapE8/q+ABxLzrjosdFxxlpVbRv5V2pNtVnXZsH4enqQlJ7DvtOljO+6TDabnX8u3QPAqKua0zsypNKuVZIRPZszvHtTbHZ4cuG2gg5UkRojpDVggqyzkJ54saPLLzsNFo6CDdMBO1w5Gp7aAUP+A816Fh9OgRGUNWgHPcbCHz6GvxyE4XOhzSCj3sOr4bMxML0H/DYH8vT/noiIXB4FVCIiIiIiFcm3vrHNSHZ93NFBlW4JBuBshusi5GczjMCqnjqopJZwBlTHz1w8oDoQd76DSi6Np8VMr5bGz591ByvhDW+Hb34/ya7YVPy9LEwY1LbSrnMxr9zWifZhASSm5fDCV7uw2+1uq0Wk3Kw+EOxYty3xQMWeOy8bFo401riy+MAfPoHbpoN/g/Kfy9MPOt4O9/6fEXD1fQZ86kHyYfjmKXjrClj3NuSkV+zXICIidYYCKhERERGRilRSQOX4hHSmtR4AKZnnAyq73V4QWCmgktqiqTOgKkMH1cEEo4OqTUMFVJfjmtbGqLt1hyonoMrNt/HGj/sB+FP/VoT4u2+9MB9PD978QxcsZhPf7zrN0p2n3FaLyCUJdQS8ifsr7px2O3z1mLGGlGcA3P8NRA2rmHPXawE3vAR/3gWDJ0NgE0g7DT+9AP/tCps+UEeViIiUmwIqEREREZGKdJE1qHI8iwZUadl55NmMT/8Ha8Sf1BLNyxhQ5dvsnEjOBKBFiG+l11WbOQOqTYeTycmzVfj5v9lxkhNnMgn192LcNS0r/Pzl1bFxEI8NaA3AC1/tKljLT6RGKFiHKrrizrl5Buz6HMxWGPGJMc6vonn5Q5/H4MntcNu7ENwC0uPh++fg7W7G6L/83IudRUREBFBAVWb//e9/adGiBd7e3vTt25cdO3a4uyQRERERqY4KB1SFR045OqjyvI39hQMqZ/eUt9WMt1WLjkvt0MwRUJ04k0m+reTxa6dTs8jJt2H1MNE42KeqyquV2jUKoL6fJ5m5+ew6mVKh57bb7Xzw82EAxl4TgY9n9fhZ9fiA1rQPC+BMRi5v/lSBnSgilS20jbGtqBF/Cfth2STj/qBXodWAijlvSSyecOV98PgWuGUaBDSG1BPG6L/pPWDHIrDlV24NIiJS4ymgKoMFCxbw17/+lVdffZXffvuN1q1bM3jwYFJTU91dmoiIiIhUNz6OEX/52ZBbqHPE0VFl9ysaUJ1xrD8V7KPxflJ7hAV6Y/UwkWezczo1q8TjYpKMtUua1vPFw2yqqvJqJbPZRLfmRpfmb0fPVOi5Vx9IYH/cOfw8PbivV4sKPffl8LSY+cewjgAs2HSM3RUczIlUGueIv4QKCKjsdvhuAuRlQesboNejl3/OsrJ4Qo9x8OQ2Y/SfXwM4cxQWPwL/uxr2LAFbxXd0iohI7aCAqgymTZvGo48+ypgxY+jYsSMzZswgLy+PBQsWuLs0EREREaluPP3Aw7EuS+Exf477Jt+SO6g03k9qEw+ziYYB3gCcTik5oDqWZAS5zpGAcnl6RDgCqpiKDajmb4wBYORVzQmqZj+rekeGcMsV4djs8PLXe7DbS+7YE6k2Qh0j/lKOQc7F1+or1d6v4cgvYPGGoVPB5Iaw3+p9fvTfwBfBOxgS9sFnY+DD6+DAMtfOchERESoooHr99dcxmUyYTCY2btxYEacsl3nz5vHII4/Qo0cPvLy8MJlMzJkzp9TnbN68mSFDhhAcHIyfnx+9e/fms88+K3JcTk4O27Zt44Ybbih4zGKx0L9/fzZs2FDRX4qIiIiI1HQmU/HrUGWeBcDqX3IHVT1fdVBJ7RIWZARUcaV1UDnWqNL6UxWjewsjoNoSc6bCgprTKVms3BcPwD29mlfIOSvaxCEd8Laa+fVoMqv2x7u7HJGL8wsB7yDj/pmjl34eWz6seMW4f81TUM/NHY5e/nDts/DUDrjur+DpD6d/hwXDYdZgI0gTERFxuOyAateuXbz00kv4+flVRD2XZNKkSXz44YfExMQQHh5+0eNXrVrFNddcw9q1a/nDH/7Ao48+yunTpxkxYgRvvvmmy7GJiYnk5+fTqFEjl8cbNmzI6dOnK/TrEBEREZFawifY2DpCKQCyjfHQnv7Gm8fqoJK6ICxQHVRVrXOTIKweJhLTsjmenFkh5/xi6wlsdugZUY9WDfwr5JwVrXGwD/dfHQHAm8sOYCtl3TORaqNeS2N7OQHVniWQdNDoWLr6iYqoqmL4BMOAifDU73D1k2DxgeObYO6tMHcYHN/s7gpFRKQauKyAKjc3l/vvv5+uXbtyxx13lOu5n3zyCTExMSXuz8/PZ+rUqeTk5Fz0XDNmzODo0aMkJCTw6KOlz9nNy8vjoYcewmw288svv/Dhhx/y5ptvsmPHDtq2bcvEiRNLrUtERERE5KK8Ao1tdqE1S7OMdVF8AooGVKmO+0E+CqikdmkUWJYOKmMNqhYh7vvQY23ibfWgUxOjK+O3Y8mXfT673c5nW44DMKJn9eyecnq0Xyv8vSzsPpnKj7v1gVKpAepFGNtLDajsdlgz1bjf+0/gFVARVVUsvxAY9Co8tR16PgRmKxz5GWbeAAtGwKnf3V2hiIi40WUFVP/617/YvXs3s2bNwsPDo8zPO3HiBA899BD9+/cvNgyy2Wzcf//9PPvss3z44YcXPd8NN9xAixZla2FeuXIlhw4d4p577qFr164FjwcFBTFx4kRycnKYO3duweOhoaF4eHgQFxfncp74+HjCwsLKdE0RERERqWO8nQHVufOPZRlhlV+gMeIvNTO34BP+aTl5APh7WaquRpEqEBZkrMd2upSAKvaM0eXTrL5PldRUF3RvXnHrUO2MTSEmKQMfqwdDOlfv18D1/DwZ19foSJm2XF1UUgNcbkB1fBPE7QSrL1z1cEVVVTkCwmDoFHhyK1w5GkwecOAH+OBaWDASjm1yd4UiIuIGlxxQbd26lX/961+89NJLREVFleu5TZs2ZeHChZw4cYIBAwZw7Nixgn3OcGr+/PmMGTOGxx577FJLLNbq1asBGDRoUJF9gwcPBuDnn38ueMzT05Mrr7ySFStWFDyWl5fH6tWr6dOnT4XWJiIiIiK1hPMTzI5QCls+5BhhlV+QEVDZ7JDuCKbSsoytnwIqqWUaXWTEX1ZuPmccIy7DgxRQVZQrmgUDsCs2tfQDy+C7nUYn0vXtG+LrWf1/Rj3YtyUBXhYOxKWx+oDWopJq7nIDqq0fG9uOd4Jv/YqoqPIFN4fbpsP4X6HT3YAJDnwPswbBrJvhwI9GZ5iIiNQJlxRQZWdnM2bMGLp27cpzzz13SRe+4447WLhwIcePH6d///4cO3YMm83GAw88wLx587j33nuZPXs2ZvNlL5PlIjo6GoA2bdoU2RcWFoa/v3/BMU5PP/0077//PvPmzWPPnj08/PDDWCwW7rnnnhKv8+677xIVFUX//v0rtH4RERERqQEuHPFXaNSfl38wFrMJgPTsfMfWCKgCvKv/m78i5RF2kRF/zuDKx+pBoL7/K0ynxsbPoH2nU8nLt13yeex2O9/vOgXAzdW8e8opyMfKqF7GKMIPfzns5mpELuJyAqqsFNi92LjfbUxFVVR1QlvD3TPh8S1G/R6ecGw9LPgD/O8a2LEI8rLdXaWIiFSyS0p/XnzxRaKjo5k9e3a5Rvtd6O6772bevHkcO3aMAQMGMGrUKD755BNGjhzJ3LlzKzycAkhJMWb/BwUFFbs/MDCw4Bine+65h9dee42JEydy5ZVXsn//fn788UcCAwNLvM748ePZs2dPQceWiIiIiNQhzhF/zg4qx/pTWHwwWbzw9TR+hy7ooMrWiD+pncKCHB1UqVnYi/lEvHP0X3iQNyaTqUprq80iQvzw8/QgK9fG4cT0Sz7P3lPniEnKwMtiZkC7hhVYYeV64OoILGYTGw8n8/uJs+4uR6RkzoDqbAzYyhkm71sKuRkQ2haaXVXhpVWZ0NYw7B146ne4+knw9If43bD4EZjWEVb+E1Ji3V2liIhUknInQBs2bGDKlClMmjSJTp06XXYBI0aMYO7cuRw+fJjPPvuM22+/nXnz5l1W8FUZnnrqKY4dO0Z2djbr1q2jS5cu7i5JRERERKqrgg4qRzDlDKocwZUziHJ2TjkDKo34k9rGOeIvK9dGamZekf3ODirncVIxzGYTUY4uql2xKRc5umQr9hprMfdr26BG/XxqHOzDsC6NAfhozRE3VyNSiqBmxlpMeVmQFnfx4wvb+42x7Xgn1IaAPzAcBr0KT++C61+AgMaQngC/vAFvdYbP7oej6zT+T0SklilXQJWXl8f999/PFVdcwd/+9rcKKcBut7Ny5cqCP+/evZu4uHL+o1wOzs6pC7uknFJTU0vsrhIRERERKRNvx++TF3ZQOR73LQiojBF/BR1UGnEmtYy31aNgdGVCWtFRTadSzndQScXq2Nj4ebP75KWvQ7UmOhGA/u0aVEhNVWlc35YA/LDrFInFfO+JVAseFghuZtw/U44wNTsNDjrWSu9wa8XX5U4+9aDfBPjz7zB8LrS4Buz5sOcrmDMEpveEtW/BudPurlRERCpAuQKqtLQ0oqOj2b59O56enphMpoLb3LlzAejTpw8mk4mvvvrqouez2+08/PDDzJo1ixEjRjBv3jwOHz7MgAEDOHny5CV9QRfjXHvqwnWmAE6fPk1aWlqx61OJiIiIiJRZSWtQOQIqvws6qJxBlUb8SW0U4ucJQHJ6TpF9zrWpGimgqnAdL7OD6lxWLluPnQGgX5uaF1B1ahJEl2bB5Obb+eK3E+4uR6RkQY6AKqUc36cHf4L8bKjXEhp1rJy63M3DCh1vh7HfwaProPsDYPWFpGhY/hJMjYIFI4xOsryi/76IiEjNUK5XwF5eXjz44IPF7vvll1+Ijo5m2LBhNGjQgIiIiFLPZbfbeeSRR5gxYwZ/+MMfmD9/Ph4eHpjNZkaPHs2AAQNYvXo14eHh5Snxoq677jomT57MsmXLGDlypMu+H3/8seAYEREREZFLVtIaVI7gyu+CNajOZWkNKqm9Qvy9OJqUQVKxHVSZgDqoKkOnJkYgvudUKna7vdxrfG08nEyezU5EiC/N6vtWRomV7p6rmrHj+FkW/nqMh/tFap0zqZ4Cmxjb1HKss3RwubFtP7R2jPe7mLBOcOt/YdA/Yfdi2DYPjm+CAz8YN99Q6HgHdL4bml4FlbCmvYhIZRkwYABWq5Xx48czfvx4d5dT5cr1CtjHx4cZM2YUu++BBx4gOjqa559/nt69e5d6Hrvdzp/+9Cc++ugjl3AKYNSoUQAuIVVYWFh5yizVwIEDiYyMZMGCBTz55JN07doVMEb+/fvf/8bT05MxY8ZU2PVEREREpA4q6KA6Z2yzSuqgyndsFVBJ7VXf0UGVVEwH1elUI7TSGlQVL7KBHx5mE+ey8og/l13uv+M10QmAsf5UTXXLFY159du9HE3KYMPhJK5uFerukkSKCnIGVOWYJHTkF2Mb2b/Cy6nWvAKg2xjjlnAAts+HHQuN9bs2f2TcgppBpzuh090Q1rluBHgiUqOtWrWKpk2bursMt3HLK+CTJ0+yePFihg8fzvz587FYXMsYNWoUdrudMWPGsHz5cu67775SzzdjxgzWrl0LwM6dOwseW716NQB9+/blj3/8IwAWi4UZM2YwePBg+vXrx8iRIwkICOCLL74gJiaGKVOmXLT7S0RERESkVN4ljfi7oIMqO4+8fBuZuRrxJ7VXaSP+nF1Vof5eVVpTXeBl8aBFfV8OJ6YTHZdW7oBq0+FkgBod6vh5Wbita2PmbzrGZ5uP1+ivRWqxwMbGNqWMHVTJR+DsMTBboHmfyqurumvQFm58Ga5/AQ6vhl2fw95vIeU4rPuvcQttC53uMm6hWs5DRKQ6cssr4CZNmrBhwwaaN29eJJxyuueee+jRowdt27a96PnWrl1bsAaW07p161i3bl3Bn50BFRhtc2vXruWll17i008/JTc3l86dO/P6668zYsSIS/yqREREREQcPP2NbXaasc1Jc3m8oIMqJ4/0nPyCp/kpoJJaqH6pAZXxWKi/Z5XWVFe0buhvBFTx5+jbpuzhTEpmLgfijQ7QHhH1Kqu8KnFX96bM33SMZXviyMzJx8fxAQGRaiPQ8an5so74c3ZPNekBXv6VU1NN4mGBNjcYt1syIXoZ7PwcDvwIiQdg9WTj1qgTRN1urGulsEpEpNqosFfAc+bMYc6cOWU+PjIy8qLHlCWcupRrA1x11VV8//335XqOiIiIiEiZePoZ29wMsNshJ8P4s9VYx+X8iL880hzj/TwtZjwtWjNBap+SRvxl5uQXdA86j5GK1aaRP8v2xBEdn1au520/fha7HSJCfGt8d9uVzYJpVt+H48mZLN8bx61dGru7JBFXzg6qso74i3F8GLvltZVTT01m9YGo24xbVirsW2p0Vh1eDXG7jNuqf0LDjkZQFXW70YklIiJuo1fAIiIiIiIVzerjuGOHvCwjqALwdARUns4OqnwyHAGVnz7VL7VUiL+zgyrb5fEkx589Pcwab1lJ2jQMAOBgXPkCqt9izgDQrUXN7p4CMJlMDHOEUku2l2ONH5GqEuhYgyo9HvKySz8WIPY3Y9usV+XVVBt4B0LXUXDfFzAhGoZNh9Y3GKMR43fDqn/Buz3hvT6w+nVI2O/uikVE6iQFVCIiIiIiFc3RKQUY3VM56Y7Hjc4qP6/za1A5O0h8rAqopHYK8TM6cJzj/JycI/9C/D0xaRH7StG6oTH+K9oxrq+stjoCqu61IKACuK2rEQD8fCCelIxcN1cjcgHf+mBxrBF37lTpx2aehaSDxv3G3Sq1rFrFtz50G30+rLrtXWh9oyOs2gOr/w3vXqWwSkTEDRRQiYiIiIhUNLPH+TebctOLdFD5OjuosvPJyrUB4K0OKqmlShrx5wysNN6v8rRq4I/JBGcycklMK0NnBmCz2dl+/CwA3ZrXjoCqbaMA2ocFkJtv58c9p91djogrk+n8mL+Ui6xDdXKbsa0XAX4hlVpWreVbH668D+77HP5yEG57D9oMArPVNax6tzesfg3i97m7YhGRWk0BlYiIiIhIZXB2UeVkFFmDyttq/BqenXd+DR5viwIqqZ2cAdTZjBzsdnvB487ASgFV5fHx9KBxkDFyNCYpvUzPiUnOIC07D2+rmbaNAiqzvCp1U6cwAJbviXNzJSLFcI75u1gHlXO8X5PulVtPXeFTD668F+79P/hLtGtYlbAXVk+G93rBu71g1WSI3+vuikVEah0FVCIiIiIilcEZUOVmGF1UAJ7GiD9vxzi/7FwbWc6AyqpfzaV2CvSxApCbbyc7z1bwuHNNqlB/L7fUVVc0q28EVMeSM8p0/J6TqQC0CwvEw1x7Ri/eGNUIgDXRiQU/d0WqDf+GxjYtvvTjnB1UGu9X8S4Mq27/H7QZ7Air9sHPr8F7vRVWiYhUML0KFhERERGpDJ6FAqoLOqi8LMav4Vl5+YUCKnVQSe3k5+mBM+dIzTy//o9G/FWNFvWNYDwmqYwB1akUAKLCAyutJneICg+kcZA3mbn5rD+U6O5yRFz5NTC26RcJqBIc4+Yadazceuo6n3rQ9R649zNjDODt70Pbm4qGVdOvglX/hrg9UKhDWEREyk4BlYiIiIhIZXAZ8efsoHKO+CvaQeWjgEpqKZPJVNBFlVI4oNKIvyrRPMT4uXOsrAGVo4MqKrz2jPcD4/vwBkcX1U97LhICiFQ1Z0CVllDyMblZkHzYuN+wQ+XXJAafYOg6Cu751DWs8vCExP3w8+vwvz7GulUr/6WwSkSknBRQiYiIiIhUhoIRf+nnR/xZnSP+CndQ2RyPKaCS2ivQ2wioUrPOB1TOsCrY1+qWmuqK5vUdAVVZR/ydcgRUjWtXBxXADR2MgGr53jiX9dBE3M454q+0DqqkaLDbwDsI/BtVTV3i6sKw6o4PoO3NjrDqAPzynwvCqt0Kq0RELsLi7gJERERERGqlghF/medH/Hk6R/wZYVRW7vkRf15ag0pqsSBHB1VqZl7BY85xf87wSipHC0cHVUwZAqrk9BziUo21wdqF1b6AqldkfXysHiScy2Z/3Dna18KvUWoovzKsQRXvGO/XoAOYas/6cDWWdxB0GWncslJg/w+w5ys4uPx8WPXLfyCkDXS8HaJuN0Yz6r+diIgLvQoWEREREakMzg6q7HOQn+147IIOqlybOqikTgj0MT4bWbiDKjUrz7FPAVVlcq5BlXAum8yc/FKPPZSQBkCTYB/8vWrf51m9LB5c1bI+AGujtQ6VVCP+zjWoShnxl7DX2DZsX/n1SPl4B0GXETBqoaOz6kNoN8TorEqKhl/egPevgek9YMWrcHqXOqtERBwUUImIiIiIVAZP401h0gu9CXpBB1V2Xj6ZWoNK6oCCEX+F1qA6l+XsoKp9QUh1EuRrLfg7vtiYv8OOgCqygV+l1+UufVuHArDuoAIqqUacHVTpCSUHF4kHjG0DBVTVmktYdQju/AjaDQUPL0g6CGumXBBW7VRYJSJ1mgIqEREREZHKYPUxtgWfhjaBxRs43y1ldFDlOx7Tr+ZSe51fg6qYEX/qoKp0zR1j/k6cuUhAlWisl9eqgX+l1+Qu1zgCqk1HksnJs7m5GhEHP0cHVX4OZJ0t/pgzR41t/ciqqEgqgncgXPEHGLXA6Ky6cwa0v+WCsKovvNMdVrwCp35XWCUidY5eBYuIiIiIVAbniL+MxPN/dqw7UDiMcr5J721RB5XUXgUj/hzf7zabnXPZRlgVoA6qShcWaATmp1KySj3ucIIRULUMrb0dVO3DAqjv50lGTj47Tpx1dzkiBqs3eAUZ99OKGfNnt8OZGON+vYgqK0sqkHcgXDEcRs4vGlYlH4I1b8IH18I73WD5ywqrRKTO0CsBEREREZHK4AyoMs8aW4tXwS6vQmHUWWdApRF/Uoud76Ayvt/TcvIK3ndz7pPKEx5kdG+evmhAVftH/JnNJq5uFcK3v59ibXQiPSPqV9m107Lz+HbHSVbvT+BIYjrZefnU9/OkS7NgBncMo1fL+pgcH2SQOsi/AWSnQHo8NGjrui/zDGSnGveDm1d9bVKxnGHVFcONtUoP/Ai7F8PB5ZB8GNZONW71IyHqduh4O4RdUfBBJxGR2kQBlYiIiIhIZbB4Gtvsc44/nw+orB4mzCaw2eFsRg4A3p4KqKT2co7xS83Mc2yNoMrTYlY4WwXCHAFVaR1Uefm2gjWqImvxiD+AXpFGQPVbzJkquZ7NZufjDUd5a0U0ZzNyXfYdTcpg67GzzF53lA7hgbwwtANXO8YQSh3j18AY+5ZeTAfVmSPG1j/s/AhhqR28AqDz3cbNGVbt+Qqif3INq+q1NIKqqNshvIvCKhGpNRRQiYiIiIhUBg9HIOX8xHOhgMpkMuFt9SAjJ/98B5VF07el9vJ1BLAZOc6Aytiqe6pqFHRQpWaWeEzs2Uxy8+14WcyEB3pXVWlu0aNFPQC2HTtDXr4Ni0fl/fxNy87jsflb+eWAETq0DPXjziubcEWzYHw9PTh5NpN1BxP59vdT7D2Vyj0zNnFvr+a8cEuUwtu6xsfRzefsvC7Muf6UxvvVbi5hVRoc+OF8WHXmCKydZtzqtYSo26DjHQqrRKTGU0AlIiIiIlIZLI43eJ0dVB5eLrudAZXzjXpPBVRSi/l6Gi8903PygfOj/pxrU0nlKksH1fFkI7xqVt8Xs7l2v9nZtlEAAV4WzmXnsT/uHB0bB1XKdVIychkzaxM7TqTgbTUzcUgH7u3VAo8L/n5v69qE52/uwLTlB/h4QwzzNx1j18lU5o7tSbCvZ6XUJtWQjxGckllMZ58CqrrHy981rIr+EXZ/dT6sWveWcasXcX4MYHhXhVUiUuPoVbCIiIiISGVwjvjLSnX9s4PVw3gDwdlR4lmJn+AXcTdfL9cOqnNZ6qCqSo2DjJFgp1OysDsX/7pA7FljvF+T4No/PszDbKJr82CAShvzl5tv40/zf2PHiRTq+Vr57JE+jOkTUSSccqrn58krt3ViztieBPta2XH8LCM+2EhiWnal1CfVkE+wsS0uoDp73Nhq/am6ycsfOt0FIz6BvxyEu2cbHVQWHyO8XPcWfNgf3u4KP70IJ7dBCT/rRUSqG70KFhERERGpDM6OqXzHm4sW15FZVkcgleHoKLEqoJJazM/RQZWRbXy/n3N0UAV4q4OqKjg7qDJy8kl1hIMXij1rdFc1qVf7AyqA7o4xf5UVUE3+bh/rDyXh5+nBgod6c0XT4DI9r3+7hnz2SB8aBnixP+4cD87dQqbj3wmp5UrroDp32tgGNq66eqR68vKHTnfCHz6G5w4VE1b91wir/tvFCKtityqsEpFqTa+CRUREREQqg8V1pN+FI/4u7JiyeGgki9Re59egynfZ+miNnSrhbfWgnq/RrXa6hDF/sWeMEX91oYMKoEcLY72fLUcrPqDacCiJWeuOADBtRFc6hAeW6/ltGwWw8OHeBZ1Uf/50W4mdb1KLlNZBde6UsQ0Ir7JypAbw9HMNq4bPMcb9WX3hbIwRVn00wAirlr0Asb8prBKRakcBlYiIiIhIZbgwoLpgxN+FgZRG/Elt5gyo0h0j/pwdIc7HpfI1CjS6qE6nlhBQ1aERfwBdmwdjMkHs2UwSzlXcGL2s3Hz++sXvAIy6qjmDOoZd0nlaNfDnozE98PQw8+PuOGavO1phNUo1VdBBdbboPmcHVUCjKitHahhPP+h4B/xhrjEGcPhc17Bq/dvw0fXw3ytg2SSFVSJSbehVsIiIiIhIZfC4MKAqfsRfwZ8t+tVcai8/L8eIv5x87Hb7+Q4qT434qyqh/sbPpKQS1jSKPevooKojI/78vSy0DPUDYPfJlAo778cbjnIsOYOwQG8mDml/WefqGVGfvw/tAMDk7/eyK7bi6pRqyBlQZZ11fTw/D9LjjfvqoJKy8PSDjrc7wqpDRljV8Q5HWHUM1r/jGladUFglIu6jV8EiIiIiIpXhgo4pPFz/fGFAZTFrxJ/UXs5OqXybnew8Gxm5eS6PS+UL8Td+BiWl5RTZl2+zF4z+a1xHOqgAOjUOAmD3ydQKOV9KZi7vrjoEwDOD2hLgbb3sc47p04JBUY3Izbfzl89/JzffdtnnlGqqpDWo0hPAbgOTGfwaVH1dUrN5+hph1fA5Rlj1h4+LhlUzroe3roAf/w4ntiisEpEqpYBKRERERKQyXNAxVbSDynTBn/WrudRevoU6pTJy8snINjqo/BRQVZkQP6ODKjG9aAdVwrlscvPteJhNNArwKrK/turUxFgbqqI6k+auP0pKZi5tG/lzV7emFXJOk8nE5Ds7E+xrZe+pVGatPVIh55VqqKSAyrn+lH8jMOtnplwGT1+Iuu2CsOpOsPpByjHYMB1mDIS3OiusEpEqo1fBIiIiIiKV4YKOqQs7qi4MpDw14k9qMQ+zCW+r8T2enp2nEX9uEBpg/AxKPFe0gyrOsS5VA38vLHUoLHd2UO2qgBF/2Xn5fLwhBoDxA1rjUYFdsSH+Xky82Rj1N235AU46xjFKLeMMqHIzILfQWnEF609d2npmIsUqCKtmG2tW/eET6HSXI6w6XjSsOr5ZYZWIVIq685uniIiIiEhVslzQhXDBmlQa8Sd1jbOLKjM3n0yN+KtyoY4OqqQSOqgAGgbWne4pgI6OgOp4ciYpGbmXda6vt58kMS2b8CBvhnSu+HWChvdoSs+IemTl2pj204EKP79UA54Bxhg/cF2HKs0RUPkroJJK4ukLUcPg7lnw3CEYMa9oWDXzBpjWCX6YCMd/BZvGjYpIxVBAJSIiIiJSGS4MqCwXBlQa8Sd1izOMcu2gUkBVVUpbgyohzQioGvjXrYAqyNdKs/rGmlu7L7OLav6mYwCM6RNRKT/PTSYTzw8xuqi+2HqC/afPVfg1xM3MZvAONu4XHvOXnmRs/UKrvCSpg6w+0OHWC8Kqu8HTH1JPwMZ3YeaN8FYn+OF5hVUictn0KlhEREREpDJ4XCyg0og/qVv8HB1U6dn5BQGVOqiqTqgjfEpKK7mDKrSOBVQAUeHGOlT7LiPwOZyQxvbjZ/Ewm7i7e8WsPVWcbs3rcXOnMGx2eOPH/ZV2HXEjrwBjm512/rHMZGPrW7/q65G6rSCsmmmMARwxHzoPd4RVsbDxPdew6tgmhVUiUm56FSwiIiIiUhk04k/EhbcjjMrKzSdTAVWVc3ZQJablYL9gHRFnQNUgoO4FVG0aGoFAdHzaRY4s2VfbYgG4tk1opf8dThjcDrMJlu+NY++p1Eq9lrhBQUBV6L9thiOg8lFAJW5k9YEOt8BdM+Avh2DkAkdYFXA+rJo1CKZ1hO//Bsc2KqwSkTJRQCUiIiIiUhkuMuLPcuGIP3VQSS3n5fgez86zkZFjrEHlY7W4s6Q6JcSxBlVOvo3UrDyXfYlpdTigauQPwMH4S+ugstvtfLX9JAB3XNmkwuoqSasG/tzsWOPqg58PVfr1pIp5Gt+P5BQKTDMcI/7UQSXVhdUb2g91hFUHHWHVH4yw6txJ2PQ/mDXYEVb9VWGViJRKr4JFRERERCrDRUb8eV7QQWU161dzqd3OB1TnO6j8vNRBVVV8PD3wsRp/3ykZuS776nIHVeuGRiBwIC6tSGdZWeyPO8ex5Ay8LGZujGpU0eUV60/XtQLgm99PcTw5o0quKVWktBF/6qCS6qggrPrIEVYthCtGgFegI6x63zWsitmgsEpEXOhVsIiIiIhIZfCwuv75ImtQWT004k9qNy+LEY5k59nIyNWIP3cI8jF+LqVkXhBQ1eEOqlYN/DGZjL+TxLSccj9/xd54AK5pHYqvZ9V0BHZqEsS1bULJt9mZs/5olVxTqoiXo4Mqu1BHn3PEn29I1dcjUh5Wb2g/BO780AirRi0qGlbNvgmmRcF3z0HMeoVVIqKASkRERESkUphMYPE+/2ePkkf8mUzgoTWopJbztjo6qHLzyXB0UPlU0Rv6YigxoHJ2UPnXvYDK2+pB8/q+AERfwpi/FXvjABjYoWGF1nUx465pCcDnv50gyxH4Si1QMOKv0Peis4NKI/6kJrF4QbubLwirRjrCqlPw6wcw+2aY2gG++4vCKpE6TAGViIiIiEhlKRxKlTLiz2o2YzIpoJLazdlBlZlrIyfP5nhML0mrUnEBVVahwLC+v6db6nK3Ng2d61ClXeRIV0lp2Ww7fhaA69tXbUDVr20DmgT7kJKZy9LfT1XptaUSeQUaW+eIv/w8yEox7mvEn9RUBWHVB46w6lPoMgq8giDtNPz6oWtYdXQd2BS8i9QVejUgIiIiIlJZCo/583B947fwiD+N95O6wMvRQZWWfT4cUUBVtQKLCaic9z3MJgK86mZHW+uGxro/0XHlC6g2HE7Cbof2YQGEB/lURmkl8jCbuKdXcwDmb4qp0mtLJbpwxF/W2fP7fOpVeTkiFc7iBe1ugjveh79Ewz2fFQ2r5gwxwqqlE+DoWoVVIrVc3fztU0RERESkKlhK7qAqPOLPqjfppQ5whlGpmXkFj3nqe79KFddBdTYjt2BfXe3kjAz1A+BoUnq5nvfrEWP0Wu9I96wNNLxHU6b9dICtx86y73Qq7cMC3VKHVKCCEX+OsDQjydh6BYGH3sKTWsbiBW0HG7e8bDi8GnZ/BfuWQlocbP7IuPk3gvZDocMwiOhbdJ1XEanR9K+biIiIiEhlKdw1dcGLaUuhNacsZr1JL7Wfc8Rfatb5cKTwqEupfMUHVDkABPvU3Tf8mocYa1AdS84o1/OcAVWvlu4ZvdYwwJvr2zdk2Z44lmw/SfubFFDVeF5GN1/BiL+C8X5B7qlHpKq4hFU5Rli15yvY960RVm2ZZdx86kG7IUZY1WpAkQ+AiUjNo1cDIiIiIiKVpfCL5gtG/HkUCqU8NeJP6gBnB9W5LKODytOitdeqWrEBleN+kG/dDahaOAKq2DOZ5OXbyvScsxk57DttjGHr6aaACuC2rk0A+Hr7SWw2u9vqkApSEFClOraOUX9eCh+lDrF4QttBcPt7MOEg3PsFdBsDviGQeQa2z4eFI+A/reDzcUbXVXb5RrSKSPWhDioRERERkcri0kF1YUBV6L4CKqkDnGtQnXN0UHmpe6rKBfkYbwGkFl6DyjHiry53UDUK8MbTYiYnz8bJs1kFHVWl2Xz0DACtGvgR6u++T/AP7NAQfy8LsWcz+e3YGXpGuC8skwpw4Yi/goAqwD31iLibxRPa3GDchk6DYxtg7zfG7dxJ2PWFcbN4Q+sboMOt0PYm8Al2d+UiUkZ6RSAiIiIiUllcOqhc3/w1mzTiT+oW54i/wh1UUrWcXVKuHVSOEX++nsU+py4wm000q+cDQExy2dah2n7cCKjcHQh5Wz0Y3DEMgK+2xbq1FqkAF474U0Alcp6HBVpeC0P+A0/vhgeXw9VPQr0IyMsyxgEufgTeaAWf3Am/zYH0RHdXLVJrffnll9x4443Ur18fk8nE0aNHL+k8ekUgIiIiIlJZPEoe8Vd4DSpNOZO6wDniz9m946WAqsoVvwZVrsu+uqpFiB9Q9nWodsYaI9g6NXH/2kC3dW0MwA+7TpOvMX81m6fxfUiOIyh1dlI5O6tExGA2Q7OeMOhVeHI7PLoW+j0HDdqDLQ8OrYBvnoIpbWD2UNj0AaQoxBepSOnp6fTr149XXnnlss6jEX8iIiIiIpXFUnjEn+sIKI9CAZWHEiqpA7yt6qByt1LXoKrjAVXz+sZYv2NJFw+o7HY7u2NTgOoRUPVpFUKAt4Wk9By2Hz9D9xYa81djWY1OPvIyja06qEQuzmSCsM7G7fq/Q8IB2Pu1MQbw1HaIWWvcvn8OmvSAqGHGKMD6ke6uXKRGGz16NAC7du26rPPoFYGIiIiISGVxWYPK9c1fj0Jj/QqHVSK1lbNj6ly2Aip38fcyfg6lO/4bQKE1qHzrdkDVwrHuVEwZAqrTqVkkpefgYTbRPsz9wYHVw8yAdg0BWLYnzs3VyGVxBlS5CqhELlmDttBvAjzyMzz1Owz+NzTrDZggdgv89CK8fSX8ry+sfh3i94Jd3adSM8ybN49HHnmEHj164OXlhclkYs6cOaU+Z/PmzQwZMoTg4GD8/Pzo3bs3n332WdUUXAbqoBIRERERqSwuAZXriD+PQu/Nm9VBJXXAhYGUc00qqTq+nsbfeVrhgCpTARVAs3pGQBV7NvOix+5yjPdr09C/oDPQ3W6MasTXO06yfE8cz9/cwd3lyKWyODuossBmU0AlcrnqtYA+443budPGOlV7voajayFup3Fb/W8IaWN0VUUNg/Cumr8t1dakSZOIiYkhNDSU8PBwYmJiSj1+1apVDB48GG9vb0aOHElAQABffPEFI0aM4Pjx4zz77LNVVHnJ9JE1EREREZHKUrhr6oKAqnAopQ4qqQs8PVxffqqDqur5exmfUc3Os5GXbwPOB1SB3nU7oAoP9gbgVEpZAqrqM97P6bp2DbB6mDiUkM7hhDR3lyOXytlBBUZIpYBKpOIEhEHPP8L9X8NfDsJt70Lbm4zf0ZOiYe1U+LA/vHUF/DARYjYYQbFINTJjxgyOHj1KQkICjz76aKnH5uXl8dBDD2E2m/nll1/48MMPefPNN9mxYwdt27Zl4sSJRQKuv/3tb5hMplJvFU0dVCIiIiIilabQL/AXjPizeJzfp3xK6oLC3/NQNLCSyufrdb7bJz0nnyAfc8G4Pz+vuv32QHiQEQwkpuWQnZdfaoffwXgjAGrXqPqEBoHeVnq1DGHtwURW7osnsoG/u0uSS3FhQJXjCBsVUIlULN/6cOV9xi0rFaKXGetWRf8EKcdg47vGzb8RtL/F6K6K6Fvk93mRqnbDDTeU+diVK1dy6NAhxo4dS9euXQseDwoKYuLEiTzwwAPMnTuXF198sWDfs88+ywMPPFCBFV9c3f4NVERERESkMhX+hFkpHVRmJVRSB1jMF4z4syqgqmpeFg+sHiZy8+1k5OQR5GMlPccIqPzreEBVz9eKl8VMdp6NuJRsmjvWpCrOIUeHUuuG1SsE6tc2lLUHE1l/KIk/Xhvp7nLkUpg9wGwFWy7kZpzvoPKsXt9rIrWKdyB0vtu45WTAoZVGWLX/B0iLgy0zjZtPPWg3BDoMg8j+YPV2d+VSS5w7d47U1NSCP3t5eeHl5XXZ5129ejUAgwYNKrJv8ODBAPz8888ujzdo0IAGDRpc9rXLQ68IREREREQqTckBVeGxfh6acy91gFUdVNWCs1PK2TmVnp3v8nhdZTKZCA+6+Ji/fJudw4npALSqZl1K17QOBWDj4SRy8zWWqsayOsLR3CzIdrxhqQ4qkarh6QsdboE7PzTGAN77BXQbA74hkHkGts+HhSPgjdbw+TjY/RVka6yqXJ6oqCiCgoIKbpMnT66Q80ZHRwPQpk2bIvvCwsLw9/cvOOZSJCcns337dvbv3w/Anj172L59O8nJyeU6T93+DVREREREpDIVDp4u6B6xmNVBJXWLRWtQVQt+nhbOZuSSlp2P3W4v6KDy8yp5pF1dER7kw9GkDE6lZJV4TOyZTHLybHhazDSp51Pice7QISyQ+n6eJKfnsP34WXpG1Hd3SXIprN6QnWJ0UOVkGI95+rm3JpG6yOIJbW4wbkOnwbENsPcb43buJOz6wrhZvKH1DcYYwLY3gU+wuyuXGmbPnj00adKk4M8V0T0FkJJirJkZFFT8mpmBgYEFx1yKr7/+mrFjxxb8eejQoQDMnj27XGMCFVCJiIiIiLhB4RF/6qCSusByQRBb2ho/UnmcQVRGdh4ZOfnY7cbjdX3EH1Cog6rkgMo53i8y1M+lE7Y6MJtNXN0qhG9/P8Xa6EQFVDWVcx2qvCzjBsYb4CLiPh4WaHmtcbvpNYj9zRgDuPdrOHMU9n1r3MwWaHkdRA0z1q7yC3V35VIDBAQEEBgY6O4yyu2BBx6okPWq9JE1ERERERE38HDpoHJjISJVxHpBB9WFI/+kavh6GkFUWnZewZg/kwl8rAoMw4MvPuLPGVBVt/F+Tn0dY/7WHkx0cyVyySyOgCo3A3Id34vW6tWtJ1Knmc3QrCcMehWe3A6ProV+z0GD9mDLg0Mr4JunYEobyw4S2gABAABJREFUmD0UNn0AKbHurlrqIGfnVEldUqmpqSV2V1UlfURKRERERKTSlPwGvEtApQ4qqQMsFwRSF/5ZqoazUyojJ5/0HMf6U54WTPo5RFiQEQKcPFtaB5Wx/lRkg+o5cu3qVkZA9fuJs2Tl5uOt4LHmcYZRuYU6qBRQiVRPJhOEdTZu1/8dEg44Oqu+gVPbIWatcfv+OWjSw+is6nAr1I90d+VSBzjXnoqOjqZ79+4u+06fPk1aWhpXXXWVO0pzoc9qioiIiIhUlvycEncVDqiq25gokcpgLbIOm16OuoOvpxFYFO6g0vpThkYBxpoPCWnZJR5z4oyxJlDz+r5VUlN5NavvQ8MAL3Lz7ew4ftbd5cilsBbTQWVRQCVSIzRoC/0mwCM/w1O/w+B/Q7PegAlit8BPL8LbV8L/+sLq1yF+LwWzdkUq2HXXXQfAsmXLiuz78ccfXY5xJ70iEBERERGpLHklv8npoTWopI65sGNKwax7nO+gyiOtIKDScBWAUEdAlXiu5J/dsWeMwKBpveoZUJlMJrq3qAfAb8fOuLkauSTOgCorBXC8cW3VGlQiNU69FtBnPDz4Izy7D4a+aaxPZfKAuJ2w+t/wXm+Y3hOWvwwntymskgo1cOBAIiMjWbBgAdu3by94PCUlhX//+994enoyZswY9xXooN9CRUREREQqS78JcHA5dCv6i3/hN+c1WkvqgiIj/hRQuYWvl7ODKr+gg8pfARUAoX6OgCotG7vdXuRns81m58RZZ0BVfTtaureox/e7TvPbUQVUNZLFEUZlJhd6rPp+v4lIGQSEQc8/GreMZNj/nTEG8NBKSIqGtVONW1BzYwRgh1uhWS8tVCtFzJgxg7Vr1wKwc+fOgsdWr14NQN++ffnjH/8IgMViYcaMGQwePJh+/foxcuRIAgIC+OKLL4iJiWHKlClERES448twod9CRUREREQqS5Nu8JeD4OlfZJfriL+qLErEPYqM+NMaVG7hbTECquzc/PMdVJ56awAgNMATgOw8G+k5+UWCu8T0bHLybJhNEBZUfTtaCndQFRe0STXn7KDKdASMJg/wsLqvHhGpWL714cr7jFtWKkQvM9ativ4JUo7BxneNm38jaD/UCKsirtXPgVpswIABWK1Wxo8fz/jx40s9du3atcydO9flsXXr1rFu3bqCPzsDKue5165dy0svvcSnn35Kbm4unTt35vXXX2fEiBEV+4VcIv0WKiIiIiJSmbwDi33YrDWopI4pOuJPyaw7eFuNgCorN5+MnHxAa1A5+Xpa8PX0ICMnn8Rz2UUCqhOO8X7hQT5Yq/EnCzo2DsLLYuZsRi6HEtJp3bDohySkGvMwOvnISjW2Vh9QyChSO3kHQue7jVtOhtFRtfdr2P8DpMXBllnGzTsI2t4MHW6BVgPBs3qOmZVLs2rVKpo2bVqmY+fMmcOcOXPKdf6rrrqK77///hIqqxoKqERERERE3KDweDOz3niSOuDCN/Q14s89vK3Gf4esXBuZjoDKRx1UBUL8PclIziQpPZuIUD+Xfc6Aqklw9R635mkx06VpML8eTWbbsTMKqGoaZ5dEVoqxtVTfbj0RqUCevkYA1eEWyMuBI784wqrvID0Bfl9k3Cw+0Hqg0VnVdjD41HN35SKXRb+FioiIiIi4gQl1UEndcmEgpe979yjooMrLJyvPCKi8LNW3G6iqhfp7cTw5k4RzOUX2xZ6p/utPOXVuGsSvR5PZfTKV4e4uRsrH4uygcgRU1ur//SYiFcziCW1uMG62aXB8E+z9FvZ9A2ePwb5vjZvZYoz/63CrMQ4wIMzdlYuUmwIqERERERE3KNw0pQ4qqQsuDKTUQeUeXoVG/GXn2oDzXVUCIX5GOJCYll1k3+kUI6CqzutPOXVqYoyX3X0yxc2VSLmpg0pECjN7QIurjdvgf8Hp342wau83kLAXDq8ybkufhaY9jbCqwy1QP9LdlYuUiQIqERERERE3KBxKKaCSusBkMmH1MJGbbwfUQeUu3o5uqew8W0EHlbdFa1A5NQjwBCAprWgHVYIjtGoY4FWlNV2KTo2DANh9MpV8m13/v9UkHsb3INmF1qASEQHjE27hXYzb9X+HpENGULX3G4jdAid+NW4/vQCNOkF7x8jARp20lp1UWwqoRERERETcoPBrRA81L0gdYfUwk5tvhCLqoHIP72I6qLzUQVWgtA6q+FRHQBVY/TtaIhv44201k5GTz5HEdK1DVZN4aMSfiJRRSCvo+2fjlnoS9i01wqqjayFul3H7+TWoF+EIq4YZXVZm/bsv1YcCKhERERERNyjcNaVPtktdUTiU8lAy6xbnAyob2eqgKiLE39FBlV40oHJ2UDWoAR1UHmYTUeGBbD12lt0nUxRQ1SQa8ScilyKwMVz1kHHLSIYDPxijAA+tgDNHYcN04+bfCNoNMUYBRlxrrHcl4kYKqERERERE3MCsNaikDrIWCqXUQeUezvWmsnLzySpYg0oBlVOwrxEOpGTmujxut9sLOqga+Ff/gAqgU5Mgth47y67YFG7r2qRCz51vs7Ns92mW7Ykj9mwmgd4WekbU545uTWgYoEDlslgc31/5jjGT6qASkfLyrQ9d7zFuOelwcLnRWXXgR0iLg99mGzevIGh3k9Fd1XogePq5u3KpgxRQiYiIiIi4gUlrUEkdZPE4/72ugMo9vBzdUtl5NrJyjQ4qjfg7L9jH+CT5hQFVek4+mY6/r5rQQQXQsXEgAHtPnavQ8x5JTOeJhVvZFZvq8vjyvfG8tTyap25ow8PXRmLW/+OXxsOz9D+LiJSHpx9E3Wbc8nLg6C9GWLXvO0iPh98/NW4WHyOkan+LEVr51HN35XXGgAEDsFqtjB8/nvHjx7u7nCqngEpERERExA1c16DSm3hSN1gKrXlQOKySqlO4gyo7z9FBpRF/BQJ9iu+gSjhndE/5eXrg51Uz3kpp0ygAgIPxaRV2zl2xKdw3cxNnM3IJ8LZwT6/mdGocRMK5bJZsj2XHiRRe+34fW2PO8PaoK/+fvf+Oj+Ow74T/z9Rt6ABBgATBXkSxqFBUL4xly5adnFssuZxb7EQxn5zPJU8cJRc/d77Y8V3sPPnd6ZzIOls+W3LsPI7tyLZKbFHV6hIp9k4CLABRF8D2Kb8/ZmZ3drEAFsDuzpbPW6957e7M7OwsCJGL+eD7/bI6byGmBVSKN+dBRLVHVoF1t1vLO78J9L8MHPmFFViNn7XuH/kFIMrAqpusNoAb3wk0dXt95jVtz5496Onp8fo0PFMdn6qIiIiIiGqMyAoqqkPZwSyrdryQmUGls4Iqj2YnoIpmB1SXJuIAqqd6CkB67tTARBzhWCr93hbqYjiGTz74CsajKWxf0YJv//ur0dmUaef3iRtX4Z9e6ceXf34QTxwaxO6HXse3PnI1VJnfX/PCCioiKgdRAlZeby1v+6/AwH47rPoFcOkgcOopa/nlF4Cea6ywatO7gPa1Xp851Rh+SiAiIiIi8oCYdaHeu/MgKid3tSBb/HnDqZaKp1wt/lhBlebMoJqIa9ANM71+aMqqoKqm+UpNfgVddoC02Coq0zTxp//8Ji5NJrBxaSO+/wc7s8IpwGpd+8GdvXjwk9fAJ4v4zZFL+OqvDi/qdetSbsUUK6iIqNQEAejeBuy6F/jMb4E/eR1463+xgikAOPcK8G9/BfyPq4D/dQOw56vAxTcB05z9uEQF4I/CREREREQeEOCqoOKFeqoTkquEiq0tvZFu8ae5WvyxgirNXWU0Gc9UUY1GkgCA1lB1hQXrl1pVVCcuLW4O1f/32jk8d2IYPlnEtz5yFZr8M38dbljbgf/5oasAAA/+9gwe2XdhUa9dd+ScKj2xur7niKgGtK8Fbvws8KlfA58/ArzzG8Ca26zWf5cOAk9/HfjHm4G/3w48/hfA2RcAQ/f6rKlK8VMoEREREZEH3K3O2OKP6oXICirPOe3WTBOIJq2LSZwTlKFIIoKq9fVwz6Eat1v+tQarq93a+k5rDtWxwYVXUMVTOr7xxDEAwOfeugFrljTM+Zy3bl6Kz9xmtYH6Tz8/gBG7Ao0KwBZ/RFRJmrqBaz4FfPTnwBePA+/+B6vVnxyw5la98D+B774d+MYm4JHPAid+DWhJr8+aqggDKiIiIiIiD7gv1EsMqKhOZLe25Pe9FxRXT1GnQogBVbYWu4pqPDo9oGqptoDKrqA6vogWfz98uQ8DE3F0N/vx8RtWFfy8z711Ay7rbsJ4NIWv/OLQgl+/7kwLqFhBRUQVItgGXPFB4O6HgP/7JPCB7wPb7gJ8zUDkEvDag8AP3gf893XATz4FHPo5kIx4fdZU4RhQERERERF5wH1pnvkU1Qt3taAs8sdRL2QHVBoAwCfzz8KtyQ6osiqoYtZvgzszqqrFBiegGlxYiz/dMPHAs6cBALt3rZtXmKlIIv7mvVshCsDP9l7Aa2fHFnQOdYcBFRFVAzUEbP494L33A396AvjIvwA7Pgk0LAUSYWD/PwM//ijw39YAP/wQMLDf6zOmCsVPoUREREREHnBfqGc+RfXCXTXFCipvKFLm656ZQcUKKrfmPAFV2KmgClRXWLCmwwqoLobjiKfmPx/kmWNDOD8eQ3NAwfuv7pn387evaMHvX70CAPD1x47ANM15H6PusMUfEVUbWQXWvQV4199ZM6s++QRw/f8FtK4CtDhw9JeAwM8alB8DKiIiIiIiD4gsoaI65A6lZInf914QBGHa/C9WUGVzAqrxrAoqp8VfdQVULUEFjX4ZANA3Gp338x966SwA4P1X9yw4yPzs7euhyiJePj2KZ44PL+gYdUXM+TqLsjfnQUS0EKII9F4L3PHXwH/YC9zzHPC2vwY6L/P6zKhC8VMoEREREZEXhLx3iWpados/fud7xd3mL9/jetfot0KoKbsFIgCMRa0Wf82B6qpmEQQBK9uDAICzI/MLqMLRFJ46OgQA+ODOFQs+h2UtAXzk2pUAgH946uSCj1M3cgMpVlARUbUSBKBrK3DD/8VfyKMZ8VMoEREREZEHslr88ec1qhPuTIozqLyj5FSv5T6udw0+q4IlksgEVOkWf1VWQQUAK9tCAICzI/MbVP9vhwehGSY2Lm3Eus7GRZ3Dp25eDVkU8MKpEew/F17UsWoeAyoiorqya9cubN68Gffdd5/Xp+IJ/kRAREREROQBd0DFkRxUL9wt/phPeUfNaekns4IqS8hnBQSRpBVQmaaZbvHXGqy+sKB3gRVUjx24CAB4+5auRZ/DspYA3rWtGwBw/7OnFn28mjYtoGKLPyKiWrZnzx4cOnQIu3fv9vpUPMFPoUREREREHnBXknBoPNULdzArsnTQM9Nb/PHPwi0dUNkVVFMJDbph/T1dnRVUdkA1jxlUsaSenhf1jq2LD6gA4FM3rwFgBV8jU4miHLMm5c6gYgUVERHVMAZUREREREQeEDh5iupQVgUVAyrPTAuoWM6WpSEdUOkAgHG7vZ8qi/Ar0ozPq1ROBVXfPFr8vXp2FEnNQHezHxuXLq69n2PL8mZs62lGSjfx0zfOF+WYNSk3oBKrLxQlIiIqFD+FEhERERF5QHB9Emf9FNWL7AoqD0+kzrkrpkQBEPmHkcWpoJqyK6gm49Ztk786g4JV7dYMqnNjMWi6UdBznj8xAgC4YW0HhCKGyXddswIA8E+v9LN6eCbTWvxV5/cdERFRIRhQERERERF5gDOoqB65gxAWUHnHXUHF+VPTNfisChanxZ8zi8pZX226mvxQJAGaYWJwsrDWer89abX3u3Fde1HP5fe2L0NAkXDi0hTe6B8v6rFrxrSAii3+iIiodvGTKBERERGRB3htnuqRe9RRMasyaH7cAZXKgGqaoJpdQeXcOpVV1UYUBSxt8gMALo7H5tx/Mp7C/vNhAFYFVTE1+hW87fKlAIBfvnmxqMeuGaygIiKiOsJPokREREREHsiqoGKTP6oTnEFVGdwt/mSJfw65nCDKqZyKVHlABQDdzVZAdSEcn3Pf/efCME2gpzWALvt5xfSubcsAWAGVYfDfv2k4g4qIiOoIAyoiIiIiIg+4r82zxR/VC4EzqCpCVos/kZcFcjU4AVVCt2+1rPXVqLs5AAAYCM9dQbX33DgAYPuKlpKcy83rO9DokzEwEcfrfWMleY2qNq2Cqnq/74iIiObCT6JERERERB5g8QjVI8n1jS+w0aVnVDlzKUBhBdU0IXvWVKbFn26vr96goLvFrqAan7uC6s1+q73f9p7mkpyLX5Hw1s1Wm79fsM3fdLkBlVCds8+IiIgKwYCKiIiIiMgD2S3+iOqDu8UfQ1rvuCuoFM6gmsaplEpqBlK64aqgqt6gYJldQXWxgAqqfU4FVU9Lyc7nHVu7AQC/PjwIk2XE2aYFVPx/lIiIahf/lSMiIiIi8kBWQMVrc1QnRPcMKvb48wxnUM3OXSkVSWjpgCqoVnEFlT1L6uIcM6iGJhO4GI5DFIAty0tTQQUAN65rhyqJODcWw8mhqZK9TlXKrZjKnUlFREQ1ZdeuXdi8eTPuu+8+r0/FE9X76YqIiIiIqIqJWTOomFBRfXB/3zMW8Y7srqDiDKppFEmEJArQDRPxlJFu9VfVLf7SFVSzB1THBicBACvbQyV9v0FVxrVr2vDs8WHsOTKEdZ2NJXutqiOKVtWUaViP2eKPiKim7dmzBz09PV6fhmf4SZSIiIiIyAMC+5tRHXLPoBL5/4Bn3H8OrKDKz2/P6Yqn9Jpo8efMoBqeSiCpGTPud9wOqNZ1NpT8nHZt7AQA7Dl6qeSvVXXcbf4YIhMRUQ3jv3JERERERB5j/RTVi6wWf8xFPCO7vvicQZWfX7HCqLimYyqhA6juCqr2kApVEmGawODEzFVUxy5Z7fY2LC1DQLXJCqhePj2arlIjmzugYgUVERHVMH4SJSIiIiLyGK/TU71wV+6witA7YlZAxT+HfNIBVcpwVVBVb0AlCAI6GlQAVhXVTJwKqvVlaLm3uiOEntYANMPEq2dGS/56VcUdSnEGFRER1TAGVEREREREXuP1YaoT7k5VzKe8466gktk+LC+fkmnxF01ZFVRBtXoDKgDoaPQBAIanknm3m6aJY4NWBdX6MlRQAcB1a9oBAC+eYkCVxR1KsYKKiIhqGD+JEhERERERUZlwBlUlcFdQcQZVfn7ZqaDSkbADqoBS3UFBR4MTUOWvoBqJJBGOpSAIwNol5Q2oXjo9UpbXqxpZM6iq+/uOiIhoNgyoiIiIiIiIqEwyE9c4g8o7nEE1N3+6gspA3A6onKqqapVu8TeZP6A6OxIFAHQ3+dMtDkvt2tVtAIA3z4XTrRQJOTOoqvv7joiIaDb8V46IiIiIyGMCe/xRHeL3vXekrBZ//HPIxwloEpqOeMqw1snVXckyVwXVuTEroOppC5btnFa0BbG8JQDdMPHa2bGyvW7FY0BFRER1gv/KERERERERUdnxmqt3JFd7RYkBVV5OQBVP6Uhour2uur9pMwFV/hlU/aNWQLWitXwBFeCeQ8U2f2nu2XBs8UdERDWsuj9dERERERHVAI7ioXphZjr8cQaVhySJs8Dmkt3iz66gqvYZVI1WQDU0QwVV/2gMALCiLVC2cwKAa1a1AgD29o+X9XUrmiDlv09ERFRjGFARERERERFR2bFwxzvuCiqRfxB5Oe384ikdcbuCyidX9yWU9AyqmQKqMW8qqLavaAFgzaEyDHP2neuFwAoqIiKqD9X96YqIiIiIiIiqEmdQecc9d4r5VH5+1QoFJuNauvLPV+UVVEucFn+TswdUve3lDajWdzYgoEiYSmg4NTxV1teuWO6AihVURERUwxhQERERERERUVm4W/yxs5x33FVTEv8g8nIqqMKxVGZdjcygmohrSGpG1jbdMHFhPA4A6Gktb4s/WRKxdXkzAGBvf7isr12x3P9fsoKKiIhqWHV/uiIiIiIiqgG8PEz1iLOPvOOuoBL455CXE0aN2wGVIACqVN2XUJoDSjr3cAdvADAylYBumBAFoLPRX/Zz277CCqj2cQ7VdEJ1f98REdHsdu3ahc2bN+O+++7z+lQ8IXt9AkRERERERFR/mIt4RxJF130PT6SC+e12fuPRJABr/lS1h3miKKDJryAcSyEcS2JJoy+97ZLd9q+jwQfJg76PzhyqfefGy/7aFY8BFRFRTduzZw96enq8Pg3P8F85IiIiIiKPVfk1T6KCmcj0+GMFlXfcoRT/HPLzydYXacKuNPJX+fwpR0tQAQCMR7MrqAYnrPZ+nU2+ac8ph23LWwAARy5OIqUbs+9cD9z9UNnij4iIahgDKiIiIiIiIio7D4o0yOauoBL5B5GXYqd4k3ENQGYmVbVrCeQPqJwKqqUetPcDrLlXDT4ZSd3A6eGIJ+dQWdwD+2rje4+IiCgfBlRERERERB4TOIWK6oS7KKDa26VVM8n1pWc+lZ/iVFDZAZVPqY3LJ81BFUBmtpbD6woqURSwsasRAHBkYNKTc6gorKAiIqI6URufsIiIiIiIiIioIJKrx5/EoDAv1U7xphJ2i78aqaBqTldQJbPWOxVUnR5VUAHIBFQXJzw7h8rBCioiIqoPDKiIiIiIiDzG68NEVE7uUIqVbPk5Lf7iKWsekr9GKqicFn/hnAqqSxN2QOVRBRUAXGYHVEdZQcUKKiIiqhu18QmLiIiIiIiIKp459y5UBrKrr5/EHn95KVL25RJfjVRQtQRnCKgmrRZ/Xs2gAoCNXU0A2OJvGobIRERUwxhQEREREREREdUR9/VuBlT55QZUqlwbl08yLf6yA6qRKavlX0ejdxVUTou/8+OxaQEaERER1aba+IRFRERERERERAURs1r8eXgiFUyVs78wilQbX6iWoAoAGM8JgMbsmVStdoWVF5oDCpY1WxVcxwbrvYqK9aZERFQfGFARERERERFRWZi85loRRNeVAIkJVV65FVSyVBuXTzIVVMn0uoSmI5rUAWQCLK+sW2pVUZ0amvL0PDzHvyyJiKhO1MYnLCIiIiKiKsbLw0RUTu4KKpEBVV6ymH25pFYqqBr9MgBgKq6l1znt/iRRQJO93StrOkIAgFPDEU/Pw3sMqIiIqD4woCIiIiIiIiKqI4I7oOIMqrxyW/zlBlbVqsFnB1SJTEDltPdrDihZ3xteWG0HVKeH6jygYj5FRER1ojY+YRERERERVTNWMBBRGbkzKeZT+eW2+Mt9XK2cgCqSmF5B1eLh/ClHOqCq+woqIiKi+uBt7TYRERERERHVDZNlARXB3daPM6jymx5Q1cbXKeQEVEkdhmFCFIX0PKpWj+dPAZmA6uxIFLphQlpkghpL6ni9bwxJzcDly5vQ2egvxmkSERFRkTCgIiIiIiLyWG1c9iSiapFVQcUSqrxyAyq5RgIqp4IKACJJDY1+BWN2BVVrBVRQLWsJQJVFJDUDF8ZjWNEWXNBxTNPEj17px988diRrxtb7rlqOv3zXZjT5vX+vs6qNbzciIqI51UaNOhEREREREREVJGsGFSuo8lJzA6oamUHlV8R0VVIkoQPIzKBqqYAKKkkUsKrdCqVOLaLN3/948gS+9C/7MR5NoavJj41LG6EbJn786jm8/1u/xfBUolinXCL8/5KIiOpDbXzCIiIiIiKqYrw+THWDHf4qgpgVUHl4IhVMkbO/MLXS4k8QBIRUCQAwZc+hSs+gClRGVdGajgYAwKmhqQU9//GDA/jmvx0DAHz+rRvw/Jd+B49/7hb88z3XY2mTD8cGp/CJ776CeEov2jkXncDLdURE9WLXrl3YvHkz7rvvPq9PxRP8F4+IiIiIiIjKgvlUZXCHUoud8VOrps+gqp3LJ06bv0g6oLJnUIW8r6ACgF67gurcWGzez52Mp/BXPz8AAPjUTavxH96yPv09fs2qNvzTH16P1qCC/efD+OtfHi7eSRcbf3OFiKhu7NmzB4cOHcLu3bu9PhVP1M4nLCIiIiIiIiKak/vaN1v85Td9BlXtXD4J5QRUTiVVo78yxpQvbwkAAM4vIKD6znNnMDiRwKr2IL54x8Zp21d3hPB3d10BAPj+i2fx6pnRRZ1rybCCioiI6gT/xSMiIiIiIqKyYBRSGQS2+JtTbks/pYa+UA12EDVpB1OTces2pFZWQHVuPDqv500lNHzn+dMAgC+8bSP8ipR3v9s2duKuHSsAAP/p5wehG5VY21k7329ERESzYUBFREREROQxgReiqF7wW70iuKum2OIvv9yvSy1VUOW2+HNuGyqlgqp1YRVUj+y7gHAshdUdIdy5tXvWff/sHZvQ5Jdx+OIEHj1wccHnSkRERItTO5+wiIiIiIiIiGhO7uxFYIu/vGQxdwZV7XydnEqp3BZ/TnDlNSegGoum0udYiH9+tR8AcPc1K+YMXttCKj5502oAwP/4zQkYFVlFRUREVPsYUBERERERERHVEVZQzS33y5I7k6qaOTOophI6ACBi34YqJKBq8itosqu5zo8XVkV1ejiC1/vGIYkC3nPV8oKe84kbVqPRJ+Po4CSePja04PMtCQbHRERUJ2rnExYRERERUZXidSiqF2xnWRncf+cwn8pPEISs8E6uoQqqBp81m6lSK6gAYHlrEEDhbf5+fWgQAHDD2nZ0NvoLek5zUMFd11izqB566ewCzpKIiIgWiwEVERERERERlQXD2MrgrqAS+YcyI3dApYi1c/nEr1oBVTylwzTNigyoeuw2f+cKrKB68sglAMDvbOqc1+t86Nre9PMLrdYiIiKi4qmdT1hERERERFWKl4eJqJwYUBVGEmqzgsovWwFVLKUjoRnQ7flLIbuyqhIsb7ECqkIqqCbjKbxyZhQAsGvj/AKqNUsacOO6dhgm8ONX+ud/okRERLQoDKiIiIiIiIioLGrnEn91c7f14wyqmcnuCqoamkHlV5wKKgOTcS29PqRWTgXV0iarTd+lyfic+75wcgSaYWJVexCrOkLzfq3fv9pq8/fImxdgmua8n18a/P+SiIjqQ+18wiqxv//7v8fKlSvh9/tx0003Yd++fV6fEhERERHVCBYwEFE5Ca6/dPj3z8zErICqdr5QAcW6FBTX9PQcqpAqZb1fry1t8gEALk0k5tz3tb4xAMB1a9oX9Fq3b14Knyzi1FAEhy5OLOgYREREtDAMqArw8MMP48/+7M/wla98Ba+99hrWrVuHO+64AxMT/OBCRERERERE1cWdQ7DF38zcFVS19HVKV1Al9cz8KX/lVE8BQGejVUE1ODF3BdUbZ8cBAFetbF3QazX45PTsqkf2XVzQMYiIiGhhGFAV4O/+7u9wzz334KMf/Sguv/xyPPDAA9A0DQ8//LDXp0ZERERERFQ1augaf1XjDKrCuNsf1lIrxHRApWUCqpCvsgKqdAXV5OwVVEnNwL5z4wCAqxcYUAHAu7YtAwA8doABFRERUTnNO6CKx+P4/Oc/j1tuuQXLli2D3+9HV1cXbrzxRnz3u99FKpUqxXnO6gc/+AH+6I/+CDt27IDP54MgCHjwwQdnfc4rr7yCO++8Ey0tLQiFQrjuuuvw4x//eNp+yWQSb7zxBm6//fb0OlmWcdttt+GFF14o9lshIiIiIiIiKimGUoWRar2CKmUgltIBAAF7XaVwKqjCsRTi9jnmc/jiBBKagZaggjULmD/luGVDBxRJwJmRKE4PRxZ8HCIiIpqfeQdUU1NT+Na3vgVBEPDOd74Tn//85/Ge97wH58+fxyc/+Um8613vgmEYpTjXGf3lX/4l7r//fpw9exbd3d1z7r9nzx7ceOONeO655/CBD3wA99xzDwYGBnDXXXfhG9/4Rta+w8PD0HUdS5cuzVrf2dmJgYGBor4PIiIiIqpPQg1d+CSiyuf+K4d//cwsK6CqqQoq61JQLKkjYYc//goLqJoCMnyydZ6zzaE6eMEavbB1efOi/i1t9Cu4ZlUbAGDPkUsLPk7R8H9MIiKqE/MOqNra2hAOh/H000/j29/+Nr761a/iW9/6Fk6cOIHbbrsNTzzxBB599NE5j/P9738fZ8+enXG7ruv45je/iWQyOeexHnjgAZw5cwZDQ0O45557Zt1X0zR8+tOfhiiKeOaZZ3D//ffjG9/4Bvbt24cNGzbg3nvvnfW8iIiIiIiIaGEE8KJrJailaqBSyq6g8vBEiszd4i+eMux1lTUBQhAELG2y51BNzjyH6uiAFVBd1t206Nd05lDtOVoBARUREVGdmPcnEFEUoarqtPWyLOM973kPAODEiROzHuPcuXP49Kc/jdtuuy1vGGQYBj72sY/hC1/4Au6///45z+n222/HypUrCzr/J598EidPnsSHPvQhXHHFFen1zc3NuPfee5FMJvG9730vvb6jowOSJGFwcDDrOJcuXUJXV1dBr0lEREREREQsCqgUoutKAP9IZpY1g6qGvnmdgCqRMtLt8/xyZVVQAUBnoz2HapYKqiMDkwCAjUsbF/16t220AqqXTo3O2laQiIiIiqdovyJjGAYee+wxAMCWLVtm3benpwc//OEPce7cOezatQt9fX1Zx/nYxz6Ghx56CB/96Efxmc98plinCAB46qmnAABve9vbpm274447AABPP/10ep2qqrjyyivxm9/8Jr1O0zQ89dRTuP7664t6bkRERERERLWshq7xVzVWshVGdgVUtdSK1Zk3FUvpmYCqwlr8AUBnkxVQDU7kr6AyTTMTUHUtPqBauySEpU0+JHUDr/eNLfp4RERENDd5oU9MJpP46le/CtM0MTIygt/85jc4cuQIPvGJT+Atb3nLnM9/z3vegx/+8If44Ac/iNtuuw1PPfUUenp68PGPfxw/+MEP8OEPfxjf/e53IYrFLTM/fvw4AGD9+vXTtnV1daGhoSG9j+Nzn/sc/uAP/gBXX301rrrqKvzt3/4tZFnGhz70oRlf57777sN9991XUItCIiIiIiIionLhDKrCuFshSjXU489p5xdP6YhrVos/X4W1+AOA9pAVUI1F819XGZxIIBxLQRIFrOtsWPTrCYKA69a04+d7L+DFU6O4YW3Hoo9JREREs1tUQPWf//N/Tj8WBAFf/OIX8bWvfa3gY7z//e+Hruv48Ic/jF27dmHHjh348Y9/jLvvvhvf+973ih5OAUA4HAZgtfTLp6mpKb2P40Mf+hCGhoZw7733YnBwEDt27MDjjz+OpqaZexzv3r0bu3fvxrlz57BixYrivQEiIiIiIiKiRaidqKW0ZKnGZ1BVeAVVa8gaLzEayR9QnRqaAgD0tgWLdv6ZgGqkKMdbuBr6hiMiIprFggOqhoYGmKYJwzBw4cIFPPLII7j33nvxwgsv4Fe/+tWs4Y3bXXfdBU3T8JGPfASnTp3Cu9/9bvzgBz+AJFXWh6PPfvaz+OxnP+v1aRAREREREVUxXnStPPwzmYl77pRYQwmVE+YYJjAZ16x1FTiDqn2OgOrsaBQAsLI9WLTXvG5NOwBgb9844im9IoM7IiKiWrLoEiVRFNHT04M//uM/xv3334/nn38ef/3Xf13w803TxJNPPpl+fPDgQQwODi72tGbkVE7lVkk5JiYmZqyuIiIiIiIiIqp2bPFXGHdbP7GGvlB+Vzu/8Whq2rpKMVcF1dkRO6BqK15Atao9mJ5Dtbd/vGjHJSIimsmuXbuwefNm3HfffV6fiieK+gnkbW97GwDgqaeeKmh/0zTxh3/4h/jOd76Du+66Cz/4wQ9w6tQp7Nq1CxcuXCjmqaU5s6dy50wBwMDAAKampvLOpyIiIiIiIqLFqaFr/FWOfxCFcAdUUg1986pS5lLQRNwJqCqvUqgtaAVUM82g6huNAAB620NFe01BEHBVbysAMKAiIqKy2LNnDw4dOoTdu3d7fSqeKGpA5YRKiqLMua9pmvijP/ojPPDAA/jABz6Ahx56CB/+8Ifx/e9/HydPnsSuXbtw8eLFYp4eAODWW28FADzxxBPTtj3++ONZ+xARERERERHVmqwKKu9Oo+K5q6ZqKJ+CIAjpkGoyXrkVVG3pCqpU3u2lqKACgO0rWgAA+7wMqGrpG46IiGgW8/4EcujQIUSj0Wnro9EoPv/5zwMA7rzzzlmPYZom/viP/xjf/va30+GUM3Pqgx/8YFZINTAwMN9TnNVb3vIWrFmzBg8//DD27t2bXh8Oh/HVr34Vqqriox/9aFFfk4iIiIiIiIiqizugkmpoBhUAKJL1ftIzqCqxgiqUqaAyTTNrm2ma6LMDqt4izqACgCsqIaAiIiKqE/J8n/DjH/8Y3/zmN3HTTTdh1apVaGpqwvnz5/Hoo49iZGQEN998Mz73uc/NeowLFy7gpz/9KX7/938fDz30EGQ5+zQ++MEPwjRNfPSjH8Wvf/1rfOQjH5n1eA888ACee+45AMD+/fvT65xWgzfddBM+9alPWW9YlvHAAw/gjjvuwC233IK7774bjY2N+MlPfoKzZ8/ib//2b7Fq1ar5flmIiIiIiIhoDrV1ib96uf8cBFZqzEh0/UpvLc2gAgBFFoGknm7x55Mrr4KqNWR159ENExMxDc3BTLee8WgKkwkrXOstcgXV1uXNEAXgQjiOSxNxdDb5i3p8IiIiyph3QPWud70LFy5cwG9/+1u88MILmJqaQnNzM7Zt24a7774bn/zkJ6cFTrmWL1+OF154Ab29vTPu+6EPfQg7duzAhg0b5jyn5557Dt/73vey1j3//PN4/vnn04+dgAqwBo8999xz+PKXv4wf/ehHSKVS2Lp1K77+9a/jrrvumvP1iIiIiIiIaP5q7Bp/1WIoVZjsCioPT6QEMi3+KreCyidLaPDJmEpoGI0mswKqi+E4AKCjQS36uYd8MtZ3NuLo4CT2nQvjrZsZUBEREZXKvAOqHTt2YMeOHYt+4TVr1sy5TyHhFAA8+OCDePDBB+f1+jt37sSjjz46r+cQERERERHRwuV06SKPCDPcp2xC1gyq2vpKKTkBlU+uvIAKsKqophIaRiNJrO4IpdcPTloBVWdjacKjbT3NODo4if3nw3jr5qUleQ0iIiJawAwqIiIiIiIiIqJa5x47JdVYQKXaLf10w0qN/UplXh5qC9pzqCLJrPWDdgXV0iZfSV73su4mAMDRgYmSHJ+IiIgslfkJhIiIiIiojtTYdU8iqnDuv3P498/M3C3+am4GlZT9ftQKnEEFAC1OQBXNCagmEgCApSWaD7WpuxEAcGRgsiTHJyIiIktlfgIhIiIiIiIiopIQ2NivIO4KKrHGrp7kBlJqhQ7ZavRbkymcVoQOp8VfyQKqLquC6uxIFJGENsfeREREtFCV+QmEiIiIiIiIag5HUFUGVlAVRqjpCqrsy0FyxQZUCoDpAdWlidIGVG0hFZ2NVvvAo4OsoiIiIiqVyvwEQkRERERERETkoawZVGJtB1S5Lf8qRVPAqqCaiKey1mda/JVmBhUAbErPofIioKrMPw8iIqJiY0BFREREREREVKfY7m9m7qqpGiugmtbSLzewqhRN6Qqq3ICqtBVUALCpy55DdXGiZK9BRERU7yrzEwgRERERERHVnBq7xl+1ssIW/qHMyP11kmosocqtmKrcgMquoIplWvyZpomRSBIA0NFQugqqjUutgIot/oiIiEqnMj+BEBERERERUc3hDKrKINRY2FIOtTaDSpVzZlBVaAvD9AyqRKaCaiKuQTesv01aQ0rJXnvNkhAA4MxwtGSvQUREVO8YUBERERERERHVqcqMJSqPWKEBzkLlVkzlBlaVotGuoJqMZyqoxuzqqaAqwSdLJXvt1R1WQDUwEUckoc2xNxERES1EZX4CISIiIiIiIqKSqK2opXRMV8lfjeVT02ZQVWoFVVPAqpCaiGUqqMaiVkDVGlRL+totQRVtIes1Tg9HSvpa09RYxR4REdFMGFAREREREXlseUvA61MgojrivvbNdn+FkSo0wFmo3AoquUJnUOWroBqPWmFVKdv7OdbYVVRlD6iIiIjqRGV+AiEiIiIiqgPf/fg1+Mxta/G725Z5fSpEVEcE1lAVJLuCqra+Zoqc/X5yK6oqRXoGVVyDaf+BlKuCCsi0+WNARUREVBqy1ydARERERFSvdm3qxK5NnV6fBhHVsdqKXUqn1gIqWcytoKrM99dkV1AldQMJzYBfkTBqz6BqKUdAtcSjgMqdjhIREdWwyvwVGSIiIiIiIiIqiRrLWkrGRCYkqLEOf9NaFlbqDKqgmvm96qmE1eYv3eIvWL4Wf6eGpkr+WkRERPWIARURERERERFRHXFHEQyrClNrM6jc70eRhIqdRSaJAvyKdekqltQBlLfF34q2IACgfyxW8tciIiKqRwyoiIiIiIiIiIhyuLusVWqAs1DugCq33V+lCdlVVFE7oCpnBVVPqxVQjUaSiCa1kr9eBlv8ERFRfajsTyFERERERERUMzhWpUII7ru1FbxQYSQhu4KqkgVUCQAQsQOi8Vj5ZlA1BxQ0+qyA7MI4q6iIiIiKjQEVERERERERUR1hKFWYWs5Ts1v8VfalIaeCymnxNxW3gqoGnzzjc4ppeWsAAHBugW3+UrqBxw8O4PsvnsW+/vEinhkREVH1K8+/5kRERERERERUEdzd6mqscx0VSK6igCpdQZWwgqkp+7bBX6aAqiWAIwOTOL+ACqozwxH8wfdewcmhSHrd725fhv/+/m3wK9LMT2S5KRER1YnK/hRCRERERERERCXDfGpmtfy1Ed0zqCq8xV/QDqhiKbuCKuFNBdX5eVZQhaMp/PvvvISTQxG0h1TcsmEJJFHAI/su4A+//xp0gyEUERERAyoiIiIiIiKiOlLZcUTlqOXqsmqqoAraLf4iCT3rNlSugKplYS3+vvqrw+gfjWFFWwCP/seb8X8+uRMPf+paBBQJzxwbwj8+c3KWZzO8IiKi+lDZn0KIiIiIiIiIqKgEV/LCy+AzE2s4oXLPoHKHVZXIqaCKJjUYholIsrwVVD2tQQCYV4u/Y4OT+PFr/QCAv/vAFehs9AMArl3Tjv/y7y4HAPz9r4+jfzSa/wBs8UdERHWCARURERERERFRHansOKJy1HA+lRVQVXoQF/LZLf6SOqIpPZ3dVHKLv+88dxqmCbz98i7sWNWWte39V/fghrXtSGgG/ubRI0U9VyIiomrDgIqIiIiIiIiIqI64q6bECq+gCih2i7+kjog9f0oUAL9Snktay5qt6qdLk/GC5kaFYyn8bO95AMAf3Lx62nZBEPBXv7sZAPCrAxdxejhSxLMlIiKqLgyoiIiIiIiIiOpIhRfMVAyhhmvN3KFUhY+gclVQaZhKZNr7CWX6Rm5v8EEUAMMERqYSc+7/m8ODiKcMrOtswI6VrXn32dTVhLds6oRpWtVWRERUv3bt2oXNmzfjvvvu8/pUPFHhH0OIiIiIiIiodnCuSiVwBy8cdVOf5Cpq8RewZ1BFkjqm4uWdPwVY7RDbG3wAgEuTcwdUvz48CAB4x5auWUO0T9xoVVf9fO95xFN6zlb+j0lEVC/27NmDQ4cOYffu3V6fiicYUBERERERERHVk8rOI6gMJDFzOajSA6qQaoVRMVeLv1AZAyoA6Gx0Aqr4rPslNB3PHBsGANx+2dJZ971hbTuWNfsxEdfwxKHB7I1MjomIqE4woCIiIiIiIiIiylXZuc2iuNv6SRU/g8qqoIq6W/z5PQqoJmavoHr1zBimEhqWNPqwdXnzrPuKooD3XtUDAHhk34XinCgREVGVYUBFREREREREZVLZF8LrRYUXzFAZZFdQeXgiBfAp1rkmdQORpF1BpZY7oPIDmLvF36tnxgBY1VFiAV/YO7d2AwCePT6EWDK3zR8REVHtY0BFREREREREVEfcl81NzrqZUYXnNotSTTOofLIdUGkGYkkDAOC3q6rKpbOpsBZ/r/dZAdVVva0FHfey7kb0tAYQTxl4+tiQawv/vyQiovrAgIqIiIiIiIjKhBddK4FQ4YEElZ47lKr0Fn+qHVAlNAPxlFVl5FfKezmrkBZ/hmGmA6qrVxYWUAmCgDsu7wIAPHFwYJFnSUREVH0YUBERERERERHVKaGm64RoJu4KqkoPqHyyVS2VSBmIa05AVd4KqiUFtPg7MTSFybiGgCJhU1djwcd+6+alAICnjg3BMBjiExFRfWFARURERERERFRH2OKP3KFUpVfU+dIVVDriKafFX5krqOwWf0OzBFQHL4QBAFuWN0GWCj+/q3pbEVQljEaSODIwubgTJSIiqjIMqIiIiIiIiIjqSIXnERWj0oObxXAHVFKFv810BZVmIOG0+JPLXEHVYAdUUwmYZv5Q1wmXNnU1zevYqixi5+o2AMDzJ4atlTO8BhERUa1hQEVERERERERlwWuuRJWhqlr82dVSSc1AQnMqqMobULWF1PQ5xOyQLNdRO6DaOI/2fo4b13YAAJ4/aQdUrGwkIqI6wYCKiIiIiIiIqI5w7hSJVdTiT5WcFn8G4k4FVZlb/AVVKX0eo5Fk3n2Opiuo5h9Q3bCuHQDw8ulRJO0QjoiIqB4woCIiIiIiIqKyqPDr4HXD/efAqrb6lFVBVeH/YzoVVNYMKiegKm8FlSAIaA0pAIDxaGra9nAshYvhOABg/dL5B1SXdTWhJaggmtRx+OLE4k6WiIioijCgIiIiIiIiIiLKUdmxzeK4q6YqvsWfPW8qpZuIJq2AylfmgAoAWoNWm7+x6PQKqrMjEQBAR4MPzQFl3scWRQFXrmgBALzRN8bkmIiI6gYDKiIiIiIiIiKiOuIumhIrPqDKXLqaiKemrSsXJ6DK1+KvfzQGAOhtCyz4+Ff2tgIA3ugfX/AxiIiIqg0DKiIiIiIiIioLFgVUhgrv6EZlILq+CSo8n8oKo8IxDUD5W/wBmLXFX99oFADQ2xZc8PGv7G0BALzRNw6Af1kSEVF9YEBFRERERERERFRH3KFUpc+gkiUxfb4TMSsc8ntQQdUyWwXVmBVQrVhEQLWtpwWAFXbpBgMqIiKqDwyoiIiIiIiIiOqIUNPTlYqnwnObRcmqoKr0Eipk5lCFnYDKgwqqNjugGs8zg6p/dPEBVXNAwQq7RWBKNxZ8HCIiomrCgIqIiIiIiIioTtVyCEMzy5pBVQXfAz7Funw1lfCuxV9L0GrxN5qnxV86oGpdeEAFAJd1NQEAkgyoiIioTjCgIiIiIiIiIiLKUQW5zYK5q+ikKkioVCn78pVfKf/lrLZQ/goqwzBxfjwGAOkKqIW6rNsKqFIaAyoiIqoPDKiIiIiIiIiI6girpkh0XQ0Sq+AbQpkWUJW/gqrVbvE3lhNQjUaTSOkmBAHoavIv6jXSARUrqIiIqE4woCIiIiIiIiIiqiPuUKoaKqhkKfsccyuqyqHRLwMApuJa1vqhyQQAoD2kQl7keW1mQEVERHWGARURERERERFRHan8OKIyCFVQWbRQYtYMqsp/n3JOiJYbWJVDgx1QTc4QUHU0+Bb9Gj2tATT4ZMBc9KGIiIiqAgMqIiIiIiIiKguTF10rjsC4qi65w7fqCKiyL1/ltvwrh0a/AmDmgGpJ4+IDKlEUsLGrEYLAvyyJiKg+MKAiIiIiIiIiIqoj7kjKg6xn3nIrpnIrqsrBafGX1A3EU3p6/dCUHVAVoYIKANYuCRXlOERERNWgCj6GEBEREREREVGx1HLrOiqMWG0VVDkpmiKX/3JWgyqn708lMlVUxaygAoA1SxqKchwiIqJqwICKiIiIiIiIiChH5cc2C5cVUHlQjTRfSs45KmL5L2eJomDNh0J2m7+iB1QdrKAiIqL6wYCKiIiIiIiIyqIKCjXqAv8YyP3/olQF/2NOa/EneXPOTpu/yXgqvY4VVERERAvHgIqIiIiIiIjKwjS9PgMiArKrpqqggApKTos/L2ZQAe6AylVBZc+g6ijSDKretiBE8C9LIiKqDwyoiIiIiIiIiIjqiDveqYYWf+5ASpEEz+aoZVr8ZSqoxiJJAEBbSC3Ka6iyCA9GbBEREXmC/+QRERERERER1ZEq6OhWGWr465Q1g6oKviFkVwWV7MH8KUejXwGQqaAyTRPhmBVWtQSVor2Ol++RiIionPgvHhERERERERFRHXEXTVVBAdW0Ciqv5Lb4iyR1aIbVjq85ULyASqqGPxQiIqIiYEBFREREREREZWFyrkpF8Ko9WrURariEqtq+B9wVVLnzqMopt4LKqZ5SJREBRSra67DFHxER1Qv+k0dERERERERUp6osp6AicRfoVENYpbhOWK6ICiormBqPWvOnmgJKUb+OrKAiIqJ6wYCKiIiIiIiIyqKWK1KIqkk1hFJu7lDKy/lMIdUKqKIpHQBKMn8KADzM4IiIiMqKARURERERERERUR1xF+iYZuW33nS3+FM97H8XUK3XjiXtgCpqBVTFnD8FZFdQJTWjqMcmIiKqJAyoiIiIiIiIiOoUCzVmVmVFRvPirqAyKj+fym7x52H7u4BdQeUEVONOBVWRAyrhzv8GALhP+z1cDMeKemwiIqJKInt9AkRERERERFQfTFTBlXCiOpBdQeXdeRTKXUHlvl9uAUUCML3FX3ORW/wJm/8dfq/hYbw5DFw5FsPK9tC8nh9P6RgIx7GsJeBpxRkREdFcGFAREREREREREdUR0VVBVQ3BsbtqSvVwQFNQtQKqWFIDAIyXqMUfALS0LQGGh3BufH4VVD974zz+4qf7EUnq6Gry4yvv3oK3bl5a9PMjIiIqBv4aBRERERERERFRjhru8JfVvrA6KqhcLf68rKByAiq7gmoybgVUjf7iB1TLWwIAgPNjhQdUjx24iP/4o72I2C0IBybiuOcHr2HP0UtFPz8iIqJiYEBFRERERERERFRHsiuoKp8sulr8eTmDymnxZwdAzm2DTyr6a/W02gFVgRVU4VgKf/4v+wEAH7muF4f/y9vxniuXQzdMfPHH+zBhh2lERESVhAEVEREREREREVEOoYZLqLLeWxWUUClZFVTet/iL28FUJKHZ64s/QaO72Q8AGAjHC9r/n1/tx1g0hTVLQvird12OgCrhb963FWuXhDASSeL+p08V/RyJiIgWiwEVEREREREREVEdEVBdFVSiq2pK9DA5TFdQ2S3+nFZ/oRJUUHU0+AAAw1OJOffVDRP/54WzAIBP3bQGqmxd7vPJEv70jk0AgP/93OmCjkVERFRODKiIiIiIiIiIiOpItVWHuUMpycsWf84MqjJUUC1ptAKqocm5Q6U3+sbQNxpFo1/Gu69clrXtjsuXYuvyZsRSOv751XNFP08iIqLFYEBFRERERERERFRHqiyfgiRURgWVE0QlNAO6YaZnUDmt/4rJCahGo0loujHrvr8+fAkAsGtj57SwTBAE/PvrVwIA/umVPhhGNdTMERFRvWBARURERERERERURwRXyFMFI6gqrsUfYLX3iyRLV0HVGlQhCtafz2gkOeu+Tx4ZBAC85bLOvNvfta0bjT4ZZ0eieOXMaNHPlYiI6s/XvvY17NixA42NjVi6dCk+8IEP4MyZM/M+DgMqIiIiIiIiKotquBBed6qtlIaKotr+2N1d/Tzs8Ae/krmMFkvq6VZ/pZhBJYkC2u05VEOzzI7qH43i2OAUJFHAbRvyB1RBVcbbLu8CADx6YKDo50pERPXn6aefxp/8yZ/gpZdewmOPPYbR0VG84x3vgKZp8zoOAyoiIiIiIiIiohz1EqhWwzwq99wpL2dQCYKQrqKKJXVEEnZAVYIKKgDoaJh7DtUzx4cAADtWtqI5qMy43zu2WAHV4wcH2OaPiIgW7bHHHsPHPvYxbN68GVdeeSW+/e1v48iRIzh06NC8jsOAioiIiIiIiIiojlRDKOUmVsgMKgDw2VVUcU1HLFW6GVRAZg7V8NTMLf72nwsDAK5e2TrrsW5a34GQKuFiOI4DF8LFO0kiIiqbH/zgB/ijP/oj7NixAz6fD4Ig4MEHH5z1Oa+88gruvPNOtLS0IBQK4brrrsOPf/zjop9bOGz929LW1jav55XmVzyIiIiIiIiIclTbRXGiWlV1M6jcAZWXPf4AqJIVUI1HU+l1IV+pKqhUALNXUO0/b10Q3Lq8edZj+RUJ169tx68PX8KLp0awraelaOdJRETl8Zd/+Zc4e/YsOjo60N3djbNnz866/549e3DHHXfA7/fj7rvvRmNjI37yk5/grrvuQn9/P77whS8U5bx0XccXv/hF3Hnnnejp6ZnXc1lBRURERERERGVRDRfCiajySK6rVx7nU+kKqrGoVdUkCIBPLs3ltSV2i7+RGWZQJTQdxwYnAQBb5gioAOC6Ne0AgBdPjRbpDImIqJweeOABnDlzBkNDQ7jnnntm3VfTNHz605+GKIp45plncP/99+Mb3/gG9u3bhw0bNuDee++dFnB96UtfgiAIsy65TNPEPffcg76+vjmrufJhBRURERERERFRnRLAsjaqfO4KKsnjUsxMBZUVUIVUOe8Fu2JoCVoVVGOuai23owOTSOkmWoIKeloDcx7PCaheOT0KTTcgS/y9dSKianL77bcXvO+TTz6JkydP4hOf+ASuuOKK9Prm5mbce++9+PjHP47vfe97+Ku/+qv0ti984Qv4+Mc/XvBrmKaJz3zmM/j1r3+NZ555BkuWLCn4uQ4GVEREREREREREdcpE5Zc2ugOqUoVBhfLJ1ryp0YgVGgVKNH8KAFqDCgAgHMs/g+rA+QkAVnu/Qr4ul3U3odEvYzKu4dDFCbb5IyKqAJOTk5iYmEg/9vl88Pl8iz7uU089BQB429veNm3bHXfcAQB4+umns9YvWbKk4JDJNE3s3r0bv/zlL/H0009jxYoVCzpP/qoEERERERERERFVLMnV18/roh/VbucXjtkBlVK6gKrFDqhmqqA6OTQFANiwtLGg40migGtXW8PrXzw1UoQzJCKixdq8eTOam5vTy9e+9rWiHPf48eMAgPXr10/b1tXVhYaGhvQ+C7F792788Ic/xMMPP4xAIICBgQEMDAwgmcz/SxUzYQUVEREREREREVGOyq8rKo5qaPMougIq0fMKKiugmoynsh6XQqbFX/6LfWeGIwCAVR2hgo953Zp2/PrwJbx4ahR/eMvaxZ8kEREtyqFDh7B8+fL042JUTwFAOBwGYLX0y6epqSm9z0J861vfAgDcfPPNWev37NmD2267reDjMKAiIiIiIiKisqiXC/5EVFyufCorrPKCmg6oNACATyllQGW3+Juhgur0iBVQrW4vPKDaaVdQvXZ2DKZpet4ykYio3jU2NqKpqcnr05g30yzOJ3u2+CMiIiIiIiIiqlPVMINKEtwVVB6eCDIzqJwKKrWEPQdb7Qqq8Vhq2oVATTfQPxoFAKzqCBZ8zI1djZBFAeFYChfC8eKdLBERVRSncmqmKqmJiYkZq6vKiQEVERERERERlQV/T5+IFsJdNSVVSIu/CaeCSi7dDKrmgFVBpRsmJhNa1rYL43GkdBOqLGJZc6DgY/pkCes6GwAAhy5MFO9kiYioojizp/LNmRoYGMDU1FTe+VTlxoCKiIiIiIiIiIgqlnvulNct6XJnUKklnEHlVyQEFCsAG49kt/k7Y7f3W9kWnHfbw83dViupwxcZUBER1apbb70VAPDEE09M2/b4449n7eMlBlRERERERERUFpXfSKz+SF73SyMqgLuLntffs9NmUJUwoAIyc6jGY8ms9RfGYwCAntbCq6ccm5dZARUrqIiIatdb3vIWrFmzBg8//DD27t2bXh8Oh/HVr34Vqqriox/9qHcnaJO9PgEiIiIiIiIiKq8P7OjB6eEIrl7Z6vWpVKwizf6mIhAragZVTkCllK7FHwC0BFVcDMcxFs2uoHLmR3W3LCCgsiuoDrGCioioqjzwwAN47rnnAAD79+9Pr3vqqacAADfddBM+9alPAQBkWcYDDzyAO+64A7fccgvuvvtuNDY24ic/+QnOnj2Lv/3bv8WqVau8eBtZGFARERERERER1Zn/9v7tXp8CUcGyAqoKqaCasmdCqVJpK6hanQqqaHYF1UW7gqq7yT/vY15mB1R9o1FMxFNo8iuLPEsiIlqoXbt2QVEU7N69G7t375513+eeew7f+973stY9//zzeP7559OPnYDKOfZzzz2HL3/5y/jRj36EVCqFrVu34utf/zruuuuu4r6RBWJARUREREREREREFcvd1k/0fAZVdsWUTylPi7+xSE5AtYgKqtaQis5GHy5NJnDy0hSu7GUlJRGRV/bs2YOenp6C9n3wwQfx4IMPzuv4O3fuxKOPPrqAMysPzqAiIiIiIiIiIqpT1dDK0J1JSR4HVGrOzKlSV1A1B6yAymkp6LgYtiqoljXPv4IKAFZ1hAAAZ0eiizg7IiKixWFARURERERERESUw0QVJDd1QqqgGVS5AVWpK6gafFbzI6elIACYprmoCioAWN1uBVRnRiKLPEMiIqKFY0BFREREREREREQVK6vFn8cJlZzz+rkt/4qtwWdXULkCqom4hmhSBwB0L7CCamVHEAArqIiIyFsMqIiIiIiIiIiIqGIJQuXMoJoeUJW4gspvV1C5Wvw586hCqgS/srCAbBUrqIiIqAIwoCIiIiIiIiIioorlrqCSvK6gypk5VeqAqtFu8TcZT6XXjUatgKo1pC74uCvbWUFFRETeY0BFREREREREZWFypA9VE36/Vgz3DCqPC6igSNknkDuTqtga/dNnUDkVVG2LCqisCqrRSBLhWGqOvYmIiEqDARUREREREREREVUsdygled7iL/tSmiqVp8XfpKvF36gdULUGFx5QNfhkdDT4AAB9rKIiIiKPMKAiIiIiIiIiIqpT1VAo5s6kPJ9BlVNBldvyr9gafHkqqKKLr6ACgFV2mz/OoSIi8s6uXbuwefNm3HfffV6fiidkr0+AiIiIiIiI6oPXrbmIqDoJyPzlIXo9gyqngiq35V+x5WvxNxqxWvItpoIKsNr8vXp2DGcZUBEReWbPnj3o6enx+jQ8wwoqIiIiIiIiKgvOoCKihciuoPLuPIDpFVRSiU+owacAAKbiGkz7L9HMDCplUcde3eFUULHFHxEReYMBFRERERERERERVaysGVQeJ1S5FVNyqQMqu4JKM0zEUwYAYNRu8de6yBZ/K9tDAMAKKiIi8gwDKiIiIiIiIiIiqljuuVOC1zOoclr8SWJpL60FFSkd0E0mrNZ+6QqqRbb4622zKqj6R2OLOg4REdFCcQYV1Q3TNGFEojAmwjANA4IoQvD5ILW0QJAkr0+PiIiIiIiIKgg7UlYOdyQleR1Q5VZQlXgGlSgKaFBlTCY0TMY1dDYCY3YFVcsiA6ruZj8AYGgqAU03IEv8PXYiIiovBlRUk/SpCMZf/i0GX3gKkeNHIZ09D3VoApJmTNvXFAWguQm+7mUIrF8Pdd06+DduRGD7dkjNzR6cPRERERERERE5KmoGVU7FVKlb/AFA0CdhMqEhltQBAFMJDQDQ6F/cZb32Bh8kUYBumBieSqLLDqyIiIjKhQEV1YzIwHkc/NE/IvXrp9F88hIkw/otq4ac/VISYAiAaAKKDgiGCYyFkRwLI3nocGZHQYBv/XoErr4KoeuvR+iGGyA15B6NiIiIiIgKZbImhajimFXwv6W7rZ9YYRVU5ZiJFVCsri/xlB1QxYsTUEmigKWNPlwIx3ExHJtXQLW3fxxtQRUr2gKet10kIqLqxYCKqloiPIa9//wPmPrVo+g6PIRG1wfrgRbg5EoF0Z5GCMvaIC5bArWtDWawAbqvAZPQMTI1hMmh8whfOAPf0ARWDAMrhkysHjCxbMxE4tgxJI4dw/gP/wmQZQR37EDDrbei8a23Q+3p8ex9ExEREREREdULYcYH5adMq6AqfVs8vx1QxVI6dMNExK6kavAt/rJeV7MfF8JxDE7EC9pfN0zcff8LeOXMGADgw9f24q/fs3XR50FERPWJARVVnWQ8gr0//y7Gfv4TdL05gCYNaLK3ne4Gxtcl0dEdwxVKHLt060MbJgEczTmQpAKNXUDTcpg9mzCyqQNHVBmvG1P4P9ELOHP+JNadS2Fzn4krTppYNqYh+uKLiL74Ii59/evwb92KpjvvRNPb74DS3V3GrwARERERUXUSvL6yTDQPZjWUFhVBNRS/VFKFjicVVKpTQWUgktTS6xsWWUEFIF01dWG8sIDqgWdPpcMpAHjopT6sbA/iD29Zu+hzISKi+sOAiqqClohj78//AZce+QmWvjmMxgTQaG8baAPG1qWwpmcSb/drEEOdQOMKwN8MiAogyoCRAuITQGISiI4AkSFATwLjfcB4HwQAHQBushcAiAkCXm4NYc+qNvzX20XIYzquOmFixwkTm/tMxPfvR3z/flz6+tcR2LgCzTdfgaabtkNqarB6JJhG9mLorsf2fcN+DACCaP1kIAj2fRGA674gAIIEiM4iuxb7sSBNX5e+n/sce33ucyrogz8RERERERGVVjXkcO4MyOufWJWcgCr3cSn45UwFldPeT5VE+Oz1i9HZaAVUQ1OJOfc9MjCBrz16BACwuiOEd23rxv948gS+/thRXL2yDVevbF30+RAR1Ztdu3ZBURTs3r0bu3fv9vp0yo4BFVWs4ZNv4OC//E9MvbIXS45H0RgDVtvbxhqAS+s0rNi2BDfveCvk5VcAnZcDbWsAqYBvaz0FTA4AExeAiXPAxEVg4jwQPmevO4/A5ABujUzh1sgUDAD7fSoe2xjE/7oqhFRCwnVHTFx/2MCmfiB2tB+xo/0Y/M6/orEnhubVMYSWJqyMqdpMC7nmCsPybS/ic9LPmyFgmxbEyXM8L9855L6WDJShTQMRERFRveEMKiJaiEqqvpRyflYsawVVUsdUojjzpxxLGn0AgKHJuQOqH77UB8AK5X71H25GQJVw4tIUHj0wgK/+6jD+v3uur6hqNyKiarBnzx701PEoGQZU5JlkNIzY2ACiYwMYPb0PY6cOYqrvDPTTg2i8mMCSMaAT1gIAE0FgcIOMrus345o7PgZl9c2Ar3G2l5iZpAAtK6wF1+bfJx1inYcYPoftsTFsj4fxxXgYL0+dxS+b+/Hfrx6FOqnjhsMmbnvTwMohARN9QUz0BSE3Smje0oTmbS3wdfgzlVCi5KqIciqlkKm6gmnfd1VhwXRVYelW5ZWhuRY9c9/Up6/LeqxbFWUzMXVA1wF97g+ntU0oQRiWZ7swSzA3U+WbIM7xnIW+Vr73JbKqjoiIiIiIPOX+kcTrAETOCaTKMYMq4JpBNWlXUBWjvR8AdDSoAIDhOSqowrEUfvL6eQDAAx+7Jh2a/effuxxPHrmE186O4bcnR3Djuo6inBcREdUHBlRUdlo8isNXXQ3ZyKwTAbTbi9vFTiC2oQWdN9yIK9/7eagty8p3olkhlms1gOvt5S+1OJ4+9zR+cdkv8Oc7n8WKAQ23vWngpkMmGid1jLwwhpEXxhC44go0v/c9aHrHOyA1LjBUK7ZpIZdmh2AzBF/GLMGXOVMYNp/n5TvGLMc1823P97ycfXLfo2nM8AUyrSBvtjCvXswabMnW/yuSAojK/O5Lquv5s91XrMrIOe+7X0O1tjn32b6SiIiIiOaJ9X6UjyKVv4LK7wqonAqqBl+xAiqrgmqugOrR/RcxldCwYWkDbnaFUJ1Nfrz/6h489FIf/nXvBQZUREQ0LwyoqOxkfxCm6/ObIQCTQWCyWUCyRYXYuwTt267E2tvuxmVrrvLuRAvgl/24Y9UduGPVHRiLj+GxM4/hF1t+ge8P7MPVJ0zc9qaJK0+ZiO3di9jevRj86tfQ+Na3ouW970Hw2msheNlGThQBUQWgencOlcAw8lSd5Qm+ShLeafmDQmefWavh5hvezfEeTX3mr1GtVNXlC65y7+cNuxZzX7Xvy7Pfl32ZQM7ZJvns82G7SSIiIiIqnWpovSmWIQQqVO6plGMGVUC1fiaIJTMzqIoeUE0mZ91vb/84AOD2y5ZO+/N4+5YuPPRSH548egmGYVbUnxcREVU2BlTkic4f/yN8Te0ItnRBCbZAlBY/2NNrrf5WfHDTB/HBTR/EmfAZ/PL0L/GDHY/gHwfO4eYDJna9aaBnJI6JRx7BxCOPQF7WjZZ3vxvN73431N5er0+/fokiANEKBOqZ01JyXiFcynqspwA9aT3WtXnetxfDPoauFed+vso4pyKu2oriBCkTXMnuAEvNWZQ8+82x72zBWN7j5TsWK9SIiIiIqLSEGe57ITd8KcsMKruCKp7SMZWwfqAp1gyqDnsG1UgkAdM0Z2yh+PLpUQDA9hUt07btXN2GoCphaDKBgxcmsLWnuSjnRkREtY8BFXmi6/JbvD6FklrVvAq7r9iNz2z/DPYO7cUjVz2C/+f0Y1hydgK73jRw4yEToQsXMfy/voXh//UtBHfsQPN73oOmt98BMRTy+vSpHglCZo4WfF6fzeIZhiu4SlmB2pz3cwOz+d5PLjygc85BT1jn5GbqgBazlootYhPyBF7ObZ4QbFow5jzXufXZoZj71pe937Tn5HkuK9CIiIiI5iR4HvnMrZJ+F0oSKmQGVZEqqNpDVleVlG4iHEuhJTi9y8r58RhODUcgiQKuX5s7nAHwyRJuWteBJw4N4skjlxhQERFRwRhQEZWQIAi4svNKXNl5Jb6080t45twz+MV1v8BDZ57ClUdSuO1NE9tOm4i++iqir76Ki//1K2h++9vR/J73ILhjh7ctAImqmSgCoh1qVJvccE1LuAKsZJ7FXq8lcvaxA6/cdbMez7VOy7POvWQx7deqwARNVOYIshYYkE0Lw5TCn1tJV1iIiIio7lVFiz/X5yevP0qJQvkrqPyqHVAldUSTVnv2gFqcS3p+RUKjX8ZkXMPwVCJvQPXa2TEAwJZlTWjy5+8+8jubOq2A6uglfPb29UU5NyIiqn0MqIjKRJVU3L7ydty+8naEE2E8fuZxPHLyEfzDyTdw6wETt75pYNlYHOGf/gzhn/4MWncHWt71u1jyu++Gb/36GcvsiajGVEO4ZpqZ6rPFBmPp5ycy+2TdJuywLN9tnv2NnB6ORgpIVlhfRyk3wHIFWbLfvnXdz7vNfiyp9vo822Z8rt+af8Z/V4iIiKhKVNKnltzfIy3LDCpXBVUspWetK4YlDT5MxjUMTSaxrnP69v3nxgHkb+/nuHXjkvS+E/HUjEEWERGRGwMqIg80+5rxgY0fwAc2fgBD0SE8c+4Z/Kx/D0Ze/i1u2BvH9YdNBC8OY+rb38XUt7+LyWUtEN9yI3rf/UF0Xn6116dPRPVOEOwWfQqgVlhbUtMsMOSaIwxLh2qFPnem10pOD82ccG72OdSlJYg54ZXqCrH8BYZieYKvadtyAjT3cSWFIRkREREVxvWRweuPD15UULlnUMXtgMqvFK/jSkeDD6eGIxieyt8R4djgFABgU1fTjMfobg5gdUcIp4cjeProEH53+7KinR8REdUuBlREHlsSXIL3bXgf3rfhfYjfGsdLF1/CE6f2IP7k01j7+iCuOGWi8cI48P1fYuT7v8SBTgUXd66GeusNWHPNW7ChbSMa1Uav3wYRUWUQhExgUikMww6l5gi50gFX3H4ct/Z3P9adx7Nty/NcPZHdmtE0MrPNvJQvFFP8rlDLP//Hhewj8SMwkWcqv5MYUZrJ79eKkRsKeSk3kCrHDCp/OqAyEE8ZWeuKoaPRaus3U0B14pIVUK1f2jDrce7c2oX79pzE4wcHGFAREVFB+NM5UQXxy37cuuJW3LriVuBWYDg2jFdPPoOLj/0rmp7bjw3Houi+lEL3L44BvziG0YYH8eA6ASc3t0K7ejNWdW7EutZ1WNu8FisaV6DZ18zWgEREXhNFQLRDEi8ZhisAywmv0qHWXMFXYn6hWNb+eeaUOa+JcHm/FqK8uBBszpDMBygBO3wLZB6ztSIREdGCVNK/nl5UUKmyFYIlNQOJErT4awtZAdVYdHpr7FhSx/lx65ea1i6ZPaC6fk0H7ttzEvvsloBERERzYUBFVME6Ah14+5b3AlveC3wRCA9fwIlHHkZsz9No2ncabVM6bt9r4va9o0j+6Dkc6n0eL68S8J2VAs4sBUK+RvQ09qCnoSd92xnsRGewEx2BDrQH2iGL/GuAiKguiCIgBqygxCtONVne4CthV3XZ61Ix17oFPNbiQCqeWeeuIDM0IDllLeUkiNmBVW6A5TxW/K5bP6AEXevsZdp6+7Hsz96HgRgREdUA9y9eel3ZlptHyWUIqBTJCqgSupGeQVXMFn/NAWte1ERsekDVNxoFADT5ZbQGZ58rtbWnGQDQPxrDyFQC7Q0V1NWAiKhC7dq1C4qiYPfu3di9e7fXp1N2vDJNVEWaO5bh6k98EfjEF2Ekk4i+/ApGn3wCk0/ugTowhCtOm7jitPVpfcoPHOwN48DKCRxceRi/6QDMnItUAgS0+duwJLgEHYEONPua0aQ2ZRafdduoNsIn+eCTfPDLfutW8sMnW+tEofQtDYiIqAZ4WU1mGK4qsbgrwCrxY3fVmGkAqYi1lKu7ohNYZQVf7pBrnoFX+n6efSQOQyciotKopF+3yK2YEstYQZXSjPQMKl8RK6icgGo8On1I6unhCABgVUdozg4tzQEF6zsbcPzSFF47O4a3Xd5VtHMkIqpVe/bsQU9Pj9en4RkGVERVSlRVNNx0IxpuuhHmf/p/kDxxApHf/haRF19C9JVX0DA1hWuPmbj2mBVYJYMqBlY14lSPjAPdKbzeMYUp1cBIfAQj8ZHFnYsgWgtESKKUvi+KIiRBggABkiClH0uCtY+zThbkzGPBOoYsygjKQYSUEIJyEEHFXuQgmtQmtPnb0OpvRau/Fe3+dgTkANsZEhHRzEQRUIPWUk7pqrH5VILFXbdRO/iKWeuy7ufuYz/PcP32c7qN4ljp36sg5an2misUmynwCtrbgpl9FPvPTw5wjlg148c1ospTBbO2xAqqoPLi507VrqBK6qWZQeUEVOE8FVRnR+yAqj1U0LF2rGplQEVERAXjT3ZENUAQBPjWr4dv/Xq0fexjMDUN8UOHrLDqxRcRfeMNqNEYeg+NoPcQcJv1JEhrV0O/bA0m13RipLcFQ8uCGDcjmEhOpJfJxCQmU5NIaAnE9TgSegIJLQHN1NKvb5gGDNOwH3jxFQB8kg/t/nZ0hbrQ3dCN7pC1dIW60B3qxvKG5QgqZb4oSURE5EXVmK7ZIZYdXjnBVSqWWT8t5MoXeOUeY4bjOUwdSE5aS6lJak6AlRNkKQFACeVZl2c/NZR/m6SyRWIpVMGFcCKqQK6/juvxrxH3DKpYCWZQzRZQnUkHVIX9PH1lbyt++HI/9vaPF+38iIiodjGgIqpBgiwjsG0bAtu2AX/4aSuwOnoUsb17Edu7D7G9e5Hq74d+4hRw4hQaATQCWCVJ8K1bB//mzfBv3gb/5Zvh37QJYnD6B1HN0JDQE4hpsXRANdOim3rWfd3UoRvWOs3U0ts0Q0vvk9STiGkxxLQYoqkoIqkIopp1G06GMRYfw1h8DKPxUSs00xO4ELmAC5ELwKX8X5fOYCdWNa3CyqaV6aW3qRcrGlZAYVsiIiKqFZIMSI2Ar7H0r2WaduVXNKfya6bAK18o5gRhsex90ks0c+tcltST1hIPl+69CWKeQCs4QxiWp9Jr1v2Cmcoxka2SiSpVPQYhlcr9+wKm1yVUHkhXULla/BV3BpUKIH9A5W7xV4hNXdbnj5NDkSKdHRER1TIGVER1QJBlBC6/HIHLLwc+/GEAgDY8bAVW+95E/PBhxA8ehD42hsTRo0gcPYrwT39qP1mAunKlVaG1YYO1rF8PdWUvQkoIIaWwD6mlYpomYloMo/FRDMeGMRAZwMXIRWuZsm4vRC5gMjmJS9FLuBS9hJcHXs46hiRIWNawDKuaVmFV8yqsalqF1c2rsappFToCHWwdSERENBNBsFvzlaFCzB2GuUOrrNucdcnozNuy1tmPkxGrEgywZoYlp6yllOTcECsn5FJDObdBqzosa3sgzzr7ePwcQ0Q1oN7/JkvPoNINJLRyt/iLAgBWFtjizwmyhqcSmIin0OTnL4MSEdHMGFAR1Sm5owONt9+OxttvB2AFPdrgIOKHDiF+8JB9exDapUtInjmD5JkzmPy3f0s/X/D5oK5dA//6DfBtyIRXcmdnWQMdQRDS86l6GmceKDgeH8fZybM4O5FZ+ib6cGbiDGJaDP2T/eif7Mez55/Nel5ICaWDq5VNK7G6aXX6fkAOlPrtERERkSMrDGsr3evoqdmDr2Rk9lAsvW6W/bR45vU0u3IsNlqCNyPMEHYFMy0QswKvPCFX7nr3Os4CI6IyqfdfGsxq8ZcsQYu/YCagMk0z/fWOp3RcDFv/ZhXa4q/Jr6CjwYfhqQTODEewraelaOdJRES1hz9REBEA6wO/0tUFpasLjb/zO+n12vAwEseOIX7sGBLHjiNx/DgSJ07AjMWQOHQYiUOHs44jNjRAXb0avjWroa5eDXWVc7sSos9X7reV1uJvQYu/BduXbM9ab5omhmJDODtxFqfDp3Fm4gzOhM/gzMQZnJ86j0gqgoMjB3Fw5OC0Y3aFuqzwyg6wnPCqK9QFUWC7HiIioqokKYDUDPibS/cahuFqZ5hb6RV1rYtktiUj+de7n+/skw7ATDsoK1GbJUktoLormBVumcPLAdgXOY89kb2vsyghQFZLc85EVPXqr8FfJqBK6AYkrRQt/qyAKqWbiKV0BFXrcuHQZMJ6fUlEW6jwv5fXLAlheCqB0wyoiIhoDgyoiGhWckcH5I4OhG64Ib3ONAykzp1D4tgxJI4fT4dXyTNnYExNIb5/P+L792cfSBCgLF9uhVWrV8G3ejWU3l6oPT1QurshqN5chBAEAZ3BTnQGO3FN1zVZ25J6Ev2T/TgTPoPTE6fTwdWZiTMIJ8IYiAxgIDKAFy++mPU8v+RHb1Nvet7V6ubVWNawDEuDS7E0uJTzroiIiOqdKGbCmFIw9AJCrjmCr7zr7cow02ovlZkFNl74uSX/GMDN1v2Hf3/m/UTF/ho1WCFW+r6r5aHzWA267tsBVzrwcj2fVV80T/U466hS1Xf9FKBI1lcgqRnpr4VPLl4FVUiVIIkCdMNEOJZKB1SjkSQAoC2kzquKbU1HCC+fHuUcKiIimhM/nRPRvAmiCLW3F2pvb7pFIAAYySRSZ88icfo0kqfPIHnqFBJnTiN56jSMyUmkzp1D6tw5RJ7NbqMHUYTctRRqzwooK3qs0KpnBZSe5VCWL4fc0QHBgwHiqqRibctarG1ZO23beHwcZybOTKu66pvsQ1yP49jYMRwbO5b3uO3+dnSFurA0uNS6DS1FR6ADbf42tPpb0eazbv1yGeZ5EBERUe0RJcDXaC3F5p4Flg6x5gi5XKGYcKwHGLOP1b19+r66dTEURsoKvuYTfhVC9tsBlyvQyg241Iacqq6ZQjLXfQ8+qxIVC2O4yueTMmFUKWZQCYKAloCCkUgS4VgK3c1WO3t3QDUfq+05VKeHGVAREdHsGFARUdGIqgrf+vXwrV+ftd40TegjI0iePm2FV6dOW3OtzvUjde48zHgc2oWL0C5cBF5+efqBZRly5xIoS7sgL10KZelSyF1dULqWQl66FHLnUsgd7RD95Qt0WvwtuMJ/Ba7ovCJrvWZouDB1ISu8OjtxFgORAQxGBpE0khiJj2AkPpK3baBbUA5agZW/Dc2+ZjQqjQipIetWCaFBbUCjat9XGuCX/fBJvvSiSip8kg9+yQ9ZlOu+bzsREREVgXsWWHABs8B+tBcYO2/d/6Nnpm/XkpmQKxkBklOuAMy+n29JzbDeOYap28ePW0uxZ37lBlqFBly+Bnudcz+UecxqL6K0ev9Rxmnx51bMFn+A1eZvJJLEeDSVXucEVO0NCw2opop3gkREVJP4iZeISk4QhHSrwOA12W30TNOEPjyMZP85pM6fQ7K/H6l+q9Iqee4ctIEBQNMyAdYsxGAQUns75LY267a9DVJb7m0bpNZWyC0tJWkrKIsyept60dvUi1t6bpn2XscSY+mwaiCauR2NjWI0Poqx+BhGE6PQDA1RLYroVBTnp84v+rwECPBJPiiiAlEUIQkSREGEKIiQBRmiIEISrXWSIKW3S4IESXQ9tu+nF3H6frnPVUUVASUAv+SHX/YjIAcQkLMfB5UgmtQmNKlNCMgBhmlEREQ1as5KDVm1lkBrEV/UtCqz0oHXfMKvqUyFV3of12PnHTmVYpGh4p237M8EWr7GTHjlDrXUkP240XU/ZD1231dDgOzjVX6qWkKdN/nLF1DlW7cYTfYcqnBsekA13wqqNUsaAACnhyIwTZM/3xER0YwYUBGRpwRBgLxkCeQlS4Crrpy23dQ0aMPD0AYGkBoYhHZp0LodGEDq0iC0gUFog4MwUykY0SiMaBSp/v6CXltsbLTCqtbWTHDV1gqptRVSaxukNve2Noih4KI+WAuCgDZ/G9r8bdjcvnnG/UzTxFRqKhNYxUcRToQRSUUwmZpEJBnBVGrKWpKZ24SemLakjwkTcT2OuB6f8XUrhSRIaFQb0aQ2Zd02qo1o8bWg1d9qLb7WrPsMtoiIiCgvQbDCGdm3sKqvmZimPZcrmhNeTc1Q2ZUnHHNuE1NActK677Q5dKq9osPFOV9RzhNeNeSEXnOFYY3Z7RBr/LMXW99RpZBEIT0jyqEUubVoc56AamSBAdWKtgBEAYgkdQxNJdDZyPb1RESUHwMqIqpogixD6eqC0tWFwAz7mKYJY2oK+sgItNFRaCMj0EdGoY26bkfHoI+OQBsZhT4+DhgGjMlJazZWX19h56IoOUGWdV9qa7Wqtlpas0OtlhYI0vz7gguCkA5kVjatnPfzHaZpImkkkdATSOpJxLU4UkYKhmlAN/X0rW7o09YZhgHN1PLuk++xYRrQjOn7a6aGlJ5CTIshpsWskEyzlpges261GCKpCCaSE9AMDbqpYzwxjvHE+Lzer0/yZQVXLb4WtPnb0qGWc7/F14JmXzOafc1QpeJX0REREVGdEAS7dV8QCHUU77ha0g6upuzgKmKFVwl3qOVsy90vZ71TIQYAhgbEw9ZSFEJ228JZK7nssMvXZN+31/sa7X0aWeFFs8r61qjT5FCRMgGVJAoQxeL+/9Lgsy4RRhNaet1oxPqlx7bg/H5u8skSlrUEcG4shjPDUQZURESz2LVrFxRFwe7du7F7926vT6fsGFARUdUTBAFSYyOkxkaoq1bNub+p69AnJqCPjUEfHYU2NmYFWGNj0MdGoTn3R0ehjVvbzHgcZioFbdCq2ErM+SoABAFSUxOkJR1Qli2zl+VQlmfuy0s6IJRoqLYgCOl5VNXANK0qr4nEBCaTk5hMTWIiMYGJZGYJJ8IYi49hLD6G8cR4usrMCeIGIgMYiAwU/JoBOWCFVWpzOrRqUpvS95vVZjT5mhBSQtYihxBUgggp1q0iKiX8ihAREVFdklVAbitetZeh56/UKij8igCJyenhF0xrcfYvBlG2w6smV3jlDrOaXEGXK9jK2td+ruwvSthVL3HZ9WvbvT4FKoAiiYinDPt+8b87g6r1y5WRpJ5eNxqxqqna5jmDCrDmUJ0bi+HU0BR2ri5i9SoRUY3Zs2cPenp6vD4NzzCgIqK6I0gSZLu1H9asKeg5RjQKfWzMCq/G7fBqhlBLHxuDHg4Dpgk9HIYeDiN54mT+c1EUKMuXQ121Curq1fbtKvhWr4bU0VFXLesEQUjPp1oaWlrw80zTREyLYTQ+mg6txhPj6faI6XXxcYwlxhBOhDGRnIBhGunKrvmEWm6qqKbDqqASREAKQJEU+CQfVFGFKlmLM//LJ/nS62RRzprXlb4Vsx/nW5f3Vpxje85xcrfJogxVUiEKpQlMiYiIyCOiBPibraUYTDMzlysxmRN+zVT5Za9LTNj37fWJSasNImBXeI1by2Klw67GPEFXvnArN+iytpuGPvdrVbGX7n0LTlyawg0MqKqCImU+pxe7vR8AhOwKqlhWQGX9amb7PFv8AcC6zgY8e3wYp4YjxTlBIiKqSQyoiIgKIAaDEINBKMuXF7S/qWnQx8ehj40hdekSUhcuIHX+vHVrL9qANTsreeYMkmfOAE89lf2aDQ1QV62Cb/16+DZugH/jRvg2boTcxt8+cxMEIR0Q9TQW9hsnhmlgMmlVaIWTYYQT9uK671RshRNhRLQIoqkoIqkIIqkIUob1m4RJI4lkIomxxFgp32JZqaIVqPlkX7oCL7241vklv/V1l4MIKAEE5WD6cdb9nFtFYtUZEVE9u35NO376xnmvT4MWQxAyM6gaOhd/PEPPDrsSk5kl72N7XXIyO+hyV3MVK+xK/gmA6637f7fFFWDlC7qaMoGYvzlz371enH/771Ja2uTH0qbqaL1WR7+3NyPJ1dJPkYsfUGUqqNwt/pwZVPPvyrG8xWrSfzFc+XOQiYjIOwyoiIhKQJBlyB0dkDs64Fu/Pu8+ZiqF1OAlpPr7kDxzBonTp5E8bYVVqfPnYUxNIX7gAOIHDmQ9T1rSAf8GK6zyb9wA/+WXQ129ekHzruqVKIjpNn4rsGLez0/pKUQ1K7CKpqKIaFZwldASVmilZ2Z/JfVk1iwwZ5t7ZteMt4ZR0D4zbZ9rn3yShnW+k6nJxX6Z85JFecYwK6BYFXR+yQ+/bC0BKQCfbAViATmQXp/ex3WrSioUUYEsynVVfUhEVE3ef3UPAqqEK3tbvD4VqhSiBPibrGWxDCMTas1YtZWniiuZE3TZtx+WfoNfGtfjWuEQEO5f/Pm5q7qyAq0m1+M51qsNQIlahFNlU9wBVQlb/EUTmZ8TRtIB1fwrqJzwc5ABFRERzYIBFRGRRwRFgdqzHGrPcoSuvz5rm5FMItXXh8SpU0gcO47E0aOIHzuKVF8/9KFhRIaGEXn++fT+YjAI/5Yt8G/dgsDWrfBv2Qpl+TJepC8RRVLQLFkBVzVzB1aaoaVDtLgWR0JPZC9a9uO4FkdMiyGqRRFNRbNvnfuu9UnD+uFWM7T0TLFSkgUZsiinA6v0raRAFnJu7XaLkihBERRIopRue+hsc9+695dFGbIgW/ddt7Iop4+jiDnHdPab4ZiznoN9fP6/TUTVShQF/O72ZV6fBtUqUSxq2HVDcgq/HR5Fp7IWSP2OK9zKreJyhV+JSes2PpHZptsTbJ0qr8mLizgxwRVyLSboClVVWZLgmghmwvTwTLwjuUIpuQQhZVC1LhE6FVQp3cBk3Lq/kBZ/3c1WQDUwwYCKiIhmxoCKiKgCiaoK37p18K1bB7ztben1RiSCxIkTiB89isTRY4gfPoz44cMwolFEX34Z0ZdfTu8rtbXBv+VyBLZvR/CqqxDYtg1iKOTF26EK5cyiAgCf5ENIKd33R8pIWYGWHVrFUtPDrUgqgoSeQEyLIa7H00FZTIshrsXT65zH6aBMj0EztKzX00wNmq4hrtfmD8TOfDKf5MvMPbPnmzltGmebgebsM9PMtKz7s8xTU0SFYRkREdUmO+xa1lOEsEtL5A+uZgq08q6fsFoXwrS3L/KXfQQxJ8jKDbScdU2ZtoVO0OVvAvwt1n2pPJeV+HEjO5RSS9DiL+SzK6jsGVRjdvWUKADNgfm36XYqqAYm4jBNk58ZiYgoLwZURERVRAyFENi+HYHt29PrTE1D4uQpxA/sR2z/fsTf3I/4sWPQR0cReeZZRJ551tpRkuDftAmBq65C8OqrELjyKihLizA3gKgAiqhAURU0qUW4yJOHZmhIGan0bUpPQTM169ZZ59ruXqcbuhVoGRp0Q4du6un1uqlDMzRoprVNM7TMOvv+tMdGznNct+7XmPYcU8s6vrMuH+d1k0YSSJXkS1qwaWGXE2aJ1n13BZt7Sa+bodpt2n456xRRST93rtfI2mZXsBEREZWN7LOWUMfCj2GagBbPhFjxcPZ8Lie0SkzOHXaZOmAa1jHi4cW9NyXoCq2aM/ez1uWGWzn78d/lgsglb/FnXSKM2hVUTnu/1qAKUZz/6zkBVVIzMBZNLahNIBER1T4GVEREVU6QZWsW1cYNaHnf+wAARiKBxJEjiL25H7G9exF943VoFy4ifvAg4gcPYuz73wcAKMuXpwOr4DXXQF2zhr/ZRlXJCR9qjWma6TBKN+zgzHXfmXHmnm+WMlLZM9CcmWizzEdL6amsfWY6lrPOzZlb5nVQNh+iIGaFWzOFWbnB1mwBmlNRlr4VVSiSkr3Orl5TRCVrm7NvvufU4vc1EREtgCAASsBaGhbxS2amCaRihVVvpcMue594OHM/FbGOl4pay9TAws9JbZgl3LLuC0oTALs96MgpYHDS3tZcN3O5JLG0Lf5yK6hGFzF/CrCqvNpDKkYiSQyE4wyoiIgoL/7ES0RUg0SfL1Np9e8/AgBIXbyI6OuvI/b6G4i+8ToSR44idf48UufPY+KRRwAAUns7gjuvQejaaxHcuRPq6tUMrIg8JAiCFYxABirkl4sN00iHY4WEW041m1MxllvJ5l7S68zZ9532XFe1nGZOf16+95A0k+nZaJVMFMRpgZcTcLmrydwBmzsAy7uPtIDnzLLdfS5O21AiIqpQggCoQWtp7Fr4cXTNDqtcodW0ICucP9xybrWYdSxnNhfOz3zapgDgIevBc38HvPiUe2tOuJVTsTVtXXPO0gIo/oV/LcpEkTL/xiolaPGXnkGVsD47LTagAoCuZj9GIkkMTsSxedncnRRSuoEfvtyHlG7iEzesWlDlFhERVRcGVEREdULp7kbzO9+J5ne+EwCgT0UQ27fXCqxeew2xN96APjKCyUcfw+SjjwEApI4OhHZeg+DOnQjuvBbq6lUMrIjqnCiI8Ek++CQfGtHo9enMyalCmyncSpmpaSFavn3zhmWuIC2pJ9O3mqGlK9RSesq6tVtPpvdzb9NTWRVxboZpIKEnkNATVVGlJglS3uBrptBLluT8YVmeMM1dWZZVhWaHd87zs/aZ4Tmcn0ZEtEiSDATbrGWhtGQBQVbYXj8B7LWeZoaWAFKHtV1PwprLFbaWBb8fX3ZoFWjJH2Q593Pfh1z66iB3BZVSguAmqGZXUI3HrA8ercFFBFRNfhy8MIGL4cLmwt7x/z6DU0NWdd5XfnEIJ/76HZAl/vILEVEtY0BFRFSnpIYQGm68EQ033ggAMJJJxN98E5GXX0b05VeswGp4GBO/ehQTv3rUes6SDoSu2WkFVtfuhLqKgRURVbZ0FVqVtMozTTNdEZYbfDnBlmZktufOVHPWZ63Lmc+Wuz19PDMFTZ8+sy091y3nmCkjBcM0ss5fN3Xoug7oHn0B58EdiOWGWk7LRXdIllvFNlPwlbUtT4g263Fy2jvy31giqmmyCsgdBc3mEgwT2Psr68Fb/grY+YB1PxXPDrIS4enh1rTAyw7BYuPWrWkAegKIXLKWgjycuftfl1izuHJDrELDLn9zQXO4smdQla6CygmoonYlVdC38DL+pc1WZdrAxNwB1X9//Eg6nHLc+PUn8dK9ty/49YmIqPJVx0/qRERUcqKqIrhjB4I7dgCfsQOrffuyA6uhYUz86leY+JX1w6G8ZIldXbUToWt3Qlm5khfTiIgWQRAEKIIVVFQD3dCzArVpoZa5+DAtd/+Ztrkr1fLtp5vZqZlTCRdDzKOv3txmqhZzh1nuOWY+yQdVVLPWq9Ls2/I+nmEbWzgSUcVR/Nay0LlchmG1F3SCq/i46769xPKsO5tzHGcW1+TFhZ2H2jhnkCXHlgGwPh8oegQY77NncDUWZQZXZgaVFUxF7KAqpC780mFXkxVQXZojoDJNE/ftOTlt/eBEAoMTcSxtqvwWjEREtDAMqIiIKC9RVRG85hoEr7kG2A0YiQRi+/Yh+vIriL78MmJ790IbGsLEL3+JiV/+EgAgL11qB1bXILRzJ5TeXgZWREQ1TBIlSJDgk3xen8qcdEOfNdTSDC2r5aKzONVr7haOWc93qt1yn6vPfIyZ9jFhZp2zsw3TR6l5wqkAyxtiudflbFNEOyArMAgrZJtUQLUBEdWOkv1IIYrWbCp/E4AVhT/vS7/M3P+zMzMHWXMFXim7Yig5aS3h/hlfUk7+OYCtAACl/7fA//tea4Mg5p+tFWjJU9FlPw60Zrbb87ecCqqUbiKpGYjZQZXT+m8hnPlVzjyrmTx+cDB9/7H/eDMafDJu+voeAMBn/+kN/NMfXr/gcyAiosrGgIqIiAoi+nwI7dyJ0M6dAHZbgdXefYi+/HImsBocxMQjj2DikUcAAHJXVzqsCl57LZSeHgZWRETkCUmUIIkS/Kjc38J2V5Olw7A81WFJI2kFXXrmfkJPZGaZ5Tyets1IpGeiOc+fdj/PTDSn6iyqRT36CmVIgjR3sFVoADbTfgWGcNXSQpSomlX0zxCBVmtpXcBz9ZQ9Y2t8epiVE3jJR1qAKetpiiwBsh/Q4laLQuf5CyEHgEALgr42AH8BAIj+9D8icmE7gC4Eh/YCbx7IBFrpcKsZkGav+HYCqrHo7AHVQy9ZJWk3rmvHpq4mAMCnb16Nbz97Gi+eGoVpmpX9PUBERAvGT9JERLQgos+H0LVWaz8AMOLxdGAVefklxPa9CW1gABP/+ggm/tUOrLq7Edp5DYI7r7VmWPX0ePkWiIiIKoosWvPSAgh4fSoArJZL6WArT3g12/2EnsgK0Ao+Ru4+rnXuCjPd1BHTYohp3rdoFAURPsmXabtYpCqx+eznk3zWnDOBs8uIqoqkAKF2a5mD/L1XgMPWjCx509uBD91rzeDKW7U1PkPYZd/Gxqxb0wC0GDAZgzJ5ESpSSEJB5MAvEE2FAHQhePxfgdO/yn9S6daELdZtTojVOtkJoA2j4+PA+dcz4ZavOast4ZGBSQDAPbeuTa/7/Fs34tvPngYAXAzHsaylMv5tJCKi4mJARURERSH6/Qhddy1C112LJfgTGLEYYnv3ZmZYvfkmtIsXEf75vyL8838FACjLlmXPsFq+3ON3QURERA5BENJBiNdM04RmaoWFZAUEaAvdljJSSOgJGKaRPjfDNKywrALmmQkQZgyv8oZe82jVOOM20fUarseyyLCMSsOce5eaJImZ/58UyQ53nBlcjUvnf0DDsNoKxsbsAGscgQenkEwCseu+gOihJmAICHZvAhqm0vsgNg4kJqxjzNGasM3oAfDfMDYeBr79QdcWId2G8LS0GkOTnwYAXHXwb4CzjUCgFQF/C9Y1t+NEWMA3f/4i/vbfrbUCLrWhhD0fiYio3BhQERFRSYiBAELXX4/Q9Va/cCMazQRWL72M2P79SF24gPDPfobwz34GAFCWL7cCq2t3Inj11WwJSERERACssEwRFCiigpAS8vp0rJll86j+WkyAltKtUCxvYKanoJmZIWUmTCT0BBJ6wsOvTka+8GqmSjBnTplP8sEn+7IfuxZVUuGX/DNuc99yVhnVElnKVBzJYhF+RhLFzGwquz2h3/9rhJMJxLd9BNHzR4ChYQRv+APgyi9nP1fXrJAqHW7Zt7GxTIgVG0frZBw4CIyhEXpjD6T4KJCKAjDTlV7fS90CAFgrnEdo7wNZL3Nr6iM4gTtx8sg+4PS77fOWM3O00pVbc9232zAqldvml4jq165du6AoCnbv3o3du3d7fTplx4CKiIjKQgwGEbrhBoRuuAGAFVhF33gD0ZfsGVYHDiB1/jzCP/0pwj/9KQBAWtKB4JVXIXDVlQheeSX8l10GQfX+t7iJiIiovjntGINK0OtTgW7oBVWCOdVfiw3Q0oFZnm2aoWWdW9KwZ5mlvPnayKI8Y4hV0OMZgjLnsTsoy92H4RgVmzuUEkv0S3w+2fq+TWgGoknr/+egmufSoSQDwTZrmUWrbgB/8ShMCAjf84Y1k0pL2iGWFWod/pdhYAC4a3MIWPHnWRVdd42P4H/3AcexAobog2gkAEMDoiPWMl9KMBNWOUuwzfW4Lc+6VkD2zf+1iIgKtGfPHvTU8QgMBlREROQJMRhEw403ouHGGwEARiSC6OtvIPqyHVgdOgR9aBiTTzyBySeeAAAIPh8CW7cicOWV6dBKamnx8F0QEREReUsSJQTEAAKy9/NZDNOYVuHlVH/lVoLlqwxzgi+nCsxZknoScT2etc+0x5r12F1RphkaNENDJBUp+9dCFmSr2kvODrFyQ63cajC/7E+HX05AlvvYue+X/FnP8Uk+dh+oYe4Wf2IxKqjy8MlWlVZC0xFN6gCAUL6AqkCKJKLJL2MirmE0krQCKlkFGjqtBUBf7DcA4rjm1juB3tas56/WDahffhxTmh/HPn0cm9qVrHBrXvdNw6reSkWBifPzfCMhO7RqmSXIyrNOUhb8tSMiqhcMqIiIqCKIoRAabr4JDTffBAAw4nHEDxxA9I03EHv9DcTeeAP6+Diir76K6Kuvpp+nrlmDwPbt8G/dgsDWrfBt3AiRVVZEREREZScKIvyyH35410bL3X4xX6iVLwRLP9byrCvwcUJPZFWQaaYGTdMQ1aJlff/5Aq3Zwq8Zw66cdU6YlnscL2aNmXU6hEoRMy3+XN3+isqnOAGVkQ6oAuriqgHbQiom4hrGoslp2+IpHRfDcQDAyvbp7VsVScQNa9vx1NEh/GzvRXzpHZsANQg0LZvfSTjztqKjdnDlWtLrRqevi4/bwVYECEdmnLU1I7XRDq1a56jScq3zt1gVakREdYJ/4xERUUUS/X4Ed+xAcMcOANZw9OTp04i98Qair7+O2OtvIHn6NJKnTiF56lS6LaCgKPBt2oTA1q3wb92KwLatUFevhiCW6Kc4IiIiIqoYXrZf1A09K7wqpBosvbjCsbgeR0JLpAM19/24Fs/sZ9/XTT19Ds62CUyU5T2Lgpg36Jqpwmu2bbkBWu4+jnotEpOkzBuXSt3iL5Vp8RfyLS6gag2pODMSxcjU9IDq9bNjAIAGn4zWYP5qo9svW4qnjg7h2ODkwk/CPW8Lqwt/nmEAibAdWo3NHGTlrouHAZhWKJacBMJ98ztfX1OeFoR2S8Vgu32/1XW/HVBD9fs/BxFVNQZURERUFQRBgG/NGvjWrEHL+94HANDGxhB7Yy/iB/Yj9uZ+xPfvhx4OI77fuu8QQyH4L78c/q1b4L9sM/ybL4O6ciUEibMBiIiIiKg4JFFCUAyWPRxLGal0m0N3wJUv7MoXcLnXzfV8577DMA3EtBhiWgxIlPqd/g0A4CsvfAX/vxMH5gy4csOudBVYTvjll/0IyIH044AUSK/3ST6IQmX8ops7lCpV1Vq+Fn9BZXGXDtuCVneLfBVUjx8cAABsXd4843ta02FVVp24NLWo81gQUcwERbOP28pm6FZINVOQlVXJNZoJwBJh6/mJCWsZP1v4a0qqK7CyFye8ygq22u2KrjYrsGOoRUQeY0BFRERVS25tRePv7ELj7+wCYFVZpc6dQ+zNNxHffwCx/fsRP3TImm9lz7ZyCH4/fBs3wL/pMvgvuwz+yzbBt2EDxID38xuIiIiIiAqliAoUUUFImd4irRRM00TSyARiTmiVbqM4VyA2S4XYtMf2/ZSRSr++buqYSk1hKlWewCIdZElWkOWX/QA+nt7+p0//aXq7E2w592cKvrL2t2/nCsLcY6ekEs+giqdcAdUiK6haZgmo+sdiAIB3bO2a8fmXL28GAPSNRjE8lUBHg2/GfSuGKGVCovnQNTvYyhNkRUes9dFR+769LjoK6AlATwKTF62l4POU7SqtQoMtuwUhu5MQURExoCIiopohCALUFSugrliB5ne+EwBgahoSJ08hvv9NxA4cQOLIUcSPHoUZiyG+703E972ZOYAoQl21Cv5Nm+DffBl8GzfCt2ED5M5ODpwmIiIiIoLd2cCuWioX3dCx9t7HAACfu/pzePu2xhkrxNzzxgoJy+KavbjuJ41MmOI8J4xw3nN77MxjRXmP7kqwgByYFngdubAFwAoAwOuXXsX/fOPJ6fu6n+MKxXySLx2WSeLMgZPT4m88mgkEg4ucQeW07nMf0zFgz59a0Tpz1WFzQMHm7iYcujiBPUcu4fd3rFjU+VQ0SQZC7dZSKNMEUtFMWJUbXsVG829LRQFDAyJD1lIoQbRCqnxVWblhlnM/0Mq5WkQ0I/7tQERENU2QZfg3boB/4wa0vP/9AABT15E824fEkcOIHz6C+JEjiB85DH1oOD3TauJXv0ofQ2xqgm/9evjWr7NvrUVubfXqbRERERER1Q13qNLsa8aq5t6Svp4zT8wdWsX0WPr+hw+Pp/f90s4vIabF0sFXTItlB145IZh734Se6YuYnh+WzD8/LD7ZACegOjiyHyfeXFgwporqjGHWoZFrASzDz4/9GkAPAOB/H/gHBJTs/XODr3xVYc6fWYsdUIXzBFSXJq33v6Rx9rBz5+o2HLo4gZNDkQW955omCNb8KTUEtMzj/4tUPE945VRo5T62g63EBGAadoXXKDByovDX8zfnD6+cYCu9dGT2mSVMJaLawYCKiIjqjiBJ8K1ZDd+a1Wi68870em1oCPEjRxE/fBjxw4eQOH4cydNnYExMIPbaa4i99lrWceQlSzKB1Qb7du1aiKHytFchIiIiIqo35ehrMPc8sV+m7334sg8v+HUM05g5zLJDMacK7F9eMPHcqPW8rR1bcNXGlmkBmrNvOiiz77uDsKSRRDKZxASmB2HxWA+AZTg+ch5ADyAkcf/+f1zQe1NEBQE5gMTodQB+B0+cfhYDj37D1fowgJGpGwEIeOTsQ3hxVEVADqQXv+RHQLHu+/xWpdXxoTFEUpGsAIwWSPEDyjKgaVnhz9GSOe0G3VVaucGWvS0+bj03HraWsdMFvpiQaT8Y6sgEWKGOTIgVcoVaoQ5AYbt+omrEgIqIiMgmL1mChiVL0HDzTel1RjKJ5OkzSBw/jsSxY9bt8eNInTsHbWgI2tAQIr/9bdZxlJ4eK6xatxbq2rXwrV0H39o1EIPlHZhNRERERESVSxREBJXZgrCMoycO4zmcAgDcsPw6fOG6jQW/jmEa6aqt3DArXfGlxfHQsxqeDwPrm7bj0DjgVyTcvfHuTHWYq4pspgoxR8pIIZVMIaVb7eMm4km8fun1zDmlmmDiJgAGHjr2jxAEY8bz1yYvA/Ax7Dm5H9c9/IcArEqwgBJIzwbLtzhhmPt+UA5mrc/dzwnHGIDlIatA41JrKZSuWSFVvvAqt4IrOgJEhu1Qy3RVah0v7LWUoB1W5VRjhdz3nYCL87SIKgUDKiIiolmIqppuEQi8M73eiESQOHnSDq6s0Cp+/Bj0oWGkzp1D6tw5TO3Zk3UsZflyqOucwGptOsCSGhrK/K6IiIiIiKiaiK6ZuOI85+OKgpgOX2bzxuFDeB6n0e7rATCMZn8Qf3HdXxT8OoZpWDPA7AqumB7DniMj+MpPL2Flwyb8+W3ftNanYnj1pIF/PgF0tCTx3svustbbS/r59hI2NcQAGMkOmKYIQTCsSrBEcsbZYIvlbl+YW9UVlIPZwdgsQVluSOYEZKJQJ8GIJFuhUKij8OfomhVMRYbt4GrYvj/quj+SHWoZKWuuVrjPWgohSDmVWW2Zaqys9a6AS1YX9nUgohkxoCIiIloAMRRCYNs2BLZty1qvjY1ZgdXJE0ieOGmFWCdPQh8eRur8eaTOn0fk6WeyniN3dWUCq3Xr0hVXUnNzOd8SERERERFVKGkRAVWhfLJVNRSOpbIeF0oUxPQ8qha0AACGljQDuAToAbx15W3pfS+dPw7gGH5n3Trce+37Zj2uphu49qu/wUgE+NaNj2HLCn9WkBXVolmhVm7AlRV+2VVjsVTOY3txODPBkJjlxBbBmdc1UxXXTCFXUA5at0ow7+OAHIAsVvnlXkkGGjqtpRCmac3Hio4AETvQcoKrqB1sZYVdI0ByEjB1IHLJWoYKPDdfU552g205rQddYZev0ZoTRkQzqvK/sYiIiCqL3NoK+dqdCF27M2u9NjaG5KlTSJw4icSJE0iePIHEiZPQLl2CNjAAbWAAkeefzz7WkiWZiqt1a+FbawVYcmtrOd8SEREREVHFqNdrvaLrfUslKr5R7ANPxTUAgF9Z/As1+hUAwIR9TMdoNAkAaG/wzXkMWRKxeVkTnj0+jIGwhpvXt6IVxf+ZyDTNrLaHM4ZcOWFYVIvOuC03GHPEdasl4nhivOjvQxXVGQOsoBxMV4HNuj7PY0VSin6uRSEIgL/ZWtrWFPYcLeEKsUZy7g9PD7uiI4BpWEFYYqLwWVqSmqnKCi2ZfhvMeaxyLADVHwZUREREZSC3tkK++moEr746a70+MYHEyZNInjxphVcnTyJx8gS0CxfTM66iL7yY9RyprQ2+deuyZ1ytWwupvR1Cvf7ETkRERERUw9yf80WxNJ/5Zck67lTCCpPmW0GVT6PfuvQ4GU9lrR+LWAFVW7CwlmnLW6z2hOfHY3PsuXCCIBTUCnGhDNNIh1ZxPT6tiiu3EixfUBbVounnOfejWhRRLQrDtOZ4Oe0Pix1+yaJcUBVXQY9d4ZdP8pX/51jZBzQts5ZCGIY9S8tdmTUyc+vByDCgxQA9CUxesJZCKEFXiLUkJ9xaYs3Tcm9jy0GqAQyoiIiIPCQ1NSF45ZUIXnll1np9KoLkKSe0yrQLTJ07B310FNGXX0b05Zezj9XcbLcIXOsKsNZB7lzC4IqIiIiIqIpJYulb/CnTAqrFV1A12RVUCc1AQtPTodeIHVC1hgq7wL7MDqgujscXfU5eEQURQSWIoFL8KhnTNJE0kunAKqbFEE1FM0FWzuPZtuU+1gzr+0EzNEwmJzGZnCzquTsz0goJtoJKMH0bUkKZdc562VofkAOQxMUHrJmTFO22fW1Ax/rCnpOMZoIrJ9SKDNmLs34oc6snrDla433WUghf88zVWblVWsE2oJhfE6IiYUBFRERUgaSG/DOujGgUidOnkTxxIqviKtXXDz0cRuy11xB77bWs54iNjXZ7wOx2gXJ3N4MrIiIiIqIqkNXir0Sf4WXRCqSiSR0A4CtCi78Gf+bS42Rcg6/BukA+5rT4KzCg6m72AwAuhEtXQVXNBEGAT/LBJ/nS87+KJaWnMsFVvkArTyiWt9or53kJ3RrwZZgGIqkIIqlIUc/bL/mzAi0nvHICL3fA5dwPKAGE5NC0/YNyEH7ZD1GYx/8TahBQe4GW3rn3NU0gOZUdWOXeRnNCLVMHEmFrGT1ZwAkJmRlZ0wIsd5WWvc3fXL89VamsGFARERFVETEYRODyyxG4/PKs9UY8juSZM9kVVydOINnXB2NyErG9exHbu3fasdTVq6GuWQN19Sr41qyBunoN1FUrIfrm7gVPRERERETlUY4Wf04FlcNfhBZ/kiigwSdjKqFhMq6hw545NTI1vwqqFW1W1dGRgUkYhlmyrwFNp0gKmqVmNPuai3pc3dCzw6xZKrkiqUh6fSQVybQ2TEUR0SKIpqLpx7ppBazOnK9RjBblfAUI6Youd6XWtGquPIFY7v7O43R7Q0EAfI3WUsgcLaflYMRVlZUVYA1ZM7Sc+7ExAKbdmnAYGDoy92tIqivI6rTuNzgBVqe1vqEzE3JJjBloYfidQ0REVANEvx/+TZvg37Qpa72RTCJ55kzWjKvkyRNInDkLIxpF/OBBxA8ezD6YIEBZvhzq6tXwrVlthVirrRBLXsJ2gURERETkHQH1+Vk0u8VfaV5DkbKrQ4pRQQVYc6isgMqaQ6XpBgYnrFZ9TmXUXK7sbUFQlTA0mcCJoSlsWNpYlHMj70iihAa1AQ1qQ9GO6bQ6jKasIMsJrdIBlh1wuQMt962zzQnFnG2m/Z9zjGIRBXHm1oU57QvztjRUggg2dyLUsTr9WJXyhL66BsRGZ2gxOGTPzXJtS0xY87MmzltLIQJtmcDKWdyB1srrgUBr0b52VDsYUBEREdUwUVXh37AB/g0bstabqRSSfX1Inj5ttQw8dRrJU6eQOH0axsQEUufOIXXuHCLPPpt9vIYGqGvWwLd6lR1aWSGWsnIlRJUDWomIiIiISiGrxV+JEio5N6AqQgUVAIR81uXHSMKqbBmaSsAwAVkU0hVVc/HJEla0BnF0cBKDE3EGVJSXu9Vhq784YYhpmlktC90VW3OGXc5z3GGZXR0GWO0Np1JTmEpNFeVcAUAW5XSg5VRwhWTXfXt9KBRCsHktQsq27G1yCCFBRjAZRSgxBTk2BkxdcgVYQ/bjYSByyQq3TMMKwGKjM1dnffpJYPnVRXufVDsYUBEREdUhQVHgW2vNonL/aGeaJvTRUSu4OnXKCq7sECt17hyMqSnE33wT8TffzD6gKELp7oa6shfKypVQe1dCXdkLtbcXyooVbBlIRERERLQIorvFX4k6GuS2+PPJxamgCqpW0BVNagCAC+NW9dTSJv+8wrYljT4cHZzE0GSiKOdFVAhBENJVSwgU55i6oSOux7MCrnQbw9zqr5m25VSHOTO9NEPDRHICE8mJopyrT/Klq7dCSgihYAjB5uUIKRus9VIAIYgImQZCuoZgKolQKoZQMoJQfArBWBih6Bhagx1QinJGVGsYUBEREVGaIAiQ29sht7cjuGNH1jYjmUTq7Fkk7NAqefpU+r4xNYXU+fNInT8P/PaF3INC7uqC2tsLdaUVXCm9vVaI1bsCYqBIn/KJiIiIiGqUO5QqWQWVmB1I+ZXiVFA5AVUkaVVQDYTn197P0d5gdWwYjSSLcl5EXpFECSHRqmQqlpSRSodYTpAV0SLpECySiqQXJ/TK2qZl75cyrJacCT2BhJ5Y3CwvH/CwNoGtRXqvVFsYUBEREVFBRFWFb/16+Navz1pvmia0oSGk+vuRPHPWah3Ydxaps31Inj0LIxKBdvEitIsXEX3ppWnHlTs7rUqrlZnQSlnRC7V3BaSmpnK9PSIiIiKqBvU5giqrxV+pZlDJORVUuRVVCxVSrcuPMbuCajxmBUytofm1CG8NWvuPMKAimkYRFSiqgia1OD9Dp/RUwSHXbEFXNBVFRIsUNYyrNbt27YKiKNi9ezd2797t9emUHQMqIiIiWhRBEKB0dkLp7ETw6uye0qZpQh8bQ/LsWaT6+pC0Q6tknx1eTUxAu3QJ2qVLwKuvTju21NJiVVutWGFVXtnBldrbC6mjA0KJ2psQERERUYUyvT4Bb4hi6Vv8qTkzqHJnUi1UMGcGVSRhBVUNvvldlmy3A63RKQZURKWmSApapBa0oGXRxzLNOv2Lu0B79uxBT0+P16fhGQZUREREVDKCIEBua4Pc1gZceeW07fr4uB1WWYFVqr8Pyb5+JPv7oQ8PQx8fhz4+Pn3mFQAhGITa0wOld4Wr8soKr5TubggyP+YQERERUW0oS4u/nIopuUivE1SyZ1BN2UFVyDe/FoJtTou/KAMqomrCXyyl2fDKDREREXlGamlBoKUFgW3bpm3TpyJInetHsq/Pah94tg/J/j6k+vqRungRZjSKxLFjSBw7Nv3Asgxl+TKoK3qtwMquulJ7e6H09ED0z6/fPRERERFViDq9zukOqEpVQZU7g6pYQVjQ5wRUVjAVtSuoQvOsoHJmVp0ejhTlvIiIyHsMqIiIiKgiSQ0hSJs2wb9p07RtZjKJ5PnzVnDV12/NvLIrr1L9/TCTSaTO9iF1tg/5fnyVly6FumKFNffKbhuorOiFurKXc6+IiIiIqOK4u+2JJaqgyp05pRSrxZ+aHVBF7EqqBnV+lyW39bQAAE5cmkI8pcOvzK8Ci4iIKg8DKiIiIqo6gqrCt3o1fKtXT9tmGga0wUEk+/ozLQP7+qwZWP39MCYnoQ0OQhsczD/3qrnZmnvlVF45AVZvL+QlS9iegIiIiIjKzv0ZVCpVBZVUmgqqgB0kxVNWQJVp8Tf/GVSSKEA3TIxHU+hqZkBFRFTtGFARERFRTRFEEUp3N5TubuDanVnbTNOEPj5uhVV9/VbLwLNWcJXs74M+NAw9HIa+fz/i+/dPP3YgYM+96rVbBtqVV6tWQVnWDUEszm+ZEhERERG5Zc+gKs1r5M6cKtYMKqfSKWYHVBG7xV/DPAMqQRDQElAwEkliLJpEVzPbdhMRVTsGVERERFQ3BEGA3NoKubUVge3bp203IhEkz52zK66c+VdWmJW6cAFmLIbE8eNIHD8+/djBIHxr18K3bh1865zbdZC7GVwRERERFUu91rK7Q6lSVfSrcvZn1mIHVE4F1WQ8BQBo9M//smRL0AqoxqOpopwbERF5iwEVERERkU0MheDfuBH+jRunbTOTSaQuXLCqrVwBlrOY0SjieSqvxGAQ6tpMYOVbtxa+9eut4IrtAomIiIioAGI5WvzlBFJSkUq1AukKKgMAEI5Z4VJTQJn3sZznTMQZUBER1QIGVEREREQFEFQV6qpVUFetmrbNTKWQ7O9H4vgJJE4cR+LECSRPnETizBkYMwVXzc3wb9oE/6aN8G26zLpduxaCqpbpHRERERFRtciaQVWkyqZcSk4gpZSogmoiZrX4a15AQOU8ZyLGgIqIqBYwoCIiIiJaJEFR4FuzBr41a4A73pZeb6ZSSPb1IXHipCu4OoHE6TMwwmFEX3oJ0ZdeyhxIUeBbuxb+jRvhu2wT/Js2wbdxI+TWVg/eFRERERFVCnfVVKmK8GUpp4KqSAFVQLWCr3RAZVc/NfkXUEFlPyc8S0AVT+n49P95Fc8eH562beeqNnznE9fMe/4VERGVBv82JiIiIioRwQ6cfGvXZgVXRjKJ5IkTiB85iviRw0gcPoL40aMwJiaQOHIEiSNHgJ//PL2/3N2NwJbL4d+yFf4tlyNw+eWQWlo8eEdERERE5AV3ViSWrMVfzgwqqUgVVHKmgiqlG4gmraCqKTD/y5LOcybiWt7tKd3Apv/02IzPf/nMKLZ8+XGc/OqdJatEIyKiwjGgIiIiIiozUVXh37wZ/s2bAbwHAGCaJrQLFxA/cgRxO6SKHzmKVH8/tIsXMXnxIib/7dfpYygrViCwdQv8l2+Bf8sW+C/fDKmhwaN3RERERFQe9TrDUyhDBZWa0+IvN7BaKL/qzKDSs1rzNS6ggspp8ReOJvNuPz44VdBxTg9HsK6Tn52JiLzGgIqIiIioAgiCAGX5cijLl6PxLW9Jr9cnJxE/fBjxAwcRP7AfsQMHkerrQ6q/H6n+fkz86tH0vurq1fBv3YLAFju0uuwyiIGAF2+HiIiIiIrIXewjoEQVVDkVU3KxZlClK6iMdOVTo09eUAXT8pYgAOD0SDTv9q8/diR9/8/evgkJzarWUiQRumHim/92DADwyL4L+NxbN8z79YmIqLgYUBERERFVMKmxEaGdOxHauTO9Th8fR/zQIcQOHET8wAHEDuyHduEikqdPI3n6NCb+9RH7yRL8GzfCv30bAtu3I7B9O9RVq+r2N4+JiIiIqlU5KqimBVRScSqoAnYFVTypp2dHNQXmXz0FABuWWlVPp4byV0o9fWwIALBleRP++La107Y/8OwpTMQ1/P1vjjOgIiKqAAyoiIiIiKqM1NKC0A03IHTDDel12sgI4gcPInbggFVttX8/tKEhxA8dQvzQIYz/8J+s5zY3uwKrKxDYthVSU5NXb4WIiIiIClCOXy9ScmdQFauCSrGOG9cyLf4a/Qu7JNkaUgEgq1WgI6Ub6fv/9x2b8j7/i3dsxF/9/OCCXpuIiIqPARURERFRDZDb29Fwyy1ouOUWAPZMq4EBxPbtQ2zvPsTefBPxAwegh8OIPPMsIs88m36uumZNusIqcMV2+NatgyDzYyIRERFVnnqtAxeyWvyVhigKEAXAMK3HC2nBl09AsSqoUrqJMXt2VNMC5k8BmWBrKqHBMEyIrnMcnIgDsGZp3bSuI+/z79zajb/6+UEIApDQdPjs9oNEROQNXnkgIiIiqkGCIEDp7obS3Y2mt78dAGAmk4gfPWaFVvaS6utD8tQpJE+dQvinP7WeGwwisGULAq7WgPKSJV6+HSIiIiIAgOn1CXhELEdCBautX1Iz7PvFqqDKhEDjUavyyWn7N19OsGWYQDSlo8GXubR5ejgCAOhu8WcFV27tIRVBVUI0qeOZY8N46+alCzoPIiIqDgZURERERHVCUFUEtm5BYOsW4CMfBgBoo6P4/7d35/FR1ff+x9+zz2RjCxAgEAiEhE32oCIgWqHazbqBvXXrrdVbvG1/9fba9udV23ur9V67/bxee1vqUrdqq9W2YhUVUFxAVtkSQgKBEAh79tnP749JJjPZyDKTyfJ6Ph48Mvmec77zPePJZDzvfL7f+k8/Vf3OnXLv3Kn6nZ8qWFurus2bVbd5c/hY2+jRcs2aKecFodDKOXWqzA5Hok4FAABgYInKp+KXUNkjAypzbNagclib+mmcms/WxfWtHFazrGaT/EFD1W5fVEB10+9Cn11LT9e1ebzJZNKQJLvqvPV69uNSAioASDACKgAAgAHMOnSoUi+9VKmXXipJMgIBeUtKGiqsPlX9jh3yHDggX3m5fOXlqlrzRuhAm03OvLyoqQFtmZlRC3gDAADE2kD9pBF53vH8uBVZeBSrCiqTySSXzaJ6X0BV7lBAZbd2rW+TyaRUp1Vn63yqcfulQS33mTg8ud0+br4oSw+9UaDTtZ4ujQEAEDsEVB30q1/9Sj//+c9VUVGhefPm6bHHHtPMmTMTPSwAAICYMlkscuTkyJGTo8HXXSdJCtTUyL17d3gtq/qdOxU4fVruXbvk3rVLZ599VpJkGTIkHFa5Zs6Uc8YMWVJSEnk6AAAA/UJP/RFQ5NR41hitQSVJSfZQQFXZzQoqSUp12nS2zqcqtz+qfWSaQxVVHv1ixax2j1+UM1wPvVGg3UerujwGAEBsEFB1wPPPP6977rlHv/nNbzR37lz913/9l5YvX679+/crLS0t0cMDAACIK0tKipIvvFDJF14oSTIMQ76jRxsCq9BaVu69+xQ4e1Y169erZv360IEmkxyTJso5c6ZcF1wg18xZckyaKJOFxagBAAA6w9TG41izRARhlhgGVI3rUFXVh0IlezcCqsZp/Wo80QFVTUNg1bhOVVvGDHaFH+8pr9S00a2UYQEAegQBVQf84he/0J133qmbb75ZkrR69WplZGTo+eef15133png0QEAAPQsk8kke2am7JmZGvT5z0mSgh6PPPv2hSqsdoRCK9/Ro/IUHZCn6IAq//SyJMmclBRax+qCC0KVVhdcIGt6eiJPBwAAoNeLLKCKZzVVZN/dqXJqzmUPBVThCiprdyqoQrczqxumC5SkQNBQrTcQtb0tg5KaAqyXtx4loAKABOr0b4OjR4/ql7/8pZYtW6Zx48bJbrcrIyND1157rTZt2hSPMZ7Xs88+qzvuuEPz5s2Tw+GQyWTSU0891e4xn3zyia666ioNHjxYycnJuvDCC/XSSy+12M/r9Wr79u36zGc+E26zWq269NJL9dFHH8X6VAAAAPoks8Mh16xZGnrzzRrz859p0jtvK2fj+8r8n8c07BvfUNKCBTIlJSlYV6e6jz/W6d/8RmXfXKWiSxbpwOWf0dHv3q0zv/+96rZvV7C+PtGnAwAAeqmButylKaJuKp6vQWQmFcsKqsaKqVpv9yuomgKqpgqqyLAq9TwVVJJ0/dxMSdLJGtahAoBE6nQF1aOPPqqHH35YEydO1LJlyzR8+HAVFRXp1Vdf1auvvqrnn39eK1asiMdY23TvvfeqtLRU6enpGjVqlEpLS9vdf926dVq+fLmcTqdWrlyp1NRUvfzyy1qxYoWOHDmiu+++O7zvqVOnFAgENHLkyKg+RowYoeLi4ricDwAAQH9gTU9X6mWXKfWyyyRJRiAgz4EDqt+5M/zPW1wi39Gj8h09qqo1a0IHWixyTJwo57Rpck6bJtf0aXLk5cnsdCbwbAAAABKnp4I5syk+a1DZLKG+ahum5Wv8visaA6iaiIDq4Kna8GN7B6qzFk8erj9uLVNFpbvL4wAAdF+nA6r8/HytX79eS5YsiWp///33dfnll+uf/umfdPXVV8vhcLTbzzPPPKPFixcrKyur1e2BQEC/+tWvdNddd8lut7fb1+rVq5WTk6OsrCz99Kc/1Q9+8IM29/X7/br99ttlNpv13nvvadasWZKk++67T/n5+frhD3+o6667rs1xAQAAoGtMFoucubly5uZqyA03SJIC1dVy79oVCqx27FT97t0KnD4tz/798uzfr8o//zl0cGRoNX2aXNMIrQAAwMDRU2tQRQZU5himYtaGiqn6hmn4ujN9YGtT/P34b3s71cfItNBnyM2HzsgwjLhOmwgAaFunA6prrrmm1fZFixZp6dKleuutt7Rr1y7NmzevzT7Kysp0++23a9SoUVq/fn2LMCgYDOqWW27Rc889J7vdrrvuuqvdMUVOv3c+7777roqLi3XbbbeFwylJGjRokH74wx/q1ltv1dNPP6377rtPkpSeni6LxaKKioqofk6cOKGMjIwOPy8AAABasqSmKvnii5V88cWSJMMw5K+okHvPHtXv3i33nj1y79nbdmg1aVIotJoyRc68XDlyc2VJS0vgGQEAAMRB1BpU8XsasznycQwDqoa+6nyxC6iqIiqoGqcM7Oi0hHmjUsOPD5yoUc7I1Hb2BgDES6cDqvbYbKESW6u1/W4zMzP1wgsv6IYbbtDSpUu1fv16jRs3TlJ0OHXzzTfrm9/8ZiyHqPXr10uSli1b1mLb8uXLJUkbNmwIt9ntds2ePVvvvPOOPv/5z0sKVWGtX79e//Ef/xHTsQEAAAx0JpNJtowM2TIylHr55ZLOE1oVFspTWKjKiD5so0fLkZfXEFiFvtrGjpXJHLuFvgEAAHqSqYdqqCKrpiwxTMIaA6m6hgqqjkzD15b0lNCsTRVVTdPzjRnskiTd89ncDvWRFrFOVfHJWgIqAEiQmAVUhw8f1ttvv61Ro0ZpxowZ593/y1/+sl544QXdeOONuvTSS7V+/XplZmbq1ltv1bPPPqt/+Id/0JNPPilzjG8kFBUVSZJycnJabMvIyFBKSkp4n0b/5//8H/3jP/6j5s6dqzlz5uiRRx6R1WrVV77ylTaf57HHHtNjjz0mr9cb0/EDAAAMNO2GVrt3q37PHnn2FchdWCj/sWPylZfLV16umnffDfdhTkqSY/JkOfJCUww6cvPkyJkkSyo3IwAA6EsG6kxspp6qoIqa4i92/Vob1pzy+oOSmiqeumJocmgpkMr6pin+ahrWtkqyd/xW57KpI/XW3gqdrPF0eSwAgO6JSUDl8/l00003yePx6OGHH5bFYunQcdddd50CgYD+4R/+QUuXLtW8efP00ksvaeXKlXr66adjHk5JUmVl6O9rBw0a1Or2tLS08D6NvvKVr+jkyZP64Q9/qIqKCs2bN09vvvmm0tqZPmbVqlVatWqVysrKNHbs2NidAAAAAKJDq4jpngPnzslduF+ewlBg5SkolKeoSMG6OtXv2KH6HTui+rFmZMgxaVLoX07oq33iRFlSUnr4jAAAANoWy/Wg2n+epsexXJfJ2uwen83S9b6dttB9x/qG6QKlpsqsFEfHb3UOTw1VYh08WdvlsQAAuqfbAVUwGNStt96q9957T7fffrtuuummTh2/YsUK+f1+ffWrX1VJSYmuvvpqPfvssx0OuXrKt7/9bX37299O9DAAAADQDsvgwUpekK/kBfnhNsPvl/fQIbkLCkPBVUGhPAUF8p88Kf/x4/IfP67ajRuj+rGOGtUUXDWGVxMnypyc3NOnBAAAEF1BFdfniU8FVfNAytaNKf6S7A0BlbcpoGqqoOr4/cQRqU5J0t93H9N9X5ja5fEAALquWwFVMBjU1772NT3//PP66le/ql//+ted7sMwDL0bMf3Knj17VFFRodGjR3dnaG1qrJxqXiXVqKqqSkOGDInLcwMAAKDnmazWcNCkz38u3B6orJSnuFieogPyHDggz4EieQ4cUODkKfmPHZP/2DHVvv9+VF+20aNlnzhRjuwJsk/Ilj17ghzZ2bIMGxbTvzIGAACIFLUCVRw/c0T2HMuqLauleQVV9wOqOm9kBVUooEruRAXV2KGhdavKK93n2RMA4mfp0qWy2WzhGdkGmi4HVMFgULfddpt+//vf68Ybb9RTTz3V6Sn5DMPQN77xDT3xxBNasWKFvvCFL+iWW27R0qVLtW7duriEVI1rTxUVFWnu3LlR244fP66amhrl5+e3digAAAD6EcugQUqaM0dJc+ZEtQfOnWsIrIobvob+BU6dCq9v1Ty4MqelyTFhguzZTaGVfUK27GMzZbLZBAAAYsMU1/qh3qunKqgixTKgsjUrx3J0o4KqcYo/d8QUf7We0OPOVFBdmjsi/NjtC4T7BYCetG7dOmVmZiZ6GAnTpYAqMpxasWKFnnnmmU5PyWcYhu644w6tXr1aN9xwg5577jlZLBaZzWbddNNNWrp0qdavX69Ro0Z1ZYhtWrJkiR566CG99dZbWrlyZdS2N998M7wPAAAABibL4MFKmjdPSfPmRbX7z56V98ABeUoOyltSIs/BEnlLDsp39KiCVVWq37lT9Tt3Rndmtco+blwotJqQLfuECaHqq+xsWdpZzxQAALTOkJHoISRIzwdzsVwa3tp8ir9uVVCFbmdGTvFX21BB1Zk1qIYk2WQ1m+QPGjpT69Xowa529/cFgvrvdw/oV+8UaeX8sXrwyzNkjuU8iAAwAHU6oGqc1u/3v/+9rr/++i6tF2UYhv7pn/5Jv/3tb6PCKUm68cYbJSkqpMrIyOjsMNt0+eWXKzs7W88//7y+9a1vadasWZJCU/49+OCDstvtuvnmm2P2fAAAAOgfrEOGyDp/vpLmz49qD7rd8pYelvdgiTwlodAqFGAdlFFfL29JibwlJarRO1HHWdLTw1VXjaGVfUK2bKNHyRTLO0IAAKDPi6qg6qFMpLdO8edqqHSqj6igqmuooOrMFH8mk0lDku06We3pUED1u40H9at3iiRJf/jkiJbmjdDyabG7ZwkAA1GnA6of//jHevrpp5WSkqLJkyfrP/7jP1rsc/XVV4eDn9aUl5frz3/+s66//no999xzslqjh3HjjTfKMAzdfPPNevvtt/XVr3613TGtXr1aGxsWtt61a1e4bf369ZKkSy65RF//+tclSVarVatXr9by5cu1ePFirVy5UqmpqXr55ZdVWlqqRx55ROPHj+/gqwEAAICBzux0ypk7Wc7cyVHtRjAof0VFU2h1sCRcfeU/cUKBU6dUd+qU6j75JOo4k9PZME3gRDkmZofWvJo4UfZx45guEAAw4A3YKf6iHvfMaxDPKf5slq737WqYxs8fNOT1ByVJ3kDoa7K9c7c6hzUEVHvKKzV9zKA296vx+PXTNwqi2u54ZqsOPnQV65ACQDd0OqA6dOiQJKmmpkY/+clPWt1n/Pjx7QZUY8aM0UcffaRx48a1CKcafeUrX9G8efM0efLkVrdH2rhxo55++umotg8++EAffPBB+PvGgEoKLTy2ceNG3X///XrxxRfl8/k0Y8YMPfzww1qxYsV5nw8AAAA4H5PZLNuoUbKNGiUtXBi1LVBTI+/Bg/IePBgdYB0qleF2y7N3nzx790V3aLXKnpUVWt9qYrYcjcHVhAkyu9r/i18AANC3RYYg8cxD4lWp1byCyh6DCiopVEXlbwinJCnJ0blZnoJGaMrITSVntGL+uDb321p6ttX2Zz4u1c0Xje/UcwIAmnQ6oHrqqaf01FNPdfuJs7Ozz7tPR8IpqWtjys/P1xtvvNGpYwAAAIBYsKSkyDVjhlwzZkS1G36/vEeOhKYILC6Rt/iAPMWhqQONujp5i4vlLS6W1kYcZDLJNnp0Q2g1KVR11VB9ZRnU9l8CAwCAviMRNTqxneIvui+7tesBld1qDq8dVe8N6FSNR5KUnmLv9NSBCyYM0/6KGtV4/O3ud67O22r7fa/tIaACgG7odEAFAAAAID5MVqscEybIMWGCUi+/PNxuGIb8x46FQquS4obQqljeA8UKnDsn39Gj8h09qtr33o/qzzI8PTRVYE6OHJNz5Jw8WY6cHJmTk3v61AAAQDckYhY5izmWU/w1q6DqRkAlhaqoqj1+1fsCKj9XL0nnXUOqNfkThuqZj0t1rs7X7n7f/sOO8ONHrp+pf/njzk4/FwCgJQIqAAAAoJczNVRJ2UaPlhZdErXNf+aMvMXF8hQXN1RdFctTUiL/8eMKnDylupOnVLdpU9QxtsxMOSZPbgqtJk+WPSuLNa4AAL3eQF3uJ3LdqZ56DWKYT8W0gkoKrUNV7fGrzutXRZVbkpSR5ux0P0OT7ZKkyvq2AyqjYRrARtfNzdQnB8/oxS1HwscOcvEZCgC6goAKAAAA6MOsQ4fKOnSokubPj2oP1NSEpgo8UCzP/v3y7N8vd9F+BU6ekq+sTL6yMtW8+254f5PNJnt2dovgypqRweLfAAAkWNTaUD004V8sf/83n3qvO2tQSaGASpLcvoDqvAFJUoqj87c5G4Olc/WtT+EnRYdX2emhKvR7Pz8lHFD99I19euiaCzr93AAAAioAAACgX7KkpMh1wQVyXRB9w8R/9qw8+4vCoZVn/355iooUrKuTp7BQnsLCqP3NaWmhwCpvipxTpsg5JU/2SZNkttt78nQAAECDvvh3I1ZzjCuobKGAqs4bkMcflCQ5Gto6Y3BSKKCqqPK0uc9LDUGUJP18xSxJ0WHYC5uPEFABQBcRUAEAAAADiHXIEFkX5Ct5QX64zQgG5Ssvjwqt3Pv3y3vwkIJVVarfslX1W7ZGdGKVY+JEOfPy5Jw6RY68KXLm5coyaFACzggAgP7P3EOpVLyqs5qvZxWLKf4kqd4bkNsXqqBydKHP9BRH+HHp6VplDWu5TueDawrCj2eNHSypZXVZrcev5C5UcAHAQMc7JwAAADDAmcxm2TMzZc/MVOpll4Xbg15vaJrAwkK59+6Tu6BA7oICBSsrw9VWla+9Ft7fNnq0HFOmhIMrZ16erKNHM0UgAADdFDXFXx/8tdo8YHNYOl/tFCmpMaDyNVVQObtQQeW0WeSyWVTvC2jTwTOtBlRt+fkNM/Xdl3ZKkp784KDuuiyn088PAAMdARUAAACAVpnt9lDYlJenQV/6kqTQQuH+Y8dCYdW+fXLv2yfPvgL5jh6Vr7xcvvJy1bzzTlMfaWmhPqbkhSqtpuTJMXGiTDYWEwcAoKN6ag2qeIVfzSuoHLbuVVDVekJVU+8XnZKzoa+uVFBJ0rAUu8rO1qu+YS2rSB8eOBV+/L3luVHbLshsqhw/VdP2GlYAgLYRUAEAAADoMJPJJNvo0bKNHh1VbRWoqpK7oECeggK594XCK8+BAwpWValu82bVbd7c1IfNJkdOjhxT8uScMjVUbZWbK3Nyx/9qGQCAgSQylOqTFVTNp/izdC+gOl7pliQ5bWa5fY1rUHWtz4UT0/XiliOqdvtabPvF2/vDj5dNHRm1LWOQK/w4zcUf3gBAVxBQAQAAAOg2S1qakvPzlZzftLZV0OuVt7g4FFgVhCqt3AUFClZXy713r9x796pSr4R2NplkHzdOjqlTQqHVlFC1lTU9PUFnBABA7xFdQdX3WCJOwGo2tQisOuuG+WP1/94pkqSmKf6sXZs2MNUZuj1a7fG32HbodF3EftEhVIrDqvHDknTodJ0Onart0nMDwEBHQAUAAAAgLsx2e0PQNEXSlyWFpgj0lZU1TQ+4N/TVf+KEvKWl8paWqvqNv4f7sA4f3hBaTZEzb4qcU6fINnYs61oBAAaUvv5bL7Jgyt7FqfgiNU7n5/UH5fGFpubragVVY/BU7W4ZUAWCRvjx4KSWVVJXTB2p375/UH/ZWa6Hr71ALnv31tYCgIGGgAoAAABAjzGZTLKPHSv72LFKW7Ys3O4/fbphasC98uzbJ/e+AnkPHZL/5En5N5xU7Yb3wvuaU1LkzMuToyH8ck6dwrpWADBA8AcKfXSKv4hBxzKg8viD4QoqRxcrqFIaKqhOVHmi2qvdPp2pbVpbymlr2X9txLpVp2o8Gjs0qUtjAICBioAKAAAAQMJZhw1TyiULlXLJwnBbsLZW7sL9odCqoEDuvfvk2b9fwZoa1W3ZorotW8L7mmw22XMmNVRshda1ckzOlSWFda0AoD8xDOP8O/V7fS+hskRM6dfd9aekiIDKF5S7oYLK2eUKqtDt0Y9LTke1P7fpcPjxf153QavH+hrCMan1CiwAQPsIqAAAAAD0SubkZCXNma2kObPDbYbPJ09JSWh6wH375N67L7yulWdvaMrAtte1ypNzyhTWtQIA9DlRa1DFMZ8KxikAjAyobDEIqBqrsLyB7ldQjUh1SJJqmq1Bda7OF36c5mz9FupNF2Xpj1vLJIUqqAAAnUNABQAAAKDPMNlscubmypmbK119taSGda2OHpV7796mda0KCuSvqGh/Xau8pikCbZmZMpm7f8MMABBfTPEX3/qpeBWoRU7xZ7V0/wysDb+z/UEjIqDq2u/xWWMHhx9X1vs0yBWaMjhyKkKXvfVbqBdkDpbTZpbbF9TNT2zWoZ9+rktjAICBioAKAAAAQJ9mMplkz8yUPTOz1XWtPAUNlVb79rW7rpUjL7eh0qphXavsbJns9kScEgAAzfRMMBevCRQjK6giH3dVY8gVCAbl8Yem+OtqQDU4qel3fVVEQBXZn62dMbt9wTa3AQDaR0AFAAAAoF9qc12r/fujpghsXNeqfstW1W/ZGt63xbpWU/LkyM1jXSsAQELFs4osXlP8RVVQxSCgagy5/AFDnoaAyGHr2hR/kjQs2a7TtV7VeQPhNnMHX+dUh1XVHtafAoCuIKACAAAAMGCYk5OVNHu2kmY3X9fqoNz79oZCq30Fcu/b1/66VlOmyJmXF6q6ysuTdeRIpp0CAPSIuP62iVMJVXQFVfen1G0MuQJBI1xBZe/G2lZJDotO10q13qagye1rCqvmjh/S5rFPfW2+rn38I0nSwVO1mpDOH7IAQEcRUAEAAAAY0ELrWk2WM3dyq+taeQoKwlMERq1r9femda0sgwfLkZcnZ25u6OuUPKYIBIA4GKh/ChD5NxDx/HuIeE3xF1k1FZsKqlAY5QsaqnaHQqVUZ9dvcybZQsfWeZpCqcaA6vZFE+Swtl2ddUHm4PDjf/njTr38Txd3eRwAMNAQUAEAAABAM5HrWilyXaszZ0LTAhYWhNa3KiyQp+SgAufOqe7jj1X38cdNndhscmRny5mXK0felIavebIOafuvsAEASCQjXlP8xXoNqoY+aj1+efyhKf4GJ9m63F+SIxRARVZQ1TcEVK7zTB1oi6jc2lp6tstjAICBiIAKAAAAADrIOnRoy3WtPB55DhwIVVoVFDZ8LQhNEVhYKE9hofTaX5r6GDkyNDVgbl44vLJnjZPJ0vW1MwCgv8sdmarCimpdMik90UNJCFPU4/iVUMWrgsoS4zWorJZQH6drPOE+Uxxdv82ZbA8dWx+xBtX7RackSU47v58BIF4IqAAAAACgG8wOh1zTpsk1bVq4zTAM+cvL5S4slLugQJ59BXIXFsp3+LD8FRXyV1SodsN74f1NLpccOTmhKQKn5IXWt5qcK0sK61gAgCSt+fYiefwBJdm5ldUXlzyMXHYqFhVUjX2crfNJClVPdWctyKSGEGpn2TldPXuMvP6gDp6qlSS5I0IrAEBs8VsdAAAAAGLMZDLJNmaMbGPGKPWyy8LtgZpaefY3hFYFhXIXFshTuF9Gfb3cn34q96efRvVjGzeuYV2rXDnzQsGVdfTobt2EA4C+yGI2EU71YVEVVJZYTPFnjvo+zdX16f0kqcrta+g3NLZz9d7wtmqPv9VjIn33isn6+dr9kiR/ICirxXyeIwAAEgEVAAAAAPQYS0qykubMUdKcOeE2IxCQt/RweF0rd2EovPJXVMh3+LB8hw+reu3a8P7mtDQ5J0+WY0rDula5eXLkTJLZ4UjEKQEAekDkHyb0xb9RsEStQdX98KZ5Fdb51ok6n8vyRujjkjN6t+CE/u/npsrbsK6VJN10YdZ5j589bnD48UtbyvSVBeO6NR4AGCgIqAAAAAAggUwWixzZE+TInqC0K68Mt/vPnm2xrpWnpETBqirVbdmiui1bmjpp7KNxXavcPDmn5MmaPjDXagEAdE28si+zOcZrUDXrw9bNiqXGiqzkhnWsPBEBVfbwlPMe74wIyDYdPE1ABQAdREAFAAAAAL2QdcgQWS+6SMkXXRRuM7xeeUpKota18hQUKHDunDxFB+QpOqCqv/0tvL8lPT1iisBQxZV9wgSZrPyvIAD0JZFxTF+c5jVyir9YrkHVyNbNaQNzM1IlSR5fMOrr8NSOVSdHVnA5rd2r5gKAgYT/KwEAAACAPsJkt4fXotLVoTbDMOSvqIhe12pfgbylpQqcOqXaU6dU+8EHUX04cnJCoVVuXnh9K0taWmJOCgDQKX0vnmo2xV8MArbm61g1X5Oqs5y20PFuf0CS5Gn46rB2rN/G45s/BgC0j4AKAAAAAPowk8kkW0aGbBkZSr300nB7sK5OnqKiqHWtPIWFCtbVyb1nj9x79qgyoh/b6NFy5DVMEdgQgtkyM2WKwVohAIDYiWcBVbyqs8wx7rd5IGXrYJDUFkdD1ZPb1xhQBRvaO9Zv5Lpazm6uhwUAAwkBFQAAAAD0Q+akJLlmzpRr5sxwmxEMynfkSGhdq8KCcHjlLz8mX3m5fOXlqnn33aY+UlLknDZNzunT5Jo2Tc7p02UbO7ZPTi8FAH1ZX3/bjaygisW5tFiDqpvTBoYrqBqn+AsHVB0LmzLSnOHH7xSc0A+umtKt8QDAQEFABQAAAAADhMlslj0rS/asLGn5snB7oLIyvJ6VuyD01VNUpGBNjeo2bVLdpk3hfc2DBsk1baqc06bLOX26nNOmyTZmNKEVAPQQUx+c5M8SUYgUi18Xzdegaj7lX2c1r6DyNgRU9g5WULnsFl2aO1zrC0/qwIka1Xr8SnZw2xUAzod3SgAAAAAY4CyDBik5P1/J+fnhNsPnk6e4WO49e1S/e7fcu/fIU1CgYGWlaj/8SLUfftR0/ODBobBq+jS5poeCK+vIkYRWABAHcZ3iL079Rk7xF4uArcUaVJburkEVCqg8/qAMw+j0GlSSlOa0hR/7AsFujQcABgoCKgAAAABACyabTc6GtagGX3utJMnweuUuKpJ79x65d+8OrWW1f78C586pduNG1W7cGD7ekp4u57SpcjVWWk2fJtuIEYk6HQDo0yJDnb4Y/UdVPMWhgsre7YCq6XiPPyhPw1R/jk6sJxUwjPDjxgosAED7CKgAAAAAAB1istvlmhZaj0orbpAkBT0eefbvl3v37qZKqwMHFDh1SrUb3lPthvfCx1tHjAitaTVjelOl1dChiTodAOib4plQxanvqDWoYtCf1Wxu9n1316BqCqI8vmDEGlQdD758EaGUt49UUFW7fZr947XyB5vCtVdXLdSssYMTNygAAwoBFQAAAACgy8wOh1wzZsg1Y4aGNLQF3W55CgpUH6602i1PcYn8J06o5sQJ1axbFz7eNnp0qMKqMbSaNk2WtLTEnAwA9FJ9fcbUqIAqBicT6zWobBazLGaTAkFDbn+gS1P8RYZSfaWCau3eiqhwSpL+8+8Fev72CxM0IgADDQEVAAAAACCmzE6nXLNmyTVrVrgtWFcnd0GB3Lt2hYMr78GD8pWXy1deruq33grva8/Kig6tpk6VOSkpAWcCAL1PLNZw6mkWU2wrqGzN16Ayd2+KP0kKNAQ1R87URVRQdXyKv2R7021WX8BoZ8/e42dv7W/RVn6uPgEjATBQEVABAAAAAOLOnJSkpDlzlDRnTrgtUF0t9569cu/Zrfpdu+XevVu+sjJ5S0vlLS1V1euvNxxslmNitpzTZ8g5fZpc06fLkZcns8ORoLMBgMSJZzVVvLo2R1VQdb+/5hVUzb/viiS7RXXegGo8fp2r80mS7J2ooPrussl6fdcxSdLmQ2eUm5Ha7THF29FWwqjSM3UJGAmAgYqACgAAAACQEJbUVCVfuEDJFy4It/nPnpV79x65dzdVWvkrKuQpOiBP0QFV/vnPoR2tVjkm58jVGFrNmCHHpEky2WwJOhsA6Bl9r34quoLKiEFxUfOKqVgEVLPGDtaHxad18FStfr2hWJK0t7yyw8dPHJ4Sfvxvr+7WTRdmdXtM8RQMtv4fwjCkk9UeDU/lj0AAxB8BFQAAAACg17AOGaKURZcoZdEl4TZfxQm59zSGVrvl3rVbgbNn5dm7T569+6SXQvuZ7HY5puQ1hFbT5ZoxXfYJE2SydHyKJgBA7EVWUAVjkFA1z6OsMQiofA1rSJ2s9oTbdpZ1PKDqa+p9gTa3lZysIaAC0CMIqAAAAAAAvZpt5AjZRo5Q6mVLJUmGYchfXt5QYdUQWu3eo2B1tdw7P5V756fhY81JSXJOnRpa06ohtLKNGydTPOfIAoA4iuf7V7z6jqxwikUFlclkktVskr+hCigWFVTDkh0NfTe1fW3hhG7321u9tfd4m9sCbVRXAUCsEVABAAAAAPoUk8kk25gxso0Zo7TlyyRJRjAo3+HDodBq1y7V79kt9959CtbVqW7LFtVt2RI+3pyWJtf0aXJOawqtrKNGEVoB6LUiQ52++E4VOcVfLCqopFAo1RhQxaKCatywJEmS1x+U3WqW1x/UTRf17mn6uuP/vLizzW1nG9bgAoC2/OIXv9Dq1atVWloqq9WqOXPm6KGHHtKCBQvOf3AEAioAAAAAQJ9nMptlHz9e9vHjNejzn5MkGYGAvCUlqt+1W+7du1W/Z7c8+woUrKpS7YcfqfbDj8LHW4YOlXPGdLmmTQ99nT5d1uHDE3U6ANCmvpilR445RvmUrGaTGifjM8cgoLJbQutaef3B8BSCNksffLG7YP2/XKqgYejyn2+QYUirnt+mz13wuUQPC0AvlpWVpZ///OeaNGmSPB6PfvnLX2r58uUqLi7WsGHDOtwPARUAAAAAoF8yWSxy5OTIkZMjXfNlSZLh9cpz4EBTaLV7tzxFRQqcOaPaDe+pdsN74eOtI0fKOWO6nHlT5MzLlSMvT7YxY6i0ApBQpjjWUMWrZ3PE+6ah2CRUVotZUmgdpVhUUNmtDQFVIBie4s5qNne7375gfHqypNiFh/2dLxDUD17ZpYsnDtM1czITPZyYeuiNffrfDSVRbdfMGaOf3zArMQNCr3XNNddEff/II4/ot7/9rXbv3q0lS5Z0uB8CKgAAAADAgGGy20NrUk2dKq24QZIUdLvlKSyMCK12yVtcIn9FhWoqKlTz9jvh480pKXLk5cqZmxf6mpcnR06OzE5nok4JAHo9c5wqqBpZYhAkNQZUHn9QvkDs1rZC//PnbUf1p61l+tPWsn4XUDUPpyTplW1HCah6iWeffVbvv/++tm7dql27dsnr9erJJ5/Urbfe2uYxn3zyie6//359+OGH8vl8mjFjhr773e/qhhtuiNm4vF6vfvOb32jIkCGaMWNGp44loAIAAAAADGhmp1OumTPlmjkz3BasrZV7375QhdW+ArkLC+UpLlawpkb1W7aqfsvWiA5C0ws683LlyM0LV1tZR4yg2gpATERVHfXBt5XoCqrYiAyPYlJB1TDFn9sX6HK/V88arVd3lHd7LD1h1tjB2nHknL59eU64bcnk4dqw/6QkyTAMfoe14UydN9FD6HFcD73Dvffeq9LSUqWnp2vUqFEqLS1td/9169Zp+fLlcjqdWrlypVJTU/Xyyy9rxYoVOnLkiO6+++5ujef999/XlVdeqfr6emVkZGjt2rUaOnRop/ogoAIAAAAAoBlzcrKS5s1T0rx54TbD65Xn4EF5CgrkLiiUpzD0NXDmjLwlJfKWlEhr3mjqY9AgOSZNkmPixNDXSRNlnzRJ1uHDuckDoFMiq47iWdQTr7emeK1B1Sgma1A1VFDVeSMCqk6uQfXNpZPCAdWaXcd01YxR3R5XvPgCQUnSrHGDw233f2GqLvvZBklS8ckaTRqRmoih9XoD8Te4P2gMmDXZerPVq1crJydHWVlZ+ulPf6of/OAHbe7r9/t1++23y2w267333tOsWbMkSffdd5/y8/P1wx/+UNddd52ysrLCx3z/+9/Xww8/3O4YjIg38Xnz5mnHjh06ffq0fvvb3+qGG27Qpk2blJ6e3uFzIqACAAAAAKADTHa7nLm5cubmatCXQm2GYch/8qQ8hYVyFxTIU1Aod2GBvAcPKVhZqfqtW1W/dWtUP+a0tKbgKmeS7BMnyjEpR9YRBFcAWheZ6fTF94noMccmobJYYlxB1VpA1cmpAyPH8aO/7um1AdW5Oq/2lFdJkpJslnD7+GHJ4cc/eGWX/njnxT0+tr7gk0Nnwo/P1Xk1OMmewNH0jEDQUMSlIn8g2LAOHHrSZz7zmQ7v++6776q4uFi33XZbOJySpEGDBumHP/yhbr31Vj399NO67777wtvuvvvudqcLbM7lcmnSpEmaNGmSFixYoJycHD355JP63ve+1+E+CKgAAAAAAOgik8kk24gRso0YoZRFi8LtQY9H3oMH5Sk6IE/xAXkOHJD3QLG8hw8rWFWl+m3bVL9tW1Rf5tRU2SdMkH18luzjx8sxfrzs48fLnpUlc3Jy86cGMIBE/sV6XCuoeqA2JHYVVE03x2OxVpQjHFD5u9yvLeKGva0X37z/f+8cCD9OsjfdHo6sRPvk0NkeHVNf8va+E+HHP1+7Xz/+0vQEjqZn+AJBORsSqksefldlZ+v1wBem6taFExI8sr6vurpaVVVV4e8dDoccDke3+12/fr0kadmyZS22LV++XJK0YcOGqPbhw4dr+PDhXX5OwzDk8Xg6dQwBFQAAAAAAMWZ2OOTMy5MzLy+qPejxyHvoUDi48h44IE9jcFVdLfenn8r96act+rOOGBEKqxoCK/uEhseZmTLZ+/9fbgMDXTBqCaq+V0EVKVZrUB09Wx9+HIsKKkerFVT9M6A6VdN0A9llt7SzJ87nZHXnbsb3VcFg0+Oyhp+9B/66l4AqBqZOnRr1/f33368HHnig2/0WFRVJknJyclpsy8jIUEpKSnifrrjnnnv0xS9+UZmZmTpz5oz+53/+R2VlZbr22ms71Q8BFQAAAAAAPcTscISnCYwU9HrlPXhI3tJD8h4qlffQofC/wJkz8p84If+JE6rbvLlZh2bZMjJkGztWtswxso8dK1vmWNnHZso2dqwsQ4b0yenAADTXFOv09R9pI0YlVN5A0x3zWK5BVd8QUJlNne83cs2qWFR1xYsrYq62JAKqbonFtddbtPezGWjY5vYF2twHXbN3716NGTMm/H0sqqckqbKyUlJoSr/WpKWlhffpivLycq1cuVInTpzQ0KFDNX/+fL3//vuaMmVKp/ohoAIAAAAAIMHMdrucuZPlzJ3cYlugslLe0ujQynMoFGQZdXXylZfLV14ubWql36SkUHg1NlP2zIavDSGWbcxomWN0EwRAfEXeN+7zAVWM+hk9yKnySrekGK1BZQkFNbUef0Ofna+Aiqya6s3/mSKDtOYB1R1LsvW/G0p0YfbQnh5WnxSrwLU3CLZzKv6GEqo/bD7cbh9v7jmuFzYflmFIK+eP1ZW9dB223iQ1NVVpaWmJHkanPfPMMzHph4AKAAAAAIBezDJokFwXXCDXBRdEtRuGIf/Jk/KVHZWv7Ii8R47Id6RM3rLQV39FhYJ1dfIUFspTWNhq39YRI2QbM0a2zMxQBVZmZtP3GRkyWbltAPQGkfeNzXFMqIyYxUftPEeMnmL6mEHhgCoW1Ur2ZlP8daVPW0Tw05tjixRH03t78yn+5mUN1f+qRB+XnOnpYfVJ/kBv/i/dOYF2EqrGKf4q6/3N2o1wFdmBE9W645mt4W0b9p/UG99epCmj+l740h80Vk61VSVVVVWlIUOG9OSQWsUnTQAAAAAA+iCTySTbiBGyjRghzZndYnvQ45HvaHnL8KrsqHyHDytYVxeeOrB++/aWT2CxhKYPDIdWDQFWw/fW4cNl6kKFAYDOi6qgStwwYiJWt/MdEdPUdaXaqbnGgMrfcJM+ssqoo3rzulORkhsCqjSnVQ5rdEAVOf3f1tKzmpuV+BvYvYnHHz3Fnb+9sqM+xu1ve/q+Go9PklO7jkaHHd5AUE5z6JrZX1HT4riC41UEVAnSuPZUUVGR5s6dG7Xt+PHjqqmpUX5+fiKGFoWACgAAAACAfsjscMiRPUGO7JYLmBuGocC5c6Gw6miZfGVl8paVNQRaZfIdPSrD65Xv6FH5jh5ttX+T3S7b6NFtBlisfwXETjAioYpnBVVPiNWUaPaIMCgWuZDDGt1JV6YNjMVUgz3B6w+Vw1wzJ7PFtsiKqkOnagmomnn6w0NR30cGen3di5uPtLnt0XcP6FcrZ+vtfRVR7ZEBXWtXv6nPR+p915IlS/TQQw/prbfe0sqVK6O2vfnmm+F9Eo2ACgAAAACAAcZkMsk6ZIisQ4bINWN6i+1GMCj/yVOh8Oro0aYAqywUWPmOHZPh9YbXxGqNOSmpabrA1gKs1NQ4nyXQf/TUGlR9aTmdyOn0LDGooHI2Cxq60mdkKJ8zIqXbY4qX/153QFLr0xhGXl9pLltPDanPKDtbH/X91NH9pzqo7Gxdm9sOnapttT3Yl940BpjLL79c2dnZev755/Wtb31Ls2bNkhSa8u/BBx+U3W7XzTffnNhBioAKAAAAAAA0YzKbZRs5QraRI6Q5c1psN/x++Y5XhKutfEcjAqyyMvlPnAitf1VUJE9RUavPYR40SPZ2Aiyz0xnv0wT6jMi1ofp6ZWKs7mebI8KVWFQuJTVbi8nWhSn+JOmrF47Tsx8f1uEzbd/s7y0Kjle1aJsUEazZrX1jysKe1PyqaG/dpoFg66GzWpo3ItHDGDBWr16tjRs3SpJ27doVblu/fr0k6ZJLLtHXv/51SZLVatXq1au1fPlyLV68WCtXrlRqaqpefvlllZaW6pFHHtH48eMTcRpRCKgAAAAAAECnmKxW2TPHyJ45ptXtQY9HvvLypikEjx6NCrACZ88qWFkpd2Wl3Hv3ttqHJT29KcDKzJRtzOimAGvUKJls/GU/Bo70FEeihxAzRoxWoYrMpFqrBOosl715BVXX+swckiRJ2lNepSNn6jR2aFK3xxZLgagp2VqeY5rTpgnpyTp4qlb+QLAnh9YnNA+IB8xr1EYw/qO/7iGg6qalS5fKZrNp1apVWrVqVbv7bty4UU8//XRU2wcffKAPPvgg/H1jQNXY98aNG3X//ffrxRdflM/n04wZM/Twww9rxYoVsT2RLiKgAgAAAAAAMWV2OOSYMEGOCS3Xv5KkYG2tvEePhgMr39GyqO+DNTUKnDql+lOnVL9zZ8sOLBbZMzNlz86WY2K27NkTG75mM3Ug+qWRaU49ces8Jdv7/q28mFVQmWJbQdV8LaGu9hl5XOHx6nBAFQwaev/AKVVUupt2Nkkt8rrGw5u1m0ytvHad2Dc3I1Uzxw6WLyJQMbdxjkOT7Tp4qla+wMCuDmpNVb0v6vvyyP+efVxb14MktVVQGHmFePwtwzq3L9DNUfV/69atU2Zmy/XgWvPUU0/pqaee6lT/+fn5euONN7owsp7R93+rAQAAAACAPsWcnCzn5MlyTp7cYpthGApWVsrbuN5VY4BVVibf0XL5yspkeDzylpbKW1qqmnXroo63Dh8u+8SJcmRnhwMsx6RJsqSn9/mp0TCwXZY3MtFDiIlYrVkTGVC1d2O9o2yW6OnsrJauTW8XWXllibir/8r2o/qXP7YSuPegt7+7WCPSmqZPbetla5ze0DdQqoM64ZXtR6O+/9PWMj1y/cwEjSa2cke2/QceF00c1mp76emmqSy/8+KOFtu//8ourcwf1+2xof8ioAIAAAAAAL2GyWSSZfBguQYPlmv6tBbbDcOQ/8RJeQ+WyFNcLG9xiTwlJfKWlMh/4oT8J0/Kf/Kk6j7+OOo4y9ChcuROlnNyrhy5uXLkTpZj0iSZHf1n6jSgu3qiXiZWFVSReXM81krqagVVVEAVMcjXdhxtbfceVVRRoyFJ9vD3bVXkNYZ1BFQDyyBX09S5935uiiTpP17fJ6lp6kog1gioAAAAAABAn2EymWQbOUK2kSOUfOGFUdsC1dXylpTIU1zSEGCVyFtcLO/hwwqcOaO6jz5W3UcRwZXZLPuECXJMzpEzN1eOybly5uXKOmoU1VZAnMQqBIusoLJ3sdqpucwhLpWdrZfU9TWoghFrPEX20RveUwwpatq+byzObnU/OwHVgBRoSI8XTBiqry8KXRufHDqjN/dURK1dBsQSARUAAAAAAOgXLKmpcs2cKdfM6OmWgvX18hwolmd/odyFhfIU7penoECByspQgFVcrOo3/t7Uz5Ahck6b1vBvqlzTpxNaYUAwYlXe1O6TxKYbcxwqqCIDJXMXf94j7+NH9tEb3j2ChhEOnUwmaebYwa3u11hBdawfra+E82sMoaKqABseG4YhP4El4oCACgAAAAAA9Gtml0uuGdPlmjE93NY4VaBnf6E8+/eHgquCQnlKShQ4e1a1GzeqduPG8P6EVkBsGDFKqCLDn+brR3VV5I35XUcru9RH5Bpb0RVUXR9XrBiG5G0IGVIdbd8WfntfhSTpl28X6TufablWIPqnxks3KlhteBwIGuFrp7nTNR4NS3Fo4vBkFZ+sjdo2dqgrPoNFv0FABQAAAAAABpzIqQJTFi0Ktwc9nlBgtXu36vfskXvPXnmKiloPrYYODVdsuWbNlHP6DFlSkhNxOkCfEbs1qCIDqtikP5YYpEiRAVXk416QTzVM8RcKGdqrOvMznduA1FhBZW5lHbWAIfn8rV8XR8/Va1iKQyNSnSo+WauHr50hi9msf/njTmWkOeM/cPRpBFQAAAAAAAANzA6HXDNmyDVjhoY0tLUZWp05o5p161Szbl3DwWY5Jk2Sa9ascGhlnzBBJnNsqjuAeOuJWCJ2a1A1PY7VGlRdXXcqUmSRSbCXBT2hadpCY4pV1Rn6j8Y1qCJ/DBofG4ahl7YcafW4xkDzo5LTkqQztb5wAPrJobNxGi36CwIqAAAAAACAdrQZWhUUqH7HDtXv3Kn6HTvlKy+XZ/9+efbv17mXXgodm5oq1wUXhAKr2bPlmj1LlpSUxJ0MkGCxWucqspdYhS1dXXcqUmTVVCCygqo3zPEn6YmNByWxvhRaavzZjKwkbKymCgQN/debha0eF2gWxO48ck5JDktUv73l+u+Nli5dKpvNplWrVmnVqlWJHk6PI6ACAAAAAADoJLPDEZ7er5HvxAnV79wpd0NgVb97t4LV1ar94APVfvBBw4FmOXJzlTR3rpLmzZVrzhzZRoxI0FkAPS9WNUX+iFIlWzvT1XWGNQZTBUZWTUVOldcbbs8bhvTK9qOJHkaftfPIuVbb+0sA8+i7ByRJ7xScCLdtKjkjSXrojYI2KwybZ87+YFD+QNPPZNCQYjQLZ7+0bt06ZWZmJnoYCUNABQAAAAAAEAO2ESNku+IKpV1xhSTJ8PnkKSpqqLDaobpt2+U7ckSeffvk2bdPZ599NnTcuHGhwGruHLnmzpV9/Ph+cbMTfVAPzEgXqzWoajyB8OMkm6WdPTsuFhVUkev3RIZVveFHOhirF3+AagxwmusvAUzZ2foWbUfPNbW1Vf1oGEbUtT52aJKS7E0/k4GgEZPpM9E/EVABAAAAAADEgclmk3PqVDmnTtWQG2+UJPkqTqh+21bVbdmquq1b5SkslO/wYVUePqzKP/9ZkmQZNkxJc+Y0VFjNlXNKnkxWbuGgf4hVUBN5Q9wco5vfsbiJfs2cMeGp0JpPfZZoHR3Oz2+Yqe++tLNXhGq9SVuvx0APYAxFVwveuWSibBazHltXLKn3/Rygd+HTDQAAAAAAQA+xjRwh25VXKu3KKyVJgepq1W/frrqt21S3dYvcn+5S4PRpVa9dq+q1ayVJ5pSUUIVVfr6SFiwIBVaW2FSMAH2VOTaz+kWJRcgwapBLs8cN1vbD55rdmE98gBHsYFCQNSxZkjRuaFI8h9NvDJTKNKfNojpvoEW7YUS/BikOa9TPUmCAvD7oGgIqAAAAAACABLGkpipl8WKlLF4sSQp6vXLv3t1QYbVF9du2K1hdrZoNG1SzYYOkhsBq3rxQYJWfT2CFmOlbt5FjH/hYYlQyZLeE0rPIG/O9oRqpo0FBY7jgD/StKyJRBkqF0DVzxujZjw+3aD9R7Y6qoLKYTdEBFdcR2kFABQAAAAAA0EuY7fbQ9H5z5ki6XUYgIHdBgeo2f6K6zZtVt2VLKLBav14169eHjklNjQis5suZR2CF/i8egU/kTfV7Pzel2/1EBhe9IJ+KChHa0xjUDZTKoI5q67/hQKkQsjUEr6uWTlSa06aH3iiQJH37Dzu08/5l4f0sZlNU2DtQXh90DQEVAAAAAABAL2WyWOSaNk2uadM07LZbmwKrTZujA6t161Szbp2k6MAqeUG+HLm5BFZAB0SuZTVr7OAu99NaQNUbdHSKv8bpE3vb+BOtrVC0o69rX3foVK0kyWI2KzcjNWpb5LViMZlkNptkMoWm//MHgz06TvQtBFQAAAAAAAB9RFRg9bXbQoHVvoJQWLVpk+q2bm0ZWKWlNQRW85W8YEEosIrHAj7o84wBXulgjQioulOh1VpA1RuqkXyBjgUF1ob3BwKqjhkoL9O6wpOSpN1HKzU3a0jUtsZrxWRqCnqtZpN8AUPkU2gPARUAAAAAAEAfZbJY5Jo+Ta7pDYGV398UWDVWWFVVqebdd1Xz7ruSJPOgQUqaP0/J+QuUtGCBHDmTCKzQ58Qj7zFHpVJdT6gapzfz+EN35k9UufX2vhPdGVpM/Mfr+5Q7MlWFFdX61cpZbe5naayg6gWhWm/y5p6K8ON/+/xU/fvf9kqSPjhwSl+YOTpRw4qZcUOTdPhMnRblpIfbbrowS898XBq136dl51ocu+toqC3ykvE1rD21/fBZXTljVMzHi/6BTx8AAAAAAAD9hMlqlWvGdA37x69p7P/+WpM3fazxf/yjRnzvX5S8eJHMSUkKVlaq5u13VPHggzr4pS+paOElKvvWt3XmuefkKSoa8FU0GLgsEXdKu1NBVV7pliTd++puSdLafRXt7d6jGkOn4SmONvexUEF1XpfljQg//ucXtidwJLEzbXSaJGnZ1JHhtosnDmuxn8VsahHffv/lXW32+69/+jQm4+uvli5dqqlTp+qxxx5L9FASggoqAAAAAACAfqoxsAqFVv8YqrDas0e1mxqmBNy2TYGzZ1X91luqfustSZJl2LDQdID5+UpasED2CRNk6s7d+n7mh1fl6cE1Bfq/V01J9FBibqDHEdaISsLuXPH7jlVFfd+bMl9/wzR/FnPbZ9hYAUZA1bZ2Xr4+q3Eaysj3+9be+61U3MbUunXrlJmZmehhJAwBFQAAAAAAwABhslrlmjlTrpkzpW/cLsPnU/2u3arbvCk0JeC27QqcPq3qN/6u6jf+LkmyDE9X8vxQWJW8IF+2rKwBHVh9Y/FEXTsnU8PaqUBB32Q2t39jvj9onHbNamknoLIQUJ2PuR9eH43/uSPPrbUg02oxtagwbPfl6H8vFWKIgAoAAAAAAGCAMtlsSpozW0lzZkt33qmg1yv3p5+qdvNm1W3arPrt2xU4eUpVa9aoas0aSZJ15Egl5ecreUEotLJlZvbbm/ltIZzqDWIfnkRmNrG8onvTj4cvXEHVdhUMFVTnZ+6HJVRGuIKqqa2107S0ckG3F9j1xzAPsUNABQAAAAAAAEmS2W5X0rx5Spo3T/rmNxX0eFS/c6fqGqYErN+5U/6KClX99a+q+utfJUnWUaPC0wEm5efLnjkmwWeBrupNU9ElQnQFVez6bbliT+KcqPZIaj1kaNRYNRMY6BdEO/phPhX++TdHBVQtTzS0BlV0e/sBVUyGh36KgAoAAAAAAACtMjscSs7PV3J+vvTPdynodqt+xw7Vbtqkus2fqP7TT+U/dkyVr72mytdekyTZxoxpCKvmK3nBAtlGjUrwWQAdY40MqHpRqBQP7a5B1bDNMKRg0OiX1ULd1V7A11e1tgZVa//t//WzeXLYoivwlk/L0BMfHGy1389fMDqGo0R/Q0AFAAAAAACADjE7nUq+8EIlX3ihJClYV6e67dtDFVabN6t+9275jh5V5SuvqPKVVyRJtnHjwmFVUv4C2UaOSOQpAG2yxKuCqhdmGe2uQRXxOviDhuwEVC30x2lNG2d0jDyz1v7Tzx8/RGlOW1TbiLTQtKfXzG6qoL1+bqb+uLVMGYOcsR4q+hECKgAAAAAAAHSJOSlJKQsXKmXhQklSsLZWddu2qW7zZtVu2iz37t3yHT6sysOHVfmnlyVJ9qysUIXVglBllnX48ESeAvqoeMw+F6+1cnpjlNFeBVVkJVmQaf5a1R8zu8b/1pE/B639TFgt5haVVY3rlUVeV40haJC1zNAOAioAAAAAAADEhDk5WSmLFill0SJJUqCmRvVbt6q2YQ0r97598paWyltaqnMvvSRJsmdnh8OqpPx8WYcNS+QpDGiGBvaNZGt/TB3a0N65Nq+gQkvtBXx9nTli9r7WMtvWrh3DaBlQNYZbXEJoDwEVAAAAAAAA4sKSkqKUJUuUsmSJJClQVaW6LVtVt2mTajdvlqegQN6SEnlLSnTuhT9Ikhw5k5SUv0BJ+flKyp8v65AhiTwFxFhvLsgxD6Ap/jqyBpXUVBmDaP1zir+WFVStrbXVWkD11IeHJEklJ2vDbaWn6yRJv9tYom9/JieWQ0U/QkAFAAAAAACAHmFJS1PqZUuVetlSSVLg3DnVbdmi2s2bVbdpszyFhfIUHZCn6IDOPvecpFBg5Zo7V0lz5ylp3lzZRo1K5CmgH4u8GW+K4cR8sewrVqyRZTLNRL4Obl9Ag1y2NvcdqPpjAVUw2LKt+VR+UmiKv+ZO1XglSZsPnQm3bTxwSpJU5fbHaITojwioAAAAAAAAkBCWwYOV+pnPKPUzn5Ek+c+eVd0nn6hu02bVbd4UDqs8RQd07g8vSpJsY8Yoad7cUGg1b77sE8b3y2qGROiJ6qZYPUU8xmqJUwVVb9ROPhUVSvzy7SI9dM2MHhhR71Z8sibq+/44xd9HJadbtLV3msOS7Tpd643jiDAQEFABAAAAAACgV7AOGaK0ZcuUtmyZJMl/5ozqtm5V/Zatqtu6Ve69e+U7elSVR4+q8rW/SJIsw4Ypac6chtBqnpx5uTJZueWFzotbQNULs4z2Kqgivb2vQg+JgOqeP30a9b25HyeYO46c05dmjZEkna5pO4DiDwNiY+nSpbLZbFq1apVWrVqV6OH0OH5bAwAAAAAAoFeyDh2qtCuuUNoVV0iSAjW1qt+xQ3Vbt6h+y1bV79ypwOnTql67VtVr10qSzMnJcs2eraR5c5U0d66cF1wgs8ORyNNAHxEVUMV0ir/ep6MVQP2wUKhL6n2BRA+hxwQj1h1rr1CRfCo21q1bp8zMzEQPI2EIqAAAAAAAANAnWFKSlXLJQqVcslCSFPR65d69W3VbtoZCq23bFayuVu3GjarduFGSZLLZ5LzgAiXNnRuqspo9W5bU1ESeRq/VE1P89WaRVTHdufn+uQtG6fVPj8VgRPHjslkSPYQ+pfn1YI1I7iYOT+7h0cRX5NtAe5VirYWXkeuVjUxzqKLKE8ORoT8ioAIAAAAAAECfZLbbQ9P7zZkj6XYZgYA8+/c3BFah0Cpw8pTqt25V/datOv0bSWazHHm5Spo7T0lzZss1Z45sI0cm+lTQC1ijKqi67vufzYsKqCJv+M8eN1g3zh+noGFEhR4mmWTICIeEkdsMQx3et7nGfe95eVdUu93asSn+YllJ1p9YLWbddGGWnvm4VAuyhyV6OHHTXgVd5LWRl5GqguPVevDLTdNB/vhL03XHM1s1eWRKPIeIPo6ACgAAAAAAAP2CyWKRc8oUOadM0dCbvirDMOQ7fFh1W7aEQyvf4cPy7N0nz959OvvMM5Ik6+hRSpo9R645s5U0e7YcubkyWagwGWjMMVqDymELhT+N3RkRpWnXzx2rG+aP7XrnXRQZUA1Ntnf4OKb4C2ktqBuZFpo61OhnpYeRp9NeBVXkpsbA02VvCj6dDVV6NkvHwlAMTARUAAAAAAAA6JdMJpPsWVmyZ2Vp8LXXSpJ8FSdUv22r6rZsVf327XIXFMhffkxV5a+r6vXXJUnmpCS5Zs2Ua/YcuWbPlmvWTFlSqALoTYx2V8fpmuh1mbqezDTe1G9cyqcv5xemfrLQ0Mlqj378t736Sv44XTSx8xVPrb0Mja9NMNjd0fUuUT9b7VZQNQk2XOSR14spvC12Y0P/Q0AFAAAAAACAAcM2coRsV16ptCuvlCQFa2tV/+mnqtu2TfXbd6h+xw4Fa2pU++FHqv3wo9BBZrMckyfLNXuWkubMkWv2HNnGjO43N+8RYonRGlSRhxqGEXWDnksmMR746x69/ukx/XVnuQ799HOdPr61/2xNQWTfT2D++YXtrba3X0HVtG330aoW+zc+7m8VZogtAioAAAAAAAAMWObkZCVfdJGSL7pIkkLrWB04oPrt20Oh1bbt8pWVyVNQIE9Bgc698AdJknXECLlmzw6vY+XMy5PJZkvkqaCbLDFagyryxr1h9O0A43MXjEr0EGKi7Gx99zpoJahpvFz6Q4XQX3eWhx9HT/HX9jGtVpVFPG56ffrBC4S4IaACAAAAAAAAGpgsFjlzc+XMzdWQlSslSb4TJ0LVVdu2qW77drn37pX/xAlVv/mmqt98M3Sc0ynXjBlyzZkTqrSaNUuWwYMTeCadN9ArHaICqm6UOkVVUDX8a21bonTkv/ON+eP0wubDSrb3j9vH3X3d26ug6m8/N9HXa9uvXGvVVWZTy5+hfvbyIMb6xzsMAAAAAAAAECe2ESNkW75MacuXSZKCbrfcu3apbtt21W8P/QtUVqruk09U98kn4ePs2dlyzZwZ+jdrphyTJslk5XZcb2WOWQVV9Pd9McDob9Uv7VUCdUTra1CFvvaX16hRdyqoIvfvb9cQ4oPfiAAAAAAAAEAnmJ1OJc2fr6T58yVJRjAo78GDTetYbdsm76FD8paUyFtSoso//1mSZEpKkmv69HBg5Zo5U9b09ESeSp8Vj3vekTfXY7VWlGEYCvayOeA6Uh0Wrg6K92B6SHtrKXVE+2tQdavrXqjphNq7VlrbYqKCCp1EQAUAAAAAAAB0g8lslmPiRDkmTtSQ66+XJPnPnlX9zp2q37lT7p07Vb/zUwVra1W3ebPqNm8OH2vLzIyqsnLm5clktyfkPAb6feR3950IP25varPziTy2+RR/88YP6XK/PakxZ+iL1V+tiVXgKEmLJw+X1H8rhC7LGxl+PGlESpv7tRb6maigQicRUAEAAAAAAAAxZh0yRKmXXqrUSy+VJBmBgLwlJeHQqn7HTnkOHJCvrEy+sjJVvf66JMlkt8s5dWpTldWsWbJmZHRrTSR0zNk6b/hxt17uiGMNo6nCZvLIFE0akdqNjnuOuZ9Vv3T35yfy+Hs/N0VS05SQ/eU1ajQ3qylETU+JDsvv+Wxe0zetTvHXsoKq/1WYIZYIqAAAAAAAAIA4M1kscuTkyJGTo8HXXSdJClRXy71rVziwqt+5U4Fz51S/Y4fqd+yQng4dax0xoqnCavoMOadNkyUlOXEn00/F6j56ZBZiyAhXIU0dlRajZ+g5vbn6ZcuhM7ru1x+Fv//qheP071+a3moYFTl9YyBoyNLJRam2lp4NP7Y2HNv4PK/vOqbHOtVb7xb5yjR/LW2W9tdpa20NKmPA12aiPQRUAAAAAAAAQAJYUlOVfPHFSr74Ykmh6dR8hw83BFY7VL9jp9yFhfKfOKHqtWtVvXZt6ECTSfbsbLmmTwsFVtOnyTllisxOZ7fG04uziB4RuVZUdwpuIg9d/ov3dOh0naTur4PUkxrH+vSHhzRxeIrsVrNOVns0erBT5+p8WjF/bMKr+iLDKUl69uPD+sIFo7Uge1iLfSOnXXx5W5lumDe2y8/b+NoUn6gJt43//utd7q+3ae8/a+R7ROtT/LWsoDpypl4VVW6NTOve+1N/tXTpUtlsNq1atUqrVq1K9HB6HAEVAAAAAAAA0AuYTCbZs7Jkz8rSoC9+UZIUrK+Xe8+epiqr3bvlP3ZM3uJieYuLVfnaX0IHN1RouWZMbwqtJk+WyWZL4Bm1ND9i+rDeZuroQdpZVimp+1PCNWoMpySpyu2PSZ/d1ZF1pRpPv9Yb0N1/3Nli++jBrvBaTL3JmVpvq+1mc9Pj45Xubj1H42tTVe/rVj+9VUfXX2vtR6S1CipJOl3jJaBqw7p165SZmZnoYSQMARUAAAAAAADQS5ldLiXNm6ekefPCbf5Tp1S/e7fcu/eEpgjcvVuB06flKSiQp6BA+uOfJIXWs3Lk5ck1fbqc06fLNWO67NnZMlksPX4e6/7lUq3de1w3XTg+Jv3Fo9jrpguz9MLmw5Jan76so9oKt9y+QDd67VnnmwFvf0V1rwyo2hJZ7dPJ2f3a7CvRFWRx014FVcRPXmtBlinqdW563NkpFTFwEFABAAAAAAAAfYg1PV2pl16q1EsvlRSqiPEfPx4KrXbtlnv3btXv2aNgZaXcn34q96efho81JSXJOXmyHFPy5MybIufUKXLk5MjsdMZ1rZgJ6cn6xuKJces/FuzWpjKbWE3x11f1t/Cltannut5X6OtAz1zOV0FlaqOaCohEQAUAAAAAAAD0YSaTSbZRo2QbNUppV1whqWE9qyNHVL9rV1Ol1d69MurqGta32tHUgdkse/YEBaZ9TTI5JEn+s2dlHdJ7p+OLh8gb6h2d5ux8/XSkvTc631j72nplkafT3bXAGgOuvrSmWGd09LRaC/oiX5PInyEzCRXaQEAFAAAAAAAA9DMmk0n2ceNkHzdOgz73OUmSEQjIe+iQ3PsK5CnYJ/e+Arn37VPgzBl5DxRrfvI2vTHhIo2oO6uiiy6WdeRIOfPy5Jg6Rc7cXDkmTZI9K6vXrWsVK5G30LtXQdW7b8Z3JFs63znEs9ouHiLzEYu57f0605e5m/30Vu39l48MJlvbL6pqKuL1sfTTMA/dR0AFAAAAAAAADAAmi0WOiRPlmDhR+nxDaGUY8p84KU/BPv1wT6GmFe/TvOIPJUn+igrVVFSoZsOGpk5sNjnGZ8mRkyP7pElyTJokx6Qc2ceNlcnac7ca41HBE1390XW9/V58Ryp/+lvBS1uVPd3pq79Ng9iovfOK/LFrLaAzt7EGVX+tNkP3EVABAAAAAAAAA5TJZJJt5AjZRo5QypIluquhPVBTK8/+Qrn37ZN73z55iorkLTqgYF2dPEUH5Ck6EN2P3S77hAmhwCpnkuwTJ8qelSV7VpbMDkfPn1gXDJR76B0Jn/rdFH9Ra1B1s6+Gr/0txGvU0dNqLehra92p/lpthu4joAIAAAAAAAAQxZKSrKQ5c5Q0Z064zTAM+cvL5TlwIPSvqOFrcbGM+np5CgvlKSyM7qhhfSz7+PGyj89q+Br6Zxs9ukerrs4n6oZ7Pw0fQjpSQdW/XoDoKf5Yg6o9HV+DqmVb9GvS9Li7rzn6r97zGwAAAAAAAABAr2UymWQbM0a2MWOUsmRJuN0IBuUrL5enqKghuCqS91CpvAcPKlhdLV95uXzl5ar98MPoDm022TMzZRs3VvYxmeG+bWPGyJY5RpbBg3t0GjVTVD7V9eft7blFR8bXy0+h02I53Vx4Dare/h+6izp67bf2sxnZElVB1U9fK3QfARUAAAAAAACALjOZzbJnZsqemanUpUvD7YZhKHD2rLyHDsl78JC8paWhx4dCjw2PR96DB+U9eFC1rfRrTkqKDq1Gj5J1xEhZR45QoMYb+/Mwtf640/20cYO/t6xZ1LEp/trfqY/N8Bc1xVx3i3kaX5te8p8z5to7r7yM1PDjiycO084j59o8ljWo0BEEVAAAAAAAAABizmQyyTp0qKxDh0ZNFSiFqq78FRWhgKqsTL6j5fIdPRr6V1Ym/8mTDetdFclTVNSi7+o5K6Vx8yRJxZ/7vKwjhss2YqSsI0aE/qUPk2XIEFmGDJV16BBZhgzp1HSC3bmd3tvvxdut518Q6Hzn8NM3CnTd3Eylp/SN9cXe238q/PjfXtujf3ttT5f76u8VVO1ZMnl4+PHXFk7Q4+uLo7a3tTYZU/yhLQRUAAAAAAAAAHqUyWyWbdQo2UaNUnIr24MeT2hqwMbgqqxMvmPH5D9xQv4TJ2SyWsL7eouL5S0ubqWXaOZBg2QdEgqrLEMbgqvBQ2ROTZUlLVU11qbqEO+hg/INHypLaopMLlevqX6Khc9MGXnefep9gfPu883ntumlOy6KxZDirsbjj1lfTWtQxazLPiPy58BuaRl0Gm08HoivFTqGgAoAAAAAAABAr2J2OOSYMEGOCRNa3Z724g5p+1FJ0rgnn5CvokL+EydDAVZFhfxnzyhw5qwCZ84oUFkpGYaClZXyVlZKhw612udJ1yBp+b9Jkg7f+A8656sLbbBaZUlJkTk1VWaXS2aXS6Ykl8yuJJmTkkJtSS6ZXKG2oNMlaWiL/oM1NfIeOiSTK0lml1Nmp1Oy2Xo8/Pre8tzz7uPxBc+7z+aDZ2IxnD7H1M8rqDp8Wq3sF4wooYp8bCahQhsIqAAAAAAAAAD0WckXtV/FYwQCClRWKnDmjPxnzihw9pwCZxsenzunYHWNAtVVqqltqhqypKVKZ91SMCj5/QqcO6fAuXMdGk9AJunq/2rRXr99u4r/Z1V0o8Uis9MZCrecTpldTpmcroY2p8xOV1Oby9WszRlqczbfzxnVp8npjHpKl80idJ05vAZV/wxd2lpDrbnWMidHRGWjI2IqSZv5/NNKYmAioAIAAAAAAADQb5kslvBaWO2tmGStrJceeleSlLN2rdJcVhl1dQpUVytYXR36WlevYH2djPr6hsetf9/qOOw2mVNSFKyvlwINYVggoGBtrVRbq/NPqtcNVz8Sflh81ecaQixXU5CVnCTL4MHhf77Tg+I5mj6taQ2qxI4jXjqau7UW0I0flhR+nDkkSXcsyVaK3SqXnVAUrSOgAgAAAAAAAIBIptANeFNysszJyVJGRueO//7rLZqSF1yo3P/9lgzDkHw+Bd1uBevdMtz1CrrdoZDL7Vawvl5Gw7agu15GfWNbfXSb2x3R1nR8+KvH0+rQfAcPnnf4tdO/KE1afN79Tj7637IOHy7riOENX0fIOmyYTNb+e9u5scKo307x1439modWP7hySrfHg/6t/75TAAAAAAAAAEAXxDN7MJlMkt0ui90uS1pa3J7HCARCQZfbLf3Xx+H2rGd+3zIIq6kOT2MYOFcpmzGyQ89x6rHHWjZaLLKNHCnbmDGhf6NHNzxu+JqRIZPNFqvT7HGmfl9B1bET66f5XI9bunSpbDabVq1apVWrVp3/gH6GgAoAAAAAAADAgNfRtXe63H8P39A3WSxNFWARkubPP++xaX/bK208f6XV4BtukP/kyaZ/p05JgYB85eXylZdLn3zS8iCzWbbRo2WfMEH28eNlnzBejvHjZZ8wQdaRI2Xq5esV9f81qDq6X/88/562bt06ZWZmJnoYCUNABQAAAAAAAAARDCPRI0isjkYPo378o6jvjWBQ/pOn5Dt6NBRSHT0a/bi8XIbHI19ZmXxlZap9//3o53W5ZM/KkiM7W47Jk+XInSxnbq6so0b1mkCoqYKqd4wn1jq+BlV8x4GBgYAKAAAAAAAAQJ8Sj/yIG+5NuvpamMxm2UaOkG3kCGnO7BbbDcNQ4NQpeQ8dkufgQXkPlcp78GDoX1mZjPp6eQoK5CkokNasCR9nTk0NBVaTc+TMzZVjcm5XT63bGoMppviL80AwIBBQAQAAAAAAAECc9bf7+V0JKEwmk6zDh8s6fHiLqQYNn0/esjJ5Dx6Sp/iAPIX75dm/X56SEgWrq1W/davqt25tOuDqR1r0b/j9nR9UJzUGU+b+mlB1EFP8IRYIqAAAAAAAAAAgUhxKtIL9bNrAWE+DaLLZ5JgwQY4JE5R62dKm5/F65Tl4MBRWFRbKvX+/3Pv2tdpH+b/+qw4O8ss5Y7qSZs2Sa84c2WK8vo8pvAZVTLvtcwb6+SM2CKgAAAAAAAAADHhOqyX82GqJ/d33DftPxrzPgcBkt8uZmytnbq70hS80bfj+6y32NQIBuXfvlnv3bp174Q+SJOvw4dLCe2I+rv66BlVHDeyzR6wQUAEAAAAAAADoU4xYl+9IGpRk0799fqrMJinZMbBvm8bh5e0RGffdpzH1R1T/6S7Vb9um+r175T8Zn2BwgM/w1+G1qoD2DOx3WgAAAAAAAABo8I+XTEj0ENANlmHDlDZjutKuvFKSFHS75d61S/ZXKuQ1YhuoDPQKqoEe0CE2CKgAAAAAAAAAAJ22t7xKRsOCXaaGSd8MGa0+bm/b+fqI3BY0jDbDob3lVcoalqQUh1VBQ/IHgkrOmS6vcSJ2J91goFcQDfTzR2wQUAEAAAAAAAAAwjo6w99V/+/9uI6js/573QH997oDPfJcDqu5R54H6M8IqAAAAAAAAAD0KX1xiaSX7rgo0UPosDqvP9FD6LUOrbxRyYsu0WX5Fyd6KECfR0AFAAAAAAAAAHGWP2Fowp770Rtn659f2N7h/UcNcsVxNH3X/938tOrLd6l+xw65Lb+RvvBgeFvxDxfJkpaWwNF13fjvv57oIWCAIqACAAAAAAAAAOA8Rv3oAWUc36Pa9zfK+/EnUdv2X7xQSfPnKXXpZUq5bKnsmZkJGiXQdxBQAQAAAAAAAEA/ZjJ1bn+jL86h2AOsQ4ZoyKLrNeT66zW01i39+ztNG/1+1X30seo++lgVDz4oR06OUi67TKmXLZVzxgyZzKxZBTRHQAUAAAAAAACgTyFAiS+jT67y1bPMdlvU9xPf/Luq161TzbvrVLd1qzxFRfIUFen0//6vLOnpSl16qVIuv1zJF18ss92emEEDvQwBFQAAAAAAAAAA3WDPytKwW2/VsFtvVeDcOdW8/76q331Xte+9r8CpUzr3xz/p3B//JHNKilIuvVSpy65QyqJFMrtY7wsDFwEVAAAAAAAAACCMCrXusQwerEFf+IIGfeELMrxe1X7yiWreeVfVa9fKf/Kkqv72N1X97W8yOZ1KWbRIqcuWKWXppbKkpCR66ECPIqACAAAAAAAAAOC8mhbz6ui6Xia7XSkLFypl4UKNvPf/qn7HTlWvXavqt96S7+jR0OO1a2Wy2ZR08UVKW7ZMKZddJuuQIXE6B6D3IKACAAAAAAAA0KdQ4BNfvL7xYTKblTRntpLmzNaIf/2e3Hv3qvqtUFjlPXhQtRveU+2G9ySLRUn585W2fLlSly2TdejQRA8diAsCKgAAAAAAAADox0zqYLlPI+b4O69Ov6bNjzeZ5Jo2Ta5p0zT8O9+W98ABVa1dq+q31spTUKC6jz5W3Ucf6/i//4eSFyxQ2lVXKvUzn5Fl8ODYnADQCxBQAQAAAAAAAABwHpHT+nV0ir+O9WuSIydHw3NyNPyb35S3tFRVb72l6jf+Lvfevar98EPVfvihjv3ox0q++CKlXXmlUi+/XJbU1NgNAkgAAioAAAAAAAAAAHoJe1aW0m+/Xem33y7voUOq+vvfVbXmDXn27w9PA3jcZlPy4sWhsGrppTInJyd62OiCpUuXymazadWqVVq1alWih9PjCKgAAAAAAAAAAGFM8Hd+MSygapd9/Hil33mn0u+8U57iYlWteUNVb7whb0mJat55RzXvvCOTw6GUJUuUdtWVSlmyRGaXq4dGh+5at26dMjMzEz2MhDEnegAAAAAAAAAA0Bn/tGSiJOm6uQP3xm483TBvbKKH0OuZIub4+/6VeT3ynI6JEzX8n+9S9ut/04TXXtOwO++QLWucDI9H1W+9paPf+T/av/ASHf3u3ap++20FPZ4eGRfQVVRQAQAAAAAAAOhTpo5O094fL5fLZkn0UPqlsUOTEj2EXqmtqqkFE4b27DhMJjlzJ8uZO1nDv/1tuffuVfUbb6jqjb/Ld/SoqtasUdWaNTInJyv1M5cr9corlXLxxTLZ7TEdR3qKXadqvJIkm6WnasrQnxBQAQAAAAAAAOhzkuzc2kTi9JY4xmQyyTVtmlzTpmn43XfLvWtXaBrAv/9d/uPHVfnaX1T52l9kHjRIqVd8RoOuukpJ+fkyWbv/82M29ZZXAX0V7+IAAAAAAAAAAHRCb8xmTCaTXBdcINcFF2jEv35P9Tt2hMOqwKlTqvzTy6r808uyDB2qtM8uV9qVV8o1d26Xn4+1ytBdBFQAAAAAAAAAAJyHqTemUm0wmc1KmjNHSXPmaOQPvq+6T7aoas0aVb/1lgJnzujs8y/o7PMvyDpihHTxvyZ6uBigzIkeAAAAAAAAAAAgfvpQrtJn9KmwymJR8oULNOrHP1LO++9p7G9/q0Ff/rLMqanynzgRte+JRx5R/Z49MozO1Ud1cndAEhVUAAAAAAAAAAAMCCabTSmLLlHKoksU/NEDqt24UXqrPrz99Orf6fTq38melaXUq66UNKnNvgil0F1UUAEAAAAAAAAAMMCY7XalXnZZVFvqsmUyORzylpbq9OO/jtrmPXSoB0eHgYAKKgAAAAAAAAAAzqPvTOrXdZn/71cK1NSqZt06Va1ZE7Wt+LNXyjl1qtKuulKpn71SEiVU6B4qqAAAAAAAAAAAgCTJkpKsQV/4vMY+/j/NNljk3rtXJx75mYo/8xkFKisTM0D0GwRUAAAAAAAAAACgXTnvv6eMBx5Q0oIFkskkw+cPbzMCAZ194QX5T59O4AjR1xBQAQAAAAAAAEA/NhCmpkP8WYcO1ZCVK5T19FOatGG9TMnJTRsNQ8d/9GMVLVqsw1/7R537058UOHcuYWNF30BABQAAAAAAAADAeZhI+sJsI0bI7HQ2NZjNck6fLgWDqv3wQx2799+0f9FiHbnjTrn37k3cQNGrEVABAAAAAAAAQD82I3NQoofQr2UNSz7/Tv3Q95bnhh/fuXSSJvzpj5r45t81/DvfkSM3V/L5VLNhg0w2WwJHid7MmugBAAAAAAAAAEB/8r3lufqvNwsTPYywzCFJevu7izXIZU/0UGLuVytnadPBM3p+0+Eef+6Pf3C56n0BDU3u26/rVy8cp2c/bv/1uyxvRIu2lfnjZLWYZTFLX5w5RpJkz8pS+p13KP3OO+QpLlbtBx/IkZMTl3Gj7yOgAgAAAAAAAIAYSrJbEj2EFiaNSE30EOLiS7PGKCPN2SMBVfMp/jIGOVvfsY8ZnnL+87CYW5/f8Lq5mW0e45g4UY6JE7s8LvR/TPEHAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAMWQYiR4BAPR+BFQAAAAAAAAAAJyHSa2vwzQQDNwzRzwRUAEAAAAAAAAA+iwK1oC+iYAKAAAAAAAAAIDzoYwIiCkCKgAAAAAAAAAAAPQoAioAAAAAAAAAAAD0KAIqAAAAAAAAAAAA9CgCKgAAAAAAAABAn2UYiR5B/2di/S3EAQEVAAAAAAAAAMQQeUn/REYDxBYBFQAAAAAAAAAAAHoUARUAAAAAAAAAAOdhYp47xNjSpUs1depUPfbYY4keSkJYEz0AAAAAAAAAAACAgWbdunXKzMxM9DAShgoqAAAAAAAAAAAA9CgCKgAAAAAAAABAn2XISPQQ+j2TmN4QsUdABQAAAAAAAAAAgB5FQAUAAAAAAAAAMWQYVPQAwPkQUAEAAAAAAAAAcB5McgfEFgEVAAAAAAAAAADnYSKhAmKKgAoAAAAAAAAA0HcxoyLQJxFQAQAAAAAAAACANlE9hnggoAIAAAAAAAAAAECPIqACAAAAAAAAAABAjyKgAgAAAAAAAAAAQI8ioAIAAAAAAAAA4DxMYiEmIJYIqAAAAAAAAAAAfZbRQ89jIp8CYoqACgAAAAAAAAAAtIlwDvFAQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5lTfQA+rtgMChJOnbsWIJHAgAAAAAAAKAnnD19SkFPXVRbWVlZgkbTNc3H31uVlZXpRMW5HhnviYrjKnO64/48Pa3ybNP12vw6bWyvrarsc9dwb9aYFzTmBwOVyTAMI9GD6M8++eQT5efnJ3oYAAAAAAAAAACgF9m8ebPmz5+f6GEkDAFVnPn9fm3fvl0jR46U2cyMio2qq6s1depU7d27V6mpqYkeDtApXL/oy7h+0Zdx/aKv4xpGX8b1i76M6xd9Gdcv+jKu37YFg0FVVFRo9uzZsloH7kR3BFRIiKqqKg0aNEiVlZVKS0tL9HCATuH6RV/G9Yu+jOsXfR3XMPoyrl/0ZVy/6Mu4ftGXcf3ifCjpAQAAAAAAAAAAQI8ioAIAAAAAAAAAAECPIqBCQjgcDt1///1yOByJHgrQaVy/6Mu4ftGXcf2ir+MaRl/G9Yu+jOsXfRnXL/oyrl+cD2tQAQAAAAAAAAAAoEdRQQUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVAhR71ySef6KqrrtLgwYOVnJysCy+8UC+99FKih4V+6ujRo/rlL3+pZcuWady4cbLb7crIyNC1116rTZs2tdj/gQcekMlkavPfoUOHWn2eN998U0uWLFFqaqrS0tK0dOlSvfPOO22Oa//+/brhhhuUnp4ul8ulmTNn6vHHH5dhGLE6dfQT48ePb/N6vPTSS1vs7/F49OMf/1g5OTlyOp0aPXq0vvGNb+jEiRNtPsdzzz2n/Px8JScna8iQIfr85z+vbdu2tbk/7+PoiKeeeqrd91OTyaTLL788vD/vv0iUZ599VnfccYfmzZsnh8Mhk8mkp556qs39q6qq9N3vfldZWVlyOBwaP368vve976mmpqbV/YPBoB599FHNmDFDLpdLw4cP14033qiSkpI2n4PrGh3V0evX5/Pp5Zdf1i233KIpU6YoJSVFqampWrBggR5//HEFAoEWxxw6dKjd9+UHHnig1TEdO3ZM//iP/6hRo0bJ6XQqNzdXP/nJT+Tz+VrdvyufXdA/dOb9t7d+Tujs7wT0H525fs/3mdhkMunIkSPh/Xn/Rbx19l6ZxGdgxJc10QPAwLFu3TotX75cTqdTK1euVGpqql5++WWtWLFCR44c0d13353oIaKfefTRR/Xwww9r4sSJWrZsmYYPH66ioiK9+uqrevXVV/X8889rxYoVLY675ZZbNH78+BbtgwcPbtH27LPP6qabbtLw4cN16623SpJefPFFXXHFFXrppZd03XXXRe2/d+9eXXzxxaqvr9cNN9yg0aNH6/XXX9c3v/lN7d27V48++mgsTh39yKBBg/Sd73ynRXvzazQYDOpLX/qS3nzzTV144YW69tprVVRUpNWrV+udd97Rxx9/rOHDh0cd85Of/ET33nuvsrKydOedd6q6ulp/+MMfdPHFF+udd97RwoULo/bnfRwdNWvWLN1///2tbvvTn/6kPXv2aPny5S228f6LnnbvvfeqtLRU6enpGl65d7QAABE4SURBVDVqlEpLS9vct7a2VkuWLNGOHTu0bNky3Xjjjdq+fbseeeQRbdiwQe+9956cTmfUMXfccYdWr16tadOm6Vvf+pbKy8v10ksv6a233tLHH3+snJycqP25rtEZHb1+i4uLdd111yklJUWXX365vvjFL6qyslJ//etf9c1vflNr1qzRX/7yF5lMphbHzpw5U1dffXWL9tb+UOb48eNasGCBysrK9OUvf1k5OTnasGGD7r33Xm3evFmvvvpq1HN05bML+o/OvP826k2fE7ryOwH9R2eu37Y+Ex84cEDPPfecpk6dqrFjx7bYzvsv4qWz98r4DIy4M4Ae4PP5jIkTJxoOh8PYvn17uP3cuXPG5MmTDbvdbhw6dChxA0S/9PLLLxvr169v0f7ee+8ZNpvNGDJkiOF2u8Pt999/vyHJWLduXYf6P3PmjDF48GAjPT3dOHLkSLj9yJEjRnp6upGenm5UVVVFHbN48WJDkrFmzZpwm8fjMRYtWmRIMj788MNOniX6s6ysLCMrK6tD+z7xxBOGJOPGG280gsFguP3xxx83JBnf+MY3ovbfv3+/YbVajcmTJxvnzp0Lt2/fvt1wOBzGlClTjEAgEG7nfRyx4PF4jGHDhhlWq9U4fvx4uJ33XyTK2rVrw+9dDz30kCHJePLJJ1vd97777jMkGffcc09U+z333GNIMh588MGo9nfffdeQZCxevNjweDzh9jVr1hiSjGXLlkXtz3WNzuro9VtWVmY89thjRk1NTVR7TU2NMW/ePEOS8dJLL0VtO3jwoCHJuOWWWzo8nptvvtmQZDz++OPhtmAwaKxcudKQZDz//PNR+3f2swv6l868//bGzwmd/Z2A/qUz129b7rrrLkOS8bOf/SyqnfdfxFtn75XxGRjxRkCFHvHmm28akozbbrutxbannnrKkGT86Ec/SsDIMFAtW7bMkGR88skn4bbO/o/P//7v/7Z57T7wwAOGJOPpp58OtxUWFhqSjKVLl7bYf/369W3+jGDg6kxAddFFFxmSWoREwWDQyM7ONpKTk426urpw+w9+8IMW12ijW2+91ZBkbNiwIdzG+zhi4cUXXzQkGVdffXVUO++/6A3au8EUDAaN0aNHGykpKa3e5E9JSTGys7Oj2m+88cYW76WNLr30UkOSUVpaGm7jukZ3dPUG6fPPP29IMlatWhXV3tkbpFVVVYbD4TCys7OjbnYahmEcOnSo1Wu1s59d0H/FOqCK9/tpV34noP/qyvtvfX29MWTIEMNutxsnTpyI2sb7LxKp+b0yPgOjJ7AGFXrE+vXrJUnLli1rsa1xip8NGzb05JAwwNlsNkmS1dpyptP33ntPDz/8sP7rv/5Lr776aptz6nb2um5v/0suuUTJycn8HKAFj8ejp556Sg8++KD++7//u9U5od1utzZt2qTc3FxlZWVFbTOZTLriiitUW1urLVu2hNtjef3yPo6OWr16tSTp61//eqvbef9Fb1VUVKTy8nItXLhQycnJUduSk5O1cOFClZSURK0hsX79+vC25mLxPst1jVho7zOxJJWXl+uxxx7Tgw8+qN/97ncqLi5udb+PPvpIHo9HV1xxRYupArOyspSbm6sPPvggvN5VVz67AL3lc0JXficAkV555RWdPXtWX/ziF9ucSo/3XyRC888FfAZGT2ANKvSIoqIiSWoxx6gkZWRkKCUlJbwPEG+HDx/W22+/rVGjRmnGjBkttjefI3rw4MH61a9+pZtvvjmqvb3rurEt8rpub3+LxaIJEyZo79698vv9bd4kwMBz/Phx3XbbbVFt8+fP1wsvvKCJEydKCq0tEQwGW722pOjrcdGiReHHKSkpysjIaHf/RryPo7tKS0v1zjvvKDMzU5/97Gdb3Yf3X/RW7V1Dje1vvvmmioqKNHbsWNXW1urYsWOaPn26LBZLq/tH9nu+5+C6Rrw88cQTklq/ySNJa9eu1dq1a8Pfm0wm/cM//IN+/etfR92o6sjPSGFhoUpLS5Wdnd2lzy5Ab/mc0NnfCUBzv/vd7yS1/UdbEu+/6Hmt3SvjMzB6AhVU6BGVlZWSpEGDBrW6PS0tLbwPEE8+n0833XSTPB6PHn744ahfmDNnztQTTzyhkpIS1dfX6+DBg3r00UdlMpl066236i9/+UtUX+1d12lpaVH7nG//xmOCwaCqq6u7d5LoN2677Ta98847qqioUG1trbZv366bbrpJn3zyiS6//PLwtdKRaytyv8bHnd3/fM/B+zja8+STTyoYDOrWW29t8T8rvP+it+vs+2xX35fbOobrGvHwm9/8Rm+88YYuu+wyXXXVVVHbkpKS9G//9m/aunWrzp07pzNnzujtt99Wfn6+nn322RaBQE/8jGDg6m2fE7h+0R0HDx7UunXrNG7cOF1xxRUttvP+i0Ro614Zn4HRE4gRAQwYjTdG33vvPd1+++266aaborZ/+ctfjvp+/PjxuuuuuzRlyhRdccUVuvfee/XFL36xJ4eMAa75X4nOmjVLv//97yVJzzzzjH7729/qu9/9biKGBnRKMBjUk08+KZPJpK997WsttvP+CwA9629/+5vuuusuZWVl6dlnn22xfcSIEfrxj38c1Xb55Zfroosu0pw5c/TKK69o27ZtmjNnTk8NGQMYnxPQnzzxxBMyDEO33XabzOaWdQO8/6Knne9eGRBvVFChRzSm2m39BUZVVVWbyTcQC8FgUF/72tf0/PPP66tf/ap+/etfd/jYyy+/XBMnTtSuXbtUVVUVbm/vum7cL/K67sjPgclkUmpqaofHhoHpjjvukCR98MEHkjp2bUXu1/i4s/uf7zl4H0db3n77bR0+fFiXXXaZJkyY0OHjeP9Fb9HZ99muvi+3dQzXNWJpzZo1uu666zRy5Ei9++67GjVqVIePTUpKCt+4avwcIvXMzwjQXKI+J3D9oquCwaCeeuopmc3mVv9oqz28/yIeznevjM/A6AkEVOgRrc0Z2uj48eOqqalpcz5ToLuCwaBuu+02Pf3007rxxhvDHwg7Iz09XZJUV1cXbmvvum5tTtz29g8EAjp48KAmTJjAHLk4r8brsba2VpKUnZ0ts9nc5hpQbV2PNTU1On78eIf3j9wWifdxnM/q1asltT/Pflt4/0Vv0N41FNneuF9ycrJGjRqlgwcPhhclb2//8z0H1zVi5fXXX9c111yj9PR0rVu3TtnZ2Z3uo/nnEKljPyN2u13jxo2T1LXPLkBrEvE5obO/E4BGf//731VWVqYrrrgi/H7YGbz/IpY6cq+Mz8DoCQRU6BFLliyRJL311lsttr355ptR+wCx1PgL9/e//71WrFihZ555ptWFGttTW1urPXv2KDk5OfyBUOr8dd3e/hs3blRtbS0/B+iQTZs2SQpNbyJJLpdL+fn54cVvIxmGobVr1yo5OVnz5s0Lt8fy+uV9HO05ffq0XnvtNQ0dOrTFFD3nw/sveoucnByNHj1aH3zwQdRNISl0nX7wwQeaMGGCxo4dG25fsmRJeFtzjdfp4sWLo/aXuK4RP6+//rquvfZaDR06VOvWrdOkSZO61E/zzyGSdOGFF8put2vt2rUyDCNq/9LSUhUWFmrhwoXhm0Vd+ewCNJeozwld+Z0ASNLvfvc7SV37oy2J91/ETkfvlfEZGD3CAHqAz+czsrOzDYfDYWzfvj3cfu7cOWPy5MmG3W43Dh48mLDxoX8KBALGLbfcYkgyrr/+esPn87W5b1VVlVFYWNiiva6uzrjxxhsNScZtt90Wte3MmTPGoEGDjPT0dOPIkSPh9iNHjhjp6elGenq6UVVVFXXM4sWLDUnGmjVrwm0ej8dYtGiRIcn44IMPunq66Gf27dtn1NbWttqekZFhSDI2bNgQbn/iiScMScaNN95oBIPBcPvjjz9uSDK+8Y1vRPVTWFhoWK1WY/Lkyca5c+fC7du3bzccDocxZcoUIxAIhNt5H0dX/eIXvzAkGd/61rda3c77L3qLhx56yJBkPPnkk61uv++++wxJxj333BPVfs899xiSjAcffDCq/d133zUkGYsXLzY8Hk+4fc2aNYYkY9myZVH7c12jO853/a5Zs8ZwOBxGRkaGUVBQcN7+tm3bFvV5otHLL79smM1mY8iQIVGfHwzDMG6++WZDkvH444+H24LBYPi9/Pnnn4/av7OfXdB/tXf99tbPCZ39nYD+63zvv41OnDhh2Gw2Y/jw4VGfC5rj/Rfx1pl7ZYbBZ2DEn8kwmsXrQJysW7dOy5cvl9Pp1MqVK5WamqqXX35ZpaWleuSRR3T33XcneojoZx544AH96Ec/UkpKir797W+3Wt579dVXa9asWTp06JCys7M1f/58TZkyRRkZGaqoqNDbb7+tsrIyzZgxQ+vWrdOwYcOijn/22Wd10003afjw4VqxYoUk6cUXX9SpU6f04osv6vrrr4/af8+ePVq4cKHq6+u1YsUKjRo1Sq+//rr27Nmju+66S48++mj8XhD0KQ888IB+/vOfa/HixcrKylJycrL279+vNWvWyOfz6Qc/+IEefPDB8P7BYFBXXXWV3nzzTV144YVasmSJDhw4oFdeeUXjx4/Xpk2bNHz48Kjn+MlPfqJ7771XWVlZuvbaa1VdXa0//OEP8nq9euedd7Rw4cKo/XkfR1fMmDFDu3fv1qeffqoZM2a02M77LxJp9erV2rhxoyRp165d2rZtmxYuXBiuLLnkkkvCf+VcW1urhQsXaufOnVq2bJnmzJmjbdu26a233tL8+fO1YcMGuVyuqP5vv/12rV69WtOmTdPnPvc5HTt2TC+++KJSUlL00UcfafLkyVH7c12jMzp6/RYUFGjWrFnyeDxauXKlcnNzW/Q1fvx43XrrreHvL730UhUXF+uiiy5SZmamAoGAtm3bpo0bN8rhcOill17SF7/4xag+jh07pgULFqisrEzXXHONJk2apA0bNujjjz/WF77wBb322msymUzh/bvy2QX9R0ev3976OaErvxPQf3Tm80Ojn/3sZ/qXf/kXffe739XPfvazNvvm/Rfx1pl7ZRKfgdEDEp2QYWDZtGmT8dnPftZIS0szXC6XkZ+fb/zhD39I9LDQTzX+RUh7/xr/yqmystJYtWqVMX/+fGP48OGG1Wo1UlNTjfz8fOM///M/jbq6ujaf54033jAWLVpkJCcnGykpKcaSJUuMtWvXtrl/QUGBcd111xlDhw41HA6HMWPGDOOxxx5r9a+kMHCtX7/euOGGG4ycnBwjLS3NsFqtRkZGhvGlL33JePPNN1s9xu12Gw888IAxceJEw263GxkZGcbXv/514/jx420+z7PPPmvMmzfPcLlcxqBBg4yrrrrK2Lp1a5v78z6Ozti0aZMhycjPz29zH95/kUjn+6xwyy23RO1/7tw54zvf+Y4xduxYw2azGePGjTPuvvvuFn/V2SgQCBi/+tWvjGnTphkOh8MYNmyYsWLFCuPAgQNtjonrGh3V0et33bp15/1MvGTJkqi+f/vb3xqf/exnjbFjxxoul8twOBxGdna28fWvf93Yt29fm2MqLy83vva1rxkjR4407Ha7kZOTY/z7v/97m9UCXfnsgv6ho9dvb/6c0NnfCeg/Ovv5wTAMY8qUKYYkY+/eve32zfsv4q0z98oa8RkY8UQFFQAAAAAAAAAAAHqUOdEDAAAAAAAAAAAAwMBCQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHkVABQAAAAAAAAAAgB5FQAUAAAAAAAAAAIAeRUAFAAAAAAAAAACAHvX/ATVThAbN3UY/AAAAAElFTkSuQmCC", + "image/png": "iVBORw0KGgoAAAANSUhEUgAABpEAAAM4CAYAAAAklsrbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd5hc5X3+//eZPrOzvTetyqpLCFUQohqMMLgbG+LENrHjxI6cODEpP5PExIlbEuLYiRUnNt/YJI7jOMYdDBiQaBISqKCKykrbe5vZ3ekz5/fHmZ3d1UoghKTZcr+u61ynzjmfs3rU9t7neQzTNE1ERERERERERERERERExrFluwARERERERERERERERGZehQiiYiIiIiIiIiIiIiIyCQKkURERERERERERERERGQShUgiIiIiIiIiIiIiIiIyiUIkERERERERERERERERmUQhkoiIiIiIiIiIiIiIiEyiEElEREREREREREREREQmUYgkIiIiIiIiIiIiIiIikyhEEhERERERERERERERkUlmRIj04x//mLe+9a0UFRVhGAaNjY3ZLklERERERERERERERGRamxEh0sjICNdffz1/8zd/k+1SREREREREREREREREZgRHtgu4GD70oQ8BcOjQoSxXIiIiIiIiIiIiIiIiMjNclp5I3/ve9/i93/s91q1bh9vtxjAMvvvd777mZ1566SVuv/12CgoKyMnJ4eqrr+aHP/zh5ShXRERERERERERERERk1rssPZH+8i//kqamJkpKSqisrKSpqek1r9+2bRubN2/G4/Fw9913k5uby8MPP8xdd91FS0sL995770WtL5FIsG/fPsrLy7HZZsQIfyIiIiIiIiIiIiIicoFSqRRdXV2sXr0ah2NGDOp2QS7Lmz/44IMsXLiQuro6vvKVr/DZz372nNcmEgk+/vGPY7PZePbZZ7nyyisB+NznPseGDRu47777uPPOO6mrq7to9e3bt48NGzZctPuJiIiIiIiIiIiIiMj0t3v3btavX5/tMrLmsoRIt9xyy3lf+/TTT9PQ0MBv//ZvZwIkgPz8fO677z7uueceHnroIT73uc9dtPrKy8sBqzFUVlZetPuKiIiIiIiIiIiIiMj009HRwYYNGzL5wWw15fpgbd++HYBbb7110rnNmzcD8Mwzz7ypZ0SjUaLRaGZ/ZGQEgMrKSmpqat7UvUVEREREREREREREZGaY7VPgTLm3P3HiBAALFy6cdK6iogK/35+5ZlR/fz/79+/n2LFjABw5coT9+/fT399/1md8+ctfJj8/P7MsW7bsIr+FiIiIiIiIiIiIiIjI9DblQqRAIABYw9edTV5eXuaaUT//+c9ZvXo1d955JwB33HEHq1ev5uc///lZ7/HZz36WQCCQWY4cOXIR30BERERERERERERERGT6m3LD2V2Ie+65h3vuuee8r3e73bjd7sx+MBi8BFWJiIiIiIiIiIiIiIhMX1OuJ9JoD6QzexuNCgaD5+ylJCIiIiIiIiIiIiIiIhfHlAuRRudCOnPeI4DOzk6Gh4fPOl+SiIiIiIiIiIiIiIiIXDxTLkS64YYbAHjiiScmnXv88ccnXCMiIiIiIiIiIiIiIiKXxpQLkW6++Wbmz5/P97//ffbv3585HggE+NKXvoTL5eLDH/7wRXnW1q1bWbZsGTfeeONFuZ+IiIiIiIiIiIiIiMhMYZimaV7qhzz44IM8//zzABw8eJC9e/eyadMm6uvrAbj22mv5nd/5ncz127ZtY/PmzXg8Hu6++25yc3N5+OGHaWpq4oEHHuDee++9qPW1trZSW1tLS0sLNTU1F/XeIiIiIiIiIiIiIiIyvSg3sDgux0Oef/55HnrooQnHXnjhBV544YXM/vgQ6aabbuL555/n/vvv53//93+Jx+OsXLmSv/u7v+Ouu+66HCWLiIiIiIiIiIiIiIjMapelJ9JUp0RRRERERERERERERERGKTewTLk5kURERERERERERERERCT7FCKJiIiIiIiIiIiIiIjIJAqRREREREREREREREREZJJZHSJt3bqVZcuWceONN2a7FBERERERERERERERkSllVodIW7Zs4ciRI2zfvj3bpYiIiIiIiIiIiIiIiEwpszpEEhEREREREREREREROZebbrqJZcuWsXXr1myXkhWObBcgIiIiIiIiIiIiIiIyFW3bto2amppsl5E16okkIiIiIiIiIiIiIiIikyhEEhERERERERERERERkUkUIomIiIiIiIiIiIiIiMgkCpFERERERERERERERERkklkdIm3dupVly5Zx4403ZrsUERERERERERERERGRKWVWh0hbtmzhyJEjbN++PduliIiIiIiIiIiIiIiITCmzOkQSERERERERERERERGRs1OIJCIiIiIiIiIiIiIiIpMoRBIREREREREREREREZFJFCKJiIiIiIiIiIiIiIjIJAqRREREREREREREREREZBKFSCIiIiIiIiIiIiIiIjKJQiQRERERERERERERERGZZFaHSFu3bmXZsmXceOON2S5FRERERERERERERERkSjFM0zSzXUS2tba2UltbS0tLCzU1NdkuZ8qIhEcI9HUy1NdJf2cjA20nSQ00kRNqx50cJoKbiOEiipuI4SZmeIgZbmKGm5TNgYmNlGHHNOxg2MCwkzJsYEsfs9kxDSem3Qk2J4bdgeFwYdidGHYXhsOF3eHC5nBgd7hwOD3YHS5cLid2lweX04nDYcdhs+G0G9htBk67Lb02sNtsOGwGDruBY9y2027Dabfhclifc9ps2GxGtr/cIiIiIiIiIiIiIjJFKDewOLJdgExdL/zTb3JzbBvlr3WRecb6MoubduI4iOAkjJuI6SKCizBuhs30MVxETFdmO4yLsOlmCB9Dpo8gXkKGj4iRQ9juJ2rPIenIwel0jAubbLjGB0/pbdeEQMqG02FY19ltOB1j59x2G26nDbfDhttht9ZOa9vjHHfMYcfttD6vYEtEREREREREREREskkhkpxTwMglbtoZxE+nWUg7ZfQ5ywn7qnHlluAlhj0ZxpGKYE9EsCUjOJJhHMkIhpnAZqYwSGKYqcy2tU5hN5PYSGI3kzhIYDcTOEgvZhIHSZyj+6T3jeSkGp1GEidJfESBYbgYuUsKUlGD4aiHIXwMmLn0m7n0kUefmUe/mUcfefSbubSbefSQT5dZRAznRXj4GJfj7KGT22GbGDw5RwOoM0KpdGjlddrxuuwT1p7xx9LbbocNw1BwJSIiIiIiIiIiIiIWhUhyTvN/4x95eujvKC/wMbfYx0qfK6v1mKkU8XiMRDxKPGatk4k4iViUeGSEWGSEeGSERDREIhoiFR0hFQ+TioUw42GIhzHiYUiEscdHcMSHcSaGcSeH8aZG8Jkj+M0QLiOJzTDJI0weYaqNvvOqr9fMp9sops9WQp+jlAFHGQOOMvpcVfS6axkx/ETiSaKJFNFEeh0f247Ek6TG9eiKJVLEEimGSFyir+hkE4Mm28TQ6TVCKM+4bZ/LWnLcjsza67Ljc9px2Gf1NGwiIiIiIiIiIiIi04pCJDmnK+eWZbuECQybDZfbg8vtuaTPiYRHGA70Ex7qJzw0QCTQTTTQTXK4B3OkF0e4D2e0H198AH9ykKLUAB4jTokRoIQApE5BDGsZp588upw1DOfUkSyaj6tiKeUL11M1dxGGzYZpmiRSZjpcGg2bUmPB07hj0USSaDxFJL0+WzAVyZxPEo4nCceShOPW/axtax1LpjI1htPXXipuh20sXHI58LnT6zNCp/M570sfVw8qERERERERERERkUtDIZLIGTzeHDzeHKioPa/rzVSKwf5u+tpPMdTdRLSvmdRgK86RdnzhTkrjbZQyQBFBiuJHYPAIDAKngZ0wZHppcS9gKH8JtsqVlCzeSN2Sddjs9kv5mhmJZIpIIkU4lrQCpvEhUzxJZNz2xGtS1vkzrg/FEoRiSULRJCPp7WS6i5UVdsXoH7l49dttBn63A7/bQa7HWvs9o/vOsWPp47kTzjvwu534PQ58TrvmoRIREREREREREREZRyGSyJtk2GwUlFRQUFIBXHPWa4aDA3SePkyg9VVi3SdwDjRQMHKKOYkmco0wy2KHoOcQ9PwIDkAQH42e5YTK15K76DoWrL4Rj89/Sep32G347Tb87kvzx4FpWr2rQrEkI9EE4bi1Ht0PxdJh07jQ6XzORxNWD6pkyiQQjhMIx99UnYaBFSyNC5n8Hqe1f0bwlO91kud1kudxkucd2/e7HAqiREREREREREREZMYwTNM0X/+yma21tZXa2lpaWlqoqanJdjkyi8RjEVpPvELvyT0k2g+QN3CYedFj+IzohOsippPj3lWEam+gcu3bmbPoSgzb7J5fKJFMEUoHUiPRBEORBMPRBMORBEPp9XDUWsbOxSfup7eTqYvzx6DNgNx0sJTnsUImK2ByjNs+y3562+PU0HwiIiIiIiIiIiJTgXIDi0Ik1BhkaknEYzQe3kXf0WdxtO2mdvgAZfRPuKaTUhrL3kLB+g+waM1Nl23ou5lotKfUUCTBUDpkOlcQNRSJMxRJEIzECaZ7PwUjCYLheKZn1JvhtBvp3k3WUuhzUuhzke+11oU5zrFtn4sCn5MCnxO/26HwSURERERERERE5CJSbmCZ1SHS1q1b2bp1K7FYjIaGhlnfGGRqMlMpmo/to2PvL/G1PMPi8AHcxtjQbV0Uc7r8rZRcew/1KzdmsdLZLRJPpsMlK2QKhK2gaTRksrat84HM9lgQ9WZ6QzlsRjpQclHoc5LvtdaFOeMCKJ+TfN/EAMrjVPgoIiIiIiIiIiJyNgqRLLM6RBqlxiDTSWgkyLEdvyBx8CcsDTyP3whnzh13LCKw5G6W3fpRcvIKs1ilvBGmaRKKJTPhUyBkrQfDcQZDMQZCcQZDo9ux9HacgVDsTfWA8jrtFOW4KPa7rHWOO7NdlOOixO+iKMdNcXrf57Krx5OIiIiIiIiIiMwKyg0sCpFQY5DpKxIe4chzP4EDP2TF0PO4jCQAIdPNgfJ3UXfHn1BZtzjLVcqlFIknGQjFGBiJMxiOZcKlsdDJWk84Ho5fUM8nt8NGid+dCZmKMwHUWNBU5HdRkg6jctyOS/DGIiIiIiIiIiIil55yA4tCJNQYZGbo727j+BPforrhh9Sa7QAkTBv782+mePOfMW/5hixXKFOFaZoMRRMMjMToG4nRPxyjbySa2e5PH+8biabPXViPJ5/LTmmum1K/m9JcNyXp9YRjuW5K/C7cDg2tJyIiIiIiIiIiU4dyA4tCJNQYZGYxUykOPfdTjB1fZ0V0PwAp02Bv/s1UveeLVM1bkt0CZdoZHW6vfyRG73A0EzL1p5fRY/0jMfrSgVQk/sZCp3yv89yBU/p4Sa415J7dpiH1RERERERERETk0lJuYFGIhBqDzFwn9j/H0JN/z5rhZwGImXb2lb2HJXd/ifzi8ixXJzPZSDRBz1CUnuEovel1z9C4ZdzxePL8/xqy2wxK/W7K8z1U5Lkpz/Nkloo8DxX5bsryPOS6HZq/SURERERERERELphyA4tCJNQYZOY7uf85Qo/dzxWRPQD0k0fDlX/Ounf+PobNluXqZDYzTZNAOD4hXOo5S+jUO2wNt3e+f2P5XHYq8jyU5bmpGB805XsoT4dPZbkeXA61fxERERERERERmWw0N6ivr8fpdLJlyxa2bNmS7bIuO4VIKESS2ePgcz8jd9tfMjfVDMBR1wr8d26ldtGV2S1M5Dwkkil6h2N0BSN0BiN0p9edgSjdQxE6AxG6ghGCkcR537PE76KqwEtlvoeqAi9V+V5rv8BDdYGXUr8bm4bPExERERERERGZdZQbWBQiocYgs0ssGmXP/36BVQ3/js+IEjZdHFj+p2y480/UK0lmhFAsQVcwSlfQCpWscGncfjBCdzBKLPn68zY5bAYVmYDJkw6YvFQXeKhMB055Hg2dJyIiIiIiIiIy0yg3sChEQo1BZqf2xmP0/8/vsSK6D4AD3g1U3/MfFJfXZrkykUvPNE0GQnHaB8N0BCJ0BMK0DYbpGIxkjnUGIyRTr/9XZI7LngmXagq91Bb6rHWRj9pCL0U5LoVMIiIiIiIiIiLTjHIDi0Ik1Bhk9kolk+z+3y+x+tjXcRtxuili8B3/waK1N2W7NJGsSyRT9AxHaR8M054Ol9oHw7QHxoKm/pHY697H57JTU+ilptAKlWqLfGP7RT7yvc7L8DYiIiIiIiIiIvJGKDewKERCjUHk1JHdOH50D3NSbcRMBweu+EvWve+Ps12WyJQXjiXpCIyFTK0DIVoHwrQMhGjpD9M1FOH1/pbN9TioLfRRWzQxaKorzqG2yIvbYb88LyMiIiIiIiIiIhnKDSwKkVBjEAEYCvRz8t9/i9WhFwB4sewuNvzeN7HZ9Q1skQsVTSRpH4zQ0j8+XLK2WwdC9A6/dk8mw4CqfC9zS6xQaW6xta4r9lFXlIPXpd+fIiIiIiIiIiKXgnIDy6wOkbZu3crWrVuJxWI0NDTM+sYgkkom2fWf97Gx6d8A2Jt7E8u3fB+3x5flykRmplAsQVs6XGodCNPSb/Vgau4P0dQ3wkgs+ZqfL89zTwiX5o4GTMU+cj0aJk9ERERERERE5EIpRLLM6hBplBqDyEQv/+LfueLlz+IykhxxXcGcT/0Mf15RtssSmVVM06RvJEZT3wiNvVao1NhnrU/3jhCMJF7z88U5LuaV5DC/NIcFpX7ml/qZX5rDnCIfTrvtMr2FiIiIiIiIiMj0pNzAohAJNQaRszn43M+Y9+Tv4TfCvOpcSs0fPKogSWQKGQzFMqHSWMg0QnP/aw+T57AZzCn2Mb/Ez4JMwJTD/FI/RTmuy/gGIiIiIiIiIiJTl3IDi0Ik1BhEzuXE/uco++ld5DOiIElkGhmKxGnqC9HQM8ypnhFO9Y5wKr0djp97iLxCn9PqsVSSw4Iya72wPJc5RT7sNuMyvoGIiIiIiIiISHYpN7AoREKNQeS1nNj/PGU//UAmSKr9w8fIyS3IdlkicgFSKZPOYCQdLA3T0D2cDphGaBsMn/NzboeNBaV+FpX7WViey6LyXBaW+alVuCQiIiIiIiIiM5RyA4sj2wWIyNS28MprOcEP4acfYEn8KAe2vo+ln3kUp8ud7dJE5A2y2QyqCrxUFXi5dmHJhHOhWILT6UBpNGQ62W0t0USKIx1BjnQEJ3zG4xwNl3JZWO5nUZkVMNUUerEpXBIRERERERERmfYUIonI61p45bUcS3wP5y/u5orIy7z8jd9k7R/9L4bNnu3SROQi8bkcLK/KZ3lV/oTjyZRJ60CI413DHO8a4kTXEMe7hmnoGSYST3G4Pcjh9snhUn2ZFSotrshlSWUeSytzKfW7MQyFSyIiIiIiIiIi04WGs0Pd0kTO1yvbfsSy7b+L00iyq+KDXPWJb2a7JBHJkmTKpKU/ZAVL3VbANBouxRKps36mOMfF0so8llTkWuvKXOrL/LgdCqRFREREREREZGpRbmBRTyQROW+rbrqT3YFuNuz/LFd1fp/d/zefDe//02yXJSJZYLcZzC3JYW5JDrcuHzueSKZo7rd6Lp3oGuLVriFe7QhyuneEvpEYz5/s5fmTvZnrHTaDBaV+llaOBkvqtSQiIiIiIiIiMlWoJxJKFEXeqJ3f/SwbG/+VuGnn+K3/yfJNb892SSIyxYVjSY53DfFqZ5CjHUMc7QhytCNIMJI46/WjvZaWVeWxojqfFVV5zC3O0VxLIiIiIiIiInJZKDewKERCjUHkjTJTKfb80/tZN/Qkg/gJffjXVM1flu2yRGSaMU2TjkCEox1BXu0c4khHMNNrKXWWf5343Q4rVKrKZ2WNtZ5f6seuYElERERERERELjLlBhaFSKgxiFyISGiYpq/exOLEcZpstRR/+ln8+UXZLktEZoDRXktHO4Icbg9ysC3A0Y4g0bPMteR12tPBUh7Lq/NZWZ1PfZkfp92WhcpFREREREREZKZQbmBRiIQag8iF6m5rwvz2jZTTzyveq1h576PYHJpqTUQuvkQyxcmeYQ61BTnUFuBQW4AjHUFCseSka10OG0srcllZk8+qmgKurC1gQalfQ+GJiIiIiIiIyHlTbmBRiIQag8ibcWzfc9T99D14jDi7qj/CVR//52yXJCKzRDJlcrp3JBMqHWoPcLgtyFB08jxLfreDK2ryWVVbkAmWKvI9WahaRERERERERKYD5QYWhUioMYi8Wbt//u9s2PtnAOxb/wCr7/h4lisSkdkqlTJp7g9xsC3AgdZBXmkJcLAtQDg+ucdSRZ6HVbVWsHRlTQEra/LJ9TizULWIiIiIiIiITDXKDSwKkVBjELkYdvzbp7im87+ImE7a3vtjFqy6PtsliYgA1lB4J7qH2d8yyCstg+xvGeR41xCpM/4FZBhQX+q3QqXaAtbWFbKoPBe7hsETERERERERmXWUG1gUIqHGIHIxJBMJDjxwO6sju+imGPsnt1NcPifbZYmInFUoluBQW9AKlVoH2d88SNtgeNJ1freD1XMKWDOnkLV1hVw5p4A89VYSERERERERmfGUG1hmdYi0detWtm7dSiwWo6GhYdY3BpE3KzDQx+A/X0+d2cox51Lm3bsNl8eb7bJERM5Lz1CUA61WT6V9zYPsax5gJDZxGDzDgMXluaxOh0pr6wqZW+zDMNRbSURERERERGQmUYhkmdUh0ig1BpGLp+n4KxT899vIN0bYU3g7a//w+9Z3XUVEpplkyuRY5xB7mgfY2zTAnqYBmvtDk64rynFleiqtrSvkipp8PE57FioWERERERERkYtFuYFFIRJqDCIX275tD3PF9o9hN0xeWvJnrL/7L7JdkojIRdE9FGFv0yB7m61Q6WBrgFgyNeEap91gZXU+G+YVs2FeIWvrisj3agg8ERERERERkelEuYFFIRJqDCKXwvP/9XmubfgqSdPg2C3fZdl17852SSIiF100keRQW5C9TQPsbR7g5aYBeoaiE64xDFhSkcdV84rYMK+I9XOLKM11Z6liERERERERETkfyg0sCpFQYxC5FMxUit1fu5urgo8zbHppf9cPWLTmxmyXJSJySZmmSXN/iN2n+3mpsZ/dp/tp7Js8BN78kpxMoLRhXhE1hV7NqyQiIiIiIiIyhSg3sChEQo1B5FKJhEdo+KfbWB47QJAcet/7f8y/YlO2yxIRuay6gxF2pwOl3af7OdY1xJn/+qrM92RCpavnF7Gg1K9QSURERERERCSLlBtYFCKhxiByKQ0PDdLyz29jafwIg/gZvPNh5q64OttliYhkTSAU5+WmdKjU2M/B1gCJ1MR/jpXmurl6fjEb5xezcUExc4t9CpVERERERERELiPlBhaFSKgxiFxqgcF+Or5xG0sSxwjgp/99P2TeSvVIEhEBCMUS7GsezPRU2ts8QDSRmnBNRZ6Hq+cXsXFBMRvnl1BbpOHvRERERERERC4l5QYWhUioMYhcDoGBPjq23s6SxKsEyaHvPT9g3qrrs12WiMiUE4kn2d8yyM6GPnae6mN/8yCx5MRQqbrAa/VUWmAt1QXeLFUrIiIiIiIiMjMpN7AoREKNQeRyCQz20/6N21maOMoQPnrf/QPmXXlDtssSEZnSwrEke5sHMqHSKy2Dk4a/m1Pkm9BTqSLfk6VqRURERERERGYG5QYWhUioMYhcToHBflq/8XaWJw4zjJeed/0P81bflO2yRESmjZFogj1NA+w81cfOhj4OtgVInhEqzSvJGeupNL+Y0lx3lqoVERERERERmZ6UG1gUIqHGcD6GB7o5tesReg/uJtbYBMMR7E4nzpIS8ufMp2rlesqWr8ZwuQEDDGPcmjP202vDBobdWtvSa83vMCsEAgM0f+MdrIwfZAQP3e/+oXokiYhcoKFInJcbx0Klw+0BzsiUWFjm55r00HdXzy+mwOfKTrEiIiIiIiIi04RyA4tCJNQYzuXRT92Eb18nngjkjoDtda6POaC3JEWkNIm3ME55XpR5ORF8xhtpYsYZodJoyGQ7R+hkt4KnSdenj5/1+te63+i2HWyO9DJ+e3T/bMfS28bZzp/tc2dcc9bP2c9y7ei2E+xOa38aCgYHafyXd3BF/ACdlOD7oxfJKyjNdlkiItNeIBznpdP97DzVx46GPo52BCecNwxYVpnHxvnFXFNfzPq5ReR6nFmqVkRERERERGRqUm5gUYiEGsO5/PLDG1iweyizP5gDg6UQL3Jg5thIJpMkh5J4BqGqG3zRyfeI2aGzBAIFJqFck1huimhuipjPJOFJEfeC0zBxmuA0TZyYOExwmyYFqSRFyRSFyfQ6lcQ961vrWRi2cYGSw1rbXWPbNifYHelj469zTfzM6Lnxn7Gl72UfH1q9znV219jicJ99O/3sYHCA4NeuocbsYF/eW1j9mZ9k+6spIjLjDIzE2HXaCpR2NvRxont4wnm7zWBldT4bFxRzzYJi1tUV4XVNzx9QEBEREREREblYlBtYFCKhxnAuJ7f9J4G2k/iKqyiet5KyhRus4OAsRqLDnDr0Al17XyBy+AjOky2UtAzhi75280oBQz4I5MBgjkHIDWEXhNPriGvisYTXic3nw+nPwZXrx52Ti8vpwmnYcBkOnIYNp2HHYRg4sOHAhtMwcBg2HKaRPm7tOzFwY8Nt2HAbdtwYeAw7LtPAYxi4TPAAjpSJzUzhSCUxzCSkEpBKQmY7MXZswvZrnB9/n0nXnHmfceeZYb9d7W6OpuawMNWAw0hxxHUFy/IiYHdbbc2RXtvdZ2yPBlHjt88SVE0IsUav91jHHR5weibun6N9i4jMJN1DEXY29PFiuqdSU19ownmn3WB1baE1n9KCYlbPKcDtUKgkIiIiIiIis4tyA4tCJNQYLhXTNOk+eZCegy8TaW4k0dYO7d04ewZxBEZwBkO8oZHuziIFDPqhLxf68wz6cqE3z6AnH3ryrfWQl4s215LNsGEzbDgMBzbDht1mx26kl7Ns2wwbDptjwvEzPz/pXuPuYzNsE+5jw8BuGNhMsBsGdhNsmNgBuwl2wGZa+w7AbprYMbGb4DBN7GYKu2lmtq21iSOVxG4msadMHGYSe8paHOm13UxiTyZxmAkcyYS1nUpgT8axmUlIxiAZh1TcWidjkIhBMjpx20yd9eu6M7mUjfajNKdKqTZ6sGdraizD/hohk3fsuMMNzvH7ntcOp8YvZ55zeq17215vwEgRkUujbTDMzoY+djT08mJDH+2ByITzboeNdXMLuWZBCRsXFHNFdT4Ou/7MEhERERERkZlNuYFFIRJqDNliJpMkBwZI9PWR6O0l2ddHcmiI1EiI1MjI2DI8THJ4mPhQgOTQEMnhIRgOYURj5/WcuNvOSLGPoWIvwSIPgSIXg4Uu+gsd9BbYGHAniKXiRJIRosmotSSixFLnd//ZzsDIhGGjAVgmOEsfH79vT/cEsxs2axsDV8TGFw5up8gY4l+Kb6F5fk4mGBsfho2GXpkwLDUaiiXT26kzArBEejuBPZnAnhoNwaI4EnFrnYynA7h02JZe29NDK9rTz3dM2oeLmnWNBkpOX3rbl973jtt+vWPnOpdeOzwXLVAVkZnJNE2a+kKZ+ZR2NvTROzxxvNocl50N84rSw9+VsLQyD7tNf7aIiIiIiIjIzKLcwKIQCTWG6cqMx0kODhLv6ibR1Um8s5NEZyfx9g7ibW3E29pI9PS87n0MrxdnVRXO6iqc1dW4qqtxVlVhr6rCLC6AonySDhtJM0nKTJFIJUiZKZKpJEnTWjLH0vvjz2W2U+nPm4nMfSbd8yyfnXBfM0kqlcp8LnPsjHrGP2v880f3R589/vzosbPtJ1IJEmbikv56fqithD+L7eWUWcq753oxp8EPudswMkGY3TAyPcAcpoGT9Dxf6V5gTjOV3k/iTFnDIzpTycx1jvS1Tsz0/tg8YaNzhjlMc9x909ekr888i7FtZzr8ysw5Zpo4nD5sk8KnHHCNLv5x274z9v3W5yZdl6OhAEVmKNM0Odk9nAmUXjzdx2AoPuGaPI+Dq+db8yltXFDConI/hgJrERERERERmeaUG1gUIqHGMJOlolESHR3E0qFSvK09EzDF29tJdHfDefwWsOXm4igqwl5SYq2Li3AUl2TWjuIi7EXFOEqKseXmzshvnpmmmQmuJgRM5wqqzli/1ucSqQThoSFu/cmfUmAM84P6jxNbfcWEkO5cnzufWjLXnm9odkbYljST2f7yX1T2caGTk7EAy2WamcWJtXanjzvHnRtbxj7jNOy47C7cdjdOhxuX3YPL4cHl8OFyenE5c3C5/LicPlxuP05XHi53Lm5PIU53PoYnH9y56cVvBVsa4k9kykmlTI52BtmZDpV2ne5nODrxhwxK/C6uSodK1ywoYW6xb0b+vSgiIiIiIiIzm3IDi0Ik1Bhms1QsRqLD6rk0IWhqt8KmRG8vJN5gDxynE0dREY7iYuzFxenQqRhHUaEVNBUXYS8qyhy3eTyX5uWmoRce/BM2tX6bV53LWPIXO7NdToZpmhN6eb1WMBZPxsfWqQTxVDyzZPaTZzmWPj762TPPJVJnP362z42/JpaMYTL1/5g/a0iFDZdhw2nYcBkO3DYHTpsTdzqscts9uB1e3E4vbocPj8uPy+XH48rF7c7F4ynA5c639h3u9GfceOyeCfsOmyPbry8ybSWSKQ62Bdh5ygqVXmrsJxKfOP9dRZ6HaxYUc/UCK1iqKfRlqVoRERERERGR86fcwKIQCTUGOTfTNEkFAiT6+615m/r7SfT1kezrI9HXT6Kvl2RfP4n+PpK9faRGRt7wMwyfLxM6OSorraH1Kiut4fXS+7a8vFnxU9y97U3k//tqnEaShvc+xoIrNma7pBlhfMA1KdRKxomlYpnAaXQ/lowRTUYzx2PJmHVdMk40Gc1cE09EiMVDxBLh9BIllrSWeDJGNBUjlkoQSyWImwliZpKYmSI+hYItBwYumwOPzYnblg6oHFZA5XH6cDtz0sesAMpld00Kotx2N16HF4/DY63t1nr02Ohxt92NzVAPK5m5ookkr7QE2NnQx46GXvY1DxJLTgyVaou8XDO/hGvqi9k4v5iyPP0whYiIiIiIiEw9yg0sCpFQY5CLJxWJZIKmTNjU30+yr5/kQL8VPPX3kewfINnXhxmPv/5NAZvPh7O6CkdFpRUsVVZYgVNllbVdUYHN5brEb3d5vPzAu1k3vI2Xit7B+j/8XrbLkUskZaYmBFSj29FEhFhsiHgkQCwaJBZNr2NDxGIjxGLDROMjxOIhIokw0USYaDKaXmJEUnFiqQQRM0mUJBHDRsyAiGEjahiZJWbLXig7GjJlAqfxa7t38rHRMMrumRBK+Ry+CaHV6DGHzTErQmeZHiLxJHuaBtjR0MvOhj5eaQ2QTE38p+eC0hw2poe+2zCviBK/O0vVioiIiIiIiIyZyrnB17/+db761a/S1dXFunXr2Lp1K6tWrbokz1KIxNRuDDJzmaZJang4HTr1k+jpIdHZYQ2l155ed3SQ7O8/r/vZS0qsgKmiAkdlBc6yMhyjS2kpjrIybP6pP9n5oR2/YsUTdxM2XSQ+c4zc/KJslyTTVSoJsWGIDkEkCNEgRAIQCZAKDxIN9xOLDBCJDBKNDBKJDRGLDRFJB1XReIiomSQyLnyKGgYRm0HMMCYcj4w7FzZshA2DsM06HrFZAdblYjfsmUApx5mD1+HF5/Rl9ke3vQ7vhH2f00eOw9of/UyOMwefw4fb7p7yf3bI9DAcTfBSY39mTqVD7YFJUxPOLfaxtq6IdXMLWVdXyIJSP7YsBr8iIiIiIiIyO03V3OD73/8+H/3oR/nWt77F2rVr+Yd/+Acee+wxjh8/Tl5e3kV/3qwOkbZu3crWrVuJxWI0NDRMucYgApAKh4l3dBJvb7dCpo5O4h0d6cCpg3hnJ2Ykcl73MrxeHGWlOEvHBUyZpdQKnkpLseXkXOK3OjczlaLlCyuYk2pj1xV/y1Xv/cOs1SJCPJIJnjJLNDD52OgSHoTwgLVEBsG0hvFKAlHDIJQOmiLpoMkKnYz0dvpYOoAK251EnB7CDjcRu5OI3UHYZrfOGRDBJGImCacShFNRkmbqtd7kTbEZtkzQdOb6XMFTjjMHv9OP3+WfsO13+nHZZ0bPSXnzAqE4u073saOhjxdP9XGsa2hSqFTgc7JmTiFr66xQaVVtAR6nPTsFi4iIiIiIyKwxVUOk9evXs2nTJr72ta8BkEgkqKio4Atf+AKf+MQnLvrzZnWINGqqNgaR82GaJsnBQRIdHcQ7rGAp0d1ForuHRE838e5uEt09pILB876nze/P9F6aEDCd0bvJ5rk081i8+N3PcnXjv3LQdSUr73vmkjxD5JJLpazeT6Oh0oRl8Iz9/on7qcQbflwcrOApp4iwt4iQL4+QO5eQ20/I5SPkdBNyOAnZnITsNkKGwYiZJJSMEEqECMfDjMRHCCVChOIh61gifNG/LABOmxO/0wqXcl25mZApx5Vz9uNnhFA5zhz8Lj8um0s9pGaYQDjO3uYB9jQO8HJTP/tbBonEJwakTrvB8qp81s8tZG1dEavnFFCueZVERERERETkIrvQ3OB73/sezz33HHv27OHgwYPEYjG+853vcM8995zzMy+99BL3338/O3bsIB6Ps3LlSj7zmc/wgQ98YMJ1sVgMn8/HT3/6U97+9rdnjt95553k5OTw0EMPveH3fD2Oi35HEbmsDMPAUViIo7AQz7Jl57wuFQ5bQ+Z1d5PoHguXRvcT3d0kenpIjYyQGh4mNjxM7PTp13y2LT8fZ1kpjjN7NpWWTujZZLzB+Zrm3nQPfOdfWR59hY7mE1TOWfiGPi8yJdhs4C2wFuad/+dM0xqG72zhUygdNoX6YKTXWod6IdSPMzaMMxknL9gFwa7zf547H3KKwVcMOaWQUw8FZZBTRiqnmLCngJAnHUTZ7IwkrIBpNHgKJUJW+JQOnkbXI/ERhmPDDMetZSQ+wkh8BIB4Ks5AdICB6MAb+YpO4rA5yHXmZkKl0fApz5U3YX3msdHtHGeOQqgpJt/r5KbFZdy0uAyAeDLFkfYgLzcN8HJjPy83DdAzFGV/yyD7Wwb59nPW31MVeR5W1eazqraAVTUFrKzJJ8/jzOariIiIiIiIyCz1l3/5lzQ1NVFSUkJlZSVNTU2vef22bdvYvHkzHo+Hu+++m9zcXB5++GHuuusuWlpauPfeezPX9vb2kkwmKS8vn3CPsrIyGhoaLsn7KEQSmSVsXi+uOXNwzZnzmtclh0dI9JwtYJoYPJmRCKlAgGggQPTEyde8p72wcEKvJkdZGc7ychxl5TjKrW17URGGzQZARd1ijriuYFnsAI3bvkvlR7540b4OIlOeYYA711oKXvv36wTxsBUyhXrTAVN6e0Lg1De2H+63htuLpofo6z816ZY2ICe9AGB3W0GTvxRyysCfXnLKwF8LRaP7peAttN5lnJSZyoRJowHTSHyEofgQI7GRzH7meGxobD99fvQcQCKVeFNhlM2wWSGTMx00ufPOO4DKdeXicaj3y6XmtNusYKi2gI9dOw/TNGnpD/NykxUo7W0a4HjXEJ3BCJ2HIzx+eCxAXVCakwmVVtUWsLQyF7dDw+CJiIiIiIjIpfXggw+ycOFC6urq+MpXvsJnP/vZc16bSCT4+Mc/js1m49lnn+XKK68E4HOf+xwbNmzgvvvu484776Suru4yVT+ZQiQRmcDuz8Hun4d73rl7TpimSWpoKNN76Zw9m7q7MeNxkgMDJAcGiB47du4HO504SktwlpXjKC+nt3AhcIDy0z9leNc7cVVU4Cgvv2RD6IlMe04v5Fdby/lIpax5mzIhUy+M9MBwD4x0w3B3er/LOhYbgmQUgq3W8npszrHAKbcS/OXYcivJza0gN7cCcius4yWlYHtj39hPmSlC8dCksGkoPsRwbJhgLMhQbIih2BDBaJBgPJjZHooNEYwFiafipMwUgWiAQDTwhp4/ymVzke/OH1tc+RR4Csh3jR0rcBdMOu+2uy/oeWL1vp1T7GNOsY/3rrGGEhiJJjjUFuBAa4D9rYO80jJI60CYhp4RGnpG+PHeNsAaBm9ReS7Lq/JYVpnH8up8llTkkqseSyIiIiIiInIR3XLLLed97dNPP01DQwO//du/nQmQAPLz87nvvvu45557eOihh/jc5z4HQElJCXa7na6uiaPQdHd3U1FRcVHqP5NCJBF5wwzDwJ6Xhz0vD3d9/Tmvy8zXNCFc6rICp65uEl1dxLu7SPb2QTxOor2DRHsHALleB/F32plvtPL0//cZKjv6gdEh9MpwlI/1Yhrfo8lRWoq9sBDDoT/eRF6TzQa+ImspOY8hI+PhccFSdzpoOjNwSh+PBCAVh6F2a+l45dz3NWxWT6ZMsFSRCZ3IrRzbzynJhE02w2bNj+TyX/DrRxKRsaApFswETxMCqHHHRgOoobh1LmWmiKVi9IR76An3vKFne+yec4dM4/fHnS9wF+Cw6c+1s8lxO7hqfjFXzS/OHOsbjlqhUssgr7QOcqA1QP9IjMPtQQ63T5wjcG6xj2VVeSyvyrfCpao8SnPdGupQREREREREABgaGiI4br55t9uN231xfkB0+/btANx6662Tzm3evBmAZ54ZmzPe5XKxevVqnnrqqcycSIlEgu3bt/OFL3zhotR0Jn03QkQumfHzNbF40TmvM+NxEr29VqiUDpcS3V0cDfwfV9hfJX5FPsZgGDMcHjeE3onXfLYtPx9HYSH2wkLsRUU4igqxF4zbLizEXpjeLirC5vVe7NcXmVmcXiiss5bXk4iOhUrD3TDcCUPjlw5rPdJtDak33GktHa9xT8NuDZWXWwH+Csirsnpd5Y0uVdbiPL/fyx6HB4/DQ6mv9Pzef5zRnlDBWJBANMBgdJBALEAgEiAQS++neziNnh+9NmkmiSQjREIRukJvYO4qIM+VR5GniEJPIYXuQmt95rankCK3dc1sHm6v2O/mpiVl3LTEmlvJNE1aB8Icbg9wJB0kHekI0hGI0NgXorEvxKMHOzOfL/G7WFqZx6LyXBaX57KoIpdF5X58Lv3TWUREREREZLZZdsY89Pfffz9//dd/fVHufSL9Pc6FCyf/gG9FRQV+vz9zzag//uM/5mMf+xhr165lzZo1PPDAAzgcDj74wQ9elJrOpP8Ji0jWGU4nzspKnJWVjP/2b/tPHbD/fuYVd1K/52XM4eFJQVO8q2tcr6Zukn19YJqkAgFigQA0Np5fDR4P9qJCnGXlOGtqcFZX46yuwjW6XVmJ4XJdkvcXmXEcbsivsZbXkkpaYdNQBwx1jYVLQx3WMHqZsKkHzGR6/7WSJsBXnA6UxoVL+TXjjp1/0HQu43tCVfmrzvtzpmkyHB+2QqVoMBM2jYZQ448FomOBVDAaxMTM9JhqDDae1/O8Du85w6YiTxEF7oJMKFXsKSbHmTNje98YhkFtkY/aIh+3rajMHO8bjnK0Y8gKlzqscOlUzzC9wzGeO9HLcyd6J9xnTpHPCpYq/Ol1LvNL/Lgctsv9SiIiIiIiInKZHDlyhOrqsekDLlYvJIBAwBpiPz8//6zn8/LyMteM+uAHP0hPTw/33XcfXV1drFu3jscff5y8vLyLVtd4CpFEZMqqv/4ukvv+mkWpBk6dPMr8RcutIfTOksyPMpNJkoGANQ9Tfz+J/oH0nEzjtvv7SQyMbZuxGGYkkhlOL7x//+QbGwaOigorVJpTi6t2Ds7aGlxz5uCqrcWWnz9jv/kqcsnY7GND2L2WZGIsbBrugmD7uKXVWgfaIBG25ngK9UHnwXPfz1s0rhfTuMApvxrya61tx8UPjQ3DINeVS64rF3LP/3OJVIJANMBAZICB6IC1jgzQH+1nMDKY2R6IDDAYGaQ/2k8ilSCcCBNOhGkfaT+v57jtboo9xRR704vnHGtvMbnO3BnxZ16x3821C91cu7AkcywcS/JqZ5BjnUMc6xrieNcQxzqH6R2O0twfork/xJNHx3qROWwG80pyWFSR7rWUDpfmFPmw26b/10hERERERGS2y83NvWQBzYX69Kc/zac//enL8iyFSCIyZfmLKjnqvZKlkX207PgB8xf97et+xrDbcRQV4SgqggULXvd60zRJjYRIDg6Q7Osj3tFJvK2NeFsbsbZWa7u1zQqZOjpIdHTASy9Nuo8tNxdXbS3OOXNw1dbgrK219mvn4KyswLDbL+hrICKA3QF5ldZyLqYJ4YF0sNSWXtLh0uh2sA3iIQj3W8s5gybDmoepoNYKlTLrOdaSXwOunEvyqmfjsDky4c35GO3xNCl0ivQzGB2kP9KfOTYQtY6HE2GiySjtI+3nFTo5bc7XDprGrfPd0ytk97rsrJ5TyOo5hROO9w1HOd41bIVKXUMcT4dMQ5EEJ7qHOdE9zCPjxmR0O2wsLPdPGBJvcXkulfmeafX1EBERERERkUtntAfSmb2NRgWDQQoLC8967nJRiCQiU1p00R1wYB9lLY8Brx8ivVGGYWD352D350BNDd5VqyZdY5omyb4+Yi0txFtbiTU3E29uIdbaSry5mURPD6mhISJHjhA5cmTyQ5xOnFWVuGrn4JpTi7Om1lrX1uKqqcGWc/m+GS0yYxkG+IqspWLF2a8xTYgMnhEujYZNrdb2YAskozDUbi0tu85+L1/xuIBpzuTAyVto1ZQF43s8zWHOeX0mFA/RF+mjL9w3ad0f6c/s94Z7GYmPEE/F6RzppHOk83Xv7bA5KPYUU+Yro8RbQqm3lBJfCWXeMkp9pZljRZ4i7LapG7gX+91s9LvZuGAszDNNk85ghGOdVo+lVzuHONE1zInuISLxFIfaghxqC064T67bkZ5jKZfF5f5MuFTsv3jDIYiIiIiIiMj0MDoX0okTJ1i7du2Ec52dnQwPD7Nhw4ZslJahEElEprQF198NB77A0uRxTp0+xfx58y97DYZh4CgpwVFSAqtXTzqfCoetcKmlxQqYWlqJtVjreGsrZjxOvKmZeFMzI2e5v72k5KzD5DkrKnCUlGguJpGLxTCscMdbCOXLz35NKmUNnRdogcHm9Lpl4joaHBs2r2P/2e/j8o+FSqM9mArqoLAOCudaNUwhPqcPn9NHbW7t614bSUTGgqazhE7j10OxIRKpBF2hLrpCXa95X5tho9hTbIVKvlJKvaWZ9WjQVOorpdhbjNPmvFiv/qYYhkFlvpfKfC83Li7LHE+mTFr6QxN6LB3vGuJUzwhD0QR7mgbY0zQw4V4lfheLxg2HZ237yfVMjXcVERERERGRi++GG27gy1/+Mk888QR33333hHOPP/545ppsMkzTNLNawRTQ2tpKbW0tLS0t1NS8ziTgInLZnf7SeubFjvPkws9xy2/em+1y3hAzmSTR1UWspZV4SzOx5hbirS3EmluItbSQOkdX1fHsRUU4yspwlJXiLC/HUWpt2/PzseXlYc/Lx56fhz0vD1turobOE7nUwoNnhEtnhE0jPa9/D3c+FI4GS3MnBkwFc8DpvcQvcXnEkjH6I/30hHroCfdk1r3h3gn7/ZF+UmbqvO5pYFDoKZwYNqUDpjJfGeW+csp8ZRR7iqdcz6ZYIsXp3pFJ4VJzf4hz/Yu8usDL0so8Vlbns6LaWpfleS5v4SIiIiIiIrPQxcgNvvKVr/DZz36W73znO9xzzz2TzicSCRYvXkxbWxsvvvgiV155JWANb7dhwwYaGxs5duwYc+fOvfAXeZMUIqEQSWSqO/Tfn2XFiX9lp2sjG+97LNvlXFTJQMAKmNLBUrylOR04tRDv7oZ4/A3f03C5MNxuDLcb27htw+3C5kzvjx53ObG53Riu0WMuDJcLm9eHvbAAR2Eh9qIi7IWFOIqLsXlnxje2RS6peNgaHi8TLjVby0AjDDTBSPfr38NfPhYsjQZNo9t51dY8UTNIMpW0wqZ0wNQd6ra2Q70TjvWF+0iYifO6p92wU+wtzoRKo8vo/uja5/Rd4rd7faFYgpPdw5lh8Y51DXO8c4jOYOSs15flullZnc/y6vxMuFSRp7mWRERERERELqYLzQ0efPBBnn/+eQAOHjzI3r172bRpE/X19QBce+21/M7v/E7m+m3btrF582Y8Hg933303ubm5PPzwwzQ1NfHAAw9w773Z/aF6hUgoRBKZ6vpP7qLoe7cyYroZ+fQJyorys13SZWGmUiQDARJdXSS6u0l0dxPv7ibR1U2ip4dkMEAqOEQyGCQZDGKGQpe8JntREc7q6vRShau2FveCBbjq63FkeZI/kWkjFkoHS01WqDTQOLY92GQNl/dabA4rSJoQMM0dC51ySrM2H9OlljJTDEYH6QmNBUvjezWNDpvXF+4jaSbP6565ztwJIdP4gKksx9ou8hRhM2yX+O0mC4TiHOsa4lBbwFraA5zsHiZ1ln+9l+a6WTunkLV1haypK2B5VT4e59TqiSUiIiIiIjKdjOYG9fX1OJ1OtmzZwpYtW173c/fccw8PPfTQOc9/5CMf4bvf/e6EY7t37+b+++9nx44dxONxVq5cyWc+8xnuuuuuN/sab5pCJBQiiUx5qRT9X1hAUaqfbev/nZvuuPv1PzMLmbEYyeFhzGgUMxolFY1hxka3rbUZi1n7sRhmNGbtx6zjqWgUMxa3zodCJAcGSA4MkBgYINnfjxmNvubz7cXFuOvrcS9YgHvRQtyLFuNetBC733+ZvgIiM4BpQnggHSo1jgVLo+vBZkjGXvseTt/EOZgy6/Qxb8HleJOsSqaS9EX66A510xXqojvUnVnG74/EzzZT3WQOw0GJr2RiwHRm4OQrw+u49L01Q7EERzuCHGwNcKg9yKG2ACe6h0mekSy57DZWVOexJh0sra0r1DB4IiIiIiIib4ByA4tCJNQYRKaDw/92D8s7f8KTue/mlnvPneTLpZMcGiLe1jZhiTU2EW1oIN7aes7POaurcS9ejHvxIjyLFuFevBjXnDkYjpk1HJfIZZFKwVDHxGAps26EYDvwOv+08+SPC5nmnhE4zQFXzmV4kalhJD4yIVTqGumaFDr1RfrOe76mXFfupJCpzFt2yXs1hWNJDrUH2NM0wJ6mAfY2DdA3MjlsXFCaw8YFxWycX8LV84so9rsvah0iIiIiIiIziXIDi0Ik1BhEpoPWF39EzWMfo8Uso+S+o3jdCiCmklQoRLThFNGTJ4mePEH0xAmix46T6Oo66/WG2231WDojXHIUF1/mykVmmETUmo9pwhB544bOC/W+/j18JRN7LmW250JBLThmV/CQSCUyQ+edrWfT6LFwInxe98v0avKWnXsYvTc5V5NpmjT3hzKh0p6mAY51DXHmv/qXVORy9fxirllQzFXzi8n3Oi/4mSIiIiIiIjONcgOLQiTUGESmAzM6TOzLc3ETZ8fmR7hm47XZLknOQ3JwkMjx40SPHSd6/DiR48eIHj+BGT77N1vtJSV4RofCW2wNh+eur8fmnl3ftBa5ZGIj6VCpedwQeeN6M0UCr3+P3MrJQ+SNbufVgH32hfymaTIcH54QMo3O0XShvZr8Tv+kgKnUVzohbCr2FGO3nd+8R4OhGLtO97OzoY+dDX0c6xqacN5uM1gzp4AbF5dxw6JSllflYczQubVERERERETOh3IDi0Ik1BhEposTX93MwuCLPFKxhTs+8aVslyMXyEyliLe2Ejl2LBMuRY8dI9bczKQfkwew23HNnYt70UI8ixfjmjcfZ1UlzspK7EVFGLZLP9m9mUqRGhoiGQiQHBwct4zbDwRIDQ9jmilrNDHTBNPE8Hiw5fiw+XzYcnJwFBXhKCvDUVqGo7wMZ2UlNo/mKZEpIjw41nNpQtCU3n69OYQMuxUy5VdDXrW1zq8d286rgZwSmKXhxPheTWcLmUb3Q4nQed3Pbtgp9hZbAZM3HTDlTB5OL8eZMykQ6h2OsutUPzsaetnZ0Mep3om/tqW5bm5YVMqNi0u5rr6UfJ96KYmIiIiIyOyi3MCiEAk1BpHp4tQv/oH5e77ATmMVV/3VM9hss/ObkDNVKhSyhsM7fpzIMStYih47RjJw7p4RhtOJo7ISZ0UFjpJibPn52PPzsecXYC8owOb1YDid1uJyYdjtmIkEZjyeWVLhCKlggGQgSDIYJBkMkAoOTQiHkoGANRfNpWAYOKurcc2bh3v+PFzzF+BZtgz34kXYXK5L80yRC2GaEOqHwcbJw+SN9m5KRl//PnY35FVBfs24cKl63H6NNW/TLA2aYPJcTaPzNXWHuukJW+FTb7j3vHs1eeweir3FFHuLKfGUUOK1lswxbwnxaA4Hm1O8cCLIjoZeQrFk5vN2m8G6ukI2L6/g1uXl1BRe+FB7IiIiIiIi04VyA4tCJNQYRKaLWMcRXP++kYjp5NWPHOTK+ZXZLkkuMdM0SXT3ED1+LB0uHSPe3EK8o4NEd/fZey5dQobPh70gH3tBgRVWFRRM3Pb7wbCBzQbp73+bkSipUIjUyAipkRES/X0kurpJdHeT6OoiFTpHjwOnE/fCerzLl+NZsRLfmtW4Fiy4LD2vRC5IKgXDXRBss+ZlCrSObQfbINBmnec8ft86feAvh9yKsXVuBfgrILc8va4Ab+GsDZuSqSR9kb5zztM0ugzHh9/Qff1OP0WeYtxGPpFIDr0BF4NDHlIJP2YiFzPpZ1FJBZuXLOD2FXUsKvdr2DsREREREZmRlBtYFCKhxiAybZgmg19YQEGyj4eXf4P3vf9D2a5IssiMx0l0dxPv6CDe0UlyYGCs51B6bUajmLHYWM+jRMLqleRwjPVQcrux5+dhy8vDnpc/tj0+JEovF7tnkGmaJPv6iJ0+TfT0aWKnThM9cYLI4cMkBwcnXW/Lz8e3ejW+dWvxrlmLZ8Vy9VaS6SURg6GOsVApmA6bMtttEO4///vZ3emQqTy9rrS2c8ogpzS9lFhrV86sDJxC8RB9kT76wn30hnutdaSX3rC19If7M9uxVOwN3dtMubCbfgrchdTmlzInv5QiTxGFnkKKPEUUuAsy24WeQvxOBU4iIiIiIjJ9KDewKERCjUFkOjn94EeY1/pTfuJ9H+/58//Idjkil4RpmiTa2wkfOkzk8GHCBw4QfuUVzHB4wnWGy4XnipX41qzFu2Y1vtWrsefnZ6nqyycVi5EKBq0eXuleXmY8Djab1VPLZkuHg/nY8/Kw+f3qwTWdxEIw3AlDXel1ehnuGrfdCeGBN3Zfh3diqPR62/bZNQeQaZoMx4czgVJfuI++SF9mf/RYT6iX/sgAKRJv+BkOm4NCdyGFnvQybrvIXUSBp8AKnNLHC9wF2G32S/C2IiIiIiIir280N6ivr8fpdLJlyxa2bNmS7bIuO4VIKEQSmU4Gd32fgl99kiOpOir//GUKc9QLQ2YHMx4n8uqrhPbsIbxnD6E9e0n2T+6x4V5Yj3fNWnxrVuNdswZnTc20+8n/1MgI0cZGYukl3tZuDf+XXs7WS+s12Ww4SkpwVldnFnf9AtyLF+OeNw/DObvCghkjEU0HS11W76bRkGm4E0Z6YaTHWoZ7IBF+/fudyVt4loCpDPyj63RvJ3+Z1ctpFhkNnFqDPWw7cZpnGxp5pb2NpDGEYR/BsI/gz4mSlxPFZh8hGB8kfAG/BgYG+e78TOA02qPprD2d0uedsyz8ExERERGRS0e5gUUhEmoMItPKcA88UA/A4297js1XXZHlgkSywzRNYo2NVqC0dx/hvXuJNTZOus5RWop3zRp8a1bjWbYM95Il2HNzL3/BZzCTSeJtbURPnSJ2upHY6dOZ0CjR3X1e97D5fBg5PmvtdELKhFQKM5XCDIdJBoOY0ehr3sNwOnHV1+NddQW+tevwrVuLs1Lzrc04sREY7p4YLo30nLGf3g71gpl6Y/d3+ccCpcz6XIGTf0YOqzcSTfDk0S5+8UoHzxzvJp4c+y/Gqpp83nFlKdcu9pKyDTMQGaA/0s9AZIDB6GBmeyA6kDkXjAUvqA6/0z+hd1Ohp9Dq4eQeC6CKvcWUecso8hSpp5OIiIiIiJyTcgOLQiTUGESmm86/X09F6Dj/Xf2X/ObH/zTb5YhMGYm+PsL79hHas5fw3r2EjxyBeHzSdc6aGjxLl+BesgT3/Pk458zBNWfORQ+XTNMk2dtLvL2dWHOzFRidOk3s1ClijY3WEHTnYC8qwjVvHq65dbhqa3GUleMoK8NRVoqjtBR7fv55DVGXikZJDgZIdHcRb2sj3tZGrKWF6PETRI8dIzUyMukzzqoqfBuvxn/DDeRccw12v/9NfR1kmkklITwII90TezON9FjHhs9YJyJv7P5OH+RWQF61NYdTXpW1nVc1tp1TCtN4CMZAKM7jhzv5xYF2XjjZSyr9vw2n3eDmJeW8b20NNy4uxWk/9zvGU3EC0YAVLkUG6I/2j22fETqNhlFJM/mG6rQZNko8JZT6Sin1lVLmLaPUV0q5r9w65i2lzFdGgbtg2vXoFBERERGRN0+5gUUhEmoMItNN6//9GTWH/51fGjdy+1/9FJtN39gROZtUJELk4EGrp9IrrxB99VXi7e3nvN5eWIizpgZHSQmOkmLsxcU4ioqsnj5uD4bHjc3txkylIJHATCQw43GSQ0OkAgGSg4MkBwdJ9PRagU1HB2Ysds7nGW53Oiiai2veXNyj23V1l2VuJzOVIt7eTuTwEcJ79xB6eQ+Ro0chNa4XisOBb+1a/DfcQO6tt+Kqqb7kdck0YpoQHUoHTd3pcKl73P744z0QnxxanpXNAblVY8FSQS0U1EHhXGvJrwXH9BjOtWcoyi8PtPPw3lYOtY31LirOcfGuK6t539pqlle9+d/vKTPFUGxoLGA6I3gaHzj1hfvojfSSOs8eZx67h0p/JVX+KqpzqqnyV2WWan81xZ5ihUwiIiIiIjOQcgOLQiTUGESmm9iJp3H993voNAvp+939LK8uyHZJItNGMhAg8uoxIkePEH31GLGmJmItLSR7ey/NAw0DR1kZztoa3PMX4Jo/D/f8+bjmz8dZWYlhn1pDSSWHRwjv28fI888x/Myzk4YI9FxxBXm33Ube5ltxVitQkjdodFi9oQ4ItkOw7Yx1uzWvE6/3z3PD6rFUOBcK6yYGTMX1kFN8yV/lQrzaGeThPa38ZF87vcNjQ00uqcjlA+tqee+aagp8lyccS6aS9Ef66Q530xPqoTvUTXeom56wtd0T6qEn3EN/ZPLcc2dy291U5lRSk1tDXV5dZpmbN5eKnApsxvTtVSYiIiIiMpspN7AoREKNQWTaiYeJf7EWJ3H+e8NP+M3b35LtikSmveTwCPGWZuJtbST6+kn09ZLs6yfR34cZjpCKRjCjMcxIBGw2DIcDw+EApwO734+9oAB7fr61LirGWVWFs7oKZ3k5hmt69Jg4m1hTE8PPPMvQU08ReumlCb2UPKuuIO9tbyPvbW/DWV6exSplRknGYbhrLFwKtMFgMww2wUAjDDRBIvza9/AWQvFCK1Aqqbe2SxZC4Txwei7La7yWRDLFsyd6+NGeVp480k0saf2+cjts3LGykt+4ag7r6gqnRO+eWDJG10gXbSNttA+3Z5a24TbaR9rpDnW/Zo8ml83FnLw5E4Kl+oJ6FhQswOf0XcY3ERERERGRN0q5gUUhEmoMItNR59ffQsXAHv49/9P83h//TbbLEZFZINHby9Cvf03wV49ZgdLoP6EMA9+6deTdcQe5m2/FUViY3UJlZjNNa5i80UBpoBEG09v9pyHYeu7PGjZrKLyShVC6BMqWWkvpEnDlXKYXmGgwFOPnr7Tz/V3NvNo5lDm+sMzPb2yYc1l7J12IeDJOZ6iT9uF2WoZaaAo20RhspCnYRMtQC4lU4pyfrfZXs7BgIQsKFlBfWM/CgoXMy5+Hyz5131dEREREZDZRbmBRiIQag8h0NPjI/RS89DV+ltzEW/7iZ+R6nNkuSURmkURPD8HHnyD46KOE9+4dO+FwkLPpGvJvvx3/zbdg92fnG/Myi8VC0N8AvSeg72R6fQJ6T0Js6BwfMqyh8MqWQfmydLi0zOrJZL88f7+apskrrQG+v6uJX7zSQTieBMCV7p30wSnUO+l8JVIJOkY6aAo2ZZZTgVM0DDbQGz77EKJ2w86cvDksLlzM0uKlLC2ylgJPweUtXkRERERElBukzZgQ6etf/zpf/epX6erqYt26dWzdupVVq1ad12fVGESmoVPPwH++kw6ziIPv38GtKyqzXZGIzFLxtjaCv/oVgUcfJXrkaOa44Xbjv/FG8m6/Hf8N12PzZH8YMZnFTNOaj6nvBPQeh55j0H0Euo5A6BxzotmcULIIKlZAxRVQeYW19hZc0lKDkTg/22/1TjraEcwcX1Tu58Mb5/Ke1dXkuB2XtIZLbSAywMnBk9YyYK1PDJ5g6BxBX1VO1VioVLyUZcXLKPGWXOaqRURERERmF+UGlhkRIn3/+9/nox/9KN/61rdYu3Yt//AP/8Bjjz3G8ePHycvLe93PqzGITEOxEIkvz8Fhxvn68h/y6fdvznZFIiJET50i+MijBB95hFhjY+a4LSeH3FtuIe+O28nZuBHDqd6TMoUM91iBUvdR6D6cXh+F2PDZry+oSwdKq6BylbWdW3HRyzJNkwOtAb6/q5mfv9Ke6Z2U63Hw/rW1fHhjHXNLZk5vP9M06Q51c2LwBK/2v8rRvqMc7T9Ky1DLWa8v9ZayvGQ5q0pXsbJkJStKVpDjnDlfDxERERGRbBvNDerr63E6nWzZsoUtW7Zku6zLbkaESOvXr2fTpk187WtfAyCRSFBRUcEXvvAFPvGJT7zu5xUiiUxP/f/yFor69vBV76f4zJ9/MdvliIhkmKZJ9OhRAo88QvDRX5Ho6MicsxcWkrv5VvLvuAPv2rUYNlsWKxU5B9OEQAt0HYbOg9DxCnQegMHms1+fU5YOlFZB9VqoXnNRg6VAOM7De1r5z52NNPaFADAMuHFRKR++Zi43LCzFZps+Q929EcFYkGP9xzKh0tG+o5wOniZlpiZcZ2CwoGBBJlS6ovQK5ufPx26zZ6lyEREREZHpTbmB5bKESN/73vd47rnn2LNnDwcPHiQWi/Gd73yHe+6555yfeemll7j//vvZsWMH8XiclStX8pnPfIYPfOADE66LxWL4fD5++tOf8va3vz1z/M477yQnJ4eHHnrodetTYxCZnsKPfx7vzq/y4+S1XPtnD1OWq6GiRGTqMVMpwvv3E/zlIwQfe4xkf3/mnKO8nLy3vY28O27Hs3y5AiWZ+kL9VqjUeQA6Dljr3uNwRqABQF61FSZVr7WWyivB8/qjBLyWVMrkmRM9PLSjke3HejLH55Xk8KGr67hzXQ15s2CexFA8xLGBYxzsOciB3gMc7DlI+0j7pOt8Dh8rSlawqnQVq8tWs6psFXmuN/drICIiIiIyWyg3sFyWEGnu3Lk0NTVRUlJCTk4OTU1Nrxkibdu2jc2bN+PxeLj77rvJzc3l4YcfpqmpiQceeIB77703c217ezvV1dXs3r2b9evXZ47//u//Pg0NDTz++OOvW58ag8g0dWo7/Oe7aDeLeOndz/Gu1fr9KyJTm5lIMLJrF8FHHmXo178mNTQ2/4m9oADfVVeRc/VV+K6+GtfcuRjGzOxZITNMLJTusfQKtO+Htr3Qc/QswZIBpYutQKlqtbUuXwEO1wU9trF3hP/c2cT/vdzCUDQBgM9l5/1ra/jotfOoK55dQ7v1hns50HOAAz0HONh7kIO9BwknwhOuMTBYVLiI1WWrWVu+ltVlqynPKc9SxSIiIiIiU5tyA8tlCZGefPJJFi5cSF1dHV/5ylf47Gc/e84QKZFIsGTJElpbW3nxxRe58sorAQgEAmzYsIHGxkaOHz9OXV0doBBJZFaLhUh+uRa7meAri37A//fBt2W7IhGR85aKxRh57jmCjzzC8PZnSIVCE87bi4rwrFiOd8VKPCtW4F60EGdlJYZdQ1PJNBAdtobAa9uTXvZC4CxD4dnd1hB4tRug9ipryX1jocZINMFP9rXx0I5GTnRb8zgZBmxeVsHHr5/H2rqii/FG004ylaQh0MCBngPs797Pvu59NA9N/jWo9ldnAqU15WuYlzdPAbaIiIiICMoNRjkux0NuueWW87726aefpqGhgd/+7d/OBEgA+fn53Hfffdxzzz089NBDfO5znwOgpKQEu91OV1fXhPt0d3dTUXHxJ/gVkSnE5WO4+Arye/eSPP0CoBBJRKYPm8tF7s03k3vzzZjxOOGDhwjtepGRnS8S3rePZH8/I88+x8izz2U+YzidOOfMwVVXh7OiHHtRMY6SYuxFRdi8Pgy3C5vbjeF2YyaTkEhgJpOY8QRmIo4ZjWHGopixGKlo1NqPRjHjZ+zHrOtS0RhmPI7N58Oen489Lw97STGu2jm46ubgrKrCcFyWf07KdOP2w9xN1jJquNsKk9r3joVL4QFo3W0tO79hXVdQB3OuHguWypbBa8zrk+N28FtX1/GbV83hhZN9fPu5UzxzvIfHDnfy2OFOVs8p4Heunc/m5eU47LNnyEi7zc6iwkUsKlzEnYvuBKAn1MO+7n3s7d7L3q69HBs4RttwG23Dbfy84ecAFLoLWVO+hvUV61lfsZ76gnpsxuz5uomIiIiIyERT7n/927dvB+DWW2+ddG7z5s0APPPMM5ljLpeL1atX89RTT2XmREokEmzfvp0vfOELZ31GNBolGo1m9ofGDSUjItOLd8Em6N3LvPBhmvtCzCn2ZbskEZE3zHA68a1ZjW/Nako++UlS0SjRY8cIHzxI5NBhIocOEWtsxIzHiTU0EGtoyHbJgFW3e+lSvFdcgXfVFeRccw2O4uJslyVTlb8MFt9mLQCmCf2noPUlaNkFLbutYfEGm6zlwP9a17n8ULMu3VNpA1SvA2/BpNsbhsG1C0u4dmEJx7uGePC5U/x0Xzv7mgfZ8v291BR6+eimeXxgfS1+95T7b9BlUeor5da5t3LrXOv/WsOxYV7peSUTKh3sPchAdICnmp/iqeanACtUWlexjqsqrmJ95Xr1VBIRERERmWWm3P+eTpw4AcDChQsnnauoqMDv92euGfXHf/zHfOxjH2Pt2rWsWbOGBx54AIfDwQc/+MGzPuPLX/4yn//85y9+8SJy2bnmXg27/oU1thO80NDLnOI52S5JRORNs7ndVjBzxRWZY2YySbyjk1hjI7HmJhI9PST7+kn09ZHs7ycVjWBGopmeRNjtGHa71VPI4cBwODDcbmwuF4bbjZFe29wua9vlTh93Wr2ZRvcdDlKhEMlAgGQwQKKrm3hLM7HmFsxolMiBA0QOHGDge1adnpUr8V9/PXl33I57/vwsfQVlWjAMKF5gLavuto5FAlYPpZbd6WDpJYgNWfMgnto++kEoWwp111jLnGsgr3LCrReV5/L3d67iTzYv5ns7m/ivF5toHQjzN788wj89eZwPXjWHe66ZS2W+93K+8ZTjd/nZVL2JTdVWj7F4Ms7hvsO83PUyuzt2s79nPwPRAX7d9Gt+3fRrAEq8JayvWM+Gig1sqNhAbW6tQiURERERkRnsssyJNN7rzYl066238utf/5oTJ05QX18/6Xx1dTXDw8MEAoEJx7/+9a/zj//4j3R1dbFu3Tr+9V//lVWrVp21hjN7IrW1tbFs2bJZP7ahyLQ03A0PLCRlGvz5wl/wD791XbYrEhGZFcxUinhrK+EDBwkfeIXQyy8TPXJ0wjWeFSvIf+c7yX/XO7Hn52epUpnWUknoeXWsp1LLLqv30pmK5qdDpU3WuqDOCqnSwrEkD+9t5T+eP82p3hEAnHaD96yu5vduWMCCUv/leqNpJZ6Mc7D3ILs7d/NS50vs795PLBWbcE25r5wNFRtYX7GeqyuvptJfeY67iYiIiIhML5oTyTJjQqQ3Q41BZHqL/OMVeIaa2GL7S77xV3+in4YVEcmSeHc3I889z9ATTzD8/POQTAJgeL0UvOfdFP7Wh3DPn5flKmXaG+6B5p3W0vQCdB4EMzXxmrxqmLNxLFgqXQyGQSpl8tSr3Xz7uVPsPt0PWFnTbcsr+OSNC7iipuDyv880Ek1GOdBzgN2du9ndsZsDvQdIpBITrpmbN5erK69mY9VGNlRswO9SQCciIiIi05NyA8uUG84uP/1TqucKiYLBIIWFhZezJBGZ4pxzr4aDTSyMHeF41zCLK3KzXZKIyKzkLCuj4H3vpeB97yXR10fwkUcZ/NGPiB4/zsD3/4eB7/8P/ptuomTLFrwrlme7XJmu/KWw7J3WAtYQeC27rUCpaQe07YVgGxz6kbUA+IphzkZsdZt4a901vPXjG9jTEuSb2xt48mgXvzrUya8OdXJtfQm/f+MCNi4o1g+lnIXb7mZ9xXrWV6xny5VbCCfC7O/ez0udL7GrcxeHeg/RGGykMdjID479ALthZ2XJSjZWbWRj1UZWlKzAaXNm+zVEREREROQNmHIh0uhcSCdOnGDt2rUTznV2djI8PMyGDRuyUZqITFH2ORvg4P+yxjjB7tN9CpFERKYAR3ExRR/+EIUf+i1Cu3bR/5//xfC2bZnF/5a3UPqpLXiWLct2qTLdefJh4VutBSAWgraXrUCp6QVrXqVQH7z6S2sBcOezdu61PLj4OhrXr+VfDjj46YFOnj/Zy/Mne1lVW8Anb1jArcvKsdkUJp2L1+HNBEQAwViQlzpeYmfHTl7seJGmYBP7e/azv2c/33zlm+Q4c1hfvp6rq6yeSvPy5imsExERERGZ4qZciHTDDTfw5S9/mSeeeIK77757wrnHH388c42ISEbtVQCstp3kR6d7+dDGudmtR0REMgzDIOfqq8m5+mqip07T+2/fJPjLRxh++mmGn36a3M2bKfvMH+Oqq8t2qTJTuHww73prAUjEoGP/WE+l5hchGoBjj8CxR5gL/KOvhL9dvpGnIov5xukqXmkx+cT39rCgNIdP3LCAd11Zjcthy+JLTQ95rjxurruZm+tuBqB9uJ2d7TvZ2bGTXR27GIwOsr11O9tbtwPWfEobqzaysXIjV1VeRbG3OIvVi4iIiIic3U033YTT6WTLli1s2bIl2+VcdlNuTqREIsHixYtpa2vjxRdf5MorrwSs4e02bNhAY2Mjx44dY+7cuW+6lq1bt7J161ZisRgNDQ2zfmxDkWkrlST5pVrsiRF+0/GPfO8vPqafahURmcKip07R+81/I/jLX4JpgtNJ4W/cTcknP4lDwxbLpZZMQOcrcPpZa2l+EeKhCZcMOUvZHl/Cs/Gl7Egux8yv5ZM31fOBdTW4HfYsFT69pcwUR/uP8mL7i+zs2Mm+rn3EUrEJ1ywrXsamqk1cV3MdK0tW4rBNuZ95FBEREZFZRHMiWS5LiPTggw/y/PPPA3Dw4EH27t3Lpk2bqK+vB+Daa6/ld37ndzLXb9u2jc2bN+PxeLj77rvJzc3l4YcfpqmpiQceeIB77733otanxiAy/SUfeif208/wF/GP8nuf+SJzin3ZLklERF5H5Nhxuh94gJHnngPAlptLySc+QeFv/SY2tzvL1cmskYhB2x4rUGp8Dlp2QXJiuNGcKmVHajlH3FeyfNMdvOu6tXicCpPejHAizL6ufezs2MnO9p0cGzg24XyuK5eNlRu5tvparqm6hvKc8ixVKiIiIiKzlXIDy2UJke655x4eeuihc57/yEc+wne/+90Jx3bv3s3999/Pjh07iMfjrFy5ks985jPcddddF70+NQaRGeDpL8Kzf88PEzdgvHsr719Xm+2KRETkPA2/8ALdf/8PRI9Z30R2VlVR+sd/RN4dd2DYNISYXGbxsBUknX4OTj+L2bYHw0xOuOQUNYRqrqN+4zvw1N8Abn+Wip05esO97GjfwfOtz7OjYweBaGDC+UWFi9hUvYlrq65lddlqnHZnlioVERERkdlCuYHlsg9nNxWpMYjMAMd+Bf9zN0dTtXzniv/m7+9cle2KRETkDTCTSQI/+zk9X/86ia4uANzLllJ27734N23KcnUyq0WHoGkniVPPEDzyFAXBV7Ex9l+opOGAmg3Y698CC26CqtVgUy+lNyOZSnK47zDPtz3PC20vcLD3IOa4r7nP4eOqyqu4tvpaNlVvotpfncVqRURERGSmUm5gUYiEGoPIjBDsgK8uIWka3J7zAx7/s9uyXZGIiFyAVDhM/3/+F33f/jap4WEAcq65hrI/uRfPsmVZrk4E4sN97H76Z/QeeIzVsb3MsfVMvMBTAPOutwKl+TdB0bys1DmTDEYG2dG+gxfaX+D5tufpj/RPOD8vfx6bqjZxbfW1rKtYh9uu4TBFRERE5M1TbmBRiIQag8hMkXpgMbbhTt4XvZ9vfvb3KcvzZLskERG5QImBAXq/+U0G/ucHEI8DkPfOd1D6B3+Aq1ZDlkr2xZMpfra/nZ889RxzA7u51naQTfbD5BGaeGHhXFjwFitQmnc9eAuyUe6MkTJTvNr/Ki+0WYHSKz2vkBw33KDX4eWqyqu4oeYGrqu+TnMpiYiIiMgFU25gUYiEGoPIjPE/vwHHHuXz8Q+x9q77ePsVVdmuSERE3qRYSws9X/s6wUcesQ7Y7eTdfjvFH/8dPIsWZbc4ESCRTPHLAx38y9MnaOwJcoVxipvdh3lP3gmqhg9ipBJjFxs2qFpjhUr1N0P1OrA7slf8DBCMBdnVsYsX2l7gubbn6A51Tzi/tGgp19Vcxw01N7CiZAU2Q/OsiYiIiMj5UW5gmdUh0tatW9m6dSuxWIyGhoZZ3xhEpr1n/h62fZEfJ6/llXV/x+fftSLbFYmIyEUSPnSYnq9/nZHnnssc87/lLRR9+MP4rtqAYRhZrE4EkimTRw928M9PneBEtzUUY7U3wV+t6OMW91Ecjduh9/jED3nyrR5KC98K9bdAbsXlL3wGMU2TYwPHeLb1WZ5pfYaDPRPnUiryFHFt9bXcUHMD11Rdg9/lz2K1IiIiIjLVKUSyzOoQaZQag8gMceLX8N93cjJVxaeK/p3H/uj6bFckIiIXWfjwYfq+9W2GnngC0v+Mdc2dS8EHPkD+e96No7AwyxXKbJdMmfzyQDtfe/IEp3tHACjxu/n9GxfwwaU2PM3Pwcmn4NQ2CA9M/HDFSitMqn8r1G4AuzMLbzBz9Ef6eb7teZ5tfZYX2l5gOD6cOecwHKwtX8v1Nddzfc31zM2fm71CRURERGRKUm5gUYiEGoPIjDHcAw/UkzINVsW+zY77302uR998ERGZiaKnTtP/nw8R/PkvSIXSc9DY7eRcdRW5mzeTe8vNOIqLs1ukzGqJZIqf7Gvj60+doHUgDEBFnodPvaWeD6yrxWUzoW2P9UMwJ5+E9n0wrtcM7jyYf+NYL6U8DdP7ZsRTcfZ17cv0UmoMNk44X5dXx3XV13FD7Q2sLVuLUwGeiIiIyKyn3MCiEAk1BpEZ5Z9WQKCFu6J/xR9+7B421ZdkuyIREbmEksMjBB95hMEf/pDI4cMTzrkXL8Z31QZ869fjWboUZ3X1mxr2zjRNzGiU1NAQyaFhDJcLe34eNr9fw+nJOcUSKf5vTwvfePokHYEIADWFXv7wLQt575pqHPb0HD3DPdDwNJz8tdVTKdw/8UZly2FhupfSnKvVS+lNag42ZwKll7teJjFu7qocZw7XVF3DTbU3cX3N9eS787NYqYiIiIhki3IDi0Ik1BhEZpT//S04+gu+EP9N8t/yx/zBzQuzXZGIiFwmsaYmgo8/wdBjjxE5cmTSeVtODq7583GUl+EsK8OWn4/hdGI4nZBMkQqHSYVCpEIjpIaGSQ0PkQwOpUMja23G45PuaziduObNw11fj3fdWvzXXIOzrk7BkkwQiSf5we5mtm5voGcoCsC8khw+ffNC3rGqCrttXHtJJa2eSSeftHoqte1hQi8lVy7MvwEWbYaFt2oupTdpJD7CzvadPNv6LM+2PktfpC9zzm7YWVu+lptqb+KmOTdR7a/OYqUiIiIicjkpN7AoREKNQWRGefYBePpv+VnyGn624G/4j3vWZ7siERHJgkRfH6Hduxl5cRfhV14h1tBw1gDoghgGtpwczEQCMxI56yXuhfXkv/d95L/rnTiKii7Oc2VGCMeS/NeLjfzbM6foH4kBsKjcz59uXsItS8vOHj6O9E3spRTqnXi+ag0sug0W3wYVV4ACzAuWMlMc6TvC081Ps61lGycHT044v6hwUSZQWla0TGGxiIiIyAym3MCiEAk1BpEZ5fgT8P33cyxVw92Of2LvX71V/7kXERHMeJxYYyOxpiYSPT3Eu7pIDQ1jxuNWuGS3YfP6sPl82LxebHm52HNzsfn92PPysPlzsef6seXmYsvJwbBZQ5ClIhESPT1ET54k+uqrjOx8kdC+fZAOrAyPh8K77qLoYx/FWVaWzS+BTDHD0QQP7Wjk359pIBixhlJbW1fIn9+2hA3zXiN4TKWgY7/VQ+n4Y9C+d+L53Cqrh9Ki22De9eDyXbqXmAVahlrY1ryNbS3b2Nu9l5SZypwr85VZgVLtTWyo2KB5lERERERmmNHcoL6+HqfTyZYtW9iyZUu2y7rsFCKhEElkRgm0wT8tI2HaWBb9Do//yVuZV5KT7apERGQWSQaDBB/9lTVPU3pYPcPrpeSTn6T4no9guFxZrlCmkkA4zr8/08B/vHCaSNwKKG5aXMqf3baEpZV5r3+DoS448Tgcf9zqrRQPjZ1zeMeGvVt0G+RVXaK3mB0GI4M82/Ys21u283zb84QT4cy5HGcO11Zfy021N3FdzXXkuc7j105EREREpjTlBpZZHSJt3bqVrVu3EovFaGhomPWNQWRGME34u7kQGeSO6Jf46J3v4n1r9ftaREQuP9M0GXn+BXq/8Q3Cr7wCgGvePCq/9EV8q1dnuTqZarqCEf75qRP84KUWkikTw4B3rariM29dzJzi8+xNFI9A4/Nw/FdWqBRomXi+4gpY/DYrVKpcDekedfLGRZNRdnXsYlvLNra3bKc3PDbEoMNwsLbCmkfp5jk3U5GjOatEREREpiOFSJZZHSKNUmMQmWG++3ZofI57Y5/Au+G3+MK7V2a7IhERmcVM0yT485/T9Q8PkOztBbudkk98gpJPfgLD4ch2eTLFnO4d4R+fOMYvD3QA4LQbfHDDHD71loWU5rrP/0amCd1H4Fg6UGp9CRj3Xz9/OSy8FZa8HebfCE7PRX2P2SRlpjjUe4htLdvY1ryNhkDDhPMrS1ZyS90t3DLnFubkzclSlSIiIiLyRik3sChEQo1BZMb51Z/Drn/jwcTb+HHpFh799HXZrkhERIRkMEjn336B4C9+AYB39Wqqv/Y1nOWaK0kmO9QW4O8ee5XnTlg9XHwuO79z7Tw+fv18cj0XMPfOcA+c/LUVKjU8DbHhsXPOHFh4ixUoLXwreAsv0lvMTs3BZra1bOOp5qfY370fc1x4t6hwEbfMuYVb6m6hvqBec3eKiIiITGHKDSwKkVBjEJlx9v4X/PxTPJ9czocTf8Ghz2/G59JPeouIyNQQ+OUjdH7+86SGhrCXllDz9X/Gt0bD28nZ7TjZy989foxXWgYBKPQ52XJTPR/aWIfbYb+wmyZi0PS8FSi9+ggE28bO2Rww91orUFp8O+RXv/mXmMV6Qj1sa9nGk01PsrtzN0kzmTlXl1eXCZSWFy9XoCQiIiIyxSg3sChEQo1BZMZp2wvfvokB8lgd+Sb/8/GNbFxQnO2qREREMmJNTbR+6lNET5wEp5OKv/gLCu++K9tlyRRlmiaPH+7k7x8/xqmeEQBqi7z8+W1LuGNl5ZsLH0wT2vdZYdKrj0DP0Ynnq1bDkjtgyTugdDEo6LhggWiA7S3bebLpSXa07yCWimXOVeRUcMucW7h5zs2sLluN3XaBAaGIiIiIXDTKDSwKkVBjEJlx4mH4UhWYKdZHtvLbt13N799Yn+2qREREJkiNjNB+318w9PjjABTcfRcV992H4XJluTKZqhLJFA/vbeUfnzhO91AUgCtrC/iLO5ayfm7RxXlIX8NYoNSyiwnzKBUtSAdKb4ea9WCzXZxnzkIj8RGea32OJ5uf5NnWZwknwplzRZ4i3jLnLbx1zltZX7kep+0Chi8UERERkTdNuYFFIRJqDCIz0jfWQ+9xPhz7c3KWbeabv7U22xWJiIhMYpomfd/6Nj1f+xqYJr5166j+56/jKLpIgYDMSKFYggefO82/PdNAKGYNj7Z5eTl/ftsS5pf6L96DhrvHhrw7tQ2SYz1nyCmDxW+Dpe+EedeDQ+HnhYokIuxs38mTzU+yvWU7wVgwcy7PlcfNc25m89zNbKjcoEBJRERE5DJSbmBRiIQag8iM9H/3wOGf8OX4b/DL3A/wwv/3lmxXJCIick5D27bR/id/SmpkBEdVJbVbt+JZujTbZckU1z0U4WtPnuAHu5tJmeCwGfzmVXP4w5sXUux3X9yHRYfg5JNWoHT8CYgGxs558mHxHbDsXbDgJnBc5GfPIvFUnJc6X+LJpid5qvkp+iP9mXP57nxumXMLt869lQ0VG3DYNOeniIiIyKWk3MCiEAk1BpEZafvfwfYv8aPk9fxJ/BPs/au3UpSjn5AVEZGpK9rQQMvv/z7xpmYMr5eqL3+JvNtuy3ZZMg2c6BriK796lade7QYg1+3gkzct4KOb5uFxXoK5dRIxaHoejv4Cjv4SRrrHzrnzYNFtVqBUfzM4vRf/+bNEMpVkb/deHm98nF83/XpCoFToLuSWulvYPHcz68rXaQ4lERERkUtAuYFlVodIW7duZevWrcRiMRoaGmZ9YxCZUQ7/BP7vHo7aFvK20Od56KMbuGFRabarEhEReU3JQIC2z9zLyAsvAFD8yU9Q+gd/gKG5Z+Q87DjZyxcfPcrhdms4tKp8D3+yeTHvvrIam824NA9NJaH5RTjyMzj6cxjqGDvnzIFFm61AaeFbwZVzaWqYBZKpJC93vczjjY/zZNOTDEQHMueKPEW8te6tbJ67mTVlaxQoiYiIiFwkCpEsszpEGqXGIDIDdR2Bb24kYvOxJPRt/uTWxXzqLQuzXZWIiMjrMhMJuv/xq/R/5zsA+N/yFqr+/u+w+y/iXDcyY6VSJj97pY1/eOwY7YEIAMur8virty/j6vnFl/rh0PqSFSgd+RkEW8fOObxWkLT83bDwVnDnXtpaZrBEKsFLnS9ZgVLzkwTGDS1Y4i3JBEqry1ZjMxRAi4iIiFwo5QYWhUioMYjMSIkofLECzBRXRb7BqmVL+daH12W7KhERkfM2+NOf0vm5+zFjMVx1dVR+5cv4Vq/OdlkyTUTiSf7jhdN8c1sDQ9EEAG9bUcF9ty+ltsh36QswTWjbC0d+ai2DzWPnHB6ov8XqobToNvDkXfp6Zqh4Ks7ujt2ZQGkoNpQ5V+Yt49a5t3L7vNtZUbICw7hEvdFEREREZijlBhaFSKgxiMxY/7wG+hv4zdhnOZ27nh2fvTnbFYmIiLwh4QMHaP2DPyTR1QU2G8Uf+ygln/oUNrc726XJNNE3HOWfnjzO93c1kzLBZbfxsevmseWmevxux+UpwjSh45V0D6WfQv+psXMOj9UzacX7rLXrMgRcM1Q8GWdnx04eb3ycbc3bGIqPBUq1ubXcPu92bp9/O/Pz52exShEREZHpQ7mBRSESagwiM9b//AYce5S/jn+E7yY38/Jf3kKJX990ExGR6SUZCND5xS8S/PkvAHDW1lL2Z39K7i23qGeBnLdXO4P87S+P8MLJPgBKc9386ebF3Lmm5tLNl3Q2pgldh61A6fCPoe/k2DmXHxbfbgVKC94CDtflq2uGiSVj7Gjfwa9O/4ptLdsIJ8KZc0uLlnLH/Du4be5tlOeUZ7FKEZlJzGSSRHc3ie5uUqEQhsOBLS8fV001tpw3PydeKhwmcvRVYqdPkejvJzU8gs3jxpabh2tOLa65c3HW1LypeSRN0yTZ30+yv59UKAQOB/b8ApzlZRhO55t+BxGZfpQbWBQiocYgMmP9+n544Wv8zPk2Pj30Ib7z2+u5aXFZtqsSERG5IENPPknn3/wtie5uAHwbNlD6B5/Ct359liuT6cI0TX59pIsvPnqUpr4QACuq87j/HctZP7coGwVB50E49DAc+jEExg1558mHpe+wAqW514P9MvWamoFC8RDbWrbx6OlH2dG2g4RpDW9oYLCuYh13zLuDW+puId+dn+VKReRSSPT1ETt1ilhrG2bECpTthUU4yspw1y/AnndhQ4qmIhHCBw4Q3rOH0J69hPftIzUyctZr7SUluOfNw7VgPu75C6z1ggU4ysvP+gMxqWiU6KuvEj50iMihw0QOHSLa0GDNvfcaDK8Xd3097kULcS9ciGfRItyLFuEoKZl0bTIQIHr8OJHjx4meOEH05EliJxtIDg5OvrHTiXvePNxLFuNdvhzP8uV4li4973BsNJyKNTUTb2slFQ6DCfbCApwVFbgXLsTm9Z7XvUTk8lJuYFGIhBqDyIy1//vw009ywnclb+3/Mz7z1kX84c0Ls12ViIjIBUuNjND74IP0/7//wIzFAPCuW0vJ7/4uOdde+6Z++nY8M5kkcuQIsaZmkv19YLdjLyjAs2QJrnnzLtpzJDuiiSQP7WjkX546mZkv6Y4rKvns25ZQU5il4eRME1pftgKlwz+B4c6xc74SWP5uK1CqvRrU/i7YQGSAXzf9mkdOPcLe7r2Z4w6bg+uqr+P2+bdzQ80NeB36ZqbIdJWKxQjv3cvwc88x8vwLRI8de83rndXVuJcswb2wHvfChbjrF+KqrcHwejMBTyoSId7SQvTUacL79xPet4/w4cMQj59xMyeO0hLsOX7MRMLqMRQInPPZNp8P1/z5OEpLMRMJUiMjxDs7SHR1QzI56Xp7aQmexUtwlJRgy8nBjMVIDg4Qa2wi1tSU+bfRpM8VFVlBks2GGQ6T6O09Z+CFYWAvKMDm82HG4yQHB89+X8PANW9eJlCyFxZan4lFSQ0PE2/vINbcTKy5mXhTk9Wz6Vzsdtz19XhXrcK7ejW+1VfirKt7Qz3OzVSK5MAAGAaGy43d/+Z7f4nIWG5QX1+P0+lky5YtbNmyJdtlXXYKkVCIJDJjte6BB99CyFXMsuC/cMvSch78yLpsVyUiIvKmxdva6P32twk8/GPM9DdwnLW1FLz//RS85904Skvf8D1jrW2M7HiBkRd2MLJzJ6lg8KzXOSoqyH/H2yn6yEfO+pO9Mn30Dkf5xyeO8YOXWjBNcDts/O718/nEDQvIuVzzJZ1NKglNO6xA6cjPINw/di63Cla811qq1oCGdLxg7cPt/Or0r3j09KMcHzieOe5z+Lil7hZun3c7V1VehcOmXmAiF4uZSDCyaxehl14icugw8bY2q+eL3Y4tx4eruhrX3Lm4FizAXb8Q98J6HEWv3VPUNE1ip09bf3+/8AIju3djjg8sDANnTQ2u2hpsOX7AJNHXT7yjg0RHx7lv7HRi83gwI5HMvzXO5CgtxbtuLb616/CtXYN70SIMu33CNcmhISvkOdVAtOEU0VMNxBpOEWtuPmtQNMpeVIRn5Qq8y1fgWbECz/LlOMvPPbKImUgQa24mevwE0ePHiZ44TvT4Ces55/jWp7OqCne6t5J7YT3uBQtwzZ+PzeMZu69pkmhvJ3L8OJGjR4kcPkLk8GESnZ1nvec5GQbOykqctbWZHkzJvj5ibW0ke3snv39hId7Vq/GuXIGjrAx7YaFVTyxGoqeXRHcX8c4u4h3tJDo6iXd3Twj1bPn5uOfOtUKpDRvwrVt7wb3ORGYz5QYWhUioMYjMWJEgfKUWgCsi38KbV8yu+27JclEiIiIXT7yri/7/+A6DP/4xqaEh66Bh4F2zhtybb8Z31QY8ixaddRz/xMAAod0vEdq1i5EdO4g1Nk44b8vNxbNkCfaSYkiZJLq7ibz6KmbYGgrH8Hop+eQnKf7ob2M49E3m6exwe4C/+cURdp22wpryPDd/ftsS3rO6OvvzbiXjcPoZa7i7o7+A6Lhws7gerrgLVr4fiuZlr8YZ4MTACR49/SiPnnqU9pH2zPESbwlvn/923rHgHSwqXJTFCkWmLzORIPTSSwR/9RhDTzxx9uHSXoO9uNgaom3BAuwF+RguF6lQmOTgILHGRqLHj0+6p72kBP+mTeRcdx05m67BkQ4gzpQMBIgcO0b01VeJnjhJ9ORJoidOkBoennStLTcXV10dnuXL8a1ZjXfNGmsOogv8e8KMxYi1tBBtaCAZCGA4nNg8bpyVlTiqqnCUll6Uv4NS4TDRhlOkggFM08RwOnGUluIoLXtTvXUSvb1EDh8mfPiw9TUbGiYVDmNzu7Dl5OAoK8c1pxZn7Rxcc+tw1tRgc519rr94V5c1NOD+/YT37Sdy6NA5e1VdMJsNz7Jl5Fx9Fb6rrsK3Zs1rDseXikZJ9vWR6OtP90p3YM/PxzVvLna//+LWJjKFKTewKERCjUFkRvvHpTDUzntjf83e1CJe/stbKPG7s12ViIjIRZUKhwn+6jEGf/hDwvv3TzhnuN04a2twFJdg2G2kQmFira2Tf+rVbse7ahU5m67Bv2kTnhUrJoVDqWiU4e3P0Pf//h+RAwcA8K5eTfXXv4azTPMOTmemafL44U6++OhRWvqtoHBdXSGff9dylldNkbly4hFoeAoO/giO/QoS4bFztVfBFR+A5e8FXxbmd5ohTNPklZ5XeOTUIzzW+BiD0cHMuSVFS3jngnfytnlvo8SrXogys0RPnmRk54tEG06SDAQw43HsBQU4Sktx19fjWboM19y68x7O1UwmCe3ZQ/BXv2LoiV+T7OvLnLMXFeG//nq8V67CNW8+9oICwCQVDBJraSV2+jTRhgaiJ04Qb209Zy+a8QyXC+/aNeRsvAb/tZtwL1lywUPPmqZJaiREaihIKhzB5vVgy8nBlpub/R8smCVSsRjRI0cI7dtP9NgxEn19JAMBMMBwOHEUF+MoL8dZUY6josLq4VRRgaOsDOx2zFCIWGsr0ePHCb30MqFdu4g1NU16jqO8HGdlJYbXg2GzkxwaIjk4SLK//6xB4ihn3Rxyrt6I/4YbyLn6Kmy+LA2FK3IZKDewzOoQaevWrWzdupVYLEZDQ8OsbwwiM9J33w6Nz/El9x/xrcAG/utjG7hu4Rsf4kdERGS6iHd0MPT00wxv2074wIFzDksH4F640Bri5OqryLn6auy5uef1DNM0Cfz0Z3R98YukhodxVFUy51vfwl1ff7FeQ7IkEk/y/54/zTeePkk4nsRmwG9dXcdn3rqIAt/Zf4I6K6JD8OojcOB/4dR2MNOTrdscsPBWK1BadBs4NbfPhYon4zzX9hw/b/g5z7Q+QyJlzZ9lN+xcW30t71jwDm6svRG3XT+gJdNTcniY4KOPMvjww0ReOfC61xteL+5FC/EsXoJ7yWKcVVXY8/Mx7HZSkQiJ7h5iTY1EDh4i9PLLE74Jby8oIPetbyXvbbfh27DhvHvwpkIhawi4kyeJnTpFamSEVCyKzePFXlAwNhxb/YIJQ7CJnCne1WX1Pn9xF6EXXyTe3v76H3I6cRQVYS8ugmSKRH8fyZ6JP4RkuFz41q/Hf/115Fx3Pa55c88rbBz9drSCSZnqFCJZZnWINEqNQWQG+/kfwN7/5JGiD7Ol/Tbuu30Jv3v9gmxXJSIiclmYpkm8uZl4ezuJ3j7AxHC5cVZV4Zpbd96h0bnEmppo+d3fI9bUhL24mLrv/RfueRpWbCZoHwzzpUeP8ssD1nwZRTku/mzzYj6wrhabbYp9w2eo05o/6cD/QscrY8fdebDsndaQd3XXwgX+VL7AYGSQXzX+il80/IKDvQczx3Ndudw29zbeueCdrCpdpW8GypRnmibhPXsY/NHDBB9/PDNEKw4HORs34lm6FEdJCYbTQTIQIN7Wbg33duwYZjT6hp5ly8sj95ZbyHvb28i5+qqzDi0rki2J/n7ira3EOzoxYzHMZAJ7bi72ggLshYU4ioux5eVN+nM9MTBA5MABhp95luFnniHe1jbhvC0vD3d9PY7yMmu4vJRJKhQiNTxMMhgkFQySDAZJDg1hAI6yMtz19XjXrMF/4w3WvFr6u0SmEOUGFoVIqDGIzGjP/xM8+dccK3sbm5s/xHtWV/NPd12Z7apERERmjMTAAM0f/RjRo0dxVFQw9/v/jbOqKttlyUWyo6GX+392mBPd1k/Ur6rJ5/PvWsGVtQXZLexcul+Fgz+EAz+EQMvY8bxqWHmnFSiVL89efTPAqcFT/OLUL/hFwy/o+v/Zu+/oqMqtj+PfMz09IZUeOoQuXbqIICigICAoYkER7L289nrtDQVEAQtNQMVCsYACIlVq6CV0COmZlKnvH2cmEAghlZNk9mctVoaZM8/ZQ3K9mfmdvZ+sU3n31w2uy/X1r+f6BtdTI1D+GyBKzpWTQ9aGjeTu2YP95AlwutAFBGCsWRNzo0ZYmsehMxevA8529BjpP/9M2vff5xvrZapfn9ChQwkZPAhDxMXHNLodDmwJCeTu3k3Ort3k7tmD48wZnKmpuJ1OdH5+6MPCMNWtq3b5duiApVlTFL2+xP8OQlR0brcb24EDZP69EuvKv8lavwG33V6qNU316hHUvx/BfftibtZMAiWhOckNVBIiIT8MQlRp8T/CvDGkVWtN6+NP0jQmiCUP9dC6KiGEEKJKcSQnk3DrGGz792OOa0bst9+i85MxYlWF3eniqzUJfPDbHjJyHSgKjGhfm8f7NSG8ou416XLBkX/V7qQd30NO2tnHoltAm1HQcjgEypjjknK6nKw/tZ5F+xbx++HfyT5nj6qOMR25sdGN9KnTB4tBRmyJosk9cICkKVNIX7oMd07OxQ80GrE0bYqlRXP8mjfHEheHsWbNvK4Jt9OJ/cRJbAcPkrVpI1lr/s23X6DO35+gAdcSeuNQ/Nq2kQ+phSgjrtxcdU+vfftxpqTgyswAvQGdvz+6gAD0IcHog4LQBYegDw7C7XThOHWSnO3bsa5dh3XVKtw2W956xlq1CLr6avzaXYFfXByGqKh8HX1uhwNXdjau7Gzc2dm4bTYM0dHog4O1ePmiipLcQCUhEvLDIESVdnIbTO6G01KNBqmfYNAp7Hi5H2aDXBEmhBBClCX78eMcHHYTzuRkggdcS41335UP5qqY0xk5vLl4Fws3qaNrgi0GHuvXhFEd62DQV+BRcY5c2LtMDZT2LAWn5wMq7/5JbUarXw0VaM+nSibLnsVvCb/x0/6fWHdyHW7UjxmCTEEMrDeQGxvdSLPwZhpXKSqq3L17OfPZZNIXLwbPR1SGmBj82rbBVLMmGAy40jOwHT1Czo54nElJBS9kMICiQEGdEIqCf6dOhFx/PcH9+6ljtoQQFYoz00rm8uVkLFtG5sqVBYbJusBAcLlw2e0F/28dMERGEnBlFwJ7X0XgVb3RmeT/30XJSW6gkhAJ+WEQokrLzYA31P9dd2MGR3NM/Hx/N1rUDNG4MCGEEKLqyVq/noTb7wCHg+jn/o9qo0drXZIoBxsOJfP8jzuIP5EOQLPqwbwyuDntY6tpXFkRZKfA9oWweRYc23D2fv9wtTOpzSio3kq7+qqA45nH+XHfj3y/73tOWE/k3d+sWjNuaHQDA+oNIMQsv4sLyNm9mzOffkbG0qV59wX26UPE3eOwtGpV4IUIbrdb3ado6xayd+wgZ0c8uTt34kxLy3ecYjRirFULS4sWBHTqSED37hijo8v9NQkhyoYrK4vMlavIXPk3Odt3kLtvHzgcBR+s06kd8AYDrvP+W6APCSF40CBChgzGEhcnFziJYpPcQCUhEvLDIESV93YjsJ7m6YiPmX00nLeHteKm9rW1rkoIIYSokpK/+opTr7+BYjIRO/87LI0ba12SKAdOl5tZ6w7zztLdpGWrVwIPvaIWzwxoWnFH3J3v9C7YMgu2zIHMs3v7ENNS7U5qeRMEXHyPFFE4p8vJ2hNrWbhvIX8e/hO7S/05MevN9KnThxsb3UiHmA7olArcxSZwJCWRvW0b9iNHcdvt6ENDMdWuhaV5c3T+/iVaM+u//0ia+jmZy5fn3Rd0zTVE3DseS7OSday5cnNxpqQAoJhM6ENCZD8iIaoQt9OJMz0dZ2oqisGg/rFY0Pn7o5hMeeGQM9NKzvbtZK78m/Sff8Fx6uz/v5sbNSR40CACu3fH3Lgxik7+/0dcmuQGKgmRkB8GIaq8L/rBkX+ZX+9lHtvZkDu61uP56+O0rkoIIYSoktxuN0fuuQfr3ysxN2pE7Hfz0FlkT5SqKtlq4+2lu5iz/ghuN4T4GXn62qYMb18bna6SXO3rdMD+P2Hzt7D71/zj7hr394y76wt6Y+HriItKzUnl5wM/s3DfQvam7M27v1ZgLYY0HMLghoOJCYjRsEJxvszVq0ma+jlZ69bljZjLx2DAEheHf4f2BHTsiF+7dugDAy+6nu3oMTL//IO0RT+Rs327eqeiEHxtf8LHj5cLDoQQZc7tdGL95x/Svv+ejN//yLffki4wEGPt2hgiI1B0auDsdjhwOx1gd+B2OFAMBkwNG+DfvgOBvXoW+t84UXVJbqCSEAn5YRCiyvv+Xtgyi21N7uf6LV3oUj+c2Xd31roqIYQQospyJCVxYPAQnGfOEDZ6NDHP/Z/WJYlytulwCs9+v52dnhF3V9QJ5bUbWtKseiXb3DorGbYvUAOl4/+dvd8/AlqNgCtuhSjZ26ek3G43O5J2sHDvQhYfXEymPRMAnaLjyhpXclPjm+hRqwcGnUHjSn2X/dRpTv/vTdJ/XZx3n6lhA8z1G6CYzTiTk8ndvx/HyZP5n6jTYWnWDFO9ehgiI0FRcOdkYzt2jNy9e3EcPzvaEKORkMGDCL/jTsz1612mVyaE8GXO9HTSly4lY+kysjdtwpWVVaznKyYTwQMGUG3sbViaNi2nKkVFJLmBSkIk5IdBiCrvr7dg+WukNB5O261DCPU38t9zfWUWrhBCCFGOMleu5Mi4uwGoPWUygT17alyRKG8Op4uZaxJ4b9lurDYnep3Cnd3q8WCfRgSYK2EocCreM+5uLlhPn72/Vge44jZocSOYArSrr5LLdmTzW8JvLNy7kI2nNubdH+UfxdBGQ7mx0Y3SnXQZuZ1OUr6dReKHH+KyWkGnI2zUKKqNHYupVs0LjrcfO0bWhg1Y168na9167IcPF34CgwG/li0JHjiQ4Gv7YwgPL6dXIoQQhXM7HOQeOID9+HGcSUlnuy0NBhSD0TMuT48rO4ecnTvJXL4c28GDec/379CBsFtuIajPVSiGSvj7jSgWyQ1UEiIhPwxCVHnb5sOCO3HV7kLj/Q/gcLn556mrqBHqp3VlQgghRJV28vXXSfnqa/Th4dRf9KN8aOgjTqRl8/JP8SzernYq1Aix8OKg5lzTvJIGAk477PsD/vsa9iwBl2djb1MQtByqBko12oJcoFRiCekJLNi7gB/3/UhyTjKgdif1rNWT4U2Gc2WNK2XvpHKUvXUrJ158kdz4nQBYWrUi5oXn8WvevMhr2E+eJHvzZuzHjuM4cwYAxWzCWL0Gprp18WvVssR7KAkhhJbcbjc5W7eSPPMr0pcuBacTAENMDCHXX09gr57qPnEyvrlKktxAJSES8sMgRJV3ZB180RdCatOPT9l9KoMvbmtPn2bRWlcmhBBCVGmu3FwODbuJ3L17CezZk1qTP5NOYB+yfNdpnvtxO0dTsgG4ulkULw5qTq2wSvxBcsYptTtp01eQfODs/TEt1TCp5U3gF6pZeZWdzWnjj8N/MG/3PDac2pB3f83AmgxrPIwhDYcQ4RehYYXacuXm4rY70AX4l8l/Sx0pKSR++CGpc+eB240uOJioRx4m9KabUPT6MqhYCCGqFvvJk6TMmUPqvO9wJieffUBRMFavji4wEMVsBp2CouhAp1NvG41YmsUR2KMH/h07oOjkwojKQnIDlYRIyA+DEFVe+gl4rykoeh5uvJTvt5zmif5NmNCrodaVCSGEEFVezu49HLrpJtw2GzEvPE/YzTdrXZK4jLJtTj5Zvpepfx/A7nTjZ9Tz4NWNuLNbPYz6SvwBitsNh1bBppkQvwicuer9BgvEDYF2t0GdLtKdVAoHUg/w3Z7v+HH/j2TYMgAwKAauqnMVw5sMp2NMR58IpbP++4/UOXOxrlmD47Q6VlHx98eveXMCul5J8LXXYqpbt1hruqxWUr77jjOfTcaVlgZAyOBBRD3+OIYI3w3phBCiqFy5uWT8/juZy1dgXb0aZ0pKkZ9rrFuH8DvvJPSGG1CMxnKsUpQFyQ1UPh0iTZo0iUmTJmGz2di/f7/P/zAIUWW5XPBqFLjszOz0My/8lc4NbWvy/og2WlcmhBBC+ITkr77i1OtvoJjN1Fu4AHODBlqXJC6zvacyePaH7aw7qF612zg6kNduaEmH2GoaV1YGspJh6zw1UDodf/b+8EZwxRhoMwoC5IP5kspx5LD00FLm7ZnH1sSteffHBsfmdSeFmEM0rLB82E+c4OTLr5C5fPklj7W0bkXwtdcSdPXVmC7ymYbb5SJn61bSlywl9fvv88Ijc5MmRD/7DAEdO5Zp/UII4SvcbjfOpCRsR47gzs7GlZurXmzicuF2u8HlxpWZSdbGjWQsXaruO4caJkVOnEjwgAGyt1IFJiGSyqdDJC/5YRDCB3zYGlIOsbbXt4xYotC8RjC/PNBd66qEEEIIn+B2uThy9z1YV63C3KwZsbO+RecnexP6GrfbzYJNx3j9150kW20AjO5UhyevbUqwpQpciet2w9ENapi0fSHY1Q+J0BkhbjB0uFO6k0ppd/JuvtvzHT/t/4ksRxYAFr2FgfUHcnPTm2lSrYnGFZaNzNWrOfbgQ7gyM8FoJGTQ9YRcPwhLs6YoRiP248fJ2rCBjN9+x7pmjXrRnIepfn0sTZtiiI5GMRhwpqZiO3qEnO07cGVk5B1nrOO5En7ojfLhpRBCXCbeTtCkqZ/njcMz1a1L+D33EDxwADqzWeMKxfkkN1BJiIT8MAjhE2ZcB4dWcrrvx3T8KRyzQUf8y/3R6+RNvBBCCHE52E+f5uDgIThTUgi8ug+1PvxQ9tzwUalZNt74dRdzNxwBIDrYzEuDWtC/RYzGlZWhnHTYsRA2zoTjm87eH9lMDZNajQBLsHb1VXJWu5VfD/7KnF1z2JOyJ+/+K6KuYFSzUVxV5yqMusoZTKbMm8fJl14GpxNL61bUeO01zA0vPobbkZhI+uIlZPz+O1kbNuQLlM6n8/cnsHdvggcOILBnT/lvsBBCaMRltZL8zbckT5+OMzUVAF1gIIG9e+N/RVtMsbHoQ0JAUXDb7eByofj5YwivJmNHLzPJDVQSIiE/DEL4hB8mwOZvcfX+P5r91oJch4sVj/UiNiJA68qEEEIIn5G1YQOH77gTt81G6E3DiHnxRfkQ04et2Z/EM99v4+AZtWOnX/NoXh7cguhgi8aVlbHjm2HDF7BtPtjV7hmMAdDqJmh/J1RvpWl5lZnb7WbT6U3M2jmLPw7/gdPtBCDKP4rhjYczrPEwwv3CNa6yaNwuF4nvvUfStC8ACB50PdVffRWdyVTkNRwpKeRs3Uru3r04UlJw2+zoQ0IwxkRjad4cc8OGsv+GEEJUIC6rlZQ5c0j+5lscJ04U6TmG6tUJvuYaQkeOwFyvXjlXKCQ3UEmIhPwwCOETlr8Of/0P2o1lwIFhxJ9I5/Mx7ekbF611ZUIIIYRPSV+8mGOPPApuN0F9+1L9jTfQB5bvRR1ut5vMFStI//kXcg8cQGex4N++PWGjbsZYvXq5nlsULsfu5OM/9zLlrwM4XG6CzAaeGtCUmzvUQVfVOsazU2HLHNjwJZzZffb+Wh3UMKn5DWCsYgHaZXTKeop5e+Yxf898knPUEUFGnZH+sf0Z1WwULSJaaFzhxbmyszn+5FNkLFsGQMR99xExcQKKjD4UQgif4Ha5yN60Ces//5C9bTuOkydwpKaiKDp15KhOhys7G2dKytmOU72e0GHDiLz/PulOKkeSG6gkREJ+GITwCZu+hkX3QYM+PGR8jh82H+fxfk2Y2PvioyGEEEIIUT7Slyzh+ONP4LbbMdauTcwLLxDQ9cpy+cDUdvQYJ555hqx16y54TPHzI+a55wi98YYyP68onp0n0nlqwVa2HE0DoGNsNV6/sSUNowI1rqwcuN1waJXanbTzJ3A51Pv9wqDtLdDudghvoG2NlZjNaWPpoaXM3jWbbWe25d3fMqIlNze9mX6x/TDpi97d4+V2OHCmpKBYLOiDgsqsXvup0xydOJGc7dtRjEaqv/4aIddfX2brCyGEqDpcVivWdetInTOXzL/+AtQxeBETJ1Jt9CiUYnSviqLx5gYNGzbEaDQyceJEJk6cqHVZl52ESEiIJIRPOLACvhoMEY2Z1Hw2by/dzZA2NfhgZFutKxNCCCF8UtbGjRx//Ansx48D4N+pExETJxDQsWOZnSN7yxaOTJiIMykJxWIhbORI/Dt3wpWWRsrceWRvUveqiX76KarddluZnVeUjNPlZsY/h3h32W6ybE5Meh33XdWQ8T0bYDLotC6vfGScgv++UvdOSjty9v4GV0HHu6HRNaCTkY8ltS1xG7N3zWbJoSXYXXYAwi3h3Nz0ZoY3GU6YJeySa2Tv2EHSlKlkrlqFO0sdR2ioUZ3ga/oRdstoTKX4DCFz9WpOPP0MjtOn0YeGUuuTj/Fv377E6wkhhPAdWRs2cOqNN8nZsQMAU716RD/zNIHdu2tcWdUiuYFKQiTkh0EIn5C0Hz6+Agx+LBvyH3d/s4m46sH8+qD8n6sQQgihFWdGBmc+mUTKrFnqpsGAf/v2REycgH/nzqXqTMrZuZOEW27FZbVibtaMWh9/lO/DXrfbTeIHH5I0ZQooCrU+nURQ796lfk2i9I6mZPF/P2xnxe5EABpHB/LGja1oV/fSH/hXWi4n7F0G67+Afb8DnrfpYbFqmNRmNPiFalhg5ZaUncSCvQuYu3sup7NOA2DWmxnUYBC3xN1C/ZD6FzzH7XSS+MknJE2eonaPASjK2dsAej0hgwcTce94TLVrF7ke+6lTJH70EWkLFgJgatCA2p99iqlOnZK/SCGEED7H7XKRtnAhp997H2eyOso1oEd3wkbeTEC3rsXaV08UTHIDlYRIyA+DED7BngOvqfsfHb5zGz0mbcNs0BH/cn/0VW3evhBCCFHJ2I8f58znn5M2f0FemOTXpg0RE+4loHv3YodJtqPHOHTzSJyJZ/Bv357aUyajC7hw3yW3283Jl14idc5c9OHh1P/5JwxhVTioqETcbjeLthzn5Z/iSbLaUBS4rUssT/Rvgr/JoHV55Sv5oLpv0qavICdVvc8YAK1HQqd7ILKJpuVVZnaXnWWHljFzx0x2Ju/Mu79HrR6MiRtDx5iOKIqC2+Hg+BNPkP7rYgCCBwyg2p13YGnWDFdWFllr15IyazbW1avVBQwGQm+4gYh7x2OsUaPAc7udTrI3byb1u/mkL16MOzcXgLBRo4h6/DF0fn7l++KFEEJUWc70dM5M+pTkb78FhzomVzEaMTdujCEiAl1wMDqLBcXPgs7ihz4khIBu3bA0aaxx5RWf5AYqCZGQHwYhfMbbjcB6Gue4v4j77AS5DhfLH+tFvYjy3cxbCCGEEEVjP3mSpC++JHXevLwPWP07d6bGG69jrF69SGs4UlJIGDUa28GDmBs3pu43X6MPDr7o8S6bjUNDh5K7dx8hN9xAjTdeL5PXIspGitXGq7/sZMGmowDUrubH/4a24soGPrCBtC0Lts2DtVPgdPzZ++v3VsMkGXVXYm63m42nNvJV/FesOLICt6fzq0lYE8bE3Urbz/4m85dfwWikRiF7FGVv2ULix59gXbVKvcNoJKhXL/w7dcIQHQV2O7ajx8jdtRPrP2twpqbmPdfviiuIeuwx/K+Q8dpCCCHKRu6Bg6TOnUvar7/gTDxzyeP9O3Qg6onH8WvZ8jJUVzlJbqCSEAn5YRDCZ0zuDie3wqjvGLjEjx3H05lyazv6NY/RujIhhBBCnMORmEjSl9NJmT0bd04OuuBgqr/0IsHXXlvo81xZWRy+/Q6yt2zBUKM6sbNnY4yOvuT5sjdv5tDIm0FRqPfjD1gay1WZFc2K3ad5ZuE2jqflAHBL5zo8dW0zAs1VvCsJ1PFph1aqYdLuX8HtUu+XUXdl4nD6Yb7Z+Q0/7PuBbEc2g9e4GL3ChUuvI/y9N4npV3CAdK6sjRtJ/PgTsv79t9DjdIGBBPXtS+jwm/Br06ZUIzuFEEKIi3G73dgPHyZ3714cKSm4rFbc2dm4snNw5WRjP3wE6+rV6gQAnY6I8fcQMWECisEHfq8qJskNVBIiIT8MQviMb4fD3qVw/Uc8vK813/93jMf7NWFi74ZaVyaEEEKIAtgOHeLY40+Qs20bAKE3DSP66afR+ftfcKzbbufoffeT+ddf6EJCiJ31LeYGDYp8rqMPPEjGsmUEXt2H2p98UmavQZSdjBw7byzexay1hwGoGerHm0Nb0r1RpMaVXUYph2D9NM+ouzT1PmMAtLlZDZRk1F2JpeWm8fvct2n6xgJ0bpjaX8eq9v4MbjiY25rfRu2gS+95lL1jB5l/Lidn1y6cZ86gGI0YoqMxN2qEf4f2+LVqhWI0XoZXI4QQQhTOfuIEp995l/RffgHUfUlrvPcuxqgojSurWCQ3UEmIhPwwCOEzFj0Am2ZCr2eY5B7K20t3M7hNDT4cKSMkhBBCiIrKbbeT+MkkkqZOBbcbU/361HzvXSxNm+Yd48rN5fjjT5CxbBmKxUKdL78s9oio3P37OTDwOlAUGixZjKlu3bJ+KaKMrN53hicXbOVoSjYAIzvU5pmBzQi2+NCH8zYrbPWMuks8u7cPDfrAlfepI++ky6VYbIcPc3DYTbjS00nv34l3+2Tl7ZukU3T0i+3HnS3upEk1CeqEEEJUHem//sqJ557HZbXi174dsd98o3VJFYrkBiqd1gUIIcRlE+TZSyHjBA0iAwE4kGjVsCAhhBBCXIpiNBL18EPUmf4lhshIbAcOcGj4CE5/8AHZW7eS/ttvHBoxUg2QjEZqvv9eifYYMTdoQEDPHuB2k/y1vHmuyLo2jGDpQz0Y00UN+uasP0K/9/9m+e7TGld2GZkCoP3tMGENjFkETa8DFNj/B3x9A3zWFf77Fhy5WldaKbiysjh63/240tOxtG5Fh7emMve6uXxxzRd0rdEVl9vF4oOLGfbTMMb/Np71J9cj1+MKIYSoCoIHDCB2/ndYWrUi5rnntC5HVFDSiYQkikL4jA3T4eeHoPG17O3zOX3f/5sAk57tL/WTeeRCCCFEJeBISeHE08+QuWLFBY/pQ0Op+cEHBHTuVOL1rf/8w+E77kTn70+jVSsLHJsnKpZ/DyTx5IKtJCRlATD0ilo8f10cIf4+1JXklXwQ/v0M/vsG7J4LpQKjoeM4aH8n+FfTtr4ScmZkkLVxI+5cG5bmzTHVqlmm67vdbo4/+hjpv/6KPiKCegvmX7CX2s6knUzfPp2lCUtxefakahnRkjtb3EnvOr3RKXJ9rhBCiMrN7XbLZ2MFkNxAJSES8sMghM/YvQRmj4Dqbci980+aPbcElxvWPtOH6GCL1tUJIYQQogjcbjcZS5eSOn8Bubt3owsMJLBXL8LH3YWhWuk+JHe73ezv1x/74cPUePstQq6/voyqFuUp2+bknWW7+XL1QdxuiAoy89oNLekbF33pJ1dF2SmwcYY66i7jhHqfwQ/ajobOEyC86HuFacntdJI0dSpnJk/BnXu2oyrwqquI+b9nMdaoUSbnOTP1cxLfew8MBurOmI5/+/YXPfZIxhFm7pjJD/t+INep1hQbHMvtLW7nuvrXYdKbyqQmIYQQQlQMkhuoJERCfhiE8BnHN8PUnhAYA4/tpufby0lIymLWuE5c2SBC6+qEEEIIUQEkfjKJM598QkDXrtT5YprW5Yhi2JiQzOPzt+aNK77xipq8cH1zQvx8sCsJwGGDHd/Dmo/h5DbPnQo0GaDum1SnS4XdN8ntcHD8iSdJ//VXAEx166ILDiZn+3Zwu9EFBVHr448I6Ny5VOdJmTuPky+8AED0s89S7dZbivS8pOwkvt35LXN2zyHDlgFAlF8UY5qP4abGN+FvlC5GIYQQoiqQ3EDl0z3XkyZNIi4ujl69emldihDicvDuiWQ9DU4H9SMCADh4RvZFEkIIIYQqZPAgAKxr1mA/5UN77FQB7epW49cHunNPj/roFFi46Rj9P/iblXsTtS5NGwYTtB4B96xU901q1A9ww+5fYPq18PlVsG0+OB1aV3qBxA8+UAMko5Hqb75B/SWLqffdPOr/8jN+rVvjysjg8Li7SfvllxKt73Y4OP3hh3kBUvi4u4ocIAGE+4XzwBUP8Nuw33is/WNE+UVxOvs072x4h34L+vH51s/JtGWWqDYhhBBCiIpGOpGQRFEIn+FywiuR4HbCI7t45e8Uvlh1kDu71eO56+K0rk4IIYQQFcShkTeTvXkz0c8/R7VRo7QuR5TAxoRkHpm3JW+vpDFd6vLUtU3xNxk0rkxjibvh309hyxxw5Kj3hdSBK++HtreASfsOmszVqzly510A1HzvXYIHDMj3uCs3l+NPPUXG4iWgKMS88DxhI0cWaW233U7GihWc+ewzcuN3AmqAFPnII6XaB8LmtPHLgV+Ytm0ahzMOAxBkCuKWZrcwutloQswhJV5bCCGEENqR3EDl051IQggfo9OrmwsDZJygfqTaiXQgUa4SFEIIIcRZQX2vBiDz9z80rkSUVLu61Vj8YHdu7VwXgK/WJDDgw5VsTEjWuDKNRTaB6z+Eh3dAr6fBPwLSDsPix+GDFvDXW5Cl3b+R22bj1CuvAhB688gLAiQAndlMzXfeIfTmkeB2c/LFl9R9kwq5PtZ29BinP/iAfVf14dj9D5AbvxNdcDA13vofUY8+WuqNxE16Ezc0uoEfh/zIG93foH5IfTJsGXy25TP6LejHh5s+JDnHx3/2hBBCCFFpSScSkigK4VOm9objm2DkLP4xdmLU52upG+7PX4/31royIYQQQlQQtkOH2N//WjAYaPzPavTBwVqXJEph5d5EHv9uKyfTc9ApcHePBjzctxFmg17r0rRnz4bN38LqjyA1Qb3PGADtb4fOEyCk5mUtJ2n6DE7/73/oIyJosGQx+sDAix7rdrtJ/Ogjkj6bDEDYmFuJevRRdGYzoHYsZf71F6nfzce6ahV4PvrQh4cTeuMNVLvtNgwR5bMvqsvt4reE35i6dSp7UvYA4GfwY3jj4dzW/DYi/SPL5bxCCCGEKFuSG6gkREJ+GITwKbNGwp7FcN0HnGp8M51e/wOdAjtf6S8fJAghhBAiz/7rrsO2bz813n6LkOuv17QWt9tNTnw8jtOn8WvTBkNYmKb1VEZp2XZeWrSDhf8dA6BpTBDvDm9N8xoyZgxQ90WK/wFWfQCntqn36YzQagR0fRAiG5d7Ca6sLPZd1QdnairVX3uV0KFDi/S85JkzOfXGmwDoIyLwb9cOl9VK9qZNuLKy8o4LuPJKQocPJ+iq3igmU7m8hvO53C5WHFnBlK1TiE+KB8CkMzG08VDuaHEHMQExl6UOIYQQQpSM5AYqGWcnhPAtAeHq16wzRAWZCTDpcbnhcFJW4c8TQgghhE8J6qOOtMv4409N63CmpXHkrnEcGjqMo/dOYF+fq0n78UdNa6qMQvyMvDeiDZNvaUd4gIldJzMYMmk1n/y5F4fTpXV52tMboOUwGL8SRi+A2O7gssPmb2BSR5gzGo5uKNcSUr/7DmdqKsY6dQgZPLjIz6t2223U/PBDDDExOM+cIWPpUqyrVuHKysIQE0P4uLtosGwpdb78guD+/S5bgASgU3RcVecq5gycw6d9PqV1ZGtsLhuzd81mwMIBvLH2DRKzEi9bPUIIIYQQJeHju4oKIXxOgGd0hDUJRVGoHxnItmNpHDhjpVF0kLa1CSGEEKLCCLqqN0lTpmBduRK3zXZZP3j2cjscHJ14H1kbNqAYjRiio7EfPcrxp55GFxRM0FWlG8ebtWkTZz6bjP3EcQI6dSby/vvQh4aWTfFlxO1ykbND7eCwxDVD0Zeuc7x/ixjax4bx7PfbWLrjFO8s28PvO0/z7vDWNIi8+Og0n6Eo0Ohq9c+R9bD6A9j189k/sd2h20PQoI96bBlx22wkTZ8BQPidd6IYivdRRXC/awjs3YusteuwHTyIYjFjaRaHpXkcik77a2cVRaF7re50q9mNdSfX8dmWz9h4aiOzds1iwd4FjGwykttb3E64X7jWpQohhBBCXED736aEEOJy8vfMPbeqV/zVjwwA4ECiVauKhBBCCFEBWVq2RB8RgctqJWtD0Tow7KdOkfjRxxx/8ilS5szFbbOVqobkGTPI2rABXUAAsd/No8GypYSOGAFuNyeefx5nenqJ1874808SxtyGdeVKbPv2k/LttxwaNRpHcnKpai5LtoQEDg65gUM33cShm27i4NBh2BISSr1uRKCZybe0473hrQmyGNh8JJXrPlrFt2sTkGnvKrfdjvWYk4yYcTjHroA2t4DOAIdWwjdD4fPesOvXvH2GSivtp59xnDyJITKSkBuGlGgNnclEYPduVBtzK2HDh+PXskWFCJDOpSgKnap3Ynq/6Xx+zee0jmxNrjOXmfEzuXbhtXyw8QNSc1K1LlMIIYQQIp+K9RuVEEKUtwBPiJR1BoD6EeoVpwcSM7WqSAghhBAVkKLTEdizBwAZy1dc8njrunUcuO56znz6KWk//sjJF1/k0K234szIKNH5nampnJkyFYDoZ57B0rQpik5H9DNPY6pXD+eZMyR/9XWJ1rafPMnxp54Gh4Oga66h5gfvY4iJwXbgACeef75CBCnO1FQO334HuXv2oPj7o/j7k7trF4fvuBNHSkqp11cUhRuvqMWyh3vQtWE42XYnz36/nXFfbSQpM7cMXkHllXvwIAcGDebw2Ns5Ov5e9o+cQEbIMHhwC3SeAEZ/OP4fzLkZJneDHd+Dq+QjAd0uF0nTpgFQbexYdBp0/V1uiqLQuXpnvr72az67+jOahzcn25HNF9u/oP/C/kzaPIl0W8lDYiGEEEKIsiQhkhDCt3hDJGsScE4n0hnpRBJCCCFEfkG91XFxmcuXFxqs5B44wNF7J+DKyMDSvDnh4+9BFxJCzpatHHvssRKFMilz5uLKyMDcuDEhQ87uD6Mzm4m8/z4Akr/6CldOTrHXTvz4Y1zp6VhataLmu+8Q3L8/tSd/BkYjmb//gXXV6mKvWdZOvfk/7MePY6xThwZLFtNgyWKMdepgP3aM02+9XWbnqR7ix9d3dOL/BjbDpNfx+85T9PtgJct3ny6zc1QmjpQUDo+9HdvBg+hDQjDUqI4zNZWj992PdedR6P8GPLQNuj0CpiA4tR2+Gwufdoat88DpKPY5M5cvx3bwILrgYLXTzocoikK3mt2YPXA2H/X+iCZhTbDarUzeMpn+C/ozZcsUrHZ5nyKEEEIIbUmIJITwLeeNs6sXoYZIhyREEkIIIcR5Aq68EsVkwn70KLZ9+wo8xu10cvzpp3FZrfi3b0/dWd8S9dBD1J3+JYrJhPWvv0n/+ZdindftcpH63XcAVLvj9gv2AQrq1w9jjRq40tPJ+P2PYq1tO3qUtB9+BCDm2WdQjEYALE2bUm3UKADOTJpUrDXLWs7uPaT9qNZY8+23MEZFYYyKosb/3gQg7fvvyb3I96MkdDqFu7rX54eJXWkcHciZzFxun76eF37cTo7dWWbnqQxOvfIqjlOnMMXGUv/XX2i4dCnBA64Fp5Njjz6GMzVVvSjr6hfgoa3Q8ykwh8CZ3bBwHEzqAP99A057kc+ZNO0LAMJGjkQfGFBOr6xiUxSF3nV6M+/6ebzX6z0ahjYkw5bBJ5s/of+C/nwd/zU2Z+nGYwohhBBClJSESEII33LuODu3m7rh/gAkWW1k5hb/ykkhhBBCVF06f3/8O3cCLj7SLu2HH8jZshVdYCA13nkbndkMgCUujvDx9wCQ+MnHuJ1FDyOsa9ZgP3YMXXAwwf37X/C4otcTMmSIev7vvy/GK4K0hQvB6STgyi74tW6d77Hwu+5EMRrJ3ryZnPj4Yq1blpJnzAC3m6B+/fLV6N+2LUF9rwbOBg9lKa5GMIvu68bYK2MBmLkmges/XsWO42llfq6KKGfnTtJ//RUUhRrvvIMhPBzFaKT6669jatgAZ1ISiR99dPYJ/tWg99Pw8Da46jnwqwbJB+DHifDRFbD+C3AUPhowa9Mmsv/7D8VoJOyW0eX8Cis+naKjb92+zL9+Pv/r/j9ig2NJzU3lrfVvcf3317No/yKcLt8KNoUQQgihPQmRhBC+xduJ5HJATipBFiPVAtS56wlJ0o0khBBCiPzyRtqtWHHBY26HI2/foogJEzDGxOR7PPy229CHhGBPOEzGb78V+ZzeY4OvvRadxVLgMSGDBwFgXbu2yPsuuV0uUn/4AYDQYcMueNwQGUlgnz4ApC4sXjiVdw63G/vJkzjTS7afizMjg/TFiwGodtttFzweftddAKQvXowzs+z3tLQY9bw4qDkzbu9ARKCZvaczuWHSP3z+9wFcrvxjCXN27yFlzhyyt2wp8zq0cGaq+rMcfO21+LVonne/zmIh5rnnAXXMou3w4fxPtIRAj8fUMXd9X4GAKEg7DL88Ah+2gX8ngz274HN+ona9BQ8ehDEqquxfVCWl1+kZUH8A3w/+nhe7vEiUfxTHrcd5dtWzDPtpGH8d+atC7F0mhBBCCN8gIZIQwrcYLer8dsjbF6lONbUb6XBSllZVCSGEEKKCCuzVC4DszZtxpKTkeyz911+xHz6MPiyMsJEX7uWiCwgg9OaRAKTOX1Ck87ldLjL/XA5A0NV9LnqcqW5dTPXqgcOB9Z81RVo7e8sWHMdPoAsMzAuLzhd64w2AGtK4Xa4irevlSE7m8Njb2derN3uu7Mrp9z8o9hrpixfjzsnB3Kghfm3bXPC4pVUrTA0a4M7NJWPpsmKtXRy9mkSx9KHuXN0sGpvTxWu/7uTWL9dyMk3dgyrpiy84OHgwJ198iUMjRnLqjTcr9Yf6jpSUvNGI4ePuuuDxgE4dCejRHVyui3eBmQOh6wPqmLv+/4OgGpBxHJY8qYZJa6eC/eweXpmrVmP95x8wGokYP748XlalZ9AZGNp4KL/c8AsPt3uYIFMQ+1L3cd+f9zF2yVj+O/2f1iUKIYQQwgdIiCSE8D0B4erXrDMAeSPtEpIlRBJCCCFEfsbq1TE3bQouF5l//ZV3v9vp5MzkKQBUu/12dP7+BT4/1DN2zvrPPzgSEy95vpwdO3CcPo0uIAD/Tp0KPTawRw+AfHUVxrpyFQAB3bvljd07X0DnzugCA3EmJZGzbVuR1gW1K+vovRPIWrtWvcPhIGnKlGKPncv8408Aggdeh6IoFzyuKAohg9QuLG/HUnkJDzTz+Zh2vH5DS/yMelbvS6LfB3+zfMb3nH77HQAsLVoAkDxzJmkLihYUVkTpP/0MdjuWuDgszZoVeEzEPep4xrTvv8d+6tTFFzP6Qefx8OBmuO59CKkNmSdh8ePw8RWwfhrOlDOcfOEFAMJuHompVq2yfklVisVg4Y4Wd7D4xsXc0eIOzHozm05vYsziMdz/x/3sTdmrdYlCCCGEqMIkRBJC+J6ASPWrVf0gp66nE0nG2QkhhBCiIIG9ewGQec6+SBnLlmE7cABdSAhho26+6HNNsbHqvj4uF+nLLt054+0qCrjySnQmU6HHBnTrBkDWhg2XXBcgc+VKAAK797joMYrJlLduRgEj/C4mZbY61k0XGEj9X34m+vnnADjz8cfYjh4r0hqu7Gys//6r1nhV74se5+3Qylq3Dld2wWPSyoqiKIzqVIefH+hGy5ohZFhzcH38rlrHqNHUm/8dkY88AsDpd9/DmVn83yfdbjcp8+ZxZMJETr/7Ls60y78HkzeQC7nhhose49+uHX7t2uG220n55ptLL2owkxvag6SgBzmddQPJh2qQviOJjM+f4fDg7tiPHcNYswaRDzxQVi+jygsxh/Bwu4f55YZfGNpoKHpFz4qjKxi6aCjPrnqWk9aTWpcohBBCiCpIQiQhhO/x7otkVTuR6oQHAJAg4+yEEEIIUQDvvkjWVatw22y4XS7OfPoZANVuvRV9YGChz/eOjrP+vfKS5/IGQv4dOlzyWL82rUFRsB8+jOPMmUKPdaamkrN9OwAB3boWXm+vnmq9q/+5ZA0AbpuNpC/UjqPIRx7G3KABYTffjH+XzrjtdpKmTCnSOtY1/+LOzcVYsybmRo0uepypfn0MNarjttmwejufisBltXL86WfY3bETB28aTva27UV+boPIQBbceyWvBh2nhjWJNFMA95g7svtkBuF33I6pbl2cKSmkfvddkdf0Ov3m/zj5/Atk/vknSZ9PI2HMbbisl+/iJkdKSt6+ToWNUAQIv/MOQN0bqbDAzJWTw4nnnuPAdddz+p33SFq0llP/wrHV1Tj6dzg5p0FndFGry2n0e38Ap6PMXo8viA6I5sUrX+T7wd/Tt25f3LhZtH8R131/HR9t+girXS6OE0IIIUTZkRBJCOF7/MLUr9nqvgZ54+wkRBJCCCFEASwtWqCPjMBltWJdt56MZb+Ru3cvuqAgqo259ZLPD+ypdv5Y167FlZt70ePcTifZmzYB4N++3SXX1QcF5YUtWf8VvjdK9tat4HZjio3FGBVV6LEBngArJz4eV9alfz/K+OMPHCdPoo+MIHToUEDt4ImcMAGAtF9+wZmRccl1stavV8/frVuBo+y8FEXJ66ayrlp9yXVB7fY5+uBDpH3/Pa70dHK2bePw7bdjO3y4SM8HMBl0dN2jBmu/NevFjhQ7gz5ZxbcbjhF2+1gAUufPL9beSNY1a0ieOROAamPHoo+IIHf3bk5/8GGR1ygt68qV4HJhbtIEY/XqhR4b2KsXpthYXBkZpM4vODBzu90cf+JJUr+bD4pCYM+ehI0aRVD//vi1bYspti4hVzah3hAXFsNh+HEifNIeNs+WMKmY6oXU471e7zF74GzaRbcj15nL59s+Z+DCgczfMx+ny6l1iUIIIYSoAiREEkL4Hr9Q9WtOKnA2RDqRlo3NUbzNn4UQQghR9Sk6HUFXqR0ap99+m1NvvAF4upCCgy/5fHPjxhiio3Hn5BQ6ei5n1y5cViu6wEDMTZoUqTa/tm0ByP5vc6HHZW/Zqh7futUl1zTUqIEhJgYcDrK3XnpfpPRf1VFooTfcmG+vJb/27TE1bIA7K4uMpUsvuU7Wpo1A0QK0gE4dAci+RHjmlfbjj1hXrUKxWKj50Yf4tWmDKzOTk6++WqTnA9iPH88Luu5+7X56NYkk1+Hi/37YzkvWmigWC7b9+8mJjy/ymomfTAIgbNTNRD/1JDXefBOAlDlzsJ+8PKPJMleoe2oF9up1yWMVnY5qt98OQPJXX+G22y84JnnGTDKWLQOjkdrTPqf2lMnEPP8ctT54n9jZs2iwZAk1vvwB03Nb4JpX1SkBKQfhh/HwaSfYvhBc8jt5cbSIaMH0ftP5oPcH1AmqQ1JOEi+teYmbfr6Jf44XraNQCCGEEBfXu3dv4uLimDRpktalaEJCJCGE78nrREoFIDLQjL9Jj8sNR1OkG0kIIYQQFwofdxeKyUTu7t04Tp3CWKtW3mivS1EUBX9v6LHp4qFH9ubNgBoMKXp9kda2tGgOQO6unYUel71VDZEsrS4dIimKgn87NcjJ2lj4fkvOTCuZf/8NQPCAay9YJ2TgQAAyfv+j0HVc2dnk7FDDF7+2V1yyRm94lrNr1yW7pdwuF0lTpgIQce+9BF9zDTX+9ybo9Vj/Xpn3b3MpGb/9BqijBqMb1uXL2zrw7IBmGPUKP+5NZ2OUGvxlFnEvqeztO8jeuBHFaCT8nvEABHbrqo4ytNtJmTW7SOuUlneUXUDnTkU6PmTIYPTh4TiOnyB9af59vrI2bOD0O+8AEP3UkwR2LWR0oskfrrwfHtwCV78IftUgaR/Mvx2m9oS9v0Mxurp8naIo9KnThx8G/8CTHZ4k2BTM3pS93PPbPUz4fQL7U/drXaIQQghRaS1fvpz4+HgmTpyodSmakBBJCOF7LKHqV08nkqIo1KnmGWmXLCGSEEIIIS5kqlWLmh+8j7FOHSytW1F7ymR0AQFFfr5/XsfQposeY9uvfshradK4yOtaPB1LOXv2XvQYt9udF5T4tWpdpHX92rRR173EvkFZ69aq+xjVqVNg91Sgp4PL+s8/he7zk711GzgcGKKjMdasccn6jNWrq91STucl9zayrlmD7eBBdAEBhI0eDYCpbl2CBw4AUMeuFUHmanV0XqBnjyydTmFcj/osuPdK6ob781e4+voP/LgEp+vS4Uf6r7+q613dB2P02RGDYaNHAZD280+4i9GR43Y6SfriSxLG3MbJ11/HmZZ2yec4UlKwHzsGgKV58yKdR2c259WY9OUXeeP7HImJHHv4EXA6Cb7uOsJGjSpa4eZA6PawGib1ehpMQXByK3w7FGZcB0fWFW0dAYBRb+SWuFv49cZfuaXZLRgUAyuPrWTooqG8+u+rJGUnaV2iEEIIISoZCZGEEL7HO87O04kE5IVIhwvaF8npgA3T4bux8PPDcLTwK3KFEEIIUTUFXXUVDZctpd7cuZgbNCjWc/2uULtrsjdvwe0seJ+S3P0HADDVL/ra5oYNQVFwnjmD48yZAo9xnDqFKy0NDIYiB1SWuGaA2ulTmKy16gf8AV26FLiPkblxI4w1auC22cjadPEALWe7GgT5tW5d6H5I5/IGXd4OrovxhjXBg65HH3g2+AsdNizvcZfNVugaLpuNrPXq74ABXa/M91irWqH8fH83wq/qqZ7n8D7umrScU+k5F13P7XaTsWSJenz//B1cgb17owsMxHH8xCVf27lOvvoqp99+m6x160j56msO33EnrpyL1wCQs30HoIZqRRnN6BV2880oFgu58TvJ+vdf3DYbxx5+BEdiIuZGDan+8ktF/j7msQRDr6fUMKnLfaA3Q8Iq+KIvzBoJp3YUbz0fF2IO4cmOT/LDkB/oU6cPTreTubvnMvD7gUzfPh2788JRhEIIIYQQBZEQSQjhe87rRIKz+yIlnB8iOXLVqyB/fgh2fA8bvoRpfeC3F2S8hhBCCCGKzNywIbqAAFxZWeTu2VPgMbkH1E4kc4P6RV5X5++PsU5t9fkXW3fvPkANChSTqWj1erqKHCdP4khJuehx1nVqiOTfsUOBjyuKgn9HdZRf1oaNF10nZ7caVlmaNS1SfQB+3lF+uy8edLmdTjKXrwAguF+/fI/5d+iAISoKl9Wat9fRRevbsgV3djb6iAjMjS8M4oIsRl6/qze5UWoXVfqm/7j2w5Us33W6wPVsBw5gP34cxWwmsEf3fI/pzGYCe/QAIHPlykLr8rKuWUPq7DmgKITfOx59WBg5O3ZwZsqUwl/XDjW8s7RoUaTzeBnCwgi98UYATjz7fxwZfy9ZGzagCwig5ocfofP3L9Z6+QSEQ7/X4IFN0PZWUHSwZzF81hUWjIPkgyVf2wfVDa7LB70/4Mt+XxIXHofVbuW9je9x46IbWX1stdblCSGEEKIS8OkQadKkScTFxdGrCBuICiGqkII6kcLVq1IPJ583ZuW3F+DACjAFquM1Wg5X71/9ASx5qrwrFUIIIUQVoej1eePCcuIv3L/ImZ6OM1HtJDLVL3qIBGDxhBq5ewseaZe7Xw2RitM9pQ8MxFjbE07t3l3gMc70dHI9nUoBnqCoIP4d2gPqfjkXk7tLPYe5SdFDJG/QlbO74PAM1P1+nMnJ6IKD8/Z58lIUhQBPgGP17OtU2DoA/u3aXbTDRlEUorqoYVoP2wmSrTZun7Ge/y3ZhcOZfyydN7Tya9MGnZ/fBWsFdOum1rX6n0Lr8jrz2WQAwm4eSdSDDxLz4osAJM/8qtCxdtnbSxYiAUTcfx/G2rWxHz+O9Z9/wGik5vvvYa5fr9hrFSikFgz+BCaug7ghgBu2zYNP2sMvj0LGybI5j4/oENOB2QNn80rXV6hmqcah9EOM/3089/95P0cyjmhdnhBCCCEqMJ8OkSZOnEh8fDwrirjxqRCiiiioE8kzzu7QuZ1IibthnboJM8O+VMdrDP0cBn8KKLB2MmyceVlKFkIIIUTlZ27qDT0u7JzJ9eyHZIiKQh8UVKx1TbGxANgOF/xBsHevJXPD4o3gszRVA52cnQV3+uTE7wS3G2PNmhgiIy+6jje8ydm6tcCxcS6bjdwDBzznvHBfpYsxN1aPtR08iCs3t8BjstapYU1Aly4oRuMFjwd2Vzt+rP8UHtZke8a++bUsPGzx8+x9NciQxG1d6gLw2Yr9jPp8LSfTzo6W89bl36HgDi7vyLycbdtwZmQUes7cAwfIWrcODAbCx40DIOiavpibNMGdlUXqd99d9Lk5O+IBsDSPK/QcBTGEhRE761uq3TaGkKE3EjtrVl4HVZmKaATDZ8LdK6BBH3A5YP00+LAN/P5ivgvDROF0io4hDYfw8w0/MyZuDAbFwIojKxjywxA+2vQRWXbZH1YIIYQQF/LpEEkI4aO8nUg56eDZrNg7zu5wchYu70bIayaB2wlNBkDjc8aftB0NVz2r3l78JCQfuEyFCyGEEKIyszRV9xnydt2cy+YJUYob9AAY69RR1zh8uMDHc/epIZKpmPs4mRs18tS2v8DHc3aqHVXe/ZMuWl/duuhCQnDb7dj27bvgcdv+/eBwoAsOxlC9epHrM0RFog8LA5eL3ALWBcjapI7QO78Lycu/vXp/7r79hYY1Odu2AZfu2PE+bt+zmxcHNWfSqCsINBtYdyiZgR+tZOXeRE9dmzznb1/gOsboaIy1aoHbnXfui0lfvBhQgyej599PURTCRo8CIO2nnwt8nuPMGRwnToCiYIkrfogEYIiMJPrpp6nx2muXDNhKrUZbuHUhjP0FanUERzaseh8+agv/fgaOwve1EmcFmYJ4vMPjLBi0gM7VO2Nz2fh82+cM+mEQSw4uwS1ju4UQQghxDgmRhBC+x9uJhBty1fEe1UP80Clgc7g4Y82F3AzYNl89rMvEC9fo/hjEdlffvC56QPZHEkIIIcQlWfI6kXZf8CFt7n41RDLVL36IZKqthkj2i4RItkOHADAXc0yeqV6s5/kJBT6es1PtYjE3KzxEUhSl0K4m7xg+c+NGFx0Vd7F1vSPtCgrm3C4X2f9tBsDviisKXMMQHq6O7XO7yd6ytcBjHCkp2I8dA7hk2GJu2AB0OpwpKTjPnGFgq+r8dH83mlUPJslqY8yX6/j4+w04Tqqj2AoLpfxatQIge2vBdXllLPsNgOD+1+a7P/iaa8BoJHf37gJDtpwdaneVqV499IGBhZ6jQontBncug5GzIaIJZCerY6YndYQdP8jv5cVQP7Q+U/tO5YNeH1AzsCansk7x+N+Pc8fSO9idXPAYSyGEEEL4HgmRhBC+x2ACo2ezX8/4C5NBR3SwBYBjKdmw73ewW6Fafajb9cI1FAUGfaSuc2gl7Fh4mYoXQgghRGVlatgQDAZcaWl5IYJX3si5BsULegBMdT2dSMeO4XY48j3mslpxpqQAqJ0txVpXHcfmDaHOl+vtRLpEiATnjMbbdWGI5B3D5z1fcXj3ebIdOnhhffv24crIQPH3L3RMnl+bNgBkb95c4OPefZ+MdeugDw4utB6dxYLJ0xmWs0fdq6leRADfT7iSmzvWwe2GP35epR5bqzb6wICL19XaGyJdvBPJkZyct2dVYK+e+R7Th4YScGUXADKWL7/guWf3Q2pe6GuqkBQFmg6Ae/+B6z6AgChIOQjf3QZfXAOH12pdYaWhKAp96vbhh8E/MKHNBMx6MxtObWDEzyN4e/3bMuJOCCGEEBIiCSF8VAH7ItUKUzc1PpqSDXvVKzppMkB9k1qQavWh28Pq7d9fBEfBs/iFEEIIIQB0JhNmT3ePt/vGy7snUEk6kQzR0SgmEzgc2M8Lp+zHj6vnDg4u/l5LnlDHkZiIM9Oa7zG3zUbuATW4KVKI5Bl55+1eylfj0aPq+WrXLlZ959ZYULeUN/yxNGuGYjBcdA3vGLacXTsLfNz7vTE3aFikmsyNG6vP23P2e2wx6nnjxpa8P6I1jTPV79FaXTXW7E+66DqWlmqIVNg4u6z1G/LOaQgLu+DxvD2fVl+451OOd5+nS4zoq9D0Bmh/OzzwH/R8Sr3A6+g6+PIamHsLJBU8ilFcyGKwcG/re1k0ZBF96/bF6XbyVfxXDPphEH8k/CEj7oQQQggfJiGSEMI3efdFOmcj3lphanfSsZSssyFSo2sKX6fLRAiqDqmHYe2Usq9TCCGEEFWKKbYeALaDZztnXDk5eUFKSTqRFJ0OY82awNlAxsvmGcPmfbw49CEh6KtVU9c9nD+ksR09Ck4nOn9/DNHRl1zL7OlEyt2z94IPo701G2sWr1MKwBTrCZESCgiRPJ1AliaNC6/NG/qcF+x52Txhmbl+vSLV5N1LKnf/hSPkbmhbizti7ADs8o9m9LR/mbR839k9OfPVpa7jSEzEmZZW4Lmy1q0DwL9DhwIfD+h6JQDZGzfiysrfUZKT14lUiUMkL3Mg9H5aDZOuuA0UHez8SR1x9+vjYD2jdYWVRo3AGrzX6z0+7fNp3oi7h1Y8xP1/3s+xzGNalyeEEEIIDUiIJITwTQV0ItUMVTuRsk7tBetp0JuhTpfC1zEFwFXPqbdXvQ+5mWVfqxBCCCGqDJNnX6Lcc0Ik26FD4HajCwlBHx5eonUNMWqQ4zh1Kt/93k4kY80aJVr3YiPtvJ0/xti6RdrHyFS3LigKrvT0vPF6eWt5g65axQ+68uo7fBi3y5XvMe84OW9IdDHe0Md++Aiu7OwLHrcd9HSJ1StawOcdL2j3jOm74HzH1L2rYtq0wOWGt5fu5vYZ60m22vIdpw8MxFC9OgC5+wvuqMnesgUA//btCq4lNhZjjRq47Xay1q/Pu99+6jSOxETQ6fJGDVYJQTHqyOl7/1EvBnM5YN1U+KgtrHwX7Bd+f0XButfqzveDv2dcy3EYdAb+OvoXQ34YwhfbvsDusmtdnhBCCCEuIwmRhBC+yezZPPic0Mc7zs4/UX0zTvXW6v5Jl9JqhDraLjsZNnxZ1pUKIYQQogoxecbZ2Q4eyrvPGxCY69cvUiBTEGOUGiLZT53Od7/dE9CYStCJBGCsrXYHeYMeL2+oVNR9jHQWC8YaapCVrwvLZsvbH8pUzD2bwNNhZTDgzsnBcTr/a8/d7Q2RLr4fEoAhPFztuHK7yd13YViTu987zq5oIZKxljqWz3a04BDJdkS9/+6be/LWsFaYDTr+2pPIdR+tZMuR1HzHevd8yt13YVeT2+HI656yxMUVeC5FUfK6kaz/nt0nKGfHjrz1df7+RXpdlUpUMxj9HYxZBDGtIDcd/ngZPm4HW+bAeYGjKJifwY8HrniABdcvoH10e3KcOXyw6QOG/zScTac2aV2eEEIIIS4TCZGEEL7J5AmRbGdDpJqeECk6fbvnjoKv6LyA3gDdH1Vv//Mx2GTzWSGEEEIUzFzPM87Os88OgM0TUphKMMrOyztS7oJOpGOeTqQaJetEMlZXn+c4kX+vJe/4OFNsbJHX8h57bojkOH4c3G4UiwV9RESx61MMhryA7Nx9kZxpaXn/Ft6xcIXJG0F33kg7Z6Y1bx1TvaKNszPVUUMkx4mTuGz5u4ucaWm4PKPpTLVrMbx9bX68ryv1IgI4npbDTZPXMHvd4byRf4WFSLaEBNy5uej8/TEWsp+Uf6fOAFj//Tfvvio1yq4w9XvC3X/BDVMhpDakH4Pv74EvroYj6y/9fAFA/dD6fNnvS17t+iph5jD2pe7jtiW38fzq50k9Z7KDEEIIIaomCZGEEL6pwE4k9SrM2Nzd6h01ryj6eq1GQGgddQzeppllVaUQQgghqhhvEOFITMSZqf4ecrYTqUGJ1zVER6nrns4fInm7c4qyb1FBjJ5xavYTJ/LdX9xOJDj72s8djWc7enaUXYm7sArYFynXE9IZYmLQBwYWobZYdY3z9n5ynFBDOF1ICPqQkCLVow8PR/H3B7c7rxPMy+YZcaePjMjrAGoaE8yP93XlmrhobE4XTy/cxpMLtpJjd2Ju1FB93v4DnC9n5y5AHden6C7+1j6gcycAcnfuxOEZJZi9wxsiNS/Sa6rUdDpoPQLu2wB9XlAvJju2UQ2SFtwFabLPT1EoisLghoNZNGQRQxsNBeD7fd8z+MfBLDm45IK9zoQQQghRdUiIJITwTaYg9astI++u6iEWwE0DPBtSRxfjyky9Ebo9rN7+5xNwypxwIYQQQlxIHxyct++Rd6Sd7YAnRCpNJ1KUGiKdP87OkXRGfbwEXT4Axuox6ronz+tEOqzu61O8ECkWgNxzRvnZj6q/d5lqFn+UXV6Nni4ruyfwAbB7RsaZ6tQpWm2eTh770fyBgiMxUT1HVGSR61EUJW80n7eOs3V5/t1q568r2GJk8i3teKJ/E3QKzNtwlGGT/yElVP2+FjQaL3e3J0RqeolxfREReWFU1tp1uN1ucrar4+z8qnon0rmMFuj+CNy/EdrcAiiw7Tt1xN2KN2WaQBGFWkJ58coX+erar2gQ0oDknGQe//txHlj+AKespy69gBBCCCEqHQmRhBC+yewJkc7pRLIY9TQNzCFYycKt6NR9joqj9SgIiIT0oxD/YxkWK4QQQoiqJG+k3aGD6r42njFspgYl70QyXmScnfNMEkBecFVcBk8nkuP42YDG7XLldTgZaxR9r6WCxtnZPeGIsQT7IXl5QyTHOd1S3n2HvHs6XXINzz5G54c+3hDJEFn0EAk8ezUB9uPndXB5OpFMBYyf0+kUJvRqyFd3dKJagIntx9IZ99vxvHXcTme+43M9/47mhpce1+ffuQsA1n/X4DhxAmdSEuj1mJsUHkBVSUExMGQS3L0cancGRzaseAM+6QDb5oN01BRJ26i2zLt+Hve2vheDzsCKIysY8uMQvtvzHS637DklhBBCVCUSIgkhfJP5wj2RANoHqlfrZvnXUq9WLA6jBTqMU2+v+UTegAohhBCiQN6xbrn796thh92OYrGUeN8iONuJ5DhzBrdL/QDXlZ2Ny2pVHy9mCOLlHWfnTEvDlaV2ajiTk8HpBJ0OQ3i1Iq9lqhsLqEGNNxDJG2dXxLCn4Bo9nUjHzu1E8nQ41S5qJ1ItTz1H891f0hAp7/txOn9nmN0TxnlDpoJ0axTBT/d3o3WtEA4QgEPRgd2O7bx9qewJRe8GC+ii7ouU9e9asrdsAcDSpAk6SzF/361KarSFO5bAsOme/ZKOwoI74cv+cGyT1tVVCia9iQltJjDvunm0jGhJpj2Tl9e8zF3L7uJw+mGtyxNCCCFEGZEQSQjhm0zePZEy8t3dwqxevZtkKdoHDhfocCcYLHD8Pzi8pjQVCiGEEKKKMjdUO45s+w9g8+zdY6pfr9B9bS5FHxam3nA6cWWov984ktQuJMVsRhcQULJ1g4LQefYU8o608wYj+vBqKAZDkdcyVo9BMRpx2+3YPYFI3ji7MuhEspemE8nTGeRMSsoL3gDs3j2lihsiefeoSjxvvKA3lLrEHlU1Q/2Ye08XRnSK5bS/+r19/cs/SctWRya73e6812iqc2FX0/n827cHnQ7boUOkL1kKgF+bNkV/QVWVokCLG+G+9dD7/8DoD0f+hc97ww8TIPP0pdcQNAprxNfXfs0THZ7Az+DH+pPruXHRjczYPgOHy6F1eUIIIYQoJQmRhBC+KW+cXf4QqYGifvhwVFf00Sz5BERA65Hq7X8+KWl1QgghhKjCTA3U/Wly9+8nd78aIpnrl3yUHYDObEbn7w+AMyVF/XrGsx9SeDiKopR4be++SA5PiOQNVoyRUcVaR9HrMXr2KLIlHFLX8oRIpRtnp3ZL2U+dyutwKu6eSPqgIPShoWpt53QjlXicnXePqvM6kRzFCKUsRj1v3NiSoFj1NRzbuZ9Bn6xi54l0HKcTcefkgF5fpA42fXAwFs/+RxlLPSFS2zZFfj1VntEPej6u7pfUyvO7/OZv1f2S/v0MnBKEXIpep+fWuFtZOGghnat3JteZy7sb32X0r6PZk7JH6/KEEEIIUQoSIgkhfJOp4HF20W71jf0hZ8k2nwag80T16+5fIWl/ydcRQgghRJVkbqSGSLaEBHJ37VLva1DMvRgL4O1GcnhCJIcnRNJHlOL3mnOe710vLwiJKl6IBGdHr9kSEnBmZuJMTQVKFyIZIiPBYACHA0diIq6cnLN7NhVjXYMnjPKGZVAW4+wS891fkn+72nHqz0sjVzoJSVnc8Olqlv+5EVDH4ilGY5HWCejU6exfdDoCunYtcg0+I7gG3DgF7vpDHXeXmw5LnoIp3eHQKq2rqxRqBdViat+pvHzlywSZgohPimfEzyOYtm2adCUJIYQQlZSESEII3+TdEyk3f4gUalff6O/NCSn52pGNoVE/wA1rJpV8HSGEEEJUSYaoKHVEnNNJxm+/AWBq2LDU63pDJG8w4zijjrMzhIeXal1DRGS+9bzBSGlCJHtCQl4Xkj40FL1nZF5JKHo9Rs94OPvxE3nr6gID87qLisLbWXVu91BZ7onkdjrzRgwaooq+nrGm2mk0oq6JHo0jybG7+G7Rv+pjtS89ys4rqF8/dXwbENirF4ZqRd/PyufUaq8GSdd/CH7V4HQ8zBgI8++E9BOXfr6PUxSFGxrdwKIhi+hduzcOl4MPN33ImMVjOJB2QOvyhBBCCFFMEiIJIXzTRTqR/HPUK093WINxu90lX//K+9Svm2dBVnLJ1xFCCCFElaMoCuYG6vg6t13d48bSpEmp180LkVJS1a+eMElfLaxU63pDqLLpRPKMszuUUCaj7LyM1T0j7Y4fx3bYsx9SndrFGuPn3afIcercEMkzErCEIZIzKQm3zaaulZQELhfodMUK9gyecEuXnMT0sR0Y37MBMZ7fL1dZTaRl2Yu0jl+L5tT84APCx42j+ssvFefl+CadHtqNVUfctb8TUGD7fPikPaz+EBw2rSus8CL8Iviw94e83u11goxBbDuzjeE/DeerHV/hcru0Lk8IIYQQRSQhkhDCN+XtiXROiOSwoc9SrzbdnxtCRm4pxi3EdoeYluDIho3TS1GoEEIIIaoiU6OznUeKv3+ZBCn6sFDgnD2R0tPV+4OCS7WuIdI7zk79PelsiFS8YAXyj7OzlWWI5OnWsZ84jv2oZz+kWkXv0oFzuodOnQLAmWnFnZWlPlbMEEkfGgo69e22wxPq5XVwhYej6PXFr+v0afQ6haeubcrAGuoIu225ZgZPWsXeUxmFLZEnuN81RD36CIZSjjj0Kf7V4Lr34O4VUKuDehHab8/D5K6wf7nW1VV4iqJwfYPrWTh4IV1rdCXXmcvbG97m9iW3cyT9iNblCSGEEKIIJEQSQvgmb4hky1CvCAXIOIGCGxsGkgjmZFpOyddXFOji6UZaO1WuVBRCCCFEPgEdO+bd9m/XDkVX+rdmhrxxdmqI5MrwhEghpQyRPIGDsyz3RDp6FFtCgnpf7dKHSAZvJ9KxY3mdSKY6xQyRovOPoHMkql91/v7oAgKKtZai06EPUccj540X9KxX7K6m6AtH49VwqhdCuaqFcygpiyGTVrNsx8kCny/KSI02cMcyGPwpBETCmT3w9RCYNwZSJQy5lJiAGD67+jNe6PIC/gZ/Np3exNCfhjJ319zSTYAQQgghRLmTEEkI4ZtM58zdt1vVr+nHADijiwAUTpQmRAJofiMExkDmSdixsHRrCSGEEKJKCezZMy9kCBk8uEzW9I6zc3g7kdLUEElXyk4kfbi3E0ndz8fuCUOMJQiRDDExKGYzOBxkrVuvrlOz9CGSydPNZD96DPsRzzi7YnY4eV+PPS9EKtl+SF7n71HlTE1T7y/mXkTeupypqbi8o/E8Y/YeGnElnepVw2pzcvfXG/nw9724XPKBfLnR6aDtaLhvA3S6FxQ9xP8IkzrCqvfBWbTRgr5KURSGNR7GgkELaB/dnmxHNq+ufZV7f7+XRM9ECCGEEEJUPBIiCSF8k9Hv7G17tvo1Q716M8OoflBwIjW7dOcwmKDT3ertNZNArrATQgghhIc+JITYBQuo89VMQq4bWCZr6gLVTmtXhtql4iyrTqS8cXZncDscOD1hUknCFUWny+sQsu3fD4CxDDqRjJ7RdfajR8+OyatdwnF2ZRUihYYC54RIaepXb3hYVLqQEBSTyVNb/pGC1erU4Ju7OnFbF7XD6/3f93DvtxvJLM1YZnFpfqFw7Ztwz99QtyvYs+D3F2Fyd0hYo3V1FV6toFp80e8Lnur4FGa9mdXHVzN00VCWH5bxgEIIIURFJCGSEMI3KQoYPEGSXZ11T5bn6lqzenVoqTuRANrdDkZ/OLkVDq0q/XpCCCGEqDJMtWrmG2tXWvogtdPalekZdebpRNIHl9E4u5QUdb8gtxv0+mJ31HgZPSPtvEx1Y0tVH5ztOrIdO4b98GF13Tp1irWGIToaAGdSEm6b7WyIVIK9n+CcEMnbGeYJk4obIimKki/gctvtOJM8QV5UFEa9jpcGt+Ctoa0w6XUs3XGKGz9dTUKStUR1i2KIaQFjf4Ehk8E/HBJ3wvT+8ON9kJWsdXUVmk7RMbrZaOZeN5cmYU1IyU3hgeUP8MqaV8h2lPJiPiGEEEKUKQmRhBC+y9uNZPeERVZ1LIjbPxygdHsieflXgzaj1NtrJpV+PSGEEEKIi9AF5g+RnOllNM4uNBT0enC7ydm9G1CDJUWvL9F6pnNDJKMRY/WYUtUHYIyJBoMB7Hbcdjvo9RhjireuPjRUXQNwJCWVQyeSZ5xdaPFCJMjfJeXwBEgYDHnnABjeoTaz7+5MVJCZPacyGfTJalbulRFh5U5RoM3N6oi7K8ao9/33NXzSHjbPkmkEl9AgtAGzBs5ibPOxAMzbM4/hPw0nPile28KEEEIIkUdCJCGE7zL6q1/P60TSB3nG2aWXQYgE6rx0gD2L4cy+sllTCCGEEOI8ugA1RHJavePsMoDSj7NT9Hr01dT9fXK27wDOhholYapzNkQy1a5d4jDqXIrBgLF69by/G2vUQDEai7eGTpcXGDlOny6DPZFCgbMhkisvRAot9lpnQ6RTeaPsDJGRKLr8b+nb1Q3jp/u70aZ2KGnZdsZOX89Xaw6VqH5RTP7VYNDHcMdSiIpT31v8cC/MuA4Sd2tdXYVm0pt4tP2jTO07lSi/KA6lH2L0r6OZvn06LrdL6/KEEEIInychkhDCdxkt6lfvnkhZaieSJVj9oOBkWhmNUYhoCI2vVW//+2nZrCmEEEIIcZ68TqSMTNwuF668TqSgUq/tDVJy4tXugFKFSOd0IlmaNy9dYeeue87eSqbzRuYVVV6IlJiYt/9QmXUipXpCpGKOs7ugrkuM2YsOtjDn7s7ceEVNnC43z/+4g+d+2I7DKR/GXxZ1Oqt7JV39kjo+O2EVfNYV/njl7PsOUaAuNbqwYNAC+tTpg8Pl4L2N73H3srs5aT2pdWlCCCGET5MQSQjhu7zj7Lwztz2dSP5h6uiTE6ll1IkE0GWi+nXzLJmPLoQQQohyoQ8MANRxdi6rNW+MVklCi/PlhUg7vJ1IJQtWACxNm6BY1It5Ant0L3VtXqYGDc+eo1nTEq1RYFhTZiGS+lVXmhDpdGK+TqSLsRj1vHtTa57s3xRFga//TeD2GetJy7YX+9yiBPRG6PYQTFwLjfuDyw4r34FPO8P+5VpXV6GFWkJ5v9f7vHTlS/gZ/Fh7ci1DFw3lj4Q/tC5NCCGE8FkSIgkhfJfBuyeSJ0SyqiFSaLgaImXkOsjIKaM32rHdIKaVGlht+LJs1hRCCCGEOIe348hlteZ1IWE0ojObS7220TtOzROsGEvRiaQPDaXuzBlUf/MNgq+7rtS1eQV06ph32++KK0q0hjccs587zq6Er1Xv7QyzWoFz9kQqSYgU5Q23TmP3hEiX+h4oisK9vRow+ZZ2+Bn1rNx7hhs+Xc2hM9Zin1+UUFhduHkOjPgWgmtCyiH4egj8MEEuLCuEoijc2OhG5l03j+bhzUm3pfPQiod4c92b2Jw2rcsTQgghfI6ESEII3+XtRLJ7Oo48nUiW0CiCLeqmyqfKal8kRYEu96m3100FR27ZrCuEEEII4eEdZwfgSFJ/r9H5+ZXJ2ud3vZRmnB2AX+vWhA4ZgqIopVrnXIG9ehE6YgShI0YQ2KNHidbwvk770WN5QVxJO5F0AWpnWN4eVXkhUmiJ6ypJh1S/5jHMv7cL1UMsHEi0MnjSav7Zf6bYNYgSUhRodp3aldTxbkCBzd/CpI6wfUFex6C4UGxILF8P+JqxzccC8O3Ob7l18a0cST+ibWFCCCGEj5EQSQjhu/JCpCz1zZsnRMI/nOoh6mMn0spwpF3zGyCoOmSegu0Ly25dIYQQQghAMZnAaATOdgyVWYh0XmhU0mClPCkGA9VfepHqL72IoteXaI3z935STCZ0wcElWkt3TieS2+XClamGSfrg4u9RldcJdu44u2IEec1rhPDjfV1pUzuUtGw7Y75Yx+x1h4tdhygFcxAMeBvuXAaRTcGaCPPvgNkjIe2o1tVVWEadkUfbP8qkPpMIMYcQnxTP8J+Hs/TQUq1LE0IIIXyGhEhCCN9lPGecnSNHnVUOYAkhJkSd01+m+yIZTJ6rD4E1k+SqQyGEEEKUKUVR0Hu6X7xBQ7mFSDExZbJuReMNa2z79wNqqFTSbilvJ5Ir04o7O/vs/f7+xV7LG24509KwHz2m3lfMbrCoIAtz7u7MoNY1cLjcPL1wGy//FI/TJb+TXla1O8I9f0OvZ0BnhD1LYFInWPc5uFxaV1dh9ajVg/nXz6dtVFsy7Zk89tdjvPrvq+Q6ZcKDEEIIUd4qfYi0cOFC+vbtS7Vq1VAUhUOHDmldkhCisvDuieTIhhzPvgEoYAygujdEKstOJIB2Y8HoD6e2wcG/y3ZtIYQQQvg8xV/9/caRnJzv76V1fueRsXr1Mlm3orlgbF8pOq50AWc7kVxZWeqdioJSgmBPFxKidpoBtgMHSlybxajnw5FteLRvYwC+XH2QO2euL7t9QEXRGMzQ60kYvwpqdwJbJvz6GEzvD6d3aV1dhRUTEMOX/b7krpZ3ATB391xu+fUWEtITNK5MCCGEqNoqfYhktVrp0aMHL7/8stalCCEqm3M7kXIz1NvmINDp8sbZnUzPvsiTS8i/GrQZrd5e/UHZri2EEEIIn6ezqL/DOJNT1L/7Fb/rpSDndr0ofn759l+qSso2RFL/7d05OTg9+yvp/P1L1NmkKEqZ1aYoCvf3acSno6/AYtSxYnciN376D0eSs0q0niiFqKZw+xIY8A6YAuHIWpjcDf5+G5wS7BXEoDPw4BUPMvnqyYSZw9iVvIvhPw3nt4TftC5NCCGEqLIqfYh066238txzz9GrVy+tSxFCVDb5QiRPJ5JZnXlfbp1IAFfeB4oe9v8JxzaW/fpCCCGE8FmKxQyAM8UbIpVRJ1J4+NlzKEqJR7xVdPpq1eCc/ZRKEyJ5RwvCOXtUlWCUXYG1GI3ow8JKvBbAgJbVmXdPF6KDzew9nckNn65m85HUUq0pSkCng47jYOJaaNxfHbH956swrQ+c3K51dRVW15pdmT9oPu2i25HlyOKRFY/w3ob3cLgcWpcmhBBCVDklCpG++eYb7rnnHtq3b4/ZbEZRFGbMmFHoc9avX8+AAQMIDQ0lICCAzp07M2/evJKcXgghysbFOpGgfPZE8gqLhVbD1dsr3yv79YUQQgjhs/I6kVLLNkRSDAb82rcDIKjv1WWyZkWk6PX5AjNDVMlDJMVkQjGroV7eHlWlCZHO6QYzRkai6Ep/TWirWqH8OLEbcdWDOZNpY+TUNSzZfrLU64oSCKkFN8+BGz8HSyic2AJTe8Ffb0lX0kVE+Ucx7Zpp3BZ3GwDTd0xn/G/jScpO0rgyIYQQomop0W+d//d//8fUqVNJSEigehFmYS9fvpyuXbuyatUqhg8fzvjx4zl58iQjRozg3XffLUkJQghReoWESGc7kcp4nJ1Xt0cABXb9DKd2lM85hBBCCOFzdJ5OJId3nF0Z7YkEUP2FFwgbdTORDzxQZmtWROd2/Biioku1ls7TjZQXIp3TnVS6uqIKObJ4YkIszBvfhd5NIsmxu7j3241MW3kAt9tdZucQRaQo6sVmE9dBk4FqV9Ly1+Dz3nBiq9bVVUgGnYHHOjzGOz3fwc/gx9qTaxnx8wi2Jsq/lxBCCFFWShQiTZs2jUOHDpGYmMj48eMLPdbhcDBu3Dh0Oh1///03U6dO5d1332XLli00btyYZ555hoSE/JsgPvXUU3kjEi72RwghSs3g+VDFcWGIFBWshkjpOQ5y7M6yP3dkY4gbpN6WbiQhhBBClBHF24nkGWenlFEnEoC5USNinn8eY82aZbZmRZSv46d6TKnW8u4dZS+LTqRzQ6To0oVb5ws0G/h8THtu6VwHtxte/WUnLyzagcPpKtPziCIKioaR38LQL8AvDE5uU4Ok5W+Aw6Z1dRVSv9h+zB44m9jgWE5lnWLskrHM2z1PwlAhhBCiDJQoRLr66qupW7dukY79888/2b9/P6NGjaJNmzZ594eEhPDMM89gs9mYOXNmvuc8+uij7Ny5s9A/QghRakY1KCqoEynYYsBsUP8TmZiRWz7n7/6Y+nXHQkjaXz7nEEIIIYRP8XYiOZOT1b/7lTy08FWGiIizt2NKGSJZ1N83nWfU8Vql6kQ6N9wqZV0Frq/X8crgFjw7oBmKAl+tSeCerzdizZU9ZjShKNBymNqV1Ox6cDngrzc9XUlbtK6uQmoQ2oDZA2dzdZ2rsbvsvPLvKzy3+jlyHOUwolwIIYTwIaUfonwJK1asAOCaa6654LF+/foB8Ndff+W7PzIykqZNmxb6pzRyc3NJT0/P+5ORkVGq9YQQlZRe/ZAFpw1y09XbnhBJURSigtXHT2eU05uO6q2gUT9wu2CVdCMJIYQQovS8nUhuu7qHSlntieRL9GFhebdNtWqVai3FEyI5UjyhXkDJQz1jzNnuI1PDBqWq62IURWFcj/p8OuoKzAYdf+w6zfApaziVLh/CayYwCoZ/DcOmg384nNoOU3vDn69JV1IBAk2BvNfrPR664iF0io4f9//ImMVjOJZ5TOvShBBCiEqr3EOkvXv3AtCoUaMLHouJiSEwMDDvmJJITk5m8+bN7N69G4D4+Hg2b95MsufKu4K88cYbhISE5P2Ji4sr8fmFEJWYoaAQKTjv4agg9U3/qfRy6kQC6OHpRtoyB1KPlN95hBBCCOETvJ1IXoqfRaNKKq+QIYMx1q5NtTvuQDEaS7WWzuztDPOMFyzFODu/tm0xREWh8/cn8MorS1XXpVzbsjqz7+5MeICJHcfTuWHSanadTC/Xc4pCKAq0uBEmrIW4IeB2wt9vwdResldSARRF4c6WdzKl7xTCzGHsTN7JzT/fzMZTG7UuTQghhLjsFi5cSN++falWrRqKonDo0KFir1HuIVJaWhqgjq8rSHBwcN4xJbFo0SLatm3LsGHDABg4cCBt27Zl0aJFF33O008/TVpaWt6f+Pj4Ep9fCFGJ6T0fCjhs54yzC8x7OCrI04lUnlde1u4I9Xqo4ylWvlt+5xFCCCGET/B2InnpTCaNKqm8zPXr0/C3ZUQ/8Xip1/J2IuWNFyxFiKTz86PewgXUX/zrZdmX6oo6YXw/oSv1IwM4npbDsM/W8PeexHI/ryhEYCQMnwk3zQD/CDi9Az6/Cla9D65y2Me1kutcvTNzr5tLs2rNSMlN4a5ld7FgzwKtyxJCCCEuK6vVSo8ePXj55ZdLvEa5h0jlbezYsbjd7gv+jB079qLPMZvNBAcH5/0JCgq6fAULISqOc8fZ2bPV28azb+yjg9U3/afLa08kr15Pq1//+xpSDpXvuYQQQghRpV3QiSQhkqby9qjyXDipM5euM8wQEYExOvrSB5aROuH+LLz3SjrVq0ZmroPbZ6znuw3SPa+55jfAxLXQ9Dpw2eH3F2HGQHkvUYDqgdWZ0X8G19S9BofLwYtrXuR/6/6HwyV7fQkhhPANt956K8899xy9evUq8RrlHiJ5O5Au1m2Unp5+0S4lIYQoV95OpHwh0tmrdyO9nUjlHSLVvRLq91a7kf5+u3zPJYQQQogqTTkvpCjtODZROt7vh9tm8/zdXNjhFVKov4mv7uzIkDY1cLrcPD5/K5OW78Ptdmtdmm8LiIAR38DgSWAKgsNr4LOusOlrkO9NPv5Gf97p+Q4T2kwA4Jud3zDh9wmk5ZZ8Ko4QQghRmG+++YZ77rmH9u3bYzabURSFGTNmFPqc9evXM2DAAEJDQwkICKBz587Mmzfv8hR8CeUeInn3Qipo36OTJ0+SmZlZ4H5JQghR7s7dE8nhGVlnOPvBS9TlCpEAej+rft08G5L2l//5hBBCCFElSSdSxaJUke+H2aDnveFtuKdnfQDeXrqb53/cgdMlYYWmFAXa3gL3roI6XcCWCYvugzmjIVNGD55LURTubX0v7/V6Dz+DH2tOrGH0r6M5mHZQ69KEEEJUQf/3f//H1KlTSUhIoHr16pc8fvny5XTt2pVVq1YxfPhwxo8fz8mTJxkxYgTvvqv99hflHiL17NkTgGXLll3w2NKlS/MdI4QQl5Xe8ybekVtgJ1KUd5xdee6J5FW7AzS6Rt0k96//lf/5hBBCCFE1ndd5VFlDi6ri/PF1OnPl/X7odApPX9uMF66PQ1Hg638TmPjtJnLssheP5sJiYewvcPVLoDPC7l/gsy6we7HWlVU4fev25atrvyImIIaE9ARG/zKaVcdWaV2WEEKIKmbatGkcOnSIxMRExo8fX+ixDoeDcePGodPp+Pvvv5k6dSrvvvsuW7ZsoXHjxjzzzDMkJCTke85TTz2FoiiF/ilL5R4i9enTh/r16zNr1iw2b96cd39aWhqvv/46JpOJMWPGlHcZBZo0aRJxcXGlmgcohKjEvCGS0659JxJA72fUr9u+g8Tdl+ecQgghhKhSzh9fJ+PstHVBJ1IlHGd3vtu71uPjm9ti0utYsuMkt36xlrQsu9ZlCZ0euj0Edy+HqDiwJsLskbDoAcjN1Lq6CqVptabMHjibNpFtyLBnMPGPiczbXTHGBQkhhKgarr76aurWrVukY//880/279/PqFGjaNOmTd79ISEhPPPMM9hsNmbOnJnvOY8++ig7d+4s9E9ZMpTkSdOmTWPVKvVKjW3btuXdt2LFCgC6devGXXfdpZ7AYGDatGn069ePHj16MHLkSIKCgliwYAEJCQm88847xMbGlv6VlMDEiROZOHEiR48epXbt2prUIITQUN44u4I7kaI9nUjJVhs2hwuToZxz9xptoclA9crBFW/CTdPL93xCCCGEqHIuCJGkE0lT53ciKcaq8f24rlUNwgPM3P31BtYfSmHY5H+YcUdHaob6XfrJonzFtIRxy+HPV2DNJNg0Ew7+BTdMhTqdtK6uwojwi+CLfl/w0pqXWLR/Ea/8+wpHM47yULuH0Cnlfr21EEKISiYjI4P09PS8v5vNZsxldHGQN1O55pprLnisX79+APz111/57o+MjCQyMrJMzl8UJfp/xlWrVjFz5kxmzpzJpk2bAFi9enXefd6Ayat3796sWrWKrl27MnfuXD777DOio6OZM2cOjz76aOlfhRBClERBnUjnhEhh/kaMerX980zm5epGelr9uuN7OLXj8pxTCCGEEFXGhZ1IVSO0qKwUy3khUhXoRPLq0iCc78Z3ITrYzN7TmQz99B92nUy/9BNF+TNaoN9rcNtPEFIbUg7B9P7wx8vgsGldXYVh0pt4teurTGwzEYDpO6bz+F+Pk+O4DOPMhRBCVCpxcXGEhITk/XnjjTfKbO29e/cC0KhRowsei4mJITAwMO+YkkhOTmbz5s3s3q1OPYqPj2fz5s0kJycXeY0ShUgzZszA7XZf9M+MGTMueE7Hjh1ZvHgxaWlpZGVlsXbtWkaMGFGS0wshRNnItyeSd5zd2RBJURQiAy/zSLuYlhA3GHDDirL7PyQhhBBC+AbFcH4nkoyz05LugnF2VSvUaxoTzMIJXWkUFcjJ9BxumryGNfuTtC5LeNXrDveuhlYjwe2Cle/CtD5wepfWlVUYiqIwvvV4Xu/2OgadgWUJyxi3bBwpOSlalyaEEKICiY+PJy0tLe/P008/XWZrp6WlAer4uoIEBwfnHVMSixYtom3btgwbNgyAgQMH0rZtWxYtWlTkNaRHVwjhu/I6kWzg8I6zy3+1aKRnpN2p9Mt4NVqvpwEFdv4EJ7ZcvvMKIYQQotKTTqSKRTlvnJ2uCo4XrBnqx3fju9AhNoyMHAe3fbmOX7ae0Los4WUJgRunwE0zwS8MTm6FKT1g3efgdmtdXYVxfYPrmXL1FIKMQWxO3Mwtv95CQnrCpZ8ohBDCJwQFBREcHJz3p6xG2V0OY8eOLbARaOzYsUVeQ0IkIYTvMnjfxLvPbjZryP9GPyroMnciAUQ1g5bq1QH8+erlO68QQgghKr0L90SSTiQtXfD9qEQfOBRHqL+Jr+/sRL/m0dicLu6bvYkZqw9qXZY4V/MhMOFfaHi1uifsr4/B3Fsgq+ijbKq6jtU78s2Ab6gZWJPDGYe55ddb+O/0f1qXJYQQoorzdiBdrNsoPT39ol1Kl4uESEII36U/50rQvE6k/JsBRwerb/QTL2cnEqjdSDoD7F0Gh1Zf3nMLIYQQotJSjIbz/l71Ol8qkwtDvaoZIgFYjHo+Hd2OWzvXxe2GF3+K5/3f9uCWbpeKIygGRs+Hfq+Dzgi7fobJ3SFhjdaVVRj1Q+vzzYBvaBHegtTcVO5aehe/JfymdVlCCCGqMO9eSAXte3Ty5EkyMzML3C/pcvLpEGnSpEnExcXRq1cvrUsRQmhBX8Cb+As6kdS/X9ZOJIDwBnDFGPX27y/KqAkhhBBCFMmFoYWESFq6INSr4t8PvU7h5cHNefjqxgB8+MdeXly0A5dLfpetMBQFukyEu36DavUh/SjMGAB/vQUup9bVVQgRfhF82f9Letfujc1l49EVjzJv9zytyxJCCFFF9ezZE4Bly5Zd8NjSpUvzHaMVnw6RJk6cSHx8PCtWrNC6FCGEFnR6QMl/33mdSN5xdpd1TySvHk+AwQ+OroPdv17+8wshhBCi0rlwTyQZZ6el8//9deaqHSIBKIrCg1c34qVBzQGYuSaBR+Ztxu50aVyZyKdGW7jnb2g1EtwuWP4afDUY0o9rXVmF4Gfw4/1e7zOs8TDcuHnl31f4bMtn0lknhBCizPXp04f69esza9YsNm/enHd/Wloar7/+OiaTiTFjxmhXIGC49CFCCFFFKQoYzOA4JyA6rxMp0hMiJWZe5k4kgODq0PleWPUe/PEyNO7vCb6EEEIIIQqmGM7vfJEQSUsXfD+q6J5IBbntylhC/Y08Om8LP2w+TnqOg0mjrsDPJL/PVhjmILhxCtTvBb88CodWwmddYchn0KS/1tVpTq/T83zn5wm3hDNl6xQ+3fwpydnJPNXxKfTyvkwIIUQhpk2bxqpVqwDYtm1b3n3eZpZu3bpx1113AWAwGJg2bRr9+vWjR48ejBw5kqCgIBYsWEBCQgLvvPMOsbGxWryMPD7diSSEEPlG2unNarB0johA9fHkTNvlrOqsrg+CJRQSd8GWOdrUIIQQQohK48JOpKrf+VKR+fp4wcFtajJ1TDvMBh1/7jrNmC/XkpZt17oscb42N6tdSTGtIDsZZo+AxU+BQ4ML6SoYRVG4r+19PN3xaRQU5uyew5Mrn8Tm1Oj9oRBCCE307t2buLg4Jk2aVKTjV61axcyZM5k5cyabNm0CYPXq1Xn3eQOmc9dftWoVXbt2Ze7cuXz22WdER0czZ84cHn300TJ/PcUlIZIQwrfpz7k6VH/hm/rwQPW+M1abNqML/EKh+yPq7eWvg12DsXpCCCGEqDQuDJFk+ISmzu9E8sHxglc1jeabuzoRZDGw/lAKI6f+S+Ll3m9UXFpEQ7jrd+g8Qf372s9g2tVwZp+2dVUQo5qN4q0eb2HQGVh6aCkT/piA1W7VuiwhhBCXyfLly4mPj2fixIlFOn7GjBm43e6L/pkxY8YFz+nYsSOLFy8mLS2NrKws1q5dy4gRI8r4lZSMhEhCCN+mO+eNvP7CD1m8nUg2h4uMXMflqiq/jndDcE1109v107SpQQghhBCVg+G8EEknb/m0dEGop/fNEVgdYqsx9+4uRASa2XkinZsm/8OR5CytyxLnM5ih/xtw81zwqwYnt8KUHrB5ttaVVQj96/Xn0z6f4m/wZ+2JtYxbNo603DStyxJCCCHKnbyjEEL4Nl3hnUgWo55As3pMklYj7Yx+0Osp9fbKdyFH3qgIIYQQomAXdLr4aGhRUSjnhXrndyb5krgawcwf34VaYX4cSspi2OR/2HMqQ+uyREGa9Id7V0PdbmC3wg/jYeHdkCvfry41uvBlvy8JMYew7cw27lp2F8k5yVqXJYQQQpQrnw6RJk2aRFxcHL169dK6FCGEVs7dEFVX8HgR70i7pEwNx260HgURjdUZ5as/0q4OIYQQQlRoikk6kSqSC74fPh7qxUYEMH/8lTSODuRUei7Dp6zhv8MpWpclChJcA25bBL2fBUUHW+eqXUnHN2tdmeaaRzTny35fEm4JZ1fyLu5YcgeJWYlalyWEEEKUG59+RzFx4kTi4+NZsWKF1qUIIbSSrxOp4CtDwwM8+yJp1YkEam19nldvr5kEace0q0UIIYQQFZaERhWLcn7nkY+HSAAxIRbm3dOFNrVDSc2yM3raWv7Zf0brskRBdHro+QSM/RWCa0HyAfiiL6z/ArTYL7YCaRzWmOn9pxPlH8X+tP2MXTKWE5kntC5LCCGEKBfyDkMI4dsuMc4OINyzL9IZLTuRAJpeB3W6gCMblr+mbS1CCCGEqJgkpKhQZE+kgoX6m/j2rk50bxRBls3J7dPXs3z3aa3LEhdTtwuMXwlNBoLTBr88oo63s1m1rkxT9ULqMbP/TGoG1uRwxmFuW3IbR9KPaF2WEEIIUeYkRBJC+LZzQ6SLjLOL8IRImu2J5KUocM2r6u3Ns+DEVm3rEUIIIUSFI51IFcsFe1TJ9ydPgNnA52Pac3WzKHIdLu7+agNLtksnR4XlXw1Gfgt9XwZFD9vmwed94MxerSvTVK2gWszoP4PY4FhOWE9w25LbOJB6QOuyhBBCiDIlv8EKIXzbuXsiXWScXYR3TySrxp1IALXaQ4uhgBuWPevzYySEEEIIcR7pdKlQ8o2z0+tRFEW7Yiogi1HPZ7e0Y2Cr6tidbibO+o8fN8vY5gpLUaDrg3DbTxAYDYk7YWov2L5Q68o0FRMQw/T+02kY2pDE7ERuX3o7u5N3a12WEEIIUWYkRBJC+LaijLPz7ImkeSeSV58XQG+Gg3/DnqVaVyOEEEKICkQ6kSqWfJ1IEiAVyKjX8dHItgy9ohZOl5uH5m5mzrrDWpclChPbFe5ZCbHdwZYJ82+HxU+Co4K8X9JAhF8E0/tNJy48juScZO5YegfxSfFalyWEEKKM9O7dm7i4OCZNmqR1KZqQdxhCCN9WhHF2FWZPJK+wutB5vHr7t+fAade2HiGEEEJUHNKJVLEYzvn90uXSro4KTq9TeHtYK27pXAe3G55auI3pqw9qXZYoTFA03PoDdHtY/fvayTBjAKQd1bQsLYVaQpl2zTRaR7Ym3ZbOuGXj2JW8S+uyhBBClIHly5cTHx/PxIkTtS5FExIiCSF8W75OpIuNs6tgIRJA90fBPxzO7IGNM7SuRgghhBAVRL5OJOl80ZxiPOf3SwmRCqXTKbwyuAV396gPwEs/xTNp+T6NqxKF0hvg6hdh5Gwwh8DR9TC5O+z7Q+vKNBNkCmLy1ZNpFdkqL0iS0XZCCCEqOwmRhBC+Ld+eSAWPszu7J1IFGs9gCYFeT6u3V7wBOWna1iOEEEIIIS6gSGdYsSiKwtPXNuXBPo0AeHvpbt5dthu37ANasTUdAPf8BdVbQ3YyfDMUVrzps8FpoCmQyVdPpmVES1JzUxm3bBx7U/ZqXZYQQghRYj4dIk2aNIm4uDh69eqldSlCCK0UY5xdapYdu7MCvRFqNxbCG0FWEqx6X+tqhBBCCCHE+c4NkaQzrEgUReHhvo156tqmAHz85z5e/WWnBEkVXbV6cMcyuOI2wK1e6PbtMLAmaV2ZJoJMQUzuO5m48DhSclO4a9ld7E/dr3VZQgghRIn4dIg0ceJE4uPjWbFihdalCCG0UoRxdqF+RvQ69U1/SkXqRtIb4ZpX1NtrPoVU2YBYCCGEEOeQ0EJzinwPSmx8zwa8PLg5AF+sOsizP2zH5ZIgqUIzWmDQRzDkMzD4wf4/YGovOLFV68o0EWwKZmrfqTSr1ozknGTuXHonB9IOaF2WEEIIUWw+HSIJIUT+EKngcXY6nUK1APWxxIq0LxJA4/4Q2x2cufDHy1pXI4QQQgghRJkZ0yWWt4a2QlFg1trDPL1wmwRJlUGbUXDX7xBWD9IOwxfXwLb5WleliRBzCFP7TqVJWBOScpIYt3QcRzOOal2WEEIIUSwSIgkhfNu5eyJdZJwdQLgnRErKrECdSKBeYdzvNUCBbd/B0Y1aVySEEEIIIUSZGd6hNh+MaINOgbkbjvDEgq04JUiq+GJawN3LoUEfcGTDgjvht+fB5dS6sssu1BLK59d8ToOQBpzOPs24ZeM4nXVa67KEEEKIIpMQSQjh24owzg4gwrMvUpK1gnUigbqBbeub1dtLnwaZFy+EEEIIkHF2FY18P0pscJuafDCyLXqdwvyNR3l8/hYJkioDvzAY/R10fUj9++oP4dubIDtF07K0EGYJY+o1U6kVWIujmUe5e9ndpOT43r+DEEKIyklCJCGEbyvCODuA8ED1sTMZFawTyavPc2AMgCNr1Y4kIYQQQgghqpBBrWvw4cg26HUKCzcd47HvJEiqFHR66PsSDP3inH2SesPpnVpXdtlF+Ufx+TWfE+UXxf60/Yz/fTyZtkytyxJCCCEuSUIkIYRvOzdEKmScnXdPpJSsChoiBdeA7o+ot397HnLlzYgQQgghhKharmtVg49vbotBp/D9f8d4ZN5mHE6X1mWJomg5DO5cBiF1IOUgfN4H4hdpXdVlVyuoFp9f8zlh5jDik+K578/7yHZka12WEEKIS+jduzdxcXFMmjRJ61I0ISGSEMK35etEuniIFOZfwUMkgC73QVgsZJyAle9qXY0QQgghtCbj0yoW+X6UiQEtq/PJKDVI+nHzcR6et0WCpMqieiu4ewXEdge7FebdCn++Bi7f+v7VD63P5L6TCTQGsvHURh5d8Sh2l13rsoQQQhRi+fLlxMfHM3HiRK1L0YSESEII36bTn72tXPw/iWHeTiRrBf7l3miBfm+ot9d8AskHtK1HCCGEEEKIctC/RXUmjb4Co17hpy3HeXCudCRVGgHhcOsP0HmC+ve/34I5N0NOmqZlXW5x4XFM6jMJi97CymMreemfl3DL3rZCCCEqKAmRhBC+Ld84O/1FDwvzV7uUkityJxJAk2uhwVXgtMHSZ7WuRgghhBBCiHLRr3kMn45uh1Gv8MvWEzww5z/sEiRVDnoD9H8DhkwGvRn2LFHH253Zq3Vll9UV0VfwTs930Ct6ftz/Ix//97HWJQkhhBAFkhBJCOHbzg2RlIuHSNU84+xSK3qIpCjQ/031de3+Ffb9rnVFQgghhNCKjE8TVVzfuGgm39IOk17Hr9tOcv8sCZIqlTY3wx1LILgmJO2Fz6+CPcu0ruqy6lm7J891fg6Az7d9ztxdczWuSAghhLiQT4dIkyZNIi4ujl69emldihBCK/lCpIv/JzHUEyIlV+Rxdl6RTaDjPertxU+Bo4IHX0IIIYQoHzIaSfiAPs2imXKrGiQt2XGS+2ZtkiCpMql5hbpPUp0ukJsOs0fA2ilaV3VZDW08lAlt1PF+r619jT8S/tC4IiGEECI/nw6RJk6cSHx8PCtWrNC6FCGEVs69QreQcXbVAs52IlWKWdW9noSASPWKvnVTta5GCCGEEFqoDL+z+BLpDCs3vZtGMXVMO0wGHUt3nOKhObJHUqUSGAVjFkHbW8HtgsVPwC+PgdOhdWWXzfhW4xnaaChu3Dzx9xNsOrVJ65KEEEKIPD4dIgkhRL7gqJBxdqGePZEcLjcZuZXgzYwlBPo8r97+63+QeVrbeoQQQghx+bnkQ3ThO3o1icrrSPpl2wke/W4LTpcEqZWGwQSDPoarXwIUWP+52pWUk651ZZeFoij8X+f/o1ftXthcNu7/8372p+7XuiwhhBACkBBJCOHrzh1hV8jVoRajHn+TGjKlWCvJeLg2t0CNtupYiD9e0roaIYQQQgghylXvJlFMGn0FBp3Cj5uP88T8rbgkSKo8FAW6PQQjvgaDn7q/6xfXQEqC1pVdFgadgbd6vEXryNak29IZ//t4ErMStS5LCCGEkBBJCOHjzg2RChlnBxDm2RcpJasS7IsEoNPBtW+pt//7Fo5t1LYeIYQQQgghylnfuGg+vrktep3Cgk1HefaHbRIkVTbNroc7FkNgDCTuhGl94Mh6rau6LPwMfnxy1SfEBsdy0nqSh5Y/RK4zV+uyhBBC+DgJkYQQvi1fJ1Lh/0kMC1BH2lWaTiSA2h2h1UjADYuflLE2QgghhC+RPZGEj7q2ZXXeG94anQKz1x3hxZ92VI59TcVZNdrCuD8hpiVYE2HGQNg2X+uqLotQSyif9PmEYFMwW89s5aV/XpKfXyGEEJqSEEkI4dvyhUhF7USqRCESwNUvgjEAjq6HLbO1rkYIIYQQwjcVMjpZlL3BbWry9rDWKAp8tSaBV37eKR/EVzYhNeH2JdD4WnDmwoI74a+3fCIgrxtcl3d7vYte0fPTgZ+YvmO61iUJIYTwYRIiCSF8WwnG2SVXpk4kgODq0PMJ9fZvz0NWsrb1CCGEEEIIcRkMbVeLN29sCcCXqw/y5pJdEiRVNuZAGPktdLlP/fvy1+D7e8BR9Ue8da7emac6PgXABxs/YMWRFZrWI4QQvqx3797ExcUxadIkrUvRhIRIQgjfVpxxdv7qOLvUyrIn0rk6T4DIppB1Bv58RetqhBBCCHE5SOeLEIzoUIdXhrQAYMpfB3j/tz0aVySKTaeHfq/BdR+o0yO2zoWZg8B6RuvKyt3IpiMZ3ng4btw8+feT7E3Zq3VJQgjhk5YvX058fDwTJ07UuhRNSIgkhPBt5364csk9kTydSJVtnB2AwQQD31Vvb5gORzdqW48QQgghhI+RSE87t3auy/PXxQHw0Z/7+PgP+SC+Ump/O9yyAMwhcORfmNYHEndrXVW5e6rTU3SI6UCWI4v7/7yflJwUrUsSQgjhYyREEkL4thKMs0utjCESQGw3aDUScMMvD4PLqXVFQgghhBBCXBZ3dKvHMwOaAvDub3uY/Nd+jSsSJdKgN9z1G4TFQsohmNYX9i/XuqpyZdQZea/ne9QKrMWxzGM8suIR7M5KOB1DCCFEpSUhkhDCtxVnnF1AJd0T6VzXvKJeuXdiC2z4UutqhBBCCFGeZO8XIfK5u0cDHu/XBIA3F+9i5j+HtC1IlExkE7jrD6jdGXLT4Juh6rSFKizUEsonfT4hwBjAhlMbeG3ta7K/lxBCiMtGQiQhhG/LFyJdqhOpEu+J5BUYBX2eU2//8QpknNK2HiGEEEIIXyF7VFUIE3s35IE+jQB4YdEO5m88qnFFokQCIuC2RdBqBLid8PNDsPTZKj1toUFoA97q8RYKCgv2LmDWrllalySEEMJH+HSINGnSJOLi4ujVq5fWpQghtHJucFTEcXaVuhMJoP0dUL2NetXeb89pXY0QQgghhBCX1cNXN+KOrvUAeGL+FhZvO6FxRaJEDGa4YQr0flb9+5pPYO4tkJupbV3lqEetHjzS7hEA3lr/Fv8c+0fjioQQQvgCnw6RJk6cSHx8PCtWrNC6FCGEVkowzi41y165Rwfo9HDde4ACW+fCwZVaVySEEEIIUfVJJ1KFoSgKz13XjOHta+FywwNz/uOvPYlalyVKQlGg5xMw7EvQm2H3rzC9P6Qd07qycnNb89sY1GAQLreLx/56jIT0BK1LEkIIUcX5dIgkhBAlGWdnc7qw2ir5mISa7dSOJIBfHgVHJe+uEkIIIYSo6HTy9rsiURSFN25sxcCW1bE73dzz9QbWH0rWuixRUi2GwthfICASTm6DL/rCqXitqyoXiqLwQpcXaB3Zmgx7Bo+seIRsR7bWZQkhhKjC5LdYIYRvyxciFX51qJ9Rj1GvHpOWXYn3RfLq8xz4R8CZ3fDvJK2rEUIIIYSo0hTpRKpw9DqF90e0oVeTSHLsLu6Yvp7tx9K0LkuUVO0OcNcfENEY0o/Bl/2r7NQFk97Ee73eo5qlGntS9vDqv69W7mkZQgghKjQJkYQQvu3cN/OX2BNJURRC/NRupPSqECL5hcE1r6q3/3oLUg9rW48QQgghRFUmnUgVksmg47PR7ehYrxoZuQ7GfLmOfacztC5LlFRYXbhjKdTpou4B+82NsPU7rasqF1H+UbzT8x10io5F+xexYO8CrUsSQghRRclvsUII31aMcXYAwZ4QqUp0IgG0Hgl1rgR7Fix5WutqhBBCCCGqHENUFAABXTprXIm4GD+Tni9ua0+rWiEkW22MnraWI8lZWpclSsq/Gtz6A8QNBqcNFt6lXjRXBTt1OsR04IG2DwDw5ro32Z28W+OKhBBCVEUSIgkhfFu+EOnS/0kMqWohkqLAwHdBZ4BdP8PuJVpXJIQQQghRpdSdNYvIBx+g+iuvaF2KKESQxcjM2zvSODqQU+m5jJ62llPpOVqXJUrKaIFhM+BKNWBh+Wvww4QquRfs7S1up3vN7uQ6c3n878fJsksAKoQQomxJiCSE8G3nBkeXGGcHVTBEAoiOg84T1NuLHwebvOkQQgghhCgrplo1ibj3XvShoVqXIi4hLMDEN3d2ok41fw4nZ3HLtLUkW6te6OAzdDq45hW47n116sSWWTBrONisWldWpnSKjte6vUaUXxQH0w7y2trXtC5JCCFEFSMhkhDCtxV3nJ2lCu2JdK6eT0JwTXVfpL/f1roaIYQQQggh/p+9+w6Pouz6OP7b3VTS6L1XjdIEERUVbNi7gr03Ym+vHaw8dlGjPggCjyiCooBiQSR0RHoLSg81tEB6z75/3Nkkm95nk/1+rmuvzM7OzpxNNrM7c+ac2xLNQwP09T2nqWVogLYeStLtX/6tpPQsq8NCVfS/S7ppmuQbJO2IkiZfK6XFWx1VtWoU0Ehvnf1W3vhIs7bPsjokAKhXhgwZovDwcEVGRlodiiVIIgHwbpVsZ1fvkkj+wdLFb5vppR9JBzdZGw8AAABgkXaNG2jyPaepcZCfNuyL1/1frVR6VrbVYaEqup0v3TZD8g+Tdi+TJl0hpcRZHVW16t+yvx7s/aAk6fW/XtfO+J0WRwQA9UdUVJSio6MVERFhdSiWCsEuJgABAABJREFUIIkEwLu5tbPzwjGRCjrxMumEy6ScLOmnR6WcHKsjAgAAACzRtXmwJt55qoL8HFqy7aiemLpO2TlOq8NCVbQbIN3xk9SgiXRgrTThEilhv9VRVat7e96r01qeptSsVD214CmlZTGuFwCg6kgiAfBuBcdBKkc7u3qdRJJMNZJfiLR3hbRyvNXRAAAAAJbp1bah/ntrf/k6bJq94YBGzdokp5NEUp3Wqrd0xy9SSCvp8GZp/IXS4S1WR1VtHHaHRp81Wo0DGmvLsS16d+W7VocEAKgHSCIB8G6VbGdXb5NIYW2k80ea6bmvSPH7rI0HAAAAsNCgbk31wbA+stmkr/6K0Zg/t1odEqqq+QnSXb9LTbpK8Xuk8RdIuxZbHVW1adagmd4c9KYkaeq/UzVn1xyLIwIA1HUkkQB4N5stf9pediVSaH1PIklm4Nm2p0oZidIvT0tcbQkAAAAvdlmv1nr1ipMkSR/O3aqv/oqxOCJUWaMOJpHU9lQp7bj0v6uk9d9ZHVW1ObPNmbr75LslSSOXjtSexD0WRwQAqMtIIgHwbm7VR7YSF3MJDfSRVM+TSHaHdPlHkt1H+ne2tPknqyMCAAAALHXr6R316HndJEkvz9yo2esPWBwRqiyoqXT7T9KJV0g5mdIP90iL3q83F9FF9I1Q72a9lZSZpBcXv6gcJ2PeAgAqhyQSAO/m1s6u7CSSq51dQlpWTUXkGVqES2c+ZqZ/eVpKi7c0HAAAAMBqj53fTTef1l5Op/TY1DVasu2I1SGhqnwDpesnSQMjzP0/X5FmPyll1/3jPV+7r946+y018Gmg1YdWa8o/U6wOCQBQR3l1EikyMlLh4eEaPHiw1aEAsEoFK5Hq/ZhIBZ39tNS4i5QUK/0x0upoAAAAAEvZbDa9euXJuqRnS2VmO3Xf/1Zqw14utqrz7Hbpojeli/4jySatHC9Nv0vKrvvHfG2C2+iJfk9IksasHqM9CbS1AwBUnFcnkSIiIhQdHa358+dbHQoAq1SyEikjK0dpmdk1FZVn8A2QLh9jpldNkGKWWhsPAAAAYDGH3aYPhvXRGV2aKDkjW3dM+Fs7DidZHRaqw8AHpRsmSQ4/KXqmNO12KSvd6qiq7Poe12tAywFKzUrVy0tfpq0dAKDCvDqJBAAVrUQK9veRw26W84pqpE5nSafcZqZnPSxlplkbDwAAAGAxfx+Hxt7WXye3CdXR5AzdOv5vHUrge3K9EH6lNHyK5PA348NOvaXOHwPZbXaNOmOUAn0CtfLgSk39d6rVIQEA6hiSSAC8XNmJI7elbTaFBvhI8pIkkiRd8JoU3EI6uk1a+I7V0QAAAACWC/b30cQ7B6hjkwbadzxVd01aoaT0uj+ODiR1O1+6eZrkEyhtnSNNGS5lplodVZW0C2mnx055TJL0waoPtDdxr7UBAQDqFJJIALxbwRZ25WhnJ0mhuS3tErwliRTYULrkXTO95EMpdqOV0QAAAAAeoWmwvybdNUBNgvy0cV+CIr5ercxsWoXVC50HS7d8L/kGSTuiTCIpI8XqqKpk+AnD1a9FP6VmpWrU0lFyOp1WhwQAqCNIIgFAnvIlkVzjInlNJZIkhV8hnXi5lJMlzXpIyqnn40EBAAAA5dChSZC+vONUBfo6tGDLYb3w4wZOztcXHQcVSCTNr/OJJLvNrlfPeFUBjgAtj12u77Z8Z3VIAIA6giQSAC9nK3ayNF6ZRJKki9+R/MOk/Wukvz6zOhoAAADAI/Ru11Cf3NRXdps0beVejflzq9Uhobp0OEO6ZbrkFyztXCBNGVanE0ntQ9vr0VMelSS9t/I97U/ab3FEAIC6gCQSAO/m1sKuYu3svC6JFNpKuvA1Mz3vdSluh7XxAAAAAB7ivBNb6LWrTpYkfTh3q6at2GNxRKg2HU4vkEhaKH1zg5SeZHVUlXbTiTepb/O+SslKoa0dAKBcSCIB8HLlLD8qIDTANSaSFw6ce8ptUsezpKxUaebDUg493wEAAABJuvm0DnpoSFdJ0nM/btD8fw9ZHBGqTfuB0i0/SH4h0q5F0qTLpeQjVkdVKa62dv4Ofy07sEw/7fjJ6pAAAB6OJBIA7+ZWiFS+hFJIgI8kKTnDC5NINpt05SemL3jMYmnleKsjAgAAADzGkxd21zV92yg7x6kRX6/Wxn3xVoeE6tL+NOm2mVJgY2n/amn8hdKxGKujqpSOYR31YO8HJUlvr3hbR1OPWhwRAHi2IUOGKDw8XJGRkVaHYgmSSAC8XMXb2YX4myRSYpqXtbNzadRROn+Umf5jZJ09cAIAAACqm81m03+u7aUzuzZRSka27piwQnvi6u4YOiikbT/p7jlSWHspbrtJJMVutDqqSrntpNt0QuMTFJ8erzGrx1gdDgB4tKioKEVHRysiIsLqUCxBEgkAXMpZiRQc4EoieWElksup90jtz5Ayk6WfHpXoow0AgOcp53cbANXLz8euz27ppxNahuhIUrpun/C3jqdkWB0WqkvTbiaR1PwkKSlWmnCxtGux1VFVmK/dVy+c9oIk6cdtP2r94fUWRwQA8FQkkQB4N1vFK5GCcyuRktK9OIlkt5u2dj6B0o4oafX/rI4IAAAA8BihAb6aeOcAtQoL0I7Dybpn0kqlZWZbHRaqS2gr6c5fpA5nSukJ0lfXSNGzrI6qwvo076Mru1wpSXpj+RvKzuE9CgAoiiQSAC9X8St0XWMiJXlzJZIkNekinfuimZ7zohS/z9p4AACAOyqRAEu1DAvQxDsHKCTARytjjunxqWuVk0MFf70R2FC65QfphMuk7HRp2m3SinFWR1Vhj/V7TMG+wYo+Gq0ft/1odTgAAA9EEgmAdyt4cqW87ez8fSV5eTs7l4EPSm1PNVff0dYOAAAAcNOjZYjG3tpffg67ft0Yqzd/2Wx1SKhOvgHSDf+T+t0pySnNflJa+K7VUVVI08CmGtFnhCRpzOoxik+PtzgiAICnIYkEwMtVvJ1dXiWSN7ezc7E7pCsjJYeftO0Pad23VkcEAABcqEQCPMLpXZronet7SZLGLd6pb5bvtjgiVCu7Q7rsA2nwc+b+vNek5f+1NqYKGn7CcHVt2FXH04/r4zUfWx0OAMDDkEQC4N0qU4mUm0RKTMusiYjqnmY98g+Yfvs/KWG/tfEAAACDJBLgMa7s00ZPXNBdkvTSzI1avPWIxRGhWtls0uBn84+Lfn1GWjvF2pgqwNfuq+cGmNi/2/Kd/on7x+KIAACehCQSAOQpZyWSf34lkpP2bcYZj0itT5HS4mlrBwAAABTj4XO76uq+bZSd49SDX6/S1oOJVoeE6nbO/0kDTWs4zRwhbf7Z2ngqYECrARracahynDl66++3ONYFAOQhiQTAy1X8Cl1XJVKOU0rJyK7ugOomh4901Wemrd3WOdLar62OCAAAUIkEeBSbzab/XNtTp3ZspMS0LN01aYWOJqVbHRaqk80mXfiG1OcWyZkjfX+XtGux1VGV25P9npS/w18rD67U3N1zrQ4HAOAhSCIB8G6VaGcX6OuQw26WZVykApqfIA15wUz/9pwUv9faeAAAAAAP4+/j0H9v7a/2jRtoT1yq7vtqldIyuTCtXrHbpcvHSD0ulbLTpSk3SrEbrI6qXFoFt9IdJ90hSXpv5XtKzybJCQAgiQTA69lKmC7lGTabgv1d4yKRRHJzxsNS21Ol9ARp1sO0tQMAwELUIQGeqXGQn76841SFBvhoVcwxPfP9elqH1TcOH+m68VL7M8yx0VfXSHE7rY6qXO46+S41b9Bc+5L26avor6wOBwDgAUgiAfBulahEkpSXRKISqRC7w7S18wmQts+TVk20OiIAAADA43RtHqzPb+knH7tNs9bt14dzt1odEqqbb6B04xSpxclS8iHpq6ulpENWR1WmBr4N9Ngpj0mSvlj/hY6kHrE2IACA5UgiAUCe8ieRQgJclUiZNRVM3dW0m3Tey2b69xekuB3WxgMAAAB4oDO6NtXrV50sSRrz51bNWLPP4ohQ7QIbSrdMlxq2l47tlCZfK6UlWB1VmS7tfKl6Nu2plKwUfbT6I6vDAQBYjCQSAC9XuUYveZVItLMr3mkPSh0GSZnJ0owRUg593gEAqG00xwI83/AB7XX/OZ0lSc98v14rd8VZHBGqXUhL6dYZUlAzKXa9GSMpM9XqqEplt9n1zKnPSJJmbJuhrceolAMAb+bVSaTIyEiFh4dr8ODBVocCwCqVbGeXV4lEO7vi2e3SVZ9KfsHS7mXSskirIwIAAAA80v8NPUFDT2qhjOwc3ffVKu0+mmJ1SKhuTbpIN38v+YdKMYulqbdKWelWR1WqPs376IIOF8gppz5d+6nV4QAALOTVSaSIiAhFR0dr/vz5VocCwDK2EqZLFxzgK4lKpFI16iBdNNpMz3tNOhhtbTwAAACAB7LbbfpgWB/1bBOmuOQM3Tnxb8Wn0ja73mndR7ppquQTKG37Q/pmmJR63OqoShXRJ0I22TR391xtOrrJ6nAAABbx6iQSAFS2EsnVzi6RJFLp+t4qdRsqZWdIP94vZWVYHREAAADgcRr4+Wjc7f3VKixA2w8na8TXq5SZnWN1WKhuHc6QbvpW8m0g7YiSvhgiHdpsdVQl6tKwiy7tfKkkKXIN3SUAwFuRRALg5So3JpKrnV1SOlcIlspmk674WApsbPp/L3zb6ogAAAAAj9QiNEDjbz9VQX4OLdl2VK/+RCV/vdR5sHTXb1JYeyluhzTufOmfX6yOqkQP9n5QDptDi/Yt0tpDa60OBwBgAZJIAFAJIf6uJBKVSGUKaSFd9r6ZXvS+tG+1tfEAAAAAHiq8dajGDO8rm0366q8YTf4rxuqQUBNa9Zbumy91OlvKSJK+vUla/KHkdFodWRHtQ9vrqq5XSZI+WfuJtcEAACxBEgmAd6tsO7sA2tlVyElXSydfKzmzpZkPSZlpVkcEAAAAeKTzw1voqQt7SJJGzdqkv3YctTgi1IigJtItP0in3iPJKc0dKc2MkLLSrY6siPt63Scfu4+WH1iuVQdXWR0OAKCWkUQC4OVsJUyXjjGRKuHit6UGTaVDm6So162OBgAAAPBYIwZ30RW9Wysrx6kHJ6/SnrgUq0NCTXD4Spe+J13yrmRzSGu/lv53lZR63OrI3LQObq2ru14tSfpi/RcWRwMAqG0kkQB4N7ccUsWTSCkZJJHKLaipGR9JkpZ+Iu3+y9p4AAAAAA9ls9n09nW91LNNmI6lZOre/61UMq20668B90o3fyf5h0m7l0qTLpOSDlkdlZu7Tr5LdptdS/Yv0Y7jO6wOBwBq1ZAhQxQeHq7IyEirQ7EESSQAXq5ylUiBfg5JUnJ6djXHU8+dcInU+yZJTmn6vVJavNURAQAAAB4pwNehsbf1U7MQf/0Tm6gnpq1VTo7njZmDatL1POnO2VJQcyl2g/TlUOmY54yJ1TakrYa0GyJJ+mrzVxZHAwC1KyoqStHR0YqIiLA6FEuQRALg3SpQfVRQEJVIlXfxW1KjjlL8bunXZ62OBgAAAPBYrcIC9d9b+8nPYdfvmw7qw7lbrA4JNallT+mu36SG7aW4HdKXF0lHt1sdVZ5bTrxFkjRr2ywdSvGsSikAQM0hiQQALhVIKDXIrURKyaASqcICQqWrPpdkk9Z9I23+yeqIAAAAAI91SvtGevOanpKkj+Zt08/r91scEWpUky7SXb9LzU6QEvdLk67wmIqkfi36qW/zvsrIydCXG7+0OhwAQC0hiQTAy1WunV0DP1clEkmkSulwujToMTP906NS4kFLwwEAAAA82XX92ureszpJkp76bp027qMtdL0W2lq6/SepaXcpYa806XIpfq/VUclms+nB3g9Kkr7f8r0Opxy2OCIAQG0giQTAuxWsPqpAJVJQXiVSlpxO+pJXyuDnpRY9pZSj0qyHJH6PAAAAQImevfhEndO9mdIyc3Tf/1bqcGK61SGhJgU3l26bJTXqJB2PMRVJyUesjkoDWw1U72a9lZ6drombJlodDgCgFpBEAuDlKlmJlDsmUo5TSs/KqeaYvISPn3TNWMnhL22dI60YZ3VEAADUL1ygAdQrDrtNH93YV52bBWl/fJoemLxK6Vl0RqjXQluZiqSw9lLcdmnKcCkz1dKQbDabHuj9gCTpuy3fKSUzxdJ4AAA1jyQSAO9WgeqjggJ9HXnTyelZ1RWN92kRLl3wipn+/QXp4CZr4wEAAAA8WFigr8bd1l8hAT5aFXNML83YSGeE+q5hO+mW6VJAQ2nvCmn6PVKOtcnDM1ufqfYh7ZWalarfd/1uaSwAgJpHEgmAl6tcOzuH3aYAX7MLZVykKjrtAanbhVJ2uvTDfVIWbTkAAKgWlbxYBoBn69wsWJ/cdIrsNmnayr2avHy31SGhpjXrLt04RXL4Sf/8LP3xsqXh2Gw2Xd3taknSzO0zLY0FAFDzSCIBQJ6KnWgJ8jMt7UgiVZHNJl0ZKTVoIh3cKM173eqIAACoH6hOAOqtc7o307MXnyBJevWnTVoVc8ziiFDjOpwhXfWZmV72ibThe0vDuazzZbLJplUHV2lPwh5LYwEA1CySSAC8m61ylUiSFOhnWtolZ9DOrsqCm0uXf2Sml34s7VxobTwAAACAh7v3rM66tGcrZWY7NeLrVTqUmGZ1SKhpPa+TznrSTM96WDoYbVkoLYNa6vTWp5tQdsyyLA4AQM0jiQTAy9lKmC6bqxIplUqk6nHiZdIpt0lySjMipLR4qyMCAAAAPJbNZtPb1/VSt+bBOpiQroe+XqPM7Byrw0JNG/KC1HmwlJkiTb3F0uOmK7tcKUn6aftPynHy3gOA+ookEgDvVoWxAvIqkdKpRKo2Q9+UGnaQ4ndLs5+yOhoAAADAowX5++jzW/spxN9Hf++K05u/bLY6JNQ0u0O69ksprJ0Ut1365RnLQjm3/bkK9g3WvqR9WnVwlWVxAABqFkkkAF6u8u3sgvxNEokxkaqRf4h07TjJ5pA2TJPWTbU6IgAAAMCjdWkWrPdu6C1JmrBkl2as2WdxRKhxQU2ka8dLNru0/ltp04+WhBHgE6ChHYdKkmZsm2FJDACAmkcSCQDyVCyJ1CC3nR1JpGrWboA0+FkzPftJKW6HtfEAAAAAHu7Ck1rq4XO7SpKe/WG9ovcnWBwRalz706RBT5jpnx6TEvZbEsZVXa+SJP0R84eSM5MtiQEAULNIIgHwbrbKVyI18HNVItHOrtqd9aTU/gwpI1Gafo+UnWl1RAAAAIBHe+z87jqnezOlZebogcmrFJ/Cd+h6b/CzUqs+Utpx6adHJaez1kPo3ay3OoZ2VGpWqubsmlPr2wcA1DySSABQSVQi1SC7Q7pmrBQQJu1bJX12ptURAQAAAB7NYbdpzPA+atc4ULvjUvTo1DXKyan9pAJqkcPXHDfZfaWtc6TombUegs1m05Vdr5RESzsAqK9IIgHwbm5XalVwTKTcSqRkKpFqRsN20hWfmOkj/0oxy6yNBwAAAPBwDRv46fNb+snfx675/x7Wh39utTok1LRmPaRBj5vp356V0mq/leHlnS+XTTatPrRascmxtb59AEDNIokEwMsVSCJVtp1dOpVINSb8ivzp7++SUuKsiwUAgDrGHhRkdQgALHBS6zD959qekqSP/tyqudEHLY4INe6sJ6XGnaXEA9K812t98y2CWqhv876SpD93/1nr2wcA1CySSAC8WxV6Rjfwp51drXhun9Skq5S4X/rpEUv6fAMAUJe0+2Ks/Lp2UfsvxlodCgCLXN23re44o6Mk6fGpa7XzSLK1AaFm+QZIl75vpld8Ie1bXeshnN/hfEnSHzF/1Pq2AQA1iyQSALhUsBLJ1c4uhXZ2Ncs/WLp2nOnzvfknadUEqyMCAMCjBZ91lrr8/LMCe/WyOhQAFnr+khN1asdGSkzP0v1frVRyOsct9VqXIVLPGyRnjvTz41JO7V7seH57k0RafXC1jqQeqdVtAwBqFkkkAKikQD9TiZRMJVLNa91XOn+kmf7tOelgtLXxAAAAAB7Oz8euyJtOUfMQf205mKRnpq+Xk6r++m3oG5J/mHRgrbRiXK1uulVwK53c5GQ55dS83fNqddsAgJpFEgmAlyt4EFW5SqRUKpFqx8AIqev5UlaaGR8pI8XqiAAAAACP1jw0QJ/dcop8HTbNXn9A4xbttDok1KTg5tL5L5vpP0ZKOxfW6uZdLe3mxsyt1e0CAGoWSSQA3q3glXgVbGfnGhMpKZ1KpFpht0tXfS4FNZcOb5bmvGB1RAAAAIDH69ehsV66LFyS9J/f/tHS7bQaq9f63Sl1u1DKSpW+vkHauajWNn1BhwskSStiVyg+Pb7WtgsAqFn1Iok0evRo9e/fXyEhIWrRooVuuOEG7dq1y+qwANQJla9ECvQ1lUhpmSSRak1wM+ma/5rplV9K0TOtjQcAAACoA24d2EHXnNJG2TlOPfzNGsXGp1kdEmqK3SHd8JXU9QKTSPpmmLTn71rZdPvQ9ureqLuynFmK2hNVK9sEANS8epFEWrBggR5++GEtX75cv/32m+Li4nTxxRcrK4sWUwAqoIKVSAG+ZhdKEqmWdTlXOvMxMz3rYen4bkvDAQAAADydzWbTm1f3VHirUB1NztBD36xWZnaO1WGhpvgGSMMmm2OnzGTp6+ukg5tqZdO0tAOA+qdeJJF+++033X777QoPD1ffvn31xRdf6J9//lF0NAOvA6g5AVQiWefcF6U2/aS0eGnKTVI2Fw0AAAAApQnwdejTm09RiL+PVsYc0zu//2t1SKhJrkRS2wHmuOmra6Tje2p8sxe0Ny3tlu5fquTM5BrfHgCg5lU6iTR58mTdf//96t+/v/z9/WWz2TRx4sRSn7NixQpdcsklatiwoYKCgjRw4EBNmzatsiGUKD7e9F1t3Lhxta8bQD3jrI52dlzBV+scvtK14yXfBtLBDdJ3t1sdEQAAAODxOjYN0jvX95IkjV24Q3M2xVocEWqUX5B08zSpebiUFCt9fb1JKNWgLg27qENoB2XmZGrJviU1ui0AQO2odBLpxRdf1NixYxUTE6NWrVqVuXxUVJTOPPNMLV68WDfccIMeeOABxcbGatiwYXrvvfcqG0YR2dnZeuqpp3TJJZeobdu21bZeAPVVgSRSBdvZ+bva2WVly+mWjEKtaNxJGvK8mf7nZ2nVJGvjAQAAAOqAi05upbsHdZIkPfndOu0+mmJxRKhRgY2km7+TgltKhzdL398l5dRcNw2bzaYh7YZIEuMiAUA9Uekk0rhx47Rr1y4dPnxYDzzwQKnLZmVl6d5775XdbtfChQs1duxYvffee1q3bp26d++u559/XjExMW7PefbZZ2Wz2Uq9FeZ0OvXAAw9o9+7dZVZFAYCkKlUiudrZOZ1SBv3ErXH6Q1L3i8z03JFS4kFr4wEAAADqgGcvPkGntG+oxLQsPfj1Klp013dhbaWbvpV8AqVtc6U/X6nRzbmSSAv3LlRmTmaNbgsAUPMqnUQ6//zz1aFDh3ItO2/ePG3fvl033XST+vTpkzc/LCxMzz//vDIyMjRpkvsV5E8++aQ2b95c6q0gp9OpESNGaO7cufrzzz/VrFmzyr40AN6qgpVIAT6OvOm0DJJIlrDZpOsnmunUY9J73aVsDlIAAACA0vg67PrkplPUqIGvNu1P0Ks/M6Z0vde6r3RVpJleMkaKnlljm+rdrLcaBzRWQkaC1hxcU2PbAQDUjkonkSpi/vz5kqQLL7ywyGNDhw6VJC1YsMBtfrNmzXTCCSeUenNxOp2KiIjQ7NmzNW/ePLVr167mXgwA5PJ12GTPzTulZXHlnmV8A6X7F+Xf//NV62IBAAAA6ojWDQP14fC+stmkb5bv1ow1+6wOCTXt5GtNNwdJmhEhHd1eI5tx2B06u+3ZkmhpB6B+GDJkiMLDwxUZGWl1KJaolSTS1q1bJUndunUr8ljLli0VHByct0xlREREaMqUKfrmm28UGBio2NhYxcbGKiMjo9jl09PTlZCQkHdLTEys9LYB1HEVrD5yf6otr6Ud7R8s1qqX1KSrmV76UY1eVQcAAADUF+d0b6aHh5jv0c/9sEFbD3J+pN47/xWp/RlSRqI0/W4pq/hzZ1U1uN1gSSaJxBjCAOq6qKgoRUdHKyIiwupQLFErSaT4+HhJpn1dcUJDQ/OWqYzPPvtMx48f11lnnaVWrVrl3ZYuXVrs8qNHj1ZYWFjeLTw8vNLbBuDdAvOSSLSzs9zDq6TeN5rpabdJhzaXvjwAAAAAPXp+d53ZtYlSM7P14NerlZKRZXVIqEkOH+nacVJgI2n/GinqjRrZzOmtTpe/w1/7kvZpy7EtNbINAEDtqJUkUk1zOp3F3gYPHlzs8s8995zi4+PzbtHR9P4FUDlUInmYi9/Kn/50oMQVbwAAAECpHHabxgzvq+Yh/tp2KEmjZm2yOiTUtLA20uUfmeklH0pLP5b2r5Xi91XbGLMNfBvo9FanS5IW7F1QxtIAAE/mUxsbcVUglVRtlJCQoEaNGtVGKJIkf39/+fv7u20fACrD39fk4kkieYiAMOn6SdJ3t5v7rzSURlW+0hUAAADwBk2D/fXhsD66efxyTVu5V2d2baor+7SxOizUpPArpAH3SX+Plea8mD/f4Sc17CD5B5tph5/kFyw16yE1P1HyC5Lsvqaiye5rxqj1DzGVTYGNJR+/vFWd2eZMzd87X38f+Fv39brPghcJAKgOtZJEco2FtHXrVvXr18/tsdjYWCUlJWnAgAG1EQoAVKsAH1OJlEoSyXOcdJX0XYH7b3eWntlhVTQAAABAnXBG16Z6eEhXfTRvm57/YYN6tW2oTk2DrA4LNemit6Qm3aQ1/5OSDkspR6TsDOloMeOWb/m1fOt0+JmL+wIbaUBwY8kmrTm4Uul/j5V/g6Ym2dSyl9SgcZXGKAYA1J5aSSKdc845Gj16tObMmaPhw4e7Pfb777/nLQMAdU1AXiUSYyJ5lJHHTRWSJKUclf79TepxkZURAQAAAB7vkfO66a8dcfp7V5wenrJa0x88Q/65F86hHrLbpdPuMzdJysmR4ndLcTtNW7vsDHNLjZMOrMudn5H7WKaUkyllpkjpSVLqMUlO83jyYSn5sDodkZq2a6MjPtL6eS/q1LT0/G3bHKaqKaCh1KCRSS75BUvBLaTGnc3P4OZSSCsppIVJTAEALFErSaTzzjtPnTt31jfffKNHHnlEffr0kWTa27355pvy8/PTbbfdVhuhAEApKn4VVKCfOaBKz6ISyaPYbNIzO6W3O5n7U4aZxBJXugEAAAAl8nHYNebGPrpkzCJt3Jeg0b/8o1FXnGR1WKgtdrvUqKO5VVR2lpSRJKUnSGkJUsoR2RJjNWDzl/ol/YCWt+mpU7P8pcRYKW675Mw2y6YnmMRVWRx+kn+oFNRUatVHOvFyqet5pp0eAKBGVTqJNG7cOC1evFiStGHDhrx58+fPlyQNGjRI99xzj9mIj4/GjRunoUOH6uyzz9bw4cMVEhKi6dOnKyYmRu+++646duxYtVdSCZGRkYqMjFRGRkatbxtA/eBqZ8eYSB6oQWPppKulTT+a+4yPBAAAAJSpVVig3r2+t+6etFITl+7SGV2a6MKTWlodFjydw0cKbGhuBZzWwE+/LB2pFU3bShdPMjPTE031UkaylHbcdI9IizcJpcRY6fC/prIpfo+UckxKjzcVTilHzO3wP9L6b6Wg5lK/O6TQ1qZSyT/U/Axqaub5+AsAUHWVTiItXrxYkyZNcpu3ZMkSLVmyJO++K4kkSUOGDNHixYs1cuRITZ06VZmZmerZs6feeustDRs2rLJhVElERIQiIiK0d+9etWvXzpIYANRtAb6uJBLt7DzS9RPzk0iSNCqMRBIAAABQhvNObKG7B3XS+MU79fT363VSmzC1aUjFByru1JanSpLWH16vlMwUNfBtIPmHmFt5peVWLKUelxL2S9v+kKJnSUmx0sK3S3iSzVxY6BNoElv+oaZ9XsN2UouTTaKp0zlFkl4AgKIqnUSaOHGiJk6cWKHnDBgwQL/+Ws6B+ACgDvDPHRMplUokzzUq3iSPXKJnSeFXWBcPAAAAUAf830UnaMWuOK3fG69Hp6zRt/cNlI/DbnVYqGPaBrdV66DW2p+8X2sOrdGZbc6s+EoCQs0trK3U8mSp+4XS+a9Ia7+W9q+RUuLy2+ilx0uJB6XsdFPhJEkJe4tfr91Hatheanuq1KSrqWJq009q27/yLxgA6qFaGRMJAOqr/Eokkkge7ekd0judzfS0W6WXj5l+3wAAAACK5edj18c39tWlHy3Wyphj+mDuFj099ASrw0IdY7PZdGrLUzVz+0z9Hft35ZJIxfFrIA24t/jHnE5TsZSeIGWkSGnHTPu89ATp0D/S8RjpyBZzi9thbgWFtjHjLXW/WGp+ohl3ycdf8guWHL7VEz8A1CEkkQB4OVuVnp0/JhLt7DxaUBOp/93SyvHm/quNaGsHAAAAlKFDkyCNvqanHp6yRp/O367TOzfVoG5NrQ4LdczA1gM1c/tMzds9T4+d8phstqodh5fJZpPC2khqU/IyTqcZc+nodilmqZR8yFQwbZsrJeyTVv/P3NxXbFrkdTpb6n2jFNzCtOXzCzbzSTABqKdIIgFAFQT6mWoWKpHqgMvel2I3SHv/NvcZHwkAAAAo0+W9W2vp9iOa8vcePTZ1rX599Cw1C/G3OizUIUPaDVGgT6B2JezSusPr1Kd5H6tDMommhu3NrcuQ/PlpCdLuZdLmn6Tdf0nxe6WsNElOc0s5asbdLTj2rktQM6lRp9wxmEKkgIZm2i9YatLFjMsU2MiMyeTglCyAuoM9FgAv56zSs12VSOlZJJHqhHv+cB8fad7r0rkvWhcPAAAAUAe8fNlJWhVzTFsOJumJaWs16c4BsttruJoE9UaQb5Au6HCBZm2fpZnbZ3pGEqkkAaFS96Hm5uJ0SjlZUuoxU7m0Ypx0aLNJKGUkSxlJkpxS8mFzK4tPgNR+oHTS1VLfWyW7o8ZeDgBUB69OIkVGRioyMlIZGRlWhwLAE1SipD5/TCTa2dUZT22T3u1qphe+I532gBRESw4AAACgJIF+DkXedIou/2SxFm09os8XbteIwV2tDgt1yIUdLtSs7bO0Mnal1aFUnM1mWtUFNze3Dqe7P56TLaXESUe3mp+pcSa5lHJUSk80847tNOMzJeyT0o5LO+ab2/KxUpPOksPfJJcCG0ptTpHC2km+DfK3CQAW8uokUkREhCIiIrR37161a9fO6nAA1EEBvrSzq3OCm0l9bpHWTjb33+lCWzsAAACgDN1ahOjVK07WM9PX6705W3Rap8bq16Gx1WGhjujZrKckaVfCLiVkJCjUL9TiiKqR3WGOM4Oblb1sTo50KFra9oe08F3p0CZzK01Ia6lFuGmXF9LKtMpz3Xwb5LfM8w+RAsIkvyAzv6bHngLgNbw6iQQAVeXnY5JIGVlUItUpV0VKe1dIR/419xkfCQAAACjT9f3basn2I5q5dr8e/Xatfnn0LIUG+FodFuqAxgGN1Sa4jfYl7dOmI5t0euvTy35SfWS3Sy1PNrfeN0k7okw7vKwMM/ZS3A4zlm/qMSkz1bTHS9xvbhXh8DOJpYAw06IvsLHZZnBLqfM5UvOTTCwAUA4kkQB4uapdmePryE0iZZNEqnMe+tt9fKS5r0jnj7QuHgAAAMDD2Ww2vXbVyVq9+5j2xKXq5Rkb9eHwvlaHhTqiZ9Oe2pe0TxuPbPTeJFJBIS2k3sNLXyYtXtq9XIrfY9rjJR2SMlOk9AQpPcm0zUs7blrlpSeYmyRlZ5i2eqlx+evaEZU/7Rdith/aWvIPNdVM/mGmmqlBE8k/2Izb1Lhzdb9qAHUQSSQAqAJXJVI6lUh108vHpFcbmenF70u9hknNT7A2JgAAAMCDhQb46sNhfXXDf5dpxtr9OqdHM13dt63VYaEOOLnpyfpt12/acGSD1aHUHQFhUvcLy798To6UmSylHjcVTqnHzLhM8XukI9vMuE27lkgZidLRROnotlJWZjNJrt43Sq37mhZ5Dk4lA96I/3wAqAK/3EqkTCqR6ia7XbrzN2nCReb+p6dJL8eZntYAAAAAitWvQyM9dl43vffHFr00Y5NOad9IHZoEWR0WPNzJTU+WJG2O22xxJPWY3Z4/XlJJstJN8ijxgJR8xCSbko+aCqeUI6bCKTFW2vOXtG6Kubn4BEoNGku+gaZFXkCoGYMpqJlJMgU1zR2bqaFZzj/EPOYXLAU2YpwmoI4iiQTAyzmr9GzGRKoHOpwu9btDWjXR3H+1MeMjAQAAAGUYMaSrFm09or93xenRb9fquwdOz2v3DRSnU1gnSVJscqzSstIU4BNgcUReysdfanGSuZVmz9/S319I2/80rfQkKStVSthXue3afU2bvKDmpsIqqKlJLDVoYhJOgY3NvAZNzbygpmY5Ek+A5bw6iRQZGanIyEhlZGRYHQoAj1DxLyYkkeqJy8fkJ5EkM1YSiSQAAACgRA67TR8M76OLP1yotXuOa8zcrXpqaA+rw4IHa+TfSCG+IUrMTNSexD3q1qib1SGhNO0GmJvTKWWlSZmpJpmUnmjGYko54j4vPVFKPmweSz5sKppSjppxnbLTpZxM014v9Vj5Y7A5TCVTUFMpuIVJNoW0NAmn0NYmCRXSKvfxlqYSC0C18+okUkREhCIiIrR37161a9fO6nAAWKJqV7T4u5JItLOr+0Yel15pmH9/+X+l0+63KhoAAADA47VpGKj/XNtLI75ercj52zSoW1MN7NzE6rDgoWw2m9qGtNXmuM06kHyAJFJdYbOZ9nW+ua3sKiMz1bTOS4uXkg+ZhFPSISk9wfxMizct9NLipaSD5mdGkuTMlpJize3gxrK30yA32RTaykyHtjZJqNDWUuNOUsteVDYBleDVSSQAqCpXuwYqkeoBm016aqv0bu6BzK/PSJ2HSM26WxsXAAAA4MEu6dlKN/Rvq2kr9+rJaev062NnKTTA1+qw4KFaBrXU5rjN2p+03+pQUJt8A6WG7SRV4CL+9CRTtZQYa34mxZrKpqRDptIp+bD7tGSqo1KOSIc2Fb/Ohu1NUqnDmVKHM/IrmRo0Ma32ABSLJBIAVAHt7OqZ4OZSr2HS+qnmfuSptLUDAAAAyvDy5Sfprx1x2h2XolEzN+n9YX2sDgkeqnVwa0nSgeQDFkcCj+cfbG4Ny5F4ysowiaaEfabiKfGASTAlHZQS9psk04F10vHd5rZvlbT0o0LbCzPnBDqcLgU0lELbmORSWFvTQi+kleTLOF7wTiSRAKAK/By0s6t3rhmbn0SSGB8JAAAAKEOwv48+GNZb13++TD+s2afzTmyhS3u1sjoseKBWQeZ9cSCJJBKqkY+fFNLC3EqSfFQ6/I8Uv1eKnmkSTYmxpnIpO0NKjze3o1tLWIFN6j5UuuxDk1SiLR68CEkkAKgCKpHqqVHxJnmUd59EEgAAAFCafh0aa8Tgrvokapue/3GD+nVopJZhXLUPdy2DWkqiEgkWCGoiBZ1ppnsPc38sJc5UMB1YKx3bZSqXEg+YeQn7pMSDUna6tOU36f0TJNlMW7wTL5ea9ZB8G0hhbaSAMKlhR7MtoB4hiQQAVZCXRMrOkdPplI0rUeqPl45KrxX44rfofemsJ6TZT0q7/5LujTJXOwEAAACQJD16fjct2HJYG/bF6+nv12nSnQNkt3OMhHytg2hnBw/UoLG5lTQmstMp7V0p/fSIdChaklNKPiStHF/Mwjap2wVSx7MkvyDTFq/reZKDseJQd5FEAoAq8Hc4JJnvE1k5Tvk6OECqNxw+0uUfmS+JkvTnK1LP66QV48z9nx6Vrv7MuvgAAAAAD+PrsOuDYX102ceLtGjrEf1v2S7dcWYnq8OCB2kVbNrZHU49rMycTPnaObGOOsBmk9qdKo1YJqUlSBnJUswSacd8KSPJtMVLPmwql9KOS1vnmJuLX4jUqIPkEyC1OElq1ctUL4W2kToOkuwOq14ZUC4kkQDApRJVRK5KJMm0tPN12EtZGnVOv9tN8ijlqLn/Yc/8x9Z9QxIJAAAAKKRr82A9f8mJennmJo3+9R+d2bWpurUIsToseIjGAY3la/dVZk6mDqUcUpvgNlaHBFRMQKi59bzO3Arbv0baON0kltLipZhlUkaidHCjeXzfyqLP8Qk06+xwpmmL5xMgtekndRsq2TnPBOt5dRIpMjJSkZGRysjIsDoUAHVUwcqjjKwcBflbGAxqxjM73MdHAgAAAFCqWwd20NzNh7Rwy2E9NnWtfhxxptsFePBedptdrYJaaXfibh1IOkASCfVP677m5pKZJh1YZyqW4veYtnipx8w4TAfWSZnJUlaqlJQqbfrBfV3Nw6WbpkkN29XuawAK8eokUkREhCIiIrR37161a8c/I4CK83HYZbdJOU4zLhLqqZePSa82Kjr/6xukm6fVfjwAAACAB7PZbHrnul4a+uFCbdqfoDF/btHTQ0+wOix4iLwkEuMiwRv4BkjtT8u/3++O/OnsTJNMyspNNMVuMAmluJ3S1j/M+Ev/PVs642GpWQ/J4W/GVgpsKIW1kxx+uTffSnXXAcrLq5NIAFAd/HzsSsvMUUYWSaR6y26XRh6XXmnoPn/r71ZEAwAAAHi8FqEBGn11Tz349Wp9Nn+7hvRorv4dG1sdFjxAy6CWkkQSCXD4SiEtzHSjDlL4FfmPHd0uTbxUSjxg2uyXxreBFNZW8vE3iaYmXaXTI8zYS0A1oJYYgHerhis1/HLHQaISqZ4r6b0yKkzKSJHmvWGuHAIAAAAgSbq4Zytde0pb5Tilx6etVWJaptUhwQO0Dm4tSdqftN/iSAAP1qSL9NAK6byXpc6DpbanSq16S427mERRQZkp0pEtppJp30pp/bemgmnVJEtCR/1DJRIAVJGfj0NSFpVI3uC5fdLoYnp2z39TWvqxtPBtaVR87ccFAAAAeKhRV4Trrx1HtScuVa/+FK13ru9tdUiwWKugVpKk2ORYiyMBPJx/iHTWk+ZWWE62lJVuWuEd22kubs1Ol5IOS8sipYMbpJ8eMeMvnXKbFNiIlneoNJJIAFBF/rkDxJJE8gIO3+LnL/04fzonW7I7aiceAAAAwMOFBPjqg2F9NGzsMn23aq/OO7G5Ljq5ldVhwUK0swOqgd0h+TUwtwaFWoX2vE6aeou05Tdp7khzk82Mn2T3Mc+1OySbw9wPbi51PV/qep7UcZAlLweejXZ2ALyb01nlVfj50M7Oazj8yl5mw/c1HwcAAABQhwzo1Fj3n91FkvTcDxt0KDHN4ohgJVc7uwPJB+SshmNyAIU4fKVhk6VBT0hBzXNnOk2lUmaylJ5gKpRSjkhJsVLsemnx+9LG6ZaGDc9FJRIA5KlcWa+vwzyPSiQvUJ7S7x/vk3oPq/lYAAAAgDrkiQu6a+GWw4o+kKBnvl+vCXecKhutlbxSiwYtJEmpWamKT49Xw4CG1gYE1EcOX+n8keaWkSKlxUvObCkny3RQycmdzkyRdi02iaQ2/a2OGh6KJBIAVBGVSCgi6bAU3MzqKAAAAACP4edj14fD++iyjxdr/r+HNXn5bt06sIPVYcECAT4BahzQWHFpcdqfvJ8kElDTXG3vStKW5BFKRzs7AN6tGq5883MwJpJXGRUvjTxufpbk3a61Fg4AAABQV3RvEaL/u+gESdKbszcr5miyxRHBKq2DclvaJTEuEgB4OpJIAFBFPrlJpKxsejl7DVfycVS81P/u4pfZt7r24gEAAADqiDvP6KiBnRsrNTNb/zd9vXJyOI7yRq2CW0mS9ifvtzgSAEBZvDqJFBkZqfDwcA0ePNjqUADUYa4xkbJyqETySpe9X/z8L4bUbhwAAABAHWC32/T2tb0V6OvQXzvi9PXyGKtDggVclUj7k0giAYCn8+okUkREhKKjozV//nyrQwFQhznsVCKhBK81tzoCAAAAwOO0b9JA/3dRD0nS6F//0Z64FIsjQm1zVSIdSKadHQB4Oq9OIgFAdfC1U4mEEmSnS6PCpM8GWR0JAAAA4FFuO72jBnRsrJSMbD37w3o5nVyU502oRAKA2jF69Gj1799fISEhatGihW644Qbt2rWrQusgiQQAVeST284uk0oklOTgBmnpJ1ZHAQAAAHgMu92mt6/rpQBfu5ZsO6opf++xOiTUIlcl0sGUgxZHAgD124IFC/Twww9r+fLl+u233xQXF6eLL75YWVlZ5V4HSSQAcLHZKvU0H4fZlWYzIKz3GvFX2cvMeUFK4Co7AAAAwKVj0yA9daFpa/fmL5u173iqxRGhtrRs0FKSFJcWp/TsdIujAYD667ffftPtt9+u8PBw9e3bV1988YX++ecfRUdHl3sdJJEAoIp87K5KJNrZea3mJ0qj4s3tsQ0lL/f+ibUXEwAAAFAH3HlmJ/Xr0EhJ6Vl6djpt7bxFmH+YAn0CJUkHk6lGAlC/TJ48Wffff7/69+8vf39/2Ww2TZw4sdTnrFixQpdccokaNmyooKAgDRw4UNOmTav22OLj4yVJjRs3LvdzSCIBQBX52M2uNItKJEhSw/bSc/tKfnxUWO3FAgAAAHg4R25bO38fuxZtPaLvVu61OiTUApvNphYNWkiipR2A+ufFF1/U2LFjFRMTo1atWpW5fFRUlM4880wtXrxYN9xwgx544AHFxsZq2LBheu+996otruzsbD311FO65JJL1LZt23I/jyQSAFSRb+6YSFlUIsHFP1i6f2HJj48fKnGFJQAAACBJ6tIsWE9e2F2S9NrP0ToQT1s7bxDqHypJSsxItDgSAKhe48aN065du3T48GE98MADpS6blZWle++9V3a7XQsXLtTYsWP13nvvad26derevbuef/55xcTEuD3n2Weflc1mK/VWmNPp1AMPPKDdu3eXWRVVGEkkAKgiR247OyqR4KZV75If2/OX9ErDWgsFAAAA8HR3D+qsPu0aKjE9S8//sIG2dl4gxC9EEkkkAPXP+eefrw4dOpRr2Xnz5mn79u266aab1KdPn7z5YWFhev7555WRkaFJkya5PefJJ5/U5s2bS70V5HQ6NWLECM2dO1d//vmnmjVrVqHX41OhpQGg3imama8oX0duO7tsDnJQSKdzpJ0LSn78y4uku36rvXgAAAAAD+Ww2/Tu9b10yUeLFfXvYU1fvU/X9St/qx3UPaG+VCIBwPz58yVJF154YZHHhg4dKklasMD93FKzZs3KnQhyOp2KiIjQ7NmztWDBArVr167CMVKJBMDLVT3x45NbiZSZQzs7FHL7rNIf371MOr67dmIBAAAAPFzX5iF67PxukqRXf9qkgwlpFkeEmhTsFyxJSswkiQTAsyUmJiohISHvlp6eXm3r3rp1qySpW7duRR5r2bKlgoOD85apjIiICE2ZMkXffPONAgMDFRsbq9jYWGVkZJR7HSSRACBP5aqSfHIrkbKpREJxRh4v/fEPe0pp8bUSCgAAAODp7jurs3q1DVNCWpZe+JG2dvUZ7ewA1BXh4eEKCwvLu40ePbra1h0fb84JhYWFFft4aGho3jKV8dlnn+n48eM666yz1KpVq7zb0qVLy70Or25nFxkZqcjIyApl3QCgMB/GREJpbDZpVLw0qvgvA5Kk/7Q3P0eRTAIAAIB383HY9c51vXXZx4s0d/MhzVy7X1f1bWN1WKgBJJEA1BXR0dFq0yb/s8jf39/CaCqmOi7G8OpKpIiICEVHR+f1HQSAyvBx5Lazy6adHUpRngTRgXU1HwcAAADg4Xq0DNGj55m2PiNnbdKhRNra1UchviaJlJSRZHEkAFC6kJAQhYaG5t2qM4nkqkAqqdooISGhxCql2uLVSSQAqA6+ue3ssrKdSsvM1hPT1mrUrE2KT8m0ODJ4nOsnlv74f8+ulTAAAAAAT3f/OV10UutQxadm6sUfN9LWrh6iEgkA8sdCKm7co9jYWCUlJRU7XlJt8up2dgBQ2XGQCnIUaGf3y4YD+mH1PklS4yA/PXKetTt5eJiTrpYykqWZESUvMyqMtnYAAADwer4Ou969vrcu/3ix5kQf1E/rD+iK3q2tDgvVyJVESshIsDSO3Qm7FZcWZ2kM9Vl6dro2Hd0kh81hdSgoQ4/GPTSw1UCrw/A655xzjkaPHq05c+Zo+PDhbo/9/vvvectYiSQSAFRR/phIOVqxK/+L55JtR0gioai+t5SeRJKk/54j3b+gduIBAAAAPNSJrUL10Lld9eHcrRo5c6PO6NJETYPrzjgUKJ0riZSUWbSdXY4zR38d+EtHUo8oIztD0/6dpn+P/ascJ23kgZoyrMcwkkgWOO+889S5c2d98803euSRR9SnTx9Jpr3dm2++KT8/P912222WxkgSCQCqqGA7uzW7j+fNX7P7uDKycuTnQ+dQFDIq3lQcleTAWiqSAAAAAEkjBnfVbxtj9U9sokbO3KTIm0+xOiRUk0CfQElSalZqkcf+iPlDTy14qlbjaRvctla3502yndlqF9JOzRo0szoUlOLkpidbHUK9MW7cOC1evFiStGHDhrx58+fPlyQNGjRI99xzjyTJx8dH48aN09ChQ3X22Wdr+PDhCgkJ0fTp0xUTE6N3331XHTt2tOJl5CGJBMDLVb2vto8jvxLpQHz+gK8Z2Tk6EJ+qDk2CqrwN1EMjj0uvNCx9GRJJAAAA8HJ+Pqat3ZWRSzR7wwFduuGALunZyuqwUA1KSyLFJMRIkloGtdQJjU+Qv8NfrYJa6faTbpetGtrSF+Rj91GYv7WD1gOoXxYvXqxJkya5zVuyZImWLFmSd9+VRJKkIUOGaPHixRo5cqSmTp2qzMxM9ezZU2+99ZaGDRtWa3GXhCQSALjYKvdF1NXOLiUjW/GpmZKkJkF+OpqcoX3HSCKhBDZb+RJJMyOkKyNrIyIAAADAI53cJkwjBnfRx/O26aUZGzWwcxM1DvKzOixUUcEkktPplK3AMXl8urmY7uKOF+uJ/k9YEh8AuAwZMkS+vr6KiIhQREQZQxRImjhxoiZOnFihbQwYMEC//vprJSOsWfRYAuDlqn4Fk4/d7Er3HzdXT/n52BXeOlSStPd40SuqgDw2m9Tp7NKXWTNZWjWxVsIBAAAAPNVD53ZVjxYhOpqcoZGzNlkdDqqBK4mU48xRRk6G22OuJFKof2itxwUAhUVFRSk6OrpcCaT6iCQSAFSRq53d/uOmlV2zYH+1bWS+DO87RhIJZbj9J6lxl9KX+elRKZP3EgAAALyXv49D71zfSw67TT+t26/fNsZaHRKqyJVEkqTUQsc78RkmiUSbOQCwHkkkAKgiVyVSama2JKlJsJ/aNDRfhvdTiYTyeGS11LB96cu80VJKT6ydeAAAAAAP1KttQ91/dmdJ0oszNupYckYZz4Anc9gd8rObtoSucZGcTqe+3/K9tsRtkSSF+ZFEAgCrkUQCgCpyVSK5BPn5qFmIvyTpSFK6FSGhLnpsQ9nLjG4rrZpU9nIAAABAPfXIed3UtXmwjiSl69Wfo60OB1UU4BMgKT+JtPLgSr2y7BXtT94vSWrWoJllsQEADJJIAFBFDluhJJK/Q02DXUkkroxDBYyKL3uZnx6Rtsyp+VgAAAAADxTg69A71/WS3Sb9uGaf5kYftDokVIGrpZ0ribT8wHJJUsfQjnqq/1Pq3ay3ZbEBAAySSABQRQ67exKpgZ9PgSQSlUiooIdWlr3MN9dLh/6p+VgAAAAAD9S3fSPde5Zpa/f8jxsUn5JpcUSoLFcSKSUrRStiV+i/6/8rSbqgwwW6/aTbZbdx6hIArMaeGACqyG4vWonUJNj0dT6alCGn02lFWKirmnaTHllT9nKfnlbzsQAAAAAe6vELuqtzsyAdSkzXa7Npa1dXFaxEWrZ/Wd78CzpcYFVIAIBCSCIBgEuhtnTlVbidXcFKpIzsHCWkZlU5NHiZxp3Lt9yoMCnpcM3GAgAAAHggV1s7m036ftVeLd1+xOqQUAkFk0jp2aaTx50n36kTm5xoZVgAgAJIIgFAFdkL7Ukb+DkU4OtQiL+PJOkwLe1QGSOPl2+5d7tK6Uk1GgoAAADgifp1aKxbTusgSXrxx41Ky8y2OCJUVKBv0SRSgCPAypAAoIghQ4YoPDxckZGRVodiCa9OIkVGRio8PFyDBw+2OhQAdVhxlUiS1DSEcZFQBTabNCq+fMuObiMd2Vqz8QAAAAAe6OmLeqhZiL92HEnWZ/O3Wx0OKqiBTwNJJomUlpUmSfJ3+FsZEgAUERUVpejoaEVERFgdiiW8OokUERGh6OhozZ8/3+pQANRhhcdEauDnkCQ1zR0XiSQSquTluPIt90l/WtsBAADA64QG+GrU5SdJkj6bv13bDlGlX5e4EkYZ2Rl5lUgkkQDAs3h1EgkAKjsOUkH2QusI9HUlkXIrkRJJIqEK7I7yVyS921Wa/WTNxgMAAAB4mEt6ttSQHs2UkZ2jF37cIKfTaXVIKCdXwig9O11p2bmVSD4kkQDAk5BEAoAqchSqRPLzMbvWRkGmEul4amatx4R66Klt5VtuxThp36qajQUAAADwIDabTa9eebICfO1avjNO36/aa3VIKCc/hzluTs9OV3oWYyIBgCciiQTAu1XDFWqFx0TydZhda8NAX0nS8RSSSKgGwc2kW2eUb9kvzpV2L6/RcAAAAABP0q5xAz1+fndJ0pu/bFZccobFEaE8aGcHAJ6PJBIAVJG90J7UVYkUlptESqASCdWlyxDp/oXlW/bLC6WcnJqNBwAAAPAgdw3qpBNahuhYSqbemL3Z6nBQDm6VSLlJpAAfKpEAwJOQRAKAKirczs7XYe43bJBbiUQSCdWpVe/yL/tqIynytJqLBQAAAPAgvg673rymp2w2afrqvVq6/YjVIaEMVCIBgOcjiQQAVWS3FT8mUlhg7phIKbRRQDUbFV/+ZQ//I/32XM3FAgAAAHiQU9o30i2ndZAkvfjjRqVlZlscEUrjShilZ6crLSvNbR4AwDOQRALg3QolgCqjSBLJNSYSlUioSRVJJP31qXSIdh4AAADwDk9f1EPNQvy140iyPpu/3epwUIri2tmRRAIAz0ISCQCqqHA7O1clkiuJxJhIqDEvHS3/sp8OlLKzai4WAAAAwEOEBvhq1OUnSZI+m79d2w4lWRwRSlKwnV1adm4lkg9JJACeZciQIQoPD1dkZKTVoViCJBIAVJGjxHZ2uZVIKZlyOp21Hhe8gMNHGnm8/Mu/1kSaeFmNhQMAAAB4ikt6ttSQHs2UkZ2jF37cwDGZh3KrRMoylUgBjgArQwKAIqKiohQdHa2IiAirQ7EESSQAqCJ7oT2pr6udXe6YSFk5TiVn0IcbNcRmk66fVP7ldy2SRoVJM7zziw8AAAC8g81m06tXnqxAX4eW74zT96v2Wh0SilFwTKSMnAy3eQAAz0ASCYB38w+p8iqKtLPLTSIF+NrzqpLiaWmHmnTSVdJVn1XsOWsnS2kVGFcJAAAAqGPaNW6gxy/oJkl685fNikvOsDgiFOZKGCVmJObNC/ChEgkAPAlJJADerWF7aeho6crK9zS1l9DOzmazqWFuS7tjHKygpvW5qeLP+U97yemU0ukRDwAAgPrpzjM76cRWoTqWkqk3Zm+2OhwU4mpnl5CekDePSiQA8CwkkQDg9BFS31sq/fQiSSRH/q41JMBHkpSUnlXp9QPlNqoSlUWvNJRGt5E2/Vjt4QAAAABW83XY9ebVJ8tmk6av3qul249YHRIK8LObJNLRtKOSJB+bj3zsPlaGBAAohCQSAFRRkXZ2Pvm71mB/8+U3mSQSaktlEkmS9N0d1RoGAAAA4Cn6tm+kWwd2kCS9+ONGpWUyZq2ncCWMsp3ZbvcBAJ6DJBIAVJGjUCWSb4FKpCB/KpFggZHHK/e8d3tIo8KkHA6qAQAAUL88NbSHmof4a8eRZH06f7vV4SCXw+5wu98iqIVFkQAASkISCQCqyF5oT+pToDIpKK8SiZPyqEU2W+UqkpJizc9XG1dvPAAAAIDFQgN8NfLykyRJn8/fru2HGRfUE/jY3CuPru9+vUWRAABKQhIJAKqocDs7e4H7tLODpV46WvnnbvxBmjtKyua9CwAAgPrhkp4tNbhHM2Vk5+ilGRvldDqtDsnrOWzulUiBPoEWRQIAKAlJJACoInuhdnYFBfmbL8S0s4MlHD6Vb233/Z3S4g+k15pIHFwDAACgHrDZbHr1ipPl72PX0u1HNXPtfqtD8nr2Qq09SCIBgOchiQQAVVR6EolKJFjMZpMeWVu1dbzSsDoiAQAAACzXvkkDPXxuV0nS67OjFZ+SaXFE3q1wOzuSSAA80ZAhQxQeHq7IyEirQ7EESSQAqKLC7ewKCvbLTSJlkESChRp3kp7dU7V1/P1F9cQCAAAAWOzeszurS7MgHUnK0Ltz/rU6HK/msLu3swvwCbAoEgAoWVRUlKKjoxUREWF1KJYgiQQAVVRKDimvEikpPbuWogFKEBBatef/8pSUmSb9/oIUPat6YgIAAAAs4O/j0GtXnSxJmrw8Ruv3Hrc2IC9WeEykAAdJJADwNF6dRIqMjFR4eLgGDx5sdSgA6jBbecZESqNFAjzAqHjp6rGVf/4bLaRln0jTbpX+/bX64gIAAABq2RldmuqqPq3ldEov/LhR2TmMA2oFH7t7O7uG/g2tCQQAUCKvTiJFREQoOjpa8+fPtzoUAPVU/phIVCLBQ/QeVj3rmTJc2reqetYFAAAAWOD5S09USICPNuyL1zfLY6wOxysVrkTqFNbJokgAACXx6iQSANS0/HZ2jIkEDzIqvnrW88W50qgw6dDm6lkfAAAAUIuahwTo6aE9JElv//6vDiemWxyR9yk4JlLb4LZFxkgCAFiPJBIA1KBgVyVSBkkkeJjqSiRJ0qcDq29dAAAAQC26+bQO6tkmTIlpWfrPr/9YHY7XKViJ5BQtBQHAE5FEAoAaFOhrvhCnZtDODh5o5PHqXV9KnLT6K8nJwR8AAADqBofdplevPEmSNH31Xq3cFWdxRN6lcDs7AIDnIYkEADUo0M98IU7LJIkED2SzVV9F0ldXS293kmY9JL3SUDpWoKf8zkVS3M7q2Q4AAABQzfq2b6Thp7aTJL00c5OysnMsjsh7FGxf5+RiNADwSCSRAKAGBfi6kkgchMCDPb2j6uvYPs/9/phe0tHt0v410qTLpI/6UKEEAAAAj/XMRScoLNBXmw8kaPJfMWU/AdWCdnYA4PlIIgFADXK1s8vIzlF2Dl+I4aGCmlTvGEkuH58izXw4//7fY6t/GwAAAEA1aBzkp6eH9pAkvffHFh1OTLc4Iu9gt+WfmiSJBACeiSQSANSgAN/83Swt7eDxwq+s/nUe3JA//esz1b9+AAAAoJrcOKC9Tm4TqsS0LP3n13+sDsfr5Djp4AEAnogkEgBUI7vN/X6AT35pPkkkeLwb/ie9eLjmt0NbOwAAAHggh92mV688WZI0ffVerdwVZ3FEXobDBADwSCSRAKAa2W3uWSS73SZ/H7OrTSWJhLrAx69m1/9xP+mVhtKoMOnQPySUAAAA4FFOad9Iw/q3kyS9NHOTsrKpjqkttLMDAM9EEgkAqlHhJJIkBeSOi5SWycEH6ohR8VLX82tm3Ue35U9/epr0SiNp9lM1sy0AAACgEp65qIdCA3y0+UCCvvl7t9XheA2SSAA81ZAhQxQeHq7IyEirQ7EESSQAqE5Fc0gKzEsiUYmEOuSW6bW0Iae04ov8u9v+lF5tIm34vpa2DwAAALhrEuyvp4b2kCS9/8cWHU/JsDgiAICVoqKiFB0drYiICKtDsQRJJACoRoXHRJKkAF+zqyWJhDpnVHztbeu7O6TsLGnyNVJOljT97qLLZCRLG3+Q0hJqLy4AAAB4pZsGtFf3FsE6npKpD+dutTocAAAsQxIJAKpRae3sGBMJddLLx2pnO5t+lN4/wX3e0o/d7//0mPT9ndJXVxV9fg7tIgEAAFB9fBx2jbz8JEnSV3/FaMvBRIsjqv9sxbX2AABYjiQSAFQjxkRCvWO3SyOP1862kg+735/zohT1ppSau/0N08zPfaukPX/nL5cYK73aSHqrY8W3uX+ttGZyJYIFAABAfXdm16a6MLyFsnOceu3naDmdjNkDAPA+JJEAoBoVk0PKGxOJSiTUWTabaW3Xpl/tb3vBW9JbHYrOH3+B5DqI/2Ok+Zl6TMpMy1/m2C4pq4z+9WPPkWZGkEgCAABAsV649ET5OexatPWI5m4+ZHU4qG5rJkujwqTDW6yOBAA8FkkkAKhGxVci5Y6JlEESCXXcvfOs2/aosKLz/nxVihotrf82f152btJozkvSmN7S683yH9v6h1nPzoVF1zXTOwfHBAAAQOk6NAnS3Wd1kiS9MTta6Vkc19UrruOAyFOtjQMAPBhJJACoRvbiKpH8ctvZcbCB+uCSd62OIN/i96UF/3Gf95920ur/SUs/yp8Xt8P8/Po683PS5eZnelLNxwgAAIA6L2JIVzUL8deuoymatHSX1eEAAFCrSCIBQDWyFVeJ5JPbzo5KJNQHA+41re082ayH3e9/dU3xy7nGWKqsnBxpzde0vgAAAKjngv199MzQHpKkj/7cpsOJ6RZHVD/ZVMxVmQAAy5FEAoAaFuCqRMrMsTgSoBoNedHqCMrv2M7i5//8eMnPcTqlhAPu8xJjzbhLLhumSTNH0PqioMiB0qiGUlY9P7GSnWl1BHXb3pXSb89L6YlWRwIAQLlde0pb9WobpqT0LL0351+rwwEAoNaQRAKAGhbom1uJlEklEuqRc56WWp9idRTlt2uJ+/21U4pfzumUdiyQ3u4svX+CNDm3BV7Cfum9HtJbHU0FkmROhLsset88t67bNleafo97sqy8ju+RDm+W5JQ2/1T28umJ0j+/SJlpFd9WTUlLkFaMl7KzSl5mTB/ptabSljk1G8vOheZ3Wh+NO0/6K9KMa+by+wvSL89YF5O3ysmWlkVK+9daHUn9kJ0pbfpRSj1e89vKTJMmXCIteLvmtwVAkmS32zTy8nBJ0tSVe7RxX6Hq/K1zpXd7mO9TqN9yavEC0dgN0nsnmg4IqNt+eVqKPE3KSLE6EqDCSCIBQA0L8DW72jSSSKhv7ouyOoLym3iJ+/0ZDxRdJnaD9EpD6X9XSKlxZt62P6TkI9L7J+Yv99+zzc+MAmMq/fmKNGNEtYZcZTmV2OdMvlba8J00I6Liz40vkPBIOlj28hMvlb69UfrYg5KR/2knzX7CJBAlad9qaVSYSSxK5oSBq7Ltm+vzn/fvr+akUeHqtYKyMsqfaIxZZsbu+vDkir+G0hzbJWUkF/9YZlr++GG15e+x5uf3d0vLPpH+/m/pv8Pi5GRL4y4o2sayPLLSpRXjpLhC1YqxG6T4vRVfX20oz3soJc4kMspTEbj6f9Lvz0tjz6l6bOWVekz6fJB0YH3NbePo9pLf6zVp1iPSd3dIb3Wo+QsLln8mxSyRot5w31ZOTuX2/y6lJdHLIz3JVKX+9bn7/KPbrW3/mpEsxSyt2u+msKwM6aurzQURxUk4IH17s7RjfuXWn5NTuYs6alPSYel/V0pbfs+fl5Nj/tbl/R/ITDWfta80rv0LcmI3Vnh/369DY13Ru7WcTunVmeuUk1ogkfT1tVJSrPk+hcqprW52GSnmApYjW4s+tudvacpNRb8fuPz1ufR2R+nAuhoNMc/ng6TE/aYDQmVkZUg7F9X/TgF1wd9jpcP/SJt+qPq60uKL/8z+63NzgZDL6q+k6ffSSQFVRhIJAGqYqxKJJBLqpZePSZd9aHUU1aOkq/ve6eJ+/+AG8zMz1X3+um+Kf/6ySHNypPAJ+mO7ij85FL+3+JMox/e4f/lPOCBFz8xf1uk0B71ZGeaE+quNzXYXvVd8XJI5mVbctv6d7X4/M80kyaJnlryugn5/vuTHXJVHrgPvhH3lW2dBWemln2jKyqj4OgtKPmx+fjHE/PzfFebnjhISp1OGm5NGruRTYanHpf+0lyYXMz7XsV3m7xQ1On9ezJKiy1XVwU3SmN7Sm62Lf/y97tJHfaUN31du/UvGSB/2kuKL+XtmZ5nERnGcTmljgW1mV/AEx6oJ0t6/TTKkoieHF7wlzX5S+qhP/rzju83Jmg9OKvv52VnmbzcqrOInepd8JP1wf8VOmP78hEl0H99d+nJvdzKJjD9eLvpYTo6UfDT//rpv86cr83+z5CPp39/c5zmd0qQrTEKlOG91NIm6/56VPy/1mDT3FelwMe2hjm6XDv1T/phWTTLJadd7fe9KU9lXGwp+DhzbVfqySz6Svryo5GTX3pVS0qGSnz93VP70rIfyp19tZPb/lUkG/fubqbRcM7niz3V9Pv30qKlK/e3/3B/7+BTT/rWyyb1ln0rjLzQVo8X591dp7OD8RNXk66TXW5jEUWKs9PX10oSL3U+sFbRvldmPVOR/8vVm0vZ55oKI4vz8uPTPzybJUhynU1r8YclJqK+vy/1/2Vj+mEqyfZ608QezzcP/Vi1ZuHys2X9J5vNxx3zpmxvyP9+nDDN/a9f3j5wcadz57r/7bXOlBe+YeN5oaeY5s6WN0ysfV0Ud3yN9fmb59veFPHfJCWrg59Do2Ptkf6u9ueCoME7Ye4adC00Vz7+/us//T3tzAcsn/Ys+Z/wF5rvwd7eb+9mZ7vuG3/7PnMD/6dGStxu3Q5p4mXRoc9VfQ2GZqeZ1lZQUyMl2f0/OHCFNukwa3c7c37vKfHf5+YnqiWf1V2Z9qya5z598rZlfWmV9ZqpJelTkQqbEWPN9NXZD6fuyo9uLHq+5zHlR+uTU4lsrH9mW//fOyZE2zah4d4CcHLPPLfhdoODFM4dzv9dsmyt9flbFL6xJPGjew2MHu89PTzLvz9+fz//uPesh04Z9fRXHA4bXI4kEADUsgCQS6jO7Xep/p9VRVI/ln5V/2dHtpOgZRecv/UT6+wtzNW3MMnNw50qofNRX+ic3OXN8jzmh/1ZHc2JrVJj5oj8qzJzMKHxAu+ZrU5XyWtP8ee+fIE27zZxUlqRf/88c9L7ezJwIc3G1DHM6zUGj62Dpz9fMiUbXya0j20p+vT89Kq392myvoLGDTczlTS79/YX0RouiiYrytHRwOs0Be/IR6fXm5nU7neYWPTP/IG3uK+Z3UPhAtqC0BGn3XxU7WZiTIx1YW/z84sQszT9B/89sKSvVnMSTzEnU31+Qdi2W/ptbAbLgP/nPLXjCoaTkS0UVPHB0JQvWTpG+ONe0a0zLvZJ6+t3uz9u/Rtq9vOj6EmOlLy/OTyD88bJ0PEaa93rRZcedaxIbnw8q+lh2ocTFtzebhMGoMHOisqBNM6SxQ9wP+Gc/mT9dMMn312fuCZKcHPeTAlLxCdaDm/KnC5/M3PqHicvVpmjpmPzH3urovmzBhFZ6orT8v/kJtowU6Y+XpPXfmhPMkvn7fH190RPk026X5rxkpleONz/LW/W4/POi876+Tnqns3ktkvtJm4XvuC97ZJv5+5dk5yLzOqYMc5+/eZa0c4G0elLZ/2OufeJbHaXF70uRA9wfdzpN8uHT0/JbxB3fI029xf3E+uIP809Q/1QgeZWw37RPnHR58Sd5q0NJreui3nS/73Sa5IyrdeAfL0m7l0krvyz63N3LTdzvdnNPUH5/t9nvFj5p5kr67FmRP+/QJpXL/Lekqbea6SnDJDmlmaVUozqdRfdLBzeZ/fLcUaaCt7CsAm1LK1vl9/tz0p7l7m0wC5oy3Lxfp99tfmfb/jDbnXCxaUfrSs7/8VLx78svzjUXYGz/s+QYMpJNC8ElY0pepqCyqhRilkpzR5oklCumTTPyk56uWIp7j0gmSfH3F+ZEaVm+ulr6/k6zn4kcULRCu7yObpd+fdrsv35/QToUnf+Y62+/Nbfd67zXzM9XG0l7V+R/H8rJNieXo17P3we6FL5fksIXDTid5vNgawlt5JxO6bs7pZ8ey59XcH9fwdZSrcIC9eh53dTFbqpn0+e9VXShH4upekftm3S5qeKZMtx9fk4JCZiC7V0PrDP/9+92Lz4Z7PqM3DjdfD8o+Jn5UV9p1yLp04Hli/NgtPn/LPy5W1wS5I1W5nXNLiEJ9GpjcxGca9+x4Tvz03WhzrhzzU/X94qCUuJK/7x0Os1+cE6B8XFdFzIU/Pw9uj3/+9Jvz5a8vv9daZIeH/Ut+ticl6RPzyjacvfzQWZf//kg6bUmxa936Sfm+4MrSZ2dadqbu5K7Sz+Wjmwp+r11yk3SJ/3yj6/WfWOSiSV1B0g4YDoXFPbzY2afO6Z3/v497XiB+D42PydfK8Wul74ZVngNxrEYc6FB4c+t97qbn66LG10Kvl+yM933lUcLHes5nSYBSEtjlBNJJACoRs5iDkr9GRMJ3uDx6LKXqU/SS7gSes4L0i9PmatpJ1zknvSRpG9vMicw9v6dP891suXtTvnzCn7J37vSvX1F9KyiV9WvmmhagZVm43Rz0Ph57tX/i941P3cuMCfYP+lX8nPXf1t03tHt+Qe6024r+Ur/5KP5B1e/PGV+Fk5UuJ675CPTmqxgkiAzzcT37U3mirufH89/7FC0SeZNu80cpKXFmxPRUv6BrNNpkjWuk6tOp7ny8Muh5gDRVUmy/jv3mArvz1OOqNgeK4UPyFwmXCz9eL9JHBV8PWnx5mT9sk9MS7+CB5QuBStzCrf7c/0+SpOeaFoLFawsKdhuMDb3gHPGA+bq+8InOFxJrJwcc8L6ywuLtnSZcqO0e6k56Vzwd7Xum6JXX7tOpMZuKHrC3ZVEcDm40SQMJHOisuAJ5+9ul/avlka3Nb/HwlePZmeZpMHBaHPC4sf786tufnnKvMeLO0kh5Z+oPxaTP+/7u8zPv78wJya/zh0jzdWmqLgWOGkJ0rqp5gTO31/kvo7/k359xlyJL7lXYriSOT/ca/YF/2lnWmYmHDAnGKJnSEs/Mglgl0PRpg1f4bHdUo+Z5FlpXCelXa8luUCly8ICY+tkZZjf19jB+SdanU73doMFTyAVVDD553Sav0tJFTXf3lT6uGgFr7L+e6xJXH54shl37fMzc1/DEXMi/vfni54ILvjeLvi3LU5KnNmXFn6POp3mvTv+QmnlhPx9RtJhk7B/q4NJYhW2odAVv1t+N8mZwq0DixtDrmD7s2+GmZaao8LMvmH/GmnPX0Wfk51lTsq7rP6q6DJOp/l9px4zf9dRYdL8N03ib9VE92VLqgh5paH5vFo1UfriPHMCbspwyZkjLf6gaGs9yX2/vWpS2RUaPz1mYkuMzY/bZUXu/9WkK8wyha/Ej11f9j7SldxITyyauN1eSqveBW+ZZNQfLxdN5BWuosvJMSeuS7O1QAu4NV+Z/ex3t5uTwwUVPNGbGGte9/ihpuLwl6eKbwtbUgLXlejZU+DigLR48/cs7qKF2A1mX+ZaX8G/5bJP3Jctrm1fces8XuB/cV2h7xebfjQ/F39oPqs3/mD+BxJjzd96/xrzncJVbe3ySkOzX/z62uJfe9wO0z5q1YT8k6oFkwjFVZSsm2rGiTwW4/4Z+PcX0tc36K5W+a/Df9XYos8v2K7qrY4m3upsp4iKK0/7ucL76O1Rps31zgUlP8f1fWHs4LIvnsjJyf8sPRZjvqMm7DfJlK1z8r8ruBT+jHA6JeVuo+BFY8Up+P/qUriauWDCKDPN7N/f6VL0s3nHAhNn1BtmP7j0Y/MdrnA1skvB78cFk8M52SYB7boAreC+6J/Z5n91/FDzGbT0I3NBxNhz3C+ScHUMKCzpkLkY6eh2c0zmkrDfbHPiJeazvOBnROELbgp2Y1g7pehnQkqce+L+/RNM54JvbzYXuayZbL5DrS5wMdvMh1Smwp8XyyKl2U9JY3qZCw0KV9KVZEOB45nsdPeOD65jJMlUf319nUkA1mZLY9RpPlYHAAD1XYCPa0ykWhx8E6htYW2kUfHuB/Qo3qoJUtPu5Vv2wDpzNXpB024tulxp7TQkc7J7eW6S6djOoge4FWmfFrdTatyp6Ekr1xV1hZd1tQq7pZQWNSu/lHpcZK4Ql8yVh5d/aE6cjztXCmsvxece9G6e5b7+357Lv+86iHeJ32sOyF0H0bf/bA7KkmKLxvDDPe73i1xVaFPeQXtBvz7tfj9up6lycvnqaqlBgWRiRnJ+EqewhP1SaKF2cwWrEFKPS293ltqeKt39u0o0uq352eNS6cbc9loFTwiPO9f8v7qkFZiWzIn3UfFS9I/58z7qIz21TQpuZk4A7y/w+5lSqJXT683N37vr+UXXfbzQifzSrk6VTGXeqPii7Zz+077osvNeL3pF5o/3mVhcJ2GP7TQnbcLauS/3+4vSVZFF4/nnl/zkZ0GZaSqSVMzKMEkgl1+ekgbca6r4pNzxDB4yJ4tdVk+SrvjIfT0H1hVtjei6aliSUo7mV2B1u1AKyr0Kd9rtpZ/kqoiDBX7fSz+Wrv48/6pc3yDphf3ulXk52ZLdXDSjmMX5839+zLzf96+WHlwqtSimbVThZIvTaf53//3VVLu6RL1RfKwFq1wKv/6C779x50ovHTEDWsdtlxq2NyfTRiyXmudWdu5aZP7mN08zcbhes0vBk11Thkv7VprpuSOlQY8Vjc3plGY8KDXp6t4irmASZc9yk5BwFDw0d7o/7mqp6TL93qLbStjrfhJ/z3LT1sYvSLLlvlcLv56CCn+OfH+X+f31v8ucmBs7WBr0eNHlXb8Dl4IXWexaKHU6x/2k1l+RJhH09HYpINT8jmwF/pfS4s3npCT9cJ90+6yi48vsXp7/t36tqfRYof2DrdD/ZmGuKkTXz5cKnERd9ok06AnJN8D87pxOU3nU8mT3CqQJF7mvM3KAafG7+D2p3UCpdR/3x5OP5v+vuhRc36yHpX7lqOx2/d4LJxKP7ZIadTTTW3437zubwySKb5xa/LrmvSGd+4LZd+yIMuseeTz/9+d05leQbp8nDfu69H3MinHSpYUqPAt/Bhzb5X7iu6TKo7kjzc/vc38n7/UofrnsrKIX9jidJsEWPdO8z4KaSDkFkn5Rb0gnXm5aJLpsniX9c7U5WXvTd1L3C83nh5T//n1so9SwXd5ngu9W98/hzTH7daKKEb83/7N81sPSVZ8WtxRqw3/Plu78VWp/evmfU/B/LW6H1Lhz6csvfFc65+niH8vJNglQSbrsg+KTPDlZ5iKAU241VcoxS90fL3zxjcv+NaadZ9fzi3/c5cOe7vff6SI9t0/yD5b+KvDe3PSj1OdGc6FXwcS2wz9/+tubSt9WcTZON/vZZZJeLpRkLri+wvu48ee7f3ctzrvdzM/C+5X1U/Mvtjuwtuhnqkvh46MZD0hNuuXfz8nOv+jvkbXmeMjln5/zt1vwM0+S1k423zELf28srVVv4dbgrrapw4ppv75uqrlY7OynTNWuS+G/tWRe45+vuieUJLMv3fu3FNBQahFeclzwaiSRAKCGudrZpWdx5Rm8wKh4c4VueVu9eKuCX/BLsmKce6uuqvikv9wOXAqfSEw5qiJ+uE+68PWiV/t91Kf4g7iMJPf7WenmKmKXkloQSe5XY0vmBGKHM0x1hpSfQCps6s3u911tM1wKj3Mw9eaiJ7RKUnib0TOKv4K48JXXWWnmoLOglAInKDfNKBqny/snSg+tKjr/6HapSZfcK1yzzYF1wYTtgPuli98yLSxclW1S/tWUxbUq3FZKyybJvNbCSbl3u0rXTSiaaNlSzNWRk681JxoKj3E0qdCBe9LB0uNwcVXOlKZwAkkyv+vCY9OM6VV0ubwD/EInEEoa6yQ9oeh4XsnFVNssKNQirmACyaUqY3h9c0PRk/iFRb1pKihu/bH05SRzRXGPi/LHA5OkdVOkwQX2WZnJRaswFr1nWrMVrFSS3K/E/fkJcyK+sMLjDGz6oeh7rySHNps2WS6F2xUVtnqSSSBJ+Vdjf3qadP9Ck0CSzP4oLV5y+JW+rsK/93e6Fl1m8yzz+yvsmxvc789+wmyzcWdz4rCskeWLq3ApnFiKXS+NbiOdcJk0/OvKXeTx8+NSn1vyT8wVl1AtzaHNxbeAys4wydXMVOnPV6S+t0pXfmL2OwXbNe5cYP6HC7eYm3qL+/0JF1csLsm9nWXhz8B3ck8Sj/jLJABcCY2CCr7vXDZ8l98a6dlCnyHvdJZunm5OZg64V9pSzNX7Wwp8Fhb+e6UlmKRbcc+TTDWu67O58PurcMtJl4VvS6fd794K9JWGJpGUleZ+IvTfX4qOEVmcIq31Cu1TD6wz1aLVZe7Iou3+Xm2UP/1OZ/N7KfieWfRe0XamWWn5+/tvrjcJwcI+PFnqcl7R+bne+3WTxhX3QMFqubVfm5vDT+o9XOp/d9GEo5ezlbX/q6oJF0t9bi57OZeCF0lFvSkNLNRStvD/6pIPiyaR0hLcLzKRik8gucx6yLxHXInMggp/9507Sjr1nvyxccKvcn+8PK0ady+Tul3gXrU+4wGp53VFKyPLGrsyLV4KCCuayDkWY5KpPxT4rKroeIVFLrjIlZNd+nFG4XH+XJ/3efeXmAsHvjhXRRwtUHVeMAH+UR+pZTFJGqn4qkyp6AUOrsp7l/1rzfi23YcW/3yp6N9fyn+fdCpHRdGKcUUTSJJJcLvGdXwhVvINLHtd8Do2Z3G9l7zM3r171a5dO+3Zs0dt27a1OhwAdVDHZ82JukYNfLXm5QvdHvsj+qDu/d9K9WnXUDMizrQiPKD2UZFUvz25Jb8Xd0k6DHKvSCjLi4fNWEaeonXf0seDkaSL3ylaiXTnr5U7oVmWJ/4pWp1SHt0udE8sldcDi4sfw8gKD60qvd1idel9Y/En/MsrpHXZ7auKU5Ovr+NZRU+WuDy2sfge/3fPNVf8ertL36u+RH5FDbjPtO+rLu1PNycJPVlJV+VXRt9bi0/YluTa8UXbrFbGoMdNS7+aEn5V8eMxulz0lqn0KtyuqjQPrSw6DmN1vv9u/dFU5FbES0eKtgOuijMeLr5aujQPry6+TWApeqWN1fqAQif9n9lpKrwKVlGj/qvt7gz+oSW32i6Pxl2kR1YXjfnMx0xSrKKePyC92ary8ZSmuH1KTX33rk73LzIXRpTUCrg6XP6R+7hUlTV0tHR6Ocff9BKuvEHXrl3l6+uriIgIRUSUMn5kPUUSSSSRAFSdK4nUsIGv1hZKIi3aeli3jv9bJ7QM0W+PnW1FeEDtS4lzH+MHKEvvm/KvgAMK6nR2xa9WrUu6nFe00gJA7Wl3mnurwso67YGi42ug4v4vxow1VsfcmfG0Jvi9U/aCgCc696X88eKqqt8dRcfZQ83zCy7aGaJSbNKo49WwnvqDvIFhL3sRAEBV+Pu42tkxJhK8SIPGZfetBgoigYSS1OcEkkQCCbBadSSQJBJI1WVZpNURVMrNfvX8swr1W3UlkCQSSFaplgSSVOwYsIBIIgFAjQvwNbva9EzGRIIXGnnc6ggAAABQVyx82+oIKuV8VVMyEgAAD1Tnk0gffPCBTjrpJAUHB6thw4Y699xztXw5H94APEeAr6lESqMSCd6o8ACiAAAAAAAAqDPqfBKpQ4cOev/997Vu3TotXbpUXbt21dChQ3X06FGrQwPghYobZS4gt51dGpVI8FZPbJa6DbU6CgAAAAAAAFRQnU8iXXPNNRo6dKi6dOmi8PBwvfvuu4qPj9fGjRutDg0AJEn+ue3s0jKz5SwuywTUd6GtpZunSef8n9WRAAAAAAAAoAIqlUSaPHmy7r//fvXv31/+/v6y2WyaOHFiqc9ZsWKFLrnkEjVs2FBBQUEaOHCgpk2bVpnNlygjI0Njx45Vo0aN1LNnz2pdNwCUR3FJIlclUo5TyswmiQQvNuR56YWDVkcBAAAAAACAcvKpzJNefPFFxcTEqGnTpmrVqpViYmJKXT4qKkpDhw5VQECAhg8frpCQEE2fPl3Dhg3Tnj179OSTT1YqeJdFixbp4osvVmpqqlq2bKk//vhDjRs3rtI6AaC6uCqRJCk9K1t+PnW+CBSoPN8A6cQrpM2zrI4EAAAAAAAAZajUmcxx48Zp165dOnz4sB544IFSl83KytK9994ru92uhQsXauzYsXrvvfe0bt06de/eXc8//3yRJNSzzz4rm81W6q2g/v37a+3atVq6dKkuvvhi3XDDDTpy5EhlXhoAVElxdUb+BZJGaZk5tRcM4KmGfSU98Y/VUQAAAAAAAKAMlUoinX/++erQoUO5lp03b562b9+um266SX369MmbHxYWpueff14ZGRmaNGmS23OefPJJbd68udRbQYGBgeratatOO+00jRs3Tna7XRMmTKjMSwOAamez2fISSWmZ2RZHA3iI0FZSqz5WRwEAAAAAAIBSVKqdXUXMnz9fknThhRcWeWzo0KGSpAULFrjNb9asmZo1a1bpbTqdTqWnp5f4eHp6utvjiYmJld4WAJRHgK9D6Vk5Ss+iEgnIc/8CKSdHerWR1ZEAAAAAAACgGDU+MMfWrVslSd26dSvyWMuWLRUcHJy3TGX83//9n5YsWaKYmBitWbNG9957r/bu3atrr722xOeMHj1aYWFhebfw8PBKbx8AyiPAl0okoFh2uzTiL6ujAAAAAAAAQDFqPIkUHx8vybSvK05oaGjeMpWxf/9+DR8+XN27d9cll1yigwcPatGiRTrxxBNLfM5zzz2n+Pj4vFt0dHSltw8AboobFEmSv49DkpSeRRIJKKL5iZLD3+ooAAAAAAAAUEiNt7OraV999VWFn+Pv7y9///yTVQkJCdUZEgAU4apESs+knR1QrOf3Sa81tToKAAAAAAAAFFDjlUiuCqSSqo0SEhJKrFICgLqmhEIkBfiaSqQ0KpGA4jl8pVHxUuMuVkcCAAAAAACAXDWeRHKNhVTcuEexsbFKSkoqdrwkAKiLbCXM9/ehEgkol4dXSdd9aXUUAAAAAAAAUC0kkc455xxJ0pw5c4o89vvvv7stAwB1XglZJB+72d1m5pRUqwRAkmSzSSdfK93+k9WRAAAAAAAAeL0aTyKdd9556ty5s7755hutXbs2b358fLzefPNN+fn56bbbbqvpMIoVGRmp8PBwDR482JLtA/AePg6TXcrKphIJKJdOZ0uhba2OAgAAAAAAwKv5VOZJ48aN0+LFiyVJGzZsyJs3f/58SdKgQYN0zz33mA34+GjcuHEaOnSozj77bA0fPlwhISGaPn26YmJi9O6776pjx45VfyWVEBERoYiICO3du1ft2rWzJAYA9UtJ7ez8HLmVSCSRgPJ7ZLW0Yrz0+3NWRwIAAAAAAOCVKpVEWrx4sSZNmuQ2b8mSJVqyZEnefVcSSZKGDBmixYsXa+TIkZo6daoyMzPVs2dPvfXWWxo2bFglQwcAz2OzFZ9GclUiZWbTzg4oNx9/6fQR0ur/SYc3Wx0NAAAAAACA16lUEmnixImaOHFihZ4zYMAA/frrr5XZHADUeT65lUi0swMq4f6F0uvNrI4CAAAAAADA69T4mEgA4E1KKESSrz13TKQcKpGACvPxk148JD23z+pIAAAAAACAlxkyZIjCw8MVGRlpdSiWqFQlEgCgeCWNieSbW4mUQSUSUDk+/uYGAAAAAABQi6KiotS2bVurw7AMlUgAUAvy29lRiQQAAAAAAACgbvDqJFJkZKTCw8M1ePBgq0MBUE/YSuhn5+vIbWdHJRJQNQ8stjoCAAAAAAAAr+HVSaSIiAhFR0dr/vz5VocCoJ7Lb2dHJRJQJS17SqPipZOutjoSAAAAAACAes+rk0gAUN1KGhPJh0okoHpdP1F6YrPVUQAAAAAAANRrJJEAoBqV0M1OvvbcMZFyqEQCqk1oa6sjAAAAAAAAqNdIIgFALchvZ0clElCtgppbHQEAAAAAAEC9RRIJAKpV8aVItLMDasjdv0unPyR1v8jqSAAAAAAAAOodH6sDAABv4JuXRKKdHVCtGneWhr5hprOzpNeaWBsPAAAAAABAPUIlEgBUo5LGRPLJHRMpkzGRgJrj8JFC21odBQAAAAAAQL3h1UmkyMhIhYeHa/DgwVaHAqCeKCGHJF+f3CRSFu3sgBr14BLp1hlWRwEAAAAAAFAveHUSKSIiQtHR0Zo/f77VoQCo53ztue3sckgiATUqsKHUZYh07XirIwEAAAAAAKjzvDqJBADVrcR2do7cSiTGRAJqR8/rpFHxUodBVkcCAAAAAABQZ5FEAoBa4Osw2aXMbCqRgFp152yrIwAAAAAAAKizSCIBQDWylTAqkm9uJVIWlUhA7Wt3mtURAAAAAAAA1EkkkQCgGpXYzi53TKRMxkQCat/dc6RBT1gdBQAAAAAAQJ1DEgkAagGVSIDFBj0m+TawOgoAAAAAAIA6hSQSAFSjEgqR8pJIjIkEWCQgTHrhgDQq3upIAAAAAAAA6gySSABQjWwl9LPzceS2syOJBFjvwWVWRwAAAAAAAFAneHUSKTIyUuHh4Ro8eLDVoQCo53xzk0hZObSzAyzXIlw6/SGrowAAAAAAAPB4Xp1EioiIUHR0tObPn291KADqubx2dllUIgEeoXVfqyMAAAAAAADweD5WBwAA3sDHnptEohIJ8AwnXyulHJXidkjLP7c6GgAAAAAAAI9EEgkAakFeOzvGRAI8g80mnXa/mR78rJSWII3pZW1MAAAAAAAAHsar29kBQHVzOouvNMprZ5dNJRLgcQIbSY06SKPirY4EAAAAAADAo5BEAoBa4OdjdrcZjIkEeLYGTayOAAAAAAAAwGOQRAKAalRSnZGrEikjO6fEaiUAHuD+RdLQ0VZHAQAAAAAA4BFIIgFALXBVIkm0tAM8Wlgb6fQRVkcBAAAAAADgEUgiAUA1spUw379AEikjm5Z2gMe7N8rqCAAAAAAAACxHEgkAqlFZ7ewkxkUC6oQ2p5BIAgAAAAAAXs+rk0iRkZEKDw/X4MGDrQ4FQD3nsNvksJs6JZJIQB3R5hTppmlWRwEAAAAAAGAZr04iRUREKDo6WvPnz7c6FAD1hLOU4Y78cquRMmlnB9QdbfpbHQEAAAAAALDQkCFDFB4ersjISKtDsYSP1QEAgLfw87ErNTNb6VQiAXVHUBPpoVWSb6C5vd3J6ogAAAAAAEAtioqKUtu2ba0OwzIkkQCglvj5mEok2tkBdUzTrlZHAAAAAAAAYAmvbmcHALXJ1c5u4754iyMBUGnBLayOAAAAAAAAoNaQRAKAWnI4KV2SlJieZXEkACrt8Wjphq+sjgIAAAAAAKBWkEQCgFpydremkqT41EyLIwFQaQ4fKfwKaciLUo9LrI4GAAAAAACgRjEmUhVkZmYqOzvb6jAAeJAcp1NpaWnFPtahUYAk6eu/dmnEWe2rfdsOh0O+vr7Vvl4AxTjnafNzVJi1cQAAAAAAANQgkkiVkJCQoCNHjig9Pd3qUAB4mOysLO3cubPYx47Hm7GQGvrbSlymqvz9/dW0aVOFhobWyPoBAAAAAAAAeA+SSBWUkJCgffv2KTg4WE2bNpWvr69sNpvVYQGw3A5Jkt3hUKdOnYpd4ipbQ03ftFppOfYSl6ksp9OpzMxMxcfHa9++fZJEIgkAAAAAAABAlZBEqqAjR44oODhYbdu2JXkEoAibzaaAgIBiH2sSGiRJ2nc8TQ5fP/k6qndYusDAQIWEhGjv3r06cuQISSSgNlz4urTySyluh9WRAAAAAAAAVLvqPYNZz2VmZio9PV1hYWEkkABUWJfmQXnTS7YdqZFt2Gw2hYWFKT09XZmZmTWyDQAFnPGw9Mga6dk90oi/pFHx0oVvSC17Wh0ZAAAAAABAlZFEqoDs7GxJYuB6AJXi7+OQr8MkoONTay7B49pHufZZAGpBQKjU/EQzfcZD0uVjrI0HAAAAAACgGnh1EikyMlLh4eEaPHhwhZ5HFRKAyjqnezNJ0prdx2tsG+yjAAAAAAAAAFQHr04iRUREKDo6WvPnz7c6FAD1hLOMx10JnolLd9V4LNUtMztHHZ+drY7PzlZ2TlmvFAAAAAAAAEBd52N1AADgTa7u20Z/RB+Uj92m/y7YXiPbyMrKUlzccTXeHyMfn+rbzf8bm5g3vWJXnAZ2blJt6wbqneAWVkcAAAAAAABQZSSRAKAWDenRXJKUlePU6F//qeGtxdXYmr9ftZckElCasLZWRwAAAAAAAFBlJJFgmV27dqlTp066/fbbNXHiRKvDKVXHjh0lmZiB0jjL6PIW6OfQe9f31tLtR2sshuzsbCUlJSo4OEQOh6Na1rl422EdTEjPu38wIa1a1gsAAAAAAADAc5FEQqXdddddmjBhgho3bqz9+/fL39/f6pA8RkpKij777DOtWrVKq1ev1pYtW+R0OrVz5868hJSn2LJli1588UXNmzdPycnJ6t69ux544AE98MADeeP3lMfevXv12muv6ddff1VsbKyaNm2qoUOH6tVXX1W7du1q8BXUPdf2a6tr+9VclUJaWpp27typTp06KSAgoFrWeeeEv3Uw4XDe/QvDadUFlOmc/5P2r5VunCK92tjqaAAAAAAAACqMJBIqJTExUdOmTZPNZlNcXJxmzJihYcOGWR2Wxzh06JCeeuopSVKHDh3UqFEjxcXVXGuxyoqOjtYZZ5yh1NRU3XDDDWrdurVmz56tESNGKDo6Wh9//HG51rN9+3adccYZOnTokC688EINGzZMW7du1aRJk/TLL79o6dKl6tKlSw2/GtQkh909odgoyM+iSIA6ZMjzVkcAAAAAAABQJXarA0DdNHXqVCUnJ+vxxx+X3W7X+PHjrQ7JozRt2lRz5szR0aNHtWvXLp166qlWh1SsBx98UPHx8ZoxY4a++uorvfXWW1q9erXOOussffLJJ1q2bFm51vPoo4/q0KFDGjNmjH7//Xe98847mjFjhqZOnapDhw4pIiKihl8Japq9UFVaRlaORZEAAAAAAAAAqC0kkVAp48ePl4+Pj5555hkNGTJEf/75p2JiYopdNjs7W2+99Za6du2qgIAAde3aVaNHj1ZOTvEnoaOionTXXXepR48eCg4OVnBwsPr376+xY8cWu7zNZtPgwYO1b98+3XTTTWratKlCQkJ06aWXaseOHZKkzZs366qrrlLjxo0VEhKi6667TgcPHqzw6z5+/Ljuv/9+tWzZUgEBAerbt6+mTJlSZLng4GBdcMEFatzYc9sXbdmyRQsXLtSQIUN08cUX58338/PTa6+9Jkn64osvylxPWlqafv/9d7Vo0UIPP/yw22PXX3+9+vTpo99//z3vb4G6qXAl0qfzt1sUCVBHnXK7+XnyddLN30tP8z8EAAAAAAA8H+3sqoHT6VRqZrbVYZQp0NdRoTFuShIdHa2//vpLl1xyiVq0aKHbbrtNf/75pyZMmKBRo0YVWf6+++7Tl19+qU6dOikiIkJpaWl6//33tXTp0mLX/9Zbb2nbtm0aOHCgrr76ah0/fly//fab7r//fv3777967733ijzn2LFjGjRokFq2bKnbb79dW7Zs0c8//6x//vlHM2fO1FlnnaV+/frprrvu0qpVqzR9+nTFxcVp3rx55X7dGRkZOv/885WUlKRbb71VycnJmjZtmm666SYdOXKkSALF082fP1+SdOGFFxZ5bNCgQQoKCtKCBQvKXM/Ro0eVlZWlDh06FPv+6tSpk9auXauoqCh17ty5ynF7PqfVAdSIwkmkbYeSLIoEqKMuHyOd8bDUuItkz72Gp8cl0r+/WBsXAAAAAABAKUgiVYPUzGyFv/y71WGUKfrVoWrgV/U/uat13a233ipJuuaaazRixAhNmDBBL7/8suz2/AK3+fPn68svv1Tv3r21ZMkSBQUFSZKef/559enTp9j1f/bZZ+rUqZPbvKysLF1yySUaM2aMHn30UbVv397t8fXr1+vxxx/X+++/nzdvxIgR+uyzz3TWWWdp1KhRevTRRyWZpN9ll12mX375RatXr9Ypp5xSrtd94MABdevWTUuXLpWfn1/e6+jbt6+efvppXXPNNWrTpk251lUR8+fPz0v4lEfHjh11xx13lLnc1q1bJUndunUr8pjD4VCnTp0UHR2trKws+fiU/L5p1KiRHA6HYmJi5HQ6iySSdu7cKclUPqHuKpxEAlBBNpvUtND+9povpJ0LpIYdpM/PtCYuAAAAAACAUpBEQoVkZmbqq6++UmhoqK666ipJpnXb1VdfrcmTJ2vu3LlulS3/+9//JEkvv/xyXgJJktq0aaNHH31UL730UpFtFE4gSZKPj48eeOAB/fHHH4qKitLtt9/u9nhwcLBef/11t3k33nijPvvsMzVp0kSPPPJI3nybzabhw4frl19+0bp168qdRJKkN998My+BJElt27bNex3ffvutnnzyyXKvq7zmz5+vV155pdzLn3POOeVKIsXHx0uSwsLCin08NDRUOTk5SkxMVKNGjUpcT4MGDXT22WcrKipKn376qdv4Rz/88IPWrl0rybQCRN3lKJQcvHFA+xKWBFBu/sHSCZea6danSPtXWxsPAAAAAABAISSRqkGgr0PRrw61OowyBfo6qryOmTNn6vDhw7r77rsVEBCQN/+2227T5MmTNX78eLck0rp16yRJZ511VpF1FTdPkhITE/Xuu+9qxowZ2r59u5KTk90e379/f5HndOvWTQ0aNHCb16pVK0lSr169ilTHuB4rbl0l8fHx0emnn17i61izZk2511URo0aNKrZNoCf54IMPNGjQID300EP66aef1KtXL23btk0zZ85Ur169tH79ercKNdQ99kKVSEF+Vd+fACjgnj+lpWOkuaOsjgQAAAAAACAPSaRqYLPZqqVNXF3gamV32223uc0/77zz1KZNG82cOVNxcXFq3LixJFPtYrfb1bRp0yLratGiRZF5GRkZGjx4sFavXq2+ffvq1ltvVZMmTeTj46Ndu3Zp0qRJSk9PL/K80NDQIvNcLdhKeywzM7Osl5ynadOmxSZCXK/DVdlTV7gqkEqKOyEhQTabTSEhIWWuq3fv3lqxYoVGjhypqKgoRUVFqWvXrvrvf/+r48eP6+mnn1bz5s2rNX7UrsKVSJnZORZFAtRTdrt0+kNS8hFp2SdWRwMAAAAAACCJJBIqYM+ePZozZ44k0zKtJJMnT85rHxcWFqacnBwdOXJEzZo1c1vu4MGDRZ47c+ZMrV69WnfffbfGjRvn9ti3336rSZMmVfVlVNqRI0eUk5NTJJHkeh0ltYWrqpoaE8k1FpJrbKSCsrOztXPnTnXq1KnU8ZAKOuGEEzR16tQi812x9O/fv1zrqeucTqsjqBkOh3sSKSO7nr5QwEoOX2nwsySRAAAAAACAxyCJhHKbOHGicnJyNGjQIPXo0aPI41lZWZo0aZLGjx+fl0Tq3bu3Vq9erUWLFumaa65xW37RokVF1rF9+3ZJ0pVXXlnkseKWr01ZWVlatmyZzjzTffBzV1x9+/atke3W1JhIrkTgnDlz9Oyzz7o9tnjxYiUnJ5eaLCyPxMRE/fTTT2rSpIkuuOCCKq0L1ipciZRFJRIAAAAAAABQ75FEQrk4nU5NmDBBNptNkyZNUufOnYtdbsuWLVq2bJlWrlyp/v3769Zbb9WECRP06quvaujQoQoKCpIk7du3T2PGjCny/A4dOkgySYzLL788b/6CBQv0xRdf1MArq5jnn39ef/zxh/z8/CRJe/fu1ZgxY+Tv76/hw4fXyDZrakykHj166Oyzz1ZUVJR+/fVXXXzxxZJMS8GXXnpJknTPPfe4PefIkSM6cuSImjZt6taiMDU1Vb6+vm5VS+np6br77rsVFxenMWPGuI2hhbrHUWhMpLjkDIsiAeo5v2CrIwAAAAAAAMjj1UmkyMhIRUZGKiODk6FlmTdvnnbu3KlzzjmnxASSJN15551atmyZxo8fr/79+2vIkCG68847NWHCBPXs2VNXX3210tPTNXXqVA0cOFA///yz2/Mvv/xydezYUW+//bY2btyok08+Wf/++69+/vlnXX311fr+++9r+qWWqFWrVkpOTlavXr10+eWXKzk5WdOmTdPRo0f10UcfqU2bNm7LP/XUUzpy5IgkacOGDXnzgoPNCcJ77rlHgwYNqt0XUcinn36qM888U1dddZWGDRumVq1aafbs2dq0aZMeeughnXHGGW7Lf/LJJ3rllVc0cuRIt8TWqlWrdM011+iCCy5Qu3btlJCQoNmzZ2v37t2699579fDDD9fyK0N1sxeqRPrzn0MWRQLUczabdOuP0sJ3pZglVkcDAAAAAAC8nFcnkSIiIhQREaG9e/eqXbt2Vofj0caPHy9JZbZJGzZsmB599FFNmTJF77//vgIDA/XFF1+oe/fu+uKLL/TJJ5+obdu2euKJJ3TDDTcUSSIFBwdr3rx5evrpp7Vw4ULNnz9fJ510kr7++mu1aNHC0iSSn5+f/vjjDz377LP66quvdPz4cZ1wwgn6+OOPdeONNxZZ/vvvv1dMTIzbvOnTp+dNDx482PIk0kknnaTly5frxRdf1OzZs5WcnKzu3bsrMjJSDz74YLnX0759ew0ePFiLFi3SwYMH1aBBA51yyil6//33de2119bgK0BtKZRDAlCTupwrdR4ijW4nZSRaHQ0AAAAAAPBiNqezvg4DX36uJNKePXvUtm3bEpdLS0vTzp071alTJ1pzAXDT8dnZkqSmwX5a+aK14z/VxL7q9Z+jNW7xTrd5u/5zabWsG0AJstKlnGzpzVb58waOkP761LqYAAAAAAD116h4qyPwKOXNG9R3dqsDAAB4PnuhMZH6tm9oTSCAN/Hxl/wa5N8fOEIa+qZ18QAAAAAAAK9DEgkAqlF9re0s3M0uJ6eevlDAkzXtRm9JAAAAAABQq7x6TCQAQDkVOm+dXV+zZYAnum6CtCNK6nur1ZEAAAAAAAAvQxIJAFAme6Hqh+wciwIBvNHJ15gbAAAAAABALaOdHQCgTLSzAzxI485WRwAAAAAAALwESSQAQJkKD8NCOzvAQq16Wx0BAAAAAADwEiSRAKAa1dcx7wu3s6MSCbBQk2750w+vti4OAAAAAABQ7zEmEgCgTIVzY1QiARYa9LiUmSKdcJnUpIvkGyRlJlsdFQAAAACgLgtqbnUE8FAkkQAAZStUiZRNJRJgHb8G0tA38u8HNZWOk0QCAAAAAFRBo45WRwAPRTs7AKhG9bVAx16oFIl2doAH8Q20OgIAAAAAAFBPkUQCAJTJVqih3f74NIsiAQAAAAAAAFBbSCIBAMpkKzwoEgDP4RdUdF5Yu9qPAwAAAABQd3HyByUgiQTL7Nq1SzabTXfccYfVoZSpY8eO6tixo9VhAJYp3M4OgAe56nOp2YnStePz5/W+0bp4AAAAAABAvUESCZV21113yWazqUmTJkpPT7c6HI+SkpKi9957TzfddJNOOOEE2e122Ww27dq1y+rQitiyZYtuuOEGNW3aVIGBgerdu7c+++wzOSs4uM/evXt1//33q3379vLz81Pr1q115513as+ePcUun5OTo08++USnnHKKGjRooNDQUJ199tmaNWtWdbwsy9TXizZsxbwwxkUCPESz7lLEX1LP6/Ln2fiKBwAAAACAt/vggw900kknKTg4WA0bNtS5556r5cuXV2gdnGFApSQmJmratGmy2WyKi4vTjBkzrA7Joxw6dEhPPfWUpkyZorS0NDVq1MjqkIoVHR2tAQMGaObMmbr44ov1yCOPKDs7WyNGjNAjjzxS7vVs375d/fr109ixY3XiiSfq0Ucf1YABAzRp0iT1799f27dvd1ve6XTqhhtu0MMPP6yEhATdfffdGj58uP79919deeWV+uSTT6r7paKKikuOkUICPFjTblZHAAAAAAAALNahQwe9//77WrdunZYuXaquXbtq6NChOnr0aLnXQRIJlTJ16lQlJyfr8ccfl91u1/jx48t+khdp2rSp5syZo6NHj2rXrl069dRTrQ6pWA8++KDi4/+fvTuPj+l6Hzj+mex7kAghJCG2WGIJYheKovZdrVGtWr8IVVWU/oouSi1tEZHaW1qprSiJfWvttLaK2mMNgsj2+2OakTGTZCaZyUyS5/16zSuZe88999yZO3fu3Oee58SxYcMGli9fzqxZszh27BiNGjVi/vz5HDx4UKd6Ro0aRWxsLHPnzmXbtm188cUXbNiwgbVr1xIbG8uwYcPUyq9fv57169fToEEDTp8+zbx581i0aBFnz57F29ub0NBQs+y1VZAp0NITSc/eakKIXNBtGbSYruyV1GqGeoo7IYQQQgghhBAiQ/k0vU4B17lzZ1q1akXZsmXx9/fnyy+/JC4ujjNnzuhchwSRRLaEhYVhZWXF+PHjCQ4OZufOnVy9elVr2eTkZGbNmoWfnx92dnb4+fkxY8YMUlJStJaPiooiJCSEChUq4OTkhJOTE4GBgSxatEhreYVCQdOmTblx4wa9e/fG3d0dZ2dn2rZtyz///APAX3/9RceOHSlSpAjOzs507dqVO3fu6L3djx494r333qN48eLY2dlRo0YNVq9erVHOycmJFi1aUKRIEb3XkVsuXLjAnj17CA4OpnXr1qrpNjY2TJ8+HYDFixdnWc+LFy/Ytm0bxYoVY8SIEWrzunXrRvXq1dm2bZvqvQCIjIwEYOLEidjb26umu7u7M3r0aBISEggPD8/R9plKfo2raO2JlE+3VYg8rXInaPBfT9J6Q5XBpCJlTdsmIYQQQgghhBDmr3SQqVuQb6xYsYL33nuPwMBAbG1tUSgULFu2LNNljh49Sps2bShUqBCOjo4EBQXx448/GrRdL1++ZNGiRRQuXJiqVavqvJwEkYTezp07x6FDh2jZsiXFihWjX79+pKSkZHjR/91332XChAmkpKQwbNgwWrVqxezZsxk1apTW8rNmzWLPnj3Url2b4cOH06dPH+7du8d7773H2LFjtS7z8OFDGjZsyJUrV+jfvz9NmzZly5YttGjRgjNnzlC/fn2ePn1KSEgIgYGBrF+/nl699Bt0/OXLl7zxxhvs3r2bvn37EhISwrVr1+jduzfz5s3Tqy5zEB0dDUDLli015jVs2BBHR0d2796dZT33798nKSkJb29vrePm+Pr6AsrgYJrbt2+rzdNWfteuXVlvhMg1FlqCSNITSYg8IuQ36LAQ6g03dUuEEEIIIYQQQpirGn1M3YJ8Y9KkSSxatIirV6/i6emZZfmoqCgaNGjAvn376N69O0OGDOH27dv06NGDr776Ksft2bt3L05OTtjb2/P111+zY8cOvTo/WOW4BUJ5O37iM1O3ImvWDtq7E+gpLXVd3759AWWXuKFDhxIeHs7kyZOxsHgVm4yOjmbp0qUEBASwf/9+HB0dAWUPlOrVq2ut/9tvv9UILiQlJdGmTRvmzp3LqFGjKF26tNr8U6dOMXr0aGbPnq2aNnToUL799lsaNWrE1KlTVUGr1NRU3nrrLbZs2cKxY8eoWbOmTtt969YtypUrx4EDB7CxsVFtR40aNRg3bhydO3emZMmSOtWlj+joaFXARxc+Pj4MGDAgy3IXL14EoFw5zXEzLC0t8fX15dy5cyQlJWFllfGhonDhwlhaWnL16lVSU1M1AklXrlwBlD2f0ri7u6vmVapUKcvywvS0pbOTGJIQeYSTB9R4G3Z9+mpap0Vw5zQcyHs3QQghhBBCCCGEMAKF9DcxlCVLllCuXDm8vb2ZOXMmH374YYZlk5KSGDx4MBYWFuzZs0d1zXzy5MnUqVOHiRMn0rVrV7y9vVXLTJgwgVmzZmXahtR0F+4CAwM5ceIE9+/fZ/HixXTv3p3Dhw+rrtFmRYJIhpD4DD4rYepWZG3iTbBxzFEViYmJLF++HBcXFzp27AgoU7d16tSJFStW8Pvvv6v1bPnhhx8A5U6fFkACKFmyJKNGjeLjjz/WWIe23ilWVlYMGTKEHTt2EBUVRf/+/dXmOzk58emnn6pN69WrF99++y1ubm6MHDlSNV2hUNCzZ0+2bNnCyZMndQ4iAXz22WeqABKAl5eXajvWrFmTYU+pnIiOjuaTTz7RuXyTJk10CiLFxcUB4OrqqnW+i4sLKSkpPHnyhMKFC2dYj4ODA40bNyYqKoqFCxeqjX/0888/c+LECUCZCjBN69atWbNmDTNnzqRZs2bY2dkByl5Nc+bM0SgvTE9rOjskiiREnlWtOyS1lyCSEEIIIYQQQghhYG+88YbOZXft2sXly5cZOHCgWqcLV1dXJk6cyIABA4iIiGDy5MmqeWPHjtXp+m8ae3t7/Pz88PPzo27dupQrV47w8HDGjRun0/IFOoi0YMECFixYwMuXL03dlDwjMjKSu3fvMmjQINWFf4B+/fqxYsUKwsLC1IJIJ0+eBKBRo0YadWmbBvDkyRO+/PJLNmzYwOXLl4mPj1ebf/PmTY1lypUrh4ODg9q0tK6C1apV0+gdkzZPW10ZsbKyol69ehlux/Hjx3WuSx9Tp05l6tSpRqnbUL7++msaNmzI8OHD2bhxI9WqVePSpUtERkZSrVo1Tp06pdZDrXfv3ixbtoyoqCiqVq3Km2++SWJiIhs2bKBYsWIAauXzEgN09jNL2lIVpkgMSYi8pXi1V/8rFGBtD1b2kPTcdG0SQgghhBBCCGEeipQxdQvM1pMnT3j8+LHqua2tLba2tgapO7MhR1q1agWgMeRI0aJFKVq0aLbXmZqaSkJCgs7lC3QQadiwYQwbNozr169TqlSp7Fdk7aDs5WPurB2yLpOFtFR2/fr1U5vevHlzSpYsSWRkJA8ePFDlVIyLi8PCwkJr17i0YEF6L1++pGnTphw7dowaNWrQt29f3NzcsLKyIiYmhoiICK07uIuLi8a0tBRsmc1LTEzMapNV3N3dtQY20rYjrWdPXpHWAymjdj9+/BiFQoGzs3OWdQUEBHD06FGmTJlCVFQUUVFR+Pn58f333/Po0SPGjRuHh4eHqryVlRVbt25l5syZrFq1ikWLFuHq6kqnTp0IDQ2lfPnyauWF6WmLjcmYSELkMZXaQYcFUKLGq2nDj8DFHbB5jOnaJYQQQgghhBDC9PLrndEG4O/vr/Z8ypQpBrvpP7MhR4oXL46Tk5OqTHZ88MEHtG/fHi8vLx48eMDChQu5fv06Xbp00bmOAh1EMhiFIsdp4vKCa9eusX37dkCZMi0jK1asUKWPc3V1JSUlhXv37mlER+/cuaOxbGRkJMeOHWPQoEEsWbJEbd6aNWuIiIjI6WZk271790hJSdEIJKVtR0Zp4XLKWGMipR2YtB2EkpOTuXLlCr6+vpmOh5RexYoVWbt2rcb0tLYEBgaqTbe1tWXKlClMmTJFbXratr5eXpiW1nR2KbnfDiFEDigUmgOlFioNtQdJEEkIIYQQQgghhMjAuXPnKFmypOq5oXohgW5DjuSk88LNmzfp2bMnsbGxFClShNq1a7N3716NceozI0EkobNly5aRkpJCw4YNqVChgsb8pKQkIiIiCAsLUwWRAgICOHbsGHv37qVz585q5ffu3atRx+XLlwHo0KGDxjxt5XNTUlISBw8epEGDBmrT09pVo0YNbYvlmLHGREoLBG7fvp0JEyaozdu3bx/x8fGZBgt18eTJEzZu3IibmxstWrTQaZmVK1cC0LNnzxyt21Tya+ccCy1RJBkTSQghhBBCCCGEEELkd87OzlqzXeUFy5cvz3EdEkQSOklNTSU8PByFQkFERARlymjPkXnhwgUOHjzIH3/8QWBgIH379iU8PJxp06bRqlUrHB2VPbZu3LjB3LlzNZb39vYGlEGMdu3aqabv3r2bxYsXG2HL9DNx4kR27NiBjY0NANevX2fu3LnY2toaLehhrDGRKlSoQOPGjYmKimLr1q20bt0aUKYU/PjjjwF455131Ja5d+8e9+7dw93dXS1F4fPnz7G2tlbrtZSQkMCgQYN48OABc+fOVRtDC5Tp8l4/+K5bt46lS5dSu3ZtjaCjMC1tPZFkTCQhhBBCCCGEEEIIIbJPlyFHChcunJtN0iBBJKGTXbt2ceXKFZo0aZJhAAlg4MCBHDx4kLCwMAIDAwkODmbgwIGEh4dTtWpVOnXqREJCAmvXriUoKIhNmzapLd+uXTt8fHz4/PPPOXPmDFWqVOH8+fNs2rSJTp06sW7dOmNvaoY8PT2Jj4+nWrVqtGvXjvj4eH788Ufu37/PN998o9alESA0NJR79+4BcPr0adU0JycnQBmgadiwYe5uxGsWLlxIgwYN6NixIz169MDT05PNmzdz9uxZhg8fTv369dXKz58/n08++UQj7+eff/5J586dadGiBaVKleLx48ds3ryZf//9l8GDBzNixAiNddetW5dSpUpRqVIl7OzsOHLkCNHR0ZQpU4affvoJS0tLY2++0IOMiSSEEEIIIYQQQgghhGGlH3KkVq1aavNu377N06dPqVOnjimapiJBJKGTsLAwgCzTpPXo0YNRo0axevVqZs+ejb29PYsXL6Z8+fIsXryY+fPn4+XlxZgxY+jevbtGEMnJyYldu3Yxbtw49uzZQ3R0NJUrV2blypUUK1bMpEEkGxsbduzYwYQJE1i+fDmPHj2iYsWKzJs3j169emmUX7duHVevXlWbtn79etX/TZs2NXkQqXLlyhw+fJhJkyaxefNm4uPjKV++PAsWLOD999/XuZ7SpUvTtGlT9u7dy507d3BwcKBmzZrMnj07w0HaevTowc8//8yhQ4dITEzE19eXSZMmMW7cuDzbPRTy7xiECm3p7CSGJIQQQgghhBBCCCFEtjVp0oQZM2awfft2jUxX27ZtU5UxJUVqqlwGvH79OqVKleLatWt4eXllWO7FixdcuXIFX19fjdRcQoiCzWfCZgDcnWz4Y5Ju4z8ZizGOVSsPX+WjX86oTTsysTkeLnIsFCJfWNwMbvyZ/eUdPSA+1nDtEUIIIYQQQgiRu6ZqT6dWkOkaN8jMzJkz+fDDDwkPD9faQSMpKYkKFSpw48YNDh06RPXq1QFlers6deoQExPD+fPn8fHxyf6G5JD0RBJCCJElhZaEdjImkhD5yNvr4FwkbPpfxmWKV4Xbp7XP86oN5zcbpWlCCCGEEEIIIUResmTJEvbt2we8GuZkyZIlREdHA9CwYUPVWPRWVlYsWbKEVq1a0bhxY3r27ImzszPr16/n6tWrfPnllyYNIIEEkYQQQujAQkuavlQkiiREvuFQBKp2yzyI5N8x4yCSEEIIIYQQQgiRTwUHB2Ntbc2wYcMYNmxYluX37dtHRESE2rT9+/ezf/9+1fO0IFJa/fv27WPKlCmsXbuWxMREqlatyqxZs+jRo4fhNiSbJIgkhBAiS9rGepKeSEIIIYQQQgghhBAiv4uKitIrnd2yZctYtmyZXuuoU6cOW7du1bNlucPC1A0QQoj8JL+OMqc1nZ1EkYQoYDL7zMvxQAghhBBCCCGEyI8kiCSEECJL2noiCSGEEEIYXd9fTN0CIYQQQgghCjQJIgkhhAHl12CLQsuGpeTXbldCCO3KtVT+LeQN1Uyfk1kIUUAU9jF1C4TIvsK+pm6BEEIIIUSOSRBJCCFElrTFxiSbnRD5jZYPdZewV/87e8L4KzD8D+j0/WsF0x0lWkyDVjOM0kIhhBBCCCGEEELkLitTN0AIIYT5s9Byy0Gq9EQSIn8rXQ/KNlOf5lAkg8LpjgeBIfDoX6M1SwghhBAi3xiwBXwamLoVOjt4+T69Fh+ikIM1hyc2x9bK0tRNythUV1O3QAgh8g3piSSEECJLCi19kaQnkhBCu3ya11MIYVw+jXJnPe3n5c56hBAiH6rjWwRPVzsePUtk86lbpm6OEEKIXCJBJCGEEFnSNtaT9EQSooCzkA7tQog8yK+FqVsg8hNrR1O3QIhcZWmhoE+QNwDzoy6RLHcWCiFEgSBBJCGEMKD8GldRaIki3X2aYIKWCCHMxpszTd0CUVDYupi6BSJfyacna8I0qnYxdQuEyHX96nlTyMGaf+7GE3nihqmbI0TB0fcXU7egQAsODsbf358FCxaYuikmIUEkIYQQWdKWnGr1kWu53g4hRC6ycVLvhqh47bTRySN32yMKrnGXTd0C47ErZOoWCCFywsLa1C0QeZ22lA9mztnOmncblwFg7s6LJCanmLhFQhQQr49XK3JVVFQU586dY9iwYaZuiklIEEkIIQwoD/4G0El+3S4hRDrpL4QV8oa2X4F9YQjoDdV6gKO7enn38q/+L1JGfZ70HBGGZGVj6hYYj3s5U7dACCGE0Fv/ej64O9lw9f4z1v953dTNEUIIYWQSRBImExMTg0KhYMCAAaZuSpZ8fHzw8fExdTOEEEII47G2g/bzoe1s+N8pKKzMd0+nb6HzIs3yHpWg908wZB8EhqjPK1QK3pyluUxBvHsucJCpWyCEEEKI3OJewdQtyBWOtla839QPgHm7LvEySXojiVzQfLKpWyBEgSVBJJFtISEhKBQK3NzcSEiQsVHSe/bsGV999RW9e/emYsWKWFhYoFAoiImJMXXTNFy4cIHu3bvj7u6Ovb09AQEBfPvtt6TqObjP9evXee+99yhdujQ2NjaUKFGCgQMHcu2a9pRnKSkpzJ8/n5o1a+Lg4ICLiwuNGzfm119/zXAdhw8fpkOHDri7u2Nra0u5cuWYPHkyz58/16utQn/axku1s5KvECHynZp9obYeQY/yLaF4VbBM14spreti0BDN8m9MzVHz8qT6w9Wf2zhDg1GmaYsQQgjDyqq7vnTnL3iGHzF1C3LN23VL4+Fsy41Hz/nxD0l1LnJBo7GGr7P2O4avU4h8SK4Aimx58uQJP/74IwqFggcPHrBhwwZTN8msxMbGEhoayurVq3nx4gWFCxc2dZO0OnfuHHXq1CEyMpLWrVszcuRIkpOTGTp0KCNHjtS5nsuXL1OrVi0WLVpEpUqVGDVqFHXq1CEiIoLAwEAuX1YfyyA1NZXu3bszYsQIHj9+zKBBg+jZsyfnz5+nQ4cOzJ8/X2MdP//8Mw0bNmTbtm20atWK4cOH4+bmxvTp02nRooUEMk2gbTVPUzdBCCGEyNv0vGlHCCGEMBd21pa837QsAN9GX5axkUzFseh/f2W8UiGE8UgQSWTL2rVriY+PZ/To0VhYWBAWFmbqJpkVd3d3tm/fzv3794mJiaF27dqmbpJW77//PnFxcWzYsIHly5cza9Ysjh07RqNGjZg/fz4HDx7UqZ5Ro0YRGxvL3Llz2bZtG1988QUbNmxg7dq1xMbGagw6t379etavX0+DBg04ffo08+bNY9GiRZw9exZvb29CQ0PVem09f/6cIUOGoFAo2L9/PytXruSrr77i4MGDDBs2jP379/P1118b8qURr0nfM61kIXsArC3lK0QIIbKksDR1C4RZy+NBpPojTN0CIUxLAsGGUbyqqVtgOhZWpm5BjvSsXRp3JxtuPHrOhuM3TN2cgmnAFqjaDQZsMnVLDK90fVO3QAjxH7kCKLIlLCwMKysrxo8fT3BwMDt37uTq1atayyYnJzNr1iz8/Pyws7PDz8+PGTNmkJKi/S6VqKgoQkJCqFChAk5OTjg5OREYGMiiRVrGYwAUCgVNmzblxo0b9O7dG3d3d5ydnWnbti3//PMPAH/99RcdO3akSJEiODs707VrV+7cuaP3dj969Ij33nuP4sWLY2dnR40aNVi9erVGOScnJ1q0aEGRIkX0XkduuXDhAnv27CE4OJjWrVurptvY2DB9+nQAFi9enGU9L168YNu2bRQrVowRI9QvJHTr1o3q1auzbds21XsBEBkZCcDEiROxt7dXTXd3d2f06NEkJCQQHh6umn7gwAHu3r1Lx44dqVWrlmq6QqHg008/BeC7777TOwWfMZhBE4wi/XZZWig0pgkhhMhAodKmboEQxtNiei6sxNDpwCS9mBDCjJQMNHULcsTexpJBDcsA8O3uy6Roy4MujKtoeeiyBIrmw/G4SpnnDdlCFEQSRBJ6O3fuHIcOHaJly5YUK1aMfv36kZKSonbRP713332XCRMmkJKSwrBhw2jVqhWzZ89m1Cjt4wHMmjWLPXv2ULt2bYYPH06fPn24d+8e7733HmPHas9/+vDhQxo2bMiVK1fo378/TZs2ZcuWLbRo0YIzZ85Qv359nj59SkhICIGBgaxfv55evXrptd0vX77kjTfeYPfu3fTt25eQkBCuXbtG7969mTdvnl51mYPo6GgAWrZsqTGvYcOGODo6snv37izruX//PklJSXh7e6PQkvPb19cXUAYH09y+fVttnrbyu3bt0ql8oUKFKFy4MFevXlULVAnjUQWR8vrd00IIw7F2fPW/9Lx5xbGojIchhBD5mpwPixywdQWLvH9Zrk9QaZztrPjnbjw7/tL/Zl0hTMrS1tQtECJPyNv9Zs1Eamoqz5Oem7oZWbK3std6kV9faanr+vbtC0Dnzp0ZOnQo4eHhTJ48GYt0J0HR0dEsXbqUgIAA9u/fj6Oj8iLTxIkTqV69utb6v/32W41gQVJSEm3atGHu3LmMGjWK0qXV7+o9deoUo0ePZvbs2appQ4cO5dtvv6VRo0ZMnTpVFbRKTU3lrbfeYsuWLRw7doyaNWvqtN23bt2iXLlyHDhwABsbG9V21KhRg3HjxtG5c2dKliypU136iI6OVgV8dOHj48OAAQOyLHfx4kUAypUrpzHP0tISX19fzp07R1JSElZWGR8qChcujKWlJVevXiU1NVVjH7ty5Qqg7PmUxt3dXTWvUqVKepV/XVxcHA8fPlQtU7Zs2QzbmhsKwrXCgrCNQgg9ObrBW3PA0gas7UzdGvMngTYhTEgu+hcIbb+CzUYYgF0YifzAyMuc7azpE+TNt9GXWbznH1pVLm7qJonsqvMeHPneePWXrAU3/jRe/dlR4U04tMDUrRDC7EkQyQCeJz2n7qq6pm5Glg73PoyDtUOO6khMTGT58uW4uLjQsWNHQJm6rVOnTqxYsYLff/9drWfLDz/8AMDkyZNVASSAkiVLMmrUKD7++GONdWjrbWJlZcWQIUPYsWMHUVFR9O/fX22+k5OTKq1Zml69evHtt9/i5ubGyJEjVdMVCgU9e/Zky5YtnDx5UucgEsBnn32mCiABeHl5qbZjzZo1GfaUyono6Gg++eQTncs3adJEpyBSXFwcAK6urlrnu7i4kJKSwpMnTyhcuHCG9Tg4ONC4cWOioqJYuHCh2vhHP//8MydOnACUqQDTtG7dmjVr1jBz5kyaNWuGnZ3yguP9+/eZM2eORvkGDRrg4uLChg0bOH78ODVq1FDNmzx5sur/9MsIw9IWOJJ0dkIINYEDsy5TLIdjHoRsg6WtMi/jGQC3TuZsPcZWbzjsm511OSEKCldJ+ygMyKMy1H7HTIJIEhwRrwnoDSdXmboVBjewvg9he6/wx9WHHPv3ITVLZ3wNQZixNp8bN4jkVdv8gkh5fFwykXuCg4OxtrZm2LBhGmO/FwR5v9+syFWRkZHcvXuXbt26qS78A/Tr1w941UspzcmTyos4jRo10qhL2zSAJ0+eMGXKFAICAnByckKhUKBQKOjSpQsAN2/e1FimXLlyODioB8g8PT0BqFatmkbvmLR52urKiJWVFfXq1ctwO44fP65zXfqYOnUqqampOj/06bVkKF9//TVOTk4MHz6cN998k/Hjx9O5c2e6detGtWrVANR6qPXu3Zvg4GD27t1L1apVGTFiBEOGDKFy5cq4uLholHdycmL27NkkJiZSr149+vTpQ2hoKPXr1+e7776jYsWKGssIw0r/GUr7T2JIQgi95cZxus8vxl9HelW76VdeoVD23BLatZtr6hYIU3hjiqlbIIQQuaOF7jeI5iUeLna0CygBQNg+zQwiBV5eSJnmVExzmruBx1kyxJ2o5TSHZBAiN0RFRXHu3LkCGUAC6YlkEPZW9hzufdjUzciSvZV9jutICxKlBY3SNG/enJIlSxIZGcmDBw8oUqQIoOztYmFhoUpHll6xYppfUC9fvqRp06YcO3aMGjVq0LdvX9zc3LCysiImJoaIiAgSEhI0lksLPKSXloIts3mJiYlZbbKKu7u71iBF2nak9ezJK9J6IGXU7sePH6NQKHB2ds6yroCAAI4ePcqUKVOIiooiKioKPz8/vv/+ex49esS4cePw8PBQlbeysmLr1q3MnDmTVatWsWjRIlxdXenUqROhoaGUL19erTzAoEGDKFGiBJ9//jmRkZEkJydTu3Ztdu7cyaxZs/j77781lhGGkz4Ma4i0mEIIoWbCNVjWBm6fNnVL9Bc0FE7/ZOpW5B+1BsBG7eNmCiGETnL1XDUH63p7HazsarimiLzBKf/+Zh3U0Jf1x67z25nb3Ip7jqdrzq9BiVw0+qzmNN/GcO+8AVeiZxCp7vuw/7UbjHr/CJ8U0l6+bDO4vEv7PCFEjkgQyQAUCkWO08TlBdeuXWP79u2AMmVaRlasWKFKH+fq6kpKSgr37t2jaNGiauXu3NEccDEyMpJjx44xaNAglixZojZvzZo1RERE5HQzsu3evXukpKRoBJLStiOjtHA5ZawxkdLGQkobGym95ORkrly5gq+vb6bjIaVXsWJF1q5dqzE9rS2BgYFq021tbZkyZQpTpqjfeZq2ra+XB2UavNatW2tM79u3LxYWFnqlJhT60Z7OTvoiCSEMxE7zho88wzMgkxR6EnQXIsfk5hWRX5Vrof8yH9+D6Zo3aAphDvxLuBBUpgiH/nnAykP/EtrKwL1YhH76b4SIdrqXt7Q2Xluyw60cuHhqTs/svMDeTNIotvw/2P6RqVshhEHlqyDS+++/z3fffce8efMYPny4qZuT7yxbtoyUlBQaNmxIhQqaJwNJSUlEREQQFhamCiIFBARw7Ngx9u7dS+fOndXK7927V6OOy5cvA9ChQweNedrK56akpCQOHjxIgwYN1KantSv9OD2GZKwxkdICgdu3b2fChAlq8/bt20d8fHymwUJdPHnyhI0bN+Lm5kaLFrr9SFq5ciUAPXv21Kn8/v37iYmJoU2bNkYL5AmwSHeiZvHfvxJCEkKYpdy+4GxhCe/uzviOSCGEEPlXbn/nmNtFXiFe07+eD4f+ecCao/8yorkftlaWpm5SwaXI46+9RXbabyY3nmSr7UKYt3wzgMimTZs4ePAgJUqUMHVT8qXU1FTCw8NRKBRERESwZMkSjceyZcuoV68ep06d4o8//gCUPUQApk2bRnx8vKq+GzduMHeuZs57b29vQBnESG/37t0sXrzYWJuns4kTJ/Ly5UvV8+vXrzN37lxsbW11Dnroy1hjIlWoUIHGjRsTFRXF1q1bVdNfvnzJxx9/DMA777yjtsy9e/f4+++/uXfvntr058+fk5SUpDYtISGBQYMG8eDBAyZPnqw2hhYo0+W9bt26dSxdupTatWtrBB21lb958ybvvPMOVlZWTJ8+XYetNp6P2lQC4ItuASZth7GopbMzlxMzIYQwF9JTQoi8wcKAF+DtixiurvwioJepW5C/dVtm6hYYT0H9Hs2Hm/2GfzGKu9hx7+lLfjtz29TNEQWO3OoqhLHki55Id+7c4f3332fLli20a6dHV02hs127dnHlyhWaNGlCmTJlMiw3cOBADh48SFhYGIGBgQQHBzNw4EDCw8OpWrUqnTp1IiEhgbVr1xIUFMSmTZvUlm/Xrh0+Pj58/vnnnDlzhipVqnD+/Hk2bdpEp06dWLdunbE3NUOenp7Ex8dTrVo12rVrR3x8PD/++CP379/nm2++oWTJkmrlQ0NDVcGW06dPq6Y5OTkBygBNw4YNc3cjXrNw4UIaNGhAx44d6dGjB56enmzevJmzZ88yfPhw6tevr1Z+/vz5fPLJJ0yZMoWpU6eqpv/555907tyZFi1aUKpUKR4/fszmzZv5999/GTx4MCNGjNBYd926dSlVqhSVKlXCzs6OI0eOEB0dTZkyZfjpp5+wtFS/c+Obb75hxYoVNGzYEA8PD65du0ZkZCTPnj0jLCzM5KnsBjcuQ7/63gXiTivVbzw5PxNCCJFfFfaFhzIweJYUCmg4GvZ9beqW6KaZDqllavaHYzqk0HZwg+cPct4mc1W8qv7j1LWfBydXG6c9eZUhgyOVO+VseWtHSIzPupwuKrWHv341TF0iX7G2tKBXndJ8/fsFVh7+lw7VS2a9kBCZKd8aLmzNupwu6g2Hg/MNU1dWrOwg6UXurEuIXJCtnkgrVqzgvffeIzAwEFtbWxQKBcuWLct0maNHj9KmTRsKFSqEo6MjQUFB/Pjjj9lZvYaBAwcycuRIqlatapD6hKawsDCALNOk9ejRA3t7e1avXs3z588BWLx4MTNmzEChUDB//ny2bt3KmDFjmDNnjsbyTk5O7Nq1iy5dunD06FHmz5/PzZs3WblyJcOGDTP0ZunFxsaGHTt20KRJE5YvX87SpUvx8vJi1apVWoMk69atIyIigoiICG7evAnA+vXrVdMuXbqU25ugoXLlyhw+fJj27duzefNm5s6di4WFBQsWLOCbb77RuZ7SpUvTtGlT9u7dy9dff83q1avx8/Nj3bp1LFq0CIWWH089evTg9u3bhIeH880333Dnzh0mTZrE8ePHVT3S0qtfvz6lSpVi48aNfPnll/z++++0adOGo0eP0r9//xy9DoaSnwNI2t7DVIkiCSGEyK+GHTZ1C3KuarfcWY8h0uVY2WVdxlAssriP0so2d9ph7hqFmroFeUOBHSO0oG630EWP2qWwUMCRKw+4FPvE1M0RZkWfwHo+7KonRB6WrZ5IkyZN4urVq7i7u+Pp6cnVq1czLR8VFUWrVq2ws7OjZ8+eODs7s379enr06MG1a9cYO3ZsthoPyp4R8fHxOapDZG3VqlWsWrUqy3IuLi48e/ZMbZqlpSUTJkzQGHcHlGnyXufr65thjyNt5bVNA/Dx8clwXtOmTTOcp01MTIzq/++//57vv/9er2XMWYUKFfjpp590Kjt16lS1HkhpSpcurXdQOKO6MtKsWTOaNWum1zqE4RTUDBNCCDPi5mfqFoiCRAIJuadYVajQGg5/lzvrc/KAwBD4Y2nurC+vkpM/IUQ2FXe1o1lFD37/K5a1R6/xUVt/UzfJ9OSYasbkvRFCF9nqibRkyRJiYmK4e/cuQ4YMybRsUlISgwcPxsLCgj179rBo0SK++uorTp48Sfny5Zk4caJGEGrChAkoFIpMHwB///0306dPJyIiAguLfDO8kxBCmB21MZH+OwYX2BsvhRCmUapuPjnwyA9VIdT0WJ75ANQ2ToZf51uZpd8z9GdUPvNCGNTUOAgcZOpWiCx0DywFwM/HbvAyKcXErRFmz8VLy8T8cN4vRP6RrcjLG2+8oTXdlDa7du3i8uXL9O7dm+rVq6umu7q6MnHiRF6+fElEhHrO6bFjx/LXX39l+gA4dOgQd+/exc/PDysrK6ysrLh69SqjRo1SW5cQQoicSZ/OTjUkkpzTCSEMqfUXGc9rFAqtPtOtHrtC2qf7vaF3k3Is7djp3zH31y3yBmN+mTYeZ7y61RhxG5yKg6O78erXxtB3i9d427D1iay1yeT7JCvWDnoukIdOiM25J0Trz03dAmFgwRU9KOpsy/34l+z8646pm1MwGeIzb+zjRtGK0HAMDNxi3PWk985O46+jajfoGm64+hqPh3ejDVSZAqp0MVBdoiAxeved6OhoAFq2bKkxr1WrVgDs3r1bbXrRokWpWLFipg+Ajh07curUKU6cOKF6lChRggkTJmSYDg0gISGBx48fqx5PnkiOViGE0JU5//4UQuRh3vWg7wbN6a6loPnHYF9It3oy6p3u3SC7Lcui18JrnD1f/Z8WIOi2LPNlilXRu0kin7N1zdnygSFQtELmZTosyNk6MmOI9JOT7kLo+ZzXY2qGGDOqIBh3WXNaQC/96+m1BrzrZ78dgSHZX1ab0kGGrS8nzPUOsEl3zet1MiRLG3BwM3UrTMLa0oJutZS9S1YfvWbi1hjQpLvZXFB+RAPgXFz9uVMxeGMKFM6so4KBj11egYat73W131HeAFOxjeHqbPYRlKhhmLomxULnJYapSxQoRg8iXbx4EYBy5cppzCtevDhOTk6qMvoqVKgQVapUUXtYW1vj6emJn1/GP1xmzJiBq6ur6uHvL/lZhRBCV2lBJDP9GSqEyMssbV79/8Ynyr/tv8l8GV3upGuk59iZr6fPsrDWb/nXZRV9r9Y9Z/WL/KHtV6/+z41rTek/b4aW088MgJUR25cZc73Qnt9ZatlnFNm4XKGtnvRyGqDVV6sZubu+TJnhvt1jhek+67nB0oaCHDzoWbs0AHsv3uXmo+cmbo2BZHt/1fL582uhOa1LGHT/IZvrMBCtqeUMJKsbXNRk8tnJ6CYDc/gON+b5lSFY2cidwdkUHByMv78/CxYY8UYsM2b0IFJcXBygTF+njYuLi6pMbvnwww+Ji4tTPc6dO5er6xdCiPwg1RxO0IQQ+VfD/ynvlCvbLOMyngHKH9uZ/didFAvNJ+u37vFXwLO6fsvkhGPRrMuUf9P47TCmCga8G9MU6o80/jpq9tc+PX1wKT/K1QsZBfyiyaRYI68gB69vbp1Xlsvl1KZ2Lrm7vtxQtZth6pkUC5XaaU4vUL1z8/cxqbSbA3V8i5CaCr+evGnq5uScg4HTq5aoAf87oz6talfw72DY9ejLpcSr/036mz+TdTf9MHtVSvBE5EBUVBTnzp1j2LBhpm6KSRg9iJTbYmJiGD58eKZlbG1tcXFxUT2cnZ1zqXVCCJH3Kf77sSMhJCGE0VnZZj7f0jbrH4NZ1aF1mWzcoVemqfKv0Qb7LqA/evP74Om67GeW2diHRQ7k4zMcfY6HefZmoSw+U5Xa5+768qpiVTOepy3wkx3Z+X7Ok/LqZ8kwOtUoCcCaI/+SmJxi4tbkUKFShq8z/Y1EDf5n+Pp1VWtALq3IUMfMgv25EsIUjB5ESuuBlFFvo8ePH2fYS0kIIYT5kZt3hBA6G3cZBkfBkH0w5m8dFjDQD8KJN6HZx/DhdcPUp4ueq2HgVmg0Rv9lze1ibdFKMHiXqVuh1HqWqVuQj8kXuihAileFQTtMf4d/XvDenoxTTg3a8V8gTo4fOjG373cTaBdQAjdHG2LuP2PRnn9M3RzteqyA9w+aZt3WdvD+AeU4hfr2nDek1l+keyL7rRBCk9GDSGljIWkb9+j27ds8ffpU63hJQgghzFPaT0b5TSSEyJKjO5Ssqbx45+JpgAp1PPDYOELjULDNxd7mNg7KwdwtLHNvncbi4AYla5m6FUpZjW+S1+Xky7R6n+zNE7mrTBPlXwc307ajoLMrBKXqyN1QuvAMyHie2byG8kMkr3CyteKD1hUB+GLbeX45nos3+OjK1QuK6TBWesPRxll/scpQo4/hzyE99Bj/3VRjk5nF8cTQ8uM2CZELQaQmTZQnzdu3b9eYt23bNrUyQgghzJPa+EeqEz358SaEEAZhbj+gza09eYF/B81xDXSiy3dpBmU6zIfRZzWn2xUCLxMHAV/fh7xq534b7Itkc0ED7//OJSD0kvb3ypDKtTRu/SYhxyLzVYB+B4w+C07FTN2KPK1bLS8G1PcBIPSnU2w/e9u0DcqOwr55ryejfSFTtyCHcvE7wE4yZAmRFaMHkZo3b06ZMmVYtWoVJ06cUE2Pi4vjs88+w8bGhn79+hm7GVotWLAAf39/mjZtapL1CyFEXiQ/54UQZiO3gh3ZXY8h2yeBHRPT4YKpIcdK0OX6rEKhvHv6dc7Z6PVXJlj/ZfRh7ZDxvHzftTkVnIqCtb3xVqGwBEcP49UvMmao8YGE+XL1Qn4B5YxCoWDyW/50qlGS5JRUhq06xo5zd0zdLP0U9jZ8nXnh3M4c2li9t/JvZmO15ZTC6JfHTUeRD7IkCLOQrU/JkiVLGDBgAAMGDOCnn37SmLZkyRJVWSsrK5YsWUJKSgqNGzfm3XffZezYsQQEBHDhwgU+++wzfHx8DLIx+ho2bBjnzp0jOjraJOsXQoi8KO08Mt9f8xFC5CEF+IDUbZnp755083v1f/MppmtHVrITXDGmLmHqzy2swLcJeFYHt7K525bea8G3sdxtn11tvjRsfcP/hA4LdS+vy0W+Gnk0xWGfn8HWFd74JHfW1zj01f9Vu2Vd3ru+8dqSXsfvcl5H41BwLAp+LXJeV0FTOsjULcjzLCwUfNG1Gm2qFicxOZV3l//B1zsukJScYuqm6S9bae3MIBiTV1VqD0P2wzs7dCiczd8EPVdnbzlz0v0H7dPNIRAo8oVsBZH27dtHREQEERERHDt2DID9+/erpu3bt0+tfHBwMPv27aNBgwasXbuWb7/9lmLFirFmzRrGjh2b860QQghhVAotJx4F+JKtECJPyqdHrcqdYHwMFPYxXRssbV/932iM+d7NGTTUMPUYIp1NQC+o2lV9mkIB/SLh3ejc/8FvZQv9N8KYv7Mu615e/bkx7ioxds8og0i33XUGG7Zqdz+wL2zYOgMHGbY+g8tgP/JrDh/E5F4aKfvCMPkhTHkEXZZAQO+slyla0ejNonovzWn63kDg7AljL8AbugT7TfidaY53qmV3fD65eKvGytKCuT1r0D3Qi9RUmLvzIt2/P8i1B89M3TT9vDEV3tmp50JmuF9nyMz2W4UCilcxbo9e73pQ8S3j1W9sfTcovyfz8jYIs2eVnYWWLVvGsmXL9FqmTp06bN26NTurE0IIYUbM7JRSCCGEhZkFbRQWkPrancXmcFHQUBfzDBEky6gOU19w1GVfys72m3q7RN6W28e49OvTZd811fEtO59Fc/u+EAWOtaUFn3cNIKiMG5M2nOHYv49oPXcv49+sQK86pbG2NMU+mo3vKPleyx59Xrfcfo3z8ntqrjdwiXxF9jJhMjExMSgUCgYMGGDqpmTJx8fHZGkXhTA3ab2SzOF6oBBC6E7PH4Z57SCX19orhBBCaJOXL+QKnXWu6cWWkY0I8HLlaUISkyPP8uHPp03dLJHvyfFFiOySIJLItpCQEBQKBW5ubiQkJJi6OWbl2bNnfPXVV/Tu3ZuKFStiYWGBQqEgJibG1E3TcOHCBbp37467uzv29vYEBATw7bffkqrnxajr16/z3nvvUbp0aWxsbChRogQDBw7k2rVrWsunpKQwf/58atasiYODAy4uLjRu3Jhff/01w3UcPnyYDh064O7ujq2tLeXKlWPy5Mk8f/5ca/mHDx8SGhqKn58ftra2FC1alK5du3L27Fm9tk2oSzvtSs1TXfKFEHlCVimB3CsYdn2eAYatL7c5ehi3fp9GupUr5q/+vN6w7K2vbPPsLZdTdoX0Kx+gJa1UtZ761VGls37lDcU2XforY6aFyY6a/Q1bX/qxuvRhNhewTX2epef6rcxsf8otFtlMc2ZOyjTN/XVWaJuz5St3yrpM2WY5W4ehNBhl6haYLR93R9a9X58PW1dEoYB1f17n/zafIyEp2dRN01Sjr/pzfQ/RpesZrCk6eX2MQyu73F2/seX0HFjbmHe6pC81hnrDc16Hzucu5nKOI/IiCSKJbHny5Ak//vgjCoWCBw8esGHDBlM3yazExsYSGhrK6tWrefHiBYULGzinuYGcO3eOOnXqEBkZSevWrRk5ciTJyckMHTqUkSNH6lzP5cuXqVWrFosWLaJSpUqMGjWKOnXqEBERQWBgIJcvX1Yrn5qaSvfu3RkxYgSPHz9m0KBB9OzZk/Pnz9OhQwfmz5+vsY6ff/6Zhg0bsm3bNlq1asXw4cNxc3Nj+vTptGjRQiOQef/+ferWrctXX32Fh4cHw4cPp0WLFmzcuJE6depw+PDh7L1oBVT6oKLZXFsRQuQ/ju4w6iSMu5zBfDflfEMJ2ab5I1sXhvixl152exG9/VPWZUafVY5zM/GW/vX3WQ+91mZdzrGo+vtW5z391wXQa032lsuJss3Axkm/Zcq3gvcPqk+r2Eb35YceAr83lP/r8t7ru3+k/6K2cVafZ2ULo88p9wmLDDKbh16EopX0W2dOpLX39TGicmLUKXAoYrj6dBF6CToszN11qslh0KnFNGg/L3vLuleA0PM5W39e02MljD2fP9LDVe+Ts+UHbM66jFcdGJru91fj0Jyt07ex8liaGV2+v8b8pXwYU33df1cXRNaWFrzXpCzvNi4DwOK9V2j25W5+PHqNpOSULJbORVW65Gz53B670rm4+nNrBxiyT3vZUadem2DkmxjG/ZPuSTYvLtg4KM89vRtmb/mO32lOq9gGyr+ZvfrS0/eCSYtpOV9nZrI6jyxSxrjrF/lGPjjjyb4FCxbg7+9P06ZNTd2UPGft2rXEx8czevRoLCwsCAsLM3WTzIq7uzvbt2/n/v37xMTEULt2bVM3Sav333+fuLg4NmzYwPLly5k1axbHjh2jUaNGzJ8/n4MHD2ZdCTBq1ChiY2OZO3cu27Zt44svvmDDhg2sXbuW2NhYhg1TvyN5/fr1rF+/ngYNGnD69GnmzZvHokWLOHv2LN7e3oSGhqr12nr+/DlDhgxBoVCwf/9+Vq5cyVdffcXBgwcZNmwY+/fv5+uvv1Zbx5QpU7h48SJjxozhwIEDfPXVV6xatYro6GgSEhIICQkhJcWMTkrzEAWSzk4IYUSFfZTBpMzmG4q1PXj4Z13udZm1LyuGDMRb2WZdxtULXDyVP7azU3/xKjoUVKi/b9m928DKJnvL5YRLSd3KVWyn/Jt25+vrva/04ZGLARobR81priWV+0RGnDyU+02OmHh8icLehqtLV05FoVDp3F+vrrJ6fd3KZRxYzIp7ObBzzbqcvrKzS2S0nYa+C8rVS/MibV6V09fGpyFY6nD89kjX29gQ43dkdSzV5TvFpYTyYUxyB55OxrWswKS2lSjmYsuNR88Zv/4ULefs4deTN0lOMYMfnnn9fVQooHhV7fPsM7vpOZvbrbDMeJ6jW/bqfF1hHyj3RvaWtUz3fZf+vc3Nc7Q0Fpm8VoaQfvu07cdu5Yy7fpFvFOgg0rBhwzh37hzR0dGmbkqeExYWhpWVFePHjyc4OJidO3dy9epVrWWTk5OZNWsWfn5+2NnZ4efnx4wZMzK8gB8VFUVISAgVKlTAyckJJycnAgMDWbRokdbyCoWCpk2bcuPGDXr37o27uzvOzs60bduWf/5R3uHw119/0bFjR4oUKYKzszNdu3blzp07em/3o0ePeO+99yhevDh2dnbUqFGD1atXa5RzcnKiRYsWFCmSy3dA6uHChQvs2bOH4OBgWrdurZpuY2PD9OnTAVi8eHGW9bx48YJt27ZRrFgxRowYoTavW7duVK9enW3btqneC4DIyEgAJk6ciL39q9QX7u7ujB49moSEBMLDw1XTDxw4wN27d+nYsSO1atVSTVcoFHz66acAfPfdd2q9ZSIjI7GwsOCTTz5Ra1O9evVo164d586dY/fu3Vlun9Div/MOMziVF0IIpepvK/96GeCmDa13NGb3woGOy+l0YSKPX7wwV7peFCpdV3kX/chjhl1/bgcdsnMRzK+FEdph5Asm5irTC3U5EBhinHr1JmeHamoN0H+ZQjoEQNNSjdbsp3/92hSvZph68jND9E4QmbKytOCdRmXYPS6YD1tXxNXemn/uxjNy9XHenLOHTadukmIOwSSDyua5XWa9/9J6ElcxYO9eU/HTIc1x6frqzyv81zPcWN+3ZiGT/aaitlSh8hvCUIKDg/H392fBggWmbopJFOggksiec+fOcejQIVq2bEmxYsXo168fKSkpahf903v33XeZMGECKSkpDBs2jFatWjF79mxGjdKeG3jWrFns2bOH2rVrM3z4cPr06cO9e/d47733GDt2rNZlHj58SMOGDbly5Qr9+/enadOmbNmyhRYtWnDmzBnq16/P06dPCQkJITAwkPXr19Orl5Z89pl4+fIlb7zxBrt376Zv376EhIRw7do1evfuzbx52Uz9YEJpwdOWLVtqzGvYsCGOjo46BVnu379PUlIS3t7eKLRcmPD19QWUwcE0t2/fVpunrfyuXbt0Kl+oUCEKFy7M1atX1QJVt2/fxt3dHScnzRQ12tYh9KfvuFlCCGE4/33feP/3w7HJeHh7PfT5OedV99YhpVpupvoyJo/Kpm5B7srpmFoeFcHWWcv0dK/jiGPwjh7nF05GHtfKELpHKD9fOfH6WBBDdevtni1edZR/S9Ux3jrg1Tgytln1wEl3fjz8D3A1cODw7fXQ8lPD1mks/SK1Tx91ynzH6yhR49X/ugZt7ArB2+ug9ef6ravvBqjWPetyvVYr3/dmH2de7vXXOy1lkcNrvWmLltdctmRg5nV7GfHzpUsgLbdV6aJ8f95el7N6fJsYpDn5mZ21Je81KcveD4IZ/UZ5nO2suBj7lOGrjtNyzh5+OX7dvNLcmUKxTHqJh/ymTEesz/iU+t5gUljzuozOXEvpXraCDqmCS9eFhmNePa/aTflZHf6H3k3T8G50zuswhqKZnM/qO0an0EtUVBTnzp3TyLZUUGSzv7pILzU1ldTnz03djCwp7O21XuTXV1rqur59lQMLdu7cmaFDhxIeHs7kyZOxSJcXOjo6mqVLlxIQEMD+/ftxdFSm1Jg4cSLVq1fXWv+3336rESxISkqiTZs2zJ07l1GjRlG6tPqPr1OnTjF69Ghmz56tmjZ06FC+/fZbGjVqxNSpU1VBq9TUVN566y22bNnCsWPHqFmzpk7bfevWLcqVK8eBAwewsbFRbUeNGjUYN24cnTt3pmRJHVOi6CE6Olqv3nI+Pj4MGDAgy3IXL14EoFw5za6rlpaW+Pr6cu7cOZKSkrCyyvhQUbhwYSwtLbl69Sqpqaka+9iVK1cAZc+nNO7u7qp5lSpV0qv86+Li4nj48KFqmbJly6qWiY2N5enTpxqBJG3rELqT+1iEECY38hhc3PHqgp6ltQ7pLHQIfI84pj1I8LryraDdNxmnBckLgoZCg//B7dOwMod5/vXVagbYOsGvI7Iua0jlMupRk8ObIt6Y+up/t7KZl339jlmjM8ANHzaO2U8XA1C5MzQcDdEzXk1zMFAqm9e1mvHqIny9EXBgHjx/aJx1tZimTAGjz3hYju5QpjEcX6F9ftdw/V6b8q1z9t7ktrTAW3qlgnRMP2iEm5fcK0CjsfDLuxmXqTVQ+denIbj5wbEfsq7X2iGT400mygbD1QNZl7N11u19f/31Lh2k/OviCd0i4Mz6jMcmqjNY+d26JYP5Qe8rU8LqGhTR5zpE8apQuaMyZWB46yyLZ6p0PfjXAEFrhUL5/uSUtX3WZQQALnbWjHqjHAMa+LB03xWW7r/CpdinjF57krm/X2RAfR/aVy9JEUcdU+G2m6scA3H9IOM2PCP6Xour0kX5Gc3I8D9gZVfN8STtC70ad9EY3pwJ/h1h24dw9hf9l9cpTXIaHV8z7/qw779rgYb6rIJx0rQaW34Yp0+YLQkiGUDq8+ecr1kr64ImVuHYnygcspEPP53ExESWL1+Oi4sLHTt2BJSp2zp16sSKFSv4/fff1Xq2/PCD8iR78uTJqgASQMmSJRk1ahQff6x595S23iZWVlYMGTKEHTt2EBUVRf/+/dXmOzk5qdKapenVqxfffvstbm5ujBz5ajBLhUJBz5492bJlCydPntQ5iATw2WefqQJIAF5eXqrtWLNmTYY9pXIiOjpaIyVbZpo0aaJTECkuLg4AV1ftX4wuLi6kpKTw5MkTChfOuCuwg4MDjRs3JioqioULF6pF5H/++WdOnDgBKFMBpmndujVr1qxh5syZNGvWDDs75Z2H9+/fZ86cORrlGzRogIuLCxs2bOD48ePUqPHqjsDJkyer/n99HeHh4XzyySd88cUXqumHDx9m06ZNGuWF7vJ6OmghhJmwdoDEZ9lbtkgZqPueYdsDrwUAMrlgqVBArf4Zz88t2e0R6lQM3vzvgr5zMf2We6p/OmAN9YbC+a05r0dfxvoC02e8KQfzTXWcfVnsh3Xf0zJ+VxbvRXbfq3pDX/1vZQN13oXds7JXV3ql68O/r13ct3GEoCE5rzu9Kp31K++VRW8RfWi85q89dyoOT28bbn3mIqBH5kEkSytlQEUfeeFkuXJH5SMjltbK7d41HV7EZTw/u7J6jXTpkaWLMsGGCSIJk3G1t2Z0i/IMauTL8oNXCdt3hZj7z5i68Rz/t+UvmlcsRo/apWhcviiWFpnsV2npJbUFkczxM/t6j8HXuZeDUSdzpy3pBb2v/JvlTVcGCPyb4/uijYc/xJ4zdSuEMDoJUQq9REZGcvfuXbp166a68A/Qr5/yTuC0XkppTp5Ufqk1atRIoy5t0wCePHnClClTCAgIwMnJCYVCgUKhoEsX5V2yN2/e1FimXLlyOLwWIPP0VA4YXK1aNY3eMWnztNWVESsrK+rVq6cxPW07jh8/rnNd+pg6daqyt5uOD1OM8fX111/j5OTE8OHDefPNNxk/fjydO3emW7duVKumzLGdvoda7969CQ4OZu/evVStWpURI0YwZMgQKleujIuLi0Z5JycnZs+eTWJiIvXq1aNPnz6EhoZSv359vvvuOypWrKixzLRp0/D09OTLL7+kYcOGhIaG8vbbb9O4cWP8/f01ygvdKf67qCDZ7IQQOSIHEQMz0usp75NppH/d88pFFIGMCZTHmOVnS/ah/MVIAfMCyMXOmmHBfuwZH8xHbSrh5+FEYnIqv529zcBlR2n8eRRfbT/P1fvxpm6q8cl+I4QwAemJZAAKe3sqHPvT1M3IksI+512n04JEaUGjNM2bN6dkyZJERkby4MEDihRR3mUZFxeHhYWFKh1ZesWKad75+vLlS5o2bcqxY8eoUaMGffv2xc3NDSsrK2JiYoiIiCAhIUFjubTAQ3ppKdgym5eYmJjVJqu4u7trDTqkbUdaz568Iq0HUkbtfvz4MQqFAmfnrNP6BAQEcPToUaZMmUJUVBRRUVH4+fnx/fff8+jRI8aNG4eHx6uc/1ZWVmzdupWZM2eyatUqFi1ahKurK506dSI0NJTy5curlQcYNGgQJUqU4PPPPycyMpLk5GRq167Nzp07mTVrFn///bfaMl5eXqo2bd26lSNHjlCqVCmmTZuGj48PPXv21FiH0I2cswoh8hcdL9iVbQaXdyn/mkq5VnBxW/aXL+wLD6/olmPe4Mzkwqh3Q7i6z7B1uuQgnbFXHTi6JPMyRSvqV6fOgb9svifu5eDu39lb1pi07dee1XO9GdmW03FgSgfBiZUZz3fSo9ehNkUr5O2eSEYdi8ZMjm9CC3lv8hsnWysGNy7D4MZlOHHtET/+cY1fT9zkxqPnzNt1iXm7LhFUpgg9a5emozEbkld+E7uXh3sXoGJbU7dEO3O+uGCfcUYeFZ9G8Oy++rSSumc7ypZyLZXnYY4eEB+b8/rM+T0QZkWCSAagUChynCYuL7h27Rrbt28HlCnTMrJixQpV+jhXV1dSUlK4d+8eRYsWVSt3545mOpTIyEiOHTvGoEGDWLJE/Qf1mjVriIiIyOlmZNu9e/dISUnRCCSlbUdGaeFyylhjIqWNhZQ2NlJ6ycnJXLlyBV9f30zHQ0qvYsWKrF27VmN6WlsCA9XTbdja2jJlyhSmTJmiNj1tW18vD8oUda1ba+bF7tu3LxYWFhqpCUuWLKmxH4Gyd1dG6xBZSzvHSJUfZUKIgqRrOJyLBP/2hq9b14v+nRfB2Z9h0+jsrSdkmzIIVcXIYyCZU+8lx6IQf/fV8x7L4a+NsHFkxsvoatAOePYAimgZYHr0OfjaP+s6qnYDUjMfxD6rAe6H/wHzDX1Ok8l7WDwAqnaHQmmDY5vJxYdO32lOq9AaOi+Bn99RPjfXCyXt5ymDxOnp+zmq/nbG44zV6AMuJbLXtjQla0LtQfBjv6zL5iYHN80LeNrUN8BnPqd8GkHMXlO3wvyZ6+dUmJ3qpQpRvVQhPmpTic2nb7Huj+sciXnAoX+Uj452WdehxlBjaBmEgc6lBmyGC7/pd+5njPO41p8r00uau/Tbbl8YuoRlPI5Wu2+gUjtYli5A12mRcixIYwr+SBkc9GsOsytlXV6u2wgDKdBBpAULFrBgwQJevnxp6qbkCcuWLSMlJYWGDRtSoUIFjflJSUlEREQQFhamCiIFBARw7Ngx9u7dS+fO6gfSvXs1T6AvX74MQIcOHTTmaSufm5KSkjh48CANGjRQm57WrvTj9BiSscZESgsEbt++nQkTJqjN27dvH/Hx8ZkGC3Xx5MkTNm7ciJubGy1a6Da47MqVyjsoe/bsqVP5/fv3ExMTQ5s2bXQK5CUnJ7NmzRqsrKxUKRKFfiSdnRCiwFEolAMV6zsOkqEvhNkXgpr9XwWR0tevyzHZuRjU1OECsL7tNucLftV6wMH5r547FFG+j7oEkazsIOlFxvNL1cl4nmsGvZNef60sLCBAt3OeDLmXy3wdxpBVMDXLNhjhJELbANgKBVTr9iqIpBMT7M+6fC6zYmGZ8bwOC3SrQ+3kTst75K/5Gy3HMt1XdHgv7ItkHUSysFaOkWVI2fmcZTmGiMj/zPj7Mg9ztLWie2ApugeW4sq9eNYevcbKQ1fVyizbf4UutbzI9FNYrIruQaS88lvYySNn3zGGOqcwxlimBpPJNmaWgUDb74KAHjlvTlas7aBm39cm6vpdqqWcXNgROirQA4IMGzaMc+fOmWQMmbwmNTWV8PBwFAoFERERLFmyROOxbNky6tWrx6lTp/jjjz8AZQ8RUI5PEx//KjftjRs3mDt3rsZ6vL2VaRz27VNPM7J7924WL15srM3T2cSJE9WCjtevX2fu3LnY2trqHPTQl7HGRKpQoQKNGzcmKiqKrVtfDW798uVLPv74YwDeeUf9B/e9e/f4+++/uXfvntr058+fk5SUpDYtISGBQYMG8eDBAyZPnqw2hhYo0+W9bt26dSxdupTatWtrBB21lb958ybvvPMOVlZWTJ8+XW1eYmIiz58/V5uWkpJCaGgo58+fZ8SIEZQokcM7MoUQQgi9ZfPHuDkHavI0eV2NL6vXWN6DbDO344JciMqCmb1fhiTvvTATvu6OTGhdkaOT3lCbPnXjOQI//d1ErcqAuR3DRf4k+5kwkALdE0nobteuXVy5coUmTZpQpkyZDMsNHDiQgwcPEhYWRmBgIMHBwQwcOJDw8HCqVq1Kp06dSEhIYO3atQQFBbFp0ya15du1a4ePjw+ff/45Z86coUqVKpw/f55NmzbRqVMn1q1bZ+xNzZCnpyfx8fFUq1aNdu3aER8fz48//sj9+/f55ptvKFlS/Y7T0NBQVbDl9OnTqmlOTk6AMkDTsGHD3N2I1yxcuJAGDRrQsWNHevTogaenJ5s3b+bs2bMMHz6c+vXrq5WfP38+n3zyCVOmTFGlhAP4888/6dy5My1atKBUqVI8fvyYzZs38++//zJ48GBGjNBMrVG3bl1KlSpFpUqVsLOz48iRI0RHR1OmTBl++uknLC3V76b85ptvWLFiBQ0bNsTDw4Nr164RGRnJs2fPCAsL00hld+fOHSpXrkzLli3x9fXl5cuXbNu2jb///pu2bdsyY8YMw72QBYwqnZ38VhRCCMOQH3fmSb7ohBBCiDzJzlr9ekLZoo5cvhuv9SpoYnIq1rnULiGEyKskiCR0EhYWBpBlmrQePXowatQoVq9ezezZs7G3t2fx4sWUL1+exYsXM3/+fLy8vBgzZgzdu3fXCCI5OTmxa9cuxo0bx549e4iOjqZy5cqsXLmSYsWKmTSIZGNjw44dO5gwYQLLly/n0aNHVKxYkXnz5tGrVy+N8uvWrePqVfUu1OvXr1f937RpU5MHkSpXrszhw4eZNGkSmzdvJj4+nvLly7NgwQLef/99nespXbo0TZs2Ze/evdy5cwcHBwdq1qzJ7NmzM0wZ16NHD37++WcOHTpEYmIivr6+TJo0iXHjxuHi4qJRvn79+uzevZuNGzfy8OFD3NzcaNOmDR988IHWVIKurq506NCB/fv3s2nTJqytralSpQqLFy8mJCREY2wroTvFfxc75dKaEEIUMMYMqriWyrqMseTnYFHFtvDH0lfPK7TJeZ2GiHlmlnrN1Iw9IHZeVSKvvC4SlDeoSu3g+ApTt8KwCpWGR/+auhUil/0+pglnbz4GLQluWoedJ7CKBSNePCeDZLT5m2U20n2WCoJjPxi+LQVV8apw+7SpWyFEpiSIJHSyatUqVq1alWU5FxcXnj17pjbN0tKSCRMmaIy7A8o0ea/z9fXNMFikrby2aQA+Pj4ZzmvatGmG87SJiYlR/f/999/z/fff67WMOatQoQI//fSTTmWnTp2q1gMpTenSpfnxxx/1Wm9GdWWkWbNmNGuWST7a1zg7O/PDD3JSYyj5+PKaEMJUcrv3jb6BAmMHFkaegG+qG3cd5qDjd+BRCaJf6wGc/vXtsAD81NPOCD2NOglzAzSnt/w/9SBSq89yr00ZcS+vHKPGXFVooxxIu3g1WFA75/XlmZ6GWbSzUrvcaUZG8nOwNydGHIPrR2HHZHh6x/D1t/5cOT7c0SUGrtiEn4sh++CXIVBnsOnaIHKdQqGgSkn1sfMW+S3k1D83uPTMiUtHr1HJ6g79c3KV1KcRxGQ0lreZfhdUaANWtvovF9BLeUOIV22Yp+Umg9w8Zpv6+8EQ6++7Ab4om/N6hDAiuRVfCCGEXtJOf/UJxAohhHhNEV/D1uddP+syplCxLZSonnmZGn3AuXiuNCdv0vJ9+/pdw4V9tC9q4wDlW6s/Vy1jqH1Qz/MBnfZVE15sUyigalcoWl7/ZXXpUVcqSP96Tc3GyXDBsHKtXptgygurxlp3DuotoZlhIVNuZSGgJ1R8K/vrzIyNIzQK1W8Z/46a0zIbnN6YPKtrTrNzhV6rTdcmNWYaWCgg3u3zNnM+CmXlO3XpVac09tYZ95J9GP8yw3kqRSsYsHWZsMxG0CeNb2P1594NslePhYXy2ONWVstxXZflrbS3R40Zfz7sCmV/WXct+4mje/bryylTrlvkKRJEEkIIkaX0p2+qMZFM0hIhhBBa1RpgpIrTHe3zTI+KfKxMsLLn1og/c16XR0XouQoG78p5XXnRe3sMX2cRX+i1BgZlMnh79beVPfQKeRt+/enVNnEvixbTofE4zemdF0G7b3K/PXlFhTbQebEy5Zo+Wn4K1g5Zl8uKtqCLiye8vQ5CtulWR/t5yvfY0ePVtIajc9627CjfCjovgQajTLN+YfasLC1o4OfOjM5V6RqofiPArMSeqv9rfbqDbt8dYOm+K9x49Dy3mwkKC2Vvlf4bwdou+/V0/8Hwx+DsHNdHHocOC6HecMO2JT1nT+PU698B3j+Q/eUHbMq6jLGkncsrFDBwK7y9XoJIQmcSRBJCCJE9EkUSQuSEfRFTt0Apox4c6WW3rW5+upXTJRd9Vj0cDDnGjL4Xt3XpRaRPnU7F/vubD3snuep6YTiDL9lCpZU9t/S9wJyRim2hZC0t68lhgCOznnYuJhxxwrHoq/89A7Tfza3I4U/kCq2hVCZp8CwsoHovHXpB6HCilf4i/etsnbVPz87xLDt311doAyUDtay/ENTq/+q5jQECH6Zg7fjqf12P9dqkvR9px3iFAqp11//OfhsH5bHhdWmpI4tW0q2ejI4t5VpAaR170dm5KN9jB7dX0/RJl+WUyX6tL4UCqnWDohUNV6e+6xd5hkW6425qsSp0f+vV5zAlFY7GPGTapnM0mLmLsT+d0KzASs/gjj77etFKUDY4i547OrAvrH4MzoougZjXj+u6KFQaarwNlkZMb2tfWHOavp9JCy35DYOGgWsOzmWcPMDb0OOjZ3LOoO3mAFD2DC/3hvGCbSLfkTGRhBBC6MVCfgwJIQyhTBNlKohi/qZtxxtTITkJAnpkXKZcS/3qHBwFRxZB88m6lXdwU17guvu35ry+G+D0Ogj+UL825ESLafqVr/PuaxO0/JBt9hG8jFdezMtK/02w9ytorGMKpa7hyjEI0o/9Y27qvg8vn0BTzTFC9aLxWusiG3d9NJ8MSQnKi9n6CNkOfy5T7kN/bVSf1/tH5bT6IyAlWf825UTXpRCzH6pl8jkPnqS80G1tnzttemMKJCdmfuzJSkAv/ZcJ+Q32zoYqXWDV65/HDPaVbssyrzPDc0Md9j1dAvmZyuW7mt6cBc/ug3u6wNFbcyDhSSZjoWiRFmBJez9e77XVbBL89Wv2xzkatEP5973dsP+bnB97skXP96b7D3A5Cmr0NU5zhADo+C1seD/LYoqeK/Et5A2PjoBnNaK9mrL59C02nrzJ37efcCn2KfwXG91h05yHlfrSyuoKrplVmraP+zaCK3t139er9czd88D0n926Q+D3Kbm4blN77bhl56L8rZCaojxu378EperkfDX6XFNJf9OCLvr8DCs6v3re9xeYkUnQKzAEfjPFd4TIawp0EGnBggUsWLCAly91yG0qhBACSDcmknRFEkLkiALe/MzUjVDepdhxQeZlLPTsmVCyJnT6Tvs8rT8aU2HYYZiq5dJD2WDlQ3tl+rVLVw569FRoPjnzO8zTttfONevXOU3R8tD5e93bUKWz8pGdIFJu3RhRNliZVimnsjP4dXbYF9L9/UqvdF3lI72017h8q1evQcLTHDVPb1W6KB+ZaaIl9Zox6XLsyYyHP1hm4+d80QrKz9f9y5mXSz/2pc493wyc/tJQn09Dfs6DhmhOcy6mTE+k7RieUTvSxiFJez9eZ19ImbpyZdfstTPtImexyvodT015s5h/B+XDKMzkJji5Gc/0PAN0K5cW4G49EwAfYFiwH8OC/Yi5F8+xg0nwX2bZwY8HwWG4aHmRjzLrWJN+H8/qOyk9fT7DhpaT1Hl5RVafS2Ok49RnfGltvaE0pNsGv+YwNU73+q1slSk/98/VfRlRIBXodHbDhg3j3LlzREdHm7opQghh1tKf4qjGRJIYkhBCmIBcgBJZyQv7iJxE5FxeeJ/NRIE5aZV9Qojc4OPuSOcar3p2jGjmh5+HU6bLJBeY45AQ+VdwcDD+/v4sWJCDm4DysALdE0kIIYQQQgjTkwt/+Z++73EB2ifkznwhhBB52NiWFRjbsgIPd/wB+7WXWXXkX7bePkQDP3eaVfSgnIcTVpaZ3NcvQSchzE5UVBReXl6mbobJSBBJCCGEnpQXe+S0VgiRt+h71JKjnAaDXtAwp9fXnNoicld+fu/z87YJIYT5KeyQSS67VDhw+T4HLt/ni23nsbe2pJZ3YWqULkSN0oWo6+uGo21+ukSbm99Bpv6+M/X6hcgd+ekIJYQQIhdIOjshhEFI5wMhREakd1LuK4ivubG2uSC+lvqS10hJXocC4w3/4lwrVIb9l+5x9uZjnicms+/SPfZduqcq41XYnoBShahSwpUuT17gkRsNM+d90JzblpvkZRBmQoJIQgghhBBC5HUS2RfCTJnp1R85ZgghRK7xdLVjYptKACQmp7Dv4j2O//uQP/99yMlrcTxNSOL6w+dcf/iczadu4WB1m/7/XbEdvfYEVUu6Usu7MNW8XFFIcEUIYQISRBJCCKGXtFPWVOm2LYQQ+V/6C8158aKFXCg3gtf2A3mNhTnJi8cpIUSBYm1pQXBFD4IrKvsapaamcu3Bc479+5Bj/z7k1PU4FLdflf/l+A1+OX5D9bx8MSdqeRfG39OFemXdKF3EERurTMZXEkIIA5AgkhBCiCyl/zku6eyEECIbFBaQmgLFq2Vd1sIaUhKzv67iVeH2aXBwz34dOWFl/+p/CzP/uVHYx9Qt0EG6b2FrB/0Xd/MzXFOyw87VtOs3hOJV4d8DubtOU3x+nYplb7kiZZV/FekuYlo75HwbdPl8FqsM9y/lbD25pUgZ/co7FDFOOzKT9l4agoc/3P3bcPXliI4/XOxc4UXcq+cKy+ytTmEJqcma3/luBnx9RfbYOmc8z0Q/cBUKBaXdHCjt5kDHGiWVEzf/AkeV//asXYozN+M4e/Mxqalw4c5TLtx5qlrexsqCisWdqVrSlSolXalY3JmKxV2wt8nm/lu0Etz4M4dblQ02zvDyieHq8/CH2HOvTdRyo0H68xTXUoZbvxD5jJn/qhNCCGEOJF4khBA59P4B+GMpNArNumyTD+DZfajcMXvr6rka9s+FoPf1X7b3T3ouoOXHuKMbvDkLLK3A2l5zvjno9ytc2AZBQ+Hw96ZujRbpvnktLKDdXEh4Cq4l9a+q6YeQkpz9/Sm7On4LT25D0QqGrTdoGBxaYNg6M9ItAu6cVb6Gtk7K7TmxMnfWXTpIud4L2yBwICQ8UQZmNv3PeOv0ewMaj1MGzVKSsy4/cCv8tRGajFc+t7SGt+ZA4jNw8VQ+sqN4NfBppHyvs9L2a3AqDkdy8DkeeRy+qQFtvsx+HZkZ+Buci1Qe2w98o/tyJWtB8CQ9gt05OGNPey8b6/Adpeu6234FjkWheu/stys7gj/SP2D39jq4sgcajYXomeBdH2L2QmCIbst7VodbJ149f3//f9/5Y5XPQ7bB2Q3QZIJ+7RI5034+/Dochh19Na1QaWj5aZ66wWFmF2Uw8kViMpfvPuWPGGVvpZPXH3Ep9ikvk1I4dT2OU9dfBUAtFFCmqJMqsFS1pCuVS7jgaKvDZeBWnyq/86p1N9YmaTd4FxxdDEcWGaa+Nl+Ag5vyGLSoieb8jt/C0zvq5ym1BsDDGCgbbJg2ZEfRSuDTUPlaCGFGJIgkTCYmJgZfX1/69+/PsmXLTN0ck/Dx8QGUr4WhNW3alN27d5Mq3UWEgSn+u2Aoe5YQQujBo5Lyx6wubByhybjsr6tQKWibjYuhvo2hfMvsrze9oCGGqcdYyjRRPrRRXXw0o2+6WgOyv6ytE7SeabCm6MxYF47f/Cz3gkiVO74KvjWfDH9vyb0gkkIBTScoH+kZM4ikUECzScr/z/ycdXnv+spHeoED1Z9X6wGn1ma2Us1Jvo2h1f9lvX5QBq3bfJ6zIFKRMjA1Luty2eVdT/lIT9e0ezn5LtCHtvcypxyKKN+b3JYW1NRHuRbKB7w6Xvq31335mv1g84lXz1//zi8dpHyI3FWzr/LxuvojjLRC45432FlbUrmEK5VLvAqApaSkcvzaI87ffsKFO8rHuVuPefQskUuxT7kU+1SVCk+hAL//Akuz/1s+MTkF69dXZF8YWs8y6rZoVbS88nNjqCBSVscgbecplta6f//kSCb7SpEyyvN4CSIJM1Ogk2YuWLAAf39/mjZtauqm5EkhISEoFArc3NxISEgwWTuaNm2KQqFAoVCwadOmDMvVrVtXVS46OjrDctOmTUOhUGBtbc3t27czLDdgwABVffPnz8+wXI8ePVTlzDlYdufOHYYPH07dunUpVqwYtra2eHl50bx5c37++WezCkbdunWLQYMG4enpiZ2dHRUqVOD//u//SEzUL/XPw4cPCQ0Nxc/PD1tbW4oWLUrXrl05e/ZshsusWrWKBg0a4OTkhKOjI7Vr1870ff3rr794++23KV68OLa2tnh7ezNq1CgePHigV1vNieo3rxntE0IIka/oenFRxv4Q2sh+UTAY/X3OR+d58pkQRpePPi8iT7GwUFDLuzC965ZmavvKrBocxPGPW3BgQjMW9a3FyOblaFbRA3cnW1JT4WLsU35ON77SrN/O8+acPXyw7hSrDv/LtQfPTLg1Itvke07kggLdE2nYsGEMGzaM69evU6qU5L3Ux5MnT/jxxx9RKBQ8ePCADRs20KNHD5O2ycrKiqVLl/LWW29pzDt79ixHjhzBysqKpKSkDOtITU0lPDwchUJBUlISERERfPDBBzqtd/jw4RrzHjx4QGRkZIbr3blzpw5bljuuXbvGDz/8QFBQEJ06daJIkSLExsayceNGunTpwjvvvMPixaa/E+L27dvUrVuX69ev06lTJ8qVK8fu3buZNGkSR44cYcOGDSh0+AK9f/8+9erV4+LFi9SrV48OHTpw69Yt1q9fz9atW9m1axd169ZVW2bs2LHMnj2b4sWL8/bbb2Ntbc2WLVsYOHAgZ86c4csv1e/6PnToEG+88QbPnz+nQ4cOlC1blhMnTvDNN9/w22+/ceDAAdzc3Az6+uSGrWeUwdV1f16nbz0f0zZGCJF3WeVymjMrO/3K2zgZpx3amMsPv6zG27FxAu4YZl0WGve9mo6tMzy59eq5jaPh12Eu77HZUGTw/39yKw2irQs8u5c768oN1o6QGK/8X99jnkU2x87ISnbeSytbzWkKI9z/aozPujHrzU+M8X7mRbl5riEMS99jpr7H5BxQKBSUKGRPiUL2tKxcXDX9dtwLTl1/pEx7d1A5LSHVkr9vP+Hv209Y+8c1ANwcbahRujBVS7pS26cwqj6K+owRZusKCTns4anr+kw9/qahP8evnzPaZDKOl6Hk4v4p8i755hbZsnbtWuLj4xk9ejQWFhaEhYWZukm0bt2aTZs2cffuXY15YWFhWFhY0KpVq0zr2LlzJzExMQwePBgXFxeWLl2q03qPHz/OyZMnNeatWLGChIQE2rRpo3XZsmXLUraseQyuGRAQwMOHD9m+fTvfffcdn332GUuWLOHSpUtUqlSJJUuWZNpDJ7d88MEHXLt2jYULF7J+/XpmzpzJgQMH6NmzJ7/++itr1qzRqZ4pU6Zw8eJFxowZw4EDB/jqq69YtWoV0dHRJCQkEBISQkpKiqr8H3/8wezZs/Hz8+Ps2bN8//33zJ8/n9OnT1O7dm2++uorDh48qLaOwYMHEx8fzy+//MLPP//MF198wY4dO/j888+5cOECH330kUFfm9x28roR034IIfKv1l8ox7tokvlNGgZXa6ByfAldtZ8PHpWhi+nPcYyu47dQrAq0ziLtUOfFytekW0TO11m22X9jngzVf9lqPaFMUyhRA96YmvO2dFumHHy5anfDvudOxbMuU1DZOinTnFVqB65er6a3nw/FqkKrGZkv33wKlKiZ83a8/ZPyvX97fc7rMgf9f1WOpdD3F+WYaCVqwBufZL2cV23wbWqcNgV/pBxjSZfxhlp+qhxbRluqqcI+UPG1mwWbTsxem7qGKz/rHb/N3vIZ6b5cuT91W2bYevOTGn2h/JvK10kXfX9R7tM9Vyn3o7Zfqc/vFqF8L3NTjf/So1X9b7yY0vUyLpuRXmuUr0HPXEqPKQyveAD4tVCeX2am1QzwDHg1PpYJFXe1o2Xl4oS2qgBNJ5JaMpD3Rn7Mgt41GVDfh9JFlDcT3Y9/ye9/3eHr3y/Qe8lher38iH8UpZjt9TWL9lzmz6sPtGeqGbBF+XkdsEX52XUvr5xeT/OG60wFhijH6CtRQ7fyvk2UaVCzc05pCB2/VR6HuoZrzqvZT/m3ZKDy/Kb9PN3r7bVWeZzotcow7QS03rgDUG+Y7ucMosAq0D2RRPaFhYVhZWXF+PHjOXnyJDt37uTq1at4e3trlE1OTubLL79k8eLFXL9+HS8vLwYNGpRpz6XY2FhmzJjBxo0buXbtGs7OzjRp0oRPPvmEKlWqaF0mJCSEjRs3snz5csaMGaOanpiYyIoVK2jZsiVeXl5al02/XQDvvvsuycnJhIWFsXfvXho1apThMv3792fLli2EhYXxzTfqA6WGh4dTqVIl6tWrx6+//qqx7OtjIt2/f5+AgADi4uI4fvw4fn5+qrKZzcvMixcvmDp1KqtWrSI2NhZfX1+GDh3K8OHD1XrsWFtrvyPY2dmZN998k7/++otLly5RuXIun6Sn8+TJE9auXUuZMmV47733VNMVCgUzZ85kzZo1LF68mF69emVZV2RkJBYWFnzyifqXZL169WjXrh0bNmxg9+7dBAcHq8oDjB49miJFiqjKOzo68tFHH9GxY0e+++476tVT/oi4fPkyZ86coXbt2rRvr55Pe+zYscyaNYvly5fz1Vdf4egodysKIQqQuu8qH7nN1kk5YO9vE3UbT8XdD4YeMH67zEH13rqNX1O0wqvX5KccrtPSCgZknIY4U51zMPaJNsUqw9CDWZfTV6k68Jfm+Z/4T2ctYx5kNH7F6xqNUT6m5nBg9JI19XzvzTxllVcgDDv06vm70bot1+9X5WfSGJw8YMg+3crWH5HxWCUKhfKCe9p73vLT7I9rUqWz8mFo/u31G0enIOqQcRp4rco2e7VPV2yrOT9t3LKcHgv00WH+q+3o8l+mjqNL9KujQmvlQ+RdFhbQZ13W5eoNVT70ZuTvm6YfoGj6AV6AVzFoW82Tqe0rc/dJAof+uc+xfx9y9sZjTlx7xMHkyjR7PgsuAhf/VlXh6+5ILe/CVC7hQm2fIlT2ro8i/XfQ8KPZa9tbX+tX3tIK+m/M3roMoWj5jH8ztJ+nX+AovQpvKh+5wc5V93MGUWBJEMkAUlNTeZ6YbOpmZMne2lKnNF9ZOXfuHIcOHaJNmzYUK1aMfv36sXPnTsLDw5k6dapG+XfffZelS5fi6+vLsGHDePHiBbNnz+bAAe0H2cuXL9O0aVOuX79Oy5Yt6dixI7Gxsaxfv55t27axc+dOjTRjAEFBQfj7+xMeHq4WRNq4cSN3794lJCQk0/RxDx484JdffsHf359atWrRr18/wsLCCAsLyzSIVLJkSVq2bMmqVav48ssvsbGxAeDYsWOcOHGCzz//nORk3fYPNzc3fvjhB1q0aEHv3r3Zv3+/KrgzaNAgbty4wbJly3QOIAF0796d48eP06VLFwDWr1/PyJEjiYmJ4auvvspiaWUQateuXSgUCpMGkAAOHjxIQkICLVq00NiXvb29qVChAvv37yc5ORlLy8y7Pt++fRt3d3ecnDS7Hvv6+gKwa9cuVRApbXystHkZlU9ff0blLSwsKF26NMePH+fQoUM0b94807aaAxn+SAghhBBCCCGEyL+KOtvSLqAE7QJKAJCYnMLNR885fOUBx/99yJkbj/n79mMSk1O5ci+eK/fiWffnq+XLF3Mi0KcIVUq4Use3MH4euZCKTQiRKySIZADPE5Pxn7zN1M3I0rlprXCwyflbntZbp29f5V2CnTt3ZujQoYSHhzN58mQsLF5lSYyOjmbp0qUEBASwf/9+VY+LiRMnUr16da319+vXj1u3bvHbb7+ppZ+bNGkSgYGBDB48mFOnTmldNiQkhNDQUI4ePUrt2rVV7XVzc6NDhw6ZBpFWrlxJQkKCarsaNWqEj48PP/30E9988w0uLi4ZLjto0CC2bt1KZGQk3bp1U63XysqKfv36ER6upVtrBpo1a8b48eOZOXMmkyZNYtasWSxcuJDIyEh69epF//79da4L4MKFC5w5cwZXV+XdWZ988gl169bl66+/plevXgQGBqqVj42NZeHChaSkpBAbG8uWLVu4du0aU6ZM0Tl4deLECTZs2KBzGwsVKsT//ve/LMtdvHgRgHLlymmdX65cOc6fP8/Vq1cpU6ZMpnW5u7sTGxvL06dPNQJJV65cAZSvXfry6edpK3/9+nWePXuGg4NDpuVTUlL4999/VevIC0EkIYQQBZXcRSCyIvtIpmQ8LCGEEHmUtaUF3m6OeLs50j1QOZZ8QlIy1x484/CVB/wZ85Bztx5zMfYpySmpXLjzlAt3nqqWt1BA5RKu1CxdiKpehajjU4TSblmMwSmEMEsSRBJ6SUxMZPny5bi4uNCxY0cAnJyc6NSpEytWrOD333+nZcuWqvI//PADAJMnT1ZL2VWyZElGjRrFxx9/rFb/8ePHOXDgACEhIRrjF5UvX57Bgwcze/Zszpw5ozWtXd++ffnwww9ZunQptWvX5ubNm2zbto3hw4ereghlJG3cpD59+gDKFGl9+vTh008/Zc2aNbz7bsapd9q3b4+7uztLly6lW7duvHjxgtWrV9O2bVuKFSuW6Xq1mTZtGjt37uTLL7/Ey8uLDz74AB8fH7777ju96/r4449VASQAV1dXJk2aRN++fYmIiNAaREqf4s3a2povvviCsWN1z+F74sQJjTRxmfH29tYpiBQXpxyDJ/32pJcW6Esrl5nWrVsTHh7OJ598whdffKGafvjwYTZtUqbXefTokVr5mTNnMmfOHHr37k2hQoUAePbsGTNmvMrbHxcXh4ODA+XLl6dMmTIcPXqUzZs307btq/QLc+bM4f79+xrrMGdy/UMIIYR8GehLgivCnMn+KYQQ6uS4qAtbK0v8PJzx83Dm7brKIS2ev0zmn3tPOfSPssfS2ZuPibkfT0oqnL4Rx+kbccBVAOysLShdxIHG5YpSvrgztbwLU7aoZoYYIYR5kSCSAdhbW3JuWqusC5qYvXXm6b10ERkZyd27dxk0aBB2dnaq6f369WPFihWEhYWpBZFOnjwJoDUdnLZphw4p86feuXNHa2q8v//+W/VXWxDJw8ODtm3bsmbNGr7++msiIiJITk4mJCQk0+36448/OHnyJM2bN1cbN6lfv358+umnhIWFZRpEsra2pk+fPnzzzTfcuHGDPXv28PDhwyzXm1l9q1evpnr16owcORJLS0tWrlyZaW+ojGT22h8/flxjXpUqVUhNTSU5OZlr166xevVqPvroIw4cOMCPP/6IlVXWh40BAwYwYMAAvduam6ZNm8Zvv/3Gl19+ycGDBwkKCuLWrVusW7cOf39/Tp06pdarrnHjxvTt25fly5fj7+9P+/btsba2ZsuWLSQlJeHq6kpcXJxqGYVCwcKFC2nXrh3t27enY8eOlC1blpMnT7J9+3aqVq3K6dOn1dYhhBAiN8gFgoJF3m+Dk2Bi/pOv3tP8tC3CYCQ3txBGZW9jSeUSrlQu4QooU/o/eZHIhTtPORrzKhXejUfPeZGYotFjyd7aEl93R5pWKEqF/wJLJQvZG2RIDjX56vtOiNwlQSQDUCgUBkkTlxekpbLr16+f2vTmzZtTsmRJIiMjefDgAUWKFAFQXVRPS+2VnrYeOg8ePABg8+bNbN68OcN2xMfHZzgvJCSEDRs2sH79esLDw6lVqxbVqlXL1naVK1eOoKAgDh06xNmzZzMdEygkJIQ5c+awbNkyoqOjKV68OG3atMl0vZkpU6aMKg1grVq1qF+/frbq0fY6p03LrMeOpaUlPj4+fPjhh1hZWTF+/HgWL17M+++/n612GEJaD6SM2v348WO1cpnx8vLi6NGjTJkyha1bt3LkyBFKlSrFtGnT8PHxoWfPnnh4eKgts2zZMgIDAwkLC2PZsmXY29vTqlUrPv/8cypXroyVlZVq3wdo1aoVe/fuZfr06ezatYvNmzdTpUoVfvnlF3bu3Mnp06c11iGEEEJkj1wgy1fkgqcQQl9y3BBCmBFnO2tqeRemlndh1bQH8S/58+pD/rr1mD0X7nL57lMePkvkeWIy52495tytx6qyDjaWVC7hQlAZNyoWd6F+WTcKOVgbPrAkhNBJwYh8CIO4du0a27dvB6BJkyYZlluxYgUjR44ElBfzU1JSuHfvHkWLFlUrd+fOHY1l03razJs3j+HDh2ernW3atMHT05MPPviAGzdusHDhwkzLP3/+nNWrVwPQv3//DMccCgsLY/bs2RnWU7VqVWrXrs2CBQu4c+cOoaGhOvXaycjs2bPZv38/bm5uHDlyhIULFzJ06FC967lz5w6lS5fWmAa6BVsAWrZsyfjx44mOjtYpiGSsMZHSxkJKGxvpdRcvXsTGxkZjezNSsmRJlixZojE9rRfc66n+LCwsGDlypGr/ThMTE8PTp0+pWbMm1tbWavPq1q2rSo+X3pw5c7SuQwghhBBCCJPLVwEJc7zgaI5tEtmWrz4vIttkP8hSEUcbWvgXo4V/MUY2V17fuRX3nGNXH3H+9mOizt/l2sNnPHqWyLOXyRyNecjRmIeq5e2tLWlYzh0/Dydq+xSmgZ87tlY5z7qU6wy+r8h3ijA+CSIJnS1btoyUlBQaNmxIhQoVNOYnJSURERFBWFiY6iJ7QEAAx44dY+/evXTu3Fmt/N69ezXqqFu3LgAHDx7MdhDJ0tKSfv36MWvWLOzs7OjVq1em5detW0dcXBzVq1enVq1aWsusXLmS5cuXM3PmzEzHVgoJCVEFWbKbyg6UaeYmTpxIhQoViIqKokGDBoSGhtKkSZNMe0Nps3fvXt5++22NaQA1atTQqY6bN28CaARIMmKsMZGCgoKwsbFhx44dpKamqt2BcvXqVc6fP09wcHCOgnfJycmsWbMGKysrunTpotMyK1euBKBnz546lb969Sr79u3D39+fqlWrZrutQggh8ovXf/jp+kNQfjDmHbn5Xsl+UTAU9Pe5oG+/EELkD56u9rStZk/bap6Maam81nj1fjwnrj3i2NWH/HH1ITH34ol/mczzxGR2nLvDjnN3+Pa/5T2cbantU4TyxZypW6YIAV6FsLfJg4ElIcycBJGETlJTUwkPD0ehUBAREUGZMmW0lrtw4QIHDx7kjz/+IDAwkL59+xIeHs60adNo1aoVjo6OANy4cYO5c+dqLF+nTh3q1q3L6tWrad++PT169FCbn5KSwt69ezPtCQUwZswYgoKCKFKkCIUKFcq0bFoqu9mzZxMcHKy1zLNnz1i9ejW//vorXbt2zbCuPn36ULx4cezs7LQG2nQRHx+vCnytXr0aT09PVq1aRaNGjejVqxdHjhxRG48qK9OnT+ett95SSwX36aefolAo1HpdnTx5En9/f41A0YMHD5g4cSKAzun5jDUmkouLCz179uSHH37g+++/Z8iQIYBy//zwww8BGDx4sNoycXFx3Lp1C1dXVzw9PVXTExMTSUpKwt7eXjUtJSWF0NBQzp8/z+jRoylRooRaXY8fP9YYl2rv3r3MmDEDb29vVXvSPH36FEdHR7VgV1xcHH379iU5OZkZM2bk4NXIXXJTlRBCCK0kpYgQQgghhDAgbzdHvN0c6VC9pGpaWmBp19+xXL3/jBPXHgEQ+ySBzadvsfn0LVVZX3dHyhZ1ompJV6p5uaL9Sp8QQh8SRBI62bVrF1euXKFJkyYZBpAABg4cyMGDBwkLCyMwMJDg4GAGDhxIeHg4VatWpVOnTiQkJLB27VqCgoK0pvlavXo1wcHB9OzZkzlz5lCzZk3s7e35999/OXjwIHfv3uXFixeZttfDw4OOHTtmuV2XLl1iz549+Pj40LRp00y3a/Xq1YSFhWUaRHJyctJpvZkZNWoU58+f58svv1T1FAoKCmLKlCl8/PHHjBs3jnnz5ulcX/ny5alSpYqqV8369eu5fv06Y8aMUUul9vXXX7Np0yYaNGhA6dKlsbe35+rVq2zevJn4+Hi6deuWZa+u3DBz5kyioqIYOnQov//+O35+fuzevZtDhw7Rrl07jd5Av/zyCwMHDqR///4sW7ZMNf3OnTtUrlyZli1b4uvry8uXL9m2bRt///03bdu21Rrg6dq1K8+fP6datWq4uLhw+vRptm7dSpEiRdiwYQPOzs5q5Tds2MDEiRNp1qwZJUqUIDY2ll9//ZW7d+8yffp02rdvb5TXSAghhBBCiDxBAtFCCCF08Hpg6WVSCvfjE9h74R5nbsZx8nocZ27EkZySypV78Vy5F8/vfymHcoj57z7szadvEWN5ieAKHviXcMloVUIILQp0EGnBggUsWLCAly9fmropZi+tt05WvUt69OjBqFGjWL16NbNnz8be3p7FixdTvnx5Fi9ezPz58/Hy8mLMmDF0795daxDJ19eX48ePM3v2bDZs2EB4eDiWlpZ4enrSuHHjTIM4+lq6dCmpqan0798/08H5mjdvTqlSpdi+fTvXrl2jVKlSBmtDeuvXrycsLIwWLVowZswYtXkTJ05kx44dzJ8/n1atWvHWW2/pVOePP/7IlClTWL16NXfu3MHX15dvvvlGI11g3759SUlJ4fDhw0RFRfH8+XPc3Nxo3Lgx/fv31+gVZiqenp4cPnyYSZMmsXnzZjZu3Ii3tzfTp09n/PjxOg+y6OrqSocOHdi/fz+bNm3C2tqaKlWqsHjxYkJCQrCwsNBYpmPHjixbtoyVK1fy/PlzSpUqxYgRI/jwww8pVqyYRvmqVasSEBDA9u3buXfvHq6urgQFBTFmzJgMe70JIYQQQgghhBBCiIzZWFng6WpP99ql6I7yGl1ScgoXY59yKfYpp64/4syNx5y5Eada5ur9Z3yx7TxfbDuPQgHlPJyo6+tGRU9ngsq44VXYPm+OsSRyRXBwMNbW1gwbNoxhw4aZujm5rkAHkdLe9OvXrxstKJBfrFq1ilWrVmVZzsXFhWfPnqlNs7S0ZMKECUyYMEGjfGoGObIKFy7M9OnTmT59epbrjI6OzrJMmu+++47vvvtO9fyzzz7js88+y3I5CwsL/v33X7Vpy5YtU+vZkpmMtj8mJkbteZcuXTJ8TSwsLNi9e7dO6wP112XWrFnMmjUr0/LNmzenefPmOtdvSp6enqrAZlYySq3n7OzMDz/8oNd6hw4dytChQ3UuHxAQwMaNG/VahxBCCCGEEELkKZL7WghhBqwsLajk6UIlTxfaBSiHJ0hNTYX/huz2cXOgPE5cjH1KaipcuPOUC3eeqpa3tFBQq3Rhyno4EVSmCPXLulPU2dYUmyLMUFRUFF5eXqZuhskU6CCSEEKI7Kss3b+FECL/c/ECu0JgbQ9W9lkWz1dyelG08Tj4qwDcTBIYAvvnQkXdesnnafrsE27lIO4aFK9qvPYIIYQwnKCh8MdSqGK47DfZJoFZg1EoFMpzlL830WbAh7Qp7EPcs0SuPojnwOX7nL4Rx8lrj7j+8DnJKakciXnAkZgHrD6ivJHcydaKCsWdCfQpTI1ShahcwpWSheyxsJB0rKJgkSCSEEKILKXPEtg3yJvlh67SvJJmCj8hhBD5jKUVhF5UfhFoSfUqMuEZYOoW5I7CPjDxljLQKF4ZdhhSksBK7mAWQog8wb2cfJ/lVz1WQOJzsHEAwNXBmmoOhajmVUhVJCEpmUP/PODUtUecuhHH8X8fce9pAk8Tkvjz6kP+vPpQVdbZ1opaPoWpUMyZ2j5FCPQpTCEHm9zeKiFylQSRhBBCZEluhBJCiALMSn4Uiyz8d1FGpGNhqXyYKzm5E0IITfJ9lj8pFFm+t7ZWljQpX5Qm5YsCyjR4d58mcP72E47GPOT4vw85d/Mx9+Nf8iQhiejzd4k+f5fv9/wDgKerHR4udrSqXIwy7k40Lu+OvbWlzuN25wpzaovIcySIJIQQQgghCo48f+E0r7c/l+XZ99uc2y0XIPKffPSeFugLZOZ83DA1eW1EbpD9LD9RKBR4ONvh4WxHo3JFVdPjniey/9I9Tlx7xJn/UuHFv0zmVtwLbsW94OS1R6qytlYWtKpcnKLOtrxVzZMShewp5mKH7CsiL5IgkhBCCCGEEOZE14ugBfpiqciQ7BYFg7zPujPHY6U5tkkIIUSWXO2taVPVkzZVPQFlj6XLd+M59u9DztyI42jMQ87ffkxKKiQkpfDryZsAhO27AkAZd0e+S3xC+f/qe5GYjJ11up7L2fl+kO8UkQskiCSEEEIvcn4ihBAGJgfW/E/eYiGEEEKIfEehUODn4YSfhxPdA0sBkJKSyo1Hz4k+H8uZG4/54+oDLt+NB9MyVIYAACplSURBVOCfe/E8skmE/4Yarfjxb5Rxd2TXf/XFPU/C1QTbIURWJIgkhBAie/JsiiAhhBBCCCGEEEIIw7OwUFCqiAN96/mopqWmpnL4ygOu3o+n0O/WkPCq/D/34sFO+f+By/cY+dEWCjnY0LaqJ8Vc7OhQvQRFHG3UeywJkcskiCSEEEIIIUReJ4F9IYQQQgghzJJCoSCojBtBZdzgtDP8q5y+aURDtp+7A/tflU1MTuXukwSWHYgBYNZvfwNQsbgztbwLU6WkK/XLuuGdy9sgCjYJIgkhhBBCCCGEECLvkDSgQggh8oEqJV2pUtJVFURqXsmDn+rXY+/Fe1yKfcLu83eJf5kMwN+3n/D37SeqZWP+6700b9dFKtQpg7ebI2WLOmJlaZHbmyEKAAkiCSGEEEIIIbJJLuQKIYQQQghhCDaWFtT2KUJtnyKqaSkpqWw/d5t/Hzxj06lb3Hz0gntPX+XD23PhHl+d/xMAa0tljydPVzs6Vi+Jh4stfrm+FSI/kiCSEEIIIYQQQgghTEAC0UIIIURmLCwUvFnFE4B3G5cF4NGzl/C5cr6vuwM3XtpxM+4Ficmp7L14D4Af/7gOvOqxlJSSyomYBwSmC1AJoSsJIgkhhBBCCCGEEEIIIYQQeUAhBxvV/593DQDv+lyKfcKl2Kfs+juWC3eecvZmHInJr8ZNTU5Jpet3Bynn4YSDjSX2Npa0qepJ5RIu+BV1xtXB2hSbIvIICSIJk4mJicHX15f+/fuzbNkyUzfHJHx8fADla2FoTZs2Zffu3aTKQNtCCCGEEEIYgJxXCyGEyA3yfSP05+fhjJ+Hs6rXUmJyCr8cuwGblfPT+v5ejH2qWubQPw9U/9coXYg2VTwZ3LhMbjVZ5CEy0pbItpCQEBQKBW5ubiQkJGS9gJE0bdoUhUKBQqFg06ZNGZarW7euqlx0dHSG5aZNm4ZCocDa2prbt29nWG7AgAGq+ubPn59huR49eqjKmXOw7M6dOwwfPpy6detSrFgxbG1t8fLyonnz5vz8889mFYy6desWgwYNwtPTEzs7OypUqMD//d//kZiYqFc9Dx8+JDQ0FD8/P2xtbSlatChdu3bl7NmzGS6zatUqGjRogJOTE46OjtSuXTvT9/Wvv/7i7bffpnjx4tja2uLt7c2oUaN48OCB1vIvXrxg+vTp+Pv7Y2dnR+HChWndujX79+/Xa9tyg/nsEUIIkce5V1B/7tMw47KKdKfvzsWN0x5zVL2P8m+94bovU7WbcdqSHR6V9Stfa4Dyb6kg3ZcpXV99WZE/5db7XLKWcevPDf4dlX/rjzBpM9QU+e/CXJUuhq+74lvKv44ehq/b3NXoq/xb933t8/U5lmZHZt/bouAoE6z+3K+FadohzF+N/85rS9R8Nc2tnPJvla7615e27GusLS3oXrvUq+dWFqx/vz4r36nL512rUc3LlbJFHVXzj//7iBuPnuu/flEgSE8kkS1Pnjzhxx9/RKFQ8ODBAzZs2ECPHj1M2iYrKyuWLl3KW2+9pTHv7NmzHDlyBCsrK5KSkjKsIzU1lfDwcBQKBUlJSURERPDBBx/otN7hwzUvajx48IDIyMgM17tz504dtix3XLt2jR9++IGgoCA6depEkSJFiI2NZePGjXTp0oV33nmHxYsXm7qZ3L59m7p163L9+nU6depEuXLl2L17N5MmTeLIkSNs2LABhSLr3Or379+nXr16XLx4kXr16tGhQwdu3brF+vXr2bp1K7t27aJu3bpqy4wdO5bZs2dTvHhx3n77baytrdmyZQsDBw7kzJkzfPnll2rlDx06xBtvvMHz58/p0KEDZcuW5cSJE3zzzTf89ttvHDhwADc3N1X5Fy9e0Lx5cw4cOEC1atV4//33efToEevXr6dJkyasX7+eDh06GOaFzAHJXC+EyNvMKAQ+9jwkPAHnYsrn46/Ak1tQLJOAg0IBo89C8kuwdc6ddpqDdnOUF81L1Mi6rF8LaDoBileDHR8bu2WZG3sBXsSBi6d+y9UaCMUDwKOS7sv0/RnunIOSNbMum990izB1C3KPwd7nLI6FRXxh6GFwcFOfXqMvHF+ew3Xnks6LlQEkXY4bueW9PXD/EnhWN3zdZYOV9Rf2MXzd5u6tr6Fm/4zfa3c/GHoIHIsaZ/0eleD9g+BUzDj1i7yhRHUYsh+cPeHhFShWxdQtEuaq+tvK40bRdOd570bD/Yv6fT+M+RtexoOTbsc2BVDLu7DqefdAZYAp7nki524+Ju55Il6F7XVfvyhQCnQQacGCBSxYsICXL1+auil5ztq1a4mPj2fMmDHMmTOHsLAwkweRWrduzaZNm7h79y5Fi6ofQMPCwrCwsKBVq1Zs3rw5wzp27txJTEwM7777LmvWrGHp0qVZBpFat27Nxo0bOXnyJAEBAWrzVqxYQUJCAu3bt+fXX3/VWLZs2bJ6bKFxBQQE8PDhQywtLdWmP3nyhLp167JkyRL+97//UbmynnfSGtgHH3zAtWvX+PbbbxkyZAigDP717t2bNWvWsGbNGnr16pVlPVOmTOHixYuMGTOGr776SjX94MGDNGrUiJCQEE6fPo2FhfKO7z/++IPZs2fj5+fH4cOHKVJEORBhfHw8wcHBfPXVV3Tp0oV69eqp6ho8eDDx8fFERkbSvn171fQvvviC8ePH89FHH/Hdd9+pps+fP58DBw7QrVs3Vq9erXovJk2aRM2aNRk8eDDNmjXD2dkUFw3N6KKrEELkF87F1XsTORRRPrLi6mW8NpkrS2soVVu3shZW4BVo3PboyrnYqyChPhQK8NKzJ4i1vf7LZIcON+vkOmNdGDZHxnifM3pPPSpqTnMy914u6bbFysZ8jgVpbJ2NG9TyDMi6jFGY+LeCLt8R+gTls6OYv3HrF3lD8f8CR45umZcTBZtCodnj19ZJ/+8HfW9SyoCrvTX1yso+KzJXoNPZDRs2jHPnzmWa2kxoFxYWhpWVFePHjyc4OJidO3dy9epVrWWTk5OZNWsWfn5+2NnZ4efnx4wZM0hJScmw/tjYWEaPHq1KM+bu7k6XLl04c+ZMhsuEhISQmJjI8uXqd8YlJiayYsUKWrZsiZdX5hddwsLCAHj33Xfp1q0bFy5cYO/evZku079/fywtLVXLphceHk6lSpXUAgvp+fj4qMZFAmXvGC8vL5ydnbl06ZJa2czmZebFixdMmDCB0qVLY2dnR6VKlZg3b55Gejpra2uNABKAs7Mzb775JoBe6zWGJ0+esHbtWsqUKcN7772nmq5QKJg5cyaAzr2lIiMjsbCw4JNPPlGbXq9ePdq1a8e5c+fYvXu3WnmA0aNHqwJIAI6Ojnz00UcAagGhy5cvc+bMGWrXrq0WQAJljyY3NzeWL19OfHy8xjqmTp2q9l6ULVuWkJAQ7t69y7p163TaPiGEEELklNxAIYQQQgghhBAFOogksufcuXMcOnSIli1bUqxYMfr160dKSgrh4eFay7/77rtMmDCBlJQUhg0bRqtWrZg9ezajRo3SWv7y5cvUqlWLOXPmULZsWUaMGEGbNm347bffCAoK4vDhw1qXCwoKwt/fX6MdGzdu5O7du4SEhGS6XQ8ePOCXX37B39+fWrVq0a9fPwCtwaH0SpYsScuWLVm1apVar7Zjx45x4sQJBg4cmOny6bm5ufHDDz/w7NkzevfurTbGz6BBg7hx4wbz58/Hz89P5zq7d+/OypUr6dy5M0OGDOHp06eMHDmS0NBQnZZ/8eIFu3btQqFQmLwX0sGDB0lISKBFixYaKeu8vb2pUKEC+/fvJzk5Ocu6bt++jbu7O05OThrzfH19Adi1a5da+fTzclLewsKC0qVL8+zZMw4dOpTtdQghhBBCCCGEEEIIIYQxFeh0dgaTmgqJz0zdiqxZOxgk/URaUKVvX+XglZ07d2bo0KGEh4czefJkVfovgOjoaJYuXUpAQAD79+/H0VE5YNvEiROpXr261vr79evHrVu3+O2332jVqpVq+qRJkwgMDGTw4MGcOnVK67IhISGEhoZy9OhRateurWqvm5sbHTp0yHQMopUrV5KQkKDarkaNGuHj48NPP/3EN998g4uLS4bLDho0iK1btxIZGUm3bt1U67WysqJfv34ZBti0adasGePHj2fmzJlMmjSJWbNmsXDhQiIjI+nVqxf9+/fXuS6ACxcucObMGVxdXQH45JNPqFu3Ll9//TW9evUiMFA9xUNsbCwLFy4kJSWF2NhYtmzZwrVr15gyZYrOwasTJ06wYcMGndtYqFAh/ve//2VZ7uLFiwCUK6d90MBy5cpx/vx5rl69SpkyZTKty93dndjYWJ4+faoRSLpy5QqgfO3Sl08/T1v569ev8+zZMxwcHDItn5KSwr///qtaR/PmzVXruHTpEleuXMHfXz0dgrY2CSGEEEIIIYQQQgiRNTNMSSzyDAkiGULiM/ishKlbkbWJN8HGMUdVpKWLc3FxoWPHjgA4OTnRqVMnVqxYwe+//07Lli1V5X/44QcAJk+erAoggbL3zqhRo/j4Y/XBjo8fP86BAwcICQlRCyABlC9fnsGDBzN79mzOnDlDlSqagxT27duXDz/8kKVLl1K7dm1u3rzJtm3bGD58ODY2NpluW9q4SX369AGUKdL69OnDp59+ypo1a3j33XczXLZ9+/a4u7uzdOlSunXrxosXL1i9ejVt27alWDH98+BPmzaNnTt38uWXX+Ll5cUHH3yAj4+PWro0XX388ceqABKAq6srkyZNom/fvkRERGgNIqVP8WZtbc0XX3zB2LFjdV7niRMnNNLEZcbb21unIFJcXByA2vaklxboSyuXmdatWxMeHs4nn3zCF198oZp++PBhNm3aBMCjR4/Uys+cOZM5c+bQu3dvChUqBMCzZ8+YMWOGWhsdHBwoX748ZcqU4ejRo2zevJm2bduqysyZM4f79+9rXcehQ4eYNm0aK1euVKW0u3LliioQmb68EEIIIYQQQgghhBBCGJOksxN6iYyM5O7du3Tr1g07OzvV9IxSv508eRJQ9up5nbZpaam97ty5w9SpUzUef//9N4Dq7+s8PDxo27Yta9as4cWLF0RERJCcnJxlKrs//viDkydPEhwcrDZukq4p7aytrenTpw/bt2/nxo0b/PLLLzx8+DDL9WZW3+rVq3FwcGDkyJG8fPmSlStXZtobKiOZvfbHjx/XmFelShVSU1NJSkriypUrfPLJJ3z00Ud06dKFpKQkndY5YMAAUlNTdX7ExMTovV05NW3aNDw9Pfnyyy9p2LAhoaGhvP322zRu3FjVCyh9r7rGjRvTt29fLl68iL+/P0OGDGHEiBFUrVqVW7duqQJbacsoFAoWLlyItbU17du3p0uXLowfP55WrVoxduxYqlatqrGO0aNH4+/vz9q1a6lVqxZjxowhJCSE6tWr4+3trVE+d8kdK0IIYd5k/B4hhBBCCCGEEIYnPZEMwdpB2cvH3Fk75LiKtGBKWnAlTfPmzSlZsiSRkZE8ePCAIkWKAMpeGRYWFqrUXulp66Hz4MEDADZv3szmzZszbEd8fHyG80JCQtiwYQPr168nPDycWrVqUa1atWxtV7ly5QgKCuLQoUOcPXs20zGBQkJCmDNnDsuWLSM6OprixYvTpk2bTNebmTJlyqjSANaqVYv69etnqx5tr3PatMx67FhaWuLj48OHH36IlZUV48ePZ/Hixbz//vvZaochpAVqMmr348eP1cplxsvLi6NHjzJlyhS2bt3KkSNHKFWqFNOmTcPHx4eePXvi4eGhtsyyZcsIDAwkLCyMZcuWYW9vT6tWrfj888+pXLkyVlZWqn0foFWrVuzdu5fp06eza9cuNm/eTJUqVfjll1/YuXMnp0+fVluHs7Mz+/fvZ9q0afzyyy/Mnz8fDw8PhgwZwltvvUXjxo012iSEEEIIIYQQQgghhBDGIkEkQ1AocpwmLi+4du0a27dvB6BJkyYZlluxYgUjR44ElBfzU1JSuHfvHkWLFlUrd+fOHY1l03razJs3j+HDh2ernW3atMHT05MPPviAGzdusHDhwkzLP3/+nNWrVwPQv3//DMccCgsLY/bs2RnWU7VqVWrXrs2CBQu4c+cOoaGhWFll/yM2e/Zs9u/fj5ubG0eOHGHhwoUMHTpU73ru3LlD6dKlNaaBbsEWgJYtWzJ+/Hiio6N1CiIZa0yktLGQ0sZGet3FixexsbHR2N6MlCxZkiVLlmhMnzp1KoBGqj8LCwtGjhyp2r/TxMTE8PTpU2rWrIm1tbXavLp166rS46U3Z84cresoVKgQs2fP1tjXli1bprV87pE73IUQQmhhgPE2hRBCf3LsEUIIIYTILRJEEjpbtmwZKSkpNGzYkAoVKmjMT0pKIiIigrCwMNVF9oCAAI4dO8bevXvp3LmzWvm9e/dq1FG3bl0ADh48mO0gkqWlJf369WPWrFnY2dnRq1evTMuvW7eOuLg4qlevTq1atbSWWblyJcuXL2fmzJmZjq0UEhKiCrJkN5UdKNPMTZw4kQoVKhAVFUWDBg0IDQ2lSZMmmfaG0mbv3r28/fbbGtMAatSooVMdN28qe9q9HiDJiLHGRAoKCsLGxoYdO3aQmpqKIt2Fq6tXr3L+/HmCg4NzFLxLTk5mzZo1WFlZ0aVLF52WWblyJQA9e/bUqfzVq1fZt28f/v7+qrR2hl6HEEIIIUSBIcHMvEveOyGEEEIIsydBJKGT1NRUwsPDUSgUREREUKZMGa3lLly4wMGDB/njjz8IDAykb9++hIeHM23aNFq1aoWjo7LH1o0bN5g7d67G8nXq1KFu3bqsXr2a9u3b06NHD7X5KSkp7N27N9OeUABjxowhKCiIIkWKUKhQoUzLpqWymz17NsHBwVrLPHv2jNWrV/Prr7/StWvXDOvq06cPxYsXx87OTmugTRfx8fGqwNfq1avx9PRk1apVNGrUiF69enHkyBG18aiyMn36dN566y21VHCffvopCoVCrdfVyZMn8ff31wgUPXjwgIkTJwLonJ5vwIABDBgwQOc26srFxYWePXvyww8/8P333zNkyBBAuX9++OGHAAwePFhtmbi4ONWYRZ6enqrpiYmJJCUlYW9vr5qWkpJCaGgo58+fZ/To0ZQoUUKtrsePH2uMS7V3715mzJiBt7e3qj1pnj59iqOjo1qwKy4ujr59+5KcnMyMGTM0tlHbOr7++mt+//13OnXqRO3atbN8nYQQQgghhBGkSs9sIYQQuUC+b4QQZkaCSEInu3bt4sqVKzRp0iTDABLAwIEDOXjwIGFhYQQGBhIcHMzAgQMJDw+natWqdOrUiYSEBNauXUtQUJDWNF+rV68mODiYnj17MmfOHGrWrIm9vT3//vsvBw8e5O7du7x48SLT9np4eNCxY8cst+vSpUvs2bMHHx8fmjZtmul2rV69mrCwsEyDSE5OTjqtNzOjRo3i/PnzfPnll6qeQkFBQUyZMoWPP/6YcePGMW/ePJ3rK1++PFWqVFH1qlm/fj3Xr19nzJgxaqnRvv76azZt2kSDBg0oXbo09vb2XL16lc2bNxMfH0+3bt2y7NWVG2bOnElUVBRDhw7l999/x8/Pj927d3Po0CHatWun0VPnl19+YeDAgfTv31+VEg6UKf0qV65My5Yt8fX15eXLl2zbto2///6btm3bag3wdO3alefPn1OtWjVcXFw4ffo0W7dupUiRImzYsAFnZ2e18hs2bGDixIk0a9aMEiVKEBsby6+//srdu3eZPn067du311hHyZIlCQ4Oply5cigUCqKjo/nzzz//v717j626vv84/jqnnHMsvXFplQKlFrNiKIgaURek2BGL2cxmwibBtKsKzFlqthXXVmfoJhnIhkQ3CV2ajSIbEIyODIcBFAquY5VsLHG6CZ0wCkNAsTcu5bTn/fvDnfPjcHo7tYdTTp+PpAl8Pu/T8/l+z+mrp9/3+X5P4LOYAABfEH+UDy083kDX+NkAAABAH9FEQp/4D173dnbJvHnz9L3vfU+bNm3S6tWrFR8fr+rqamVnZ6u6ulovv/yyxo8fr9LSUj300ENdNpGysrJ08OBBrV69Wlu3btW6desUFxen9PR05ebm9tjECddvfvMbmZmKioqCzha50uzZs5WRkaGdO3eqsbFRGRkZA7aGy7322mv69a9/rfvuu0+lpaVBc88884x27dqll19+WXPmzNEDDzzQp++5ZcsWVVZWatOmTTp16pSysrL0i1/8IuRygYWFhfL5fKqvr9eePXt04cIFjR49Wrm5uSoqKgo5Kyxa0tPTVV9fr2effVZ//OMftW3bNmVmZmrZsmUqKyvr8XG8XEpKir7xjW+orq5Ob7zxhlwul6ZMmaLq6mo99thjcjqdIbd58MEHVVNTo9/97ne6cOGCMjIy9OSTT+rpp5/WDTfcEFI/depUTZs2TTt37tQnn3yilJQU3X333SotLe32rLeCggLt2bNHb7/9thwOh7Kzs/Xzn/9cTz75pDweT3g7CwCAoYjLY0UY+zf2xNBjys8/AADAgKOJhD7ZuHGjNm7c2GtdcnKyzp8/HzQWFxeniooKVVRUhNRbN++AGzlypJYtW6Zly5b1ep+1tbW91vhVVVWpqqoq8P/ly5dr+fLlvd7O6XTq2LFjQWM1NTVBZ7b0pLvtP3r0aND/586d2+0+cTqd2rt3b5/uTwreLytXrtTKlSt7rJ89e7Zmz57d5+8fTenp6X0+K6e7S+slJSXplVdeCet+i4uLVVxc3Of6adOmadu2bWHdx9q1a8OqBwAAAAAAAIBICX2rPQAAAAAAAAAAAJSXl6fJkydrzZo10V5KVHAmEgAAAAAAAAAAQBf27Nmj8ePHR3sZUcOZSAAAAAAAAAAAAAhBEwkAAAAAAPSum89vxRDH8wIAgJhGEwkAAAAArsRBUQAAAACgiQQAAAAAAAAAQMxyOKK9AlzDaCIBAAAA1zrOmgEAAAAARABNpH4w/kgHMIhFIqOIPQBA13hHI4Ao4N3UAAAAVw1NpDDExcVJkrxeb5RXAgDd82eUP7MAAJejKz608HgPDVeroRBLz6fBsi00gwAg1GDJaAD4HE2kMLhcLnk8HjU3N3M2EoBByczU3Nwsj8cjl8sV7eUAABAFHJSOKM4AiT0x9ZjG0rYAAAAMDsOivYBoWrNmjdasWaNLly71+Tapqak6ceKEjh8/rpSUFLlcLjli6kU3gGuRmcnr9aq5uVltbW0aN25ctJcEAAAAAAAA4Bo3pJtIixcv1uLFi3X8+HFlZGT06TbJycmSpE8++UQnTpyI5PIAIGwej0fjxo0LZNVAoVcOAAAAAAAADD1DuonUX8nJyUpOTpbX61VnZ2e0lwMAkj7/DCQuYQcAAAAAAABgoNBE+gJcLhcHbAEMCXwMHAAAAAAAADD0OKO9AAAAAAAAAAAAAAw+NJEAAAAAAAAAAAAQgiYSAAAAAAAAAAAAQtBEAgAAAAAAAAAAQAiaSAAAAAAAAAAAAAgxLNoLGAx8Pp8k6eTJk1FeCQAMTp+c+VS+9vOSpNams/K1n1fzZ5/q+PHjUV4ZAITnuqZWudtNktQSQxkWd/q0Eq7idiX/774unm3SpUGyH/1r8rae14X/remLPN5JF00Oxd5zpb8Cj/mnzVf9MffftxT8WPjHz50+rc64yK9p2OlPNTxG8sP56cdKDGzLCSmub4cGPE0t8gzAPvA/du1NrWoP4/v8//MwOHuCngvDru3H5lqU1E5eAgPpuqZzMfl6FdHh/x1p3k618nwKm79f4O8fDFUOM7Pey2LbgQMHdOedd0Z7GQAAAAAAAAAAYBB59913NX369GgvI2poIknq6OjQwYMHdcMNN8jp5Ap/fq2trZo8ebI++OADJSUlRXs5AGIQOQMg0sgZAJFGzgCINHIGQKSRM13z+Xw6deqUbrvtNg0bNnQv6kYTCd1qaWlRSkqKmpublZycHO3lAIhB5AyASCNnAEQaOQMg0sgZAJFGzqAnnHYDAAAAAAAAAACAEDSRAAAAAAAAAAAAEIImErrl8XhUWVkpj8cT7aUAiFHkDIBII2cARBo5AyDSyBkAkUbOoCd8JhIAAAAAAAAAAABCcCYSAAAAAAAAAAAAQtBEAgAAAAAAAAAAQAiaSAAAAAAAAAAAAAhBEwkAAAAAAAAAAAAhaCIBAAAAAAAAAAAgBE0kdOnAgQP66le/qhEjRighIUF33323tmzZEu1lAYiwEydO6MUXX1R+fr4mTJggt9utMWPGaO7cuaqvr+/yNi0tLSotLVVmZqY8Ho9uvPFG/fCHP1RbW1uX9T6fT7/85S81depUxcfHKy0tTfPnz9dHH33U7bp27NihWbNmKSkpScnJycrLy9Pbb7/dbf2hQ4f00EMPKTU1VfHx8Zo2bZrWrl0rMwtvhwC4KlauXCmHwyGHw6G//OUvIfPkDID++v3vf6/77rtPo0eP1nXXXaesrCzNnz9fjY2NQXXkDIBwmJlef/115eXlKT09XcOHD9ekSZP0+OOPd5kDZAyA7vz2t7/V448/rjvuuEMej0cOh0M1NTXd1sdCnoS7DRgEDLjC7t27zeVyWVJSki1atMhKS0stMzPTJNmqVauivTwAEVReXm6S7KabbrIFCxZYRUWFzZ071+Li4szpdNrmzZuD6tva2uzWW281SZafn2/l5eWWn59vkmz69Ol24cKFkPtYuHChSbKcnBwrKyuzgoICc7vdNmrUKDt06FBI/YYNG0ySpaWlWUlJiZWUlFhaWpo5HA579dVXQ+rff/99S0lJMbfbbQUFBVZWVmY5OTkmyUpKSgZuZwEYEO+99555PB5LSEgwSbZ///6geXIGQH/4fD77zne+E3hdU1xcbOXl5VZYWGgTJkywd955J1BLzgAIV2lpqUmy9PR0++53v2tlZWU2Z84cczgclpSUZO+9916glowB0BP/MdfU1NTAv9etW9dlbSzkSX+2AdFHEwlBvF6v3XTTTebxeOzgwYOB8aamJsvOzja3221Hjx6N3gIBRNRrr71mtbW1IeP79u0zl8tlI0eOtIsXLwbGly5dapKsvLw8qN7fjFq+fHnQ+O7du02S5ebmWnt7e2B8+/btgRcQlzt79qyNGDHCUlNTrbGxMTDe2Nhoqamplpqaai0tLUG3yc3NNUm2ffv2wFh7e7vNnDnTJNmf//znMPYIgEi6dOmS3X777XbXXXdZQUFBl00kcgZAf7z44osmyYqLi62joyNk3uv1Bv5NzgAIx8mTJ83pdFpmZqY1NTUFza1evdok2aOPPhoYI2MA9GTXrl2BY60rVqzosYkUC3kS7jZgcKCJhCA7duwIecHjV1NTY5LsJz/5SRRWBiDa/O8MOXDggJl9/g7fsWPHWmJiorW1tQXVtrW1WWJiok2cODFofP78+SbJ9u7dG/L97733XpNk//nPfwJjv/rVr7rNnR//+McmydavXx8Y+/DDD02S5eXlhdTX1tZ2m28AoqOystI8Ho+9//77VlRUFNJEImcA9Mf58+dt5MiRNnHixKBmUVfIGQDh2r9/v0myhx9+OGTu0KFDJskeeOABMyNjAISnpyZSLORJf7YBgwOfiYQgtbW1kqT8/PyQuTlz5kiS9u7dezWXBGCQcLlckqRhw4ZJkg4fPqz//ve/mjFjhhISEoJqExISNGPGDH300UdBnzlQW1sbmLtSVxkTbib1VH/PPfcoISGBDAMGib/97W/66U9/qsrKSk2ePLnLGnIGQH/s3LlTn332mR588EF1dnbq9ddf1/PPP6+qqio1NDQE1ZIzAML1pS99SW63W3V1dWppaQmae+ONNyRJs2fPlkTGABg4sZAn/dkGDA40kRDk8OHDkj5/UXSlMWPGKDExMVADYOg4duyY3nrrLaWnp2vq1KmSes6Ly8f9defOndPJkyeVlZWluLi4Xut7u49w6+Pi4pSVlaWjR4+qo6Ojp80FEGHt7e369re/rVtvvVVlZWXd1pEzAPrjr3/9q6TPfyZvueUWzZ07V08//bSeeOIJTZo0SU899VSglpwBEK7Ro0fr+eef17Fjx3TzzTfriSeeUHl5ue6//36Vl5eruLhYJSUlksgYAAMnFvIk3G3A4EETCUGam5slSSkpKV3OJycnB2oADA1er1eFhYVqb2/XypUrAy8++pIXl9eFW9/bbcKt99/G5/OptbW1y3kAV8fSpUt1+PBhrVu3rss/aPzIGQD9cfr0aUnS6tWrlZKSonfffVetra3at2+fsrOz9cILL2jt2rWSyBkA/fODH/xAmzdvVltbm6qqqvSzn/1MO3bs0F133aWHH344cPUGMgbAQImFPOnPmjA40EQCAHTL5/PpkUce0b59+7Ro0SIVFhZGe0kArnH79+/XqlWr9Oyzz2rKlCnRXg6AGOTz+SRJbrdbW7du1fTp05WYmKiZM2fq1VdfldPp1AsvvBDlVQK4lj333HMqKCjQM888o8bGRrW2tuqdd97RxYsXde+99+oPf/hDtJcIAMCAoYmEIP5OcHcd35aWlm67xQBii8/n02OPPaaNGzeqoKBAVVVVQfN9yYvL68Kt7+024db7b+NwOJSUlNTlPIDI6ujoUFFRkW655RZVVFT0Wk/OAOgP/8/pHXfcobFjxwbNTZkyRRMnTtS///1vNTU1kTMAwvbWW2+psrJSJSUlqqio0Pjx45WYmKh77rlH27Ztk8vl0pIlSyTxWgbAwImFPOnPmjA40ERCkJ6uPfnxxx+rra2t2+tWAogdPp9Pjz76qNavX6/58+erpqZGTmfwr4zerlV75bVuExISlJ6eriNHjqizs7PX+t7uI9z6zs5OHTlyRFlZWYHLSwC4utra2nT48GH9/e9/l9vtlsPhCHytX79ekvTlL39ZDodDW7duJWcA9MukSZMkSSNGjOhy3j9+4cIFcgZA2N58801JUl5eXsjcmDFjdPPNN6uhoSHo+AkZA+CLioU8CXcbMHjQREKQWbNmSZJ27twZMrdjx46gGgCxyd9AeuWVVzRv3jxt2LCh2w9hHDt2rOrq6nTu3LmguXPnzqmurk5ZWVnKyMgIjM+aNSswdyV/xuTm5gbVS33PpJ7q//SnP+ncuXNkGBBFHo9HCxYs6PLL/4fC17/+dS1YsEA33ngjOQOgX/wHdv/5z3+GzHm9XjU0NCghIUFpaWnkDICwXbp0SZJ05syZLufPnDkjp9Mpl8tFxgAYMLGQJ/3ZBgwSBlzG6/XaxIkTzePx2MGDBwPjTU1Nlp2dbW63244cORK19QGIrM7OTisqKjJJ9q1vfcu8Xm+P9UuXLjVJVl5eHjReXl5ukmz58uVB47t37zZJlpuba+3t7YHx7du3myTLz88Pqj979qylpKRYamqqNTY2BsYbGxstNTXVUlNTraWlJeg2ubm5Jsm2b98eGGtvb7eZM2eaJKurq+vbzgBwVfmzZ//+/UHj5AyA/sjPzzdJVl1dHTT+3HPPmSQrKCgIjJEzAMKxadMmk2Q5OTnW1NQUNLd27VqTZDNmzAiMkTEA+mrFihUmydatW9flfCzkSbjbgMGBJhJC7N6921wulyUlJdmiRYustLTUMjMzTZKtWrUq2ssDEEGVlZUmyRITE+1HP/qRVVZWhnxd3mBua2uzadOmBV58VFRUBA7aTJ8+3c6fPx9yHwsXLgz80VVWVmaFhYXmdrtt1KhR9uGHH4bUb9iwwSRZWlqalZSUWElJiaWlpZnD4bAtW7aE1P/jH/+wlJQUc7vdVlhYaGVlZZaTk2OSrKSkZED3F4CB010TiZwB0B8NDQ12/fXXmyT72te+ZkuWLLGvfOUrJskyMzPt5MmTgVpyBkA4Ojo6AgdNr7/+elu4cKE99dRTgYyJj4+3+vr6QD0ZA6An1dXVVlRUZEVFRXb77bcHGtH+scvfEBMLedKfbUD00URCl+rr6+3++++35ORki4+PtzvvvNM2b94c7WUBiDD/Qdyevq58R0xTU5N9//vft4yMDHO5XDZhwgRbsmRJyLtT/Do7O+2ll16ynJwc83g8Nnr0aJs3b541NDR0u64333zTZs6caQkJCZaYmGizZs2yXbt2dVv/r3/9y775zW/aqFGjzOPx2NSpU23NmjXm8/n6tV8ARF53TSQzcgZA/xw7dsweeeQRGzNmjLlcLsvIyLDFixfbqVOnQmrJGQDhuHjxoq1YscJuu+02Gz58uA0bNszGjRtnBQUF9sEHH4TUkzEAutPbcZiioqKg+ljIk3C3AdHnMDPr66XvAAAAAAAAAAAAMDQ4o70AAAAAAAAAAAAADD40kQAAAAAAAAAAABCCJhIAAAAAAAAAAABC0EQCAAAAAAAAAABACJpIAAAAAAAAAAAACEETCQAAAAAAAAAAACFoIgEAAAAAAAAAACAETSQAAAAAAAAAAACEoIkEAAAAAAAAAACAEDSRAAAAAAAAAAAAEIImEgAAAAAAAAAAAELQRAIAAAAAAAAAAECI/wOU75UEeggURwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] @@ -348,19 +562,19 @@ } ], "source": [ - "N = 20000\n", + "N = num_iterations+1\n", "fig, ax = plt.subplots()\n", - "lns = ax.semilogy(range(5000), jnp.linalg.norm(all_b1_params_array[0,:5000,:]-jnp.ones(2,),axis=1),label=f'Adam b1 = 0.9')\n", + "lns = ax.semilogy(jnp.arange(N), jnp.linalg.norm(all_b1_params_array[0,:,:]-jnp.ones(2,),axis=1),label=f'Adam b1 = 0.9')\n", "for i,b1 in enumerate([0.99,0.999,0.9999]):\n", " lns += ax.semilogy(\n", - " range(N), \n", - " jnp.sqrt(jnp.linalg.norm(all_b1_params_array[i+1,:N,:]-jnp.ones(2,),axis=1)),label=f'Adam b1 = {b1}'\n", + " jnp.arange(N), \n", + " jnp.sqrt(jnp.linalg.norm(all_b1_params_array[i+1,:,:]-jnp.ones(2,),axis=1)),label=f'Adam b1 = {b1}'\n", " )\n", "ax1 = ax.twinx()\n", "for i,b3 in enumerate([0.999,0.9999]):\n", " lns += ax1.semilogy(\n", - " range(N), \n", - " jnp.sqrt(jnp.linalg.norm(all_ademamix_params_array[i,:N,:]-jnp.ones(2,),axis=1)),label=f'AdeMAMix b3 = {b3}'\n", + " jnp.arange(N), \n", + " jnp.sqrt(jnp.linalg.norm(all_ademamix_params_array[i,:,:]-jnp.ones(2,),axis=1)),label=f'AdeMAMix b3 = {b3}'\n", " )\n", "labs = [l.get_label() for l in lns]\n", "ax.legend(lns, labs, loc=0)\n", @@ -377,7 +591,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 28, "id": "2cf96de0-cb01-4338-87b4-dd80f0498ebd", "metadata": {}, "outputs": [ @@ -391,8 +605,8 @@ "Final value with b1 = 0.999: ((1.0000061988830566, 1.0000123977661133))\n", "Final value with b1 = 0.9999: ((0.9527199268341064, 0.9080769419670105))\n", "AdeMAMix Values:\n", - "Final value with b3 = 0.999: ((1.0000168085098267, 0.9999828934669495))\n", - "Final value with b3 = 0.9999: ((1.0000070333480835, 0.9999932050704956))\n" + "Final value with b3 = 0.999: ((1.0, 1.0))\n", + "Final value with b3 = 0.9999: ((1.0000046491622925, 0.9999949932098389))\n" ] } ], From 81da0e44049e5f3900b11ed52c9e992283dddf64 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Tue, 22 Oct 2024 11:56:34 -0400 Subject: [PATCH 15/32] fixed import ordering --- optax/contrib/_ademamix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index f5d5bd97a..1e23ab486 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -6,6 +6,7 @@ Pierre Ablin and David Grangier. """ +from typing import NamedTuple, Tuple import chex import jax.numpy as jnp import jax.tree_util as jtu @@ -14,7 +15,6 @@ from optax._src import numerics from optax._src import transform import optax.tree_utils as otu -from typing import NamedTuple, Tuple class ScaleByAdemamixState(NamedTuple): """State for the Ademamix algorithm.""" From 94e3f0a4f806e0545a4db37782d3ec1e2d5c05b6 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 09:20:23 -0400 Subject: [PATCH 16/32] updated references using rst format --- optax/contrib/_ademamix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 1e23ab486..769d58e2e 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -37,7 +37,8 @@ def scale_by_ademamix( """Rescale updates according to the Ademamix algorithm. References: - [Pagliardini et al, 2024](https://arxiv.org/pdf/2409.03137) + Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older + `_, 2024 Args: b1: Exponential decay rate to track the first moment of past gradients for From 40c0e6e839eb16af8624788b40a63673740f3396 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 09:42:04 -0400 Subject: [PATCH 17/32] updated docstrings --- optax/contrib/_ademamix.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 769d58e2e..6bc7ac0a4 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -17,7 +17,15 @@ import optax.tree_utils as otu class ScaleByAdemamixState(NamedTuple): - """State for the Ademamix algorithm.""" + """State for the Ademamix algorithm. + + Attributes: + count: iteration of the algorithm used to update the fast EMA and second moment. + count_m2: iteration of the algorithm used to update the slow EMA and alpha. + m1: the fast EMA. + m2: the slow EMA + nu: estimate of the second moment + """ count: chex.Array # shape=(), dtype=jnp.int32. count_m2: chex.Array # shape=(), dtype=jnp.int32. @@ -87,8 +95,8 @@ def update_fn( count_inc = numerics.safe_int32_increment(state.count) count_m2_inc = numerics.safe_int32_increment(state.count_m2) m1_hat = otu.tree_bias_correction(m1, b1, count_inc) - # NOTE: AdEMAMix does not perform bias correction on b2 to let the momentum - # buffer fill itself slowly. + # NOTE: AdEMAMix does not perform bias correction on b2 to let + # the slow EMA momentum buffer fill itself slowly. nu_hat = otu.tree_bias_correction(nu, b2, count_inc) updates = jtu.tree_map( lambda m1_, m2_, v_: ((m1_ + c_alpha * m2_) / (jnp.sqrt(v_+eps_root) @@ -130,13 +138,13 @@ def ademamix( The ``init`` function of this optimizer initializes an internal state :math:`S_0 := (m1_0, m2_0, v_0) = (0, 0, 0)`, representing initial - estimates for the first and second moments. In practice these values are - stored as pytrees containing all zeros, with the same shape as the model - updates. At step :math:`t`, the ``update`` function of this optimizer takes - as arguments the incoming gradients :math:`g_t`, the optimizer state - :math:`S_t` and the parameters :math:`\theta_t` and computes updates - :math:`\theta_{t+1}` and new state :math:`S_{t+1}`. Thus, for - :math:`t > 0`, we have, + estimates for the first moments of the fast and slow EMA and the second moment + of the fast EMA. In practice these values are stored as pytrees containing + all zeros, with the same shape as the model updates. At step :math:`t`, + the ``update`` function of this optimizer takes as arguments the incoming + gradients :math:`g_t`, the optimizer state :math:`S_t` and the parameters + :math:`\theta_t`. It then computes updates :math:`\theta_{t+1}` and the new + state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have, .. math:: @@ -183,9 +191,8 @@ def ademamix( Objective function: 1.34E+01 References: - "THE ADEMAMIX OPTIMIZER: BETTER, FASTER, OLDER" - (https://arxiv.org/pdf/2409.03137) by Matteo Pagliardini, - Pierre Ablin and David Grangier. + Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older + `_, 2024 Args: learning_rate: A global scaling factor, either fixed or evolving along From fb095e11f0de0acc5cc952613069b5c8e6b90202 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 09:47:59 -0400 Subject: [PATCH 18/32] fixed linting --- optax/contrib/_ademamix.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 6bc7ac0a4..47e116091 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -20,7 +20,8 @@ class ScaleByAdemamixState(NamedTuple): """State for the Ademamix algorithm. Attributes: - count: iteration of the algorithm used to update the fast EMA and second moment. + count: iteration of the algorithm used to update the fast EMA and + second moment. count_m2: iteration of the algorithm used to update the slow EMA and alpha. m1: the fast EMA. m2: the slow EMA @@ -95,7 +96,7 @@ def update_fn( count_inc = numerics.safe_int32_increment(state.count) count_m2_inc = numerics.safe_int32_increment(state.count_m2) m1_hat = otu.tree_bias_correction(m1, b1, count_inc) - # NOTE: AdEMAMix does not perform bias correction on b2 to let + # NOTE: AdEMAMix does not perform bias correction on b2 to let # the slow EMA momentum buffer fill itself slowly. nu_hat = otu.tree_bias_correction(nu, b2, count_inc) updates = jtu.tree_map( From af8f22d221249cbbe10dddf6ae464a19202b7d0a Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 09:58:32 -0400 Subject: [PATCH 19/32] synced ademamix api to adamw --- optax/contrib/_ademamix.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 47e116091..46bed690e 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -6,7 +6,7 @@ Pierre Ablin and David Grangier. """ -from typing import NamedTuple, Tuple +from typing import Any, Callable, NamedTuple, Optional, Tuple, Union import chex import jax.numpy as jnp import jax.tree_util as jtu @@ -14,6 +14,7 @@ from optax._src import combine from optax._src import numerics from optax._src import transform +from optax._src import utils import optax.tree_utils as otu class ScaleByAdemamixState(NamedTuple): @@ -42,6 +43,7 @@ def scale_by_ademamix( alpha: base.ScalarOrSchedule, eps: float = 1e-8, eps_root: float = 0.0, + mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: """Rescale updates according to the Ademamix algorithm. @@ -62,12 +64,16 @@ def scale_by_ademamix( (as in the Adam paper) to avoid dividing by zero when rescaling. eps_root: Term added to the denominator inside of the square-root to improve numerical stability when backpropagating gradients through the rescaling. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. Returns: A `GradientTransformation` object. """ + mu_dtype = utils.canonicalize_dtype(mu_dtype) + def init_fn(params): m1 = otu.tree_zeros_like(params) # fast EMA m2 = otu.tree_zeros_like(params) # slow EMA @@ -121,7 +127,9 @@ def ademamix( alpha: base.ScalarOrSchedule = 5.0, eps: float = 1e-8, eps_root: float = 0.0, + mu_dtype: Optional[Any] = None, weight_decay: float = 0.0, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, ) -> base.GradientTransformation: r"""AdEMAMix. @@ -208,11 +216,18 @@ def ademamix( eps_root: A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. weight_decay: Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the "schedule multiplier", but not the base learning rate. + mask: A tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Adam gradient transformations are applied to all parameters. Returns: The corresponding `GradientTransformation`. @@ -223,7 +238,7 @@ def ademamix( for a use case. """ return combine.chain( - scale_by_ademamix(b1, b2, b3, alpha, eps, eps_root), - transform.add_decayed_weights(weight_decay), + scale_by_ademamix(b1=b1, b2=b2, b3=b3, alpha, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype), + transform.add_decayed_weights(weight_decay, mask), transform.scale_by_learning_rate(learning_rate), ) From 47e42482f8d51166c110c0c64ed83a20473946c7 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 10:02:22 -0400 Subject: [PATCH 20/32] added defaults to scale_by_ademamix --- optax/contrib/_ademamix.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 46bed690e..dae57ca97 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -37,10 +37,10 @@ class ScaleByAdemamixState(NamedTuple): def scale_by_ademamix( - b1: float, - b2: float, - b3: base.ScalarOrSchedule, - alpha: base.ScalarOrSchedule, + b1: float = 0.9, + b2: float = 0.999, + b3: base.ScalarOrSchedule = 0.9999, + alpha: base.ScalarOrSchedule = 6.0, eps: float = 1e-8, eps_root: float = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, From 934761a0dde6722ceb9163974a44651c744549ff Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 10:05:16 -0400 Subject: [PATCH 21/32] fixed syntaxerror --- optax/contrib/_ademamix.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index dae57ca97..fb1a3d1c4 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -238,7 +238,15 @@ def ademamix( for a use case. """ return combine.chain( - scale_by_ademamix(b1=b1, b2=b2, b3=b3, alpha, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype), - transform.add_decayed_weights(weight_decay, mask), - transform.scale_by_learning_rate(learning_rate), + scale_by_ademamix( + b1=b1, + b2=b2, + b3=b3, + alpha=alpha, + eps=eps, + eps_root=eps_root, + mu_dtype=mu_dtype + ), + transform.add_decayed_weights(weight_decay, mask), + transform.scale_by_learning_rate(learning_rate), ) From 1b72988b06f5fdb1b6c83f98ec86a898a2aceec2 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 13:52:36 -0400 Subject: [PATCH 22/32] updated docstrings --- optax/contrib/_ademamix.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index fb1a3d1c4..903611d80 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -24,8 +24,8 @@ class ScaleByAdemamixState(NamedTuple): count: iteration of the algorithm used to update the fast EMA and second moment. count_m2: iteration of the algorithm used to update the slow EMA and alpha. - m1: the fast EMA. - m2: the slow EMA + m1: fast EMA of the first moment + m2: slow EMA of the first moment nu: estimate of the second moment """ @@ -45,31 +45,32 @@ def scale_by_ademamix( eps_root: float = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: - """Rescale updates according to the Ademamix algorithm. + """Scale updates according to the Ademamix algorithm. + + See :func:`optax.contrib.ademamix.` for a full description of the algorithm. References: Pagliardini et al, `The AdEMAMix Optimizer: Better, Faster, Older `_, 2024 Args: - b1: Exponential decay rate to track the first moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW - b2: Exponential decay rate to track the second moment of past gradients for - the first Exponential Moving Average (EMA) - same as AdamW - b3: Exponential decay rate to track the first moment of past gradients - for the second EMA. - alpha: the coefficient that "blends" the two EMAs. paper states values in - :math:`[4,10]` work well in practice. + learning_rate: A global scaling factor, either fixed or evolving along + iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + b1: Exponential decay rate to track the fast EMA. + b2: Exponential decay rate to track the second moment of past gradients. + b3: Exponential decay rate to track the slow EMA. + alpha: Mixing coefficient in the linear combination fo the fast and + slow EMAs. eps: A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling. - eps_root: Term added to the denominator inside of the square-root to improve - numerical stability when backpropagating gradients through the rescaling. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. This is needed for + instance when computing (meta-)gradients through Adam. mu_dtype: Optional `dtype` to be used for the first order accumulator; if `None` then the `dtype` is inferred from `params` and `updates`. Returns: - A `GradientTransformation` object. - + The corresponding `GradientTransformation`. """ mu_dtype = utils.canonicalize_dtype(mu_dtype) @@ -147,8 +148,8 @@ def ademamix( The ``init`` function of this optimizer initializes an internal state :math:`S_0 := (m1_0, m2_0, v_0) = (0, 0, 0)`, representing initial - estimates for the first moments of the fast and slow EMA and the second moment - of the fast EMA. In practice these values are stored as pytrees containing + estimates for the fast and slow EMAs of the first moment along with the second + moment estimate. In practice, these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step :math:`t`, the ``update`` function of this optimizer takes as arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t` and the parameters From 3e0699b09614c206bdc37c492ea41992bc4790f5 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 15:44:22 -0400 Subject: [PATCH 23/32] fixed typo --- optax/contrib/_ademamix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 903611d80..565f16640 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -209,7 +209,7 @@ def ademamix( iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. b1: Exponential decay rate to track the fast EMA. b2: Exponential decay rate to track the second moment of past gradients. - b3: Exponenital decay rate to track the slow EMA. + b3: Exponential decay rate to track the slow EMA. alpha: Mixing coefficient in the linear combination fo the fast and slow EMAs. eps: A small constant applied to denominator outside of the square root From 3ff8abaeb6727b055457353b02caeb32011fd3ab Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 15:45:52 -0400 Subject: [PATCH 24/32] reformatting note --- optax/contrib/_ademamix.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 565f16640..c7994730e 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -170,14 +170,16 @@ def ademamix( S_t &\leftarrow (m1_t, m2_t, v_t). \end{align*} - Limitations: AdEMAMix consists in leveraging very old gradients. Therefore, - the method is best suited to settings where the number of iterations is - important. The paper reports on this effect in Appendix C.1.5, showing how - smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations - scenarios. Moreover, retaining gradient information over many thousands of - steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is - non-stationary. + .. note:: + + AdEMAMix consists in leveraging very old gradients. Therefore, + the method is best suited to settings where the number of iterations is + important. The paper reports on this effect in Appendix C.1.5, showing how + smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations + scenarios. Moreover, retaining gradient information over many thousands of + steps can pose a problem in domains requiring fast adaptation to a sudden + distribution shift, or general cases in which the distribution is + non-stationary. Examples: >>> import optax From c933d175acf9493a1d1e75b3d9057c6bde426b04 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 15:52:53 -0400 Subject: [PATCH 25/32] fixing formatting issues --- optax/contrib/_ademamix.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index c7994730e..ceba3a7da 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -170,12 +170,11 @@ def ademamix( S_t &\leftarrow (m1_t, m2_t, v_t). \end{align*} - .. note:: AdEMAMix consists in leveraging very old gradients. Therefore, the method is best suited to settings where the number of iterations is important. The paper reports on this effect in Appendix C.1.5, showing how - smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations + smaller values of ``b3`` (e.g. ``b3 = 0.999``) can be better for low iterations scenarios. Moreover, retaining gradient information over many thousands of steps can pose a problem in domains requiring fast adaptation to a sudden distribution shift, or general cases in which the distribution is From bdeb3e2bb460103cf3c181fcd65df330fa01e7aa Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 15:59:06 -0400 Subject: [PATCH 26/32] reformatting note --- optax/contrib/_ademamix.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index ceba3a7da..c37832123 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -164,13 +164,12 @@ def ademamix( v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ - \theta_t &\leftarrow \theta_{t-1} - \eta \cdot \left( + \theta_t &\leftarrow \theta_{t-1} - \eta \cdot \left( (\hat{m1}_t + \alpha m2_t) / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} - + \varepsilon\right) + \lambda \theta_{t-1} \right).\\ + + \varepsilon\right) + \lambda \theta_{t-1} \right)\\ S_t &\leftarrow (m1_t, m2_t, v_t). \end{align*} - AdEMAMix consists in leveraging very old gradients. Therefore, the method is best suited to settings where the number of iterations is important. The paper reports on this effect in Appendix C.1.5, showing how From 8d3805868014b7a81f8fc09badba4bb5418638b4 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 16:00:32 -0400 Subject: [PATCH 27/32] reformatting note --- optax/contrib/_ademamix.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index c37832123..d93c589e8 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -170,15 +170,6 @@ def ademamix( S_t &\leftarrow (m1_t, m2_t, v_t). \end{align*} - AdEMAMix consists in leveraging very old gradients. Therefore, - the method is best suited to settings where the number of iterations is - important. The paper reports on this effect in Appendix C.1.5, showing how - smaller values of ``b3`` (e.g. ``b3 = 0.999``) can be better for low iterations - scenarios. Moreover, retaining gradient information over many thousands of - steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is - non-stationary. - Examples: >>> import optax >>> import jax From 6f4ec8a174892c457128c04bf7dc373d3a52e769 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 16:42:16 -0400 Subject: [PATCH 28/32] fixed ademamix docstring --- optax/contrib/_ademamix.py | 39 ++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index d93c589e8..38e639328 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -147,29 +147,40 @@ def ademamix( the parameter vector at time :math:`t`. The ``init`` function of this optimizer initializes an internal state - :math:`S_0 := (m1_0, m2_0, v_0) = (0, 0, 0)`, representing initial + :math:`S_0 := (m^{(1)}_0, m^{(2)}_0, \nu_0) = (0, 0, 0)`, representing initial estimates for the fast and slow EMAs of the first moment along with the second moment estimate. In practice, these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step :math:`t`, the ``update`` function of this optimizer takes as arguments the incoming - gradients :math:`g_t`, the optimizer state :math:`S_t` and the parameters - :math:`\theta_t`. It then computes updates :math:`\theta_{t+1}` and the new - state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have, + gradients :math:`g^t`, the optimizer state :math:`S^t` and the parameters + :math:`\theta^{(t)}`. It then computes updates :math:`\theta^{(t+1)}` and the new + state :math:`S^{(t+1)}`. Thus, for :math:`t > 0`, we have, .. math:: - + \begin{align*} - m1_t &\leftarrow \beta_1 \cdot m1_{t-1} + (1-\beta_1) \cdot g_t \\ - m2_t &\leftarrow \beta_3 \cdot m2_{t-1} + (1-\beta_3) \cdot g_t \\ - v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ - \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ - \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ - \theta_t &\leftarrow \theta_{t-1} - \eta \cdot \left( - (\hat{m1}_t + \alpha m2_t) / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} - + \varepsilon\right) + \lambda \theta_{t-1} \right)\\ - S_t &\leftarrow (m1_t, m2_t, v_t). + m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + (1-\beta_1) \cdot g^{(t)} \\ + m_2^{(t)} &\leftarrow \beta_3 \cdot m_2^{(t-1)} + (1-\beta_3) \cdot g^{(t)} \\ + \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot {g^{(t)}}^2 \\ + \hat{m_1}^{(t)} &\leftarrow m_1^{(t)} / {(1-\beta_1^{(t)})} \\ + \hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\ + \theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left( + \frac{(\hat{m_1}^{(t)} + \alpha m_2^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + \bar{\varepsilon}} + + \varepsilon\right)} + \lambda \theta^{(t-1)} \right).\\ + S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}). \end{align*} + .. note:: + + AdEMAMix consists in leveraging very old gradients. Therefore, + the method is best suited to settings where the number of iterations is + important. The paper reports on this effect in Appendix C.1.5, showing how + smaller values of ``b3`` (e.g. ``b3 = 0.999``) can be better for low iterations + scenarios. Moreover, retaining gradient information over many thousands of + steps can pose a problem in domains requiring fast adaptation to a sudden + distribution shift, or general cases in which the distribution is + non-stationary. + Examples: >>> import optax >>> import jax From e8ba7635fd7ce7ef3da1f66d8737d0198f785a2b Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Thu, 24 Oct 2024 17:41:53 -0400 Subject: [PATCH 29/32] fixed notebook ordering and line lengths --- examples/contrib/rosenbrock_ademamix.ipynb | 210 ++------------------- optax/contrib/_ademamix.py | 28 +-- 2 files changed, 27 insertions(+), 211 deletions(-) diff --git a/examples/contrib/rosenbrock_ademamix.ipynb b/examples/contrib/rosenbrock_ademamix.ipynb index 4ccac1917..588fcdb52 100644 --- a/examples/contrib/rosenbrock_ademamix.ipynb +++ b/examples/contrib/rosenbrock_ademamix.ipynb @@ -22,15 +22,7 @@ "execution_count": 1, "id": "55182561-ad63-4fb1-ba21-116ca65c21b1", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n" - ] - } - ], + "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import optax\n", @@ -85,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 4, "id": "a153b4de-331c-4c78-aca6-63864e1551e0", "metadata": {}, "outputs": [], @@ -103,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 5, "id": "92b6987c-8ba1-43bc-8083-4c2b6324cb28", "metadata": {}, "outputs": [ @@ -199,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 6, "id": "01fe6b99-cb4e-4203-8490-75be300448ee", "metadata": {}, "outputs": [], @@ -218,7 +210,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 7, "id": "e652d62f-4135-478b-8995-a34d7729c30a", "metadata": {}, "outputs": [], @@ -242,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 8, "id": "11a4561a-1d92-44ce-bab0-5af22f028167", "metadata": {}, "outputs": [ @@ -252,206 +244,26 @@ "text": [ "Objective function: 1616.0\n", "Objective function for b3=0.0 at iteration 0 = 1599.227294921875\n", - "Objective function for b3=0.9047933220863342 at iteration 1000 = 11.408196449279785\n", - "Objective function for b3=0.9512062072753906 at iteration 2000 = 11.29694652557373\n", - "Objective function for b3=0.9672003984451294 at iteration 3000 = 11.221567153930664\n", - "Objective function for b3=0.9752980470657349 at iteration 4000 = 11.092667579650879\n", - "Objective function for b3=0.9801891446113586 at iteration 5000 = 10.87605094909668\n", - "Objective function for b3=0.98346346616745 at iteration 6000 = 10.514174461364746\n", - "Objective function for b3=0.9858089685440063 at iteration 7000 = 9.911301612854004\n", - "Objective function for b3=0.9875717759132385 at iteration 8000 = 8.91100025177002\n", - "Objective function for b3=0.9889450073242188 at iteration 9000 = 7.268399715423584\n", "Objective function for b3=0.9900450110435486 at iteration 10000 = 4.631922721862793\n", - "Objective function for b3=0.9909458756446838 at iteration 11000 = 1.0591912269592285\n", - "Objective function for b3=0.9916972517967224 at iteration 12000 = 0.0565548874437809\n", - "Objective function for b3=0.9923334717750549 at iteration 13000 = 0.0014951552730053663\n", - "Objective function for b3=0.9928791522979736 at iteration 14000 = 2.036277919614804e-06\n", - "Objective function for b3=0.9933522939682007 at iteration 15000 = 4.5986325858393684e-11\n", - "Objective function for b3=0.9937664866447449 at iteration 16000 = 1.4210854715202004e-12\n", - "Objective function for b3=0.9941320419311523 at iteration 17000 = 1.4210854715202004e-12\n", - "Objective function for b3=0.9944571256637573 at iteration 18000 = 1.4210854715202004e-12\n", - "Objective function for b3=0.9947481155395508 at iteration 19000 = 0.0\n", "Objective function for b3=0.9950100779533386 at iteration 20000 = 0.0\n", - "Objective function for b3=0.9952471256256104 at iteration 21000 = 0.0\n", - "Objective function for b3=0.9954626560211182 at iteration 22000 = 0.0\n", - "Objective function for b3=0.9956595301628113 at iteration 23000 = 0.0\n", - "Objective function for b3=0.9958399534225464 at iteration 24000 = 0.0\n", - "Objective function for b3=0.9960060715675354 at iteration 25000 = 0.0\n", - "Objective function for b3=0.9961593747138977 at iteration 26000 = 0.0\n", - "Objective function for b3=0.9963013529777527 at iteration 27000 = 0.0\n", - "Objective function for b3=0.9964331984519958 at iteration 28000 = 0.0\n", - "Objective function for b3=0.9965559840202332 at iteration 29000 = 0.0\n", "Objective function for b3=0.9966706037521362 at iteration 30000 = 0.0\n", - "Objective function for b3=0.9967778325080872 at iteration 31000 = 0.0\n", - "Objective function for b3=0.9968783855438232 at iteration 32000 = 0.0\n", - "Objective function for b3=0.9969727993011475 at iteration 33000 = 0.0\n", - "Objective function for b3=0.9970617294311523 at iteration 34000 = 0.0\n", - "Objective function for b3=0.9971455335617065 at iteration 35000 = 0.0\n", - "Objective function for b3=0.997224748134613 at iteration 36000 = 0.0\n", - "Objective function for b3=0.9972996115684509 at iteration 37000 = 0.0\n", - "Objective function for b3=0.9973706007003784 at iteration 38000 = 0.0\n", - "Objective function for b3=0.9974379539489746 at iteration 39000 = 0.0\n", "Objective function for b3=0.9975019097328186 at iteration 40000 = 0.0\n", - "Objective function for b3=0.9975627660751343 at iteration 41000 = 0.0\n", - "Objective function for b3=0.997620701789856 at iteration 42000 = 0.0\n", - "Objective function for b3=0.9976760149002075 at iteration 43000 = 0.0\n", - "Objective function for b3=0.9977287650108337 at iteration 44000 = 0.0\n", - "Objective function for b3=0.9977791905403137 at iteration 45000 = 0.0\n", - "Objective function for b3=0.997827410697937 at iteration 46000 = 0.0\n", - "Objective function for b3=0.9978735446929932 at iteration 47000 = 0.0\n", - "Objective function for b3=0.9979178309440613 at iteration 48000 = 0.0\n", - "Objective function for b3=0.9979602694511414 at iteration 49000 = 0.0\n", "Objective function for b3=0.9980010390281677 at iteration 50000 = 0.0\n", - "Objective function for b3=0.9980401992797852 at iteration 51000 = 0.0\n", - "Objective function for b3=0.9980778098106384 at iteration 52000 = 0.0\n", - "Objective function for b3=0.9981140494346619 at iteration 53000 = 0.0\n", - "Objective function for b3=0.9981489777565002 at iteration 54000 = 0.0\n", - "Objective function for b3=0.9981825947761536 at iteration 55000 = 0.0\n", - "Objective function for b3=0.9982150197029114 at iteration 56000 = 0.0\n", - "Objective function for b3=0.9982463121414185 at iteration 57000 = 0.0\n", - "Objective function for b3=0.9982765316963196 at iteration 58000 = 0.0\n", - "Objective function for b3=0.9983056783676147 at iteration 59000 = 0.0\n", "Objective function for b3=0.9983339309692383 at iteration 60000 = 0.0\n", - "Objective function for b3=0.9983612298965454 at iteration 61000 = 0.0\n", - "Objective function for b3=0.9983876347541809 at iteration 62000 = 0.0\n", - "Objective function for b3=0.9984132051467896 at iteration 63000 = 0.0\n", - "Objective function for b3=0.9984379410743713 at iteration 64000 = 0.0\n", - "Objective function for b3=0.9984619617462158 at iteration 65000 = 0.0\n", - "Objective function for b3=0.998485267162323 at iteration 66000 = 0.0\n", - "Objective function for b3=0.9985078573226929 at iteration 67000 = 0.0\n", - "Objective function for b3=0.9985297918319702 at iteration 68000 = 0.0\n", - "Objective function for b3=0.998551070690155 at iteration 69000 = 0.0\n", "Objective function for b3=0.9985717535018921 at iteration 70000 = 0.0\n", - "Objective function for b3=0.9985918402671814 at iteration 71000 = 0.0\n", - "Objective function for b3=0.9986113905906677 at iteration 72000 = 0.0\n", - "Objective function for b3=0.9986304044723511 at iteration 73000 = 0.0\n", - "Objective function for b3=0.9986488819122314 at iteration 74000 = 0.0\n", - "Objective function for b3=0.9986668825149536 at iteration 75000 = 0.0\n", - "Objective function for b3=0.9986844062805176 at iteration 76000 = 0.0\n", - "Objective function for b3=0.9987015128135681 at iteration 77000 = 0.0\n", - "Objective function for b3=0.9987181425094604 at iteration 78000 = 0.0\n", - "Objective function for b3=0.9987343549728394 at iteration 79000 = 0.0\n", "Objective function for b3=0.9987501502037048 at iteration 80000 = 0.0\n", - "Objective function for b3=0.9987655878067017 at iteration 81000 = 0.0\n", - "Objective function for b3=0.9987806081771851 at iteration 82000 = 0.0\n", - "Objective function for b3=0.9987953305244446 at iteration 83000 = 0.0\n", - "Objective function for b3=0.9988096356391907 at iteration 84000 = 0.0\n", - "Objective function for b3=0.9988236427307129 at iteration 85000 = 0.0\n", - "Objective function for b3=0.9988372921943665 at iteration 86000 = 0.0\n", - "Objective function for b3=0.9988507032394409 at iteration 87000 = 0.0\n", - "Objective function for b3=0.9988637566566467 at iteration 88000 = 0.0\n", - "Objective function for b3=0.9988765120506287 at iteration 89000 = 0.0\n", "Objective function for b3=0.9988889694213867 at iteration 90000 = 0.0\n", - "Objective function for b3=0.9989011883735657 at iteration 91000 = 0.0\n", - "Objective function for b3=0.9989131093025208 at iteration 92000 = 0.0\n", - "Objective function for b3=0.9989247918128967 at iteration 93000 = 0.0\n", - "Objective function for b3=0.9989362359046936 at iteration 94000 = 0.0\n", - "Objective function for b3=0.9989473819732666 at iteration 95000 = 0.0\n", - "Objective function for b3=0.9989583492279053 at iteration 96000 = 0.0\n", - "Objective function for b3=0.9989690780639648 at iteration 97000 = 0.0\n", - "Objective function for b3=0.9989796280860901 at iteration 98000 = 0.0\n", - "Objective function for b3=0.9989899396896362 at iteration 99000 = 0.0\n", "Objective function: 1616.0\n", "Objective function for b3=0.0 at iteration 0 = 1599.227294921875\n", - "Objective function for b3=0.9900476932525635 at iteration 1000 = 11.411341667175293\n", - "Objective function for b3=0.9950113892555237 at iteration 2000 = 11.296905517578125\n", - "Objective function for b3=0.9966714978218079 at iteration 3000 = 11.221532821655273\n", - "Objective function for b3=0.9975025653839111 at iteration 4000 = 11.092639923095703\n", - "Objective function for b3=0.9980015754699707 at iteration 5000 = 10.876032829284668\n", - "Objective function for b3=0.9983343482017517 at iteration 6000 = 10.514177322387695\n", - "Objective function for b3=0.9985721111297607 at iteration 7000 = 9.911361694335938\n", - "Objective function for b3=0.9987505078315735 at iteration 8000 = 8.911260604858398\n", - "Objective function for b3=0.9988892674446106 at iteration 9000 = 7.269254684448242\n", "Objective function for b3=0.9990003108978271 at iteration 10000 = 4.634438514709473\n", - "Objective function for b3=0.9990911483764648 at iteration 11000 = 1.0626317262649536\n", - "Objective function for b3=0.9991668462753296 at iteration 12000 = 0.0512956939637661\n", - "Objective function for b3=0.9992309212684631 at iteration 13000 = 0.00029214631649665534\n", - "Objective function for b3=0.9992858171463013 at iteration 14000 = 0.0002549797063693404\n", - "Objective function for b3=0.999333381652832 at iteration 15000 = 6.654856406385079e-05\n", - "Objective function for b3=0.99937504529953 at iteration 16000 = 1.4632985767093487e-05\n", - "Objective function for b3=0.9994118213653564 at iteration 17000 = 3.8770863284298684e-06\n", - "Objective function for b3=0.9994444847106934 at iteration 18000 = 1.1391578027541982e-06\n", - "Objective function for b3=0.9994736909866333 at iteration 19000 = 3.6230431987860356e-07\n", "Objective function for b3=0.999500036239624 at iteration 20000 = 1.222501424535949e-07\n", - "Objective function for b3=0.9995238184928894 at iteration 21000 = 4.350654592144565e-08\n", - "Objective function for b3=0.9995454549789429 at iteration 22000 = 1.6131096458593674e-08\n", - "Objective function for b3=0.9995652437210083 at iteration 23000 = 6.2786540411252645e-09\n", - "Objective function for b3=0.9995833039283752 at iteration 24000 = 2.7355170573173382e-08\n", - "Objective function for b3=0.9995999932289124 at iteration 25000 = 1.852578179750708e-08\n", - "Objective function for b3=0.9996153712272644 at iteration 26000 = 1.0457262078489293e-07\n", - "Objective function for b3=0.9996296167373657 at iteration 27000 = 4.015987542516086e-08\n", - "Objective function for b3=0.9996428489685059 at iteration 28000 = 1.5677557030358003e-09\n", - "Objective function for b3=0.9996551275253296 at iteration 29000 = 7.655046374566155e-09\n", "Objective function for b3=0.9996666312217712 at iteration 30000 = 5.6290506478262614e-08\n", - "Objective function for b3=0.9996774196624756 at iteration 31000 = 2.7529409862836474e-09\n", - "Objective function for b3=0.9996874928474426 at iteration 32000 = 7.927974365884438e-08\n", - "Objective function for b3=0.9996969699859619 at iteration 33000 = 1.0755715607047023e-07\n", - "Objective function for b3=0.9997058510780334 at iteration 34000 = 9.467848371969012e-08\n", - "Objective function for b3=0.9997142553329468 at iteration 35000 = 7.07339040673105e-08\n", - "Objective function for b3=0.9997221827507019 at iteration 36000 = 2.70986788564187e-08\n", - "Objective function for b3=0.9997296929359436 at iteration 37000 = 3.4848568475354114e-08\n", - "Objective function for b3=0.9997368454933167 at iteration 38000 = 4.460630975700042e-09\n", - "Objective function for b3=0.9997435808181763 at iteration 39000 = 1.674882810220879e-08\n", "Objective function for b3=0.9997499585151672 at iteration 40000 = 5.5214698591044e-08\n", - "Objective function for b3=0.9997560977935791 at iteration 41000 = 2.1872224920116423e-08\n", - "Objective function for b3=0.9997618794441223 at iteration 42000 = 1.8436061566262651e-09\n", - "Objective function for b3=0.9997674226760864 at iteration 43000 = 5.354401366730599e-08\n", - "Objective function for b3=0.9997727274894714 at iteration 44000 = 8.93862761586206e-12\n", - "Objective function for b3=0.9997777342796326 at iteration 45000 = 7.074592645039957e-08\n", - "Objective function for b3=0.9997825622558594 at iteration 46000 = 1.0837551656095457e-07\n", - "Objective function for b3=0.9997872114181519 at iteration 47000 = 6.424252774195338e-08\n", - "Objective function for b3=0.9997916221618652 at iteration 48000 = 2.5165093120449455e-08\n", - "Objective function for b3=0.9997959136962891 at iteration 49000 = 4.583888824072346e-08\n", "Objective function for b3=0.9997999668121338 at iteration 50000 = 4.14928891245836e-09\n", - "Objective function for b3=0.999803900718689 at iteration 51000 = 1.0028271901774133e-07\n", - "Objective function for b3=0.9998076558113098 at iteration 52000 = 4.8693053145143494e-08\n", - "Objective function for b3=0.9998112916946411 at iteration 53000 = 7.591589223920892e-08\n", - "Objective function for b3=0.9998148083686829 at iteration 54000 = 2.0316193172220665e-08\n", - "Objective function for b3=0.9998181462287903 at iteration 55000 = 1.0680969353416003e-07\n", - "Objective function for b3=0.9998214244842529 at iteration 56000 = 2.4416024757556443e-08\n", - "Objective function for b3=0.9998245239257812 at iteration 57000 = 4.2578651715530214e-08\n", - "Objective function for b3=0.9998275637626648 at iteration 58000 = 1.0644058079378738e-08\n", - "Objective function for b3=0.9998304843902588 at iteration 59000 = 6.184544076859311e-08\n", "Objective function for b3=0.9998332858085632 at iteration 60000 = 5.135669667311049e-08\n", - "Objective function for b3=0.9998360276222229 at iteration 61000 = 2.991244230088341e-08\n", - "Objective function for b3=0.9998387098312378 at iteration 62000 = 7.362999099314038e-08\n", - "Objective function for b3=0.9998412728309631 at iteration 63000 = 7.823453529454127e-08\n", - "Objective function for b3=0.9998437166213989 at iteration 64000 = 2.4600041115263593e-08\n", - "Objective function for b3=0.9998461604118347 at iteration 65000 = 9.106253173740697e-08\n", - "Objective function for b3=0.999848484992981 at iteration 66000 = 9.104553555516759e-09\n", - "Objective function for b3=0.9998507499694824 at iteration 67000 = 4.70359395876585e-09\n", - "Objective function for b3=0.9998528957366943 at iteration 68000 = 7.890244546615577e-08\n", - "Objective function for b3=0.9998550415039062 at iteration 69000 = 9.323736094302149e-08\n", "Objective function for b3=0.9998571276664734 at iteration 70000 = 1.139520122706017e-08\n", - "Objective function for b3=0.9998591542243958 at iteration 71000 = 1.6440836247966217e-08\n", - "Objective function for b3=0.9998610615730286 at iteration 72000 = 7.427800596815359e-08\n", - "Objective function for b3=0.9998629689216614 at iteration 73000 = 9.9156537203271e-08\n", - "Objective function for b3=0.9998648166656494 at iteration 74000 = 6.971504262764938e-09\n", - "Objective function for b3=0.9998666644096375 at iteration 75000 = 8.126028205879265e-08\n", - "Objective function for b3=0.9998683929443359 at iteration 76000 = 1.2932197535064915e-07\n", - "Objective function for b3=0.9998701214790344 at iteration 77000 = 1.196456196339568e-07\n", - "Objective function for b3=0.9998717904090881 at iteration 78000 = 6.57613128396406e-08\n", - "Objective function for b3=0.9998733997344971 at iteration 79000 = 9.470069528560998e-08\n", "Objective function for b3=0.999875009059906 at iteration 80000 = 1.3248055097392353e-08\n", - "Objective function for b3=0.9998764991760254 at iteration 81000 = 4.3568547880568076e-08\n", - "Objective function for b3=0.9998780488967896 at iteration 82000 = 1.0066288780308241e-07\n", - "Objective function for b3=0.9998794794082642 at iteration 83000 = 1.3248055097392353e-08\n", - "Objective function for b3=0.9998809099197388 at iteration 84000 = 4.3568547880568076e-08\n", - "Objective function for b3=0.9998823404312134 at iteration 85000 = 1.0448232501403254e-07\n", - "Objective function for b3=0.9998837113380432 at iteration 86000 = 6.760933501936961e-08\n", - "Objective function for b3=0.9998850226402283 at iteration 87000 = 8.262901474154205e-08\n", - "Objective function for b3=0.9998863339424133 at iteration 88000 = 3.462270470322437e-08\n", - "Objective function for b3=0.9998876452445984 at iteration 89000 = 8.820779839879833e-08\n", - "Objective function for b3=0.9998888969421387 at iteration 90000 = 1.028013230097713e-10\n", - "Objective function for b3=0.9998900890350342 at iteration 91000 = 1.3802278431285231e-08\n", - "Objective function for b3=0.9998912811279297 at iteration 92000 = 8.891692004908691e-08\n", - "Objective function for b3=0.9998924732208252 at iteration 93000 = 6.484791015282099e-08\n", - "Objective function for b3=0.9998936057090759 at iteration 94000 = 3.8966874171819654e-08\n", - "Objective function for b3=0.9998947381973267 at iteration 95000 = 1.4512920643028338e-08\n", - "Objective function for b3=0.9998958110809326 at iteration 96000 = 8.160476738794387e-08\n", - "Objective function for b3=0.9998968839645386 at iteration 97000 = 8.40046254779736e-08\n", - "Objective function for b3=0.9998979568481445 at iteration 98000 = 4.609432835422922e-08\n", - "Objective function for b3=0.9998989701271057 at iteration 99000 = 8.32615398849157e-09\n" + "Objective function for b3=0.9998888969421387 at iteration 90000 = 1.028013230097713e-10\n" ] } ], @@ -475,7 +287,7 @@ " updates, opt_state = solver.update(grad, opt_state, params)\n", " params = optax.apply_updates(params, updates)\n", " all_params.append(params)\n", - " if i%1000 == 0:\n", + " if i%10000 == 0:\n", " print(f\"Objective function for b3={b3(i)} at iteration {i} = {rosenbrock(params)}\")\n", " all_ademamix_params.append(all_params)\n", "all_ademamix_params_array = jnp.array(all_ademamix_params)" @@ -491,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 9, "id": "69d8642f-dfcc-4fac-8f85-3ee1fbfa135f", "metadata": {}, "outputs": [ @@ -546,7 +358,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 10, "id": "9db496a1-7b7d-44b3-a5f8-a662ea10bb5a", "metadata": {}, "outputs": [ @@ -591,7 +403,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 11, "id": "2cf96de0-cb01-4338-87b4-dd80f0498ebd", "metadata": {}, "outputs": [ diff --git a/optax/contrib/_ademamix.py b/optax/contrib/_ademamix.py index 38e639328..07a0a536a 100644 --- a/optax/contrib/_ademamix.py +++ b/optax/contrib/_ademamix.py @@ -153,20 +153,24 @@ def ademamix( all zeros, with the same shape as the model updates. At step :math:`t`, the ``update`` function of this optimizer takes as arguments the incoming gradients :math:`g^t`, the optimizer state :math:`S^t` and the parameters - :math:`\theta^{(t)}`. It then computes updates :math:`\theta^{(t+1)}` and the new - state :math:`S^{(t+1)}`. Thus, for :math:`t > 0`, we have, + :math:`\theta^{(t)}`. It then computes updates :math:`\theta^{(t+1)}` and the + new state :math:`S^{(t+1)}`. Thus, for :math:`t > 0`, we have, .. math:: \begin{align*} - m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + (1-\beta_1) \cdot g^{(t)} \\ - m_2^{(t)} &\leftarrow \beta_3 \cdot m_2^{(t-1)} + (1-\beta_3) \cdot g^{(t)} \\ - \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot {g^{(t)}}^2 \\ + m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + (1-\beta_1) + \cdot g^{(t)} \\ + m_2^{(t)} &\leftarrow \beta_3 \cdot m_2^{(t-1)} + (1-\beta_3) \cdot + g^{(t)} \\ + \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot + {g^{(t)}}^2 \\ \hat{m_1}^{(t)} &\leftarrow m_1^{(t)} / {(1-\beta_1^{(t)})} \\ \hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\ \theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left( - \frac{(\hat{m_1}^{(t)} + \alpha m_2^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + \bar{\varepsilon}} - + \varepsilon\right)} + \lambda \theta^{(t-1)} \right).\\ + \frac{(\hat{m_1}^{(t)} + \alpha m_2^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + + \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)} + \right).\\ S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}). \end{align*} @@ -175,11 +179,11 @@ def ademamix( AdEMAMix consists in leveraging very old gradients. Therefore, the method is best suited to settings where the number of iterations is important. The paper reports on this effect in Appendix C.1.5, showing how - smaller values of ``b3`` (e.g. ``b3 = 0.999``) can be better for low iterations - scenarios. Moreover, retaining gradient information over many thousands of - steps can pose a problem in domains requiring fast adaptation to a sudden - distribution shift, or general cases in which the distribution is - non-stationary. + smaller values of ``b3`` (e.g. ``b3 = 0.999``) can be better for low + iterations scenarios. Moreover, retaining gradient information over many + thousands of steps can pose a problem in domains requiring fast adaptation + to a sudden distribution shift, or general cases in which the distribution + is non-stationary. Examples: >>> import optax From c84ce4945480e0b3d4c65616937381885dc9a63e Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Wed, 30 Oct 2024 08:37:08 -0400 Subject: [PATCH 30/32] added docs image for ademamix --- .../examples/contrib/ademamix_rosenbrock.png | Bin 0 -> 585241 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 docs/images/examples/contrib/ademamix_rosenbrock.png diff --git a/docs/images/examples/contrib/ademamix_rosenbrock.png b/docs/images/examples/contrib/ademamix_rosenbrock.png new file mode 100644 index 0000000000000000000000000000000000000000..29eeced718db2245363e4245e568c95a3f05c643 GIT binary patch literal 585241 zcmeGCby!?6v^NUl46cO%irY|%7b!Z}U_}NB6!+p*TnBfG7l-1mMT%RYKq)RQ?pm}A zGT5D-bMAT1`+fKQ+xF{$n1WJl>nkXn349L(6 zz(VFQmJ98npx~?7%F1de$;vWld~kkm>tKa~q8Oc~i~UY(gd)#ClMw?T2Ugx!B~Bs* zD+9PLKQg>v#35FW!;v2R$X&axPeSsNvmLka7b*$OVI^m*xffR~W_X&1tI?m)Y+`VF zz>eSF0JpuKr<~ip95?HeE_9Tx6rQx*!Yq_+r3Xlxe_oQPN=mt5RGCw~+(v0I%=(f47go9j zOYbxsJbON0UuiueW1hx0&?zyJNmBalkhmwUSOFqkO@I;;8yU9%O1L=THO2cucIh|aBA>xIK%54wZ%G$9 zU+WpOXZv3|ZP&-YA*qTRL^ID+B>hu2{e4uxQo2F*s^jnWK!ZhDDX#>e7pe8iYw+Um8lqKohm( zN8`7VTgeVM3&yuZBXI$mM<}omTq&*#5|W3kF_NO*4vReX8Iyp5?#rOn$z%yGl~w4& zz(cLd?VvK^y#Gz3wUr|p&{Wdnv4za5Q#)l*#wr^E+(}Q~P*RRSe5g=G z3j!xK0;hY%iJUM6ja*^NG?lVeIngvI*vpvX1APG-W3`Dpgts+gxb~@6YfLwXI0V0; zIUtY~{$PT+fNya`Zp2Yk4F7`3p2ChjLb1DzS)yofFEA~D(P02$V7_J^S^xK!?H{K3 z5J5BwPg48-#riE}@y5+O{MxS@zs<29MuJ6)7RJ)qtHz3#E5<9n9~4#cn7k4*X`^Zr zD0Z{l9IN)Hzz*b-dzmM$A0s{2}ou@;`?WE0dg$=AHRE`RVS zeMy7RPtMX@F7d1fBJwE&QMjK!16g9eAk;Mmoc)xJ>`ZTylT(zT_o(i8c6{yc+9Bn)MfMODJ_c<#UYq&zLz_QXwkqxJa?~&&^=w zg-V`L`-6$Y!X}Ba3KKFIK1E4v#?_P1vk>Dbm?v4X5RAyNM4@FyzLZu@YF1=zB&d)l zWKc+&8I~_bZwnQd^G%dWx&}p57KxUV*oA*o_zl7)klmW9uwfG=eBZFf%sDY20BD}-Cq*sI}|F;;5AC0_T`FxSgbReKcoZZUpZcS0iC;lYj^ezYz9dg^V}DEMrXMWTQA@ zI0}Hhb>l;(PBzl$H-&;zv?qg>=0dLAu9&U@uJq^R4WTy~ztowME|soWoHfa#)koQ% zr3fWoC8{OYCbug(kJwp~H440r%8~h$7(2YPy|PWVUA1k!O+>C5!#bEkDo>)htDyt+ z)X3sU=P*hoddpQQIG0VW{xf}I(iX**kM8zLKLK_K-h2vM#6rYE&f-%-1-+Zc7)@7$Z=f$z z7im|3ukhvS(U(hfxZCBzzVP9ia4*3hd~^KwR1SP1){|m7*3}E|`hF!hV1Iu!kD^|r z8loN-Tz2~P{+F|*1Gxi>;G&KDQ29cJ?H}tLN5l92{a?R_nQdAy32_L#w!NTHNgp)Z zX?*d{S}&k9qLi@|cgC^wxRidZ`X}jF%-B`tif~!hm!AS-KgJAx-o1SDlIi8H-mKoe zxl)6r!FxTYdU5C5q0z+`BX*hEMcS>+d4dz|lkMY(Jc6G?3-3!+n}2#36ml&IefRuw z{7h_ie2)7^?dVCPgx%?To|ZL-HNKC$;}$uVhVZWjgolLBBhey7BO^b)`R?+a={w{* z+GHNFXn~E)-!Q4%;~dO!?ePWnX*MiXLy-pCO%+N% z%A_9@=}Wcv3xEZy+M3$@<>eJiC%BWu3B`%+N%pdHgU7F_KhOT)(P0q562j@Y=uX8~ z#Q#bV`95y!tfy@?FID*^!&#<^X;bL@vVHz`wfcA=>iwtvyRQ!}1>rk)I(L#60@q}S z<4eDxh8;_{D;MJ%vD+W_2{%0(z9077)G(|!N|stm$7G( z;1g^RI#<;MPcIgak8H6|WLCBtrlv!rI!Zw#ar|+kY z+yM}TM$$Q2KM^>{Q*w|lK>yCq+G*7z4KN#z;LZe$)6h3%H zoX7nZ5G}r%Hn-h-Gm&{@eKpc=DYg^kU)jlFg8jJP%Hcr12BM?5A%FXWT&Ow=;R&0pY_^sJ zR!z-M*PQlH!eD*|XDtzSRfiYH-;Obt@=p|(T_3Tph?>O+JU^Y>H>osMePL_ab*S-k zJa!sbsXCk#l{cEmo0ZVL!usP_=~ceglF(M?F7>PK*V9sP6S^=_1g(R_H_5lKUl(Vk z%yG=D1d0PW5#2-8L*u2}r5}q56b!PvvJ#%ViMLb@HkBNHJvrBAFK7Rh2NcosZhT6; ze0zE@Hj`jb*19Soa61P7!e_lS5uV#|W-*yyJ?=K1-?<)8uv54 zXRo!9uD=)QAtsW`M(znNb=%#$4;R=GM3u1S2M?1ngT3=eRH7xiL5cRecV{=ZWVa~^ zMk@K!d7q>Nd~}wMmsh6AX5DsLIeI)K(j}iJ4emSUPqU4Gb&A|auF*RNq};U|j~Wx+ zw{509RvE$C?&lWiSDw~AR*#6Bl~)gYYJ9#2dOZeA-&xbY4`2zXzOlTkxOMxr=ZEft zj;`4qCw0Ragt{}uh8`s4g$h=|3PPQTpXI3hyJ0H*)_PGIg=EQINhnO~t@U)>J>3dg zNhKa|D#{CF8i0aw{ z&8@h-9bNwMK@s;BMJ64s+|3xg9UYw9M7<@L{-cH{GW{=`hl$}os<_)rFuhaNV32kG zV8tNB&CAWpB#Fzwz##s?^1Z009OS>qkzW!_Hty~&qC7laUS8Z@0^H6Ytauq$ZsyJ9UhF{9mr9hOM`igPxqNBT{C_HYEA@`Gmy( zgW&(K=zlf&KSNz?N?H>R5RZMw6=S2HPI)ZBLHuxi0;WP6j$v(gxLlo_Q`}8CZu22L2VQ7~){6ANE zftDMz|E%EunZrMZvN zpB;3`b@Wlk))h}?maa0_pxz^sC&a~}ZcqEN-)UFX&dp_+RDg?u)bdZj{Y}D?_wCQM z`=1}q*848>^5J?rMyoqUqpMwhUT(?xO|~80=jOLtQjc2~iwoT?;%B|+J0~5d+!L_J zzdwEouiY09ZtDvKZhnitFzA6Bcr3L#oA~5k`{ZAc(4Ui7+1VZH7oAKH5sP@8EWa41 zmJAgMTdGzL)48e7jipu=wHW;X0kN=VnlQ3U=)U zyXG3Y_}2a)*1aV*%4vKab8*$$)75zv&hs+=={DaaEjf3T^-tGPK_p%Krf&HYdC5Q| zQF>cj8)x7Nr2Bj$lAil}0)3s{*N3Ymdssb`{>r~i<`lHfZWUb&yuK%j(5OadKCj9k9G^3Q1R8LA>XX#hC z6>kK*LaS`lz3^ARRRanP~1IN`(S;MZZ6HYq+jzg&tps;~MiY@b3 zVq#)3C^TJ19oo&0e8r}Xsv-eJT1>VT(f7;KPkzCK`!X3aam zxHER`g6CEroT<1r?MjVY_1JlEDKl!(Y?)<@_y#!zd;#j7N8D@(qB-| z6Zd)6aH&v>jkbNRe*aB$uYFgJjJ(-v$N+ES#D@QGRBqhVfsUs{YJ9V;F`y`T9H-~GqFJ4DSq-RD1j@t?T%SLb`Sg~{VBy%JY=UGIT$ zIf!(zV5GtBoSz(h&ib(JS(f{y{t8Ha@t@NTn44!_*^WLzZNLF{?4KqE@1( zOcJF@A%)Aq=u)MSYb&rIc9x8weA+Z2pTRSbpeE^?fG`NN#=hn%Y+@tyqcUq_u@eZS{&_E9mC4U)c~J3XiKh%K8- zL9KM0uz(n=hzF0h1}wE|%Nv>eoTG(aozp*^@BHh75y9(k=bqteUb)I{KR*)!++R+*Z5ysr82RzAGiT3dD%Jmy6BbSw3=x2XP3;XtK(TctAZijn?; zu|sGiN_!?~qgq{H1J^Me2hY81{I30k}gurV@EtypmTIQl6f*XLXu8ej;8Ms1 z-3~FbWprQK@lD10au2C!by}N^kHetYxsUt`ftP*s_kBfAe_@XkAm-?$tT9PU3xkb6 zNi__KK9Aa7QC>sCkc6lO$am{(V(+;sW5Ln>pB{)mYmZAXzf0KD>C;MShTkhB`s;S>(aO8aiV_k{U*>A zMLoYsL;KKY=xqbXyU`e4K^56d|KH^WxP$vx4`#ista3Sx+&P9^@z z3}9#Q#9E-US6KIDo-8vJuXr%YyUCswcID)BX;dw29DQit-^J|y-vm7TIMT&!sgV19todVX`R!jljT+0OGJ?IS?8m-b-Zz1KJTtfx zhH6XcugB=&zxhm96WQEu*kwgOBJdY9e0ZfAOTO#kZ|7!vE2dFT_}Zgn5F!2vrD{U| z+0xyZ?ZcqS{a{SANC^Sm#kCpzm6?ZREeDr<_y`AZH}gZ05RC5q?&OD_PwYYkD;X8(7j!sP47Ox5o7-CHAExMy_Vf{BC8jkA!acV?o2wgpO}i%CTv9x( z9P~rH?%tH!iqK&G75yoHr6=MC9}~B-jo@n4H@T~4>$QAnOXdLbj+k3}z#X(gzdFU7 zIzRsG(aXDPDswz=3my<;*hGbP-^BKOt<|lEK1Mt}AUARSk7{fM0P zB<9)PSR^Yj7QvcxbexA#I~$|kaS@$^2KFSc+nVv`)`YtbROQ7b39LW1s<(a^(B;ln z#fIiTt`j|ezDLwO`Nprk7PfV5{xJ2CztVVf=j~F~#J9?X!%w12U%%f)n4mz{c2uw-|FQYp1U`%HBy5e8QefJV+&-Yx!+LMS|y z?_$&KO1WJ}w2Ic8CNAjY4?bh=&dZ*zHwg7Ca;jKDNavS4lsY3y7j;w<9mJc3+qV3>wq*Igl~oUG4;`@9QOxs))2D|X``>-iU7LjuNyrTmK2wVb zl&0HOl49q6NCY}f$yM6@Qbu_Z&A!kp#thS4#JEEwCnf~I9al7u?1yY~8?Z@MFC%R+< zypIgl^k*vXhxMfCY6ud;-Bc*bs$K5~U^hfZ9qf1bw;F*|4&}!Xs>igR3tc;dDQ6S= zpenDY`-vz2cwDS*Zoal5bo^QD{E#`dwAX)jr1fU-nyeeI=e8h0;kbt$XGEVSrK zgGPwjs-@1g)93TcDGwj@J(8;(0`k3XUKPW2<_|fCl>!2VxhWUkYba zTP#*_2p-1yms4))mX*jGgr|F2)k%(W2E3usz+SKJ?DQ+2m}H%0KGMVYVD&lUSonl} zN*+0hc=I86yxv%1@~TFjl(hIAd$c6FaHH&l!wxFtscxAWk##o0$&Pu&Dx)PU>i(5P zZ3)y*a8BWYJ$AVGWrpz*hfkBv^G#BsxMC&4L((qI2dFXceCwp@Wvjcas5x?#I*JCf zGt{y#J{T>~2GMpaykrHoPuNB%>Izf0|ZRuxahFaf1LaDgd+ehqh zm~A-EU;zgX?_agi;13ML9HM~qe`u@vyM59Na!+Nftq8P=ZU)@WA`aitxFFrtxZH`m z&WWm9C#x1JRO#Ou5m5|yy3jEYfMO|GX!Q3auhV&LYmW|Eh`DnB#!!I=l-N6}x*1K@ zAnWJPXWejvp5_;81AbIL$lcHTvO{XcXB%3-fs)orHSE1cF{fCc;)+K^tDdLDiVkMX zV1b4YLxmAO@e_d; zfr&0dOVZlaIQ!}H*@aT{f*1Z8m+B$61)0!v!OPhI5xb(pKouz~QXO(j$FRaS19aJM zXcdas2iEv@%&7sz0|^hm66PK$ZA;b`ZED5kX>N=2gYb5YmS*YT>hWPWsp_|>=2Ot? z1MG!6{4zLJlPktcdE=~EPDi5^l5erkXTKjAP_fjE3yux;c_(9^GFt#u<4EDMnJ=V{ zgOd6H)7X+vw-a?4war$OvH{kQMGTC-JmrFQ@&EGKp!k=+O--P~9{2|@5@r@Cf#Ru|a?d&EcGqOl_$YlPun@tH|<-(f3F&mkzu&T6s zU;sBuL44%bia4wZG@_oz1uz9a>(~ zzBRcF?wwoz#<_YBmC_MIDBj~o2`5*ezJnc$*gy;ewMgc>sgfNKaZ{QH*+ zR!2w*0Os@8IC%5ncNAI&wp&!NK!QLKX8IWRONF`#; zsHaF5S1G#}Vj55+E|NyXy&Oqsym&6nEip{3ENbujId4t0P!mb!ke}ii-b)gl`3YwxN&OJ$hjvQ z2N1z0|IrNH&*RgE45Q_nO0rZVlR79nNcoHoYp0OgPQ^uQbu{z?tK`mfoX`&149`sn zPdPlFFRH$t4+AuWMjB);{9HmlPmqXqL715#wFdz))BXe?g@-4B0v2%x>*oS)gD@=s zuM6pc5sav2mNP>j?L9Lz=52!IriuCao=wi~!rf=x9bDyq@k8FIRNf58O9(d-(wNYg z`j1dvz-Oc$W_WB>2JXMoLqF5;>TDD~*q$OtylsxIqUeh15QT66;JhL|?iD0o^uRb< z^u_ULdJ>A;jR2sy3+~7j5Fa!NJMy~hi$k3_h}G}eFf!*`}ZU z*Q?c35=6P^5xN1CGVdhU_l@Cme=W29;5~UMfyObGyS$*-cmUwq&Ka?TMA7us7<{2I zIW!-I{DimOlp2_;ugV%@x56My4hB#j;)6klG}ifB5-f@fyU= zuWte->`=(;Hws9VfwT&?jkmnaR4HWHgK~|9jC=0iI8`KH=ma>U_&{E?YM>C>KldcT{rU3J<&PM0(E(?HetA?S5;brppU1&UZ#-LqK<4 zOdNg`-&etZ9%XVjdvb~A5?Un+fVm!66L+4`j0Q2H7R0TC6%r|YCV9S^APz{jqx?;9 zP&Xc;-r=3}Yi7o1`RYp+tN-rTTS|f0YFl5jugX73-*XClfJ(xvrrFd2B7F;A&T<7L zb<5KJDPVRnB2nD>=kN%VX9y_?{2~^!1Y*WN;+DY)0tOa=j*0+f?KgVw&Y)onR%=7! zHorkdsvac8?)ODN{2~xWk^8;J)F|L9G|V!0<~ddXIjVM#0o_i`>Xs1+c@f|kO1=Tb z4~3?dKqU`!_g;U0S9u*hL1AG|F63uputE_oX3Dm4n&u=CZ=;rpn-*Lg$@gwh>8@J2JxjHF-6-OR~v^!Fxg^F%CyaE$^K zV#|&r?ZI-q2~3j;yT*|#^zwlD8zz7P9~J}(JVbNfj3X6~X`C$7Tlv0kUgH}SQMog~Z(9T2i5E65E05pud zLzsD_2xwOXbZ8!jMAKNY*ltUV7H$AFoi6809_Enw&z#}RSYuY_KGt#F`rL;O19yKSrJYBPDmp@wUc;f1ecL*Jj&s3dX0KRR` z4fCh2xH~k3blV#_Ix0Qj9lVZF7y^x&{j~0kwp+11wY`h`vmnYCe;A10O>F{U)me-EV2y!_W}Ca3y*+0nu@oE=4m$!r@SpNM9`&d4RVwhwKr)$HoG#$ji|s=A zML_B1Zu~pxhd1I4c6T8Kce6dfn`S9$NNrCGMY#LmB=sPO-imtAt_Rm1u6=7;c38px zh_p};KPYJ9_Ee5Jw+H;-uIKo`K~)&JY}LFEoHb$&6k%W#M8;$i{o0s9N=KHW*st&;-_mQ$ATW}KFET51{1nk4k#p(lPHJDVuf<=|H1s+Xh5vAn zC-8j#WcIf=ciaWevI?Tr6V|y_8CW=eczPMXZRc0_Dlz!>ty4{_dm(TTPUlgLd404c zt1&C|xsFZ!a`5nxw;T9dFyEc)LPi7PGThV(75A`vp1G@WWPDV#lxIA9Sl3cU^wzmE zWn7+ZD5LJwK{N<+^gw=8)_AMoEj;W~pkKfnta3Y8ZGLcF)4OyTF?%1IZ_Omsq3REkt%lE=)p+Ws)@}JHL8WLg*a%p*D=p=u?J|>2Ilm;UTAdzht0H*a? z{&8e(Yj#>C1J6oCCl94qR^Qh+1c=2D7Uh{Mn>_=dg<)%k^?vk{dj`OG=VuZ-(Pj3l zs(G;Gb9ILL8`wGD^t2)cg+#q@IjP)F?OHWF=z>D04C= z$S}e1mr6nSTQsD^QAEo8n6bSM=fAZY`CGqkf|2k7X&O9T!@JG=Wzo=}m;FE|a&F}{ zVqw?zQhb0t&|0o8Sv-{{12->;9)ke$O(s7@NlKwGJ{2py)vLN)-5jbkeO66<{Gd8h)dq+HT8;^eeRRHg^tetHq#h752|q0z~b;xr1vP~usRO3;u?=Do@>l}nO?&! z(?*3FjAGWy&q7s3_1i?7!oa9m_nK-RK!N27sj@RFP9+*h4Ep@$NWk!~f9%A=;q*WC zZ7`R{vEo%ea>8FUTCS{WtX|=h%ZH^^f{`c zU-V>dURHPN6gRPmxt1N2do2*d|3Y-yx!G%}<61CD!ctqH`f@KW(M)4v4#(EIw^?LS zV)Y2de>NbC-~YO3TwACqaJowtO-8aarLX!V8JlaE2tDnk8HRlazN^*NDr3ZS2)>R9 zylb-wSRu;$2d`;|laZmAuV{7-tjL57h88EQ)dVVKTX;}x)C5jSUw%JtPs*9&WnpNt_!TqJJ4sCD+CqSZoF=QP z##WA{F#`S2wa?vZEmnKiVcuFT7c4dZPGEGgVVCljY(|mKrTY&pB9FgXL43{LoYZQk zF9rmxk^JI7y6PIVhOfN6FwtbuLKscm503ct6OLP;fjG2924U2@OREwoSk2@48lS~A zKgEa944!>#Ml(l00Bi703t$%DI)_A?7QbX-06=8&OsQ}inm2PghZGe)@V?yosGAm~ zWG3l96os}!)3zSSEHf|(0QBcx(6;^%3{Mx=SdDwf&xF(83>k{Vz<8cW5mvCrj?S7C z99(`DSVRGQ`Zp!JHDxsFD8DJ-6r#kY^jF+Rj2ijr_w#QW1dJA5Vofh*kpW?BFHk}BAS9$yI_U6>_^Jzsr2 z-M_A>JEsNzG0cEw@k=cNBTVuTc~0=#ibdr|x~J-GST?k=+G^%;QQL2n-#nr}0VwDz z@ltHGA*nASR%k4Pesl93>`*EqIdIkHWl3O7F9Vp{-rAmo7@>YD?|&9C%|<1G-vbYM z0=k{eubt@tsdr@2`o&1JM!oolJC5R@*U>6Q8Gp(o%IN zp%V}%2bL(B@b{JT6NyK}dXg`kLk+iiDol8Fu;E10rEd}?ud6qrC-F7q68+$UQL4v2 z{Jg!5+;7s#EOGRPJ1jtWb*Vwbt=+{XT1H?BfoXhRhIgq+?jurET`L=Q0Kib;qrzgO zhhJAy1$&&^?QF#5IiOYN(_9Z-yy=y) zUnzApbx$WS`+Mod7BuXe#?Ub6!tEDCUp)J`RR(PH9ij9gLCtIuCA!{tovSF2nUY_k zep6_+3@PFpW8jBa;XU6B?}_&qVV)Js{PC5(wr3~Og@PfLHBC&~-UL&eK2yM1lGI?& zDq3bzpF*4bMJMEDvJ!#xA$b}^UtT(J{?l<@$a9T(VuCAj@ll>&Sw(0n@O8>|s2m9k zgp|Yg#4-QCQ2mCNF8+mUStmq!6Fu*RT)r&8gKDLx?6~;6gEMTWuVuve$22h0*$-!jANdL+*-u8m`WY;`rCuk;^Jrhxh#g-i1=q z=0h*Q0!&3h{m%|79s?Zs27ed|mCYQ8j2wWB(vE>DArOuY{6tSFc`Tt++grih9lNvz z2$=Vicb#_`i9#_kQ{-K>Z@rP3BDfbd_&tkV7Xi*UMcasVHBw@7R(POTYJ=55RTcOh=+f2N%d~2=$cLN=0uTMi6>`B>M$me{2Cd5XXr48l_;OM6!@T zM`bIn4fjFWPZtBX1bpb=fo#DnC{WMSvb;K@+cNez_gu=xZL7GF845eN*^Hg zuVcwu^HfIWl6VJjJ>%j$?sO&nrLhSAI2(@Z%~!Sw@u((BINs~`#aRo)w_CTn;k{o} zsUxupy#xo0C7WJfnAM2x8y8hs7W*^g!&i>rnTEYd=m7eQWLHrZ6Qn-%D&-~VAw*p- zUd&wE5=|Gb)23+ApzB(@Cm%bwF5EH0wlRye)D7Ns--XA&;Csw}*}Lpmp; zxI|GNV#mw4YP3z3#0b16$<`s*()t9hNbqo&_Ovkbq@R;bLOd)QRxGKOIVCn6qU0aB zXNfLR3s+UXoa>r=FLHbi1$Py*8R-$_o zgBWf=HxMGev#Z-ix#Mz<@Hm@2a#W|q6jYAhcjo&VmZ0*13d)pgt^WFn<+nJmik_1@ zyle(T3jk0wDvnfD+%Tvi?#ol$7%J&5V_x53kS((*q}E%cx|Qz+G#kb#zoK}}`w9(> zLYx?I21knnfmKHeF4h03>j@XDcTrH6uF&elXI>d$eB^{mM=5AGM&4NB+kR%-^c%#h zYic#YP|$Q91C3COdOz8wR_`{^t&h420QkJUIt;hi>Cc@3F=G05^d|<@&!g2*csO^( zlM47hA+`6AV7v>u*+y;C*J$kAV-=Lp6j4vcOdGC4V#Xo9nH>>SK+uHdJU{w%TK6oQo9ncEFN@&YsK0oL8PsE^wy95wo5v ziVcD9GMQl!%r!_S#)QD~;M(xdkyoSHzTA#clR}PW7m|Aj_JXirM_XsPF?w^3Dmy`MmJ9D~@C>6oa?cf0)oO$JYN+gu<5=wW^H;oiEEFZzj{h)5`iFr}9|`tu#)T*#+DpEDY6q;7`sJLd*=rXzm27qVBIB z7T)_l?~m^OtTsct_?E;+tB!he+I#G$7?V<*J1OwO!8`@kASZEQdThT#A05 zV86hK1T6KZVb6ach{JucsDT}?6S!$I+I#;M5LcCb|A{hXey8+*VyA=ahn#V~l3JOA7ySKPrCgZoztSD`QN`exmS5q<1tD8w6 zsT3j6)5;nD!0URRlB+QfgJm4|E3+;n+MNWx*s&k9Zz4Nq;K4^hY83L1#zN(&t0}*h zGQqafdWA^qp&Hzy*pC;Ma-aIEpL0wc0)m~XBM`351ti3ys=44esuA8%IUFPB)HAy^ zVHO6@C8pG2bSu$UO_8^=fgoMS?nf5XKG~si5!N`B7yMgx1D>SB2*M)A8b=_%N&Fp*g#Je_(qZLJsYR1@BgEd->uxUpN_pJH(^z)}L=lRih zVwFP+f%mGf^{b@j=)4&;#^6Q>%!l_hnU<-pzZFP%^QbxyY{#B8>+PQL-=LBy7-8OU zvEA5MbuoF2>bZQvI!5T_z-E(6on%os1RSOE#3)6BOzclk;ioi~OT?1%nk#!IQcZK6 z4y;o<6SK$c^GC0cMHM2pQg?;2E+j;;LRh0g3S^gpA0M z+hxfB_>)afsr8igF!>JSPs`Xx6xV16>3WClXZ&%k^hiVVpL^PYJxik2C=$7x^2&sR zU1~6uh1fDaID2dj$uJYJJwu~y7Q`ne#;8b2El8f-bnqKO-H<^cwaH^7j#SRQ*6Idb zKfu@Eq^qK>(O4$$bAgY0t+Wwx>`dKzQZ@v+)&}7666N11D)Lb%P1_3sl<`FMv_wHPbFn4AG8bZXUpZFe5 zbBg3XJD759w43}a3~}!nUqw6-{T*_vM1oUJpCvnuQ;En#!2k!Onr2fK*mtE%J66&# z>)g$jzLDUUH~lbUZBGel+&G|$D0x(vg!m-9>Ie5!)EN`e+}ag=++2zZ`n)Z0lJqPa zUNP|0{JLta^`AlGvFb>JX-qd=YO;_}p>7|pgcc$_1pwSyo*B@DnmLp4pld&9VCAug zVtDg^jUykJw6SIW@qqNtHwlMLzPDNf03xcV&LGs$&cI=C5n@RmJe+dXhZi+9s1eGZ zoc>(d>`<7;WtBZ07j5FKG+TDAX|EEM2Ko8+F$&A&yu={a#e{y7$VyEB^>FL zG796m zByJ*MesEWwD+k>G4WEq_Cr@2?mE7eb2l0@v7P(bA zy6MH&L+jZTnvlo{d3&s`BLA=Dx`w2X76mX$wNHT5%t??{yNw|b^mcSgk$5XHu_p^@ ztXZaq{~oK6r{fb!M&r=Whzio86ha6C-IbEuZnQe!~xUy`fUGVkd+v=5CC3kYw&jK+Z`dmbZ80jT{0J zp6H3t`j=gC`7S(b;)7e0(o@~lv3$&Pv;udYNxCdL2u$x8V`X5Tbn4!j&T4*+&*7AM z_%gKwhb%Q4K;#YfU6X?}JRNZ(gg91zjymxtr0gYNqXxgTZsHeRA2Q_+>87*&*d*O9 zL=;7n7No(R%4m@yoO`8o^+E+s(567>Jm1a!<4k#?ZU}u-8316R3Du%|(>m2!%T5pM zO2ODKcB-dB<;w{PRy4mEfZL&kL^}ohYxMxHhLhpbJyT-)7yzz^!S^P8pV80Wdhnp7 zJT^^<9mG;}%_Z0fbkfmGaj1$(gJ1EJb6=95(i0{-h7G1YMk>x;)s%$=ZJ`#tBo;U% zZLCxu`KtNqE?_AfX|S~5jnGZh%B>^M6`~iE1cLa!r5X*PgtoQyXukUjCku?zFIotj znN$RZfTzCfqHl0X)<)w0hM!!&+!)?x-RIy>mBD(!s`nf}a|110oH%asx3@L%&z@SAhOiZxB)lASj0M_NR(8J zEX_Ehqk6%1+|Bu$2nS$^K1j;li5mBk8uOK%66$l`4nw77`xUnuj4+`-fGn=-TAtxo z_Zzi{Pf5q`hfiOjN>>!7x|uYWMnhYw2$qn(tM7;WL~I_J!F+n@kq2*iOvtSCR&-5B zG{t2j;Ba3u;^{IX^RedqL_DkkRo$}@vpY^> zvLnlGziO`MaKpkYjCbp+%2F^eLi+Uu5uCQND9mNT{Du)$ZS&__(qIscM?_8t4Ug!r zx-_B%U~xBL0a1IB1Org>+OoZn#7~QGP9SzBYu8kQJRrO=@@IWOjkc4n8^H-c2`les zLvYkcFCLb3;kYplYQgv@`l*~24>0}+voUma60uPrEaUfH0BQLFthWhru^F4`Di$^k zoz;o9e2QK+nK;^ueEqWOW^beb1{Qw~b%cN!@V(U!tcQ6_e18v5w$uwrpgK(}<>6bM zsSy(+xP61#D^4|?q6Uuf58DD&*Y*Mc+NZ&Ix0db4YFwl3xwreS5)J=u0+UsV`#G@@ z|H4Hr2z~#S-=C@)gQv8ct60wEpm?d%3X4d0efL82{O~XyT>mwJvA>t)-Ml6Wd%ojd zs(pS|F)=9x`7@_It_)n$rqn`KBS=`Q9dXmwKVHc}HZcsr-4s~6HOwfWA7M{qmT)MjFG&{2n+gk>*(eSelN)=iQH4}tP4duaK4Wd(sJGxY};(j+mW zr_SP2$KccBTTz;v$BT25Gd4BWLKi_UJy4yQsA8476`OCC;7^F=wB#W}X{eQhGTu?dy5uat^N5~``2cV<3IjEsA1mbtSmO_H*b zPy+Ac4aW;DRy*%VcyQq&kZgF*h~4KUf-;j}xB5T!=XI%Rg_Ps=yP-G$57ZXgi z`%Ydh))ALuXon1U%J;b1pQRYJ+AyJx7A>UTgY$wORa)&lQJ=FHY#^=K6nKO5ns|hN zXP=;QjKTlhGnvvtjbF}tGS|mu;zLDHc4KIpv(23#9w1mZTCn%q^7*E)%X)5&39Hru z-sUbXcwt&<_^H0@4+bhKO&*ST+NEG=GL>aJHC2moH(MQZ@~}ud3ECUjh<$bBzx#bi z?7cvC{8)SRu@|{tMHdYpX5e_|eyJe9mxVA$eQ5Ql<%$c_&%FkH)hC9`V&lfY|)N@ib zV_u!2K!~htfyg!GGY0S${P~r57tKgNC`gPnypW;C@E5y+VT%md^v(>^5Ar!}B|hw& z)V8*mw@?H6l9+BZm1NlQuOL+t9oDU4=XH#|6RCMAwz{{;*#om82TV2sFO!j&B*#f)IG=yv%JJ<^b> z&h7m%eEIC;TAXV3RB02^@d^>TuPKz`VbW%!u$bw)*&W~i9s{FDLHy^#(j_;YwxLtD zCZ5n6JHA|m)?AiiI-iPlgfpDp|1e^|}mp!S|$RN2+9+LKh4Y zh#a;b>qBms4Qy%`PNX#OZ(m9m`Y=$qW(6bA{(6gA;E>gWRJcH6z>;d1T@h!ZDvo2*0ju zO~H^8AP%7ZE67*5ql?(aF73nx>BfOGY6UeLo}tNDkx6a+w!yjw&Ew->Rdxc%;HFQ`hcw<<_R31Gj4-{-%*7ZS>L=R&?SRPLo_FrK=biiB z-xIrjvA*@KwYE5C%XLk8>Ep{B?(wuaOP_!HU=G#!Am@Jf+l8V7qS|-m^#$sVq;W({ z^5r)9IAE0d*oT6s?}cH^Q8IjQ-Qc`6w~-K)6O!irZ&lP-0mNOJ$vgYQQ(A61Go4u2 zkG)y_fnrG>rW=*`L@15D*?ux}ns@5CRBo%p{1j^%ht=B|R~iW939*fWD524<8#X_AB21q{=I{0ea$+WafVCyp|%#XjemHcmBmX=t0v;WF@k)(zY_ zXk-}6%cH?!Qj0zQ>i-3d#z}bkv)uaxLxHg&moA4`ZR`E z-En}xJ2vf*Eo0RND3i zJi<$M!`N_bM;o&!g1j&ux<^|HccL5B%yTcUvr>%blD}vFV5DF>G|vgb6$l&vQ+-Z{ zo_qHaWxf$`PVJL2InI^!44|X`qGYq0Wi^Ge)a~xK|(+$oyn*CQ6 z8Fs)=@}IOT->wP`=U1Pzk+4Lthku}l8cH*ZU-t9#{Jdi(TgI@c?~0ee;Hy{B3Ca#w zKA~`s;^Ov`SN(`(*QMj_s@tC~m^sW4{tXyY+2mllps=b}w*l>0F|S9%gZg3*f!#Ty zU-cJI8@*{IeZ{%SJJ>Gc>ENR^6RybDbK?#Hz**`AFC3~;MDpQOR=(h2KDIW`GdI=f z%^n5TGB?qOQL84Iuqk{Hv~#=~&5D@wT>x;TD?qT@GOZEz&R`PH3)7DJtJeW*wccJ% zPPyhyIv6DNK`zp_;KWo|F~a`okm-3srk#H6&RU3EuJZx*Re9JOtf#a(Z8#t9vpyH3 z8^IfTrPI%<2+%T5RAfs&W~~sV*II(&&+^NO%e_mcd}tH)@xDA5xe9}k*WMs4p200~ zu>!pUpX+mm`~3pZ`3F9rvc^YLUWZm|@?BzwT^NZSiw0HI_*P4971~4-{LC(1`&@IE zN(0@c-z3b^Z=7=QD!VN`jB4AjYmGFD?6>db6VvIl-Gu zWnqXp7vV|hdZFz}G$4`I-kW5;`L$JE#P{{^73Q*0`{Mx~YsN0&n~jKeB^W3Dy0e;y zGbk+zoFnS7mv?=u4E~yw1|I!B-aF+=bErV{mvEy=$~6|8 zODLTu_5LNn=i+f``*`pTVP11Ld%0A`jE9yT2}QtfCN9oUivXqvdikT(2(@)Az0P0p zBuM(>_}4ETESAk%USO0;F&BP}9%Uk=mzP`WQfUl^NhyV7sdBtAcoL;xPq}=c{az}B z$-#h~^fl9a-7?dhYFCa$L$-iNEUFpuG4(q7Tp9{tW49}(vJIvIxu&K~z0duNq+l!# z_~msyuIvsgYy^L67PAs!nN;b&c7J-MIxH{4{0UVYx^J5IwB+vz+^t9p*2?lub@|hL zmToM`T>O-;0#zD+2)fq#Rh5n?U^Ot%9dcQ(#Z!(|MaFz5NK8tDXh~B>5UL*(G*b2y zO1)tx&181KFFPxw9{fnsoxM1DN8(6N`_r|PpFP4mW}Cg_7K*fCskro`Qb$ey`GpF|na z%P9Q}=B_Rj>&j>?xI2&f)Uh6Vj1EFv3E-1AT?ta&!0!aI8ZU#OO^%G*`U(3E{ zE1kZtu@CLkTW42UCsjL;7smS}WbGiLe;4(IK8>L!EK4|^nm~y?O_U+d5s*3A@K@nV{=~mWF-5wQ>;R1&ii-ZW*=E+vh z081P+jk~8_EuzxNlRoVR1_ZgS-BIOUEv=dlD-y_PLbaxDlk8TG(6!j8%A=E^YnY;y z2&0ESYEUZq)IPhSSOiQBm{+f*JW3air|{)>fct%(=`&5-dr*9hQ;N^< zA@-xpCk2@yD}!umFZa#LDrDt*iNm-kKYH7P8p`i?n`O;U{OgS?V(T+zT8D)sPUBof zrq0QiJzn?ihc$N>SBg?$w8Y_}c%S?#;_$Eg3tr}W6n#MyNx)-VUG!;KMUs7Evflnelc1V$R=Q?fXWY14nD=d+YN(ZK{e*E_L~^3p z20r^BK@&rABTb&pW_Mr3e;K^Vq$iESnXh=8_^o3T-OEtBub*bLK$$Z`nQ3C$fcHHS zM_AF$8xuL?Jja;OYQ6~XeZPlD@qH1;W`g(Z9pglq9nKq-yD7+X&CfmR;zS2nR|LHk z(_bbG`nxj^&+Mb0&y%DK$fB&>Gd*JPJY#qha#u3q~_0+nA}}S&Qm)omtAQWoge= ziD08(Qq2+>Pm%Yifda?OO;&8b`0oneav?)}DLArHxLcGR#q-p%&q(#NBSLV;`y>sDiIhIdDlTEOrC zldw=DQB$!7p25V5`~ZX+@SeZ6Kll1skw<7KiL ze4qSjD=Uo$9vj-IoHNss#sZq2&Mw=x(sILi)hWgmTEgo++hJ4#q5 z8s(Q${R+MDU#w1UnOqbW=Nm5jiP1l5*GucUXbPcW zx5lsnOywz$fwwoE(t?Z_=gOin#BN=?F_A&A64KK5Vg-clr-s!$IsN!YxQXSq?Zn^0 z|7gCn!t{Edr#&T@)r!Q2<%AnpjfrkayJKDn9UsTzxm%|9zw!#d?ELW1tg1r1rL}Y| z+n6uYa&+ zzVcL}S|)stXLnTJqJCS=Gt0HR_kOcmN8BJ1JY7x$iDIh_M;B09seDuSxjgmh1h1z) ziZ^9z93Hu|i6=Twn{uUPU)%BiVdwo-$3;IbH-y)eaY-Xe>*LJvTJe3^6hrXGeE`y? z!v94%_61VC5Xo}wQPoczy> z30-FN<^Wf@Q!)2#cTZag)lNFt>#IRya07gsT7qgVcCq=hyLC;+Im;Xq>D3%7UQY%!~r|kc_${~V^sC-O0FSp*#-qw-!XPuY02K}cl zIkmexFGuB-^?`(oQ_#Pgq>PX_eE)w%YLo;bIGBHb9r)+Xy?qWgpgW@nrBh$^Knyf@ zm+`G&GvzM883_3T_v(CYrlHksh5+w~RS(^yc5O z(pJIF=<7K8?IIrkuN;qm5gPv);4zB_e8o7ING4mg(uKHsNrl~m?7w~d-_CgT6J+)$ zuMHzqEejd^!(w#W#@oB~JEW;$tSpc4+T#D`M4${OGU;; zg6|g9{H@sgQJnno26weJ7i-?;TcR`bwT|rh>b&_95&2HKb#N|`c=o&NwhhvmYsJ_( zkp{8yt8}Z*;YH)bL_IJL@uO#XLlV5&IR$>6@^m{x%;9GbPAI!tW3|CQYJYJr<~w&U z)kxVv*|~~#f|#3UV%b|1DI-_07k~7qS}L@4se))n-ESOM%f686|C_A;71?q~0~%Sr z8#(W%Zd?0Kqh@^GfnhMba zI>=9dEQyMui_Xyx_KL4uKoW)ckbkh7|_-CoM=}g2~v}h16emYaz zd9Hyrm3J3)-jB!*GT%$@Ix41MNeM;jk?XRLBzsI2>4FA%>EEe1jYggDQB zE&E?=wL0dcgtosALkOtaGBf_rW4q8iO(9=OdwFUp@MVA7@w$6$z9ZF#f9qDI7CU~Z z;aP5qkE_Orn9g9JdakZ;;{8B9lYl7YH%HNhIb`^uER4Mu?frK#g_qMu64$j(b6`ro z+EScM86AG=FTEQcSM_whrGd`}R5t5y@H}?*sWP*=D)usLjyzZPQ5wjAy+3Bp%&XED z(uMHNg;HjpDw~lTT(n(y@x@o2%xcW}D5U)?UKI0rhW6rqb#|cYGlJD;hI0}_F9Ty{1vSvQWUEOqC3 zN+@!=?bu9_BB*)QxiFZ$6@hCpG=5>Sg<7XhWgfb0jl#cdrEgrORoGI*%J#nVo!H7E z$PbYGBG%%$iT}rP&ZkfBs`o-*Q>Gpxi1IyWyFFi$F3jFKI8f%DOy?X1Xqy zoXoc8?uLdZoJ;mZR}RJ!5MZ9aBN%ssEn9i=aKrY zLeg;IKil{5&mh#7S1W4E_fGmi^q<&&UwTDBe=qy=RlSdjTi3q_9N~4G`qojiIk%7j z2Wggdj=N{wAw?#6oIoIFo&_K&tki&Lk1V&tpCor*vAbzw5hM>^vIrp;s?3)3brCqo zC*dj7OD}#N$h|6UaQr*z+XqJOdRg$i^Fyvaw=%#%>b8+w*)q?($o{1VhB?46G_Z16 z&*`unOXGkoJIZB8VaAvqx_ZRlrO3EB9A4tDE5QjGG6*DIBH`T&S&5%?iyZS zw$<9@fb_;(=%gla^s;d~dwr1DeI_CfEOaA)djur)vc|Wtg-`l~hf`YY32EgCptwaV z6!yAvHT?9zZZY^L(+Nhi>)V%&y-fXSBeQw!+ZZf#%l^7AkDG{u&=32tG77&KBz4)^ z!FR5>&YP=)e%J$jp8~Q7BBA?_j`x4N7oqdmQ z&iQkE^4GTJAYSpxNPTXFa>$p7C*U>%#Y1`E+>VebIUz zMr%UremmDpeI1jyuswK5=`omgV+-n51S3(7EFyb*{cAax0tE7vYMmSd;TTD_L#e=V zi`s#>Xp$APbtav;ILU9&vvA^{uYCx7Yf_sW^|7<3)dn2GV1LN92pZPz--I5A0#cv6 zCY&TmHx_hrQBb#T*(jmZ9LpL6AKAi6g+F_YQMYwV|4-5LNtkNQk!AR1k0M>3WyCDh zaIlu6?@m&fqW|yeBnnLDS09r+68!j+qAQ4ip*gTeOz=Y-^`O8^^D~DY5BtinkzOGh z^u1*s;-Lj)12#sEc6*BBKNXQ|qkm$QIYxiy5LYVzdjJ0dfbB2E+cSP8`>bnjQI?~6 z6x6yg*8Pv`V-N{X(YT(Yck+7OrEzHKwwE2D)EX0xqc5&?nScJNN_-x#AI9# zE-57BlLg{aETZ3y9(w*@LIsJDY=qtyxCxHo?1?(V8?(QiJ#5lfI5V#u1$<_Tc+y@dBaz-n1${*%N?4|9k5E z-c=7P{n+>Is{u`Pzeg&1fQV8ny z#lT*NOzh8IH(Ll9?@>X-D*@@FYAkQ02@ksQz9C3GPj9$o)OQFBE_02{acc2#t|o+R zKT5+VbzouRM@PSFNqHb9zSftA3-{O9Hb3o_FvV5|mwn!8SUf|vyUUib5lW$am}mAW zY6v>S`*#DrLEf(N+Xi6xuDc5{YA3aS1c6`^+O~87O6{$b_<>;?nHdgB?X$XHl;mh( zNn`Jm5eY*!9**QN9`bnE0};a}9I}~ficRx)ZEie&Y|tq9;t~GxFVJWLg2Y22C49y- z$JOf~Tf@)Y(|W%nXJ10C4kCVO|8qPFq3wClP({>d{!ijx(<9#*Z&ZA&hxctC#~YmUwQbbgIThb%Jj%6F zXEfBC?_T4poO5`#li78HbHXmi`x~q6*);AuXDnxCMl4WwOLpZBtJjIK;+lFk?KuDS zmaAyedj^F0?UAu_rmAzc0vL%MIa$;8(mW#syL)cex0&*VlEo^sAhj~Mky<25zhB5t zf7g`&Wf+>IE&?s;P)r}fh_%kRc)i}yx(68hEpH7MCVZ>`Sre3WUWPh{xYoBoMih|k zHIi-k7S279`MFtDSa)YPs?Mrkc`=TGCpC_-=bvZv1PAFAtSp=VV_MR=bJrbEck@|1 zwoUzf7U#6FR|ixyDBXvDtlR>O%z##&i#5}u(^&t${8m5lGh7531OmNSTx^4( z`J!aty4w8LBxq-bpi#Pu7BsXL?iifOexxdB z{)WQH_5f*;@0P78sMAjcW~fVxpi$l5{_Tj~soU02$yS~tc@Y5YJfIQlzduO6xj?}b z6Fe!jv{WXo<=+wgm!T#*hMZ4BPwG`(ocJ24xOV~j(&JaD2ojS?AEEzUSK8CkRJ=fQ z*zpzF^0LjX3+tNvnP`4yYvR=t3!JxWf){b`B|p%uXEQD6n*0zNHE}l*;0YO-DkP34 zgzP9EapWI|Dx1q;eWb+x2m|?Li<`QJ-#0 z_v*1qo<0$VUQ>$_gkMOl4>~2vw8;1Ud995K@Prfsy! z7ib-+BhI*1B4KsNh+-OOQdFhghi9B1H6dGcWV^B10&YO$MhzG|1HqNUHkG8M0uv-# zfxcV>SoV=%E$HfrHrtl0qG?3Q7_M7XkRORkl|BxIi;l++WS-ZWN&eBb4RIeg1#38GLl; zO(6*WJ2_!Gr9Q@s#ls{Jh$Fbi$;1x1jmAX>-3@J~dY5mYvQ%^|V<8TQQgte6q>fe| z^at}AZl#Fsc6@I6hX0n)`gV0tPVt!H!0kF(oij;v_7W~g-4X>njE)){d^w?toyH|q%i*3C zjHW>!OxnLe(rhFcf{XhAS%wpOzMF9M27@6P*Z2tI8|3qj=o8i{WftjG#jWXE=3EMs zmf~bettQEh2JgScGd#~R(2OeO{VbIGFRnGiM6j!>WZsFTMYN}eL?z<>B7fUa0alIk zEwU2-t-Zg@CgV1@K<%=I)5Ughs^I%(`c%FF-s3TbnOrIfw!1pf%6mzGt)9!nFlg+^k=+urHH}deGzCr4i`Kw*F*ACQ$AJ1GPv4Fxya8-p`O73_1 zOOm)urSYGQWbYq23&!tXh6-1aqIEo*CJJk~-+0~iblJTs3m;A^{R8CAg zV!E)7nCFj6B_YJQ#f#KgN3@~-CEA%BR+?hDmql&vu!HFB^2vwg*MgKww#SgsaH?9| ze^PAX?=#G4z*MBpV27g^%YUz(@1{3bww@LSU&rNB&Z|GYrP!JH^>H~y(}Omq1Pum- znx9N}$Uw|sgI`7{8^aK;!_QBqorC9IM8@(kRzX!;gB;d06n8x_Hwe*boF0ZRis`!T zw>ET(t>C6@ia{D3l86BANP@p4IKG&&!mq?P(mpGWy*kgN-1${>>Yy0)%4lLF`3B=3 z!24X9`iwhbO&mPOw=LwbJ$^fotINGIY8(TYeg~kFR%iOvB)Kls5JHcM2d5SuY6wQ_ z4JSH{Xb8^sLl3I5Y8ADrx$MqjCzg7EyJbInf-;Jtbw@h26v zd5em|9~%GmvKPE_=FW>%8=I$OE`sc_q}?&GOdpEsbBy2!vyEW*&XLjc3fN4Kf-fS? zxy}nrMQiq0EuC9(3rMEXn;`hJ*8-pEdL)`*QO^MVZB6@2M(8fBb9?4Xxf3f>(0z}{ zTYP6*t8U4KrhJyTIX=o@eVLPIyFoU9C+f0w8)y6yKC|oNY*s(Nup77$C~}}R?pWxZ zBytxn4Bel{CtaEQqOe?TQF$ozy{8vn_S>l*m<_{+4L7D-0p7gl zc@MhVJt7|ZuxAeNsp8Z_0bL8S7jXTOl589g`cXl~aeZl0ar ziP+J(jx0r6E9>>-=_LSQJWvQMO#}eWM5(L-33DdrqzafP_yu|{4#nLzY79is@oXm0 z3JQ(F_ZxMls(;O8c8`t5<*<6KfP+kgJXP;7HW2NJd7T$6Xwc$Xux$5w(_bCrRA~L1 z4sGq$u@$=|s=+tg^WK;iU1DI7Jm`|8HLRVSZg8B&3-5CP)`@uos1wxA-_;f|tZYR#Y3|D0~%n zb6^7M-n+Y*Dy%&gBOjglfiL4U)E;fmMJJ^fz|DfI2fhPkO}tQHUKx5BYVV#lML*h@Q&XtyPpGlnVR4XEx>cTXEcYO06hX;zlJv&V5Gs{ z@5E)xE&$TQhr{U%`}5avxV-H1hJ{U9hedYU>8T)$J2cKTt`QbwL4i!8c^8!BL7q8} zX{jK9%IO7_4w5zPBMI>jLhg5-Q@AvpeP`%1&9kkc0%g=4+Vs%+d7~!M-f$FyAq8K@ zL;ys5CSWY;$hHsVp@)vn?CV+ax+*jE8K*r>KP4>;y!70tIs!Z3d4$))oDO91pknoc z�?T9v=2i_v|G-^l1=+zZwUP>+4Nd&y0-{oQ~(pDW&bfkP&J)F8QyA&@FB#0PJ6-u5e!XJ&PJvzllrq z{It|Ari|Q6&c68>iEVl;WM84%zk@z1TkZxm5YB!(exVeZRZo|~!l#uwN;qeceS#8{ z={?}V{Bqdqv;BL`aXS+wSik9Q<(=s5!6lfw%1WX+Gc>~m)TA}G-$sa+#^O03gn9a7 zXvMc$BF}f-FWd4C0N}zjZY3s*I6z|63J`-67@X%Fo(W3&VuZp7#OopOGNVQU!t0#3q7Jcb zift9p+}AJ&{)^PhHzQ{{5R(oQvR>;&8g`9=4K$eQjIW&YJE=6WM(u5U^YKh;C!R>U z8ewQz3ZeqCwH6w>Wmwz5v;>$jqjUy6+qba&RYmPyr}%vfD|rJS9zJ0`xlT-1PeTo) zzl&%&xoav7{*U%S#Q{Oalb394F}TDpV?v2}C~cSzmN%1tPus)8+5?~U4JGTNXu0J4_wsS$IvT3~s8 zbRolk27=r4Sk$al@hH8}3Jw&GvIjSr0)I}1NdA;!qE(XkH+$jVPWsET zqB)3$KE?U@+zuF{eiqE<)!q+o-=d(doNaX9e{%im;9{NusngOXRebWxcWE41;PN%0 zwFJl3GO8p+IfPkP*!lGeboz3B7UqwvPGqo#KpIjX6*)af@M@F?yTTJ}*VGlJ-Rw_V z*`)S(6p>&1|0uR_Qqu0YnSXL`plH2VtNJoZ`(`gO#~voJ>nfEh&D`gNR*K;O?c9TJ zqm69OdcrTz+2xrR9Tll>(O^F`vHv4);=M*MWv zd}3}L{>&Vo_8D&zw-I9R5GMEkhlOe)^OX@z5p*dUYJritrubjh}_lLeV^3*UkK6R1pPoq+E#Lq%IgYws6b*oJUrhIs2 zOtJRMGVkw6+S$W~N4)F=DgfBg-%e!(N!@45aQbQ4c(WL_#CaW4Hu!Amh?Mtv>(~o7 z7pC5|Eep@qSzSv2*-?F391k`bzb#xBNAx~eS-G1oMA{`k&MU6lDYqA@d9ttnigkwV zT|<#ynQn3#PL(=@dS}kt{8M3?WK){fW&C;WkC+iAxJ6RWtQEh2zMMY!HH;oqblFp> zqdj^0Xu^#w8;{kV>&}Pz&y)f)^1j4dD;bEqk3aFn6(bvZMu|0MKZCgq=rVnbZ8Z^Pg$?HfBG?=MN=DLAn5n14edV4}j9ypWsthxG05~cQZ*xmfSHAaX(daEgv(ml+u<~%gBMg zL+xFk(r~j3KB@E-kTMqar*@ir!D}xwUp#(*wR6@SCQ(V2a+iM7wH^jTS-yplx(Zac zFJyir5T;dZDhQBn5fu_u=8k)@s~#fusmqd!Og?PANQFN?Xzu%R+~#LMm8*gG-hrd4 z(*f_PHI*%2jnzot=hkU&&pv49XCvcu=W?vRqh?C{8$9?7#ZcY-poq*^YOTWD#GQ4S zrlau|Swg`27UAu+T{(Qk@A0(1{c4h)eFb!(Nj^-Ql69O z%ASd$DfcD~V(y>H&w5<^+*;N9QZvzNApXeB@WYdiFE2sqKhN0f7S*0o32Sfln;W{I z8}kHwo_hfJ25-zNfD0;fp8z9o1T|htWr;NMDXbqRmaB7eZaRQ)(HyMkq1>YBZiwjs z{v_Ko?KL!`*QRA(-ig(cRrMfix6#J!x$gavjR0!FR4Qy+^ITkZkvFy024J0it!Oc$ zaFAFY$p_)C6?HJ`^&(Dj_~fWP4K)K8?DWw8rR=k4%g481Du9iwt%141PlGCsM81C{ zIzuEZ{fHt;RelVCEZ^GCMW4WNTX95)&tX=o7DidYYBbsMaEfro0Kguo&hDA@@yrHA zCUY%`k90Rd2Iumz6v!0?;F7AM!T}lxL4w=n&Lx(Lu#~&>0-#UD0s&l9_`ILiWxnSd ztt<=Rs+(EBK`!b0iFh5r?2kLJhYIW3?!UL;HIDWrq=r0Jw|RfJOTH1knABAm7O#Pk zAg&m`2K#mZMh0r4`|T<0WJ)L6`r21acg@~q5)l5x%_1aB^!7!-g~Q$i%(jFXvvfbU zo2)JP+F4=AL})&-1RZ~%V|#(^YQB+EWXaX7ZW87s+IIDB{{GJh(p{MrBv#h+hxZsB=vu(3H19a}Bx zerYLS-(GzXF3AYx?X{S>2X4oK7F{%eol6c*Cq2JjL|{0XtxR0U6?Tw8wjcANOXT0m zfstR17up!m#O2sc760m{esw5G5sMypl+P`_ak&h^-%xgViUU3nndC6~{Yw4tBknU4 zMKor}ixRYNUa8NtGsRzRjy4L^)-aPq5a0=JZ{vs?&H+a5S9@c~-~o|ezJ2}&Idvli z-ymhu1)Gla)L0;pgb}Zo$Y`+0`zT9MT*zzmEO#k7d0Yq#%BJ(t26=xLvEac^-7Xll zm2sWc$O1t9X+#A|rn=%Zh$oTOwUT>l2t~1T?k_On{G_`JSWHgNxu^}6fwtuj;VTod z%HA2D|EOfzZKc|cqx-SRa|zJ6R-PX(X={kQ)$)Y{4E^EgOE*F^2dFyRp_L&nH3@ll z@eXT%)N4;S$&vsrG_Gl*7qXEAklwjIP{5mvjR!Z$SNJ%6LuU!v`d=T+YKwl9LcjnK z&ysKvAP+AoFb}sZkg0mms&)5@1sqrS_%)#;d_p7a#W(yfgeN3~yxAU1)H3f++!u76 z!VD7@@yo|4wbBN*0KE9T%1BXglXvGaGex9hf%<2Rqe(@kRK=FmG*aaTl6XHfCi+JX z3lh?8?sA|-fC|*de{sG^;2XnF?r6L_?EX7(l?Mba#XplM?o{cJs#$oV*L16= zkbBwiBDa%JO2B_l=1DV_;;v07G}Dd*Tb~N#$x8$v3vl0=_hO>2mgF@X(0E?y(PfDb z24Ar5bgF722`ejRW;m*zO3_23ATZi}_X2qm0~IOK)x-xyNDV|O0YpK8E1bBmQG=c$ z%bi5-bF)00-{1_9kp?4E9Ix+|^mFIc7~dZ_HNN|C#2j2ZaG_f?b3fCm(`Q0y<=MAc zd9fxVYrO+3G7ASE-mUlN0;`5Y#z!lWCoM1*f^U6fz`8=d4x0d8EfF^3M&cnnEL~?a z1A%ztDGuQ=6F@j%(1NysZO1&ruq>*XLXw6_93;X%wxdxz|409Ul#IXrZ?A@zU>}j>ha+^gNF6NEm(nQ-sf; ztDoq}Jv;jP?^cn$+r0CTJle!Z_SGrF#xadU6EN6F%tDp?d1ZGZX1IhzijdbOA-Z31 zmt%l(a*(fceHa3quhND=iQdU5&}VDrJ}tu!v?tKMk+aOTx7RjAt}XQk%gAi1;RQrJ z9N|P%j!{qxF~WmPF*5y3;Vt00@c44sdQ=D^9HPyGetjzU$$>&rak;G}0WgK@>3!Wln9>5)epp)@j?NceCWtp##?Ldj93!Y7C*ViL)tR0R z1e00^S%(&;`GL)?5ZW;?hGmpFh@|A9a5w1`x5hyfZ0!FJ8zn}#M z6Fowwu@WzFg@@(qcHgaiQ^7&KtUi$x(vLdugYB)InrAtDVs6 z*1F{=pnLO*?muStZ0SGcV_lr0@=OiQ>2C6hE)LR2xD|MNui1U@CoCTCn=+Et1{IuO zHo3V@AcR|(&ym1{=5SuV?lPId0!jVV(c6LfXr5q)DmL0k zmzFF;{tlVAd=Rjuf&kb<=IB|T>X-P@2(Dd_)ai1NAtiW%^SI?^8F<7PX-(W;fmyp^ zJogxlD_PlLWHEL;+r(*Wyvw5)J7AIqN;Pun{{!{z5|z-NFCnvehkLdyMs(rhqqv$4 zXk7l%qLf^=GXnYQA)RZ^FR~W}(K?u9ggLuOyTgILQ@FLgZTIrM4Q)XXdI=>TOoAIc zLcX_3mgyH}xr`qk0GES8b-#qU!lgocSwhwS#|+zB~fWzTn90aQKZj-zxBclpPo$<0@$%V}rqek3sU-WUjskIm=#ne{bg+E2pMiAyN60%_R46Ffx>us%%3OTe0|gua)RnxTR)fDbDz`Z^gAF zJrshgS{|g%0m19_c{k%@FI+zz;ege(aa&bXSGw3N&%$qET6e!$$*A|EW-MZ%RnP8s zb<4rgY^%+s_`>O;L(|jb*ayc#LI|r*xh#D0U=XR(aIUC#*Fntb$9-#|c)>O$PQSD9 zkyOFuajzmg4abnjf!Q`x$*=esq{E3EKybAL!%JkEKoMs*02FZv4Zytg|J{H+l(=vc z8@%)5+E$kV0Bp9mTu9-|+?-c|e0%h^0zx-(WP`-o0agBVpd|y=$xU1oDSpvkY~aA>&@OY z7#ymy>?t}VTUu}R6bG5>ajqFwg2r_oo)@t7W%GBY;d2>1?IDcyYkbt$dB@77%~rd5 z8kwV+F5mvo|7<-T68_7QRCt`lXtDs|+%lfaH(R-`b>w2b!0!qINlhlp%XQOA6r>TD z*us)jTJ=@)4adG9{{nj|Ky7Eegg4YzJu=e>mX|P9EKq98ke6yHNS_Yde=R_766;)W zv#Ea*ITcAUw!jb)w1$%3EXJ?F<_Ix`7~P4*7mZkOggG=guQ+$s%JTBp+VzdB)jHZL z**q<->l!RgCWM3soLaufC^2Z4Yh5so4iQe;>7?J8M$ATI{))l}lG&y7aB1F)#=mYo zH@lh|BhJ(fG@kGW=H=t@a3A2Ymf4Z&v!kg-95e4+gy-MkAF?~hHI4usHf%bn7ZKW5 zen3#=-x*kP(s(m8WuYt%0q1-o0wvo+1XQ4?_aMC)wg=B3a@c-_gam7*Wq`_NPKd?h zI-LXC?)BEXtJLZ=7+DC;K-GU+a@dx5!9KX2j)4!xRhnidwF`PLvHdY3B3RXj_0Kc| z)cpF8{f;~T_MggPaD1m#xuqP4hx><|mjd8`u3Y`&WwkXndZ+`Y7;2l(0_}v2M5OX> zmkg*|WAZ!mGC|ln5X?{+HC5?_mK_LG%fi$U2tT56HOEV%>>6ieDv+_I{pxwF^+mY> zU6#CPMM`i^*J=Hd$x1DSzT@IZuty)OA-}GO?aW~?x#xVt#^>b-tr4I-QkhV0q>bPn za?>LOs0)ecdk?z@S?(7okUPiwv%PJqpf# z;GOwKU_&O8778~=4LW<66T*MgP;_qB_CsiaFL{|62O@#{3wHTZ`V&|$r;v#Jx=kA{ zt_qPRk|IJ2`6ZsDmP^|LLjM>A`n6#0kvNC)9&NZ8UQRu<{TzL0d)kC;tJL7KBE++k zvePX6^9?VhpMcD_VfL?Qm;|5Tt2x}v(NEez-#EZ??+eD>Q}|=euT>)3^7RVSA;49Y3%L^0~wYeOW@uMa?J^7YzxJs%zWE#%E-wl zn`Ak|=P@;&B#)3qdDZMu&juSb`+vNqS5P=wWxkn@qWt{HCQP9cWQ@~II18tPU?A}} zrMn|5m+2lM3R1YbbcojIlMwiU{CE~rB>}Y7*c2p1Dl@-3lRwl8w!$RvkPK+ZuwGV6 ziPZ<#!HhUo8|x zf&NmDA>=mAUbW)Bf+D2}U02$1WoP5Sid>J0HmYgyz{GyhJ@j~9R%iGGe-qsT1K;KK zWJ^dp_G<|!S;fJhK&RQVmG23-5W7oS3AWZ1H+Id-+oI=)QqPB$!=H5|D%Q!7Of}5} zjxqx95vJ$gdF-Szx0op`);@2&z=Xtxs1*+oSrvWQSLEE!K||okL@N^R^K@M~z0Z7^ zHL1qz#;pCS|5N-X<|7H1z!{D>SP3rKqR8_I60mL^<%j|CQ%PZVz|1nt(fC4{)O7&I zMqrSe8CY~av4q}0{FIWEz6zQ>@EoEAl*^(8oM>2P2j|T|clyYVu<+>%JRXUUh&sYS zW>@@z&n;p`@B=c?Pdlu{_#cgPoSq?g1|{Cx6S3Z6K0>#cgd-gjjfvzcv*4eTd+sqF zW#S$TYR)4#(TPH&8%=VLk zbWI%eezeqpRwt5-AC0qw))Of}pm?v$(hEOz-A;#>75;ZMdV=*D#{kdx-cSS@UCda`% zHp{X+Hj@RFyq|~PO_zHba_Vkm*_QPF4tpkInQ=TM8d(!EGY)9|OzWwNj=2e-JFdY^o5C(;jr_~lry>mTVgqU!vv3hq3>)uc15hA1YAdSYsb}1Is=6HWO5u-qo_bJYw~Iu<_?P9f`r?_2ik5H8 z0hhNgIwX^~I-IkdV37$PJ%~hB+MCLqom#exIUfWO>WktE?k;x3@|PXPc5;FrYu5X& zt7Ym($yDgcygljFp`P?9j$-Rn&9nx*>dKG60CWY*B6(0J#}q3Q?WcSV;|@vxB{kQ~ z4%;Mn%^T-j8tlm;TMZvSW>mkdtrvYDm)XKN$WvX5K>Hi2Qq2mMMM_t;u}*K-GR=+% z8vF|E4$>r4C?Elp^7iwDs{uBx7G`?~`*&(ag;6fGZ-0&9WZiGy^3yeWY(};E?yc=x z{Qp9`J@X6fy&>=gm~7@sST@Ig!iH6VPEr&h5o&K<(CYYyeagc;3s2h z<&I|+K!+lt4c6MVa+x;9lcXYDZL8mCw2MiQFL;!Hc$6IzNm>J8?1hjxw(!5fTaW$h z06^V|M3`O`;Vr5Q?GfbfWqGKkS*H-0v;rY7L`?V0ZJRwFXZ&yi!VWLAXP$o`-5d|D z^&r5EWYUn*+0O(WW!9~(5C5E4{+FM?r%0><1@n8e_da-8cRKgx(#G%ncgdX63Km9E z=}_L#t-<`$s-uz?hfY@YFgujw2EL@$T`}?b_KNEGr1fU*`fNC8nkZcn;N{fFu;j{e3HdS7 zdw>4d`e3FFsoZ6yvFkUg^51C|@l`)Gr3bwh0=$EI?IllMuBF>BhNX)+WxR1gqiZKjUoG!|a)t?YgORABJ|p^d-K+BO-2O zj6$yEK}qb3N{LUgyOE?eNbnk2y(oa}c4U@Pa{n%WC3oZ7TdWx{D0^lZ&u3WoEP2De ze}nUz0h&(2rP8T!JHGC4$Zs-?K>By}vovAfxDI?{p++Q?lZ`rUII2b6Z-{!wZH;Dn z^?=*OOoHx%^d8CtYBtJ@?R`pgZY}>-g~fls#Aj3T!ff+GuR+(k-s)Z5>K6^i)@HBI zeJU*WR97OB&Pfs94FDwP`MsvSko#vV$;V4E_okr`e$F`a zm>4j~W{>B*O{LE_3bUyntNe;GaSnKE0aVGpYX~;s{r5_re}$n1I(t(kl^@Eo?Lihr z(Qoh2D;IJu*V6r*M|`I>t31yj_kC23g3_i%P^23v^9F>+otWIyv{-W&`agE*+mR)DV#R zfgLeSl#F0}7O_Uxxf=DkNpgZ3vtcJp3`IS1Fn$ro_WO;SQiC=-fv;A1>cz^tqrv0R zkID{FUv+=_O8cKYpV_&r_(WOq4EnFLbCmd3sdk$P^^S9WmIy9>*0)%>uMG%p`t3B7 z{4g*pNU(03E*6DEeJcO@{l(YcoSFCi>bC1UY-kG_yJ}=A5+SVz#64|6Oz!`>e+DUGwR*Nao=v zSKZw4vQx>N*N6nii1pShd?orIGB??DJ2j+)t%d44N<0*atdCdT)jS6~AvmugijF50 zMN9l_$Ga66t2|}el^S|3`s=c5-go^;{aE~4MIrsR=_tJshO*CY;op75o_DFv1YQ`6 zCdtavqhWI&&o?WBFBDLdsx9h6Un&b_IFO7K%wv2&etzKzX;N;=_#uoKIlxCw{?kaK z>&Ej}!13jyd(T#1y~#+VnNDQADl7l0XCOGC$MODiKgrCbehG`sa!>3i@(qq-jrjlW z0;u{>H1U9fNF-{f^(N47tEz{{rOla>Q`43?FAe|{lcRSB{vFwte6d@;djl~67McC` zla>Bv|fMx5r%8*Ik{EFQXb6N zA4Wj2HS0d?xrgE_o_`1?Q9iYk#Q?CXi+AjjU~%GN;I858%+#;iS)P6VhB=Wpk!DyO zo=357#2)XES@oBUC$ipM7jSq%$HA3iHAH8%{IO`F5254-K!Z>k2gJcFQJi*4pWlfL zzYuN5eZ;`WZ#R=7^Fk4*6=vDID_p;arQ;j~MfCg<-)s-n(TP}4KGkT0Gfb0S?iJAG zFOO5m*By>$k(z1ej4vLXfe9Vg4nyK#OsOZZ;Do5>b_y<<`+05hRtk%E?WBCVNvfzmfZe#uk-3(uOsaoLiMNy*>cjn+dv?* z1`QQelRX>|&gO))F#Htrq{0m3$}O~9%u zawaW;$L<*p7ZO28qGbD?%2pmp>0^!B>9@zsauG)%@X#c$2_g`vo#zlEBy|*R(CNz+ zg~bQEO^3>vazW*7KhX6cGzBo@zkW>>JOJIL2e0)0%GJq-;?LmG_z%}j+lO%0shi~E zapYcnu1LNhmel4v0pUuIXkh3-k4A{llC0&;7=v(yR$-?$+)G9Xd)cVFMh6!Z#~<4>qAM7+d}$sg4{efH+fI8 z3=ZH4kYM&wTQ9#|?!?I$}ua@YIz0A!nI zsP$|ulgR@ZfJjA@X)v+BO)V{WUuCPT8nK`*hoI=!7U{N`e0P#LmE?UH03AC;k-6=Pp@)ACSh1=+ zMqD4f^uyZ}?F@k`-EQ(XAtbu?Z6fUNBrCj?qrH?0kC&r_*z0goV#bmJ6>WWN@np?! zVHA#+7=SsAL-@>^g<&0FndDBE=uZ7`5#N;>fPg(S%SDi|;gN`G+Z4vjU`bL6J&*Vw z5*V|YMoQSENR?tx&}u?b^sR5sukBO&o>OD;$mCV@WR2En6+L>~rbN~d0Gs}J>C@5o zjflFUaEbiKB1F33;3u-)`s+=qdG|^Elq}W-@h5kUuJ)}-L86IJvkND!(6=^@!cV~S znZs7=nJm8+st=(FkN$;P6N`wsF?T^;!v(8_ZFtE zylqUWTx2Z1A2zI`16qF`H+}zTD||M^z=HZLFt3?v-bb~)s(9k6b%s4*wyf`udU~(F z{HHw?wd>`Y9^<>O-9?;6y=1G$rN2-0OeSjZ$cWrRV+n^AIZ~0NN3yK`@Yp8kgJy2r z?oqiM4A}o(eQ&zCe$LK&?qwsB*SNgb&s5FYP?Gfby*AVLf@mHsyj<96`XktQ@JoCK zazc&MJ>o6G#JRz9-PieK|pR#)8=loY-(0a78Z1u>z+9k)fbBV<7d|4)GygcYkSoO)g7O>#aO>6G`1;w5It7qX zI;<_Vcw4C6fH2(W}{>vSHv?|vA@IMfoa1NL1Rf)zyoPDkMI@ph!@Sn zoC`IzR5yt*mATBJ^Dsbuc*z`w4bFjoX_MVt!655)V8{+fa)LrtH28p(AaEC`M$~J0 zDA5>pg8*2>M}(k$Clpb!mZTWtp&q~xQ_6^dpPVMXao5KZ-^inGqy97NOaxi3m+SY9 zz;W)O8yIhjda?efWkH-ZJegGQOj@YZ%x&n_dSQaBJvR3Ce{a9^aZtDo&44dH+v)tx zQ)nReAzdN9 zEHk+km)Xs{GdESs-7}+HLG7Gdm5l|iwnQtS$(AlfbQUU`IQ;`VX|pgq_3`_v?zbTI ziGP0oj5EI>hxYf@o!I)`RBT@1jPxFS$ZI?A33cS~0tn=9KRvERWsS(d*qzgwk$q_P#na7yzPN{N?Kt)#I@iu z!OR-+*bj__EJ^YOn14o|M(~ebZWp63F8MHuRI6y6Hb6JbnwlvwutkGWmba>iGj~WXBWH@Ts7CYQOl3-n~(JLP^#6-{`7I1XL`?4DOV!9)}$GY!I zHd^qN)b2P@Y4Gs=Qa+$Zl@|o*3jZmyO0v9Ng+s7aD_ygFTqRX_-bvck^W<}%19IqS zJ$BY6NP6R^bb#oP1<0yL-5yyd-?)P{%gbIufb>Zaw;)gduaJ;j?b2O2VjpaP7;myf zf}$Z-O>7ON&H<{uCXJ^{F>@;~4K=c6{yU8~gWg;o4)-uy z?`m!4ru{>K!F7g2|A@KQvT+@;aJeQv^nyu)h}?i)B6lU{^pzQT_$x_G0qGxD^jF~1 zKreC+OYsm(#M@emsdgK9rgsNGiSAIuHAWR2IG433vxhBH+qxzAN{WTw9Jp8uXL8q5`6 zF!5jFx#)k>nxtl1we3B*?l?xbKAqTJC}J??m}yRB`hgKx+_b-ODbCW4phqotJ2_`M z$_%XyX8sm2^=awumI|nQ`{ktSY|W5RfzdIN$t&bjArH!5zcbS%lbc1C-8jX z+<$IWT`*qmS|lJMeuBfN=@>$)%Cs7LHkq2KGV|B_>lI_+ih1Zq5`o^k-L)` zljU;W#P9IF^iEXhmf?DiIRi26V#{G1eenZ`QgohtYZSX$9!ti9Q`6HS@0R5nQi0?s z&X!Prm5SzL1Do<-7Q_H|cRUj>Me2j0u-{yAXVJoT9j!1Odp8OIGKdtxo|-I6P)EL= z*Mb)t_=^h!upwlXA^5)(VWAHH+><-r0N<@Z(wFStfCBNN*`5=S0$6Cq;%Uj*X=Qsp zJch_%^H>ERinx0SBx$~QY8^hKHpu@xMCTLkhi+SScsfMr8|`O$u*CT!hNKr?hvK2} zK~aNDfL_mPWlns=1$}ot1AqsjDFakV;x{a@w`^qq#z~p5Ju-DCZPc{2ax~qn<}pH>l@L8^qGnq@APi^qKR0{OJV+H5RZNRfdIws1{wZE4 zju9%Ly{JstYp-J1!1{?s892oNSi} zBR~hG1*BvIqP{0alP1%vGw%)TJ{pI^;t-sAmwe*$K806*OXNJ0_{8Qs9(haL`D5#o zHxW!4C1}wtG2yP~EDc3nX3H=L;l8H%<_U(7BF%`C1pz!VrIFLUWT6}X24G3`0S#UG z`px0C_m^G-O%`YE<(&6hWu>Hv-raU-H>A&XLXdUl3z0FB`5;mTxxmA}2%XERb{SPK)}E_#z+=UL1RUJ7&5D z(IcR2`D6c*rR?isAucx}IecuEmS5uOGu^vRh%-Ex@|UacqD_D4QD;JB0LIc|`20!# z_N!gf=cEf)S~K9Kwyn)S{MAnNbcIP?0@Ut+LX4}h&A<H+>?EC@{wy^g1BD!eLk zY{BT&QuTy2S`(qqcfvgo7)nQW;wTNO;isEcr`RYcvH}uJpqvL0d@Pr`dalN2i--E} zA(JDbTom1^gQQe;Ic|4v*1@x8VtZTse~}vk%zZIaTsZ* z6KZxlPS3Abvg{E7?7#c&IQ^tFfS*EE%Xq z+Q#V{Jw~U3S0t(clO8r z0rj)i3~_S7kK6kC zC-^{gkJqlvkuk%dm+|qZ%SXjV4heFB385$zmVl3Jg6duHSRhdqeUWY2!2;GdMH}8BpBFA$WW>#{Hwu5B%Pizp z=Q@P-T=a9(tl!f8l4(fHDAPqEcl?-Pz~+(8UFZfXr~C>O;U1u@+EuEz#Z#}xOkQr% z&A7MDZAZ#oyzO}LVg-O(G!|JVC-!YMktdf{JFLmhHj7Wb1AJ>|L!$#~#CzJ!D%kDlM(o^vKX< zY^Z;?8WuT_=+A9b(1u$M-j!{=ZOGH)NXh2r)Kb#+(2O7tI@!Ma(7A{f^EV>p;*=Te zbQ@_#n1CcXb$YTZLd#b=_Az%TomoKySoFc!xLAFSHk1*ojZ>YXAP3x{&I>WHg1{#q z+{Fl7-Wz}B2cS-f8gx_?sF?s#EB5jpG#K#e4TRWG1%Pf~yI~oql)eDt^$$UWch1x^dcrF%gRp`2!>bpJ!~ z|1JB+N4f(ceFWF^+TolY%s5CgloVx>YzpWp>ttI$hhn2L%yIWZ-lNnd(mX7F#c74N zP309o`zi|y+8VwABd2fX#vvPu%rYU*AYFem+NbrGBFb1Y`Q72LoZL(nAsSS-R(1vt z1(G`Ka-Q!r-ZrJC9oMiAGoi^^YKE!>z_Gz$;%Ur*C)S0glP zPFCEZfS{XY9>kSZ7M%j3G>P2;cTqLw55Juh20(X)sL{uYrzxQeS5wRv>+E9>KYU(p zb%_{EOh~YEZn66PI41iflx3>iZ|7cmEV(Vxdcg^RE0!lxfZz+Cl%7>UL_ZM8PcE}` zGxgHJ05g4DW34|rfnCLe=*dz@-8*dSXPEwQeb@Swch-D z_qv*ZV@ajlY>Of*L8-7uIpn8CsfPlrT!_zCC#_lp@d0u&5f<*r=1!ZeU7~9?fWVfNjM@wAie21HntZA*;&GYg-suPkhxn9bvKYye zv%tNxo|Vw9BTdOwN}HETOcL!U$jDUx-AS6OG8|o#G8M`5o&=WN_aXNyFS*sHANs0v!$UPcY4YCi3f-dMU_mK$w80Ba zhPWXpjiYxzpFJvwqPfjauDrce2PjxO58VxoqLnD<;(}{(@56#AzSkU6SV*bbidWhX zD>0~~V32w6PtO?PSRC#2R8s@qcFWVhPWEom<)OrTxJ8(xC5r|^41gU~B3vas^FkM7 z5&?xuEzoPBGRRLJk&z4~16jbn@RK8LjyU9iAeIvE9Rcr~-+1(7CsCG7e-4JNSQgQbF1B?zJd9UeyPlS9A5}e)+5?ph) zcqtfj(=PLl2VZgGg5Oh2V(XQKk@5bE;WS&tnJ?TZ5{E?4X-#>r2Ukk&6C-m7B-%N6 zHlGR6VDj>TUNhaW<=(54g@rw0`Sa_O)?Eb%bKvRX&_1)tg<#JH5`(5;bSU^Hs8~#v zW}c_Ns39dFZ+?_6rhxrSJ^)8)5a+&C9lQ!|+^Jc(+Q*-|p_%0QoGo`<9Xa*rUW=Kx z+`+xf{YA${aL2@FryeJBt=5B&SDVA)Q=Ek=)8h9oo3`JNWLC7Cy#Ub4;lNS|m^Eme z5&<}YO0r{-H>J~Jj^S=)&rZ9Vi2WoqHTX1Fkd;Bgj!`*`6HaekS>6DBqOrQEh>hpiv$$+E=u zsp>q&^#qQ}3~PK|AI{}UJXZH4$%51_}Ta77oC3oz?F|d4zhQ%%uezrhV75!MjQliIM5}3+%Et zZ6Z5;Cj`B?*2pw306h}YR)35Bq3tuG*G*|On%PnlcBXro?V@#)&k4O{F_q%FAdjf{ zJ-|N!EQND~y?ChA@H4R&m?~oKhX$kF;ZFV_K^fLZWbs5Nm`{5Pgm=$)Oi`Xfinc$) znsitt?_(2}(ikWe{Jh)L!K}+yrglH(4y*g>i^B>*9){qf-q5mh2Bhf`$W*h_BZ zb<6bf7HdGBV>dY3;653%v?XO>j@@IYP}`S=Kz<~Zzir*P$3N@jL&^NCaK_M@n)0dn zXCVWIkBHC)R%EIF+>#r91Ri*py_C{AaTS%#fSZj!s|2$}jR+i!v}Yxg0A@g>E}^~9 zE#0E^{M%ErHn7hji-hk`CknG2W2KTgdi>V<)JIEkm{<8KWliFlHN4ea|3 zis9J;k+v=ZJj>Idq&NsI{FXsG4^EqUWvk0Q{O`Zd!1MjFZ^xwRL9_dN>|W=ixz)y; zjusWo7dcm}I~Q9`!&-F|TtPo0`XyKUE1Ih5v2745se23a{(dhRSWxT6XmsBfsY6t4#)(FWY{TWM1D>?|Dp<;tBCV?o>*^!(7xlc|j_?{-~30`n?LItLwh3 zL(d(ebrR{VcQD15i(1-OL(m*x!X^Pv40qLcD9CE`Z5uJsWEe)7C_eF*{=&*;B^H;^H)lSb zI^vz=Mu}OWvQ6jjqfL3F2;>LvMY0T3I`CiAHlKF{Io*Vw)a~AE@IpPH_sLx-G5$g} z(f3b#e;hY5T)(go@s;}%w!F1*MVffn7smMKYya0EWUGGZuvxTN?&;-zv%EA!6%c`= zm*C{ch=Eh)LRYLDH5zIb{GLu$of8H9K?K0fpg6z`#_Ym}g)R8j@SHN?=>3OJOSRGj z+V9^o;5>beYlK_kE+C2E666TjCXj6V6hhPkIxB_{$YW)&tU9YUBxHumj~Ig{ zQ_)+UbU>8+<{bcay9@x;T#e6pBhG0P_ejx7uo3x~6JkjqM3vTVFxqf()`<2!fHf-& zCJ=78SIshsNl?uMEm}$b6{_h0E$oFkgy{D_8Y?06?_p?}FhIlyXqU~BjJhKM$onAx z1rU>+gkH+?)NrnXzw1f%umz&V@~58$Q-IhSxO3}aM(Oh~1-I{z+wt6^SneczkTeo# z1^ZmXc1SeU$6t+~nY$wTpDLxtJHyaL+aofD2b~tRV+IhUCkkMS|B;vD}! zYT&sh6x2QfPQr4Bq=jA0t!`fb$M4GfUVUU(=1|?Har=fX-rVzVmiv&a&h2O093k*z zU(UarqF?`f~3?W-$ys4MTW=Z6x>Ob0)XmlNh<{sdQ7v z>WQ^F+uuTde$yU}i%*4OCIh77><#o_C>m8*fj%rMR!?~Tt#zokg73$ekWnz8?EC6z za$XMYT>n|vGCi5CD0w(pBs^So7_?`xESIouU!0-2+_w2yR?!5_n+qGGI==3 z@gKV6KR~Tl3=vrjWu?1i_*5%AU`CxI_j=GDY^Qz=%10rXzZeG+T`$$3ha{H$&m;kl zLVS5DVTbb+`q-xI6WeM z=x#j-d=fwI7BskTef+Py#)c)Rn{%f+#`yTqd6Ifr%rog;0q}fhE8J_nIbk}ke%8Uz zN$O{d&k~2U$Pg>-BH z?~9$riXTxp{ooLQT)l9Fb~6riXv`enpPzqHSiVDjI5`St- zqY2SXeGjs;bjpJOi0xh`0)@qwl8~Yh`(9KIMM^Q}g(5N3vf9PMP=!hzYiK+*-17tf z`(J$%*gP0i#w;TEP)&qU1`7yqcuI{x0C_sTXx%rXShH3HYH9YhW8BQ%jdn^KfmL>$ za&W|(ml@W)$!m(~hi||qg8roZ>-4UH@?{$Ko; z{SW^Mf3|UJf3^OA#<)loCiLmK{(LV3&7? zE|5`C#UczSl9Zl>TYIJn6oYn%(inA|hK>{z75Y793-9dPm;E8Xt+QVXl?R{6`&uWJ zi5Q-@wq7E{tYdWac!a}@h@1EF)bYKugrHA3pi=r*X~wI zzxX=1bJ<_^Q+vVeDC2H?Ze=r*>k^OH8KLQCwr|8I1|+r^OLo&XI6%6X*zX`Guyvlu zO-l;b5jdV)MrpVWzkPT;FBKem)s1IRWOt|{pRu6!Ehrx|zF?(XZCXN)l2VIye|K@m zDJqhs#jqG7m&&6Ye(8IujoMJ%hA^K}LJ0KVw>{JP`c6Yh@JvZoHOKo1Cv%KV7w*T;J-X6y^@BFx-}R}gMD*;eAPkKlTW zhxwyQU0EC6!3(!q2yKb?iBjxfVSA00E{h5@5EyL(`2YH{KvX`;z=Mx!W?hwSQ2{~R zF<#NT&A(kAdS=FcnfA#>!BF4aZawDT6ELYZyK-4`<^TkPU9u#2km>?&`>OBRsTFs; z?fs=#j{h%ERCyTC04w56K0ySdh;%Ng=vdf>EEYiLU17mC2DU~sI_Q8oWuygzXD?qG z2$@%UHq(wrO51xs){JE&7()9wZ;Lf-rOWqR5MKd>*ZH7pbX-mltir- z!njAEL*z1|Z1aG7=wR^^3~6knrME5i-Buo#sQZ;uf#tu64_0s%`%+M&P&I);7f4q% z6cjeiVu}K@>};!wy+0Y}e4q$>HeH0ue$C`kK*&kEvvh8AQ4!Qrchl|d|Nq1~u7769 z;V33gN5?urHGnhkPBtOSWp=t!9_Ua(oMP#q^SiD?QO>#+5}8nIJpRVE1X+L}|xXS*kuGYPK!U&6~{ZyMx{yci+% zU6*A30u*4F=%8Tn`OHA$`d26E)KO4lSsX!k-of(=pLVX}-XFu$!|(8PN_t5Rodkc% zdiMr;|= zmoH6xAzs4nbMwi0W}VL>8-uJ{O}5|YiQ}wZ?x>}%K}E))G!R@2BVxrNoFMQ{{#t2( zOu3y_E>J~sLB@FZL}~l)9_ts3vGE&OhhO&43oZ?1!rMO>jJP!h6FM0ZlfTeqYRRWp zW|7|aC{Y*Ze{=bLB?zZ0*(76tXdv`RjQopndtX*YcGpiIJjv@8f3zNX0+Z2RVQJ;P zn?PGzm3>l{JGmj4X^@1`(Nhco7iD41$Or)WEi#R-qF(*~iBC_+UDaMK7YA)fC)tQ! zZKu}#c`;)cJx~Ba4b48Vs1{$iY->3T{j}ZuyLzjrleF47`wEXi@G<8I{f;u|sfD2y zO-EjIHxX?!a$dJV%vjBZ1;1_Bcj(+MHn7zJ*aI(S7Nktc<|*woel1==E;0YzA|}p( z_@sXduPXjix_wQHDJ@%(eXtHHYR;y5-T=Iqwog!K@O>i1!*BEn(5TfABGL3f1Mhp| zH9nO;mk|1ejF0PV73S}5BT3ONLecsU3e^R`&`z_dG=E$5l>q+$DzSkrrBkNF0JmlQ z0T~}Yq?8yb8+Y=173q6NL{m9k-pZg=9!reD7u&3<);U2bz*z}d_weK!haN>zdMXjb z?g1T=*(sfU{>HDwuyAt;fgQ=Be2Jjqg5Q<8>Vm2`1|fS4y(z=$2_QZkpIr%;~4&#nh4b zJCJH$Yv9%)@=fvbag_@Vfmv`mmF>o|v~FdgG3?AjtT zJc~4n#?;@)ZC2j?gr=XX z-Rs;!ci-=>ivN73XVyNot*1PMHV)o^I05*yn3&`RS z3i7r`__|)}gF*E?L8CUN`P%nxR_nK-*iWjfGUF2$_|H~84nJGHb*AaFMs=}SA7uV+ zEb#0zZX~=U`#Ihi{rid2GX34)UDHC@gsWAv&+5rR1Gxp!!#Th7dPRd3lFfH!8#LJw z+7ez{&uA1Vi1QEJh*dkg&irlMhu7zAAK`2#ZcK2c)GD85oa(dX<^q2ClIZS;+9HM^ zXu*J`vwc>1TFpSM|6h@A!R`Q2msRSOqF3`@XYG~WOuN}(_6~#b;Se+drxHi(lUHG* z+|F9sUSkz=M*oEa3BIu2Ffp6Sz$O(qH)lTEM8GbrmMO2KF%60SEV)6(IkhDy;n7_%p2;@bj(i;&r zdL!}gwVj_e`4@Gpi^3V3Vw@%9H zT81Ks=^PE68`9N9W2p35&HLZhPk`m6rf};&)2k-wgdp(KOL+&<#5&8Hm%NWb(X6&i zu_|7WW-}%;rT^==v4M*a4gP1g{j&j%qG-OSJ zqe;KDcK81|viO%;TIK)sa7?9q*vkBHJ+6csK2@f<;EY>^fql@d{pHOBXsNK|(jB$g zvtpdwZ(>yW-*(IkVX@m)n}4|M#%Ew)^ZjwN^k6B4RMgdRLGT{WBaV=G@=r}+!dP2s4?4$a74e4B=;)S7rtpsc2r{`fGMhE9@U)_9dHBT>x$ zC?G<_)Wmhqyd%tbD}<1rI*)#nv;WBxUdn*Q+Zv86N!1-v!TE9T!i7hIqs(-h$3PMM z_7~AdYXVO)d=lQ`_dgK2nQCXEiQV`3PyFYhXvEic@DG<2|C^=GGu{Ejt0c`hp z;t-H1r4a0qomw*=*7m)L-M!0o-^R^@g_$M17^i+dF|zkl(_#$L({K5W?3|myY?UlDkC4n;6cGQZcp#f@U3B{oeZY&7%|rJFb(Zf%`5v#?v{1LiGN}2DHd=;z ziCAts0+<$OrLQ25?VTm8TA)`6S(ow0No65qIz^yJRc`H;GT1=p>71 zD}a@`8~M>z$~@=WItx04;-Px24?Y?}A}|_=hTE#pKOJp#3V3tI;7e`!EnsSz@Q9(A zWiM=jSRsoICL(cqy`GWwVEFeZHw99L?5n>M?tgy=fY$tyrTSx_tM{`4gRzONBAl*3 z7F_`PM?wip)x_)xiwx=n5SVcXXgqHh{-0EEd`zC5Et^XY^=iLSNk{F{kpwi6E z?33;$DTj%$0wky5+;>x1qT0V(WypV|ZRof~`R{|nR_>0taHPKc~Aab z#T5SSx3_X@dd0(<2)Uri#1)W5{^Fnvo1gkGt1QFMb2?m$X~g^F->r;400bTD?>km6 zYt{^ENtJ1+tX8nB|DQbq8^oYr|M}HJy7lDxg}_GP7_%|VHK|+Bfg*1yem2Q3&Kg zq(yG%izfe}#8fiBep+0G5ny!r!|qEjJksKC_3rA#Gko##CzY7n3M$4PQzWG3cgSZa zae1z#5A1`Kgk$KcU%|((L)V@+ty}-}sd`GU!|}sb$4Cv?`MN1{oBu=+6uTKuCd97F zytj%!1zEXUYt=i|Eqvq5KE9o~mp%u#6Crs8MaJEp29Vp}WL?JnSc>nFFQ|&9o{*8% zC53iOWo%4H{J52;;OY6!osmE!O=2ID{*IBdJl+OQ<#F?S*Y?1#9`o%kgqAK2jekX9 zD_6wv6+a*cA1UUv!)AuMUl?=~oB_uVQrwdz%c5a!fUj@Lo;bB=4t1Yut34?BjhSc0 zx?Fe?x52IErX4UVG)UNLy;QEceOC9g!OZ6zs~hH{Tr4y0TCxsWmlu*3K0E#n)bP8! zPXvx4tP$A}-*3LpH}2T>FnVB}q@LGgJ?1P^j!`MHw_9o|>f1dV!*GwTr0#CsPBmIO zJ_|~FyzKa9mp0k4BA{(ODXbVejUUPA78Od`F#SM%NU#cbp=Qa@7V-;=sSdts6K`<_ zodc~f%7z6k9F*ekpYq0`xrJRDg8|`13;e{MGb03m3P)c%|Nii=(%qgS-uMbSGGe3IvvIGCcEPvpAN=R5f{dC(}us5`>Gnxb694U=>{SJs|?#h zh)#vwJx*F}FLi;B9oku^SO8P>BODZ`xVG^{n8%*urCE+wXg3Q3taXXhvAr9oXiIK0 z2r<&FEmXWdj|hv4OqoqB+ygs@Z7+cklOH8U0NF8r`Eow&uq`<2tFy(M&QJm@+-mAE zz&0*9c80rPHEjb4zY-3vu?SKS1p>rvk^W3EEJ#Kjay-QnX_b_3p-@!tni4;DB+f(S z!}x7IUCML);+Kk%!1Fw*&j)KMR3QUzJ_8*ST};+?o_IDXfrQgC^Z(gG!a-WFD)1c%HciOtmEks=QawYY&kLMy5j)0h z&n5K@FgQ0YcI*nuvkq|Z=9jX4b{l@kI}Np!K@o=#^!6sPQ*E1h*H1^^AV4lE5tpO0 z?p{Nofb!fg=O~Z<(I0R;)x#7|hos?xyJnzPz5wPIIcp_`_YZvdvb}?fZVXXs*&{n6-ZcHJ!(^ zgRgGUd2`!cCgP#!S-udwAj7Nv_?7+w*O{<2EqB?8Pt0tFU~DR3^4)6R5dpderVh7T z8ky_FP5ge5$ZTPx-NtnE#9X4vjH~Ue>R$yPAsk0j>-ILhh3V-GSp`e*dLE`=*eHtrlm_#OK21(jzH<_7O%%K5Oq=Gnhbl+voL1 zeL$Y#=y>2BE%SlFY+$^(KLT_v-TOibguJ&};V@IuB@-MYo;V(ObZ^GtCZOMXFpWJq z+Mbv1PMwBI!sl!*M~lcE;nD90E2$bcl7o>bj9K7*uzFNrP!E&dU0<{5P#$fO{p_{F z;XC)QnSmJfFK2!sG777z#wwL@?Ev;>GB50FGkEyZvBfuoS1g^2`tWBO#yDs%K zoAVCI!m_0M+80&&n*A2R89?Dcy0iK(r_bInJq&p-`eqS~SBCEXfQwCD$IaTOHZ0u` zq&!l0{w9ORt2_ZDy7sw}> zvQvPEMud@oJ5?CJpzu(>n?)n9_TstSAxbU$gn8U_1tY1F^r){o^#HtQIgDHbsr$dE zdhbB0-}wLgeFo>4=h&NLi;%rJW>!`u*+llp-VRQjtcaAE8B(%Wc01-GtyR?us5NMC*8Ry)LEQQyA#v(g%<)-xYVR zJAmwnSu;i5xU=2|x|uGJo7p(h)4PUkA{gdH|0GTH04EQQNMH;1_S zb>EH`bI)%3J(9m71M*A`ef~6i?fLt3s8S20o%z(UH^kilW%jRg{Ig9#TuMYW9Us&K z?~q{okk#mkw_0ZTSNgsVO~=@7W1#DS&WSpcz+POC7}DjnRXpUwOz-whe^YXdbmHVV zIN~v)O3La#vyc`tOxETY$wpk70D~I}bF3O0a?fO;D`p6kn`8rG>>9@{ai}g- z$STkssO-1UNxabs@52($6#37Hb=VfoNdd+J5@er{MDejyN%*KpWX`Q0iBF7j!0j6j z!jE`)p44w!g}Uh8MXgGS5@4KP zt(uC&U$wUCb5*$Z*JwirNHuUP#(ce-6SoRQt~!e^L+=8d0Sn~)GRPt@*N(RPk@?(&|q+>DmOJ?SSy;nS17(~90}SDjuo?Cf|#kYXH8uL^wk z(!j>?R@&%_w*cj@KmK9r+6q#nm@GO7i~_Mdg@z%&+Z%#a{C*u?2DjM**u#QILh!!; zq#R4tUt`S-raRw>u~&Qd8NlCXKF%c5VV{ASk*m*ILFG(vy6!p(an4LW#gghCzAYeDp0*vBrZ>*}mX7>COb+-IFwaNlrtkd)IOn$F`_5 z9kPrcpOMYe_|;#@X)@$->ysqXOAIZ2n78EC3cG5F+RFL zmaJFNB2%E=kx5R{=+zhU8r`lq=CrTSlGrb{b~w9orHQciD^>8l9|^e*S`l7|o6{l>t71sQawb9wsgVLmw8VeG>p7?QU2N@}^D3i@x|TenA(b>3D2(S_DX1ov+YXy?drbNi)*vrlXgwVXV3c7n$;M}X(^eA#9>(Z z7v)KM)-f6qZ0QK}F_-P0IXWFe!!8Yg>rHA*}^Rkpi6@GbLMav^TXZrz+SNPs8oN| ztziuD`IgRse#sc%#CiQr{fdvq$>|Tktm)Ivi~ryJvUcP0hmt8_n|BM>wgpE%S!m;h zVJmr*tbY_|Jd<3|-66>#m-?A{#k&C?WrGl6R*YZn7T^Y^IZ7nFT zRxckn7znhH5ChS5O}Bu;Dq&pNc=5YwCdhtc%a>~fxPKec%Q9-IoUR!tcNiv1s^`Y~ z^*WvQXO=vP^-#2aVkS9A$HFb#VLHZwbNo68%i3BCDJE!96x<)wcHeW?(M-zFx*D}8VGKX2SW#g20P*JA8vxR#wce?R^dG(es`zoVP*>k z3Zzmx3{~V>%mns@>V_bMBX9|+1%{S@0bR}gn>P<L~%xo>F#m3e0Bl*Z4BP0(8efmyDr#+utN(L{t%4_F6L+nLPQh^ zu@RJpyi+40x+(|s<_6?L3J#>4;tP&O1qQ%lD!?o|tV?HFk4+Y^Xv^3LUK$g={AXGG zIQ3&^1%ZM~ZNg!{T(ZL_Ron+-jvP6haefy8Yy~OV>U2wi+M~NuTD0chhoO3&*7n9g zwPiupEAfwNz{hyJnAs_!nDJ8`?|0dQSfLFm@P*HQ#( zAsl6CHxbIm6yQPzQuS^Hojlu4is%cOD|AiM6z5<2I=(tf6K^!k6a4!B16-*?r>!#3{}t_>Uy=U*#6(0f)H&}@=v zN_#Dp;-*?e!HlzF5CxKgY|d37%`Qm{ha8As%^g|jf;=2Pf~|C}!j|Z?l%fogeW0_V zMDwdXtI|1_yt>83$?W6HjSXgWVj~yAJ4Hnn^g9Zq`Rpan zL_c_<@G)sI+T>c^y%+O_>`$WIf6ay|+*3BB@q799RryVzz|Jn8v*0V!(L-i5#1fEo zkn|+`s|;X{XTkFP5#N#NSQRtQo9`Nwa*%RLLplxYGKxuAZc#nK%9N5xI_a^Ug* zU7^9Z6Kq1d*i{aiwqU^kB-}7jpsF<~#z3g~Mmb%gXzeoJ#|(kq#*gA%WmlFtS=(!25&fv(rnTX8T+v-EfH7ZMlmK`VoI)y`ubT5?H_*4~G4XaL!% zPOCip38j?Fp(N1!d)Ws9Z_}?!E2Ol|=eCClfaZF~c`Zu6ogKRY04j{E<0ga{vp&$u zyycJ>=I;OyOq$08p76nYAed{50s@{j@l-+@Umj>p=gSP+e%w7uDeV9kBYAy@vP8?% za@>YB7HI6v3UxnmMy)>t(S>O)uGJt!&nJk-H=!@RM8y^P7wc@4TAsmO{FsP;2Uc)# zN9v`=OuT$Wnw);T>UAEE(0`L2J-@A4T^=R;?@Jndd5l&aS?BG}b@Bri)hev2^1qP< zkYb(#OM6iVsHc0#w%%clDn8*hPL%r2pog1Ews8NA5W&P>aCq2coy3YZ#?8_i!no zYf5idYpLtuxBcIsjHmV-DghO9j=RirAy6^L1k+7`<%Wy~wouW9Xf}QJ$-k_P04Qek z7k`8Ppuh@2d<<+TKe<4jr%dRO6$bRS0LXl(Ir_ukVHHFeGyvqXrgIekCismrg;G_!)xUy;WSw)q1~P|LlPp?04!)4U6*6kwjq9f1*KVM?t7t)z8*Tq2j`h2hZuAX~li?O$8;Y~KKQzK-D0 zSno6o_NWjB0KNdhsH#n?5ITS`B*>G29%FS^0A%AS`z}O=^V1P529QA;(Nh-H{A7sW zh&wLG@U7{kq@)!6DKc#R%?2#UBm))Y9)SJ4bg{ql2|F%*B^&X!F``7){v$7|HIm}#DH!N zZv;LpK=smpS}futDjwCpICP^JPj#w7PjK#3^0%$1^4sB1wSjIkJd|iU|F;%9oYpY}&d1FkbL$ z)Cr}r<})0!J+1KOOM~kgnUpNU3yU$Lp>Gh$C9K~_OWL13lj`zLe&ok&j>7-&qq?@7 zIO-+`uq>OuKMD5+)Ea-ta@ol;7F6Yw#6T&K>>yYb$Cyh4Mb>eg&W=(-@TVF=2BjXh z>a|acM?EAla}VLHS@Y!-0Gxu_xF0bZkH%p*>+dl1#sq@yU=J}IR{lXe2+hcE3I{Z_yAg=1oO)i_yba!k&8UCtg>kfX9SI7Nr2cPW*ZtMwXahC2CPcPi>OY||9 z!puMkn7R97Jolu4)KMQTS$$UqknD_%yRi*LZu@@WRODk6JFYfec8H{sDcm{h^_BT& zHWRk%_~I-8zkJg(GsN$n{`Iz^oImVm(N{?^yE{Dv-=6oL7T8T7@Ugh7To4Skv&wS+ zkvjY5xJLAG$ifp4!dZ;ZC0F;Bw?RR&2@p7Rvjkjxw5Im6`F=?vy>i}gWNiT$-!!!y zAH|jF(#zA3&#f-N>^1K@zXJ-Z8!X#7bn8fNfy?{)Vy)`By8<|jxE9TVs%BA;6&-wA6f#K^_LL0r>nc%3cwwxi#}f zzD9RpFnQj;D#Kzc1eY$NUi)65L65YV zDp-!@gZRh2gosDP>-RC-wFUl3rXxI=&bPVi^G8Pe9bBU(-kU=SRH55$$I(;EJ43zY zd^Pd^y9yj)eW{&-nR)`vBB`BGSm)gCzU2$fwRL|{3SmEU#y)%G@rhJ0v8;yw$dSU>9a zdQu2*4a-2Nh>W7oow2k?ENwAtW0(ntAq}CqCoRgC%W;Oo#&70<&Dme#a|>@TvCFG; zD^Nnj!;^;X1rwn(hT>>rA7!RLkW_{d+`{!HUBJGe^_R$Pfk7=7qOrbx05Gl?&i1J& zx>1QLfuzGttBca-+p2MZT28^kq-F`T?_Iuj>k*JZ%X{YUxX-<27w+ zOi5?$b1&ik#T&wMzAOag)tBFoc=FZgBpL%U*D7>$Yd5d98Hv;4vLt#0cVuoy`php>((q- zAK1vDfINrOP{jm;K7fU@u58P?(=gkdZtb!+#PG8=9WB1JDnDIw+Gapa{pD*dc*F}v zU1#4E)$!M`K3l@#Co*ua@DoYIo^3%}%lvhNz5xYSvQHYnXi=0i1HL^0IXznM=KnHQ z`QKpyIECpD7+L-7uMtSmt_Ps08b-i&x2AR;!Ddj_8_t?W27MafozOLWo6FU>PL|(` z;+}-yq;&m5)QP61skmQR3|gMsU=wqxi(+<^g-b6>%ix>M0JOfQ_}|HwDsdL^mkvpQ zeR=nTAG3j(m=7i_4o@>xE7dFL{K{@iCWb&rBBpMVew;RvDWkw)-$y0K3ossNbgzly zun|}U-{@c6>rxnj+TdfrJ73=*u8|KYP+z-k>9pL>E^Bn;4Ozp+tt`e2MWNkU8Cr#?yhh#((Fg6;^FLkE2H7jVuttPof5uOS*er9c*Q!iUZ*9ZSt zO&CK>n9QU?8@}jONFCZKj%I;uPRo6=OsBGYD$pycVBM#m071*<>6g@%6*hx5 zzfL%>Q#1r;3sdfd1#O|RvPp!Up`DcIa}S6{P_#lSS<2ssQ`e`#BfI*~%;!bnS(t?Y z^3K7Evw>U-C(Xuk^GcQQ+_r6uD2#Omf@U}J?gfp9i&DH-QJi6)Z-7WM$pCOu?F~VZ zMEDhCq)@6}zwMvtzh@G@4pzE5ut9%IgI6;KQ~Ml**;QQ{noyz?m7uCUFb>|RfXEC(9wGXkpSS>;u{ zoVfgCdW|FsG6?+qc4Wm{mK-AVb5qIwBGKUj>-&uE-(YAg5sGsLhIUp0SQH8Q22PFA zm!@JAFF@grB|-6}K(bH~FW&0ach!G&t0w}98kQMM21}6Z>9QqPO`2{1aL3=j3mEsv z>z5GG-Z{K6w;zLdQo7{E=*9=yCuw=FjEsoE2DnN+oh~>0hq6HRm_Oc_g23e>x}Qd( zb)n@zGlKJ$YO6)DTvf8EDS^GtcuU~D(*K(LI-mS~Q^%Nr3HlpiID#&>NYAAXV!6~o zUfOc)J@R^3z<%XR=EPNi>IVMXRq}I{;a`Q8w#i+??r!J1C{z^eQ=-VJ5V{yL%>$DFLc*as-bLIDZIc!wuh*2dF9!mZ1cQh5I6)vD4UdQTc!W z+P^iG9?bk{So^qIi~0#@JnB@>WEDuB#LL?lpdFv2&b8>wx>Nfs81gY^t#VbbH zU0;DN3=(q7+`Iin(cMWd#ZT4P@FWUK9fPCY$47A>aT=>jR{+{myVbr2SUVN!|Eo^N>$b!r{yiO-x}QP^~^ zx@r&qiD&LhL+o~rMKY8_iM9U&oZZj4bR+oIB)x5A=N*RN1V*bMh;(IMNxb(@I=dKO zS3@xxc70k!vB2(;u?*Mf@4;m_kz4n9>;1ty=V zorj}T^(;w@6;xLsPm)`& zKJHx$IR>Y;cx@!pp2q?L-{IUjM{B&#r=;w@$_<)gJ)+m$$lO$%u>v1~?nebZ7mpws zj}w}Hdx7sudnBUSaQHTm+9CIoC=32rw034>_X0hFE1Pi0elfmMz}-h2%!M&>$W6eD zjWKT^%(r(*-BNQoXDI9S%pn^2=a@B2Ex9C3wp&b1LTpq6KH2L&hp*%56$E7&z~P6B zp>kF*db#RQG${7B@GPl807uBE@8^}r$a7JztZAXr$T-4S7G>J#HE4D5Xthu#vtAa*%fQei># zru%+W9KK#Q0aVB!SbOMlEThCfd)&8q9s1mI;i4eqmt~SywDN>avjq^{ES6^FqMcF* zK%b}dX}tF&W=4ZFfs+VlmyvX{C2^UPl6RBDIPZfLqM)gT>CPD#(z8=Doc%k-Vp1fM z=GjetqO?Eb{dI8g*q(sL_6t1`#rF8Q+1mqe1ja9`znTX(s_O~{g?+G#J(IBIJ zGLU(~balAvBy>0^5lk0>VRA@fl{tF-`pj{Z)JYPd;E%bZ8@hg3?>xyr+m!~zqTYIA zAt^il^Cu&aCqA6;Tq9cSaWZKc_IG!LPb$!8eCQ7(At<*VPmIg?0~mnP z&zor1XIsrT7J%L)#}X-X`am?eIzH;2PMWU+3u--lj3MCk=YYXaAb^j_3WnC`alSj@ zUe5ISX1IMav4*a6K|fvm-Zn4k^m?-MIX^M_O%36{u#WxLmnW=tH|cHrFKXw?^6J|h z5oM1`^qF!8)HbEw!7wxUpJTkBZlqhNg1U<(Rc$8X3fOOnbCLmvG&T?Sju`;VqV zk6wabzg2MHil-Y4V;;+3;i1}oo1S5c>mQCYfwSaCNuyXl>$J0vnJc`SO9s6HZ{hCw|q3#J)rWHu!HW1Jm{HykleNH~^?Y+FNx#unVVsnC_Mjl`TTio~sN(aG> zLootyYb5kz11PIotMV|lHSw^ z%>wmM$N87)wG0bjzRL7!d#)H$?`r$j=5oZ|J^79F@Ndsj!33VHni-U&o7)m@bQ?AQ z&|ooF7FnR!VMIOTMur*IAkxAFWAF%SXDoo(CBhunf+(shcNt}<3-s;dcPpk4`0oQD zAS~LsM_8}mn^;W6${8X@}EmN$%4-j!r{UKxZELWF0QHVf{y=2uhOpg}0KpDC6oM-IZ$ zdV#RvJ2FS5%T2sT$ZK?zuk~4qX073IWIXcxeK_KIN=ra+SRxpp8V}3=BKF)dp9r+k zdGpcDXd=D72zyJeL+}PZ27=&b0f6QS(@gKBh0&0+^8)B7bs;XhWgSY5j{-NYj{c%= zOJY`v+PaetO0T+AGR4HoR3?OAVJ?779=KKV#!-Hk9CR<(S4ea?-%HM1?)Evox1{(# zNf=6;psL@;-wHj-qM)FV4nOs~1Pi-^NX$#D`pnMa`_LoKkR4CZk073D!oy^bzrrc} zS-@61eR&KV9cfB9{8dRQy6nK`|LJVPkKAY%wy<{OXjK^aIb>sh?~{kiQp2q2zoH%E z5)h`-oo(&>;w!<{`n8x^2In=%%NuESr6&N&wkX#w{cT**__)y4(qpHY*5KL$YZ(x) z`K?<}P%>o^zty!!WN9@p9nh!zfV_8ARhTi&nDUdhs*t12kt3c}cy)rQ|Dgx#OM80k zFvx!?y9WxllzUcI!AQ*R9*Ft^Sg*au@TbCfaOhrq&dr}tF&xwM7?vK{f@u47*Oma8ox$IO-FR^rsgMsFSXHcmJ`LH7!$5zXYAEJaz6mNs+9z#{4YSoc?6!b z?(LhvWe1vUJ@1+f3XDHpJ8IHO7TEl6bp-$SdA}Gw1_dS%e6j|Yz7bg1PS@o6ej!F} zTm*uYFRd!=g@S5&ZEboD4$_XNuixDuqf`ShV|QJieZ{3T#Q18ICzIVkaM?M4f1LsY z?J+eMHQJ>_6Fx@2SlNQBEU~`2Ckeycba%#j?`S05Jld4UIak{~x*fvphclr{1dyyo z(nGJ(GRcZaD5$uY(Wfv_l*k5b`VQ9SMjqdc|E~oFKCqAfdtmuy#zhie-vI+u4h}9J zx9#B`7s*~6g>DQq#0xWIQB2|1+0-OnX5A0@%@f6e$EpbyqwUp5;cCtWGU@^aG$H)> z`{a!GUU`t4iPShA1dN_{3~dd#_-8ouX3#vm7`$F|#lq9;Fu>_&W}{?Z04{@$l++#E zY2JAIy#7s?8Hk3ly2OryBqf2Kl)|c z-%GNff-5p`Md!q%WM{i+YHEIp_-3U*Lw)%683<0^y}u{@+rJ= z?Cv>oY2>&~Y(;KKw6Wdi-$orsxz^X`XZG-vxIeCS}2? zUtljVR1~tv*Os2PDQP|4&6e^SgkkYq_davAKycWqzZ=L?`pO>p*8s%nzjGe~bxlEB zo;^rOO`WCR%2A9`NTET?=YD?5!rX+o;lnT^_>@W*U)DDFT({%hmoNI?Ru%jE`+Z%$ zq{1#jdBzv65OIXbSzgZ*v#t@z+^5IN_9Bvv;Tjl7>QgF!cwBxV1thSM88o(8VWWTZ z3^~;6vAm7|XY7;xn5Mpot4qt`l3BTLsvBRYJFT7Vh7L@4H3rQ!c}Bj=X{_kSYZ{DK?lT%GKa@(ZM_6h062{=N{HvQ!zhnLqao4I#AGOKWJMytx*!4!Iz4Za-fLX_M+A?Y7< zYbf+#Dozw08GmC2jrNMg+`s$U^p4PY5*p_nz>zic=HTA6zq3&sBLrD*N?S-zFt^1t zV&Ry7zJaSTDCMeb&z4^6h!aN{lBw{cZ`-^5qJcEe^Z3YxTDs{V1QySxE-*4i`E09p z$kb1;@%%IQ>*_d+3qzu&O?LKAbnt3<>JH;adWC>sPf$m=h>ErMdED! zUb0{?V0_!^-rH8C0ngwC&TOgBWC}sSDi4|Y5C4Aeb6@V$`>8)BSG2hXH~s(s@Dg|A9t@QjV{Asp0=BWO?p=@8c>01CVgI!;Q3hj_5EC9^R2P(MXBOO)ZPuA?R zDb!(d0>%-4KJt5fN(|)r_~+qmQxWQg)QwWUT>TI|-98?%+J_$Xx&Ym?H=^Ill~a$&36Kd2a`69;5zA?YjGYt$jqzlz#V*N8$e#e@PwuBfaeRn%IQc{-lnL-)KepKb64;O zE_gukZXaA<7@DGI_&^KI4;yGi|>V4WCDGWC8%?tbVOuK%(LP7|t1pzmz zOFve?t4&^|nERT>daJGW=O`aV7X()&5jnXr@h~n?S{{gc0}Hva02@&UrR*8iv;mWB z>Xy|_7e=*VRR}H8&qUc|wFrTFZGc|yRP+XvkJrF?B?P%me^dIV2_@rmze?FOYIK={ zgG&ko|2d5ZANZ~uCfuP9*faWDX7`!AM;<%D zR^P&0Qcc`j0+(isCBc%LSn&e|do?v9;zbkG2Q^0+6eEc$q2XG(q#dQoJUyZ@fBTM)4Rado;F%q^Px%i5nJADeBsd6gYI7xX?nVATUPo&C0B zwX<(R!QC;g|IKN+Q|?AFhf$ANop<$($a`4>ndc=QRVVZf>jATOBMK6u8|^MseNsHc zpM6dt?7d=n257ySBf&a7|{z_CJ-x-d#smRJc@S znI|pX`W{yFC*X@JV%qHdC7GV?DG#JtzeAMo+)(AbE_{FFx2NvCa3)92H*qNy`Qz>%jD1eAc+g?8(&5w!4`)gX zAP=U>C>KHg~_-J`;3d!>5W%0bbvy_bhX*bh8(sd8MC;-BcvpC%efJLo|$-1GEg zlIT|@@y5Id3GzbzD?c^LQ?Sh%@c!2iGTqh3R(t=&yzJ!!qJ#|Qyc!J9R> z+Lf^nF}``_MstTy(azSsps|A^#UVJ9oNW~Q23c%PN)5`Bz?%B)|n?G|*8A1Mb^dLKff}0uOhMXaJH?%aOc(nPCDUqR;SyJJVN^r2S7eKuM(9vu_Bm%heD zDvJm3v&Ja}{;@nc8{_{Tr1i0N{@i}(-%u6$?};83BvAu!K-)k`UtHs_LPVnyGLU-FM{=eAEyc)hplg4 zmtC9BxOIEhYf$CuB?ficfH4z_Cebl@mEGWYFd)6sU9r6wm-o~xG+Wr(%q82SEfe@) z2St}U7pJk|NSU%KE4l696e^F*WEm5>BV_!-O>g*(j-xIwhx)U4qN2k^1SR4ZdQVTv=wv>(Lg>Kj!o~g{e zCh<%2sQwe%@QpruDvP_{WBmtnsMDvb*?cVQbO}P+b2km^+XALPesV48XcS~D)5A3T zjc78+9e3i{NTRCm^~k5xI)muhTqaNZj$bXCGj?%o_vcRr16i>Pjij~tpdNT zU={+R&2ZT7F!u{k@bH%G7N&VFuy*`Ac$>SQ8u_m&K>{RS_RNQDzM+)gYFbLBMCW&Z zlxjcv9!jYX$ee@~KbM{l^9_6mSUc5bjky)cf_}}Fv-E}na@@|z3Iu+Q(;|z7`6m1oBeN(3 zk2S9LuhC_Lxp23}o^4dKZ$)&ETG0RQEbyC9ZobM%2q%vXvs3GHay(_8< z@uLqw4&D`ypqX{sfMsw~bqHt(qyKyUiY}i&d2|00U-3syw<9;k$^*Jd_+^Vy33f%v z-qViHn$sU)1EbD92nvW<1?G?-sVTor<94XgPoF-ejCQXbknUHF`KA2q^ct?S0ltHO zIuBA;=cjWzcN}H&H@aMZe`K)P;&ZhWov-Hy@msHXY=Raa8d%{;Ose@*UJU)4)8%>K z{ZTa1rN}|@nhmy)A25FVXhPmCsV1`HJddsQ;#a`j2x3Y(G*a)dNK$*4BbI;DV%B3K z?!|1n&Bg6l1r@PdqSnES^I#xr(+ao%OV99X)BLq&#g}Fx&0Ij^UU-_+MDH*B>4hRw zZr{Loz01NP4t}np4!E1&n%p?1?;lH4P(iSTm#_GUk#k6TZeGLc5Q$-^Yn>vW*FmTN zLT{R3RsAG;5uKMiy3Tp`6;)w`_!I`1aB$jemG@Z1eZ$eqJo@s*;k736H*IYkdyC!+ z-xmY9XPImf$5+fBzHr~HR1j(A=+HSlatX=vJ@rtX(2?MYf_)dA%6SFT@+PHb+kj|j znReIINO}6aBU8fzVD+^a1)p_bqChb|i3mDaaS2$pdaR?iYCIeH!fH#Qhy8x&R$a_l z`!+K}hs@r+cfTJwZS7ikHCi=%`;@G*j8<6A?=`LIBKG94(s>NgffRFyG%r&-EgXP0 z?m%hNlvlR)gT5r#&R=~NUTwU$Gu~{)3ss``&3klsfbkTA1u$lu^LgwtIv+xUgrcEU zoo5gdvhOII&MmGm;LA7Ax9IsUG8rIO#yr6LeYgaUP5mv&dsY=sYr+5wzc}A=q7?CI zE4zE*hGm8V54wmLWTmKh#8M_^z{48qh~k2x{~3hEd_`i(zhuwO1BKt?ZCedU3ptIy zbK^Q`w6dF2(^~EJ*|EdPetvc=XzSgmu?ky&JtYLVX2cMSQsrmLDU|`DlS$F4a1kS| z-6-wvdh!5g8TUsK_!G{pc_!>|fupm`tt%-wZ}w{d+NXykJa7~(Unc~p_B`WuAw(Al zAuww|g5Mp03XZ}!nTs@W&~$x=#5Ns1`lksVkj_r`C?xi;)31K1Ny{ZuRi~jm zk4xaa#cfA5BV*&mF&4AinmH{Moy)8!1_G|uB>n?TkXHM~7MK-CY=mu#86T)jG&zWy-$PmUMgLzEgepMySbgl^m}%ZKhIO)w^*;%TbW82)|pVZuqnEooc+&TV*bVb zqU&C^XB>Xf9yNagGPvIpau;EN?M^^-zsOl1VV!z26>^_->Bq$2~guNV4tD2zC*_r(hky zjkwZ_K(N)7>j9~F8cH>1as++{;8M>MP6@*9cCoxD)-O)v(YWu!0eqORd^hRYy3PR& zv6&i=RGvRPo}iR&ZTb=}2Wa@;sBJr*kUZ=1DwXvP?4!CqCFhi5x3Wu1vk5tFrdgX8 zdfRFb)_HGWu4P)+T(0|cHuCNOk2n?}sd6(QD9egJ3Red=qFBd+a&S@!F=cQv;1#hn zrqNkz6a1YL6#Rz4kkN~z>hDZ+px6x4{1LffuVeL^XUrTd@jJ7*)lC`Pml``yY|~f z9>rrRvb3bWd(*f39VT0JM|uDeFtMOK3vy*NSGH73&#u97#; zp9IukWB99sZql{#B1;_s1~REBC1FYXclxCO)bJlI=7lZc*`@5>5 z(B{8CL}E78bIfxMPUQ=byy5`0h1X*M`Y2xI?;>h&iZ}LLe|>OI_e8JP^UgOF7Iv0A=6kP`kUD!86pTiTu1>SORrYH9 z=3OrFt7#Jt+geaJthH0Wt}}qO>C{uX_S%P&9J3Q01^;j>xU0$H2@eY>DwW{r!pVk$r_exfLb0 zPWB4!x%Eer0W)*ggQvS)7#5VavSc5N*13I&6@YYCvBlTirczI5EO5Wsmmr745pDCR zCqNQypOE7K-Kv{+fbPiyI_$87n5&UjS~MkqL~)lf-1Y{bjH=4yyaH|<{yvn-+8dlV z#lQsv(rGA#wwCtjh$3z{iqd4cpU5fimZU84WNw)onwa8kO>4BPzu!T;@adTA$@9vu zPwnv3G3kuuwZv`I?Td0E>f_(l2HmJZzxe(Zk9%8_YJP zp3wu9o;8LQ&PT~cO+U>WrjH^WxC5sLNVO6-rqC;w1s}#WD#A6&8iIV2&zH5yGc9!^ z{QBzAEfuRzk=v5h(bWCcPAvX{Ut**G<*SIs9)1jBH9?a9LZu>QY0#YzYu@MHV6-fI ztOJC8>%32Jvh#m(ZAMXv7dtuC<$Q+v_n%+0b&DcVenMAfwhzacEovyT4lhQj9>3e2 z)^{Bi>Y0c^+0iAw_`A$69kTCSnwHFiX!#Sz5z#~==M*_6Ivm)O3=`_wN?4PtXtefI zw6c%>-CGZm|Ic51HM~Bs^!BiZDA6l@>pfdKTjbMCWqBL-`4#@Ubd7MPEs}J*TO%Q2K;P{baB)E zVb^V~Fgr!@Y09uYu;A%qR?j%Lz@}VkakZ1KcM>bHd#`A5vN}hV#eNh=>jO1%*W6A+ z1D0NhRNs|&Is4P%ciB>h2S|R%u9uXaWqGUU^!R(^yBnclSvK4s zto#K8cme~xRjWRi+wh>e*zrP}4`R-UN^g8_*GhgP7J^?{?K6Al?n7jP`V7(gp?Xg= zA6||?C)To*%KF}D-j`QW&jWHZkJZ6%3`85LUZ;+>$+p6Rl&ixSY zS`8eK&P$_a>VRW&KsZ2daw&EehQYydP*s8WVyV^;nipgmorib?dNGh_bwLgmt#~+f zZ}1SnX}%q88q+_v9N+l?+G44cLaZ9}w~SKr?24=!91|UI ztab4i_|VzbmQdQ}IMf&IlTAzy3eXzy{cbi9nH)EX%n5w3dh?Gtvk=^7A3DYQQv#_k zcuQpP@M50-BBn-*W3ME@p6h^%G}Y#{Zf!O!4uqNrkSe>lr9$AByn_hfOnKrwq1FbY z(mN=mvG4}E+g^at!~#)hY@9b!Cw4jp+0L*NrF|X`z^e+1(o9sN*bxg0s_%?Os$J!Wmkqrg|LAC9-^2Izp#3K^Aq=l9n3@f=1jbt89 ztm1ToMyTz6*~G5>a(##J_|pYmZjt%~0J{)Nb1&U59k0ez!CcTq0ijDg4qF+4-8N$a z7AeO%tll>FPUQLJi)y)yK`)>8VK3ofA%L?tuA#5lG;$vGrtQ96NBALD)_7vZ!hDe` zazgjbDKkbx1@WOa^En36smnD=`R+5CJ(lIJm7^vd<4zIB`R5z@!>dP-mQ)_6KN5Knqf0=o15ji%GqQ65 z^F}~o#I8OaxTQMAG8O$@-Wzf!M+t)bFb28elo$L`Vk^H|TJXTW!qVO{qq^_kz{KFtA(6^-GlF%AhjFTu8@`mnt4F^cc7c9fHp0_rOf%@}bW~9IL+{ZZYN9wy=F#>A zcE{y8mT?HW#H$j!DKg+g^Ha|Ez7FX^`vRJB&p7U*NwDMb(yeyus;HJxJI8?GKi8+K zcEzWAhq+Q&oW$+P-BpgoKk-ATrY4RzYMZmBY${0pu{3NZTYxP$h5iW(r#CCWhpo5t zQ$<5cZjIkBP_t$MG{C)}*K55riIu}dty(*O#UETR#I4f3Lq_aA-hhrO{B4i&Y?{`P zihrbF%@d@Yst?RTT%R9vQ*=3<`R=GrJ?iBB4c4%8t5B|FX0ZYc4E-bbw4sGMo76-ZzES?b?HlTT5PvgJy zxgSWT0oWfvkbM05W0_6Y#OySwMqEwRn1+`OP~qMf7!A$-oA<7ajLGU3f5Be6FAXW5 z!PG6RO7hd^jSoVr*jX1ogv){(m9wl>U6FkS;F%?Af-?#nSLoAoOH>J>bB` zv$1jhv@>Eqq1NMZctAbMvCP{MqG53}Tk7+-rEfC-#=*Qes){$24ZYX%++T$MjSAoI z@=GA=3_JT`gJjqDe#MY6Jk^8BOqXX|oVZrTvviM(1oO*JkkQfcJy9&5+mE1^7_E$7 z95?bc`EqRsa}0|QecJebT%?C|5{t736`0P;I3kD!S@3ec%HzMc6?OWqdViuh*~ZX` z-TxvBDQzQ4r3@JTf4F+@cq;$+fA~7fIp#TLwnH|VC7A~&BO@zhgpjiL9_JtCE`GB)gUc9b8CA#85s ziQWsnVA43vD3L+F^pv^*i&u@zMN zDpW0A{2(b(Yjh0Mo@xydal87U?Ai+$#xp*d062SV;FS#BD9!nHX|wqVVPCfY+b1ba@SA~Pom~}-NJbKL38*wIW&}_XLo0&Uc_STXbse2)o)-2crT2#= z!||6BDOV7MHdeII@mO9UA`j~}45Ik=$y=g)oL!>?@&G)e48R^_%f})AG}|USi5nsF zenIl?kBx1^930=8af{F(wII=4hVKT1cpy1jFMaxy=~bd+ zY?AfM_8b_-a$JaPC_@RizL-f0ZmgK;D;{skEl)DLE!L1 zerK#N&BM96w&0pINDe36{QDn?a;*jj((1f7VmcihPdEI%nYQGbl{jsN-1?^7a%Jef zxm!!i7n?IJr`1-|ATqSRBqxP80N~p|=67zA3R2>|8zm)LhCx;b;j=nMSHY@L88OmZ zucgJUagAx?{aFBf;*7|-@!XdO(t1trQ#+NcmxS;uDr}t07C%zrB^~FzI6SEh(3!~*QrHWLHKBhn4i@SHq)8JSW@6t?sEqW&l1oBZgj3B8FU@;zvl9sc~dgKgPw5pV#fDL4QFSJ zWcM@vD$PEQn1L2wD&)9Bj)P$1rLCZYyDTYWh86Kt_e0~dinlk$ErtCKBq}s;LO!z98^TO#nXp31$CxoI?R2a_s#NOjNipsp=0HafUa{FyYtw z`wdvP0erfl_)BR0RXWrDFd8J5xg-MRu;!ovB@>Xvgi$rYFpaATVBn&k0BR{5Nwc{} zlmH!l-6^0`MtO+Ve-GDd1r|h-Ibro5u&{CJ3^r*#)nohFToa{kRoVF z@)%8e;9_F_zaW{8M!XO@Et%uR(9nYWL_>%eQ1ADTl|CdRiv0SQw@UxpYZGCPxc?lI z5(OStT3TMZl9fnWJ&$IX(EHMpKRTUUl}yX6q(YznrJY4?>t1Vx;cf%aVNavG2_IX2 zz~k3mhNJ5OiefR0$G0yPB6Xwl`=8YQWe+3m?j>FEo_NBX3({mYWM%P+n;4o`?HhE8j!?h9)00 zaI?#uwCH3`SmUK*=CdeJ(ZryXH_x5fc2acc*Vg~%PM_yx#7RP|z=LI8<_K#E8dUVjJt zvIRfy7g0IBr`*^7rQ+#T7K}A%G!>BGN}d77zj*r*Q75do)$G(n z?L*^zims4Fs(5A@lJ zk}yun4`dGhXKl9wd=7GwcH*0U87^`Ns~U>zFoV^Fy>+sG{=C%wU_yg-OrBLu ze>&>mnZK?5rAo;4Bs`B+s+yr+e{qQHY;dp{v_?DOz=QT^>(q11xof-K*M5pJ_1>d* zau=m+9$df*Qp)RO52R2V#zYqVju7_0mS3sujwXo9u6B=Zdl8>;dXO`z9OX^|Vf!zP zTTzWxS}!YJ^Y-&?{c4Z!@(1>p!VPxH8_Fk}P7GCPUE>A&;L|9pFJspTC}!iyJh9L} za?_h94P(+SUWBg^w9S0#wXt;-@)~q9YenQYcivJzkCb87?z(MQtp_l?jQj~1bT_#fc8KL9MnKL1K_I>m2Z(|l?0&vFb)5Sh zfv};dj>l-Q-bbqMu(opZI|NjudnhxkYUQ% zCc_B}w@j$Ky~9IRrtfn7Z#4k&vRiV3uofg1=C2PR64C5nl}~tHT=GCXE;ztGb=!Z= z=LpyA>oG*ELQ>%n2Y3RE27^CPC}+gke#bBkvrss}-We zzI1&P)>^Ut{f6gGPJ?SD@<($|+EoE37pDq$78jTETh1#(mz!<;KQMoKs3>~6G#bZA zYs07WLM}%NuF78t+4t4Bx-A0B?UEfQ^X2l*=5XcVe&m_NjV&MM_95cN$_BEzX;-Z* zd5T)L7o)I)vP=72yGh8$v1|KQ(G+0#GA&>3R;y2^+w$0Y(fV&@(mJWxu*Zi^GqR%j zOVFqrd)f1uf+P)mmyej)kp<&oVwbBed}K(bb!dI-T6sR}-x-D`bY_9%Z>e8wRtd>@ z&!9+Zz@ULtNVvF983oa$X5ixhg0Ed4jeS7AG#|9sbbN9sIu#**>~@b`m$*UpZ9n)U znbv|hndZ+Cs8(-6*wQQ2k0@BBI>*vyL~0cw+6|C8xQ8<($_=6iQJviWidOAJ&ul@2 zZifzH1h3%dO|`A%Mmd$59e4gEgr1(BiqTEj!LU&>{wE6AjMQfcLg!t7|0@VxPE=0^ z==Hh)b*wwl;wAlK6t1-JaWiE0zwh?__As*ZmCC5?(T3fMwCUIRB3^w5$}9?6s&9I^ zjc#Zyh}r(-d?Pn2G1==ah5e^lh4Rj9Y0-u|<)nB+F_cgmcVY}f5obq=FAauVlLCFI z?1-4tQ_hW@8H=)^ZDYzETT!UvbnDuDs#i6&dnVcfaf8rPps|H-wB)misHb}gJFD`9 zB|HI@yGxg%?QQ;3uK6oe4qz=j^lz5+=X9%oCIN(9L`xhj<-Qx8&R$a&-QC`~5{{4P7NL8}Xeuf+9SrJkFm^e4Y9yc2ylM}O%TgOs z2}_duVZ#SXnYEL>3o1!QD295|^Du_g>%;!rxEfI-GpS8>(sARl`WfR|k7whz8hj!^ zZ0OCET=Rc3;ZIBO%;o-_kqF{lg2>FipMddN`AM)ToLqptKEC~{0Gh6l6S%B9i2Xa` z_^zkAwRwOYPD8gTR9{*wq1Ph62%b2PUT?P{&jQRD;g@5Upt#!@>;3oT`!n}+j)v1Jj)nw>&Egy*LDnhG4;?AOHj zpinN|+BRq2*Tb`foLh#q707T?#4hFBAqnQuXZz)O$nTSjUC`orsRe1MEd_QP({p6Z z!73oA#=SOjHr7^3e&<@Qb2KYEB>LXjAZTTw+`M0X|wkwDQuGaM=IPv={CSs?t$ASlCtBl zh}PclC@OUo-3K^Ouz`;+81Q`=(zZ>EkjjTPY2$ijRgjkNq}8sUvhO}zkCKLqq zOR3l)Q_nTVVMY|S!*N~qa!T{>f!qMIIZVdnWZ8$V1*tzjtIM1YuORM*4UD_Sj{L9q z)XoeG#R$!M$a5frtkyVS3GbUb1m(}BEfBn4-=C}eWUfUts#QFB^75DV!5^_ZT$7Ap z?Nsn2sFC2^THBG<3zaZJuo(W=YZ6R!%H>UBx*jx3Mhr))g$1Dx48lHz-rc=f)(X+{ z92)#?;-eL3OD68G#F)0`%drK&-QYm`Z^9b1t1~y(Gwc>;9+7LLB*TFf`VY&EIxf=4 zwb>p%tvZ|AQ@30ACd(2-#SSIU^YKTi!gn8}yHr-P1$?oc4e2U_i>^K{tumFuF6=as zb^dZQ5CC!raQu7~bNWS(n$ON_ks5zfzZti$S9ZQ|nA`3)9zAK=OJ>MZNVBHp)v3Ec z%PD1PN!2=)B}#81U&4kFig8>=EN<;F)V&QE@Q!%S{KQ+*INI6jM4|E4DV`jqxzfA4 z8qgI6ZMoORvMz#~Nw2^Es;e|(TmO|HaBb;})|C)Du}&^{S0XGj=;BCjTKJ?PL}gOr z&P@@O!L!E{bUL?wbT#9xUToFway9+p<4A2f5KJp6%6zd2{xcXB-{?5i^5D6w3T zEIN;=9iP^Wh6a?EHBF%UQt5BuU;tDDY@M|;6q6Y9dv!MODv|wRYTD;AUpc4i^MA6N zu79;@SpDtkG@{cmHM7>*!9f~boF)^vR(5^I{?p%Z`mJdhgJLZ|r+rlu|46?-@#;&j z?zcU^r7T_XJk`kRwcC=kVcwsUF2#>M=NrRsGFxwPN5mT)5Wn7#ve*h6{C2mcpKk;Q z0NfmV%GH~1)6d%Hs|UO>BZErvBWOUKTlB=mZy1LN{}qE4I%8*Sy7`36jJD9y3gffV z5z`cg1!upq-~Y{id#*075d4Hme;R#N7rkkF+7=4L*S;f2qnWk6#lvd2#lIAXGrM2()Lh#%Un#u>mlOGa_{~KI(JP=yO#kW>}mS)C3{F*8-bggBkT{WKTjD5-aU1)7Ju1$_fn3>m?(03!Nz;IF{VNu^hZC+bY3OTZN`#h=~GFld}y;O80INSSd#Nj3_AHG%VA~_9G;@_*=GCX*&KJX zMILFuRCw6rcdXng_^Z2p3u3B2q2PzZ1qwzSipS~}t9C>GPKN1yfBzk6)WJjQ z_Pxvla@MBE_EcuOg0{w^`4a_aoc@RDB3rQ>5q^G}_+cOh@;tT$8TVXvGm7tZkZv)K zV^*YL+Zj|j6mOF{rOpkiAma{C2LNawI!jHumi$?tnDpjBkFySCxMbC{!gGx;UtC5b zK3W)an_(QhDS~Z+s$_wL`Eo?A&$H{-Oh|BzrxO$n=goZEuJv=S4KAN*gMW2jb{Y|p zZCJEy0y<_TAIyZ`R2m+)Q4*#wK4gW>aLW}13hlho~{Ntw|r{*1Q|Jsq#k^_Sf;GOe1$bxeJ*^5&2(l-nF z^#5~d{Cx|GzAKhLOeCi~N}-WA!>}}tfRZnyp+tPhhCBfX6l`^`f6{kOUpXJF{_7EBxz&f^JZSI%4Ru4riyt#h& zZ1>q_+|li*e-;1rW>^n(Ng7DC;RDQ=9@=^*#?S=JF(%rWN8btTcpnJGY z)nO7ZzzQa09{7ldE_0pt?fA}ILCe{$(HML>Mt`cvte+hzzFM>e>XjTBo^H0KeN-*G z?`W;`gQuFDq@f2pI`0Lz@42hZ?AvFSpE@svLoO#@w+OfyMcvCyz=56x{F!CQnR~;g z=%;$?cD0!hB!?rC-6xll$?(C>S?Qoc)}ukii|l#t*_4le^>d{6@UwpjwA{lvcW<iT|Eh)8-yz|Q1JnvpFJL`?5 zWtaeURrb#B$FZx|Y|W@}0QYasTi^$K@2E{Bx&n@uAruysS6{jp&?oZYr)r4_|F5LP z#_3yxM_)`GO&*Uqz73V`d!TWvK;xJuq+oGMLEo^~P723EYfHc;_&CVT+YsO=&<~kF@;< zeDOnWZX5A-9qg*OU+fMFdwD4DIwus<^eF%DelQ5a@3}7Z0?}U%SYtH@+?u@@gzl@z zK%!NrgCD*)gXm)lt!j$r%00|-u{@CByvHXDM5Li`+KV2R$X@V`xPO< zBA`HEc^PJx$4;Q>XjaGVDQh0Y5i2of{;(mKVPB(uVH_{ZpwTwvORe!ca&Ac*8RlfAqCKl#wjo=x_+g?J}YEC%m4J+ z`AQ#RVuq%^Ei6bW$_N~Zs;gFLW|e3UY-)=Yjq2_+o*~xl??2y(uz4;&A%7cJ2|D-U zwToKd7rC(_9$7+rG7xpR*xnCrfZS)!U2S@OYPL1XYmvS=6QTpv36E)rNuR$xt>tdM zlNWW5guh*o(HhWu#Yql}k<+&G%pfDl1;fg{Ce`U-0H=Xm0Oh)=Q3dYEhc7KQQsXo-7oTy%n)s$t7Aaxx)zWKhIam@b&@nl(U3y)xX6 z+Ulm=l-s;|BQi*OD28X)w_t#x2uVWXboowHhG6wGiG)TZS|Pug9Dqb%{BYyqv$J6s zLN`QrN4P#rWD8k-dI8pdBm{W}ny(vqQvlIKkF?~``#}K6+1peLHMY}q zF~+34(gbqen_5OMZ7X_XY5>}R(hW)XmTtBR&>Og|2q?tIFG;p{t{2)P7{tFXM1Vu$ z0_gWtkjRzF%;SuQfvj@V!K}dj0ejEwn84lo zUyDIqJ`0YCJHdO)QhvHIcdnS;AC&t|X!`x->lxJzpSJqJTFNe0fgT!4iK}oZSQ=)^ ztMNRv>33e!{!I0Y{ENDIpWK~~eUK^}q70l<&gEc_6Tr~ZTcQ-?3j+T%j*0q=!96wS z&u^MmhFX+NSjzYjO8VWuO1&`cj|;=3Kjt=b`!L1z9I06&qtIVuENm`q*^=AI*2TBJ zMz5aups=M=npkw@+UA^LizdlvJ6>t0!=z6y-0+q~)jbWrU-VFv`?F;3h#yw~E}_dm zenjdNQ!gC=VbpFu!i*et8eguLhP4jWXp88wl%(%GHeRT?{np{N>-)A! zY?EICff(e1?0RBWyfOyeU07KPL6DYiF+0_At zAT`x{Y`$b*3-d$KttnNea0UEU0|`cd)2qPOMcVFHE7Kb~0@fDvw5b>a%Gl}YxNdFc zd{)EeF9+nQ8wL$XOhyu?)~meB**{BIaC`pDOjA0kmNquL9N(S{fBIiMm>xVt)a=Hv)O;Xo= zI?W`n3>ObRLMy}~UPn0L;9vRbJ;mU+s6yokZa?1hI4|`&eLWc)45vt_Oak4kvq_2M zY_s7$-}%2U4LQ`mE2z8d^_x5V`y_SGGt+MOyML&e98cwnINqGf^lqy#a|e@vx`2(n zI|vNvG*MlQb$rsrRaZfzhpJZcm?NlrIN~E-P*?u(o@-^pqOafUpPt^nT7IAI$Y*($ z4S4?iCL8hTwloJM+RZUY1vGavo?NfnYDaG;$-1V<4$K-{YN!J`R5^hJab+%(gOG#| zE?H|%=C;{F@Zh28_=41I$UVzf8xTh%F?9pEd zG6w+4h>SmpPXY>27CVUN&p1M`CdEinGxS-mIwXTnXBiss7jJ>RXN%fj|g*EyRlo$aqm{^L#21?aX|c zJB+EWQ;P`4GvWZo_$S$MbQ?GY0vbxQGqdnkh6qElB&J>hDLDd?N+03Y&hG4+6db8E z!vLVxf?r)Eb-`0Y8k6T*cQ^F6f>yDwe-#y*vk4Fb!fuD_?{N@tdqsoi5(3l}@mc7} zQ*pBcY3Bn#S-b2S_p^c^wgymmF37}u;6ICYv6SND!L}&i^EBLBV{D5=t6OMa5W3^j z58wpiaSHjN7-5NB7>VUR(rttSIhrwTg8$jkbN2rjQ#}JZ{&r@bfpIfCnVKOiNI8VeUY$;+aAD?|GA6#W--2g06mY0G@M$Xg1&x%hy|4RNj*I z{Q%D3A)|p?+UF9HV7;XHkiiLe04j_b3%mg3y&Ub?y@+9v1)v8VaGM*<`zo}FkA%Rv z+_VbY1)9X!-b7JnM5*79&x4^B6B3{-VB3;nZ{C=H3=-@)=h6*JX4kfOD`f{@qi5z} z(X9U1M0`&51%a8wj+?Urkst(Xf4^*NX#w83%FyOD(I%B_EhZkVWUDXF|9a^hm z#XmQvA~pa_)rxg^q7Rm_lddp{_3rKSqQ@^ZF6q^A%rGEv&~x+-dA<+lIVJ*MW^_Z) zj$HW|`>@%r-`yU0KP2;b)8wPuxhl|hP35QPKSup52Q(WGqn|Nm-cK>jzPqwIUTrVx z8^=mXS!MSZB8yP~of&)Ogmh2E56&rUu&26bpp7~^l`kI3j?$;lHGUe`6!3S#ssV=a z8rBMeQ1)0qH%mEI;ooASsYI8*wUHS5x0&YFx`1HI=5S&;jyqTPjy<3lM+1jrG)Z&`A#5`JGB@{8F1|zM&zzqOp z|6R%7(g)0{Vl2G36lTEPLqrCD9E-jTD9Qe=R{`^RW__0w_fDw^C+Mplia~f#k|=Df zkY*PlVH3_(BHUf6>fc@%w`RtO6Q&J%7!Ol_+o~Ni0^qobc$GT|hN4gOK^gz8U3(|x zLeCDYU6NJD6G+705dg(SYmWub3D}Tz(!I(AC}|yc6j3F61Z2||*jyfod3<@6K()N_ z9iYUaL@$P90f@0q)d|6kj3uJ|Co20|78`OEj$>ezLP4h zLPuo06(VB8;|k5Jqitc~n4W%XfS9KO#%k&0E8Ae+g)MJg8|V3n!j4e}Zp2%m@|Fnj z7Emb$4Zd!duTz}t5cejFaiC;Fa4vF)$^{ocH_frSCCzjcu^MCGZ@Dg*2tM*AhC1=ep}}woYt3t9msQ{D1`316mp1)16KTf?mpa z`X)XA1&hvjxsB`o6r@EiLEzx#Kv5Ii`u8mVUV=pbJ-cFZaS2=z#}_*<_DO|cSP+%{ z=wjzioLhN#%u0eeGW9HYQ<}n!56*2=1BhcP=~K9``mIsLAhwzfO|pHe2W16Fai}f+ zz$qNt`B~g(J8Dt{J}e95#JaKZV{%(G@TDO@0_Pix?`;01&WdZdgM$7Gi~>DWDJ~#*RO&T) z6M1Yv4rY0E3_d{V@DNf)o4ftg-sONF-J3!#ki&2}$Y7YQAxWFto>y;0Qzi zF!%Y&*J;Qu&F$;)Rao%_*U@Zn=Hk1Ot6(-5fZOAgzkP#M-fz$Q>F4hAfZHsFU6xhd z-(0|4x#fKQ5m2X2`OiVx0vRU0T&G z1O@2v^wm8}3li0)6MXT=!7OjmEwZ9T7ZSakc9yz9pZR9;z^de(ppFL4>S@&nqkCf9 zh{^*1I!<2pMRo&s`#S^<652f2HXIaKmc1p@R}Mt%cC@j*O7*Zxk3u)cNyI}o0K#5^ z%;xR)0XrosO=Ki>*o}?&AZ)G=sHs2Ztnwnu!-DGe7A0?WdAeRIxSf;AQlg~nyE{kE zUR-Cq7R&Hi#yjk23!5%{-H1LfQKRP)4HtzwkUQX=aSzI6kf(6FH`?wQ*RLwlVM`E9 zUN}=@x0n6B>`|rnpP8)?Vf67|u)>J;CAKN_m7({3mDh}nMAk3DsFgm#@uj*g#8Wp3 zkPl+D>P{wm?4%fvhM^cdpy_C(Egmju9>*W0?FAq^{7A6$rTPXdMc~zreV|Q4kvmFG z2ihhH&3ED0cHlV_p+rBQJBkwe-3os5yY{3k)o$%QlIAG)l-=Yg)7X2YOV&H~*v2nN zH8|ka{;RE|sGrO@F~iFq8Tz7-N-kj8Ny+b0#=HhDTlJ}6q)D8dn`tb!{>;NV+wK@A z+{s%LXYlzMnT=9q(1Ll9OLMVV<=kwzn&Rb46KYCEx zi~^$5SY_HMO5!h~Zhe?FXUB;_o|7U_I)*8x!FCDDn@})WGBK#oa*PRAqfX!yNx1e@ zSLQZj+I5?Zsw+(5s!VVkw1r5<6%Veaw&26eYu;{!y9V8Vxl!O7>GdE4u5`;@HP09- zeS&hzUQj^|6UU^$;*iMJkg5WuDBhO@$SssC=*pxlv;eSfpI-3{_X{t<7R2G$c%PPW z?|ZMtByCxsvY4r805??hjnbt;`NJx=ZO=EA$Qhh+#tRO!r4jSm+X+{^XXOc2g_0)c z^gRst^5bWl-=7!LcI&#neCSkPhw%7xcRF`Ke~niu<_Ux2J11}dE5qHq5!*Q7JAuWL z&J*I66LnKzXAC8-lt*vRI>1PQ2haZ}Qb6!OhNBYlrro#*;jqR`hoYupVy3cn^ozf| zAAF$T1=eppPoL}k6E0(qWH;TIn|~knC?Ef~=^&Fn*M?`&~~O z0v5osdg&T#ZEY)xHR^c*TsOA&N%%9#NX|#RqeiG6w6yj-6Vc8poo!rX*YSR0n4PsW znoaln*Ru@+ho1AQ!9$Cl^Q##JaXuO@n(QNvK+f!PWAdgOpFjeL`oPO{`bM{wILtkK z)N$BT4DJcuo}Qx7@?qhsT*AYbY4J!3<)oJ9S`YJ;ee@F6QeK$;sjx2k$Upk`AF}qx zbrj;;31S=X%|_M`{UJ;D(*1WVOCJn4dF%YtN|Wfz9jR-NH7{Uh;YBmDbf2uAoxTv~ zAC;KwuKfXY=xfMcRj1(+AqBXx%yedJOrVV&jb%<1=R}*kPl#>-m@C&b1}`zDkJ(-2 z>lrOpQy{vGN}XhLGnNtDj9Rv?PzOkW6AK#SUEcmpsPR-WmN07JoBG`u;j4{zyk#! zDM%wL^80sy2$9am68>KpbL+@qsCuP-P~=fo+xI6N?LJLx;m=3!yp}S%e|IlE=g0i# zqq++ArBUjMIY-OcmBvHoy_eoR7fKEI|0^{N{3r9+b!K22MqkN(70_7Jvn3=*LI2bZ zwCs$ai{dko_UfP8C`p5ueA9cZBW6{bprGCh2XJI$M&+gVgDC#FUkU}cHVBQfEAKIR zO;dokyRNE5T@@0JO)4zg6GU3;nTiU(d=Z8nhq{+dk^&(BFC$rm!%gI5~ z5)6lS2hCcyY?1^kVypVV*4xW``H6Gm(*)gxXM28$vcreo zHT5+nYYn5@k4NKDmAY|SiflOBr^6k3eu|Dx&33NlQ&+(Be05GjP)Q;;9N_vfF>zR? zSIh_our+36o@-O+ZAql9lD2CoEm5|^P~wq3079e5*x!i>A?g@+TmB4BNF$Gd5 ztl!t9JoubzBo9_A@y$IpUMQwKhNR=4>nwVVEp-gHM(nohFY>rd)~ez;0gHaIUcgaR1SL6B^u-4FnKxgGBo|60s^- z%am`d#YfOp5YB#pLa(lyg+%1Wek5`0W>i)>-Kju3Cm&-5CNnRpO`xsFxErt&+*p5=a8Cz{nsces zSRHOyq=2Eob(@nKB>6p>2I4`Zj+Ur+%%jJzLnCm47c7!O@ zzCN~(UVYQYef{HSmQO^P*L>Qd25Nif@cit#J6Ruf8}bjDZCFE(aKiwBf<#K2)F>~s@M zwJLr8e(JN+mE7i@6~ZR*uip*(4taU8K6T~+cVa1z*we316QMQEyu!O=|!h<$z9zo z0N#>wcL~CR9>C;V;ddeM9deVF?}$i#`G`GzYHMJb7=fg8x_rApWA8{w;NO6K;^<<@12h)qIjV;2Uz1H7J1I8qk9T=rduBvq*0#(* zf=Sj)x0R~VXnHHtUq?1gG!dRf*Vdrz)-Lsk;3+RN?NpigU|;Z0knEgzTg3Joo6pc~ z@A`~@N9W|l{~{dI3=i$N_^D53W=w@ZLF(_{L4Ca|^DJOWoslIuIK2_j)27!_I~@dn zBEXw?Z@(pyazjySu?Q7KE`VDL~(}J0`PuNkPx{HUKt8< z4-6DO2q4%K)4@YrXyYa%jqhhy*ASB;=f*>L$>HXF(M@SDLx|^wZ(X`#TKZh` zeU5epkx?B7x2ND{$4mU&1ozN2bn;$2kq&@7dHc&w)_XNP#w&ir{Vg&;R1ufGejf*z zXD_Tz>#&K1HKqn+$AS7wxjIfftyd{aUWG=%PakXs-&1Zv3p-xQrQs*_fOI|;OhLDk z-;S7e{hbL7jf@&7nC`ZHIFfz)t#*xNslG8uGRU#hIWD$Vs_t;q+$6)je=}}+m`3+# z7F!c6z%O2J!^2;Rb=k2T$x?c0JnK~b!fX4sb8EL=-B*tCimLv)nhZ{w_2~a=%jsgU zS|bDWEG%|N;%-XJyr7*<*|ZT-ZTOXssE8deP8pmLZLX+$&`Hh4dHeqFs_A*^tHN%< zvVxc3mZ~eQspC;yiECm6Tvg5$YAyTeaj*`Rc|W4?x*_;UusA z@6rT5NhOQUb|@Y>P9e&ZLrda6NT2vRIhR%*iVW42A!IAC{tj7|d0mj1=)6{Bo8?y0r8?jIvmtxzHM56I<7%#7mClYEJ9Ygsc6(-wp3riK zPmt+@{yPAd*Jy>3Sxw=c_t=_J_HPB%wIgHQaf2 z(fs}!w2a}R;(uBS@Smlqs?Ux&X17D_lxWfm}$0n1`?57HvMa&<}-JC914F*HCWqD^MKKOfsFAdiReL%jX zs=&c08$o@EQl(utM3wS@JMz5WO%uGm>z%J5(KIhaEF9FwDA{DjNF1F$<}N*tH54-% z(}&pxf@^g5vyIPod9Y~V&GwXUcGE98KN!iGluO61J@HK5R(L8}Q9%}&N&s*eK^{aU zKPB#}NhJo?EO(u1*zqF>Fhk5+eyAi7l)+E$gOQg>M$3n#=!X$nJvAt2G%?oodzuXL7r_Q1Nu<3GcJM{um16@bSGjP65Ng7TaR z|Lk`I3=t9hRSA*td=y~iiNn~{3^q`I>2wp|c@l{{b<$80cUk{D}`PE(wmImYP?X82X z&(3bqP&%P6MXCNQy74B{cnC^`>%Mqk4B3p**}4>l#vwTfl$pIAO-v&78hPSbOix9@ zegsoQB9;1ec+z(-aDfnz;l{4g!&orRGq0CpAJn)wb$oe2TF$bf%Vd5B8DwRqCV&3X zMk7YV4Kr-a>+9hiVs=}A&sw&%VAH=|@9#;o!by-c&!U0t+-DzL#B;Vg>d&jD^Lu5V zUot$kNMaY)eSY(RXi)KN-|%e`0etcSJdAnljIx{%B{mZNgy@z(Dn_z4nzd=2?@5Uy zky$z;col=j`$Kfno*4;1x?&FG7wACIa4dZhn3&e*ZKn1LC0sSoI%W# zg+L#)i5@v+#STb6SY6MH>MQztCY0`d=zTmsdfRl~_)tI4MWJ*^m2T^qaDtM#`-9RH zFv`REmE9A^Zs6refYK8&f(h5P@H2eE+nUov4Ss5pcm?9Ueo6`aG}j>Gl`H>ef8OAHgEdO`Y{q$K@0yL z28wBQZ$e;x>Uuf-w8Ho1{nW?T-Y!tjL{c?QP3aUUZ<|8fG~}WUmO^N2$()o{8IV9* z!xFw10>(q$rP;UT0Q1MVD{`}iIj!_WZIZQLq&%YpV~?(GjyjItnAp6AvbEiGZ_rYZ5MUR>#$E93f_W>=cCr}L9j4O-*nqdT~!iG>sDYW<+fLNq_J z@;!(+=r6W=*n&|kM$v!N{zr2NoHbCeCqz7Som|DK@((ZecP}-SUiapvtgO;zVVk^3 zUZZZ9 zoex6Pe8a5PSLr{(3{F>L%%|5iNS8fHuFe15eOPU^eRF#9G$>%J>rq~7LrbRRPDU6- zk2VfvoGSNN5~q|BMGwbbK z#K{Y^-EjKKf6(eOQRvK$+Zzxx_3TF}yl2?{8)^+G{Z;y;^qL6iU0{}o1Mq(rQ^y0q zCBAgxo5dBr(9NewJADOzUdi}^Z~Zqmd=TPw3gzQ7Q(4QB7QsVLJ1qfFelYehO>_M{ zzFfM*dCCI-@s^0RC;I-am=<+t^aKA-(W@U*(}DlZ0&pCNwZ$8xa9<8UV5Mm@FPHoR z-&yQ(K~_-h1sGD_(x(q97>LR#l|>&}AOd)DcUR}SmF=uxN9s3p)!0vP)036Y3yHBc zEJymSx`E1Kee2yO*B70m#nXvApTmx3DC*H2isWjc;@Oo93SbouKt?2%io(pFyxhyH8hiM7C_D)bfR zr^JTK9~BN4+15Dj4$_T|_J zNxT3S%rNW%upaqYkc<|TU2DZ1kZ;PSaP7^HojLL6*PC!@&AXb z|Bh$->;K2`BoYaV*jg+0URAY;h`skHY8FKst2RN9s#U9~y{l$RjiUByt+qy6)TXGS z6s_`mzFzO^x<0?#_jdm9kKD-f`8bbrU+0`tpxMc(C3wH%)SP?urYrlyf47fXJ!5o0 zBZ%mlEUiWT0*h_0DOF<%k*^S``u5(gkZm&k&sSOe33|t`St#E&V^k4L87_&BtM59# zs8N>pTu{us$^qr{OypCDF`K#Bn|6-zOzSGwWBg3}>>1avVK=f|uclgNe3Hm}yENOu zBtb^fg@NucWQt4BlIG`Rw{BT@oJPSb6Y6bHs>W# z4Yyr$&22pkaXoYNkrjJ2gHLVebk*^byWq-E?X<&iQjd2zEb0BvK3D4&opZ~%za{^m#YCs81Tv2Jj5A#Xv7c1a%26mWdE!LKsH>@YI9$G} z=5wdf%!L=lV6Q<)GND(DAp&_!lsyTmeWCndu>Y{}A9@`nl2BxZ=0KJp^~e)Poo9{? zaIRZ1^$k#HVNO-1>;Bb5`{*d_eK2$=S?(0~<=dg>@c-(poYBEyYgH&;KwkE4R1;&+ z&J&fF6PE`kO*Y3pY4-DXgQ6%$U9S`swfPLT+CE}9({B2`nnWS$jPrf{U}RBi-D!|a z)yfLhilTOE6}z1ZlI{P~BS?7U7bom75e@vBSyJ#VAS;%=#RBl{Hq8nerd;~{4A~5W4g~q-xcbo zpFJOx?h0{JjhW`Lb8xw{Ik(rc=&IH-Klq*L?=2np(dd1!5N0=D-LZw7V^83lh~Njo-@lK>(~_q&<<;c1 zO}LJ~MUAJ9KW4IQjX4UISCJpUxoNY#r~3F9E}#0FDm^6M+#HZxmSPNGiNon^2p%@z z*bM}*AOx4^gP~s<*rGHO&sRvFISl_&vrn&MK6#1^`4opExRzCvM_+3OyELy3#m{pj%L+YKgY2V?<|Nvvw!z&kP7=;y8mjB zKf&>qw=+w@p&d)}7wkjU?MLeWQMFGacRR^4YOSp++9nkHS$$Rwg46TH2=C|xRjV6| z(CII}ncLY|nULT8BIekv=J{;Pq>-(d!QzV;1%JjXf#*;#nC6kqOgqEk68 zU5*O8oT@N1H=Z8|D|aAkf6E@jc{kNOxh4RNS;~217u{zTebAwL@XCh{C0p}MIJneKIJbgENuzYGkx7=glp?RSMS64G$s) z6%Y4ZS6&dguvSDTN^cte|OtKlKIwGG}qo> zPKCbXIvMyrKLkx(%vT4cLy3O0s4gR!d)T^~IN!VZp8km((sAkR(XpL6v1y1^U z^`}#R;oT{WJbw>ys7IRN-LGuGbX|`}+d)$vjB18>OBESJvy>VbgZEfO(z4CG#jUV$ zPUT=T%m&bnaFMYAquUX5NUZAv98?ypF|(^65nZ?Yl9U;&R`|ml50YSGWfQT>JP8NU zHTZX9!*I(zntKfB?W_QB^cPO?b~cnh7}DHT^n|9;9YXl{d!;qQ_HT$H>OBM+T26LW zz?uf0J~Y`WzPU!Ix?|v(EP}l|342|@h!PmkRz`4{7L$mf=aPN{&&X!$cIspUpX!Qo zRJDPf;+2@t-_!y!{j?ACKf6s#_dUCI^Xt5jX7aTq1N+F;4GJFypy2**>@zfVqqU(d zn-j5-VIL;qHIv6z@{1uWI)ADm;^)7M)JzT#Ih@RiVE( z$jx=Xq;KRP;ZUVp`MQ%pEY5ln$0dp(l}SxiQv~3{b?j_&Y)~%$)v~u~wYG zm2(d}JTm332u}&g{3x|wHn*6OtJhz&+i)?Gbw6&&VJM9!mjc=`k{;TZgJkUXrIC2r z_tdJdk-3ek@Ng=4I-LP!8-_Dj=gydSZ01>JIFNZk;waD@)JnubJu2%$eV!j8SDpz z(Dqqiy_bOMN)46kp8TVt^rpN+q*)3(_2ur@H=^$rE^NOhuQJ?;5-JTPhXu((X4z`mTI9cCZB<>h>lr6aH?-=zs&NsV&X-Hy=Bl$a-R zQUI4Ub3L(H&K{7`2^y@Oggy)>LxS$uq=HajZ1-c&uEpn|;wRbHslOC)DuM_^1m64_ z{a+g!{E~b(QT|OQC<+)|cjqn(rXFvjC_`z38q|BfY;O;N;ev(wX*3$)ewyc@%4YTms+ zUx9h-LVxFRr-g$^MA`jkoZOE&BmD-b+^yV(e$sOAXk=Ik)L`Ooh@oo|=tKl-qMPEH zzC#hv__z*^-|{z$93B|9p%8slqI%zT@$Rdt)nU}Q=f>a?@oE=2RFQgKAG2Y!%G9#1 zKY{dzm@G=H42GDojj^1vzD)&&h{Y~-BQ|}zV1-U!+ZcK$_t|rrP)gBx)P%~k9dJrK z7N2?&Llv-eYw)EBJMs$|62>-@+7A>-0yal5)^X`C+GJEvsksSctW;jsO8pH-8gYb^ zI+|z_)jlzR^gm|m(pn#h=f?zV&;I?ED#(o)cFlN5=Il3;KWSUS2T7vZzk8m zT>8<-1O7yhvNkwD^r-aDS?;Kx2DtY-&5(Q4|BV>$sOF(J^vb_;E#sPp*cT4ty=YJs zyy$so7Qd*1C>S$8%wkVH@a62))7>%{(Upvxf)>@`&5y2{iG~tOf6GrCfEG)ABh8+R z>fNJ7*~k}p2=0~$Zw-<}$2_fr!k-J>Tr*mWC`uwn@{~5ET2!SiUWTwQy`Gv==-7{*Gs1+D5I>u=q|(A(N(jtzlL8 zL8)M{>O0r*M`sT?Ms7ZZ5{Gt*{UTSEC&SV38`ak|0vht6#8vUXgZq@1%XYu#8|_9I ziID!xbN^RoUdH&jK>@hK&SKzs#A*EPu)g2?3;OtBfg9X1;xsE+2MuNxjhX{D?Tuy% zrF;d;YTa#3F5-dNQ70po7(Xqw%7+Fee0R*c18Qm%2S*2zfX?Cz=_r`kK!#mbBt%=d z)wg$GO?gka23+6U+XKYwpyii(L}N5x_ZkGo%M-85J6vZv+AiM^1G*3s%4jdP5`k17 zW0URipyTFIMxl#V=t`na(F0fPZ(pouUv zEJl=7I??yl!z7p5p{K=_yJImnTw00b$QY@FNL!Y0l?DHe8$9-p8HjO9q7bVYy-UbcpnhGE?=`>tTeu?sgjGKAc)g%L(1`GqY7 z`E;|-?XTJ=fK)&Q`+^_f-sQ&&2j1!`&Q2BuDzuDWVerVe`tU2_h>@@XLMT* zLOSnA*|uIT?l*FrE*gcd9O!ebeDcfOuX#v$p!(G)h736y^X;yUo7+0L_?&vjGZ;%B zfaad0WXcHgRb?HyWrMf6_U;=qp4>4wJU^f5cdzL1yqVbrd{PWLJ_)9R5MphKGBls>h9wwT5D+~#%O_^rGFVweE@ON+_7%p;ZEf)~ zQR`x8*(Xj%p|<`wv4Lh_Dl&2Se`;k2d-&#)W+AtpgP$zj?2`3(WdrZ4jccL4K9ixX zmf!4Q#O-3940G{UsaKd#fLu?kQHji#x8_YqXh6g2!kEeQcWa|hE!MxZX z`QDmfh??@44Lyo#jwA`v4-R$cO9^cQ0MXxI@(dlg@Jf##!sx;-2Ns_D@O?8CjR)B{V73~rf3D)?5 z=Uw_1O?X5F;&BM}0Y@h=wjx))EkUrVj*LQfibrviRU+dLPf;yBYQU6fQ`F{V`Hj`) zDlgH|9I;gYsG*(liUy@shu!;r=6Gkb2~JQx=L1Lkn*L4yK4}_z<^KN%Ay_ZzMjkBh zr@Z;COUu#4b06H=tHHv6x;{j%Gb-93#Ij znHf~mGG?~5@g-&dd|>wX^HPHYMdmePE=4oN_U8Is0WZ(`<+FxQB#@pcPk|F2BZ%G@ z9fo@}oasJd48D{d z^kmy$2)&{Y1N25)T_9c8og3;wC{L*~)bO#;l{10|yA_5r<>zn&Kx@K#1s&0cXYOu6M%V$}CsNCCDEM`l7rpWIt1CHM5 zZIHz&Do4>!MBkW(}QvLNR<_JI6pq{5rf&F)cxrzcX=PSd-o5J(Q z-YWf7pvRpyPJPBW6^}O0x152}R)tOHFGo2qMGQ96fyJUqrnw?*-K?j`dlNjhyi-eS z^uB#XpOx}x{V$sAr@iLpj}n;xiq6!Qm{wZ$vLKboB~v2huFq}D^~S5SFvQr{A21kl z0Img{kgaMXD93X&+UQHJoVzQ4qdhy~DQFEf@HpiGgyCUw__7(7+^ZqPlW0>P(UBN+ zEl+#JkwcGfCNzsYV(J$nLJ6OGs7CnRjt$hP5vLYi#u?d1(xMk?n}Z6w^l~kfsXE~ zQYcy|MvlCt=Zm_j0VB5QDBCJnQ%SAc4s(bS-sM$02<%Dn)op1-wlN>E8gUfI57 zYOYWXcZ8Own+z1uPfBZI^5HYy zd{lWw%?6LU(ua}k)-vS;>2p>C6oALG`(w}h0XV386Liy4%^~R@SpDnl{SN}kRF4u{ z7!?STF~7P1PM_+bZ-CzTq}Isf?yG%&^b=8IKU-26S0|2$~Dy=41`q};nIYPSFyd#J!MJXkckUW zqJdpc`Z0m7_Sq4i>rJ^DReXcoY=3c{f=?15gW$^CqC-o~N z%bK`~qt7Qxrecb+l}^--(Nj-3RvLbhu%6QYI-v21qpI}1H}F{N;;0v2rMb(8RyABGYj{RUZWFHYW^)~Q^%JdM04!IRHM zLnW$Xy@Glk9~h1AynQi>Gog&hdh?8yyd~L#^`-dK=JTg8VpccdE;9vrCmKJg@|{ol zW9IZ#aP(2U9dJ%drCSD(_P|j9jlhDjX5r#PpE{~QJ^U*ifzRk~UWd!ZHtl$#=}<0T zVF>+M)lAQ(zX&97r*y?5>O@faZ8a}+vtf+j4yQX81`+e&Ud#@J0EN?`j!c+PGcYpK zeVI{TE;KSLYWsHnk9e}_#&M=KaB&`0@XJRcb%(|e;n+j3As)KNpMuTsvRo&~Ikj9Sflk+7_)p}r>stoGG~q1+b!^yAoN!A7Yd|< zvJ=!cf$E8GG5DgNoh#-Mt;cfIAFB@<1Mpw(A%veigT`P;Xh>j0tp2ju{M%Mew2J_f zxQ#1Zu0095=mG9k(t0c8Xpk%_L^P-R##RF%Z9~mq0yN^x-8n1hnN3YGR@SBI}Ms-t02iA5P_%%E^1D2hN z%z>BYyAT>FBi}knnUPmAPw$O{AzrGT)BHz{!byni=iL62xxfGCE8SiG!nvcU76g(= zJu7tXWKu~Q>;uVf_hU}ac^p{P1^T(rj&qiFapN(fms|My6r}21-JYiRi|`RaZ5Q6Q zM;5jIR^X1w$0$ACk6C;fPq`|~(w!s)!6hR$=KaH&v7o7}{Ye@a#tH}S?UgL!M-MqFWZl@+#0zK> zy1AoMfwRC+$y=Q7s!gm^xv-Y&GNqtm&BGf2S(J64mL^BSccaRc#jJ}|es5dGZS)3z zSun@T3ZWRcOm^DktZy*&tvmH#By90kW~`-opKD|o-V)(Z>Dq@LUNZbMN+^()jt`^& zL6fB3vj}UWy&K7fjZJDrhWy${LS9i8a)O{AE~S#{W68STgQZ2$plW~nN1)|ZgKDL2 ztr)8yH}M6qPCxt?GaDAHJA>!FM<8wH!6 z>H>RLh2F~9q)COK9eU>eH=ENK0fFhSL(nP7kA*{;R|r5$eIfG%Q;|yqh<3vXAmBN?yI zlf?WXzw<`Cuh@F27)!kl$5yvkhch=5XW00b9}|0+|3ObI{MWLHZ@}cAU28mZND6I2 z6cdyS%E=srP8>^ZcwJ%eX-wUx{O5|Q{P0vrTK7r@yx-^QgRbNxmren6Hd{fmN(n#? z3-12Q19ZdZf!eq}ZCk}~xct$|`d+N_`$^s3ZlQWV)qd>-{QcQkaXk-#Y+@IS8R)$=P;yr0VhR($cI3Spd78Qu?TX+doI5Qi1fu%P# z-gt8D?|aHgg8wP-9awDJI2ceMUHtnWYC#|gKyIB&=ub^dYoE?aO>48ivl7!Qfn($O z(H9bMr50?^E7=+k{0FpaF%mL;x=0k?yVWQ{<{DuRiNZ@}f`jRHuZz)$R zBHxZ4Pa5Z~5 z1lIMAan1m^Gy_FFqyNE`LbW-ii5gtOuX34%v1EEXF!?@oIWp)j8Z)w{g_fr8tv*yu ztqOeufksEi+G!;&b;rf9^oq&rpgrPm%D48UXWY>i`n#`jV>l??{k_Ijzt$8#ehvAs zkyym0&Uk(pD+3CEmv6qlu9(S_X~Y%xw9#=Zj z3Dt|f1@}}*%b{k}B;p|bbB>)c8~XxvDP!DdWJP;XAj_7FxUk%bwdAZ{2qwJom2{@y z?dxame5MrfUy^{AzFxTae=q(1e_uMkFglN&ggKfJe}L0imjQf;!6hR2o(-P)%5~DB zW4{<+lEpe9WNEg*x2VU^Q!+G|Qz8o|TD}4M&;toW({O`y;pT7y(6nkG{fz*R1(xVD zcM={pEStpDv(MvSsj4be=_kub%50t2lr=tv)-~1#PXDT4)`)V`{`z)gr~AtjGr*jH z8F}CV!gC~r^8VL_~qKwR*x-k zk4-eRt>GO>bTbuQLQ8V|u8VSN*kReBimV8rP~v@k-NE{SShIEcMAGKYiZW{FmaQwcjs~vuL`^PgKylmN8 z298@N%kb4}0J2G~D3Yswj~GukdXNM^_Q&~sb&?9xwl7+AqC&LvXR2s`S4JZ>o+>-i zqi*J`qD2dIz$N$A%g7Ua^60|GYD3Wglm_LtQbS&CU<6_#1Hq6YlA|)K?2rkSUmgsF z%`p`eG`nXzAQO8-KVvRy?+ZC0v@8))BPIbN)~&pRyJK9=+heu|V2F>wC%~x;;(Tb7 z6|!2jq&(~yKlQ^nT4Cc5I{2h{fGd3#>4)y<43Dy8^k;-JpoafWzV}duH7;vlyaRmEMuUnR*O=Bs ztC5NKgrAtTBfX>OrhlQ(xWaV;GI3Qf#0QKKmsniuji#XKic^q0q1t%81d1#HG{DkV zGc1aNXT5J;g20g_3WiQpKYYQEdjqTo;Mgc5hlQw9l+M7`2>;asM)fefrZqz%(6$XA zSZxeRytMWfY=r*_FdepQ*Lw2j+AG5_0B_G9VW*}LeeTb%Yg8(**+@t2-BrIw_`Mmi1&%tMG#3=ahzp3S z6G`vIzNSESNI5}3I*7PMtixj>6<)vzT0{n9V{M(O1Z1-T2@)((@j9k5R27$`R9pfU zD%$RDe+^n*At`2XWy?aLk@fEDWk-xWCcR?~MZag(5}8XgG9+?dD zzBH7xc!g+DO1!a9-&VGVkrI5u<-MtE&$=A+EP1db`5efxt^TFwCwt; zo^hNyXuIIc%al+S0nBQ&qg{JZERX+0v09tGDOx@QX=2cL?#Ec# zWHW{5^VGO<4PUK3N6Lbd4Dlbt!OQeP-fT|8`<7N>jV?qsQ!YqZdW{(8Td`-`_-lv2NVI=OXOWVxk#^eBU zLviMvt^9kji*03GRwv3{d|#Bg?i2m0nJex(7An7FocTOQN>fIE@BG&+VdvtwWq@|= z&ST+AO@J4kyET2I(*QfA%TaxF`hFlD`#NHssV~PclIzAqzTM?+3rBU_>MA6r#A)N+ z)SG?Y@SaS><2T+S^mUor^)(X#c7_y4CMdF%-sJ8!Mz%>Ek6x}cPI&p?75z1xijN>{ ztdBPtRsYSs+jtO)hSPQg&iVmDAQ9r8?I5YfLytg9^Hhq>-bLSpltJ6*gus|blw!(d zU zTqhJk7TI6=r>KZ|qB?8Dl< zyOq&*?SOB?Zv8fwr0dVF4iN+O(x1FcoO44KyC3rT+dyDT-$kBiXTROb&Gf0LY%$k>xA!1W z`(iT$w;N-;0-=evR|he;X1pFz)pY#r^9}@S%t}j*z?i#9@s!RZQ9_+{u?~7-Eto%h z=E842seB=0qC+hy1a|xkcij6`KjoJ49c2Y18{o@=<;u?VD3d|iq7QHWEYHUNN)oR@?Sh&V80pdikZQ!@DrIa2-|+^A&(zm(a* zz63P_8)6Y~9qOU|=9rooeur;J9C=+Uj-m>9uc5{HMffa#RoF41%C!1Xz( zN0(5!(_BM0txSXJzURluD)gK24yWI1km@Ice)3TrK*!X%FD=eklJyHH!3i?25A4W! z)Axc12VLDz^qdWErIyyW+CNw+e~!VK7a8-^e`tB$dG)c7`l{}#bog98fb{Rr>;v~# zBrRwC2+tQ^W&VFbg$(PZ3zZN}V41mCd-WsoIhCyeZfQk-&J)s4>Clop^U0d6TakZR z#mfhwb^-rv(&jM2?_JE{0p@ykc(Oqh52AwFB0nz#T( z*iwY{YVL88`0IYgFqz-BYO|;8(@|{)Eu*%($%C%L6^-}7w>FgbfZT~9*aU70*oSp; z4!?(`yJ?q>lr$_4IJx}4f2&1ufIj5)hYLCG@}bv`(mHy@_v6shf#7JD3|nEHW%1-= zSks|>_ASbR^O6Wy?-hdrHo2D~J1uqFA9mW!B7TXP<5ASu%cgl`VfTghBM9LZ*Z-!e zWg6@kH;`Uj6FvHPsIrwZEOc}uP_GR>d{UI9U_*Z9tYaXTRdd3HPQ8}fVTEzx^tI?h zD8xH_oDlqjPt&K>Fupw}g0@n>66=Vgq(-c2(FzP-N+1VR6SYFoyJG02VuYoyBt+4T z$}L*F@s2@Y3Y>{Uw#hL<4u)W2$y4ijx}?i$PboL;N{iBAke$puN=(K`jBgR$UxFn0 zH55dtjj#K&*6ug}xs4D~!^?k-yu|-A@@{gRm_Uj12PRw3kLYs>#kaov+(z4>R1u_Z z-@`iXjQ8Lxtf3GZa<%~aOCo5>*23GXjUmL#`7=rH>NOuFuXYe;k$#L2rXQ4dBXv#p zCYsy|Qt@qX>Dvm)iLASP6{O66?^7M8JP(%pY7Vy$3f;;K##L~(MB$-Ajw z2=8(j;_V;QJKk4~MI>48`{P1xJL&gocgEQwT_WmqfSyiC0_J{dI=JEND=AX*aN|ae zUuPLM!zHBRJ*>O}$pvs>a|u2X8s1+Y1iQu9k+whbGe8F@F2nbLm~B3wng);d&12)O z*K1k5JZI@Wu0L;C{v3H+|KW=d`~PMd?vp6)lH>aV30G)Re_Y>E@Oxh!<($_-JzftE zXU#Adce|MRMMcr6CtxuGiOQG~s(l5#&)FYeklIpuxNgU5S+)p~LxudUHh;#0|61r`58hbBB zJeXF;a6YMpgp-LeTJi|Lv{WIHwKmNKEw#0td`*+U<>AZ%Ucp(JF(j`2ejFl9j=b}% z@-jsL&D!9``j4mE{pzgM&;nnyCoM-Xq`U*@6tz7ycni?;WRUUgJTpCKxuzYtw+i~FWCQ#2cM=&a&0VV=At&@Zl9RW%!RYf9I zi5U76frC{0cV~3H1U;ViMo=8pEM^gOj!Mw!8Uv56dxC|NQm;5&drLlr8f+A)CpLi| z3;XMb4Z|Cm{&6i2RrE%sqStiuMo<`5OH%i1377u5*WXJ#_50BQjUO`qiz%0toMi+% z$&HXIp{SvJ7Tk14m4@}PHg2tty}nG)1%{ahefx1)`OzTnzt|m1VQWtzqJoPIT1vU} zYu&s&ThPCsHw_-eO0wL`7K=Y4mo4~5Dh2>}S{TKnhYmhE>o;jQhT~&f`nDZ(1a#&5 zD8U$$NcU<8Y3=i+p{^j+H%}?vR^IU@T;5neo22uOB9Ug4PT-jOjG#(*!6Om_%|y_> zJ4gg%!wYQSePgP-q#L56JIiPsZ@VA2#-`prk9?^A#(1Fy4Dl=2@V7;nOgX5zo9gb8 zwwU7Sd~muS_ohzbcs`O1;dk|ruPuB3*U?TJ477QULkKQ%KOE-%FhNk0Bru77jR$-)zq8Ww5xyF|iP!!yqo0>^0mjaMnDPq(xvum9v z!e}FKX+_^vETRJg0Qvn=$7~JSIw1{iJMe^YuUTjbmQNY&qldy@;+wsHXW%i zn%FF(6{gxWHZP9m+dkZRz0D`Rseg|}mL}A{urD9zQ?_Ku(75Z-4kOTXJ%5S?rlQWRa776-rZ%SFm_Fud|1ZY)t$8~s zFsA?J(O$Y}CQEd}?Ls9>shiKJz5nHCF;W1zUQ!hPQ}(Y4>-986K?t47kSoMIJ0>vR zm`j0?)FI7VZu|H}%3IFwP7uOt`d6^3 zj3=Mp{IlSWDi?jLRz~kQ?39uSaxZDa^}w^sE#6~kaYNt*7<>aSu=H?li4C=0U<~e! zMQFfHi+S|#y9T40@?lofGz?l|4?(KTq>GfTRFE25XZ>KoLo|?q+i2y^DnF!)ZcK+m z-(>ASdSRfd!`M#Sn}-fm07HJjb$su@AL{Q%Y^B0%ApJ$t)G_6Cw5Fru4|+>v04e*m z;f?^>VdH;_TURgR=@*WWS}2kF-WfdtDK0C|&Rn~GJaD`n9h8i-W$4BjFbj93=P;ZI zeq+xRY!PzM(aI6zum1#?VQbMg6hca3^>8c#9VtUi+(>&N^-M&rQ*(_`5kC&k;PNvCqkoujU8`rgv-lPd)vOlD>lDu2CzDK1Pr|Y-8m~vP^PkGaw2OIb<)?2Q5M<$ zavcKo4v^56UywD){1MbA;BU2-82naGu%5(O*cW=q8J#a6VU4OR(dYV!E zUa5Ve+OcSlnjzwkYsDl<+2^KlR%H3hp)DO{Zciq5R$IFsxJ!85w-DeETSNRu1<-bM z48Xt8sXPMAAx{hV=$Zd_TbZA%6gNe-TM))RjkUnUW0}Ik9Iwo zf`l@n`tQGJ*_w0}5QCD@4h9bP+?6In4tyC>djRn=1X_jN^4aGi9&9bq+v>AhFq#KK zK$=$pG520$1^FWFLluD|%+9@A+WfOZ#ce{xPO6YtRJ*uPG$_haE&iPPa|}x@1(g~& zx-5O4>y#D)b(PEKpnSl9Lb5OE_(S7wXdX(R56h*c;CS{6nmOjnsJ)kTa6H$zU-=gF zYXx>e4WsEAf63*MEq=3oh`Z7&;g76%0-#_#hYvAE#>3a)B`2Dvt#SKzt+0@Kgw2~a zz6y)k`*y&E3_jRdGWC4338g7SK71+t_uC)$%M?*AZ@i(z=a&P*y(*DkBAfoKro=b{ zfnU}&OnqMl@ZJl34@7--1%M%_Oa5r$P`5dU>Ju_M+!Oz26%#cLJ(Me{%(IbDj5@#CI( zw~WDYrQfRTp zcmG9}b8rH#8-l*TWh70QO#GqGZcL`dn!G5w;akyY9#sR^^#JjKvGAGLJA>Sw6G6Bj*!^J8E4kQM3Cwg9#yH;zM2U! z^j10{v4F<4s+w&$KDXoA=MzCv-{P0%_=cT!7|n`e8Snom5&qUOisEE}wexEiPE=O+GEdQ5qsOnP!O+wI)_~tibm#qD9xuYCyu>@51q>K$YlMubnNIN5H*&^_XM_- zUEV>cCP4`7eNturg0rpv1803h;V&Tln1@nF(bW!#&nE=>x&dyTz^tuinIo8w@Xx%> zOpO5I4c2;i+1H%tCCpr=E(<4<{bmRd;R|c(H^FnEb+n>{|JY&=0+9g?uKvhlxfLko zu$rfb4{)1|4y50^4{0iPhc9Dk zzxD76@*bxbFEw`0dNfd784;}OGW1jE&9g8h`5Pdm40yf!6!$&APGeAFe`zr%`^WcB zwhGwvi21ucBqfchL<$_|+lRahMWqtv9qP9ZFPa*ztUq)Zn^)4HK?%J6w$Kl!CtF4D zbIB`?$kDK;MWFlG{WRq3{R-BcLkyn;0z^)I{OJ-lNWD9%qq2==mgd_0&mL;;xl&<~ zxO3JOIqRv)4u7&km#ZAt*e7|48AAY_b!Xc7Qux&^59KOQAX$nqK*trLO({#{_TzYi@M;<L(H`4Fjv!2hiKRP_P9Q0k>HYo+YIzWip}*xl%t~C9h$x%M5yP zn#alPDJ+UcDc4Za`#Rpv67&UvK3wRU?e5r&34#p{@qt|r5j@m?9mzXLZ*NSczaHsc znDGs0^vsehk%~x(Lv4(w z*y3^Q{mQdF9!xB)r2J0bLlc{f>(<~mBD6@vqd6SizB*GeMGMPmI*HRs9+f~fdfW1mJoCXAA^p(m_ccS8Af7sxhC`LM{x>08c7Q(l(}GfFj( zU2E61>zQ6w%)+rvK!SUxFjTjgFYDDi2mx*vqcplF{Ux-<2CQ90U;8TUEHs0R z2HggfO3s!D4ja+4)O zY0jNzmJjOO`R#2j8Iu1d^dXS`t4lMU$qf3Fcc?O%7z}3Eb$^?Ds+rTX_ScnL860LN^_RW(81VRv>H< zP--VC@shq^W*ScBo(EsyC;8kyg+EJyd{z%w3>)*KJ7!DSt*kOQm4X(51ePC-cQj$N zQ~ae*cGdA!$cW;;Y%C$QrdV#?ePzq0N{>J6jsn<8BN)XeH4@3MdBcE;$HG62Op9Wv7;%gvoL)GubT%O^!)Pul}E_(`a#X0Pvn@DVOSZ|)^Puuj#dQvX#8 zo;7#>jXX-m&DLMM=ny(C_p{RY9JvZOyWukpqfB25m7k&2cCH8@X2qn{OWCpfEf?*6eeO{42U6nSM(Vg+zVyl`+HYwIWrW zum=yOTMDjat%xv*Lb<<#RZn5-vOA7&bj`%a z0Ry+tSH5{6lyz}yMS8S);DRtPsY8xbh}(DksFvGA&qq~@TRWleG+8h)FEg56{P63h zO$}&_3ZL-^C=6{_{RP2VwAl1&Cn2r`DQ1H>Q{cWzz>h#RI#gUV)>=pU^MVcwcy>vLCmKO~gUJR+sV6b73>^NS}nlra8lL0<0maq#v z1yQ>H?cu6?4>`H$`4Jm=x?Z^O%M#eYRmVn-?0fxE^ATsIfjz=zk2s^LGYSXS#Wx5P z#2}UdF}FtW%tbBetMoj-zzwpVx#w``GE)XGKaI>;Ba$`yDX5rUpJ3y^Zav3DpUy%7 zjuyIFBlzYM6O{UfVVYsDXC8@U^aSF{lI+~raPD{~j;qkm+YImUy*DQ1v(0^`pZM-O z_#G+8S1Zu0inwKV$*)`B&5B$}t$4q5|4e6uA*A}GR)CjZP%_l<>J?)_(8|O8 z{XXjW*w|R3=Yr{-w5OSG*CnK-84U7dX$t=Cy2b0L-xl*5QRxUh696KUk20wSm7b?4 z(kmZs14rMQ8qUFit(T9@S1LkOwR3<+otY`h?rRRhuz6j3m~D|s0b#--i}~>$%^Q07 z{~uLv9TjD`Mt#pPzzjV!C^0nB21ql6NJvU5tw?uE4Lu+N(kUS&AfdEK3xY^0q0|t9 z(%tago_C$|zTdyD#ahDL^E~&puf2cUwyJEi4xbS7g8-2nWX0wtqqN~d!mqJ2@slR_ z-M1O-${0?QYo1U=h@LPAYB4>9n!z7!T5h7(AYWxYGG9>j+3#sx2EPI!EPo#n0+CN2 zkkEyHs!R6_JaU^#9?x zfv%|aH;73JKr;Jj7ygP)yNbv47=Gf;Zd}Y7{eKq$)xspwI}@x#6s9jb~7+3yig|;$^ei7juH3f182t;eX1*ViIU$)5LFzb9M{7 zxVJ6Zz%F*HvTDEoAWwvh)`#Zn{3Fr*(GeFih2XGYuS!t%)5D|e+JQj@zZkTvi5 zGAEdx3F#k9c|7EX1Kjuv!a=ziRt^(#Fd;R?tpx5nW0AKkHQUu>!3cz3wiGm<8BMku zj>fnO6G^jGkfzLL;-{2e+<7k#`;y@(_cqXUJ%&p3fQWwMxhfCjv*iFtCOzFur$ZzO zk!TmQ9J-|xTFH)VMlwcN&_w#7TvgX)%N`6bg(OS~NQ`lO6S3u9>xeL04OaDLHd{&a z-v6nUYerH|#i9JGCapyvU@(*rNFpye#gs_YnA@D`+}Sk(y--^enlvCYZb)nv*J(0l^#QwC&F} z=<-V^BdD}vjw-S<_06Ix$RBs>c6-@+MU&InLp}M{@rK~;3fcpqikZdziHlJ1ClS_g z4p6<*tCr+Y+rM<^j3C}-^+oROUlYQ=3ON%OUXb*k{}4a$=>z<3fm83RYJQm-jn0s4 z`i`*45-FZgxC9T?pBflrC7ZG%0?9OxaQ!Qv1(u%DKBfW2ya{9N|Ks%KOk1~<&yQ6c z%P1uvuMohSwH48?C#VPfN5AOziI#H1$!`L9|_bTOrC zBzK9@c@=-8KDN@TjJXIJPZaqv_*S7u8-vQH!u&)LEw_UBTC#4Uf;~(hLS4iPAyk4s zf}~yCnZzq5Z$v^R*EX&})}DUu_(6nt19N&#WB;=?OtAfhk>Zg9aMJT0W67{c$U7Jq zhaf^8J%96gc+jXjPn86o-Qv^XIs-CL>;v6Y3%}a&_l;jrH}k8g^m6=@lizZj{RxJ6 z9KC+du2!AOFcOURX5oGl3nTy^$N%b#-YDSqWh|kxX}YoOJhc~x7K;lp0eg7Gf~|f9~Wfeb)U7+Kax~A8JKEj zF5TDlQ!aqy_xhqiUQK8*%;C;8+{4^cJ>Romb1a*!XPd4wUK0^6m%m-=-K?w>u57;b z)3(wY7#Iw!;Zh$v6WzXjoBP4~!4oB@RaI*9?iL8@xvRau^SCd{fk)laTes5nVOhAF z$zp|?0ke=%j-xx;7oq$UW<~kKJnwz+`-kP0nzt~mGUcL#_iN>ZRn*7JZHl{oPVF{B zLPq01S}GEy9}1%(#gd>S0}8o!aWwezGaPQNLgTPIc6mV|8nE@+2 zf|=*#wQ&4{ifFe7)1(LXdckF8wsY>X>AbH**rt(`ypuMX#)F@pqaoea(PHsR~NgY-%>?_NV|s_pV1J{Qgv^r5z$UKE7$ZLHw(m6bK^Z!~4r8 z+6D@$QxIr@y?#5AV_QSQ#UTvGU6h**r_E(0QNz%Z!N1b2IMn(2-d0tPsi8B7I~pxi!$Tu11(L4VxX@ol z%*{a|&%5J3a7NL{t;t0H{9Z|CN>pTJ)DmX?PCJ#3TIr-pGd>Eh5W>&guEBq~&UQBRg3GibNKRKlB^5NT_aLE#0#f0pRGs=) z%UzA6sa->-E#~LMw7Ir!e}G?J6T7~C;s*9Ejie7;P%<)6PY?^@=xC?YRv>VMh=14HXiEw@gRf4nWTf|7eX4@L04weavZO@#gWBl#re^}#{ZCZJxxLF z%4cUGbR&ulrzH$^jP~NJz@ubWEhzlxg~AMDP`FugUuErQG2QBfEjOT@R_p}bjU1yr z&kPm4>~(U(v&oS}@7;LK8>nuMjN*0IDu{?v=BW{Gl)B$b$bQR9f$cg-T@C?uh$0Mnmx0C^a|Uwd38Et()WxJboY{NH}n>0`eZlnf{p_28Fb zM_uTB?zcfu=Ps{NxX57Oyz+Wi$oxIe%9_Z%;i)GZIWoToUs9Rnz0|wDtmMkFqC6Op zebw%Fehtv38l%=GiGPu8g~099$PtnOyg4wucx@7m6gWQc+}fDHMV7+mXWt)gV|Y{O z_pejmrIJSVfX856wq;+4`aP}MingZ5Q5O;Ws5R8oRfSChkM6+UP#+ zZne=s*?ko9%3QwL^r+$VQVW!$7qqxyQ(I?s;~qZx068f|Ie&nRl%;%2NobyMC=1S^ zLy;yF(WWEVgb~8MP5PMjN(kp>q$k_}GC>;*w`@CqYXO0W(tA+&VEc+C??Lv`Q= zlvx@MUnu;8SUt8dm50`H;Zd>mWUdEP^r&q|l|%v4W{A$~N7Q70|2!15TeqDy1AGUD zB%khMWVI=#y6ur0%^q3G=K2@!TeKyx8e}~1p*}vU|GQR zzh!}|gaFV^*l(A+j@l4;EH9%(YU@ExZWa#QS}i;lh%;JL;?l3%&)+3FSLrJT>cDTQ zp@~+e;vrCq&b5%P%nK5vozN0HEJ|#fot`b5`sQ=+xWHwr2%E0l-#lm+{S9PW0`N^~ zgh20i?m;)d1WHI*^nFbVDV+8j;iP6QJTn@E`-MF_paYzq6p0yrH}<553VkvNI7Qae1pSfrH|+Ic4k7AbY4PxL zFfC6nzompZX`Q?=yV4Kv9b`svCu+w6ozdI0=xumdD%PDBjfANNJTc^w)K3Hm5H)Sw za|FUg7;WJZPB2YA<`%!6e+nyvI5z___<%X4dF#zkUEN#oEm?rg$5glf3>l9gM30Bi z%@=(jv}lSF#lv^_^!m!~Z;YH7MJkB`$2`gQ0<=Sx1*zBH4#Uf5uNuwrRYQi>Y0(9||>W?rxep z9{|d+zdD1mGS36OuEsw<=9nvkPWaE4{~!Hz@(11-6^i^u{12~rm(g3)O$iUh?YxyI z1$07N3?$q#*JF3(Z8b97C~hv(+CdB6!cv;~f?uiLXsZ7KW8XRc-e3g{@F==wPtq4M z2Sh&vDRCRpnvFEf_2Hs?=tJRqvy^pg(y+nqx)~YNybOZEl|_lL5t$>JPOAA~OB#Dq zP@z3!b-m+7_H{dZGNC1KA6${Iz)9NG$QOPdpPW63-($x|JnxLUCF#|-nwtH<`urI$ zI++XI9x|`E%{m@l{N9@a`?qC$^;-Yy4X|d%mFQiD9_hWz(G?Zk0c;irkFs!ijuEf` z8K4lHBr0)oauh$C+_wsGaPzvEV%)Fz_Je`amq zdVz`NC7VWYa$TJ$weV)!Rgn6lKNCMUypPM7{#nFNks_6!t+gIg>Rm zJmwP0UjEaqKq+Fnih@WcKi+0t%=hz7z>G7YlI9bla=nMmn-y(uxeLX!Anhd)ec)zr^LE-i0#~~~e(#gcJ)xZ;6Y}7=z$!46K9#VZ^DAP^vLt#Pk`1Z+B*uj{ z+Hm+pU`hAAMXV^87d?Li*LQ;w)d)|h>B<{d9z) zAn!*b{*L|Fk7nZmlJAWvXXzxp&+|Wn2(TS~(W(3T4Gy3rFvoSm82{C~jqGR*TW&we+fGh6Ozw61>fg*0NZO>pQ62I#_HN#!jMAWQy(?5d}$}UxvZP0+} zQvLdkRuMN9P0u3*+-Q1{LG7KQiCUH99wxXys+&>uxE6*lBKS}AFh){$?am`PlYG&` zxVzkv*k+W7ku&03lk()fs4j*b6iFM>^U;Fp#oMKzY7xCsGMXM=lF1n6c`$y7lRHQT zeDO=z<5)qjclnRb+=FjtHC<4>yd#dxjw0;c+02fr+e0140cW7BD|;id1uMZUj0n%Q z#8Z=r1mRH<28Jj15*h1q_=cHrDmpn-joQj)LY`QX;M=J9c_91D`m?Cj@6*7u?YlHi1UrFiH)i-Nteim42EYk?gYDNJ5{1zWN;Kzf{iQ05Ac;VOeZu>vje1W$; ztwDR^`&2?(Yr)gx>nun+fSZ8~s~>8ew-{u9(uyd?033iQ4_m?zqu3=)>^rHX7U)AeIQeJ=5N6* z7@_9Br@Ii73Z4Pv^Kl4dvl0tJniWas-W;Kn3~?rchN?0AY)Bw|cjH?$Kf6#2=bJ({ zrFpUtZ74(hx2lN#r>0wJ;T8#$Kl_yQ{lLCEBduuRK%a<;8t++wfgKh}KD^ftyIF}O zt%@A+Q(H$=ZX>6=xtTvkzlDw+#92I%UXdg zgxivxuqO7LwYx0w*)ve|y1fqL?RaTd)6t~;Tv$-HI@c8w&QbIkZrTh|E&UPa^IYG8 zk*t$G&YZd&zi|>Z?*2%S=I}+ghm}Q9L9N!S`yKhtLzOS1bf8v`Lc1M*XJ6v^sTLl; z^h1|EG7Q$Y?%8@wlJhZTB_=^Pg?ej--}!C86P}5mxXnGRlzyy=PW>%+ zPp=0`#lZ*g{yq9ANOeMnv2s^?auna5-ELW+twOZztGCfqemeXqB9_YgaVaEEdIe~a zRgYV-pcwiGel%zwcluh=K4$6WHX=zR^l4{yD_}*cE5~OR+&1s_Ls>y=wSY>D`8t2( zW0eIZu2^c_68*>|UHy|R{(RT9V{2-ZE5GBr8x-9q8^FV`_!_|r>*-cJEWGVW?ILN+ zf(A(>N0oSt$kog!b28KrzVh`38YL#at>A7XHGlfWONI2x{3(7w6c5oTQh>2&;N@;E zID*DX(FJ@aD3OY)DC1loD%7sEPF~sfOGSRX>DSO(el3)2_0*^@FrDOgDk>7h!xGCO z>AI<_8Qk?_Ql+FeT5R+)_iLKQIdi{XW9tMLi)pI=y%QU==_cbLuTr)XGN>q-cdFGOo4fupf(R&Ds~axO`n>b_HI*QFcyqW}ZOUHRFRIf`!1d$1qKy)!^Jbf@ zQ@qwVLR?b?u;sfF**E^6e)l1e%RGO3Z05yg{5=8a61q^9HbZH)P0~hHc0c}xpLpfB=`-(SCG>1MAkkFw@5HeFe>``qQvakRoBhkyT;ZCrJ*p38q(*-QP3 zTWj51tGK0n`o>i_<@V~pCpN2=gLf`uTYf|S1rVIF#Qb80R={+bh4e8`aLZ*HQwn5f z+OJ(9FK@%nvO{W4y#) zxiinrWtn~Q%*vIJVJ>MD@ioDtFcw8sOS*e6lb1O6@q`i66oSIx-3U!62wD0*q$z(V z;?}K&Z?K=Wjhr&W%-`02;OLx1jaR8rx70S;9;CXKV;3<q}m$R4;F+G$gb2MmTHG8H=mXy$t>QgC;XsjR7%qop?k{)j1Try zbk|^sj$t$M5f#)F2-iZ(-Fu-X4`< zFIzsqa|_RyprPc#i3xdZJeo=5ZSAmRX(Nf{0MvS@#6Plu-pYZ!pa19=N^tL#j>?;F z5ocp#uzsihp2F<49R-#n?1zY;nc`*Cz~VQI4bMCOt z=AI1Lbg}{hK%`I0A~SAvKH8P z)~3;wTdaRpgKx3)yDkvnz*d?4O}GZ917AONXuXk8gM@besS zSbQww$A)u&i0RRS-fYtp+S*m^riADxi|7iV+ zSB*eIup|_acgSB5`UGZLRKP|7j6G{ebf%t-uJIVM4J~XXw1w(I(q$@ZqK45n-$Xa_ zKO#G34FFA+g#YW?2HqQgs30#U(b$NPYMS71#5kuFyJq4!1?J9w=t==)B$A};S7BvV zSOv-$s1vX)>Qkomk$QcQioV|su3@4?X(jq(Pwr9JGTfd@N;m;MAt8myZ#KS=nA2?f zuAChlp>3^R{Lex$X&=&dZhq+`eo)gk>uy<)PBJ7X$r7VDM9^lbhxx*ve8->0QS%Wl zvW&Ux7?=m!Z45ZYzmtC8(k$(#SFP9oZWbF_RoB0PTAD70jL_b*Q}?yS?%kn{9HFOq zj;FROQ(bRjqzGllDN2eBw?-qMUjVyaBSm&p7Xdwkkz!oHHyMk~RY^fv5g(mUf4hpi zX*+|o%xw4$tUr^)7oV} z-#&G-_bkOV8zj0}Iv5qkMY}t|2#D|e;!X={u#eUah`+28E0_~!#ixU!Uo z^YIt8SO`*U#thYl$AXm_>SQS<(Q&M?szi8Vsfm5PQyP^Vdp+obcaU0WSifJ4eY=;v zO=7kof((I_(MeoFa;u8lo!^&~SC>EmIxeQVFxi%PzCxKAl_%oJFoc^Y5V#MhrcC4( zjg+qyS0tZgQgrwk-5zcKyPrLWBeP`N5uX$f2w4=L*m-q^WB{^*o*DH*>l;? zN|@?%3U1!jI7@N?e8g|2Mx(}v@lEK>AR`av#wB&w-Z7J0OJPYJYH)NkDR2GQm>Lzc za5p@C@F^YeJ15G9JS{YnBSyZzyxjyWqqk%NcSQsKKFobzn8ZL!f#}ALDBFEAKaMa* zv|apa6LeXLjiwV0z>x?tJ=ve#7zd_-K|Xl!)>|6X&)nRh#;7?`LB{}cf&1!9&r~rF36byM#V(UzuqlONR;Ed2KGA~g6R3g zFLmr}*B1*EBP#iN`9IIVN(m+#z1bO6-7)TLa?@wt=L^AUnEmh zR#UvhA`mX6v3I>$p0+7DxNI{uT%RT+PyQ`G7%~~QI6SrImw8p)+d-OR^TvlNcS-{U z`gzHz-8~?R2wJo48dcm`z~}FV7$dvyRZ2yG$dD0J z7jF0YnS0ihOX4%*Wq3R$ckxgvc4+z7<+$ITF|da9B6x{*k4W{}3235>U_I|(t<}xK z61(_zU9Z54j4}iw_I96E(5>?xmviOg$zEAH71_S*r^#4Ke_*Ko=Pa3(0hVYZ%F;VD zaXMp-lpLyw#cx8}6bXkvdCOL0zb32u0Lbz`gl)N{1;~4yz|8V-zB{jcYXkrcAUyK` zD+^DcWS)Go-vp};yJep2_wEH{X6AAGtoO^{WMEnFfLlN^$mfgiA>bB=t9BxU3C1ge zBb+YxCL_XJ&785RNAeAMmq{3^4nSR`ATps)Rj^a#FIQ#gdH$gAP|&jIuay?dYxXAMJ1CPpXx$}f}p{EKy=9-d( z^#DxP)$`XnEZE+ON0#{^-x}(rF86Z7-I4exaQKi{aSKf5ZOu)ZI({z}EqA%3Nbx7U=_;fO*{O@FlZGC#AGeN$z(ORCnlO#1NUYTY>aHT|}0k``FtnQ8>lYut zPEqq~bHx2&m(Zn*C17bfY-iqCu^hyalWM+=Nt2fJSCgJHv%eSGn(a;+;`zja`G{$utWi|JJ7jTJWOpP%?di zvjsTvl6H;bIjSq%7#6KUBx?6FQzrc8%zJYn$v};20x?!q2!wQ=T~b8yxVVEs7ji5K zacC}wKp|3yv(*Ba0zSyV<-+PQ0&nI>#K)%K2^l2eUJ~Vk^e6#t-L!!>@_iW zDb2}7_q<3G&o(wC<0+@!8)<=uDN|t(c@{*`@U2Im{l7gMba>gIq~(z#@Fmi3s;JQ% zst{jv72fYf+I^nXl17ZG*#qp2}{*|5^0TXE=p?xZ3X99@+9O>b%pj{cR(& zPq)#UPJmg=Pwr`Pa`|6A+|KYXaD8g$DDeH|`$ezA-cnx;qdj8BW_445tf$-D-HO6` zCqvPjHR9c`&W%%%ktML=r*ibUC;!DiMd)K+CX*4CytbY%)R+-$Yp|&(rQwZvYmgnF z4Ba&8l|VW;6143&40=u8cf?4YuX4O|>@Z zCMz*=_jusN%KS+mj)=LYLs8@b1_dorctCWw$dFS785#mO=8saO6;?l&1tp2c@;!fb zmd_Uy1_!cVI+Y*OL%SWul)3dlOVNG0=*O;aq0N3fPi)U{8bT~z76rG}Fo4gnh zk|_DIoZI`EvPIC7-Fp%~rV>k<+uw{z*73HNyHcw4_##7ENs?dy8rTTOC(1&iq}4q? z%Gls3T?F*d1VOwtc3y>BH9|%b=OKv4y2X~az#O#{j;L{Chq&S)G1XzhpKZ34UQ3e7qUu!;j)eAS+VxMw6TuB`ge6#jk?DxsrjRJC=d9iI~_3n3R?u`n^1OdHN+G-evhnrD}n<0>L+ZumVXmyuQ{7G?gLJ@QbpWUj--TolUAuh?l5S7)C(7URI|mJSlDb*?m{XY zeDbyHtJ&du*RO&NMw?NeBF(p1bTIA2T5aQrM*}BA=Xa^u088d4Yoj5wK+Y85^;zR} zc`69|D=mp03#x&?;Y6m%|MkAtCJ+ zph$)e@>H75=&@0xGz)l-vt2hSkre{~ornF4GCK6*KKp&xmrAjvQSTFPjU0l+YeuIM zw&|Iap4W4TA>$qY{5=(peR}KrrmT08;R_?(FGpsLGPSnk0%Ny|Gn8sKM~x1ijQ5dXc4OtjCgQw&M3qlG-b%oMuCLTlnyO7xh}}# z-qAN}DHk3K1Po2Ja0Z0_V}~+AZ%dPC9#MK1Tzhy6InesUfAzTGXC?pSyn-#jnovk) zB$o+hdz#|+4v|bqnbBf9gmig0Vv`B(+|^1#hM)l0exKDh1u&Pt(@`ga4o5yO6G6MW81PGZJ8(WYAdbmBG_3fFs-)`Y@Xx=DbWuvAS zerLIm^Jr!xiqyre0lMqqc6(()Lq+a$_5dZJxbk>e$6VVgV9~Z@fQauT>6w`6PnB6m z1MWgl71{ytb{+{EacnFdd zohkz??(+xKc#t&zRo<%##!n8-MG_E?4(MI7FLonPO!0>_LhEse$V`Xv!%xf>YO4_; z@3%9WNRZixyyLXx0}CGf=j!3a@MZAHUrB+=!Ayvpv9lS5jaLZ8y3B1v$m>Il&N5 z&BZu{YKU1!{7(x=^(wtF_!2et#KLnF9O79DuL!F91_)?KNzcbjcXz zM@L~^6{o6&T4Fm}5j34{sZU0ydic*^WV4!^cP3w4EqhvTKBPkx!(~eq=dOlk!Cnv{ zF~SQl^Lpk0s5Lz&6NZsFnHkGydNFq#`tF+^3=1{M9AXY<25&9M)RFN=2oC+OkywsE z#DWE{CDIE`w+ak+kJ+sy{veR4*yy_1YtT+<_GIj6Wl|{-PN8ZF6b6wpb1NWc)UJ(dS0)~6hJJWVdeh2=@YUF`FM3v6>)XB<}!O&xx&ms z5|t7#`nWh#!`dJ_g$UxS_~BAf0ii-jC*1g)`$sqRyjXF zoP_bhc`}q-t4)vvya5J{pCPcD7(3VGKdfG#=rfUMs8S^F?Om)qvezVk@s_>au`}Qaumi> zy1KI2(wGM_Uc3z^#-?r+NmDk}QR3g}8B_zLl9inIx7p#%iQGtH0hmpDBt zx$o9YVB=WlHT}G>&h!_7l5sb!HWEop;t=xl!(l~Pg^AUidH0Kkfi3^dv-{J&l#Gl` zG*6%A?7dAo%J$g$=r(t?i#zgf+cQu8P$t1QI7)ooX$VRQa*BV25&(2a@R0otq~^2P zq@DC;B%V^6JKBVM8)>n0*x-)ly^k+}0vh(Y1c?8*P~Fw&KkJ|t)7Rto*YW7rA9)se zrcb0OMauDxK^4a-MYzsX!GnhU&Ke?qxjGL4x-!HPlEM0md>}~L6zrE7=ZmCJuzFd- zbf{^2nnSvjiU%{~o5@58lo6^>LP^!~_Z~bWukS2RT~&*RLC}8)aHZR(r((mEQ^aDy0%4Z zO-rvL4zA=dXrrPuO+->L8S-gy{a&0jfBTpermZ&2-0J^t^INb&>k zWl`D8)%7-eEOO&z4E+Q}ic>XyB{$jY#3{nV`b-QT0txRh7X20mal6MZfXmg+v?Mg#K2gXG&EaUU&Tf_VQYgfVu+# zBJjuGd}D`%aNVZzR2C$UM6K;NH%CtOFH971-=g>!Eak62j6ksFARf#vR$C&Vb6`Cp zr)SX_zy?(_b0XGdKd&F$=olmatY>B>X@l1tifI-yg=lkwi!wOgeBK1EmZa?Akmro) z5pDTm-jGM%gXx*Yi=|Mc)x%?Xr1QuCB?AF28Gd_{MLdb2v3`nd>qq?iN^al|;;m%Ir9<5@U#_3Z`yORzT!$ z^|M4+B*@An>3Pnm_M0W*}is6KIG91G${CWi7PiV*qj^rQYg$Jd!>{)xjkZD zHkLA|Q8RaBBUMKBE*lydq@&x)FZOscUE(^_Dyy+KQA|DagJ?OcA>*J8r!;TA_o5ru zw^WFHCEu=|hR}G$Bdyt77D-|x0}k(>RAeO#7(M^z)2}B%N^A(NZwfKoOufLfP3l!V zRD=4Yg=3xkzMB7{^`?UCecoVZgX0HR&cLiR&sG`*{;pJclX?>qpW%mCcbFTqAYpjc zc4K%J(${~}#N3o?W=9JUY4uUccNQ-jslsh@%GYl1b2D6r{D3kdySM0HgI#17JbNsOkVE(2t874VP0)y z(}hbdMJDiYm{AdYzE3AgE`>jBDn($T9zoi#7x6xa%TL_40i*nuPm)X+VZCF-g`Tfh zVZpvvFy6)jrrUOg84)sUH+TdbaDM6Tr1x{X8oSq( z1H1w?)#Kw*&W(L$d#&**;vHiDlPg<>1<>~@wOTK}c`wBMng}>(jm!SzY16R$Z~;bN zM}j=z|C-8-=aFO}ZUhR0hE!!GxF|cbNRXg~l+#`;w*P2JJn4Z2cjUGmmBYeqNkO<_=$8M)?FJejaRjRv`B{#eW z{u2FG+#nDwT*95dT8fW|rmXFuj)vSccbLrSwQ-Wa z|0us?LMZs_ul3QpK_yNfe>YdtT}}P`x3#155B|>~T))HOW9j&P^Q7r$a_84o-)}7@ z^UI|ddglIBe;unph)~spVed@MV9vb$CiJ90&FCdj!WbKacpIgZs_LMhekaP|&PM{# zCIL+fmxt$1Y_^>t)hoQ8qapdsVP!2WlLYP7H~R=)b*m84-}6R1f_!2}Z+|&oU5=`T z)rN^Ukhcfnr!z{4H-x$2NjF>M5N=C?^XGw&m4HKF1GCJf(5G-aOr}KMfYF4?&Ad|r zHhN#VrKbxhkyY%1)~#Y~B1N|ZXn4lj5S{8czO4hZwy;Cmye9$H#P@d6^O&;Mdx~1E zH0@b8cXitEiQEI>@(X%+-O1jxdSq`T_^lQ{!f)9u4u_gBQ-1`{;C<7eh?(PuTQ$rR zbGFQ|Ur)3C=Rq-3p%Vv1c-J`@m*x9SI9;MxT*d9OXi>*grnG-G#9qxvb?e;Pjbxx9 zF6(YYQj#mmmN!>QI6)QIihO++PO5o9AE01MKRpX_z}V*H8t|vs2f@tUaBQFYIwhHU zKS%D5^qRmXRQ+W!p9A)4A^CsY!xkF_ad&5Nw__znn+VNdx%plG^LJdc?(|v$bUY)M z5Czly7Q#4cK8xz;Q8ihjiM~H#oYR|_tyQ3oNT3pjzbj!@R~g2}z9gIqfgKGomteHp zC=ykBP}Ek|Pv1{aV^c|0<}2E*p$Qrdwpzb9`(H)9v$+a!ojV>HZvf3D^^;OyYNGSE zwApSnKYz>G(0jQbpC;GD2xSi@vUt)#vS3^3Glth43qX~0L{%za3ClM+c>|IT-ZGH) zklPCQQs=S8uyCD5iPPy`%q{eKd-gfn3~HV8Xbdv*5+Bh%T|%%T<|srPyK9pyEJL~he~*N_)hrAhA&#&$4>V56&j-dEIb>%kbZ zl4t6FiyH5g6FqTb?&{<@Dssz z{9VG1nH~Kv3t&A4QbTk2-A7>V(31cmEj6V_q%SWzg z0rM7FkX|9$`fH;Fb;YfmRnZxV-(IS|arQ%NKc^l7PGD`}vsePe+>Lnzq5*V{eMeB) zVV*sJWd}0R>0XJ8aNHTrfHlWbG9jr(b38DhN+N3yb8F@k%4lrHQ}q!Bu`E&7Q;FF^ zI&z{x<`qe{D_8wOdzf6I=3{<$is;%;F{hJ-g4gcyT2e)Cm>#59B<||hg5OtS*_Ah% zX;WKpBin>Alj{1lhr}O!)o!M|5X{O^z|6a7@Kbb;jlC#Ol&;zQ56{i_|9}O2aYh#B zA^sbixq%HD(&u>!H2V%RKfHZtor2Ay4tTG^zLEj2XfcOF6D%iHCu948&TGog70bec z#s(4<40j6$se|N+b}~D~#rV+zeO(@Jr}(|mmr-Vcah%EXK~$kP<+$U6Eh~(I~jE%`e!`c()J?Zni!^W?HHuW~pgZVk?=;#zK);BC$ zMc$Yod-WfK>NM5#vVM5SE2ZS*GBtm6_^}GnR5ryA5&TKcH~6JV465z)^Pgd%9FUw% z1f|L1Lf8=0%-rPyL*BmrB*$!QbOBZnMuOC4S0d=fABQ}N6A~8}KigRXs>r0q7Xmyb zm6q@BI!YD*;^9Cf&f|ek3|3ZMc{Dt)k$0x)iH7nHHdz`%)eNooxrukZ^KKA}VeUSm zN4-D2w4Qopf>awRaxetF0+lFAtDy?_`w4uH-WfX?9b6ha7(5CuHpL8E=eN*uRX#IZ zKRsRwmC0{411I~rLA1mFvE9(u6CemapILtLe%l?zPl((w{CN$kpvt``yh&^YGh|bt zsaiAER5nh7xYd3fD4L&fSOK5e8%*~$qnOUM&QoPmucF#l z@5LWe+4Q~mz?m-2qY2aM=%S{;K0-1issT$(cniggKI!mhyfJy)K@aJ@?68^P=M>lA z&eWIe;zSLUIHD{*S5QqF=`m??Dq*~yzi`hN#+YI7l%#|M& zHWD1GdB+tNG~CTX6J+a+fq5(pC2xjM5SH8tXLwZmI)}QM^_T2Yfc9j>uJK5Jrq9>4 ze9p%ypUvxE%w-RNv`dUB*>bf#X*pF-sWeBoR&ia1yJx-q(Ay*tsp;F17Wh&oDqy6r zdhzKSAkx>+dlSb05<;&(yo^$R9y4B4s5KQBQ?>r1vdfz*Us{_O0Kn+I9tFiGK1}t1 z4vztkMC|wB*V+Qh_E7d4pH)xF`X3yT@wS}7r3cx@=7$h0`6hmKqD1ll;Z%9E+KI|$QC)NL+B^Qw5JgMFC*?&~zew=aO@p+0(lbaVgL1e|Q<@!L6j;@1cY4{dqmv27pYxw8 zPmVgGZWr@#s9MW(YwHfih>(UyZR}#&so8zyYe|s(7ux;`60F>u;;Y#WJ(SRkn~phm z&@93LjZ@ojeeS^d*IBbm-KP3seQwO&e|v68hVeXxWE_Tq7M^%dTQ#TD}-WDby8jAYi?YhtG=n z)yUA{w<2aCup-gv*KRHOWb2yHUB^FJ{IcFGi!G6Ru8do*tV%@q?7n0ZI8cwRvBGt! zeG$s}OTtuvBD%h=gGx-jIf=^dDpx=`}RV{*;hc0x}XVq7l)Nx3W_rMOj`k*SWbrExy+85=qf; z84V8KN+!i3qfNst$At;(Bm`Wd`=!I z;w(xOoqlIbu~nRy+=oC#DBdf!r%9>E8gxZrcT^Gya+-Ete_9!04YxuYB3Q?n@+9yE zOctNOpT+H`unuO{iqc;<^42~Sg)pE90|EzLZ32RXn>v1}AwEJQkr@y2=-CmchOhXr z!e^2YfF1A19v}IIw|;vYGrz|b@dT{lOfVJ$Gw|Dh7VZ7W1IWJuE7vFxLdcT0a>hi) z#TlbN2@tkJjeM5V?_$Sd#{`oezO4(sXq$Q|cQW>8Sv>B);2N}`>sHmdq5pi_Ri^i4 z$9&y8Q17v>XSc0>Z28fB1E_GEIZV!nf{R*U1DDqOtVn-3^kZ%%L|)fPOv{#OCjf1m z7g3nJoxy2tg=j;tAc3}kCH0lXe2V=uXpYG3Hn6L~=>goyT!oEIIKf?B69VU?-x96$ zuOT7Bklc_p^ds^2sR(3vccJ2gwDJHpaco!hpTvp#nu6#rSr7B+eLZyh$_w$6Kkfb% z_##HUY08CO2hdBfPbX%4C(~e2*M(HS*%tL1^ibX1Q}|a}Y~>>J|9q0uR;4D@*CVy7 zxzHxM36Ol&@t7!Z3W{A zAZ7Tt3PImi#=6cwn*|?(RrnBd9@I&;Vf^0Bb%s?M9KEj$ewq>ynXequxdRA}y!PbS zwA#l=1+hlu`As&4fQAYL$>s}*5Fax!jz}uK&154*CN5xF5@A&%LR!vJp+xi*Me7{y z(^cGzb4oG@cJKH*B%hiHw1nyyh}(oDxE~@>&Zlibk(U3v!v^%g|T6Qop*Ssl`;eB@b@v<27W{ z)ymCU@^GSk9}A-QiI}H2k#a=R#R8?UhN*gx4*_lO@(HTx1o^b#@q|NgI~bU-!6v!1 zdPJd=P}(+-0wHn-xb~dpg(|w7vU}UT0&;F<6H0qKFFFhahDEsfceR3_n1oa+^d7o0 z#0{8wT72sT#$vH*O>Xr6$JKj>Q~mz&%O1Q$K2-zm34(Xe>m5a z!R_YHJu1w3AAim3;4fh;QjllBf|vw*s;&oN+j?mc$jG{~a+MH@IWDjVuc>$7R^ett z@S#o3S(jwz@Isp-{YSJb$r)b_3x3&LrU>7!*Vc6D2_%8^Wr7 zG87RFHe847KUwILCbExHP_u{NUJ`U-tHMOrQVDt?A~qQwUr5m431Z_^uI1PdMaFh& z>j#653CNE)BYU+EUkly{4275z7Fa^2*l^03=2KbsdRd6cXG#BDKUm=Xv|klN%|=In zy<;aPfJWZbV8O>(){?{VZj$uj$x`rHipjo*7#9m~+L;RG#?B8u#GY(Oow zUZffQtzH99c*?cxpKmXwN=3sqRD%-Txw=X%c#qPLb>Em;S3*VORf8QDoZ!~V%}NhO zdggl%H_blQlOJtc6Ra($T9%ZGC;)EqyuGpYg&Tv6lXH&mn!KfdoY_rky(rRG4L)lX zL@HWXOEE&KP^P@Pa63CYfA*6FYi&rQBFAN}JjNeBi!L2M_3+(LxKiK-8J* zlbG2*pIfp({RF1sKl{b>e|Ny6y2y(^Zx4=H&uTwsD40X7n=Ye7WlM0{9n{RjFqI|f z5cfJif~0n`J*{2%TBG5GNe_S6VHnEjnALuh$>8H<$R}+Pkh_HGjVHkPGC@ok(m1nk zf!(7YTk}j>=H3Za*`@dqNbz@3d$IeaY&WES8B;O4*S{n89KA)wYxLf0qSG_O&glCf4H$uwI2$f}FZ+vCf50Vz)MN~YPG5f&3L*|#_8AT^9 z+xU0I+-Q90cMZG!oQX*JSmTY)c&_*4@wW}IIy!h^i+dz&4B=`RyB%ykt3}l+M)cE3 zh>EUu*S#moID5HXe|5%DVHU>1iL_oDzhQ30Pmia?$0wcU4y{&dk;+j5|K?|fR-*4P zh@N$RI~e@gG@gF6kWum{NZL=(^Yg}Hfs5_NT?IbZ$*^9{qk7bOeR3-?g6F--n$J3!fRnkzf7*gYdJFK#xvAa+DLSYGQRMaI01aivEs2)VuS{L~27MU;#)d-c4B z${e5Nko`cL(b6|N9Vt(S`_G0!2*JO6( z7#^ll7IgYa05oCLcG5|MuA7*q{K|m7wL92+5UF*LSHy$q#`#V1vRCsctwYkoNBZOS)X1fUy|5Pm)ja4wt{uKbZAh=g)Q>Foh zm9yh4c>G3QB9IfIz4V=TZf9Nxumjr=hwQfAcOAxk91}8(qzj>KSNzP#<-2~O?Ry8E z#JnlCxtm~~z@tNGUi%G1)g$8;$B)6Z^eF1Tn`)h~fS$#@GTn>TX_upZ*=1-ve*Rqh zM0VHvTfbnk8I8oB>^QKhp7&PR`XC@7t47n{BznFhT%3`Qix~!!i~>i3)%$9;*7^}z z6lR4~pQ7o?*ku%EJ-k11c{=p7Ei`+DoN7PBR%f{XA(C!}qunp&Z*xx_vhCP&ZTg6b zrRu)NBLa)8;~7Gtu|TqCS+XrkeZRyzuSijZ!q>X4NbsUPqg#@yvM^l2O`J2}3R#Pl zFkkJep7?0OkKQk2SF1fDHsc~ewk`k4@cNr#=N0R&AOvKy0YS2B+mrpd5D|f$1do;` z)DvYcGEq%Q_>Volt{y>HNvsz!as+rFNG`?KVH?@StDrkoAg{X4*4P4lqNONkON zx5s26&l$Om-sO5anOYvLX^nuMYmb z^l9r|Ok{DX|IkN%=7WC46S|Y#O8W1Nkm+!uI8XFs8)N)OQjC-PL)4Ya7!^0eTg1pt zwWVK3Gx`;%(kN;BbvM&a0$8gW7SgIL%Dob8Ndy>s6A67`=bV;S6`m>4CHmQ0PhK?%bp%KJHi*XOL{a;1dPB6HCr zee6A#XHx&D_v;3)V?M zfgbf^pj-LQRnF!21*}HNCA@~63`W#j-;#(*rx)LM>4l1!C-X;Gu8*2y0k&hhnfv!Z zXf6{AnDK^Y=cQK07)E=fBr1C4!yn@7XI@m!Y*U2zv+d7p>wI<*n#ghV%u^O#CAdd# zxa3HrXhoQUpq5txp+$H|dDT~1IKIo_>}zr>`h~qU4}a1Ue2`0RVGFBegMXjH7%{{6 z*xZaCeUvh$9a@+jgT$5+NTbPxE5|6NilDs|Ut`A!90Y37G8{%#!E0%=b;0c)wkxPk zR<)i@So*cSOCdQJG$=jm%(xKpo==!?9}SLxYz92@K$Y%rn@q-0E!)v@iI5At|Wy$?x*L9^*k`(yoa{jpF}Be zV$Bm}If!||NG>8 z_cZO#BZcY1KdzPOXGBIio%;WBtAth|Bh)CKW{vYmwqI1?=Dwe6iIL&_4cED$?bCR7 z#amj~9S~>3LP@&BzV}5?Bs6XTcz{&2H)vf_2a&Ap@1oj83q0P`NEI?6#hBTcs$_^B zCK#Z~Ln79ZnofKP5F-_gMVo9{z530Hx zV7hD%3_ReM_a3mG@8<l}L$w`7>y= zkQ}-+EZOwZ!I?fDD9r#hZ}%7SCY9~@B>!hJxT^8ubEC__%MO?pTKi&6X#CS_A1~9= zPQ2ceJrb;{aFuR*$J~QkT(Of%TeWI9{Kdrt1bU4HSm|Y;-kxyigiCuXbgA#%M8k#ypsq%xQUMBno-^*p-Bnm*NR{7Gug(< zv0O5OM?5D+W+IPSuhE1mQ6L;KPbG=Es7WwYvGK*3zeC)>Ald!Xlkxs*tzEYVHaYD5 z**xv7x;rztQoT|B9eQSmos=B1hwdyTmnHN)J7dJ~$ZTG6V(R5cgrpfAT>uB6^MgJ@ z0Mpw@MVB6@ReS1sFr`O*U<|I#ta1K2mg;AMb(?EvOzT=>qPlu7?AsB84nc^1y@5Hk zoJ{pHTQ1vtm$hj& zjm}2c?)^g4ZH0%Q8fBDu`~M=v-yy5Tj8l0)6Y`Z}4D-9p2`pZyTaZ3T5ulUf(rr(8*3_;CkEVt%39JGQ`*4jM+=6%f^cl z-gOXM6&i;P-9cDqq`5~7U7~(3{7YAaA!~Z|YLhefhytGk(!`qRcnme4V82TAoJe+|48uf;<-|06#cO<4^hRL6;Z9Rag znA7V&d2Xo;kF0`}=dgYj(~SQTLA(W6m;&F$Gl*+jW7gEQzmb_2A$vHR*2fqI-uAF{ zO?HKyl*McPccq5kFOjvl>6%CHe2!(aW@U!AXK#yO5Q+t6L{5;g}IH z{n~kVNO|!Ez!NJ1{Y>K_j}5p`d{e5aMu|Pd4kI|rI-<^^);GSr1&G!}?<+4De~|up zsipQ^WF^wx554fNPUSmay&!Bo%g(!(0JcNn#|CjTHy!bU_zpvyT3-^f%36AOW_s_e zGGecwy?4Chy<5lL8*O|ysn|HwnedtMop?0*_!&m9cQ+zmQp|h&2pf3xIc4(G&&IFm z%#-hQ>yEIs1hcdB%$df@cyxd|G4fTo9MZCOM-FGPUa?S;Z-eNHA4Ja3P3jN2g!;fC zCDKbOdEVsFjC)s_P-A=8ZhY&)Ys{@&$AtWxe1!@Qo>ZH}L2- zcHQ82j9!n{+#B2r@2HLiXq%d*25nRHl1Vs>B^76m55tKriWSPCF3}2tRxuZF5VTSJ z${77i4mtX#9ai$7(+*45)8rTGktl4fM#umj7a<*T6pBtm$jcN1@u8#oep27Vn}2AY zWcg~X1X-qFX9?t3ARhE6@GYqjLzqHBbuhiQP)r8#id07l_L@_H6=dnppr{_GBKQco zu|RIQN6u7sg;cpxdwR&C+VtSrlvAmUU1{Tw{NGx>-49<$+#Q${^RPYnCS?97Bfv64 z{$R#C5H{$&*!GNK86tru&s27UIHF&RjAJ0EI}NZlrMwP(R)SRJIii;oR&3w{zT0{^ zR3!;ySXhPbJxI8=DrzxSu9Z~aqhtAjXFwgpOM_uUugTJ{v`-{lp$wiT#0w88o|)Xu zDN3(nYrUIL5!=M_VB-!6Qq<-wU(-E~^H)kx%SVq_tl?-b7XuBRXFT+sZmJtmJ)o5M z{&8tg>!b6R85Y%a-2b|`O-EoHKJyN`f&0n_>lvqQ4HlUI<4sD=G?va6l_hj0R^j7( ztBvJZDe`@jXrzhH9OaZ@+A{jU%j^K~9l_XZ7aeKsB0n%S;Lzbo$SzoltBP=Y`AV<) ziE%xXOqWvc5)bta6TiaMq}u4+1%qa6WxQF%$4k*i#B$d-GYu4gyyKM1ITPoUIDmWt zeI{KG31L`dhtFI}Us@u#x!P|Yz|8V3ozm^4rv;H^*aG&;0_VOrwB3)miST@h>EoI} zIa`qyk_3#2>w(|BjRo3ly4PCYsq6oNn-Sh$Vq-uc-%QoE`HQ7Fg(rEFiz`pqm!vwU zf-bKl=`7g}CAv^ZTg_eL#jLX>VeZSQAAGbFl$2FLb560!}<-&Ecfm?>RZoJZF2ccV~PmL0>&4 ztd5idE9r!-Ch&&!^LvazAJelL!SO{eeuW~k7*bgWfd*@xFp674?FvR@KaVswhw;+i z6cvml#W+p!W|SHW9`pXmZD_WZD;XssQ5US>h|6cn;$%6wXNJFImCEUhWB3d4Qzg5-Vqoa@z;oB!_kpOA) z!XusIQRp4x#5?RNXh|e1hD&Jurac*wd1g99gt@;xa@~wTLd+Wq+lJ*?y-8;g%X_RX z`=OoKB)VVJC#1C1;hc}D>|p~PNHQQZ%DLy{>o8lW@XT~OReg7@jrAw=}%DMC2*kgzvoU3rqap^^*^Yf z_`JW08O}S6>79Bk6$mmC;vFO8_o1)#>tH}>37+M4LNhulb_p;AapIBBe?V@a`-TV+ zt*K7y6su+gQnGGh??zqbORh9-eY81{7m{&KIrw5K?MuP=kTqegW1eZs=I2?bxWpY^ zhZH#KQW4Y(c);12291ew0+|!LrZnIPVe{IKJ+=q9M^IlK)|;v8gE~+cJd~iGeJ555 zK(P%G^P190^)mI4EN)Wv&I3*H2WfV5X4H#M8mxbRn~?aYa>wT19Ze0^`CHk%#+*xv z82TH~6%Tx91!`OFNE3LHALSR&A8RF3W{YPSPITq-`BFtFy&l@BUSOir3XK?AMk;~HY5hZcFG&x3v`iQ88}`qtn^fEc{+MC z3>K~#PDPuLYM(($CqW0i!L{Gt%+a@&t|dw@LPJmqox0a-i|I@niSrJykT2E6NH|dt zd_@0)u@)%>#WvG1*-}R|BthS|OWs75T`e<1&F2%l;;+=K-N!oyxg!bz^Y>2* zh0LufS+D4=L$I67!yR#G|HHj}60C&f_P0@hSyn5gk2;{!0+4mrmf9Jbu2#kIbF^;F zQG9rg71l?v(LZ4iq|ndK`I7z=^kWoj^Y&+3^Jqv-@}4XbD{)Y&Xb8sHHfsQb29DT+ z7@5?#TeH6$fN0+F1Bah2*$nH56$4O-?B9R#()@I4e6m{cR40(A{KfxjsP z4@}QJgLR(M9ICRbc4f-ns#%b@*xN7pmx~ew%W~Y)mfx`-hWWcx_`&kna7n`C182y; zJqcBFNG@WSim$LROW}!s==N!&DE?^Ch!O`t6cNQZ@dH_Qb#6aW!$W z6JWyp6^^J%+D{fzwi{16?merL;tagkpBWpqhfP#ztM@vYXF)O|yNNBG;r z>S~ixo&FkAO=fJAx&oK(s?V@bo6fiKZ;Jms!6#peoYL*``}bs`D?{77cN_>@y>qb6 zdJ)15wlX96$B~$8mp=PV&rds;htrm`74FU_L3L{@sZa&&bDEXcBjTMC>%j9JtLc#v zU#6@$>hXiG_Px`K0{86R26cY z{G=_$LkpY%9|nRlzO;dySwa19O?OIcJ)?G0He6qVOpvw?X&M2@+U&ocTcBWuV7R@ zkoEkJAhYB*&Xke2%?LPo?Q$A&GA4I$aC`24gBTUzF!wYe!d>pcAy@>3Qh9HEm8&hZ z@w<#GBL$JLnwHlUn=2DXWDS$~>kpHBqhct@tqz*JH|c+ONDH8Ebl#mnWf@*;bO~ZV zc2+%wQ^DA*k|9}7srf9gYr5?)8TJ)cz9PNW@k)xc9>ti=!*)n8qk5k>yVp8NI?%%1fSgyd6Wesd!C)11 z?Fg%$=6!+bW9`Mdj}^eBd3+wxf428V3yM+HQhxWHH{{q3j+mjDlt<(5>0%MmY`8BJ zyLIS@Tf9*FmIh4)Ev!&Z8}=R!7RSe= zVMK;q)C`3vcrPEcoIE$(O!&rG3`=vs))7c$lJiu4WI(UVjHMDx6&*979|pNYLrvMx z@oQ~_2s1_vj4`K86`aVM3Uga?7rj2zuZBgDaR@b*kD$%BfN~cXxm}QENE@xh0kyTM zVoX#FCp`10fvME!#m6ki&t^L^(9a)y37Y;dDTb$;V5gI!2eg-EPlYV*?Xtff$3DBP zdVk0R1-?ZqAd+1RMlpG+O++J=JfNXzlu1Mu!<3F6A98>mVuGKM5kYGm8lmeb5o?SWvsv=SLb4zCF|R z%#29!7lp-Rif`6%82RkD*WJJwNj;qvDDc?NW$A9Dj%L^d@d4EYtq78JfxDLzw9FVQ z);#_Dxm)(wEEJBY z&}xKGVNDrg`QN6rv=b(7VC1;4C6Axa;;^-l(l9CY33B)4dpJS2r+Z(&&d1ET&$ zq8NrsL@5@XG|^2E`ZFyh#zsjnj^^6u9<^qeALs#{P^vF1>AM9v1p%5L)By)0O{Ae| zViV9)K&g~%vL``_=vvG-iJ0Mhwh_+HjjI*Wbnh3O`XxO#Fm(`ZaUn2$;o@rI&t9VP z-@3W9TDsPa#l`s(0olKZdRNph@qv19z47TbCS8Q=kQ&bKzGyKC?eG@9;SgTfaJ`m* zOq;QubqBj6k-hv}Y%%h2MIy+p<`iEP@nb3IrOKj-q-N)l3Iu!@MuoWeUgD)PPW;== zYtk&;rmFSY1@{N3INO-j&?&ZnWh(2=dH!X7koWzVD>9?vQtV>Vrq7&204)b-AmqT1^KFKerSQCd|$#C?- zEr_LAJGJHD90-cAhL|Q+d$&{-kdwsObM5^Tk>*~4Yzt7*9h^yy|CuX~l8K-%dZ#k+ zlsI|L;5<@~B%sLdsSV$6dE?sh?_B=gui|x!tK}qZ5Ng6JglD>c2k4D|2I#eg#{l=3 zHCJt{%d?+eX8%r{g@Pn~>*Kj%`>0k_?=xW{F6mQ)tN?T9BCIpVT`3-y_#i2UGjKS~ zXrxZG@IeyyA58nTnX6QcpxdnqdIVW-OBRWGKXVTY@r1erqA6YwXA<< z{B!<#YfvPa+6A_x;~hZeMvzk^rsFdP@kBhQ2^*BA#1qWn=sCw9bi+!l5M6`s^|VZH zf+K@)fEb!tc`5%n2W;nEENu_Zq<*HCW+r4tL#4-yS)MOrQr4AlrLo(V8an`QTo-fq z;AFD}W_%XuN#{mL2FEe$AJM%h*1a}81Q?ABBVuHw*XWGV)|LiT!MY=Dm9J|uHECA- zaLiXFi*cl|XUK{w?4|*uQ!_v^eiu=X|6_pdcO{iOiE61HI3GMLK8cs_N^27n;qQ4$+KtHM>4HL=LW3pnYPYh-shn^HV{bg^mNEG=dTrKAb<8YY_S8xa5>F zIeXV2nbu_JC1>x89lasfx#&dgVY=mGh@{mwI9Rbm>ABM>5lCN0c1+#cY6TGvj%XL! z|Maz0@mjww!^t~h8#r;?4&BebMd5%K9OE8htDaKvtj0cUn2wbi(B$?A#=6+tYP?<~ z?X!q*U2KX$Qa|;)62RqYjtSN1B=SWB5T*nEtxD}t0ZK|8yM7yQzq4y#@8B>U`kUhE zM(MNfU$`_XG|3$$a8JJnh*(bcDyw-_5ReMxSp&TlVZ=x14SFHb%AaM5X1dH?O4n(=$G+ zcM4;7?w4(y#T>9&9ODB|V&W!zb1mmjLQ2d1)321NM$r!edQ#pH;m$3sN8+`sm!eE9)(eu0Zap^0d(|X;85sj!r0%=tDq42c!UN(!qx=l-HL-7GVSd;S_ca zN3BDIBo-cQ>;f`JzVg46YQckNI$WOOnkf~jv6oy0S)3tU4DtCW9XD}(5uZnm+6l-U z4qq!;Qb(xK6Z3FkkdHL8oq;cant6hT*H&%cTkS^4k5WPKj<5`ueu+1)ei?tEL}|60 zmCfYK;EDnIqFxEN-{J?o^GD|VKc|=#=r8Jb{%TSFhYf+@8|)sA(H&{GdHXLjuT)Z_ zx}hQd){w{HrfHyX8v$K~NPK%B8PZEx8ax?zurT_J*ELP1{P5S_Mu!c#(o z6iGDU5^SdHm^2h77tzn?Av5t(;g3uxaUcg2rm)1!)~<>lyQIEwT~Nwx->dV|nh?_@ zvXvloZwV#iybJ{~qUn^*u?yvsS*Gc?usdC z#=NEEX=iZ7?sg1kkW>b#1=qO;^dIum_J5~MhfmN@`k!^8t@6W*q1{=D4Jb578#uBA zd1;g(;-IV-n7*orIx#$5lU%-AhT{|+w@V>rPYyg&5_S0=;{M_q_DQuS3>F~b@B}JM zG@*=-D4VL(jLgyTg3x2{Q!yph?=mj0I_JAi-1nbt_MT4Gog$aLS=3Tzf2a+pb$Pd_ zz|}!P|D$?^Leh?7#X|T5)U%TH!xj}WQg@Q;jp0eh&BH08pA$}pa%wzL1U`)2pq$lX zEPy+NV`0_9n|xhR53=luO?#{GnhxdZfcP}{N%(vl1nFFDYqC?qF1m*Jwlv^GttX7i1Fk6GIvpHcP|3;0SkR zTwv^c;DYC{b`!&iNCCKO1sb#+K9!cO_VQ{!UHIB#-`*@qUQArUK+GyQ;8MyEQSdC zr-Q2E)RwHtYg{*}wfn0vLiwGlJUnxhJzln|ga_*~ox#Y9uPn$to89f_>uB|^Y?Vvy zm_Y_2UF~0h0B*n2O{w(0zg0hn`Nsa0C+cTNg-b2i(uENGQR_tbrHsbBFxL7@=^YQp zFQb?&jtPKD$?fvyG83Wf$V=IqR@*c1^w7Emrf_MwkH;jP9pj#B+9Uw5LgsF1eBj53 znqhf<3#rHJ^?XUJ!#4@b(mv5o3f$=g{>OJD5e4^KZ+d?weIBpJXVI!?pSJ;hywiAb z*MD;0%>BsubgjMmKQh0DO=6+?4dd`47Ju;x4F$D+R*9H=e$UB@F&OEW1Mne4i zl?=VEf?xOvD|hwSt|KWl9kiQR)Dxd7==c&|<-*0VOO)|eph54_Wp_oy?Gz8iDhaKf zP<6&7NB}_`yXa*pQf}k07U0=Q^P)#F)W%Q>svk878n24Pj*y=#dBiPQ6Wo-kITWu$ zfRYR06u_>8(%UH%&-!Zrk$xq5G{4X6S`VDMbm-CYd{YLvRl%h@kN zR%e-99KW&dXdWn+G^$lh4h6i#X{|AJpUd@#IjB%SU_x+&qszDDc|9*;IR1FZ8A*fB z%ITMoB?l?Uab$*K!vJ32n|^7)fb?>?P0edIfK{-s)XGQqJe==;yon#0-0 z4k$J-+^v15n4#E?=ZB;L(U*cZil`NUIwcy6vLv6miam|hG4I9Y*dYfxqFNoK`)gtT zZ^wW<$qIu-eZSnp-AHTGBhD6}oRRf=Gp)?pPye%46IU^cvKWCh-?ai@4bm2ndw>*{ zHMsIt5}E%X+&wK&6TVXHU1wd*5&u31s4R7OF8}suP)H_%3E_n(2PQ+Cc*llNMGxnL z{J)$iSew(EZ`MBjUvchUH%vSHo4}X833o0zEd!mYkKx;Sdcf|*| z|M6wNXV4{&1bZU23Lrz_h-`8vl3I1CsG#i*eNUDT)YP_Gu^=0M!&m<6I0CuAIgx)^ z$x@w^0zsYTqXZoa*(<6?wG9n8@oe6gKYU8lpL>-le3hw7I)-(TYl|uTu&sd<$@Jar zHC?d&?j;FLAl<}#*M9Tys+jM?=;wNjOgc+W+t4BLBCkDQ4*J2^->e7$RsF zcx=V`7x|oX{YNNw^I2m8BunOLK~O^FX3*%|_SgPSRc52Q0#JzsW3yI@K2F23K-niB z$CEzLwqm44ImWoBs65c~21?NhePAIrRaw}F#(i-1-; zJE&djE7pBaqQ^u&{q0Uss_N@#-cdGGp_}*T11oQ28j+!1dpl7Mdl(YXW;os}D zUx@U&j?M(L;xATbAoLy+Xs%C(Bs3c|g$}tisEf!E)Ed;iowBfjFTyEAD!03|d{^0VL3`BEHhC`4~(WsyT;R#hGot2%pavj z_@PUh*qfG>wH>=eE5zk*k-d+N{7EDVWoGVlg)(!WVlr1)kewFn=)MU(B*HUfWyi$xp{Gtg^_*D<3O8sUvf5=UCZf5e`l0Mtzv^~#@V?! z(PkSr!~z#y1xcV|mAXmPN<|~}X$Tzbia$}H_^MTG{Ee3zC4$cq8C{-+za7l-Gcw@j z5@86D3X*64N}^nN?-h4{PK!=Ec@&}xMWmuMN1pQ&gI%v^?MV1H3*ZGe{ORUMu3qSK z^VcKvs8?dtS0f6I6r;a9A4^Nl>)>EFf3!U&b05jQ>PKG)C{v?pj{nYO{gR~DW8)?i zo3@@x{GNF0#8PZJS|RvdV5;A}O=h`yL7;H>^~>~R=?!34m7m&|YAx{}ez)Q!eOt3E z5iDOJhswDcnU+gRO1primSgKyHYBxT4Dj|~L1tB5%WdQ0M6xp$+{g35AF(T0`)$K; zmWjw%6b8}T{`ktDR*B9_Wl3pTX_+0W>#^pz1(}k(4()*1PfX2D46T4iaEoiFx;nDR zf}oC-Zr8r|4R5}{%k-ww(#PV{J6th!MApDM9avm^9ov*6l^y?Bh~-1yE)@k_n!*Sm8Dv+d!D zoa0X)UVa;_k!CH`#Y@rjT|jOAcXjE~Sh;!YUojM&fBj#_n$RYb{PudvwuI>2XESF) zMsIYyKa4ce0&B%13)ps-Qv+>cmgX0B8crM|y|6w^i!cfEHz>NY_pEG!32DaXkfe^T z{-6!$pJEezx($aeR}fCydC82)sXfjRhQqR0u|G%~qXM-LA1^hM_Cga92EOIWvR2={ zf2`_1J8Dr@I=zScGP-xLvz{6FC+hR&z;;tx1b4FRmFtw~^h^Fhgak%dX^4(4F3Pl& zp*-XK;AC?uUApmX!rT9%*?&vOwMAtOx_u*G%p2(w?}PMdeYnjY7S>4M{~7|Vi6MGW z%|xy5NQVE`i>FuFaT}EOMnZl!JqqIqfhc{G;j@m%C+d;y-dv!FN`d>KWDIusN*|;p zEp;lBG6df8ZemibVdFZ=8K9&5} z>C{x#diu-8`}82VLg->wXyUYdZA@+?Q;zPDbMk>s23YtP80*bdKX1dYhKI->3lD73 z_*MgxXQt)1qC^W_LLtl@_DyIOtz!K18g?{~qza%8OL!RCJ+hT;xZaxmFbLInO!q4x zUgfrQUMP1M4=j{Gx&)b;G?CUX=rmYjEUYpA9ZtpIhVvtX8=P)n?zs8UkrHAmg2! zjr{nB?84#r0Si(P-KvFQcwZvGCux|ntAIi$)ok>?(d3!Tb159oIq$7YcM~ya<>B8C z-5~`*&>3^WJ={IY)lOCl%Na{E-${_AN53zuPFl8}2HdD2odJ9gcxI*ks(ZuVbbQM7 zJ_K+37m4%M=EV}Q$m>yP9Y=5B)LVU@2T<0Rod}Odx0vafZ>#z#f34zIdVO{4`C3@e z96OrKNm7a7vyQSV+ZC`>!;sWy;;+q&Y7@5c_Y#Z^6gD$3Oh;4K6Qi{Pg>^ zgp&osv;%A1hV4Ou><2YY8+{zXpm!Nj9n_b|A3NVg)c@Fgq-0XkVgBM3_vW;q!Xxei z)svkt`pMGs(NUS31bxs)z&b*ReRauK?6AMVr(Z(=VGYmJIh0;YjM#~X&Tz5Bg$a4^4y8bcYE5u}s%!c^GZX%)!+>eT;q zGhCHr)$pMCXhwOAyh41l~Htr$ZIMywGW)_prf# zZP~9EqeE_Po>&~@^_H9^ZGC@ryP_vMaH&4*U`W=r+&^Q=^;;!Z6umH<$OZjo8}f4l zQWqwJnlFpF^lo$;? zE@GlUbs7wO1LoLjo(DPy(3a%oP(7mM5txep2F|fe&HFom{SY1T7BTi$MN$jW6LztK z9rCU)a0tfrU7xyhr|ESyK&_ozsPFy*dVH9on^4D2ZdAuNvOKC3S^me%d(Lo#gffz?X;-yf=?1+Q+S5`;mCTo9 z9O$-J;fOLFN#eB5dg3(BFS_@9(kouQgx z0>!+0<}&j4&U>fr6=VjKaeLq71H{vp%t$~AVw?7By`Dl1hjAcGEnp_zMeC8@N$Mo|Z%&}Ch)%7sG`=%x93B3Al}R%rxfeSz;mbbWg8a*FS2t3@b7C$}B{w>it;=5tPe8_+ z<*QNe8bmqOMS|D7>No__rbe8ef{R>w2`ZA82m){$K_-|ve88bb;nx@;( z;FX7li3BG$F|G3LZXAOEp)gSsFW_*L49R}D68*{|+z~|ISj&x&RNGp8<1@V>bXL^t zIM5t@#hGuugX`eeC)bLWz;)Y$?!{&!ZYA8gJ8TysapO9?wKfa zl{xD2bNr6&#%kL50ezp1=&@{(eq5!rpubQNo|-dOMy}cZ4oR_bDd$?Q>#5Gkl!agE zyeP{*dtx@>hiJOgfc<#+-$Rs>VJ*k=V}a4YsHw6?w$jwa+hqG>&&anNOD2`8pMK|< zb+=cti|Ub5$@^bAhGV^JIGHP1{%zN|qzrJm-ODy4zn0ZOY_d3;ajrDY*mIQ$E^}_3 zihO@7dXq^G3Xk>QQ^DxLd+N#m9txaRqOvHKC$c#i;OuhP>k_*TkwA zioQiIrjhKhz01uRz{T%lr>v4v7kModWTqqTzgpsRqv^clizh|ehKIIWRcLN&Bh+qM z-I4>eNEBqqrsK~B-}{o<89{WsQ*_0=`4aI?Ol}9?N%8P&uP9{I$G0RKY19`;ztr1) zKRr^t`~{6BFt%%Hg$*hM{#>Hs_}WH_9Et~0v+=P@imQX+RKqF$}u+L z&#=&6gZfiJ62g9XcKmsLb8E%_eD3^rNK0+lEVuLfgvzyoHwQ5nM`s6KeY{5Qo<)QZ{LcRzd*ELXhfQ;?O3_zZoa zfzTCJV<1?#F?I#$+#ixrLAs&7t5*@oV$+0J(eWw$*1eEk1_T*%@oWTe+XuJLGLOJ; z=V<&VPByOMDvloF1R*}|qem-#-I&qUuPgVnS79ot>#PM<@kycrXg3O7ynJ-e^-CuX zC5P1yYp7N5JH~VKXT5sz&9L)oA`=*{-%smd(LuHGN>bg%QChCn!{hq@G)Ig4)1gu5 z8X`OQ+vebi=={Mf5!>oTZ%86l#3K13TjkYTW)5Jrs*!b7m5RV2OsbRD8RL|mdp`P< z0ECn%Hi32c`CA5&l97iGJ3 zJp~LsG$=7ccPRo=Lx)O9gM@;Bq;xa%2uLF;tssJwG}0m6N;iUZNH^c@^PF?u?;n`? z&D{IGuD#bxpIO4R{P>XL2x!bvFo#0V9PdH(AkU#RPqGy38T(^ zVK(WV=RaGeF)hk-_*8#2=h3L^u>c2u_08F!(^vUHMjac4SUy|P002mAXXxstUwMY- z6B7q&o{cTV#k&Ykjh9G7;Se>v*fBc%0pkd$;n;PSis@#)KET4%g0Taa`!@z;I# zj)@pM+8?zYig>@%+{ju)}hf zGB~dt0J5@)^7b`whGm;P0m!!=B?nj#g)i_0;kADPSWi3}CP0vKd!1jwIW)HU!PIx` z9YDK<+#^|KqhDm;M~jiZ++jv>?olY<7gGofD+B5*fKR*)Fc?d@s&HUuV}&6~kMZrO zQY7SgMZA_~Yd6V?FA~g_$_PXiZcov#_uIJ2%k{k@K+L9WRi<1Y27YA_ZPZzj(G6g! zBAokY04w?509K@*e7P$iZKL70HfXMgc~uKrvt7n;a?@FH4s!{8sDI=+ zi^gOz`;do-mozW0Rgs!=cJk-6PMCZ0GArvdpaDC_s<~un zs;Y16c|F+GD;-O7FmR2odGsWtdlXopTWi{QA(16WRAAvov=;T9`ag1pJ*sN7yH5{b ze}dP3`+BB>sXnlOe((XN##_Jo=miVQ>|aeorkcYJuYre(j2A`!jCv>ZsVY{-mis2e zImrZZZp(vr4ZE$w!s8rcbP6Ue%LMF>gxSN<;?Bc(8e-iC0+A`9^$2d3M_x9cv8aL8 zTQe-I*_7YKDnZ~?v?V4}>}Ib&&f`xD4tBJjb&-?*A?w`jZ-xaevV!ak7ZP}XgiPAc zum_YeSQ22?f4sIwaeoF-LVWx3rtK4!zkd_(rb3-0jo}{=&AsLOyBXA~%ZW44p}K^x z^>wGEeN$KIo{?T114N_WH%Pll5M{V`t@7~~cF#a0P85V+PSE+71NKbz z&)Xa^djb#Y=XGyzuyIyK=F(p@iD1}`&>o@w=qW|OqTqIwe^>FYpPWecria_hRR66t z;^RN^B~^9FC)xKtK7_S8n4l8PzGA+;l|?Yn9zfqR{fq{?{kCnk9*FoMHTkmZdG0q- z-K`t1QPjmwilhVK1M!O>(z&N*UM>;+`t;35!&bvB;0%`4hKu_k(m22XkgC^+jjT6dq$`;SeJc6|75?Uno)O?c?)$n+VK+CW zh_~Z^ehB(~fb$>2SHLd%(Y=GiMUlXA&e*QTnEEHVdr>PJG{HF!RnKxfxC&h*n0-fF zd?k4EOi1U18G1xSG)r5-7TF3!+@^gzyKQgPoDv9`xuy($;wo1UajQzJo;?CEgQv2o z$OkAlv>1ye0JNFzmKjwd2Ol_Kb&^Km*|OO*DALe^CpsN$fK;36^(9Htw1P<%v%W4f z>?|KaFL?q5T+n2jW!?u1xuEX`^*R%Z;hg40wUPy)fCNs{*x1tFn=>Io%As(_tzU0m zlq8P=8+VFuG2CbN18+&@ogJ5+$?WsjOEAe|%z3j_+`hNTtkJFg8sQwy&lHgGl;1N)C3h*z_@4D(FO^Ev+^X6yLBXw zF1noph62{?m}q#dx1_4go-s8!RWnF)hZ1KM%Y(cCG&K%-OmYe*0x*KC0r$8dK&f)o zjff=MaEPJ_Wli!^*ktv5>Y0ty3L4{L9^&GELyd&bW<6C(nPI}uCJt8YW%*#O;!XdJ z+U?5`4txjbMID}6y#SWRT2-&g8$R6UP$OU2eh7o#Z27Ysc-WH!SA7G5t}3Pe!%8pk zI0U_hVbfQFpUAL=NzTGzCwaoMW@8e&y8$^+z4nf7H`QYXoZHE#3l2u7s`*Y%I!(YS z6hEr*q@I7JNp#-g4hMM)WD%?t)X{;b%0|WvXvtPu{)EDH1r?ayP&|!dM0hMhLJNY8 zp{S=+mMvTh3=~Bl2}uBMk2J=aIJ3NCMiFmkSb9tw!Cd3m5dQj!{eq<{&nhCs2AoHm(ub&)x-qxPuEd8Q4dpH4*g;+KdP-OkC|V z37)?Xxv=jf`No+N*|>jCX8wBV0EK>7K+kcp(QY;HBxQ*{<1aB3_D5$><@`C zGR*_=fGagKY?Xe-pST%V!+t)0;jS@pGJBrW@Y!x|?xtPwI7s4nl7aNWm@*C=JxCUp z2i?Hb$ASaq9n2rD-onhvW{5wd;iyr1RBaGbO?e%Wpnsp4Gc5)b1ed*BnIu5yv^4XT zNrXhaf<*wy6mJD(!a&11-=jLb`ZqMzXDezR*WrcEZvs1RRR$Ff<0jVQ*Ig3ZOHP;U zO_KjNrOExHP~X86nAn>=@d5B3Y-9lEFEDsMtH=wT)(C{yKr@I@Gfx@7!9#?D_$}^0 z;KM1*56UQ~;Yzy7S}4A&l8Hy}9|GfFqW}IK&c_G9dF!~l&eBWF3g_@@AU}g)foVVs zaSKTsU>YfhV5(6{Qpf`~=>azd!1T`aEx#NRU`14wtl3(pg?}cptw{6lbdHJ z6MfK>8D9*pfD@=-O7Dt`L0NanSI#1But2!Oq%B56S69`pMMdG2jY1TP_7#Jv%FRDA zs_lO|*JDXZM_|7S!j+)8v(9bwJi>%yv>aHsBcwP0UlIuPT&>iPYO zhU+DYispbEG|E$rapjQhWQg=?Z)cL^k#Ev^>hC3wcM|5XCuKAAQ%QtD=my%vJe| zd%(k*S37R11HW7!cmSrmEL4Dtz;RjT6`REA@vhXe+(7)B*iKj+G^~mX(E>vj39CXHUK|oRqV?og>Au4P9;g; zuf84rNb^71iMz>W40!0Ij%ogpt!EovPZAUhuabkPjfK$0rxBhpEsF<6P!j_jWIh!_5N;G6iM;1AX`u1}uMNlxqsfV0}#)&IvbR-E83w;@IP*MWGCqe|5P`HqsW zTv-T1!?uhJEb4+D)z`UlzhjHjp5d@uG|10Uquo+ai9>h+OU>f7vU*e{-6V>{Wtv)U z^DvQGs@p#HKhHC32l)e!DUONzFL1ndCG?m|eQ8K1B9!J{Xs9z*cyD-Uuw^x|Mx#hv zM|VO3`ii|IL2_;whN0($q%f@(*-UphSh*b+&#PlD)uTU%%lCj-c`IF1>P07|%(>l- zSMF8Ic8YNk$pW;GejOv}ik|PBAVid*#4wTQ(1 zXJ(o58(yPjFWFviCYT^p8R5SbSE3ib?PFXd-2~Pc{q!^99{!|q_r7l#6Y;E!MTh|l z?&kQhj2TR_YI>qZ3AndCt?se%CS)sTL>Swx69c5~SlH`M-8NrlZ<3oWKwC%}lhE|c z0nkWYW1n1)c^d2n^^Rpk^ZeJuU@!_N<8%mbEA!@8Y_adIwzOScrG@M@r_;OZG%M0z z&wO^T7c|Joakmn3rh>`BM?5L8`e3CF%=&ZHi{(qfdu??I%ZYH3q@_Ym7A;-)=uh{9 z{wH#-(#pd&`!a`y4PzEjYOGQ+a})&*i|%yXhFFp~PwwT67h#r2bZ8t=I!j-$GcoC~ zLBcp&CLunjqCnS8GrJ1B#n~U2=H?F<)|NtC$;IF_5;TBAasY|sLsts)Pa*Q+5s2>V z1l-?FtmtNuG^b|OwF&INg=!RO-=hD7(Js0E!&&L&k0VyucWYu=rcut3R3w1$7!W$P z3mvUrpyx+Z-e<#C#c-ZztgCh$exi0`-ThfJHw0kJ!t+;Zep0L-s(twvDc4H}`r&_N zv{jp9`V;-T*WNtG9H~#v<-R^y_?JI!`4msE&sYHZx*zr&baRlAg?vMuxsjO>G;2Kth5nBC!Hf}k%FwEKsy5)Uw$VSngzU+;l_Px z=ZC$?nnUXAPA+0d?mm*ui27@mtyWjyvTDAAptDKBCCwn+vRE^!QHHk-tyao90iLW* zz)orPg7O_v^r&-C(8mK3F}P&N<^oFk=4R5DLa=wFY~Y~5Zf2soFM8$+ce${gKo%## zdqp*vY`RB~ZEp6c(s$y@*>l^9WF_NU>5SjCzA4fgSKy?8Z+%K0#+|gv5BtDbR#n(j z$d=O8H^H()?U{PBD{=DLcS&iPJlRlFHe=Vj@wCcSBh5=gxAEp8P?UpV2?jfGjVHZN zLc8n-gXPH6yNZEZ8YvLbZhI;U=p5KxVEt55$bGrHv|E-+Ep+pZBTSqIEOb|h364!7 zkCaDKQu@Tx=!M8dQXqcibJ5-VeR%vr&5`e8$75C`$@Le}&#+Y|e2#VSuN$omfd^Bf zB`|Wx)5M19P|Sp-z?QAoOqep$*Upw`SkN$(F&uR+5a zeUh1oLAQK7l_f-wak?FA9=eUGqfhBr%rRt&^C)Cl0p-gJk-V*wrW^k04C<{xvS)13bB?hLa=rb znagL#_0`uLSH!AyX{>PwMCz6U2dtgl*FxjRFUFgno!f~{Ub8*30^t~ggmYx4P_iHh z{u1(`>a)v34vp!#k7kv2OpR?-N!CSg87eM(ipwQN1WMaRjqMBSC6c-!tPA@F@!8j{ z`0yl2Rv5$gn6Q=>P!Hqxou15LDk)b={|F3vQpR(>;%e$-&U%o}u);Sld|P05+Ai>@ zo5!7j*eW*I8V1YSTko4~MKCcSYDPRY7wkCLYOI&@5&GAxANw{~WGAESn~8<<#Ef~# z$Y?QdvHR#yvA#wdb+%|%GKAKtqG}*c897$v(fWEy#SQ|o_3P*TPlco6-K3)3Ejo$= z@2OPL1I0aULOJB`q1y>*tQqzXxk=mKStWD1zEhiuh(w{-P7ZL-hWyezDJsuiw8il0 zzV<`757I40+@gRmvqxLLS3$<3nS=5SA(=P1nfNe;lalVPZ}kl_)^&IX_db;3J8{Jm zZEU8fvC+R5~(pPc33* zgY*`h6wZN!SAPz>lC^cXOU{T`cS>tY1GP!xL5Oo#nP#IJ`96(^)>0fth~EZVo*A zCp^}iqBO!FSnxgxnIT21;PMsHhN2Vc2t%D&%F^TD#xas1Xwl7#%O5s$;Yupf`F1X? z$Sq>W)Rgg+4|6t#8k~XDi`ZZ}nS)l_52=(cLPflc<}qW2_yBvDATr2#w!B=XM924O zRHU`rZ55*ehc~qLma_e=wRL>(*KX5^ABQ!otjMaKXR_|~MabsTCVjfX%1_TvuB9Iw z^%84;&3s9U6$5{zC0F2*cfV6YKrOH5by>6Rdum#Dw3n_e$Cuh4JEec_R2E3=7ukc@PVK}ZLohx*GPl4$B)iZaZbvO%yp+Snf)_+9M z%`IpeFbapMRp`hLd0X?m*24?e@YBmzp_W+|VQ_q&qc33Saxh$#`^Zi_q2Ij;>Q~8nzoRyHhs>@0?wp@iO zEj!5qqOl)6CJZcW-|*qZyrZce;&$Psgya_0FqZp_2y=o` zJcAH%hS|D4vq67_hex&w}jY6>G)t0}MDEYr04f&A33Kf)B@DJCMx{y|y>+tv@2s4t9N$TgnvL zX6r5eFybh#o^c7L5a>F2Zq+yOfxx)PrELRHAp4fu02{TH>XBHna9*}aqaPUtg;T%O z6xR7rNE7mu@wb1^MF0Pti5XFtg;X}`*N3uPY~YwaM?=2ipENSHt1Bhz)!!NXt=W(u zVlOnWE^d)Wr>bwuqQJEKomgX2qYJk?)|oZ?zT3UiO>JuZJPv=N3mJ8-%HA;M$I-uF zfP6ru|ImrR!dJv{$D#OSZlnYZAhh)k2va*GuX$a(<*)FnHX$~4YV5k|YeK{6?>oml z)lxRoNurB^40F$&E)Qtd`lw{oGJdPgC;OY6ee5%oR}zCy(1K;>aN+QJP`GTA*yGbKkQ@FKq0vBV;(UDRH+d@--1DJI z3ldy2(8wUMeSC8vV86Uj8IWrE5GbvGPP_-4LEk%Lr}6?0%_f&9VpVsQ%ubi8ZcJ?LAz!nzP9a@NEoFL{?vL)>ffl~mwkcCWk~ z3dBh)cHV{zOS}ys;eXk;CyT$~?;LVD;dQ)v|$ zss$eJUkN|$i*nZF+C73}OVYTe-X+v?He7$Y4i3e{A^6+g{C)Dd|9kQt6bKNZ2{&mp zr9DzWzbB|7wV8eWxE{kQ-IM@}1{kP}cLXEHQkMg~^Wzx<;$J+-@?Kej_W9t!9(lIR@8bsjh1;?HEkcCYFF3r0f$n&ZSxzQq%5UITTf^p&=R{P=+bWLdcv}vc`el6=)f0VkmTXuM2nkw=8d1(J_b6*o z*1t5N;5vgN1!F@n_pqAcWeaCny+6K%k{bYoMgSk1IDRdN2+eMZb4;rW$H9-+b;Kf4 zCIw=t79&k9DuL!s?I@pvgA^tlY?Gxhmh_te!pwWvHp_-$4R^c3xX7Z#hr_KZX-$fa zH9kV$ZfNKgUW=W)p1K8*n=TR-X*hC}Q`K!6*|OvA^fs{)uQRSH87w}_W~$`5IXjtL zI}-I^;@JFmR$wKsQNd4(!mJJEI(!gt@{RB*RK|@Oy8EeQG-X`+Mtf$ivCRCP)LQT7 zTs*3d{7X{9c=FG{>Q2`Udi~K|p+%YnxcH7;ks5Vb-(Yo;PSz}N7s;C$+VO=sZfx?_ zGC|Ie{*ib9#f{yMzT039S#rE>Abj9ISU5U015@lQ78mpMD?B6LV3?fGqTl1XKB}#>1)5V78R~ zhU3w%{>KX~tJ6PS%|4yVeRm$f z5(MsD4hGF*5cK%TA1J|+Ycnk{%H|HPp)nWEMbahH=zBRRrl0R9w8?<2273%93ny1V z0}jWQ;_`A5iV(+Q8~?!RHHDyGaK^2DVWU+AONyq&F}CK6Y-OKXX*rNSI7BLxibjG! z!)TGs^lTqXi@zV^(ZSt&&q(%-7SL?H&!Y_fW2uW@4=X@OJi?TC>-@^REj@>McQ+7zpY+nQs*!5siWf#z>p6)B_$PV zb|&vrmlJJmclmDVww7}_un$J$=A+^XC(+CcXA2ogR&;YMx9+L{ZCuYs`g(e~Tt3|n za|TI%92O*-vm5d-cu1pT!JW==Q9)i541R+oYU_q~(wif`JtI~cn|av8x^>Z&v1|A3 z)Av_+94yVZ%t@Kzc}|0+ew!1H0XMa);3L-YBz}q8U~o7n1jDh{ibryV^v z?+&L0sK7!}Db+t>pWE*;XiU6etNqEf{8Uv4n%S`bJk@XPps3=6+pfn?hfWa-K!(es3<-@cHxr=ngdB6Y#G+BVR%&b~x00DNj`*GOb z_@H+cR>b9>X`U(`2*o4H>!Utv;rvo~Im!4Icc#bz7@qXlyVd^%HTd-)%F?If9^+0| z-$tpa4_8tJ+P_TWkRf^~k{gN#M_5Lp-uu?kLLkQCFu)PXAQof5e>5@jZgBx82=$Sj zJc#^#(qgnEqaFo<%KJWA5-p6rXTH_E)^e*P0)fkYie4df(QKhe%fvX#1(RZU(lF+3 zkk*qS?-(hQu|lnAxA`A#xj{m4Tnsad6VDoIpAmM_n?SvR;Kx*F{*&;bcbmnlD8Rb` z6cJjlz3a=KMl%?t-%FLA@C^}tVibK6|ZG*DU^{k$Ye-*0l5QS6lX3~ zUyzMY9EEtQPGNws6Wn?w&t*XNgi@dggRLs^R)B@}*nOYI4>O<1 z(~LfvnjjVp^BUX?ynM9Dz~TGfrfZLq)%nyv@VVJ}WnkS72UlIKUDlVw=S_0Bp36zA zvIHIS!V3479XHI9FTne#az$_E^C4d?31Ev*Ybv)eP95+>dSh9G#Ah4je`DM>A}wsL zqJRHWGxTZgXThN#wD3^eP*m4EM)u=p$;YUi#|YG49dq-LtT75%(P`bKN|j*x>k0%0 z*T+Rh7MQgQK}8o34S1i1@bseTw`7IIsQ-v$#?FV>t#PV#A{P)op3rCQ_;6=jVfRb$ zKvp?khItGTgB?cpZMgsfCGJ@-^vp|Mc@{-WD*!aLtSlOD^BfategmDe*TWx=sE5wqi8LccpiHVC__qyX_-Fg$ai1+Yp*E%N!1u zmUQ8w8B)6M&g8wdGpUsyf6#rA;oe^~=nl8;)%#h!ct@gRds$b)VSl4RkAD7HZl6c~ z@0dEvN-^_f#nChVhA z=(LVGy??*NG`s)V>6uqf3V2K;Gb+e^d@#Qs*RVMD5u9}S^QLB0`fTW4s+rtZ350o$ z1Ool%2LoL2oG$aD1kzcv*)$q6)YI!!oYY3{*tH>?B;B1w&}ta^RMRz;8nEcc9Ybm5 zI;QBpb~IrdR7CGpKZospt|HOIc&$Bn^n&{_iR4s`cJZp7O*N?u~uXPD@%o z95X58c|&nW2bROxHmM^Lt(lOh8SOF632y21Xj;Sm zj_!UNFO=iuZqa6Xvilc;udRGyM7WRoqQxc|VjYTo<%&S)=4bmDnv%5=Ekqdk?obKn z?aQj9TfqMa4gn4^nb{lEK(~WiQy8d=>65u%;64#k)Y2uFv`Jf31dW6jQNNd?bji?;(qtvRur zWVcK*MTa?Y#~Kw_i?$eVuxD;DVPzy~L8m59&``P(Sr9Tqs%6Co$AY&LAG|QA`Jk~P z8d2cRla@cRk80qp?D)mZl7Ws%c<-WExhyN5XSb8aTU-#(;w^_phV_>JgQo5N2TcP+ z%w&kmZ&wAOFDviA)=tZ}D0phg>*?Agf+J}3Z0V~NvG-++)wHx()wR-WD^n#-+bx{8KinxOGnj;__V*XMrl9a@Tg}ux3F+wMIhMx^S)b+9LS>O;@rrGFvgZ3 z*oZs^XX&1NFeezoOOgJ5hSQ!>k{X`@2Ew@D!9a+$Mvpy=i)(KPGzg0^&lC*Jyr_!U zvNrXZ=ny9MyKT+cTen% zS*FZE+F%SqSPM@B*WV}>4BAUb^-^+XB-hyZ3aa#Z2aR99FMwsm)s6JV zUUeA_qmV8MbSq-r&>-G>57Cq>fA*Y(dM<{F-CA~;i9U*dm%thwbQ>*$$sQ8U8JBxi82LYv+E+8ujlKV3Gurz9^S*jJ?^87@uLOndY%rZW%V6o{QzT6J_Hi z0?vT7uO5jsFd#4!&lhUCx9td7T1L018M;~hWM3K8e33}HrlpBNUhE=?rz9i}jo%Ocz;= zBQ_`QWa>-DblqOe$nYfC#fzq)m0Y8h@JX^HFKfYvbK5Qc!C1i3z+fRU6mJa<5w%4m zqeKJu*UFLPP1P6sX=z{a=>)tsZEMraYml$pU8|~4Kzp`}zf$XTQFPPhQGc4dc)~8j zFM!Yb7uTKtUnOdUVK^TI4hXDoZus?7=@z`vFIbT|ll<$0 zAZKp~7Cz+z94$}_x)3GaJV1@T1`gshuKc6|CnIs~5`6(c!zj8UuH+eVvc=a4X^Q^G zJb6)irW|xua_E|z1fjn~fdYg^AgG)SffG+UI5lP+!rh-9^xCHSkl$O`F;Tj=<5;Qe zO|8T6b;3H?p~*aP+*f+T8kSoS zt7r;RKwH#7YfjKfpY~lS{29sj$1rlaw`gPFw=7ZzIm3aW!U`_mj*o5RYo#FAY!}m} zm5>1o6~+x%BriYQeB`LgXL=VbsMgYekveOvf&IU_i}F%Nn#a#LP;6-{hRXd@O3_yPjzMU8+f?(j;qvP zx~w(K9B}BvAz&~T`@+L1acW&_w_5F`7Z5#uYo^_t{0lx9Oh;zIoO{W#gxbf}I0M%u zhN?9vIFUG7{hpxmV%-uRKm&TBk>ic28k2TTzO4b~^Fnpj0Yh_FlV7$=WCC-GbF*i( zF5U28(%+UM9-`062N_){b3?*_QWyXRkz6iDLKLMulPd(Wk4l;ejYf#Uy-{v40;2{=6i#6n7CWU{n?`^RXW}$MQSK%?7FI?CQOfXA;t#~+~T>as$UxHr@F%`+$VCglct_5(imCX(|B z12R=FnV(0Mlui-3V_ALB1`868W~^)$GFLRkSl%M^R;(|2;S6Zs4ABOw1fMhmxTf}( zUXTLXe3ggY?p|i$mFCb4M5y0bn5Uh^P6r7QL*aVNogsHO4klW| z5T2TWml3{bu^scwTF*!Nw`ki|2H|g<3tPGv4T8n3)dkyM;_il3X39#RX6%%bfWecP z6DTdqvZu{25;;LZ%R?tBf0Sf=3K>jtM3mqY#vI#v|6ti4VEs<^QlBTQxxE1BK$GMO z{xRs&Eg%DCW%nYhUiK(XwT_qVAL6s}-x@WG=xjLgoOe8FIAGvu%1_07j#)0udE5WS zMb=XHo)FI$5i_6L5cAaa*>LdPUavS#-Lm8M8Zqz!R@@AsF+ zk|N|7Srrf)DINz>t}E>D?wiQJ=gMu?FoIW{VNmCCqNDO)7d%R_vjt2uK0Y2!Gzm=e zd=3nA!^qv0DYYJMyfEzC9ydM3q~}|0r)P@W?4I|5jf=(M04lyE4%#~aJ`Mz8g;@ht z3pCa$i1jSvc*0=g@NYBsJ4DD+lu5rkJ2xNcG+;Tf%pQqp2$T7y?HL$sSw89Ot=hHq zq)k*_Xrh*_Duq0(z~wA#1G%grdt%ln>?{smMq{4PZY(GW-0_?P5S<~A7An84sVa&Z#3^00W$LF7(7qc zMuvygBxt~kmZ7mRT+L)|oye}9RiI^4r8}EB12B0wX>Au5J{uz`0&&%IgMMqh7YtW= zZ9DtY$ceHknQfEUE1NmGl=ZHdryhhkuRXEwqF3RWW$}|>Y=k9;vzn67{_Mf_*U0YxJOO^(HVjhb7MvT3-k09^dQ8iQtsyBp~CDh7Q203L@!F4w7!MoSLC z+u$$X8E{(q!*%%~i-`2c+rjl@D2`A7nL`O&2Cc%HgRU~bD_)v#C8IOGQ+KWYok?dX z2l7LmMFTSAD1qND?6jY;n|c?XvF`^{*e;b5{Di~$bW{L|x)r;h@wBMG^CoL4hwvd;CvdgOJZm(KFn5jy|K1a4vA)|edM;p(_h|Y}U^9IeL!s4}!!z*?(nB_bklJGT8@mmBPCBWX*4Ol?L%HnC<#>qG0+lR0oC5K#{?b_2 z@Jpb*p63<<0LGf5$VMrHt5 z-jwP8GF5d^Q^A_|+dxCr!(UX8_Yjl?&EWQF?I4Kzz{Dg(6TkY#FLO!YZ+|RPzIU7= zJbpA8=zW!9?_2B`80Z#p)jfN$b#USsPY!z}j^A?h5)CpyFG4{e2nnu%D?_AA013XV zb(6eYS}X>+%*>jtBd%VpxrUoP^Bkc4joFrC;x<{uQ9cTq%je0;k^%;K32C;;1K$vk zVBcRv5?O^~{s{+R2t?7!ryg5f7Uve?pZqxB_z|X&I)_%r^GUNZXvszPT1WTb|HuF& zKE2jxRLSL>lkwRh*ScvVWr-t&zj)pA!tpNX*Mw{EzC~uM0B(*-{UKl{`v}NhQ4ED; z)6CtILu+~@E+@v3Rsbh=7oMO(12XxJd^20`BPi0FY$7 z;Ex~fYy$4c3K5FqtAWtB!u4L(x2(Nw8?2}-$hM9D_=~ucv9N_YT3t|Vy;;S29v|LG zlP;87^H5G|whMd)(p&Hk5dahetpbA1hyncH&XA|(UnJQS(U!u!G!N1CB?41=@uY@A z1g|@L00oAi9R0g5x3b^OkwaS9UY^DL?c(SB{{=eSy=ELCPth;mKQlhJ-+E8h8*5!; z7RkvGcE_+c%pO>;YQ3oXRJ3at*`Omn@V%aZpGzUcU;roi2lJGGIywo5_46#l=pK8% zP7^#Vi7$=5eCb^wFNO9GVO~V+tK<)o6l4g-L{*|ZV?O6hqLm7O|AP~R?S6g7>ASlH zC8tMZ9Rg9lp!T5PshO#vC_xP)VgUmx z1qnJQw45+dqVJv3ji22zUoe-=(qMwp=Nt;P3Y2Plo#F_kn7SldCIT<|V)Mi{m`sI$ zPReKH3f)G?+=HKup&i}YVMhul7^_ngvavoj7gJ~6 zHq_uL2%4&FyGk)4qC-s(1!S$l73}SI-1_COBb=w#xFmo-4F3rNj_bx`hs70rA=vQ; z_gpYWC@pDEtCMFL=IXR&zuOhFa~kZw2dxSA7sO$ZY-LDLf%9riQF$xp|-#-ag$>I z{7f1$?>>L5d%H;2!I!L7^%{VQWBT;B_UzppPB8(ju+Rrf`wagaQy2d`rT`%tSDYs8 z_SfNA|EhyeI(i+uLyvwJ*HZ*B!F{;?(g-$WnBVlpWq}|N0lZ@Z2(bZw8P4{)J_y+b zZE;^^MrtR7X=B#)nQ00NY2+_IQ3ne)D+#?k6ksnyVfB%cH^VZ9>7RG&?2=QW&XDPK8oW<1{FNfTMc0XXK zbBqNUyc%MvY|081gUT_7T2$a;!7Bl9C8Bh30DHsk$t!idlG~`*gV@orQU+(Yercs( z>!+ct63^pkR~h!|fwk|A;@J=KH1#_A{TPh8yUyHwn8LQb9|zD8_GKq&|FG}=42&m3 zb=N~A4HxMy#z`KGOr`1LBon$9g!gg;W`v(lr(q*pq^x|YcH^2F$47oW9k>`&NAs%A z#tr)z;phHN{_Z0J7An3Ytk=uTys2o*eS5_tPTB(kc>p7Jf)|(-Dx*Uck|30zaLtbr zYYrGt7HA2)vWskyg|0e=VMCyJ05wI=Sz&K)GwE3C*fDqhLEadw9)~(RTUz=NVJOi z1mZcEJ}&)z__QVGyIi4mryjsFE%}{a8*X*KEoC?Bn^)~xT9@p4(Ins&DHR2A&R$&N z-G~6)HYANWaa70W=I=d5$%07K=jg?v+MlIUt5#=O$$?)QSkpw6GI#-Kjj0VV8TA>fxb8DvYWFsw6 zdCJQT2yZloMvjf>HCsM6;ARQm4E4*BMeD4bfzHrSWAmSw(1kjgGZwkvOhnbTO0U!%6E zv>H4zK!&m9ikY-gB_kOb^nCw)%6^yWyr&in_MKiHzv#CL68R$8MDsdnpJ^|K6BR6|kL4Qm3j;nRA&UVfW{=%HBPHNv_`{aF)*y+`0I< z0Y!9Q?}2P(VvZuSln)Bl?xGstttLHL$YZX_3R$_M%4n$Fg~Lb^YGfY@(E81v!PboX zXcIv&xEwqQ%E$`gsFxUUe%R6_()c{!D=hRGp^{*z=NRqXwXLt=e$9!YWS1`&VG4U? zKb7GZZV<>vTvaZia3cFB5739b2iY3zocl_|dC$oqa^xCQxqqLWiIKPTw&w#lHt$J$ z>ITi-mQ*l*J z-FWv9MD(Xd`-7v(F?v#%g1;dw6eO>DREI(qoHrb~Og|~>g)&=KkBKvqyRRQt?j&$} zB+2|HdP3O+ALH~np5a~up`@P%0jC9DuGnL)c>V&)KX@f|y{3&f75!9%#}{UeQ`dib zryX2yKY-$>x% zR9=z`__6`;r}us_B2d<$0aU6VM|@zwOwM8hHU91H?@a}cf5W>M+!dt*2(x|Bdvs}* zot7(mN?Zf!4syW_at*c;QQtUD)rwLv$P@X;HxBV}Wo$03#&+2jMYcz8X%tjX?z0JC zJAr)g*KFDQgNM5_4!(_75?k*@>xXeqOxLIXubtXdS$HgSE$^}QEU@V}Ye{>8KGD&H zg4x`;YusA;*fuzb#`1X{x%Lv-x@8C)d-(|ZhB4ZWCNpDK>mKU=vGv~JRLA}Q_&LXM zhIvp{#<8+T%3cS@NM&!*NyFwCWpnI9vPt$1DYNVuS(Q;_JK4(42;ukczVFXJzw7&_ zf6jH`b)Ms~o)HeFscwY#=$iXwCJ8XrBMX5^YVeQGzy!dNc6s#Xxd0wO`3jnfnr_=c zO}}Z&TO>F23RLq_0jEFZ{uw`zZ`9W!zU&au8`qb@hM&admhcDs38y%WHsvdt{=t4uH)c`9W-);7suSQWOqSuJim#-Q^8NM!ydw?F{#{W%rsbs64^qA89{LRJS+9?nJGiff==qqeWnaj`egY3yd^C1_x z9{&{)VVS|Y1(m@wxlFTIJGta*F=9RH5>;jD7!jeP4>cm}1Z38p;~RIrXRHz;EMpq8 z>QLy-eEjopTnCRWAC#;r&u`XTR`z$21+XxOM;J1NqCQ1y5(wUh|8_$*0sI!Jqcik9 z#5^Iav`yn;?`Xzt-g6Uq#b}ur&oTcv`0o*X^T(o#cNhrLz>xq#fi88g(jeGmt6->M zFa-^FfSVK;kbYy7*O>Q0W@rp@;DN@XgovkxrByUb>QKf$>X27Z4th{5pabnn@K6*8 zd-DEO`78c9?!8F#!h3P4GO3qu++*&3`>M$&E2D#{6B;<)!CLsDp+d$iAFR$)Q`Z$g z_^;K0FmXG&@-7xHYqm424LOo%EJozWZ!7~y<3*lJ^TF}%=EI!-oBUmnB11#d>k93- ze>kIb;I@*}5@a3!>o7$Bm^&!_XYxDiM>0m`#AVS|4Q7(Om&vzMpF|E*->p`1zB250 zjm>Vt-{qR@gPT7<l>rZD`1oZFN_{~dYaaWmrn-6+jwXsT=u_gAYr}eND{M0z z+TJZcav#yQl23MMuf`4m3OLfB61zyqOyHJms8Z7d#7@`)>fGyUSZtB{4eom44T@!% zcV4%_nuz=?5^_17`xB&cK>y74HqVdDCca0q{=HI;;$|pm=CtYpiC4{Ee!7@EY)EGs z&9uCsD39I^(?>hpCUUAO|vLQ2AeZh;b$tz%elc#|RNiv>4UEtq5kWsl`?~?mIN~iU3rM<0ZOPQOx#Wr;f2B)EC1|w;x z0_MTs(;yo2JhGq6lkpkJC@xb5h}sL%SVv_PA@4yp$znGMxZJkyPEbdT3?H-(n9jDF zZ0ntJLDg}g*IA)ZJ~@YS2kt)gmvgh!8D(%KHoR&{9VpJ)GkohG4_lTi@BanK8IqzX^w7ZG;dR<4x2`ys-*kpzPjr0coj2Ebl&9`sa4CKF z;-3Y7`}3+ZBcMwC9M2g#M-V1T$d7`V#W946N0kF`x>%inDp*0WI2n z7SZ9{rrfsPs4N=pO9Ht!<=na?)V{#a)P!s?^1wxu8A4*16gvcx1fdUgWa7FY%V%D= ztY%ZTb7k1XeGh!Mlonpm@p2KqLW=m4lP3J#1NNf>ag2HJdFJ~@&lL}h|7c?C1Gr|! z1Ne9xJDO*~iJ!9}O>ROBr|Lpof|>+9T}w$*!Ep-fSh$nH8Gge#@hvVTxP)qN1GL$2}$Nenk#r{B{-B zgsE>I-hB6*9(9=xkfoE%IybG3MzXDF&v?c+y)2Y|3Ukz=|#$qK{Z81U>6 z#A{R@f(R&UUx^_L3K@-*Cf zp|>M8b>|05zA-;9wwW8|R%GlnRp^;&l7_m_2B=cr|D#HO-6H+Ei7{SU(0fwR6w6@z z=B?=Fwi1zCB7E`LpDd+1*0VAcDvpbdAMk1-9^*J`z}B6tSzK3DQ!wGz z$8BS9DMUI{_!GmkP-qGlm6>!Q@K0bCiBunZRvzSkoPsq@#eNM_Mg#M-D`4iJOA_faHMjxios5Q?*<%VE{tZMvF&674bL{_D1pL~ zQ8e6W?tplADRVRx*p5V~l(b%ayzUa!f<8$}7`?(LcjXmUwz~xl2qVJ* z^ed-!H(~4b@N*x3daRTd*Kn3J3qf`s$D1(8(%7D!aIP^Y~ zkA@!)aQBgpCSbXQ&NU>lD{u)I<@(E=-l#e}H2IWt)pCdoQc3QimA{9a6>-%4U?jN( z?tHl#jQ8NKD7KMsCf#SD*VEwzVtRvq&v2mD6yzm~dZVMZV&A*#o!KFdF$FG-t?w$PaoehPLRCFgBDfCOpvB zdIkQolZ-o45Y~b+8Opc>GatSpvR==mPEIu@h@9*Yi-336pRb>~eBTuRKA%`Oc=tc= z?ws;J)KZsN=i;8qx`#t2HunKx;Uo1sb99L=f!5IQFDsAYUU1Mv3%KWJ zC0XA4#-zXQp{WgUo7x;2ADupuECtUvJyh+*b=V>R>Ca9?THJRsTu5M>e(e?v*|w75 z180-P+fTL!?b4mToaTHV7fp;;qex+5es61}z~+7rP3Z zgY9RTcB!2?ogB?UY}pJ>zU#4023G@+TXmn#=Mo5`-iCn#w8i?j_f&kQ5#W5|p%bBd z36UI0AA7c4^?zg}y1Lw4ghJGc<$slDklU1|`m0TJ3qgP}?UTN63?ab}>B%A$+Nt&j zjk+ukT%+tJ15_VYnZ7j~6ytwS#U4u(X!!Z6Xl zZSwNs;A)N%kg#kJ=z7&tNY^#kcA~ec-CTsAmsR=w2xv;^JzO+K5*0Z@(@7w7G}oAE zQb}M44cbHk^R4y&0ta-=}%~e@}DI{g|^Z|6#I`)>gHa%C60{ znW8iK$l0xz%9$Ip#XrF+PKEw%bvdDR(QlQ6!TG!ZD;n@Cwa5BpW$md~Yl`_64p?KP zUHIYL-hdXqEJuVBYd6GAuNYD_|JVhN75B-~>!dJ`Dtg$AIv8r>w&yPfa*nKxIqaZhpgL zNJ)%-p|sU@IwjE2@!R=dO0!9=wg#E3Y(NMd3PzKBB%M{`tWPXtfdfhKtlO?jJU{~!=PWCK}f$GU;uJSAeT<@ zReGs!-Aw({Iqz<$XEOo`n0(WKW2Fy`7Abmm03xZ1Gs!sUod7-Z`1RPsNk4Oz>Oo6y z*UNFbFC4izvF2W94eNG6Bd&j6=DuLkh9kW1CW$co*IZ?B^Pdy+Z2P$K>GpP8;B5Pq zC^&-X)H|%O30oabeoy76qKjEu*k}ta_Rgy>NNX+sUehdpv9hQ4RPIG}^4SQ7Q}NZh z{&DR;t8A5k`S`J=Rh{m-GbY1w-c|h;d91#PE5dSCvpChH#OOyHC4X-PT^9!r8)i#@ zif(WL-gklOlBl~X684k%hJ>2N&u7+jZ8XIS5uEm&3Dq$5yr^M|AL1j#b21|Wc?VTi z^`X)Z#PdtKLJ2^ZW}mvdt=w1{@r+n5=S-lDyDYmShaovbq1#X@)z!P)!7Fi66N7Lv zP0KH+&=4*#Ut}!Fq-+oZ=3)c5+YV}qjS5@wA&!sLlPJLA~=1Z<#L>_H_@^2fSYqb`fKyov?JyDr}KwO1GX)< zOx?yd(a7%2W##Y&m!dIaf*9FMK4D}2@ z8)|ei*jnS%a*Z4Y7xr;zwAxJ(FpPi&6d6f^I$mmu*34k=F)59rf*AxQkom_FKyU^S zY@3D~1oJxlob%)$wsobkcjJ|tvaOz}=V5QI5^nossT0o#hJeUQsqrM9`7Mh|IP1hB z8y=)@&MC5no}n3n6M=Onlw7@CQW#D6#XBY@Q2lb7sWVTCmkb!;Ix@hdoh+2_Do`O? zcKdk`VLDob0tgg6-mcqjn1R|P!FXZ-wP0Zl7w-}_r?NoRAY|pKJI_S4IugIM0 z{|w4_-&>@&^M9oP>)QF37~trjFp;c~e0n6e-iC#~rC76Q)w|46P1V0mlngt3I(i}K ztK|c@XBhY7GL5vg-rV&|K7s!hEp_nXP$mx)>WMbvJ)wO*10Z928BiX8DYuFzba>vk z2v<`tXOYa0=`Auo(FBN4Q3-nzb@deS2NY`F*_=#GB3oIL?WMpW5C5``lOq50Wo;FO z#ZOs|CA{{A#uJ5et$W0ncInL|G;sZKIzsBou z-Kc0efYl1S)}Ei$Po@dGje6J75CKca#9R8zX4qbNI2|}*f2^ESNeEZKN}9|(Q9pA- z5NE~&q-($G$V+9LJ$L0R=w`4r#RAY)sI1#Ir`w>cuF2L2tj~jR0Foq;2n@QA$CX32~Gn7vgYqhOPki1SqBdU9huxin*Hs6 zaX77@{N&AN@iI+cN&-QIC6GoI<66ODahfy4!Y(1?K? zNoVsKa1}=jx?HpN?Pr`nrs820BQYwa=KAYD9@7p9G?SUI z*qW^T{wv(q6YD=ovs{IZ7#Vy(U3)z`Gxs=^($Taiw4Tr5xyEwz9wR?`V1N>^)Xg~~ zgi%4NF^d5Gy5pFT7z$g*n3Lb+Z3d4H7lQ<|5vbd(c`9u_~NQ2?=YE&n*d7~ff5 z+=c;iOhC-Qm+e^Jh}~_5_hqwiG2q3HN!_!^h*~n}cB{B058_<|Av!>~P6b?Oqm=;L z4O#PqH&>^tR<|>2j{XiRI4j<(r&XyMoICCzd%D(FC8hnFy#ta*SRtXgY2`aWcoHa7 zZWTmM$fs!gd+G{T8^(?_!aw%lkO%NgQvsNi= z)8I7F1=7W!^T;oS(aYK3m^O4M@DjJG?(8`gJ@kb$5hsTArDTXT6->uzkk?IT6#WDZN<-)HN7kOcI%wu(@Oah z1x2fCBE0T{+JBZ$uKCSUKOM}Rw|yX^Vzo3_>c;=zf7{S4WY%*6dU6S`{NPlJLfRC= z^u!V{<{cAOAFPyz?(1g~611=YAAW0)VbsdRhSpO_&WfHU&!m`zap?@>;mDfLgo98q z+-L!~g(gdf%4`C`s}#-?QU3VG(oAIwZOSH%XP1yI6e?FBVIv|8G}@TRa{feqiWZP^ z>8GI|sax^xRJL!e_`&*HMu+#&Pi=PdB9$$?{m***A@CM6Z~QlbaYNp8K})oxee$Qszo17{@A009hQ*0NJ$y=y;uIdEAp~C^pEx@1v{Vd!^O+bN?&B< z<;~aqp0h=jKjJCjd!98{8EdgF166zw@KLRMA`uTgeI3 zfDi!>?cm`S-wV-P)5xu0(Zy8t`l^ybBTWh=pdAJDr+{`r8PKA*^JVVAj;oBe(sL*a zv>0hcIM*>WZ}`xEJrAM=$Bk2CeRx<*ej5NsQB)OYo;a>o^b0iZ{S(Q$1pl}t&#CkH zzYd*-EZ)3eH>;Stf4k#b=5&KrlAX@?z<`WBlUbiCc;D?}fZVd=qnSUZsD4Gn--YSv z8aLYVWG4FD=5@E4oUTAoCF79?^e{R$VTmItS%^i=B$B&Aag5-@Y>MFu#{pKcRFq^MLbwC+bc6J>PZ7zbbp590ip& z{4VNMYIF7+f-=GhJ{jbHshgC-Rf}cq;sbX-qy;a~TuE23WpS7?nxy;H{v%a7;B3%X z#)!$*3S3KBz(I*YVMT7AG_StayJ$p&Vcfy=;;_$m5NOQX*IqRojC^z^f;o0 zv&t24#|=D*#HQ4^>eD^wNy_)dV(?CI6q=U)s=O@z7CEB3Q~jvV4ost~Y0o+?+U#N5 zAGBc-wb^`hyIk$z)cN7g)6;C)dHXTe3sQDG0VFYOkmru+sxp&{r6kZ_g97k2VUm6&ENfpB0EOuI1545 z=LO7^cL8#FDFD68%O`6!Dh)h+G^*6=ha4xGp#yw{3ytPgYrIo-q95vc8ECw4{)Sbu z32+N(Df9tClb@?0^4((O)<{3>?W^A=4_hmF`lg#7Ksh^^0Ev?qVIu#Obp{jV15)|Z zbO(cA!BHT*L4 z`?na_*o~Ur1aMu`=*5@qz^ZF^uHYx6*Wh(>O(Ag2$TXP5>4INRZ}Tc~3Bt!8u=ylX zbNb{{ZvZ?2o6gU_n{6t8m70%^g8Y}e0Bzu)Ms@Bt2>R5Jt>pf-}CYVY!I{}&zuDYp8+T}LH+JlYJ)}O?Qs`+ZgWYXI{x^2 z0?Z=hA5_E0$U)@(0mtMb07oVSIGfoTN*-B7zCIKkQhwC#^XvY!AZ3SLE_1z1>xHkW zv2PXQN-hY+#w&{>7~@et2qD+nXmKEsKXTg?pdZt$15p^s#hXrYNst2}5) zZZ_EEm}?C!O}uuS`2|v6(5Iyl9@ei4ldRk7bKp?9GF=hT>F9-;&mScpp4u@VwVs^1 z%sU04cSDJkd@uR`rH4#vlRnLQPxPKIzrJiad@M75@#j8Y^o&$OZT7%!%KE^IRKgNl!?DggITCfK^+4DSsBdqXD9%O9EL}41uK+W@z0Zw$H4M zM=tii#d=^O6!izko#WqJ7wWy!P%Qtbq1p34&+`8N>sk7&j5KfV>RmK$W}K8yI0j0w zwNAzO-@~1{ozVzW#YHDzeD9ho@Pn(ZEUJ5p{!Tq=#P>tfi~O#~kef_#Ek7d=O4qD5 z>wZ4HrEycr;a(aw5P3JU6e?JMJL@oo9aHXz4)A@|ZMy4a+Ry*UAJx~i%{~=89^DT9u5{Qk z`}OI3|NYYu+L_16*SYo6!0qs|E+N~8uj)`OWIZ7W%Tn`~2?}yIKSFD=ml>ShBzSKi9PeoodKZrRXZfNzYw3ohPOl4X~2M(sd3p`)r|LizV#(nuM_T5dUZGClfvi@S;%vI~Qnz3wJtXx( z*dOLx1x;L$Ck&AwNoyX`xCpO?%frP14I62blnyS`^ad~@xm!!vTpSx0v-u@%>g+RQ(N;Uz5T^C>V&5yuj$1Cm_7B8pbIdGM;>4S%NORtWVTg`c9 z+t;@(yAfBoIsTgc%kawh7C!UmuwbWXr3ryZGExO@n}t2qEPk&myu=LzWbCp^sPV-< zN~8nkf^o|I{iL6MFN{~Gk%&VzX{U}JBW|?5tGgNmMNdB1engXi1Wb!*$4y3vaa@S z5%yk!ga6qj1aYT7Lrt2aJaRN^B)cFNAjMy3O8!qx#~V@!C=C7+2x_AIs=;}j53Ow) z@cfcaMfcU*{Qbv%T1CIlIDm*s!B^d1&(=Ch?+2ji-T{{qlP@}vQIve>v@ckWq?8|G zxzSzXKpxMPVF^Qwb^>L&nKCJNSHPH;?CKg#wP&A8WpVJ}gOnua*Zu zO4j;pjXnW^s75U%)G80}ilvP$w2TcY3Q>M6U5@0h@XjmzE93D2__j`wtV|+*lgzmq z99>fg=}bSWM2dlQ+8Uq>>0P?HI) zY97g7zzfH!E0srrqd`AFgh6^+_H-S1Z047k)MyJ{y`jhyTSF$^>m>`lK4l59wF##4h_?g=1~&Y<^7j8ixBjdoV(K}z&j-i( z(WBhxB6EN}A>&5B_dY-Sw6RuQG2mttH$4q_BfbLsAA9u5rVoZdmbT2ID#5u5d0S}J?0jaLh>pCkPWoF^?(H>U)^)dG*STlh(ZMdnyr z<#itnNLHZ7!YD}xLG&{z*O71Iy#wsd)I&;tuRs(2t#1T3U-{tDA38$hfP7uba71-t`XJh}C zvg8*HM|Tp*lmCi2836g}D*j>dcZY=jVlxKu%Er$D7w3@uXkl^=%7#0zE>%t1J%xm$ z&Q>-L>t4%Ky&f`dQS{Z|)ZqhAOxw$`f<7qt%F4t z2t$Rl$pCVPuM?i1NIa)&BK#c8(|&hqb;APk!srqTlj-_qkFuQc&Lr7Tsl6EQUK%=Z zC(}`im|hrd9l0*TT-Bc=6ev%m@PsRB`57c)kV*U6e58Ar9X*(9yKnB?AwN-?CSmnK zr8c|@l#as!mU4rfGFRN0UGj< zJ@^ff(4V9MggG$`Nq{(Ms*s`oOm&8lA!}FfsyuRQI4EsBA-O9U4FSizRO+W*am=_L zXcpxaFC6&7X*R$p^Xu>%RfjYaq*$rE&Lft_G^P9q5OB(w34yo+W(YUdH>H#S+mp;B z;xQ2jp7Aa|AGL&ZdN(A7+%yiA#-;TVw~N4F8EIx^RH`Z{Se|i9rAP>Ablld!*=owH zt{YF7*r2fcqbK{)zJ+rzFXi+Y0@D8jgl{QqN80+$WSYnR{AK^oFEhqC#D!bb_FOzE z0XNoId-=rho?}h^ad*u--CY}bm$s=Y_+uT@6dCBUT=R6F)c48XL5rKx03bHtN!N4P z`(ARzf9bBBJ6vEEZso%hv)5UND%l4^zF*!4O^-(ft=D~%jxeey(4sR=BigXbV5qJRc2(L8G-}93FTxfj*Ooz$qi&|E{HD0c{ zH;Vv+H=Kz+4X4Hx=0T&!9RBN^?cbjU4c0%UNwWAadfGP0yF&#eOIcld zz>6`U&~Uj4{g8x!-RngXJ)B{s=QX*O{*Jf(Rt;?K-c2Y9X%xlu^!L|PBAOoX%&*N~c&?EVv1<&{x1 z@-POHsyU0?>;KCFxardauy5yYsISbtHJo7yu)Hbndmq{z=P?=i&MeOC*h?YAgkq>Esj8lNsvej0eENmF!DszdjDgCO0 zLcM|k*m?O~1U5yZob1L9p(B(|^8PKiM1ET!j}_e&jo{}5Q4nH$ymvtOoh_28#)l$8 z@E;AWjn#XHd7GcUZ)5|abY=VyNAtG9P%{{f$Nb3SQijlH;qU}74uq-N4h)1qDj4Z? z)@%N)dtGb^$OK9sHINvZ#Us$;)i*^Hs_D9GN(f4Dp5@WI&c0(sbPYG4{i*XOueR&&*0;C(&VQKpKd%Be@{=;YuWc`P z{;|=$#_96fL@D|6)q~~Hr?{f;A1N#b{nIs0LtSZ!`$R7Vzd&sYwfqXPcX9AOc&GkgnP8UncuGR;$$=)MN|h?MsU2?WCX7`)a!UA+HeYgyAk zG}>ReIzt8_BljW1SSW=`pIMnK+x9wlz)+}icW$>n3nc@M_Q3n8@a~2dF3n4ElRPYU zQp%<4M3PjY-Fx4QR#Q)-lco<&$(O!m=xIWsNpckfrmu{_VSjb!=xP5gbszu5GPpZ9 zhx!8ZS?G&>DxYpwNI@Pp#UJ+jeXELFU`FPE$D_BYrreswlhyr6cLv!v>SG6;#SNjmP;8YmfDp0z?1m*?SM$_?iK!!;;NL42*o`m_JH6Krg_{PJIxqsL` z>-Q<>Z*kOS?(@&nDs1zQx7&YHoUUazcd^_XGy}j6%0m!stQQ0Wxab0J5qwb(P`>~a!Mp8gk5=uhWVh zLB#XlJV~sP(%Tc0r}qB5a)0jrD?TVzi?LtipX+bX2$Gz8Q4#p}pi=G^Cns%j?7b9L zgXoW<0~Zo8uFXxge{l0j<1!RVp5XrA~*T<W;< z@|zZ#4Ulp6_&hdoecnQ6uH29|ggfY4yY8F4YOF-IjkHwR&H0dc;ixOs?rX$VU>04H zt}X<$TV?f<_}y zHmIdgTxBbDyt{qVL*Al_m>Kg*rR_Ht{f8|-n=9OGd;NO*ZT{PYrQbjYUg=F%$6tg6 zuFqfWH3b8aiHi|gvsJ}1#2S<`^bIbgQc4+= z_q!R15F-#tv9#5i7YcGe3LlVb_ELUo6i??;{_b(-*#q*quto9}0b9+716g#wl78Nh z&}+(W45%jdlQ|_c;`@HW9b!~MfMhHk4M>;(?gP@FReiVm2y}F)#{HBz`u;Mf5CTy$ zAY}=BjbT4XGgus13TVEs`7E0}-MkXi^)IP-WInA93t|5^tZ6Z3+Ww0yUyus?z`1A0 zfO@qYij>X2w3-F#(?k#lvWy${eM&>QDo16(bRFb*SW?kdeR|ACRseVc`GBlr# zFq=IMFNo@fHu*~15G!8QZYx^~7bdl62{dSW9S;iQf8TxsRSl{n-NZb}@P}*f6KKyM ztlgmyT?&u(qBmL?G+dDj`1SuZm_xXUCM)q%JbX@Bs&ktqtt%{zj>?~!on&l6tmSq? z?Ry41q{A^ct)F=1c-Oywbb7MU#Cdkpe_XuHck_Nt@_#$;9g&dJxnmlIeXU@$p6dSe zl;-5;*IE>se6P2jazU|=Z_RnYiv<1zoH8=0WAwq4+fk;&A#~3-q1Y$Y-(#+KQGGTZ zKJ`(twH6o<=+qXt3h+zpX_re__|fGQtsr<_Qztny|4t01w*nY4I5|+N+tB-YtBpk3 zP5P02RP&!w~yM|6eSlL_9#6MqG?*wxq}x(5~t_SR^KmL zxxS!_XqR({hRXlTzl{@-?hyfm7(GrWktz>?cwpW6U5(~#+H<{3Ij{MF--;%*2m}Hs zO$CY^j)v)cavO;Rp)Ih%JyWF}L$0MzAfy9YBW zI^b9@1B!-u0&d)J5ZV|@Kd~W3gc*SI&hHZM9CVBGm@FWK?JUGuX@Yo1~Qn#+rw3)b+Bj(kG!IBMq*XD7rdNUyOiflGk}v z9e;8Cam*CYnM(dpo~S_UP3Eh9dPN^X9R*5-7@yDRp{RT?h8YS=0SOdv`0#n~;T-h# zXkGblrNqQU|HqFWZ;p*o&U$$te4O)8O-t-CU@>q9XZq?_#+)>K@)_xUq^ZT$v-^eU zF{|%|e{e-ifK1E`1tT{If%(lJmA~bCF(4($Ey=xg;oT<7a!>Qo!$&4hydNF3tghG3 z9h@!<9{u9~qH?}%A&*3e@NfxuEdsrwDXD@7t^ZqWEzpbRnLr37n$=!QC7y0iHbmN~ zQ4a^5y0pR!l7oSS5GS9=fOl3E-=c${N@p@C7LFE6%bY0+rJNs3Bm>jYNm``SOWY5>KKXh-%ZDrql@+sZ&2_M7 zQg}nR=Hg`*66{`JUE9l8{?uBvLYkiNDRut%YV#kbcALD|v)f5e(VLj55K3=&#zB+P zNqw-(QR~IfM&Eq>lPBlmQ*Tl-;aVrUH(OuEFQ?p__diHD*!a=wa#e{!`nipGObWt6 z=wNE(zL(Kei=Ur^8ZBvbo~cP-?l4?{&*oki$XTeUNuYfp1!8aXz0GcWH54HFAq6!l z**@NU=*JOcMLMQ~EI&xQoJEd(apYnhm&p_{iUrP zVIt@~XgBfDqX_NwJ>`SToe<&T#ami~rp>~&UmRk>OF1ij1NcYXr1km^JZ!47yNA4B z$-2;RH{TPxuz&9omRXY>MGvmdp8s;_|6UiekkPL?{Th>U_F*7&mnu7fN~gKMzAi5v z)hZ)!uYvmkLwFKkjEl|>QkariI_|i+qRmEB1gq)~mo&G>f&9f?=IFV;apLC-+UAR;quf-1HB8k7x01%h^P z(@Hlmfx?_532V^sw1r4;Tz?;VApHv+5UT}2@a@>FWk?CWhxwZW+?`v;FP?D8&|Ut1 zemvTo&y-G4dFex`p$E4w1c(VjBeZX#gqCh2iP-~`%G?CC$FV2`aY~6r|9-3pbB=J< zrEEEo?>_os%)eRtdSY>Bp9abU&jq0UA0>UyYN_BGE4^<(bj} zQB?+xgZdxS%Z6O@Vx@SdrVx|yE4+2P8+qjic^MjJ1P49A!EEixHHW2|&ejnK zh^`;2ZmTT| zW7cQ?9&e~u_JT{!2m&#q($zw9s{zgIeJ&ySwYS)tXUCx3VcClPyS=_l0HTNW7dI6C zSN2{V(PW8&4JO{)$kZur)l|06q#%>mZqQV{+V2t4g{Ts-|*2kgoZLI>PH9jm_v*Bm~u;zi3Ap}$5v zATWZjwP?K*?*NfPcljD0*dFKI$U3?5Xt61d!2IHJcX}0wMyfgWY~eN#F?VRDD`Ih{ zC?c1onmh@Md3c&|uJi=>kGucOasHJEbje@eYd1OGCKE_tWM?+H{~^eXEhb7(paj0>g5gk4)9)5j_hurRU6OYnt( zTPqtbl`XS5!6UJ=9?nsL7jPS#*5k;E_Y~=f0xkmV0fBtF z`v(nGey92vg8unV%MA_&@UHB3iKxl(3*E|Z8zIS;tt#H!8!B&OU)&Q6jJ|tBrsAlO zdp2_zQI`6@7HSWFyxs8G^)%aiVr1wn%;a&+*R{Wkw8Abc%L*{(R~4KmRxgEY+p5f< z-F$1Qz>e(WZ#@5toe@<5_IJl}-^4ho2275@!XS@F`Gx03+Tc)tQ5BiWnZH8YG z$F+0|eRh@QT0#|IGCU(+TBecns;yAc$_tKxXi9ZbA1rtLP=7esrNdpDcF7?B z)Hy()b9Q)fEl2kk+>3%-iJTD=-)(O9St)(>xU^6rNFD-N&R)OkaR1$;o(?IDx$trD zgvuOCpd-Zt(zOasdm~F|xGvN1JRQ$~JE26@tApMnZ#guCJRG<(G9r`I_?YJB0R}?= z`*f>5>Ix+Y2##HilH}uRXd&+o7vZi7!hDG&pu@(P;@^ofd0YFwo8942+1llVJ>5J! zEeRf_S@K5Hp^j%8KvU6#29V9Dv|o4g+}Y{ju$HnjfS3AimVJpJ=>Op^+gE-0c*{^p zXiQHRA8N35aQgk&!E&yV3mhT_Cz3Mgpjeo)WbG#ewm;)JpnZc9AXUI7t@5#ex&ee4 zenm~p>-GUr2Qdq$m>Q(?<37*rHf3%)RmwlsT7_-$tGKrYuK#~ry=7EXZQDOQ1v7}i z&?q&OG!hb$L#Lz&NJ@i%fJny>%1DDJh-O0s{Z-eLe4&_mj9>t~J2k z=Xw0{xJeTdBSL}zsPh^66xNeTXa=s!9K4T~#}J_9S*Jlph?h z=mZ%5&-*gbLD2>9RI+D*7}O2)b?42%$_oKVww|<*a;L%lr%09SL0Hc28`QUP?nZjH zQmP(QeY?b7p@%#>_ok(snwILxSI4Bk3bMdy)x|(pRLk?cYQSGrj$?v|Q>Y^W5lVws zu`VS#^*pKMET9!79ATGMS7{viLm7k2GUiHGE{=4aLv-sY+oVp^kh`{*cyY?mS>PUQ=QS|r<|fuixbv7 zdmOZwBnqD5!C(dJNMzj$ZwG|S!V<YA4Aa7)Ppg&iZIY>xPZ4Vn=HmxWT$!i4 z@g{W&cz#!&%{7J}6w z#u(6WL5z&;G8qKoOgvSQ{{-zk7PtmshP46heb+!4!A~v}3WP#PW7r^71-<<~E$VMA_oM5o|h2Ofxy{a4G)wWL#?D^ss>#DWP{}C|w%wwWWpHTCmAB zoRtyBwoog~3oGV_lx zU5jKYq9O(oTZ>){K5heDCHPPl5ldKSxm2mV{bbDDm57T)9FzyNz>kjjAtdn_V)4A0 zqQN8Ao4zUod|f!3KEQInm^S!84iNJGHDjxzLp7(=2C{8#!)uBc$HiD?hO`ht7`ssP zZHPe1X$vf!bKpt+J25=0KR3hv`13PP2^oD^7&!2X#na4I&l=}(!>h?(=oi-ruyn~S zG~wMxkbnfF(2QX@sEFxnf5=hY?K;}2Y$< z4)SyN{^|d#th8cyk#(>60hj|sg8&8bVc=$9%qX7FrT}~s>vfs%w=i@++m7E1fmdwR zAwmlIHtxJlcO7<%S%VcY*O`ONK#Z27(;Q%8_Q!Z?^6nndqzr9BSs2}F<=&aH=7nFm zA(!85{K8fVHQ$>hXaGMe&Mw(w&a)+&%)_UFo^{)c{ka)=s4Po(&sSi?GkDh-#ti4W-jUq6}#%?m_*#GOB5 zaG-&#g?MUdtSF9-fEdmPxzLvqOwYoBa(mPjj=b_c-^kaH78_}PTh~Y<%v1DvFy6?8 z7KE}x{8EZndSWS%t$>NXn(XZ!OF8Je?aqXk8i+8H7ncsxfhXSf91Do1BPNuwF~-CC9fi*8Qe0$0F_+)1zR>HA z*b)zug#;`ISFK;{U;=xv9nAlaZr*j*}k%X~k1CO>V2@AE9*1c_m1xV01al%?9|BI&| z^6`c5THYKg8;X{GhCnELR$PdYto%!XKcMSL7y@Gi_O4q1A~b{%#D#5?wenU;Sd4UVyfBo#Ou6_-<2E_X!*IbnN6n?I9sp?wBrYcOlh{rIg z)8U^<53J)@yy*zA4O2UJ9|GQzm*EcRa{rFL_Fe?7Oh-G^5&+~7e5Jf#FcR=2F}o2| zDxjT-4haaMMh~;5WDi33Z6bnz=t0PkO@#YW;HJPk^&Hq?n?4;=@s#o>2Rt7Yyt`=H zj~JmVxOyfG7hscsXLPR(y6j&e2q5Mzk{D!qB@mto<_X16&7|b*c3rwKLCiHLqv}j3 zYV^GDvC3AR_aAbFLkg!YYs`I~SESjWbBlokOQZDFOREkB3E@#Z?0#qFo%!T&MrMQm z!M{1+7|Na-au`1Zz`!_VRV4ivU|7!l!IE=L=wKPS)smaury3x;3WRu^c>10me;jjR z)Pb-9mcPVR0gf`6Yu_~IV>L29nB0@suxw@&3rx=I>+pMAE?N`{!G&{(iDm^c_XouW z5KX1fz->u$Z$rUrlB&G3WMSx%XnMg9cv!tz zA{)u7T`T;!tLn_eD>cc>wRW*Tk8Jk?di(dbh5gF*0%lrCv`7y)taZ}25Z%TK*8}mt;4L>zt&0YAJ1r7EKyUqMa#PF7uky=D?IYT`q z_EW-jB_XN$HZNHCZZ1Rfab1w|N(`{5A;!CETPnuSHWGueY&*^7z*k&Tciqz(4}>;+ zF^TWe+#WpoS+0b0?~iJ?1x_Fu?c9S#+85J&*6Y8=-o6avGU{Gv%c{UAFx745kn00> z5nW)Ty0;81>phu+aUxoxQKY48#2O#iVj!56&6d)Bzn{i2>g2&ZOAUvD5 zt1)wUVJsi^3_*hwqzd$SVn?gCNgD{BglO7-491Pd`Qh4*fIz0gBSJrExilDTSZqwMn)-ublLjkO`p{10-8zT2n2$+89Z332 z{)fi{l^H=~X_(JDaG}-UYzPo0Gu(RJ21Xc&C#HFV*p!Cf+U1iXQ{8*c$2Am0@QC%7 z$vXU`aLmQ@EhY-))jJ@G%y>7eVe4_!;qHGEY?RFEV|Tc{O$AFK`V7U`j}_PCXR2!=hA7K$;s zszbaj_y7DQpz<*u?cQ;n@lsI5_ToujWUef(9$SOZ_%Z1M>+jXEM-}+|^HKYg1eVgD zTY}VodEeuqOn;|yjKo7CU)l-31&d^Y=wpbXkni(6x!EZZh0%f`U)*ou&7E*ZIW&xL z31Yl^nO=dS{Cqr(LWG{yp{~j!Pu24`Mhg>0&b=IB=zmFQQ{})2W6VvyLk7kRPVk^g zNnf@oV#f~jC4%5egnwN@>kC`~@7!Oi8}P-9HdhIWJSwT}Z;zDB zayAmBp^PZnx#<0tAgSB_xkdvRaJ25z)<7Tnylp-a4ltdImk17UzGU4)I{~~(>|6wX zRx$#Mzs&7eX5-5bc$03v zu7YvfAlZyi`o-T(-MgWtUeWV)77I6LJ#;>cln*Of|84|-QVepy3aJPx!Dlg2DwuHb zKS+D$HsWKj0#(5XWzkZr?VXI6sADw*f=IJaaeL~0Xm`j+H$guonM^h{FqSDqV^gyi zP$V!IGj#QPIgOLj;Js^FQCh9k*P%vKOZyyL4fXk^#Q~yt*nBl13C|x;q(ZWT5}H@{ z$Md{Z2f7X88Wv*YUb9k3HPjByvr9XD2ki2_lwppWGD5pwo8e}KH4dYRm9Dil`XMY~ z31b)5=(``38Gph+OCj-Hep0a(4jvt;l`Kuvmq$?(BY|Hz958>=umN-!2P0j1%qr-c z&MSLvK&mDqzRGGB-}aBhN5Ar5Lg`~$q{}ucK#t|0I?T>i_kDX6I?{6MJ?M}91d7Ik zS>CdPoXD2p9pA2F3sDy5>tU)e`l%lx#Ut%8z$}j_zR$MlUCiL4JCul_diGnn(l~3Q z`%T?%ITB=O%f_JlUP{39yR*Hi#RNb^>iT!cM4nsCMFEG*T+1iNE3+z(1t+50HO1W_ zDbIppV$C6qR2&%JhcDXM`23yB+>~4OMM9-#>KeVe0GiGV#)M-KtoJ{f$?%5Ygk-oP z6hY``tI)omlXwV_5*bE51_^qPqCXPEFcFD^D9O48O`LyJH{Zw;%}|$#I1rVd%w^p4 zTb&JQ8122DVTWUw05G_OuPH4?!~-JZqg)CiC7LaUI*f$+fct z-R_w|U#qA~isP`}q&NJO#@ZY4uT%%$2bLMQkgU9PbJR!2;N%44IzTP%v1?Au<)Ea| zo{F~-cXDI$IW1)*v-ztdOX1nfoC-Zl_SOx_wLzaX*uBq;@3pr#I*1nx3=s-mGCWj7gL)1XeT~Ks{eY!~*?LX0gzR)3LpD+#*OuFPv`YU@m6YAK z0ovMG;{*RB+MCtx5`KaC(?X@@$B!|kw&9H?G2ZdJlqIrjG8tTqgl7iB>YeYY&_q_+ zKfg>9`8?v8iilCi2S|GworXcA5r=Xat;|pOic|8Vb*tI(Cs;m+98zk{4(F%jTaYGgicgadl>_k#fH2vDf)OO=))0f?<8%SHPqv)w} zO7x2Rg&0w|&!5F?dUDZMMBT+l)CMsN7Od+FSL@(uTkTC3JW9=5cr4Kyq3txhQ;0ow z)`s7mbyj8tN(}xGZT<%5wU^RATj!{MA{3=p^>kKXR@LCqx&_*;!3;@t`HtQvd{OK= z-aYgGKc)@Rzr37$ILB&vW?(qrZ|nZL9rZy^%Rg5lzIoE;Fw;IQAUfgatLmR~7O&q* z8)~BZtO7}mc#mK(L(NgVzOX$Y6ySfcish!E1-@Fm)ED7mr+}ew`r*FP@8Ru#!m_R5 zU}CQ}D8~8H#U7>72FEc2V-dV@aN1Ga1VHU4j~F%4K-2tev@C15`P_V(=b_&`U48C* zsO>YiXk^1lo(2*;M#(0X#u9=E+TiG5h$qOImkz}Y!u?ge#f+i~ze0F!Qpke=xMlmx z8%zwq+e~0V@QO_J0C$=%(w$)lD$QXq^#%O_Ne;b(#&cPAVTP$8Z1MQC?msNgG;*c&FD9 zA_;4(w*1?XW6n%8D&;~c<4s-l4_2g>yerZ%=&78H5#mi>H!%PgyiP#Tjoj0s%5mlv z^u)L)P>sh72NDa#wLzuDM~Ebbd~%+=Xl390YpWysh!YfVfF3*jTV-+iP=i?LH!1~X#q#%k*gG4e^? zV)4S)&%2a&Nc?0fQWC&WJ08yK_3AFm3K6EbmF2f#-7k-nuj3Ke>Gw}xeFebow1daj zWg{Rlw5(bWqAQHYP*K!*?WZ$P?{YuaD&Qb=AX~b_5Tkpb%O=bI)rQc(#q_pea|V=@ z!+5_ZM7o>fEF=5r{5* zd^QI9x_IRMmm9aZWFQfg4w|nTOo+LwE(u~B1ppU_0;E_yMT)$>Gw>%h+-5g=;kQ?E z>$5#34-?DV3!TtiK4P{nDx9B)3BY%P-7+&?K}A6n5pX2&OQ69kCMcN1ep2%z{%;6x zuM)f`l&&0`k0S79ps&=+MkNHXa{w1(ot_N3xHAOMUFjIiWt>~3Bf-z~#Xb*EB@4=@ zW7=?n{VZR3jy>LaXt@gWTj0t6gz);7mEqirV6eX2mr-!wn^ZR%^iYPymA>zn zLl-X;x|zI&`1gH`O(33<3h>A~So|PNud?sI${d<-b9L7wv*`8pYqB~>1 zn)Q@k9mot@y2?oBZpfoUkiH}m{xXO3NVn6lmYMgr7-ST0ztftrS9hdpzx68gl5nG` z$DJGotDsKQuq=F}T=HZW&Rl8iNpW|}&mN*BgEl)wx2Z^ZB5^YhLU6q4$ulKGiPX`V zUyVmL)Q&8u=7(V{C*pflq52@`NWtJIeMNfd1HQyE3-K%J&6>B(3L8vm4N_(2$-n&G zmfG+vDrR-l2>OLh{-)Edr_Oh4Z8R@Zm4Wx(=u_Moo$!)q;IVsn(eJb(1>&L4Le zaZdp#V0RxT%C2Kztd+PF6lDv4CmSF9_2tv`6ej+)rn?-*Qf9Siv{{3}RP+{3z0pAL zJfjABDM)(yo9>T(R#pcu>yF>>w{9pZ^0tm^i8rEvZ@dlA0{?(|oMl!5*@Rh8sQW;X zBri!DQ=1}&;IbNuDJL|Oe$VHhDZJO`cYL-Z?AKmOZCYs<9XUtuyT*U@r>e}l)-DKAfGK5^LMv8V-QANi8~z(5JPY(lX{3uLEBsU2mBgIsJeIyLdWFtkrAvLyV~T#`7v=)*B4N&stN~3h%dGU|X%YA#Xf=mO49FI? zl=%c~3laz21EaGEL6DxzC$xEXlzT=13lixu+bZe#v+p$fdid&W=GE;QLy93T7f759 zQw?GO-BF1l1%*)o2foCVA|8^YEm((02{E}n#HH7s#>i;q<9_#v?4uJqpET!@PvB0U z%*YQg3KCFsi9Y&-ZCr72B&<7-EfFE7OygD|do!xi_dqy8vX<9wd+;ka>LYzz4Xm9F z(4nx}%`cc#y6bbYsdh>R$YD9aZD960KCuM&J1l44t0OW{o;&r*sjFyViuW+bJ+ri$lx%xm(cv^T4RK0HVDe~17xh~PoCbeSgW?4pWN@y6Nr(P z)!LVidjv3;RwG%msz!ya7{1P@@9BAty`EnwkZ;XLOP_J&+ccfMVW*2nRC5-Xe-vO--bFk$e7|<}wt3c3>>yFGvq8kt z)6+snh>QS)BRk=5Zf+r~UYiP-{99zDqGdZ(V!wRGWt};jK+{zup>6AT!5qE(WH^Oh zaYzP%%s^;p{~gIfRD69~eB>|JLeJEbl%_JZPgWO6BoW%BRItXr{w)Od`k`07;^M;V zU-+lzkfZ;*6PNz(p_`xAkY?PnHXai#+J#D;zAq(mUF*OYmrkVY$UVEO&q~1)_hCgd zpZa6>X?C4P1>pBV$jk_6RW7LFPMOOT_dWU5)06q*`OG7Jz3{j}!O)Xq0o5Aiz9+@> z&h2P~WbX^eOi2t^qwZCG5C*lm`d6#|xsatS>uE1oMz>+E@pNy~hP!08FCkdCkx(C-WdtUAMr?t$=V&xbA9Z_c>5k^o_d*U=N# zgoQ3+o>r_ZUl{4vqtx{;LMB?uA?T3JX;`*VGP$_s2uC!AyDctnd4L!RqU>vkN1`E_ zGNgbujDccg#Do8xm!Fe2MBC&sv7q}WR{FE1wf#gxa}S%3+@ z)75-Z_u~IeEZ7z-v=k&d3l@Q7YC?I~1}dU}N(b(VrE^dlZ@$JTi#nP&@5N7U-cFdM z$a1Ek6-JIZ*a^PX_8vjP4Bf$H6)byM5kCCq3l`Y6WBd?eemfR++Iyez>glsB0ybKkYsA54?nr}+xP^Sm zpA33J7QjGXmAN4-_Y*iY>WQ57(-M~zux!-}|n-7!$6ph+z^O_6%u*H7$$HPB- zaNfsb%7(2^dpY?R)O{vu_ezSQ;pHlW-^@mzAF^-=7CBWA+Q4^zOKyh*@N9*uAc5gI2R=om-@S z`X-=lZZh~LC@Lzcoo@Y95N&dQQPbmL66&cI@&)9Ra^m4;7t@PHlzEgzA5zxl%8w{(FO z#%o0c5?>u^e4cw7%t?#x{qU(in9c992m@!i&7|?6NWjZ;DK|4K{~p+I>Kfxj7KnLc z(|i2m$AGYi<*7~QODo5QY~I_^jZv}cs#7aqW2Cnx&${{;wbV$A6lB*Pb17v*uK^kB zT3_U|&k#8`{y0|DPok%P0^?(K{z}uLK+Kvk9ZVo+SRKd>1h}=z5t`qqK*9HyNYG47 zC~q=&c59h^>siCYIbnbz9poXJ9=^V=Ija}*m^-|OhTNeWk*3`fN4|*iCqpZ?5VYOp zCBcO7-2pK}Gm%I-e8`dX_htM~tSE%SOpaGEUAe0WB_v1!{G3C*VSD>AuvF?p2FAP;jS5t8p`0!l%T9P}FrogF&R*wopk3wnJnJBtivuhJjv0#05rY`0wqV zISz|Dhc1iaLqBdgo|VasNnrgaNZ- zv1R4QP!_L_k6odoCHZ_3k)y-k{@lgHPMGq;VVFKrUt)mstq?_oz9cFIn6bU9^7V;n z*dDoY4|(!WPw9^Rse;&gdk;poU5BK6;UM zg>x>)-DpCXHi(jg3gXOvcvZ!tGB8fmOn;7dozV>v*C7+sAU@+Z!9BolrfZO}@5+JU zNoRPFdP{Z%#~@SM1Vy!+(;i;adeiT83hJpy1E#s(WOMhp4686yaV#N?FYXB3Jus?V zejo2Wp8JDmr)0Pk_?OWa-!o5y*_>Xz7&`UnHDpK*!;Pypp}N3IB&Jxt%854H6O9y`u*_0~?#nsC0&maVS$> zP3kX+B+l70^;l#I61hs!U4G-DQ9r}^L&|LYj$}O-lOz|aNZSh7b($5)!6}x8DbGj8 zsEaE7zyU883gL+Ij^Bkqc(qB{15agZtcWWfgTY%MJTL;BFtP!E=-6co`yG>*Xu2qT^^pkj#?>hqO zp3Xdc8k#^zP^Q}0`zHvfPtm&^jFI2k`c+b$5+QivsBz!%CtI&A&brEzV`he_>I_{%b34>py%22;Mo5VQ9}Djsh})xIpGJBcX~~o z$wV$mxtnD~F{Qqg8Nsss%xzY%f4;(SPCt7!dqj?uOj)=1B~w@O4i{em`)z6jRhCCl zNJ~NfZ$h@LTQAduk!$4L1gh?_z;A-m3Yt}AjP#F;G^c_!Ufx@U!f`hJYh63Wh{F}R zLorLppz7Jpi!9^h6qbF;o=ko(0u8mD_<|q7lG+QTIH`=UM%q9iG)P%a`sq^nV!uBQYZB7w=?v0Qb*nFuZ_7a;cix zHTYE;1_bOn1(E%2i*%m=L@%c%_t}L%O`&`FFX<`@f6!eO5=qTTy7x!Z$N%Eigzvuw=K$HNY9OfQT=-<7AU?z|zlou)#fW!c9=|T~D*3Q&Mscu+MLofc~*-M}!BVl^u|I!AY z+wPfKUuI?)%6?O451>k4Ls@i; zzd$bFtJqvEHe!)28k34uH_YJS7DerJHA@akLptT*c5(evN%XI_%o>^w?kPa^Z%H;jQ1lba4u;CN{@Yc0JgAx!W!YM$c;9Q00GhCD&)?sg<7F*)u zXy-kUSfJuB11U$*vKsIOVSf8AeM9cYjlRV!$-z+ZYdCOy=OZv15Q+kU)p*rkaG~$) zK#KvyngORb@Nrd~S#bL$dnRD|I%ysM3_qi8Qu5DQxT8Hq`M)QPUWpC^6Tu4tGcrnb zd16S`vxlkPUe%V40!JvJ%2OX21Tj1Y&kGVHXn^c{=FA~MNI?wn1(gY@QblkQt?!{uwsv4ndzuPgS!>?!>f=iUYOQ{+F3VqXuJ7ftrz_~L~?laOEOU7g-DDUVYYZr2#SVEvD;Zflx=-0NAt6V z+Tasl&=)xf4qC_`rqiU7eI=m&UP#b~>N3rL(~I41>Ku4Eu~!qJYfQH^;s{4oWx&Pl{9_A#qje(uxMReVK2V)N?>#ikQr zwvsgLG7d^4O|RP_6%3$&X^*GkgdODvjYKopAkBT_6^T{3cIGdDv)dx}d@S+ud!O@R z&t=V^d?&Wf@E~p~fwY*O@F7R%`~x7=hkSTbv)E_+;nRO^ThAkprHp_xZo1p7IUq4& z^k~fH>RlE6#l;(m(95ZMeO|rkEX;_{T4L5mwJh&DES z9gf)wWKNos1mBAUWSaAD!dja7h7>tm%Oc!nPmTdfw&~fv$G6O{6GNV(K&I!}`O5{uaqa%Ww-Y32fd=GP2M+Sily&>o=*=%umkbiVV zo#M|>A3h2=@rj+H)-f5tolYOH4II*Kg{l|OJlP2sSs{;J4z@%20NuqyHzhQC$wpPSVFCIa3pvu2_Nkdz1;*b=1nh{5U;F@A&=Bo$CrnNOgZa%pe|u*&_j zK5DvJdjHerv+Ul-R|6HwK0ro`DBz65c-Rhb1%Q;Rp&=c3c$rXM5*E;%%;3kD;Afi~ z?Ml>YaRlE>=tV9xslS0Umt9aWG;l+fv9U5D!0b#9V2{5YQgX9zGI!wk5!rWk+h4@2 zpqOl_dC`nEdNV`wuVjHmLJ{ACg`b3 zr?pdSr_^WOB_Zt}ND zKbL-C1Z*#8`w=Jmj?S$vpG{vd1pnRC*v6*IicDr>Fb2`A@mw8&ZE3GW_JyT8-;e>S zg9H=iB;@3ee`lcPM{S(&XSfkoXWC+(2XXK)r)T%SdVet&$O^WxeD4_`)cW)#+aJpx z`#LJSG5<^Ld4Kr_DoXYzgbb2^JNAf?XQ>+i6@*{S(UwRBZ)d%-Q*r1TJC0|bw##-1 z3j(2Tkf_q^?j{0NA7^BU;(X{S3<*L%BbM;F+04yIg*h8MrNM$2m*2$#ss1{pRC(;?=y3Vx2BFAhmy6SkwkG)j8dmPr<^J*rlOf@|iQ13i; z)GLGJP$*nno0GyBX*4YU0ch>xlNrn9ko4SQ4HF4Y$O=#F4N>^tVPH}}YK)yVoB<|v z2dkv7qo063bCO&FB~@6#qh)T!VsM|m?f2459Oua-K-r~oq4EdfCq4`AC26^4K!*(s zW(f@}kXj}|a0OPk$AQV^uL{D~U?clp$o6HQXG^@`?`Q4Bk)G}9amw%o(JxSRArRvJ~Jc-4h>=NWP|$s!W+ zeDH#%fE_x!llE5LxSQ@38f}|G-N$Ys#g&Ag?Dfu5Mcu0Rvs9ap;X}Yh#>x07zyFJS zEl7)-euCe_juGVw7-s+7lj4P9&HzC99KSDH=|9bqCN>fd7&n9=kp9qgh&f@#V!WYFqfao=!w(F0S2&?pk1h2ErxAc<+b)p<*djd8S4F~jX;Za*sn(d4_@K| zRyZa6&6i?aD1bmL$4!f>VdA08ozHi##h>plp#_{qsjL$Uzagcy!c3@-;j%b91Mx%` z{y!6_bvC;{ev<)`W@0c7V~)6a@P>w-@$Nh zB1n)djElM=tOMtl8>}Oy(&TO*Aufrz3?BmtIIBd%)#RU6wDIO@`0sZhQ?TxnSgPL& z1@7Of@2(Y|?JdOA$a^mNu#QwUffVg70hP{r=5u`iPkn8}#()<{&Pt0ILIKASCsRh} zr~L7MhNGlZdyfL?6vfW9+>)&_j-zCth_)TFKv!eJ1wkZ8Tyu_82N3pma3UWDkL6Mt z&D-4oqK_B{oIrO^%4XkC4$g|TUWdCMPKWuev`CzBX0&we`?(Di34#9eTAWgzFA1mCdZ&lF!oNbC(P0p;)LgWJrC})Xsa!<`uM25q13WyH!PU!Ws2JzS7*a~$>(5wCKH76~d$jpNNvElwf zhR9PddgI6;lQ+gsZMXIkWkv+0(Z!`F>}hAY0S_BfJDk)q;B$QY=xa2$ z?vA|aaff@~`p2m^G(t)evNA~h*?CUE1v&Y~-^Lnn+R~Kzh5X1HLSv1LZ=x!j7b81- zUCr59eo;#mONlRk+@1afRDC`I4&4&dhuc;<#h{iU=4m| zMxQZi&$yeTO1nR~paWTxgL9XbUTBtCn`u8C#S-Ex$Ui3TKy8qg|L$wbnQ#1=fuzVb#ng}vcD&a}VzOBbNt+TV?lYKaVB;zR zhj+O_?auSDoWkt3cin-RWh+av78nx59d#9Pg%BszA|BY(T`p#X%==ch9Tby7LM1+f zg4}{z{E{w1v#4Nubb%dcjHImKl~_VhV7JI-_U>MGzo~`t6NcNQ2x(` z@jGO{4N8BBx&Ec`*LKP$QBHDq^WY6WLh1I+SlNM7r_pvjENz5@i^OUf=pATl$V{k{ zIosD!L&PE(-pr@e5$Ob1x-a#c04)f=2PH~8pL{8eA90o_{gM;S)6Yd&|0;Ujr^Y>W zI}Q;n1eO+cJP`=G>52^cNVw_Zk@l3{FRCL8@{7*s`EoELH!B|sy92`uG+_l<8J4{k zyrlM18(qD>h4hwKXP_A2)Y6YmM|=zPZ=ns5HN$+38v920;^8l2oEdcFLX@QAWMz^-~Gka%jnHZIdkdN|Jf|LN*F*i7-BTywO8i11`tz)g6zu>*pzjphi|9< zpT(b6W%VZ?HzsdVLeHn?q+DyB3__WK|3HeteNzqn$g4qJ4o1G9wK##sXjVPIsM3Ap zO$1DWOBN`4?7u}}(QGg|SdlkB0=vupNmTi9R)O)PX=ObQY-c7iO=bDOE^T^Jap4>#n&;EM&pyTJdw@!ojg=JX!RuYF-ug_xOeL z|3#%X%KwQ&XLmpBK5h(6ySxAQ<2U&}=PcC7P3O#~4#sTe9XOynB#j9IONNbF6PHGG z`utA$sgIMI|AdqH)L(I_7&f&;h*3>BJ=4I1D z0WA!=H*#?y)F|%-d^;n1n;x_J5@@xM{a3NX{C+@}$v6#9`MmyPCpmOol-^xd7|Dto zOfDnMc?V8N;M}^YfZjEIhS07wEq+9nV%PqX-?VrkCu!`G9sqMKds`W-&VUR55azV5 z(1xAi!20uFHygkaxq|MPIe{Sbj16{-7&n2WI7^5_{jxeucP#ayGUq7_(%_ zB5*#@d_r9#_Rr(>Sq^W70iz03UEuG5IK6lwfM+Yy#ZK7Kq0YL@yPf!|eM8(wv zt6xe)emWZ`Ic1>#Ng|VY7fH%eZ!of0gcH(@MN)k83;sg{&FOdg*1Pcz2j*y5?rdby zf2sPII^OWz&@dJ*-I{xYMTGjp!)GpyX@M=5`;<(s$sdhg+zPlrR?Y|^GCFlHfT_*A zS{#{YDD7s}WvI?I_bY6mIT?*^0>TafRV;Qq|6PSPO;C@f9S=Z#n4f)uLZvst=Us|U zuoyO)o|XfPue_5l^K9(=&Qb*4m1w}^hKEjHZd|Fjn32GilkCzQ* z1$O(!+%Nr*xABI6C#kewP@$NWi!KP`e8iyBJTsW)hAs%D3w7?XZSIa7s)Yu8Q!`tx$D`Hbf*2*c~D#3FxOv79Z<61!NOl^wm$VUpFrJ7Cir)8Af02=6o$7I4G}V36 zXO_+W>aC`eR!wd{PG)fH!(jXvR zLx+^0grqdmNaqmJNJ>bTC?G8$NQWRGDJV!I4qZdn|M8sjetN&TmafIZbwBsM_ukj_ zyQ;(B(T9|4^3-?cV({Sb+?*x4*r$L)kc<$>cR0qW1lfuS94M}l04Vz;;wg7s2yz+E z9Yls!Zvg&IgO=2Se5q{Ng{IGPB=|VnwbnJ@3%ye!zxHZtO z2T4Nn(4TI`MRqEk-vDQ;B@F*DPc1vDM2U)74L|FF*wc-J>+biS3;5BP#P_l#dT%Tp zgH1Z^8g!=3zb)7TFRZiojbF-BGx`HhvT7S|?E5xJ4`g1l2#xpmu-q3*F#2Qq31;DW z>6wji7;P}dVPP~oZkKI3s37@-r}rNwQ3H-$?Z4sK`d-IO`_*D_A4$mOs_C9vf}pvH zVE<`A6h`vP-xqzpv~oTTu!^3Q(hHVRwNcga3zsqeagvO9EA_ax_t>Az(A-`@_vw)B zqF*JhUSenkdU+l=Ddo-wTLtUl~WPllQ;8Q$&j4KJbRAg$^pSv^57vJm>zJ{RKieH|lO5_U6 zj;mA`2yqNQE?!)n%^aP<5XnTyOJ2OlV<{ZXvIrTUqnyRqtSk8F;VmQlWd@=;K^RXs z{3Jzu{!!D0G73v5@q7nO@^>Fp5sgG%q2!uED7ws+@a+F zz+?<5TN4M_ImI))iU4uX-I3smrx#AknLPKk;Cz^-BAh!A0R7Pk=?;y5AP3V?K&X>Z zJ_mO_(|Q8pBd(%aMCr$C*zd3E(Ezt9h#W-J3q9h{LaQJs!5W!_sBXhboSdnh#gLx) zlZ(GHwwI5DFFV8Ww%hVQ54U@STt4K)VFM$9l^1Ht6HN|+CUr|(q#@}cy^0|(Az=+U zy}_MC6f59IcqoF1@NBC*9I)KCSfDFXKIPnbH|qz1)Vyo@3+6oD_a7RW9{i@(vcNPY z*-6fFK?~mFw1(nZ{nbFZQM8r!UlW-5rJPQsE20`sSw8)1B7jAL#qx3#fQJc4mOsK|KXfKDFw0{ypU9Zm2$nbOMHn#l$JeR)(L^=p`-D?YrMJ6i2%N zU<%|wgBBwGYw%mdCQrUtaLJMSU%&+B3&s=tDhTP;R*GPLu*G@_Lj!}-PJBm!7Cy|4g&`T{jWBf3X7-J`}4GQ&-$~Ya(rI9%-r=Cx}CRsm&}r9->OIE1%&^n*tV6d z*&o9>ImmARktOFA7D2F}B=|mf+g*fKeij zkDn|B$4SKbW(eWYN#bNg2n4es$G0un7MJrRGQ=+(Q;D3gVOg0CC{&*y3BP3EL;)0} z=~It>KQIfPQDZaE83uShh1D;xe)JMc!DRkln2(9%x`IrPCFDR|Xc^LZO<;1epsX@B zy5-iNnb=K#5t^+_TgId`^71}X-VTo}bm9BfA(GCJ#y?l$9+E&pcpsa+* zh)azC!t5Z#mK;;dC~y_P5$qDmur@fr*0CsaQc<05tna| z7!w`bo;ozym8-{?4pt&ci;o15-LT?GNRikBi&4L!(VtxgA-M^FVDs7(=+2#jfG(u= zxNVl2n*vp<*;4^ zWX9uNoJFp@23kdKTq{HQ%JbLqQt(tK4HQTp#EVW1Y^A7N3ng*}5p=SmV&6JivDj8+ ze)?g`@Q%VaxNHMA_%w)4t$+&+NAfQ7RyH?N-j$OHUs4kM$wo zxlly}>6G7L=m~&!#47%l^u)bB_E*B;l?&1Rx&3~8rf}%P_$}p$#OyYmGgb@=edTw- z;S^oqREd=5`ptE(ilT+yG888w1?!DK&?5_~#}C;0+`~L+b9BrRg{%xRbphj%{eQ6y zB_p^Anyo`6GTW38OT|lxXNI--xl096aT8yUi>WH+qhKGr9rQP<-bB_PG@CH|aLLm# zdJ-8=LK8o+AKkXu#?%Pt0AHJD+^9*q-T&902>tr#Tq6Qpv`cJK35?53!PpTN<{Qf$ z{%7})>Js=J8j3ir9mKafQmKu9WcpG5wZwsZ|T-nlrKZ!o{XG6(F9WXz6#5<{TVKwMzm zLtzjrnh?dhh*$G~8WkyGMG1#rm}HX?9^w`t2_#T#HN+ljh_cuuf!*k+H@R zU>;z2{R(G>Sj+ab=Y=IIofOz6_O(By`z8DrVAn{0nHK?fNqEFj{EihZKN6KLoh~QN zyhn(e+!zDYBToLK(4-A;Oy zHGh%qy4OOz6;<0_WI`9^^uFgpF1D9|f`y9j&t*^>cHx1O^eQI0F)6hCR4};H7P5g7 zDPR2-D5ql|@sl)h3jxX&j&OAe>fyTVBMIWRoAH-|e>O(oggMdJFp%u?=HV}5C55Xr z?o5mxKOG;|v&y>$JRk*8M_t9t)F#?nOCV`RMI2lg;b;O9iaOt@s?~V{`OU0PhO*7D z-q{&njGbl`!~KFGBSaKo!?-9W)sX2T=OA;-so1fZA6DoBq)9xxWXFnoVHWNG9{V6j zx0%_nVwc3@h$GBPK-y(*zjeJO0=#Mj!9fXH-+9P3;(U`}fHG;R_>_zH^`8w0M0^-!W%F23AbU+I{}VW*SP6G|ymztOV{{O&+~$B}ZT$hL+_$ zvSDn|BWnS~kOeDTx-IpZ$wc2G-1z``5G!zA>vE*QU>$IYjuja=DMkkUO2kSPeFsGt zmNGK?N`AimKs@nR$VLC$wxtnX;Qm6~c95vKW&UTI>gwSed^dUi0o+bJawK$#eCd@i zh)kWadF-{jAQYvEYUxHP*Lx!%(z zVD2=iGbP2@uDYIn{6?<*!AJKK>=PhQpb~lC;6-C`cYu8iO3Rd;?~|mEo-PxNDc0Ap z?d5q|Lk>`Vc<^}U#Z4C;^4je!C7E7%Ro7MrU%EEQNIm@@)oqsKYI5|-{ABF*s>Xz% zPH#1#YPO;wJBbuQ78v)pw4fndI^a0>^<7L75V?L-C$k1Z4&b=}f)SKVL0tWlY?wm> zqtgDbQrC^`F$*C`STiZ%*M2H%-B?BQasaS{l31z(?r4Hn!Y zffc?Gum0+wH#Q=`2-r;{fuL>glU2W$hjJN`F+IkFZQtVv+8B2*#nlQb#T`cyIU?V( zggKlEc1G}{W%_(L+}+TOvX3B^&2E*Y8MO|N?`ri~vf$wIB?CW|-n{CUv))jZad@2? zkj!kAQory|@_0O*n@L9|2jFNN+&Si2UxM^v&yC zRx&BK?n@gZrKPaY?mmd+Cjq|viIlqM9G_-92lZ_QnI2n8r%|aRd5~AEFo-{{h7k+xmtI-SOQuoK@9%vp`J5^;pGj zQ}r()=cXL`Pp3n$Rc#rs&+7oO2K2@r?-ycq94`sD4HE?z6_r4=S~^AZ_mN@@2a^<< z6wxWPh_lx%#;}O+2U-9#PEOZX1O6h*)&Jp6It=R0Qj@n70ULr3i}$>>^xUgdnYNYn zZ-|1Huoedbkg*QxqURdPjX;c0yQLJzJWcHnGSwLH&K`-)0PUgAa!+Hu$)^J;aN%;~ zgi8upv7WxGd0oGTcgH7jB?ESZ&1%hpX_Ob#Vq>3WudUVO>XW%E|81M>ZaZ@1e-`C% zhO^z;ijs4tFZVSz~RrjtKgJ;*6}x3!WL!z*`%XYl#P;VvLO z1du4%wPUM=Lssz;(ePW*Y^O|FGz%1&Lb`8gQS|>c?r|60BqDqL$ zbK)JK_3O3zwjzVk^E4)*s6lsnxT7K>w%qUN0jTI>g~ggN0|&MIg+>6=isCO0QR|Aa zVetS~@f)HY9z}l)R(-(StTge{`X!q@|CVNUwGMxeQh;MQ3b3Mrya5HSxo`4ent8Pz z-I@ttwB&zu&HFaYkduSDi^DN5Pp>11&AQuAYQdGNbY3xC>){_OE)27f7M>7<<7qOn zAA}`P4a^IoXR&`7>+u4+jYNVTpi?=pN1d;{v0y&14q!T+tl}y}_7BJgTuF zfM_Ic-*GkW_;CT`wiX)OLZT674NoHPgcWoPKc>ZNhDH~TzQr8zf%xv&^VrVe_Wu0Pg>^4& z#kX7qFI7NFdD9%I`zJ$PBl?SPf%>qbqdQ;f2qW1!D$8uS&4q_z#>CDbWaf0X%OEJ1 zc>DLKgQACT|BK_ACV8Uueby@fjIQ-k`=)i|_r;s4OF{9mEo+$z&s!@ek3Rlsb)(jqD6VM8)PmOfPd7Gb56@{38AS8EbMN z%klk|EN=w~WpBse0Em(S^(Swc9xyaTk@Q}iM^=P23h$*M8d%t-_WGl03|W<-QP3o+ zBofRE^VECvCQ*|!ITVOs;l-Um^Wthz79_0|Ns!{mv)RXEGvn)KQWRDV#bYWM*ZY{A zz-80I`VX965@`udN2FFMqj9m`!;kCpznHg$R9a#?loQjbgol|S{$eH%#=3gL3X-_8 zX|)=dkz+p?E7Z2|sG1zUw%;)w>H;Opy7d^Ih-u<@wE5P-qm$$KYvVdj=(fjBzx;~i zmw4a7sWuxbH23eS{1u#g`U9w#d(}%0#~|~o(tIDGiNlyx$B-Gogqi2whiF#|A|S3Sd%Pr_Nz#IpYA}ai5;+LBlN|L*0Og1U^EMcz zlHvF*jx&ZLRB@y6I|(BoOA&f0G9J-ef*_4NG*XpK{-Lwa@|biOpsGF6p76hT%kW6# zVU`gqCCv4YaAny5>NmQQUOlXt8Cnc5bKo0dp`HI3RVR^NX8dctsn^2X_dKZaL44mxZ%UfdwHC5v4N(NE}y*T0M!A+A!xvRVL#_n?bJVT z4q*9HmUKJJYu%<0ZTqhOt}oq{WoC!IF1(mw4QGogt?us&BXJI#ehKE?3F|@s>Xj(R zM`y`g(#9S;Jd+dud^6k%*Iwf-&aaz@-Ps2%J!lm)f2) zWQd`90zD2wB|y4^qo)3s#h4i><(=va8ntC>+lhF!i-RA!;9Y@tOxG|2hqrGj$SSa` z9g75EACQtcVkTdUVq8RiA`SzzZThx5K$Mw zHd0nL-5EL&HsNGuxHw6`Rp#~ z3-eMdNnYeY;6Ziptv2=)@T|Id|LbJx=bDbZpRypo&rQqn%4Pz2uI;DV+mAXf@@H<^ z0W?ByB#82uLD%enEVqTD<+Qdilo!~PFltDNI~a+HI8a^te)SJcO&f>^y~D1LFoT*Q z=^m^SJ?Oz70`N{Ke3%ZTpb(@8DS`Te2OdzroI+C*tX?E9>NPz&R=ApXy`X9J_xN@5 zXF%)H_15#veBfU#V=g=~B67&TYZWD*T?^u^OcX|^Ni8F`Ea@R($G^k9wt`eXCm4&v zSAjC8^$25N4FJDd`tK#g%~;CA5ox7SPXUK(UU#^Sws zzG)eBus-2_qS}#ObthdivVGwE!gcU~dHl;n>-ia}HZ?fZrGH+)66X2}jZ&cqoH%Zr z#o|j1N+ap7_CZf8gY~$?>d{Ma^wPsnRv~{$I4<_#D|b%_o2SU?_`0M_bx$$F;x(3Y z+~9~(F6ygoaoo`>_jTYcmodIqSg}!D%$)o_t62HG=qob|W{a;J32gU}E)=~uob}^# zE+04;7@@uBW>!yeJRAffyw#AKr$7d(c-tADM-Zkw9Yug>BLmkya?f(>Ku2^QB4|{@ zP&2#cC)k-0C0JnlqqSSzp0|?(`#iZXD!#Bjf<|>jHdPe2zGq4OozQVVXCK z+xK>ZMw#YhZ@=c7Tvi-GBmQzOvV}|q5o8KiMbG11)K6JJ)PiOb?|$e?hea>~iBC>C zzimoap8&aiQNhMj4#QzQjEB^S3fY?Xa7zW> ze!>x}a}9)DLuJ?_x;+9qJDY5oO&hkd`T z!Nm^FsfqR8`I?Q0G${}$jFh8>CJh1+*)aPU_f1-y939$XkW?{ET#U2XndbLF>X^`*3MaiUOL<3lOm|{Oxk7n9^R$iItrUbCptXq$x*`T^p z3+bM3iy|r){`SHMqeIhwD5zb_eRd#fTEG01aZmmBjZY89w};PF{RY?pNR47|UHTJO zV=-aOS{=`wcaHiYihFfEC57%C@aXTfWGysC0x}^^?z|p8zhyV`7E*RDH#25`xL%mG zVyLQLFFz^!FUn?~g#P^UA1m_;&l_7G!ruhJ4t~%jmkT*RL+>PeEg^d4v`q5VGpPF? z=H0t%sQcxKJ0s;SGryrUcD|E;;_7`ck@})MDq>iznok7pOyo85#_BFLLO8d3M#`AV zpKzp6=Gb-t*kH#K!u%ynxLBmUoAoQJtR2TDo6W6vsWDm`FW~TlAMwnN&BRqp)>+9J7Scn@p_DOcdB`=Jdjj^GZ>88IQi7WG-yFW1Co9Q0$j~E_2D? zFJ9gAYlgOrm>J;FpS%W08e_k9J(zQUR}9T}7utYBw_uK9NHm&u_8t*#bz!hF3V}7s z9%1)FAnti&^fGuD+Ha#P$paclp&<)w5Xgu5wtrNK)@GZ*kH{RfFEX1wo!my`@Apg| zWZ>3>ETH57N^{hRBvZzHpqbsin7kAMfYt{$X04u;i{8ZFn5Z112FWIsP_vO5^@(E?#DHfHXbf9#%_@&N9_&nOr7PE7=m z5aJPA#{vp}q@0I#L*BE9po=U!fY_ap8s0?^qch*%!Fdn5d#9LT2^hX)_L_jGHi??! zNe#B&8A#c?Heb~2Jie?~Vc-5M>)XbV*HtOlc^?HuD1a-+%4}t30s7@NZ80voZl>XA z={}0+$e;bAvHL|AXKd?pd+{KMD&&dS<4{#hMjs9ecL~xQlZi!{qj}dNTcS&f+W*A%Xy|&F)uX;`c&3uklGhi4if(xnHsY)BohVe=4 zozX3qUAkAJg%1`re{TOY#(JlH$@|_kK7+CA=fM7wRL3~?H>#pgVs40nq3Z{78HE+h z$bm??2M(mFoP>R$_|ak69RHIJh<{)DME2QY@B}pc_~-O6=uRq z|1<}lTL&_E@WVtmi5m_N)AOHlT8lrE1sJS0%s}0md&W1An>B25BToZbZWmimT2ITLDri4tkgq-+G}y|Xo^5UaalR08 z-Qk02%)dFZo4coiON9q12m{H31kynK8;qj+%2+BAE z>Gcsx&ow#;Sz+it?I+d;=#BLc@MuP*DcmJcpgKehz#xzAa1d0v&`k>A+bI_ess+(I zdagqwIN0Y{EAYDu$a3kPm)1d{*Ke+0>igxK{!Ff5K@|1 z{9}c2iY09nh>SX;GR0F*u6=7}ilmfb8ioveAuarpp00M?$>XBkR=4c0!3d?_IpTa_ z&!=yfG+W-K#k^>jlNn3Bk1R+7(2&v|$vRFk3O(N}ndgf?X8AtKD^+iyGS$er&-w@z zk=H*P8c-8mAS0UBP#fC`>k-m-nCpLA)ZYazpJ9sOd+z$Kko2${8H6A*sjahO1jb8# z7v(Pb8Pk8TuVnt3ADkzaAbx{ykN&|X;y$qvSz zvmiHhe-g$vOYtKwJ(8g^;@gZrWH_=JHK{JxJg!3EC(D?Uq3>vX#<-0S;6|iK^B!N`omJlXEqV zSTv|JIi+3!%x+X6^a+jjL@v+S{$c&Ah`Zb*VG0SG(K9fK+U%*K8snLrKVTSkhB+aS3fuv_c!F9d6sZCq(g?z0!Xveg^Fh3i zrg2GU1N-gkO+mC6Y_f z&2Sz9nPBcDw0I?Xtou(YDxz>=C8gf<$8q*0ptepulazP;hOUk z10I%dsO_Jw{XK#%k>PAm&-NWZ5>C^bZfe0GsCGV!c=Ej_>%R z?H1#Z?-%@SJkn`1N5EaR1ThO8Z96z_yv>)X-6U`__%=dE$5~lVG3QPFvF6ov8{4jJ zt-|%U;Kw@D@ky`a)TbLC^WBN9r3~AGW$`!n>%)mPPCk@xAutxNn4dUBiC>?q&1$}1 zF@m93uWcK|q3;LoLN+sV4GZDI+G+dpQ6qWVZ-yc~{hvxUJz-0af!8wC*rI;t%HN3^ zYsJ``2w{4>$g?uGU{F{TjFA+=kg)lqI0vozSwmi(wSy=?ro|{c4huX43tK1s-^%nC zKj?qy%kk>QQ}V*<_~92td2^x8R?F+jmYjAD2ZoVn6HG04CqggB<~a1fS++57Hq-{A zPAvhui4s2^__5N5ZtMTY0+=mm@pLciA0IbV#d6)?xTNOhr3=-fk!P%~EeiUG-VjvF zhcX_dNOTgd4n16a=4a$NY>RtH#Ooc)yoBo~3JZozz5K^uaf`=vNCcf)s_S&`+H)H5$s6fHl!HN#w^5 z-C_i)Kxwq35tOLN+Vc5B)5JMm(Y+=+lKh5IdWKrr<}HHDtLtoGYP0|YqJbXWSXZ4LLwv@DD_lnAVf!4tt)p z9P$^s&|zKt36Cx^Qih)|?EenKE5>Hul~o=3>{g9%zK^j!Mx!Td>wd8NiE@&uJ)P(E z{8B?{Z>twmy8A+ZE>6j3LJ($w=MCe$_9C!^Dp%(uQY>IT%0sZFp1((|*{KopVw8ru-yWFK_F~+|7AW+Y!ag zY_9nN=gS|^vrhxK>ysJ0pgR}M5U%`OWG zP|<{-A^ZfDR$DsC=(?ZwU$1*0iU=}hI?Ako@Ve(ocX=X^X@BO?tj9KL@Z6tIPFV{$ z*Hi1iG}nc3v~SbAmN) zH@Pa5jsNWHg(IcQzxlU!j4;ACMc*r-m-dhOXl96f=BCYV5F^n%bhF-GMupk3&-7S= zHUTn$V#WnrE?>Cgr9pKv_+G?PDm0It`%_+sZp%`0C4;;huq}yV(7OWO@2oOV5)VWA z`_4QJwn{pz)1=!|I=u=lT;`8s(Vg=w1935OJ0^jYNqlcW7_!?h-L>>g8*Y= z!Hd~Je8BtpkW1fK>|!=tS@if}3w7?0Z5utf>qbj$YrQ!t%7cmR<7if#TY`MHAMd>U zs3krG5|lCh%i)^-8uQ#!a0baLuU#Bp{QZ*>s=#Ws6Zr^*x)D8YaEYPJ=?!sFo@ECE=&9V9~$V*y*$drN|y7H7u1dZv*6;LGE^HuH^Pf|Kn0_UW7L-A6lX ztQI)n66jJ1BN`W^#(+}Db-eo)U`cuy2x=g?`+rVJHu>pM3%}CFzK>%#@w5 zR0@-a$ot4UEW~ABLlEOje0L>!!I=H0wPC;;tm`dG!jlTgKrRo~EmHaGs<@(zZn=)V z9GkbF%wR(z8QuDRO7&biM5{(gREZ{P81Jy|ePGLs=VM0tYO-*@xP%q;`TI z-?SAAaliW$k+m`3IA0YqnZ2`lyRY{#`{-FaSkj1davo%a3Rh`&s$lu z41(z}(SaSyf}gua#p%O_K$u5|D(A;HkIToX{N)2YL!!LQ=2NjT@o@3!;WKH7&$c9h!V#Xc-a@yV7$d{7tMGI+3E_n0CPN4n&jyTtoJ|p@L}R{| zhsHmh5{IKm^(LT-y|0ii5qNPV*BZuiz6?*d!WfnL^?sfzh)!x3j#$rQX5M%TUk?}R zkk*BA2dQG&fUDGxrqJAI>wVN~<-fXyLRj~P&w_Ycc$~(2qnpmjIy-d{&&wGZy4WYF z0qZ>GnWs=yD&}AS4Xi~6V!V$PP*`W(7Dt=17hCT!lhnp2i*sbT+M?SoAh`8QCOu8i znmhXfku4jz%|VYz4*$t9_;=v^@JIJ>fD?FCb^7Yi!3Yv&zpy6I;agy@AjtcPGFk%2 zWMyQ`N@DBH*`zAnG+EkY7h(fvkYtPT7{nJt)F7Oh4k^2FhKQW*C--*4^J`dC+uv0buEfxNGu_1&@n)P0TO*oI+ltn=&+wM6_-mDkQJtLvrf9;nEm9o zvrTemw8JANXop$D5X=iUoA5#rBH2*Hs?emA9YUQ{m{^Fr{eo478H>OszL-$xX zT>Hg&&kERWRe%{w1|(LXbG7=yP%9;{IC!AuwI@Gd^i}^J{Hu9@7zdvaxFl6RJm6ZI zl;GT^Lg@Bi(591esXa4Z?L}kw_p8Hv(o~yt-N};#xd~6T2W8)!PNvx4NB6tstTfoN ze}aXUw0&w>MP>eOQA0IIezXttcuTOMLZ1;aGjg)h))?G{v_Sl2O&l$+?TTVwzxB4QL?pl3&`lw4P4*XZ23Rls-Qe7zLePo;6Jr%i~lR>y1C)!B>c@$&Y&-DR?tG>z;>-6yNau$Q^e=cB$_b$DVE2uar(g^$^1N77Zk~ZRCPB6*pRH=_5zX}V#^|h_grtSPgZQFcw zfw#GNGVL8mk?~Og0c*U~H3c)rv_2UHvr=dt^XKO%#S&8ixT2Bt{R&*zJ0fNa{RDu? zr_9r@8hKluovg*rxWsBVdX~eG%UkUPD45 zP#0tSHBRHH>z?6(=q}ktjan9e)lZ$e?`eJVzjbKSFIT9XayZ8s)dQm%8S2nFMA1p| z+O3~lx$bU=D%AL99)m_NUg~m+DXjl=pmKDWI0ik_cpjp(tsoy!sf^lZ=Wd}<5w~7f z?7%B1$wS~xv=VEb0RxbA{Sx)@#g`sGmhK<`HF3P3ztbmaV=AlnPiamL$A&Nw?_b)) z^UFzNJup_Y!*;G^TKH`!^Wow|5%N_v#|>>PKRGb-5!e~~^M&72$Y6n6LTncE2UDEi zeTmj7;yaXU^1H(0mE^1kww}A;jB;&{uIb59z%94b#eS^HL!*{#I&mKP)khb#(LO>T zD6T58?YCik!xqnl@GMfPheAJ811uVy{5NJw9h>IEPx{IUWYo$XA)*FUIALN4)udD4 z_B&HG&bx%b3cmdC4jetb`2C;pPwT(&Pb2p!>pbwn)Sl=L$qw0GTLswtH>-3`hvU)` zBdA*TI^YXl;L6*3x$rdb;>cI1~or-TPY$<1IZ)lKmsOr z3dTm?$PnxRe_REF(m^8bVdThnwlk)JF4e(>D=3qLh=h&Skh6uI6LVSrwktC*AMmeu z%SGgfJJ~=a#~^Vut$GUsGScrc-Vzkt6)2-1>t!sbfIyhd@M_V&PjH0Y6q^d}Db{#{ zmX@A%IS_!NWj4`_Ad(|-cV;SfrY?NqGCsz8^3;g79*B%HeoZ(UGq&*v4vaSBk&4Ur z>2PJFLZt;dD8URjrM`YYu^V}y%9cSC^X*}T~! zmBa5uzeKf%XWL_@JC3)-c1A{ntAb9Mu0Eg(!n(cc7C!&CQL2IXTu(V}FWJeqZC0yy z4sIz<@N+lw$)ilZE##8th6T)yJc%S!HqCekNUhgPv`J6i=(e)A*+!L`+g&Bjmq@=G z*0pYjGv9=F6UFXoS_+I;MLT6qN3H}F_XRfv5o+R zflW0`qUwJ0i*XdZe#7EVn?V%ek8xNBD0;2Rnin}!+4#W!_T95h{_#3H#?iI0djoUA3Cx9QgDAK61H zu?(sK{OFfPZS2bnz-o$|`WkXJFxwrbxf}4i{ld7zMLY0J@qGF9F=!O+b@IZ zbYs7DgjiHucSq&B5dJBWxtoLn7a*_LnAkb5nCV!doe|;yj!6t#N@z7sqN9*)$I2R! zO0pRS;J(-Z<`2`Ul~f}VFZ_CcYX;%}heMPef`FDiMD+Nt2tuk*~d7X4zRPn%ro zle6;%XTGuhEBAb8Dlr3qL)A7pE8;-$rtcSvI4Kad!Bm4ro{-lS6pY!BAwEUQ5d6a& zzVOt6hI^)~dQKT55J~W6xfPjZaQ+_;@Hr3*O^9`Sx;cTPSUB%G+259w8hy+Yg1vZe zeg3av{)VwXjs;(GjH`si=gI&PpyQ>qq$bjgsn~EJrZk0=c@PG5Ua_(PUZsFazkP>4@PTQFtmEGswcXhq?3&4IS70xJDL#ggjN3IQV zp`Kz4?!k&?(|W%6=H6Jo$>HE6@VTdZDtBIN$;Gu`I^H6>IlEe5uU{tEY!LxdES0c6 z5;~O2)w%diL6BRbM-3uA?;VuZTH+r4XF=g%UGHU!fI$I+FlksS)c_tk`J!0@_kii4 zr^c!x)C)zdsE`JdGU3FQb!nip0nb*^=?(pH%%+&ScOOxiXnJ6yfzNf6|LQ9pq z@KDGS&hbNeNMfAKZNm52=F4y1?nN(-@A*f)%}A$sNhMh_u%KO16j1hc@;Kw_Pwuza zgWQd(FL~9dcF1BR_k&L%`O;^hvW-EvTUH|1y*E@L3c4=4I(PVLeafI=dCjDC!2`DK z9N2Fome~+(t|Xyfd=n|3(3#1{JtZsrnqG~0q_x^_w;XmZOO}ZSpP>+!UvP@6kX787f2#YqJ@Sb{(C87P($$pMz%+&2T`Qa zq{~8K@v6ZzUe*z(h*!d;W6Crt+1wZDj!DEAtBdA|QU53x<#WONk&AIRy!NrioJHEl z(QaLELby*Gl292^dW8Z3rJ8h@bJGH?01bP*f~!XV&XPueiF`g5rs(th$)d$KuU+)) zH7s{RRVy2<^W~%J^B~<&_EYWebXmpU51Ab|D!%Nm-m^?PL?~kgD&IXw-+JGEAC@z`7TU-kWk(5@?ygz}h&mNKnaI1del& zt2Pq_H2D*bJi5XU9)k<+`ap#K8mhP2%ckqW$A*lEq#p*b<5FN3kQn+6=pi`J9xoby zspk|Z=bvov1=V}c@tW8E-J@+%5PW{UddXj#cs=JWDFa4;JCUIDhn$mCYm@tqczK8lp8k>y_OQ{!D}Tb zA8oX;BIcI^5GbY!MBeMuF^u)h&<)uYDzPYJ@|Eyt`q&5pf9ekD^{`Nwr-jt!cs1X3BWCl>t5eFByYqCf zZr<0Madb5w0?lZ@Z*^hC_z1=x9=#S6u&Ifydn<>r%2nOQ!#bs+5x>ywN_*sG__=ds z=rKbs5A<05pp!*~34hDCi|L~a>aTPX26CJ>$EYj35^af7KpqTiS$Mg4#lgvD315P% zxRWRle_O)-p5Nr4#YPE(LxW7x(MvOn)5Z;D&gyBHX9z@Q6ifZ_NEmrG^iGM*aalP9 zLd6DOuAKilE<8HxaTkS0LPc9N8gVDofdj=JrYFGxq21UUe1jhik_1sj!zrE6R2Lvo z_fn>Pc_uI5EMcnW(1&gv8#bhK+*M*u*L~$d7h}=a-&Lw=uZ!uepzp!J1QIefrbU4O zdw>K$w6nv7V?fcR>kb8!PgV(|m5Dy2OnQ)JCh+=luXsTjKqpuadk>eCyiy_E%MAh z0_1ZJHE4KB*rmuYJ5({$##L1ls&Edrl)(g*T->=It^eowf}{Ea?jT=aSo!ELlOFP0 zHB;(sg_f^iqPjHl!@}87hlOa{@y?@*qu2lT-)AR=VAZH%pOCKIiw$FaR%F9Gr-)XU z%ahF!UINr-Ra%};Mwd@#(e==X6LWTCm71r+k&JYOuR{JO;+l>oV0+E#&LBZ~TL?@O zApIBjcdBAYkg&fIo-dlsmOKQE7t6g5x0Rp~B43_8)u&d8v<5RP1SP}jiGls6f(gCp z!%=*A@>Nw7ot2>&67%PsQp9doch?~P#&c4VkK8UaPBpv`2tzF2Iy2@6Fb|VqEno^7 ze&ZJp0~fx~GZ%{N{oR}><9$L~mP-{8oDaHH%|Y(eqI`f z0*AEeSaQ$jFSOUs{E%=HXH*m%YF?gm2TF&lgc@-2)HM=}kb7{HBf}Mk$KZ)$3lhgZ z2Mh2c)1dFCKExR-fRdonNT%=6Xdt)T_bgl({V0jlq&U$JM2jZsjFfM4nKvaoDY9d% zH_rb&lUJ%26Q$)n7jkuE_e((qn-Ppdd0!ENjDi%za2G^MpkMh@u7jRo66LeqGhqas zkr0j|^Ha=Y1gXpBhYdhqp_k*PL-!~e_F}fdYgvRC>=dk z;b9PuQ}>#7b^a4{azd?iHE`*yQa6E@(Q!om1@LK`CV$DT?rfSN*ghY_+Xrqn?q_$2t;Y{s4&<0ScUtYROg{5UX8{5eJ%uRxgu;L;Ok%XiwnsB_BxqKK+@OwY}Z!rSZ)NFP4>LgDF} z7HM`Y;J50NSg%+-%9MxtYC@sj?fdz=D-jl6k_E zttqHJ!oKAp0qvcp3Cxf_+Uc+iH!%4SP8{4}0`w?N|LIYTJC1^Vfc@ktTStM+ey5#` zO=9p__0B=!t?ALJ*)#+aSeo{H$C%RIF|KN=iEGKadLo?4i=Y@{@wRd``6sVS`&Zkv z`%V*)Q1NHfhzkn+eRRQ_ZxPIpu)sA#5d_@v*-crdyGN1yjl1A;6mUy{qJOiLt8n_dSAGn zbEopsZ0bYsPrFDYsbR^BsaCz{`rw=%Qe0-$p_o4T=vi?+51Sc8Yh?&2v?@ZQyi)5v z+UXQaft+?yN`6aSUDfGJv-a0!7`or;B3s`8Xr6QiY_+Y{iF&--ML&B!yus#lUH&4I z^kqtw7cJBmX{0JDb!_ajz!N-O)BgLE*$V@#xcIjlWV#{}4ssze*N^-dp=w+jhf;MZ|twKUFHcNzH7{}IY5O}Ni zw~bN+v|R*HknD#xsEI5--L0~3d-enfx0X-#Z&Wsk|7u!8H~GPco;i`kt{K&Dqk_=CjpPYd%eJT!lLJZJLE;cGhyV&v z#l(%TfCp!uJj>B~hG68h|H?!KEp@^MW1McK@CtOJQ%J5u3Cd+@asL)-;vv2mHEqTP zcN58Fw|{hc%zdW*;~e#AZ{g~`+sExV!zGT+&KU+NkP#thNW{3YL;yt;1We`J6ZxHx zmqp_wt%YCj;SLtcklM=vtAj0?uX! z4N9b>4=kz6YI(%E+7Y1hbDRRV*c>blpzmZ#mXZ>38eP4K=`QGZUXW(YYH@6>Gn%Y# zq2MTu+=j2DoSbeAiBxDR(_-Vf+NhFv-<|T{V_f>Gvu>O3;@0b_TpA|0eb5udS@Myg zaNEYSdDUB~D&BBI3)iu~0$ncIdG3wgLWrM3jy-5h+rR9Cr`8}n3-n0##P#8a@MIrP+eXdRAvN3kmC+YmLH)0e7L3Fd5>l0xJpBdeAY*eIKcgMV zdNfLh6&{qO5>(pBU=Uc!dS~S44Pj3EeD4UH%5zzF@R3a>9juO-7U7BU0#tWj=jh*q? z3a14gZ)FylC0lWQrg9e?>SNl{cx$Pb57uPbmYV_bCjZt5 z-G4*Wv?&<*@{?Y(cGn)q&LH)zcb_Fnh8b@-@T*{m5kDm_z&05%5Bp49GcZ&^y9?h3h$^;nsu_u6_{_IMgGOMowT` z+#f12Z!Yjqf!>J>?jMPn0+5Yq4k~ev?U&*{duMhY`()RPT7Jh1sV zanfvfy(>w?4IyjN*!hx)L&Ryo8Eg9goiUh(CT<$>jtN*m3pE>yf{{=7wOOu){fkqM zz_P+^)CNtyC53|n>(HQ&3WYM}kniJj%=D9|)H1)&pDs6U&W1;b7u<~9zn=}6pCtwG z&t8Ls&A<;-rk${i0AxzHWyUoc(}}o`m6!YB$5h#LR*Oqy91Ug-znw4gPiIr9PnVik#=-$M<7{b-B^*nlYS-z zE7lDhZ-TM??ID5&6X$Ps0rz_78sYyx-m8$GCn_k?YIAc zU)gA-TBoBHQnJrtUjJbG86GScTLta#Vz;qV3wyzYg#dPKKuyu;XYs7)X()?^f&&MR z>()Gq951JRHA<-Z72ItZh$R}8`~rcdZqZ8i_ zB>9PQ8s)DMWs@J?LlH^?r7|ZhJ{Oh2pPNJG;taBq&;R_>oVEPLzkk~jB6_mCl5qW% zt84Rt==tJ?%Phg(*0?l!yASPU`$a^E!2px?0pXr3IwB#L4tLsYty+@sse94Fy4$Jp zLXpmoOgZKnxpNv5Kcn0v->}coP#^(SSU3fppE?F?AM-rmmPm?5!GLdF&$4l{)@=ge z-7;S1budq)m%?L=EO5wfH#8G=OR}!V)FkR<>=9Q~H?1HZtu>P8#+_pMQ=5-6XySdM z175fjeMexS=YdEeOjf%?2nc5ZDhP0Ey`>Jyk7v#B%iI7?(UVC zDz5wTfY>HkwxiQvFo;gkV|C{FiL{ww5A&Qf2D2V)mgivi3H4jwIp4=!`@e6EFA@T= zXKnt8&-e@uC8b$_MfNWn%K}F4gJa)?M=LjvMbEl})QYinjNYOgm)?O)xFdrVpPrU~vUcap@=1))hO^c*SeHNqhRkSG|-cwU5-GlSn(<39>7 zzjWVY=G#P?OlT9LBLmXX8^5<_)Nhd=Yaw}KE_6s&SBTmzB=ZBbz8UT^^@U|7lDPC? zWlnZ{H1pja^LujeTlsk9KW@jsGPL%`nW@iyfSc**bF%G1%TSKmWbEEVUT_911i3fa zZ(^n)C|K*NU*>H$<>qXXjgC2sqmXg3eLlv&?fLF8`7&)j7(1X`dRp>}vP1fZFIKa<=tyU>~#@D{<#C9Ib(q;D6yY>;?+uE_Xndg?MmBB)FoF!R8l58(x$2(m^1 zF|0ve=2g-pNUl%YmqzVWkfJH28%qc@1i0VuAyTo@GwyE^XTx3(_PJj*ZFyK$X>vuL z2Br8Xcs$mNFS&fx$DObxHLlf%0+xeg7q5RbR|X%QUM_nK8CvcBM~A!djg6!LbofW! zW9&PoDcn54Fb_@X+sxWNH7Yn^Gm)@&!=NuUdM3kT^z)94*Nc5;^&vlMdO(jrIiNZKq?K)h10fheEyjXqY5bt3 zu|$-LMVvz&2m>HT7%r3tUH1}kK$*#V_*AnWC`f6o{Z#CXwd5sSybW|Ld{1EHF@SHz z(*W=8ftj$)6NdGmGCN90nDGIzFJoTAq2J2sa56cZr&d!sxwc#18e1zGMT6Lmth!5w z=Ehqd0L5g+Vszi;C#Xj-&3C390({ z{j$4JfKN8hDDic7V5K?vx9E$-5}HI@#oQC#IZPX64%W*;l{2KW_2%`fPqusmo?1En zKi>bOhadUR+d_;0j^*m!hPK!#J|OXE9dgQN4z}dAAKCm%O9ofiV~4YCiaEFxn+XCe zN}5`}u!-Z2LKj4r3a|-%W=A1LWx1=9s;lWLc8eP8F`U3-n72LMpF6m6p( zFNdUxeJl!rUAB@FT#RR4X|Xt?75J3LdG3TRiPz1{)5x;(P!AIJ{kL5&#G7cI$O3 zkrM%WiV0Eo2G}e2MmY4KyjG1D^`wp?7x!>z15ur_aet#^SXNR1&Fcr7cuyY z^)GZ3_ePJuk(irdT_F&7XbLNVd)dAHK=1GNU@E#PTTWbv-&(U9n{4BoltGSb!G+vsz#Zc|u~v zf0EMcnGJQH?1*#tf33L#37DX;34v z2LHTlDHC0`dS{69E8PzPBL*Eu18(&r78kanFE1$L`j3?UsbB-2*C4&Ed!)Tb66Wg# zLY!s!`)~yu=xw&XH-+}D)yuV*2~dy}1;u@*ird+Z zVSmf*Ghd=6ar#Um(pxf!3p}wo8<9u?o916AWu9j(ktbUuRK~hDiE18 zIhF^}%HV95OS6t$T9Spf45n0lNOxY;ATtc`cBnC!a{Kx7)4sf|FXN+H;Mt7l6@Q&MxF!+JGG9z%HOw)zMZ=r8i$ zM*U@kxFeu}1EXG!WU>$|WTD1s+6KAaD1IlU?8XAE4I@DaBw*$H*)7ryAI=m-;aa8O zJW=gv>2NW0%WpWp9yfHI>*u=47*N{bW9%3rW)xWX*lh{%@^Xh`WtBoPcyJ1P~w{tBdqd@V@Ujl%;Q41c3U2Z8+1)Bd9Btcj<9Wk9-@{XU`l zVX8Qad8_q>j&kOl>_=ED=!SVt-)7EiPf_TXnS=6ZBF&|4r~8S_-ZT+r35z)9<% z7TYV;NDC0!ka*Q-_O)1w{n{P0V}pA|p&*=M z&Z3S4nBC<5djxXdsTZtA7Wc~(*K9XI3#i{%RPbpa;cm7UB6qHpMzLy4*U4IMl1OR& zrPRhCLK@oX=A|z0<$^KZke5S>A%KoebsV6{p~|aBj+;^>ZYgrxqS@fbyM3Uy_kQ0#+)S)FmW?mY4o}qaHkJ)3^lZ~>@J|=rf5S7(r-HM!>4UMX zKd1?q@;B0+ebOQw;RT!ubN?H2qEB%n*apO{R!AgL6#3WI7(FAu|7>$HNn!5n4HBub zX-fX)q$FNDVbk;w!gQR$fe?+&;T)EGY3J61*f;d(fkAAP-WKd7T`|!XP6D81b7pY+ z`GX4da1l<=_RI zqruZM2?p4RH@6H(x_}F!=Yib|1>x&J6APgT`~vbDuovXyW%U9McFg6?T+mQni|uy? zy8XuUz*GFI(Tqn+hbbU(5D5gOip2=j2hHONLP0>BMd818aRyDCc(F)4xzLnMV5+LR z_cyi#_gB+1Hi%r~{f9{{_wgW~=^#q<5EX@;H==Gd_=U4B97!F;t{yE$8c4yUzx)3l zA1vb-lvS($tCit^wd25c+$9M*NSO9xX;WPT!rR;`e&CnByj%xl2Z5DU-92L={nA&B zwP^vUs(c-?%0XwS!qV>so5a)n7r!eUDB2gRennaEcr_lr!+9b_4Y4ul_+=M_TrxY- z`He?*ffBKvG|0o^U@Y3T>l}=6EW?B8y-!;fDT4Qt zaXkdH(6pzB`#`{pZ$B!ah~BfUKOnBYRa-vrrDu+hfl=*vA7V$U=;A)wG%e^pe zDgvure0g?wJ019=m{r=*gEwo_?htIC&TH8B9JVp8Ih$S7i|>av|Q*ZKDIs z1ypd#JU03Bm-vm6fVIB}^b#ZujxWb@Uh8Gsn!v#*fT3U?N6(pxV5gOpOL>yTky=R? zwl=oFwQlx+yMfwjN%E>C8?9QJi(T#a+1H5>UnH!vqO7X;_zT~j=5gBc70+!R~j*T?z3gi1_5&W<{SOS z;!Mp4T-&wkg--aM=bZ9AY4L1|63)C?vuNqUjYUXdVaO9;7OxDolNEO02H)g5R zlGjo()QWS_DqCSidA14&^DzNhH8NEosmsxMX*-Gi`D+27WJn>3zt121DdwJMQ01=R zb1>;!fi*<;ev%2{9KocU&CM#9lkTHz+z(mh$+xp@iSUzR&*E`nb6{>fDSXi;%z+;u zdedb4B0<`lkEhD>qy7qGq^QKh{_XKtZ>Dm{DT6e5v#0 zpAZiUB-Zc}PgF1v$N_1cSnV^_EOnw9#zCqy2)jQt(Et91?d8CT_dMVy~~H)v>8 z{ofer%r-XV2}~H0d?(nqLIfPzi(eS89|Y_a19dXP8s6g6gb7$omUGRTe?JW050NjU ze4#R`5woS~^YV~eMxHJ`HN3Cn?847Y)B=R7i1OminECelm|kh zvN%t$#bTvrA9Rxh={mS&hUv|vc!RGgpuw;R?4L-+XL&#YJX8+ay{5pvhJ0ndihh`r zNRddEluVvMG#P%l^AN1NL-}dC?0UTcD!bT}XM}vI2zxurmDo!en)Ml-Ih~`DvAFEf zhH;CZ;@G!;oB2(mVC%$JzII`D({Fo*>zT=YGt%RqG_Wk+$>d=cd_6z(c_wL?k+j&0 zQ>T?bHa)D~ck3VN^Z1v{`gy$%H5xU6cwqAj@|Jlvfq=+ab%J#D|Dd6Ib+_Tl>GQL0 z#?BWr`_pW~R;Dk?%%^{Rkh~Z#4ASn!I%FIVa&ppliGEX=aTn5ce8lkl$2Z0bzPFn$ zE%KkPg*uGn43%nq6gO3pj&Gv{7FFu+H_eslzj%iS6aOd<4yA5N#qMJk%C(WT z!yoF6vc_ASN#SgL##@tFWPu(!H$OXRXiL;4jZ0P|Q03^r@*T|+m*3eB@Fgll*l#xQ z`sDQRwIB~1%mw|M-C8mN559=LT%f%C^p8Ms{b#D0^^uypaROeYJ@uRrq%(Hs!CA{Z zOUNnp*E~gNFI((^RJUb;9wv5t#;z{Q5qOYgr&YqS`-t`ojB1edC!z#;N_J(Nj zq#TuwBq`&6Pz)`iaM8sfTqH13h_B>>gn>-#20)xYGyFH8(A7wpWL+T54Y+N+k??L@ zD|tpkloqP6WO>En!-%iOf~m#dx|>07To;eR<>0dyZp{Lc?@`&FD9UbP$asu|A^xw4 ze0IvoU(;An_fn$r(r=Uqg%>*{Y>w-Kq)UJbIYZPT#1* z&_i>Gfb-IWbetfc`8d}bx7-mAUn5h4vhRKCYj`EWVWthT}pQXA`{;)tF!H zPu*ycdx=R4*xTgb3E9Qi)M-jAo1YfHB@B0>o1UC+PyeY4bG5foIN*MsP!S@xuSZ2X zf@c6XlFv~SH4AzTTFN!JN9sKKp9Z^=&7i=dda-v!%IC+`57z*hqeqp&91v|_%i30u zf@zoQvXL<>k}^F{Db_DQb_VhQkMrVQ%#a)^FC3}KGAmbKqq zE<1p zVfZ#KG$@vmDGV5|0ip*Snuaw?0*CMQ_dF1xvXi;us!2#y2ty(bi!7NKI42UMOBo{R zxJJto+_xw$7v_q;rHTH?mLS_L`EbIqPb5V*R~2l(%r%NeNK)D$*jHoG=SgK~kScEUhZx(!xGHhs?nFrxe-aBbss5;i_4R zqYPqkQBB)3d9KHsGfl_fy7jH;YHgDY?%k&t7O1?>hemUpo?s0)xLQy}sQ{*GVZc=F z^pDrhrSkd2kPVosKQvr;W|)Rj;=$O}Wl~}r>x_pl?!(ZqA`y}bODgi3ox;3Y#K%AFg$rIcI@U#LWN(A3?vH&(t zJEnr-1DXj)ya3^TO9cG(vp*Q*1>`-85g!#g*xHb~cLII7-Z<35IHXDcK-Nnq&2-YV zXiyKWwT2tbDa*${r2uGFpJUx0GF+`_fc&oDy4o4L#(jJ?6X50C^{Q*|mKY#$FcR7_ zU#YT_-teTI0-l%7n|J;A%zP5HaY~@$ znD=Tkz~cuX#R8BXg&a87PZSzKPG%w}1BK)LEQ!jahDHpS;`4Wh;LXx`9$*(K zWM|KZ6w7$=8{&^y6Wd|2^~*C>Zl|tec^Vauw@nw*Us3`Ppk;cZ{^QAbN7KXyv*H}! z_$~JEk(?$aUM5}xbR!V+t(x;rF2EI7NOM(21AY)@R;MypE@FQy4TU)Fd!)^o=ovc> z2ogW7n9aKW`(Iv8_X%PEl1#T7YcBLNEJJz06P7OLPjWK2B*fOBy#`!+o~(#hvo1j? z1-P=!YGG72vGaU}gnChU4gmkHud8KtUtx$Ii!4mkDwdp9%NxCw+l)H+xZcj#^+ZE{ z+bx|$7lipqqKhQ$X4iUFt`Cw0$Xt1MqPmn$XOk3EfkpjX$hk$p7f;5P?W?k#F9|Oi zyRKVu&jik2%+hq@;w7Ubf||Hc^dM3oF+c}YjIrItf?Gn_;Jfn(sBkT8(c+ zq2JV?=W)O)u6{8Mq4jdF#2`$c!KP)bL)@vXHpIhn78177bJ5^B3dDC=OtW(v_{rVY zaQH>@_3VFWFukKmPg)GV=w;ZD5}?$II@9ExXDZz=)0DWn{2hBp@AD<(vdi_jF~4zf zI*|9@u>0)`ethsQ53sBVo_n)=%x3)c18-yc6{gzb(>&K)91&ZaA@762x=0!ly!veV z{?pHR=ur?<;jlk=EIx@rf`0@ThHhFdLSHGOcRD2ZzJ%QlzJ8J&Z_kGq&!yx*BtH&E zA(l6GjPm&gvBL|#83|aaVyFo8LX956)ra_S6jy)6f|C-|v1=$v_3@s^G}x@aWRws! z6CLj=w7q%za}VwBf*-){FY4t2C*JuQuRHH6|D!)GTCg85oq6L++$RIjWfDf)L_Orl zfU7V>{=7)0INw(eFNkgzBj2KCe0l$tFXEdNAbkBS8`+fAD(l~00o2_6)0!1%yZ0an zaAsplTKgeIjbcnUnXL~1Nm}};RXS?EJmV4Prk3iM4hlI?Q=%UDmqVyT{Qk5 zR1HlIyCSit1E%$?A+ypwNp3TbUi~x%7Y*hj_l|~azr?{{D3VoJox-gLNgj4R;{N9f zLaSjs*jKD4ZFlc$lJLS&seD~ekUIA$bjwJ(v2^ge$-6m;l0o85f+f$$DcDY_rEUP# zfU)*`CI4f`bM4dlfc^H#y}y%ML34;sw7dk}-E6#ulVNq7N0a z+z}^>A8VAYsqm9PexQ1kCHERXwHna|dx3wBC0I$)#B!;cT|p9IB=wi;b!S5gSr%># z7*D9qTVt#mLnT+Cx!q#(I&^|Sm@#*{C3nIIr2O{na1I4Cvt8>;jV{ef&HLz12$O~> zD_E8jmx_BYHdIl;WSB%Vz~q~kWzc))r+uO2TnNM2GEM3s2bBL$L@n0EQwJpSSmy`)+ScBd7cg8Ms(g~3@)@TuXr9x?1|1~E z=Fx|o`|u)#pY^Yunk473QokL6p`E4ntxF!i8sqy>pQa7l8&t4@c9dgT6FD z0C#tz8MWf^`!3C_^0{S!_ce3{FVydC&lr`Z5R*qoRdXMKnY)BjFxtMwiF-K)QC!fY z6?6V+*8b!ayhY}TN)_XCStSI<2-t?DVL}T&R6oWmya=92|9YJABNaa~D3A?^**|%X z`QF+vrteth?ba0Tf*u=h(kFv|#PXt!{7@DV{JL|z6d2;)0Hh3a?`5dp$zi9X=TlRB zD*EjOf$85ui0NF6X4@0=%kCM{pe@CgwJ=@11J(ZVF7t*rh`)A#EsrXpgMkhFJRQa9 z@0s^G@WWs`3)w9D-(~Y8nC{D_6Z2tgfVCntb5AqB zY+!7zv^<|`Q!lKwxwWnS4@3GZX85h7j^6e6_qzm|HUCV0P4{T=;GoXiyFMxREnw4n zTWu(P_1$2VZvWPG>9~+&@uUaqD=wclQsu4149iP))H`y<2>twKS?nZ=L5X?8Bbmj6 zHIC1Zgbf%c-nQ7V3~m*2+vl+%-VKPpTCEy3>o5ftKV^f_PHlkb>eZ^M^}Wk*l+nQ} z_!c?)p+4Q$glVfva=5JE%Nb40RjheD1cf2e^d8Lw4it=r_l34P4hR2z7k>7JdZ&5Q zq{r=QbmlDj&fkxlSOJAV5|Wi*?Rw=(F%u;heUt15semIK^+9q%=^p_>R3Q_IMuq8j zfu~lup8;bJGSx^BlqyyzD~U{(h9m<nvvR*ypuX2^cNV9>1 z#mGN#;Bh;oQ8{7$7vo*|^6~b**zZ3+zD*0ISmw?8_To+*daQEkf#V93u0HTPmCZ*} zke}l4dLqO5pC^?wxk zG=Opav^y-w#f1d)Y`wZ;(6O)x^dTzvgixQEE^&$(Q9{*5Qi+RyN$Gy`-l5o1_sGo& z6TmSRWmXVs@f)Y+B=WX|eqH8D2s>BZI{tZkzC-ENU&704EV{XMUp-q|>t!Q#{O3xt zA^e78$t0uUIYqrRhJ(05xHX=qfwxmKOxgiN;KKH8&6$aMS!634^w;DkATHlw|6RoN z+!sl#r|!rXO`LzSUXvT_m}wxr)TqgN+YE)6l<=_XVDiQHz&LSJ(ejQP}>&AyNh+$D?DxXN_ltBte6F zt-pvF+fHXqcZpuS3OD!ro%Z%^ncD`mw_=Vnzut#OCok_vg9b=7ZxLZ^=LPqaYlSVQ zg;fP9Lso&UE&PRmn66if9*Uw&99-lqk{dmOZF_|N)+t0sxz~Nk1b{;)n8ao}2@wGd zNy>lEjgKjDeN#*%`bo~ey=6SWccp_Kqq0Go;HXS_DUi;ne#T{7?$dzH56U~|A=e`Y zlM~mto$eJkA(rjaoo$A&-X1OBm=i$+@dZQDGB?|Rq75$q2GW~;sBh*7u$A2Q z!b3ZKhf^8}@A|s~oZ59$6_KJeQUT!vv2inlcBVtkV{(Xr1$N|QR%g z(QdCL@X3R+Fykc*u9pagu*8CInRx>tAAQ_aclG*G>&RFKfxERpix)QcsUa2+D}2iS z9m~ngXueoekYYrJAan@@7li}JvBdJer-;1(}2{Wq{ z3aaDFiC4JLC}1pf&f*rxkF0^nz`yU?e}`+iw?bmyAIDT+?(f}$7&$>8=L8z^Od0j6t5sea3TXoZj=N;aU+R~=oN`2d`2J=@Qrd{XU@_)-+bwjt)---3*+^rqR*c$ zvl2hheIyr~Ca__jM4=_NhjS3LMFK=`Y~rQ!=xsmZZnY~jVhA< z9})Bp;UZJq`tZZt2l2+9*7(;OzAJuxSoHlNklO9>2ykpWdV1(@MT1sH-DK_6BE5g_ zeiJbl%3P|rP3~u5;|PMoZy**-B|m$6y;L^dVdLgr(np{>(%f9a*JM|dlc+cl>#_AJ zCs}AB-Nyx0G`1fH;$b7RK4Y6r=<_zCVc{SIU4P!8xV@P1wQ-!oR>%IwX@37h_1*qQ zT{D9~htEHDDkqs>e^^uPglGU3eMm)7%=G_xoCN+;dqkR{N{Xd&$UoU{+v3eeB_-+na&XU;O#5oiV# zoMfJYcY|#r=;Veqzjch5Gy9gS0 zL&x$h3Te5fraMgm5NG#4uhp@(3{2wKa93z2Fmc-NKl?yepQaY>Ia}+sqSB$0eBVRz z3cJ10&^;~U%|ktz$B?yhuhm~4u^IIlUA_T-nZ3lC7Tfc1pC^QXBI_0T{=L8f#TbfV}rL4Q>h(VrPhpCj2iC8OwD;9k&y z{zSru-{o^>4gj_pF|vT4NKT^I%m@!+WQ5H-hDe`wV9kVW(P&QndPzHI-y)#z8;eiF zqFc!MSJ$&|31($ikiRd@Ylv54gh1M;qK7K@I;q&I-v^|?vE@@4FpdYV5IHkaWP&Nu zF)KUFxgwj2L?Q`XD=UrE1A;9F@kk`NCYgXSp;n(2+cS|0oYVx?%~^}^8Zv7t?@VMk zM{NYLPO=q5zTlCrVKRr5EqPg*5^P?Mr_<(AaWN>EIHyovbPS)TWs)$UnnX;NS*Tx! z2h>Nm@jsKiKDi{5SoA@UXF>0jD zeC+;R`ZUXM@vnS`_Jsx-chaqfHPWJV?t3YbE3HS7raR+caJ|aI+90u*t8vk8#%eyM z-Ho3s{TuTb_D-x&jugsxJ`at4NiL3y#b%|v-HUg*^&)Rht$?!% zV;Cr`KLywyaIT`#t5MvqJ;IPsGh_xcX6o5JtsA}_&kwFZ`rgU&-e5;kxGWsG-Z-E) zUXhBR#@OcgwAa-t{S=RntSQy9#F#El7gujJko97PW4?@+7CJZ`*MyCdG%%g;8g3ly zz2Y$Aaxzyp9s4Eo^b3R;hVDI*1uAGUqADRUyeiz7b5t60I@u?{XN{S+5r9J5i47IZ z_!T4`>KTzY`BlTvFYad>lUa##AWkGajpzPt+PzIy#IaF+ClkuMt}g#@ z(9*KRA#;M4?aY@PLfJk7!gTEu}rkij3hcwGlJXpD>8yOoZxz zdLLPlDzcpD0AeTdRe_1Q8_AEN2EkQ$x=9TePq4y3(jc}CC2K8!QNRHey`FX%IjH)y z+OHdM<;l?W^``NxLRoYnr2Xa@Plos9Q`}KdFua=>6eAbX4u!9&#nhMM!T^;dwJQv! zhDCAFNN%+*TYv8(bUhq*bh5ft4}=COQ)t;svO`Pc0At}ChKv^M1a^cd2xTN#^MgMm z<`r^*`4#sJL1kSd$QlUR0mF?!I-(Ue7&dO*ugH z3dDtoQUN+6s=e$J6cC1UPv-rV9avVdCUU)_bfA2W7#$%t*cDnT*K|<)zdfa)gfK`E?8(F zLSbLXSa|g6HaGr0H6Gwbh10I8vyirmFUZ$1nvP2m5Pc5y(JVWd=@(lNKp>uJu^vlV zd4>mer%jj(mN4RtYtYBZd9l^dVAKb3s}N?j!$;$d7LGg$UH%&cGxdS1dC#J_mjBE~ zmdnkhVN|)yI0(mcW#x1nEbS$$9VT*x5)Q|!C&r||in(qD!o6z>L!O!&j4K~Xx4t{& zbnQ3Z_mXj4d5SUU|LgntN@*@C_%}XRVVli5)9-r02+Y6@5Hs~pDW%4I=I(9BunrtD zSB%-F9^`QfxVTuonU@*iuiSr1l>Aq9lTZ)%JkRXjERLfyrmIC(c$epy)TM+KTobjb zY$-1Xn{X}qk7|z!EUE_o`U7ONIxxCNfe2oOg8=|J>o*O$|S;Hy$k?{nWlY z(cYb@4gO#ra&vlim=N;)b~FEW$l_{o4nW8f)d9EjA1p%w8753=AvnoQ*oTw51;xPH z>lGcQrmvQxL%y%%Ph<5aJ7aM0EAo@WCx!3oPvyCut>R`V-IG1)FZn807ekwo3N;875Y z6O39Lrtz*x1<6n3W2;U)%lU(q3DO15R<{wf$ORXHcfS;{r7J>{d*#GD(g4drUY}W! zpDh#E8tjPwnHdHPH0%KBdi1G46E?(yvU{r@#t;b=i_zMovo^W+VW?N}0ars6S^WwV zlmGo`_6P2L2o@ZG<~LwQ90XN!F~9n44J#{cxCgFBprRMHfbwZ0P5jGqi+c?94tK+k zUH_0=uR7}=*Z+4zhf{X%uEp*DRm_$VdfzI$GsiPhs>TttfOPVkBUm<==cOiSvovC< z^SD&qZ3uMyQ6iK?;RI+B#F*yd@a!YP&@iT5_I3HEk1g%mt$5Xu8Vhwk5)g3?30tl9 z#-e3`MI&j>M`3#!{bdQ~kHMsejA1C$!%PE7HZ+!RMdFIrXIVEIBPl^?ff;H8;$vn; z$&K*-K(ml@pqX@|0B5;fXUQrqaP<*P3P>6X){uQMIre%&$zf8tP@rGFR*5kJ>+>s& zaSg^=Exxk7iq(XL*5-A?C6E$pHt`k?)H}8{adD21_ar{X7|!MfCq134bhGEYiE-dLB4ld!G) z&Qn!$*J^6%%vpy1VKEvO2fxRxViPe6io<%{$=kU8XMI-Tr8N%MyL_#h?KN zWSG#)7x1njy>@tw2ec@f7MK)o^K_Ns{VJw<*HgW6Sl?`svu++u(J$Z|FG7mUX6NGx z544?E%C+JY#6ZdK-b_p7`=OXYRBqHjn=nN*Io-J}8fRV)67 zSHQpW3KIt;L^Oj~HeL#_EuI{TqgXb!hhB48GjFYjzxyd1iVaO})n)ZA^2F^~HT=5m zxX$f>!%v&D25J~IXF%}?2@o|-rZWQ3V$g^%zlRDJ+LF~qK<^h0g%uC90c;p=fnB_F z;+qujqU!=%lP$r9*rasNF*DKpZZ#7k{^qw6eL}OYVD651e?{r#;UwO@_Bq& zDv&d4Tk~Z2E02L%N83|(MkTwSWn#)BsskcYHQST60K4AW!Xv6MYmYS&onpQ-#Me`P zTlZZ+?yr7IcgpeNVVysW={S=4yMq26ZFXtW{1SPOi69dqS+s`Q2#Mr}KOU;4E%_|x z>z{*LA$G$4`Mwv|)^ja0Ap^3}o-DX^e5P6=mGnEl*=4?mK>LB183!UzRAxM5=Zj=2 zkADh}gzyQs;o?qhp>112@M0j4Bc(H+Up)Sd|Fh(!Ve)vjnO{cg_y>&Q?Azy<@$%>u zfCs_aM+GlpRya>KE3&p^^Oz71p+HGi?;2A~Wdxf7kh=Xr9;0U_D0vp{B`5)e^M%B% zaS+u7Cm!)ZippLpTGv*q0OkG^v|Eu|J)`UN(l=b zFZ@b#1OGCGd}sbgmYJpnd+KDRCCBKRvV{1A7s1KV=8P>?iet?Uek+ z{mN)bWp!`+o60yx;*Emj5PLoLdrlfTItaB%4+H3d78@Ksc9(me<7@(`gMiQ_hc<|1 z;K0((Pme7r_M~v18zJ%81yP>lu7)wd5MNKo3sn6d8 zrz$Xym>Hibc1I(7wP!lMCw9)vl|G!wF5QLB6gtoD$Ed|b z#>9Y_>j~Q<{i1Jae_j&NtLM&?$n-yPJ=3++dbNa6j5sYDr}pkIDbIPbdJj993HXQx zqYwzWuE}@#&3W%lr=SW&JR;n0D&!NM3vhLpc(JypqP}8Gd?$3imslhb1Fo~9&KJjk zDotLZQ67qQXM?~zZhC|H0?212N2ZKsPG>PLeGDlgKU6orCgz~fX@j*uaS_13L{+NK z#?{vg1^jvZ!zusj`ya`XzK)es{-rk0vSrUPDF+2)>WDoZBjdT;Mm&5?^}5gUI&#^2 z&ODtQUe65-t%g?OkTs`le-{<@(nM0sd9yhi&&iPG*?ky{Iq1vz*mMiNKOuoX66qxY z)}tfocs zeS#f*QhS&t!|aRxOAg53=X=ah8N(gBIy5XaVm?!CN;RZR#Z(J*5U#2I8z6ionbo7C z@SHt0pbGZqE;jUq{w9MFx=Y3T$d!}X0nsvz7}VcOO#IjM=VeV{^A|S)Eee_5njtUv zP_#L-1eD^6wL4p{o^937_r;v}kuNPke20zn$r|7BabhW%yy5vpZgC9|-s|9d+Suu= zyX>?@fu3hEVijBr5o@f4bejW2@=cYOzqgF`>SDf3@=c~4)WWenClfHzdPZY|HxSwL zLZ14UFM05UtyOk)2e7f0@efhCT}Mkd~0{lJJ~V+V=CJoFREFPrYn9{!Kh{_}ic*l@w2xXj ze$a|x;}dSGUY6&=xE5JD#BdZN(1l|0GoUmnnZU{W1SomQB$4YG4Y7%`_bVc{4iA;# zN0*bnt7LLwtKNI#S1sW}%iI4UF*kxiae1^GupE*H@aJYACw&kI`Uh#WvxA62eV9PZ zl8reQ0QH7ALCo7L)4S_uw*>;>V znZB@OcK8Xo9$#JsefQ}i%;5ju=8Zu*X{w}m|KCgUb?c$7Uv7-(dv9guqo_z+i5}2B zvPd{{ z!bAeHBu)Q)i!;BG1vJ_w9EjjEDM0&XFyW>M4(ZO!%L~d1X+}k}jV4TCeZU28k`c5L z^!B)wp14@a>Sp2iumKqCM6p|vwUNI{;@IyysjajqtY!-}Z6&4)M;_|gqi5;#U+9yp zE4C}S+U;xq)RIlf6PQjskfpmNmQGDYm2~6wS9%J&exy9!wbDG1@>c%#oSCuchk% z6xI78%`*qMhP=B;q`|Y4t9&5MW>X~6vG<3(NQD1MkNM>s6xxmH2z7)gq0`Hlqu4+? zDEi6^Oh*;-^7$7POOqY@7I#Z~#|_W1)o#9bdN%l7*sZ*~mRjWnAz>i)pC)f~h#V79 zv_GTKVR`C6HeLc-v^+2r$MD>$kNzcxAo~mrDG~wazJcR3n}8A`mbtbQ3pSHJutC@f zftq^*Y0WFF?j*9=BgUWv+hqcCp{Mz{b{0odWUf`5mU1Bg?7)aJrujmgS<5S2wyS5b zCwviTd?g#k@q}m#I1op?r?hg)l@xDEV+74d6rD0vvVU zcA&pEyXSfnbYCsC&`5SP{*dRXF|E3LvvN-}If}n`{IB5@=_3`n`s8!eBmKZAezv4x zXxMz98+14@ZcNsuTyW~d(P>xzgnO6JSLM!~5U$VJ=osPS)X|O3$aT&jI=JH4A=9!ZiF=_Hha8qoD4JDfJ?JPevukVm zY5(~BTJ*Z*`isu>{ES_1#$IC6_jB&O4CB)luk7y+nYLQIU%fh3-0AlVh;Z!i?>rb@ zxIZ(*^gdH|*9u*0a5tQH+aLaIS)~-hWTq`J7~I%yQ$ek1OSH5Yea!cz^S^2Z%|ZIf*0Bkx%|S0=GUzPA@F4cD_pXYaho*2CTukDFwEv5&gTslqIf z@+Q!x%UonaTzF3B7L*B@mRh+MwDDZ2>|JThJ$o#HCj<$TW4qjv%(_3Ap{GZ4A;$a% zP54!0wW1>93nbVr6B?7gT7by=XGe4MNah+20QMhe_%fDFf2!3({abVLTyE23L*$~f z^`l81=`bwfy!woD3wX%WtOE?&w9CFI^ZTu{GX0LHtInfb%X- z;=d4odF0B7x5Q`Dn>ueCg5Gi2lTIkDKVUy!15u|OR?&`PA4yts_&Zl(!gKr4F(=qH z1BX?fP|m*sPSk7(%89smF2=`6FGdW%??T@Pr-4(ErXYGuMHKVk{xdHfYPJE;$WRHP z11so&bW(!4^E8bhtoY-2SeL)GkLTWA6^!ojYr!r!dIrRPg0?z^!8wpK+>^Kv4UPsW z^Eu|>y~JRf&Nf-Wg00V|E|nDh8pKfMq;(Z0#EHX$jzW0pk6|!IPD)gybm)!Tg+d0Z zI_$6Vn2Mp*KJyk5ejVth+m73Dz)lh0eaj-t37h%Y!qmgIC+MNP$9%`Y)IYHrUAdR@ zUs{K051aDn6C9Zq#FH}Hi!rD2hyr(6iij=>HcuG)fGiQDUr z-`~i#?N)TM8t3cZaT*7{L!UPc#FRQ)KZ}*V$v+%Z2~(z$6kQ~))V+#ooT>6uuwob= z@r{3HG5)4S3oyNC!3<<7;H&(O;v^$6l~FdGQ#3IImF?tOjxqB?^}!9hTi!CrSGBS& zB=QQJK-yL12U0?E)gl7oLrj;pxb^M6 zo*^8gPySm_FqB5>E%I%LVne0QqeG^gJ~lR*xO7y1NeNhNTaqrFPNyCMNcZ`ZGK#g? zgP}JUL!nZY+9=h#L~ZrH_>@K0}ZAA z?3_0LhykqAU27HOW>>W-hOeUi?yf>}b91K_iq*hW7g%8^&-3)}?<#4ZfH%w2zmeOa zcm5fD6kWZJk$$&sj}CTfP{O%}63Gf;!1;P9wosEhS06{78%IpP%G*BMii&Ihp)S11 zzu&vMY6a|UIO4u^hv~mGJShp6L;C*G5Aa`e!Dhg&)EbN$4O`k@hV?LG6}1c8$fRXW z`*qZsy)%_eCVn&R(pT+|5)x_|{9qt~L}pOG%zq9&#PrvrCBHeDC%DLcOn!fQ<#m5S z-MF>K9aUjYVihg45PkQp{yd>W z`iaJ03J};&>smY@H#5}D&&z_~ zMAmUDR{#Je#@}uuf4`Jt!GZ*?Re12jh;Q{W<~xERyxBkqg3K#coU&EK_gQHF8L3ZS zQ^z5iGvTV$_3+;jy$#t)qyjP{{2MdxcJ}>9b_m^F4PD}-2lu;`#_Dm{nZm77h;sa}m<>w^z^2Nb%Smk&lctD(+hbrH_p->c7^$X(ETC3`hScr%w zbQCKqHAstOSq+NRqNWygbVlqn6jgwfzg62=^0n`WS#1hu1?aOq#qT$k@^JGUwXMWD zj{+eggPDLOKu2ds5;`Bca}f+&EwluXJbuViOH+5x-ZnKFi)qrfJ&^v#pKXOs~K4ynoGD^A@FL|N+G+P`Z7ni@&U>`%I z_xme3NUtq$s7z}03#%G{Q0THM*0D=Ag=el^dG2*ODs@;`;9sbEM#OF*+K4}x{j~O9 z-g2SA}d90P8wL;E@l+3Tm?pFaB~`eUEJZZo^2{2l1eBX-1A zk$tEqxk3tszCX316(fB(2o<^|>QJT=`!=+nLR9Z`>qhQY$OYuvJKDVasZNMrKYpAS z-%Gg#2C(vQ1Ik^WG7e6JQqNL~c%g>MAS*+tH8h0F^J?;k|NkDa6201`BvQZQV-4a) z)Kk*4vk4SqsB0?7;3SEOi5Aw@OzkIbRX9!S-geDLwWDI0t$EUn2t7{4`Q~AcI2J@B zTTx$SuI%|2?FCvwc)Lu(gtQNY!E01UyPC`EKkSf8oFKOQ4feDz6YNbbj>CCCj3)9W z|Ic}Ne^B>WlCUV~sjt}@0J($^Ya^4>;$)Ga=57fV6eA3VJ;a9vX5T~&_ap=}ErS>x z&>}%f7dOVA-|p*Ng&SXI<2rcVr1ZV+q?GbJJvqDLF?n(#QY{|>Lk!Dkt06&=)D(Q- zR^9k6?INhWW6+0X9OTHISMam+L#8cmirp|UifA4JCk&bK)kjaR_%doW4lsxR?Pkc* z9jBsDbJe<546F#mZ~|C#b0#0iycRF|k~V9!*RX2}(SsTCDhY;FmE#-?M%*3)SCfFL zh-!HJkjGA{yV|r=A{=+)Htlxw6QPw-+Eay;q+0=J8!L<+wsJ+!39i7@5_W0t=c2QR z%mMz>B8Ge?J5jSgqhb-X>YT}a7A;nBc`jT})_q z2fjvE(JAip1BmLIxv_s=q)K`sDBlE76o0gre&(~-fQQpL2j!${CArJBlWqM{)!@O8 zci4G9u^vaaP{fK8WuT6K5oSmTuYq5K55o!S%I@as ze94h^d*?(~f7!^1SmZEObf$m)?x)u6Qn6{3^O1|_72oZmL8)KHp|_LdC^h!OA=V@Z z_x@<}nQ!~YENef*XSXG9eCS-JbIlMebHzv{>rX+xsIl<4;*T-^L1Zy9USX9uK}7Ni zfVd0a4GLADRR=ruf7a> zkmiKgK)a(*Nho%uaNwD%ghfRd;p_+L?IDzo5yc?!$_{42RJlNl11ZdrfD+IOz5f(- z=xpknyW2Lu@uXiB6(st5G$FY;DdpuQhZNC-3=a{WprwdFN(+(c_n)clr z3DQCG6%WKpv5cM&!&}?zV})IR>xBzBc!Ny1B_8|T+a57H%AC_$juZI8E}9d)UIeY-Wo|+Q>}KSvwOeSCbyfw+?9X%pO@894lQSQ+ zG03rsD|9pdtbpz@t@^HU53}2xWeMAwqxuSaOxjmO;&vo9ysGm?xGy0h;xgnNxq*}k z*L75&{nd2R@K&H}ddnIE#kP8Uat4H*+8(2$a59I8v+-=28h%-y=aeXIWS+5Iyq1GsJGvo`sSy71cXHPX+d zfGwx&Sf}W$3#r|iE7|mi7B(0{e`@8aI(6;a!Qp<*{CUEHbalANvfq(6bd9K^AuE#K#bMOr=!(f_2@mQ$1_g?H=yo zK}A^I(ohcwB1z+*B9^r82TQ^77ek0hDeR38cDJcUpRp~bfVcoLP-&b5!%a1i5>7_& zr%sj)SPY6?Os8EH_)n@?lAf~} zqcG?B;x`>t0-j8=#dAes1Gd*>!iz@LQ7kZDizc?DN=lmV*FkjM7xMgYg1k(mtS-U% zs=n%R4kyzqfCv!V{uc5^G?0;*>n=d-^c8zCP_Mb2zb zaUbu`rCV|@hgFA;C^$o?KDpN=!kb`+;f)N}AGkmM@Q-hF%vl3gJ}=U1Hk)S?rk3Mb|GG^;A3Y&bLDob zXuYAnYe8(?iGtSf|US9mBlg&CG;eNQ|;*dCp2|xk;=L)9@Ct6FbPBP z96^2a++i?Hd3g)oPQ2ak}z%(f%3CFvz>li`2b3r^M?p+$(EEX~Nf7CPDPOV<6tE}i=#D89 zKmYjKohXQ|5XA^0Af{%Uew9T&IFs|C5sfJ(-iU&ZJv99~8>`FvP|~7m-O+cq=gTJM z_scm=6TQD-lrM+CiF9bXBI@LrI`|10nmKf+?VPV*(!D#1jSr&r%&q9#!RUE&Q8W>d zWsE$}>HJ^7SWScg2MbB|i@_4Yuu#all+Ri&0{3bus@4m6bvWN9!yKy<$fk+?+)xD6 zYVyyhN{H5eqqV4=0*FK92JjnRZbu=5U4Jf7V~8wRMfSLvlZhm?Hj{*-0jc>F^OaGE za+}wB3L(6c$T0Iu_cBK{P`SP%hh`o**xpPvj~hXuPk1=#bat z9qJ8=V-bvAvI`#Wcmv4_$Mpa)*4u)53p-Q!`z{b8LRIbF{rl7xNI9z|f zoFgyGbQgZbBjx=iyM+<%6RnL8N zD^!(52p=yT#-|!mBqoH4t}dIng-hoC4x^xNtc-mr}6C9Ex>t-vdR0#xA2= z_-$FNmD+#L#vJ>5yRF!&k0sOn*pNN$H4QDk3l1RAse$NXXqdFbQF(NZ$d$V-tHI`E zG>&XlJk-{6{K;S!BKp5*$#4T!O{&*3B+Gv)W>_ZB&=EHONSSMp-aIdX_3S_OWR!yX zbgxs>eiu{3%fxY+5)fHhJF9S00T^C|yV=i!-Yh{&?DKH$_6IO(Ko1^?OQeuj@uXmz zdd7va#x+YKJo%lD(L}tik@}Xz%OWF5mM9FAA=Bkj8G302&3nJ1*CKPVSW}Wm+0wMi zF?%FvTZ=(kGQ)EXwd$H*X(m10n}3R%+n0*!o3kH|0L3p@Z(|DcJGoMu*M&>=`c#mK z@H$;3g~fL}o|fr$^ye&>uZ&t*DtT4kR2NU^km>H*MC|{FZQCsW8Fa+j$!i zN(ae|+(A1b&$Ym z5NMkkaBqmoT;cSaa2C*wE~`&xlC}VGo(FppRQlO8s^6AU->R;ao{zU`?rr)9#>tcc zMBwv9#_7yTV$mX=`9!G`?}QdJ6Nef1juHQ(fk`cWoi689WdolDU`2t$%Z%i<KX~J6EhCP>4ZE1JN2Z9wDbUT#@=wLi`=svo&Zsa>y}A89^XxzzLT{!VZZA! zbgNm_6+k1rr09bPmgtj1vzXzD_Ah`0S*n#6>0-jN-ky#TRAz?PpQ3*WH<9cXPJ z9COKYWD(nMXmY~-^75^|mB~Eer*@b-A<1lC ztxDW+Y9l)*ExhiYSSHp7eP%2&_kD!u^OUHtUljf5WW+~S%n?ya8XeYHl} zM)8kn>6qlgGzYiwe)yTnPHPlRqOY>uXjIZ4vV`_-YZfZS{nT_m-+ua`-dpEHqcSdV z)hn2$oTk{@swgx~~3fuxfx=1v_qU{(@?Q6#0{bOE?hzZUYefj6AWFa%$l>z*aUVVmg zo_TC*X#arG!MSU8G)hDRC(U%XLLc!Qsm2D|No*SJkah6YyrZH=#C%n3EW2rc!hiwY zEfTnN@}bft1Z{9>Mxi!75=>dEvT;R6p){jm^0z^bY{DQSCj^;>APh!`4`@~>M7C(o zn9Qevz(o=;I5h|bdR+)6d5@`vpG2+sW_S>!&7XH%7DDI0*>g6-N{y95z#SHlu}tLu zdclX%o*l-TKvQG3tR#Ociv%iK zma-=F+&P;J)J9>JO%J?8)*C&i`n3T7;L+~=OFa(u{FOYkD-46#?Pmx6YMGK=z75PK z??E~p5H3xnGfK%Z07$U)ex}PL5S1OiZ$;P$1~`88TkCN)eap)ugV+4(BOqx1+V3=G z3g%e$gykoMaF26&>3UsF*?a_+mvG6D-|h%UKxFHqO~AR~rAWJB$8 zIJh^*a1KHwWCF^yvv5RD8L~MNn-)T13IroeOPIGKj zn(5C9J>WFQZM|z|rK?P7IgI;x34^KWjEe2)cHNs^GLE|B^Bfji)ku|OXIRl-MZ9%W zn#bpNB0V#bDD?ONG?T(Y7#Sq!)Cd4gFK9uortze>o)rNSWV2bXM$2QvPuP#WqvoAH zb=gOw7WLes$Q^G$7kByQc$laWdfIiO;K^8L_a;3QbkaTtfcqZ*|G$9o9ziQBm zn_~hKgIWBm`G*35Y&w%D(TiE&{Z`u79lyh;K?p~&@Ee0Fs!@wDu*G($eQ+teLl`&U ziENWqif|eTpSpYYx4o#I8FBIYr~J)VsfwJ$k0qPq#B%CLe%(%P>181@B8d}kt?eK+ zV=e)O^Zsqnx}{nCNSBl;HY+&%P)U9IT(PQ?MdFvDKc~sfR!XISIMsyS#^XoQMO?7V zVAO-vgGDt$A4{mhJ?}&vDrc`L!wS%r;tHRMoy+^@vAUYiQ>dD@(g0kTbfO;hS_>u2 zkXtoV_QK0xvV^rASvXvQEg~-5ch3_*dHKB|BiL7YpH32iG5aZwyuJlvLH64LkFqPB z;N6h|7VJufXJ7q-DO749w_X%b4DLJunV?rJ{~SIq%zdG=|BNn=eEP$$n#1R}mZmLy z07VR=?)Plrw4XZiCEos%{A$!b-GW35lkEO8B(C~cR2*B_38ay~nX!@up)e=S&g&$2*J|FII%waI5vgsa{dC5)<>6r1olKSVL*{3h{&15G#h!^a6vExL0?Yh$xi__qjp1l*GWid1r~V-T)&+m(vI*!^(0T^zXmulfO`bG$ zVt{2{0$S8B1TVXwuN5j>UlGNQBn5rwts%!$$=F7Iy_-R*=vYEz7rygpVOIsr_ne_M zQPtx`ijD&V;txg02Z}HPBZiFDI!# zWHDm({=AE*GGHzCo%LHGR^vd2eNqvxJMzAs>rQcrKbr(^12nEe8~$!6+jMx&2=NtV z%7K~N8b7enEo3bwS`xbAil*o?a*U}) zZRP&3D-AIbp5F)BxrLAQ>F`Tv_EF*vS|x?ofpP`y>SeD|ghDivYhH`UPe7kxOsu_1 zq#b($j6WITD2Vx}v_J5G@Ns#W7(GEg%y1NB0L4z#qAISPQ; zaUrh!hdjc{&2GO|9~x)b?!9^h8bLn2e~sYa)9fK1pfvC@89UOpZGf;o?%-_32tUL+ z-cE?iftbDmqh!7uWKzmVlb)bMoby*-t!CxVhM#MFG|Um+A~V1yh?P{8w@Zb|9*&A) zBd5*_U-M7LMs86kw&yr!BOaviP@FLIsTn6_6b>Y)NfgA6(pPo}jP{)}riC~e-eX=` zaap`76yxzKN9zW2f?Wk`uyRB&toe2~u)DXBk^Y$KfIc>I_eNh{Y_9Ln-Vf2fY1{>X zKDfi1{4dMS{_-5-n~U-mIRUD1$)CkLkn)8NHRtr+#@8M)Km45i{NtCA{?g%ArmBx= z(&hG6PxwM==;_A*ftudS8lzx!rS6^6a$1KFV)j-RJFSCS;*wpg zYjum-m+!FNh|)8dcs~0tL+AE6TH|{8;Rnw}!R&&a zcb7j$*x<~Crkt?d_=Lo{jN<}#xf+D}p(r4Tu?C`=&vg&a-99Tm3Rc~EY<|}75{j1o zRn$;qW3`0s`>d+_Lee7??eu2W@lpKo^|Jt0ZgTN(?9DQ*_hQF#+h}}rt?qS?U&v|& z-QIJB+orT7&CmR)dKe%n z&w%_a#xrX*DJCJ);G##0Y;mbqzvx615Kmy$xW$oDa^G|BEFnQcl@t+#=h5ZBD30ji zLrh2oucholuqP4-so^(Exco`y%COBxJneH7P!nUF;eYy!ymj3*Q#sv#;?gPRreZEy zhjBf-Uhi>KAR>yGz^u(55P#+0^6V><&^8>ZFH3k-36eLIvXcoYo2LLd`Tla(x!2wo zr27MxvQM-8t&s(2{ZOd1U!V9yNT)@Rnk5h8o^D|t_b?5ir0`QeMLEtHb%!KWi4>>b z<;&lonyvEeIMqSrdH#_9z`u@!AI%o;9Ly}O{wIf3{FB3?8fNwmq<~H^i;TZ*O!Q0l z@1wM{U;17DsjXeUXj4~1{s|rLcQZhp0`6)~`nM3UJh;|I{xJq96b#u%Li~n(vVeK_ zL#iWz*VH&^KEOa@jiL|*cjuK%j;Z!YVK{;mfE+lPL^N(BDXS*1QYyP5lNeiGE}GH_ zH*+10Kl7UuBgty=d+K|;ZSW^?@)VEmwpdbVP~(IwLIqitOLOS4Avd0?83$(UCLY4C zevOuD2_N9|lbY-yIhJevZ?(`KNotv+vhlAMtwj69`(GNg3Yzq5aA*J)LBoKcG(Cnz z01JrrzpAw$L5Cb6rG~QgCq-SQXyBgvW!I&9+m#M$0VF0k{-tM86jP@Zz6+_t7SN znbG94D&W%k0_2joe?n= zd=#)RbHO-NNp*uX+qYf4WPU#ge|s8x2WVb7qkmX5mk-nzVk;F*S+#+(RA4@n+G?P; zPL7XdLC_RS>$@lCzoB{erKiu2-C z!lm3JV8e*8L;@3J$)Ka`sP=!lK|J}ri31=O0&v=;n>{OZZgg-D6@B#a=r=8w9>9NH zf-;8VwO_KUc%&2&A?2+doYTa?J;pTH+kH!b{g`?EsiwGQp1Wpf!l@i`D!ct#%fZjI zvxH&^%Js7X%wE2Mz&~q~rZnqU*LR%FuOK1w5H0*HkSvHBse@bbdZ;^6R4Ue=&Q1gA z2yAd77(g$w(9v|XZ|ND&oZj1ldmL;AajZM1Oy0u6Q&AhKO|l&jYZIOY=E%|}B8(y1 zZ|2jirZbAK4UuOU!|PN}DVXPhaYZ{G165c;>(5hVIoBjkW-{yH41=Aia%<1q`8eTB z-B(x9t%JV1B!$P&qquYAsH{cNURdpZb25+i9?mXfYcRKqRe9khu zK0VeFlhd}3xpoU*Pok3Z{cQx5`&%Ye=iiN;mkBrv9O*MobVQJA0v*inLQvg5Eo4rc z%9lqR#~qc8k**8~1a>k>^pcfCKQ!b(Y(Tc`Eq-NW!MdD=kr3NAA?_x;RZhnyUOSmf zD;hwMO|CS07_+m9wsg=9KQ3J+c+Of&-exw3#Q}ru&I|!K1_QyE2c!BlV)9!P4`9Rn z8ApgWUDg_H{y&O$4(M5~4HeYJ)BxoT`}YYtcD<8M%CWu$?uLHDa~2C%K5K>v4& zNP~OKczgq?i-?@&RG_O#JNfiCO-rbx_99+Y0r^M8`f|V|FdEgITJA@C0#u!Q7KC-G z8R!4zt!7gvlt)-?;oiNeNJMQGkUw50hnLOKea_L>F!lnAO3D60{+mm!X9O3bI)oRJ zS42UGN4Gc=4blIJM-a3!&^r=>%IT1(s-W=BM*~~@t%SS;aoqyLsSP@dtX6iLZUUxY za@40_4iMQ8{^ydkQ&glk7eu!2bC=tww&%LrP@hBBKE6-O!vv2x?cZzLDlA3ytEdnV z#Q-AG1*C0eV7@ET?0hCiHvlGk0C8q(;8o`V8HKu1*>(q!xCc`+ksqvq9aN^kwD_{| zp`e$Jvs%6;^b{Z+5Kkv&{n}X4XHS9)V@K|RfHSX$-+?==EY7M=K2a$&S;}spsp0~5 z2=0fLAGfNYwhy>b1^XYCFP?I@LtL00;wWd2jO%p#pheL(dCl)iMSn<%&;Ky)R0Wa- zr#QGS4daPDu*X?V1|v6bgT_oWM#f$gu8AaaM5=sX?$3W(H_9PB85t!=DOw|<-*G{@ z`SxbdK44j3`T59gE$N>yqrgP-TT9Q4hf^#9-;Ivj_%XA#&bZ61WSVCd$UxSW$gAvC zIS=h+z~896Cj9hxm+WfzQ^KJZuj&9FgTGc~_Cqam-HbWwrw4}Q&+c41SA2dLpMKUB zNP1|sekm~hYGdfPFG}ueVkpDGQM&69}z5nylUL78eHs)vT*`=u6w>$d4>PqU@;#Zxo|jme5E z?@>q|kJ`bB@>o981mSRm0bJgNhnkF0?(G|ul#1)Jp&fUppoVWQQ6{{r5J=`Z&(Q9( zwEf|DYp7g3#nPs{@xOEbTGh|rJ_$JYwlsC+87u$Jeb=&}mEZHn6(9zyWb4ZPjZ(wC z?Rnc}k+R=QsB=*w9-1h3mq9PrCYAV$??n{w0(vSEllVOV#gc(4VIzaeV%LTyV(Wb% z;6SP+69B%&_ZRcwR|#)f5vk_$E402QnW*ah^JYb|V&Mre73#Bo0y9)n*Nz2b z2*2R9IBF}kH#L|QN&Gdd5smGkd*1l5*o2SN|MZ(Vw)oxGO`Y(jQ0aC*hd*;EO`&gr zscQs;*Xx&%;LDLc&3`}M+KBwjM^}uO`dbP_x==ixdQztPObRKnO;62afK2r&Et)0B z1ue>Yhf{?FO&ky}0%AKM7~K67mnj*4jFp!eg)0lefz`!qFe*Z!=kN+db(kYih`t$V zeb*5fQFpryEZZ4A%6VUog^~Ia3FCPzfqETYbXWQZn#GtZq>!k}cH`Makme1AqRYKk zaiX!lgyZNDehV&SQiqlb&*r_f7v#4bJ8XAO|225uroj`YxJ#6i+9}s|)v4bpW^%Vo zhij$bY1y_}!2`5lsie(&1pch6e+SK%&b|j^9L4A?rap9UB0E?GHDYBt)bA?`IyfHcG0637ZgVX4$A&ochTvUuwW{iiu|0a zc8$Lp(&Q}JZsn$|cRibi^K znypaism?)Q;s8`UIi@LJL>}-4N8@l<=xb=R?b8U@Tj+IZoCx_*n03~3=N{JdaQU~| z#Azmw63p9x_5#141gwgmTGGLTR}IoPiEssM1clpzFAfO(jMTOIZgx@r8fg@y+8Fe z_f0zz0gh=i=^`h%Mpe&4SPl0-|GX-*m{l_L%++E)o~0nY=ffjEa~*?)8ijtWF^qD( zc22MPQR|yZ4vtU$NCbxyc*k~ZU5vxwaX8K9%|Mg!@M?OB;z$l@;P~o&_P(r=PX!%( z7;_=2o{I|OXBGQ?XT86xHtVt zM!ZTD5VD3LaDHi?EqWHjh00(MH^)fx1d^lE&WFuL>4zxZYQJzP zhjbpV-?pzFUvW6f{I7{0t~G0pdIjwdHC{RTN@X5sRA>oiJG}{R3hfIx*B>Xg!hx-e z?KHB!s#4o08y0|ADnX%dCWMrbad`PpU$j>UEB#(3r#^V$!!(wy7de3e)&e}^!6;DG zM)GDF?-5=sij<7d9&Eu<(yP=i!>MlU$_88)C27{-Q1L!slW2p@Qgy zJ(+Q@iQvAqW=pNn!>hV2Q}G$|0uE#ZQ0Q}OMkthw&=|=5ojQI>aPU!E93pGg|DlBT zeTT(c-}YB8#G8eHc#N$d!@i;tj$xf>{Sf`_C{z-&2PxB)iz!;aUw`ub@@Q9bGW@+GsN_rzs>lcCRi3|~yCq21{ZicF@-fPNrle@90L5V6M~w(F18? zQ%0A}oV|(`%y~7#O_9HFl%jrv-2Yw#4A22JDgJ-9g5N>ehsVHbgNZ(`_@i%Scgo)t zp75%N?CH89F35U0j z|7zch-Yn}eW>CUJv${34%X$7DHUAa@ag>|Ri6_J_*tb(eP7$C~0LNbv8rxu;5yg=? zzK;i2*mvI^{k}wxi4gJF*bushjqy#wfj~COVWyc9RP-x6rLL6^w;LBWbD8|+~|=Pma2XB2FL9Z1?<2DhcND%Nz&b0ukGZi+Np9z{9Voa zi<@Q0aL-AhZKVE|e3fS`T(5`ioV{yt8CR-dG*^Z8Qb^lR?r&rGGBM0R+CU;gK;rI- zpEbb5Qij^p03TXnq7hcLNLA@SP~P{6$Nvq*YKLvd!ICFH??@*#kZ}Kx`gZg>Jowa4 zKh!M$^i}*BKHQ;j z6gXdWun+zY{6o-e(Y#SAZR-@hs7hS=I^ek_WxT9F8V+Ag1x=&@oOHz_u>}BwrtG-y zYkvZyOr#X2I`exNG>9{QG!C8ctXa$RD7b!j#*I`o5K9nUso_0YYe1oO;^zUH(<$hFrBl)1_W&=P+-vq zWq~IAn(4|WO$CObq{he#*B{QQ{gMpF7hd#;ZES-!jY(xy&sWL?O|E#A0fMycOpJ(k zPih0Wr~i$Wn2r9_b-V|*5{|#x?3qT8DiOm|@UFu+6uf(1{`ffO{ci`Y_zpXl6X9Xx z5VKo&lyYHpqvoeJ#l(lqaOn31z8__Kp{ERPd8&&2oDKAfuoFFtEGY~ZP}NY_pjX?_ zZM1H+W(-xp-zWkG-@NRse93mEPD}v~`ERK^IhYkh#*9}+;Q=3%r}lMHlR0p~&(pLK z^#Xe=^qqG8RNY*;&ctQiyX&nAun71wrq9~@iyfqan+4LsFKEY*yQv9gzXchYxg2ey z2|ldvQcrkHeS=J_Qf>t{Nf>Y$sKIESkx+nhrG_zv(OC(jUr!AU0*@J*2!dQ=88J;y zw4^vfpf?n5I+x_K{eG6$oxUYTlHdl3+qwKQEI9uKVF~_jP;1#7M`$JIHA7Epxx}FV z(VsEasNVV5y?XoT{Y5(BWg;{a2wgUwZcIL0h#kmOX%n5GZ8ma7i&U;JW_4w6 zwiZ9|+KL1CKd~}m)LP5WBO(mczj_D0{8~dI&p$Rd^n?37x~Vfq##GFFtHX2YuoN3+ z0OKLstz1EuqfH1=w>(yqnR1AOa|08oHBDV0`yW#YEh=KM`SO1%>n*DNVHwSKl zACOnY*$e+}f|s>egxJETQ-|OY;2x-kAZ@+oE;&B69~6_KXEe#h!38280rzRiLR(+L znD$mi&vua=;P84&nh({J<}oL5#_eAaQ>Q zof9r6!Tqn8uR{WS)~0zY>DK5(YbEb{9y`Q*k1cY^(RA>$Xpl}}XyvYWuk#3a(T9IK zEt^prXfc8~yGapIKf?CilL?H)5!8IEG z5f0DL**C^!{2H;DP_UDso#=q$v`O0f{0X1ZPsKR?t zz2*?0+Wg~rSEXNyq^ngeH#oOD!Sv4OnLUz<=+UJ()Gsnst`_xLyz+)T7B$R%wvGRQ zyD+%)hxh%gJ(=M8qJj|MbKr&F1-gt;YQB$v-Xh2oI4N2S2gjN7YaR@a!ZGM4YrlQV zhQp7VIu(Ca{1Zjd*x}gfLY~>J@8~=vKeN#Lw*LyBea9hTEvY-}HCJ!?*kTsFqAt>0 zX^p>)HVIl&PUl;uVnsw;0~K78qHv`mK&Ish>-k{V5sh;GY$$`ojnZZ1F-VzzC*gq| zsgms|dIS%QY54A`e8sQAj2GDHd+l42IsI8EnH`v{F7kVgRsEZ+?$S+vB}gTNb1X30 znJ)i3J-(m6JgjY&dW(nl0~6*{P73cI^=%6fEqC+EcNEx4Pr9^#B-Q}2-t>I*w44NV zE3MOJ4oZ&nfaBB{)&e^>d!QOsv(Ph|JJk)+ni-V)l#}!-=;{Blbyi_fJz%$=8ERmN zLAr+qX^^g=yOHiN=~ys`#;}#&ds^;3NQ58?ESv$ zUF)}k&=6}HNf?U`Mo4|iXJxcrj?NIs7zPMU;+juBGU73k_>M1gm3+C6dba^Z>aZf7 zL?;!;Y#{AJ*Cw7#|r~&y#|Ee;$iuYmmIvz>5p0j zJz&sE<1s4waF0f-YO`vj8FrH@G}hLHX6EhFTG#M4n!0yepse-=Cl~UFH!v(?-hpVo z9MdtJV|5b4v9L-o$27u^RB5`Nmo0$HtI z)OOa7zO*3W4!m#2*Sg=aOn!AMsdxr`LOEPWmR7{(WG&TT%4AGV)=nfC?a z&94Xp|NMNOa$LX`{Pb(9@J{b8{kXlP*pf6AdQXv0VPACFCqQMt$mE33BPdsa1t`2X zaEvs5rk=-f>B1e=T8;j=Tr5QFSz&nU@`*b4eGNU7vug3F>B34NSd9!TP zg{MD-Up8BuNnNPt?ya5qhn5qx=FYk9);Edaj1I)_$2 zyukwNKvkPR;gTgE$cbq}W8YysRm^4IGR$-`Iublj%fT}UD0=<%4;cY>)H^Ry30w+DEQd0t4 z?R-C+=W6OP-A$KUrudtykw@2oD|WI-x8QJB8D%6AfxcXeNCjxjLC@$PmqR>!H1;_;*p-o2lGo2eWbiF7oxMc9Es_2Y~Sgu*5AXCzFuLCYHxn7 zlxYqk7C~ZaU-Jn%;~>9>k#@Y|!_t;I2sTSvT|;rHHfe3`ylpTiADCF{B0&ShDI|`7 zo*5ssp`H{BZnTnl{0UfQ7Gm^lm@3($r1a_@Zcj z!BYC7|4X2}kQYZET1YK7NXSN}NA_(qS0iawCpi=P%i*g=_z(&r_cooVZsv=A()X69 zH3*P@O>%LaBJX!SJi~auWLrX_b0U(DXH%#BB!v<0jW|supu4ZXkL3(?3szk(y7k82 zYb=ZWd;DEF{=4r{-}aBXT-mx3T_eW^2W2j*@eA~`8_d1fK_1G4(D!d9{bnOqrDAcn zM7ZDtlESg*T6i*2!{-CooFqqYwHuY}E3gO&VWdIv6M8ktKAt}&99UYnpX`g=RERj; z;N5bcQC57cN#dg`nOLJqlkHYnE3UFNm zcIH4M>1sRlobp^!cKSOxz@ut*D=7_kc_aQWrDXQ`-zinC{^Ys;5@5R;iW<^vw&FP2 zW|bzDy1uJ&u?WQg*J!*H2CT6pvDl*0=siNs22VfLsP@)W^9nK?rIC#zp1+zLq(XF% zyM}HTV&-w}eftF1s zu^tG8k%oxJbEwCvg2XuWM@`Qq;UT;JfUFBJbl! zS5Q@)=;$nE+^qpFi~+Xt&k_uhzX$wG&n9VDD~XU!S;tk5{jfjb;gFd?ZrAsMbs9?* z8=GFX;KLLxiHy7IXfU2y>?hq`Gd|KGL4U8K-#;@KP8=WPX5c)yXyPowKC=I03AGl; zMij=7*YRlk(d5b*%{KtAQtbsGJm+fedFf%zi4%o%0XQPR>SH8b>emAhbi{rqyV2N& z4sowj(oZRcKBYzUW1f5R0$}^s>0fef>ihY^Z}SQU*$^ACeI}P-%7#T0Y~!0zHM<<& z@47w~)8Ou5WyDPgINgd)WbqM-FD1V-Y#invft zrfOn&0$IDf@UhS{)A92%hwR-+L10arwRg3ZPCTc-<+~X^r(OHwb!gvA4#;UU~ap{3H0&f8yEn_uP$O0mv#R8MpX;ctr=GIyceH##cHPjYSdt5C-d{ zy>hI0M^M*){&goZ;gp1@nq@eg3iH`#6h{Ze45123LJtO7xh+cL)DS3qz-=TU5OUK+N&N%d*yX zwYxQ(bwQia>A-V}A8%0zO!Swc;`Y7Ej_6>Wz}Q2Dz?#@I9C`m+>^l*?=!Dm*mmOAq zT?F?rM@6|;S+8^6jtm(rzpdZOa#E2cWT zQR@KTah^BLHooS#kq!BFA(1uljYzGTeYV4A4cP0$HFJdaqGryO>qX632kbaVST7m& zpQ%Kj10B9?p7e+@Z-pJfGpdP16z}gC&8bIn7FHw!QV~CZveU$@LW0>alojtaU$ZV* z0akEA>&Q9nEg7)zW=h3fUej*fxzBq8DwqeC_5S$t*gfg09VN1f;)eUeBwIQ;R@GO6 z{lc2aGv!bji98rR<9pP0nu~yW={%6&zP2*v@#28SuFyO$61LySSj+aX#*P)& z5Ee)^wH>(HW}@>pYYP?E2lv_>EJu~?}F7og(Gog zNQZmbewqYST&Z6 zay>qf6YLG?;f@HwNAKCW5!x5~^Nq_mc>@hS5=Oz})_MdW)h#lm>LjVnnP*}4t!zU$ z5a9G`%wyFUlid5tj{xdImJ^Wj74Q!tkuA34R(y2wNW``Tj=`gJ7rj||9tI<%5#pJ1 zhNcTs(`%|gYo3o|#hZ3#Ie_JWr%9@4KRBeEd+$_8Fn8SB_QXGH+QRJcxb?_(HvZ-B zt(!OJMZ8Ir=E_{X(^Id5l(f~^3grz~L&m2H+$&RSb)?lVVXNbW>Re;}3z+@DSX1J*pNJvu-N_qz&_>?C$=~;9 z;m2_vY9$T4qeE zFGMv)a1i?kOemVZF($ETHF56fQplzQZp@ zVZ4lcE>&Tn83uBowl|QBZpMbyU*fUur@?k;gs*1K2dCIQ(I?X{2!APLaVq!$c3#$FJ1B8g@@s(8zXAq_Mkh zXnL0ipkCN6x<@j!Vx<2bgksMF)`dbW)^5}LmzE7{D35MKYIQ!c{QcmU9T-M&VQHi1 zAt(}5yik;FZg*z0HE%k|g%Ls7{t|3q4$*Fz=1bA52b%#IAV`C;#5RA37=FOL=4In(=BOuR$CNi6f8V7M z!NB^{OSkK~)f~vVSs7=oj(=8;Dd407C?bH7)LmBIPhHB5){Es=woO=U=2N)6!tF~R zq9BjPHU&zC>j&Zw27(i^RK-QblI?zq;0$}PwX?IYJt2NFd?4KG{TqV0oY>>#o7c7j z&?Eu?z5+3}MvwZM!c2IWjv=>;4GHJMp8&5yPChK%r{5JceyqhR7#a&Bp7pIfEp4sl zY0g>8Y+NjUCVeWjW)YNYn{cbsq7_2}R39~PJF<~gv_4TS9E(XTq}h}5%G*<5XJn+A zXwq&{GB@8ASCuQFXslF<#fWUSK&k>@>EN5k$uu~@(IoN@@)M(1V?t2qgwx4)`#xPY z+mV<>F)B#GjW)YJ0qZ&%y6=9&Dn89~NjYUxB<`ODnv3?Ez%IP12W+a*f6OM`@_#}{ z+*9e8K;Y&q%4y-Vi(wq5rhuF}TcfC(_HYhj29HB06as2{^{~NV?4KIVx4#t=!CkJG zUC4{p%`}kJb|`uf>9%^g0+Bv2y2RzI7Ddo`(yls#$Yl3H@#+NU)uFRySM?d8f)e&Y znwIwi)O9R1rf+_^l6iS`vGIfEZVa3#%LVQE2W-<>-;R-UNA7Tx98>xoz6f zQ{9%>IqNVvPG&J>1ybV7;gEACS88FFqP0I1zaQmXw=Yrx*4k4cAKibi=lE(!ztltH zQRBmvob~l^!eZ^W>xY8w=MC02h(F|PUSEU8{L_zDarU1wJo;%m_Z`rmN&TW&*_cao zeTzq>3sbZ8Jld602o~0yu%SNX1!hkiZt}BxA_i2%6F!Y3Ev^$F>h%0;4A+qp+^^(g z)P1FkCIKZ z#2ux}U(d+vy??&%&)ieG5yx&y5G*JGVQ3-bq&;xz*rvvSGNNS=GMC~RJ92Fb*g;}p~)k7 z1j=nvFE%!v1pPVM`)U*_(hanUjF$qD_TIK=;3XI2+F}`;`Gb?EyDKYtk3MhF;#bXG zzSurwkI|A4Aq-h*6^?3rWQEPEAOq6HZ)wSuLOtRb<=g5(4BckRS>sms^3|j9Z1y?} zdJh&Kmhom?ckOzZsWLSpa?;c3KnJq1v@b z>@EqG46-wC3LLVw+tQCfEw81X@9?SF-x*wSg@(VQXU2gHlPi(^G5+<0W(aQb zNgQkXut45-^Jg71>*;LYwbTqF(JcMP^+;f@I3S(2RdM)nM!x^OH9p_Rrcpd!@J9I)p zNg>8kX7wLO`7O={F>JVM%E=6Cp4XF=)1G9UV=g?5Db%d92FQlg@{tleNd!RW?++~f zohO`suqhJO#6A8x%aXED!80V$>r{Owuv50(!1fKtj0ZuR9Er-Da#@?TSDODE;u^mx zCAOGw*CA@x))C$b#o{+pCXo>uWBJLC-rQ)w_VrDp`HO`G?xL@5{dqj@BV(ds7kp=k zyUmdN5M^j0l_|%!^f6#Cw$_@I=b#SEro2H+$Ate=Xj4P z`XySJVNEE8>Wn#t{1M!PuC5KcP8@+;%iq9V;&OSpa`A_I0t zeoPOoRR4=rM?NKrX8O_7ZhBUrnz*TX%gYc9S7C+Nl*MA75zPc_muhm8*I>x+O93#d zr;}9M1ZU8jo(AP)MiJjhkw)-!WJGXd8J#ibIr>#o=X`Nx`MX+i8BOIo99cz8 zg;y22G_o>~YGcVZ!X{l$G)98PeP~N=hH`*36YV@VyGvK{Acc& z&nu~JMEn~q&t9{j@C;V%UYD(uUN^OW|PL zS(q8G%^SuHpU>%Pn3ZHbJPe71{Xo<9IhqKz6ua1$$G4RV7`YMxdBo$C88^JYnaOY% zVMWoKBumgP@JT*{N9S)g1hZQ&lCz`eE9quh(OLt3I=;~@{dQY%TgvBfH{n$oTfj2t z*m!S~Yg1KNTwRsFX2ajLRWKSPF%yx{AGIftbiG*B9WAw9T9>yq77P^(h?mMq`?gc4 zu|ggkt66yQHH9EZnk_+ZyQEZW7>wT1Nb3Nli9L`TC7cYKQ})}Dm~b+g(|~9@;75wf zq)URBUy9x;ZZNRPwsItS{wR<#k4J4!@})j~WGsIW!!R7P?$N-Z)XtDF!0)16B$DoL z{``A>fNnrB;b}e%n6FMB&r*WLQBwHBxMd_@9}6MFo^3m%#c(n|(c*xkqsE0LXa%U^{+D<}_7F%4=awmsQm#$CIj#igBN6)cU(qU$1!wS~uCK}g{D%l*Cs^u>CeqIb_w@Djp+=B3AQ6d( zu5{ocH#&x8DC*G!62X@E9bTn1L56~7mxkT*7CwP~_kyflgkfX3x3^SX8EtLM>L0+= z;A)U&Yy_vYPjauH**2D&==RB1H10XG_(u^ufH^W>`V)fYg{*usVfO$uxClLnLbFa- zY9y23Bv{Z8Z^ReUby5tAJQjlPL?y#ttZ8{dx-ieQ-g5jl#lik+Yl2D?)I=9f;AV_^ z>yx{{8;eZ{L+?p$bYyC6eJA~xHzt>>M~<^-W1-->=8PS)+ov~ToAhml@BmIHq~OF& zC7oYlb~+pb5dHRsD78Ke{kchi#&+N2u<_bAA78~`JyTr)r;Q8@JT;5~Tt>c}Eb^Cm zv_}7a0g6aUIk`=SV?wXAst5!w7wbMX%ZkBUt`r+By_gU`aYzpH_xSBw$xkKS&z(;U zIqG`3FPf&qe*lGNGxbywRsmdjeyZZWH++-vMmqzi`9Uj&?$FxDAc>s~HCxQqcj+abWkxAg-Smp)k z8i3VPd~2LjMQ|eMlixY8(q=k-{(f)WZm&?9m0+><6B~1mSToN1LkiJZZ`A3kP`_EO zWU6hX?aj=iHV_`Ojd%G20Hb~0U-x{t-+yNDS;A_-C>L!92em+YGJ&W@4X3skNtqNQ z*_V&9+(;3t))-B>%W);25-H;=nbJbKapuvUk0Pj|A)Rs>mcuZ5W$8kn{Mk15? z0jRs}U6(x>drUwoxAA73oJK!Tp49()PhJ8i2d(XqyRvh`MTvC_~(pE@CMFh*yr_0+DhC>Z26kc8K5M*0}5p1gGGSTIj=5 zL8;C;C{U^>6b-!xG$KRR@jVG_L_7zz4KKe*yH(?99NDa5p2y6ea}%F(hUJqVRA z#?a924ML_z*e23gjgygKneN|!EKr=eYe-yBUrm0{w3%f=F#uZV+jNFI4>{#ex)*lT zf_c9JS3u>p_lt_f@$nkKkbQ*j5zi1`GU)-x&zLVe^=CPyt zdhe6aDZ|{V=Lktvui>~Oq z0}7S=p-y}Hk6)K`t@Pyt4OqawpAVQM{g>s>H`CipQj_xe>**PBJ@$_ZvpH44IRz00SPz5Jz58@8k7|w zK~9Dov7n+rL25)1XihrYfbNbmXaao!nV*R-7mV~0CbHAF?3&WyEO}2du57*%SQ)D8 zUU?nrAAD7U*20YD1%3wy8Qf$ZM`y_rp+q}77RXblwAO83q)J1%X$G6)=rLUOL z^)(_nLDiU`*b1It9Z6{#S%Gu_n&u`g2GXo6lglFIUax3KTs zrcu`x=w#)yeyS`?HM%F7+?Rl~;6^WEqLpv|^GpID#HW76IqGE>%^GbU$bZktq-y|A zScwi3S2O{LyN|zoUf^?Jpqj7&2&MWn3HSz&Bs_w=Um0lrR&i_Zz*`$*(e=Ek_X#=` z=$htvIB=U*v)nQNM-0O-4U-GKzIQtsT_S->G#aZk@7IejEpYp={Vra(c2aW`Rw)Ya z4c6Y(U}T^0{8?4xDS_c7VfJ>q`wcxrzA{yyC8_e#Vb;3J3Hy_939)`H<$g*6H(NE( zf(Gc!-xV#6A~JrrSt=O6cUsbT)(cbvur=Hw3L7-?s=hmg-ni9ZpvyNA!8a}F1tVxV z)Rtr(=b<^9ZBbxXRrD?Jduedeab84h=AW8`OUYKLh?weeorn(3kxFK1PzX? zjBtxlFL{jtuwZABZu4($2LkBV$UMcTm~Tm74CC>yFywthG;?c9!^?X4?np?2iA+cX zslK9?n^B2`6!`h((w}bDNxOC+f}NiYhB2vb3g+J*M(%I!<0ki@?13#6DffODRA2%Pop6Z-<8)Ahhlt-1?5?@TQ8T(Ajt<28Mrs_f>&spW)1i|BW5=wFdL+?tl3HU zwSPQA&Ka@%le57l`gD;BiaR42T=r}g%gWM!1Fao$)J?az>SgDa&kJJ^AI>_-sPX-~ zzfSXZt))Y)kV&IHP|SzO1oHC~28Pt3)|BC)zoZZb8twAQzYU@N0y#95f%^o}AHIc@ zA>ATYUa{Ks1E`uc9Q}+8&YAAu!>d9RKm#mNitL=Zmu=8>kBFw?ixB&eJ3m9ybWv+q zpJA4(>yhY)kN$1Ku^0ommh#!p{G257xZ~*0u)pK2$|Ml@9`xoOJIIWiR6gw$#tAsU zGheFww|FC|0o=ZIh~*h!C#!k79dk`NdOSCiDbsl|>}65L)-C59OYqQZ{c$N5`6e&2 z@E;6iL`;!%h2Ib)h?$WQ*~i#jkXPPR>4ZH=ZWCZfX&V}M?%-F|hU3Vk{8<}wrS?6+ zGulAXs1Tg~1icrYTy&Rv6FwjMYYtye?W7sb;#hw$XU!{X(b#W#O=-TppXs+u-6jpk zlc4L{j@T(4bvz%WDWyOeo9SA}0sN)11nn_qoSx9J!jA+WVvtB&7qxQw7ero-zVg;8 zSZVby_2vD^xFZ$PzT-e_jKJycB-F%Nl`jk*2aFUL!GD`{f8{1Osbm53vBozA`ah!| ze)UDF%_#i`Ty)WB;+S=jylvn7@SBs05c|I{5Fj7}VIXX>2PYTc#Q40<;g{$?s{e`f z68S;EysNOQZxo1GbOD?R&Ejpk$UFMjG7n^d@+E%H5`EVFTXOytnqufd3tx`kUhP43 zl#o^{Q3E13;C@B*qX0JUln?; zrv;^Ju=Vx366t_~^=ooc&-jpWQXQFedU_yp{DGGJU^75lp$MK@NSZs$7vCMHK{y|D_{0%+PNfU-08aH`@riWj^mv z9T3^Y*(UYcbl@MDF#`mJ@SxCPljie;)dqrYF;^slARbBakU)!N#TsB8_lO_Rpf9#Dp2; zGRz5cPSckN3U8}Mlw(IY&4he^lavz2hv6#c26wj;)L{{ni%aJ$^U>veaZX+ebc9i) z0mWuwcGGsReVvIKt4A-tPvO4a>WAf5x7%mIM>9!}d71BsU_b||t>$AlyYY*@88%Vx z9|I&~%vmrzu}br(yC;Oc6g@bQPA=e$0fT1r?&NE~f0^K@7A}DjK_alyr+hMGy)qm1 zmA%~(e%LgJ8+Ejf*-}`82aJq^)$qbU>{*k@SaS%_#4O6;31Xa}T9t9eE@UuY?lDce zCU{cs32%;RX~^*Cp04}_NOBG)*`ms81S*s?_s(qiRQ9iRNEiES$Z{ChEX*>(q|N$u z0<-J40>RjCg0R$=@=wR~wuSKkwf2C_Xot@Q%tntOo`wF#d?N({)Bk?v+#5wcSLr%z zz$gDm^k7@>8*tX0A%evM4&$|F`G(^q#yH}{RuuGtPtei%UhUSplu~m?@9I>v=U4;X z?O;LztYRYmF7CYl<(1LF{+n^4N~&Lt0>e(WQMuEuf@RnfunxRnePl)hhb{W8MG^{_ zgjHPgH3zVKjoW33a_6mOQke$Gbo$C5xIJM;rH3!HkkoW)pzK8i$vOW*xDWOzs84m0 zs{<2-mMRJc?|@KhW`gMPLK;XBglF<*Uc698f)Hq^GYjxr{LL&HXf`V9R&(`d`g`^5 zH{AsRaW@wJKF;CwM9$G5kL#=mqyzFvZ=xT6AZP)c9>I&E+`5uLQg{X}%%k`>#jraY zBA=Ja!Q`o;h z%Wnx`AUC87`8-+rRh=P}j}{bx4sZaUVY|x`4w0nRrZ}ad*amUdWQ=7civEyi6WYk) zsvqELUH4kTG9(|OvPA{6vnx8@ak#DP;{gQXyu69LBKv!sDkD8yD4gnQYTg;$#vgB{_wnh=OP+#=KLqxCn#c_)i1zm`0!QgSEFgI62>r1~x37-G&DuaS&Q1ql z70FACLmOii9RnM;shh^eyl!B-MgX-(YY3DNIQ>&B+%G!*&!6{nw+#SlPC8VbaK!Z& z-&7WAxgu>*b;B21;xUW7@jGg|g}4D26xu=8mCXXmRiHLJ8njn8vuD;JT~<3O73M?$ zJ_0x_o|Om^eo?ENC8G>lgaaA!48l!m4r+{LM$$-A*>V5FQ6 zRx8rY=I!@`xHreO@}(o6n}7_5RfPiOEx8N@?h1y5?J<9np{)%sDCA6WWywK^}S(Z%%y~<@zWcBUMGPlQ6rqR zX9m!naZ@JGL2aVEXU9xWVnhgARKk<3-j}X$qn|T`i-bOFQ*wdFfn`j7m@#r;xV4#v zwJhL3V&2~`xL1Jo=HAM*x-$$Ia`O8bm}myGGXMpT?7EQo>@UeAwzEOw1`?76mgUjA zpF9wf&w*%@pfOVapkKB)f^)~2Ig2lGr_HXZYZcsoDvDv+DrA5AP&toDNOBD&Dqu?P%)uvcok4@8W zuxxo*e)^K#d>z(k@+>v%r3I-UdVYd}s9UNJQJ2+5@VJ7)$N0_GdNr|qmOCS9zqxHY zT5d|BUDrRKZs;0B%FZS{fE@yjKAJExe5eb%(G*FlEQu_#{T-8r2E-bHC)|xgzB&ix z#+!lLiYyT_#CzRbhRGehEn>$E?reKSYx7_Wq3iLcx`_*M703`eW!F?3({jJt>>sx7 zFIXA{66p3^D-_y(io}W62ijwXi=%u|=jvp15l+j$@*hgsD-X9Pr@bG8GipMrL8uNdQwt4W4pFw)*i<(u(&cz%xeVCg-F zWOO;0LTm~esH?l=zHu%wXS9np?sLUmh3Jb}soF()DN1^D4sas=5I%d5i2hH(=q4!n ztuzFU^(lLME@@z3C4i1D!^W4v3}@DDaNz}R0z&G@zYl*}u`}%e3+b)*Q^UjGlaqDl zoD;F*iA<63+h=RlJ*G~v-J3tJ=sWv*{1eEWJl&m~VjP`VH(8GB*UX4mJC8t}WK{ER z^j3A~31Ageq>tjnAAwf&L-XI$PrV=7FQ-h;&v!l_>}C;u4m*{)Ijh*3%JTf&erfq| z%55aeq|vdC_B_HV*2s@Adja(fq%5{_8p&rt?;{A_PqrqUOw3P{nGmp{jerV zupOe`9eCKn|M!5^1sXr(fnVy;4zj>+E2Z%p$F#bBM;^(y*(*by2SaTJqsSb$#Sw)U zYiFP8i{HSy+*;^87=Z2sRC*Qo%M+j;71Y@gFc2#-{^y#_T*v?(!Z}0e!rm$2LQXw& z-4uAnqU=A_>}3h(q^VBZnYhzR_^9$ysIaD|Wtwy+6eSLYn(>+W%|BrR#d;<++qhvV z)w@(XDy#vb1aT^Nl(3GG?0R+s5-D50DCx>l`?~NgsZ7^*Rs$|iMS<`xgB=cSQD$Ct z--HT&Th8>yE7|;YAJWD9RO5dV?>HXaw>_5>Ubu?R#a|l@2^rA!x-lQz-CqdvaQ_9- zSiC%tZbO`ejRxFYY5!yn572u9Gt8)cx{m+C{IRXVoyeQ%#_^y5np92tIt7z6r4RG; zv$#O%mruWL`ZUdsP)!H|Q=ZdElJoJ=y8j!z21CTFku)Eo-e9A8x~f6KlY?1qegq!s zr4d7eHCtmXztH~J@$bhXzHg)ge$JJDemvc`@X~(VNM`~2?w6dzvWvBR*ae!Gj!qMd zrkDIUVc0vo{+&Q@vHOgU8wY3pa$532XIo0&GN62*?(>+?bFm-XRFZvdB@s?l zMWZW001@zG~o^W0z^ zQ69sVz(+H$Q2o;{dkNp(BPY;%2HrS!k?`EMxebl}4cew%UdmDN!OHn54$MYcf9{qp zN}i+le04SU>G*bpzi=A!K)-O?ADRa;d9}nq*ZKn@J3i!zHMWSIIRxX~Ok#2^;(m|% z>G&WU#m7)ux%#s@nL@bYQ};r$nvx2jU!1ArCqFJ1~*cIWbO z9$t$b5}>mKisR4HUW;-xd(>kZr;gV|pk?wxA5ihL4+(kZ_`$6kZu4nd-OZNQY1O_~ zf`88=Zo3bI4w{aR_>6vMQy(#JX%J5H9QfX1ZQgr&*y2&Y$F_^ zq3=E8>!-SEv)))2bkL=9rBF)6s&*aM%L1xam!~s>2OWLl+7`yxo=S8$x&CW=-CZ?G z!rVV?B8>=oLi|3jO6AcBrEq&ioAH>${gAU-JD}RwECB%!B)JANVi#&0Q49%9*g6`i z@5|0rYXegJWP+Mx!{@5pByhr-Wl|_WgthoPWU>cQIe1y0 zVsK!3WjM=xMi36K$~b#vU}g2BC5w6y+hK^|^LN=QlVf*Y^ocq&;g(_5o!FY~QH9gW zwNK-ZwLgA2nS>1u*>2J-)OdCNu4+1F<~(~(LMqw)W~xnbL?bFN5x=0pLiKXo^8UkW z;966H?anQ^pd))5CM5jprq^5(EV>N>m>lOc7-^oAT)SXsD~`$_t;V;|#n5~5&dBUA zZU_OUe?^e__thsNJ{E?vfa^A;?c~XL07;>%md*j{`!ZAC%M4D0A~96-D=HQ{_PYw< ztl&vLlR3(pnkjPV^}|U?_rB{b-r+NE;PF5D@A1E4{JDq$=tma^hF@zimn#;uBVtK^TFV1H`7Hz>kobLP+y*)@o#z7uzKL}w|V}J_rv%U@x=Ne6x`sisL6|6}1$dX7La zwFs3*!QWDblp*uNk8A^(y^((9$A^wUGpP%GKQY1TI4l;M&+jPCZJwu<%j~*U`seQ1 zFFzkFp@lG_+#A1!QyNzS7;BQN_$C@__r~h8UW{GCg&2wUb)h#EKtBunMs868sa=az zvnn&ubwJKd*zirkN!F%4VqG5B_$JBdnxUeN_BiTo}8O%B3`Q?L~TdPtl&56v=zVQ(* zdkyy*1s@Bu%?-{61@Vem;BUM)mB>4MA2e2TqbVx|zwBIqjhnJVAbXgc^wl?xF(~pm z0oGt6+-nIPg?B`7^gL?5J;_FEWnVlFJ-sBOkF9ED`2=O+>Xr;-2?ppr!YmnGqHiY3 z2*dulynC{|rf7i3OJV@EsiivL@(xa_CDHmr(?NEX^6*z_mu#G)ZYCHl%56hgi<}b( zb>%^0_dbs3>}qf#fXhdzciqY0huNy9>Q-Cm_^Di`t|^#Lhb#^DNfx=;GiN>Ot-&&F zXu>bevv{CIul$mAAn1`z%3?pXA6cnZ|6EidUj0^`HfyG#@> zh_&wNA~AjW<&$UcyE&O0JYB$V%5kmzmejW?Q0S_W%PDrfh}s4ix_oXqmDx$tq2#+Q z4Wcl1LK~?w2PfnVeybwl+Tnh>jFg~&@9mdZMX7R>)znDQS%(dq){*%O?zF6WZBav< zo~Reh>mF4S14&^OQD^T7a$Ns5f0HVxvH9CLN}yzM|J{u#9(%lj-MDOn#&4bHUV{4W z^fG4h;Z}fa)=8*FN#+#+{$-}-PbV5ryO!;_kNm?S(`FulfwGt9^%@0(kDzh1#?ntx zn0K-oX*Yi8d`K5~>Xi5Ml+wl@rzk?Xm7>{h*{9}lQ?1kOv)Y`?S;sh?BBppDZJiIq z7nAa;-{+nfdQ#Xiu>cXhtrmbWEE&a=#W_YYLHo{B59g&(r=GcG&4adCOpT>~ZflY|D`n=INMCDDK%%!rkeQrN`rqbl zXYL?^)h^G|E)n3Ry~o2|Bt}n|bgG`YW-Rgx|5U?tkD>J$6MBHgWHR3NXA6uBrXdV5 zoV#evZV2S=Y4THa+n$i>FYn`=<=vi50&V4QGZpl1o{c=dM711lD~m6 z$9>JqZ^)AtM|eLQ&o`F}U-~p`zWex7NTW4yt1^Q_00~lqCW0T8aS1kBEKZc6u%go5 zL?+e)vp&c{s^s)Ew+Rkwp=Z!0v*AO0AqGV&pyiE#-pxRUK@RwGy#~LB%pg4WmpVoJ&>d*@~IrcD6HEH@R#MLMFb5-#y@lala z-f|zA+&3x1_Yj4!!MWa(km}vmf6XB~8q!o3Qg_C+vM=v#xlD2d+tCg? z=H4!T4%wg^?mZYLlk%1E1vxfwq@CLYRbMDA>~hAPgLcN4eI?HnIOs#PT`ZFzzA>4+86O(tgd-yIUE}yfmU*E6>JVw;yeXmpwt>p)&60Xr5?x*FPYn*m9#xp z|Ane}bU3T{bw`5_EG=4G$^R*shl~XCTY3OdxRpi?)tC}g?2|~!mlPxX3FA^y>>FGg zFn`#&(H|M9e~`q3a;5gPqQfyqp(^LMw#wgBK&q-FKk~T$JjJbMtr4Dh?m8`7Mq#^n zJKVpQX)C#Av0djHuyme%DQJJX?_~0)wsNmwcCln%?8krSD)T~K#Svd{ZRc#>7F&YO z9<^9K(cZJ)tgVrz^xWaa?O%u?5C#Z`wAUj$|3{hp?#%M8=djObGiS3$ixe2zn&Da| zAv_qWEY43l-})-k$;6(%vloCA6j-H{UG9dd#tKK`v(v5lZ*wbhlT(1LzcJlf#w-Cu zyX_;a)Ha|}OnK=|`O-^YD{*vm!l7n+slCqL6Cmcl;($Uz_W_!XKf}l`4{x*GW+iRR z8J;cN1b&wYTE>8v{6HgAVbYcV{);AEj$8K59ET zdXbrfVI~SXr9?)7Ca7b3oi{05h5E^t<4-3rOa19O6thRI7p0XLcju@W(qB-zJu~J% zs{siCPv%jIf^t9hAZprMQZe!eSlLs}Nf{@eq4Ipwk|S^Hn9;90k~vB(w;(Ri!n+SJ zSY9#WAz>YAwKkyiMs?uhf5)NfAUuP&_q-=>QLNY22X(JxfLl&KF;VtSx8*~Zk}#LO zv2bc)*9+(Owp`rgjs(XTlR>A<7t^04%GtI~Zx)Zdf5{IdxX(5!z?6TjdkBg0XYJ0{ ziw9liwy-8hUXzBo=6JIFqES8oeG+_58-NJ9!IP#<*UVkGa9(lEP96TXPlHovYTq)J zW6gA4KPO@W>@={k>!zxwGcXK(m&amX^UBGt-7hN7yK<@A>EamZCm`2uCQK2gB3BZ@ z&{%(rR4a$igs>8@_; z&&HhW-`aoNh?GDEaRW985%?Ok^E_D!aH_!Vh7r{szp$~L+3|w=bhcS8M4(VFs;1ae zzP?43+^8tq;&*)Iq{a+QoEGZk{pcps@jKj_R5K*b2uBp0x*e@qJhE=CQ?1s2I+ z7N36_l-u8XHzvH^g$>me%g~x?BX0UYD!lDgum*H!$*{Wm-DyRY+;^Goi`6~Y`{ofH z03L6zF1??9g1~4k z!6My7X;v11g-V;}ul>k#gQhm{Pml@fi3TO2Pz#&@#>p(iv=hDZKy9K%`i;8^C7v<%E^Dwn%vbjXBmdoTKK+JlkWOGF+$q;{Se0dk<0lk$d{eDV|nTN@CI9isT{ z+}MEuf8pEv9671s7owLLp&J2^aK3?e`J>qlRns%06 z>PZZVm+odQmAnC78pg*1Vr**}4q?txyzM~1Qo9=&D*TRzb1!4?t<={2&IYgfLXdNv zz(3Cq-A?-@C?KE720F3`u>Hq)7YN`MUXL zBYakI7|V=1^|;GMtQxokZJta}G!pJB)T8?L(qERljiBt?w%;bIz{K}`u%VoEZmo!| zL|7$Ctw4ZiHh$ex=n;HS+=aRoF=Q_EKWtr9Sd?AYo*80>9(q(_XiyrF5*R|HK`=l% z6zLAhp}V_5P(hIH?rso}?vU<=|9QXT|B$n}c=q0Ft$W?UfJ`v{9$4wy?Pddkh`w5q zv;R3FPG5zd4nuSIL-T22!5ctuI+ni(p!gnsl)lb1X}iTf>9%K{_jkuV?Xf}Fnb_+= zldr!wezY~7xl!(LWcUw{I^D%jI*J9}+lh7~FyWoSIGXc8(`|K;kqqtX7XK;GT4!1f zDYjM1$7cP|h-rrU#CiLK#(a_LEbCz=X5=Le>w7p`+S&OlN*c)&w99pIGlTY3q!$uD z?4bH;*Xu9!mfwHlT>2qdl`X*|aA|mIJcozme#=TP>2hiFRsZ40&AnTa*Y4;we+?dX zuGCbLp~TYSESEnQ-`_Tlv9-s|Z-JF!or9+!dm1qguy)E_e?I;;IUiuFrL0Rq7D!@z zY(c5XrD`XWvXdkYqGpk3P@aUCsz}MnjQ8Uf{KHq7FIv3Yg#Co%Va4}FM$3PyGhixH zPvS&O0@@Gm1__q!FAhe-Fu2UDF@_qHbfna<7eJp*^H(A899j%MAV7fMzJbn=+Mdiq z?YF5n>lqAe6$P81Sd?%X8GK*rE<59D&!+!6)sU0C?k(_`t_rf8)EL@uP9Dkxdr?dh zOVf1k#*48Dxdnt!=rlwp*eiWV#=I&spYZ4yPH^A(>c$xc5i!9OA+?`0_VXf_8JBwr zo8d$lk;22W;!2k?MRI%Aq*=Ow$2C_*tMRp|Q2gDueAbV($)l3oL4{wKY7E}8tm~Fy z;j|vr7^P`zRM!2udhvwTYag#I`U5L-hq%X|&;4b!2T>x2J$fW3%NWb2n_?ak4=W1P zC)ws2R^+R8f7*|;{rxd>rG`1$`y!%u)sOrW?Swih|Gl4s)?RFPe`d$MV!?wtVR=Rs zC7zTO`U2w0JRRdVwn5q@ZO_UAg?Z;BxS1?PN&~@@m5N@jY zfyL5hW=|anFB0XCx@qAI-$Xn4Mblv`|~; z?`T;2M_#opuCW8>KW3hQ_;k5%Y{6Pzof3vU-9yn zYmxWkHs_(4#d_WRIUdJ%0C2F;CYd4SOZvLhVI+s2>(x-WiYSg8Zfa{knipO@!N z>7h|I6<)b~>s5i5hr(sg6p`Nwu6}=OVhS^eqlLJ$HK{lh{MZtdRD9IbEV0rMHtIYg zGw2$T!)QoB^l<|*c_T^jhRIT#yVVOV7x1H%_53kW)tQJ z4YQfLPU88PU+xoniJT~%1g$jFV%}}x9Q9@o10xaz1Mk;T0=`4KUS}vZ|*MB(`#OA zVL!a3%rx-)pup+0irgmVWTPlndHT&P=%c%;@)|}FAxHf>}=mm8UDzB zND}#vefk2a^L-+9isfn%Vk?1A zI73h-3V(~wP+WDy^x=smkcq4Rd0rBCt-=5zL}mw1_&QiG$7&D8|eYtIORSZPUcymwWret^eclh@v+ zPrd8MQWyTTNPnR{tiVIt7h7PkhZvS~{-eX~aFd2c-Z8t~`)-CB=yHDLd<0&0rY{iL zCIndEETZ8rih$&&nHssjF)X))&l!+4c`P_#y48H(NFMO0N9Y6*JYwZ-sKq$Ty@OinmFxzKE4MFJ(v7&$;6Fus67^>`&H_+{BZBm?_X zBs{!dbb!+xPsP85lv9KTgRlOkocVX1V8Mb;Ai)GJojjO1jLd!t#ot1A?tav{w(xJ4 zW)f*AJO{dSLp5{oRJbjnw|fgUgRY)PJXY@9Z|%LoTt5`krmGq?ATson*>B9D>T;U-HGhC>!0Wcq``H_E(v{Dm;&*xdE0WKx zhN{07{E}_hJZR>+-BW20CSzf#x+`AiszL5x7B0xRPV3FUE#K6dS0RxOhM`X!9KT7u zNPBmJna-}AxO=nx^_~!k9C}WN4BF&Yrydz@J?oYfQ&d=d&i**2oKtofg%-$FsK>yR z{uYT~+3|0iev!cae#rldPvlc=)NXbZR)5);uLb%o0DJ>YMhrS8T*#12g8Q&rsS3Z` zDe7sfPq8v;0_Jf2Ej@cT)4) z(jrIX8ch?tgg*fy`%(ec&_;0d+UISSXN?^Yh@p0ZONjXM5l5L!=`D+IAr-GtaFreh zzao-_fVV^IvB5(?mo`;joTfQlvS@;dDZmk{T>l^JVya>sy3-2_;;>SEF;D!z`Lea6 z4qIA`6~VMLx}eJjXZlC`wCsdbUw~$K? zwqGYQ+(=^>@yQ$lIe70hMc3(iLe3JZ>Uf%q~vK7eB~M8=^zjC2?SeFHxvM z>dl*yH%S5?@V)y0t9^)15*)9r&P8@x=qD$VS^9$*t&_G}qqZ(~TgFdOhPnHy#p!=z zWmnd*A~w55UgZ1tzRv!dSQ&+J)74ch#`QJ2(kJj-Uau0ZVC^HjK(}0IMU}r6*3TYV zX7w0eqth$T>iZ7k`kn*Q;OrBAQZ z{~G5Q0LbS}ik$WTb?J$J832T^f7jb1PyXMDYTds(_EPND00sp#TR%hho4 z?1P{9wcgZ>Hin-N0z+2l3&6(J9|R}+@D*Uie7KsvJO;g0BqVH z54&bWo7X{5zYuxYl}fh&JDLnyWCTQ-!~>+$c)`FqlZkRFw^E`@lcgZ2PC}|OnQI5} zE)=-HIT^<>4&P|6UHuM2kS*>gs z7Xnx)BR`9@bY(Ce))D>FlxlQrYK{5ZpO`OM`a>Q8P6w;z!Lug7!#eS*!*5BuJus6j zd{vGn}yYku*FwxcPh zKAh@sT~YD75nlJ=dFXW`SkCz?8i$jrb*e|LU%Tk)j@Z6rm=a%bZtN1ge#$ujv@1vK zs;as!x<%usXXDFTTmVdm4|7QPF{)6btC1E6xLzB8M_E=ArVXuaxOQrO_@TF+=n#0w zlPjft^U(&8roXKdgh*<}$~$v0Cr86n6)lk4nClGHOYK0=dCR>o5WLkmd2V%ns8}X4 zBIgV!aD=x_D{MJ{m$yr)FX82~Gdg^I=;35<`I1G9f+9lnV{?WhCj$EFMHmj%DHE>S zGl=1P6gW$R%mqn46Vl=*$^YBFRkerNWnijJcSJdNzDlRyzA;wLh4{G85BIXuT%)bf<9d4hF@ zhWHY{^V8nt#Kd~?R+XE|V~p48zWzyf{h5w8?Mlsa*Z zcC2-A+B+tHr=16g%5i@fY#D+4G(zqNKl7vKH9DFqWJW!(v#_xnZz6ZvAfql%GVHYi^Zr zY>295i?L3O)Zt|zi!!x-9_P(|u z9bTLHyf0kZ@yhgmzxz`Pk5+&{wa*`Qc4N);CDj<;OwN6_v2*1{&0(Qe4?=ft5n zbdGHakVj1)84zK)QCQNN|1j?Y{4+HHJ=B*0c4URpvSiSL-3NSiE=Ou#PD$tN?&-(PI=uySJ8KdxqeXrXgKMdJ09Gvn{#kyis(kEh;+eN8>U zw1mM<6DTt%jRu(Np8cd_7;_Ysb->Fg zT1Cg`jC!TsGCtFx%XN`N1>ugrV?E4;;|qV5m570!S~AR^Bo!W9uVeEUBH$gAlK(gn z#NapS9D#klOJUN0*?YZpy=cRl2km6Dua^-3ylxf#52`tPS=EpdhMqpA_=G3CML2gy z``~doQ}6lH^A!)c6}iOc{OQ~66fs<&A0pYpSB2IW9qk=vLm+scjR@u>5!vDg0qg3W z51;;)J0)7a*El8go~&f7@ZJ2gSU1`;t_hnEc}_EUjSn&|k~ zI7?Wv7(sLsVRVE;D?BX-Ec8YThi*3&D15U`klzSq*@E8^&hKueCV?8M#78u|jo}XH zlr$iP&}$~6g?cbyp@PM(&mcal$mc!f`(pxynRxPFH)5EPi0%stQVp~|dlX!nEcpBD zVcQRVhCWyCG5u&2mbf87-+_!ax3!%50q1X(6BuEYx+Fbe-rK)D-o4P=RL|YIJ^t&j z`c^S0G+HZO;eNyG>~pki_>CveW`Bf-;QF=hC9wLr2LtuA^ywojDQ-*%yJ4XfADxR|6iQY0L^E>h2%5Q zuE?$3IIw;m(Bm0x^i*uBy9SSkp-ZYYCM`z)Zo9e zrv@pbY7hc1gCfL?{e1b&LWbSMbJF!4|MUH@A`@ZO?HzBx;LS1ZU-V{ zVwcI!+UDp1-}g0ChA=boi+Y;0O^XZxEGMFyDhqh?8^YcoFiPIsrd1BB`LFwKDo;Ux z?6Xp|#*z|=bpX&ZM(MLVm4X@oGPt6lrnM%Y1DkQr(#wWVD;EdA?KxoX4Tw$@@i*zH zA;Xl8Y;rX2KSye|5)WG!ZDk3?<=@?o#P zhWi`(MVaJhsW-1^d72-^Kdij`Y=_GXLAT75h+v@R95&fi4FTQO@`^gFQr6RZnP5fkpa znb;DIQ6SHLi{CKs34aAR185Y?x!nfH#DJyApBVs_+q9X6xc64vQTt=HK`!6r3T@D17 z9-b4;Qw^PtJmj@x#U zC-qcN6O`TZgvra-@1d~Up#(vkvI>O9%-ao?o(a?X`Rde(VdLTSmlt_u)H?SN_4h}<{ItY^4$V&Aq$RiV+Ly#UbOA*# zsUjLy-Z>fEWSpgy9Q+-gy`?Yz8&Tid3&1fT6i7~5;w^@PiE?0$*FE@|eWd$|Pw|0| zk(kAb(=w%$B63BnM)}aT*uEZjjJT1mxGtGJ*u5phN`i?aiswy^uuq?i0Oj@al`=6B0r1Obya z%jL<3%mKcsVYW5Sr!X`&6zyswkf!qnx-SeUtWwfcFSq%LYt1MK;SSIBl)IekR? z9Z4yBk}h;!mvVg#K1-|?g8GZylwABNj2vB%*AD!;cO8_%W0nz3nem;n>q|~`LsoIa zAr;#I8+LGl)-F^@>#3=d#LFnxSWsWWsN@kXdlv7$7e6%?j`B--t!F-vc?!U{DSYrB zkj-|=I?c146-^gS3~Zu1@OZ05JK<%55?lWEIFE?y^MEYv_f?e@X&Mr1>$=`dn3h0;YtpXcYbzT> z{9R6=KgbZ6p%|@)zwsyxxZ722?z)X2e)6*(8<@P2yCqADVTnmN?l5B&9hq5d64oO$ zcRl$uN#MFunG6hf%EPd@r%|6`75>_6t=X}GG&ae1%X~yHUiC!Nes!!lIuj>%S~xTo zjr#mB7ApAFX^W{DgO6yAh1Y44SKIwcV1ZdO-s8)SPbqkx;={ofo*Q)fNmYWAq1d$M z+R!L}65jlWGyfuL3)aE!Hos-fq9=c2_6zaq<@(nUba;FLs8&H88I>)uJI(zy*U{%f z2_&aRaZc*>u}%ohon$QOvixW#`~88Wo`AJQ@S`PD0aGc-sCb+Lk?Y`iDZ}J$1gFof z($u)sKS+!q5k0fCaS4rrXb{fK*0b`AynIr?@Bd~I274Qc<8{JrlkQKLbEpGw;nAI^ zY=z&sDXLe`!t4w0iJe}04EX{!znazYHV-;_Zy4s(aWBTq_1f*7JYJ=?37Z9Bn`LB@ z5><@qeawk0L&&I^PW<@b@Jkp>cM;n3cEtEpT~d%FEO)koe($XmS3Hb0e6--f|VToZuwWaWn2UDjq&CSjnxA#O|9JhyqKilIV zDR+H~3sm>_-lg2OG)R9?pW#wh62|iugyja?sX1t1W`Jj-B|-@55Jq?U=%)oE@_CXp z2{gzcK1*E^HNmF|JLwdwl#$(v+c3;FN+pJ;k+$Nt4Y#io?JFURHAmKNp7obC?2KX)u9G-)~%+6%n~=b)S^pL*Dt z)_Ayc<5QfT*QCdOd9&|IjN87;m#RPsX0oQKcsr2#yI}$olTBdr70vTlCs>ffjM7j! zzzN7N5zKDxnmN}u_A8{-{zwep_(sQXXzb5^k=BD!p-VXPd0imBcji~CwBxl!+XD@< z_iG}BQ_0ebb|&rQ%;4-uW1fh0J)S8lpH?5cb16Eqo@irk6X}t-22hZ3b@s_uWpi;kc_3;{7B3|=cU1y z*~k7ii^6oxv)5ea@bOCSEjXlnPmokUi38K2FG9fK=ge}50vaeiJ#yDF=J4H0 z0C*Y@q-@-VlX#%0{Hy0CZf~gLt{N_qL-95-0sn;bj*o$Umz@XA;tj`>qy~!Cmx~vN zt%vu+dXM%!J0$o(1FcUiG0_|TJk}VB4cy944rTT>!6iI01_L-SvmuNqBZ6+mmHdH$ zV+nQOWH6C{bm{?j6$lT#OTS+s19F0PLW>!Or(3tW4-Sq_)8eGC>qI`2$zmw1(9@J8XUP*e=Di+c_{zMvu zUus=Jp?!1=B~Me^7SiH*q9dyD`ZRDZkM~PnI)$OAV(&GW5?}k$P;*pI@Y+|{GPNQL z)O|Acq;#pJl2Of2$!j}z$pOf4E27**a2huPZJa`W9zq5P4#Lu{q&MNVe7!|`N|@^n zlU3D*sAjNf@i z2JcqmtHJ_jZ~Nd1!wRq=s)k_iSAO4PCSPkF_L6}|U+Y;R{!>5zM<3rIo$v136x}xz z1;-7Ru*us!gWA)xXByo}5LH6fSdtFMT_0jC6as9UvKbA0$eJsFO_=fPPVm2z|L?z( zuk69gQS^)f`Ef4EqQ>Z;rMCE1>T=CoH@liYxQr3LPPZPsC_YV!jVoXJ4ttm*Xx`w_RdqQN0V*EON z6OIPqLk<&8yKwrkL$TFM?@yz;Jz%-JcU~&k)BZ27~6!3m*n}_bUuJ}Lnx98n!$%h4xfFCU*x=cO;U2_88_O?AZkDO^x$hQoV$}& z1G+O&>;Gu>h8i?ga=+PMQYY=8#b+R0r5+LER!Mpb}rVgI^s zs{H5ZIsa+nVDtD*(*@5|#38lq-FoFtz5}hOv;GW59Jk*qaE&W>xiFAchoE9TJM`Fa} z*wV?J^FFM-u<}GEgTwct@Mz0kPdavx`Z%B~0vedLVmfiv=AaCP43zjR$eMl=iVy9* zSAV$Qt!Xj>xdEOmWp(NY|B>VV#hTjc0ATS-mvK9_8*~iQPjAc84La!T0)SPs#m}(V zp?=u8-E>Iq@>k5*uHiud{R*k{xC+w$$2Q-2LBoKt`5C)%1gmr1!aFV;T>ayx-d|d7 z$x$%oK)*x%xYO8*G{EGpB`2D&P3_5e3V0v|P^WOBslk}Tvmid*Bnqs@Ad>g(VtnaNvKCQo3sBZY{A0@-*jKcI& zv#yG{b+KmEZvQa~xQ-1GmI+wK@!!QuaVWFd7=QdB&008kh*k)s_1};VWV`q&vW%3FVoG9?Xp5=? zC#Ykce)Pk@m22VYKTQ9LC;(&_ZvF;nIm7othCjDB!i7UUe0W$WR{$_g272|lSSYhP z`SmSg%HY(aId?uHt9XRR;^EM^%};>BXFxuPj2l0j{WhtshuTo z?ALx6OHq6LHy)z-T-yJ5#NK*gX>*Z-C8&bdlO<|(@pZVt20wncy+zlRp>rE`em!(y z_<`|n!(&JZK}FBAa)w=@W`$UVZ*Ki6Jah-I8&E1LLJjhm9U@H+8QjRkF!6?ye}{l+ z=7l)@sS$)V2Wbl0GzxQjl7+305oJr|v5^mU4C~&Z-sK^T${DsTR6--(iROKNI{bSs&ipBFTQfC~6$yCd!QFd=X;-%oD#sI8wr0b4eH?sd#jA`&)scpt{u*8f~nO>F~JP68wZ`Rx?z<^}mWB!#@Y7!5$Z1X{-VwQ!e^elqQL4 zU!IaDCvXyg4f2y!wl)9A?uMa9j0V=qngB+EZL(<7|BzP=r&(;tV2soRT zC$R@WP)lC&kx~2%7?^N;r*C7b=HPv6c~rW-r~Zbk@94MEf|#kI73Hhn32fZ8)0ekCHz=dNw-CM||E=OiG1h!VcJn3x_`V-i*^&2d?AgdXo zOdW|*y_{vcHf}=_cC-wA#qg`qblA?ozifoy+mH90t&=~MDl!v1(-k^fr-*OUAt;`MH0w(K~I zR9xaQ6LLBo@4guv`QF-=NII5x=K~zyz@iHD7k@gIUiVk; z_4-6We-Hn7UCbKjn(~K^s~a2_iE#FAJ6|{lvaro!9l6a!TxIPsIhS-7(pK5kx8QPm z3HG{-zRTOAQz~k>(a)^aL#YXn^6D=As-_nR(*W{XX=TeocXZjMeUhtP2t}`Vtf;DU zFD8r|M+nvQl+_mAbq2T8asD;>o824ES6DYS17vRQxBn|+KQzW#E|dK6eQUU08MSpl zc(syLL-qy(v&xd( zT~x!cves^$*Zg$I=f;?RW`+BT(Cz0G=rLJXr`JeLE#WVjZtmB=z7!&XO*URbhJ>cz zZsrS74>_EmoU#~ocBGZHzo5AMA9g;f@&!dL)ZFz+GfB~H6D__-QPcZ}`~P;r$^Qd^ zPiiX~egPoxSM#!9!bXfs8u7~vae_#X)8e+3-qo`L2Ve-4c{)|^6C92sE1O+SkL3Pc znYMp<@$pdZ<=LkdGcX$zGats@rGn}!IW8x{?3E1z`7ZVvV_5*VIu3-ui3kqT zh(3EMRs)h)5M*NF@$EcFaDjjR_YuRH!c<#u=+=DiIZ*tch#Tg!5pFaniDO$Vn1qpF z2+zPazYhkE5H2UM3}wde!k~M>cuZbM zA|56x@C;+KR6p+HCL?w(#w1xsC=$0ex%VG8$k0HQF<@xb+LVq?wm+GR@gm|Ti0;Kc zLYmciJD#g3=KSlF8N@%_FWAut$hF!#XiH7ooQZ147!Mi}$<;UXpD927tJ{f2Xug>i z)DR4y`Pacz&sBc|1fm+A zhY=kI&i;U(NZbhB2zv(zkoqWXoU{JGu%hHhAladNt%$7ETZ^M(;J{4Rj&tkLjwIRc zQv?1=eaVp+Og>xuxLiKQ+?F=MtTp5bAFg&`4H44~cow1aNAE)3v5+S#L2@A_!r+4h zP>uIK=R2j(RFTd{8-Fx>q8A^cp%Z204OLwnvymsyQWz08$IN{D#1L#qb!0#CA{6%9 zNHiCFHaHp{MSNX;mKBRBw~&`P!gp&~Z}n|5xbU)OhpFq)2|#-Ot1g&GyW^n~HaKb$ z(G;0q`~C3tHWcl-ET;B_Z{PA-_tNTHpu4CC8~zi303eV3O4=aeWI~&JM1sDUu38By zC!ncFfBl%HNGkI*+QiZay|ozM!QfMfMk7f{yGcPp%t)In-|7AQ&d2}W)vF{B=Q_#+ z;>#zZwPotaWVBGer-LH7EJ3tn_4;u}kQGl&ezqm4Rb>w;c_sx^6`s14S-MgC`=dYRPzm>%1*0_+A7nE~-< zLiI;B8*D}V8jYSU*9s^GIJ;jo8vOKO_gq`JTmJ3ksF0*JF|5l#N0qm&it&-K53FNr zKx4Ck90XIDH|#A9P1$hb4%2alfgwUKh|4|UtnJm4y%+$iiwx4n{@4Nbxd^6DSms?> zqmE)lJxA6Upx%NYV81kI3^)b?>fwQ%T&_>v@cL5&M9p0fru5L8B-sUCo#%#6;buk1 z{(!Xk2V(*WRwx#-G(fi71Hq{4Ipx~cHr)7OlaBk-$75c;>G}o|*Y$l+%N#Oy-rYs+ zhz3cQGB8D$$7=KGs$+wP;dU;5Rd}kRNxk6r@TPeqbJAaoarX~r;S{!IGcsl>;!XwE zhZhgAHao+2@(5vfjp2}kx6G8T&>r^Z57RjtpKZ@+o8V2yPps=l)GY-$e~K4@?~@G906{#D7ZMtlzm z>i&kpM?K%j6i8OHwN`Vk3sf8kPROYh(5vsZV*~U_zb|!tREX*`dsW{pPsh%>5v3)b z^(;sPlLtM*f)`tRx~)PI({dWM@kr|7QkeZ9Og{>jl}(n{J8JN$WNv%Pg&*fHh9iY=Ei(WCez)u?F$Gf7m4%CtS`-VsoItdd z2^f${qLQ%ul*e;iUcuoqSvxi0N?=1>WUjOb!^*%P$JKS0p?#G!hP6$a{8nD3{>N93 zZ_hXaG5R{BskM5EPrwi#b~#x-h@sN8v0UBSlUM}Z7&?EjDU(5rS1Fh-01 z@V30Yy>WRaee62sGbU!UJsquU#>-rR?GURbB?vt@bWIq6r@;4=7%Tzjk_~+IU4jAr zDFqE00H53RBlW)%!)@@|)IQS`tcrHmEkDCBL{|^OU{DLR5Yz+x;-J7Zi_ix81>>#i zN)syDP+YsxjqY}e_QtNpYE$j9Z@cu+%VRs`u+y!pZH0HKzcn=ld%>)E<@X@f-D6mJ zDPm1XfAwXUo1T(L7nlYevC?gbwd}#q1c!rO%KD|ju*|L-*vw>5HhV6Ql1re8Vy+Dw4y4GJ+YE4VB)XidL_Tkl?DW@OTu6g-PZ=y zV};-sskn?CIsMu0u6LWiwF-)*KE0PSTs?Rw2V{May{CZL2HN$7EY00|G!XzS*N;->UFvZF8Q%i_a z4^%nw(_9C7j%AuV1?D8bxn>xxT-!Sae5{BMCuJ!Lb3BhX+21?wY@TJ^&pkGhn^E$d zD9e<8&ANjHj-;d427vt$o(?yC0umrI^JP}fC7*V5XuvjAKF%q84dO9*cws&1GNUZ0 z;)aRoHM}csL9>fj6#C_c$~BN*n`$-AWngw9EwIh?T6dyiHQD@ijrE)!t}C`d7#FF> z6o17d=i^pKvQ^zAsjp~&5PCX)O*iTx6J?IGhCv(buM1F{{h@S2uMNQ9Hxd5Sdi<~; zaN`bA1pAs*2}LH&&6u3nlTqOT0o?pT1)a0aWl)NttoQ$FUun8A7a^`cBmsTeVsPR9 z+W)qUw+<=VS~95TwQPS!KG_232_~h7ORj)Ej<+C~j%77!mQK&;yV1@JUF7$(09*pm+EuvwoySp26ah^nGJ(FV-f7y1NvZ zE>ILLcpxb68i8O|col}=h6aYJ{g}7Y;~D@!*=RtcLaGZ479$KV|9qHxqU-R5mp#6& zU5|>I=wZ0Kn2@v_&4CGr%7%tU3T5RSSQb;a*rrwIQ?I=lcP{f@H7t4E%q)zR00W6C zYXAN%hKZSgn_)+4j;*KzB#v9EeCm3_45EXKq7v{58Q}x^rWp5vMBSA~lrZp;JD>vbyL-9m}1#mOFheC5srcvg0@l zH7kCDpu%;_yrS|Tx^BC)&tSO-o2qu!uj<=xX+Q8%5j^{=VlMuXz~o{10OJuCXS{Z$ z6lS!V2FZ>=p^6om1@Rv2<{TVNlyi-d_-y{OWJ1pE^2~jc+yHba`osIMw&)q=k+s~K z_4ESG&hI}rOM_zTX$ar!eDCiG{I)TrW)Fv#m!*v>m+b>cqV8`??bn*dfOiSMa7G~| z9w`AHj;zl+;^_$)!Y?;mzaj~b#P%s+Mm3-^3ZBH!87C~9AwdeF4o06il=CBSc4fuj z{5Kg?5+3$1jCcc*9ZHs-D`b}r(db2~izv%6|3q3G&!o*>#a%mK)0FaShhs(yM>z0C zIh^Jw<;Cj677eb)3l|5X;Y6N|?Hoq;xhYj&es^V)5t4p)@BL9WWOqIP1_e9{FKWr1 z3}B!+7nf;A5o4GK-KX~^y+?pY#eY~&^u);GSy-l_qtv&s;PCI(oJ@=rFw>R$w?97h zRu$g#dGS-?mr>4V<0DJ~`u-_(?{&J%e{wI<%s%F+IKa(vF3sXs&H-u`Q+iKt)T41B zVq0nEJ1{uJ>pa8N?)wNTe}8LX`x}i0jV$Ye5q@Y86$BRAb_Te;LKr~bqe+A(djN;3 z@TohCR&5}v!}j(t7bUobDgav(9)H1@%mdrmE=vyx!QPe=UVq{HK)`_ObauR%fiLcP za&(wG>TmONp`mddqdK6Q87pk36iHPBI3auq@)f=YD7hb@c4DvhvWQTxk9cOppm`Yi zS4U}l#0=DacP{nD)@g60P)LeHJZ)GDsTC9xklmnZ4Rd`bA?HzRdvAGP*K1k-bSDfA zp;-~17$T4W=P!?mynT)rBq=`3NW7p5Ceta2Khlp2N5-=VGj{!#k(J3*U@DXH*t;e6xT}~Jtuv9*1rrMTQ2gmmx)6ib?MqOO{NbQi z5WnGJBphl8r`uv@r_)ej|JFJmN^zJ}3qVmQBZnZ?r+-2;m>PMq?xXfaTzxJjO7UdU z(eP1iwd+rKFBM1Rg4W7iDa?N21l@bAE{7!EuG+M#+!`ZluOD+V{I^Ql{cn{N;qiQs zi=6?PvY2R5adOZ?xp;lD-0^VzAZaV#N@-RGn3H{72?khoRe7iAQCM^6&B5t+Af zqPEGDfBmS_4HldD{1Z09%x16OuGmm=!89B3p7(mWakp6C=oYC2hWgj zM7Gi(6;s9u57gNce>E5ZezUWIxiT>I-lJN(Mc&&fa)Dby5OfZ*ckOWct08}Q9_$72 z!(znPlm-Q&B`IZLx3?d=IkGUmP-h zuu=%x&U%|VKXqhO;=lD6nSu8q)Q{%Oqo6dwniepLtV1uw4*FS8ZhmLTf*N_dYmDn% z6#B@K2=rj3gO5=E1~vL-P7(@MugeXkBuPv-{Hy-#j|d7R5b<1rz@S=?cX*Oj^pLH@ zu;8-3Ipv?%ma-!v9bPCd1D@8!=dJ-_?_CQ6b8QUMzLx1m-L*}c=>S<-C;a--g2sX# z4&jvNd>a67Vjw#<>1Xa4=}B6oM~u^)34ZqH!S`)VdlyM5oAS$7<0cC?F9mQHg*Mm* zMt=%|mv^@6Bfd@6D4Y1z&}2ii7y&iQpKS*XiBq}itk*qNel!vn`hecuTUYX8DL`U? zF4xOJ6zr{_r+Y!|;X|I@Bv&`4HvL`9)mY9-P0`AQKH=}r zbTs$zGF^r`Js@_1Ty!4X*>9^iM>g>ml#7gf{U7RdkZj@l3b+v4d%esemxP4_X+#%| zH$dXv{gTdE0dSXr%vdshYT)I41eTCLi;3GwmWU=Oo6;i#rHyED-a5*CuVOt+GG4^i zK9}pAa)Y*ej29t!mG>aJWiat;s7VY%C* z&%(En2ItbDV-X{z0ZqPwX4ZV0{yb2ZmpL`Vf|+2X^Wo1Qj7q2^CO@&EF^7K>QG^LR+7kHa0T>176X-}Ki83e>eQ%LX?Cf? zMdV-o%U$?Jy;@%_>g{s5p*UJdSLJN%R)E*eK%aaqwOBO2?C*8C&+Yyv`ke6gg-)f* zvGnC`^EIWx_9+{2#;fELa|~CiHQdgXTA5N8*D3|Et&cHDL^F65UeydWlk3GJG^6xC zFgB5cpN9!>h4U*itN6bwE0*-~`Fa?9(`VCqE-#7Rc<8SAVue?T zK1*iW#}HJj?jmMxz#PlUt&o+aYXn@rLk_OWc zjQf?HfDW^my$WDnnXSa}ZpVGpzh?f05$G+yi9FcO5tA`o_ zCmpz{JZe;JK-IXw)09F*loLV>RdX?nv2tCWLB4&+l`1}k2?q#D0(Y%mT4pt2s#2|J z@2V~l-U#$-(D>f!)&J^1|871IR`Th*rL14a8QOG|`Ji62>FF`Qo*!8H}I zWb)t9%%I#y_dHhxtOKY%!xJJyX&P{uj)0oPd;B+pgL|n345VUlyiq!7QTQ7yVKU5a1_CS9{uufXV9nqr&3r? zi*}l(cv+GDKQT<*g0X{y%Y!G@M=4hYd^c4RmI`mE}8^D<<^F&8r z@u`VtZ2&*`kUNQ&;gbo!njNR-zDn7;A6)Jx#a5yej=JxLS2}9V2ny?{5Nq^HRPCRg z1&!i^$4j<|Y5_a7yh~T0(c0g9cBmk;TjzZH(BR~U(k~}@AA?G)6s2z?530ua!Rn_? zD8LuuN&<`?evUAWzFQ02lL@`gv2*r5#GWs8`uEg(AOuibp(pw}T*lI-)`dFv&)+$- z9>Q9$n^k4$WuGk?GUD&h>KK>b!5wc8+UJ5O*@Pvt49AED{mZd*13QIHiC-z^5z13` zYS_vwp~qj{j|g9QDFyDwsKc-`7u|NQBXche-`ONC)>MwrbvzWjo&w z0|R_Nnu%XB_?(AF4|i~@ZWU)Lvbx8gz66gpe^Ip11+zWVf5rH2ZQO@vsV&}QQV;f*|tYi_g5C-0;3pfp1l(COb z*XKprN`32DXXNqGaRR#=x%+E#i3?!3Odk=!8Yv8eCI0@m^gy5S#Op1P#yt~TST1s+wRjO8UzaiuB8g^fRL)A76fdo_!E{H-KV6v2ck6K?d#BD zWkDk8)oRt#2zglryhWHy79K)g`@hoe!>4I$99P2v8=}3Np=fC{Ej@d@sXFnAoLx}A zBNJYe1T2?eLSo`+1<9F$-_64W6;$29uY`rcmD#=8%%99xztf`?9C`<9@x=Ag2{l8ATv4#81)Vs&-Y zUW9Lc8SR9!EW7Y`NK@$(Q`a@CA5G;{Ip`C|)~h~=h-8YkJT1Kn`t%M_)J%y-{`o_q zvTPzHS^7`J_uOwkNmdd(T9e7KzbUPG$g;kt)s!Y9LN--_&`LuS( z5eQ2;NjLTLvH=VJ0;-n8 z*NdL(0AD353E4zj&?~;R@iY~IxZm}3K*ES51q9c2A|!fMTvG5pjw7%(LU>!sW~k9}^eWd&JmKJ3O^i~XPd{OEs#p__;&<4I(QuXOLd z`3t2=86aVy2HN0`zxu$BjTNvj@zVmtLY?T7P;JPLXyE{T>nOGyQtc)yymN&$os2%m(JGcz!051R*p()Pg!(YKr9S z3q#Xa2WLXRSiEH}*;NqEbKpb-Bz+Mpg~&C>i~k0kEe+TxS{4|dErOXt@Yuj?cp*Pv zuDM7$#40mKsTWH6E@%5dRF4}EOyl*RwGEFm&BMv#5bfrMlfF9!cX%J9e~k%ISvUVm zNDsfMRT9wY#?3leYSXo2kp-7-<3)YVp3_`7L|~#|PRc)C;30O@-t|dP{vZKBUUi#0 zUU>I@YO^p1w3m2R`q6&4yxB@Yb1Trc@!?V9pN9?}lyuQ?+~VvX5j?iH;BxogJbK#D zzQfjvxk3Y=Jw4{ET-BB}8c(_xMqb{G)y6x{@N^3nF%~~5 z?Ia>Ml=a&FSooK+O78D>?XbMI;)hnUBYlIO!?+=jAF}+*hiQ+zG*8Q_o$i_N!}eA8 zIx7>;x2Ocm?vdRolGm}g#!(#+uqpA{G~nu=1b_Kv(dfv7Z`(ivJiePCcMR8qkBEq{ zTA+-1xqe?mP8@4h~kE0 zkUMC6K&fLxA7z+_!S*p6PzadK>mBpeV^?F3UV9POyP4 zR`2jSZ`;w7P)|r_AyC0wGfYO~;#`UU%WGxEUmaFNH8+DERBHi z01Sgh$qDxRIYO)+-Aq-GlXtx!;jOQ>1 z$mf^s_Tqb{a}*h;#d0H!@qsA-(OdHJKZp8Dxy-xzz!{)(Q5SWiox!vxpphg>dPwUR zU~ZYLi7>P+a1^}t#_qM~qm^fMflQWp9Bjov254S5srv^ zD1`U30D&JKYM_9i?p)k6aZ3Geo!}Zk!0cfK)k6QS9FFbG1(QH9G8x&s~yk;>Ox6Hi;Lg*eMvJaJ4&N_`6(9n z@u~Uu2UKLL$LGoouG`(I#g30R<{L!0J!}G_f7~(;b_HcxDg8Uy{eiv&!XUDwjE44m z8*f*{2#Xj^UgYOj&I!+YU7RiH$=UX!CcZmSLo@0S7(X<6f;UPr$jq}Ar=gtGGg$gg3M-pHnpTqmQW7|1q7EN6mi}%lS!4^SgE{qTh@dTvnb{**E+~KKD!tC{O_H za0(`guRynmcK&}Gp&LWEC~FAid~fv$A8gg z@5I&XHQHKn8SqJd9Ps^~aOwAN%#LGMfW-TyteqdPhGIU}b?g2gNxiw|0$;nhLs&Hf=V~D#4)#RCk0Xlt8}5WQZbg zsUF^DT^yLvnb3E{-2q+q(oBkB_Gl=&_4Hr*9ny!y5u_p87H3O3COOAgJW0Y8(9fS* zOna-SflF-x0pLMriLG+Uu7MD^HD~kI8VW&(LhEWkcAi~TmO~M;2Vc=Mx!Gz1UBk0l zvKpsX5rz(OBX1t-pdV8sOlMSpAU705i9?Nl8%RD}n=h8=p4{1TZ|vBhQ?!4Y#D>bN z>$IIjz3Bs>uLOzmtW;P+ow64K7IVq$tOVQ=s+6QH@-MAl+`hQRF+IC?q*r4#;KE+$ z+ix5d`{FQGKj{Ic`sEpEH6=w@_|fRMzOG#ZWc4@n$LV*^f2Sww`VdE46kR1WK3CYR z!Q;QmxXa|b)EHzV&dr!1R4FgqKxsMpm|=?snc3g?mzLs0S$QvxaH?P%AV~%R-Xq zWRy)KHd2#1rW}hBE9Ikl;_`s`a~!5dGxu{nA{aBP2GQbJGLXxX6DacK?%NXja5I4) zgAa42KEembnIg*uxna&~E?x*LLB?FgqVK+;WT!sMcj3iQdZfw>9b%NBc$H zTjh#9ytYlR-UtLRhT4(>eL~7idO0-qRP?=j zvE;#c+=Gozt2%bt9x5KbVnrb8uWi(go^Wp(KjMswbImn5MPP7_!*rV2EK9v%MZ+d8{(?Xg<}iq&XfKlKvD-w{-^8g|IazP zA^YCg6pJY}xEbB+9B)vNTc!TDc=^kdD&l>$$Mq{kh2;`zg8{kJ2;u<=HEznqFBoyvCc5_mFPw@{!E*}^L*1Xht2FtqHzB@^L{}dC zSEA|A9cB9WMKuWs?B5gVHhT8NZ*sUVA~SHR^@7=)gsH|i^h_+&U&Fv~RL3d>&e_z9 zslk(q%iRusH3m5wYU&%;yQD|`ZTIH@Nv{U=G+P|h3BRig{${ECBy)0_>tQ<~HAd4- z0W3!RSFq=|LQ}kk(2zWn;#-u;9zoai49f#?0k)ZGK$BMh9VNu8X zY2=s)wMBW*M9t~z`Ts&ZdGU|S!jK?a$WP!QO)S2D2OKOI3=4z3u0$kytaW|E3}#UA z-F0k_$(H68`1SL@9=Hv|m5_Q^2gCmzZLiN8)uu!a`aWa5=#4*VRu&~_x z`x!i}bFE2JfAh!B-ZT){jG`bQWi%4=VNTziGD7PbNXXSM>`l7L-fP%bz154byF3kT zRp;{nwP1LG?J%cPFVog3(tu|`%wCjOr`f_wpOFB1+ zV!WL50`i=n3}yGIYjf)fAuP`4O|MJeNSV)6Eu`Uq*asYXDu>BKgn*@^;TqQZ999B09J{LNmn)`l1cM+-E9hi zSce0?R1(@`L>MatZs_|qlEtmN=^A1gJf|JZ19PfVl;;Yx^ICcR$+VG|3CpB3)!Pfh zhIMq}5g8Ty3UV3!`%hd9%*C^rq_ZF7CFTf)20jSfDg(UtWo?IQY^=yDZJXnj|GKwWhjVbCHIYOV`;he1q03PdNqf{L@y`x zKU)Y5-kAg(jLRamw~keq1~ZBr3L>Cw{&Qf4>tnQQRwI3}oB-3!*@|xzs}yzX9K#Gn z$oW!rM8X;rr>wx%bk_=$z3cp{=^nnonrCH7%+?Y8d|xx;mi^`Z)im^8{{XH{jZkYT z(9xOURW6#(k^#)@);`T01IMKB6pY=6!O=`FCOSVAPNn5D1Ah7y2%3#Cf@lkFumN$S zpJLr2_#IhlY(a6=gB-PkT+wXqrVh*)Fdbs{*?0(Og|8+B?*Dt36kFjfHWtrW5kT~P z{v+$jRj^I4)?-m#W~0vguAOoGTZC%Y-D9sD3<3=VnK``pOdd|%<7IvjzdqM5xXtuJ zPAcV5gJ>De*ik<}k>FDPpu<6t?^UhaUZ9!BVA69wzUcnz9=etzv7Lr>Pt$J{I6bZs zIYcb}guXAX&!e^_6TVNE_yJZ}`NNh@kZEH4u*)UlSdr)J`dk?cWU=T8KVWutt4X#e z^Lpgd9n;VB^0ARzhPeK_Uz1``pYZojyf^}lt&dx3220B)e!B27rRiFXzCU`??wO@i z$IQu8S}OFJjxtN`?dOn+&rj{?X!$%KrjV~-!MBbT*>FCqZ;W_e^>#+oqT!|405>dnO8e13 z!P~ffi7CJc)H7^H9EM#}k}E?4;5KSvF9Q5)i2B_H5ZE&K=(U%Y+&)vT*nsjn7Yw$_ z+(>xJB{X#7kp8LAP`5@H8nHcH;BWR?9NZjou_8od`C>!T8kiCP$J*~+0BZlXmA%Bz zj=-C}ZfDcmeL&(A4P-C?CLWvd*AWB8RT^3-Uaf^j9uQZHrGpiKSQ~&1u*z^cNT;JM zh?@*BSdb99kvJK=hkA`hajfG_*#~Cc7m#v zOM!Mj78S|$_aHY@$;gzJP2E@MEMt^7sEsl_8Bp|-3ssa*rBQ;KR>9XmksqeJO5|v1 zUmd2ZVkkHSRnqP+LM9mYyt4j}Q6a>!e`?w|8((h#M1hEKu@cC0C+wzX`de%L#dguw zs|UvJT73FkLi>V3MS-{aQt`99%URUt3Brc50;(4PM4c{kM!a`vs>M5paIBOG^#51@ z0tPW$zGRutMoT*Lo)#S8m3iL0g{DXgx6KuV1iVI|V<&##FsoQz8k@K{DU;gwZ9xbJ z%xD|ce<5Y=wP*_$=EFDVsDVy9ZZQcPAxKNOHaYqc5Urx5ZV7e=yc(@5Pbd&Lc%cXp zAr^L{mWD=009k+Gv{uCO_Q_Jn0Obc@Q(Nv#Z(@;Lv7t0HYFzgWaRhhvx_c4bu!Zz` ztjRWT(sKXiO>l;f5W}){>W_YFQ>hl;h?hFwEh@z`=@-Pp=3)@kX+Phd5*X8-d}^yt>!(eqbNT&0Giw9`gBbN5knfPNsQF2(kzd z0Zq1>A2Cc&fwOglZs&~&zPkXE9*);|g8JLzmxMt9mn(--Y(FEjFVPpdPv~_2ajd-3 z*Myk}Ab3( zT0$zoU7+Azz%hpIk($V{liu-UA@nuF+o37&tG%y6@OYu1@icp{bv$j5uD@z;W7v8j zohigaTog_=zlbDlB6L#(1)0%UYY{`hL4Z|h9Oa^WCBBPyW3Uvwsg#D!N}@oA`=NL# zQ$MEYxH$LBYQwo$aKt~&P{L=QxQdv2loegS#K!6}>0-YYB~Zv*tWDEBG$nD|dO=~5 z0VRjr&g_`&1hbs&{o*f)U%{W)-}o@|mefH@**EGamFO4V+a*I(=5k2*9np_JB9IJg zZb?r(r+jlleg)*{$jwIIBekj+Nn4I7>CmTf#BP$$y&ZcJ56>1(cQ#g7FI1kY27qy7 z=I}9uP*t_}^7qWAatZBgBo_N_e!2YrlGtg>y!oE}iu%B9!N1RVmg8{9%ffv?I30UF zkV!goM^oJ=wio-R<~Fll3{(D4YU|Qxv|dENzQ(-q^w+3op>Y-%{2pn?m_Vb+>*Umg zAgAMJ7sYg#wRM&+`Ki4aW8vVfaN@KOfWjeHR^eb`XYS9*GT)qfgnyo?8J^9eqigPfPnE^u%4^3y+b zgoH5tsHDP5{=398+HWhTny2_IR&2(4Y+u@~EOGvYn$dJ>;S%#`Rr5oYu$2Aw#7v?( z_QyQ2Gr>E8WYtojfN>@qBhrE+NRsWmwmwAfiDp|26I~3>8Y?Qw8dyVGKAj8_WN26? z@9m_Tec^pzbv;iTE^V=GDZ|`*)t0(&B>)Kv=&C0ejK-vMYSG0ftX-Qtr6$brliNQ( zomx!xNFX@(M-PMj_odB`pHB1d-bwGc#f!swQp6S$L}-uRDAB zHdK@$bXsb0x9Fy@uZb-U18|!o3KJQul~FdFIghoMX$=DRBahDHZd5n^MTZJGr^}6B z)BWP0nYcg4ZV_N$j|xGmkRkBJM~C9(V4b#P#N+r1pvc`=8m10HDtQ3ERQgMWyo%}*4<5qc$0-lfBa}l|1S^nlK@ap&Rvf7yf?&UYt1Et% z3)mnD-t7ocrNK8j<}vfqf-c%I&JvekQ4((V=m1vu;E1SWvg_HeW0mlTCZ@K!o$w82 zbg;H8aw5aC;p=X7-S1djS+If{#pmOp6RnE~9Mfx&of_i=>MC2#`(%`GcgHpNQZbv> zp`J=jzWkwZ)zx7u2~Lt9bewWkytj9Ku|XCEF@kwphYO4jn#ArbDWt>4=We8)t!~Z# ztTaMfpF7A}ga~q-T{rhK0msAIWz!Y2f$!KPZ~%TX)qpuLJE2HSS~$&Sw`zebfc8th zwQi6t)!Kr$OC91qzjPJH@Wv2U&>>{cRU#=wUIWy?JdjG%lKt%c=TO9~kt&jIPi9$! ztLN&OotBhUB_pyv@V@W9<3+qk6w~Q_#f+*b=|gNcgOJ!tGZbnYyHph=Y#X~!F0DiB zP4ak$XN8Ty5damPrk6$DGWJ|B};hSMNJ{ar@ zHB$AbbZ6oe&*2^xiTqgQq8mp{ju;k^UE$95BLtvd|oVIW&HC)L)RW zd-3^8$$y6-_kRvUBF}Tim&z#33H(A5-(}jfx4SoSM_nb?kE?Gyb$I_INfi?5$RR~$ zmmn>EYSjd7`;*T?{q*>g6rdN?%q8n+Hz>!|?{V&b?pmk_Nhb{DV)Y$VN75I0a!3{@6)Vb& z5Mrw86c zlEBa+MCbHZD;5JQ1eiVwyaW1Cn)i#@`u+;mLdlV?11%i#^eD^Tc^p?BQ1o z<{1MnC?t0!{PBLm8-eGjyK#h# zXp&#yYW56RVxKiQv47@6Vi72Rnp9U5&-3riV93YI=V8IhbJwQ#;xBItS&)14ar1PG z6$3JDKSR4-Rx$w7jQ|v*WI0kb;qQDkk~q)b`y!;b)?PN2=jdHY#zkv@!6`*b_`Fj& zUn6NwglXK5nta8t#B%NC=s6pSr)W1+4ltZgBUION zFIL)r1N@!HkMxZd^$%LVb;~G(siulR!oW|O&hNtDv*Z5Ho-y7h?W%|>)p`jFK33Z< zg{ocrm>Cgn`59{|N1iRHCmU8IQnnPC(YgRcSWr8{NFSuri*prxJ^iJiX}k&dm)MZj z?Vm%rNd$LN#PgUKmk_p@XoQW^va`CDV>~u%`r${GWX#(W$IEaWdD}G`^(t$BZc1PQD0P;EPE=Xdf_|N90#;9X?H(vSM@RktmKVd1qFqYH7;c;QfJ zttIF)leh`G$r7q`7Gdx$FD05NA=a1~nfe_+*fw=k)M2Xl@bj4eJ|3v$EP&mK0BK{a z9j;o*uy=_jfLh>AmVxUWtO!#2*GVEugwdq|3cF<_y9df1p+*aOCkiXWFgX)nePJ~R zH5z?SXjpU1d_1>boHVqNqH@tB+d%4i9KG^;m=&aji!^~_*m*761W-A4HNlKd({tD;7?;fSbHUVze=PgYJj8XVb>`m zAo_=UHObQ)zZ+lR2XQp#ydNmXrG7t5ybI)I69>Kg{xVl#{~OD9=({_hUXv&?M5XE_ zrhe}HLPYb7#phW*gHxxd=HB}d8{=0E7e=Cw-BczUI5H9aSy8PqP4vl|I-5go(hKQ5(fs2_g2ZI{4nud4mcsv7etRNHgSNnIs& z*w>zptX^K0eq=4Yk}+TBb>)QWA#;X#?{wMds;n$iabNd7Yn-`ChZz()?OgolN4K|3 zaF4nStqRJ*qwKrsqY00SgkGJ8?F$c7P2#Ky#>4{0UaJJOKM0CKhc~ZH}KK6=R&oml60)ZCvU~G6O#9-9_Pp2Ee zOPqAc*D;-ZZ2Y$6o?>WnT>TJwUi#T1|Cd=TNREH0&!KG*rlp&1su>u0NLZ8;^{GTV zk*+g@rJ*Z*z=a0=SFv2{qi4lrH6Q0Qf7eKKQ$ux)P{NCThJ^3mMJ?KH!-eUPNKoX& zJVElk4C_nb1|>aGaPv4H?6tt}Qi6zsvWKaArqROU50z5hL4kKTAmJ`rVu*ymw~YY9 zhxd>6o}~5$-R!l|#TdVzMDVsL^{{ne_XgALRO<${oKMn}5%F0ET^j9`9wC+$N#MWi zuqD|}5JnJMna$guu$me`9HNNRofP=3n}lQpg|QGj-6zv-dZx-aZ7jfsA`^;Dc@I!J z=eH=sWQw=FI}eZVpdAbohXqdjs(Zf9<})T*c1zV3eHYQUcz0(ZhN{U!w{$oCohyKf zU8PR8L@5!!l3J%MsJz{-@|9a$nxYJMt;n(A%|zf3QyCT9l(S_)3MPrxkDmpBYY1_D zztS|!sNn9|f#FP9zM;b!#=VuymoL>`C)s4kn{=)Vq5rhXq_V!u^DgB|h>@!F-v9nM z*_d}dCDEs4M3AvdP_?aQ#o;^%chZ`uwh*km)YssgSNrzh%gK0U)q(A{oA6_S8iRz! z%)7mFx)C)_TQR5kJ&%yg9vEyBHNx0M6sSx_KOgHl)lTxT0(Y%rMC7)pZ{@a~;&hNfI(9W8yn$O0w?O8|jiO))0sv*for_k& zH!ttbr`QCed)W4@mq)W}LrPZ>T!+6D6*;`d_{^oIxp~YoD^Mm>P$Gv7Lnd z45Zg)!A%ggkN(%ZdzpkV#8Exq&4FPRjpV4RIC^x-*f~|?TeSp=z!=xeRcpG1UFHxV z6@t%xx>qjoM3IjY{4VqD0~z)Id962IZ5ROa(iT1g@1LywO(D>>k@QH0gAo>x5>oy_ z!+xGK6pT@PM6YK>SC{WRnaf)u%$G?5f7Q!9zYwI{eOqUN0t1m=7X3tiZTIk+I>x!_ zb*7kU+{^BAHtC`C&hBPSup=W5i;fW?$TLZNZ6Yv~gaE~e1Zo2BHUcyO79>qrjrWg? ztAChV#-rcpV%4`HwPAOrKlga-*+ui5AKSHgqpz#c3@D44AhZo6z9!Mn7p{iX(xg+! z=sQfrgMe50)~eMg*1RjTgP9SeAjRDL_!nqJHG&R!2fib!jhcIE6S3s@p+PM>y?30a z+L{>6*7x8YDCr{AQ2|dE==LwXwlwC*lwV)i8}}&|6iXHgHC2y^$XgUYm;3pLb|uZp z-;>*?@n>;kfNJnhQJHrlfksq=C}}m_dA?m-+7pkYIPxBWzM{h?Wz`Wnuy;h4?-4-` z2~7!thK|)y=OynL)C#Wz)a!-EfnPKE0kjN5-jM z_x-pjeolrE&!FcRz~D#=_rv_HW40JlMmrG?sOhyH-Byt7rAtzqbS(M}<@+l$@L9hc zr81ZAF^?=tmm|MnN{bgkIm+kw-NuSwOy3)b;acJqUAMG+s?lESV;g#Q#QCm{qmeo= zhZ*J~c_`GgaqlvVJMuX4EBnEWNYo1(8k~Ro)}2GRLkXSqS|#)H7Lmi0xBvl0I#`WD zig>l2K(8@VKu&7lqjw8tYe}-JE+4mQBt8%)Tm~l`s(e|`d_;jLdRQa^(&C-xPt4UMUQJ#emSAG5^WtjL_gGTD2w;d^*&YY9)s z?y zhook2tA`e%nFi-f)V&Ef)`*`$+m=lV#fN}?w3O%t!%m0L-t^p|{xDDD#V)i3q!zh+ zACWi@yIW;pBM#Qnpr%^st@B&6+0a?b4T&k@Hsjy+)u+(ru8OZiHruUfH$R|2u<-OT zeRxx3CWz1G`}oqt$xnZNscX3?$!4`HIi+MyTUUk;PjpDsB#{7MgbJneq7%=Gt@ z?9DTGvk}7TMWr&-hl*@8Ij#*3`M|} zAz=CAszN>$o4R^J2NA-ov{-W zt;~7t#6Z0=b9PzWzcf~1{FR-5OKd4N`#Q1mA(B}<-+fl*fQj*)Z1td8X5C*YpG1}^ z>#JVYs81bMY3i#~l&%s#aDBYghx5+o04uF>`1et3fJ^t9{Xchai($_8qw}V$`;9(c z9KCh;RcC`?jnRh(m!%R`?L?Ifd~Xa?4lR?jR!ScRf32AFTo&5miSbKJA9=01=di!P^dRH-&Q-P5jls#3)1Ax$gGGD78C zD=<93N&}4;yO*nh`of=z;Gl20RZQVGx(Hnb2CW$kR$~N=k2?r94W6pi_Bq0x)uaZZ z9-TgaG!&~!7x7EFBs})f<_U$%&)J&o?+=T3%y_lmTRb>8AJ(Wxu!eE zD$XJ%2-4KsRye!XvnX{YD0G$D(k?rDpQkN=JS0MxH;$V+7mQ(DQs^#5BJ&@w{nk+3 zyG52FJJmNRKlE-=U}InZg0?Y#KXq!3{aeGWFlFb>Vn)~BMc9T<)}BHDl`^l)1X|R8 zoA(q^L5o&8c#B6#!L}EhnOBqluPaUePogR1c|nGt@M-?m%HU5hX67d=bKIP%mU*#u zTQt3z1%+l&@^4U6S1m=xing;O&&q$FywciE8+>2!r?;Eu06*Jx7@!?8H&cH@GUeLE z{J}G?0;PadU|3Udr$GC>L zBnv?b=n(0{X@d7Z@fYmgM(O17i`RXWbzjlFm~7v9EO)$hTKF9wO^8(lMQBq{fcy#G zv=Ib@qzf3H>VQURVNX;uJpSZl0)H`YT_DU$S4P4myh7Ea4F8sOR6K*-xVucYhZ8;}E|c>9<+k^UGX6u+!r_8; zN^9rKHt{FJKA(<?!@4afscaBpcjJLLDCVvGo{9A<(^E z!laGrU{efXN@d~P8?th6jBD9vyFv06kDRo$u*_uFrCUcu`jC@D}tXISt z8IR5P!yCixQD-w@(C8+6$FEuOzV%o^nwYPRJViwJ5Spmc48QHHG||Me6I?-Da!rIs zyu4=0G>&7>EwE0yWbsRm6a4~x`3*8Ha7{f>i+LMG+w4f%kB@mYM$-j|mbA@VZ~qMT z0}22ebQ$kHS@K~DI(hg$bOB)(n%78Ji5!TcVM7|Q+EIC{j_px+9Rh1hcnyJ)jzda{ zOO{aZewgNn*_h~`PY1UaUtzG!*AFt0;_22tl?)jzGQ-iV>+q{wVI`gg&SWXg6up z4XR({8U5Jc1r`NIpc`B6<(y67FQ6oHm7X#;JcDX59PBV^j?fyFc zX~N^J<3NHT7G4iSthJmzj}4>B3zMQ4mo6Ed`|S9uh&qCDlBhy!O<)rG5I!=Ez5H; zLsriImQa&VCB<~WS>KbO)~wRXY?Fygx}+l(;#q6Cj5qC~0ND7Sodq|6ZTXT-mut;* zlRj;{Tl?LZ?JOKlKk&UH743t!MR1>Z`K#lbG;0Q|h$2)J*wlt>MRsD5Kb7rRkE;Jb zSIGQ=3`NO`}P)r4{-_1@v zT|T|Yyo~ey?MURm;|&3%L?Xw#&S@OGDBA6s5E3J3SS(7TMSi=9>`n6A_9Ym?;-mM( z>7f*$n&h3MGOEHvpuK?TFCU9U4z#qwA31hq59_zMB3AvYl~DE#e}jV__RLE<{q}be zgIQ`$d)&oxE(g7Lj`jQs!(KT(+fbLn1MF&+`bGaWq+!paYu9Lj17hD3a$@fHO5$Id zC*-pawfSrSNxvPVNg_q1RpF0o7_2T6Pr>&%01zOX>~LsQJyp|B?!%ae!i=DZfY1mu zY8Z@yb48%D_|l@dNEVcw_ z?ZPiau=KHFIM3UoIKtw*+hfODu9-kD*6T4(&Y`(0F;^Yx;=Y0z0uZT1QoH|@H(I*u zF3H^sdeV=0P%w6cf+o5S%pw@fg~$8xb0+L5u}CIW>5>NJd)IDOFG%=#)Vi^(;_zy1Zv5%@5TO98mD88#Ag+_Vlmzv9 z&0QzchC-)&7yRCfUdUWT5ky=Z)@B=y2j{8_#Fs{u{=vPBWAVFo z%PhbWVyH7N48_DV1qgu082R#O^pkn9W_2o9goi!X*!i$SS>4lP|^M z{mozqlD4Bk7V_+MrseSWJ(t<|+4&sEFkK6V!BPZKvUS_9&nIU@VG+rZi30-*=qIaz zJ~q;7s8XIG&9gJZQL>#)KBG~xeRUh-o8XC8HQ6*eZ1~*H6vV3z6bNVM_`gBrV!Qg# z2tHeO*NOkV+#UX}VFfPtG8py}MfzKyZ0;%tA4RW?>d8MmxJC!_$8ixcA(UT!yb(DM zo7hZyb1}=f`-9wkBES<+ko~KM8fcn_ zSiopv?eNSs>G;9|jJdBv~`IvDI)T^>bD4>4~d=A(rifH|%AIX_q{$abhU zdR$uoh!IKbhx@qH;+bqOy_bnwWBf$FsN*zEH~4uy+aE(dpLIuxs{f0153mk2&qTcC z6nmmkYfVgluqWR=bYcr$;Pfi-pkvgPTuc3Wedn}1_g0wFwtD~eW#!wo!*PF~fEFDp zBiZ`$ey=<49rui81MT_hYyPSndP_ANs1~>je|581*iW+X&Yxor>JV= zyD3o@4;wlZXM+VF+gE(kujyuCis}@32Q$Xg=5&bWfQEz`36dV4Chp+I&_`1}vmOSc*NJ*+ zEffDQY%s9c4`qt#sc1@-pGTiLkKHK|tIYoE_-358%Ry0MQ8uIcFJJ%K#_XKI5z}#F zN!2*b)Vq_^_iOdDelnIGzQ#MYZs=*;sbp*Lz;Ky|26-E!swIp$uoNGNDgkMj+@muuWf2xYp?@+Klcg z;U*&+L`u@-M5(~wWT?R(5*ijRO%zB0oMbA85JMmc!qD(=B^1*+J^d3E&|k7Ip2%JI zG#@Dt%c4iUO~2UiW#dNY(OuFE4UMW77gZJQ!?!mv@%hwz%lrgrsd)q#fdw7F62bq{ z@@G+lg4giAVa9n?&6MF-`IN?If<9ox2N=AU&|97e%mUGh5@^>ItE3;$MHONORMipz zR1Gjt{y_~SA7wTt2Mt;g9|Y~EQZF=(^J$RqaepcX$TUttNpUgmn|DTDD27gcvnqbd zY_D*(mDG)O{fKTdmb9SM!s>F3^8AJ+>r|F}~%yIOe7znSyrt;Zyb^&d=l zSKr6o?=pc$4+PtFQIh=KU+-%_Fz1xNw0t@5YCi3pr(-(i-l*N^FVrggfb)IynxA&S zqXG${huCqAiAk$X#tgxkgNk1pa6&|o^ax|(AH4Ps>AGkRMOmk_AlwH2G+R8mFcY6c zq1NiQr7Y85((lp7$vM&S-nx*_6Yve)BJj=K_Wx-xrRY4|cpt*LkzmY8%a1YWYN(62 znthKTlDWg$#S#)G+^nRmjf!YwKDMR5#9+UReuyTX%S6vlihgZMRY1D3U|!-X$3F&X zxl4_a3<0+-jr*;1Wjx0QNo z$d60|6zJqnjrlot)3*K;v3#FG|C&wAdCSY`S z{^tB5$ov(l^NE_?eQeSB{XbHCDwHHnyx25ix?=f5v0xtWC*qDIh5!#%6k!5Bg znJP~Rp4HD2=B6NaM9ojqqb3G!wv#RAgUG=GEJ8coX&{#khV%zw{!w9Yw#U7 zL}7C8TGxYZ!KVCY-M;sOgZp<1EQ6j$8GO;Gr>8U|p$)^77^t1a4*YnfS0D{l)E2k+ z2*ho&cM!d!_5YOOAKd%?`)-zB|7l&Cso*Ms@cq4W-7gC+>%?vhfeIm(ubv9i^gXO` zxeik}E8j_=bnb^U3qJmH3?KYenqai%9shQLf+a*FdgP}R`1I9) zp}5O9VIq(qX003JXiXj_xyH?EDjieAIfh0Zf)6kvEJIxhB%unMEkzgY#UPkpnfEo1qukP{ zL?YwCy#ZGq<4)Ja8}WO3;yAzJvq6KnHy|;^jYP{jv&Q&3loT`|IZR=Q@j#ZozrSh&G81Eak0FrYctK zJoOvc0m%*TF;NHuGUs9Vv`e|l@-c3-uR;a&EQ_6RmonAsLHK%{zC>e60CtFI#MLX=6MVrFEwK+1ttO zgvOxlxizZwf4FG<`)az?)-WCrTMl|GgMq&AYt?#|F+!2_7i8HQsKiv#B0(HU5t&nc zCskUJP!Z9Bm|C!*4iahQ{Elj7aX;N_edPJwzQJN;WUvraNn=@Y$zc!O#}S8S4l-Yr z+>A{EX6}&do0{VYJN^Hg5WW82!RNRC8+@M10HgT`sk-eptHthxf&}ygVqFRJiH`aL z4r^jF&vMhPVf6?rG05g-Zh@3T<`5J6EcU2MP@0$)3M5K=Ma+BnsNgjj$SQ-wyyG-v zCSij=bf0KL+nN7-NoO`e_9)P|AYR6MWb_R<3<`B4M6oT;4;=G(5+m5QPjc^)_%{m} z9CZ_TgOt!L3K^o2P`I9*V}K5f(CWe=Es7bf3}S&;C^-Tu9v9)WgV6_SMAwOM*d~(3 zjn__o+Hxb#ECh#kt5qym-j|*x3-@^D%FN=?%zE+pnL8oFlOX~v*pD@OjtRImzx3lc zOKG-(4>>mUkd|ONoledO^p1RPa9&DkCjIs0Qrz#1lWa;sm*{UZ6E7_`vDolvf3jMU zuZ^~XTAA}T^yU*E?By~^8MnKxoPJ5l1zc^hAXLtSEMm1}xX1H41vEv6^ajIZ!kV2W zQ|Er<9*d3@Q%^ryk@DqD)BYv5-%E}tQ?XAi-T1FU(1Cpi1dI(eUk065kLNz2Eu81^ zsUD$^!`4zCl`teeRS_sdrH_+NNow`{S_mj%Zan~o>&e!KB5XQB3v2#6cT0$hIExsd zQv)9Z)0C0s=3Jf`6T3_LggdG1LU2ho`WgE7QP630|07}Qupwit#50Z+C(jfvO zAt^`4mmx2XML{ z@Ua*NlbtSnxF;^RRSGA3cOn>LCBx=OMO6OeKBUDm^W%KpzvrRn`Tq*xz!Wu*Eykxbyc$i}~0{}5gA<^K^~d(@Nh7(`d<$j$MjpcZ#eNg^71F&)Q} zLXJq98sRyKp=>f7{nsz_Pm|1*li2h_NhjXrq zd}~CuO(qL6c@p9(#Hodu#QHbGN&>qvfxqum9T1#@@8=&80mT z1--{MXP9F$N851yVpy z-OYTW#_+{_OXKq*Ekpy?Mg_(MZf+WUgrueje=x`t2UNB3kNK3ql9UKK;?#70U;iO* z@gsjE0&kc zlFS`(e8C;s&cQ^Aiq5kD6uFrJ8LE=pc%5&h=^f2Uk0|#=f4HZ7o7(565t0*CJ`bo* z2^W@_7GFI%70|z@AT#rE;m|jQ=kMZ>m)49us=s&9-;ywmfjZbMlX3BQlQT#!_=O~|F#UX9tO$E{ugmoeo~ zQ)YJU4%ThqK!I`Sfo9C0e{c0`b{e$gp1{-wrSpSyu&+qU#K#x1V}yV^x-(D~?qb;w zP?gpy=@xhy^?5E3Ge2P*h}lh~wqiAfdc4zZ7&fT=o74V>gemU(5p+p4*;!NLvyg`N8cMzB7*|NYGH8I68kIRX7H=K;O2P#Y%UFK1xUlDGPN*Gnmvu=CzPVl zAN`kMFOuYK5iS_q6@XdjeLi`3pCqB}E(;v)+auES#Ou!kDkNDlCFI3SCJ+PZckYD5-~Y}gM$+m^2}a@aXYhDd zLmGqXd93kQ@49&5=Mc%$qVIGN7?WItQ0bpLNGKcGLJfMD*vx;swbdm@g3;0xfQkJf zT^u9cxqtBjG?Tq4Bvk`E8Jbw_%>C>5myJYIy}r}y$ys?9W^&xh)mOI^C!xW4s2DTt z$s!yd1ZZl5>DVrU7yA{Lwe6F67129OUEm&%nc8z91rzP2?SCyD|JX3e9id?oM%Ca{ z|5>8w9AL+v{237UVMI^U5 zbovgJWK$VPnGP;dZ%udSjqX)PMUn6i+Ayd$|!$U7y(fTP_V4HIHs zrwve8%qo5%c$R5+5dFcnwwphY-29qC3S>@p8sOsVj&1VzgHlNS(4#s!VEdBf0442P zag-4*#qS=mB(zed#JMGz&PjjGcnUw|@fDf~Y0?91ccihS9X7!eBoi@f49H?_z{jzL ze}2F)gfNoz3(zLN0`V6qSe42Y6SwhDtJEmC=xV)<2s`Y40{9*yNhaJ9OY%t?Y$7lIJh_-kyV2b^lbm~b z@*@*?AXfiWgNCpxz5 zupxy^FltXdRdunswl{%arllyTST1}Y zps^>QFtR<8cFR*%P-#FGZV(>Xpoe0LUUg(sw=#PFi@3fNc zt9xeuDnWs79Nb~PmhnY2e~-%R@41794CBvU&sP4Q=EE<)bjv9cl6T>`3bue(($kNN z8DOoSynfqJTbx~?hTmfx270NS93}{Q_*(rKKeG!xITt<@>i{T_I9w0OAvSnhaQ7m9C=9mj&+z##OLh)o5HX2B^zA2@Ab;Gu zBt@ukl#8Ir{aOvAPBz$(1NI^#=9HpO2H13Aitl#5-*VjkW9tbpZ2iyFGikX{d3)iO zOL@D~xLaWy)x)F7xhDzhLZnv}4Ph3p-@j4vM((Dnjdo^}X7Kt8+8e5l_5v$8thSOW znm9~Khu$`k1RxnInmkU;N!*&&3+3PQTPA6KZ5`K5D9G1)5cZx040eFh9ke{+bI2xa zQc%W395%1OSk|AS6o;3%GcBHuZ}Q`4k(1yK|ZRVnBGrv|?qO1$jnL z0tRY7k0)IIY)#ILJ-b$|w`tzv4k>P;?dFSei!Ni;z*83p!Vgjmw$SgZ{PWl>4bfTo z)w(%%S{jx{Z2Fw36uU~GA2pZdZhS8hQ81CM?D=4o>Nj?gMcz21x_!E6T^xDM@$!`( zzgMN*d-aEt*7EFDl?pj~LlhmBt3nZ*a(_r=F=Km)B%=hby_0GZG^anif;FsVix1m5 zfjr4GeLLb&&vcSi3fhq{E2}~Pp=jTMK;$zxOrx20D&j6OaiY}ZP;`T#%=P#CD}&Ve z9+rgSz(Aw6Tym&bNtVYzM=?x^3qdG{;$&oZPi&cO1e`kJ%N6}R>;4sr3C30uB;0qT zrxs6r*Dx0853b_YJntt!yGmbIC07U4O^_Bb#!B`xDy0*dN9Uvpl}r=&=1WZRkBhYc zONv<>1`+=(>Nu58SzwO$+W;j7f+)a|Z|;_Wv*Z9njs(?1ve7}4VUy?tlc9B~PECog zS%4ON@{bmLo-g}lSF$2i=IYP>k*N1R{l=eP=Vhf^@riVt5CqW-mna|Wj>+Hk1QFs` zY9TUOz8X3QRecN#a>V>^ZD6`gua+t9;|I@P`FnEClNQf@s;hx;GI9C|HGNCCz-7vy z0)?1_KQ`vvWG+f$CyDaCYa%%MQJZ>cfuu#i(EjWBXN@oBZ?`IID0PBgp85TcN#J#{ z^mICxGWe;eb1*RnDX1iL>e%cqzOO9y^S2#oRrtvlSymrJ@14@ivjD8QT?nxKO+fBV z^R||3VHtQpx0=JH(lGk4Y{b=l=0;(9V@;BhpNK0}_GsSxlyEKB)dPCh>~-_Wmlc8! z2nPYl80kUo+jj|%P*O-8m28c7v6gmHzHjsBxi2(rItjYw0CSKG|3kd z3T$NT@iRf}Z;Rk;%=8A&BQLfnh<)VVD&=p`%0zOfFtK-y5}SU4>kxcpD%d&A$Q7q7 zIAU}Pn}8J2@|ZVRep)zZ%SjfPTJ!aYXck>5MXCYrb4 z{HM&#)#FP$*%UUE{(A25Q2N9CaTl7nB=RNk$;{(yRDb|SD*+4 zmLL3^8u9u^ILjI6v-%(gCX{MW18~Ijv0e?U))vUbc8#Hl;+N5@dM8d3$i}lB&tb&- zb_I#)(*1VAt*hA}5vkY^c_)?U7x@=CDL?`F0SVU4sqX=j`A_fXVw&vMiUUu21smq= z0G#Lk7zD^RGz`A*%{yiH@>}yFC#0o`tXN(js@NGzLU2Fa5ydJ`d+jos;{2*4>L$Ye z(*LO4z9!nJswUZc;gnI6(R*95##0(^10yxS3+4+3-NBQ~+A_4mtk;=$;Icv}xDH?N z;E6p^5#PZQ;N!(^wy04}=K|xEO58|vaRt3+M5mWH}wH1Rfk$>3Zr zPq-_#d%Z>S^2KYlfj3SUWUD2H4bGmz{XCllx{UZGUmG1XEW>S{qF?0p5-Jpl-9HuP z+63i$3g>^jc1R-gq5@dJw%Kpc)Q z^Y0$3bM35h;`ue-R$9{-wsn}1A$c3^%G`2g7HhCWH=*OW#GeINFj{ly@PT@lCw(j1 zp6{fls)!Z7#c|dqd!`qb2-8m6xlV7)nE?%Olsxm1;tWnTf#tr!O=L9W7OBF>RiD|P zRzj0nh-GL1WLY@T`HGl0%zD&C>YMWqYP|=XTlj;#qEa?cSw%K?GYSf`RfbJ2z=CE2 zvR(@@RNk%tN@+$1M%lHT9xlP1zCmO7d|jHh0;3;yNTbrZU4D?X@o?-o7K)lidLns; z5k?PM&RKyafK24$KfnnlC=_3j0;<@%m>BgmKsjJMLJpFdWBY?L{B3rih36TZ5dF&F zt|UVggEES4oZiC#hU56goWt|bL<>AT^j^+s<(FUa4jCd*a(eLH^&M(f{QBl#thjOzjU|J>dqOz93uv z9bn0(9;A$(pK^KEllThfxcc`HFLm<_{YicEiLWB4;P=9|0 z&~Q&7{(;C;L}cJZ+(=MMb9yFcmbAPVm}1RRf5X21RA)m=_Pq?ZX5~mxvUJo{#?BDR7}I? zO7%vqTBtZ@IkKF}LMm5y^Xjd6NFByLXn%2k&%b!stVex$Y?7VUJtf{y2R& z^K@u@46y7@9eNUlxxv+0x51OgezTvD%pL%aXKJ1qod=lcCL!ke;}-KUH56=Dk{v5T-flOdQ-bKLSsr+m1H%^ci_tjgVQq{dv zN! zWK$%-Ch8mmG4|(x$3%#cxk$km^CD8vtWW6014oG2et7Kj-Ma6Fku)BJ6)v%&S}5lV zJ?qnm;rB!5%Mxoeh)V>*vf)quX#CCJh&TQEL;M$4_k3)QeE=K%&cArO$IRX?+KW5Z5v22!`ratz2)Zkb`t1Jr!W1vT1 zW^fGhF)GWK#Uxrlo|u@+4TSKZ8v*<~N@i`gu*Ek!N5^~OX$@P4M(W1S7mgn9oOrH9 zWe?9cNQ@;Cr}9Dr5KxK`d<@+h5%Zxy6Vlg4-ko74O~P;5kGNZq!2*K@Wgmt;Ml&YF zZoXaqMHp;Gm)iJZn1+v<6}&=IQD{3h1bFNcbm`Ma^hhT!T@NqZ5Ax=5RWqbc8js4` zDw645v%gp9WEYEX(Ur~i9rSsuKk8zkOXy|OpY-5U>7$0lpS3ur@>)aP&5W>6wRY_?aSK}wJ-~m5iutI`5@(E`ef$#*IS;AY zfjAgXh{yd#V>3_kJ$Q~R_Dl7$HTb;Z?H<}4tDA{_)Sr1INXrwgBM_4}t20xkGsR~0 z#&kgiPm*w%|99SeK7CNx=#aOh;ZQmE!Scssjor0F--`1(b#HAylgN&#@YNvKQ>AjN zz*QPI$!S~Si-4gR#A_F;ierK?WVYgocD`@r%&SWc0@mi7GJ3N&&zKbDtt8j?U)_*I zzf5!}BN#zd(&bRQd(%eC*%A;1r*0@EY<<=N!9=!jM^sR;b@L9878}Um-7U-nK@m{z zA3GfGF}33BIz~Pk2LP^Yd^@#slR1mFrH^ZlyV4_VMwH1Mujw87gnNMDk=?6^Wz6Xl zz9<=D?i$iBUPcEu^f#!Bzgv5{ z(<-Biy{2<2UMYO7{0n7@A@F);jV1ydngrD+TwnEV7l6UY z^lX+V_meol^x8(*JN$L zboTJg9}NA&KU;Cc^yT_p_+?Cfk(@hTOQKIBf}rMlps%Hvk%O{6v}&7X@67AhzS4^; zB=R{_u8%@WRWJR!H2Bwbwmv zHT9l1ikVL4vTn+KfZ@LSx8q*Pxb9j%ru$`^p!2!{(4kq?jq>{~!@Jx`?{JXTWSkdns|$&QjiphuOnI9v8#B^YxC zg<+57_J3flDt?t^{ow4&vOx$0#&A2VL=?Ag^E1EU zE;{W=i>}aHEX%C;H;Jbnx?`FZ?Ygg$>ST&|pE|x~hd_|n3TPtY>;slp{}0ib07VXU zi-R0Gc7V(OJ;D`kzXohid$I!l8cx86*E9(b*Xj_s)53b4+=|Itq^dY1f%PtGv&ie$ z=!Tc^fv1GN%%`Am!`BN|L#pEG zQI1pmJQKsHpP+#ILrOfTSFF#7suYL(;M}-G&oGLgvIQF6FnM-=?`5)xr0Rj*|Hd&$ zOeJT=`Dt9TiTr1??FNZEComMUQ_ zOcp+uIMyhT$*>0%yV`esvPt^a-{n4(F@|?pe>J)da@l(zS*(TCT_Km3GzPsCH>r2|@Ls`0k!!M#pW=|Li-;EG~6*=qwtlEoj;b;r*lR(Be zS#NdUubAz7)c(X+*6pd2>mHpPObGc!kOAQ+UNQTy7WZ2^7j+tOHdA8#7LLY{mBVeG zZYq5YqQ`pBi^3*|;o11Taa@?+l%&zoI5tyLI(9*3^*jeUNbbJsG7-hGBQdqwd`-XR z0(c_KUEEW1lTp8`S`M49(ds4xrydKAW`*FDGA7$Uz6oo=s6+_}0Y-xd+a13Z6$FRK zJ5>Wm*bZ+Tas`rO#cRfNBzZ?W9`xPzZ6U}BI@@XE!h?sg36iW+k8T*_w9w?XuDU8A z_#My0@nwLT(6E71z*5_ZEWqzyviiZ4dzbe+F?kiWIJg65m%)V>14NnRracpr0Uh zxLAxB4CWMC$RN;}(tHV+ZuabtD3ZX@To?1Vk z&+i3C-kB{jgk%C#kRo)&N@Lj_jSE;B7x&NZQ@nUdq(X|o948uO0l?PpnMAFbfp%Z{ zHohk3`NxXn!lOH-PZuA1Nz#Aj4KbCM&Y?Q5zt4cr)xLD0mJhf zEkAhv7XNMKit(RO_IF!*UNRCFezV@ngJV^aZ`YN4E=+H9&TIS~11>N;{;f zOtY!`@Pj^jgVzj3kAz8Um#k#?BxD$$TW*>9l}r@zQ>Mkg#95_m6{21LP~fC~c+}G6 zA;dU&Y*s9*H64e36&6=RV(HEYvd@x0iz=VKru335AjhU{@5MDzFm0ZwO_44|7s^>b2V7?oFF)61 z{L6!^{KtbZtP@HdH`n81h4PC$gKez!B+Vj$pi?u6owaD#mXqV9s17E1qKxC}1_qQ? zL*&y?Fz*UoC6s~G)4aty&UPEKBoN3V!Hf3-l-{j3qOMH{F{OeErg_BMZfl5Ne35qf zFGVMo4dP%$eA@l9|D2BJ}Xf@{^+oY|S3jM1KlhDXdeNnigS{J`yX zODP>wD73ke=6!Z^#KP!}EU*V{BTSwUBf);Sp>4}V#IwW*RYPK)?zQHoVOwe371)#n zQPo2wS+FVMhEYiBHC4v*2S{pOz3&k_qhmNN&637L98$?@HsE)(F^H%Mkz`MtP+&74 z&0S`8#F#5M3&>BJ1CopZ4X2snf7{vo#a2sh( z=8$uI7G+6!<^_(e056prjnv;*j|EY=*LkS7XesTHQvVjZ@4S;sys=8tDrCCnYo|~U z4wfac;#L%E6Lfa+J5jvWI5wSvNZXoeXwz-^16mH8 zT)L-<()`<^tal`Yc)$ki-t&_2Os>Ub&IV&X6?yVLo9i%wl%zLbj5x?V#6Q!7(bVs# z`j`fl>bQ0|xrI9sa_ozu26BB+vxBiH4j@7SQ!ThZPB0DyqEAR<_ArHwUlk!7>&_)B zA`Emus<-l*Z@21WxWylL<5)r7zLWp0g=#a-x+Q#g-l9aKjgNLtPZnAT zFhQtNTVu?L*;6>!K{WJghCbh-8a9UsHW&HJ*g1hPX#+9N3q=4k_^S1p%hs!PJ7C!LHcI;|A9*+73-@b>W^69?fc z9M+-A%9!n*OFR*^GavVzJqSg<<7M!rF!VrVid%TE}jfj@p{98&fN;VGR^XJuUg=VqGXQ`SwWR)o ztd=LjdI=@Iw0hj91)I`NUHkk=%B$nvF23=D!Q17-$RcQ&DR-AD?m?_h99|(uvmWW# z@bkmm=xtBZ0ghkyeo6>2JOb3H8un7DiYZ|9^6pe886{l3{(Fst(AeT*UyoB&=r6#lk`x`7iHg0t#G^zh6`bySdcGctVL7>kXV~)ygZ{z0Me|Ww> z|K96tKmM2BW;i4KA#`TIRm{_9$h-HPOD_oS#0mAduyeyZ0t3B9C4+oBP|D zS!biUp!JDU06S|dkdsL^j1hrVba2+0tI@f(I~sMI*kXR=7+hrFm(<_aUk<}y92RkM8BZZM}b8f=d*Rr=SJd%u()2b?xWsbUQMQ56mWgy>oj1?js9$ zvt5f^ud4km*C-SfTVp+6_DPl9l{5^KBnxOwROjKVpVV^JRXsF)>?UCRJy=cOO(5r8 zf-xZr63HMU#6d!S6oN#$3gw;>bZBmSG|!BXdtSaZv^-eHU*?-4p}6Zeq==d(I>2;_ zth#)e4Ge+-fQ!Gfv;%g|S2=n_mt8gvT7AFyk^s>W=nQ|7rLKU2@q^_u=2<_n^vI8K zroOxLf;K@RVWOKnVX#}UQzt%t%?5I8ryv>NU6zqt=F@w9Q1`vo79GK|641j+lMS8l zRV`0`krgI%im**yO+86OF+yzVkl}TjkvaCiHC4wUUBQ#yI)_ef11bdam52t0*H4WQ z-FlQghP*FS{+?nslzz}lTmfuybsfc&HSNW(9AG`lg$3cqDutrF!h7BxabMY({Ii|q z4<%`=tN-zY3Z>0#N0VHz06hg2LhLdgcG9*&RXoyw0|Cu9Y{^9a%AwTA7LqZE?H=+%|51KqTG7Oy?L29#b#@{R!^712>1 z)Zhv6-(FvK_a4JFkF^kRJZ{XGAkRuH_@OHKdsC*tv#)tviYVY5Fdn&8GV@LVK4RRWrpOe#s6?-UHT_8m z6kQcRbf^{bOE;0~ zB1ab$h1K!y?g;=u-nL>0#!v~>FW^_b1b%3*tfNcvY+pP3)%R$4&SSLY>iab&B4qlZ z)G96TQ^AC7U_12YuD-7gq_k)#$(VCC#YkTnpjgJ?YfQ8+pj|Pm=kGimxJ^>h!yU~6 z5^jvh6IhZ5F7w4!BmSJTRyI1&SBv8Jt+d6u+i|$ku6h;Gww=?U=C&}?kF^rlfJI(q z9rEPZZ14Nyc+VDD;N%kO8-B4o`0z4XzkuhJeC-Ig_Nze&Uy3#dh`heTNAzm~O?TQT z#XkJI)Y zYyxH=u*_+jS58ndQ1S|lyKt=`;44aH5zoIn5puwn^IxxEhK!oDG7&8GQzX4j7uM|IKcn3{R z$Y&8$GO127UmxeaAnPqvzq0)1V)8f3#&zRS#Zk}s*RpD&zok?t2(uyaduP4`R(p)9 zTsLH$;ePdlVeaRCd>Sy@V*~ zN*Vzb+KGJF@zg8c0qY0B3d9AzTU}ObF;{pCv=gSZKE%=5C{=qxAzN2CzmB&JsPP<% zJYt!!z5xwGa_t>M+Nm*@`gf1S9}z$f!*G7_hMG{Zi7t^tVhFDk!XE6{Oqr4I?dTVb zc`EUum)z(t$QN16^x3JIS*Lh=6%`toZxlmUYo?N3*@>_?5qwG3m<9IUUP%dP{cK-7 zB4xnVy>`9;r1oV;Ue{WaIr`PZ=Ek3e^dBpZe~6k9J6Q7uc{XjZ9y;?xw$(C=-YF0s zR8u0DsovBtdo649{ELFK{>K;Xdr*D*isjRe7CJl2s965arCKJ@dXs^0ThZPR(|Y!D zta&_GoE*bnqvPJpR4iqG@?s)`-=`aj)s7YPCpLLtCh-)bmV5^HvtLpokuUH#MO9H# zHaOOJ6XZ^yw%ty_&ugD|!d2*fIFG$g4019{0-Odkl>W|WB?8DgVM}g-rj(*8TuRZu z*IE#bik*0&;N~T31t0fEA{FI3cU7#E@(3G=4}r$WXy6L7>aL`oLxM{jo*o|d#Fer0 zvDW-(gxj+t5Nj?EhgkA!Btyl|5`>tC)cwz-$(?Gh6s>xfT5A*Y?>XTHksYMrsfGD|I9qUt>6C_i3k_|UnH`vhFb5PKeG0X$|^UTt`5VP++JK5KGEtZe7Qhx z8C%2+HG3i^SF2|?EJ6%!p4-R#`c*m-_u?0&@D51?i-0;`4Gfk;b>K=|46DnuBLID> z<6{}%B+`D|i1U69jiRE0+yy75KlN*7U%_bg%#5RAL@-KFZ* z#hLoHUoZ??5IMki=mh(#m$fGrhXF z{9B(en1S_8E?7WgKx{?gZ4l8@>()uZNW_Z0j)B$9> zOrn#_=cbur;>^Y90Fr7N{sLBd&a0*YyT~C>Xq0_`<&G4l|Q$ra> zK4u!Wb(kEe$Sux=a{fJ?mC*Kg28datr&Ydm8x+=-zN#Q6^Kkv!=9iacNz4zB(H0(Z zQ!RCU6-jLrLX>~!&AC?H#p!}rhYHXQU*i?W!DTG#Hfh91&99)wWN^)o(jsjE*I<6! zR|=xf)A+haS?Jy!t;+7Y7lGSg;~+?*Ts_pKk1_R-UMDTSlvrE2ega-Z#acy)^7K{l6@IR7ce|1wf^5tZmVqhrKq;&X`)2fBC2nbFqmRw$H6 z)-Tj~7}^SsfAE9cWo2!dpyf~k*5)U~_JVQJ(fyp*Ho~*2>P%Dtg<5BrI7xzsbLbzW zjMiu`*RoU9r0HpGDnAJW=EL&e)y}KEjrZ%%C2vBVFVq^Dr=9|d>xEAa|8tiYn*DFT z?rD>y$@N#t(LK+-%aQWj)JT>`6u0sM>b^oHO76|ou{I$I!HXDtT*!JGDccmWC-w9_E-*D1 z&SvasO}K4-A`iU)OO)BkilNW}0=lyZn=wcRz>p7y11I@rSS418ifg#V^o3L^PwmJqf? zE^l`+Qu56u<@cib^3tDw?rc~i;`f)qBq9yb)M!;)Vm6o1Sqf(P#*v2m3pa#=N=)<> z-)P->)!i=MMkw(Gyi{^eF2_^MWLOig*2on9$^OK3HEJ&Oyo!afZ>_ZqJ7+Qcy4bN5 zM~9#buWum;KO55Ekr0oEk2@uLr1-YiaxndFj#Gl;5jkTdiOm|Gr{X%rM?>C^Zp~C5 z<8N)x-HQit?|eD@?f9FFhK#KlaafaM(Gv*>RjJi!Whn@Sr)@Rq`>hYNk|~D#N?W_J zmkb)-F_w|;X^dTm=MoD_0Rq4oCsbBao-?;K-(dTJmCJy6uPU4maEZ3X;rNy77H)kFLyM8&-3DnCQKZoS2QTbyIvrd;Yaa zNB?J&hNk(tW<^!-Bx^BVU;0SaU*$K>?A&@F5Oqdig9#-uD?uf>J6>9GqOKKx=m15M z7Pq27f5a2e6M@2vA5ZWmVO=7Dv@E=2egd0-0ExtCz4iG(+zI;lG+Y5S^j!8X81qPi z*NL8A-T?E3nP~DZI%zZeb{%L-H$L~H9Pv3=ICp+#j5mKex~IgM9NgQDW zUx*R{eG-y?%cq*y;I1w~f}0Fdi_nlZiCLHr!XPzhr{Xdnf@1y@P8ARb3t=4Zg5|0; zIHLsF1+Q3X^MlPJI7HpZE+-=%$Vue6f5`fbWU0%&uDTAcAYVibaP{_tmoX)3 zox(#hyx4yrRBFaG7lwvCtDZ5EoO9)+d{H~DT_f-zac`&EP;(Rs;hOeNBtZ~RF8BKG zgrALqy}P9`h}Zdhh|jR&B}`{n|dVODwew%myZ-;x7App1x&99%c&sLufac zt@gIbo6U^;c&L4S+wpHN(L~vQqQ|J^LhIhG%|C?!+LgD% z=N=L)9Ij$63<*mE^Uq z@qxj`gFSZ{ZuQF1`OAKl(MGQ;XOG3V)~xvq5*ZdqA`Gj&i7K@_^bu6utOvfkR5tBE z500}VNOz_+gTqh17T&!UdWAakMG&`diIcsl+zfh&d(wdE)-HXf+u-;G(-iL?D1vzc z)27k?xnF9*@Y!(eJ-4ScP~6|kr2YtHG&$y9AvhIIKD}7Iazkl>$v+_(`n_~<1PeT( zj)G4W7waZ`1$qH|RdJE;6HxLFYmA9gBPj#dKP&%on6-MM-?W8l>?ZBwv>;FFd|V`5 z6qRQ*H@aA}SPI|-sD|VFgYC=cov^aXS3Ix3NMzayTR6Ux#L2BIDagT1(fiH;V2KA- z;c^nlY}!PfQ?iH4MZKNjz@HQb#hA`rhZw_($x5L4VhB@OEY^+i+Dp#W6l=RlFB! z0V7V3A+}RvppIKJsk2ke;=fs{xny8Jnb8op#w(24LYLw$`$`m2I?3?6l(A?nObKqp z@T_{RURlHOg#EF=q9Pj&n`$|_EU>RQu9KJz4one9XuS#7{jhrc0vfAK064M5txJAc zRyvopyY@Ia(u8%>Iig>|60NPIi-iT4%MgcKemzLWxUZ@r#$;9F7?XxKG7Uxv@xxp! zNTcaVjR@lahfP%_)X{*_XAkp#n*p|7{%>Dm{e~bP8pVgGHK{YycVyP#fKA~)z*(2G zt$x01ZciR&D4tpOk^Iqb-b0*-A}7mo3)-POe0A)SX0HGz(J`gupV#;S6TEM93D1=Q z)v&!dqsU_anHm)_o`vb2zTXgEJ*W=yJdhT+mW#lOfyXtksqUBFu>;eA{V?)IdQ3=1 z2Yf86mbf)0$CqVc%=$vI-xw$oQym$L;aI1;e{&*xZ00c^FOwvrJB_~SUDNuScGD%j zA?x{PAdmM(u-Lip+|>hNK=U5LOeCZ#sl9pwLPOPi(6z=!D%8la>)6=hwGX;bKT|f` z^lWe%f-n-MsetZ-J3xenG3n9^rM$e2#Jeieuc?mA>847)a$rPUhZllbaep3__C3Uo zNR(ylyB+JdpRZn(ex`l2-in4d6fPz>CRF2?P!35x;-fQj3!7dcZ)t>LUndO&LZcB! z_VTD7m3MNZCtT(oFK1u-E?9M&uB1D^xsB1yJYhQ>YZVu;kckn|P0OpG2oOyIvOt)8~NoK{wP$Gb|AV7Z6Av3@5}6L-}b zyS5&OzZ233j`P3hETr6|WTzOe*zVl3=0?3BSnnJQWlTpt7^}?1QF*Y6&p<}w(ZDg7 zJRqfolz!1Mv+F+ZGB6q}(BUe1@J}Vc!n_;(`8zga@`QZpwb~hrzLl}4H3svM>^|$k z)7Ui^1CP$lEkAvp2Bsv<1HOs7sV`mEKBA9pNoQ~eOLW=_>4o|1Izlz>}{$s6%^KU z-N5GSiA~Mi!3P~u#JkP!w>JU*v#Fw(!0`W`@nS|Gr>f(qD1NOC=DUPrReI}pdJ8rk z6uaZ=faZSJ@s%0%1UogLzpqRVrDz0$@f#2Fb%SOqV)5D_Z-CRAT9%>32nLI6n@OP< z)181fz4@lFHVpNB6gLhx+WoDzDs{^q;?8+=rIEVNzBKPw`wxH1kNUKgQVqs-SBxPpA)Pcpr1EHr3CLbvRaX!x^8@!0E;aL)l!y)slUv~P| zkM>^}yZ33w$#`t@o|X^3agE1>Z{p;XhVf{X4xTXeX3fc&e|o=(A^)5GTTBlFeP=j@zQ8B{x9jmOs1Vy2 zr)JB+&1!Rjs0F4y{KI#Q=jqc??n&n}elKxqY2A%p?&{~ybD;7eD&Bm zZ!&=lzWO0f<|xwDb$9VEfT<8x09XHujEgV(<=D6NMv%?6X<7*tAU?HZBd7Q`3Bb(N zlj=K^RpI~Wgyq%eWuAle8@)Fh5no$L>VNltuabb;WB+B8ssCe?MZIwC5-~%(n*AEN z7fIQq@9s23n{k9qtD+W$lX>+M%+kT85D12#3FiD$^ePnwoBRpO@jd<1_zoYAVKE_1 zPlzG&N5x$66{|MNYtqlF8=s1=jgoj@Gg`NOJ6d9_7rh>`myYEzGFBMYOLn0+)v|DCLZP0c%!URXfxcRuFzYB7g4Dm|B)N2|gXlD`8YQ7!`Qv(> zznh9}l9~5A?-U}dO@?V#oA%{t#G7&MPL4M9+LT>@BR(mYEzTpzA^)@{(kisnbZwsr zZLFL~8~3n#KxZ)5HZ;q#{rqwKZBSI^Wmxt4DATA!+?QaA(D~`kS7wv@>D;c5ms<-j zDjt(esZ{y%h6CAFQy`B4RB0$ILvbJZ-_=#1A*~(JYkmAT9XqE}z4~p}^Jy7XJO(SS zcjJ-J$c@K@l-H3@B#p#JYoa2F8A-oq(QnCbdoy`c3$(6M@=H_B9?K3=9Eykf1KW)e zB(T9Xc1A20E#Z;NZL>r~vwd~Au-^i?9B@|#Ubdf#I|*6$0Cg!gXnG4H0;${DK34w; zh#6!>%0_+t2W7jo1R7=H-JHy*Sh{NIS`C=PUPu)_F24*+(Kf5u|1FP7KL1&yZhR(t z&1=0;SX&c2?0W}07Z%hhhjxaae}l!acFUu0>z3u1StfH z6{mP9PSN5}+}+)ZyL*8z?|Hv-?j7U)%TGpF*?aA|=6dFF_JtQ|KoCq+0S%=1Js5FzwcP69;~Re^Mdw@Yijz z7=5#mgMwkTgTg@ySSwAO*I+P6hzgL5x_Mb&k)gxKmOOTG*;%*Tdni=!a#0bbx;u+k zhDzes92v51g9o|G-aGEMCsB(jzW4YW*x4JT1c7q6#4Y$xtLU9y{+mWi@bL?bZ{b#- zevNw#BfO8sl5Ct1AmxKFp~NqZkhUHLZ2*tq)*FR~IEQ#}6>TCFuaC#!I)pw+6iO zN6*NnoKl2UeJNnH!E#a?^U*jbjvl7i+LQeX4>J+O&#jm!l@svYXTR6c8;YXH;YB|` z3)^l!4SitQT*}?Fx^vDXi+IU5!9FlKurV+V0ZU%H8TCz)8GB~&ZLC{`Xj4mmdr-~u zPFjpV81dusOc zjhc(&2e7SrtKH3FYrrM-@ForZO0i=r$%?B7)Z43PlI49+JV2IL6>qJorFLqW!WOSWfGh#nK&)h4aQ@u6p{|n_-)=D#@Oua+E7%0Y7h7?@CA1#0L3G? zW;k+0lL34Qyi=E%%N%Qoag8IeRPRX{F3!1nOlBU_Ol^!QIUM@ocQ%o><9F97Ggji- z#R6wI0K{P=1NfWW>lRkVY;jmHO>5{e^gv8jQsO2c*!X=urn}J5iHjLd)}Iki6(TrG z(yQv7cT_(Bu!&6)=#wcO2>kZ3@`|%1!Q8@US?FrvDtmcxT$C$<55I|#-XAA;C60h4 z0n`OCjo!b0pC!2}86T#<@l=`^Q?M4o&re*&hUM^#G~a|h%P|#3w>s0sgb1g8K~MX` z};l zUjQY&6rdT$!uW5cma)V$5l%%ysNDPL3ITdM8k=Wb{Fj`3D8@|{@@Ji!|>#=R^6%kceuKav%TITZ06NRO$o^R0KRkWVSy2kp@A11QwZ^x)Y&- z;b;jF$TJjT!H9(9S$=&2-FsYK?uJr5a7+P4^E%OFlhL&A6=|(plwG->l9zzb7BV22Z zuu}UTNmbsIhWv^xTiqTtE>`4GVaj^7y23wQH|zrrddJ=XgK|IOw;Y-#p18K3(GvB0 zB%`}#h(!-UyT^4jNH2Aw`ic__!P;3lV6t z?Z>SxM;u5e6Sm;8SnhO*oW$=@2t?0&+HUAa&g(34Eb9PGPOMt%8}xM`Dr+d7$2OV> zBOA;Th>B}9n#DM>2B`lILI=_gjABazX%SC}a1p8qR`_TBY56iMMCgz?!dzS99{n3mpJbPnyt(Fe{vz8~MYg z4Z+`6!kf{Lk*-rxt?vBpR7@0_HM3RQt5Uc&*I}g`VYIz@D~QKLrlWr1LJv z+R?YNSV+T0ETYSBOeKUrngtkd^@G42Ctc=?$&I25nS#Y@cjs0>t46v@dJC*yOoZQw zOr697i+<}SPK|HIdY?=Z!dW$ZkyDFln!d2BgfwaFZCO7B6iOFLupj3j$MzrQzHRtlJUX3SSh%u!uZj^YOVV7y|XV;n9_r{4|tjCGwP>UtA zmTkU(F9c`CBIyH*Z_3v?UA{I3nh3p6QdK(FzquZ{v=#mQ^HPq=LQUD>NU&lJy_|;U zX2c8EE(C#)?&6f9gt?2IPnb8;=Sof5f#@KZU-6^+qDgArz>OMgOKD*WoX5IBaV>VS zBkFlvp|0oV$#un2!(@*0-d%An?@g(lk@JqWM}BqU3z%vTkJi=8GnYK0Zw#H|up4b-qd zZVlL~`D(!zSmq9Q-W@x03BRqL1O0l!75s%$fEx_x?PuYNK)l{%5Q^5rM2;!(2;?Zx zI3gI#(cYf&0A-y_QhGi(3Hz+yN9A3KUB)czQtV_ar_0mazrxRD9$tIcPXr(`fXsFM zQ|hlW-<7Ke^OY@`0+O~B_-d&{eQNCwrSJe74|N$~NIa_bzzOzNr0D7-#}uZuRxxe~ zoOK;d8gy8vI+Zl@UGao!B=$O>>ogoj2`GG>8e%zZ)uE9-kaJ?Q7iVRV884$88jMYTw{+nV*R`(wi5bg&Pd)o zBD9`brZ0u3VbSx_*GQy3K`cp1M5TL;Ho;(9U;wwy=8oEdrjXm%(??C>`AnB-fBC^W z1Mc>W+wDk~eYmu6-JN@QxAF4*Bj8pwZ?HPUAsw6Qk8fS|RrAcpL!wY?<)O;g-v>8% z@-hy@%$boGHX^p$*M111VKj4`st0x6OO)+BUjzn=GM!gc45SJX>1uC1p_$1+{uG$~ zNHi~!H-o=QmFI!FV~ON|{u@;MwDSu4X_S5z@jWVZA)3zB3$cfU;bUk4w#Zi*L#|$T zjGexUJH3A!Sl-a={Ll4W^8n~cCwEz8`^z;Ua6Bow`)qdb)F zCZW-{mH}4s91m>Q5p>p2(6lj)PG5~42cxxPsPgVw~#4t^+XnG6% z4f8KWiWUu+{a7EN_Tl{VnZXarzq90}o_IAF=SEAT9 z@vm}+MLg7W9YGizh_{5G{2w^m&XJK1jL<5)rf!uFID&{+p1M|MK_5B5FkzC~hli4L za8}0`iE#;@8TtT567QQfY>*1MxuRm)!OS+r_glWh`tbLEu%8)D!f4pp%UjypBw@Ue z5MM$%WayuF-O?fhsC<(3a7f7SpMw4Oo1&#>^AT7;tkx1T6&omS0~R;!2MQ7E0~p<9 zCHq>?lqeC6$*G2GAB)nHQskh8rf@nz>ggq}Xg?j3!6aQ&3AgRyGja6$EO#Mjf35?? z5}qEH2-Q^421Rrk3d>knymN(h^T`yFwc?@W+8R^6p>HBK;>Or^v4hDE5AOEvNqe>N z5{RYRSN&(gruaGE$|U;i+V0VqX~2CIY2aOVbrKzeCrG)_uudB7)aX6`f;iDWpDX1N z1~?_E29kr+WIfpW>QEW0**I8@7w|?$ZtqA%My8?} z)yQpcAMv;ly6i0`3jN_0<9F#py8P!yU_>~bws1LD_l}B+i%YuwzbpWZ1mE=;SDftr zbp~gKM{xlC-UK78tkYfG83GX{bZ?b_E9{O2yn~zM(anDFsxscPSigWr4)kGvXMiz% zcuhELwKLCca%RKQt6?3wps;3gSjzZ}S4SWSYG|Q$6~OojTFp!c&*>tSe}U?A&*bNT zmS67@JHZ_Z)Wz1;x;s02KhptmN4Gu7x!q#diHl7`du+#r4WR$r0al7|$_fgNj6+*P z|D3Jn1hMtg5&)eX1YXJR}3!cJ}(z=dze|`LU(4~tT!{qx@ z2+r?fV4QriunWBi?OS~&SjpSTCK~gG-`ZcuYE=q7`#OSP0&~Yn1fF3OU+TY$;Co?r zv}>bl&y+2E_%&v|{=kjO$=u)o79s^1=kCinFUdnC#n zQCxjA>h8N7_#=zowQ1XyS>Sk#vB{hSB-n!YdS+D(MuK|GeHVW^W_0oP;mggA1WTE4 zuwO9B33cxckwyXX)3n1vR#aQ{NdAa`fmyyNcoRP1_w8&Fml_a-y-BjRfCpi-=B;Ib zENYw-%74B-v7>DGKpk@zorvNXJJHl+QM`8-v$?Z9d!{?*Jzaj`%lVc zE|eAFahlnj&E%!cuhX@+7A3T}oHLtmn2|+3X?}_WEDyKyqg~1&k{eapdl$8F9 zsn|`3lQ8pn>szj*@;+t{ME1kJ?+%dL>-lu~js(s!2CzJnDgb750OXHGXpNtp6@yeN zCKNh8D#EuDrW`d5{QqeRYMTCeHsKa2Usad3q=SQAD5Iik`IT5X^64^d@sYtGb$(@5 zybkS7q)M{OFxL>sUkEwQ9o$GcmIURWV^xl{4YWBl`+;Q9UQJ+7Y1!iEZu z8#Q9Fzc(hT&S^NQFEFh5HesNzgeNH>@;vKgKBP6za1t5q_jBg zc{}<6vCDIyj|B=?C4{pY5Ui}`zF;>TJpTB?3h)d?IF$@we2RE4|5%+Gjy9o_WUW_D zzfIab;vM@fd%--=6ojN@YYY5NE|SGawjU;;fHW{Z87_SA>N?> zI9}1hrZnAE2Lqv@uD(kZP}@_{4ct zi*zKefBKI#Ng*S(_L4o-AZOUji!sz(|NZ0GY^c)T4>$VR<3VZs=i5EyH`HqjrXuOf zfKb;tmc1WBLeruKila(2W2ln-^xe;@pH+C)SGZfi$GQM_k?BpPrF2Yix*@xi z6a-%9;_#E`tX2L+nd-b>=3T?rQWr&d$422}9rnLxUy8y1Ir}oJ+*y4tj~?(t-Jxux zLvEXU?5+8HruYiEUD~v<0SF|Me@t%;`;YSMH7P7)3c;ej!~ob?gYK9iD~DvP{`c;Q zjtnIvAj8(y>gNgZfqW?=mi4O(#d^oZ&d+We6)RatOu!4ydZ6BCZXI_?%fS~$B(z4) z*1-o7A`pDk`(}4#2--pm50;3-725&i|4>RM)#}5D{Z&UL3JVcBh)dJudJj;CTP4&q z9_rj4x-~{W+|FmnY1Vw($@n{jw&V9O9ra`KC5Pe3Z9E_YAOnWIO{R2~RpP0YmJ#k1 zI(;JNI!A_f4R((aVHRA8~_k&nI2N``d5f8!)+HL$Bq8Stn7g#xIdC|gZ z*c}admkZhP#`;n}_q}wFoD|go+c0{ori==iBsjoU`*b;mSKLWtLm$eC#`g&DO}y3Q zDZ6TE^l8b=mmVU#X4upnuRbHmq$Nk3G>R@R;y#L8`IPN5Gl6-_Z~|Wrt>1-WwQ8A2 zx+mgH$-Kh1)W&dflTNfhOP)L=j41S9e0iLDoehNycR8Oxw;3RD!Y9#72={Y*t?F}V z$@_9DyN@+A$dG%8u1(3#tJmI?0ZNCAkvC;OYz50vVCPc_WzU&O2_G1xv}#Tjzk$Z$ zGLO%^P*hn19oO4dJlhghlHn@B+4nY52{4)&p-LSR|0!yVqRFw^?Sd`z6#W={?u3=t z__VJ6)*x=vvYk;wQ!8wcwz#s$bmIXNQ^J~y2+XXlY2(Ab^TZj;={!?bM?!4A>6e>> znEbOx;NC(8x6!^hoUqao$yl6ge9ROUgh&e>0{dhyHqMfQ|_W#tVyp;#_y^|1#c z!xUFv7ee`+XFStOy)y#MsznpY9cRJd%NcVB#N@{4uDm}QFRr}ZUAD}Nz!MUh9J~9f z>g$JJ=ob^k)gu!xhjl+yJLisBv63sF{4$ZWY>gUH>0exzCVerd5?E|;Oc;$FO)rjz6IH23-Wab4eSnWu%sO#!7%$d0NDWDb5U>*7S^CGD4^6r zmc1R}oev4+!~;JCiA_RLOa3v4PgGDVXjmkV(4M|4-U!1eKDWSNRoyHqz7n4+_wnF9 zL-{OLcPB&K*Z#J^?@`wC_J;n+o1G%tPs!cRtxTRZK%P2G>NQV4Jgq9cW2}*lhK$>Q)M#|?`!FuckC@>)fczy_A}RwsCKf%+o?GQ7 zAW6gFKSim`s!Gl-UOHM-hD*1m(kv_+Un@1i;XYgox?6=_q9rapfo5v`d@#CTEVze~ zm{W8BHT~c z@CSZ4pm+GxnP~^Lf3Eu*{C0q`ky^P3xqcrHTE1gphlTpcm&7K8T!$f}s zK3v(f8AE{*;na{P58^6>(8|2|ogTqik&uPq#;njuxpHjESl)n>0q z*JipulSh6@s=}dz&GqYAtzAvX^m9bVAl!YeL`BPc9+w++S6?1cPj<8kJRq%2htTfky}H$YI3j#E5CX}> zps?0Y*yA!hyOf9jDzWZ=bpo`$t-xsEz=Ftrrp#0@1H(10zy3xkd7lxPsm-+o#lLHJ z{-6WjNN1GOWkOd3pQAF)_{hu&30)BQwiCD_$tnCnJgo9gHf&8d4@Eo7zmG!CH|hI# z@YhdH& zI@B>|2~v+_Qv#pR9wJ{!JG2}lqDedQ+eqbQzZeIQ5;v%Y)3MjUnkXIhQG2UN6MLg# zt~D&fU(89TiBZcR%{6Jt>a)CwC+;oQHZ@83v$@o)QGG@r-f5B2hNpEt3rS^kPODd7WWbO{J8v>qDPVw+1*Tb8 z;FvjsKn+7q1y}(Hm<1sr1FxW4KC-;UfzQ zFU9YXOBXyf9Nn(-x#LQ+uBrSsf(04g@WiMP%pfczEUAD^;<<7t0%Fh05g{d`M@`tj zmoG(~2=4fmYG=~0&Pvc)BL`276HS?p;kBVK!AE{WAp2Dx)eZfr{+SxM>cwbQe`kUj z%wFY*H3++*N(~a}JiIS`Yie0fVix?4RT>r>2ffInQp) z7~Z3*t#~mN>(kp`EC%gt<1G`O00N`f%chq6`uyShyo7PX+iHukJ3}{GMMaxD zL%&V4+j75?Txa#ep@V9FW*^{Jgc6!^V($*~#Pp8TV&;VB)>H$2=2K`GHB~Qwm}0L3mx>`x`T z-=QJ;t=q{c^F8Crq^l^TbdsN$N|5@5vkP@Hkp?T?g64ViMic}y# z^AdZ$eZyx@JF>iT@&4v}i2O-XPxQ0VPvP^5NGxM6Bzk8sSrlo;;+uV&KzEoCMK5sP z_0sd^<@~KqosY+Ll?fIKOl5!tnq4s`DGIY4H3HE26|DqG@pgnCU&-oD_S<7)D?zZ#qe0J@> z^I81;Cn&2w-w+FgQBHT3mI@@f5js7XTOLc4#uc+gjEw0y&Y5?RgO_RE*SFvs9$mlf zemL*&XGt_d1t(?_!|I(?8G?S7pzmaQ+HB>njbr4a_s~rUqlDCVFVh84ydL9G1^5Ul zBJ@iMkOdxZwt^*)V-WBTP(XAKNK)FCKbA{di^fq>7vib;y7Vv;f1dr%;+*Z3p0cu@ zo~-!a@|wIYPAi|<-y$DY3X0Jrf$kVt+XR6sTle*ZL5{=uGV8jTNV0zmBWr6n-RYmQAIGM9;>zB+a)C{1WV@%aK+~m#8dGp*pUh!OXvTPn z>cOZZEkn`0<|{=^pBf+VW}(}o+GtYU|=;h#_l7d$pcDjIzYi zYfzWquZ4zl-obsDg*WOS`1lw$;dq7RBV9blEwmp}|vMyytCMyaXL` zf6!GOw_c?@jj7zt%$mdfIdsimOB8ZnKABcOz~J;#)8GP1QCISC%hk2H@xYzH-!;x9 zYv#)zKI-l~fSTl?OtJQQ{PJ`2cXU}lLy@x5(s-1t5?{a9!Vns3JC@S zRVl=PJ;W;j`~AL^#wQO_Gsw!dE6NfWS{v(#}w4=s=K3!`**!?|s%ijimXEvdkYRL0=p)!Ym~K$noge75^0QB4jBtp!Y(Nv?d4tad)q7WKr>dxoo>2+2iNO+<|&J$WM z0;uzIw@#G2HV31JXnpWkc_Be?fY2u(An8gxp&wbhc;Uy?@K{e%=WkzNiq+Nj?C0<^ z*fXf6y%^dYJ{?m~p6GKz6Af;LH}Ql)Q9`-R%Rw!-bvcvL2K~ou2+Zl5x*c7EkE_YE zCTC&$^Z6BRdz#?T1url^(Y&C-)YCI#AK3xSE*q=jwgf!N-#_#{rV_53+!%Wn+p_&h z4E`n>#JmE3D2tv+%Xj~EqQllik3uPsk+^uMXuSmmtVzPbH?Z*m#|UWh>?%0ufLFA9 zZN_E)Wn1rDiF&uj{O`nSN@2SOJFi36t!+VC!upDR{sJkyZ-|y)Q&Kr5suAOcI=8rf zX`20-gT}q}zsc{k^`^S54Gg1?(k3qv;((h@%zx`t4bmJ zNAU@n!ZXxs=C019Y)NmK&9tUWl|2ljP zCyR^WO;~`%*7uv+trTuTf3Dcs7GD9jUU$0q;zXn74@SZfDeHmax8-GNFPz;o%oV z&{OZEL>X|v6{)b%%q1=zteW)I*woqFy@}kFbYi`}J6y!2Dc(SDPPV_3q_gz?@JMm_ z@3=VW*}sEW-+u?OtN$IuqBVhr|Mo}a`8{Sm7V>U2!f$-AV|Nx7%)Wix?JPY;WRA~p zqeqNKlF7=!*9D+!p;hBO=wI=`KNg^=wRHBFZI7rw>B%%Rjz8OJ0m?^{z`4qlDoUmE zB$dhm27V{pIsz8b2p!-&1s+p4&_1s*g0pp?Ily4FDCRM|=!He?S%91rAd3kGNMwTH zDD}$Wyoi1Y;c)qgXY`iBD1nle8TjVj@mv4$ZFT24>GDa{xqg(MEGbG!J$4_q??h=Ipm80BJoX|9fUq5<_zlUIfz6(uWJHFA* z>rdpB1tg_YYIf{j*%P4$qP7-qSkw);PN&p+w~4$l1&{I9!g|cRB9H$9TMVN9xay z)VHGMgln?t?bY)K>3QPQ(PT1yQ4+5TjrO?hhFG1Nd(-kwv*pg|w`SL#J)a*rqck$I zauVO&R)kv`C=ylc>BC}#PiXi!T!{lvQoWkLCpZ(e1v7v-moEhm;z*b#U+)NaeIDjp zQFWGg787;?v#iAIK>c+6CIt^9JGii7oDHgXkz*Z~)zu68oUJpS#vHbyosSKbao~<4 zeI2>2#JX;Ml#B4NDqp_zflcKfU6y+KpBKFG#SZm0@ivMx?k#E8Yv|`HMwM)n+to-n zEH02tJ`Un&6GG9Xx=!AooZ_@>8$lo<t^lr2`7I=ilD1sF}|56GoJ(GvsWTL3U8fXt9%NwOzcJzG(&R& z1SE&lf95XdOct++xxU3P!tioa7Syp&UIqOrQO0G*%~Cz$Fa&&O$JF4@%tO zlR2~&W||?9q2O3~t}?&2gwgg)oLf`TOibY*swV zB2Se-Aln_l>ZvyBL%i&B{Fn=_0 zg8uBjnR8dWRPK<*2ab-NIm-cAxIP)G7h)l6aZoR%lYd`IXa9RCo&WPv0xx}K?wB!E zFW}B&B$E<>2^0vU4%bqPfMlH%ktaX`xK=H+6qyDFTVRiX%lhK!9Q*FUU>7vCK+Wki zGy#1KBppT>Cz?`a8P1y5(uMc!NJEB#eX71<^W~p{bNR=U-{*dhmx3+;X&oG;G&DW} z#Pi;k3HYR-lM0iY_^l59iyMQWmwptom?H_z&2OlKEP zTt73CC(I5GI=d!sJx367hlY-f9p%&pnJ4<2r{nctR*x;n zI}Ttr4g_CKzNtL@A;9}<*cS`3zp-ob8v2yE=a9r*EY6QT9gpBkUOm8FB<=zha8g)0 zmptEK@WQ2}sLoT9<_+BDY^XQ8#Di}6#4bv4^Yvu2;%@cYbJrepCfZ~cSoa6stp4@$ zUam^v`H6|0BM*@!+dB_BviAGVH}vZ>wm+5=8VojPtO?g_=4~5BO#ub!UD1*WE}}36 zxbzES7i_{hI@s|CD%xFAb9d5F6j|nlh(h$`D+vfMk+mzv5YRz&R zlwagd?jX$J{Z)DW{!3t%u3o>s&%gKd!~bh!6Q)jcyFQftygTzD3wgV}k%sv=&iRX`8Xcbsku}G}ZYU ze;-1P(7dLBGyc~1-c)lufFJ^+*|?__`AZb-mdNWFdk@7Q*{z%+o2DwBex6)@ zk2jTNUR|LiT&5v1B}&$EK_-G)S_b8e$|Zq>jZ_KqH~uEs6HSd|eQS~AXS7_S0_6OT zN9Ls1o_*K0hfhD>sn4gfLz_1Yw43-drk(=a*j$%-ZD-y$lA7#DU+`?Hvgv&A*RLu& zA56EL>f{kPNHQ+B__jA=(KzuU^qdjO$bc?6q=tJN^eV`)+Xo(Eg*(6KBa*YO#dj)f zN(X_=o1WjBp}E+I-rvP;(Yk4ODkY=w1~)3C2F||==c!lyY%w0nUERGd(!I_(Un0zf zh2cD75o7D7z4dIhj#2f)5T!aMJ_G_uFfuYMUtYa{flydY4yQ`nt|>Z{NOMWC!wmbN zY`L0~&={dy5sTk|`_0${pQFcUr7Tj97At z4q^ky0j5SMSKCaS{LsdM`sk1y*}keW{1?TRGE!^An*44M`}ocIMF7GGaM&1)3x&!; z33Xh<@PO}RRCqUP&oPjbln>t%fJ~WBQ64}nM2?aY=vU0b8@v_gMZ`300v1mR`xsqp z5}r*{k#;Cp_w8Ywf@FNz<0^41+#uX!u|FFO*c?uSUWAT;?rWhP=Lxut!V6cYE6S6cbv$hl?pjX z4=^C*VnQ^k`?&$wEC4+Q1hBz+Wki6x4IIuXd&z86{}g2>-U>p(;O4BR^;f2zBw9J* zvfn&CX=%IYzu&EpMxb=|*7BKLC%W`137Mq=)3kHFKnr5=c)OA6G*NrScE@mfYmu!c zp3+)^(l}Ux{HdK0CN?&=Wgi*-)UuM29_F7J7+cs?8F_Qm=!c2m!1`?J2DcWuw#bTo z6GeUpfx6L=v}uhlI;T5RvFl$5oYxB4h%4APwDxaTeG^Lg!gvpLrJlJVhn zP@S!aU~tSidG^zc9J1Id{97=(ArK#aOLcNMPm==&b8dH85EaDevFN!#Cb+E;4q8? z+r-6KJswx~bchAzGUd!kzwpdo+-*0&}#Ol|%y zU8`;K(R^PC_TwZL=bI%2I=-1BA;4WuAW#LEblN8OOhu}RW=2iMFyLKZi7;k(Jr0yn zRpxDGp>@FwzJ24B&h0GCwWQ~I&&S)pou)e(-m5^4c7a%pX+Y@BZRZF}176ugXH-yR zg#?>R;*KUra}Sk@GzZK>yxW~p&Es;JGbR)vb?HLMJt4bsZSk5?v`j*y6EDGVYexWK zz|U7ZfpE(GY7XWWQWfiA$nHL}l?OZHEqk5H3%)sT%FXnrMx5Q>r66ww-FAbAq=KYq zuv+)@BQaYm%g)SA#JOpY&y1NPRD(E#wgMK&e)R}Yz zKQCtQ3A+=fJ0z$nD8@)>%D2kC@TwQQ9CGrs2$H4f_k6HzzI1wO(3D_m5iW*d)59Pr z25!K(=rm`!bu$%9jeEK3HpkUCR=)gpNju->v-gZv+PekJstf&h5KzzjMOcqHG3&nZwRboMyTyao=rq{(ucB9WTZ)sB%DG}BZD-;ET z6^aeeD#n-fQj*<;N7f(kUHGO%;t7WvU~aRj%ZRgZV`N2CY`LcKci8^pUlnP1=zRBv zZQmPTJ-p{)BB}K?6SjHzC_|2dYzl>6FOAT-k5w@V03DwTezC{>%p(zDpty zD8@u`Tv}h(-o&(g-aX$qan{FDV;S1>^+x@U8CGVvwNC3#X>WV(ZirGCHi`PJ?!pe^ zIb9m6(&h~3We+HH%lA4hIN9@F&St881buHLce$~v{f|(3==eXO)X?7-6o3Yo{Ku#* zS_?Y|suUy`b+mqdiRwV$l?k9G{iKx0#E?K34EYNW_922BjAm{EN1+EmiCIxxPIP4X zrC*Y?hzJN;v|JW^YhW~<5Xb`|02@Hkfs=&ZqJ|lLNkqsfkVA8@jbo79Z(afp6_TXr zMIFcDRdK%8KFT@b_buNI!fM9zGDMSo|C-t>OrOW|pebLW2gN`UITCDwp3WG>sDnc& zw7O!^+_8mKauzY=f_Q#F8OXD>ULO{Z8UcToYX?*zW^;3o*i+!8QXZP&|H!{IFioRZ zJ)aB3(4d>cMM3<|qSW}i*wN`Ef}>m!+?_5X{0_V_Mat!(R1l#1drcIrB`+vZ%2d_o z`~@z70Xp%iis8j&Bv$Lw^XWZjp%O&m9dXwwX@$$;>?bGAjLQ3=YA-*1e!0Kn;|^*q zkEtfIW#e5EkQ4Hu{^0#;`PaJEZg`uPfx}5ZUzYCW*v2K{-t)q%|JaD_YMF=?cgBRK z4Hs@PZA!^EV4*^6V<^v7gTA>|6FEC>>;~9u4c(Myr^GLwH$u_a=xEBS63SY4DVCnp z9~--w$LhN^FG3*u&UFLrV6an2;y1?tCD-bexX{nx+r&1Um!#zfu2d-sr!dF4ZOu9R zL@VaHA({d8PBO<`h7O+|0>>H;J-+a3{HxxpcGuafq3V=I@ln#0@WI^XZnf?=jZQKD)nk}|0Yt>pC+iiHxTgL*Xe`dTRx-ihd&cV8Y4lc+7z{pmE=9L;4{*l?{n4V1x&l!XUIrMIIHciq_&-Yf znPY=MFV70_g^l9+d08r7G6Uua&#WHg(A7-S)m*<|jgI)(@7`SCbkBT$5dGUUYW`((RY;GEp6?v@6j^6KvF(bHtif zPLlB_Tcq&_+Ro_OyzKRK{`GaFA!^T#J};uK$F@@6RanrLq2u5`Mt_xbTS9Z@^k+4y zOvbe%b>U6bK5w>!{(0$}HgM<8*mu2%iY)wIvxH;+QjzFVT^h9hLYni$6-Fxr4$A#( zhRh-B_%k|ec1V3Oss-7cRi|?yMp^mO%?D53ryVF#+3L&0*J!)D5de4VcRA7`dS<&6 z(-`}gwD|aD&`i$iH0E^PW%d%6)Ew$<07nf7fG6gnXel?hD2KgX`agCjB$*xal9l-1 zRs6hRpF^h|qKs?lZz(DAgLGnJY#Tf;ohzy{OXHk8XNR*|2Ou;L-Fk`c!lJU0m*2hmwB1L(lkZ#yYrXyxqi-^E9*MW8DK(x$klTN2ShM}tKN zm0!_b^H(cdrqtF!;;PGwT^5TaNMMX)I&yF*nQfrx)@Lag4A909koIhQ^XD>H_+>hm zO*y5MPkjr?8DhB-= zqdODNVN=rai;dp%ugl)-_gzHH<2U_?H!V5nnY%(B7%|!wQ0x}xI&kqhW?YO{@jR}< z;^yx|dszR%cB;2Gm8ulEbL`3xc4h=mRA_yQ0ANBA1kj?TZ{!O+A?omiYUFZ-ZanRX z{6tkY|Bj|h$3*l0W9ux#qHNc%KQq+O14u~?B_a(<$IuNbN=P#xAl(f^C?E(3(kV!Z zGz?wRosyz-cQ^0-?B~Dtr}xWzo#S9`uKT>swbpMDaBDHP2ff73${0k|Z*a#nWkjVl zz5Fh^Y7eZgknVR5{5W}yg&aTW&&B4Fp%UBaiHWV_#IH=Rf;itDIQ+yDoBHc99t~IJJHMo{q0IM&)w|?#Bo^uvEyZILObjLu^0JLZjs#QN&MbZ zqtXSR*S*SX9+Z!h+)GD`t;n*t1*705y!}%!D9(-pXJ{N@XHF!)96u(epohU#A+!GG zMjc1Rq5hTXgOG}&J2M)j|7*`EI_1P~e@AAj#&EqpGsff|f6`_vu&={J-X7*w`KkAF zu56BI{|+st6re7r1oM@b#n4|wo&P7xGG18fV@aETm){KWX#aQPmH5An*A3ne!=~7~ zmX^z}u7@50V#UTq*}_|oOC?0fReAG@swX8dJ8}dHon%5n#Nq#BP`#U3nyKiCLN4O; zEBkr?g-(Dj+WoAs+l>nT&kRSlc;q6A4$X`$3bH;DzhVj`2!jXd(X6 z3Zp7vw{h1;{EAZ=!TYNH00jk%illu+{J3GcJJGXR-UOg>3bY#rrvA<9pqrW07bkk% zxB;Dd(8OytaW>tLB!7r5%TFA^Hh~tO*e!fj7IbV9x*wff>FufNx%Gi3N;Wl@GhAA{ z-%}BfL&khJF-(YoVkSo==$c7rq-lt}h|%X)W20TyrA0uBHJhfzpDS>A&fAN`1I*H2=pajL z+OroDx6uu=A1hhlP=r@;buxWRW9(#2EyG9I!v&kMhJeFm>$^LgY@>f{TMeuK`UA)R z=MVJ9n*HeA;(z<*YDe>C_hf}9ph`f>2|hHS@vSBU(gDqy#t7=GHKJljQW?AkOG2Jt zh{H!19)K5j=lH2dfyl|b%X(}ma8s>swRLBkREDZTybnX*_1=l71p+V^=ryQVx;T)l zH4g0$w;559O6&y}W0wtV1H^$MDj7ckzw0Su6^4s?%CT#DZ)JVwsQml0We1IxtBbmK zszcvx%s~hY9gLR`AL~1my)&^vZ2w*^U|xw_1oA~xC|%c#&?6Iw455=J88J5cD+RV7 zK;MV84IM4V0{SkmWU1C#n(N!w5@$KY>XXuv1LcCrQu)CDtQ_ekZ(b`lzGX@H?*;mE z{L)A}!`yi@nI0|1Fo=Y%Y>0^GDT}dE;xh&wMIYSKETH>2u;{2*#sAlu!^j>7p+T0{ zqkW@5E6+l|S1scH_g}V|5)EKH(hpdG-NagmLPAb@s1a*8iG1Xo}OL zztTV8Gm2u7KRhK)?8xZ|``6`A~JSA&PHaDAv$`0}bC6 zrgbB`ku#xd_-0AnOYw&b^Po1x2$@4@lLe}|_3HEyP0wXO^&B5s3^(lHRndT$ATov>0>08z1&w;^sIE(XnjP=x0nV4#smDCX{ z+t2TP&JF()h&gZkZy+YhF!1l?!C#eti{ZQTj7v>m3D-@{_&}K${MGwdngmt~hPmN} zaWzqshQ-eF>P9&|$CRrj#2?@`9_Z&bn^TA4cp>0B-oSD5GnI1?1q`^3kG3PgK;Wh| z)>GD;2da>Y8esuj5xxT#0wNfysEtd+k5Gfq0LTZ4_OAL4Z49NqQ4d(80YczcYi!GI zEBp6r%$vx}79ojSrYTyu9-bQt4w|(4q6~<{wb9@bm#(@`{76vAWc=;pafNK2y18C~ z$E*M}=xGdShmWDI?-7QXwY%0(_*nNAets&zftO79gg)BrrLfK zab0}R#o1tzkI%zCXOl-bfBM=BMRNbP6d6ql?0p*+&TF+Mb*kp`8oQ@2vyBgk<|f8) zl`!zpt<7upoHtEz{{XrW0`!*HnGNgI0H zQ-BH{Yg$)VQ5b=j{<3Z6rjQoaIbhlD>UP+&EsrmLT9{S)#-#A*LmqV_W;;FX=GS@x zu-;-?YhRmsV|{&s+F}3tKi4(+;75_~`v=XQ)=y^EuQc&sHb5^10)!(3e`y8I2A&Vn! zE4wJWJgMU-fYp4wYpfnURGMe(ccyl=VY#oGc@uqJz%WH^_11~b5Z8pzyzXqTooZ_I z#UtvGdRqxU;6WZAZl$2x#jc@cK+Q=AIc!sms~Xz(_v#3;C}F89N0qGrn{^^)@P+p- zpId^JXA@s&-z`UMwy3M`S27(8FA`NAs>f|^Nr?E!8b^cezFe9r&7mi*3qMs>_KeqZ zIz!;7m3OFJXWIa;mwsW9G6?jG+(OHb#iCxH-bdEbC65{4f8w@cnDV_LHCBi8KF9vd zLh(r^^v`4_lrUsh=8YZJaB-iHI^HF!n0faRjLeCxiX{XJ4`#vHrlUU&^dzuxWUfO#VjpqexNRm=6A-3vHQxR5xk*DA< zF*b5Nm9h{Y=Z%n-e0AKJb%Ep4%Yv>>1G)L*u$<%$%CFzq*AUWABwlW=OMV%RcY7@c ztO+X_+N?}@}K1crISfdEo@dv3aW`t|U!LTXg5HL@=4L+2R264b3+ zy{m!1jM4UoE&DeNpiyPNHlbenIYQLL;V?4k7dO@PB<{>~Z=HH>^-~u)!Qw(;7!!O+ z+1}te^_bRfg&U`QyT7`S_DO(OF-9z~;DOUmGEi!18~&k(Ed6{td-rBa`xkI8>-y(j z*72X`_6+_XYd>HN_*4E5qNOd6Mlr>D=M+U780P~3!*LamPlCT6UgwWr$X-?o%xy>` zY@v@0UzvLvVP9M+eY@b8)!+{HGEs`UY<*+de6z^eOl{wXhbZJ?cVbr1t!&J$eT}J% zQFhcxM2v&U1MR_8!UG8aAr+pVr#fyZ2Tv)KKIS_!RgEim6gL)vEu2!oh(I+)btYW$ zgzYp~ZwtGao2Dllm{mX6<~n%pIA1N)h-WsRf5m_Ef$y`wGG@o{ja2lv;^J`0b8D34 z=^7mp{B=SZyKLps-_`9zz>4ddFRHl;;}u8Q>?rllSf~$CfxKdRvn2ZC#8blXx6NgB z^bN?7_(#uUgnl{H-2T7~l3p1#3u4e<$ zAq~PU**crVT$lW^mg%=U+{1ahzdikP*c>1Ac6Zk6jpwWi{tTx7CFYe&o2M(*`HmdB+gfQI&g_4Z~m^zNn4&o=FcD zQFIn;Ogr0OW_6tdi@(vkYl$G?@V6HF?esi7HHzh~FnPIA<4FKER9kUEAi`7%jU%3> za^Drnb}tE%fCH=2UDt9yI{`a?z4HtU3M_DWUJ|i!98O2vRK46@$Bm5W;bcv|!`-!1 z>f7Or2;}5H><#<6e|Na9|J~uzH9)Re;b9nj?rQpu&4EHHfze=0xx4REN)b$qCw_V( z`;$OW1*hZvpd)_OY0AqyyNj`(C8Z*!>Be=C{kj!4nQ+@l^kt@ls75!^xU0E(EZv zy(N6ZK923mOkEorC*I|Tp}k(!pl0-SoQmViOmgaB71)Yemc}AVIr&WG8(vn#8w<2I z(i%HNCI&{3A5FRfGM7_EWDzg)dn0c{%LV{a!w6qy@XMEq3Xk9ui7w{i2tYy$RdCV2 zp5xH|Ig$ixeEIPk)lw4(moGO7M&`ZyTivsl4Zhlly7UK{eoJRyn6`DjQwmaw4Rm(r zfwCgtmK9x#0*HRz{+9h#^2}tj{&d}78uokc+Chmo5D|*jbY$_4_FMNO z$C*jjH*M%#GeeKmO8?PbU>16PI&G#D;2e9+TM~_W?xIYPifV0OR5DHfh?H9R^da4% z=4&h}iR=(D0EI@(F~ZA4&U?Dlwv8?lMZ-|jqAeNdYC1MCQ#dR5P#!Czv8?fHdSyNF zfb8E(%W<)Q6Nh}jO>NO%r(F#^RD1vTPytccf58>NLlu#EKYwzrv!N|7bx+7#~Abt^rkB| z4t9cF3DO}#$b^pA^Ui6+k)>{)uG@dhymE=o7oykK&@_0XBz6CO>YL`>&}LN?UIFH} zH;)^Js!vrZMEkIR>>*oK_5_`0(;L+2vQ%Xho#&TD7y<@!Kk*rE`dY_5fkNw5M`XEU zm>X{tBcXImJjIWrean*6NKXeGQz&X)pRbuK5^i>$wIo3d>2j`&OOYIRrlJ?!>U$A-gm1a9CqVf)mVX zSpcUBKa;Tt6jgW{+b0>ZD@?}CLwvRHiGM$(^ zn#FH^uX*J}-XVRL>y%#)fKx~gzyIUb2espFka(zY6Y8_Jj;bKT9<>16Z?CuE4|t7O zPPpU7wI5{=l8PoKv-v!MC5GlYk>JR*2ZzPeMH%vZ&jgNORh}JnY;d3@=Albiw}8i^ zIo+Mje?lpJG-fI-UwXN`WUS63W(Y4IZU#P)fTNqy3FwT?*~M7Y#ZBkx$=JvH%pxhR zqn^xrldG+!mm_xivMedpgB`r?0Y=pH9{;^(~q@ptpYx4$xwPfx9{wuguEky;Q% z0i^X#+7iNwU15wO4ZxP<*@<)yYvh(MLsWjgnaY1OwnSiJ#S8xIftg8M*38W|NZ1)_ zLgJW+fTug_-_5AIkLHSnwbg!poV_Z|gOU19_? zwoRd-p4bQE6ub84*t}S^u^@4jK_7@Vc8fN#R2hS8})h%8$F?1)i5aPj3}D>k#A3T zju*2n7;Q?sO1hI^7g;LTsp~CWvbP9sbHud(@?*UC)+RDR#eWL);Z)vRS1LM_j0|x0 zzG889GTCecqd*-j7mTd_X4+7H)Q0p}IFMUeH$f!wjan}sg-|_`!%v!^Y<^oy>C)ms zNe1H_Y!xt*`WZ{GB#TfwusGDfom*9Zj9TSo>|s4*vDdl4hd;pH)hHkB>P_Wo@K88- zg>r=dwbI<;KCkQDan@DBE$N5phq8N$lq;~NM%u?RnfeL3SApS13=Zo8GpbK$G3<-*D+(WQEq zdo)$0aht+y%rudfqVXhqTv=(ADQ|I(!DQV!mHOvA$^jXu^$Nr(;CUBq}L0|jSMlntHo7*))ID-ncJ*j>4o02?6FnffOilkZoOMg zKP=II1TxXYCC3Ky`na#Kn$n` zwlW)lxAnV0icmbpk&|X~iQC``fKL4ps8Chp3EDe5)2(;9t6h?M{3rwp)Dx7GTM~sO z0N#F#G)5Y;5Uc(uAiU4rCE3EQ~Xu%Wtf9e=9-L5^`^Mk9c9 z7X;_9;J5yAb4ERpNfsmpbO%wk?W?(BoHB1o4E?pNmbQ{PFWi$Ga$sJSn{VR4Hpk4O z4D+^zWYn@fTCTiUbhrqR#aM&u~%0C)sNv%xa11|9u zq-?W4>4NM9g^XGlvt?;Py3ZsB(wotY9b5B&*GUa(rZD1A)0`1pJr)DYIXR_rX3D3H zW~CO%>yYLge7w$(t413xok5lp$#v~meVD|A+U;kJ!I`1|)Z4!Ftz_*kiZ0~+&RlGQ z!Sg?=N99$b&3+2Em+FKGYARs4yI>VZ-KiZ~pbCTix|VYCV%qo?6?~4d z{+s-XYn4Q`g!{a^PB0SgMvAocNe-5q7wx(Zp;DExoKp+fg4)qX&@(Ew7?*!uoGFt) z54kzdnI)#n^2YSl+wmi} zkC=u1kYOk+tCbCyA_c2|n;nbvJTlxQ^pX|lA%?P=UvtFW-kA8 zTKkaxgGc;_{JOtEUaN(PX)NII*!d>jk$Linpvh!GSn>(Dd^Dwpp6HFe>A7jpOGpX{ z6#DxUhDk&L$h=aGIoRsg>n>Ty4HS-Za|YD@$wE5DMp2Xp7;voVhbu3I0Mx&mnD-j^ z+Ze-FH5p@wVK95$f{jpQ-OS*Ll>a*AxJT2a%$qUQPGDNd*_>)|eB!suY9tBDpa|C8 z+5lHj^kEl%*;FM`lHY#?F?0y7(6lOVx};>AJlr=&&4^v@azsAvicyW zv4^3fJZbrDN6bR<+R$27+zaZ%%mkxqZE}t1@|-aLU2AnPm&^U0z>HWz{(92YfzNWy zBpcA(*|>r8HLIJLsWpg*6x^VzA_XRK|B@A#l27|r6Hyit%u16{$^vKFFz2JMpq!)2 zuzc8^1hjmmcZ#gxq2uCK&Vvc_)D-*ZQMcdb-JD}jBfsv{yL7=z&Vij4wd$6GY>n~a zAE#(#afX$R$LCY34QkX#Wnhq)Y~Ojw-HWsK;TP-S?+vSWQLBy>rPI>3mAKF_?5$}j zDAY$1Q$!zl(9|KCW)8?uMksDN-l6jWYML0lm{k)4n=Up5{45M5dAl6-nM_PE|D>v) z|FfR|`p-f6!6-%Nt*kntk$+#+PdnhLj>)FYaPGR)Y0;*t)P3~$Q3Owi@(E`ff;n>h z5S2915B#;B?vLA3RldpK>i}wBgihYQXTy{+sd?Sh5v78#edhsfZd`mBSyjc+wME{h zx%i?7XHvi?ov)-5v|J2W$R5f*br)w2LXAwr;RpWD`JqHVs7=UFVqqxWn9dri77K@zCqf7;nKP8TkRUyC}1(vEkhe%_qt*$je1=pb%s` z8r&j{cx6Nkng+G2v3+G&!E~Y#em3i3NIvDJ)LJRQ8ias7&vOX7DsN;?Fmv>N&&y+z zbdALnayOvgmjrpKVTIc@FlqsI-E(L;`1xt>utHO;ptki|;mjokQhf_3?tvnp}P!G(9x~Z zV}km@vv^6CSt=yWK&5oLKh=xa{*}u*UN8xeUwPU%HbF4cI+=Q=7K}p0n6P7uiUYrN zn4R!1kGqT=Mn00|3c4n8ND!g*ur}sq+L9!#-tZxV>(3+q4R3_)-JSCeDv+KJJUaVZ zWB9kODkb^7>1&IA=D+7{C-vt6#-hh14i(>lpsJ(Jzgj>|!~e7ZPq;RX8q$yVlT($BXWs}IVHdGdNhTlG46vWIdjfBM!L%thwz2QdCclj^L(Yg{ z2+PgRs+UnXEJ;*z=*qBppViKGP#$isfHV=Yz1| zt93gQnD(O)ZtfCmoL+hu3$aC7Di~PR@mhH^ zCx)v8&N8@qax}VP+pv|xLaFwS-{2=@E$HWW4Ex}$kwr$}$`|zQd9trc)-)zY2V}41 zL^j9~rnEPaOMHZw7F!#b59**|+r!9z9o}-}Bp&Udf5()R=%>3*3o&zU&fVSgByv{) z)U;o*>mPjTv@T|a*b`TrK0`b;Bbhv9H=gNvm{#50x(t1!O}2H^(0dg0dx5%946F3x zvl;5WiFz~fA3MxlvR<8zPbr^~8xiOF>r($T)w-RU>t32}zxv9yumbSkt7IjAaBWXd zT|oYDOrrrD!@w&>STT%jx}CoZ!(!*HvWf}5`VO{qQMvTXS4Q}1_=T`J=c^Yl>3nRX zacz0Jk=$#o%bZ<>wC0%6EOuLob`JJBTtcq4PllUz`2^!Tq?NxsGrbX1NmZjc9$C9v z8IY%uX;Rzx{TEqy?Q2z&aa3SNqZ%-0zRmlNN@$AOsGiB(}J+ zd6W_l72mudp8gwl4i|%h10@Kub^L1mxF#*40= zRl6posuWc?ba>U42rN)md~11(GkZ!J8H*CGtEUG)7STb^{jSeqrHuHkI?ql919VhB z7qM#N-%7jjyRI&N9RJ`0hxaH??$-hiqQ`9R2d>#CUjo;(qkpeyyZ^hU-Qb;&!6X-Q zy7Pm>Kn$NGIHN;z;Hkf78R2Uw)=Qs`@el>rw>5ugRe6RYo{!L2C?gCgVci%^4zE4< zosEbtRBFd0IUYlynIs0{^fTW^>@<2U_ytR|{~kD$eyBj<4a z%x&%A#Z~&y%#RE*oijQF4|fQMPw!ZaFKMc#T?yW1&kP?68V9GHf8f(&|&RS8B@`i?h*=gP*yRr_O%f(&co) z-^bZfv4O$(Jy?8l_#W-px<%@KcoFZh=Pm63bci#AM12vlo->?vD=W*{D9fL?%(2d~ zr`fJ|sVP()=xzjn$8P6t>3kGdW0tQL_1Nv0EHcDm=M30*cKIww-6#EkROe)NXcl5! z)Jqz%0V{u?A^US%TD>E{ggvFHr>vT2tfTo%u0w~5ma7D_i50F;98y?))pqpXNIG6KrduRnWNb`F z-y?=)+%|0>KXnR#gL}jWi0@>ddp9oGJOdwoY>_|3Z644u(S!sYCy%Kv&V)}0V}MAr z3dP~EzgZuSBdwPfXP)MP;sN>92&BDw=PMATPAZg|@Z%1kZxUE6C4QiUYjIf>vh+aJIe^=6w5f zm%@wfRQ*i}2eHh%u7I=nlf1fLCl1vF%Q=}1Jbp_;`>(|o1TbX1J)}8BW9RxSq;<0a5esyc};n_Cr}J=27sRwc7Z=8D++%Er^1-w{p!&$ zUH#L|y=I||6hP2aZEwJ-L47b1im1;LPfmU~r^G+cpeU4H=zqDCW<_QV_IiQWfnj{J zrWqkjr6P0iu`kX~OXs7dWtn-lnZ>sx*Z}Co;sp_=+w$CXG2U?*2z0bq5TspUu>6aY zGp?4RoWj$cL?-Bn(<3wPdGrCXDrdqInFnI=;h&mtO04Us!S2+Rsfbs$yQIfoi<5{| z6QlYnO`N+c@Z(AsU!cO-JKTADf>bm`#Hq}`!-+Y!RlPP9lwSeo4y=J@OVQDIfExYf zLM*+0O%UAX6l7HpRUQEO$%08{qN9w?DM<}m&_heG=4%+Qa|=+4l3I}3(W_E_HA>j5 zVOXR-c2U%j>2^VEl_I~5jD9F@7modr!^tcZ<9jfjj@gkk-~uB_>B7WBM?XMohXf&a z5lTRcZh<5)wBw@6s)7V)jcyXo25^duaOQUqnVPbME4<%<|AY!W7P9~JvIC;4hS9wD zUb`oj0u%fs^x0mz67qais0Y9V-L}O!q-fgq>yg5K*c!jn+|=hTXL!g&w>IY1L<Vt4z`gc(_lGGqrmO}W2J=)xAW%G;{gk38J{L@Ya7BYD z_g=M<5;F%}1tbworm_;;W5K+LNxKaK&tidEuC^b*bd>fw#x}R6`RvWroz$Vs)j|%> zH=ESu&BPW*?%NeM2G!JA!W<7k@50bfhQc49Y+@oxFdOz%rB1MmvD+u>G<vEu{{Yt`Y-#8LAjkk4dGSc;IQG^qM}OOG-nUyCbpYBNSyzMY(BV267JKKcwe8 zv8+>IkHI7PQjs!c7e&VROHpQIXsT`coZcuG78cqoVmIa~XMFR4vpY)+fP4dkjj-Ao z8CiCkEX~-~sciofcShm9yY{|Ds9LwkQ6!U{-D;sO<>AiTLDL(`D@*ygOBpQ27);Ct zw_EM|M?Is@X1I=CVCgFi;6OA;x8Ei$Y<4cu4mpV#@!tCEMVZXV|FL)R#U`WCy8mI^=9HgLsR| zoV#N;a%c+kNdN-dLM~=CSGROi3`>Sp);o7B%m7<13rZHmO^~&7N)gsO5?O=uiK`R`$IA zCl$0nQaSoJ@uhC#@Dwh4U5{t-#k$3Wgp%pS`IDp6!Hyvgc#!PNg;x*7B&Y))TMK~8 z*`vca#=Q+6P8EIX+o3tgR>Em+r+)x$J9MOeYa`M;0QL``Bw}pyKS!CEA>`7afflVE zAiKVw=)-soKQj*DNB9W*0<4SgTA{X}kcS6jd*%l!sc~&Q$UbMO+Gl0YM&?!WqpHV* zB;B9+t;{ul(b#@}u4OC|=YKP0TsqI2H-3c;AJ5Y6eBAJd*b!V=^_7irqAp+}U5oqr zg{3X7>--Q^v;Y*qAO11OcPK5xPL>VfaIAJs=l=sO*U%e`fuQv9Eg-piKij`L6Dm+{ zCxP43qBH_U58EOit}y2 z_siybBQ1y9iEQXdH~RzTo8ZyMal83;HUaa3@qD@S@(vAsx*2g5^KJxe;q_|^Qai;T zmcXMSA+Tf=VF!iG6M%B`)5Gk|3PvF7>Mcw#UKhs12)9qhsbE5{tU$-&G36?iVSmV3 zgI3X1RXg4*(=0q~k8#7|KCxbkzBO}ytCeb38lEvX(liQi>&yN81QLupSf}_@A!hQ@ zY`Bt_zif)Dss|;MH#eB0b3N{ty2l2Reg4n+clz(Zdi>vkm1mCpVm{he*uDKRv8VG} zalc!M%i=N?_{O-e$&+UcOZbt9!R_^z$1EF)P=%wh->=@Jz`y)4ftcx)bbBHWaviJv z6b`UV05BVu(l6DnzF$~g0t0xw{Vp4awERow2=& zxgR+UNUL@F`pAs>z|uKzOC&2e>DW()@mnSShkA?;^8*o5tyEyDWAY)pWXFwSS>7ZA z69PERf-Qt`>x~SFxvUeO@2_o!e$SAFceF*uaE6Pe?zCHe0H?90QAW!w6>z&d;nhT| zsq^Lfz2)=L9~h0pPq~4VV3Rv~z=mdIhp^3N@4#h7#MxmmTz7cT3SPo zDXjgBk_9X1)*=H|Eq%nF5t6T`__W5(K{3B5X7^DD5ANbj{~2$ho#v*^wf^jz6Bb)` zKjla^FW}|4r%Cfum{i6`k7!l$t~}j5%#LN4``r6^`kg&<_n%l5iRH3}X&tt!tag+e z5N0%iFrm)^j}{fgGLXM_lfk3OYy2OxCQ5@BKf5GGp=G33BjIQeD0W3;X2 zam5hKfk)!-6aHrxLp}TQ7>d(G`;7tZaXzkd88yUH*G&MBp+2r(-}R)1+{$@bRn_{s z>upCM{4UtR!pEA!y>-F8?7B0t^Z`NA+7AFeX@c)T0kWo08X&LE_jr5uY z?!Nj0&f zAgq@<8f2b49Db=BjN-Z4b#?w4e>dxsp?)vr;{bW9ul>twFR5E7T4Wd(SD7D|DH!7^ z^fo+5CF146$!PXG`i)&g?k@{UdAIRjum7@Xvp;`N#v1k}o5DByi!M9}5gQr10!qGI zuxR$xYG!0I8X&p7FZR1T-}71i6Gn4L-|MrY_jk@=ziD#zI?ZRVj$`sNvwF(^l8`uG zqEF)BFh^`m>NxMXF(sGJy?DLlpZFrOYgFVBfKPyJmKO;5ttdGs+Bgg?D-B159%ENVLv zBWXA^IlBE@PGEq&jTP55*;2x+5Nkh7;G*m-!<=)W{Um%ygu%?-&_DLjyC(d|a~=fE zW}Q6+otLkqF5%6d$9DCw;QWO*c84vp*?OMBT zm|v!LJmdmWyFA#J zOrN)i&jh-If(&`8YG3-^1=+M&lvfpzfy()F!$M2bR5{*we@<|VCMSUvv0*CFj-Cr7 zI#fJFjZEol`C-)?1U^JBgkw5}VyIxq$ zu|Dul%Xy25$v8gljTERB&5PfeRk0@@3GBhq9*P$FBfQkZ)92R))>rdcp?H!R0tze4 z5wXO~`uUr^F7qc&PnX8-?h2D~?#-lUQGSi*#?s^QXmOagK2$g`N9{2`D8%Dut27bA zQ`*{5?IZ}C{b945>sdwk=zN`}GDRG+z z6INH9=kqC&*R4^m?%0+OLVuZeCrH|*X^XX5%n=pS8Lnlu=6iCAcfDAio~^6quAb`= zW`M5O*a|4dL2 zZT*Ws*FHDK9+)19@z#WgBrrps#l(jYMR{{>FCft*B8&=;zQjq7!h)H92=D12taspqDdo+nne6vDH}f5X2&rYCUTg;8SObziFQ^f{GReu zIpH4L_jK~cQh^)Wy3?=242VBpgY_&Rqt>ic(_>>9vqsJp2cus{Y6jgtm8?!m6F0B( zm6_0cABYo4`r#oN&kKcQeeby=+c%t~UMwW~&F&$;Jf@<-^F6QIH(X#|zx!%ZTOiR4 zf>%v29U+wHXz(u(D;R>r5LB2UpSe_d@-EPo(XtYhMNbL+3^u;){F5W~BfAYGq0bCg zI1zGr`Pq1ch>LHzRMceEeI`((vnlPBA1nliMObJYp+m>A_!sWGEKAk*8nP zbHs5z`MeZ@p4EQLLiPc|Sq78AHZIS{y!{YBl%e0@e%9io7nhG; zgvPWtedWNrANR8&O-PpFQ_THEA=$=2^Fxqa)z9}mB<0D=HQUZNc6~h3y6(yXAKyM# z8IC}k#5O7N2a4pFAQKwsjKSoagX3EFM41DI+*T7F(#$XQYif9b{8yv73* zsuk0Dwu;Q12!Y0Acqqcq;SlZR=9;vQYd8olGd)VPoWmt2zo4N`xa ziH_Li_M5d>{8QDG+e8m)Yl<-V)*iz3D+W9-Q+^}EWQ$xugI|plZz=JipV@;*Im=-g z(yt!6JInNsZ;)L&N`f#KmRBLnA5L!xHy+Wxz6+J8e1pg>)=~j6kv`ARRqGGE$oUa~ z1#)`pFG81*K}j`yo@u@Gv1n}d@8!+O#Qn0)_yJLZ&EzZN!0tWr0&tk>PDANGpZGuv z$cqL@=4HVB>#MQ-Mslz;<|XrjD)hNO{Ht0uzH@iGCO~3|yD_BaVKOMqD02JIBuU#<|odVR)Jd!7G)8q0H_y z=2Q>li5^;lpe^O?%770lcA2-d=f2Z)h%KCKYj|*598bnmxWheUaa1AmAsYqVM=y7C z2?o&zP`cTT5y92FqRfEDw>{(ZYML`5@seI|eXhsijbgDgT$8~8 z1+7x~HO_9UV6MlJ{XCb}>HL6dgo!*fYx1JmvUzAb*IPn8lG9{vHrQ*xLa>7`_bs^Q z*YBIc@)OyjIUV8?@uxg&%qbCnl`Vgw%Li=z`&zyVO3Y)Uu!lUfRLWfhew83trPvmB z9q)N{s#%y%esw>BtR={0vO45}Sn5GAQ$@aLa#Q2>yzdM^BKZ-wYgsZqXzRu7bBz!% zZR)K1uNH=#JUfRO!C@C=F54fMlDJ*9>bpUb9ZsV&oQJp0(JxfQA_sRS6d9AVV%F%W9Im1n3KqK4dw<&_LnkwQ2VK@kF=$z#1|Z{qo) z%0j_P8;+leEfEXgO9v)M5a%P7QC5iA3lVsfLOL`2LB?ZeOgUUcBoB9f|Jw_nCu6Ke zBDv5oF#(f$Un@zjfa+LAZcvyeL#lFjJgVYbs_`w^X0n5)R`2(@Q*zG%BVZV)?P$LG z&nI9~V5sKcb-A$F9J&6RDtQ}lZ~O^K^rUcQWdu=gF|0c)DLx_qDlZ`H9)5QoInI1p zSV0Doo)$12LBZY-GB}X0W5iI!+6Wi^m?E=Qo*Nwmc;4{kh_`kCaNj@IQ1 zKBB!gl_w_1MRNwz5EBxfIrUt!YJaCCnIA6hX6LbfmQAnZshZAAHESP9FGbhiK5a&( zR7lw;!q8*ge%SZd*5#znIRHhh>hdx6xesA4i);U7D{1A7;)6|AmtV~tU&%mi&QK=S zwDCYba8I9Y{Brr{{x|+4myE75pt=X=VCDo@TOJH{a2WZ?69Focw@FV`g3F43M?VNA zM5j*6*|Pbq5XdpmcP`tC6&BAK2~f~K^QPbS7s280Y5YC+``*`M*PV)9aQvLcZ}zto zda8m){e#It{N9g5$!ipkDN!ak2wSXJ?0sB2gQE z1(UF=o~^e^i>N`khoHW~d>@V{B2ge$6P*K>KCG{b#h$#vsb{_NELeX$cftZCapypc zp0b{2J@!qxTFxq_kjhySW)L=41AI1+l?p}z^H&uvwXEs6h0}-p3HsO`>6DFRBalZd zD1{1qx~C@Mh$mtbe!bdqg>Y zE_u0lQpD3JwmbG`d;e;QeA6(qnSb72=MSA6rM9@<)w@*$m8DB`XGFP**Jyr0@eF6bNP8TkM^{BdgAp&d8B${ zn0r=dz9cQqH}^R|b*giJ_MCJ(3$4G)wmOB$?S3blT0TEg7%e$N zxGM%@A_+e{!CSUMZd^!OYc=enuP3NFJi`&A(muS1|>PQ(c2j{J+xbKN>u3ULL zVkF3!cpMN!Ox8WWyjh8^h*g)Fev{t!SD6mH zSoCV@lzhV%>m4Zj*sad0?it8w)c@IOxFM#Fi7Y5vwEM1`9bk_zF{P;T!cNr*zAduZ z#gR?mA<`%o5Jb>FnFvd5k6uGY>Q&UAueL-e6LAOZ;dD%c9*0xM38zLUG#Czi{g?4i zaReag2bf01&=>1Cm`yh+ zf2hHC{G4i8#1QC{AmSb37F`?tp!gR#+R2x9Cj1QlBd=*V#1Zryp`K6au;ZEpf^9@Z zGZVCeM!OjWTM>i&5meCz-bT%)bX|*{WLz&4WERq{2aO6NB#EvH{y@=+GVq8rn4L?T z(5EMC9UIspj1}C`1+((9QI`tvkahYOJ;x?2t}A!54_9qDu1H>Wv_tBcnv=u2v-v-- z`~az$Bo>n{w8z>~kmxnR{{9I11-MFzRxjq=7^)i!9OP8F(O-_G*VqzrR$mK#3O4C3 zv$Gvleiee75qFvr$}I}EFQ7yIvg}u%g*}6FF9)0V0l5aXq**?Nbs0jaw3+cXR)Kz+ zECKWyI}SIwLlB+9dqtZDAr@1WZmG)Y84294ThJ>_Yp*8_rKNU>e`W_9s@uB(kq_Kq z+3Ia&|J;*2(K+bUi3fAK$K?&&kv2H+jU0|KQ(GUuV~n>9r`%$wKg6oIWJD`h5?_pce!kPU*#ORulKig z{$r$YzehqL_CMak`esOaDS{B;+RRPKRf7%Zo-NIjERR^g10j4js!J1)*4>;9FFrxT zwrG98@S!J=B~Nm2TN%5^MmKpEfgu!RR&wOT^g!>9a{)u_|1tH}ZBe~n*!LcWk{r63 z0i={hx_xE=n&wV`q>=!Vz z_jRpxo}Y8Amxr$_8b~8?<4})Q(O8KQKONLW%*gCag7`mnv>~gtTvJm~hfJvEhlES% z3G6<;d$V1^+w6=y=PgMJPWHh1&obHeGG^Wv9)5ijaYZ@){a00T-8naFZCL%AfyD7| z?KQ^6`}5D?n$vnBdi;77w-(0#M#Oz(GaLBCVY4~bJytY=mhOxR`_pnVW(r5ft4jBS z=i#teY&CyAtH-AyJQzFN`O-(l1ZIind^%_E;BA`=DUnCJC>pUOJCkA5$6k5IXf+G& zMG#_ja$3wcu&$bBC_0s_4?=dPIB$PR?C zoqT1;yX$TX;5GIiy*FxKytwV4G*+owyU^t{XT({MP!$;Pogqdp3v%WS zOW!#BOPzjn_a-dGVla!%Scte;Vv?GSxaS88P$iJV6U9HU{|dBG!+8=(Rb!=`PdGhH z6kYEEZG&Exu4(mn%F70i_89A@sk=m%)h3J92Amk*?-YlL?9dmEz; zyD6a}0=r_J@RgIO%=Q8NvDX3?EX)&2c?-7I_Or{aGHjmOw!^{)9<`1lr2He>4>wPu zQ@Q`JLaMBTL>(Tzt>b$RxJ=`~VA;Rva?Cv~FWscgzEb-lIz92j=?HVA zKjOS8f9HKuBHuG+Cr)m0Emt%a>yP50%McLB$46}9iwxfB^fhuC48V7cjq>k!zVBX2 zv_YK!4~^eka7DS7)ZIdkPI*6nU3Gg)`bhbSFC6>_UowzgdA+jAjQRq9*E=;MP=6NI91)+?(>d&Un`us_ZRBo zuQwM#uYYngaB#YuCJmp_72DECv;d8f6%hNc;iDW*R(xR zY_jB@>&(lT?fIF*CRvq&$##=A0o8 z1~Lz{VWE<@em#ExkG`Fl-Sds{XEGY**T|?N_vQE~U~&&7njiD`_9s?@YlpEGREy%b;`Q3;qJX=O)V-elRH9|>?rsg`gZ>=dmFlFIqk8L}%Oq&w zDz)x}#Wm_@v@&p6)>4e9+b|LGvD-rrrPPS@M}oevTn-?#H|iH%kZn1wFLUn&0#iq5 zIqF#BG)#0PcYe{5F*Rm6e6d_W^$fCXi8pHUcwFlZ-`P`R*%>XL6l*T>WjuKoGzuK1 zYSN#QBX0fmADj%M>X5t7OH`Ak9?LGa%Z$u*ZrJ)3`)bfy?tSdZ@*{!zLCG}cm)W4; zo<|DC9@jP{YcK1fV%mJZx0VcDsfjLIX2+PT%Ugn3M)9HY&+U|Uyk6pcnK~k;!*^mIAn+|Zjzsr;T%YI&4gbD0_(=t+LL3SlDWNJ@Z#(?-z3iB< zZ`j|wI<5bw_XF3p)L}mj>P9${ZIVf#liQ2szvpShuae7879o0|HssndJO*w_oD5C| zBF)^glW1XL02SMn6!tv&UZt_AJw4DVBV$}*#EU&2K1IKajb96Tj&WAC`NHe?3%m#R zcK#vU4t%`6N4p_zr%Q}|RMrhlv1CM|))SkUk}ML%I|k_JieY-S0ibAG$)XWXzDDy@ z&&SKe69H||{jO&m9)bkupTp^%SCHMI~(e8TBq5}40nCk5QXoiEWrsFNO%conm_?P!Mf{`LTvA~$?! zRH#a#^7Rs%C498o0DKI3^q1B;=PdW z_S*M=CWxxVLUuRh1#581E7(vx*Bw8&>2S5V=}*1D%!MriVynjHV~(IdI<6L;*%7^j z4jmCZ`78kW_ut2*{#gF03o#6ZUZ|E6bN4ah)L%mCRQw+Ub6rG-9x6Z$l&T)}#&W7pD}c?yq8Z~7U@_NXmI8WYM)I~nlT+xAgBUT#Djo9wn9&u51z z`tj`kSbFOy_fGMQ;xztc=ImmW^Ce3040eVuZ-X5Q)5rMD7(rn*)kAorlFZLt)lZNNKxxtrYePP;a3)xdGey|GmH(#@n2U z(?I9_(}>9u_c5~&a{zT1LP1e;mEM5lA{9L8QZ#l_wcvZ432WKtzX;S`-5 zRBbKyaQfp@1v-y-N1eFsOkQof-I`pYyGK$UB~yRO8W%TjyW=U&J{O=2pGC zbh8krGubg4<$a6q(7Cu;XrRx8ed>cOD4yMhU2H_vxe1(-ALiX|B~m8SGzM5((>J63 zzgYl-eL{pt13oVT#67dU{6}e@2?365M|;uaj2WFhwusAc_jW?u84@k&hLV&g>1=Zl zU+~HitqD|3x%8oq7<4yu56FSK&9UnTi-(sV2U?px$occ5d8GMepbeH>q}7%TfD%0W zT;O#0975dsB&AZt;Vl^eaS^#K$AAd8;cDN1Q1;%8JHHGcQR!b7xE}r80O(w7;TL@H zW-go=*}L@UDmaj!h21!p5CJ|^v>ofJ-_y6CpSYnOMwu?#F=h(hkx-r&W9>5TOQL3wI`qO zR6vw_nRt&NC0P%M+v9JEo@TbwEs|(2$^llk(-bvtkM~(I&mx`8c{Lejx(e&(NUGrM z&$X0}Lodfmv#&wf6eRuou81b&l)PTy))yM@+?V^o9>w*Z=Ap{#_P(N4cb4?5jc*!G zn#_uU(*gBf=rI)yxT51`NiHQqMOUSSlN@>*^#b-Q&!yNJ?oa+faw3n*Hz=&c>sQ<- zVhDW|8Y-~^D*l>QgD^ec7CI}5C9*q{GqlHRX~zDGnfh!!$!a4<)8&7&q0~u5BR9;@Q+9>{7&VQQOU^?PfQoIm3f1-x0ebwKVr(8 zB3i}(sX9rYj#N$7odlA>G2Hneu($^8CX0UfL#GhK03Jnj__jI{ zB0$Bpn#)G4oU=^;Jk5upy!iM5)Em;(gt&buAD_E?_`=&vE`T_t;d(8B>f z9SDWRaQg`iH~2*^u>z0gbTIeBZcX;C7)q;9thP9Q>3PHzWs&u8`wNS;LdC0W$_C#W zDYfL4H)K^qYyQQfnnWFzOpXjb^J=uaqmsSvS=a7|P@XG`_LdgdndJN+Z{+`lg3$mW zWEcj}qbVyE2UHZ%Chs^drR#W?Qul$*b9TY_q@i{%a$xa{Ag#cvD)_)L4xrgc|_))P|R#MA0&J_?R| zabIvNjV)U0opk?tXs>0cs@$(FK~CpCxjqGa?6Fq~0kbCg9IYISc-db8;(!LTDcS!mwU@FNVh#k+{ zmSw%hZkzsWejQ`pClT9sHlw|w?9UwS6ULAmu16XrLv8dM22h66uPOI1p3J@!Z~wrg zZ&5#=n}R1g91e{wSfEgiw*h;-Z%7~n;!dW9P}WINdtukdV5-ftVpW@>mur+ z4(;7lYsz)tWr$l*Q$zIRbUmqZr~n^3`b?6!!uUPGLNdL+cRCm zj{w9=nMlqK6uOn-)cZuf1tUp>Ea>D-}mj#65Zg3+^j`6cGB$eog606|#6rUSjie-dN!j2C!melU=<_#mPCUF{4K zM$waPjKYR4O0JHADh2ENtAl2t40)J)2|(}GLIe<_x0*C$kF}a24Hbj|&MB3FC2~y` zAhh21snX(l#9CPOJHtUW#IRt5BBi(V=&PK%EbH!M3P~}ahy6ZpaI`4b6pTz{+@jl$ z-^&hu7m0C&7D0>6ZurLI&O!fh*vop019E8ayj?F6a_erz2+T%`4WACn}ard1? z!}%AwUm&Q>+E1n@bbSBGnn}iM)OKKf$$iNEtHkHgi;xch$|aoV((q>J2)VAG45*KP zWf16N{9&FQ|A_{GOy%sj>H>{X%aTLR%?z4~TvQ0Ru#bdH>1v*ZD zViV22318d!uoP_jFG!bA z>GK$~xL&@WS!fv7yizmlJ-!cg_ILc57&RMv4o5y` zulbEhegmlM&Ouz<@7CmBO8se{!1z$>fv?xAn;==ZM9NbD0q$!0Kz;hYm*^+NiSNNu zU3r@!7iSy3cG59~c~G3FxCo&$M)x+DW&y>L|@b?^X5P48qMQvx5c9^*RQBUYX4v7uqlF!Be~kA$E9;tOnUZ6l*g=;xB5B6T+=C&@Q>7 z?*cmA{C~t-@xPH0C|5=esyI?TjQLo+&xABl0Kp*k@iU7+2?-$Yy)j6-`GK%LFt!%) z!(>&`yVKfxZjdSwaBc8Mx9Au3L9-?RGWT zSH?tZv<(Zyio3}TOE+hc;Mmvd*xpWiq0Bloy(JV@et*V(aaKZu!)>Tu?fuaScHIiW zz6}o2c|PV9Z@_o{6!* z1HFN7r21 zhYZLMTV*QsegyX2AB6ViI{Y;qzLjSUyERr}CQ4Y-;um0iTyE*gEU3F0Ozc0LH{y`k z8dxRMZ+)fQE7h2LGOxPY5{;Ut@a;3YHmIp*KB+9@w)BmP3HlX8ycbmUT3`Qno_Pa1 z*8jzmtREWzqa(9&r!~T*mR_x)CW<_UP^7_Qfxb_B#&Vx3O&m4n^DF+=1s*bmv7c73 z!U6RVXUP0ZrNAI^K((H&RSlG9^3@^gPxc2_f*t62#mR+e=kF>T$CmS;8U&W5}{ToqVO&KGa8ySpys;!7+; z12(6rikLCZ+TmFRY=C+^I%y1 z)6!?ona!Rsg9hNEK(>><7oEnI)qwg^Y#SQoR?)92H7%d$!Le;k0^3zzVvjv(7&>c=(K3m29w`0zBZTv%uu|B+8sJpHw>Gp*>NEswgGMMVH7ZQr)c0aFFfhdPR(C*8-(=WsMt;JVeQyL!b`_*+uZX%nYC11Pcj&VCT<+y6?U>BQ`9CK_5c0h#3w0PQi-`%#dR&mo z4-i{X@3DvLcACF@TdDBTGNI06J_!-)Po@bCu0QAK}4h-s|ivKRk{UfS@*rHph>qNuEKJSrzKj z*wN1)^dvWp*IZ6v#%q@@uXZ+rkJi5R(aAbm{ZE1i*Q|o#dn4q3LOdwJUqzaQGzQ=7 z?|falinISc3jN?dqbMJr{oPys1LbLnieNP#7sB9oCgTG&{RCd|!l;_&4*SF3k2X-P zI@%w@6N(rOXpr0b|9RY9Ot*}qY+P-@9V%sfZ2HxumufkC#{+oWhj&NbFyXx|bIqNl zf@UPT2Og+~OTcI;P3vlMa(?yl>6^i-}0=sp)`-m%dQp()r3rk#pm!hitJ; z(x30U{U(vFSlVft9_pT7>Pr-*D&czpuX&l6ijvYbhhC)k3k~du4^xN-_{4|;DsUOU zL*Fo>ilkCgn4TvR#;i(z5_dV{cCg zJ(LoK^Rv)O`B^u1r_pJglxlKSA9xy{qv2F5tslB8$bf<~BGy)-gwq#wmpLcxpRrcdyMxV(k zcZ}a5|8{!W?)B?8Ew)VQhy#|n&A)2i@Mq8BT`4iw&j@oQC0YIL8!iv$0AjXSZpQUN z2(nf3r8S@qH85ylZhg8(>E%*0m%zvl;Qs(VSF2zbpCI$q&6$vl zt4~#4!Fg&=gqEQ{wNEQ%i_N?>PQ0r?Gx0}mVInX6#A@J0kR1e(^xkH}3Zx_lmPSfq z^-2<}H@uU_1(hy-V6nm1k<#_JpE#09<6JQK1gP(xeBbu}577t)lg;M0jV*nhBo)r{ ze(DCFO(-o&bFm*1o-7%j>K86(KWJbFRM=5LK(r|jt>oMi6&TT5tcs60deP)NFAIno zZ+IhOWD*`HBFXC-;5tkv7BtrtB57j8-?pcBR!NN2 zv+}0nwY~jQFV}NmC{Bs1agAN5TiaXX4ATbtq!?>53a`6NNaC->OhEXki*h*F8e`df zmr-Hc1ASgUMN2lhXd7;C_FQzUICBiE;H3At<3lKpp`ji_x2Uvng}CoF8E0Y%DHo-B zIVsjL`)PXCa-M4^59i?*>`V`BJ?>+2^!&3wJsym^*e~eD#qO}&&9=eNu}3KC@QV@m zHUBC^Pv<31C!J0`&d|FeIq9VDe9=52<^7s7Yhs)<1#D{^`uC#A{ObM{VhBOn4G>Xo zA!B$Z5td^BR75RiWV3fkY48|TY8a1I#u$P&T2hM|C|hAvbjgajP&oooTYyiTc0fO18B+A$rG$H63y9|g2$Ta@Y2c$NP9TJ=ik zA?OF~f5D`*l~^+HACgOsYHqe{@quq;HnHeKzPd5c&7=rhjS`&LXVv1XEv57W|G?XM zlN^;FwxG&n+qi?@a(UuBTOeESeoKQ1QsrSMDX38R$6`#$0IYGmHwNZ*R?AvSe%~#J zpuVf01cM|HI0*_+A4z=^*mZGn;c>Ly=dqRLT@1!9@Z^)>0qlnG1~E$B!YfG!Xa~h{amP3BmbPd4i|Z&H zJzP8cB^MdG0^VqEBtPA63)#Si?-q}bC|tH+l@q`9jMuk7aKcjO1kRYiqkZ$eGN@E!vQ-4egi?}C;~}{jOJ+8 zavW3(K&}XWULi zQowYCw!I7f#V=(XR+%vzS}vd|znRT^@rn;=u%0C-F3Lv*z$ggP41UwdR1Ics2KYxv zu0<1zyOQO@A3~&&OsLA}MGjPuEfDG-3pHSQQuzT~y~HTH?KN%-T0M19+il14lfBCr zz4owAOr&5(rO=COsBgHbvajP#=6EyHeyVsD;h2S)Q6mZP5#nC9%)B)ckRR~n!t-3? zfCwPcg-Em9^7J%~G@GF1`etlJT0{wabhNIs0`h0X_O>n{Fdh3iFz1BOILZ1mX@oAo znQA_KL<+dV0sLS&aj+Bj9^q#Q|5=gr252|_{1u!6f%y;FT>Lr_`5d6QIi$RaU%a_K zD^ZK0;6{cj`XY6}5T5Td47SWJV(ZYw4>ysMZGPU)P;9HP`)udifDp!I(JO)(Y4?-U z^Jfd*!hda;xH(le?7aOP_#fJY?F$w-41F&vj)B&$w}p`U9_I zieJDCc!5xf1=3Q?`s!h&2*-573)pk_^7^>r=l;zi;z7j^^ckluyGNJN|V#41CJdI=?%ZP#6S0I+N;jwf6+b? zwOu^Y(^G)glKvG@o;mhIghwSXytHvM+G@M1W#Y^v!K}FIXIaFehFIz1Up_GX;}}72 z-1ny)IyId$Gsd|Td|@(xrzizgjJI(M$_SwJTWEqs5z&Cgyn;WF}>HddkEDBG%lk zAGznGyi|u;ZEuScZ!8OSm}{IC99(9dLzhy-e@%(K}wKS9pc)6Fc)S2M0e_ zh-lBqVFphXy8#)=VafQp0PM#{?iO>BoF(LOO9aa=tl~Ys0|@+FJo)gfNfWZn0 zUd4gw;S+qx4-3NW8qq`$W#Bc$pv|KbKj1recO-4v9j*SMaN z(3O40A4t-s$aj{=O6Bmro|i(@|IN3*k_jTV=pgjbPQg1L13NeD#u|uFiLkf;h0Emd z+=wN-)$M9=4FL9KKm2%8m}Z!kfH(CmKpc|z0WTno>Tr+GG>19c(*JtR$n zzeFJ>UoW?_CgQIrH#P;5Z8CQ50&+8l_FQ2W@#<8s1?r|{S>){i8Zrhdx{xqO849Uk zPpm@un|xZbC+i=5S!4aRw!un-Lo?9ig0sK(Rvg>Y8S#(wtYWci)}|_+>*1EbG(v3rTQ zEk5xQ+{=%J0;I0yRDY4rLJQ+`Z~qSQTF)?XziV$W!^?BuWKPD1)@R!aWX#embQd`_ zM^398E=^*jawI#oxXxU)&DLQ*4yAN1*n11kT_;LhoIQVAW8(9y8|~zNS7G@9vhBuE zDPW~8woOTQmY05DqBk*2H-rcU^JT^7v*fC}FCcO3Rt>Py(uKq{-COf_FNYRYH8U|1 z9L@*LavhaK%VZe!orb+K0T)YCU?yAx?jR@2I#o(K(TSzNAJ3eI8fbdWG6mu9s+kHi^Jri6fuj0qW!$_$@!xBA$ z^TVn7776UwPBMr>Pv3pI`^#8auf07qwnEq*Ymmg%)E`pyBHNrZ5FEQLfF0pJeSYMQTX;_DXW5JW4bOAH{so+br) z!bWIgS|70)5+KC-RmmJhO_m_;zg&}I@!NBLGw-hd>kR~7&-;&~r*dvC;}+wu*OXou z3U(+%i5Rw~)iHsri{nOfOkP8b9y+$^iq^|xbwUyWZ_o}TZo57)anbYE8wygJ>gMs0 zEJ$SnB7fwTVlmzQ9MjmHyD}WThdXDbn~s)H7g&Wcpeo)epXH`{$V=E-V68{YMk7OU z9nOl2m4Pz~Uu+-L4#-G5Y02X=dZf*l+)lQh_s--9lDG(B9?cwR^6!4#9qhiH@Pv99 z(ytue60Qr2b}EUxWkwCrC`ip~)p3`N~Zdo(@|NS{=9(HkU;}5G$U}_jHG&_Uh`!r@&ZJs zq!@$%31JtTJKQ!Uj0-cA@0OjiE_>xvpwDvTGk3_iRTZP=l>fouZuml|h}ga(MtDG5 zPKzb*4U`PaUza|a0+k-^S@vc3VUj04%=_kvtd{ZPz|e7WYS0lAJ&?`KL(+h*=mH6* zCK8P-uy7^_C+X8C4t|12@>FMbJJ%FflUdryCtDEQ0#2>g9w38ad zAGl`?*QsCdGW8oUT}%M(XlNd*83>m=hQgX(6g>ymyA9o~{K9PzeEbFeKvU^&6^~*0 zF23S?)7_g>T%G+zmtwMlw4Ph+7h+T)adb%5T=wz1vB{r-mE3qBqp{IXR$@r?Zpa@a zTI(L(5tGm{bYBlO1cgcek)BppUH)EuNni;D@H;bA^4Y&V3=~{yz1W06wt@+skpgXe ztaH@JR<0gM$H!vbJl!@Xuytc*7nAFbDKg2O_bZnGBpsr6bQ)>kEiSpIOB2>_5Hw@O$2gk~}+|-!NI1r-j{DS%_YCKS0v1N6w%kHENY&I5|yHs_5Ke zi4QnhQ6Ip({RSh$L>ihxcq<{VIsFy=#+GIDbQ32#z#+-$Fwp;!&ORcVk3C;rQqyvy zx6I20!j=y+(Tmi*69Tn~aU8h-;CKO+F{uw*-2yj)k`wRJ>y2Q>{oTplT)4_7Q~lZS?6sJtl^Gw$N#9VxG`3sKWk%eydE4M z*!k5mv5CO7CZd@G&e*3ThQ7y;yBYwXTm^?LFLmuQl}jA^p=#{JAa4P{NG;N|OT`1c z^YM@-S!P89E_em;If^~YvEuZBIkjxEx9rf;VIdwbu01#644hT9fO z-IJ3~a2poa8`#v~i`1Rpr4@S4LC2X~&tJa8qpXzz&9B9;hUHH!AQm7=MeK1&K~{WS zTHUb-AiiitI%PkAUPhS$#OnASiiM8S3>&@R_jIrHNZBD3eX_b9y1b%FNwn2n)7bSr zK{qy>tNC_3#300R$EAO3WL>gnd5Gg}x$C-XJ40ayvmNWM%ua-RERFCpfsW;&T#sj8 zs3`o-%cUPn57BZWT)r>=?NnijVqglK>;1HU&(tY9YxA-GZ&3=S@G_4Krp64XW}E|V zhhZVd+@mYnkBZu}vpqJ^a})WdwRj@pJVjrY-Zo$apN}ef=8guE z#1VmyQ%hcRQtnt&%jrx)-$^Y%q&Ru>-?E}Z2i3Li5~}7h0P01C+4Wm|yLqA^Fd8Og z?t)TpVAZq!Jg}gGtf&4N85;H!-8p7YA*%n!{g=OwA@NVoP!{iiAiZt%6&}(ubLfN^ z1Eajeqo`zCAa1Niut9V8{g@EFr*A(z+{6dyvJ#{v>tfb4pQv%iG$eV9%W!U6PoR$> zwTz^27QSOuC1oN-XyC7^U4s7PG6{ZRJYrMch)P4T7L^%33o^YQHQuk_YzNle{uF1Q z%^3ArmkUVcfxSt*;7*$PXwOaeMvT0R3Dv^B{*NfnNb`ApqOL?AM4@~-Nr z_wjR?GBe>#d_}vh_sNY7-eSLi@CtQMAhvWjUQWP6B#pXyS^^zDJ5&rAFJKkHSGyBL ziiaT{9pqAlOaF$J9!|e-AIh-16)f9rppAz-|FCra0ej&6Q?B`}`OgPpl6;{zUn-Qr ziSQO7p{|ziV2%U*kwfAXBI%Tbksh@GX6Pb2E&8AUqxf) z5B)bDlMIex`l>a(h~hpGDihgl&b*QJ7L&V@o&sR<_{*PYBISh_P(JaLV7!zfR`wAM zNQtG!cD_BbHM8BKCt~&wJ??~Z1kOn}(&A{*Z*kJYLQ}3e%ux-KhEmuSwe6%+(uH3& z!PBa^I^JyUpGS^og{59-9k{QhH*V>Aug>FlnbD@uoJ#kY#9y|V1Fm`sw>i*NAC#ZY z?(iF=90_2x9k~z6q}v5OjF{)Dl-pcUnhV+x)hECA{Fh8JI3(s?wM6(p`FL4)S^E(o zxUL%MRU4F)%SQ>Z`bEm3D!$c`IBF4ViY{hRY zc}~$!zkT0pawh+8gWnHa3AX&)O7wX4vG7vrRAbt2G)=L5!P5mf-@8Mr?}h06?Bqr| zLR+KKX=b|}?g_B3o=6P7>B&mTv6C_2bfAz$=4bKkO@Q=)O=JHy;bv=5o6ykR&^#4% z8qI7unuTL0s5){5l|0d?4z;CP~uR8GdMIE0?@c zBDq9h`y5vB1yZjZ(47VwcXQ>iyI%t*$`x#)qC znnsV*5Zvbe&kE28o0>!FI0u%*)% zCKCTELfF2xauX=?g{=L)zM}ZGe`qiCa;l7aS>do1y2Bf^$~)eCwwP&-&1WcDy2x6J z&Jx8vJI843WwcQ+zkmN;5CM`q&R0{;MQ2T03ZPl-HT2sF`~=MY`vd?=XexP0$lJs4 zbv_W9{T?=e$DtQ%A9xa9_bst>f*&tb0z>*dK}A*Dt{T9{xU4>@CJXZf54MnwAr=)* z4v-iPV(I-K&hTt`epP)fD-{T5vp_qtI>^mRw;eL74^u(4Qz_$U{a!=%1k)) zHTYvxaYM#Il7YK)w&T;ZcdqM3EHbZgiI?MJhSHVNDbkHRxFdO%C;e4fw7WHDOk3^f z8b{}$l2f!-W!IZ+PVps*9}r@!z**B_XV*^Mk0*E@z4eA<5JqAHJhSz2&zD*_>4n?2 zZ{vOBx8r)XC=K$cTRW^Qs#yz0ZNDZ2$fI$3RAeb}cX>PgeDmFj6=^^#;zKtpW zN=y^DngR)RxG&W{9Df`_gsVG4h~l^YPqNeBzL!1t1!^?_+*39BD@~7Z;3`q;awbAB zM(G}3jJ__5Q&~5jktdUI7-sV&b$vV|o%CryQcUQbFE+h3uzsi{l2J|Npzdvz>`{Qg zfK1Lgso}4~ol8cnl$sbIST}s8VSkoZvDf&aywFm5uplx@Fg(uHeMx?gvqvT6gGG)y zP`U1_tgg-ol+YZ$ltYCgG9acJPyWmE|JLV|GedA`l3PO&`&3BAN9vv5B?oyR+zs$v z$jU#;EmCDxS-G48CGnii)Xb%^ue|da*_WI3MSXWGCT%zUCjZy6dHldjr3L;PDxDaX z#s&Xd;|gvUte;N+xP&55#I51NL4V)xoCV*UO*Af^eDydUk}82;e1vON1n$4x@Hx_a zDVQbg${D!M1&>z6)1fi-TlAbZtE#pAFXQ?c%($9tk@T7dZg&J{9$a;J7ghu=Uy*k? zc0hMp&kfc?Va220@S#OP5LlShYcRtdb)>up*Mbtc)I65`%7?1%CsJdtGAEz9@Pb+1iej^;QWIi>u~kPVOhmyWUNH&vk932T>q9*nV^h?n}y%(T5NHc z_G6#RzWp8Y;L~%DX*CSiB}BL880^(g?bps#YXqwcVxVLV`s5XcaK#9jb0^n~9dV@y z$?y&^MfL*hhw?v^A$0qovdZ`qYs%0E%ScU&f7Fa%YFFjlb^D9H4K!E%7IiC&gK2DA zrpo3^EscPr)qAQrg0dg!Y~rC?>{$|p?ynUAO4e(qP8olBN-C@rXTp;NbV7LYFiq81 zn)y07Lg~yl5xCPElLOz(_*Y?jSGe}5`D2`V4Rk|w^dHuwc3xGdX_$O~=Uj1Lu@QbW zhT3@kZ?F+|8u8$6%7b&OpHGIUjAf`x&7Axr^yI9jmx=5QZtV!$Hx509H@-5Pn-S)O zqhg6oi?@ydjXO*uJ0H*5+=keep_}h)&k5UFQ)YRrCgx+EWc*I%wbYcO8nC>RR--Ogw;)n;PZ8QVI_SLkX%C3U-o5EyYFSEZ z3GfpNykAt$r)EPpUz>EO60kfiCa-D z&P}2zqGXibN4YJ3m+ml}4_JVC=%vfW)Wy@f`dcd&T1x*qKv?IZ-h4F-E*aU&&|A*Vq&4#JXmrVAG< z_YA1}ufiU!MI{6HVT2Shg?X9l{Ly5}uw~!pT70(a@gIpo*+kf%EtE*I1bPx8Y9_v8 zH73SaO)?Y|I5bUt5wCGv(JaGqIwINtUXy^*s z-18%eBWh%crv9TptGjE^wt->%Gm4)Ga?f+o5*Pxyn_Kj{m7|q_f&-71G2xgz*Afre zx*}A-E*k@mG>Q==2wbLs`5ebVy{eF5et!?!qJhOJ@oRDB-aQYMSFYd$sbIuKM^ z`k{YzQb4K<_{Q%>_^-AAJ& zl#th0Vtav!6Wbe#=lx?=tP*0=Q=U0<_lF{d_J^nmTa`TKgo*^iw^|$-tsU0p95X#D zKhR=+ej8cRL=q3>eQqDNH&Iu%?D64YfXmT-q?{XTU*-oV!U4QC&@4OZOuWilAn)>F zOnxs-MK68KQJX+P|k;~FuOos&A}H}aR3A~=o;VfZBj`I$X6c$Der35?c=xik z?{p||YjgHY=LGuHBO_K^7H#Q{9?n#UL=wBE#HoivsVHHw(jQ@@{j_l%T}g@sndRh^ps*0MHEP_W~CLQY60=hAJ_PXs;l+fP*N?D)j#v?2SjpP48FZ?>;E z4(gcR^Ie}jZf@AKPp#Vr?L^ELSKjoS-j`3}dQH$0AeUvU+aNp8;ke|>*r#C|lJU!CQ$=ek~MTFuct7FWI}y)W6m!wY#K=VB$XF3^KatIuS2189E^ zIRxn8E5P11XeJrNzGA+*`q3~Q)WrV!V)JqIE4p_ioL2iWoDqIJSJ>5%7tl(!nGa6t zv|*mJ?`SS3qDNizGUyw+HDfb`WZ(dgEL%=PKbmq-fAT((Qc+n#N^*3J9w0-~N(#qV z_dNqB6R55#Q|>*Q5yA*MAU$D|JBK9W&I@OBAx-!P9*tvV*6{nDpva_l)8PnsdNW$>)yd#Wxu#R4Ce-WZX zf;^9~IFERmPe{QgORGcEa+F+A`1kK$01>nu6~vd%DKz*>=H}V=N+hLo@Fiv<=Xxhc zgSJCF=VB0htaNj{6RjWc^7pGP(#_nYbu*Fc?YQetu}D7QhyPsxg#Y92kg&OI_hr{L zlB{V=Iw;jl8O|gD#7JV060W*)7WIm0!oGf9_r83uOs44q3vL%2U2~-CYnf&^X{RTz z#f}=Z7K#1GiBK!jJnBCt0%yV)t|8YPQYq@Q?0g9Cjz(znvMM;x7c~nS!EghA5trk( z>{=B;Q!ChiP~M5Hk%xu#LFaQ=P595pZgRcSU9^U?^%F5^@7@a~+?!^YH2l z1onPu+fs^+7<%W;?I_kthAHVrei@@s28!{F{tr!O85PwVt?@Gq5>i7q3?QI?qzpO+ zK|)%TkS+nGV`vx}6b29k=`NKLq`SM3p`^RJ;U4c@_Y^Iw7D!s41FbU%EvS#w=8giMgITkJ<5b1g- z!E)CA>0n(gN_uP^G1gV_;`+QqC(4HU_w;qt^cChw$K(|B3KYJ_2{C}p6L}>M8?|C1#Zp5cK*(K^Q$v3)oj&D>-`Y4Nd zQCLv5B(&oT3>P!Z*k|@Oq*z8acUR0j`EY;W*%Q$fCXF74lHq`SMV(36;8t74NYWpf zZgA@9XyVex(xXe?r1uwIE(csL1S%Rx)GdWBc_n=#XX@p&|Erieh2shs`$zgCd!D6f8mdn9vf=&+CUydoLn#C=&-sZZ%hEmh+|Z zf{sUV9|jq&dLeM9e``lPYQ||Ab7Y!Yaa+YlVLP>o^>N3TO;Z8bt5!tDAwp8SoD_%j zJGtWITya_a{2~?V;mxa)4LYlIK2{`!oJdpR!ERwUhTm63{YToM!;>Dm%Q7ulSx~CV z&Mazi6(VI)XT$7uyr90-37+NkD&D|!(wnF~eI+gcTi33G2SJBUp?^I`s*m~7#n0vx zso#2oyJ( zs=bbQnGNN8<)}w`tQwrWx6b+>G0Pl<7z6%x!Pv2K@2N&6-5chV;(a@hMwKP6H=S1q z8gx&Sx88z&`Ax2&oqv^bph|^tv`#PtIM`~u-BM_h@fc0|8*6ZsIN^Mi97&g^fz4q2 zqhInoKvG@~&7d%MyEB*Svybb`SHH~88SPfd5((TUAT_6 zGT-HUHCn&w2)P_35|tfIEp2|5S@^t{oks5m846CiGEbQm_D+HA9T56R*ihp~mdDS| zM~WNZKFPi|k^KRn?bERq1x*bruI*`g)29S=2;S6VB@`A3mW&uE52;8D2pVA{pdtss zF-k{=zEN@=ZVWybDrWw6%YNy_{$;87Az^6nGZYNm$yUj{)R#NN$Zk`f`vYYsRTUKH zV_mP-YP#tebrc3`H_SXndmFy-OkCcPF#nDDdyv32YUuL>`n(YRE7uA|{559UJ~?{f zackw%h*71q7fG*0#>(|y3115R@rq|M75{KPCiIPBN@y1}=b_yb%W(pSo;~j51KlT4 zNmw#bOGGrJ>3$y=&g7RnSh)NJ;i$g(pbG|GGZ7VrX{_-BO^46IEZ>4du>PvZwkSz# zikQK$DA6Bb?^9Sl$^c}hwYjEegqaU;j5uj8ylskgNXuLNzd7nohG zwGp2$q=*rsU|Xalx0ndwIH{jb!lj&wiN)qR*yLLleN`up60n;l_ph-|V&wy!)kaxD zvM9>LB#~_cbJ?HdM+z!CVuF3lUJ;KV5ng)J^wxAS>PLR13Kiz&^r=_tuf=eeevp4y zAF>EH59dDFh&Q#A-hTvJqJ^s0Vw=X?UEqdATiTh%^i@Xu5(oR{s}-!s7tFLs^ou0_ z<`*2&Z?Rw^;ZkqGYJeGxyp@+I$cO_E5T$Kc^tl@^a&wHvFm*_PwMmt`n^BmGwZ77XhIf$67R~gvzerB7_E5^C_Vq`WOX4= zZV%?hFGV!-Vo`_R$ke`szIK|s-FFsw-g)*XoAdT~FemSJq*#E~<;TCY=eox|9v#46 z^*D&vi^Y*E%5)*Z&dXGW#`z{Y(jTQq^z`Y|jWeIUv&FjIx}$;P7U~AR`dyb|#ed44 z``>vea|RBaDTw!%M4R9ar4aI}bJX`LdT)ZhOtW*r>WXtjVkb1Z<8rhqqAspEV5AX* zIes1yoJtF}o27a_ZO_EaY`g3SJ_W<$*Pm*z;khGB)sk%g$so4lk21sFY8TdM?|hMZ z(&LopXg0Que~4IbMl1{njUIh>Xmf`FC=@OLndjpL+r!ilqL%m!Ja^4(nV->`Y5U9O zf8;8~%;vFMqu?M|L8FQXSu;+^~S#W&sLBftJ)StN&U1;u>KjMN~UOc`9v=toPM9;Y01j6Ei<$cQm8bdM zU$R5^o0}F$>6;ZGHUbMzD1}o+xA=HqMGMpjH6?}av$;uQ zr$q;^zif1I-Yxa)sPOi9`&X8fdw4}``c0#Q{_Jf^Q_u#vDC>wZ7w0POEF7(Pq*Q&AtI9r+oNhx-5|4Q*!60TGv+_Z}I`O1bef>)MwT2k6aL)3N>5(<< z_1aTCP}L5Z5&2*nb`HU$em%XfyV}y$_3CP^;-fWF&bH|^pV)rXaGt%=IgP*dH-?2E zJbDlprQ5PM2Vi@%7L{1-!z)hMUtwMGb&&BWchguC{VQJ-^lESB*&Hm%`iLJ6!05I_ zbIX?O=ve&-Hb_&3liid7lcdXq+gES?p9OGr`Ktlz2Nnt=m?RZK1-!Umj$<0KrFrMy zJW=jAq4zFj;i_luHRyH9$^YkdiYsnnCqjsw&j)OrPTQ$6->l1xf_li|@wSDeXD{`l z7p^0VYKq8LpEFlq!TL4vy|82wRqsuWm_tpOfsmbPuBtOJk~dJ;nmc4`A#&p%P~6W@ zAX6q*izKIEXNuC45%9;3J^GX=lxt3-IMrDOLM_`@Q+j{?ZSO0HqD>Qn^*VkAD_GVZ zDdIC>_P|~{94D1fa7^CZsm|#|w{~+4W<@QNxfu4y?CqIxtDVbdWo_3I4$Ae%!jZ_S zM2~%zRc+dGdJO%4n&j5ylvtFqa{!{4YdO2^Z;8LnLuzRzL{TAMc5T~>Iz|p zZ8L}@i)^q-bZNVv+Nq1?S#CH)zyv#5o-c{02zr&Rj0K2l#AhuJf9 zHnSpidfoWJV^?OT9DIrVnE045#tLN(<0OEkL>EO~T3>u9X`~xQFs?GYcCSHn zV5Obc$Kmo1I9VBcv<5}BfvRGOE`vy~c|5<+jL)40gMke!fn(2mtvHc4Ssn%*wRbEM>EaryiG9pNV{66~oK$F{gE+ zg}(QU^nerixRP{2bt{P zPV|U$jDgk#J*fV)wg`LE-r~o$D3@P&$B=Wd-ajv_^w~x~g$NjN(e*NjeAtNR)g9%} zlg^3k-*`A-u^D#+NpCUjaawVL_229IkV!H}fmB-i+jW#)oqXlDMFU{J|IWtGkNiY1 zkX@n5gi1EKG829-629{wAcWX4(obu%CXyjD7e7YcrXTImgqYq3j?txg4;u3=BIjK$K+T*8uU8o3% zWfO7K9b||SMTQl=Zz~Dna6Vst_SCbe$t+7%1;?&Zx44VMVAGi9C}~X%7=z*Z_QAo2 ziy4=@$G>S&-*og83{w(SaFw`%gmi4*fe7u!#nmSo`II^HnJ`60a?^-v; z@#`%{J#4E*oE+U%-qsp8)c^|BKj*kc*up^^;laJSF$E#2S;DCivJpP{dl zS%-&pM_Tdhh}Tqq6qf>&V)pxgqWoK#&|-ro3r`Q2(bT^wGU3@xod+%ad3KyksRSZC ztNhe99|pb({J!`gnL#RLX(Ql$dR%eb(dedkIg`O?Wi@$9 zo`{Pd3#4Qhem`U`#6$g){GmU%i9Z$weVv^eW zx8C^gU)0-M172h$A$kkS58GTG4%X_smU$z>qPrKZ zDq69Z<7pVilU|=LZ3Sgl2o>Yy{-O2=Q((K8JVl5Nb!yjl zn_HyV28u)Q;J=zXt%k?fI6VAJD|>0_pRm$i9~I$1C1kMlyOkUsCy-cc(FR4kB&Ctp zV^**<*q&fuz1&9diEA8Aqvv7dr4g9|>8RO^IUZ`Rb4=}P*(A>Z{_q^t9rqd8ki~}1 zYLH{#X09%H*;V=C+DDJZPS;F#rNO*HwQDxF!yrS^Mncl;hL2vX10KRl!D3dv5cZ%0 zeV}W*>^tvTxB)Oi=m2;LDF6#tH--X?<;$$cW4w@15U;}2%^A}1b;SKedP~tG@5Os5 zg5G2l7MGf#AJ0Gj zM2+6J{$G-(rIFzNy_3R?0Oo6>#8iWi$x~~dBQdG5KBpGG0`>FvctanCBHb%}K+b9H zljmTjpx)4rq+!+fm7K(3MjQk*8z@B1^-nVhD^7qbQ!2$cytqM}&`it^mr*xxogfyh zOY_-03w!!lu7N%T57}a*t0I}g{YVbG`(gZQ$jBsHEt^=<3gsF)nt(M#cqu<7{fl5K z$U0QLp3+QkuP9?9R|yeTOi zq`|K^r*5w2=@x?-m<*FYMsi*4f2+LorymgTW*)Rjw(DrA45tgJb*yQ{w@@KJBNYj6 z70$C|Pj>!GoXE;4Lu!ulLQQA0GB_n4pBC7!LpWba5Z{&nL0>w(9(S&vq6VRmFm~v& zwr350{DfKj(dBhd;!y&oxVkDZW%a+WpOFtLmg&H|kVZ?XbYrzCis=<`@*kH2@tRwT zWqTn{=XjIa-C=nEI1|Pib80bP9G5=xTfWY0tCg}9LOSz)sb8ns~ z`~i*CYxP5^urd)UQG;?=4LjT~b$}kl653EC^|+Ex@qGbHu^Z7!i7Fih@O_xhyQ@Th z9}K(5Nrc*@@_w_S_Nr9WpW^9La7at!hb<}opP21j@nT2vnRPV5(QpO&V%33pv?I>= z$Iw#}2{A%K$Mf<9D&wmLf%HL(MnFZG>G>UUlcoC4T}tdu%wGK z8|uI@BIh{WlJ_i=qS09Fc3{sADkrTN5b)KK5e+|Oq~1&J$wUN`i#&ZTE(i&|>AB&f z_1$vz(%-D_xoIK`9}oIkz8|zvF4F7D`X3NFKH6)wVlYQizy%!m-(7QV{95yoV=}x0 z%Z|(Y)^1>))}&Q&ED0;UbvR3Ibh6BXT-4cGW@%|@qsb3y)rn@bvL~ND0jWc>S`~4i z!)l2Od`P9pdTzY~3if_&ztNWmOkSd=S~ii@^6m<RQf zbLbv2a(BIrTh7*Da#_!-|Mm7`!I*E5|u$Z4;wc{m%_)p{*v zp<>3XL#>JgTs=M*LKu?lu>WwV1HhKHFM$T8|0(CLHCeX@Jt-0-$3jQ!p{TMsZcgmz zdl9%7AG~Db?RyyU_Wk7Xm|!B&D?Z=Wv!RiuWX~Z+!!|`KrWbk#`F$Bvwg*DNC>1wUWDU_W7m2rK{W%4*$cP#m%mu5mT&6*GUnUfeIa3z&< zf`RbrrH$^Yz;oP+WQmu!_^Brwy3{wn>~wMG+8rYwUA>-$q4tH-YbjfZH{_`AiN38! zATjVXCdB?`H2sn&ZR!$a-qVmXYshH|R?Lp^`Si*NmP29g)j>XUU)F2{bAGg9iIj}< z;k$%Y(2YJJ%;W$JCZO=kli7$2GvQubcvI9CeaNVbNZ{dZ7HAg7Q*X|JAi#2#q zeCvBU@EG%JS)|u}S8QBwTwni<>MxI|sX|D#I7%54M?p-EF=D|r4!|;lk?b3}l^g^_ zfWbDP8QL6xC;ly}bPB_W3A7OoE#b2j3yl&P!}+)^rCI6mJ)^cmR5j#IxBAZMXnShy zucz(A&-%@D#+d<%KrPX*4>Q*!7w}LY z?qf9#{f6Lr058Xdrj$rUqi(7p`<$+`Mz~o)|VAZn2=D zZ4}{WojEIvDt>A>d0F_V&1){PSaha7ck~kzvE3#L)t1DOo>fNuT~d2Rl3Xyd-bU43?$Fq6)TK9 zg<-g+`QoERSy~06lRC|UDCWc2$XhQ$pye8oL}=Q@D} zoA3ynzHOwGHPQzB`5`D8Pgsbe%2-&zvs@1g0Z{{sJkF6E1rteY2yFeJr(dw9v%Gxm zYm$bC4n`dl0br}Xr5^0ZS5spMbr!twhYOU+u+Hb&&ifvEnUL;-D! zs34$|$9~n*@8OD}T}nd7LWf_6`}^Pit;CVoBo;H>#K?`3BsrB|PRgF~4n-`z3)5C1 zs*F0X&Qi3MB(K*KjtS<@lq|hAH)=;a@0M?m4<|P^9~GsnZr8o+Ha}JkqG05(5LpAGnkdXR4{NiCNHxUmge$8_qQ^Rrd}iCAtvXIQrNdErFNFkm+sph^W)!?8g1}uQOfX# z=1Pqm`e8+`eIZpu`qXNy63|?!RkqEaEY1CB>nnlUgAWl1TM8$8Xb(7w%gwO}! z5#&#LR-rsbt@A8dyAlBW@nB41H7)$RjD(M$w<)pd$57JN>$T8(PIR>}yEd56#>A~` zGOiyPh7kabJ)w+Jb$C1BNmINCV_PZL(%jm*2QdnN2@wnSrZ}Qij)r{K88@9+$ zfY-bY^%n2roZU}7#l<-%#kMT2Ad>~?8U6fYqx$HY=BYA+&^L6$(d>r4Ea zEATeg@npAA`J3Qn*s@+4_UtyQjtyKLQTIn2?uutYRQU^p7>yH$QU-N@Nn+PJE}40q zT&)!E(%|)W_{gD&`u7zCXtBtee9cHeBu>cd^{16-E*C8aFIgun(yt8*2XJbFemj?MD;k3ULp(N_cD z1;4)G>}4YEecQKid`^A$`m#{6<_G3pGqu08b!IB#Ja=b}LNGBtE}dga@~wPav~H0} z?1jqFS#-ygNSrlbC0|?r3>PCS|ABg%9E` z=Jyjz{N0?pSi;QZMFJxMkk-A+Pk(<#8aw(}&UT3ZqwHE0XE_XiDyU)ltGgJ{%GWgZ z@$Wh7xsQ)xcthR$#b(LhdY7_&ntd($<&7q#{7CT*WALlqhz|SD3D=cj8nt%7%y_iNu}f)q_LA zu_J3BxYl&$c-6<_&yaQnTL!?ZNzf_-nhHX|E*fB$;|J82o+nRTG}%BtbP(g62<8br zdFr};f7~|g#x;f9sPw*2oppowBz?$Il!j_7o`wJfJ|E zt=uNZ@JC2!((-3ZyQMfL$lu_DzT#^~kJAIyQ;E)OgB%bD__$u#?_sAl{>Lp>i#toY zYB@A@@zP@?VkjRU)s&(aQ`Rxfx4M!}gavnddJ(Lbv$9LNp_cmfV{CJ{2+%CBN-2DT zZDXq5q@n1qOpFD$KdnaqsL;HS57eS8NNl+Mko~uvsHKN{RG>+rVz<>?baEiNR*5(| zYfkx$*-(kuXYTVx=dO5)?8&=!$@7lrW8PG^2i%X=4JCSu_Ba?!%%#Q@*GM?0PF~|g z1~NFN%y_J3NdI@SJAtltnPK8g!DeRcB|;3vo5vEoro?50WBr%*K$fw(h8S1!*&EGz zYBLrrQp?qvEwMPu3`T8OPDmTS)=?l z&N{-tZg3`>=uGYIu?xV8w#;X~>~a!uCEkYt(k}gg;^*5J!3hPk8t?ZPO1U5MRo$YN zF;B@H$;3ohTtmtoXlc#2G-M}SdWy2PK5EEzR`;&AuvPywQH?w_2o(-d_UjVQ=+dtw zuFDM^Hhxh}gnE&%9ujDPk3hZ4l)HQNx>+KkYo?tC&!u}eUj!iS=jbh6*Vec)0Ya}h+P2uKHnJ)HZq!%F_q5b-^Xl6bhwP|n=h{eBrIb=c z_Hf9GGk!RR7^VGyxb0j!F|CU$OR&btG3foWf$a_o&VkKCb^w0@@PDPM?OeK~|J2{l zC7c2NzoPWc+5SN|G7m0=OiE*^$sTiP!=lPc!gj{LcXEk9;LCf9>*P|Yj4;wVWr^`W zSXgWKao;f>g9JSX(Pg1NA!nieb@kvivxKc>WcfxCHVNzCfF94|Ix5C^)=#f#d2^r zgz4=BAFePw&jlZbr=4XhT-dS`nxEM7-!SM@rP~y^u~U>A`S$e zMlH03Z5X_e>n=zwA)$hj!4tnjs8Cf!t+NK(0!)4xkWgt;tK-yyv4ogWjZjlq=0T9>0!2iF=tX_Xru&S1gN zNk}Vv$xhOl=#lVdM|j=Ef|2%SJp$W9a_U@L(0e<;@Rj6olTQ6x%Ln2IC$%F>s(({b zeG9dt&ysDXUWD5r-&OXUhzvaQSnlenNwIhAee9Z2A+)|p{@Fwo^MJQ>sdOgNu)h$w`im`#xMfVYoklQRUB9@1 z$tg7LOrXscz>S%ewb5c-t#%bE`RrF+Pj9m8v}{t##xz`+%U8+?&~0K1HUAEKsWCSj zwWEsG(6H;I8P3IiI-|Ac<6K{&{w?a$q7R00u>RpN$D+iu8Qltg+i>v~3H3^wUn%}1 zkt*klL#5s?aG0yzV+&Q8AM|z&yBeC$+PpHT|Cw9g@MtgS+taV5>yum@B;8MzWBWNn zn!f~bFFY6>G z-%s(*FJMe=mnQUQz)eXv>`kuTt%JNd{S()^BZ6-rFt6TtEx{22YORrv1R!~@T#k<> zC~k_B#s~WyU~ta|MwoBbDJ?B6IBrUKf-ilEqpMgWD7$h212)UMA4)73woK`wRUE z!V^m&TJfJQb|cdf1ib9pitv)y0<~bs4(2Gd`f6EISz#%Iz8g$zJHf7`T&ruAO4urj zK|*H0Md9R^cJtBi0u5Q$;))6g3)_x@4 zmuHDhjDPTXV>pqbZ9JT_CT-~;f6qk9Ew!vR10O8qR*p+Isv|6ZdS?bslcg)?Cztnt2jZjO1?R4wjJ)&BYbuHCZsd9~Fwmw$8Q+ zNr0HyFHW!bn>NFyGH@5JKMYAzFg%}N9^XZ6lSXc?xBD=*x@|QVIJ><D?NGxl50+cf`_X? z!~F%Wyiq6gGk%3j@qMF*t_U~3CV$kump+G&^rtqL4k(6;m&41(551QPOwQY`L)cbl z&|)ez&A5)vT@G#bJ;ScmeUFy(rXBLai=38WVq|7IMRPy2Vayk))T1GDNB!CZcS(}o zeBPIXdrH0spP1tBO+qTkRc6H~UUANNt;JOnxo|Srp8v`4J$fk0QW>-B!`*7^>edH&ZNhbJ#MFR^KJ_S*NUDBOfI+PA~A`kdGX zvoJ5jo8OU{Tu@11a_f|g^(loe)r82Vl)6;(Lz6{!r}&tvGFn_`bl^TxR*k2-g0m^W zi=G9|s1xMFP^Q*s?rBYS@Zr2}=oAQShS&nrW}dBxi`;cQQQ3k1sA*(|h5j_N*`F+u zFIPF-`2Kq$)sp$sUuu7vtY4;f6-HWXw+t9bE71^4)1ULO<%Os454s$c81t}|DYsfi z<+pO~From$dsT$BY=q!8&D>+z58Lqo!iqiBYC#BW^2^Cl!-FxJy(O!6~( zVQ0-iz}M>w!?jgK5g!Njptm<2xB$%VV%nEe{CXiWyVY)Dnm72R?(;I~Sa0tqh}2Zn)0XK@ReU#84RbR;_x&hT zNDB)b-V0q4pX}Ql!j;DqqEe23?fzo46+PJBF=Go*C7#0<_S#%_ zS)1`#Ha9d+5O^@v+F9NB#e|fnlzPb__1$+ImOA@jyZdTVBG$UO)-_Z7EbD?5np&bf zvJ}kbwOPbn4CMZ0e^IQpIYK;G=b0BEwj#3Ac6#oEN8d}73_}^=v)x;A;Mq?uVrU>E z#&!Z)VWdI8g?FWfOFi#I5#Uq8?jk(Fn%ib;DzRT(#rzsy0Pv!*4E@;Q9j28K8P)wD zR%AUa>&+CEHIBe%1_|W-U&O*GBgBl$t2($o2}7{q+uF}_H!)>zR)xX(6JT#u@Yz_y1PF`Ulk#1BnM<({$#UDkWXR5 zV#O8;_eokrSguKjRGLu$*x!>56Ae<5UU6Zuy#h?s(_8(qG#2TkK2L3wyt$V$1HP1r zu|H#{GR1)ez=xSriovFXrJpIejReYJ#?cFc7UyAR4E0e=#6H;B+Ys0=W&I* zI22vO=-}wKj`6G1*qsqspFF4tXJG0!r}T8`5%S8uo@`iZCfEpwtWT7$GGYMPGZZ)7W3-PF97t#9z+@S=ZK{fZC#NA#PD5*49*3xU_Wt|~-j>nv_+a?DR)scGWm04r zHsLh+3%0!=X)T{%ZTD^efteCt&+?e!74{$yaPm0sFE%1Qa@Iuj%T@101cu~^DMUm* zD2p1yzKOxXO(aCT5QXQ>(r#-v z8$Wk{{!o4F?Olm0Lj6Xo5BG-qh%{W3;>vG6h64)I_jA}XOx}YKa{~*cACAB-b@pGE z`kQ|Gex(mzvU6 zop5rjH3!BG7om58{Z?{qd5u0|-q!_JBcq1GQGW#O16f4dsw-Ad8kOXKAGA(ZK+!WU zj$JMcHln8uR8M;L9R~Vw)wCbs16aA^huNQRi9&beF z5O`W&L+L5}3B%%O{Y5m(y7mPw;p6VpdpZ@|)#B#MQN8Q0IpjK?Ka6)d+YlhbO^7U# zB*Z$-kMRx26Pv15^5z6G=i(p;nIQ!~BDL)*GL;2W ztsf3p9mH~aLQ8UeyuQlhXN{FP;mRDnPq7|`b6Zs3?HyGnt(xcr7r zSnO{l)I_?)ocqb3)f$wa1qTANX42o6mE*@;G!5WSb#7fQU$wB|V2i@{gOqBLbYGvYSOc zvc*Nd;hI{{b`zr7_?)|N@v*d9F0v&+{lNZbbTn8$M9R@0_15jiX30tJ{`jfU!6GLs z8gPF%z-K|7GlK-gn1mC1Eo+gfuSn;RwX5JUoZp{riq6wxvO#r{w|OXqs_O@9Y5s{+ z()Kk$DB&%8$xu-$*HUNJU#@N=s}&b|RI_Z4`Mrzv?26(3X_=RFV%&wkJiWGY$<>t) zfb@6t?W5iXA2%-f&bsSx2~po#PfW>98Zap+iL;^d9{|kgt3J^T6EZXTOG*1Q&nn2! z{&5Uop{ zQ0&)t$l*POF!wwjW(Sq7J&)M47ZpA@YT}H``YLx=LhNWAC!(Xx zz{Mcupr^Gd7Y`N6s6+M=cm;{Tc8~Z*^2L-sT?B}Bm0zwVfH#+NIp;>cze$jeA3c`3 z)h^(zm;TiFJ!_ly%{A`r_g}pk#E9*QUEK9_v{|2vrgNuP{3CZ<8CRF5g^lTg271~P zKMVxF;_RW|`^!$$ajd|~CC416fc3NbD+^zV*&AtJ@txYU+xg#jzuAV+P4;XlXG0o{d%Ca?>q+$eh^Zo9E zR@ao|nyZqAE!!=he%=%NU4?!4R@9jWFPwe$E%w7^E)ivevBt@7ux)3W6b6PomH}%g zo4E?R^qxsa@h9K26fQx-pSQHZXbxTHOA3pijGR3{gzgyGnmBfVCcXSldiJFGS?%jL zqjetRz|}Y6=Tl>oj%fLI)jE!8<*)7D4CX_zlWKFe;GM#kNOK zl{Vp|gg0vEt9e)SbM`YkT9?pLumb)d~Z5a(h|DAi)I}sN} zPj#&1W(7Kh#irTaa2Ot2kj3PSW@K#@D6VR@ls!g)pEgov4f z6qt-2-EbsQMUK}E+5y^ZLst6BFLQdMN9C9S7CzsjWK`JG%YliSjqSyE9+yfuh@l`6 zQ`3!`$h-G7OQMaOPp9st{}pHJ|BADWknIiu@GrgS?asK}AgS9S^G*HnA@jA*Tu80w zoSDZ}^asO&2LO!iY{{#;xyIrNbMuhylXNnakf?ZcT!#zU+p9Lv6sbHg`_hlJ%@3st zdu7m#zk`h^u{WxO68ODjY`4UMdp&?oenLv`my_B1r;wKJN9_^NkP?uckblPB=&1g9 zHQ^xmrhEhsT*)G#QqJh+CUjNA#znrP?@GIt{EN|k87BNB^ogU!A5NN;1?a~OQ^~`X z;{ZE%@%je7gxj69zlo{WM{FCv1(yp>yHu)Zy#$e%ttXo~KMQtGkwcZE&BABT&m)Ee zcNV0+jxu>UM@KKsXao=SHx}np-QB}ma0xs=$W~YGy(hXdKgt+ijAy}=*D#hdpdO0p zS6d$m*y)Pi%2S{dQP5-%dYoc^~^^8_xqnc3#1?Z(a6y>BV{#ng_pOx|*ca<3=YEnf;pD zJ3m(a%;n);khR84T-bmNf7&ljbnn1$MtFa7l+bWU?=mLqb%+SQs3Kazk)z-?tbOK$ zpO6_BY6%=eQATnUMh`^cZZ{m45P8c6={`@SW-~4-^xKnPQZd_>bQ4z@38jlhRP3r1j3N!t7{_$a(6sau}Uk>)$H;6xOG zPeqyB{2o#w4m9g8g~y|yXfK`0r@J!)q9-HGofpNTE9)5Jyzm*BX}t;iJm1@_!5n7K zs{&_74|Z=AIgQ(!J&A9r9uj+W-Q4yiEA6MGE=QY=z1zubgDMWl9`JVWW0~l^?UP0` z0a#%^L(;^}PSroNpk3zr$BBB!uK2}X;IZ^ zkx}!CHod1or9&$2Ee03~!ly}`!F?1Bc*6LMYkHq6;-aPhIHWM>f9YM1g1hl(f$~!<6w@pgoYwA!rrpT(c1@k&f8uSJB4#TJf#Hg-Ip-eeCZe5~AM}{}u2+eB;fcpbFJWrAtM{Z|#%K#&IQ04N=KuKYc3**- zG6-F>Pws`;UsM1WZqEF6LCb2%_LtD;9I9p zN+;wAPE)r?(WszrEdW29D*>d{WAFg~B_sZcaY{f2yNELF+f#BJc(yJUjG-Q|40R(N zx#x_X|BY~=f@hobaPcX!^}?UB0OPWV2*Y4^Re+vr=nQ{4+-N~BZdq$F4z zm^^s){i}vMNIfD$rAv4?vqnfJGb@1Ly(d`e2KkHPG$>3|QhD;~GOKd$%f%Jr7^?zL zJsumm6=C-JVkPH`J}Ey!4RHW#jS=3g`W`~-#?C|#ymC~%eKgaxcwKSrlB(IlStCXO zLn7L^fVUkN7E&cIh5Gm~m;VLX|1Fr#IAaA;89jTucAD6-tUHr`ZRtT@;E?!RH1tk} zNtXL^-z=N*=J{UM&&j*0{LEGLx!65!8|55XYtY>FFM*;*3DpH7MZ6enyS&!+qiaG8xT?W%=O0E|8+rPx>Z%&L z5H`U8T#k&)^`xt}_rLh;TVwrqswo>%@#x`XZn_j_R|zSG)p6?wX4J*gEUs3GetFf2 zXbW_$%5+}-cRHnZ>aDF}>nJ)A(tmLyI!pyc_=lvVE2+#rIhr@s_$IV(ZTLz&fM~f( zKi(l!q`J4}(;EhCK@&+pK)!DVTLylfEhUY7?ZD8A#1SrKF2QApXkk)I74N9u`cK3| z9TzTpJd4f|d?At1m>u&rN`VnMrq8`@3+dGWueYJ@+HB}KAIrY6t&KRNwoarrj6P3` zta$DjxV}zHj=!cj<}mW$<67Lv=lFAdo8j%Sn2l8CTl$S=lABm8wZE0@mWuyONPJ`e z&Zk%Cdm&O*tlut-#6hSwwQIX)fIR z+#R$5P?ul%F~YWz8DZe@587p^)<6*p1uN9T(LKtkU#@DP081Jk$}b2XT#0e)u&Wgd z^5Y;O<~vdIRu3;9Q-SZ{p7l<0`HS{C(g0JM()ni9|Dox-quKo5_n$}zwPLh3L9NzS z)Q(Y#7DcI=wbdwU&lp8pn^HAP6}4*bReMuvuiCX@3o(Ar`|~}&lYf$ba&k^y&wbyo z>$>jiBFjdvytA+gXS89-paI5Jd(v%fHsWz>{hb%}Xz2Qh6Omaf z3MzG54V6_!{_m)McnW{*pWZ#aLs1Vc!A|0>h9n0#5X#dEcOZF@``$N4?%$_OA=wHq znSF@Ny!uY4^K7%xEt%P1^hfM^-1MDiU3Xqdaa)MYO?Z5ebdASxZzr+mG~t{BPwG$n zpKeTWg%>5;`{G*CK36kMaok{y`K|WT95nNFzNhz(m&r_VxK5U-?uU{cTi*Nd>-sTc zLn^>hB@vt$1U*5eWc{osd7HlSlHl(02;6d`b`}Ic zwkj(T)LV4MeSGGh2Zu>Md0YcSo|YEnTBu(V2!n^t(@F#^kBzO15?8_|MuTl!6|Z08 zA@?qtF`Aew0ZmE_)HjyNuE`g9W&zwROdiaiu>G`}}ux1-sJGRxzf@Bta%+CAB*= zsKUamw2_oBw&0&ntJGAI$-#wrp@9ipb-3un+`Ob`n17<__fP4Z#9Mf;>DUx*TU-L%ln`0OS6ZgfWc!Tv7NZ_4V@gNrg@GUJ?L_IKX?7mOp#_ zf<^J!3@c~ZNts-3ny~PN@&pOWc_(&I#MxN+zahV0(6s3z zSf2hc^jGJ3!1^erD9|W_1{P(Oc=6muk=~MsvIjnxZ`tS=P}i2vPl$dfnyRnAMn-6v z+<@~opzUM^MnB;yzv6GcZsf60E}?ErP&76Do1(R(8o@BdTmAl6`256*TFd+;`sWHLebc zX?#L(e`#8xW^NLV5eCQO{1#c!x)f#^Z=S2)ix+A>o2?2ydvit?rKYlR0V^?@-``oK zSyK&(5WyMC+xH~NYWtJwqSD{gR8%-gxSu&@%+?A>9gOoB#xTy8llPCl!$|rh!ei^% zt`%d()3_4eh;hQ&DZB5fp|k5BuWvk3z|=nqdPj(aNU`EAEwMW)Zsrizwk2pEz9JO- zMi$YbqYGgWr}E~~_SWRZhD%t>8?JMr+0LD2!;8%%?e6{Y6vGBQlx-u6>349IMK zbFjd>)cu64EyJk|9`f_p5hVi7OZoUNOrOV|O=L=~d&vgJ^`mpVJwwq-#;=vT1+!TI zLj0^y<>ML~U1yon#Z_!)t+~`@=Q!&TS%gb(wA+wc-u+t3=TfX*d|YMh)GFCikg|g< zwV8(%U&JW%ei-jGhJS6I+d4Q#_wLM$22d zglgD5dh;H;-FbZQi{?1UM^NDS3;pOra*9XVd) z3tA;`V}nu#V5`p-=b#58V=MpKu3Wa|X{5Y^xLcmAitY2PWX-C2rZ6fv!CCyNq%BkS zBrDTokD-fC4d?d+%y_IF)#hk2FgsAt(qjfJ(~_ z5d#27n5vSz?z@W;_?-hX4Ffp7mqM?-WG_e+-Fm5&rSoQSob$NK`QHxWO?H7}dUq2p zLd94)-M)*PQzKBc~{y5fKED@3?I;HH^5gR5(1f^E6zWg3c1Z z7%vmKd*sOC;Q5ncKW)Rmc5Eu$`jbxyRk3##(ZQ1Z!vIiY1tGj@IFXNkWecn+2GTHcLQnvl{tv^X0 z)GP$xO1j*)VKB=02V9SZ+1M?jy^b0%3)eTNo;PP1zKcx)CLO*McMo@Oa}7!OVk%lN z8Ms|c`c;b!t=;=&;V`7#DnX7yohw?!-h!QUC^Dw&D%{3mc=bhOOtjLI=#I_w1R-H@dLc1N_8&2eaxCl> zKRzmVshBTc%Lqw|-|1hmaCz+^OwXP(sy(*$Ig+aQs-TAUsBCJ6ia{vpWNnH{y{z*c`2k26~YLq<4^Nelptb$ zFMPGWTK^U+rr{)Om2bM2id512WAj~Xxx_GeS>?Jxvom?ogEOP?A=AzJ{8F5!V80m^9yV8uCt$f zp75Mk#LjY-o6yI-h|Zwz>P{?9SaWVzHq7I-)X5nDG_6NZu(?58FT9Yy|Aj=qK+-I zqtqi|-_(!lrD-hripMPSlIEs+Bb=f-)ph#$^zMa^zSP=wuZg*;G!cZ%}vDlSccOwbxn-n?s4CL%4SYw68tY79RQph{qoBvxdj;yK^}hXbiaz z95wWa-QC)pk#!v5I)_eGzwkER;@oZs5S1Xen|&{`k&YL~&i)B=$L(t9s{`E-dCTgv zaMQln)NW!+y#lYee6j4UsYu>`O+lJR0iR=MYQ07hub&)f{ACJau`zbsUrT`=g2JThK8txs|UwA*>+5_EieDQe=(nIJyv->=^v}=S-eS8TgY$DQ~h%m zllJ^g<*~jS4_TD%#mb=Gj|Vq1G7I)=?V0bjck;93E#`dRr60cN*o@DogKq-f`eEByq8BfU z7eCb1)zLB>`R@IAit%ZWqepACK7%3}uFd`sY%^!U>D57heMOc?xvRVz!3qAfVU1p- zgWH2Wi3e0FHlo>m2N?t!HAJBgA69b!aqYTQG7{pq3oUYF03~&3+jG*Lp(=nD;HoA^ zS)y5g0?1&%>O9S4IWHWFV4ydJ=oRry^;F{t5b^>@+6WHhS2H|-4Q@?u=m~`p95AO* zjrM499-SwW{LW8RNf#qopPP^nj<>3;A-ydnoE#@VVQ01UPB_4ZjFoCVz0f|9rTrG~3iRG%-hKx00?^#MN6gDi zV`K38WYQ*kN~Ho><(WERAlQ=(L&mVg)Q$_0v~r4tkstaJWGt#&BYKKVx9X!7)q9Y$ z?ox|T3F-Vw@%*M)-5LE;IdX0cK9$u11ENkRuxb1+)3PVr&p|71jvUphl9?RpNDfDO z)c~%`>MMP_c>vkB_Yy&*@lS6Vp%%zd1B~<(J%aI=y9+F1XsX|!nI=L6m!wu&dgYN#oXvi_b_=n^L6(dqLZ?P|!oo zALrYx53YSOdA@SY+6($B^vJvk&XvWh>-mEl|4Un=7iLID^VVI6gfziPi7z^aOle+v`>Wxh zjG(~U{!e+~MaNj7Eb$f=GLN&T7rQrN9FGf+(~0vAJ67KN&bQAfRf6o~FM|-$ghA0#DFP5(({*EswlK z3$t*3;%m2N)1SBPM8KHFS^o7Ki=_HQOj9rveklvOPlf~t7KpE0%^xn@2EK3z#$iaO z4;@IKXb(*zjB;ve+;oQdoBv%rSYkpA;o1lU%~RG^X-i3Jx*P~a@7ti&tQA%pP*%g{U|b?4wg=ZH9Fp`_r8D8_0v|v*GS{Xq z)T9~sZM<~A52$AuD+=q4Bd%Y$ebO4XS1AQGm*poykFcO9mfA0SKFC5AtvW^@0TDl* zG(l0e-xXIZK90T~_R5D4`s97DhnKt|9|t;4b}>w*#RN!?xp!4T_3*IM*4t2oqq@a# z#``h`QQNWd$b-~8i=F+@B&*010AWurYiv6xocJ5@x)P3lK^#PlN8dB~pssTCOM=Gj zT+wB8ScM2ZHm19i3Z7F}23omknVGftRo~(zMCw)DWe*{ZS|QS+@5ke!mKMYXaWEow z>F-I6yt<>LXuq0|%OHlQ?J}a7z&CS(;d)=F28Fg3b#JDEN5>aLcDb7EL%)RA)5&19? z)9i!NqOqSJ7y4a;77raFuK)W(*?P zy9LVxWih_qL-PW%x`83O`GCfAs&QQ^^q|p)ubf1@gnplY-&0s1$Nn2<4muL<-661+ zd%C+OJ31Q4W39P)haZl5(n06HB|5)bo@AOfaJu3{#(ARerrZhzGwHfvN(1YQ`PA+} zgr4H?<@w2!m#gTPWO^cL5-2i>Veny9BXf-Li>TY5)PI24G)>Qvx0$C5{yT|77nu5* z5=g}M!rBRb3)*MtS37xU=GDhQ#$;w+UFeBNV#)ALV~B{$%BAI#a`I5y<2w)%l?R8v zGsv~z=$<2Rf1$F+MGfCdr5TecHxssEB9gAL88%&}bNS?K?56Sn`&o z+{f+#y#r{|HF*@wr*dO*=+)t=m2kfd8}?!L?xy~i`I{sCNwNJnBlp?wxV;~N8tQbl zG%<(uXuz7?n8*7RBEq@9+fQ*}RA@6Cmi=rdp&z1F^lx4Zk0k$DskIGt)Kb4w@Tb~B zq6`z#)K>TbgoA*J;9D$);6eV%N~6sZjw#1{3t16;Bt|&~QE?oT5aUj;7~B2%=n&WC zH#y27zxl~Aq2OjNjFHVV)zYXlF!yF9!oxPXz%aaAr*}mrA_CH00tGsQre}30b zLzBU!pw>Kjpd<2HZRnnS&V|N$OZ5Op)hGo>*bmJd;Q=xhJCH6oI=rKkJ@fygoa6c@pSOJ}&dKf;k@& zk$$r1{>B)Jm~86bCJum{px*lPC1uLe1s*8MA-Zp{^XZzh1u3NB#OPWaO6mjEzM!iu z?^!U)jur?iVjWqYOKg2taKQLwj{uO@YDLOH!ab1XZuM1sK(MLWdFMC|wL0z?ReYqc z8`a<%B^FYy(Ygc?jSK+rQsDbA2$-_Dpt|2bKAc<&E*8cn5a0uwBIpk5M>Q;l(fW&l z_X0N80;v*K7!IN}89%bz)b*zfL)sWW5Wgkiuev4YE&PE5Lp}Z ztH`w}&zYs}>Q}{d==BSGk+1H&XFaWb14K}M2CxGhcRIWO1`zC>BlVeOkAgIJ&Vr-P zWfandubt>H$jN?2nTR>Z{2ZNTk$fo^bY~9>vuO&93Ly8=m25 z?2CM>`;WEGo4-nP{~iA(UoH>8q$vYe=1TC8k{GIa0sz(u#SmW3n&GvAxqc`phI)ab z9kee{tp6aEW_btZI3-~2mVNt3WvIL3_{#$Hn0;38O~!xb=MK^m99exhIa>7RIomGM znVVM%NElAe0UuR4^hW7u%%`N19A>(R1k4?LIGcC40y(3v(*-{EHr4LtXoD4 zijdm-wQUc&BbB@foYcqt+6j`oTEIkW-t1wH{k(t(oEux1-D>PlnmX*cP>{iFo&}s=_N+1 zT{P71ak_3D9TvQn6m1FRFBs7N=#DP|BQ=;%ehU!)p^SV@0Mq$S#I0B5aA*JPSJ#r= z3G!&zExMw7QiyUgt*!{xMvqV;#D$tRU#d-fIK+L1m5UpBwv@rhR>SWtx?gnY>Wq<> zshfR`kZPhhTX5wG*g46J*w{6V3?CZmEcLrg!VaFzQzs|UDWyt#*&yEt$JF;HiP+{n z`4!#2I`Jrvw_&duTrKhVv5|vEH4>tR+s{t`HIsQer<=5~s0Ee!2kH!jj7U89fU1E6 zq1yD|B7J@Q&nM=|sMOQPK@;K@cCc?+vkQk3RLO^PFllYg`roobg}#;}^ft_# z!w`=1Bp>>zv~JhGO7g7U2zi6{jEMB~o5Rq{emvilbfMD7os*t1s2M2!$n1Cg_-eQB z^vbHh#&G6y(aOcA^@n*>jyyH#4;MunsaE}ZxjgBZ`Y#vI;UZe%F^tR@3QEZ-ba=wA zl{buav;SUeuUtVm0MheS+tDmyTmeq1+@fzF`I-M>=|s(26GqX!jg!$Vz#0Vm=s`yZQo zUb~x|q2{4dJ77-i)GN3NW*xl4H2F3*FVL8i8}R`0EK#o)!w-D#q45OGcQ$n+_bkwp z_(<-|v1_uA^kPD~!9&$#_Z>1Pe!gd&U?y#vf}5Mj_1FqT7==q^zbZk6p9-ZD=bhT+ z;Zg>HP-_%Q_IpCst4IAJj8A?Z2Nz5MlrUxhf#?^RqWUGm`dM#2aZZHQ)hGk%NMJ`m z3s4)0*S2OcXKc@rhrGq^8=WFsuvL2r+(K2Z>13kFn>lzsXD; zqHrLIHHga!%?dG9I+Eq3jkow{hF||%hV6%5$jz|(2A*OGhiKKree{u`2KStZQ*iA#WB`ygt)yGNRh=om->vA z!r=GaZIN3tBpb#MZ?kcwU>N7Gt4n5^XJ<;bU(Z;KO>%h}ZjTU@YZ%l7; zsWNQ{v!tX)l~YMe)%QPSBxD(A3+e~{Q0s8gi^CTMTI11nD7-|ORl<>Ge}h)5Sr zg0g&cWBj50)=Hiz?9a1Ar}-*qm@i>`v&dKl zFzxPMPhyl=6j)NW_r-KdCil81>u8j0l2FsSYjafDjc&7yK?E%(!irxr356SXSl6I# z!kZ)G4*l@ye>Lh%4QmdBH7wt7-os-S#{6%4thDDHZL2T6Dxre=W$EpSG6|5(Gv&%$ zeh(7mB)>i)=gLK5Cc4Snr42@f@wADXn~wLER?sQ2k>3J6Lw;?^pyz}a?jtaenO_?+ zD`zS^D(fDb4JOULmqYT_!TOv>qX)RK?Kq^H@pp~7N;pQ1Xyd>aU7dIpAg^35VyN82 zYseDN5-Euxp~NsxLhcq3W3>Q8y(Jzwn(VSyTgFDEV2FqYozMtvyky(%`62gaSzqNJ z9_P9&c|EUil3CVAKUb5$waQSZVbTrJiXUG!+Xuo2fsTcTxeX?N_?@o5m1Ys}10#esf&{d#Oj9{3`R3jClWH>&%y@S2Yu@y-j>pglwxWZ)+g;xX z*L{)|Vc7RHvWs=RlwKtEs8}|y+fgW)w=nf6aT*oG`jJy-9G6XO*Th?6wCL5#adyZc};KdDz zuxkqAlDMt7`<;BGAO2G9 zXkyTR#d0iP4MTp>EK%8h5Zxw%IOsY5gx$%ldNE3JYiu;RiLLp#JAiH7ss2%g%x32Q zjdtQb;7I3@LsOa)MUSqHPB%8)BCdvOT=x1d$C(EDN?#8xZv6|POSeAwZE+;eJch?g zX3O4}L;v3H^_iXsQV;fm_sJIR5D$Sz6Zq-4@U%R7<8C;?AllXrk)q==gM(^~n)qs;W7!C(ZS^$M{KR7^Qr;xr)hP!dz^Y)z-GQ&UOF`P=D75NX_}`)dk5q>v3$HVs!{3b6x*9 zoTElQjIYMUscySD*$Jgpb$@$;gGHg9u~q(tl325VU;f?rm=%%WMX!o2S^QT+z8WnJ zUXf5?uYx&bvNcJCXud>!uIyxgE>lQi=HJFSk~^MuRM-6u-*+j|3KqFO*E|Vf&w|0r zIpfh8&+j`a$N8)s z@sYYrU&(CsZmOnA?)>m-XdS_ zP5Itu7RCMT2^lQ*nFj;GzeHmJtGWB`Zhi8B2FiPomJdw+eH7x%22tXTSg*!|+c6!? zgTdI&m^u2JwRZ7t73PwQs^5nl`p_Szr-KToBe&e z1t!{xNFxF%cko0CyK>DcEK2t|nI3W*vPdB$h-6nKt!B}ard>(B!jq+90_mji+@baz z_{>QFovRnNWsyB`9;AOHY?|K|eLOOwX|`J3zt=}cey9_D@bo4=1xBYE4j^5vO8owzn#S2A-t!bD7I1O}D8Efo}Uc{??!~szlKN( z=8sEc1n>Zem0vHeVdB=5qL(J zi?dJ^d*C$hQw|+{EnaOrs_L%N^c|v6z3qkI+cW3if+@suw&%<<=bi^s2gt0wJh!`y zEBZOuND_IG>Xq_XF0f`Bb{fKcW;}%_NIU+$J>luVay! z3MN?Iom(_;lK7e84|{mt@CZ&;%eeaDEa+!}Q%M(77KGnj5+=uaz%TT7^7+V%^M|g- zSJ$X$S?RuN`FZKv2^g`>L=3~KmEq)H`OO(-N7ilqss;NL&#Jmbs>(AdK&5qozbLs!#G4FeuM`?l(rbH7X37ipBzDiU0M-;1-*nwF6m$&B0IiC{|^1UtA&RE z)!hIBTQg981{Far4ciiQe2!kUBJ(+ zoCC~)zh~M*(Roicf)Fc$>D?fYu%`7_{Z{>SJ^Ij{dwEXGAR;QTz0ZoB;BnMBMYBi* z$@A!oZtts3u{!U&k67tF^=9htWLA0|A0+$wyO==WT8CRIA+75jXfnU*#SRfKV{@eHOM|S+zo4( zQt&xZv-rq~N;(b(%?v89vJzKSQQ+CBvT=#8+1dX_u~LE3kGbF|#y@Xl`VDl2JL`p4 z!aJn!kmgs~?x8S-XUys85r$_^*1{4yg4p5R7ON4z8}D_JMC>R8Arj68yEmB!O|)j; zy+5eU_eV}V=eI!#79{#3lU%<;%e_e}PVo`jXvIxm~8qX~D; z0lxdDe|9OQwT9?Oe-QwjFE$3<90JZ)lXWAespu+g2QURQFjg)8b}fm(zeQcX%i5@* zG!$ehND{=`uy?msNqAMQrro!HN8`$(KTH#3obQjYr*P{-^&eJ@F&s3e(k%0tp7qwA zE#-P$UOHU%LB+dEZ>9%b0?Hz5N3DSpOpzkVsp|SC*P~um^5tS;7XKQa{>|E^wUP#= z)449m$_oEICZ_!H9J0XsV7t6qL8M^z!4RFuHHFk!<4i1=Y>py%Tx43+SG9;H7)YX3 zg3*y(9zB!|jFj(NUSp={P;#XIL_X@22h$QJNt-Hgtm)pMY=45G&hN+CdyMfKW$xVpoZj7%WKZ^YyS0-+2Ol6LPI&Eorg_bSc z$aA`$HkUljt6QtOd2O{_+wNoe$MAf^FE2FY?~A+1y~bU3yaHD& zb=v{twI=HismZp^Lf!x7gok+F3H^`_I~wm$uk3ky19#={iy!2uXnhU}_um}kv@eCjC%D}MhGN-vyoJA`WT zb(xReA+Bj&w_h$1_p_fn+$wc4EUKu7Z!u35|F5=f^=@0{fwfIXWiq+I7bW9-5ZCfv z4tx40axFpg z!pRr$I7=6a*4MR@6#NACD+pnq5IK}FDum1oD_aHy<(7~q9 ziMR=0PNl^QgAv#zL>h$r;N#U;Dr6W;DKTM!_;%IDK)TQL|^P{%+tNm z5?ce!=+c>C3g=fOZ6Um8xDE;BVnw>9!z+bp4BYbWcAICqu!kDR(SiOS;eGWdJ*>0& zujlP+1HdMVUwPJm7H&3aRy5)iSft=jXoyUR9ax#y3ML`q(DOf`&LBqje7sm&^1vCR5W(EZ6UpcE*@xCmy7=wtulyr`I>2^tSs{tw>y}*^B?!%>HFUWwMMSAt_rM*>&8Z=pbkZY+(-x*b&@zO`bXQ4 zk{RuT$KoAtoZ1&OxMK^ zPYB^l;aK+ARz!%L3fJlbGiS|dcLIzR5;g7K-=WqpRUOGJ0y^PjN4=LU|G>%h#bsmg z{(++Dc>(szMq1?C`w0gjm^v`_Edjn#AVN~FZZ0^Cn(xS}YZ*c==z>o4v5L8T4i8!1 zPS=n3m+w*%qKw8SIaHzpyl=)wRR~@smi8b1XA+IFXonjm7;)8e*uWoq)}L zdzFX*8k`I6;wG(!;ijPm+E}Ugf67+P{#27slF-xWK- zp<Zu_sqk_Tj1EJ6NYUJnoPg`#cusrnEJp|$GZyjHqToCtC% zOIyWhXbbx^$!sJu1SGm@v!sfwk>xk-jeSSEbO`b2wB#ZmXRCYBe8n9_)WUie32(7EJ&hqT4ew&hyYz zAG2-%X7MwnjJD+wx`Q-9!f2ih7#&AD8;2++y>h;HfdsPu&{Mp~>iw~C+OhRZSL4d$ zHsLeAvXKCCJZi!xQ=z|8bzRvA7{*AUNbK{3Q-{yVLE&3 zWV}?Ag@@Y!D0rjK!C2BMDpI-=8Xt{_wkhgSw6s3e6kB0w*&4-ur(uqnwb5IS(PWOe zIacxP%Jf_Q`R9WXsVksC7Uv5dI|$l}WP;U=J7>i*V#}E1vI6S?`}0oh@s}Ol z#eZu+gg-mD{Coc*^pB?{V6H7fJ9YC)O@iT>2@&ApQ%6CbZm6;Ul_*(^f%Xd4E8OzwM!z%6=YXwz-T}LdhJF$C_n>vV00H zm$Ie|&8inH@JA*lRgpjsNg_x|3y*P6dF>3ssz_Fqx1%;;)WgL}y z%hCqCUi{{vg)zXGKHgK@IRSvSDPKCYLOG0%nmqPR(PMsQG=5@B65hievaxa^g)eFn zw^LnMEIKoSq=Vj3OZv(RE_pi3J0XI*dK^N_x0Afh-Y{)w4-y;h%|Vs7*1s|w9ByPr zy$Lu-IN%$WW44*>UddZLE5Z(v&p$A`1Dly|6R&!W7&e>hZ3Bdpd;4`s7w%_-0A8E5 zd)50m?}-%?UiR>wIK{+Kmvl59vzRxtYEnUM@0~_j?Xj=@LY5_SRKZDT>2(gTzq1ah z9JZ`eOG&Pp@IFPtpTPhDUSw9YFSfLwiU% z_E}gmb}cbZVw#}t353^fQisI-n3Jz z^MCuu@^%>?13$K@f^c+wNDt4v z-7L{v5pdPadF+>3FnQ(lEwnA4SdM!=MX|^Q*;)eEoyp~#ykq#!tg?S8UPCq4F)ka6 zzKV=KZ69x`8W;J3y={K$Kk~9Tqfx`-HLxt{?&hCT4s8|s!X=IIcOs?HMX?X)k`b&-5<2@>?wjFV>8n;BHxaYTag zaDTp{xSihoseR8pK9Zh3tc<2ZnT%TWNd83GH=*x7K9cYrp_-qPX+~CC5n+<4Z5k=< z-+wc8ETO@9(zg{?M6DGqo#Q=+0@1_+0U&zTW41qJmM+g0;er zONxnh(705Lt(JPVLgTH_ z39xU!d|7teOfXBe^Y_YAsesR&Zw>}ug+}|2F+3&vJ(zN*yy{=T%|g3wfVdqH1ja)kV+1{oR!R%vtIkIalu8w7SUM+{MEW7-s$HM(5IGkN*-UXiPT6hT!~bw5tX$xPn276Vx7x%3p&uA@ z%V)+GiCi0S+Im?yXq#aebJBe*t_na;CM0&@2esE4b|Px6;p-Nc#y@O9$E&))z{U$l zOa%07&b~^r^uIG_y!YRkTO(_lSO?GC%QI(35S<;A*>A?+xDKY?T=~S9-3RbGM?pK>%Fb^reNS~ENi@>SSk4V1XylvXnk^t0eXFO*>-p+cTRT@VQn-fX2XXy4Lzj@7ZF zMFXK|aGM|a6$J6a1J~V1mW}s@`7C~F-huu?ORFn*0(ri8O5c7TGjj>I&)gH7 z!7mAl$!jR^MJ?l`^Y?$$hqat~fIyaOhug?t8n!I`g~mCKEc%6Ta$62xgG> zH8S+@d6LD1witF#ijaq|VGbXM@_D8f@ww0#SSr2gAwiMCn{mOG9tdp+~;dJ`fvQFHA3z49wMr09%#3C?-0STqGR>rqc%v4 zYf5__pM;ObJhU+MmXHv19~v6)ArWNDrdDOPVa;ZL4M#9DXeXI2e^^d(vPuOUR(2m`?9F~WDY zAnMhmrrbYSpZ4+URf~`)GbwepV=q&u7&wXW%2(F2t&_pM+Y_8UBAofdlW;cYkVd zHQ8_wu{=t!aH&XoHCJgP%z09%Uk3{zHy%RwTwK4Gj!zKA$r=5YMcmDYA(KS8jV2-q zlU|{cmtz$?nHq>_N5Mj*onnAs*yX7N5X6#Bl;%mGJT9o~m+GLGmKk0?z;gIKF;V}4 za>;2ZLW`9Nh%hoODHB(5wKmHxlVJGc9=)1a#IPy&n2W+_|2`GSJ|EFd>A?MvZ8rPl zi;}BX?rx7Kz7AYTeI++2{oy{Vdvb7~H|OTe5wN<=9H~F&$Dah7XxuJiShfkk3_Wxy z*)px!VcJjKMx%Y`5VaYUklkv`c%z5M&>UWI_KH~Uz!I0`IUP2vNdqN+MS7B>nYNtE06RO zY}$XcC1)c>RFtD>U(!o6P_1YrOvsX?Sv8fk_h1e8=1=`Kk@I){*u79^w_q(eFdX&I1a z={qWR=}lXr_=r#F(#)ePNK<4ZY? zBQ5LVLVK58_K_AG%8W2$m(&RfHFO&n8a#|Uj=}oYpFojViDzAn0lICC;r#sW)VC{p z|GdoMa5&gAP~9Z^AN9>E#QTuhvzlO)lbn~Yh?IJ{zFG8UX>gM|X{1Myb~%C)z?K7; zgh{cqhoZNn{fUWKfBNM_&NdZDFlKN8yrm}nBv$9ob2Fzj^Mc|n)#eC*7~1g<*S}*@ z5X1y9&*?&P4D~jTpA)RUIPpbGsCgTi4$?Chc*f#wq}iZ(HjCg+za++ zCZ^}?KExzj3FwJKbvKoo&-8ONJtZ>4^9`jP%-@`yppNJ9^>+2Unbx#m9zXF^&^ATw z^<6!(L}l_~G+!lQ2g59-omi&cWBtAO`9U8VNBru?>bZb8+xgiC$CWcrg8(!pLCcUo zHnSyO+4Qr&8ISm(N9TWPqL%Y_E)Y?xd_<{EF;~TgFIES?h8F?QF6V)@g4?ClK{=3! zY>GFF*B^MPk&5>->xSWn5B&sMfNd%>X3KK$=IOSBPq3pK@Ndnu9x9ypYkD89+do%gJ0Zc7 zgIAq_Uc9@kIX@QiHi(iYfae(M!aF|YKY?HdA%a7%*2z*P?;+sMFR-LkIzdVS3SiP( z<~qyD9E)roHV*A!LMAOF@yl4)c#MOvrHHDfXrTU+88Sp013(}T#0dBm;1eD_WLCa! zh$Uh+!AV|3ZlkGz3v5gYnvsi+KGhRsnBkuY5o$eaKaZ!9DCF2RIb@Xdk-t8}i~P;T zIMH;=aAepVT=)eG^O)v5d)u(xtJXEcu-^HDg%0sU%WXcs5K)4q_lgq)3`#AO&O2lZ ziI>|w)5C%9j^-*NQcDQh+S^~YnlnsH*f*$Q@@ybTaax~YYOatG`mVgG`@^T+TtCjA z1l-p+?bu&M<56F80QVh3$?>fViADQ@m& zQ|_t#u{=>NlYjHI7mb)_J2P<_`pn)@8xE#wsOy)-y^O1Tqarq8Th>Hr>*X64Ji>kR zS19K|to2wZ_-St0Hxl>Fj)OyAFRCm%J-1vsiEZMqGzh4C6rDt=JPqJ4Tc7o}8xI2# zi}^6Gz0(YnXGt3q5_?Ho$<_cTmbZV15YA|g#1kO@RZRg5gI+oT)hn&PLZIe;uLc~j z7e38_Ub;h(yL+WOxZs4;cUd$J@;$X43#6bkdlzqR(gZBZ2_f|(XZGweS&B?bp?E!r zXxsRK*|UrV;Gg+vk}hPtEOOtK0%Sj#ylo9q=~|kS=sQgw{%X!&ouFR4L9K$cG4#b} zY$iaqX1y`T#U2HXoIKeTBJLnpqXLbC$_gF=seTM2k)byDtkjbo%)E8M@{Y|!u%TSL zn_I(2O3w4%i_^?szEJH-V1O}_=%$3Fd!Y3W2cu>R?k^Xw zrbvkfxHzi&&IY)zY>?7L$3{!BAb06y^6!^D0KUOlw zR=#!_bm$F-`k)O)4E^Fx^@bMkmRpwMR!28C2Tevmk;Y z4z1FC5rYBl$vND;O=#bPmx2f{2oM-GFFS$2HbZ^U%#((9! zHgfH_cZVpE;)j+}A}qZ;L1PC3T`*)y3NgIO$O=fBIrVyU+Bx5OCGNtvo1 z28vFDB+Au+NA}7l1HD&gof|EOmX`A&VX@@(B-pEcxl=&Hhuv7 zrPhL0lyAO%^OVL9taFoA6c{(t5pFw5vVu1kH%Buw+W)^S07l_Me*?Yuo&D-O4|3d8 z@c}#_r0YjJ!8Z!=81d&R8ILFOiu7=Wzp$KBJn!DRXs@Fr0`ML2Z~)@RxF>@gV!Y+9 z{_@erG=iJ~B=ss&?Vm3;#6;@g8pFqSIwAIMhll5E+}=w0RS})s-Wz9bC_MYn;IyF) zH*#JC8o`*^k4>P~I@2ETslS)0md&-?P!AKM|4+7ETK`5pr-)aNqvQF8QShjN>%vma zo6U-s-omSV!AE~>$FP|m=&le2{;gKt)WBD%%N_mQ1$lxFfq!lb0XD2cu9!BazSO19 zsQd3t6K{f=zR3q|=$|+|h9V5=8`otLbAc_GQl$*qp1TOdHuOu;@jW3HM&d9Zag-i zRmu}3s^0wc;=<(fXSUqo=oWOv`DC|>oyNvTx3(0+_(<6)zANrX*{r?FZDBA|Or)N; z2S2LeX4cHqM*QXM(AH=l?+3M|g<#4mvtq-x z)wqlZmT^lnx}*5a4>1m_fc&Yt)yZTV0%T>n?bAN&-({H5M3_H~K3;lP4O`n0ok}m3Y)_MCmP_eD2O?C(Y+ky%?xV0`MH0aO z(0woj;>!a%nA?D8DS9)^dlMY|b!;C$W=Ks*VVX@WzIK5l?&Dz+(+%;e{!Rn3i5^97 zZVv^;88{;g-Hs<<)UJO@I`LHZqYOqWQtUnVIfrd3SG6R0LoC|tjRECUx~%ZX(1z35 zn$djv>o;p2A5dnsKlL0j@2z~n=X)I#ZA2&?TxoXW{K(O*ns>?toR@@DA^!a+FQ&oS z5K*7VvhQ^qE9W+?Fg`bx>TIHz z#)#F)%pDhS^W#HF+0n1x>@Fstx2?ZjzIn-qJpFS~Qo9$He!F=GyFK>3gYA1 zI~LndW#m+5n<~SXqn}k(wP){uLNqZv)AFcn8G4M~vmaO&_}@ z8DdyH-o0U9&26+d)e;B7Lujcqel1{~5(5>Pt+v4Zs4~lrm);Xk|4h26AvZ{uJ8y&*vFz()ELbu&LC#aj&*bY7-$Ose|1C%=i6%v5-eZG(bCHyP~+^Yiu z&|UOeehuyq!tTS7veAZOH7puO!=HO%l58eUL)K?1sprp+$PZJI-jB>#Czf3$on^f4 z>}MGJi?Ru%ot|^>8d@yqllftJ^#ol0&=^zn?k|Z-Lv=6%<>527z4RHH#7IZH2cszD zOrB@EiKZoeKQUFIv#Uy^qmXLd_~}}KOr5+xr6a47 zj{7YA-No-6$DRYyS4Kz^@yw%A439r#WzozBb+y{|#~?Ecu<=dI%ZbfwMz=}H?< zrXLoAb2h5==P76yKW+OPy6mu|dM6CKCYPjqny-3VAaOUOG}DkzOA|3Sjtiv@*mMf7 zNX|Dt`A7joS}}%HBLjn!Fr+@j)nD#6_m5&p5@ zDu_c6yuEsjt_<^$h}SY?e3iUGHJWrdw9W8Y&X?4na{f`a-N*K>-rdHvZ=%(_>4aAu z2{nt^GJCT0ska`zVCPOk!)|nJ#$VjkH(kv-Bzpa?7_Paz|4hu#8{C~DqZqH-B|}E? zi&!xAK4>jvax(7uw8~NiA9g4w1u}^D7pSjgayu6jrlS*2Ev*6n6l|nu!_((4Ti*;G z5#k$|j*}7~q9y16QV?neaL>FAUC6EvH?5|A4=`as_>jHS2vUfrw9>$%3GrsEgi$-? zUBOY>01Nm44piofByKtZ=u~E2e{3x|32q05NIY{f0q(63Weh0*vdw5>WpLGq0>5+? z`wKKCv?Uf*3dU%F*#!>&ycy5;7G?zvz`&N-?4QjIcs27_gsNVww+E@t$)oMrQugav;0{+RYfo2c~1jy5Z<^1T&@+^MAK1TTz@uTf0Mss-eXnp zd-^R$klx^1iAwz_jb;tW^f0fDqzjfEz6f8R*(~fOaZG}sUXW_zoB5grBb2C}`j*ZF zU!kkbH?tGy$xN5avrk}&$St}uVfLZHA%$xf@@B@M>*5$u+JT?$zY}~;8`{us)_fx6 z^pW_uk57E!&;)@KKp21MibsDgeg7}9vz}@m`FS=wB_MnG)a-fxDplWoxwd4;Wc0rW zQg-C(IkU9a^+%9-;e`56kt2w&?8K8r`ni7GFj6`TDQcJ0Wz?$Sh=QW^_&tqA=w>G& zIG_>1!>vhzFHYrl^lUq2)#YvVtw0<2n7$%+MhOt|`DcW1MID2xu6Bg;6h|R0v~RAU zA4C;U|Dg(~)<(N#P&Cu5XdUpM@)^`By6z>96NU69(pW$WrDiHZshg#k`E18=Y;oXF zNuiP(hCCW5D8Me>X#J#ZRgsmcNdOnXepOsUZ0S~FZTWxz7!JXj%E5+$CJj}*h%gM& z$?&g;;3<;ruqW2ne~1X+IBB)O9)7yf8kOecD>%kcBR0ldg z0xVBV9KC8RSX*##IQ&-|wwc>aua_epyXnNiEXFOE_u=)Uul5$?Rxz-9#fr}bqPS^IV8Kj^~_&M8{(U$aV`={dzD>?6Kf z(6}cJ2?0`jqWk<`Q5bKQ_jkAZoG!nQVn7atE6xURuamThVRP}9C$PQTCD7PmuU7-WDvYZn%9e8u-CQO#X? zet(hFG$L_4FIMQoUGSFOi86Sm6{{E~(V2(kbW+KxU$qmk--vUuy}9W9{PX<$!>yq6 z2PdybXEPVYZtb*p?`CcrCR}lx|M%YeFDV0**jSs?fl3PH)^QC`UZFw)W@~gk;Gj=U z4!I(!WoC@nmphC4y`lFG-iQLQjozxuRIZNS@9X|1hz0RVn^~8Fb~dzO-l#nY(+#bTRPlY+3=h)K13yG2}S*39INN!4Mx z1@q4@aS9TkALz0_NdFV1B|pCa53Y#7^!dEca;RW0MOTt^C6d7pcdd%DGkW$SH+>4) z()pJC=<7$HtH;-s#Pho`hkll-S>XfD6eI5olS!gPU3T`z_iH0k0_iG!-WAyxCIk-1 zHaFcJO^?@4H=P*S#QD5T!u|%^cSShIwDZre@^!1aQ!R8uz-lOkb3tbtN>~X>(bqHx z(+2c*Csr)P8i6;+4q2_nq>9&WPQkDKCnqXmNjX=6FrLA51OP<{+faK6K5mg;>D z5>yBhaPefG8;&8gZ7FTU0V*7fIRzP(Ngm8+x9qA%Si;&FOMbw>-BH9*#%mRMH}9~J z<~c^CJxS&3~+P$)7Qak2n%Tn)+YV8Lz|K7S?Y~#7|o!H$M;n zhgr-I-N7!ekoC71OFhMHAi_b_5% zUHNpr!0-n3QtDMmYg@?tvrsD*85U|RP2+F*{fHOoXKN>a~T z{#*RrL_%DEEA9bnUn#o3FgX~8fziMMrlKFysnzLkfq?iWXzpd_@;uPpQk?l9#y+$; z!J;NbQt4gVsQgvLoVvWm3qo4H(yb;p#1AR96jIVVLCNw|El&|Pv{svBiwztTj4Zo9 zhswBr0K(I>$UIE15^Qb$cK?C?oXt_ZszK!xqYulb?_baR1S1m{78`8?db={5tEj$+ z+8#<7#ECY5`gsO|wEj|*ey=O-PJGSmoOrDtNesAv$_Xme#$}Hv*3Q;&`<%!d(Aax# z8-Uq3Q~92sldaC<84Vr3x)0(V-lv3iP~F&NDEX5h`L5Gs#m>)+`kIg@uB@N7i=L1Z zP(Ilo@M>dxhvuOC+eZT--?_-hrAnh{|4L~7eBfhs-Ytf3xe&I1Yo{U?wp>AYYvq2& zK+B7t!&<7^cD&#D>DTyUW8 zDQ-z!J3d`xbdwS^QT6K6!(SNMA+5sn32Eb*yIG(h$7^TpOXSLmc_M7^bz1#f8Iz6~ zyex?{6r`)ItnAtE4YC05N6ffiRoEeLHCHGm(xaM*8tN)t_Kcibv%4K1tW2ei8?Fim z)yr=m(9hH|l>Mzpxx0Kv@;l-obt^hF;IJlYhOYI7N=V~g@+~@J9us%w;D?{-P9)X_(Wv*iO20S+Gv%4h| zbae;yRmHAeZpyP8cKRsN&p(pt0khOcIcteU1g~N>>05KWKZJRF=oU*DCLN}prS^mNO%AHpRlaAH=cfj~P}oW4ax#S|-BlvvmXuE*k#B%o^sWJz&vEeL z{PC5&Rl$jv>|JE!%_)41AMi4C z!7}Uwr=miMeqP^1y@>jv9d)V-1Jq}^z04pyT-o22RHKNkux$Wkbc#S+&HuH2?Zs6JAILNj&Oh@o`FbzDDH#QGq% z-sKP8kJDpf$DrsBWmKrH+DO^!Cq#4DnCd$%x_U#Gv88RABnA~HHtCNBC_KEnh zS_lv|T%0YZL2;t4YeIw0$yzkY#NpDkG9-*$wMqvCsO}Fk{?WsfBH~W3Xv_>yqGrv{ zm|bTDdlqAr4xZI2o7yQ7I~Z`W$=nT>rduSz73pi0JgyIqhb_+^>bct&o{%%V+FJW2 z(t_KbJ#f|GxOJ{ez#8Xqup7IvWVzN}O9c0q)yqNgf~>CU4-PO;3=}%0c+=&bz1u>N zl7IS+k*T=e6~`Cr2`k@`PFRj3RB16ZC(Qr_rP0K+YR2|F0Mn$`3WxV!=UHN%r$Q(1 zcWlJUU-ET;RN$gew=v-m$Fc&uO?BHe4tE^H>{34 zp4sG%X$xcHF8STzO&c@_`q3(V|Hj4BiM##*t!SZDa zY5O$VMJcAW%cX0NXprXdQGXm3Iqv5J1#Db)a_qP!K23-Efe$jjS}y+Sv z-@KP^`U9G8btso5oAyi&o0!86+cujI8DSsaH{E69E$zI(J-%6K_mxaqzw)+nd!?sjzd`e*owfJYWA#GS zsR{Wx-?ElSl-*$KRigpf{2@JexKPvLuBb~ai@6Tz?K0j2-%O=S#(3BV#^4EDg1>i` z!Nmi!#0&)9U$>vQxHXX+7}-QWLBTv?6;HrmMHfR!{3PQu83K5DCuTH;fgn^$M714a z$jdeIte#pOrTaBgnLS1(0Sn5CBRoLhMQ!Wr$^ukKQkvZLd^H_&t^LP8j=m8+}GPvAS>x0;ps_L`VIsphC+Xaf`;r# zEQ{OAj>8T6TWzncg!yTc57O$4`W`m0GbO8Xp+byF{oR|7v5y|52JuRJ-_=Dfr$oj~ z4YMa08s?;6RFt+aW%>6Wiw0E{N?D6wWo1gO{0OopnyS?E#KjIkw)5-4tFH792kuk;nwNVj2h8`Z~=$}thBBETMIJcrN;y^h3q$F53 zV5dT3C)v6a-MY7UzGad6rFnV#aQ&avr}eo0aKYhJZv;a&*WdIlOKhQ^WU)&gry&aF?_tT8(;C^_en9aBh zCSq6KJ5j|f=g#-xqOaN}+vuT;=?cV0rv4gH%}Y`b;xLkd!3?|CMqZrj;~bxM;+^{r zEM@!`jA}=puJ&AlKPqtEH>G`NJ9+0i@6Eo~u&vg#E)mt$cMr`U?lyGrH1?*?2QNK% zhUBr?#P0i*a|W-zaR(KhK3wQd{syLabSwZJlB#Xc$ApA9i2;!8H%RZ-oyl+M0Y7q{ zgCr{sIp?pxtK$veJ#6bm$Zokf7>>!UC5|XGm_&JB?@QsuYGR`H*cdlQ#f&j?tOZ=@ z-i9&OD|G@m-Drv=fDis5kfs?LF@i}$S+rhxhyX&WeDY39U`-E8K79*6W2CTjlFm1b zb6R0Bc!<8`o~M<+8+>YPtdFnwzq^3mKT~9>qR{@_3os}EeOXxpx;Ruk7Aod%Bi!a` zTl2eRNdQ0D<>ujaf{%&TBm&P_lk0rB8D%C4C* z6Z5=+fhou+gbpLt@+4+9f;Rz+q!ph$F=tS_ zd!iG(?VmoHq@muRCA4dt^)ZPEWY7O%eRsypt$CIoaZFn3kDqs>QI@sq^4%%zZkL^8 zsF#LFe9IQIdSk$nr)uDoI-l6qgK_hXwZarkUAKI>N}vMyS}wt(m60jHzV%VCP2*^O zIG2#s%`KnvSLflX$TDV)qv8qsB7?n#{oiMn79i8qeqpVDyrZdhaCwXNH7}PA?#Hsd z%iIxds?;4{T()>z%W)IQcPV}yw>zH zO2`&kj4cHV-6S2Ez*4>7L*WwD+Dy>Qd-Rp(GE9pb13z*$g><3KKo}GSnokHB?@xOS za0HLTS|ERp_1YG#fDo~}VEgd**DJ`p=F8UqI3u6`-oUal*VR|x8~AiqleX|*vK>gO z{}SSHpZ5yin!03$RhqWbEXIO2RX7MH~Z|?O@k^5(hVmwX6k%D&;(4;YefCSn7lTSf`j$7C- z>q&D@IM6E3Ek5GEVOk{;FHpyIGK9RG^&{}=JD~qqr=%bMPT6yfR1Km4UpeU2)8HkL zFx4}jM-=v&dm-P8KfGHoOAn>ZA76HbeG>3FMcRj=y>^D8wID=in^#})=@yddc{ryaZ`fMD`3fog9P5wH6|fyETk8O*d=&_ zt#V-Cx=smxKUTEmTD45}>Ds&acoCHr#R^s*uwwZ*X<602JitKL8~*}sU+_(-P8VcN zZA$8P!UFVt>C+!c)#zF0iboS^C~%QDZ0f|_Ak=G!?xYh;=b3BgV;cAPPWcu=^{wI60cbmX zPc~5=^?nS@R-><;!Q|zS5WeuAB2gPesEMGxCbP{btRe^F&OQThQzV{B7aYcdc%=Gg zh{b}WR<%z}C*D3hlHi`C)9Zdw4d*0KjF;J}Kf;`Ur2VXMNLS(pXM~Y@Sf9RDBtMA^ zf~T}UAs6RK-WFN+!HC7&2|ND-JJ9~V@0LXLCd=7~S<|5cgl!qb8|hTJPeql%v0i8! z9P9tjux!0ezgdt`;y#fxfgVsGvYW(}R1_&)J(n&XJnF&CiUwK1CYn z_$QDL5f4(A6vPk)qy>Nn0&+T^^93WmSs_En5I{y7Yt2!D!A~=pzWAktKL9os3`QRT z&=A2368&%;|MCF-7#3$LAZZc;$D|>sOmpsKn&4=>zsg10<1^)Lv7BshWa@?blonpK zy?n^O-^xFCd5c>9#e#>Uck@G*X94kx4k597oKQL6)F`{2AMvSoQ9*rOJ*0 z167|vx}_MQ!^F#&=Vg)!;!?0}k^KoU{PH}QF)m`)mnEvwnOC^LF1?CQY*LWMGN==iWwD}2wX0lt7>;?*RrGfJGfWG?2G7ZbIw7k9@?wij|voFI|X&@k?MN8$J zPuQ@%<;e@Q^QOo@moe5xW}tm1xqxA9JGt%Gek4yT3w;%*c}~he3-yw=I!fvw)fGhN z8@&Go`~TllIF*MXD2Csp3N~k5s*zB4(9AdHp)_R-iJ`4tF_sLFyMN+Dx4sb(%j*%+ zo`q`<8XvWWec9IXOkO5y@-<4uzIDF*9WH%LaKE>aA3C&(&IxM|$5&C=A4FIdDyw55 z?{oHRj=&&KsisJbpIAJ7mj1(BfT!PRanlM~<5HNLL_9W@#L^y)7q44vnf7eDO+v>6E@2IA;_DP|+5VX(Q$Lk-Ft6mcZ3`Gw> z6A!l#-qjddJ9M!grVvT#*k7!QfYb0r(ZMcPZV2jQ&Fs{V2vo`!O=-Y{Ff%+kids~c z$3ZlYT z-4|{9t$sJov6j&vY=3!levy8+K*7)`(!=(LxfU6HJQUC8h0k!S zAZ^~7Wpmeo$C<4p&3MpmW-jwF;>q!|M?f$(Q_goFP-*R!n%To#t5Ftng zz9holVmzSxlO*r>3dYQU_&q7he8wP_-^?<(Z=~$nnj!X^RJ-A75A?>u0C*zkv&oA& z{@&j{ANJ45A1~x|y=EnVEFp|0pM+)@Efbu14uw*Q5@i3(8R3EDT0j4>@8Zx*!{}zk zWFa+4d7(FdV(TzL+rnb!w%Py61q_}5{Xg47|5gEdQFkHkf29&r$v;a#;<5!##d{;* zaDn?aU-*{Imu@>DWW4c2oPNyZN*~t;wDxCtGuW99X%0`?j_|s8vM{!qiOP`>Mkki} zi(^5?C8nkmXy4db;@3`2XXD%-dAPNmsKD~CP$IA7=h>Hh6Y=K*SP08%bcm*Swjk9{ zw!ki9Tx}D7!2}qNsN?1}FEI(e{&&jG#}}Oe!u2J#j_zE{l*n~zt^^Ok%Yt+&1N=1V&gZYu z*xePy#}e$S0S`p+pnnC!uE+ZWl@$B?4dh2)zjcjL%H=E)Girng3)5;sF|{hj$GFvkA(9E zl~7LEq-{E+$5mo?d_$Rdd?%qJidjc#Lbpmp|DrF|VL1Y|gvEIN_a%TY`u>({=x%ng z;EMj#({IpmPV7Y2$AIG)O{%l)hOQeT@2-ivZheFe-Q3NtyC(1V9BO|_Rj1fBoIGbN z#n-D}qr(2Ik`8fjm`v9ls#jr_4bUnmY>tud0{9sxgKLyw4L(TMF}c%~--3i#kYccm zsnkI9tJ*JGM1eKQi`E+bXhfyr74~}6pC@+#*U1KYCZ=5Gd#jQ;(&cY}C6?A(9fwWS zTwzc_ftBA6oV0Dq0w83);WzTYhXQa2y02OPp7SH6j-l~g%XuTCha#Qt$>NLGsyK)? zQb^XZF>XXmjbV5N=A%DNkBH!pb77H!g~hsG-Mw>bm16ni+Yt- zNgPq$C;8KPnh@F9dNYW(v=zbZFz@&L&>r^b-OSbGcVW@Kg%%|xYVpt>XQg}M7w|~s z!P=1o+mJsJ^pdSqw>?Z-8;;G}8%#Dy#ZaWl7m3BG@=rFP`F1#^0_+qynYzA>N6j@Y zCmOgXu_}v`Ib9lMmCFC|i+Hp!b*04C{%i}K4`~#Msl>a0st}ZyFH0xktz;gdfz)|P zv6Zt6vCyW0^?g#0)~L?D-9nS?uV~UrB9TfLkJj}`nfizudZb4;pS@V$45fsrcl*z) z>x*<>_okcOiS8nUn_!0>>ekkH=2pDPq1TI?{fezRJq5?G_|u+^lcrA_P#fE*6hc~? z!t9QC(fwRC|F`d%UhR`UA?=>Yqk!;PyucP=RDkuGV8L@Jh%0x^34?4jG1{NtrQ2Xo z`w>5!EcSxTs64BgbUGEP#*x{K`H2%C4iN!{b;JttxirL`Has^;SPCD&F>G4&-I*Fj z=g;TEuSt@>kz~o=gf)>A`|xlQ!Z(WS!k_F1)EJgy8mlb)uGp-v#xZFiP!m?cf!DG8 zsI>B1Z^3y2EY;f|^Ea5_z*AlyzZ-lb}K%a-O-qyoKPKhkztqWJwSDK8YxF~ z^5=J}`|-`;bSmCAc3T+C*HG^|ILd$ly|29$B8a<0bQ^}@P8zDpilrOPhKKXst7e3} zgGwXP*7rE>oH@7vg+v;oOoytCaE&Edxmf!b4g=UX12&+2De`NcStgWUW0yasU+dTZ zxwrjCt@xwiBWLX#`4jTg_`8=WO~#9l&27in0AeR3|D!n^t^&GnG;&{s&+^F3^ZI(u z7!T>W?iSb7=o9QXxBrlD>pZFW)3f86Rgvg+nR@YrRo)n}wC!Fk#?>pav{S`|r+Cnf zyFI@!DPQ zF4rO;gx~Jh6ZjaWMF~zkCUpGNsjH^_5f!%3GWFmz?yBG9FT>YOH9Af-$8IMw2Dj0aMy*Bz{#cHYt{cn~5CNpT_0Nb73sj@fojom!E5F7G zL0^)|Qw*pkZ%pK39AoS=EQrmrp%^ZUZ;3LBr@G(26}zK0K9YaO#Idc@zn26W;La+n zWGud!5lO^Ai{cF}A7BwxeJkDz@48uF>MvwQI-ki;fUOP_Gj7niWB>f+Z&Zxo+5CEk zMd>Y%V1Fu7@WS?s%yu6J_<#o`FU3Az-n?G%sx5hgzsGj@Zw*(>RiD|}PsT&>;mqo& z(>T_X7Rkz)mW}?;EU*Jy;d>trdvEC364K@LGS8oGW44oDy33WgX1g6!Py3a+%f+Fc zw|wyE9?b1)`0%%?)=N0|AEG@?XkX1Sa`l|WwRlBcOUSHgcGhy{u8z{z+HjBg(~GjB zKh27E3#Fsj)>$Nd)-3cNK2B>o9+Ys{E+OfvzWvf7u$mg}%H*qBob9xpYnS#H77xB! zjS0t1Cf5v=CGGfacw4%Y$?xCAX~26pR}s-)C?GbifTm$1j-hPdEL0=|i%kzs*PK*i zs?hQb$fUkWzBiQRs*ta}aYBfHAZ2}gLjD*1AiEyZe~d)N0WNxp0e07B8Pw=}yK9fb zf*c217`v1mNfrdIeM1}hAf4K`t?hQm@7_9apQuJD{T3o7c#bTvaP!tZHIbyCx{Vd; z92!*QSAJKbodI4bKmNN=tTwBh((fUJ74K_(F4SIlR^4`TXyUN%>)|W@2{qr@LmtKb zrLM5`seZYZwNCO|ou@;Kj<}IHW8*Gz6V89FQ{<)l25lvVD>Sl&4@CgKQ->Spi;N>l zxCtf&>%&*il5oF+cVVVDWLqF*dqhpeK_Dl2n1d$gx(_3S#t+-33J_nQ9_+08;m0h* zf@P>6!Xyy1Tulhzm*OO8sJeyWwtP?w#`XC?i@OkDlVq6N=}EZk^W@Fr13y3J0CxmF zmz(MX>hNh|S}s+Eg6!H9QBh-5mqWSV&}(!+qP&I9c68yz*ez{b^ieanx9HqeR{oNI zOX2wjGiph-)lq-OXC9YxlM{FCv-555ivLSL0x{DXq5YR3yW$uX46xb?hNR_HXI9$t zX@jx1$eUvM^nqyzi%A_Z5HEc@}HK2hoO%>A4l44|TD5?whb?qIX+`sFVg z>D+?H@?I;LsPA-Dcw_f{P3<7;0CLvy)vz~8)oV%O@U#pRou7YJ<@%t`zA>_g+?&7f zTtOE-TlR*>Z!dDXh7LHIUfZop^l#TlsnjSOPlh;!VZ(cV4jv|vRndo*P;LqMEl6-rWulkT z8Dm%kc>SJ{X}`npZv|&I4%}&bC8nLmOyJEep5L@F!e; zUVHsQP2=rP@VZ*h`*+}Gs;@rG|HOoyF^`o77l8k!?H zI6`R`rAkc|7@I^W`tLQ9yCXMYVgU~V+Y7!HO2`tRWLi&W;H1QtsHV1_wWm1t!ovo9 zoFqscBsji{8mB6hH7&xghZJ3Y>u^oe23u?&;@y3G?R5}I?_d*`Lc1GsX0c$vYO194 z?=90$-ox>{Fw<2vljg1JSIKB2_=lh+L9{&msRbl{f6F2a2BunUxuq~De@=hwZ zP?WnBuQ_a9>_4>mgqv%ZM7G08wPZ72;~lK&=8xSs%z_lX=<;Ff6M`cfe{t4wGe95=l3dy1`!aC!!^Xi)I-+dOpeEAB6kW>A`#x9Vn_RU4zKa45%5# zDdVj-x2cPZFqTv=lbKJ*xpUXMJ^rx!lnzcgO;PITEBLJFoVrbYasJA$u6+{~A28DdxaIC@}sA<6)J?aVCj6y0&EVL*^o`{kI>Gz$aG z)4F*w&W|-1_=w=%!3MH7TMEi!AotaSg#qv~{7x^M#(>*A!BK&}meRS2JhonmAIA zao3bszhHQTp;t^fn%USo5#mg}+e*?q`);eAqU+>wgU$US;mMk`EyEnh)9 zh+Aii(?{AOq_)3r~U zXY2Ntm(O4I&XqN922poBu9cnx4)#=HEO4}bL|@P0ObkE4cQeGn0)GGM!}}~s4;X_M z5%hxWD_HL=3Iifar+sI_A(eYcWhhGwFCMuwmHA0JhPMoOe-iVa)6v?V;D;=ngiU8rKzc#*g=YG60}ecTI!qjLeq#CC9}(WX zZQ}A{I|sl9ubasBny|Q!d^#9@DW-Ai(^<(o0_iWGK6 z14l8K>%=*XheA`xL|jIb6x1#&XL8x7;wI}(u(9U^e{gl?5R-P-yo9Xy*V&*nuP2Mr z!uRVfeJ}gxc2(^{3@)~xF0~EsC4M*0n2(J9v`TK`T^pVp&_7$&la}|*H%N}Y@XoZWB=U^C;mkg14302SzjPc(k&o||#@p9i+hrc}mk>Ow zCGqHKI*l~ZW8fCqV6LsOakyxh(90v|cIo0r(nUr2Uf;SM^=dHq|BnCYLozczo3J>$ z&e`vNav>JXt6n8pUB~a@#g6-Zg@9 z`~}LEzswbs7!7Huu$Ea6WW4ZFv$jwkrH(*1!1Gy!t2xJ&`k z8nY~;WbvTWxM()k$&c>-{=_ZRi8ZXu8<~~1OnPF%v=cjYp?A9!B;64@QC(@v&4(z1 zDVp?pc>)fS~x5^OibtB4@6UoVVu~ zo=QQM`XhULq=pbq@_wk0Cq3i z#Fi(u8ZmxD+sTByKeMoMl>wvTOZGoVNpXsZ>^(wKS94#i*X)|8r((IE$90>y*^OKfbL9s zpBg$9zh6=)K}_-w(bDzz9*E#U%TrCi_)%KF1cJ?5soCA;c&194niU zQD#Pzm1FOjksLxs_9&H2#<68PvNJk3vN<^R!7=W4-`{;d?mztDpYzxCzOL)_T$eEC z!?mpOv^lRE(Js}R^~t)g<|rO0$sZ_!+#XD8_(MQE&zE~vjCAf7YoH&1ATD9=Wx*M$ zN_&&6PYvtoJ9qZm8Wic=TWmvJcCrxUy@Vx9(DwJ! zgS*)D*>`^_qwStZ2*TVB?sc_jyHu^(5)r+ia*x2TkUZlQZubp=VldU79Nv5n+Q|1a zW-^ip&+Zp>^Cd=7%b{gC;S{4U1uZ2;WAaGbw%xPb-zHxhxlQf;g!`bDmnD0bb(unGlUR|4>Dg5NLBcMsxSo zGpqAU*kgWJr#N5W}Z(38FnS)9n1@ZmQu6I;6SI7c@++%riR@b^786d;w{R4q8JN=8uENL z^(Er&ryQdR%fHo6Ebie2`SMuGb~0?&1w<$LlMciR;x;>mS;hJffFR>S6azEw#(2mM z$sieF9dEyzJ@Carfi8sF*yUvkvua<6tGK-`>u zI}LaC>9Z)hdnHyLamwZEo&X?bSCrKbf6BxkIZAYVq=)k{@ErmV`L*&9hx)tHDRf%f z0_=BC^wP0gzc^%U+{&@RrX~N@_jj-L~a@I*3;=YZ6-f5VF`^q1U7BueF zr8|AQdZNRb8_Vw&{FOwxk&8b0Y6~CHn+-@GKIvT^Y9YDt>wz?cp4Cq zLdfaHx8j#r()&m*IEbg*CCNX=MMHjIkX-M_BYIHb!slymucMcaiOD)bj>Co^)ja9j zCpW@ujDd5Gfb+u6RrSgFj}VEoQthX}I#l?-b!hPa)}iVDe;vyFw+`)1e=gU$=fCi2 zcV-mj^d$1<@o_34+)jquy;I}8VhhpFiv_T2!^YNRQke=!MgRT7^M}U#YNpPlB=sQL zK0Bl2b->cC=(*{nNA`_~gQhYr&hFDA=P*aW4@a}&kug3xQ6f9fJcMqw<>CBE zUI{x%om1Ci zPrI!{k9!-JcBZ@QnRA$%6|e2(1F*y%d$}o*bPe+-clY0l$uqVeWaTC5ip?-hw_qQy zDsRM=?E1-bp6)4}=d9Nc6BO}pORNo>DqbfeE6+K?8p<6~Tc|81t3BprVPq(QzhLhB z^gokKn7M_4l`|s__&^8AzkxZ%eQtp1F0LnQ^(5ZD`1tR_qiTuQ=D5%MLVGpU>hbWM zI{zTNQL1I+1FkE3Y8I5!n01F8lL+R%2mEP4Cz5 z036-*uQEMnPD?zo0Hl;-QX2G~$?wEGp9A!_*-CnjO>X*t858k(-CM2i#J0+;UBjPC z+WWI=s?z|m0cM7yE`%)a78&GrQQ~C+H`9;TrzXrS;`bgXqaQvU0aF}6bp)+}N5i)Q zjsXG3EKMP15Jk@?S(7lTvgctl(SL(zQ)2GCLXn1%Ah{v+<{;JLd#MMcUN-U3opC?T zl>+SP`0c9h@bDH{WXL^$9Np6z1VC1D7|li)O>VpJ_Iz) z`+U?Oriyzw&TAgAy#IxHJ0Y*-Z=92p*9E4%8ZgUJ$t`u3d~{billt?oo24ReN{FWL zbFQHj>-&a1*wOX=uLu@tP-5@CnuO&5!$|g@6&FFsP>I(eqa_^Ea6%WsYRaI{eK<9F z0sqU@V#JQynR%qXwABlG>^W|V#39{_ zE{9G^euP$Z<)Ug@cJWWQRO_)}E?wK_9fXe{3^wGtGF};c(OBuUaRjWI4TmN z>eJ5rHwe2$lZU>YO+Gajt-#mjg4DKcM&xEH;>f+Rej`a9++ES7r%<(zwxm(2pj#{9 z8>6An73VO3Dyff4JybID`eVz@fn*P;YyBBUf)>jsku4M4ip-*d@T`6q-DxI$?sYu! z*_HW{$Z0n@3VAyAO6n7-*TDrj`Ps19vNb>&XpR}>r?Gylmn%?@jFnc!hfE(lQ zT9i3fEp$7HQl0b-}h#K;>eFeDd&|R7K%t(&&qvweW)9$%-&^c z&6XZ?P22ksdabvRJ5rXr!A@PJQp!liMr{B^h^k~H@!OTT+P^d#?9 z{`a!EI}cZXu!F)aFC%ilTF(&)$c|tWhB+>bg>9|dudF8LH>;P%)B+%8d-)l19o$5#giFPQbe~6&&EjED_%hM$%y@xlku;ogCI^U6lsMg@!O+&;huMzO=w!6evnnkR^r4rSU7l&|AsBEgHG(49-KM%s|7ws+*Z+8v!mtR)T$T&vt79@r$J!|a=L#d2Wqm%qLm z8~hIft@-x?u&BR!HP;M4phaUs=>T{Y$&?`*aF1r*(==!mb5+XXe^IUGD^CxKkf-Yl zlwEex7jHbeL4oN2k#xz8hv*8tW!4n*1Vw~zw?u~va<97R>m)i;_-TZ0-Oh0ZmIiK) zk#BG-8&x7;?RTdIu6QgG+4L)s@hR&FWtg#s?cF0s@{6Y;UdZrv-|LYeFr_5^>^rWZ z;e=w18O28zB;*;dkI;b3*%JRWyZxp7?w=dJ?Jq1FL}yKXAdrwKRoAkU~mw6mdA6uhLa2X)DnW< znd4-|*ABJhvHD+8@A8I`*Sr#PbgPET(8v31WcUF z9Dj6N0)v?dxCQ!&zSCBckAyE65&Vmz`T_3{Yk@J9mcBF~o_E1p>%1MKw7ofMf zVh!5Azd56eEcs1a_J&UH1Bf!yZ&l^Pp^|7}EhP*T$*k_qgoy(K`A8ixdY8HDp+)Rr z5M0;!pNx) zVlUYO(Y~UJ?k(C>rCy_oW62Ll?$M$+>}PnedH=2e+z%yhAfA_=Tjd8=-qim6R#qGm zWEnHUZ?{K!?1QcJ%KEZo-Fi~cc(xQN{-1e#qGF4P+u0oH{-T>J!YWx_e4m6wML$<0 zvLT0Gs$Rv-(yAdR*fY!GB^P$A?of^P8BX&id@BjN=IWfY{L^lkY9Z<$%iq*ue><9! z;1j0ih7I6EOQfUbQPL&m<1rg{J74G8(`o=!O6cFo(v^y*aDz}AOrmdGvsUv>} z(3}Y+(6a<5`E&age9uTXoAnf9ib~dU-h65YPpc4Z?wQ*5$lR;l-?aVJD$zMg6n-l+XTjrW9w|+J zBSHFl>#A}hx{o^yAo{S|*+C!R@<&Wa&LqYbEx2NTV$dzcXnO088}4IH4Mc@NJEVqH z9TaGH5TPGv$jUH~AE(xy3%)GWKTy1MCu1r2(O>=9#ASArw&p%aJk(2i4KH}z71ALU zjK2G)(BsJ&7tObkty1WB3y$X0>lY1xFAY#XzW>fh_ss6;@9(6;9#l_3-7;51ddDY# ztmV6i$?QK4x~{_z)a@_3wDYqdmt{+oEUxZmjsd#RIV}7$S00n=sH!#fBZ$J zE4+D*Z~89qnB1}OeT4)9q1zBB&zX^c?X%g%a^dr**egSki=dNQYV-apfX0zMF$_Ru z+P`o<7_kZA<%JgIIq;s8*e56d{OusL^448t;q6T-YFaoj@M*Q zj<|m)IZLC62f{JZ>kC{=@vZxH2^dLxgbJlW?mI=9%`Uv?sZzPiqG~)A z8Hdo#yNEws7sD{h{ioA=Y02CdYKE_1X}a2)@f`W!2b1B-ZQH4STYg840dJ0X)@m24 zu+z*_r7^(29g^L$I8sU9SQLy?=ST|oMNlnxyL_#+9Keyt!~OsJeJ|(U*}ZNP>Uy1| zuw+I$_RhFA^U}KkQ}v*?xl8S-B!)?}L_7oc+CcP?=JzI`OOuQxrft~)V@1DhT2*?) z{E6w#t$EL#10Ae_a56x|eN~0eqn`D<A49;rCKJ$xG}^&bacCC0}fJ z2bdmA&fLxan`r_KIXHbPpP!-)-`d6|9#30wO#{FhTtbRaiwNy}JIb=eUcj4aPi<7*Oj%GKR~2>CcY8I$3_ zpDFK;JMRgG>aCpBcCJSSpXDNR;}}M&Zrg62!K+PfFI%8T`!en`;ua3{BSmm9xh0PM zz^7Ap#ro7u>Psr=G4&z7_q1?e{C@?9%^%wN+e~+}u={<9!tJ4Mi8n6>j=Fn9`Ipih zws;ke8<)tXz`;0OkHeGs{gcE?6lqDHn4i+z!4$qvfvf4iHV;&5H5S3o99gUrkX2B5COtc?&8 z;5uRz&yp7?HWjM8Nc`hvyk5ly9_O6fXa@FZ0+Gb=PwUo!NP$oh$LC>Pn;<%sbPJ}+ z)>XT)1_$SWwSR~S(9webbu{EPjov0Duz{t1DwU`JP%_u>v)S240oTDE81hj5+>35& z(I$K*h_05SE@PvtkK45?W8$la9`O|AO?g@=FCvxi4Im1{dNvb8JQqf;y9&HOqt{l# zStj?ai}VvF^vQs({NiQVOtMnW4(ltz+uj(LiHsB%J+$Pbn~{ukJD^xVBIc7FEneyI z`cHKAsDutvhQ@-hO3m+DR=EH`h$R;W1YSchYNf}!<&j-R1hzunQdD< z7^xV3JQuXn>4(<_qU&>fANc5d^@&)^iG|=09D}3&DM)fD-23l=5(VimEcqtQ13Vew}hI_znlG<&!~-#v*$nqyM*w zCIA1M*esFtC^6i`pa zfj=1sKZf;`F*PZB67hk!S(e$tKo%Js%%c$M4u6P1HP{Ja!28hq2a;FCGeIyX`>zPA z(Vurp!{qDHHjJuGQFn(V+dmMpmq3pn=!3sH-O`~ET?CH3yd9>eadC;E&JoK^l#MhK zMO8-*8QD4`v3nK&2cTX0FMNsfTQw$MW^Tv+HJg^CxAqGNz8EtVdq1+!D#+)U@$|mh zM&mnNb=AR2m&D=fqZ1VzY9EfC!+577TbK745b5Y=xo88&rLOM&s6>9#N!=M`#IrY~ zxVVqO<(>ly2WIsRZ59&G4lm&x;?#(h1IOBh?W1_NzadBsGusW%BZJtG^?=*}!L#1!yL}GUxu~m?U;e1r$ZV=-A*(elh=c(TKscJphuX#x%xEzcmVXAei#SJR zxS+>>9GBFsPP{ChawTp)7HNBP2~B9<2F}^-R6<**Phr;{(Bk4mWFB&N+~>&H{!H$# zEN?IxN9Qx?>dq0IhVlX3Tx&h(tiItJ7M&Vq_>#SP)0nOdXd@frq2Lv7jr1yNSoCo2 z^P`8Rs(W@Pn;W`81LR{3ZKwWK{CHxRT6MBATH;=03sJ1?P0r5@AOvvAmpmL z@+DSpF8C&eA4FWxw!T_TG)ejW_y}s~jOvDiNRV7>C%=k~9>l6G#Plk0p!{1-ck%B9 z@m(7_?AeQ#&IWZvg}<4M0eIo}I8b8M&)eb{71c>$WmN@2wLqtOLG^d}9NBCN%D zLz+`H5XrCn$X~WVP<@>S4eUL~R2-}9ExVAk@Ukh06B}6^AM!C?a6NL3=%2!m73{&z z4{Q-L^6sqXPcYIx~9%DCrI2LvGo8b&BEOArE- zQw}HN9Zl4t6y%v@>E)Z~<8;o`T{z>fF9W&Wr@JZ$l_DB1GmrP#8mu@2X|Tl~YKfHh zS{YKprsy9SD7}=C?!Rv;b?e;1Em+husITT`i#Lg9*&JekE@}v_Uz|jbvcI-?$${^r z9kgRip5mu}fFa+6*Napx$(*7*fSHt+ARC;ah`c?xv*EteLABHvYUq3Yg{DCABcYz~ z25H^3PA`?)HL?`=F?h77b_af(awen`Fdh41aG5)@VWr#Fg}&V->Xq^47=rlaE*pL9^N{IHN&mdhdZBOUt`||bcMhqi zHd!cY2{U&NINo@x#cpGR&<|`;0Q``@3psqAQtrG;Oerui7*j?fzwxYu;szz8^9K1> zb2wMmEKS&kan0a@_ghy?FOQ6auIqE$?XuPo)v(?lAwQWwV?BJGcWhLZWI)yRKC0Jj z;cm?KpooomSu*eW6&>o~jlfY;S*RkjXEUFjTYCsDJ`d^Z0uJg3yebgfFAq7XyR`GL zk;hFC8A!Xwm5T+cwQN{lYaLNDqZ#7mO!NLn9f|&rI~^|k_C?%|zsIG_NOaSQhW zktB{5K-F5EP%vt|L?gqi$=@ExWPKX({`;rHrR0Tn!P*$_Rh!BGynxDElr0J0`rFH= zt2+}h_8c5uAp#0lei&l^fDf^yZ$3F90Rq#3|NZ^K(pq{pnUKL2JEK6Oo%F&z2n>H+1`b`Zjpmi6_G7*Bol|I;u(VqS0lOa3ykO$RE+Iwn@_s=hIojtU4Gls)%S zG*HPyKM`vupCVX=m3EK-N47W%vH5YC;3xZJOK$u?-$ce#WoiGQ`rik-WqUG0AltV$jWdl!-q4sYp4ugLr-UBwHxe2`g2#-i7*2Y$ zRFNQ<2HHc73>J%P;Uk80`&!B*>Gfo!AEws7VBg%)I)j*P=D)4*C>BpTPy+V0cpFD4 zCyvCrJ4?Rb7oLO1=%AJ1_l3Fhfx_wln+x#GzzbYiv%SR2U0FOo)eh6N$_R=@dvV-w6Kpv=(APkl?ern~Wn1it0mu1Fx8 z-di(hw;(}W>KBQf;6W5A+ZNvzawv#X7n_ajnNH zsolPIWq^R&AKBXkXXj;o{>nu3u8f%N`tYM`MmgKVx~%2g>i|c2A8h0+;m3b~Xag61EAaEkPoE9-%y{m#X5hTi>N(zXDjsU;pCJI~Z zrmD9|nh5_*lOr2I&t?Lxx)UM7JRd2-T4IwAxE`)fm8`%Yi0L?<*O};}Mc!e4pWUFy z-{ZX0SG^(!=%hy>o+`5MJfB1BBv^s5u*(P%hra}Muj0=Cp4MAiU0lBjNcw-5?^pkW z6I`#1+Wi6DLqL`>C%_fOGx@F6wxp~vD z@8kLlWp#KB3)7XE%lK%8wu(YdONE`BWAXIQXZs4DzA(%1z(El^Op=47FFI^o+B?Wz z5CL4}=!a|&xuAX&ToL9{>%`HO$uaS*oaG6Tz{&`q452iECB@#>;`Mw-stm5`2utY{ z0@*%9D?F6+EP3 zqsS!KbMog!8H@}7nIF>zVh!h3`$Lg0mma!0T$O=cIT1Y(K%}wvv{CnldTWSE_jw;)XKkH5aRpcYnC> z`R;!h_(0pn#!l=A$K0KE?{)9}8Ou+n71nIbb1io(f=3dF^w*a*(Oik@R|iTtwqCGy z>_)oGD()je;u7V(J&)OluRxQ@tylmH73V0}2Are4l~0jsF+-tXmh?@?wv==$wOI*5 zY%Xqipz)E?B8d&A;(NGwFO2L!cJJCC(SC*q=9fjP$yXizC_r*nND-z=#AgraYP?;) zeZ_{deKoqK^8WqUh!V;YLl06Cja38>T?+o!{bPIGE^y~-$atZ?C^_s6!}Hqv$g5s-<(ZPWH7{N$SpHTj3L0~;nl$g%w?j0shwoxH6vY&U*lolC6 zy~%*60@6oZ$^l{k-F^9%74#&DoTofMBVgbmjcQ&8^{-nRhMPOguSlaV1SW#+&9mBs zcits7l;NWz%B6bnQ+7MUi>0h;E``M_ZIu@;y4@J#?*fp0>nSzg04i8*c3z6hz0t${ zLLYN%X3^kwpOc*z1%{wYI&0pp4y=;Ro zODmTRk-Lje_i3tmuc#IJoO0D?uIo|`uSR2F>(0qaf@{kq(FHhn+-`G!W03XqFYUSw zPGqjEhXDFFuN8#w0ylGl^52qihmtF=I|6CMIrpd4UCf|tnBbl8)6diLEM3R5r-Ne& z`(u^viL}?pYouM_uFgw^)0TIpTi5w66<>GL=Y`zgz20jTuLvS4&!#(yZ^1S1UM1!> zX`eS?XPUSUte%WqCVBkrQiI|Nk*=aDG>{ihl#LAr@eoE~u9<+p#&lBi(cK?Xvu%M+ zW?4h*2M8-5YU5ZE>aQFiFl?;yg7(6U!&~W3u#tGxej|(;vIDGMrgHfuTq&&ZsRfP+ z{f--ymx)Gl8Ctu*PlVCBw9Y(Kk4# zOv_XQq5G@H^aRnHy_EWm2r?*QxDO0n#Q%owzCzilATV@$bd6a}MBeEABv|Qhj_*HCdm#G1qC`HmnH5kGs9BQRla%E;;2V zhCCp8lQDTh4Y<=O^D;glHW^G{%j= zcA%U65I&AakmU%+QO##<@|*BD4rcm0_G|l+t0^2?7GEqJpOv;aVO3jB&WZQi(iIsZ z=RNnol}PvEg3gvs``oGfRv74kQ%@<$bI}8!0|{|TLB1OzH}?+%pY(~Q&;G&lkOicl zwI*dX=+0u~11ho?Gw;>iU6Ll8Q@)0xDc|XY-Awu680l2|IX9tAvt1@jHLs9sf;s)} zxzhqOs$SmoY_C-^R*%psxA>2PlqJ0svWk;R-(8os~K}_VKP( z*Pq5t3(~g}aGRm>h!NjzT+r^)>2he}0~#*p;6ABK#R@fwVPM#@|fnr zu+m=6276q1=(+6$Zs5Cv=YxO4)BBMSmtPYNjoa(8^0*uI7Gw~=$54KW%AhAwv!|Q^ z%HNEJ$*)kOb|9u=5!M?G44_Ux%G2J__Gn(2F@$7n02I4WXEz2dqZcJngaNKoP>j?E zVj68?I^ySbRSrQCpZ?lj_3P+h#5%w*rE4J&kCf-JL0!jxz(Nk{KHSEZ0e;e1W)p?o z1{W$FPI)>)`vtAXG~qg!J#>3aZRgEE(?a9Z^l$aLpzbqHm)9_IP{BvKFHNA>3hf_P z#;TmLw_H0PO&QJ9M*0+wm9|SV22_UGJd`O4-zqjyNlKNl26w`Itz-{iN%XcaYvd3( zL5Zk!5J@TXmmf7Rw9JJbDbt6A%hQ&5Wyp_{0QdbyLA;BphgIJ z_W_86UV)EUCpRP8!(XB+imcUPW;Rzgo&U+q&*tLcuuN3_+75A|LMj%{iJ|}2;HW5y zQM_}#1O!2A*Ka(1k9tqB-FCF3|blp{VhuhE;vC)*Q}TZ_m%Ig1m9;fm%k~@l zc712&u|c|sxq!d`@fg6OJZUgdtZI$$Q~172JCa6i@|X^44ITsPU_vavNgCwi z)ysjHJJ^awAq8ikCIQP%X$inr+a3oo3kM_vL)Adi_%YDM-Z5oxa3$*hVM7I6~OtRXb3ktp^ zDT@p)hJNXN_>kSG=>c$NoIvu>+v1FiL>d-GQ|4y|9DkI4&!gl3*M~fQPRV@%dJr-3 zpq~~HB$l@Mf%u$RAjDp8<^i=da|#2&;GiY7-j_qJzg`Al4kb_={4$%wA!Eem9|r4) zB|Q$0$(PlwPwnR#SN*oRun5f665@0yIBPjG^4Dx9Ph$3q{L>H2M~O_2vFY{TWV5smNjaDtRA?$r!uU7VM6e{sjPn{plHY z%}r<`*vOn5T$WS~6LPWs?0s(2|BWf7p7bDOd4Hs`2fLZqNBXi)yQ)}FgksQ5@@^8Z z|E%JbcW3&!F><-LbbX2&Fepoq6zKHr&ZS) zSo}*=>DQY(X6?X7Y@!ZVJVcf(e@KKf#Be=IHSw-HFD85>eAPaeyrhF;>ZxVC32SgK z5S2x}3hl_@b6QB}@1Pil?tp%uB zKe|~Zo!YxpY(mrcM95(e5m*5|jxL$@7|it+xCr1Mouuy&%-dtq|3vKPaGgdr?j5_k zEPNLF&h+)UjnqGhe%TjyJJJzDFrBbd@nG@)wl{4s9dp@cTzpJGVS9Y%SF-nPaz!eo z5HG)gB{^wdwacFifR^muclT|?aQq^xr6?_o4n()4QKTUu!#(=KaDab3i+K54Z&h@R z8oc{iIQ>m9<4SxO0%uv-{ekG-!>KIb$xs&d3mhw$6ZHW^vH4AGC5X=C$x4tDn4#<) z$R5HpATS%Ew4`Z2ba(B?PbNh$UHlrojy%iUjZHHU-J-|QN^~c`eB}7cj$#$-HaOkS z6n3q?rM1b!m=5w>l;U^3HkbI8e460@=4B6X)@{25eyF-YJ7n&9JU0)*p8_QTA^y(z_Ni3USB# zqRsj)M>E-dkyp;vXv4KjxpnmKU@@(uIjYd>?KFc=Up-f${3j=!!yaGPqvR7mV&9#o zl}`0Gv^&^c_~FVubo&+_DNfxAJ)Dd_o!8%uEum#9?bbW0DZYFDA8l_B#xoPHJomNS zL2`dIG!VksxZ3v8)(U?MFNhDm*vva4*yr936h#}U4-@idL$+}B1teZ??jMDv&n(f3 zjTco(fU61?bmTKlgu>*s$(@q)T9a<_h7rkd1&xymvw;NO$nrB?6oMy-Qb$&u6$MzFMEGOR;^D0CyX}DtVCXYC~E||$)@{o?I1Kvg< zTyMS!%q3&d7g?w$+}_ohnc-gs>sZgtc#kYOyR;s?*v)j>etmqniG0kSogGJ@(G!>+ zUTOnQC%DxCihi5p|J7<6&VH8vJ1Ev6tHjVmr0V-~8^D7vPsbJ@xBNwf_)b|Iiw4b9 z(XvFs4OY9?eCFmxi4~<%%G;@LPZ40hS*jmqZlolCfk*5M9`gr}i)Q{Us_&n!lRO!E zr9qKH9cB~Mr^)Ivkv}|^ZFQ?t`X(G&d{?f{`0gxR>#?UcnLm$)awc%0s-%=9@Qb4W z9~bo_R-Hmn>>h~029yb=(D%6?iU0iiQ@;($L56olQ*Su%c2uxpMi1XQDqx9G_2pTX zD3l2o)G4{&HLbu=JNV@~57#812w(U2z!d$0%KY&g7*9!@3(Jj|BN?&(tTU4@DZ?Zo zYnnM2KrobL3YpYhHQN*RXS%4illYi|i+dmV@fay`d@(!NB$3xAiiPI}u0ic8iiV5( zfi5kKIN z*z*~)iLVHmo@&N=oXJ(h7sSWd9c6xKyjHzbe)n zd3}su-fTbXF6?6ia;KNdBVYN&6-tlZ;xIjM{Qk3yZ1IgkXp<1%le((_FKc1KE+XSM z#J{Xm5^?0yll?fg_Ri4j{ZbPJV0R#`E@z}uB@SbG#o3Hx39}Ibfedus--J~?55J%V zQM>fmkg90OeiP>e7DKN5Z+4_0aO{I`jCF#7ydh`eIb|V9NK@6~yE0OFTA$bSUNJ`P zzh(G1Lv4{c1X6sH=~%Nmz!D(p2bASsTpS%e$5yJprOUflNdF6NXr%vkQYJg*VhQd9 zo5fe6xu91E{7cc^Mf-czZKoUa=%s7qvMq0MS*gN1oL9B!u(-@CpD2s$04g}aXvF{h zn8z%}8~rR%W zd-?Gy;zc>r=qsfZPw|5lr}KG}BVAw|e+~YCtmKi@zy7kDtoV{+Bb?CQN#;d%cXS;_ zB5$hX04j{UmMvb>7)~WteqlyB23$>_Ei@ASB+LW48xTY4hN2Cwtha{Ej}yM?B^&#k zFiLN<01IBBwkF@%l^p4j*DGx`$8*bHpPGQ~q(wAY%pS&)yz}W(y+<(yeLhSAE6Z8( z^;4=%qzaqq!Cge?k-_9cT!c7f3$N`Z?mW7VXuNU^RjhBlXf6!J*6SD9-5yJQ1?P?ccoW&ObmFigCZ z+Rgecvt1UDkFfl%O%~R<731}d2K1Mab|pMaLt}uHCERV_LPw9%%;wEt6l#&!Ti*}9 z*9qqnnlf5fc2tnnf9aBSm`gpYw?=g`^}?O&h=2UEmd05C$FDPbF+ZWHlc8SbTuub- zRhQcSD*1HZDFb~X@PK_WC?T#X*mkHKnXuy;ub7_MK&_Z)Foj0=RTwqp0-L0mYkbZ? zTfNHgh%HhZy&izN$X+|F$!R~{$ki@%PmODv&~t0H$iuf4Q>!L1BLPXZHACE@^#)({Fo+qcAv#2tVkppK}+Dki)P z9AubC21T_UoT*5Px0{}EAx6a}B@shv%RuG>`*9#0fP%5KZGH0lB@%i|p{r*=ug%oh z(xJ4P65;NF1MHnVlt`o5#S+8zmo!1Y!yIA;$|0mv`Nv}bj4td#l`+X?DF6u zdyDsS((xsRWSR@!voJ5GQZo6Dt`OeBm8b9h&ww{86!`aV@0W)<*<+ZjURyJX%VSHG zzGpYd`R|w;vimw4Oj1|P`>Ee9;g=`TV0{F6g3~5lHRzypQeE34UTQAWB+)(TjaonE+`3Ix6dKS z$f5-1(?$A*RTZY!<-O6jODyj%&G0vT-gWFCtFi2?TM#)+ZBv)S)ZE8UGzjm#SpM_v zbk-Eg1JgCAJC*!Ry-6!~2 zf9Q2Yb&>-UC5rpH>wI97Ws-@y?W%V&Xh>1&L$cm$q}I@!BvZnNms}r1t~xdMcb%WJ z{U^ue(&?vcs>cZW(C`@0irIMoOVKGC#yJd>01F#$))&L4B2};d2fzJ%TY*3{zAy@{ zx>Lx4fBg-s7fz>g$3@E;J8qpPicLfuzs)_|k`voWJSQWo_@qeT_h@FCqezy~P6k9Y z3f_D16#d|()W$C`h1br1gla|A=ao&esG#( zqUy8_MDOqu(do#7$XG0hU>W324?vwZsqt~FKB=7fOWHZ$`^-UB{G80v8$FKE;L8)|$urCzym*FS`69WW zKSbGHBPPAY&;MkWi_>}=CH|&9S9T?^E|6(A9h$Yw*SAJ4VY{|s=D;dnuXa6LR~`@o z-L*!{*I3L;;WeSW!QXWHH?UL4ZbZO>=?A4<6rgXk&B9x@GuIFJZy`nS{r+1ah=bOu zI=$bjGwUOid3{~|P;H>Ps0lp}i}Bkzm2-^HH>(ScT74?JY`AyUjmMu(2aBAS`}+sZ z4bM38L(df8nDvklvC-zdX7^^=HF@zl=!a&62qbbDsTX&19}J{^%#OkeAq`QAAlE=I&QcLzb1!Gl0JBI@uj^FP+ITXCz?EB zKys&%LAK6mU&JF8wtq|ad}>?jExU5}x4FNjK3UR#T}1?1+9h80cw@U3bwLZ%Ok2}G z`pS6zpGf0-doDd!1wbdbq-Yre9D_T+kw@j7shN}6sc;dGL};i!GcWmIlOkJr8&q8_ z6}#JH3)#wINn+|2ioFG<3!|207ePop+AA|xxtO{^aes^}QTJ1>SLrbKt6=eYGvQ{@ z{rOK#?mA#V*YC=3_95F&mbFWmm5=0caW<_m34|iI-X)CQDJT=MJSiWw!<-%a!36p> zx8AmxUz79PhNzc>Og)lT#|Fff0V=4y*`h?TVv)O-(LgJo!X}FSYUhspPIZ-L^r`Bu zRp9sOW9GJkZ>QC`Ac;2IR#y)7X?evWwoC^caA1=6)7fd z@oQGa@T2RFqa2olfBn$UA@H!ilgpYO{*`@1+W(a~dR_i_iV&_AA8SzW+A1f~>}gq) zIBE)qTcaE*>~JBEF+&5Fbk8^^_P5F1TeX8m*8%-Vnl^SD7l;;ZPxn5}IPe{kz%&mL zIFT*)Yb%T^II-C zl=gJ>YGER`XWWo{4ZBNN1Fg3xh%BVUsE;-5&1qk=0_Z*1t++cA{#KNsddV5Ec19Bev;$uIW#s^p)-=e|CHmD0;^(WlX{Jh#M5U8K_<;KDCV#20^YoFiz~B z%q|GNgkO~iVGO~p`q-gzG%a=NIwd6c{~Y1J$Nm+UY+DQ?R1N4%jo*a2zLT2hQpq0= zm@j{TUOw4wyPV*cpO__ru%Fu4B}7LBe=-y>_`F5`y>^S69C8Ew=|I(G>`9W{BS1?% z z6x4{x{AIqSJU-7y*8a3fzq8rhFlG+_b=^nCdN&=ZZv`w}x4Q!D5k_r^tD@y2!EvOk z5?g$P=Pwtk2YuGi>7bx#iwC#e3H1v52w~rn%4 za7lyfJj0bbEYlY%8P=ij)sr>o&1S|Fu|U4$3~b1v7iWphgfX|r7;7eW2}J`&Nju4B z_W%-I11-)FE$h!|2L6DUdfKYpC)efqhg1R--E>V1mo>Wo;sJrO`5((JbZMIwPzG#T-yMYf zfd#9X^dxAv)&m{5;PX=@ncUVWG#x1N$W}yr9jTpJJB760eL$W$lLkbrJVvSR;*ZYWk{1s?LHQB~cN5n)1k2j82iKBjbz(|DqDs zQ1`s8vOJjj9zJGqvzG`&u|*97vx4q|*cUVVS%Mx2{2jQNg1q1%3#wqwTA~x4z8h1e z?pPB#3@qUBHRhUVE8m^A)1GOC+v%?%*LzN(7jJ70uhY?@*c2m9;xj{h z&G1N8pP7YqyN_%eKDu>3zB1%qwA}hK>bd2Sqpxc`ZRp8GPc;9VBW7z8iOmtht>9Vd z4G7+HWf!}%LX`zB37_GyAi+W$DX;;>TEy?`|+pnyRk^sr++dslf#CIwOJuh zp>EOBe~C=wB2^5avGS;7X4>OkMBV%QYD7%3(!P&fKFkP<|?Lj*~Y25AK;NofXANf{6j zkp>m%9;8v}?izY1fuTFk>woWk_Bm^veZH{x01FoD;d$=3uHV&#t8$O!@v=Jc8}Mc3 zetcUA0$WJ=xy;8IM-#T76fTgwQcG&-MeZ#~<+GZ9;QZf$v&b>wK-G+EsorgpM6#UUVW_H-bKmbafW;Uqla9X}2!HZZ={ z?Gt9@YM?1`JF_;Z*6mA8&~|2r_iY=SKL+E;et}GxbzM#Nb0?@YN|X9b_8Q&|x8SKc z-*D~5pML0U`Z2@k`MRJNfy+*4+ne~+ona~qj%=2yP0t3t^`h620^Si)LBl~aU@q2T#RcFZlyBa%k*Uiv<9*mgm?0Tx&7bgqp{h^W=gU%R9LR;BYkRzmu|;JhD9mo)nA1rxoD6 z6++App3i0JiQ{(rKX@&GC3>2HnXZbF@WEB3oYA6+D z-4cL@>+K*j&PD$3mVI8O^+(O4|E@YG5%1cXR8_T7erqfVVc@!E%b%b7@&Ku(<70eK zfQAPsIZyiwgJwZ}W5Db_2pMAx&@@=?mU_bwK2PA+wVdUpf_hHkr^q7QiT@SY#e$4g9TFWGJ%lII`*~QBjsXflk ziO3a{r+VOq{oJwNRLNW#_EcEk1a&`?krk6+pOZe_)#Ee3rPJ)-v#VN@o%wUecu~$S zL!_qbXtcQJ!d$Rt^X$s}_1}vY&lNjHn-%|qrtlzN`M8z+iek(%IL&gCPInP~u#)jcL%|J-qQI{D|{I3niJVL6EBgbwTuXXyUQ`^}=Z&1z+RI%|J; zE)KlAz56Q|pK*{PY<||fv-nR&VHOzK#9?pIeC`C0d%SSMV+G)KNgcn`UVRs!cW%^` zXn9wKRr2~vFv##DpQ+`FI z!1MG#mr2*6!&^f>tC()OH&4yUfv;V&$l#j5p7~=_ymsSc&KTKuNiccAGsJT$?JX#Q zFh1I)Sqb5>CUajQyHPQ}Y;=qKAl}fDUHJd|GRDgJhRmG&nvaytY?kFBgBc%|y`VzC zcR#-mzDTs9{4!`c($0u;=bQ#mB~{4E9H#x?=+cZ>JrpnUBcWpK3*7-xsfonUIMA32 z?#+D?J{D8+PT=?=c)IY!II@(uE%}-LNIucz85U0Gw%Ia)Ot2JDsXYa-}>jIa2@}4Rk#-acyE6*w^R1N*7we zG>vK_-L(DY@AfRBXZe!#UiZp;a+wk8tCulxMz36dS|t8u_R#Sq_d&e_V{&Z#h1tYi zX^=wNv-`B+V5NSkFBuzsfSVWoXgs`7Hj4 z#hx9a>rV40C-wmdybKap{mi(O22Rhqnfzl7ktsXKKgz}Iny7Np)@=k; zHx|(Geo&X`8cN;0sxT=GB;y*NaVhvK|CY;r>pO%{yAh;9FniV$E93|lU}-O2*ti-c zuw)|J7i%517GDP`f*VNjo9l4QrBO&t-~TMR{_{!Rz4n-vH3eTorz8!->Xw6CygQYZ zd#`r_owjcF>F8v4iGWtFVuXZFZuj$y@K}NAH|4oTT^*uXqw*dKUfb&_w3K5c50>s@ALPOcOKcS2u*wQ)E zw~teT<_`$NfRID!%Lwe?bhuf#XBJ`7!|(mZPRJm zh&;^(Z99DLxojr9j4tyAwj+>N7?gzE_nkoN-?o`PoC)+?5?DfJBe2S+%KiMcswjRc zv!bYjG9{?IT(&N)#?2M^!4X8SrXUn!^?IzwroYijb6vNU1mk|Sw-U&4>Q_6nB`*a7;7^%F<>2VdEndW;*|QW=g8?h=AAUM zy*cCx$NhT2OQy)clqJB|DR3QifZ453bET}b2vJ@Qjyth9X9*ezG8$V(y0~WI5cx4P z&LgE3pmbNHN)MSC7RoYlfHwh@(T_RAAmn8tg~az3Hd;V%?*ks{4u=wI#3M-2a9AQHdTT7IW8Ly{um#*uJj93nNac>>ew^4ehhN4qRA1+cW<1W+wY>CQD<0+2z?<%>xoy_pYmM zZyn=q4^U$v>Gvd|=f{1DyIgDnTCh}q6C|g|g{hm_!FV!u(eroH?12sDV)?k_LTX-Ljp(nb&c^r%pVDc=lR2sZ6d-Ow5+0B>jpL}oTEhxj0e6Ra>nQOab z*Y9tO`!!lLSIgF2R>+HxzcF8vdIa%VCkgzO>qek3*|eU^R8MqoVMOQ$j$E2Ipj5-o zYcE9+)a0xsdDFa|BKR)00a`LHWKkeR=xK+^Pvs|n< zR!D0+`wv%IcyS@Q_!o2|zU9it1wn9b|2D@qKk8v-1a+94>|(m6K|xobQxhk}%tfSQHkp`oTA$ti$AV=u|h8&8s9^{D$7CtS(#M z-$TV&GuX2tRm)5?%79D~h>u-n8tb1eXOnMh%{G)H0Ht9xe~R{He7$kJ74UdmhCs=K za~F@HZX6v5a>{y_Mji!z3w3OZSWFeO@+l-(z$WJp9{Dd40J;ESiL>stXVgIdbY;}o z+``de1}VnZ8Jet^Zcz1^x%tghg#>Gz8OSiUOL^{-$&n&<|v1zb_lea)5e}xmj$xpOmnYSd#13A^mJh z{-nnnyXAQ`Rj_uI#pEezF=zQ>?dj_508ep`Ejy+neN)mcWhdbx85Bya_sDg+Y)*Dg zPpwwWh}GmZ$;|rx-SgellJYp1o)tFEj7G{=amwXjoaF;DJtoo)IT+3r=HwSV8Lozc zNV0%8H6)y3tNFGsBm;w4*M8$GuoA%H+@~^V$We;hG!z9))!)dv==pAAbm#l(0|aDcaepF$$k zkk>a`jRc9DUEUdfmVWaVvliHP>D%p;uE1?3B1jf|C<0d@GKY%uQNW{ZLDp{FB@<>3 zqsK7yM9!1AM3Xm^tVNIBkiMZ&fI^AXtnPjOb-UfxTu0|(dt+mN8+)VL)cow}aX@JS zi9&;<-+-N}IhA-Nn{5UAb)CDdaK>QL$a<1do50Rx`Oggjw+};nCEW%FjT@l_1WoDz zU@zFDZmC({0gl*xWc9lWVW+D*p>?rL0Sn1D1~LbWWUW&jEVG9#()au@$3A( z_W-`I`?t_s*4uLMj?d9jbR@ha+QtLW&LhKbbdRgl+%i5AD|Oh<`K-fz53xUB^<1AY zoZ`K$^8vF7?j0GP?l!^j;G4*+J+n+%@^Eb)$@Ky{sWxASwS42{NtE-7s;*!S|H)(U zLXo0)F5#wl2te?bvk4?-QYspv-D!S)HxGf`wUfcxvA%0H&^*YY`&2ni%(2(JF;QB z_B6u`>9{#;E+ZmK*?gYjgu#$$wH<2ieQ?^z9;46t(2;xuYlfxnHpQ9f5Y7MfkVX(} ze=`WW8zkiGHK_;P-|u_Z@$8!>^VtFu`Esz;1zs0V?d08G@(cjUoR}iC>q8pZ@1$~2 zD=-?LHkgGn29y?9qKxR2k-xAl^Whr1IFyBH z2*1Adno4V2JJJn4xP*q1MNd!rSJUtw2Tp0Wc~fDVSc1Qchq8ZJ=+{pdO6%YbJ%JR zu)Pn+IeGXzyjSTc^R&qfEToJp)mOiV$qt3|7$5S8yeT7!#mGoWQ9quJ(n}8F26BG7 zYmIcuelaEOHspRbdh(LAb2(Xt5(moN$)FeA#=u!FbN{;ou2+=!py9I?0*7lmSyjnI zkm+g$PUOENRP?>col8#o>7jMd4SKo{j1L{9!*_8jQYsDDA>^hbuh{UZhH;e;!p@+k zOK^LD!{~#UF9JWNkw-*AF;X@fK7EG?j7q3rV)iCXM=#5OA9qx{WHp<@&kXHlABGz) zxizimHegqC-8cKEBr0bZLVZd@}Zv&>g!RjC!ShH1Wg%+ zZOVik_ugMVdDAiX&(VKTQ|#?`0ls7Yw9c~%F%-yB?SY=j?v4bP_@s_Ib^E>tliu|e za~>M=g|n5_b(V6+*y!V9dxjOC?R2zTN##!%qsfPQ*ObUi|CNNj0>1B<(@U$%jjM@P zwe9q%yo`zNdUj#Cec?!95v0~cd~!$589qF2+L}x!z_P$~R3I1Ec;UvjCREK8*7D0= zBaFO2sZJK5Tu`v3D6sT4ni|4t1Q6Zwe+=MypGUWTocx+b{yUa;5po9(D(x9-d_ik>Vv1G{9e@0nli5UQI19I*b z!qzijpcZG_&*-l@ACS}Y45Ww~kX5aRf=EjaHIsiXt8vfexUgc9&A+a)oAnj2X4}`l_yepci1)r;t z$K+6#lHu3#gCiB?_Es8LfY)xufT^=uo|*nR;^s8Zs6%ZVJweF#0!~dF=96(t(ud#GltQW`Z8tKA zZ{9utJqc^tBAfrTRyUJ;WtAF+|7i_1nMl;=;%UK`lKG(K;^8>DF{E8RrObu1h8Z(_ z-f(D)x_vP-9$4u9h|w~5e1`1EaKQcsT|K)oCu19MXMG|*m0)#gt?6LMcWP8}1ypkC zuuL)B?1zcBS8`X&I&-o(*~`a7SHMi<#Ck_)&HfGQ4$}d`FbY?v`CWh%PUwn^B`$-= zdf<>9k0>yG{q>a@P;AA?g216IINoe@5y4p;6`|~s5SAxiVX?~`l+fv02>LIB@hMWl zm3gJ~mXd(*SHf^PmtCsm^gz(Bc1lOMYppx~oKo+YNanMtZq-{4aX`hHNWo#UNWPKj zsIm^(Li7-Aq|@IK`7X(r_FlX`SiUuaMTtHY>c0wHI#&@(82txv-o@Dy&} z^|0P~6k;wbOm2JiAtfjH4@^~Ylq)qb6Z|n-_BLj^H(I$vh590z-0ul&pO@)Sq1YH! zG=q%Yp#XK%bKU(eyp;fBjfAD=xNp(@6DN+-ox*8-goIv&Qu)@aonH}ku-dz z)}O!pf$}dx=K=T-2k@IM@43VFXAe2*`1&rfmGHq!cfAnv$7u08R=6judmI3+3NdPl|L4L%W>DJC69cY#8iSO zr-|;U;-;ei_&8y67M8+W%Fc&I#}3PRn$&7vU?-E%zYd!GIYm5+&9r#yhxr?2_an1reFV{!h29vB z0hyOF8HjonZ}_&wzz48xf-T{ZLh^CT2&k0yJrtj3=59K~>uVL$X|+VTiGHVUv#K)& zeq(ib_m-zVRIYeeJr4^19;vDwem0r??V{{^R9kB?k;3i?U+f0Qotzg`0Aw&uyUMDs zRE_(i^|Lb;xOSKE(2S@{^WacwVFfGn4Ix1bz86Qn*XW-DFM#@#<1SRv2SG}Nuqw$v z<5uTQZEXZ#pFihVF~0Z1pAhGqQGX09nKyC}6!JwuA=%kNazGOGqc!u!Wlt5;q(n$)L& zc$!n~ii(|?j6J&wq#=gVBxQ^HBy(bDtpQzc5$xtwyd|*d72=Hg9#&6>skihO4L%$j zX!BMEgXHtePhB9MJ7+UVBG}0T1yH zTVfk;Xm%Q@06j}noimtWrUl`~k-e$}&&{9d`%orYM$E&ub1@(S9el}+`6hp* zGhAsPUjX~zpB3N_7UFel7T>tR%1SW6dJJPIJNTU-He}IrZ;!<{l@^`KeY6H4pCIFb zD&5y!EzW651*@h~u2PbwoA>c-&>Z)`_AYDeAhT|(vT$_G;yt3KdiiO*UNoJF^3fpO{SGdgp)7{-aZ&(pIRxOD#8pTW00>KqJ1^n@qUy!v z=noNuU98&uKpJAV5+X;aZ)|J!5Xb0|nny;>bBHkz2)tDw+Ic*N%e4Tm&Be>f$zGNpxoy|V z$$8$`rBM><=?GQr4fhE4W6EuZ5_nnM z9wOHy17i?cDc?fHuzSIDR)^U0dW8;8mGN7w;uOa?s6^uIc1O=-?^xfCbdytqGTqAt z9ei=O6rc;Co%jkkDAU{9appC6&hQhQ<+mS4^|c8^LULocsC19fZ|hJt`-? zIj>z`F2V;;={_jSYE7K;a3=zA^{BnAiHQPr(*s9lQ!uJ|5q zJ-Qr0Ht=i*`K$E`mHn{3yHl7h)MIjc=sQg|rkkj_Nlh?su3v8Atd_@2Z>kx~14{9R znlAZkOOjwI`QY{s{ny!3*go%dv^?_NjJQiT)&o?C_Amr+^|(0A`D@JokwRJ=O#ZV< zvW5$`!p!D{9=Kk%l$M&$)yYXU(AJoX)9utM9w(v(yhqQDbNjnr)o20x#2mq~C?DDwHYuK6o(Iz1D5IWjZRH0cAZUse!^b)^Md{KwW0zY98w{j;>lK+erEFIg*%EspTF z;iqUco$LfP0$P!A1!bZpvW6noL*k!*BY`3$#7qI|h*g@ofTw0cS@vUe8&XqBwO!dV zP-w7;&lkBTV>3v0Nwp_LA={_gEAI3fX8eXn1bqAQZC}aMU z_pC_JF#unW6iUydZ{d%!b$0$k&Q-=`xoJhrwid=C5nM632sK|QeSy!!6B%U%?R)p( z7eP(7=(-gn8jbYR;A*T*_tSar$(Zszz?8r(V9S&OB*Be*y-5<1qLTNR@Z7_SXT$^x z5C(5%%z?EQQ(QIG{U>5b zt!FiJE{j8{-Tuq=jOe;vw3&}&_oYS6{@PyrX4LDsMBfCxhpubpCipC#VC_=T z%PzBX{+nCT!*(f|eI6QKO&)%m=L;?VGe1KFua7s$+EPm-8U8Ho7}$*8tm#j*g$=}X z97=}a6{G2FIanAQJVBOcyNscS#j`azNitsqzfe8~IgFaz#a0V=Y^;{64xFU9HDJs6 zI)_LdFH}Pr=@818OoR&#dEMBqU3?&z^yTnE_!`23Uu6&$NPfD($gBMuKzx>XWr`LD z7$p2XpG~~ka5^WUvsF1`iF)0oFU}6fKWt>q@7uH_{;1y_KE0&})u76Iww1n?YCiZ0 zXgxHI3NE|Ke>e7X&G47JJhk zt2n|AO=psO62qY%+u}m;w-$KIz0-%Ixx=U?PamQZc&4UUWTXVCybsBb_H+0qxO0b* zOOgQVTyXNo?OAV$?YR&!QX#VTfE`Kt=A(CfOM^s6BB7ipJk>6}$fMsslE6kj!F;@c zM8w&{!o+aZ^JcWi|DyZKbQSe%D@AA%q(`_)Cl_xO*E4bMlrsCttejWKVn9bD+E*sK z$KL+JTxS-`<7sx$>GC6W(9=(5&EtZ<$K$59zlq<;vI1zbz6lJxaN%CEC=I&%@O&K8 z)4{#;8nP%C1Yb?2$4QTPz18CrLsU~q%J6pP$Xt!t3P@j(M8r*CGd~hDw?OChFVtLH zLZ$psmVnCVH95R^-(jy*y16l&3>fWi`PqOL?ofNv=kRYYDSu?AY2ZyIQ%@K=P#t(g%>Jb1_K7?CU1c zXNSu2cf@dEYj;uQ$ZC(z;Q0;V6{NhebP0n|;89z|xTB&8x( zaQS(-P;KRD`0O0o3OG_iIpCtj!W}ZcSYvQ8V^aU(k6IoIn;~$LKCw5kebS2+MD^DX%%2zrcVoUeBb&Ac zY~w1Jhqxy02YY7t6bG`4XAg9Uc+8{41f{R$R$jQLj;}1Vm|l*7h?LArR9=#oha{=M zaYw**tw`_+vh*=pP>MUf^+n;{C#2BE8GUER+YUzC*m8SIa8{8GhBc550uYf`L-;wQ z3s$n;fa*bk7*Bd2E6oCq${Y24`O`;Q*EN9My{wk;s1N%068j0PAhKv7ILes(DDDwk z4qG?Sia=S!>`0H?TdDz}*S9-rapkDnr)}ba2&n36EXD*EH=ayI;j2{TU)KYx;MC##5D`WU1r@^mANXKRE5cU*myFsI&?gwj%0r$z>+= zWLUje9SkTc`vE3~5{K7VErbz96hjEV=ka?0bd&ED9W#iE=+rttFN;v1=A;6RV`J<5 zye8exo=qnDMScY99^I5Y0lIy|r!%?Dr7$5W3A&J@?KKRuIP8PFK=1aL<(rU|=!w>3 z*c3%xQep-C`$d!KS&u#cj|Teg71vCm-XgW|hICSCv~Kv;z{#6%(=#4XRYvq=Kh`-6 zjL!9T{%L9Qer8ntMe1V#bfeMb`Rv+BCpP+Zrc)}0!!5O@t2q8*!!C2P8_O{9{Zek! zF}bJg56my+lF9y}V)pqpwp5mK2mW{Y*L{r~GoZW+*bG+ zL7*xhS4Me{9f?@SRX8NJqW{siUVbYJP5^IrdV!bAzb?L($}rs^PG|doG+3~G_#Kq| z$hNRNSgDCN%H4Z?#zRKBlDe{+UGMuX1kcH%HlzRJRRNYW|F#%eAWeZdCs}lLu2)YTr4Yddb7de@|ehX_erdV zIJgQlV<0Sn5IL*XuVb9v2b20IIo)Er<|uFGMmqh(as3*yO)SnFs=!sZWS;zaQqJsY zr^bv|Tj1sCCI}pTCRtpAf@MnA-JiRvEyVQb=(J9KXxxn|*~aPzU{b-##7v8*N$Q*_bpZCEt8j`l72IbA@6<5^PHC!Su1t z0(rexr`IdTwHAeIJ?NDtm(=+`O@k(xpuoxA&ne?;g1E)E^6Tk53<>B=yl=ULB*Vtu z>`(DwnqF|(1ec$FuXc5DoO-|a1SSt7{uIDp*Bbr{=hB7%nI`p3XnJx^I3<~VEFeWViOjJA*-l7_Fo+H#t>2sO-XPdSEYl% zRp-&7yQNUX@?^g&H_(mP|FYLc#t$`FVx7`+3-Lgs4G#qUh@zKH(B(~}l6L(daqyGX45);JiLnc~x%?;XX9VgD;1~h|g@ZTdpDQ@imoA_zo}#2lbvAx(DiHOw0m#wR6Mhhb^B)8mliuR zr{;ONA_fq^s1jj9qeCe&?aqIU%LO=5o!v&RNYXfnS&h52Z%a7S+lADgB-!v82=3R` zgj@GK*)%{iCXqQ%L+*aaWu+p8r;O$w{zN<#uUj!lr$>5u4$y&a6nfCg?L0(3*SCY{ z`!anu?~mP{_4e5NhG(j>9O+HTS5d+6ThK3q(w0R3Z0z$> z!k+UP(;~Tg`pp-IS7U<((cCA0&UyDLxLH)L)-^_T#BW%@!s52(Rzw%t6`6e1JJ7Z* z{cK4szq^0{97yN<8&V7Uo5+A)r$bvbMaYd(JLJ}-3ro2!Qy4?gqRg}#vh zUx)}%R{HK-u8r)rlj|fde6XCxcD1v_*m5M>864o$fIw9ByOPY9O#+2eSAbYEMpp9u z$vct+Qv`uvr?-v|Z>_P{3PPD7CQfQ<5|$wqEtqh%9LU{>HQ&GeE=XsjBt&^?BMH*_@9ifS*rHU>3G)4j&x%zxukmAGQ{i`{0Nl9NM+X-b+;oKYZl` z<$;;J+*hX`x`mj$=&r{I>|Dtn8CBlgSvAl*QW(Rr82)!Da0+!J!S%v+WI>Vt!n+Y1 zG!!m6Q2gOg1sippNXACZpvO3ZiYW}wc~K!raUb|qmx!f%@%~)@iH3i4mCcLlBO5WT zw>K-Vp9`?VeVtNH(mhqlv$#_Jt$qRvdFqY?t z0E$qJeZ{Zq_z%?0?zNt5auBnzL|CnZ>Di+f2HZ+e1jMEX4=oCc*)wb^SbGbxUI7xMT*q+G>!F4~YZgctQea}*L=IiI z_daRzFJXz8cI86h9W^>gfZ9Pj;I~X9&j6{&qQH6;rJlApK)|u)CH0Os1+mr zC@!3`kT)*}Cfe2uArwX65_99SIrMj{(;aX}w>|DY$g-8tFClDsr#gked(`+NLC7H) zW#>?juOf=qf~x`}JVe}fdP(fg!L8=0#LKPHbUh8={Df%VX3~W|d_~4pA-V)bw2aHt zU3y$Jkf3(SuNTV0W>5=XPTrg}M$cE4mGxfF2nr23^1avo{nO{AT+<1S#_;RSp-K5B zY8`E4Yru9Ro7{0yzVCHa&Qme-N{RiXifGcu4dyaKBF~k z^YS%4LCYTUj|DIMNBxKxh1lv93tyWzpnHF4ja_e{n-(kvQm0GND^eS-BLzYAD3mw4 zLg&Ke^+4b8*Xq9`pRY@jnYGlpEnoLBTvU*(06(TJy;fJpQ09+F$Q4OJ_i@FZ$yyaL z;yoC1K|uZOnri(XoE!qx1&ntMW8@%I_&hxEm9n;WI@V<5khovvc{sMw!lGhAEjg$6 zPQctQ({nM}f#-ly>iOYlTApa$yw?K=S-lgtoms;Hq>Gd=4hTdD(jbgS2(THsp2CgG zJ5GBtmtaHl-12F40YP0g0Vp=?nKS=?Jr?>QyG_39^4A)g^2ZUhGpXEXLQ~RrE*^dX zeuFLZJdN^Dy(L_goq==VPp`uKLy>N6kQ`i>r98UoP0d~PG;_4dF<=u^ zEzX%^Zt>)$ipw9+deZV9PgVJAYy6i~sFH+Q>wc}xV|-6T#4T_;fG{@1#XV#0Bv|Yu z*x)62yR0G$r~p!c(pLw<>Tfd*bG8+?$^Z@PSHNy)*MqC>m4}^yL!`_JjuYiQ_Mnpn zql66V*AE}8m2Ci{5%6n+ex=`lvEzfyumowtC6=bYk-i5cc{V4zF9ThId9G(WN~7h2 zoL{GXaV#?Yh(mj)ssH4s$r7GHxx`danDTr-U37_$)U?t%mYt}ArKv{sSe2-j|7UT? zr>okwcjxZ)#JoK_58O6|lY6GedV*2+-|vgn2A#nsL5uCQds3tr)(chr;jDjUa`2^j zbBNsaVvRgSso@&~EM0Vfwq~nP(B=8hlN9E1@VBLd2{o}`;)SicJ8$ynNBAXyhVdSaUTR_hD*Zui) zVbvq&#h;E+`TttJOX$LX%*JwC0g#H}xK5iUyk~F<^{&Bwy%Oqwo@hEDC$_2Ob_ zou_g}cd6@`f_xqH%c}2JAlUurH}6UwZykZ0U!*qk~IMU#25IR#!$Zl9KBHnYi{ zz6<*Ho?p!P(mA|rQ#7T+Nczgqr8Xq*cwK9TJ>>aDRU>tdEb*^W@x2ULIl+}@XQ$drQJWmyF7&;O7G4&G;2*Frh0j zkaXs{`ReJ+S;CFN!tJ2{x!nH$c%!yL6HaCsLk`2-qQ}nq!QBo{xf-{{mTgX<JKbY zdm&f4hTTI)Ijr%$VJEiqJV^1aKkjeFe9sjwew)P-@={}QlK@jmwiPq>0i2ERzbOQ_ zYAbSG9A zcyJ}wK_t>1mA}O}XZyX#ld(InVa3NUV9SF!7+YTyCT^WjQ`r!7CIE^RUpF86UhZNC zICmEh!nrXuEE4oA*ZWUjwBA;?u&Z$>(Fow>HVGU3Da_+ur3elu@Fp+Q=(4>i zSh*#oVyBAs@n{KcSLhmJ|@0Sj7mHrS3+kjk5cVA{F~QD>5yV}nv~%v-j_C< zp`^I0u`kbNCv#YLtsQ-g^=@mH;W4fM&}zX$JqBGE#WcL7%Ist+-D_EXt2rnK1yE(( z@}85kf*PSWTJ7~wpaq|90j&5Q$)%v|?Cd>IdpPr&&r=>iW+adn}fEl8+U11Um{{zr?Zx^ns4veouFQ zN4VTx@3wEv#}ipkXP(unI44SCR!c3%i#x*ch}nZJykdkDfc>-Rwj62d;fwPv0c#=li={emaLn zwqd3F+ZwsKmrh`lEIZ91Au=IKAU4{VxeLP?vNkHa?cmIaWYD;|iQAmPx7fykr9<#G z1@aEtUS<-*E*i81G3d0-wx}oD;q5PEIMey%a*5yzlf6~gH&5u)IUvmc1b>8K8uK`1 zz`{Ma6;wmb=T^+4taq-m>$Vd1;vxp1Kmu#2rRGNZVSu>Bs{WmJ5jj5z#nqr?>WG^m zQxC7NX<6mRV^Igq{XJ85O|6O^5`{GIi3mVG(0w^yr3Y3k`O0@u8=-<&&|CqQ$43A6 zBGM{VR^d8a=rSQQ_OBW&YQ?j!3)WW(V~H{I!6Bsff&tn3_xwv4@9xVve4MJ(N3?ot zXiMJKkABxQC-KcLu}5|~yHK*=doq-%z54L^iD1aZOK{($sIRXjo>#Rjz*Qy1`HYD7 zB$|8vtnAYSrr`dslD6U6#m}-P0AOIc51w-E_!yXA8%sXihdVooO#EXtQrA(KC@->;_Vf&#A2T0nwRqhA81>H zjy}_ye5jcZ5}ga~0fm4WDb%$GEX{%5f$P`#zNabenP+C=17FX4%jr$zhp@^kzB&y{p}TJftEq?O_SPg}$$< z;V5E!c*J`v=A_uw-ie1Z9HC^I9*k78t&D_lHZ}r=S9U1{j*IVM0kd!lZ?JWzoBmB| zc1Z4B^FW7j#LypIyG-YN&K?#Nvg6;dS)UQryuXBTuN0Zew7^D&9Lrl$vjo3#W0D*? z^~_JrQ>BT{#(u;f`i^H>=yL`wNy@;Z3r4Xy`=YRuHXT9e_@L^Mklf|ed85sGn?Ii= z+;%_p_Sk;=I@c<4Ep_Jfx*01&JTTOC@||GH>UhgM)@UwR?0V6{G8*kO&`zYf6?l4< ze|m(pph2?0X4_H3^)DYh+}~43@39+fX>cLIZ3*aQ{gc60inbIa(9B{Mxt7lrMG6wI zDo_kUv`?Ep0y3hLiM~&cz--sWJGL`efmS{Ze+5%TgPca*3bnx1SW)Ver0S9l~$txXP<% zu{6my935MyHdK+=D~f}3m)8xKk1mgrWmztg`mX0Nm_2E!{@tJxrkJVPrYFYIwr$C> zUU$pwA*-s!Qz-g?cLRoEXN|^Ry>p5KS6q5d?p~lu8?CMG~LvwPbvmL>_LNUjhC9SjK8Idd@efsX3%YnT`N9}nTVC3U* zm(x-ThF!SEIQTX8@}R$GhsH!u97z@ib+wh&FuYzJGJD8a12T17^4VWNZf7B8Tox?E z6+(+^EVB+Oe}FA_4)~YEUXqDuB9tBjN|HoAFE&kWgPGFH_GH-Y>v+qpbg)6KGAHzX&MbF(`GcDvJ7V`n&qy z^fw^je+#h$%Lu@ARf2Mm`!^58Ek=V1^&285UF)CXPUHEMb=1Ibq_3$T{L_;kuUI9y zHMEfXTEiGPqa-HSE(vGb6_Wj4uF zL9F@XGZ;W8%54t3$@pI^009qq9x`0t(1>Gg<1VonTp5RSl?=l1O%=JK9BC_Z;3uQa87dlm$gcf@LoH=l5e=s5lcd9kqCUQ}Xd zkYoIiKGMisp4p9#aPQbGRy8w+P3~r<|Fl?$B?c*<)fya)RUizcu?;s3*k2+h-a=QG zTS&_i-}{k(G(5*ndhQsG=DToP@LAfM955$7Xfx({EGyT&FM?zq?FzBhyPWCwb@bku zk@ub#Khc`?NkntY>F2UL?Y*bVjNNHGcocZ@p@TYmgf7$9s!PO3OEgq56uyjo#ok8Bp!m!x1db$b5;wF z5%y|7l)-6dUm}5$6?BN0w-gNh0Z2fx-7=#j>m88LfW!Ta1FjtW$h0exfmclhUm<2Y zBQl$)DH87igs^Gr;CGd}MAvvVIc|_rvm+*q|9X=|LB`*oH6z1uPDu7`@YDb4bMUqP zPmVPe01kitr!e$YsO)ZA=0%F{aV|^DJ4}aIHPe_n?Q1N{2Su0cM^B;fh_~zY>V7O- zqQsncl_Tdg!f#FHDs<^irv8Y+gmrzo>*i z`9K4xJlW__3UY(80vzKud0U^~p74k$>QCm@CtfIqa8pT?wH`;c*WDqv-c~91Wffj2 z*q7l*H|2T-RZyz>f$dn11?(17ZWCHoac@9cpcWMr4_m}yrGU52{})y784p+2zJKqT zQKQTtN}>%CBt(zsMoA>169hq&AbRhNK0%OZNpwk+=)oX*bfQP^EqZUGKReg;|2_A8 zzu7Om@|nH%TIV{?)>)8=CEtO3q;*}%&wlHx zyA{G%L+cPnaBm|U=No=Dp7x6}-ZQwW{^#p`)$xEjM4du^6tw1XGc4k^moiJW7(#~Bg!`4)0%_+L;uZllh}#!eREO2;@;ep?a9y| z+s_j0Tb3{N-6m3M5ghz2EMamK*NoOMU$+c4q(BgJCOP2YYlj~j7NpDE~FN0ZB4V4 z;ogl)CViVSHe>928QEj=5lcCfH4s0&yNWl zRoC<*3MkuI^q)P5cpIIBPdG(Ko4L<7NZB#4n{zyyS)k*!^DR-Zzr;OkJmETPLBda*N{M&c@*KA{8 z4A1|S8lm8~Al-6t^(W2#b=7U4aq_zkRsT9T=U^_?q(C;8H}zMi#Ta)JmkWqU{${oy zjQGey%$dFZnLp)c@JF^IOZwKztk0k3!oeQ;lRFbbc(SCnP~-TH}-DI`x?+Z@Yq%Tu-8iFuX4&#~pruW&kx$J1BQIm-p_;S;s(RGpFTCtqPv>1&M z!f9<(JuGKA(YAk#MpC}RS9!<$ESYzGiidYsr|gNpEt6hEQqiY~G|vUo-A{R0rcW}q zj+2?9jV?!BcdOdiT_re3a;c?Wqku66=3#a-KG`Ygu#Dxk{(4)x#2C zEPWP0+eTX_d_z$2!w<>V4}|yK7k@|&gz~@cJZ#_JwVwV8!5<%wcE!o`R+N{VstPBJ z)$ML71HnJ3g#i+f2vcPMHeQHXK%om14SMqc=Jdbdyj%SUAC7c(prmkSCDMlP6Cfc}C*PyQV;a2#2@IZsx9w6(&_Ov_g~Rbm?b?CzPRVHi zJeurK0uw!C@MRS(=tB? z@A?Ow-*VyFeO{6QhcME{5r;b)F(ouVk?k1=HyKSp#ZQ)I`s@@2znVzC0KjC`XF5dFV^PX$= z%k}Su-8&|*UWEqHVQzksTjQ%!*da+h>E}GRZ!SmWy_#2Wk>8*%lA@xL%iHbS?Z>J5 zyFK&yQ!?df@``V`{F$|PGac8&&;8ocNeL!N9_JL_KK79&{m|M*!Y+e>P_M$}vrp?M zb$iMWTSJ(<&ARRRD^&IGao#=}j?nscg90uu0C1~)KMdTYLzC2Yo;3}Vu(HHJgp3~u zjZK=&5fQT9^3e!oLXO=60_~f7o}Uc9p}lB1R$OMu0 zP+UM@0quwqX=~lPOS)AH4Fk1aSD5-llUYT27cpwqQKlcUPmd0c+PZFBl5qK>ZSP|; z!l31r_fd0{?kD^h&W4boRxB}Y%=DAP>#4K`eA}kS#z6-FS(1~K5D`Yg)o~$MieA0%S1sza04zr1^B0*3d^_+Odtoc1)DsV&rsQH8nMI z=42oMSIWU|>X(~L6^P|-@VW0>Ui5OEe>8W>&Bn=SvLv=%TlII(^s(L0pY?kSDK>t* z-QsRuY07z-zFR?F)!4cjZLjk*=~sKz!rMtR-lY%43ys=r>kh_OaZ%z?EywkXCv@tI zO|Z7sg&_Fm+vCfzc>ap?Nq*~^qRp5YYdh7H<283{`Yuk}vSE}N?n$Os#YI8 z&-ioD)3&)ugu&@ilIQ|kP7=i`=9?b!aFiU3c8fa47!}ZtPt*tL1EZTsum)c+VZ>u= zH1x<83_!GdLKE>$QV}(Y@c^1vj!RO!n5hU=NPK4Ft_e@yTlw_V)k%tUB#QdMZ`D1i z;)|c4+eThGk~n8HW8s0OP;dhwrRtX2d^;0_SY9;WDyd~{AKBR=0=frv)#I9JqkCGb zmd@#QCu^a|bDD48*8*r&{=LD28lL|fv(om%!25c2mGkRc}akwdAn{@N6z7Wj89*D!jl=#fJ`vvVPv?B9@J}6wb=D^Ug^YO^X?h<6>P30yIa<+yWahl;*IxQGhw$be?3# zn6Qv7`yR;yv$z<~1B0DWu-^#EgR_G*pRY$7GDkDth#1?a)Ewv7E})}Luz4}?D$fna-P$~SS#qy+d>EUOYL{ox zIbqnoy>Xma!8_9H&tV>MRCD3&Q&z#%(PXX@nH9A% z(1YLa^uxIA+xDsY8~(DnqTpNL&eIpdjzh)bq1s7{GuZC~@=&Urb!~YRa4S7~V{quR z7>9m(7lh~l)_DPprQV~%-O$#;JJm{iuN1AWc_3G$(s5(rvn^G)-wFk~k2YJ=+wZS< z0+H;)pEW9iAtH*Jo_x-=;l5@R{G#*1c)X;H(mU>~vCPfvj&|L!k`pmei4>RGd`U*bd__l8IYl<>RI2owy+uX zl3AYA_a!guVy8`GDaCovknuB(`yDo*xp)*Rg=6|%P(l9D7Q$S7ADQuuld<}V;K##j z43CY2Kb00homHw zvvR7@(}Hzk0HW!19KH(m0Cj+Ey0))^>eT>~i)O`q$vZnT6Fj))q27G_Hv6m=u zOzJohePuRz>h8x%z~N3EI@J637VjbR?^3Y--=%Q$rft4YLHauUa-4VlH>Q=Jf=qge z2s5s$`rC~Q-fYO;gnzg}0XRmys`}%8*w#?+cGR%r&zaXkI~~fesxk-+D~@_ zM85ttOs79vwScDH$<5@rept-ikU5(~ff)Uc9|Q(wlA{`UlMA_IWvSGZ#TX%wz6NY~ zyy(S4CF`%?#K#y`mRWG~R!ML|9ATg(Mgkg$1?bBlJ|KT+G2-40C6@4A-Sq`Ds* zmz>xq>j28ijbBH@nod+*cbuo;#*#+}r;3mY>50U{(b=ZWIsJDN)MA&cwbpHlHA~po z-Kh$uw+{?l#a@*&E>Nj%h9CQz8~JSbEd?Wto2y5aTEuzHxW8eaV!zg|7mO!dEc=Wx zy(nz6?@8R`O{hntExV4kOlxfl*Ovy)boxuVffN|wi5m|n0$$_ND6PVp?uP!#j87&C5?C<1sUZ$wU5AqXe@72s=fKEe0=K@xNCoiv^r}Eyy^{4 zZ3l=NU}*eh50GQ7XYqV5jxCcS7|<8i;u&^i(?nyWQ+Gv9TSG^j!=W z$CxC|>hX2T8e7f0D=KOX{v30<=xn?Qe=rRatyM)?Fu@<5Z@-m(X3}|e)DpD&mBZa_ zG)l&s_h;wvYIEtzo@`Mpqu2&=T!_~C$?rXd6Hn|!;nz}Wj+d8)lA&0y>l^1=`$q-i zw(j|T!{Bf=zWzv{Aq>>@(Aw8t$e%ZK!P_oDm*Dw$~iJo-zP3yJj~Qo3%O=3rK1l^$%Nj`X z{uKsU{6`oRvGF${e5()vIG&}x3U2ZQnhWbzhG-(k(XUcJ)gNt`k+fodqFdIbF}K_H z!vl4-go0Ne`&x`gm1MpV(&uZ3$y{&# zIGJUAXA(Vj@SZ5YYZYM|d$eg7Rc$XtC4KnXN%q$|Bbd0W64_RNxYYY zgL*47=u-8tJ3{DJIR0-MvM)y>lZBE@o`u@IZY)we`-hp`+ZiY172eVn(pM>|+l|f)&t6PzHe1H!fcVL@FoLoJn5;!W7mqwh#Gz4?sS=!lx z8iVKDc=TP!Z8RZPu6oi?l%bVoa5o9Jt{EKENLaTOB-S0sEAgAHZZuHc$I^X7Wu?OE z5tqXrJ~G(qD*3#S@6RqI{pQT#JQPKH3ec`ud~P7xSlr(9E47JmlK=dM$QyAy5n=6%j%=K)c1(xW_-(e9i`gS9n3_oz8sH|xYbn0DRoJSL8&6hN>=|^UNUv}G0KK>BRxF$OD#6G=o0jA%Ti?Q>wKM5{lA`Lp>XnWPo4H_~ zKY$(E?`cG#`_hUtg>!A`&hDRWSOjuwjQJ~2IRJ5VmA>3wtr0iw4%WMYp3NfW>~(J; zYICO-eWDLVx!1Ri&=h%hbq1_%C^ZafX+PakUnQX2;d6EofKZ@`Y*GMl1%SR(adjoB zu}6*k$!PMvJ)6?HFBoG55w-U*0=A6^6_knM_;HHxnIGoqiNSe($C` z7e{lFz2uD<*Yy(IaGq*Q)lePCzc4*_2Nw#nk-7>|ilLLA3elNeWL_EIJoQ86HWqsf zwW~hDex3^H(>~(9%Nu`gH-56~ZJEv6)pcZE7hl`fppCV&_VySz5FtA%YTMN5Za%I9 z$7wC}s$j7Z)V}O9@mkg!SHQ%WcVuJUb$Tw}Dfvvy$bnykP}7e_?Wl z#{YJRW&eX?GBAOkcc|6~re3rY3VV*-|NT#gX6vPGxv}syVuuyqk_PB)4#9e`nQ~oz zCSL5lCh=WsHl;YW9@4`$oBFJzPGX4ih=PzLF&iyr-0@j_QC zvy(c2oIPq`QCCH!N%sTO+a*u)dL3b(O!`0gVst{tOb8!Dv;k7ZlKyN+CmxKo0lff} z08CL|=;rzR0h402tb1;1U?X)RE{<~M8^?403ZBRT)vb`27P=ls0P%7M_~_dHi#{I} z2Jd9BbVxGT5<9BfS^n{@gRq7;t%`|RG@4PMYp%=eY|{RvoC?O?yFV9qnaB{8sFV`n zES9;i3Nmkw{pP|ew{ue*RbC!!$`JYR)%j3MRb0h74%4MsHDXf|MrjvHi`~vP%WTax zk48FQiGNqQs82asatCRsqcUaglEh-?meNdOLgQcN1qG2ib3J($^StXHy4`CQA8%~w zD~mEqT{df-A6*bGjIq(acc473Au*~;Ihm(__W0dL`>z&09yaPG>-?tEQu%_j##Nh6 zl9uSrKG@6IwHoCh(t-v6*mRH=>u$#bl<=};lgH8h#;X;iGP zz>pQ!$SC|iK(rys0)qz6=BTjVMq&*CNd0Jg`?dV!{Rq3Ow^r4j@g}`&TIIWI}gYAnMeczdE z*YTFQue#D4S7NW7ef3U|sZ_Iknv>Dqg^2(rO!Z!K>GA;!FUYfx6Z>M4$K(-k_!n0& zM)DtUio+YpzH8`)B?lwbzfP?9RT^_Zw$0v&@=@WgK#LHB( zxNJeNlWmgaD9@qK#{dbq$lR09&&nZWU*;YH(`)!o`Q4pZX?h`#Ek737(%yPw=13w9 zj^l0uHy)!t($o%sUEl(8St6&|IV2nUGXYkD;%V{L@~d|!W`T|hTWkY6uM z#7>9|*1iQJl~ivfc5MKNym`#Z6tHS*5Z&&U?D_dBbNoq18W%7IYxbCd0 zkj>{Wk`ecFM1@J}231?1^AC$XcMU5X65l*n^kXY%eVXmIq8A-hp!Xb(LoGHnN=EEb zf!wk4r6kp6rT}=Gkc!{JwDmlEYlc~sk^S1N>>Kyzepl@h!W+EIyylV;wDRIS;ND3u z?_lMRy>Dqe@km!@&pf8lb~aa&lJO4_e-MIFM&L;Y2mK#Ofu|IR9}g8}bM;|7x6)Ap ziU1+T^uA_u5R-LI$r;)m0QlR5Ny=MN@nR94%)l+AlMJBoMU($D4(W7dbz<^s=P?+6 zhB4KQA|%&g0&@;{9Dd^Ail2%Nf(jp4xgfngN#ERG2ak|c-C}HQ;JmY_t#D%dfbFlf4;)>f|6%ZKnE%&_ z*UVID2?P@F7^@x_3SO7={)kWOatCPs!M>;{n)LzFf)_!w@*+3d`r9|K>#_-EA)sMh zibBNcF)##iOiWY`<}DF1gc7`%_l4Th*{KRBva(}b9t+&R1495z`L!CWhx&V@rVRna z0?1$)LedSmAP+OC3e6AS%B{c@`UW+uJp_2b^Q8DG$Z^$jLi zqVKOPUBZW`KObA4S7n-F>`hhs{Tr)VJO2M1q|}M zuVrPPwY3L>BLY_l&(>A}l*6^w7?pZsLB@r}b#fH<`q^b%Cd;3Ls{k~SCKq5blV?O# zbQxs|<5Ca_?`)7d(lG{ym9k{^ul;xZ+5QEUMkEZxLiZZ&FU}i9q))#!Dc| z@4Q%e`!Qedz!3bnb%car5%t%BK*Bz&z41@#0QnBrJZ(G|RHM-KMbi(V1sYK(xFM7X zO-UR=9bP_BCj-+rQ<{Y%ug|3)N>WfD&`<&UQsj8Rf_O$88x)YW7Hl9wM>2p!YGg3~ zmmHE|`#pS2|1B5-7rTJ>Scvoi3dX){AoN8(Z31cUI@WXhxH#oot{o$Y(E8B3+2u7+ zNoFqYG|?^0kRehYn~dO?ueK9Sb&^efrh-wGl)S-wx;Bl&1$mn~<|@2nH{2-S-+nl$ z?WSL@&|6tcC*8d}T@#U(cju&a$iv-y&zrMrVs~_+x23GMPq-k>bQ$-$W*!^}Y=Kbj zBs&pm(~G>_m5zQ~QxOv;+RycFs&qlDi;xOd)sMGeTF`HgfA)th8Gy6HhGpAH9t|(; z3G}I9WuoCc3PPg{|xUqQZt)1|v+y~3oKERUr+?h8CaUmslpNlxR|ugdTIbzFgT(!lk9^rSkVwC4YTZDs`y8cer6-F@aPu6Ef~VdL743q2X5 zf|E0op)3)p)&!YY9l?HrcyL6J8?YkDM<+=YLMDoQMQ~=4;mEL@j0(Q7GS*_Tq-tYC3i^R#NCI+u;3 zx#BKIGraadlyOrm32QpJe=dLCBT2Z>73S8uTPhZ8|M8(Nhsllkub3{CqfM7d4*Rq- z+mWo4nvJ}a7d*Y&DP6tW6$ajZzIn%vei2z;&g+ss60rmy5G|E2JmRm5o2OT(pbf*7HAYBt5i3~?AFKdvEP=1wMEJmv*WMkZtQlFz01Kvaa)nJfzzhdgX?ucNx4v@@UcS6GszvjN?@g+Sy^nQ!I{vC;0Hckb z!9x*60!ojq&8g*O8erTGx3WLeQu0d6OMJl)odi)L`z`CUJb8~mG+d<&b{U2!Guc~z zBQ(L&=?V}LiPMA{j5@HGIw@O8N(hVEOQP{zLGr97=v3I(U$t8H9#^OqsP`Xy_1a0) z&~twbN9Zn1@j$s}y{P=q!;t0ZYHRR8Yc_LAWYe{hnM>J0lBTSyh$7E-9=lA~nfJST zkz?Jq=c#+{-^b?n94CX0efbxcYd&C;CJxpvwi&2!4qZy;yuAKmhOTPfSL@{-b{W!s zOM@dVVa%h64Po5{*rxva?hvn1e1U?_2Slmv_G@3-II!m4T|3q(=4CQbMl%td{EG&c zsTw19rFZ=ZJl1@8t7Zk?e;X`GT69_&$xd9s&!i`y7iDtztD_V$mJ7y1;^EECKE}i? z$ao?#UqJqy(0I0~=#saC67W+C!nY%N2Ou)vrY0L!S$!~|P%gs;Wd%XhH)&P(o(zIc z&9(ifC8uVv)n}g%*`RExvrQ97GG$cwA;-a6TX_`N`gknmxs)_>v*>K;YIgWv)dX+D zzf!;4|C0L2!3~H23fQNHt$}yRI@DAP<0#dGmBol@`cFX95+3ry_T}-i-nzwnQT;|0ZrED26+@E|)=Sn(m^#YLj9aU#JbuTzt-!IZ zoT4rFhVz9zh9{?eLICS@8OUHeSefE~dbTcuyH_7eQmZ4y!_zBh=-(n+vl^wj3#*Wz z9of{yxV{Yu>y2zK@>~0X+ZzS5jq7b4o!SbEE z=?Xm%A|T?6wS*fv>foBqih8hI0snAaKOlecOD~^!xI*I{#p>W!Lu=3u^>?~sc+lj} z)mr#NPn-JLKV7i@=eq=zQvYYt@9KkDFkQZuxCoHOv`z2Rha)X7&999xv^>FU4p6@u zQlT1oQQPkfqgo5}aywR-DR}E)_yz>ve2ZF19@LI39S<)VfTj0+PMo6t>W98Y0;g~+XBD)$E z&qF9e4t9-zn%i0ehGRe1?4}de_zMI4ZLiS8z0>k; z=*{cHi2#u#lYk9`9bnQ34*C`2c#U9FLn{R92xxsd3iqR9(7&Fp?o)kgq8B|Jsm2E? z@E({wCgg>G*}Q>YEcT7}syJ$V!yYK5vf?73-wlVaH}AcUT)(NoL}dBbcK~9p4oiKO zuiaHZ{Z=i^!iyX3xJ0TbkG>~3qvk+5bbUd8&~tf<*WQUW7mXI(bv4Lgj#MBgW*HUuqrP5=q4z49;P3yeXsHOsCJaHD%hti#kJRn zxQLIR%36Tvaa@@#57+6};$6u;xg$-7?zF1ytg`3(rDN`za<-Pnl6JLD?9)f#Y`dp< z*$d;l4=?7vb6jpT_uy)I?^4d^GCD5)3bE-Z_NTGS`tzZ5@#j;TX!D=jOFJ))y!mU+ zswNYTo$o9j5BFS!oV|jGiT?Q@Y^Y2_mv#Y5hjCXx5$3tXjzS(1-B zFHh2HeK2IIar?`;AsyVHY7)!DRbx6*WnGvZmD(j`qe?#lA+`x(6f-OQ48wPzcQHWR z27T%Q6_OxfoH!cU4XPX@ z^P9vpnW*XC{>jPw`BWex%L33UE{B<{<5M7f*xHZ#*kUp`9=d+zEe~AmS$m)#ljNFW z;Ueo=XI@#5{rLgm%~{xYShhfnZ4)0@ zzOp%Y?Z#6nqnh3Re>~r8$izQm#%*wX^nX4)^RXJWidGV zgBD)%;U}b^z3?4E`71C=ndLLT)$pY2Y}D>sv$zkowCLilvC*$qwp2KfK_JR<4G643 zD2?7P7=pU}Suq1ME{QGo!N5|I*&gDu_E3%xRWkXZg6BRe z5)WmWCy0`RBEvo1_wZ+DiLPN^;zWA1zNgN}D`g(?t7&Q0~!B?Y!%LwKp=qCE2+(P*2H0LpXoYtJ2~v?fOsxcTo9=bbHeT zd_H6bvZId1bv8qIoxB``Y>A%^bg{F!DrZw63p*)UfJlb#B;D zHYYgv5rVu)0D}GQL{0jPNuQ2Yp4it2UZf61|D$S{JNl=;R{dZ7wU1uLwvPNgQYRjh z#}W02UjH?!q?!Z>Q6|G{w$6)mDt?zy@0Vv}VCBM}y-BNRfi27-8m_i!(fy$PQ_;2! z1OZ;JNF*Lez(Crc>pf4e1Zdynn!hSx5emoVb6Wk-OX;&ajn>b}u|Kve1F*bZ#c%_QF$^Ut1 zkATn1riYBEfFBbP(+RlcsEHQQzLl=&#XvNA_n>nDpc6?})byaHo$>5-Y>t z&kttatCIm#U^@wr_`>-nvP)n&`*Uwrw_Z^K`Azx{K(IDp4X^HuM3Qri=^7p|!s0@%sK&hf8z3x;2HwI|r% z8Qv_0R45CMx*Vf7g}cID8Na$rm)>u*X*DJFu{qC>oH`=>v}GlMie6>}3^IVZU072o8V;+u{MTz9^zU>iDqUFf}qcx4E!VaX1YICay%D2wd~qMI*l3O_Q%Q zDHMhoz$j(~-BKsAmt?nmT}M_$XO>yr!_9XL!iGCMS3oG%$GqY{OqF=+`2yXdrydbAH*f-UW< z(U?VO@9CMvtf^jmTgjwn`y33~;9NK1e8V=JAb*N0FK7}HS{cIJRnsdTCFO?mU7luS zJzxzsemwbk#Hby{;`;3wN~?t;!zFK!rT*Ms(aZlZ?|0TL zd%-o=C_Q1V^jW7D0v+*rFDqHDgu#%=yqUth1NIZmk)IrNSXyc(4xK^7>{<;%iUBizkd(|+0a$}Q$(}LQB|BhOt zmAd|b=u)Ty6s@cF)q_KpkEjKGa4cdAf?|^Ac|$1xibtoPKfHe;+@70T`b;^VzkLQl zsSr*r!Rm5kp_JKKo(#V^f-f8_knE&|A_x&k2B>~QiPnkIVaxJwM&g&0r%0je{7!A_#eNY?(*K2xf4p@f&utY44%JbG4@Enc%EUJvURPW56RjT}o07F0Y z&HO~|eZ~7mzJDgPK5C_vvgNII8!u!m$o*zLA$2gya<&3fR5@9}q$Ddz=nwZ!cI51f zcl57z-nvT0UP`2Pj32Bx(=!{EBvCu(V3Bn7_L7bISAzz=4L0u$ z+*ptJON--$(tR#*Dyg(m|KUQ~Xj9Ssg_2%Ez8%rB$akVK0K6E9-JT7Y0B$(=jvRV2 zy@I5k>f_VeP9p~Z<14ft>pNu)y`LFHo%_6$4>FCiXu8Xlps zYjk(6nu6hS(}a7{BJW)Giua_8z52d6NeV^1yCHIs@~m2{=p3#(RG=)YSerT?kvA-L z2aXY)wjiF8t}<&|_OvL>WRvv zeGgQbeu^?<#mY6*jJ z4HuWiugL;eT-Rm}d7-r5v7`6G3yzIHm*z_$mJ_rb3Y#!E!}qJR&tdr)W}hdDV){2l zXHOuGf+wZ%{+GS8A1@buLKnR~mKdpyLbt_U7EM0*`q_1&=b8_0K9|D#n8U8DO@&A; ze^tn*vH7euI`=u-I($n^z`~`}>5pme%*z$$6Wyskcd7jskEF5NriuIHcm`;A|LJyW zJ@Z@uboK%rS)y{m|7I#ktl$^OqYfn8Y4{WiMjVpQ7c6Z>J~u-m$`W_L}84ibFN`~SY@)j$9Jo`DWd z=CcO>^0TYi+I4Pw09mHVeQ>BdrzRS=6D_B~5|Mq2H7Eq9-a`K{o|bN;^gX1f&JP6Oa|E`4$30B*%lXL^p1u<)2zyQhUne zd&-+Zh{o{qiRNg(-}q({!zL^IbLoSO2Q)Q^gdGVa2EiE?A-7C`pYY=q&zWVX^tOck z)usj6nWbq9(iX2@A+w^=d;|Q>rcR7{9LrU{%NF`0U3=mS8o9QFp^2LhM|wx2>lfpu z?q_*82Ov~`9gk&4N00iS{Tw>l>5$QXzH6#zvDM=>cH(C_`7!>fe$ru&p9<~v8}Gb4 zH%p@?zw@5#&7~2QF=g-9!r#ndjXp7y)W_0xDQr(!s45?YB{LNaZJgR4SK*R>bcfIe zLF{+aknyNK)7u0SYkI z8pMnD5N!tVL|HkTo{SPmn|!Ue^ZiV+m)R#SGt&X;^%dPQJBl+*H}+73lzHcEaS$p@Ve zv-mQV4>j=RfZ(H6F!_av3EuuQ9!i?`lPqRmuKXbE$9>>u3Incjn3I5U;KSSE8wxkn zkOX9-aZTCXM>pC4JZX*^kGd}3y5^TlzC6zJBvOkKtp_P9_O1I>g-NzkuWamV$)6t` zlRwxfT>dqlXO~*M{liPI-nq*yOT8pox^hHnA6Mynfh$gt$bO%{Ts`GAm=uXidFr<< zc5Td2!D;pAyY-s2?`|5d&v0qC@|6_0Q;fOBy2UHz@qA3NQ;rdyqsck4c@uKOl}!GtnABm2WM(PcG+bJcdF z8kKd;F-l=IaE*J9-^LI*$>8-8?$W_-I#38eR%S(6?!umoDgVN|!9@Ok0X(bcjp6-7 zUxg3ZKFg^P{d@qOMHYqMr54awW7N|)Bx%Wd<8@sH0?3fu$%{tI22j1ufOtLvWJSpG zxR=qg_mhu^lIf_Cb3)h}6Ym6Z-p-+_tnwaVS#lynNn2UaDu3|jMy*XeGpkxF%;ceu zle9zkeFj}+t?0V7u#KgwIbXl(n(n_F370Xx)8AS)=06qfENedM0r1n_dBP*jX;F>c&q5tT41S2Q&+6|S--Ee}qa}r~O`KL*_fXt);g~3=ZCRNnIrhicdh0~) z=G4G8)~&9v+BLdhbN19Ip{?_!r`V+LVE)CdE!ELCw(NpjMfSF3@6D$ik2ZRva-089 z1qlmo&kPpW+VL#K%-CMci2{h;D9;rK*J*|I<@AVOhm%?R1}Y?$DEjqg(*}G$?Zq;l zCK_Kom-CBNU$ep@o?F@gb91XWOWZgfCT5k){J~chx<{PV9a^jLPwsZuewoP|NZJAJ z64j%N$&A&No#hcf=FA_wGwp>CIY}SfjaJxx484i)f&g>B^ZLS_@c1M&flN=Iu}?F< zyF|YZ6W{*qBLNYIK7C2I=&S3ixZUw;??ujioz17&w2{J3LBi0BcRtBQ4W<77E|JLp zSt1RB5Tr+PYRUWf)RB&sodP^$rF28n}q(}(MLQ!Dh0t!z6hrV7nc2w2A(q{U{yv`iMA#6W9Lyk3SZg4;xK$!d6c*AfFeS8ZTJ6J0Q9^hPVd8- z^PR5wShmXCYV~L&-CqNR_qx>&=DRgQ%_wFj&drAS&DzH}6=9?fx z1IEbapH`q2z~L(L8KD@^P(hEjOtT$di>-l)gWj=}@^T8P-fXnqLXVrNjm9nf67uW&n-5% zwsBW})JfaHJ%p(I*NFy+w9RHlx66VJh06LlZY+JjOc);4 zrDznMVWO$O*?RA7Vw{fe?OV2NI`R)s+8((`aAZ|1KpLr; z>fT8P6Wb0S!+xJkE50<7GaT1rs>JNXbWcZO#MPj8DUzk$NMP+eV*1thDDCBGetXmU zitjgN1?)sCjjs-(N{Y>U7QJH?X5MGoHW}(=s>BCp z6Ga=P-l+^FgQd^(txjbwWYdC9riQOGCAH`OI!xIXlj`ZbA^sS#-T)p@RQrA1ot{@! ze@j4Pq$#;Cj}257=iv?D8IrJ&%H3n3@({waCk!Ay*p0(LFv!96cQk$*kP?vnLZ`0y z)P9Y{_Awy;l>}(iuhjc6tg1e!s_=|7oV9l?`?h>giK6wl0eosc&e{%-XxaOm`AVzG zmNmo00o362Yaw^X$5J_`?2ps3c+4TX9g3hT&yo}_Wsv377M#E(`OjNyv@&~tetYLi zPUY`{a#+g$rF zK5}e*Dxqxe*U>2R+dTQaIkKTOND+PBJ8-;N04>vgfiw<~N^k0X-+&Lm4k33EvL8SZ zzifxaMOoV!Apm5jPp#4!OfzdEVI7T6iTR*$PhyLzXX$e1B~oS-1H0?M@5V$`)B9?Y zw}EiETAh%PZgAPj*QFLd({_-d%Y0*T|3#bUb164pDKVA8FdweHsi*(;5nwzKd@p{L(Oddxaa@CBy2`` zI4q)MLWlN!F2ajN*ppdyi!ciVk3rluv3#uYgpl#uH$A2APk4Uv$VL%P-SoWw*5@;P zrb);I>|MhhdoXpo@k>WwuYWtquI{7lq_OSLHaXrAg)nEOk$U&@(@|*2gvGU_E z_ZtgCYF?!bfTmxtTaFKy2-Kf(z)sDVeDj!hbj?M8d3vom93-n;HO18@9kh0MfV8ve zurqcs78bjRqZWJr?5TqyKe1O!7k|stN~ks%e!x=-7yakS?LX;*&ogk%25ena2f@tG zqI?;VND|cb&~WV^8NJsoHUh_zIkY^dV-6#~ud1y00C-FY0I`Th=WxS;xdn3I-Kg~# z#w=x7Pvd#dO>7&tUD3geCc?6juvd@)a9wExy5y6udL!3=fYPwqhDAMeB1=^RTsrPPT1nXNt$E zRqa=_XNgMY$6pFh3>(JH1g6G2 zvDdZAYuN9;nSXHeVWuC#qh5@B?yGuvw4e+d=VoKA4#;LGzk5a&*pSCLM97Vt_QUZf zDrQHOHAk9O|9s|aChmIOy01RznL-Bq{7AB5j8;KQAdk9`=WQ>}w{~g#II%?TZn%1J zsT&LIOD}m>^IF@Rk+*X;@9do*k=(1p$7JZ@C7yg z6C9B&sPuS(L7 z?m4?!A}gD7uO!)kp^J{Buj)PkNsntuwxiLxS5!Jayzw16qZHMM*N7(xEbzO(t&{}Z z-v(-)TCeiECF;+=1>lp3&$){zO7_4*FoEH&)CR&dwdy0WeE_k}kf~l*WWWTp-#Y9=5QX$E)p5`!b~GXk3+ z*N<77AQ#|wD9cVi4=^7`?#QXyiMj#OK{ecTP7VQGhI0q8NblM zu57Nek%~zzz)RVpqadNoayIl8?|HN1z(hhv_^gcK=Jx@frz_{kpjWcN<#RXErGAEF z8253_=m7)CyF-zegQoe|?*r1-b6?5}KCuQx9k|Pij=6>$g6P%Y+VRxz+Ooq(f;_X_ z+>Ec-eThnL=*lZkiEnnc>z>Vcmu z_`0R=e&^j`M_2C*k2no}uCxkFPuH87Q!n>Wr z-ZEo$Mt)8Gdy9AIKz)~aOz+1&NObS=iypSS=Ap6mczs5Fw{fS1GsoxGv+6;s$(!40 zhZxOgD1P!IgMpDro8CV0gm-&bJcJIU^^2Vfmn4jr);grzk{i}$^4-akubbp4g>~cC z9-`3+$)b}CAX?q9j{q}Hb8yPU#3jkJ7l=tp$_$cv#~s{xhl}4yfbW;0u1GF`x|O<& zSA(*8q#?S&3YVucmAUP81hbX6->E17FcMD~;NQ{$IF{q#f&h8E0x&g}v)|{=1f$VAPoM5l%#*4vk-@Q=% z;Zd%*+xoeB@u6|6T>#1%AZ5gZv(dzWsla|Fj;R4 z$IJKRWbQ6zB{qtWPL6YkZZAj)dTKVEmMmExtk?aHjyk3%=RaWLc$|Z|O z-bu%Yxx|)jO28L9r3Y>*2x;j8QSh)lktRG*RysF?pG*Ov62crmT$zCV28LGi0iZ-& zJ;>bS_&#=xKq;C~48rWM`mVS#w8w$$3omNACv&V zX_{HDbqF)@P!bSS93h5d0YEXQh=*%`|FlD%I$K=${J~T6Uq@QKdX0wK-D+s@f}RNN zQGBbzvL8Tnf8AP!=b@&*CsNzLLR9Lx&`Eps_(I5DkeJ)Zh>M`>K3$joJgMRT0Pgkw z58$rp;@^o7zYdc<3Y9bUP0v#uDq$@gZ@zk|E7PXnJkY0`I%o6Y&hHNKlx5m+cz}G$ z>jR^Wca%-SoIsdOiaG{HgT3cw_zOf^R__sN{+T=v{PDWBVUY_U{FQ2?OD6(aC)_L^ z*ot=_0}Cst2V^|qI)9`jS(!1!s2BYw%}fgCW=({v&Bwy8=67>tJs$_`r`~hm0_z{U zL|t{8HS6W~KzNpH6#+oPGbkD}Ds&>auUFiJpX}-7Lko`cG-KGGHW}1dBsV%B1i9FUJLvq6xYXQ5Pb*eD^|>;ZoXz?J zubR62039I$WR9wZNKNL4 z@Kg|z778i~k7Az?bivPl^KXcCk>p#juG}9>(pUYNdHao$-3}0Gn=HAwWZGQgYv4M& zOZ%0sNO=OSsw%v3@B~7<}^OqgFR|5o5n@z0Qu^}PH&o&WW#g#YVT**yNQes%djKBTu?LEs-B0&wCve3J02Ifuracd>(Xr)9v2Pa;{tq<%N%SWuR zq74Iw!4zWGCrdWfSxPtKRm3IZKt%5dLY>ynRcDfC+7h-p>Ub2( zC9XEpn5LY?ju5~3JYRh4B>I0s*qe-_dS$NL{a4 zva3VR@U(A+CKtn^X;WVF?Fep+v#@StN(fd8+nvvJFSMWviO zSkj%Qyd1(9T5MdcoeUd41@syoY`OCozNQX*`iS(1Za8CqBGSp(d-2)=UK z31SU80nuIKGK88hfF~5i3L$g^2ud}&AtW-ciVqmRzJU3O0erv{;MM{iRm**?4BeFo zi@3}kGM?o`IOX(!!L?=lG)!FtkbeyYh-DB6kTYibNXkIs|9J0_4B|{16ZVq6{!SZ`yZ$wP1{>=6Nq{;Z<^5Et7HNW^< ze<&~oadDF;`?zno3+&wYW=>_(`~OcBGWuT?+WP-hp>NJW*jf{r!ioeLbi9e%3kif_ zZ5wndwekPaxp<-}Ta|oP->|m7e0+a2=v#{&JUXjDTW4a(*~J6SF;%rreo+9Yg6%De7K(h@*Mdk_u<~ot)eDZd8;dqjwTb`$Z-eWvfD^IpNM_&Ed zusf_7d$wQkBTFO?-{y=T%x7FX zQ^qSwlyZuww^;34kI{D*Js&mVHt%0x*gS`@^`0#;Qop=e((k4>(i~leDYHpeCTyNp zH2JbxgvpUzE}N$8Fr3RK`zF1<`a(J z*n%23bKPh{86F}8!0S&aAku8l4b5f5SQ`m?ko zn?ymM)-siOYTe5(W~5K|YMp)~5KyYCdR`R2ZmWuyv|RP3+y!VG(!aiYUgCQt%}>EY zkf3y1a)m#kw27`7>{s321}*>hX>q#rU%P7Yw=n>_`|}JVd(wDav#rh~gV1sq7{NnO zo_;pYn2m!f<(1mE@0O>&ey*6kP0>y0?(t0=fWOI8z4&re@VYJr57qjyMSvOz`{3{t z&B4GOMahwpJ_&8NvxZT9sn zE3bI~^DyqVm%dJWTvTP=O#@~PSJIn`ThIAUrOb#(DSPW>JPnSnAk&o=(JoWgVP9#i zX!RM;=JZ*7qIKU|`bfw(D=0v3E3N3pR*B=YEwuH%kJmss%T2SY-RGPapUVWzrI62a z0^EMFK8@+4(QXi{eiJz6vu3xZw52Z|+x45~P)Y>%VoS!aUzrgtW|R5d z%hg!yhP&h`TT)80BKG77gklcpCW}G`D(vaTq|YvY$v+&%L*4pg z{AN%i-A}QJF`-@`@q#8*h=}9V&r%-gm~Ao!j^8Na{86P=Bks*+GCWpYAtP>DfA-So zokMyd01n&3Q=cURorHc~{)VXr@2ZMrs6|ZATT3g_ug>M74bgJup@D=Y-@wK?I(8>? zps#OYcc~M*5XDyYC2y=W(E>w&+zZG zL9eHg$F5gv-f}iiL-eGjOuj4fb=TgAWPc?93{~XY$U^x>E2{BT6XMD}>Zld@;_#)| zfU-g??tPLP_3NQOqvDTZ#CY<}sMSSp%;zR$@lPJzMkqls#Nvb9o_c5mz>!JZ3L2LU zK!VQ#KL{1`PSt}bS7Gh?^PU2r^^NK?p$TTaT8A3UyL{eNHx7wG(8z^540?2oUCdJ@ zZVyi_MtzGDp1QX5>Zt^>(tov3h z2Ml!-9qK#fKf1(*kJ0X%`wqSj0jjZrz~xw%H)Ybt0}~YI>04a}YPFBkT1#gg(48_; z_w8ZXtd*Vp2b0_%sk4!{6q+h(N?1E_za`ClCXT~%uxX_PCU<*G zar+zQ$KxP$i1t zmd(UJ9Sn8gDqH^4nUTJn>?dBl&avbCNHS`zy~y!Nld;--YuE{?@9#Iu$(Ioic4;K@ z=oJZPr|WyEIZi=%U5o#WwESeLc|Syy?+q?aw^EUzXPZ8-{4f< zlv)QCxw6t}P3kbukgEuu%SOU6wH-j>ys zW|h-#Wt1sisxf^L`}z~KswN3!JEURFE?)=kc8i?9zuF7YTVSl{p()c5GE$?@nOIS% znNH5EFnX%0cNzGCFW;{(mCx&C`>Xfq)Ij5^; z=Pj3{h})sne`qk1ND61IKqyU;BKV@mu$(8JP5#G1&Ggi=JLX&vk+u+^xL= zx6S~#N&hh#Lx$`;M-V&Qc+@bk!)?H+29;Y%PzTU`l~(T8I#%7|8C-%{|+ge z6son%x^rs=7hD~5vBU;5gZa({RbwYAI&X7xV;#-U-z~807V0rj~AI={$-hy zg(7_FZBE0%ih0QYjRpw2^tY_EmlapaZ9@aZ{XPK9?mV@pxVV+on(_V3$js2OQD2bE z*B)nuZKwjVM}?Wub}6;y+be_5K*kD1P#Ok-F8~Cv8$BLL#imjsK&Swtwfqi>9^fFp zSDG5ez%KA%KdKp_elHQy5M z#0p1RU4|!L#J=(qH0dy84J^nA=9oAU605=W$6UbabkPMoWs4Y0l(fsc>U2KKl^GCU z_Kdob(x}wQ$MPecQ1ed7!FXA%{CUf+?9r)m2et@Cj`i4nktKJn{tT#L9-sP_@mg>K zt89m@gQ-p4X^pMYKV{+T!3}niwVP2Z;F+f0>=- zV9mL|`KV6VUKwjtaLJp&SCo5>964;g6~`SnJRM9sB(gW9x@r^TDftl4`)<-*ej3fb zjQfKNdyP|6WF?=W36Qb^CIC!MEWmW)BSO>Xm2%3!WgBGR_l?2LE*X3b82SZ~KN0{W zJ}Gk=-|!J?Wd6#5zsEEBcSaors_P?*Ns@RtpT%I*O_;#=-?&-%Y$vUH@Nz?rRgNtA z=YLwD|20L9soB5dTFrFbDQNH}Ui^nU`)qL9u@7X+I{pN$87wkv#MORF zOmp@FAy@m(_9__%?GFMDhLo|#A=k&wmX_>(3iViLSF${YPt@1R)2v{HI~wI3>(QTA zH41HC(wK_T2yS`Za6D&(-FoXSclNE&g862__;*g0++1M}VDYnG{Ve`_O|AM}`-zwO z$3_DdcG|&4F+B9WS;r%WXnBF~pRBj5rx}6fAs>Tu1z9Z)n3D5t#wZh5%JLJ%WRe_2 zDL%<9b`;dFnM}bCSdk{F2JRW?c}oGrf_unfflsW|XPjoTUrnHz^1s2|606~E|a&GLXgo zEWu8kt9J*wopUz{22T6;M^DQdd&|F0HNj zoR6#W|5P%fl`@SB$GZ(9o;}gUHczkzoNVRFx>pVKtM@k73t~2>GMkcqTHqHm&+1om zo~L2g{nS?6)3-En@a&TcHrTdQ~#|ml0ace+va(3M1GP6=iCo^{r&RGVQz1gz1w(MbgR_|#dwA$&v zeqv(e)chgBZ>}~-{TyCjgZm->27p`Ke9!tN5>Eb>92M>oTCK>F{}#y#;LpO~K?*{| z!x(^fe%E?XNKgB1hfX8lgb;YqmJBIK2=co;|JW2Sz^AM4bTXd;#4MkGrvskq**_@l zBTL%VjEkh;xu^J6$42l}&`uTsuuux|*>Ltupf>tO3}+2VzGv$RMtc@|m5rmXD+h=% zelLBE6p}#>PEub#BJHzV>>)0!ZtgUM0BG^v7yGXf{PnL9G(DS!&~4Y@kmt2I=aLyY zIj~uZpL_g2tEp~kJfhlv z$|eOvDT=>7k0Cl`RWOKJ1*H6li7tSL#dFG=B)3#d-;SlPEQe~oUU3?oIA~RV9Y>bE z-@q0hl->JALvGcPVT;4Fb_%r`WTRe|y>DN)(K7dh%`k27FaVt~GI7P=8X7wLdB($Q z$YoZ}Rs7YRN$(H6K8H2bcPRy@+-kqfeF>8lEtR24#y>i!MP|I=<@b_L+?5)Pdfcr> z&s0d3t0=P~yOyrX?i+ShUGGRRj#6n~PxxC$F6I=k8uyylVYz_vl|I&#WFR1Q&zPCD zcp~)W%aNYp51-buhbpp+izjzC3S~dr{R}axWyxj+Ul#LX_k_kbZQWJ-LAYST#%pd%a)Y`dcTf~!m5w~j&-Ul@LtdPfUjL9it{wO;-msx|L?gFLvso;Dlj8hO zKGGJ49thZKhO$+T{k%Fnx+}MT*Jrsa(c9mc0{n}jk#1X1&Q{C+jT5jc==3kKUMuPo z3^+(p?ZF?q+w)K+Z|~LR=v}iR^Q$DaT_yAO8+?-`<+17{@Wo68kxYR7S;}Vf7l?}G z67o-3-Q08RrK5FKPoGNTd!-c5{G?O?JXvlHrh0jI9^o z1_;Ul!F8Em44N-Q=ur<(I9Wuoc@2CId3t#fc+bi1TFg$?O1Ytl1t}+P*?foxf#f?( z9`)=A!@l_2>{Pvp?(#0lXqV14QSR0AH|U*9kAOIKBQI-@?9G~AYiANJWubv}%L5Ij zm;0L(NB6NeS04LK+|fQ-Y#sLLT@QP8TzK!32d|EIfzU>2>xJq7ua33V1G@XPJ#PMG z1L_0Ele}n}%r`n5y%h~&cQv+&buFlf;es3Z#5;AU?Gd*wGIJQqQJ_HGf#R;ApOOCk zP<^fJmrv36ru}RtTpCqT-Lr&smVV78*rJG1MoSus2*#%_l_l^bLQ|R^+|G&C&zX>C zZH*V}(KGww315ltbzcmfT*P|F-;_Zgtl*xvdYAjhRmgZrejo0j-TWbYuy}efOcU}3 zXT4P;cqb)ww~?9}-qlPHa$_t41~bLI;1GV;xmJ{+Rb5E`W@nm`eOY066o5A1HN+SH z&RKO%K&w8-6Aqu_oDoe~(IGn-y{JO^;;gX;oHNrn}g0D-uN04Z`G^Cv1>CV zweDo-YS`l{l?4Ok(-OHJnd+DPccd%30tC1q)Y!AauQK&EMf>&1)F6tSGV3%@98&Ft zkMLjKnDX68@Q}N}`SdvkkN2%OH{yY3ERP?|H>o>lPM%xPxLVfOioZUI$9^vjat7GU zuJ+PPD;oK9P7gMQk4CE-N3K>~uR6^3lG4YeXLl*VVpn z5CVN|OUCeLfaa5)^Z6$YH-VSh`$1w~6JAGRLT6R^Z$;gbx5-JmI{aH0`S_Tw>#1&b zq#O9-kq1e?AY8AQZ-qQCY_~lOw3Ssfyhyq-YWR4bP5bR=eHtSxBVse>=ikLDJLH;w zyQqpfl={`@T#B^K_tbO8%L#Rd&VvhsIU}~gYgvio<4?q;b42uIa*#4zE4$;~rZ?g* z8?TUfCV&o#-jG}+cS&rgN#V61{pBGD@e&^w#ev+oV#q}f_Fzq}ZaWHh=Wu+sF==rvz&lsfw{j9V zo}P9jL@P=h#yw@X4tmu8TNe*e@>t>aV8+)p3Jv0X0f61k>=eIXC!?N2@B6}93)%b6 z;mO?$QRZ|(j`8*;zeV1DOw-*Z8YwJ{q>SX}`>CSPlSUaGu9c9W)q;wn@4p3JKICet zlWp2zV<2*WSgRz`6%L)JgFv2yH?}RmB+g^t6$1?9Ylx1CEt!%^Q`ZT!Q6B@S(#o=H z=#@<=&w?#!t&u)x{}=anFlFZJ{&TSFai8l#%F)EqtjxdYRJ&4kLnVd|QQ4E?L@*ZOk6dt4s=rk)@a?D9Be%u<<^IZwT zu>L`ddvzExt+#m&Dsiu#zR3~K#(A333nO|x8Ta1|tjHW>r8{$x0t8YC(-bW-9al)t zCOow254EntT;svWY1kK%mfKy|Kx+3=uKSM3AR1ixp?I%r3B?BI0O;1c1u=@cbaxOv zu4%fF$x5m4PttGKbf^%a`2++fnv4dc7t4g<$BOX^e8jDoonRjNl!E&}ZaJqUD_Hcx zX@7nA%kX$_U&ZVh?UVkr*eX_0^UkF((TwE~d;W{}GwXrJ6|8U#-iyUQEie0c zQaC?{7R8Sr|GXA*-dVpdB8~g+@X`JFH-0~6lhGCG!tK-!1oh8etcp)*Tpwt--Za`x ztFCXqnulE>QWgkQ{1(W}YZiFBICpniVN;4QSy~U(8oX4ot1H~~ENjTcE~qGY(EXpH zz`!{48h6LFDr)Xp+0hzDdzimO&%&VNpeXj21M5u2*J(JeU8x|)bQb$FF$7%KBM*K*!qP=Z_l@C6pS$$zXCLDT^jCB78YqIs6-`8dL^|MsY7kz`r#~ z=rpT%Zi&W!Ov%2B=LUh}<9#I7711&w(0I9r@lg4=1e)i5KT6cx=!L2k{UJ$T=g*`O z>g(q}&0Q|k58ps>@_nCR>ptdVT=|Bu%zy4ULZs{CS(#lQuSaiGB{lSIb+=!3QdOze zc@yVTs+Mx0NxWLPnuzqO-&fO z*dggtYQuss)=~4I^e#?s)+J81)9A}l*UPk^6@GJ(Jrt;HkZ{ZW06z>~5)bCoX7LZ!WP90SG51IPi^h-h=99I#wq#)r%V82nq4KM9uo?hRjp{K z-ao^E3fH7lw8oIguqT0v>KLWhtRNdr1-E2_*sEN}4_nP;WuMpSw_CLKQPhr)SFq>9 z*f<4!kuqn7l7Xv3&(7BsWuoU|-Nl1W;GVysFzf41=ksx6uf+Cnil4r}j*+6-mUZEm zU$>mc`Xe@;E|zm(7I$WUn120BtIzv(Wnt(YwbAIZWZs;`<=IGTl%*k}tz3a4ZKEh6 zDypp<6TU*X62B9yXz@;v`k{J}QTrFqZrn+P>`eQc)WDzpu5WQ3{(G~zaj%vGCaCfT zj%CqMuM?T2Q!};%F$MQM0u=poq}igmS>zo+T(x64I6+R0p*Kpoo6GKCm>YPJdzVSC1e&BN1wP5$cS)Z<&*Mw= zaXXF%{=i3enewD8Iw5(OHa7DbYMvLtHW+{JWhyz`qkVSqD&}2Q(XQF1=9iGB4wr^L zvqE@WXsfRMNQtxU)Y%lnhxB9zO3A@)xKg)FBKVxJY}*~@PO z(iv(Jf%Tjh3EMIq!OgKNH%q$Tu;G)hORCI!iVa7>MC3*(o?I?f1wA2ynU=iyv0Lu3 zU7_6~cziKRoyWxGaW5q)jUb8jS+ixuow4k1)MuUNc_VM&6gH5p_iSEm@%0>y-{tzA z`7YJ<$OlS?v+D0%n#r}=ij1KRgDB5M;zj7sK(=~(;ViBpzGNVjZH|B2(+YZL2R-g7 zO)1f|WW`!SUQ)Du6C*E|$ADQ+JPy?_B(`LD1BWO`08{`GS_0rffskxIzIA##2H%o_ z-9~ZW2M-1JcuHLZnDib;_Ru#(kbI}uWe`6U%No^q8y9=Ft$q4lKvt-s+jE;1W1r}IiQ#u^npoA6Ld{n^b_8e2*Bf#b(ji#qQQ!H57OfK*kDmX`GC}_h zy}|gz%_muNr!40eYjUnX4<7%tkvr6uJB(8hKCu87=RKJFqE{4QLKAXi7m^l$I(fhR zUk_RQwP`{qWGMTj2sX|ca>nf_Bjx0Ee3**u(!9U?!p(9N|Glt5Wp;xtzgge5N;^6K6VnX#e26k_jls(aAFREd_0VDn*eW;BL4<0(k)uce&t}<|rp$NAK*CfT} z@lG&$$gcm2(gLHUmaL32f@i_G z`QDAyf@=+Io9&k_qHC5=Q|`6%hVsB=k@ho@IcJW6e))l%(hPwrn&X*K_0_ih{i_bT zdFgk-?qid$j94i1?F4g%6TME#iZdGLXc*hZ56i0(Ap~x<0;AXt-vDH~qJX~$&^*4v z+7rQSdeU~c#@T30U!S%KO7CbwOL^1J|DI6YBXuY$%IbNNCP|Wp!og$R(C^b);cJR2 zs$Yzw0;#(0#UH!P`Mp=BjDf$*Q{IQeFfWBQ2DEe;*wECzQ{k2|Ec8}l={kxZ7|IV#u{#3YNFM}?^Qo>di zVFa3T4{esI*xLmFkvaQ`^Z|i_p56}J9>w+I;gKos%-03^pV|cUS-Dm9DGK{6&@bH0 zIeh{UAukqSQMXb6c=fFRrc^IL;!fSOSyBhCMQS|uR#T{TvAWk)WT6zB==)DIm%B+S zb2=wylXGew{JCiBa~ZPEWh@E@kn$Y4x<=C|o!$`{_!Y z=1<-$ukTimZJnL9S)OG(!nDZ+zrPPp#bclqx7PphX-DP(7KSke9Hpy%>D`O9*-=q4kAHS{D6O(~TvDCJVu3|8aMwogm7f z6ZMig$&{Lp*8HQsU#4Pl3*~*22GtpTeXA?rW-+%WpGDyiq{7^q1o6CQrbfnhLf6Y? z*)p0<|2eGaYMA}offbL*ebX|v)i6LfYBVRRY|5u`Q1}2snP(3-Cq@uA zps-U5;gwd!2;bQgNNGlZRHoet1S#nekf2VhI|%98gBOy z6MpnDiazZXGd%R*&^XJqZ1u~Juj*273ad4UFbMWfYWx~)yk<$=f51@9QTA)`f0sj5 zHC`sdPnnb#0I`U$-f5v4w2Vdl}@6SV~Sf{pb?_AUUcGvejU!Ydu1>`iL9NHa?cG z)95|;a%m&yq)mF*{Ck9|-Cel}R$5>vU)G^KCR;n25Ku~53q6QyZN_iuq*!@nAH8PH zXX(bp4e=1?sr>{*S1suQHcW-ltUcT1>5CXI%RDUs33aR66Ae*uA6&(eE;RokZWG^*6MZO{wsXAS85|Qn-Mr>lB0EW%623 zCYnuRo6}R$SwML}$a{>+e_F`2>Hga(BkPYWAyon)cZ0?}C9&$!Ym*z9S>qv9fA=oe znN0?+wLOI68^YmCj^7=1O2P)x`GGqRlmFs)?J`e*J)(X zJc*!TXXFaawf^0VALed|i@?Vdu+PgD0P+T7I?4!*v(^G~lJ?{@FW-NHljm1qqDGdo z(#;?HJlni;em3zSHP!U`@(zrGkrnQHdFFa{@ISj|AfX2=dghD;@$}X#%x1s-@@NXH z=IXCqV|DH)7kDR8s{->$@HbGh} z@N4RqB5a&~_JZsmYq0sJ3LiCps;Sv=^!rti9>a#y_y82OBmN`8r_|#cw;9>359A@q zG;l~FG%P98PXVuRQ& z3{wzj6fzi&8X(F_6A5Rnckwv+bvU{vF|gL|;k8&6xk+3rry&4_H>me3$~_j0J+x*y zxe^cwkq)YEV4309je5Zc-O8foUUJ!=HuC_?T(KG>>&>Oe`9K@P3=G*i$A|b#$$owDgBd~V#Lol?@_#EGv*FPkKufPfOA+4|N4TB6 zI>1INwy;1a0zeQ?A2qP}Wint8?zZ={I;{bFf`RCRUH_=GVpp9NF*<(;7DSa&tJ@Nq zi0Z*;&$Gi-7X3pqIw?wPCl@Hj9%)iB!rM@)CpEV(0JwYPS;qGX7=QheU`eig5GMY! zy(D1fXXO<>l+$JQe|BB41A&3$5k;19hYv?p)nIFSu}DAFrBW_A*Z%>GiSW_m9s4g$;Ow9>=z^lO zVgtHuUd`}fBuVrt9b4cb|L>B^H22@e9BzRJq5>fJB(b&vp$d=}AVx_PX_9k$$UiBC zkY5D-)G(u_lyJM)CgYMc&C$x_uNsR|6)>!-)egV|AT$~g@AxbL@)N&>xQ@wu8c}%$ zA`1lo;dSI2ADzp;hSws7P!M?Mjh4s2qKOwQSZS$sRL49~eunBkV>=3h#6~kT2rJHo z_W3AR+K=@_D#8y}@5n`)?<|qA$i;m~#-6_!bp}1aWCPtRU2kZTd&P$1y*GTp50m}| zAJr@J%GtRuuk1b`qnx-YC#KjB0$DG&WYXT%sauD_4IS;kvnJ0BoN{SrJwJAXR3qQ0Hbng@AIv|9+X(P z6vVbENQr06qh)w#eMH{((Bt!m#mIZ+6+^?vQ)WLBU!ARt5l7pz@5IpxA0u`2~ zN%l#4$<-5Fj8Kx`C`}+G%8^3;|mW!!A;2ItwUnFf82sJ%ngGK{F%^td|Ejh@13>wbt=stu~v7d zL%oD)--M0j%(#2{b+t>pyHClc(@2Ns;jXOCD?9 zjcbCBysU$8p0FHcGF#hqfZ$}m&9fd??k6~~0N@1)E?2JGt(1_7HeHdj7=;LWGsX~t zTKVU@7#_DrvIh>fH*o}?BX`KVAMrg>A&J4#BgzlEWqp`kfi5ek`j_UzdfG-L;==tpEmE&T~( z|2AH(oPkuzTmHepy4B?CGxLeI2J>O|x7r?2M+a$Jgy+y=m4XLR&a&eAV_Y8uW`4WM zE(cm%XKU?hj`WD%`gG6I8y|6dQc+6OSD>_~$tED@^)NHu_}C5F=3j3>7>VyD7GKIJ zzs}eF-CmLbMMyI_saVKEy1QvoC}lNEdiSL6Pca`k9`dyE0%2(s_i`b0q5xaqlYZzj zXdQgHBL@(nti)=~0R%HYoXmm)L-CgeOgSQ+Sf@l;ynya_$3uwyn0weM{{++c{R;In z5G1Z;TQp?_+e+KY5)jzy{7mBk70gub_4FQAUOoPh!?h~xdw0)n>PaxN@mgkWIH0Qg z6wjwoL;8F}!qWBr7;L{xcN_YX$@Y-(oQxw_eS71_Y$vNs$-y$)>e>EjMaS7r((%LJ zKh#{UC2=$-GSGbVwd92?r=F4sP^NdNFswKds{^wLeR)HG9gqh&6GVtGn%hufB+2b8 zfSFknC~m3?&4g(Z4{P$OD@cA*Tzpi2J2MjY#p^RDN=;Jx%e^3*plQBPtOxc@)Q&11 z7N*>)u7yO2Fz?u^u*Edz928mfUwWO!@t=Rn9=&<6#Kh9XfhO$A*t+6!95}60nY$tf z{6x1^-AL(d$!C!|L@c^Qr8V8*(8a+1w~i- zSel#@C4hbBr2!~nRsBSbiN>8~BvN+%FrZo>#26@spa(65c&Y&*PIkaL0ssnlq$IC% z5Z0;mBOZL{HsO{DOkw*H-8>6lxfx+m-ZE?weibpk!l8|9qoBB0hFFq5{Fif^e@X8v z>&9d2#{ED_ANI8mb#KYD-Y|ID9uB!rM7E=NJd88?G%BmE#H$uAzuK8(Yf)BpFH{Y? z9CI(s_Yu~DnUWW=zfNa=Ov8>0ksj&oKU~Oyauhdzl*z0@n7xuk#!dm)FUu!$SeM(r zva|XJ?5oQ8=S|66WdPnqVjTp`;m456YZQXm20ekTxM>w55Sl3*2T#juOkWVN7B_tz zOM!~(zH3$Rt+?7_FN?+0tZX4t)&C(0X9b`X&}iGI?6?;>=qt(h^(`TSxIv<##;Vus z5vS?gKfr(H=VhJP)dju5gY3rtwC|u*S<}Ol2c)eG$)Hcn zGG5hVTlY~vl}`r=1W8+p^Sez;8~jv<;o+lfxL%5<+tV2CUakyXml)l+mnsj@s$3HK zEvq@!C;mW~X^3gahoas`Z#2nOhPrw@AJWCO!O*(cLaqE-dL`?!vA<+Nj%IK_xniFB zmIlN$J_p0FoHtq)?y^t`szB{VKR4kkNXXE>kf8#>4d457(`VWM$%fvNSf! z4dza#65FfFgF-rrvBk5`mbPTJO#9n>xlgaWf?saca8A^Bevy(gbk+1S6rb9NsoC_* z9Wd!SFS}E{CNufaQTvZwNXq_-f)m<+|K3@mZX?!E6?7?o@L9*!TtJ>V1*mruN$4?ae zkmL_osRJDOa9=K9pmpx+-9j&_)LE0PrxKBSa>a>zE5%9ImPR=C!>q>7J@j_zQdOC9 zftRn2dXBFn4%uXui~rwAWb~ixZZvIiZyoF=G~SjreJaP327^IXTk^^eO|rJj z@Au=u<@0~qr>;@lwv}o(NUXzl%j=dV*b%l>p$lbm?+DSY#Y%M?jc4wm{(LueiNIIT z9IiquNu98B>%eR!r9f})P0VV+q0t+J;ZcyO$#a)H(OLz(;LpUFO z+^}b$Vt^xtsDGRS!WeE&0LU-dqU8zP6zulsnuz;IW~rr?A_L*iUqA^wwF2brT9KJ{ zFS-}rAnz0q-%0NkUIc(UKPq>IV$b7d(xxo8()Slf4;#Fh-reXr4R-xs&oHXS_IAp% z!NoHXT$Mcc1UlhEY%C(H9Z>c)cH~%5`e>tbDQkLoZRl@!UGwYSvbU*qacPP{Y4e^U zFiDv83nc(SNi1Q)*;!#mr#2vV8@bIj(&hTKh8-*^*uCqe9IHK>J`#q138Gg9`q6TequeER(e|C?5(#zu`4b-MYpX z1MZ*w7R~(iX=`<-nfFI%*QKVOUqpy_*X5wT#0T9+!DF%hA>>n%mmj-pJ|4Gcw@6;L z1qvMsoT(qG+9{nP3c?mr2odiz{uHCb(k&JIWQ)~_Q%;_>Wzgf;%_2)uD`$Y{(4vy6 zR-tQT-4cCBU*TD+@qfV*38g3>5FO>H<48Y|00RJaVLT!AU;scF-N|3+;8xa)pH~_z z1=38KtW!1WIiD-eeDb%0HcWP9{>a-JVDx^1@DnMcRgp<^*k7l zuPXxFvVxH)+-ouT8*V9)&}`B~3KeAt&?N_AZ-m@hrUF0$8*Z;i%=`o#Mv()ntkldm zFYdm_L7)2_)DA`jJ^@F>6;JyGMZ?<7TA3l$E8A>2yI*$d;z3@E0qvgk@tKbN9_N*H z52iiEvDi`1T9)9bu@G~U@_oJhsXNZ@Iepvh?6Rjw@1H_d*62W?GxhbFlH9EX$64j4;nvuyH9D7JWm>{VZ zK-se*@Yv}v44MpNnbDm9d>X0|ZMNd?#sO`^?8y>n(E!9~<&&CeV|rldi%B9u1w>^h z;K`6x(qkWRv6BYnM7W;lX|OGXSV54I?iqoVS7Y6B2!KHlw1N>&fJ|lfk+0hOk$wQa z5EI@Le7TjBRpjlI7c|Ti{M+op80WG7KY#xzsgnWdo*VymD_x8eeAL(h?vNkgqyA23 zAhTO3dwLz*1#pz#GZ37-DIKPbL4sbVLCAWdcEk4r1|`rxxkT8klN1j71w@))!m zJc!cN^1*OAou*uYH4mez^J*pQ*uc=RMk#rxpld$bwlQc8{Z|6)FZl@L2)-; z(5nAN-+e#<7N1Tq|6nI!a2)p{onRRtSi{>Y-iwmQHgAAoW`u^y!V=@>3)y|LK_>?e zlD8y4NEuy;Oh;w}&DdkZ)~wHvpW8pj5A8Nq_b2mc{=$K4z_D`? z`~oxnBg?6u5V=CT|3!IFhHLow5zbQTw4C_1IT6&~=6 zTXO-$SV2Z((>M6i@*~Y(ZnKIcs^O`j?+xBFo@7N9k=vGZK2_eEvx&9c8%rd*qEmD^ z8UN?KrH6q4U;?Y%0E`3$*_}(luLaDsXt~)vS*UUhRY+TLqC($$( z9g7#1{JqKWpS`_V#y35LQl|vc+$h_SSb~6UO7SKA>kCp>>+= z$6!uZneZttaAfhU(-G0n3B>xM^reI-AIo-aSN>4cYAabW21Q! zaB?jw*L&kL+druxB0)?$mtH%gw{e-3Q8MG0B>83L!+sN@&-q4+;ryeg5rgw)t%PZ) zJ4zBqZ2qyY<583+O$X~|Z$yyKJ+Jx?4)>CdOH)w9sC%Hw&j}R@@NKFcSRRz7cnGs* z%4E|f%mqQ47LA}Q4dPZ)uTBGBt$k=>P|x4y5~?%%)4YmQ9rDhmicxvM=7C%P@bTxb z^F`g5Hz`R5(QUL?}3OKUm@B8A%ht-%GLO3Ry zU>ZhOv~m;#YZbtaopJdD3C*vg1V~c>y4e}XJh*)X$N9*Pw3AYxmKW3B-u1}M z>_-x@<4rw5wInDl7%+So3bH#znuHdH5{ch z+Wn3aFWrM&DT-G#HriDFRk)>N-*N+O1*HH)Y|$CC36v5jIu&igrKEfED*ku~3%G`A zm8sDyh%R>6L13v-WQ#^f4=iEEt~-}A|8NHa1da_TdHwYh=>;%V$$=>f1PYI8)25(Q zC{$R4LL>o_?dKIB?mSyD8!^FW5!*QWLKb33vFhUWEX z66$rX?k#m{Wb{H|hg%`jYe1_4XHUVZ$elny$Z&2*Zk&neI>;dC$9u5+9(;nu`)h-> z84!cLhae!37ck!?Qye%InGPfqw#TOHArth88iZPJxdHBjq^DW8b3HQ8-i#bYF)pbv{*(X5A8d-|{JYMGI&K`=LfyEZCs@u4d+P&mwFRB4+f87lDZoY>k{{om4iZiwLZU}HL&0Grd_08 zx<~APJQuvGHo2+VNAueAdv~AnRwpcg`#`mX3C1(#CF$q2-HAI+``%&V02GG-avuck zM7RK@*Zj8+fCRFr7wM}i-@^G?1*<6}`id!ylvE`T31B`^!Pkx(jSs{JNS@c@&?9h7 z52>6z{}Alq!lkOP+&+~O3Q6JO_;>)c=5YsRdYqNglUP}l;<@n>j`Pt*SrA( z1b;GdCI_fYu5%Y6T5&;epC>!bNwMA9E6%5s&Plr9kPm2k{ljt;2}Z!MTIg``+x%4N zUTJogvcZINdLSN3)<(8;-)!%Si$p?&9Dd0^eAN3R84Re8|MH1S`7y^w1z#NY9*ZZh zSWV3yD}g^&Q6#M2^>9RR>kjCMH${y6ae_7y!TDzhlr(%I-^6l(0sZ%Q^!Eve1!{mL z>xp!892Hc+p~+|kH1`fXZo_|Z4fxM))Jiy4S82~q3^Be7yeE2?YUf;;9Ch!=bDp?w zlX3NsDcDi)$U>?jN!iU{{7iZMoh|ZbN7}yDMx^iJoQt_9iTUB6SersDsRQ9W3Gr95Uzu=j z79mV-dZ!pFm@0_wd-<<1qj1+Z-B8$FOC-umxt}0(54*SKiBR{+*PKH;JrE9tHYbl~ zTkyC;CAe$5#jkdrxnzO?!jEc&x3OZ=xU^1ed^pq8KX7gRDVY&%~Mgp`iQ3ga)QG;#z+nG}7}@QhcU8x;6&0RMWH!9tEvhe?c66 z6$Hg88?=mef^t}L7HoJl^aWrd#7pTtL<4-jH}SR|>7)gJ{NudLYv3@PVMOM66ml9d=ZZ6lR?|NBzoy=W?P#|L0{Bw$v>lw^o!@b z^;s!Q-}P~BeCtFs*zIn*;cjNPp*n;7clp>5O zs2uJsC%OV^<`lgQ@67};a9#Swoun~xzqBO3u1yu|ZpVZ-IK3NwV0{#Gz$gMHZ#Dce zI5d{^$zBCX~O?!G}MR-{`X19|Ji_iEI{Z|53xrf=NK+pZ!J~+ z?c&r*{L#m64TXeCH-bc`ijq?v|EG`%DD1;0j(2z6t^Tdfhj327zf(p(Bt<)Jk@XRq zEJNsrXMW#yP}H=LP?COHrTxCoHQ#?~cD_g8+qR~ly^RV@!rsW>M@4sjI9m!*f!UGt z*yXLLo(A}0qcZcK-;)pe`o(Jb!boYuyF~=v0o+dy;=_hlC4?Af)@sCmf$$kaHw)PR zwqOb!$oz)F!d+Y8GB|Mml4GkIr=%Ux49}8=U~3%m3CEPp8sAy3jOysYV~X^_T;l3` zzJk(wtLNvAvo#F6?<(%zT4&Jj>hJG&YqjF*aEdgiiN5_4REN~ehDYr?&Rw(p@}%@T z!oozN6kJ}*2z)kv*Y;E&(xIaVpHuz{;N>z!mk<(8066QN=|n`bnXRe!04CQuDvYKk zb$y5r>f7TWch2XcI2fJGy!JWA2TrunL5{eNz;S*sx5?*F7$i ze%y5%{Th_h)t9~B{e~s*{ps6Y2lOXJd*v|`BmU2eHwQhEH%dPvcOC!ye=hUcUFyc9 zi`A1?QlQUNa)#1Bb+Gt8d3;dreiJ`;%Q!!J>Ilxi6+Z-;?+C~2B~ueWK3w%z#BKk@ zTFU6s^=u`1flo{K z;ssOCi$$XnY6pOUCAx%A9?qL%ezujAzE!lnJ?%VLd2w62aPGI@S!wXj+|~T;`N=HV z;nBSJK=p?2{`!@H&&Dkn&w79X#)>78-DU=16h2s|v~IDEu&`+Ha?3{QZ#J3Ixl-v) zI32V0++W+%?JV5zBJweazDl2QNNG&Bn0x5FAQHYc$Xsl!E&S}3mCHFwC*(59BxTU_ zILU0?Mta`oiA8VJ%8RQK!-0~UmGMuD53ujw5|M|x8|wDfE73AO|b0{)%m+E@=KyA zN51*4hK$q?`tA9im8n*X1 zum@Mly_MW37%<@#ag9hh0QzJ6rI)&f^0^Yj_Gg`_IYoIZxjalbn7!y_{-#OFFjI;Q zOG+I&Oo&bHrk+h^1Ob5gh8`Q`Rj>F#64Hx*TwrP^&waC?%uXCDna# zHG6@eCW%<7P`7@uju+=X8Ll;woKJG5AnREf4G+wCS|XubUS+;Zp@|bhZ2gJVYTy3q zYQBHtl^&8RY&Rz-8FRVGD>(gp&AYH@>t{xb`bEl?8D+e zl->6QGS2rTQY&Dm4h<_FI=??Naoh4%tn6Lp68X$X<`2>ttkI}RZaWko2W(L2`{7Iy zr8iv0w}1VTzJCsPm4D+Tf$e~ge!r2%qG^fAtGCcJ#ZL?O2l)@YMABC-hBA-w9D{$GfhK*rLCVIlF_@&t6OVX#YWxO=eGL5l%R z!w8a^urc#c^K#y53$fst@576J{f$4maJN+r-L6}^wI{KlLSLv?lUnre@$Gx``t5b2 zLRiWiNsV&O**17yzKinIL&;Hz^HCEw-@yZO<)>EPR0A#Ep{Jr}$fLY%zK#3W?2f#G zP6&VN`cRpnt+?P&fy7VfPyj-08YJ&c&@i`B= zZXTq-ny@-~4{-QGhs6giJkL9UAYmcsL(|X-SNPayT&1a*8>Z!)rrPle zgcq#PJe_t;`_Z*juz0AogCzg7ux*~#FwS*`14P@DkLofF-Yps`=c zbJ16q51;eU-_4B~{f#s*;d-Yn<4%wIjW)i5OPuY~=JfT8T$@LLj7Dr|DQ(nf9;wO` z&aH-l1j%vbxsjp^U;Cm(V=xO92z$kr!AinbmUYcbNzS{(TStnbV%!GzsrChP1zT-( zFkm^9t|!ZENWUQ6ok#>2@pBJLHs}ZEm-(B0q#Sx};fh_|Oa1ms@-5T8!p$4^%)O>V zQB_Fj%K3&Xdm+%avuF64`}nGn5bs?2ZZQlt%t{Fa%b3sz`laI{_kJy8Zd%Kekpe=ULHA>l9|?d3;AXO| z)%3%q`1?o*=xUAjvYDwvQ@cKshriX~?-ihAc*`Rc;Qz#HN3*7mq)U_AwVKh(8CE~Qml?@DOq>_j%#$`Vn)Ph$ z-&T1gxck}6zQ6G>v_Mdubu!^>6g0~1^qvzJM=KH_CKIP;yaP|(^xOhQx+C6=E43pb zw-*^4Q?7d8xj@YMv$$tsz~t=uciw;}WMbns*Wm4(e{n9o$00i-1T{npbx8CsPXCp5 zEG5Uip02IvO!+qE4c{9)gv}}C2y)GTa;uU7bzec+Q@@(=C0*kp6?XT-kf|w4R=QR= zO-q?*1G8YAcq-RA{L#M5zvpGclM6ZuSGxGp^S0~vzUx4~BV+?`c>-4MhuEgeT@2i! z=zcK;+}qTxOj>}@cs|LL*&zmAtcHeCNOpCKZtfV1Zv;d$kif#7S%Y|?_f&9k00{TJ zi_m#{4!YNe;_8Dpi`EMR|#%Iw}4 z>Jl^18h1~P$5HJv7o4f=^0?G{efZ91mP48!rF*9S0<|{?M~}-g>0{Q8U&;tY{q{b?Z!HW3~^Bz6uhGU=8}nmpoJK zVv53yG)x4D^vI#qYA|Y{XlaJW#^!WTC!r}eXF@GNM2HzYwh?d5 zbOd>xit%bp)W+c8uPRScLh&5h#C$P;m)piuf=jHj#9oxtg1)y38~vc8Vxz&_&Njnw z?4z}8T32O+_O4pRyR8=ndXJAxZ@*5!Hu|N*eAzWnV=a9Y=S zCt9|r%b2!4zinn9I>3FE9#yZ6_K}pWJ(_0sQd>BeaMKFh#L-zF4T(3B!IZ_0Wq z%ngSJ5+0a8kdrY|HqD2A$;g6WC2!&sP|L!zjo;6%H}yN|<2#`zwO$wDywP?eyF3E* z_;X6Mnl<0C_So_bzDuauecVkRMb?ki{-B?xjJhcH=Ua=`it(<_a*mNZFlO?j-OB|{ z9lk2PX9IZa9HWUO=emFAY6ETQktzCbFxw*b2)2*i&qg**lAqsBXWR~-%*y$;a;mNU z`QJ(Cx;qJGwnP?j%^@ya;S{aoC7FLR=S}^bB5~jOBj8S)M)eLFXp8 zrnYQ&dInP*@SMB(!$cyDIUu+$>)c%Hda&_F5L|}p5U+v<_yHrSLm(>f5iqU$A^a(B zASkW;dUK9XnH<$LIPjK-?vEJt&%nkWt-(zkkw_^mCouPgA)5vK@N;5ze!s*(ex97d zgmgd%c^6j;s1NTGKtm6 z4bL~=g!K+y^2*F)Rzl(n^ED zpc#kUv0ev_oWSzjwb|qlqWh6M<@#vhTDvSG&6e6_Zi&+~vykjz7PN7#_$Nrn3gs0B z)x%yZ5!JEQ24<$0guw7~F;$WkBa277^V*tb0Sq&zbmC*g?yKH#C#DVSXvxNJBgwK4 zn^rd4nvoHI28C$_Bxdc_#G;j=5mm3@V~?YzefAq*Ak3bvK$TMB>Jy@VDWpWFzgKMW>1+MJcU zu^ha~v0<)Mp0W@wZbjPD+I>!Nj?N-bHL5)ji$n!azNy*NihRj%i;04hwwy&jm1V`? z5T2IOpB9GTcovxR1CU52W_u_%B2;| zR#tTc7KF5i*wNsUi6S}VGRyb1k0Vl;zE7|s)T8S0hKYAwf!x;FM}yJ*$1PITbGC|& zH}#SyAf=yg1dAd9Dk5gWxw`8;>e9 z?~5?gb$qqj$MtAC-c3QCu7-W*Knrjp>%h%cNK9i4V*Lbvw2{qyxB0DWFUcmRC-2K63QV zeKz9jmt%Br${P*3P!>u85$TfeL0`bHv1vWi>B6aIjD3O+>xMf*VNQhIz5M;xxp&8wpT-z1@$x z_CHxx`6oNw=<2+#dugsYK9WoB>uF=T$0ZWrXrOX>GeGhXf-B%1=4^EsNc*EQpq0s8 z6=>`I`ROGrQEI_OiGn; zS1~>fi~s$jdj8Bz3;RF}(H`h$DX4fS6!CX~VnLS&KSp63>&62{5R%_g(RmQ8v(DKilNLr_)9jEQV4DufdeCNHHs&! zGnL{xASqMI8JR?^tx+^5pVBytjb5Amu5@dB^m<{%az;$Y&~#Ni!uQ;BU!!#@y?tMR z;jebaVEvUEfrQ(uO@Z@J;aAqavBMi5I1O~iTpLUH3U@BME*by~H0A$H#y7DfwQS zoR>E~XfSZwWI_4MzboW-y{=B_CFkQjPVv!N7SUPVUAw3kTP=H1I-(~xIrhPWyI+^e z_Xyr|{(=A|7ngfp9qw2XK6#_f3_q_=$_rhT$axfN4zlbO8}~YIgCh$tV{_H>S8AYl z?+9H9*y#(Im>LxS#YfZ`-148+Y!mtH?l%3~o!Q#>_cueK@j}MWZoNGL><+V)>XisM zx~GIO*&%Qq<@;@g2?UlY8S-L0=czXLw2VfiI3=}nG~@G4 z9c9UZw5K@wJA*NU%f`Q18wX=NXPaZ6lBeviD%9z_tvzhpdv2%Siz9xM5 zyc|8foakF7x3C}P8~Yq+0c6V9P79*?ta-uC@>qA!pdPzvv%#M4=MC8x<(6as#-~(< zu_*f>f2X zd@q%s)Pf?Rm#>giKuLkUIJmJhZT!|W(h!gZDYtA%ns1M1ysgyWvY;yg_${x~^Nlxa zTjjXTeqq4o|4iQj_XB~u%dKLXpyC*Tk4Rm9r}5_QLw;4cPP%=DT*moVnXq(F&ND>+<&e0VKUV0en@4Gci|1;*f zP%v+oRCmH^-u9M3o4r3;T#ckP<-MbN@&a}Ux-goBIl%(^ zPrQGOP$R~XN>_3V=3NYg;MlB|An-#?$)EHmbsn8U=$-BdQ3)SOi!mNkb;pybuUq0? zc2t`mLEDOBgy;*fLV?XiUM{gzYlFxDL;|H)7jkEqkGS+n#dQGAb82L9m`%K0cNn^l zcZ$Nac*(k4*=E>3vUEZ|>UPI9a^n5hPJIHJGaJTtVS_fu)yTy4vbs~ZZ;83|9d6fi zQ6eRbyVZ;v+1C5BP89Bp)vk-)bZZCTNPk5$mexb|1tazS&e=wgN!Zh^Tq^aEu=CgP zegqDlbtIrT%l(g00TKuZ5(Nb?bePGVk_I%C=R>-+oSN{+uJzo$xXSk77=;4uX!3qa z$=>pLd{mp6m|}{i{nkS{_s0HjS6=6=5kZ}mJwINaSnxk zP+NWjj*wSLws2(VCV_{Co02CRrYvGb=DAn>uY?zsL?a$i-rty*8H2fL>TEpNfRbF? zPu#OpDuU;+zdSat1}R?WlfdQ-#ti2CYoJjL;9d4O5%+#&0b+xvbN^~QKuj=8nmi4e|~C?mO;}GCdx0WfX3ktn0(|TTB_M9JQZ`Z?A}rTlsz^gf#3` zdv>u{$*`pcT0V7Zlxarhl$_=Dq7OwPlc9|D+Q*9Nr#@U2co_}*ja`)9%&j{wI=Opo zCVXZN6FuvyH0vq2#HJBC)7A7^6mVb>0x(A{StUoQz&%5Bax7ge+EOBaUo z$sZMe;>pLL!^qjg2z#QD3y}pp!nv^$f=I)@Z0G%}j+60)6Wx)kh8@Wv-;%?)|B<@~ zmUo<>+)VL=UML-B8@#5;Va?hx67-D}@Xo`#izd9e8ljjiGrwS$gjym|s zFBTQa^AeKX5jje6=)G;_h+*EuO;=muxyV#YKIgBB={WfYFFSO6IfyXu z;P!hy<)lnT55|*XR5!?6l{*qQ}*=Q3LD)aI~8HD{W z0?u!gW%pp6CKf@_j0L5Lc=s)YKD4?4u5*keEL-;^j4IqOSxe^oqKJ%5+}TNc+Fen& zW1W_J<+U`$bbjih{8M<8i>SucviEXyIzkKsQ*GgxCCZmh7pE2*Ru}9lzV2H|gj{`a zyabWB5RWxI3BLyz9h4T;_%nm~kI~J+8czdBjNdDZ{3xc0G(Z$T1$e*N|Ma96vOW?y zAgo|NSR&t@JV74v0L!oU8!{)5{W}snAlMO#;k_PHrul(sf#1o&0Sx;{Y%kmasB}c%p8hCMhlIYF>UjV4j$8c&jUq&pRZz^=$E+7ES_;G}IT;xAnG5&NPymj${gT?t z9E05v2BeF>@)tH7Woc@1O{uPyQ$l@J;iY;Gq8wE#Dfv|H^Bi7uhH>IhQm+mTg@IOj ztS1X^e<-u%u}b^rP_U-1PC`wcgu%JKgjl4ugE~vpsZU0f$);t88K+-b4h72=2J31L zm7;vo8jW)qM#_`QZjyU8@DQg}nOY5oUIVd6-etMIfnr218=WmCfPLa=d|x#{gECh4 zsrFz^cwrQgkB&if=Ps7A%aV*iLdousw3^%O#lssL*6R|6f~tnghF~7^{IGP_oTduq zeXEm09}D4;VTBA=`mPl*^k+)*HJ_90qN{X%B|CZ@GW2XznPbtV2O8!)v|=v0IaGO1 zA06###bgFhR&4N=S|x)f0I&KV^YpR_4X;vRcd}OL_1HSeGP7roC7P16=Jh%KfEDsI zPOiQ0a>{-*-HTAx9`K2UBA-ig7L;|fJvDfp-SaOw(!R`4jj)QzQ#5b69qhrkxoHJr zZIOXoUv-n`h5{NcHM}Oo#|G#)f~}QmOW5@nG9JZ|z%*kCFaYM);OXA`ak3u9QSYg{ zK*=?UI1OuNuN?8PK>I#~gzi+ItowSIPb}cgDehm1FQm<%={7vozgXnnN-Wp>AE_=k z>yM8RTMn)t-v&!@;VsnlrUq(XM_+Y*Ihaq=LHsr0->~G%T5fmZwz|d=izqZ}ak@s|AChq-AhNep(xZ;8zn#qSPJJzby11(9Q|J?W5 zsVCx|5FIuHo3OjfSc|~WiijGl%!A$(LLnob9L47_JekM2QE$D$Pt`1zP3Y$Gw*`-h z+FpO(NjUS1vkbsoHn;$g0+e4$%m88rsQkS%FahSWU(4wi21$(?V|26gXBD>Gu~a$| z-d&c=@t9`p$Ciip#1=ht-^ta{@#7>xa7Bd=hRf;b&;Rs0Q>BE^ahzr=$F`nNIxkm! zR=z!@+ZpykTATDvm>5jyYzWn_nB~?x%_9FPDVzndN1g|cSP{9|5D?GYOx8$>{~j(6 zuXZ~dKCX)GV;qtP-ZxW$F%uzU03$v+1 zO#Jwh{J-zc!O@lh07m3KnOEwi#kYSCzO&Bgvt&&>oV%XAXDPD=QHU*C&Mkl1VfHcz zD#ty3|8Quq$&>{ove*2V8;G#2WVlBU&{87THtw>4tg$(|;2hpqlCibR3q2XNPw2b{ zI)Eihczbgrm+ahETYaYeZsh(cdYx(dDVhE8y~Av-I1HXR9Kd0h?%A|5NWxmz-nyQG zM`o4&;$r{Wc5C|f_zQ7edw=`o$%e;^mcXB;lUS@YtXc0r!Rs zp-;m*I|_~F&i7Z^PuDQ1AWQ%1#UyeTpp>UwbNyIBRg3_R_$MXMG`?SKJ!Zt#!sqR? zXjI`unqkc(PgM+=d8(P?6S0l0ttC(?V)dQ+~OO|Du!E~`i zC`4%1=J%WsXJgY7F1Z(A1Z{OvmZT$t*kJjro7j`JH7aCrC*90>`INHF=XPN=jcL&V zb9g2~O8<6?{}uf5L3S*jNYOa{nSMKhpsHpi$6?-vzJJzO2S83RQdkGtkl5 zPrdMbTf=jKO3Ixte01P%N-Zm46NLNm`CQaS9AnGdCsmk;kG&k=A}1oT+r`{lvXIej zu~lpBX~@L+=wcY^3t@yEU7mlW;mgm=EH5%h;cBw7>VV1ngzf)EbL_wFfILPDn)Q_i zu&l6WSFZ2AxZ{n(K#AHUp>F;hR_QLsDX%E}O_A!xoghfdDm8)>Wi)}~CAL9@9W@R< zAMobYEMKzO*X;!CDLzzyzo9tIz_&Er9?FtAJ<3jYqF6B}4BX8rq{<2|kC+0$Z1`oo zFB57XZ+)Kv#8wX?go&if!>7QxCY?npi70FM=i>mXmCYF{)8m3LI+Jh%x|P+|5KET^Go{T`aC6*qX}d-0?wX@$}ogRMe6 zc$A?7QRlrC!_Y`;@gd$Rm9CwtLXP`I4`s7jxTBj_^3JrcIsfk5ev$RIO5kC4zuz zLr9o!P=@i-aklt5Z6VI>K(gWV@r7{(Lzyo+b^I!g4}AAaC3Cgz+5Ai$^;*1a_(k+) ziGLl-(*~(?92w&r)gg~#(|jzFzWewg4RBj%r6H764*Nii*TkA%`lK~KrEmQCA@?)I>_~@VC~!g8>qhBLx9+Nt)pIJ6nxTk zc760K_nK(^a3n4zug|xX{I2{|+$By4_!s^P>KiC&YB(~8Z0qaDJKKY+GVOb5+VaQt zIxn3q*09Rhgq47n-}XhE>pn0Bdjdq2ph34BZ(j~6Hr}IdGDyT~)!&DzNnc}Qw*__; zLrPO0#Hhu@j+FWNu|Z*U?I?hhC{B+Q)+~F}wVSXz`wYA@PhSFuY+Xg%WOjo0%5K6KY~t{mtY^6-B(wA>}?UKo*mF(=?o5oxGDq zxjQ;%hq4BCcGDAXVl%UunBB0})fa-p>n2%ILB6q$*PfBHT0y=``R4fG`{i$!sFlMh z8;OkH<$u5H>TFBc_{PfVR=hc)Iu00m)Hyt_<}~H}C^vM^{i@N4}UTj~kjpPOwE@-hb2Pk?a0yf6^$cNUEAK z6nqN=GpzKD+Jdv`U&@L{d&G>SNW_!~(dAlHo{flwK4|;=sWwpHV%xxuHgAvTZG9%V z2;FYie`tFsLD_J9bmEBKvYy$BRP>5>xlHjhM|a8HCIj(X`?s%)Bkz=V-$fl_?nbK$ zZdxZ&TjxIOU`K2A#m2V~4_!juwac`TF>WTS7CL;qck zkIn|uMMTRjSCHz)M8~jDKR7}}M~}FCKeUZoW!(>y!ZkaO+gO6rO-z0Pll@p&g29jt zZDl=R)IZW_R#O6(%&2Le>usCmjef_BqH0pc#_3Cyww81Nrx!wSc3;@IdKdQ8zsSn) zl7;z~BV4%R6?GXdT9_SE3&Wa}*SHUeKVgO&)9QyYJxJ8T#mRP6bTX55d8qX;GRZo` zaE>$V=ruSsQIRYftm*7lLjJJ8x1TPs5u8rOjogaeVa!P`<~O)tWLR)u+Qxm-JvNIkNeF^8)u^hlgX z>EdZpG@eS!O&4+MA+?cu}(kfQGIpo4Wa zj}%6=n~OH$%0|XUJAJH#UTVhOHUn)Q?y*F^d%iULDhvF9gL`?y~Tn6`(>N3fw@g*8$75zCS6>W$1H$ka4?6J9Mafl7?*rKhSP`y6xE5 z%cBWe^ANBGVw!b*#@+RhQzHiOFTV$+w+k0Z#X=>J1B_XvX`ZtUQE#AdE$IoKapEt$ zl5MPy%xmXe0$U@TO^1Muz_4;>wI1pfqf998vq96<3X7!EVPisr2k2odW93->;GrzT z$j?b~JJst<=b*eLM`*C-AaYrE<4;8z*|auFFf6J=?b^<5C2>YL3hlO-Lw}NOCtP7y z_W(UvYOAx}&U?;g+Ia8}l;RCW?J!3@OJwZChA(Ps&c&3;Mmkx(+StT661i|<8yqGs zv|1J34|E_#VH<>D5TqvfD;enQ9|`=T8zvc8X-aI1B})m%;slh9Xq`qBoFoukA2SmH z)&qRNM45zO)fWg_{4H3_`lYu%WtwEKaQ(Ls0ZJ&xl;S>pWUdfc*!9V|2lnanu;}ZD zW_n#-118raSjSk?ZJ<$vNubUU~nm~)F+RPmqS2hLFxFb#N#eG`< zalJ_V&j`IKCz&j9RuYwKk%4@6di76X@sWYNe*PU!5L_jiy>1)K+qhjfRLS@kb{m#@ zwCR%PqTmUiFKb9FD)_KZ_guIyO_L$C*9+}_Nv<)h}0|Q{awz;4`DfQB3FSxd0aAZ2mO;V4a2s3 zhizI>!?)LI`H_&NgZ+H`4%Psv!TH#=uN=W2O{vJL6NS$Mk7R^;Ak^q_tBz487kVtn zrhnw=@!`Egy<~nU(KH$lH_N&#ve>Go-MF*6xx}C9xN+qJzj1Y%hsS0IuUJR)m6~#f za2MQWB#QeuT6fS1%;>_>exZAmo!YQ=!qj2F!q0HD!CKRh9+H^s)DksXZtam8X}(!8 zZ}*P(?0T|Nd5n8)voCo^I6kg|S@MNQW+(Em_mrEMWDnQxhE5-ggCh^LHzMHW)Dn)o zWdd`H@oT#`Z}_R0__0|4On?X=<4KEpUiKNLoluyj^NO1k1df*U=+<6KAw;Z_pe$<9NaxKuD>6 zX`;4Zm-niB!W)D8=w{ZT`^Ld>0&d#s4WKe}x zVs~@Vz=sKQJ{$J86({iv7xwROZzQ8=$_@Xo@1MEr`=@>D*Qdez-hYdAiUCSBmfRQP zGA=E)Ru;|*6~9lWDV(MF_IAAXTKM1)P1791Fuw5fo7das(6R%<$5zVlAgg~puag_| zuVJ)8f`oBE(B=Q6bWhiLeV#wzNq-^ZjjujW1hgWKpj)4VV!A$-Fh2+$p*9!oFNcWD z*18!C_<0i^37?Y>BaY+~I{L_YUG4{b#|`~Ng~GwcU33h`Xu!La^6+JsHbF#w9JjFV za5~g|3Fx9I4EW_tyD~D{i?JQxx2>#rxfj%lLL>X@Iy`n*I&}8~v2#;-`2s<{6Wx{h zI`ZcGetCgCdq|1rUMKG{ymxaj+4^cc_o?=R>CDN+pRz-*$P_;sWvYx$kH#VHT^*tY zq)ubOeu?Z<2d93oqVB6MV1PmGF7v!yW#?QryCCu3r1JM{MboCCHSPBml$Pc4gSY0m5!W3UnL+~RF z0O>@p8cGOU9)ah$noT@8v=eEm>gN3BlIdmaB&G-aJaDsI@dw(FPYa zseV7ms!mWbXNIHvGFKm|B5+P4$BUgYsocdZ4g$Gm$%Bwl?n@azf#mNXopU#B6Z=no zQBtwLrBVd6jV>J5n1JHRN_rt8l!u)uZ~b{)J$sb7aQ?zesm^*wc-Xll^^B;t+ zg;9Cm0~~MkXaGcb?fEYi>Vds3*MZYt6PNhHR&1!ioG%bKQDu!IHoPt6V$Nla1^UiR z)xYE#RWC}IA&BBq^vh0ASjj@&+i9hJQQcAegk5%{PXPXWqe@I@-6Xp;z;5gt<8Bu1 z>C?Ux*v29(PGmFAx>;tBwq+8t*s-85$xeG;=(6n0TC&vCZ!8YJLEb&!eO_(*Rp{im z3rmzRHddm!?g>lx)d1k{nF+yWT#U2sA2}B(@J~&^^NvrMD1QS4<$)XlL3B?hLL$BS ze)mToFldbiO5t$LOlHbOsI|z8UUfzbI2 ztDqo{*k_QjulTh_=iTQ9cPPs5_;bx#=i=cya8Uk&)-a(EtmtLXP6#Ya#K6(lld1l; z<^8s3)>noOZzbm+=X#Q-18)D1ueS_~s(rse_Y9pg2vQP5gGehKgEZ2pAT5Z1gtSNw zAq?H!AR*Es(v74HA)qwUpmgWiyubhXcCK^HH|7g_o@d{&?zPrE)UN{I-_c^|UFGKX zyv*9k+SG(1V^W-RdXCFnH89gZX28+8WQq%g#mSK8l%9_2SSNswgdlPA*!DY{ztiQ4 z4vVgZEwr0p4r#Pv2oG5SbW`m_Uo6=V!ERH^p8CqbT%rM#KYfbM5Vh6!9H_;?&&)%6 zjNq1GNYd9BxruO#1FO9}IZLr)582g6`5Rwf_c7w?{~8*5bLQ=yz4Ll}V&DhXxl-Sf z5hnMj`87kZUH)>rNVITNm4z0NYKp?y_R^?{o^g-qsR|C0fAg0)Fy>2A&#rhHYo=Yd zOIlRt&bNeieBLjKJ^2*+5OnYRdsFsIeAqpF-Zjt{rhL&~j8r%>C;iV@b6sgP?WBoO zUS=P00KI3 zeTxTPG(jmw!OSeqhn+sglP}-NXnP(EoScjey?obPIyUz1pyE{5PrP|!^7_2SbLD=i z>i>>SH1W3l{$#f_CZq#F7k!&_D`>znt&pg^*grZfXgz8zz+9h37dB>J3}Q`J&Sh&d zwlriVO$nwyPw{GWmNu$kVQF2=P<1t#hz%d`AR{4TlV$Ed`y7@6w}$L-Ugqs1|8jw3 zXPn9$Nw~z~r0#w5zfpSK`BMr@ABb0e&OD(yu#jTS6S!Rt(>lD>%gYyOA(PBLuBu^cOQpzjthAlutIgO73 z23|^rCJj@!1>{qQd#l&j9w+iKD;Zf6|2RopJ>KN`Z7-dzJ|3Q~v*v3)rQBFqa43o* zX^1s6cox(ZcHcN9A(g$Xq-(W0%08v{Y&Rmx6wX5iPOX9ye*Ls{Eoxj-nuLum2xXQ@ zn&x!wj-5kwzD6i#w|q^{qQcj{bH*UOZ{f#de3sZ+LNV_2e|W6zEgq{seR9a52aa}j z-Wykeiis`K%dU6VLxs279TsJl z!;o3x$2HRK;G%~or<{toOm)G}HPRuWHZmOppB2dPIPdSaI0`O|t$pYWi0UB=f_o9h zn9HgaUVq<5{K)U-^@NKb zN-0LE2m17m>}(5>h{vU%vD0F!?Xumgg~Gfq59@##b1Uq+)NBuE1CtAbi$4oPIlb-%2E)H_pKvMc_d z;r&!6E{yaZpPOx`nn~$$mf07;8N9lDUniaIRcSC;g9Cu2NG77=$J(B4+2pSK9t0ItrW0Sun_#gAe@ASg=-Fbh{O%|Og9VTi5jjF3oN}_cAn`co}|LrRo+&>k9 zKEX%D=kNk*yd&o;=+QLcnEyW397#z3}BIK!^zARPH|n-={kWLbeAZ5*{f^cwYSUGVJTgA zV*I_~>_Es%$vmDL7QH@ny2Z-OpPYq+1VZY_A>6Ce8=fA~>&!3tvJz3Gu>C(IR?T-F zrdWn=asONEI~RH~FIZmh{Wht-dt1AM@kPE5!@rVncFYE#W6bA3fQZQ5>Qk2-|#}u1#8z|0x4z9&aWX2rDB*wzIAU zdeGRdI~=g9wu{kEiNOH4r|q&2KLA0<{^|2>%Fq1bEUUuSX<1?X{G(i~E!YAliBN>Q zRW*k8omJ*6QN^Nz^#2)8Zi}4Ik&&@xVF#^StGVBUz>Vd*#4MY{iqL3f6b$3=Ns^=2 zUxjg}6GH(6b%>Bl{SAH>aLAyle9jR!6~H!o$Ig(QB%b?AP5Y z$y1vDf7RXUE&MAWvTLJR0akZ&dkrAJredDH?xG%3h`B6a)X2V$WlWpGzN1ynH&pO0 z(^!S-d4Nl9j8&A`pwzTK{-+8&9|iHjRu{ITr`1AZpkcFy^iHozL#izLyMQUC14tOyn+otXiy&@ht zSNHfkARxI6 zdUC^x_)L1%R?_>Ko~$h~21DM&QmpJ``}f4B@PCMGH+A!HhNJ^XCMfHp15rUn*B66p z_zHC7zhnGiyp09mW63{-)4zlV+JIcZ)w_$8A6Y1t19t_&z+$(Riyr))LcIR40@&j` zq%vR+4-U>q#pr`g=QLU9`*I!>ZD>Q#M2zQ!hl@KG*RJ693-zl==wj8#e+)Q~8UO!@ zv`(&(UjBhy)gNM7ZZ2N8oWYC@)G?(WvPlZY8xg=C`Y}l)`WkOi=Pij9Hd{Sg^0T;u zuzfi@cL4XUKrJQJs8q9>m$S37@$UENGcrtAm__C;m#1JO>e^_X^%F z-hlpFgT>o+kU|+2)wp&_BJZ;*mt}e^h4=n%vS$R-*6=rVQ(Z1AVgZiLSFbO6gqe4o z&fPn#ez)bq<{X}L9nsu_SCT|OFnQ7IY_r~aMA)kYfI>RUKePL5b4_V=`qi`9g*>rY&9;J?LL3Ew(x4Jb zaWnQDi1J;j?Ync=qIxa*8DVJUu(7l#ls*ITSa8&n-TkD0Y;VwHxbtEL4=Iqo%&&>ZxsBk}dSg>O8IQj0)_hU@ zg!%F3-gn_rxh-kI6Pr}>o?aG=uuVF~JAX&a<8s2Es$$&bL9%E?H=Rk08?pFQ@Z_7C zKc-R2hxkeSHmXAhIn&aCBoPHz!wXKhgsd1c{J60=V!;3+)(BS(JT*STVYG$a-~g^? zOwW__AJhz?_qiH3cvk9gm5{@bru8Fii;_Rx{q;O-pd_!$0lqPu*uHF7Sn)PrK|utp z*Wd#VLmOXyPx{n-5=>me!ki1Dl1?iJ5fx*h9$GO=3yWT;*uZ1lws`4t+@Se>{xnxoj=ePI6^g9}CfIldM zD0fz*rUSfx<=lfH-}$5BYr#@v z{FoTKx{N5_3(B6-;K<%35={vu>qnA-L?lHJ>?smK+tDQCcljJrmwwx5=khht_dg5# zW(*3pRlbXEk7)b-qH|4kp>W!vvwS;Yu#Tw?DKec31p+3mK(3@mMCy}!p%)_M`4`?$ zyBxe-_`kMva6}BeN`dt;#9~o+*euJl0l3JBWK0Rq+yvNw-Zxq%q>Q9_*0kP+PC5br zNL!V&_6XTOgM)A=Ua@EH$M{%Pz`XvQuEfZSD-ul9qV7&DC#nzQFP^r*UB$a)2mo+M z4N_;uJp9h56fs+(-y=O*3o~5ztxU6q1Qu~xH)PJ|4*hq|0vRU$!^<16if+lPqDMSF z)wmFZ?*g7L;_5W|#^d}~k91O&&++e{JB(LPI^&zu-ZcM*yq=TVwH}Ti{*k%7*vP*5 zv@=$5a$et((*r>`s|g-D0WN$t`0EB&yRPxYTiLRV$ItJMo^R9Lqeg$0>mJ@BZDjuKJtN8QAaC z`R}y%>T$3IhG_qTM`WXnqBy1qNrKc}%F8}a}mDJG6 z!7rbzdwmalwh{!tkL*P|uk!Q0m%NwY-3IxqFLKNRlQJa|$GvM8O^JfGtW}xR6up#1 z;hm)@*rrLsS+jy*Oarm7GD*h9VSCk@0PyhvJkFv}ULRCgOQn2ty(H)Y%kDvc`!rG|`4_yPqxp`anN;S5!F9&g81T(@7h1EoPx!4)`^{jU#7Xsjrm*4|(g=N&6qD`}V%QFV2bBr=OQ9JoVw zkB-HpY%i45!o)XleTvMPv7W%hhEYt?PTu&OVY*O1m70;mcHG&bxi|cPCEK_DIhWs8 z3pyu>r2-Pk%@@_Z>5p2Ex(|dF zpxTvVYLtIiZ5N`I(=9A`n&XR2AbJe{tnGNnyyi1(pBAlTEm^VrZFgTkIR<2Fy{UOi zO0gUvjL%&M&DrRtt<6T7YeuHyBfm2UXKSy7KU+beA9_2<(EH{} z>V-y~=qtHkgXJnx7M?TC2)3Jk^o#PKC_2iP*bP`ZYiLHl@xV#qFS0_n5#R)Y{! z*r8lLKJ^$k?0Z?$UpWC>R}B>*St>D{I}14pQIZEKJM}f&UfKq2#ORi@-O?>e`r_H9 znQziq=Sd2_&+S%X880S6?Q5vcibg;Z9{te^;YN%H%fIapujy6}W0NHx^5nw=R=Y7L z1=87j9>x=qlAo5p0(U_a#EH&rir>G9(tTJWpoR_sjF9?Dps-w^kis@xWRRoEAq2m7 z01&eZjMhddO;E*{JlYiwW7Esh{-?oA_kJ{cW9KgxWSqP#6U}SCWs#?Feyit% zFLrGE8ayY^81t}7!JRxN2L~00qq4N#yxv?)qQfV{Mfgi;EgZgk6AtJ5%?{=ETAl}k zjq~YqH(sq0%}a)Ujlch2UD~Ie+k4{kB)ha7yeGe}$|i27d+>M*-t5%Wb z-FG+9TsRC@dZ38uwBNwMq;iS55Su2b?x)yk>bqPJbJ1Il0DFWygt0V_=vqFgefVwqflnrU0Cz9d+1XNDMk?doH%XA63&BU8QhzuS z6(o}l5HyE?g!#4o%5J!~qt;Fysvf*=R^_=-#~SZ-_^4hLFa+=T2Okj-2E zy9GDErfTn2SQP!uPt(U8lpCKXP2$|1j`h0*j@Ao_(d3(Rh2U_}{i+p*;NXh+kRk<59}g3E&`OZWg!WM(aMrWnhV^v*xi9c0A+h#c|Rt5cXN_ zpqd2WnaOf3XL3IzhDV6V#Spj~O(Eh9IZLz`{XDrr37wropZYDuuz*+=7K$S{3oZDd z7Vvd45=dnmw~`d+n>6X404Y5Hfq`+b1BJJM?PHe;`W5-z8)gR8Vr-sOzU2{Q85Ecu z172)Lpl1#nX4OYz4!n2GS$NM|)_3`$NtDsYwbT?!vCFjeFkf-rXJE-7>3vI&!y+Tw zmm!;)f=)*J+APOcmZ&zDsJQf<30weJ*sCVVP?CsA8c8cEm)a+0*XI~oJJ_i@g0LOh< zbksBUBEK+YTGF?tz}o2U@(FvTF>VW~zLoyGL^V4uJn6nmWN71Rd+saF1N>o;Ii4IV zFg;#;oqPW6wr*;`(zL$CH-m(aRZ>AqnE0vRZ~prtS{T6p)l+tZ~tGCz#(h~~bbN&Pu!YOH$jr<)T!JrL%gpA-T|qt=qYuN~^o z58`+rXjaMtG9WB!c(glPBHyv3Hq&uwckUZ>$U8pl;YSA{^yJwSSDHRsPzN?U7`J`s z%-rC_-+#0mlu#xB1jwetI%M>r49g|dWqvTf901GJ=V!nkl4(?6Gxr5X@AtZ34+47u zdQM|jnRdDBEh*=#f-71x(TBw69jY*UAn{mKOjM`ougBu<^hQXJQS|ta1_m!)9#5@| zojq+(Kg?jyx@HxyL?#NqXQpmlfTt+Idgp$Tt!!s7Z)dWb(>x)!ns} zkU8yxMz7k`f#$pIQA&7>`fj8sfG(dc!K0(%WOm8_X;^n~f)A~w3|Is|?HTSak9dsSA8wE^?_%PYs2q2kLz8}FPu=@l_+FJ-`T z=%LQz{t>t4pJ~diPVSr@ z*BlqFFhKCA@6Xbu?nM%4rS^J>BvUW+@kJHJ6#cDFICZO?`hD+zbUK5HOYK=3^1)czWo6#H5DdibEtMhiMv3@T^!rdnpVh#5 zg!MsfME&)#ul5ckPIc(4jlvXaoVo?MQ&XtA>n}h?f+Aq|ke3E$n$xdJ2Nq9p1kvvA zb?r~ZW%1xl4EZ@7I6dNOD0&nLw}O1vCfgh`U!hthkKsArjONt}=|g?xZ(k)EviiU#oOYF3+!EgPW-X@=1y1N3Paenawt1Y`qOBZ zhT!k}FzK)HiX!5>5aCb`CDV*7znu3nZuZ>vpYjw{-kv9!Cbz8cUF!Fh|DXO>GWNDN zCjULbVO0ZO0Ef4qSN^*I^Ztu-Cg+2u=ZpVC)9z;p1OC%>Amf;?jE4h(UL8D z`J2KRIN@v}WM+5o(fc{&e;>f7p|h=;o!+uJaab!$02iA6Lt;*r1NrN;pzyP)k4g^U ztP1>dE!qJn-OZ5P-k()Jjh9b!OetctBr$rZ>io1zA~OSS1k)I-VHpF6*rLbJKGb0Y z)H3V4IkR+IDQp(Tc3&_2H?g{!mX@?O;bmztVVpM9E(INj2D&=_yqv0YR*I!LZ^8ldvvt2i zoMr8WS`vVr$TxIUoHDX8+?p2sicVvX8(6$i{+|mU%WaX|^8NbrUDPS{pq|1n>~ z2Nf&p+guf0;kM$;WDM7fri|Iz!_MD^ax;oD_ugFpAQQv>3V`G zSE$J7z&UoZdn=Ex7vhu?mmvXVuI4$blxJI(k1x9cc9G6~M((;iN)OPgz0ddq{7CuR z_aZuL-(($#cd62XB25`{hV^14X>q1$F~*=gwqcP(6|PjJE-c#=LCqzOc=CW}zW@{t zWx=CVDqhj8sl~_J(k|5b_l_3jCcgd%p1z87&+{!BIjY>_-}i)#eWAl^KGiZvFnk4z zPe#do!cbey*XBw4U_p|?6A5T4tob&TXoNx4-hz{#JkUJ}<(nGkhXZm^@|9eD)Gmk2 z2aoQe5NrOs)&IQL-dBp{k)wlqk&g!kfuKA_wS)6>;A6^KT=(mlWXZMg%CsrAmGG2b zqaF%vSjQtmmX&fY^!`>P-Ntb!INc_!O|tW0 zOSwRS$UE`5f6cPz9kW;7hvxkEcQO;+Uq#lF60@TvSR(X-thj%A*(o|%DC-lm6eJ2@cVJe+nZ8SA|VOxEN<{{DK%{nI85=0&T(g(#v$DPyrG+ApEL%o0}EGRMOrhK!n0kb$bo4 z^8H%&Nl`Tm)o^-cu`=jxAXUKv&;hLNa3-wDORcV#^n!*`aI-dA3S(GYqOLa0X=ADxCF#TKD*dmAJ4*#5H zgGGxYh?UHw(>K@n#7H|WweQ$IO^KIY-sYUg-HzR&}n!~&v zjL5GHKabpg!SE7q+xJ-LKQ5Wg$@c-+peG#wwQ#YbbXJHMtTI7UL8&+g!vvzn(gA2rr z6IUXX?kg**{OP6sgsj0b3m{lP^`cA1szXzH$w;6S>r2Ty`JJMt7@FajoyB}gI*i#> zv78Jzy|BR8OfDHX`=akq?J8&(P);hmdc*hNL@eJysgKGU4#hdBuD&+QHzKCr`am?G zn^!hCoL#Kjzx4IRJdbTx{c59>#_LQadx@TRp6m=~+JdJ^BY!TvN*&mvt@`pyFa8Q2 zL}0h^L2zCfQ)m@wQfSo#`g}Kz~ciL5L=gcq~G)-^S#kf*Yr1N%?<0P3Sa0$C&MxwD^|?XOIblz4hH^ZTyiZ z6dSvvzO>$n7Zq&&r1`ge@zM$Fc(r{ej{D}99l3eT4QF1R+MYIB{64Gh@1%s~E(UJm zWm4x(yjmox-?aW;RqlFQmCJuSIgA8Z2Ztv$2Bo(x*@*kX5g`&ji)4$K2IbQbaSp-- zS+Uh4LyNVQ1R!XJzP>wQg_w*&ayWW*^eP#{p2q}2CH`MDuY;Qzij{?4ro!kifmaF9}-L6O=he1#x=ly6+>u3oh!r z4L$4ZtPdTEa0`v&uS2nWyG9!jzxX-HkIDU@PnD@`LS)iZTkH}i?Vev2KPaYw7m_uZ z4?mw>D@pw=Hjv4d;Wy2@a+|DUbBVyCkxBpLk| ztvjl>#Yc&;-N*Fv;AvvamAb4iPH0gFN=RvWttEE>KHg55lXF}2;a^Q z!5u{#DF3-MXE~$Ms&J8a0zaEu=pu!#j7ai};xa)=#+>dmV$e&s7!8t!rq4tbmegay zuai{m6#?j#J%q^XoMiN2!ji~vQ{YE6S02dI2um1Mz@1F?sF(5plp{NzdT0YR*TM`M z&$tZ{nssVjL$qP>$~d8*i_;Q~rH*Rbhun2JD6_AXML8HHq&dd2Qq7C7j$-0LChq|s^P%41c2@nw0p zQeYtJ$ii?ctXgjx#_&&#ym#-XdIZUXlr_p0nVS!5na?EOO?a>FokdpH)vmpkD{~+K z*Wr-Zjby7Ut$HK&{Wo5|AwUASjuGb04b?3R39uYPD3&KcZ-!hI!H5+KmS zp+9EG?&eDAddRG(-e%Z1V@xqZ#ahIA4lyhQ1wD^i6>z#U^bVmte4@cGb;Jz|>gqXu zXPy37BLK{WX5kg}rl?0Eu96li4ORQy4XbvW>k5NcmBq?9&NA>T{%pCtv_w*0U+zeU zcUv9wEgmt!+0lrcl#-7^mk&WzK=!}v*82D-Z-plsT8X$nZv^@+z=}j}&gT<^=l6y* zw7hQb1IYO6xNVI%n%ax(Tfhw6oIn0QlZ^(fMm~`2eZF`j<^wLP+6I3Yyh1EV?}PE9 znhBWZ;b@xIfH^MHpP+gdBAZ1ldVaItfAEAyXMU4nqcl~9(z(Z4sTtR+t_${FNF9sV zuCP#vCxPI18$S66tVvnbW|YhJSy$cPm(zdDjbbE5%D%kQbous5Q%O(_4@&okj3`PT zj??Z$a6t;d#H_FpvP;p+8};`Ucj>)9nwNYmy0pHd$J|MZf}rH19Py5mxY zO5bJJB^j?u5&St2y*hg z!lfk5PjZyhp8-qdA(wvyHbRW!G15_fxWpYZxOT&h>$|Tg5$dsUs-U7th|-^tjQ@|`m3h+^($BV4_x~Gc z$N#m-oWuM~J+I$hAK|g{4V@&Qop~~i(=>^goL=pu`ldBL(3DRUP?K8m*7ulH-PC4c zSwEg{PEuA?kq-)84>24FGs5?jmz~j4(Lmw%yQFP%(yTb3Jy3qgD13jS!d#;dQuLFE zNTVf3rbX(W7kj$_;yczSL+2lvaSiCwa~@hnNVo0H*K-Xc|FLe$HqbQ!Dg0!2WBtPx zxP?NL(UKV7>M5xS{1&3N5Ltq0A=Kg{4piX3oZ_#+C>9)c$Lb8Y2|=5GcvNt`=qdbC z=lO#aS!io+aXQI5?6_G#sZVO&<@Lsk)g&>afzEZ-+6K(;k%?2aJ@GKrdmflChgKU3 zSr>ylK=tcq1O2xC(Jf7Kr*R|9ptjpCA$V)Zl6^H02 z9k2pH&nISZ6l9Hb?v<+?iw4V@tAd8PyN;TN?V#*RNzze(;OOe z&mkPM6cp_ra1w`*(@>@Hle8?qOvmylQ$gXDW0sntUeN@hThJb7qKeDvplscYPN~GD zu#gI{4Ar=P;A7Vj5(UM15N#$elz(O`#^S2syyv80vC`gZbkS+ISI;fEgGcf;!)DFh zfP_QFqxXgWrVwvC3ar~qo+NC1PdisrqH*~nNm5THbc*znTWPB9}D~00Nd~(qat#=x`8;?~>yw&Qs8J z{&GPKD%wD*FAZC&T^>_fJ>7o8Ok;Yq$*@S~tc;qK+a)tYGqwGv{_djMr{=nzOn#jU zLX9>bt1|>{u%7WLs!FH6FMH7z;DQ`ozXeNE#-q;PR&vnK1%3*!$zJ(A;3&_g$Jiv|8P-^}HMa z%e~1x%HYTwOYk~=4Vq+SUMISn0$t%-1LQ!*IUhHnIX7K}XiVVO&vdYa)rZ<~@tkO5 zf1<;C%ztD#svh~leaRdEba}`3T{EVCZwCtMxV@IXV?l`>v+(lZY7!-JWaRRz_&D@M z=p1YfnNW^Cyx9CZ4`~=*-k4yq|9Vb7J51LN4jaJPQ$K?=p*Q3nP1`VaqoqP5vGBUV z5*f*5&LEzi){RHUVvlw-BFL)EiY1RYaL|9D5tAbAZ1RQjo7$gH-Yex1-(^fS)ykK1 zrQi46h&1adFmysa* zHmA0SRVF;K5*Kgp;j&gA+^Be1e~-Dc@?-v;l|AA^^Pkwq7XVf1Z26N$--5;WkqQuo zl^F%LOwhk*@CpD%G%KuWvq`RP18Z0Te->)v6(R+_SEnw9jdMO^rR&pcrFo3Z5OK+| z#xmz3RB#VQSy-mDtB7Iyd$=s1Omsrvt^m0j~LZ7omQ22 z^W;Cjx}BTy9Tm)FiD^6~res+eq;CILs&(HsNV+qB$Mj3U#9S3Vu6Wxz7k$M-!ER2Y zuX0mYUqj)aT|T}0s;*c+EV_Scet_XAs1Fx+u9_2*by-V{*e^L14~xBeRMxRai_?*$_Omg`SjBBYaI|sR>!Lc8WX_z3C7G27hoebJ z6=T&U^v!D|kzvI!#(OO(qYTdI_pz$?v=!%%Rj;*=pETfEGzV;evk^uAWoeOYMopgtDcM>flnpqO5F}?}B*3Jr&mmgRtFdWQV)sBWPd^h3FXSCUwu-V5c{B=Yi66=fj5F{<{ft%Pd47xBnPH(psWQ6&A%Gx%2t3 z5Wvqvs7Zw6yZWs0OZmO!qhI#`;mm82RVZv(?$5M;~h7tWqz2I(Zt z&812I)pD2rDRU<4R}N>+b(lH2C~PE8vI>405@khDtj;%Ntj-J4(+hm7ilLM#&X&gd zg(|bPJ^xW9^Y7s-6~2?4l7Jd(UXzH5)^8)MMBeS917}UDBcu5;KYnSJNi*H1A@;ih z%E~WBLAR9zcy*%&GDxrpiwsUyFd_$egDzSDZ>7F} z;R=KmfB%%0#BJC$5Z5$#)<%cKMg^PMihNWjKfhq9ghwsiVUFEx0#|@$P8!$K_UskatA1=rV-J5ILM?T=b!MX}Q zrG%7#MBYOy>5cX7e_Xqe*+)N`^VUO^Kc;!O8>;Xnq69}5J~Tf()vDVF?Elmva8c&PE_&G^3zYOFa_T(3+Lmm3wCbCVvYan(7=lP&22?qR6AbuMGG7w zTQPo4yp0GNR<}JGz))V!Q+SdIlhLjIs|!UaGb#a(wBenX+;0!EoLJlw0<%h}*=O)e z%>PF6rLmnHzXGW*IP#Ko>!(2kIY3eN5;b|7tS1=?Iq717jYD>4P--DWFW!Tu5{@qr zni5lrjo3F*5CP28iZh?Q>Q3l;s{_s?kbsK3>{HCpBd9Ai%naJlFvK<2HC!6S{>rM< zWsa6d!kRi-@KC+CkiJOKJTrgU`bAOd$SlCy*;2p%I_N|f3MYb6q?J%z(|s8;6!%^9hDXY_JQAk3 z;;(p1oBT(9-xC)heTa9X?0NF-YP-gJPp6Ul>nAY9f~=0Q*;`TL?3M4|aRA}Fi|5Dh zbC}Pl&Uo>wXY>n3I{p0W^q|v%Pn+bnkOJSmE(ymSqHImR%h}dU-$l$#w9U=&KfG0A z&d?#$s13A3@$(1)$|P>9 z@_usNo;$#Fm|k$YWz9!Z0wCNv7|uU;hy zy=M7CokcKR;PYK+!)tEL`s@MjTalBcx+YxdvEKJ?*=HI=S@sS8K9bb6xyZhI>#F0C zHSZL?nJKy!$eWmdm!{O(q1mo_@C=B@55n;$q)P?sgdWy~RB$AMN}=-b&hA&;yQ&}Y-N4wCGC7HXksES&I|pMezWYCICt=Go@U(k|HZ0q$+vq?zg_H6 zeFpam?Y#W@-*ws8ZoA@>%9g&ipHHi8Jx|0V&erA{0@uT6tR++F%=d6E=L~*gwbD_D zm{Z+-{xH2d;|g{c_2D#pzHV+g0H1s2882Gzc|Z>{4WjFbxVr!+-8V1`!^w5-TvhVv zE%Kx8Jni^3sZHGP@;#p#fV0$-Qj80McGPRn-_2l2BRhs7k5huai=KEviA8nD@+kpY zfLj$F5R|aXuEe*j`}eL2R~h*~ke6#K?*$GjdV2)KTT39IPXv??i+nUQk(^mVq9Lky zN6MqJd$D?P==~cH_of&nn&*yL#K_@qB(Gvq*6P=Dyy)nQI%h(D$PBqB&^$ml{mRC< zL&t%~n$p5Y`*yl1e|eY5;Dn~<$m)ke5u%xm?t*;H8(ti9fULaB15r*8V+6?2k2#}N zbJF7<2M^#%qEXmp-vgW_yP9Aq8agZ{Pjl$&b=vbkWLI#v&=|xfvyR^jyb8|7-*=R_ zQ4hpCWiJy{ew*sFeoM?ph|eJ{1b@sVI~=5DPOcqz^N#%Xhpn5g&*}}1AI1`*Labab z^H}d_Mr~_g7=0~eNLoe9B<7y|{-L??onFl^YUeL(G+aDjufu2FWsDniAo7&iEWYiN zi&}3Th|ayP^|?U@qN>^ojq=+!oZxTUv7U*t%$d}D>bCdl7~w{?Z|`fdD07gg<2j|D z@C|OgEyk9KjmCGS07m>y4ANQ=#=iw;^$t7LA@ELk2+6La)S3&zGk2_+W3C)*i4@HO z!>CUfd-sN)NGyGJW@)e{a2+V2_Bwe^ULd97wjO=R=fiDvbRk>MMj7tXu0zyyII7) znys+}YbQMG1mZF+kS5V9%ioilsCIt9_w2(# zkA~^YPWph1pu>vz>Zwh`*!gM1UcE;4SDPuP?9T)5+I&Su#cFsj{=Ms-$X&})Tz+*y zjFJ8~n8h}*mv|pOJvN|AzQ*=S_wJy5P9y?3;%zKS>GvSrrcg%tBb1p@O$+6T&4^E# zVypeFzl7?$0XAH12rtT#6^DhRrq>%JRwc|XHyMaB9^F?Q*3HFNCR>gtDsr!(tsM?Y z6xo~h&7%D(yUWiG@-VyN^D&^JA%bq|?uhu(ffen26pI-9~4X(Alp&-pKb!8;G&4;-%^@G}rUy}10mX>oJO_(VKd z;@}8C5Pr?OpIP>4Kl4-kqu_5NJXGT=?4zvf?~(Z8#CzTEh`6ZivE&gua(E~#6|S#F zL;zW2INGkbWa2w0N9lqDU>>)G%#kW{>6*$-;g;c}KA6y4eB`~fW6D+YjwA-!xBPYU zRJe|;X^3tDJ|>X2bGh~1TU30!`wgU6KL062!4YHvoX_km`nYdLYQS*~I^3_d*R^P( zl4>@juNMNu|Hz6*`l(4}G1|Bci}c(ZGa0wp1Q;Der7XxDd+%Y~$#Hte@}oXkZ3q0r z?%V&K)-c>L!JgwUhOLW#+Gc<6Uxa521;G~*0Avn93it>()tK+{--Uco9YdQS??Ffc zV3KLJ;IspDMR_FfoA2j~y?{?ttNPbPvdSEgxD?wj%j{!(VJn9;65uMybMey(q9;y; z7TO>6cmiI<*{8t*xM+&tS6V%dZFe#ZQar1ITK7fm@9mI`FKqbSJvyx}_}6JA?(|#z zY5QG7$6oH@oiAEFw0j8(>O`)Rh?5 z%?}6uIoJY8^7y!?qKwo8+U38SNa9}mPfn;cFI-}CrMzKkdo1|Y=eI_cnGm=AVCHao z#0k^5U}aLhCKUFd{zW-iUD$^Vm_l2Uz}ZI!xip2BdU=e>qMy2X8vBlMnOT#B=*TSG zYTF2DHuN{k-L#^}fgrA>%-l^sT2{d*Vwp%Zz9TZ_jxiM2dja+SR5`Qyq){qi{_khu ze#eH#qkZ36PDFl=ZZSAEcI=q>-+t2H_yfq9r!J-!=HScgelM6N04D&)Ap&O~z;B+C zD-PaEQxJN7Pow<~nuD=&y4cj$PUCK}{aRTapO zYu}QdT3Mao_%#ZZ`d^VOED!P9EEu(kOXXh%;M2D8_p$%xV-sEknTv4E@Y)oSjsXJM zo8T51*shXj1wb17mbo?8YKR7g;)M_uX|~~6!SZB|7#$u<7)pu=BbI)Bx5j5UOU9$E zG(K@6-P9o0oiFOy+cxp`d~Xk~jK*FjxENdG)8x+T^e#N5TmA7lTN_G5i{iIh72T}7 zi3upcQQozCHbuZm{X`%Si6p3{nSRY06r)Icptj--6q`W_1sN)0;cD!LR;`YTF-uG@ z!j+&DcWY9SIGE3f)KaF=$3(JG!7qZz>2ZFx7)gGKTqm0L=Yv!)XH&L?LnWU=5bJ2% zLX)S7p@HqW zD@tjI%pfHT6rmkoo{r+m2#$B;ql6J(0vqbr3dEa&(12KYVGNeJka*}$Y;ZC*PSVO> za61b{0BkvMxPud&Yw%(E2?w8X>(>_6y8?Nf_AL5sPl?B>IMBl_#6|!B00IqB6U?F#uOLMKwguRZCR?6Z9G+MQOjXb$hfXd z7SVgJm#=w>U|?*KkhXF`%recPH1Ux-FCJX-sv-ciTQ; zMGgLdm&~GUS|j>Zx_oug($ecK7)=20CYQB`& z&X4R7ZjjVk(tB7OazZBVbrpMlx>CC-=HgQ8)pA`?%;z=P`Y4%r?SFaq+u&{9wJdCH z_<#jCxw#YN|98$!I8#(6&-lC-CdWB^|MWD4p8@z#7)}C)dq(%!hu52;cQ+ND$&BuB9ZllC%n!HjNySKVASnYhuR72W={5U9xy>L8)2ozVu82 zG%({hg=6@mmx>}<$pvDcDn8f}WMg}JFr!Fybp?MD$njLqt8zFC!hP@jGhDKtGJHrUdcF;?{`W4e== zjNbG2#uo;RfIWjEe)|@sKo6VxXc2_iv~^;Dfg=O-ZQAF2I?A3LIQzMLhlfIZvs*cp%>{zKp%u=Y8+oJM+yQ{y2=k*yo(R*4k_B z4ddk<7$ZNs^+C&+|71a_n1e;hd7VNR%F7G4@o*mN{5vJN56dIP(O&N-`h$RzGNbOR zO%SAEYuwwNh4gct^u*2Z;K1_geOh~GNP4px2suSJ&vMZwewPEiamWVi=jYE>qXKgT3j_LK>c(&jI2>714^{hitIVZ7za&!- zM|jI%Qx8tm^wYpnA0B69fakjD>|(XN4?z@v>^b3u%!Qrh!!NBK2hQj+d<0f#8DDw# zXlnW^k^N8Pu==rTS>lL=Ny|#(XGK3S#gKD5`;O1%&g(dL;3(p0Pff%tFy8H{2lbHE$2eigfg%zK$maIPT~R zh%Li9Go7X;kPzU{zE?fSg-NGQ5ea6t_}bk<16x$8L4{RCjFCT`8#4!^u94_!F*cYh z3qt!2>t{qTiT{#=mf4{9{YWg#G zrWJbZK3Dox?bsP9bIU@;SYB)$VO*LHd8MRC`BB9*HT+E2Ms+n1-3X-(}Gc0uX zrkmE0Jk@+S%DGhJ0|*7E)y@=RE_QL(h-@D25?eRrliprhc=K5mR8U`SoYlFcD*5So z1l1i5=M6YvvMdRs_5R^K#&oj%wr1f@PY9`D@habSLXsg->3h~xK`8z4m(U32?QoXC zg6L)_&ptadby$ z#7T!`+fI=k#L%Q6T@Z2b9m=Ta!@5gnGt#jKk4hURK#^{$MM7C$3$^$0atIuwYB}83X8^ObQiN?Hb49YGc29m-ozPjhlHONrX< zs!H5LjTRWqWx}HFf`0XepmNcWji_Ui!;h=aYTy}t74p|;s~oj9|Gb~kUjHZtLOIO3 zc9~wg#ue7ud{Fn-YOnWQsP0?YMkwOqXByb23|e4fIlcar64dp^MkLVPr}Nb*q5M`xPI4~zVbHhMWxQWjZsOP4;o8>z>Ac8*vf)sCkPVvnSj=586r0D9e*^2!JL6rcv@#tHSm0dN`V$oCx zT!ADw4#pQ8@$6D;Z|s+o<0W+xu0@Zy^UJEA`?b`7osfSrsIDRXS=6F;K)i-WKZ0+_ z!1zqgK0G?B8#|r*tF{~vyZ=9U_wEJW^?&U6JpyQ&mrtTU|D&$6kT0M(3OxblDuKQ2 z>zadn-W{2@VJMw#z2=G-plnrU8V>rn@Xs{g_1=51i!k#w2Fx2H<3=>@M?I0g?gi^% zm`2V%OtYn?%rIE*Fbe-Lu+vbkMb}S@i93DyKyklPR=*xME_zJLom*mve>O=*#D__! zy7NPe=^IBXxjWwf;%mM@7Figl4i2eidqw%~AlApL+zIr&wowl=#}KZM{v64#syo0D zSmOS6&siK8q5ISgV$6#Nr9RNuCTx_yz6GIsus9Idn)fYGeuXRb4DWduQaWt&mvU*a zQ;+Pl_Qx7WJt`(%bBncBiF>-2Y-n{U1pihnPmg^41O$(u=FiB4!$>Sqo&sXlStwF) zOhV^({5^<+lfA{e6l;^^(g$=rG+N1@k$0|s%a4K)ICZ-;p(9)Tg+2l zUhnm2qm~(x#`0z;mlsa_$Pq3PfOFb4FY<%W9NLOIQxamE|5^)i4rgR1UNN1c5WAyF zQp^=Id>>yIp~Pv*1N7DY|N5%Rbrc9G8)!|AY58gY%ly~Fxsx$6wL|rQt!X8z%B*v~ zTiuevcN8v~?c3M*f*EtEv#&k?OWkbsTBPS5v~y%7;s`wMI=S60B3isQ1EP0b(M0T>qHqH3Bwj}rNN#`+Z{}qWvnY-Z~rGlfAz=DxlwoV&Bc&3XPZc3cy@+5uF!6`#2zG! z)#qbXGP{)!v$Nh8RA#?DK8}OzYerm!Fixj3r3n3sWDPf~LIuXR(Cl*39AYG68ASN_ zZABWluc<&KY@EQJY(K)JOUbug`YJOBzS-_(Z(lE{z7|pr6>0TM;r5}iJeU3UrknZ1 zD@v4^)&T1jzn>P|AY=JiaB853g97Nk;~xd$Lz)hx9u>Q=>7@PlUM{)-?b}#`jBTMPB3h7T zNJ^&ofHcx$Iw#TDCmcnA zkJP+}#;O+cQ)0sXWIuLDeCvA!eZI* zBo-_%bzglQOeR%*#E2^i{}Vq_8^S6XKE;a)MLI0fQY~*;<^+W*p>~f?2#N89w4&k2Uw%7VL_-ymQDQYU)8h?&oXZ}9JP&2YSICHN&_|rpl0h5DY>RpPS|BCGD zd{N}Y@7O~lpvc>ADPH_5@>j?MP}J-=i5OMDFY^~w`27wIaz>er%Wt1?gMFaIoQLhJ z^dh-{v%lErz@C&;yTMKU>qH|te-3F4ul9SR}Zcuv~ zi{OV;<5OBmP4kP(5F{}Ul^-ec00a$)qVG7u8b}j*$Ul<_b&bRYj@(8Xt`a6Lg;PFf zKTBifEZlp*&sxUpfpe7iRTLe#_va&vi1Sp(y4om>^1U_dpL|qrGS?1o=3c(Kq^EW` zFf}}q9)p-%3Yyc+#*aSjjn3IjBkLeTV_Ce?i zO<*tg50$YkhJW&}{r&2?(9 zA?!>*MI%t@d$m;t?iO_kR?-QdZd&OuXcgWmMov&*l6I(G)0h^%d@9t!?zo7M#6H02{sz=c=F5p4fiOzYc1y@#0CRxX!A} z4N_lmP~>6x=5#1>hkZuR7`~NAO?N8B)$kTabw>eq`q+T=)=Bzs1DyE$!JQ^>SSDdbA;kDPd#!YD!SET85}+`6`@hYvh78VC)$zj3H&C| zcd28-3!=AG_t21cJ_SuFUfu7Fzcd94L zb(TF)m5ojIu;DSsI7k@R3iAo7x;jKQI{IQsp4PgQ31JPA4>#$62xOJr^9E^+u~p9d z*)wh~mco=unRL)v-zS+zs|=r$VscKZ>a2VYPt*#YatE`}3g?+J`xRli{!?EcO8`7F z`_A@ldr1&-!iyk{zG!W?Y>DG)hiSoQr{@l>f4;Y`*^e@V!&y8&S|%-!Ye^BOGe>jf zt(qs!xW7OJt5PK@G(LXyjwEK#A~y}#VCoaa(Q=gLUCBhWY46d9IxaxSt3$qlPm1)= z2o1MhsAAc7pKcP>G4SYq;X2uH^>()^ z^?p!m4|z}I{9}4NN+?Q8dJ>(KPY-ERTA9xWA%cB5Siyzgf!+SgkHeC7!;y4X)I&epB?Jdj;Y;YJw(JpfM2lJW;@md)!V z1>0bj5N9%U8d1N#-n49L8XLB3WBG;4r29-$o0PxBBhVREc0ZUepl$qFdwSk?mcuEZ zzqj@T&VM_yOU%mw6#{B>XagXXo)rfUFCj2aG4pwZc|E2cFVg@b#Ge1nom-w6%tcrQ z9B3F-xinFHB)gp;plBqL2i@Vv#C337DQhY>m8$yPsyA!{y&pxQfy6oSM~OhWBn!mL z`(T8#&?f3l53V_dBTspzykC5f2wmCDBIcI`-qdqy7QZ~i3%VhHCYgNhQW6VW-{3rA zz1U!dE2Q;AUm#%M#15vAT@Ob=$;DsxlHw=~>k9Hnp&6Ah#gP!|=4554VoMHk`?Fq* z$Zf$eI(5hp;okH!5NYtM9~I=n_UEy#f{d5`ge=<(>9(io?w|M@p7Q0o?o7wU{cN)t z8lscKuP?pMK}%hm82^js>%Nr9oAveYZy)2Hzzid%QG8)j_9E4Y29EqTs@u8d(-Ho% zj_Sz#{b&qSA4M*jrGYG7F7c>${CR^&GLe~?(nu*(_t@`JfB1<^^hrgXmuA%>-F-!yPIYSXt5OCy6%gMn-0)t>w0OlE8KsN{EtJmW1^a$72Ft70D&7#edgZ$G2ILgTHeI!NysaI6iJb#tCv8%0OdOPYiLji#cScp$ce{GTK zL_K!wj5o~n>B4IlHIj*koDc-7DdDvqd0j%ShjSo(i14@6l5Y~Ym^1_BXFQ&?)Np)w zPO4k_ExD*XX&M~LJrZ&42hcd)TF3yWFu*9dY9{5V-Vo2nAlg%6KO~47emP5{qfa}w z#QeZf3?vg=S#1;i1$d~pZ3P-(03#lF1!%`IuJ@%bka=fEU?dFPBQY{#WBrr+^n6VB z(ME37-nVC$zV?!!C8BGk$U?>vZ zj6bvO9+ZXU2t|yXy>Uro#Z@bk+;F?48Tm8&VH_b@JdSW=yVsHYyP&Nm+fM9R%@|?& z)u{Q>I%uMtTd`Xp{h%9SayoirbJhBsozZ3zj|V)Z!|t#YV8j;;+MDB0`* z53S(Qk053drix&m&otEcN@#F#%EGPsXR((}wIYF-rj%skiPs}yL1N#sl&a!J!=R&y>HD%E;_rS1JkOvPbE$3skhfjpU@@t+t(7- zhmKw-t5}-_6!Qx)BHYEQg39OZ>yx`4yN`rbRZ%35k%;Aj)R)$(L~AFH#Xmj>YpW%7 zidMKOdmZC^<%E74dDlSb1-)?`?KHW>mA;L;*UgZ8T30F4UJ-uhOg_2xq_|GByPoE^ z0S?)Ejg$I83$yaRVJHg(BVEh_7yOeZNkCgGXTOy^Q3`|?>GQlb43ShmtqD_?Cj7`O zK!YQTu^L#)=%3pr=JE?=hmJLs>a?}lVA{#ld%#Uop(Ulkoxb-;S3^XVOlm(CA1jT^ zfoZq1jE4Ng3kYP6Vq6uUHw)DD^`4gRLwkp6$5rq#oB+6j_@){i{+NI1Znb<|1<_ zTOL9+m;5Ods{!(8a`~w%QF2pNe{48MyE4`BZuO^={Wn=`j=vwDlY+Vwo*T`p{?}1# zY%e;>dLT#3a{w>jDK4PBP?}rBN9Ja%xWiUxDia_5&bM6oT{^IgqpeWLeQqr#FEj<* z%8Hgfl*N$(bVxdg+n8mROxeKQ^yzMxic)A4^mp9SljqfB!}9KnN3c9S;+=TR`&#^tNA;h+VhO0#$|`aQl+ zXz+3bSyEJit4X;7f=_UAMfnH{3&}xhxR9^Rjj!Z90K(vY8Zm2%YLLL7uM;PHO~ftR zgPq>Yq@B#2qC|Ni=rg;yZ;`9hTjwL+zEUf2@i8+vK$ZxbI?I~UuE&=7v3X~yp>IhxEMv;x< z@E(fR{);x7{L;Ckw=1*mT>9PGFH&wQf`g_wxs9#OV>tXgs9@vPRqYRzwT5Gyo6N$ea4)IVCCWz@>y@5OT*( zZx`ew11(6WsN8(v@p=VK0V2GI(=|%<;5yyz-f5ZZhcJz%wn>606rv;2X@Q;iFe*)? z2CtAHqD#o`!+pjaB!^MG=01k+q1N%5kV|r8R5NX{5V|KS<=k13%Y#YeWxO@N$!f3c zC0-Mj0H^!5W_{f=;odODBblvmA9z6CUi zA_R$i*pt=bJf17d+8FWa-Rpbu?T2XjM=LYHrU4L){2x0}EW2PQGOaQXzEgm-@@{H8 zyuerY-r;r6`qV5<&N@D{ez(*4(>ar9RpK(qW#G&u!c4P3UqgaHW(nu2_ka9m*=`!S z<*WVVm#3SkOAWw|6bK0{SVrUy-ZJ6Oqk7q6XKrut;mVs(1=4Rz`V3KP*HPv?%1a>R zA9jPoYUgBaj>SP;+BU+zXw_?3vfAh|SGO)asPyuQh*G=5wIWV{xIo}Zfd!Fs-v@(* zf+c}qy;DcrZT#BGdy$C_a>4^esbXGn!rP`4%RcN^`YZcF+#v!k5n)bw{C-BwN9?y} zpK9PG$k)^3U@OlR>TJzLGIaGg4mj-f`%qcwU!*D-ZN4Pe7lgr35CZtV`1k&rKy@Hp z#K?c0xAO=TT4sh+#ICr=MX_tEx01p|l~O-QZ9V%-ODB?EGg!qu_~Z&sIJ72RhmJuY z*@gRtKQ{}Yx_DB`l#{oH;ON|Lk-(OACm57@PXs}B>r08 zyrXs)xpFYR+S2OKs9^phjAhZoR~1{d8Un$vynehpzm}UawMA_;H=kB8wa|J{`L5;Q zwCcffj*Msy=|3f%L+#>qIGlN0sX_tjvfPaty3hk$?}U4TpzlJ@eeGvr)vi`y?TmS- zTp1n6eFu$1a!G!5hF(r3`Q&?z`-Ul$&(PBvqZtd2J2jw?tIwYkkt%qhJ*@w1GLfn8 zfsZlR-gkLe^If=Kuww(|hz=V5TO+H3A*o3;ru9}kaLI?|G?Xj~^(&H$f0b9eIMD-Y zv5LR_TmRW&xF51nIvR%uvb9(}q#hPbq5mEqdqbg0V_3oo-hC4T7lv2zpGw~m=wwef zMA1^mKN3E!Of7#_p9)3M<(uq#ojiEEb?;moPY97aR0*wl^+8vyn058f)hOuj72eRk z#kxHs`7-U1zVeaucl)@@`$k2)^-o2{4uI=8bpvRZ=plKmq>fq>U3nnLP|uOXY@SU+ z*|WXlo+vzsdw?Ou{AQd-Vr_{714t{YqY=T-A)gK6O>k4q*h;$)El8> zDNSJ@REkG^!3BB6-VxvK8a_=5KEBc(<;5ZueBeK!VKnpF#)7}hKiZQeTHwo#$YQ@& zMF#gLTvOt@+IhAyTyY%`45Hl99f9g9Rr>y&2h;%8Ffg3}l6elya|%}M1&|p=>%ejH z5Q{ei(hRfx;w5NSoINGx~_Q3bJbemW|LIYDm| z+vb~g4?s>aCXR|c+dN-$^0SFkhMU_v3iZi3$RaW|aoKBll4b!V7w_Lq*Yy6varjK~Sdyz@Q*{+us;d>SD>}98bylI|#3GU+YzVB5 ziCjeJGo2m#vHYr^q=nZmr~IQ)^_2olTP&TqXhP8IH}w!wLTs?}RRyDzL@!RCO3p&S zPP;#5?##aXk>+Ks-|MHAZCQpm`fIIPTDldZuM-+wfyfr(G=*qml;F~L7;+`{N^tc%K|DQA|RLWeDL3+`(NOic1c_$eL zn*;(A?~PEm50K2o3P=NaD_XhqoTUDevU>lZ{N7cxrFy?%4T(*;e}R$vrP*BZwew{RnbhB;8eGjy+G; z&$>EVwY`82H;B^kjSVDi@n(RWwUHw_!4SBk?pMVRPu z1QyfI<571zAwn$aAf&%I$O>~_ zx8A;-uYQBo=fmqKd|;$`;I35+HBgC%J7A>L=Kh$haJ$lPEj6qi2-UdPAp_t6C#DO; zZrgOVn~MbG%)yrmb2w@}uQ=QGIX8KK;s68f+_LP#M%y8yAIEuNZcv3@M&!xFpuVl~>tq2XTPoM;lmYKcj%*rSahxPLC+4Hw6YPzw6q!unm>@h7qnVSC z4q<8tgH6-AbATcwO<{iW^RG(q6yzdg_E@7k!zAL;U_^$RG7chG?rKED6+j*^8x~qf z28Aw=h2^P1iAA?mC-)oQR6J?NRY;Uy7Ng*l^w4M{d~L-m&Vh@1=mo)H!=D`UdbrVo&1y@fw7@=h$nolCMXX6=@We4>pDAp>gJM== zK2%t7A{`8N z>JT$-)N-%kMVjb53<=+8ZO4|D=gkg15NU7%$~oTo|5whyyuY%0V|9O9f{)HmImc&g zhBxxpnCsL?4^{yRbr~Z2uX3P8UDuXvk zLE#KGJmLJ|B#!)$FES;=`}SRhzwYoZ5cY2{tgs3GrCts&OWmb!!Yc8@`B!Be-19`3 zv!1!Oo^Iz)>JmFAq)^%mvaj-q{gwm=0k4Q=#dbD21Ll#e=2`hb3Kk*`L}sMFGUZJNDr*4^;1X0c;+3&Z0w7|Pl155EYt=nZ8oOocnMd@Y)Fv4**I$*S^5A7PAHcRvRx^1#{tAUO%e)g}ebm@u-` zViQ#ODQf{JC5*;3%01jzjYG(~iH7@L>DR?+Wzi0WJit5dZxFM=e%pwCEA}&zzVF-j zlJlm#ZQ9Gy(mF;A-btT6_+5!f33c!^dQd}d>J)P+=e{h&`I3k-;I+lzVM2-yb+S)( z!QlqiO@=?sNv=e0(HDR*5HL)lh1wS_H_=paH#xN-jM6acN-4!`i7a4us^|6)RHwxW zvpx~mY!V(K|h|4N)rY*teuLCK^R?eLw}*x zo%l56BAW>ILcBu?^DJr3ryzGu9c{zK5U2Q|f#$5pC(?Fd@of8;O9srnM077q9rQ6* z3=sFWk=Ctx+-a9EBoQ?C7!wcOA| zgTa)8l=l2<(P_R;S}NEXCUdd<93VSBV)|GQ9zk1%;J1uMD#gG34Ra`bn3=*yBihH}>^&Bv=574VfQI{6b2r+%IEay>8q zoy{BjS9v)GdXk;4vzf{UI0D@WIuuuYH@JCOukmU$qUN!JLNSdZck4|Ed2L#8_|41O zJ)2!V5%hMQn?ic1&vdKz`6OAK35|%_6$sUjQ0!HI-@9e5ScwB=(PXr_l&K_rkA}J$ zD#^&@VhGS|K7C$&@SlvcyUry~|a zxgRvI9OwQrH6#!p{LgURwgCe|-C8mhI01}INi+5JGZ*~J4iWOV=e+WG;QZ*zY}nI9 zRatARtMcEk0Qu?DQsv|)HngsFqzBNF%rfa$^N>;LX6jf? z8WDz*ipDUvJi&W_igg z`5I9fA~{AD4F%qk=L6xDVY^Y9f`aMoNE#@Mn<|*}gz@na%RK-{AR$)=Z2MepgJ@lx zIt=@eLu7ABa99U+FsGh!pPI$O?-$voCIs;QGmxI8spRVjg|Vmb4T+y|y05-)ie^+1 zAT5>XKEdpf8=}(a{=7%DIW$7KudD7(Z95(S8V8_q?7Rjlmoj(dd2=xcV<>O3 zdeNt{co2pps36ul`>N`Rm-t1Q((sCyiEE6->^TEu+Q_26%0W^w05JFsI^6(WkZIVup0c- zEdy)fB2Ic_uJfxWT8Dyk+@BQkTqn`rWI3hXlLuN>SMtbh;?(L?2tsz&MGVA0*$+pWM^kS< zZbY0kFw{r>L^u_Y@#{K}@NpaB!tHZXLu>Bvh_H4m`kIfWWjK>}APj&dlVf~Is|FX{eZv0=IGd{I<`AMw7^nvn&|4wh+L)^4Zl^ONEp&(%NeFG;N^5gd ziE-On;@GpZAHx;vzE9tEx5WJKi1+_GU&>-d*fE0C=Sm{#8V^8GHm}r{j)64;0#&$v zwxx&cRsN&Y7k;}BpVg1$g_JMWoer}Tj_$<49?h6~6*!ZdmOM`%2;eUE)T z{{DH*{db7*4CQ3soC99EqE$uFk0-aQ>Rm{X;+?j3)fJ<;No4QH7?}IbpT72I3XQ23 zCH>@k(a=IexV^Ol-bD72O2~}hibC~XvUEI=i5{2EiXrr~D8CBtQDk%_*L`U3yZQQ3 zHw!-6RwB|5I7{Sqh*}MjEiL+G*JZt`Favqe4+ulkZGE*~!S~z> zrJ6+|%L}8qTk5> zM%;hNPH|VSx?}{;V>MLYeX$x|kTT&q{_11lbDaL@=>+H0|40>YE~JVYACo{oFs^rs z|DQ>BpcXcC{`pt++&TMP&O`7gqG!;T=So=PDV)Z-QB}5RLsO{)}t2d33_Kw7}VSVnHctR&` zS?5A4oa(M>^n-at;X0EV%xh2({JRQ2j{qNvl=ER&%1ui!7)AwJJ>p-A$U$Nc?S&k0 zF_*jXzkILL_}`_+9Ldr-AE0BEhf`?94c z^*OZ5b_&09QkvG1=y6W{-Of=%?MVv25z%{e_i%3vK^^1r;CSCFpN=pX^+PTYMzJo} zP~f!d5OgVPqH73PuZX2w8S&(RN*OwjX$DXah+Xww^(@aD;o{>SS;o;A)}eOX-Z<%M zV~5JR7>3OkWVE#!VD>aH2=N2PBVX2FRH`AS;$=ep%H?QI2rTa zsU@|?@e(8B(lTg0BNE&?Ll3op(bn#!0kDhOLjJ8QM1P_|2!pqznhVLv$Ac#_?@f`( zF)?jN=Fm;2mOEDarqH6-n}7Zc3w3MH+%#NP(Q{0dO@#y#Ul^soH2Rib`S8uk5TEjM zK?|7c=wrW#9`)gm9$@8t3)pc|m(Xw|`%9-3diVU2G=2Z5Z#{l%LQQF;XEvgOHckQj zu6?lil%)V9_PNUQxVYG!y?;gVkL8~xa^qezXKo>upomhH#KDX-b0)UH;~0Rd z#Not!cuM`NU8E^GV5n8~29%Sbfv2ttG6*#?l58B8T_Tma#w;UJy17FOEa@%9*WZlYFe|9g<MZn z*`9J~k3(OG_7sEa$W#Rt#Y9z7#PPqzYw?#nHf+Tz@%#qGJ@Uot$JL)>u?u@a6!q&10_ed|CuaiQONhcP;0GE7!~g0RY%LF0B=fSU=S2)^NYB zHj#)YN)?2YfUw384mK!?rEoa(XG8Mt^8&J-pP4BwG7mO)W`76Qwc0;gYoB`VFNYf- zi4D8XA6XNwo0m?RF+|XVoY3apiQFvp`WyBm9ATI3XSCUFFho`v75?6y7IKwlMYAEv z16E>JzW^IHa_)!mpv#3br3O z=~5ZDb(qWZ7*uEDv!`Q-GN&xM&EzP0mljw7VLphZT@`d5srHca3tJTE1rGZ%6!djVY1^k?p<{Pm&$}ioy zj6e}9Ulauf?)B{CU?OppeZ`;VerVv`k1KgOs6(a;tiiXIbU3mi<%Kl46xHnqj|%jK z26rBZwYlidhv?Ntf2Mk{#${CIHre#Tvbkbt&(ih%lgc34NSc!%l?>&&J{{Y^BCWXk z`ov-?LCeo=IDaZ@xH2*X1$Qh)D>+8q#;grJLivygIky@?O!Y;whf!i?H$bA+YUdC{XE~gU;UJW`NG-x!-dH>_AOrM(R$DL;`y-Dd?yBijPl+n z7ISAM(E6cm?5mhx5+7&3a&3hn&`mkkBr4RL5UFe<_UpPJNd|r8 z2pIYCNoUIao|%!kZ@Q+u%vryQ=so;|4U|ErN_;BgV419}JfMeh*2Yth^X z!j!y)yZ`fkKZDg$;0hd}+l4_2ylhW5grV9C{YM$1;c4a|icUMg7LD%&?V%^K^&VpT zeYt@B@!O!YFfHRWx6xVajY~Y*eAndzyDZc|GNiOiAQ6_ARM~$r0OBf*VI!?}J&Za9 zyeQ47DJ>>&^FOpA>bOoJD0P-sx(e4Rl*ElI!*U9pD=!c*Iaomn+JliiJ;cR>Ff8}~n-GOa{+EoI zAOWm+|JxWwm?jUqdd2|$61Mk4>g>a$Ux3alUeYhpTMjk1`hA&xRabUYZLaz$hR6PP z^?RO?^HCeCh0-qe@7||agXZV!vKne>W{qklXo*#Y>$Z!eyqdALK8Z`i7Nhu=Ad<+0 zzvL{=DXhFNis*CYva!YqEy`zEMb5u&t7<(r#UW>A1dAv(ERkh0l=z>RjNXCe5iP6> zcRC%~SsC^zc$xOfr3mLJ|AdyXvS9dq;2s!3o}2+`!%{R|-nT;hP7@B_itl(pvk(_2 z|K$;R!eKXXA$6sVRil6q&>vke+U8AmH`)(A&j*I6rhFL{cDuA{mDXWawWNE+)MWHh z1sMu2bVsV#JO(5vlj2-?PxXsT7JA-6i1FKSFMQC%nfj>*B56y zPt?{hh+7ZiNil1nrFJLFC@7H^%;g0VV&QbAP~ut67(*~;*>tm2ugaEEqUp;V2B>1H zBK$0}d}V;8R>ekb{DE`Z&r{eO8|(O{xO*y@E5hb|HEIChBlfbbImbG9_iDRF>sH-q zK)}SrsWuKB7?t_H{C|`HX~acu7oQez1qON2Nwoi_v3FGvcE>)Cio$+v`s_lGEjAHI zAGB;;f1IUsDTuefgYE3aQ2BMNX69CH^@8YeTCKAIttxUn*8483nf7~tmL!hGg{Y-Y zNgTSPOLRrHkv^=wjES+|Ryiywy(Sx#$>(c;$=6jdiUf7(;Zka1_0Vx)>^bz3aBzE4 zsnS~Ghm^~FEX}LCK!Bv(|BzLiXPWvgx4<*a*t~S4>an@VKm|$d;2e4$t>bpubG}p> zmbJ+Lv15B4?$9)GYu!X;#m-tKaH^#Zt^BR6iW+VII{7mwNqB@7;lE`&I=b>Bdy!=j z!jAt2vW(x_T5(smlFA zMN;e~vD5Ez@RSXoEMs& z#+G<&NKQxJE794<0aHq~A&)c}ho?e6j#C);@+wC#%H_O&`;+mFg=PRnd9I&56kAX8 zt|v28OAd*{in*A^+97Ka=J&}19Q-MqNLF4kyFj~4^~AjkB6kC7QNfUs%yf6ZA5m!@ zupY;v-PtzEv~amr4l*Y;#YUV2Y=xlV`(@2ki>#IeYNu>ovdzhg6`EKI_d&yVhEJl&PC7f%KisX zM+tYT7l|jU)v$6)SvUEBcl+g_Siblm8^u<>mEhXLUWJn6c5x@XP2qKCdeZ{V2E34{ z2MEySkXPq9(TP+y2EbRVXca}?vN1g>t}rB@4K_w;Ld|Hoczth3nwI?*Hr&k zn)mf!r-|XZTCYn)4-{yX zoXhQ!kj|fV*hr61cEFV06&im4X0rnm_=k;n$e`f3(2!>Vu(>b-4A5@2gl>{%SdX;v z>-yyY*6m*M@FU)4{Ys>VQ*T&xrr?Ok*j0$iY4jpK{UboE5j{^Uc&VdIU{wzkQ6s5ZY@0qFjl&K<`=}w*0e7`be z>|-&OF<1?phq)r>hOWUfucA5GA?h^mO;fC(o&Z4k-;+Ar;9Al2F60C9myGP<#3CrL zak-Cy(xUeOvmNfk7!GWkxa((ZPf*P4_vgX`S4$AY5F2%?gap-^Wl>Ov!$!sDV71Pj zveZ5Rd*hoF?6q><ag1Y4Odu-aLWv( zQ6yBvgWP1A)SaUG}FoQ-QuY-Lz9_U>Q&^&@4C95v+jj0#*W31dwAsPnTx!DJTEWSv z;B>2uhWe-qD8+0pSrEH-lzj|^4eZ%n8S+D}XtGvvCWVZcF)wp)43ANi0iQn}qAu4; zCuFwZNPHzS??Y0vY)G*o&#I}&M!2xA8VS1o$1N&aS}S+TaQN2}Pn^@M1udoeLkSBh z40$$Bz8&d7=L?HK+4WS{S+dN@B<0J}bZ>OL{0r;4*+amb1@o^#Vv;x2nIK4Roux!h z%gNrybw?i^tybE8-~QdKdEtfxn9BdXC|D$DXR5SWS;f(8m*`awg@%;I5(BKfK@rLy zA2V8o!5uP0u*>}ia+c(4ERwof+@4(7#K_er>eeek4nS~H%cxoN&ufWq;DV9xPm2_M zJb~l-O{#%ENU!=*7r`1oU`X0eC5h}loVO=k|4e^iHUoA9BWp*}hnJ#oCKIifa_Ng` zjNuRxNt|=7X72yt>#d`rZlkW@85m&bp+Rbf5D+85SM!HczN~OC&TDlvNMrouQ zq?8h*8|g;)e%{{C``*vDzV$8EV(}+)=DN<=XPMamZUerm65+>3<<9vQStDqk7l%ej*`5{BW7Ds2-UVb4(PgCx~M24vY0U{JMk5 zg}GkuYk@^OsaxQ8q~W2QCFO}uf0n3&COTuT|S?HbC zSW|hV{Y)GvRWsG#NYOA>Okv^H1+9Jmg+7)KibFfyWG(96cNf2N^u14UH$MI=-Zm-v zGm@rg_WsY<)KQD{GVsgQo7c-R_t5Bn?UxE?pNlJJCep=VyLC;S$Mw!Y}*R7m2{;A zqiEOV$ejjakWV%Vb3gNEfh7i>HohGdd5jf!W2Gs{zKMFB4x$AtS3rsn%~ut}o+jNs zugLqj9&QLT4=B#H5z|@oVqP9GJ&t4L1>t=35ck2*fO5jTnCa8}U&XP68B0X;keuRj z7Po#^2ZxQLgF=#zX5-K~@Vz*aXBr0wbq#y}R8p-uY0gR0HM?EBo z>t5dHE)He-+Rc|P;=pzZwO0&@D3lmdcP@_?hWG4=E_K(s7r6;eBi(iC6Wk_4mqWdc z(l+}F=j)SQEkgDAu5Hp>%yThC6}e* zcV}l#?6ucY1V0!IRyid^P{5O?sqyOaoRxHnlN-azj&v$lVH!*s(+Z!vsynazrmzF3 zivZO|4vhx;3n2BjXR}JTJchs{K)6Z>yT1O=2tRy2PH_|e7JOcOCaHI?DCnp_v}uzA zS9Eh5*oLP9C;o}1)naNV-XFJPRij7q_lh*mBvgl|0LKg;*yT`cb*G=y$?LS03Qt;jHwwIazs+&@@{&e8IzW@oM)n3S7_SWDN zn$c}Jeu>!4m2cT+ntO&eHK8mQZdL*KqL-DC7WKqHCxsL;1B74b)ZVs)B zZeyDB&&(M{g~qK1nJ{$UeMiYS!OFwk+g;&j@ea4a5Ihim8oUIiM=(QJU4a+SyExiJ zH2e9(P7v%}6vYm@t0H1@Ve*?lo3muIl_6}c<* zH*dH_51|tCV@@L)#6?4d5o2(5;<2NV9v|Dj>IvqOwC7{NzoR-+IGuh-XlPPy%OXpegThiEvyO| zB&$3NL@&HS3cHU^hT49J$$QkBM?LP?3V>*nz1+@V)gmx#Dnl1?7EmUC59-eVJCgp#E+DdgZCB?Hf0 z`-auwmv&FkuOuHG5(rp*AVo2r=83EJ-d$;2*-}zX3sV%^ZE9J7kmljm_N1V(xt;2&E5t6-TQ1w77Ds-ORt(wnyM z4+KGrjMy|4ckbRd>6gvtbeuupA@B-2oxvBi>LU&EN?&FoAM7ZEOy~N&`na?oLrS3O z;Jdi-^k*(4hGRMXW9zFsnw9y-{?KedP$2_sY3&)JF-g_H@-eT&IJd67pM-0h5O+Tr z4VBkK5iTcCc)8)E(4?sjj9g${m0Csza1cbUNxF~}@^t)O z0;YErDXtUvNWe3WgbI%lgu#v!hsJD^F*3#b3&YmDO$0^>mjpE&NI@MR$X0Pb?gP2u zuw4}M&@`V%b0L@M1NPG^<@>Gd4?{>k94HdR7KH**<5=|ojJ&E4a#yCpyfEUZmF*9}>l%ISl+Ow^ zcyDuG4WOd~LVSPh0AEzE!Y9G&{1o@_?{NZR@pbixZ+7o++3IM zwi|TeZJ2?Ezzu!_n21t3n)~65`L`;_Kra$eM1!-YU)NJrBg)KMVKaqg!+xj3la0xJ z*IP_Jxd)JT6q@4jE%V1a1g9xjqMU@NEF=%_g9M%qxoH!WcI2y80EV3>1YXRT7m|Ez z#W;>YHZHK|4YuS>^$8ZtCZmm%j%rM_X*l$U&qgnYgzo~fHy$liZuSz`tC&OU3(|3%T)fW0FPSvH zEnNIH_QyL8Se+31vpdf7hU~vwUX$(;e%smp&Ns1JwOLC$Bf!ya%iq{OU9?+^0Y5B9W0W2ulLv@q7B2VXzYnxm*{NCoL0N z4sdK|z2>rlEOow%Hxchl!4EL9f75zYRWO$xsndhmA;Mkbw@p5YL}z)A`BRaO#!hUk zNohG7xq}XdX3;|Q9Yw)nb@yRExzYPx5*#bzIJbwtOx1~b46YPp{wBVwBBWpu=%h&D zkMs$)1x7`fM%w5(obOZv{DPvhOoD=>`1ljBXTVGJ^*Buaf_t`)Hc{zlb`;@Z8D9l& z4Wbjv+!T&Z=ax{x=*OZ{So2%;eKjj+%sN+m;<94!?TsU=yt;>qZdSmoU}#idf7bx} zFbC+WUg5;%-t#-h%ba(~2FQbHjhrT!a~`fGY+^i`s}>Idw}-N}=1L*!FlJkI$dac9 zyJAY7-c^cy&^EqLS=8tMyPNdm&0i{E;y+YE13M`{)Ne`FRJ+T~VgyWV^}sJo`RO|% zltg#OE*RmznCP4&hyUHn$B-^2w0EKJqhn`qQUGi_?VXL-QZ>D{@6z zB*-?Zlq-&gw8&wVuhLL-Agh)*JFhUQl!yYbU5ew9;(BH+$i<^J!f)$vi}cyDW;`Rvrkcd}K(k1CCO;Rr*)3Pm(h@f~g z2O&G~KXk6wj+@BXu5_Yc% zNlj_mpAO!HvJ%69#>FpL>I$g>y;MUj&^UDoW0Gf{mGz|ownGs}q=@#wl2aNdN`N1l zhA{uHDv3jJy2p+gy(~^MSJCM-`7qZhl*~I}Zn>Xr6bN;P6}`_qAzmMxx=~O(i?*Fb zodpw9N!B`8f=*1ivWWt6Bwf%l@cFL=XvwdtSk>{|oIiSV@8{{yysu~6T`)J+0JO3# zHaU9_+hdxQTOEq;zaSq8XZ0Q@VxHwFf~8c>Bs^F zIWyu`1x5;om*6^PH%Ne1R+HnV6X~tn)`q@4Up%?XSR4ZkBLCdu9%aD$}La(bw~{O)9|z3?A)ElBC} zf_)Dfh)D!i6T`h1G*cZV>N+sFM$d_!nAj=`o{Op>+k;s$U;}(BM^G)Me$IC%{x4(6 zBv7%G-tm9rBx1&uEV;r7h~6`XTFwS!d7%kb7@4ybl7@=U1Y6< zM(bxD7aI=`QgvY1M4c`R7<^w!CR)0(kGMdMDdMz5Wn(Q#6hG{!Fzb+FbQWbJ76j4S zRpmz~VjqjLt$ZZ`)3yWeuYT@`vH`{7#^J9BIQH#*LYO8}H`-Hhc?ghn=t5jJ-GR-IAM*_jnwg_wSD?$q*Rb>8fALJZ9zc zR0Xli7`0lyW>>T`hfQR@F?MPz7`))ss>Vd-B)Gz9FB z-he0#5G-3BeGH?ah5esd@a?$-W+lDJ|Q70DeKgXkgJ# z{+36T@m+_z7oen-{eiYI=0czbS96So*}*xaX~G%$6uYGfMaGd2LZiE@Ha6zQ;`*Vq zyja(Y?{i^j)|Ap2@P*#_b_CCj}I~U?|0yDLl z#I5}RTw_Q9#Swp#-QXeNYFIo&hx0Ruhk1iAMV~30zs{R%+{Yp9(z$&nMU7k-z067)85-qCCJ*;3m=Fm-BG9II;`QAecnP z(^op-qP}R)Jn zu({P>&O2RCwP*Rna`0}PESr#S9_12=vEb4*Hpgrbf=3H#&R*)uJCwGdCwbuY9L6Oc z+sRamf)qYq_*qQzJNiq^3kE*uPIJ*1u;vB)aPI!CGadh}GZlYihs9H&ZR?bpzSH+w zwGjLl6s6=B0JPU%I7({2Qqz?p?xqeSM9ws^tc5dYm6S}ptQmE80t*vZuL5b8Tt?KD z0Ik}v@&@dqVf|uPJk+ocy7x3}%3U>X0#* zp6*6`{^{_15}9{LtKx_7pT=jSuTx}Xo@_nA(D@S9_xNk|+t=P-C|cD}QXf%!g1Pr( zpYlk3u!4UJcv16)4{pU>)J7~TX90>2ml;;(Dot9}YknrXGRCc7_WfFepZ1*a5J5?p zpxv@%`t@A7FCC2^*NYIu?gd8$hplIPAdB>hRC`18D$0r1OlH|Tg)%1(omNdU&s&PW zWlDYgmm1eXXX4R!{yRAPq8bOdz_$KQ z^}(%CEz3V*8$fRJ`R~^=_y78O0IN@t-@f+&9^3W48vCu662E@LWZ4AZcP3~W_v-uz zeL*oHucB$dQg)p0)7KIepTRb=oe;@9G<%-7&Tv>o$s4P!9(ZFBh*ls)T`i@kDH?rkFq-u2I6_0*JH^(-oV}5}c3CQo1m6H!Tyvjr4dC z&M+0I--D>*z~1^FuITZ9``JWd0MUb!Gs1u-)a2kw;PNb$i`WaWA)Nz#l z?!ybM!53RF^7{g-aQt!@_wK{U8V z!F~YBCHe%C7+5i69&(>>8rzORDCghyQ-QEDZL;lXmjX0<+}vogQoZGsdQ7BoPzXc^ zP~jzt5(9>b8^iI>;^)6Wr(fbaB>Eaf?@JYf_!^eb4WE$G8}wj}TD`-qp_B1+A;^Hy zP&&Q3#8MjbgG*+~)uG%ro(qap3zs(-=HcaZDUce%^l2FL#A7TmRS69i9Xk8blpnEd zKX7mN+w3gsUn|YQVUi9cr}-qWf&0y97xbuTTZNzY@`2m=Gq!i%I=Af$LU7`Y;(f&c z8U)}t!N|ixZO7V^r6`jg?D{W3Pt6LeFx_xpdeStNp}|8c-vAE!ZYczxmIe=55*I9j z0x5a@ozwHJQ(}UWdnn#uT0X@Dp2h9jPE(=sp-P z$DLf)%m*EAEiMlb>bY7{!h0^S8p_A3LL7d$F;#Tvm@Il(lS~YS778-jP?IT;%5c}! z5h^(q{WiH=&z8|qaTZk?5#3dZFx(d+3Rfixz%&P83P&%s$0aORNsM)_Ws=jYip4xR zsd8Dw5|QDrVkjallAunPp7?S?Y|jAlZ$;PwImEkytgd!3COQtIzWbQq1I{#9C^ND?ebgQq&Qbc z)kO&|rsS=j@5ZOIec1-acu9*47(uEbz+7Xdu8~2i5*xI_VjwB@G8FJqzh3jW8>+hA z+AQsB+oj%_BJ-~P@VN3a3D^FQDhM$CEc_?X^}PQLbP8J5b)j#u=6lG--`3gSo6To@g5sTIO0&6j2OA_;Y6zaq{N7;;jo#kLe8M!G0I8ZT5HZ z%fTd0#G{tu0rV5*9jkTD9fwo>{RxnXpFQr!5O-F{j#?h%Vhl8q%7O)L{o(~vO|@*3 zLlCZ|N+=jn68%I)c2#LT@aMzrLP0LLLe(d|C;0_djuJb&M;j=O>r%>v>8`(=1s!MU z8DlMwq!17NhfvT51;^imiw88 z#@o=VUtF04Fh~h^tDY~s7_k#$!?;lnk#`tLOHUYg;Adul*Taw()E;D3CZxIf-23d& z-SJvNRHg6L2G1<_pqh`v9eo%$o7)4KVl(~fN-gyEs?jBL3& zZ839K(ZOOZ33K}L6Wf9}YGUE8ZSc3UmVqP$EdvQSi#j#xX+h{xl9sVp%?*%sT_9N< z4-oVXu#F^<9dqIz2W#Fb5K{7~rIVpOiRb8K3SABgYsVay2p<(zfr@ZUa-oCBQ)J?p zqSR)7`bmb|Cp;q8a2*lQWeEhKOK;SBH#Mxr-k8EA(ZNtMOvwtDKIV_Li7(_k2|6KUFlAB6@G^&#<1*;MzEr0h3r%0E)x^DWC zbN55+2GgbIU#ZEVz#nJS;Te12EvEmopGRr+M_A1#cII@J-L~qh-de7FAovu;kJ+y! zJ`T3tZ2t@GdI&=EJ)QTJRutf*@@Kzh9-d^8s-gp=z$4uPbt$IhlN(_`7y2cn-@M@rM+lcjS;zw{YoQKPUiS>5EAkpcVE?k#z+m+_Zob3YRt~ z2iDNaNL82p3lMcuGwL_cf#d*~wgG*TJ;i64NV|o+bt&dNj2$;F&Hu}1+>ol{9v30Q zPm2^K-T@B;j)TQ+RsONTS9xXguQ_b%-<0s>zbQfT@mQm)y2#P5Yyh_{`f8ScZA)qX zPXAK*^|ZwIZR^I7l|Wvtlpz^B;s;A!4Mg5$OI{#dQw1lqC5VDc(inyTf*t}qkhGjY zS~DZBupIz{Fwc#n%uy~fLr0}Szzcxx@44{0F#T7;FK2eErE6wQX)|Ct?U>aEn?YmK zQ=va)?t*PIU{jXBVlds-Em17o%}F;L97X(i$snITN^DQCUOxteE{^e-(OK&9=yACj zwSW;SUQ~Ap6=%{o-juaRulSOr{W50;z99&E%^p}%c8F(4v=rT^8O1m!LYDMv{3J|y zXI})1)s-9!iHVK(L!H$~0wDpRaXsuNK;p2_`tqvtzRg8Z*4FNIVoXQo2$@JF^^4*9kPM%i zByL!uD!;hTJS{)Ep3kMsxBWB+mjE$?Z;_Y`LlcK}z6bIw_qXME-YMYiId6$2fHe1fL9#I0xax|Nc+euF z62`@k5S({BQbW+ILP0iTxy`{g&@W<7DcG;WUwsFE?NV289ArS>LN2ZLuQCZWo-sw~ zo7gHX8!go)#+7sSjiB$=m=T9;mnq2$pPk}Wv4>Z*JA|Af7!RnSoCD~&a! zDihp6z(+^P*2b&R2E^UJ{XUR7OA*#a$s;BBT>)iAA)`l8K-SG@=%D2F(4?^I5U4@E zL;P!h{{BpLx5u;&yHn!&5tsNfR1JPWz}8WPZZg#9&WK5{t0-s3ziuPp=K6k(NRayF z=-f*N#6~U#*LxNICfK)W7C8Y7LeoJ9*?{Z8OeQoI+aZC-c$_rS@96Ro5r|wpqP(#E z=LjTGUesf%pIqqyq?wI+Py-duRqcfsJknqRemyrz0?ntdqi3`7Vi_2z-a!vUb|QR! zVJ2Dgd*J=T^qO;I&C%$S!Pa-^R{y#%3UGr6+PII?M{$StIY0 z*fLs$>u(vbX50YX@B~Yh~hOb*LHsRK-74R18bl0!U!JrWp@1?-=*^;CXX_lAy zQRVXm3(GJ9qwg7}PY^}HdG8l@SG-2Xf|%FUJH)gQ@DluvgCcmTT1I7ymK9%6wqdbQ zh$M0>HQ+{Y4>R9Js=>S=d(RkYCAIHqpXBf{rs_iPA*ays=y(KED#!DBdn#66Yh>}Y$pt#3-gZ+6H3F^ zu&H%U_-codWXO5qKk?y{xH5F@KrWPbS6;OH&=pa?vn4tek53gzf&bjPHfs%4kz2ad;LxDB5PHn@3Oq7kpso_&cUZgl<7iTh_3Q0!j8pgn zAgVb~gDEI8c-$<;7&DC2)1=8j8cCqb1Bt#pp?3b7y+Eg@>pWMuo{s&NV9L*JXG3LM zp?p}UzG=y;8e39;V6gs)aki@Z84I$hbu!ro5zOruf(iVh>No-TS>;6 zKq=Wub&)I>T`bdwGtdy~vL@=+^i`@@M2|d-`5l;w@0y=-Y$qepse0m{DJSYW6lQ%tcouN_oc>a} z@^)E`i4OOfInhU9BBSZ*TLuMqh-ve{*PSYL2Xz__xn-`opo*vot_GvAHBkA4pJbsA?Y-Vdsm=SraC;mod+xx-DlgzZS=foZ=?-zxw9lv_K~5Lv^qY>%^F^CLCrPat6)DybcGwOq$^%7@uu}{Bl-%K-}~tn9I_epYr$291$Ft zvG|{+BbnfW?pUe#{Q-wqKrZ!q5>tb-q(Lt#UoD?xsOA+l9zNn(PT~&y^5)jvP)zJaBl%l4!P?3FF_Xjg zLobcKHX3q*rF4Zd}|^*@W^R%EKG%ztN{;*7Uy^4Jv-+?SEF&*~Fk%1e_mFG+U( z&Yz85kDzERP#>SIUanx|)qwz?FLmAQ6I;+vLssS0=$zIMFfPzg+S{#sEO+eQij4gBIX-B~FTBjOHQX+2t}0d}ZYHjS z_q5N+o;X*?=>w70ZpDR}7AixOBz|7z`;(c$as~o9?qNS+lSZ`qfKTGN3i>Y3@H;KO z4u}HD^$`S8pEx|DE08R7(aEqUx~bN-oR7Kud7Y^Tp(tSY=!jN!({%l=rs-;e3lpmU z7FfF8i&45_E~Hirv~Nen($v=yD}K^LC*qSL>nRLinFh>Q-q`il^Z_cv^{$opnly7C2_ENv5xUSUwlLmv0SK2~&AP z*&mtfp7_SkPnY1bHF4QZUFsZFk+UW55@h$3N#r{3v@^@UtOt0^)n(|wo2H-(n!$mm zY*lE|?-N!$S#xFkHI~(9oj3ZjVETuL3Q+Kg{UdR~e~ME$E09m|m6FQ<#sD96wP3(U z=`Z;lg!x^Ly)7ooPD7>Jl_ z#WZX8MZc6LDcDC%O4wHWdP~K+5O~+|ka6^I%B>3igaTuG?34RM-+I7E=_Y5makw}c z__3T4L+p5Lb#KN|>wSl@KE>$_`J9KP?4e)=E2qqZFAkMq`S);cTHSzU64y49ZRun( zu()+1#*znDGk2i>7m62>#I#vt&EH8-y#n8ft_SM!^p?C^<#Bg+8{YHiI8XH#zbaMr zGV)+=ZTZ~nqXPlX^irNQ@=umH{e!kpC)XdLM6@@-VN%F~Rj%B}OgxB|&KSQc?y>Nh zGzpfFiRDi;xxH{vfqtu$+Khdodq z6cWZT%ugUr_{bjP$V-!_qI&Evh;LLuSyU~EfmrN!#*T-{$KYO>$xf(ZLY4J%%H*U{ z=+FY_FqP3c>apupQ3}M1B`t`#76;#X5xSh}5~Wm{G~gLx#Nyi;FsZ9kI$6-#>G#V& zpulu(*Ch~zKor>{Fi5*cS7=_!p?eZnkQF=LyGLzveI$gj#(pRIC!q#O06|jYw9k-H zK-y(k=d!8;hx1)z`}QkuNxYv&Ebg}_*}UvEA_NH|a6}4bT!+(ZHPlXtq`*(Jo`Ez! z@peZR78$^boGa-_wf3}0R#HMat5D{`WL~EEm_U=1AksjV?=nPr+|BQTt_Zmw4Z2-W z5ns8te95wXdx{8{I~E*b&R~Pu!S%An<>b*8c#;&JkTP+w&pbO!dH}6$37M}V4p=9i z(c-|~5e#BU(EVPzqlof*)pkqR#DGSKf<$Y&6;GbeqA9m^EDM`B-C|anidcv7q{Dpk z9cJfCQ8qvRl&49C4Qcokl6q)E>`mopOeYJW#(h*_(=ykCK1g=eDsFdgcOUFchMv|3uLJwApu&tyNx z0O8NBtLHGIVGpI~-Ql2=WE#7l6Qk9&%rU+YTYzPy=2no)!Jh~+nHPS7aIPgWD@qB> zl7=>`5aJ8)8#UmOj|^JCS>i$u3w08i9c@)WhdXA%eRk*^gnn~4sYVBz=2XLu{(137 z(xR-h)xv?DV?6?y!^keM2mTQ;gWOH~F+?iZD<`&~u`Va?3dALE@mP(vlNiXId(+_t z+3dP6(gHm2hkXvQ;(_;^;oxNw23iGr_uQLcoiXQxQ8FXbuM!ugkk9Er&5yW|-B=&a z&=DB8`re#%kA0kf@cZzJ$fE#Agg`Wx;Qs^@_P>MKPbS83|9j)9f-RWkd)4|079mV^ zaSGA89prWhQAwUzx2d3?-vp#l7mnRa zGsO9TE+`q!SjmHJ_JgMTY(15-MVq>I0;48xK~A&k@d8&3c^O3-Y-mGtde=g72H}eM z;UjIqX2JBR!2d*IK3&W)efqb6*j&aQRJ1KxnK|%K{hT&tiO}+iaujs&*(wPIe?@%8UH<;G?7qvHN!vxD0T?^6e9*vS_zfjCNgQmLFl=hB4@VZJG@`zk9KRLHWj z9kqAeNVzV1_vIobZOa`f=xPvl{eRA)Y~`O;h|{nF0Pr6vQ5pv@DQ@msj~D7&25g1~ z9vgacvwM0>)O?!sB(tT1_v^-5gdGH*NPiuChAV~=&@G84NKC(9Vs~L8Dkr>F?M?YQ zgF`!dWEewxTXmj+yk%N>sgxp45=DrbTM78YQdeS7gJ;jih=E%E;U)>oshA;9;=O60 z9!y0KU2x4Bj@O{TNE{s!^puI@O{wC(nQLxXFD~2JnUKkMRCx?5okI?VS#8V>9#Jq3gLsTM*p z6D`$#cF-$9AekJ@!UCDB6qy|SKfZlA6RJ4~=7<}mPIil`DE+e|kzK{M8yJ#@CZ)68mEvM4zuAnK&|_A>vLt6knZ@nl z8S_S71s9*7H4)RK1n>3?R4#=ncJBe! z*x!x`ry|RbF6#a!#2lz50=UE^u#o-L(j4vDI7ruE1X#eFq=W;1T*$Y(K+2_K8#@T{m|olxWnlEaFWS$z7dUw z46!owVLv5V1nX3gu3vkhG38r=y6pm_Wy%V^hX=nXzFh|` zdUo!x@A?{?-Q>tw&YO4#bxpj3@A^wVTy_}$Wb3@`ItIq}isufo63+=9_X5QQ7>n0O z1&kfi^Vr=!D~g*{`#m;VnRqcC7=@Ftf%^Q7q^mP@4!?-`-A3UdbIPO2_VT1Y>}Osx z@T%}#JaF~TbU%K<%mmrKk+lf4$c53&V$S>NBsC2EaDl;-8TXO4Q3xD@uD#Fo#i`3Ma z8{U$jFR_5~iCPey^fk?j8(}UttsLp_YU#+b^`0!LW>RcGk|d!%j%ypIwDgJTAcA%{ z4Ih?R_t2Y{7M&A6GF&MQKjRAlN}ekgBSqC(-KLI}jH8{PQR3GnQL>gw3;*czyM(mi~x7Ew*@ zOF=aq@(uI(=uqybs^j>{DW@JcvryfE_moaX^q2%cab$7v1M>Xpi+RQZmW;K7BQ*nJP_E57J3>aPVMNDA6!%>y z1C;`sg-qS@Bk&Hw-K`ec{37w%uUyJalwc2ZRQANcZX_tnM>$4K*{j z{79%K(-T@X9|GdKA=h%=(q#~)=1+Dz)%r-RE0c$hI&DO&_E4lHCamHvOU$BPXmjEO z1gw)4wN(z9p|}@f!pO}?wL6-IxfR8gBtIZ*AkZQQ<)cMK zaE4y=i!*eyJwaeaTLoP>Bx8Y2CL;N4Woh%QSQL%1<Iv+1}wZwD;Qp1F0r47P$%z~5X7ZRE%<`K%ME632O>I3o4O7jT}@Mbv(<8ZY?-!u zehh)>%dy+3MDZ}wdz*5dEzs(TLq7(h~eEoNS z)FG^};Cq)+8)PB$qmsns1hL2sy9+d|%xCp=(L)cTKTR&cx)S`8x{Cjsy2v5XKHJ#G z*_9aJSY8b z0Da&>9UNO;TT!NT#? zy{+tzGs|p5qf87=iZb4@@XbrJw+=7s%QLloKS=WR>hk@jz+B)rLK*YEZ4_s86ptY4!`rO02NnjSXIB=V(ij69_r zC!n`i=2!g~&$XP%(*9XqO?#g_Js{!Xgx_W}>JwUiZKmEGaa(j3kCr^tbyDxy7W4XT zY3uh)h)3?gdpRx*V))irR=mV97OZQ5p|2mF!S~FAvE0FVd3=d}CDKw~SpO8aF@7K- zPokI8qH`eKX`ST>K*?)I3qd-W8sT|w5Jl5;&-}iiIBnhJ!IP{$E1>1tk+geiN4u) z@2vquGI~;a5p!nXZjZUuy84a;xgm#JJ)Ak^_Yw(PBR}u+*n}}3X383lXRY2fo4sWH zS!6A`*Oe`EFUduU|F%GCfqsNAUO&6jN}w}8Wh&U!06I!-Mv)8FT~`a zJK}D^ynZ-grnll`D3$&?QaEc8UQ7zSbK-bzn&RCa3GodK^ju*Y5usZaq;x6fsdZ}z z137!B2GB8bqMCsrpaqWGfKf6R#1jO+3}%7EC=I9-1eUK%FG~%30e)ipvCgvR%f0@K zCJyfrm3l=7JPHE7#c@fdfR*=T5^D1udA*V-jdZ9yahRc4=$5xQta^~)7zCK$Vaio( zUCHW_ab8Chx|J{lfS#i`8^3C8j<%^cd1)k zbfkX6+Y7K^9k70If( zrzpNfy9!9R&JTY-stq}iOCN+d+GD7pnAsnFpZ&vebxN7+ESWp6MZO(GJ@8IGLQm*d z+VUq~xAdmi4q)$&03881cn;kfN7wcHpNUk`BogUpO|q*XXC}mrP)k}SnRQbtg3E2| zM2nwvI@$1~UWwP+HWZ#PdXVxvP;N9PKj5G_6lqEatD}#r z_mZDM6uDm53$d7_rWaCg_2O6w)?Ltdg)GLUAjC{mNO?m0Ltswa^pns!)#A%i8&S5yId{VEsr*I+W6qAQw7pK$DDPQOf~2k>{h^q#EWDq~9p+IJ^aQYXM;ZfU z?|WqCv(D@qcb*z?piYbGt!!WW^!@U>F8y=6+Ixb4%k}uY?*0EIYj7jmk9_ZSu6irK za65Xlb(si+u$bMby%^sVS8IR6BZ~7lOc~Cs8(|u@M$=$s$NS18;0dCJIXLC>eAStv zxr=davth<$^+J#*`ViMv>@pdEDd(!nc`MV+IZ^Y%8Y!f=$t8VvEpkvZPe+eNksmmym4XarICrgfJ543AKg#YHpl>uEX+Q|O)ydTj6)jE6({~sS_s0Zo}(eMUb8A1^$i2K;05Y$#)E4je+fg_ zaZ|B@H&Deu)7Tu&VAOdu7|gXbhqF*8*3T`mGm&*5m})~zL?@PrN5J^lzsu-4UZ)@b zzA?4^G)@11fcBDaaok|P$o}K{A9NQBZ?k{6Z)Q%6k0Xv6?!UNv#CStW`ttHi{+aIR z^>8Vq`+BLO_7C409NIp6;6oWrUf+M38z|wVpkW{|_E2(PegIC{!OARStM16X-g*^t zAQjSORX5;vp$hbRpEIJff~Fz-^&rzhjZ7=qFk!yC{C=4wrk}y-bS)1dLtNLFSyZdR zH)H{T(V>6a4VOto`El{8-9HY z-IRN%_GE&1yW$FK*5rlSwEKkOdqOCda&NK$djN)ch!o?lDb?>`h`-P?=@T#oB{y2r zU`}q;41Y*n6@UIsFWnVcnvV?x2cE0($qx*$VD3zoiv<;S9-Cw4dVX0nc@<#Fb{fiykU&&b|?? zdkEnRwJ~nL6$sOliNq4{D`)zZE22LQEC1d7h{a04<+ExE3jyQYjTOVB{loGMZtO|c zTV^7yPnJdh4_jXy7FE}NEkg}Ggmlc%DBaBfLrUX=Al)D!DT1QJ3@{)_cbA}|q#&Kr zjnYyQk|N!F2lV;9?|WU}zx=^9=bXLoz3z3dwYDPY5fd+-JaD&Cx-G@fb0{e(d+sCI z{f8z=zA}ETZBWjlD~m^Bg#Dr3O#QeY>%1sGQiTTAg%+@vB&jSs`%PUg(gpS5?sRVu z2M?>Y9h73VN7e5&ly97fJz4`UZN9;cP-XxYAh4Vp((YoE&1%C}0Y*QN057n0Gh*f- zHqCZ3@YmU_fr=n|S~PZ9Np*j3r9yE(Uf7nf>$ziZ%I7{f+`A{Tl?H z)WdMy*Y&mAd1YYA>pbyO zB@!%lXPE|Pq|@c=JV$q!74I_GSeYjL8S(FY*cZ7TxU7pB<^6JclaYWEn3wa;gb`hb znqsn`o%^V{?|O(`uT~-4V26rAYtM3ws}6vHTo#$@@uty!7F8&&sy&fx!o~eme@bm|6NDD@-|h(^LRzR?>qCOtcu!I-8PMuKaLS(PuBl=LqE00kH7d zF8mv|Ko4xf<1mM!4arnmyLej$E$=))KD&I&YgARH?4DJ|s@TFO0uL;%I6QNE0O3F~ zx`};|cDT}Nps>#4-xD+Mb58YeSB4XGK z^2=rRwZ@l|Cqh^EcwemDEx3BeVE|mV_UHWsMgx5vDH*4}5a@Gh={(pmn!ann+>RhP zC|!=4~5c9>y995p24&vU3SFz)GV6p{&s zX`q9~X(^>{mp)-At1A=PA4erCrTL;&b}xLAvz1nZewdMJIhO~NrQo$>iy~gIf~tQs zX!(~*eKR7n9IGFmzK>uXbglo+TCM!|dTb{!VebKUBg&Hu+aDFIC?&>Qs)kMTSE^yxh2oA%F3x;qpX7!2?Ehm2xe- zj&i5%ByQD+Q7AMTd*L?0mvZ`)fo_5Gl%S=yH%ynLzEdB5C3?BK;l(-9%ar^%J6{7; zis&2!HhX(N_Uu=a`F7*k!3lMxoKF>~ZDRoy5MUqglLSGj2%fy_CUt zLqgYXS||*l^ydQ|AKZK%LNE#jJ(l~9cm88alRfsw1O@J9J^BY{2mO@ zQFTutp=9gde2VfDh6FJ2E<`*l4<8JviJ%nOR$10fC5-TfeBG4T+ty`l=0k>w<$L+J zxP&KhdcqeTXvqUNjEXG3KWP@=6rapeatw*;@PS|~OnAr=>h?%LD870FhF8{QWR1FG zNMPINUo+_uJQOgVg<*g9`WJl9YD|t+j9Diq*(fhfwG{s}!r8LGHZ39%D||o-vRfU< zh&!rZmGKxieJU|Yf>=xIL6$bw4p4Jv^|F;4~KU2{P(i$WWTE>eTn1;hox9Q_>z;IDzeYJ3u~;%)i%%Uk`zB_z{%_U2rZxvJpi)Teu@3Z?s2j~@?o zqmMhp8eu9?j0s!Q)xFbmm)^eX`dN1x8g&0@y<2eyVxOI>!UG6O{1A zA(0*H9?J@S1I}ZSlxSCWxtUL;^q7s<6R$Hlux6azXs^ptY=%$cMwzy#>z`%s-3D0| zJV!^CiK)L;E=XVWdkZqt=4OEgZ4dipJS7<lwNT=@}|Id_uuuT^B~`jfv!@)ocrfpFRjJPk7kPYv^+E8_uTFK+3izk`<3y-pd!u!G6oQ)vF>+_%HUm^f>9sAC-!0;__ketY}80Z^PUA!Ebjw+4^erAt8{ z;%)BV$CeI!w+&hNoJ+G1j7M<0Z<34jjp10OUgLa2id`6X1%2x6zQlRI^IpTe-1lGN^|0g@Jf~sOFF#KaNGEr!5zKYSM20$?T@I7sGtG z?z0(qBZad&7@apeu_}*>092lVxhIoZew)@zaa8s5*t_o@H1rH3-KG&I_d(3~52p;S|3jby z7{vx%gam9V;C5_NQIw5-Jat^wlADNq^8oH!TM?Z zSXgm`Mv?-w*?27ys?5=T$)&(WD91^Lr62BcOWeJfz<-{U;PatjBx>5>^Ro%Xfx-uM zrvtjd7Hqz)>4D)PsIY^#uA+(mKttOrJcH8>Q!al^inpshbYhQ*Y(`6ZCq*oTm20rf zJbRXX5-p#oq_cuq;)gpSvHl822W--|li)A{4cAr_Svxfw(P^oQBM- z;I^~AY}m*vPtNZwSPzto1G#v2eICKdNlWGLWSxwOEtUD7CevRX`5!RU z(0!z?T;BI;x6uA<)y}8Fg;@cu2R)xRbXBn8{*?Wq5JmW5!IMz$+ zj3#x>(C{ZKC!YP%xuvW;A~fz$dbW+6kwSD|sH5}+{b?p2t~<{DG46LBNu8;f7ILDJ zljbKWVc1%uS^Wf`n9hR?(oS&gBiaMwPTc5g0iKq6{$&|6r=;3<+<0^J_`w%q+?)&e z1#Z*NnMVDO&$j4nrnP6F010v`^3CaH$8mySz$BBIRVPyTtlIxf>?G9x7`J9`Kx$c} z#5b~jzku(F7xcVi32Et#ef18B)qZ^q#mRo59FO_=B9-mN*awAbqCb-aiog|^+g06_ z0;A7d-3wf2#gi`|0in5kbZjRu7$hjOjMv_WEz7ys;Vim+LAxH@p(=`?W(u%a=2yet zyBed4(QVN{*Ram+{FvZVRnVB+Dq7dJ6k_~JnTb2(4Axe#8LH|}$?U=4?nRT`&gltP zWs9QwhDHk>DLeIUhKOgz>I+*12Ef-4UsWY0iTmsUBI#q(`;(zC?pNv|!M+9~N*M-} z9{Ch4`y}hKUaOXU;oAB`zk^aql)(3%*~(2293SQCIw8y3N7pZF^e+wY^8YN;B7ipThap`rFMT{0|_kO6Jy))=L z;7R{ntLZnVo7#)8?rych_BNL~07yn0d@6aNi3g^lxVV4KbZ2*irFUaUy4#p814^h1 zq?Y3>du!f)*C1~9Zcg-s(77XDE1t(jv{z+EK6u-)$Xd3IV&&TifUlGv`{ zs>$DmG?kYA(4OENPVQB!UDCuaZ6fLq8F^bw5tY_?Hw#0;y*YHVd8aR&lS#obMCB0E zMYICS`I29dW7OB*B+gBE`dm?tdC+Ij{9bupuIJ*@tY?~kiXN06w9R;uX#6|J+PIgoi{g-$C&Tls z<-wKn-OcwqtjJ0je6j86u{lvzZ~|`G->+1(aKLFN<9_ysO4ldt($?&L>v|sQw?WiI zD*hOPA-~u{i+KBm7*nl9BTW9Y)#iH}VjnH=NZ^vZ{M+engjM!PHwtl8>uq5Jptb7~yl zgu6jWbzO~c@9{|H+}0&6CcG-qvJ_#0yLZ2Vq>^AYelR2p*XqdmG#`7{j3nRfR&JcN zZ<%Ha(N78ld?L=RA^Rt57h%aGNUfPPc(|LMTdWu3O5Jzul{v4rU)wLW=^^$qD@GbE zGqIf2oAyHo3q<<)Cm5~wmbH(Onuf+mcNT&kv%)09tDN2hNj=85NImZJXe# zG_NLp?C#JJqso+!(tMQ(Yo#)kd!HX?zAB(PN(#-hV7&eSP|EL<8mTt@t%Ag;|TP53&4hb}3J|yW1$wmU)B9w{D zc<3W8D{3J@@ClStT_VaO!p=#-W~=876W)Oj$>Z9Tg!lHMSa@JNQJvC}s~5*Tm!Isz zS%tH07jUX)Y$KX*JIC0O{TfIw;M=M7!>BOlIO*(OP@&h)J? zveT>PCw|UJbCxY#B_nWn!?$qt?+Y{GbJpF$@VGR8k|?I4jPNbzEQpgZ_sH?Xnj~jg z@h`FN{*qEt|KnwuUr@r_%iX-J=ns(0uQwbGoYX-UM8$y7QjUI5zj)gNUqWJKDf(o#B-1R%E8HRAGtlrL|*xvHp@(=5vRc6a5R= z1U0jvx4@8de`ws0AN&w0V0Z`PLBt{pRMpU6;kLA9?4o!%%XRnc8=*M<>W9$f5iPRl zX>Eyl=I<+6);77_JTt{LiN7Ze&h>r<*Q%$ByDO+7?%Pu$THM{-bZYDB*b*LIzF|x2 z@3}ad<6}A<3SHVTzpAbL^r_o;c9-(xH`A4S=LcrhBUm3TfYy<*RRB8(96Y|@z6?n& ziUPm=Jf=ei;phM;9|l9{gg030Fii+IRSnL6X33rq{kw1Xs_9(=>Gs0U1pe3MA%Kmd zns0h=(sbff`I~LFpgjGr8Lptodb03vyLJzg6w;VaxM}KMpO1bWrJ$B=@kJARk$3lJ zp?t1X+jMwX>(V;Eauj?}8=VGfHiG*BegXFCoC>Ug9HOW*OB{sGB(?}s$O_Cg+<#6` z$9+J70C(k%K`Ofy#)qEOc^`LSW3GPw9THado+y?xOUh+BDD5L;Ya9vu$kH!|6Z@ym zvC`E)+VM=j<7fF!tGsHJS%V*cZ{+>as_S6o!rX>QA;a@{b5q~^j?bE|o}X+#ZMq9A zul-!A?!C+->B(~2Ui@Qk+RcDt;{)ZAZw zDvt$nS=GVegek$SMsG-`<736s1vD%RvT23^cg2UkBL2ZImw@IdSZ4J;?Q1HGZE9U$ zTIP*Gca9xz(-&q$9=25uuLp7}R7S9)p0J;(jQMq8 zx@VOWJkUJ=s})t z=Macu7k2&p zVZcX;Bo`ME93^9Kb^F;0eH5l2gtGV0C)LdIRJUNLoVaZpe#QGcHexB@|j!^%KSGb97pqP~|Iy_0Xi zopNG_(*;MWrjm5|{p>z*l65nia!D-AK64{XA=f#7GKrcG zwj6pArjUf{yc96szxw4e*w|~+IVi3DF>p_byZLRO6VD$}K86wiU)15Kw(f~=g_y5# zpSP-}QIjNIpR<**lsLp(XX}fzg2lAY@8x#zGag+feX(VD-^hufiygvvBLOZA0Y}Vu zR?i%aq~&7?{K%dleh~2_K80?sd zNp{LjS)^4y54I+_*7~@b6}X4<8UFPeW9}8OqZe$|_;}pk7aL*Zo{@V;9sA`L-&d#l zy+{9a2y>#L?WYr_KHDpdKI#MSU&ehH$ z`|M8rlXyLfqqaoP8IDmb{mDEi)S=E1?YW0t^UI`ER=K?5-pAAa@$PRW7cOn8&mR75 zWsR=IF}GqV%A`G4B4(a#e9$_?+!D+>tSsz)GH%Kt<-KVba{@s|3EvaS+|TFeo^5!( zk!pm8uC*A|`j-?C!UL0$P~=4nrwX~)VHQ^z;BxCVY{gN$a z$4aA$j^xMFMAOuJlF;?pW%Dn`e_yphpte@4tZ8r#VAKSh8(pECS#D=$s#)A092DP% z()C$D1rHEXoEchZ6=nBMRjZFJ(WxYr6ZD2%_sEuzCrKtozH44u;gHPygEw58z7rmd z0a{cQEKpv@@Le@DV0a7IEqWUUbljQV8+EqRM1wI^Ihq}1`$2!D8=lu)DF!{4+NwNwSVXiyqe@nKsuSK4OE}?`%KQ?qhuiz(`BnV- zp#hl<&ZZQ3+Yk!O+36+%dy|8wi=TrhU~LaqKvBK2p~0>mjF9f3EhjxSn2UQ)x{zt@ zjWuXmq&sYz+{WIhhbH^fmrF}U*;)yv(oIysKU86v@)Lrvt+6Onh6tU2t|-3M$0{y@djX zgcY2&&g4CZRRM+p;GwbhxWFSL69aCH;AlD z^BV_Al6-v*_zJ7DCu~YYR{d}eD&U+pIzdxu!7G@8(DT+|zhqwZ-#@JkklH^+@S-WA z&L%v{U5Rb!qjDo*dx6SSaW%QuTsl}4+@z{l==8I8V|i-sMa90Udx5!r!_9dF2w+&R zSt7D%Cy@v;h*NOoK;F$8qzMd_JIAOw8mmsByKm~dXi>+aZuI2Qq@^=17^yj-nkIo? z231a}$7j9O4N_DHmg4LVUwUi<#extx%i)t*(WXMYo0GRE$T_={+eIbtQ{cKi?uM4- z=m64!5~-5$RcRDU#XsA02~Q7TuI*0EBH|A+bRnG~O~%S}%=GV)G(dTCw4Q>D?*?%h z$OXbRiKP}(Z&O^Hmdc1QRrLj>;zMj~DHR=a=SxoJwK}ReX1^{ZhdX2-pc+}Q<5yPza z=#x&UXKb+Ts`G6}9U1Sh*5e{=Tb+bq+?S$23j%%dBNlSLRAYxF#8mr%3ow8%$ZdYk zfPWZI#dX&uT&zlmIhFxLzmW=cf2ckH@z#PHAR0LF#P%5@1XA|A@3X(Z(cHpKG^74Z z#s+Pb^>w+C=ro@KCIw*BZFW?w4*yD&;r)jQa0Exp`j<$*T(w`G=eQIYx@}#=vt|P) zR3vYuPN7x>i(A70AwO1a7I-xeQWI)KB#2L9l=I2+V} z&XU3W<^#HcBjkLV{FZ_1SS;B;kG;(1fjxgNIFj+Yt*Et>MmbS&(}N60_7k-sts44R zv~mM~_MD^SXY{d!o3|p36D`48yoFpLVwX_*s>XfUY%RoB3wA7}J1MYX!bot*NsUNW z5~DqF$;P(Cx>|HGnE1IsZ}WL#rgeUXbj->8!n%X}i%{_{LNZ9-GxXx)mNlqMc-;r* z#WSGe*?pB=bZ>6~EX)J~B)Q`8o%pxE<+l1D5uH@mkG({>2IKTtX?Ir%)Y>#T0s-Kpx(W;V@yDym z?{uHL%~SJN0;N%LA_DK#W#KHJ)L&AQo(vJ%;d2bmXV0btG3FFP*&q+B42D_IoRhDL(j`byu zzeQh$&UME)5>0Bz7AR=`HQh0cz^d!6!B6^%6`5M9jFs)l?)U7-U@ zE_;iBnC$AVmF;GE^=0d@v20K)jLD7l3>$9{R`5*e^ejv~dKg?Q%;?kk^(ip;&Len6 z!rcO8i8JLzH%klHXHhD`Z!40J^~3(9z~zbE?o- z2ITJYT;h+fd3Y@^Sz&Le@KE7;@;fG*tG7Kt{vicrdQXykd?FiS2FvP^`It`_i+=04 zR$=}U2zCx*6gSav7F5^SgApYW)S6b_#_Yl+bk37zzM3ZlBj>A@yK^a+z`PiiD@Ym8 z?gDoguHq6@SR#g}=Q0eto05p|i`CpfM%_~EwJZKS^s19THPhg_y9ELU2!I39kRU!y zJd_NgiKQQ^chmK%fe|pM_;TXwcoU(o+j3He+hx5%rQi4G+B>UR>M#(9WK}a9O|<4h z1wN;S1dT3vt7#`qU1yAySJU{JQ3qu@Eor}3U(r(@r{R)u+NXj{v>C6;nT(6W3N}MI ze+*lB_L>08?Th>aWB?}5^?>y zIMn7CXT4U37n5vDzP~|^F9Ao9EL&mYiAc>6gKNV1e!RK<>J&d1?Suq_EAt1CpJik?IIqeyX6yOodFhky0KwD0~*vHU--$YJ@_bdY$?BkFX(mZ-0z^ z6YTzc=doPvEb8cZ<9QW1xS)71F{2v?SalqxoIz_Q^ijlA0!yz^k&VtR-GDLoT@N9c z%Dw4wnaf6)-ZWaUf(ca2XY@>0GJ)#(Sg2vwl%DK>XqN+L>iFgH{Q<7J>@%h0&)fa} z%nyPBhW;`>hUV8F!~C3O@FNz&rmGTp^N3DCaBnXJ@ykC29Q;drih#MCc}_a*@(h;CWA#66jKw9S@-(^bYD8j4u2JR&x- zez$M)afo4jFGfmcdzto}yrEA;N6nj|rMj5`!7CEc_$(OPlU_3tD0pJRqE`#>X zX}7|>^I^P6`L*BgfoA!*E(4!dc-p>wZ0sQnRt&R&^Aj)!7|x?|L}0N)s0jj1G-EVf zw*^^i0}a0Z$QvQ1BMym~!*5wXym6~Kxj5o-0(h!VzI#6@aqHAC#o!r3n~%)K`ET5& z+>S>m^3|T8GFvAH$FQ8w>LY17*tdEgk)ah7td7A|4mhj}tN*l1_=< z|I4e4fdP%_82G21W*!!zBKoWQpDoCM)k+3Q&$vx~?f!rLvC>z-LD4>9!dBL0X4Zjj z4qp?TXpMy}j#@HvP#9uhv56Ko4DeZ5dky(99$^HdH`EHegc7QA5<3|1-^~?W+2l{&pohdZk?2pwG^=NNm*1WE;Q$EgYaJU zAWg^KoZj;VvC)z!dI=(|l&C>gj8<_+=@|*ZpmMf%9J}JPmz!CMOwX@WSQA0Rb|Sxy zRs<311Z)q^ z{`nPmjAa^q&O*aT3o|&W|C_E>K`%xnf&-jgxXw1e%jN(f+Rje$zBa6wkKJ3TXjjFz;hw2;w%tt=Y>z=Lkwdd8xE2V)rgZ0b5t36N2~U&3>V;_C7l?Ux=)&K{mAO=zgYl;9v!86iS1Q7s~9LH_*%#} zT_B!fiMu;n8Ihc<%%zOMMXv-Qt9V6AYI*ukHH=h%h}nYp9rgk%$@}MG-s!k9^HY9f zLVB%tf+X>(;)?B^OJphHHg7#qd^I4v@@7_1CMK{$(Qf!*#a*V?ex zJ><6?Qwp&>sR|{+nH+%UwkW(>>-NymR_^f97)8o<2V>evh|elhsdtHhQMf>m2i}8{ zz!XyM(VaW*ch>!Lvhq4GhtWtM>PoxlQPOlxi&j=6-h0{~(Q^JgU>bsWsOKqeMR#?Dw^x6B1SO^Xh?Nf3-17_Iwx}zR)^CRWS zgLAs1!~)#E=#vf;*qx#sp5O+2HmHJP0nyFDWB33o7Xq27E-Z^_A~u3}b?p$jVH!{g z!+3+)5jqR9f)tC!dJ6!T#5GOA&gN_6CK^(wyiMma)12~J*hCOCw(b4yql@4q61-UeW1 z9F$36ht{h>Tn3jTGGOGDmwA))#;?bo?Hp?XIe+#Tb7|Ri*Jmvg)%*<@w1Q(q|9l%o zs7`uToaa;j2aKm_N8OXZ&L7ODsw(Ukkb9^-eP|fFK*imHHpPexg-CU?X`=59$)bJ) zTdz$L5QSy9#0bJ12Gx=X$IC}uFJ6u(pTN(wRPMHF^(@_SL>zyxY$*98Z=Or%of=V@ zXG0pTHeWHX6>sDf#PB`jkmYMgG9~?bzvQp-s2CFfo5h$vBQ@S)Ay|Gon%_XsF~-17 zxhon8zs|zt{pv>2M;@^x(~|m^tO#0Kk~H0i!=c<5KV`6R(73=Vi~bH_wOWe^)tjAe9ibq= znhZ@ofUp4{K_D8h&~u>+cL%gA0Q7Tif~tF#?y%(xNJaXg`M5?St9|r=n=?CylHGkQ z#P~;<^1YhDq`NOq{u!xq^n^P#{7c+00?l0+W04qtA)xUKC^Y@E#DIIF{kQdZx5OHF z_cxI)YQ8-y^MYtVrf1kNu@lf_jr z%by?QH4ke4aCXg-*WM7xAoIbmK!c7cGYduQ|mvbg&M`dU+4@9h>QBMxvPV-y^} z*g4?~tJ}tDw+0YXMJh8DiK!niBZ9IK5{Q}(iz}0^6V)C5>Ex^dx(B&xiXF-t!PLg*neI4J9 zWvc)6J(fX&>N5mikGPV}Mq681az0{#&aK#jD@0KNOL$rQH9NvzVNps@p>81Ca0E4h zp^3hAZ*Zy%zbtu(4Jm*bmT4sya3@UxdtexlZqo)N;u4qOW4z#6kKV8Jo`1~;)DlgJ*Rur;20v~$ zd5VG=Z26vr*YRGs64n8#Xd=bAz-8*8P;dm6AYGNx5UM#EFWy>wNeLg6+!cR_0dm-Q#`+<(}kj1!x~k z%MOZ75gr?<{Z=y}QiZ!MTfN%PiA9rzlveMrX{h#xk`+rVc#|sawPc#GGVQEU`4;=z zFxHYlp{>y*VTZ(%FWWqq%9Ad_olkQLPl!JJ3!Fr6T{oTqp%7aCk6Wu5pWQB`OsA2A-68;s{r~NqSevFXG*n%)9s8uj(fzA#5xTiw?5VuQl-Y z_g;Yj>&$0U>07=1{JC-616pf!yJ5Q#S7$HRCK}sZ9sXr0OXaW2H&Pm1+6xGcxb>#~ zgcwDJ!2!WqAd!?3R|&%-l13PWWQ^TG!I>Z>C*ZJ5Q%wy8wEy^}j0Cw!-FGEf$8LX5 zXq{Z*Bcq~tOOG^MeuVqwPFoR-Vt70KJJLC9xq$hCbNgpytMdr#)>>bCok>IX`!7Qv zFA4}NOVNRjHU(o^+=zHrZ1gjd^gL<&>Q{`_9_+6I`n};pU1t2i zEJ0NnaWtH)#q#W^PO`*|+#xn0wqqj(8XIAqHcJ^KD`o}5PmdPCS5;OQ8*GO6@mTqE(1DAD0%VWvMUgsl1qsI zv0`3yhde%oZUIvk7+oq(3p9)1d@X&ebjg}h=;{Jl`g zT4mpQBvQ$3rSV9#k7du=i%6K6u-CA9l$HLJssh8%L${;@`5Y7I^=6S7ZI830F2TBt zz$*VSL7eH>m++gayjPhsQr*-I7bJCKu##-hw{O(?c z^IHsJmR*vI7X!R`LD2I&7_tPev zo$e6K%3;>Y=vIMye5s_iXbnoo61B2!^n!zII&!7$zMCB=^mXn@^C1h=NDJ;0Ta^a; zOf>0rF}k$-;5^m&tlOSoOHAaZ_XPaWB*2e;xS7KBpQN;K_$^SV+?krh@!ItA%=asX zq22T4*ESl5`!+VW04gPXhaeO+kq&to!j8TE6*VuZYQgFaMS^tRIu;^x|8Zw_6jXi+(K6C2gYr+ z5HoEr&;X78+8!}I_arB70m_G{5@%Iaa0px}sW&7;o!KiA?B>?_$+0l(DxQPj^(^wa z>^{foGHNnTi>Dx|!d^>;?Bq9|GqYno z-yY-HmCTWUbpw=$n<*Nyfqs4&RnD4VxB4P7YTQ>S^=KzkAS%3 zai5yPCfrEr*9mt}c39DtMGFl}EIZ;`Hbs8rGpL@gqspinI9xyR9XiR;M)f{9*+0F8 zM{V1Fy3cgXbQvfG8Gk+=u5308?p!I5j*7s>R|z4FC{C&O`fSyY$26jT1?kP}RXU;G zj-(}RIY(+NtMAS``1IpG4I950Hnu$)oa1&$6Rj)U|5uHTeXV`?!kGG>u%B!&3Xz>Y z*!#QW{QdO_-IH>4YS(49h}JV3qtX&Cf2oNE^w8MT{f6c~GI+CC3?fz+UAY#!*j?a9 zh+^d>o!2oR?hF@YBA419COAbq6tPi`K;s1~nhe;BN zRxr$a=t{?Ya8jqUC1>ST@AS}GK@`cmJ+EA;97M4&f#+CzcwQ@g_4_$hbFb?A4t>o^ zwZFOE_wPB0ZzBJ0TT#(#M6|hDVf6&SfB^{Y8<6WJ56FwnFAAi0pZfEBsi_${`i5;E zu_5|I6`yYP-8(E#|9FFTUhJ5mkV+IJU+Dm@rS>IZ3xbAsqF=o4RNHn6UFJ|%@F{O8 zd)MjGA@D(t>zSH;IM0xyYT|f+4!F*(xO>yoCh8(H;rD7BN%plv_8xbQp*b%|p zIO}34!sleEBpf;l@{wWD-vPZB-bYa&006LjGsAB-W7mIgMjc#GMQ2p^u>VPr^r1FW zTn%}rVpOw7gi8k3f_-w43VHH)AW%Ub_{|2sdapL_UzdkJdE69%IHbod&u7_4 zUhK%|_2jL{+IX@5*p2UJhz-%Sr%-P$Tbqd>d(D|aIUXOtiy-8sIoLm_{m`E2twzXa zr--uDtH*Kt;;ChX>BShG673JBC3!O8`0=xgfn)Pw#WBQLG6sW`YiNAf&xf$7$?w>M zB+*n_v<#|J*mm{UIt)cr8LyY+gwQ|#`!kc}Uf>_5nS0zik}d@G}O!&hQOO5{-#CAW@rT@cAPQFQ^pTkIos&*zsmI^ZjB*rW5hjy z$Mn2)u%+U#+w*u<1qt(;@Z2M?9n}2lcU*>k|8?arqxRuC2H?YfWvRh`V`~S{H6_e> zG zea)JzPhl)pVKeYCoc6m6Dv@>r73$UrTu?kqWwG$i`9}f_W!$Y>&ZpdPF^~~SL1)DQ zo`K+`C2{)~u4q_{FjLQvUzf^=pPTD~5mSR{)kS;7_mV%M1wh7R;GbM`!~glUlps~7 zrj2F3Xuv`+f90!!_s~oBiDzM@JUO#^yLSNGkfJDJWYC3fE<83u=C*FPoulGsYrYrE zEX+@yB(>MxO`-a>n)G^nqxpxxGE+*zuFt+RN~gM1oaqT;ILScp`r%Hcz@?b$yufFSv&|1?et=iF{@As6JNBlLb4UPE3?vq(TzgamcKQf~1Qo=#ieCI5C>TuVaaSs<(?NE|D$9zNhI!qk+p+ zgsnNJ-Ats`O|Y<;0T-YKj#TKzkcHtEP-~6Y1B;V)jUc!ud&>+ttIBVXc?ZqHAx80L z2p~Wv_}dB>fR5YE8!2S1gPf-K6}6GuZ3MP_`%K*877l;UCX?&Vb=Ydq@V`4LK*ZmD z5!H2cxqsDljFdwc;|YK@(J#7C7Rm?*O|TnwS%zD6b|&UwSmaMh{<}f8G?eAAVl;D7 zKT1$EI=+4E%WkwJ4)0)g-$IuHgG}02?B_4+B;sGM=1V_fR{+@~Ke0by?CIoJMu03C zA5s=>0Ox+gqI%6@LP;!OcQu+P`vk1Z>sd;3@O5Sl&ms4JVl{vgAOdNgeq3h~TXMO6 zZ+Ty=iP9J6k-YSi&-zn&4hr1t+=kzwyu41Ol^E7+yiQ3I?4BVCOKI;474y`nhs}=1 zD-=AL;@yE9Lh^^sqGX4Yx$#jf`gyOP`%)XEaoBkrcqPzXE{qDEM=(9zXQDM~Is}qe zuLs8tJlB2aht&dJYM#UTn?!4XVq*vWAd!JaeLvKd0ImaI)+9%=BdYW8Vt_b71lg^J z1O~XplGPRY%&A4T2h2vV7YvlzwjP(KXJVzU-Mu5^zHddArFc}ra)~X7a!^525YvM=$Ur(&c(IN17qkpR z0a-vw+e=!ys>>{3P`1t;=a=lry%D%gDcW&23cg31Gg3#dvjYzm(JTGE0L&5@Hk|dyBaF_K)&L z(Y1hoIKB(!q7)SZ&=!U?`z|_$f@4Vm#vQr5n3c@DW5wXK!=!(V@F-Hb8Tyrdg)?1T zUaM^fiKeOOdXw__%CE$1!CSf&;|@LmBDfPw-^C68aZi_LdVfYY^NW8#&A@+yn$jL{ z{44BhcHm^Hq!EE%Q4wR$vKk%Mxda%V8OWGNZY|xVlz25!;Fu$eP;4oqOb{T4yo(SJ zT4at8oDm8SeJ6TST#Z(x^l34Wz14stX~(}#9hFMb=Lg4DHqH4jbtW(8^-?M81bF;^ zlTtZBP66Mgr!U*L3iXEhd0wQyZsc(if()A_i*6y z2pHo!v7|!MqiG;xP*wCx8-?^$;`N93TsUG$3zL{1e?>ZNz&x}e)cY+GZQ7n+$znXqe(wJgEVXj;_6pHbz|9D|BOpj{ffT# z@^JNr_4L}FFnXdy@|RBFRla#`g#elGU#~5qIficrCg=03&SvE&se=My&g;;g%tOgP zt#UjEZ(-&U9g6j7ju&Vuo)X1on=r$(XdFYTYnmH)@Ytar>^z$iHXkBmp19Z&Tq{U6 zBM_0Lnh9_qFR)(IBp1#a!bLj(jRVt^C`s?@Rp|#We{T63{UWRU7lLAk+-xrR{}iA| zQqE;h!OP~0-Gc{B3PAj5jQvv)SZ4iO#afGRG!n@t&8ftfxS#}qi60WAkpYIQf>I$DrE6G` z^AOTC9A{C6-T{nt(P9QVIEXn!j|i`ns^FsVUVL(Zm71cjw29g)84M7!jq#{o%z%ES z2LO6j;qdv-v%nl%l9$Rp`$Wf8XVtP(TTc7q@$!Rv0DpBroOkiJi{!oTf`L5k|4Oz5 zg7PDpuR9&TtBhaKHF2E)i&7o<(bR*G?)W|Kw9PGInzu@6eN&>Ajr!>$QxCbi-S%l| zp~9_wR8^Y)uf4B|i+bz(l@vr$LO>-1K~j`ZkPtyh$)Shtl%6*uSKZMjop;Vap_C$x!98}x3`IjhxPy(xbDCp_2r za_)(V|Cc95Qu`aaJh^krL$T#>QZFvZ(QH*$o~*NDcJ_@*ywX0soR0?6O~ZuT&J}Il z`%3%u5iWrm`4mx!2GwL)Lgl%5#Hf(GGs33>sk%ssQYk|Ahq%T{_e5uA*@Q;ZQ+nyhH z((SHhr7f!VxIHm^SN&{r&i)2Be1DCg>9WmV^KV_CQqTW2zZl+)qpQ3OBA4YuHjxGR zAGDXHx=qv-v|eWtzI%H)moTuw)tIjH^#}2xss|Brfg6O;6r`r>R*2U$#4l_0$BmdP z&DeT&Uu0aic(UN%OJqcUi)O@Y+&y&7ZEp~CU!7=K(9J4uM*olIOJ2Cu%>SkN0SqtK zKIuh!o~}`9KJ?7DmB*VM;d{D1f>m%;O_et&q%Bd+PRCG)fu{Eo%Y3}m!#VC+-+Qze zwv$hn{lvEh;b_+h;8F8G_a-GMWnN_wH^#NB(u_8}@RIUI|VS#Rlc zLSJs*%X*bW5;E>G?$XY;m`#W$7f<}Hv%4T?Yl!J2)ab#ZM~{)~0?KUvvBb0_7||pJEnAR@v#rR?8;e4OrBM={^wT_G=A3Ka zsQa%xzm3VC+5?`u254M>n5Aa6&5A!{aMV z&-dmUzpt9=b`~j9P5C+KqoTWrOaFGA6}dm*n@fBQ-~h)MLVOi?-E(LN7YT_p85y-P zH*YxiOMJYgr>dBOq{0sy67S9r%)v7FIe3`|Maf7i+N>2-t0;r1qX>T+Jto35tR1?3 zIE3FHye{zjz0Ph*In|`?#T8Fz*;4}W9UPZHn^{UmLi>ebiR|2H0960*G>-ez+3&4Q zXESJDaYH+;lC#$4W51Z}rn85qIUU2FvO?(TUtzM5 z1n86Snqz(-ie?Ar%T>EU79u11{db|P89*dHmX^*$`^f-M~FKd1@ ztC~zp#Bn#uF3!ExvZB|7){)a}^lek?&5y(sR>&$-vka@Ana97y)H*y##V%oB-`DWN zmgsz?fXQfxkyDRH?p{_K*UP#lDUR`UFD>Gn3fGwFv{3z3f^pt79ZzO> z`c4z=S*J>}4YQFP}OZnB(X@#Q$(2Rfg%tNW2|)ll;< zy-pnKxHIPcK8_-sbR+(~f+S*>{I62bV~|$1#rTD2w6Stw;?a{}_zO2Q)H+PA?j4Q~ z@2wv{7*RHiAjnc9NoIID^MF@BPcPhb@LtU;DUH4{4l@(W0z0{|KnT@(k50L4SKVwP z(UzyL9f>q*gyGb47^%Fb?%imUP>LW=nH6fCsuki4khiQ2&*54(i)h?8{v2%hJWzn% z?-C{+)0HScsj-l@K_+@e5~i>a@R5Y7{_=xGo8HcXwMJsoJR4%R1a7Nf6l^VmKxx*B zTFIf5JT{iqiIUJxE>u-pzK}1#9shNEyha&nVV<2ELrFYbrrP${c9esvxS2^pFy3SR zZ8eOLxQ1=FYJKSxD`}naW?=~y1~$pZml$Luq(YaZ#3VTM$*I66A&G73qXmu_>?vFJVJxUS9T_H z6$6{<*0zqc%%917pSkTKvC`+o*|}88QXwTu9dBDt=yZ#NM}O6J5d)Tkt==fSgdc)M z!7;xdE{+O%ofV&L^Q%M3Ri z{)n>4`S@+Po=`U->>fw*#DbNbrTlKgc#CJbti%oT1{j~MVcIEwyJel_A{psrxX5fI z7#&EO4iG&ZM#ed5F$pF`@@wdASG1+439WxZ=~YhGSXnPFNY=VdC%#{o$Kl3T5-ole zx7~t*sSePnmL}FZoArh9mVWQ!)E)Vtjr_4OkYf7a=CNMu)8|6ju6kBGpVHM@Wa6Ak zqMgroNFT)lfJ7uCjcAl#$g@;dX0QZBEo|gKiq-Z;^OTzJ!%-L8@hIfX2nTgAqVOxS zXjup$_=>r~?@pCQt5F?lhS8L^Y18sBY9nu>a=qlxoyr|QL}Xc)Z}lnABNDMTd@%61 zKbYh%O}a}mUVR%i3*TYX7C!c#TCI8e)94ZG1%*)V}a9A{w7HLz&g-)rj@X zuN_2ouz^R#b>p?r!yj967B$CBMRJ>tv#{@Kk7ATrBi9RAMl z{8%xJ`Tfx(t%ww}-qf9H@=D@#JHXHuw4W9^lSOe`o@8!NQ#4gFXn7h*4ACrjPUO)# z%#={-61CTQ^j>CW8fhLgM^3+`^Fa|CiHynZuV~Vt8FUHL9g~Uv7i;zw7Ji6%XQ(nG z_$W_FjFUdT_2(PTTW7g;jp_p_cNV|ED0CRJJ%SwEOBubU^zqc2?0aq$|K!W ztZcWsy}-_!P&jJaKW~>fG<&#c4C-k&$jrG0Zy9Nqpg9k0(K-F>|4zSa8vDy$#|LHe zw&ncg8n!9w9NRX;RuDG3VARoRemyMV4QpvBqeJrqd+pj0Pj6=K(}Db*$So3L)h-%7 z`9uWDzM2^7@h){@EX)$#jr^{QKBWkV{m;B`1WwyT6E;jYv9us^rnU2Zs`!YvnH)X`rL8 z=*~GK&RVe-6S>|`FjxuIuJh7^cYYwkWV35Lf|UBG;mp;QLDNpwDOJLQ+Y;!a343(C zOK%E)Np1^&igEsRk|;o}keio*J*T`E7LHy$+Mvx6)(fb-))MpDnl#qt{a)FIGlrQh zj$i9Ho>q<{A&oUAoL^c)=MX++jd&Iz2fe3pF&G|bp%1wyk#*gK9z*FWhi({}MkIx8 zu8txVCW9|3Ok&b2g{Ndq$eiZbA9pWMSHBsL)r9AxyvKLLy1h12%0p@hOlJ)LF>-aJ)>3V! zS*#=DtR4Q=W4bSX-xS9nQ<{k2{0k%mQGR&#FQX|UuqEzOoW4Ha7mk0ap05cl6}N-r z$e&w&Gnh6QmRi*ZFkH07n$NhAQLzC(jjj-^3A6Kl?^>%5d2R0^ZerC5o zaY3L!3E&!}FqOs(qx@K5?9%62?Bn5Y()=q8WE#a!oC#pwEtS+JPmCEar zWqASolkmAqR?~~`$!0`oFRU^>6iD$1v!?he18~Sg*0|_pWlCThRyNOGrEZvQ`r^4B zuC80PkE;h|oQfEEMkKV3LT(NC!Wk2aHF|j$aR@Fuhpvj;>I@y;)?y4_f9Fid>`mHrKP{KRB1^_m1H z)O_zf*laE~YuSvT-tu-}+>Fdy3J$BW)Lh&s&^#>Rs8f%#L{!FmbMc&{;asqPgb{5j z-!Xwz`6=mKFdxqx`Q{t<<3+fod1}baR9%5d^OMqCoo{`%I>A-n`WhI``3an>3snnj z6(tQDFKB(hC5GAVy0BvlNbVu5NsHdD7%#Sq-JGeXgR7g~)cX3|84~zO+K5+FihWv~ zkyT~H5w&4_Eq3#Y`8t}TW636c@tj|7{-(ZI*bsP7fTAR>jg+UV3NBCb?pIIh?*BH3kVi{)OXabdQE!MR}lrUX|$ z)xAtc0gT$8SweKqKvg`iU8L%<<>Z+69<%E@jrQU9)mqp6GVjxkeQxJS3oGYH+r~jt z7{8_UbEuc3E~a6I@2@|cq&_E{|DC@Pt__=?@_Hob~mK zEmZ(j!?QBJ!b*i7f~kM-;6fx0vG+D6O@xr;5A447d7W+TqkbV^nn5tc?-@L+3zq|+ zKidBA0(eJoy-V`k4gJrrm*~H_(4(n!0e*pdz2FJGpWijKxMBGNaE|%R{v4`bs2SR4 zrJh?hy#H9X3rukS^)J`}MXrgVzx_un&bM&!3k`t`FJl6`i1TIv?alw14!nAJjzLNO zgF(?V{h9M$*yiuwe1VQ2h_L^=#s6~g|FNKo8)6!cES?9UY&?D%y zq#wtwUVGo9FR-RQmf(N14tihf)`w#3TeWZ#-f@a_RuD` z%KmiWDZ~VrPTA}!Sqj$O*DGspZwjr|pmZI7%)6{l4u|cCi9OxXg^KoPh|VFjy$95a z)%t-cr_Cl)j#b5Wih)~<-tpW;sF%v8@FM{=6h)X7Oy=$%3d4U4|0oeLvO>U27JDM= zJa$(!yTT}*_K0J)6NDapn)cjmc8A`(j>YKIW6_7BY@+H_-te7fs>&ohf`+eeH+TVy z*JNa)(v zuTuNe#e;7c2vI$(y&b4~2J}z5bRN8TD5Sg9`oq*yLYCintt!`SIUtfD&KJ)Yh z{_2(6$&+;<1`j{B`ZTB6cT>Cc-7Of8VtA^P#yO!}!6%CC#HB`HfaVPw#)sWXx)J8a z7tHkW9GVL&YY=s)iMozG4u}Z|{kgTnBQfl^`vo50<#D%L=vo+Jic?8?taRh^_kYAF z1!=O@RCQ_0ST-rjFvZ6-xJ0Y)0(z z;D_)o;MT+3Fc)@46$GGkF@AXWu3Jny8DcKvM2a|6>o29Vydj2KV9evBV@uFPo<@}d z^!gDqx_jXq%!a5*Vlqm6kaAQ4%c16FD!z?D8ZCC92%f+f$$YIJ$3!nqf=w3&5X)s% zQT3v5#t6d9miv~b`4$g5#`U@t0!HAr%=F}Q9|A5n^3dOM`+~7={F;(vII^CMgr-o@ zn+jqlKt2U3uOk0(~$J=jae2&*`UzPyNGCEm7 z9;k*N5Y0>SdW}NTeiTAS`2Bn6hiNu`F|#FiFq^iXlFs z!<_4t#TC!CIFKeZqf-w*Sf6RETXmH4{iPUODit&;x!E2>V6(F{Ts8{^OC`u*>?vi=DnmY>OVI18bt z&*{9!wBDN=th^@eRzmr5@Q6}C#GLoiTxacf)v&NuQ{~b)VY+U51@W(~kR?2biE9dv zhk2*Sp@SW?SgeF9s%9Y4APvty0{1ZpE0yXt!x80r8G_mpjTNnLRu?Ot zv)2478h}w0tysIb)FO&hf(27(OqOQz1dqX~aP{#d^2xmR_JoT+IPgP;ZVHm){nKcV z0BDrG(-kW${vXtH{)*lSpfqtRvMYgq$2@3^ybK_k(ve}`Erx%25&)&L2=HU9Ggr-)??UeT|K3$(!#~VVlJcfa2W4D3rH6qP2arh&{4jc=n!-MRU z4ImqO|8%qUtOf0%L{4^8v>P1q(54AWzZy`Z+w6gaJ6o+z))b$B!(?kFYpgCqRBN@b zVx*ZTaB7><`Tnl--Wh(r@x!k51lgs=GKot_GtSd+v{|;8NxrzAM{rtvpr-MUFq;kxw?>{xRsP zqFcOyHh#!FSt@n9sahM_@6b%^Lu}rz6p`nmXTKc0Y?r?{=SGF&P>iuL++v36OIt ziS0UUOgHeG^pWo$79AC9tbXTZXa5AUum_j3~{DR?9;ikQ&-mq)fS`qO$F96S5~Hvj_1y1eGWhAy)Nt| zH}S7Wi_Ycv!2Llf~n3;BX$W zC;5py)N}tvks{&^FM`Md=P{jNYwdgs1hrLuCyF-^pVald5$RsPr1znmkhh|zZdN0? zp`Jk^3TJ?p@Dz1$;L<&PKOkKk+zdAcnzL2O68Ws+XTa%w8`Ci3^5mDR9#Wh`PL48| zj2=C`33@)ux61lcAB|Gz?gG4-&y`e;9jFz);_DGzcA2-?bH7%WyLoKp<52Rw=`FQ= z0p^5spCbz@A?VgTCa(i6zQ)?!a%xyp-2UAxYz)s$o;^AXxr(M5)^ zyZ|zyqrO&F+Q-f-s>i6-QBo`Tab zs~-$Le$XT&*_-OOl51vm{>l&=1d84@Xxwn@A z;~m{y!&dfup#Q&V=mLf8sZ{kK<<2p%V7ntP?}Pc^QfQ+K6XCTi@gqkAR5M%E0Lrp- z?_e=EVxcfPx+S>zqmRcreCTGy{}C=w~oQlqnTfhbwp7BYHZ|W*q)1p>fmvE^uZpR94LFzyN$pj~-P$6N3bJw?<#H ztST?~E@)@vMVwa6M?@p*ggut_Z04PFL@@^I?iZ&oZT@BHD^S#_-7_De-`|l}umKqzrYIb8l{kpNeM_yvk}SCC($ifu)bk!d{QCBafa|6F?&3*2 zfd+-|Tfer{M~>4$F>qAjx2JP0-n~Dc)7u_~gL+1*)PWC5V^07Dm0@SWcz55vw^`}K7-3=ZycLolBC-5hc|OBPa48!b2!_Ig>9j{OWmfj{ zG29lg!dJ}r5b7$MF$y9R8}*As0*ZOvoVS8(aG|&? zW)=+6)3(*+a4o*3)1%cg*&z7)> zUCPUHTfr^xI4KZ4>5ah1DpRHz(#CRUadDOF$kzx`z=W3%2k_T84Ar|FaUg2SwN5PK z8BeBjo)fR4$}<{IcZ;r8kxt8Cr{)(>exYpcg+JLW4Mc0{6TB4aWvFAbmsln#O;Cy>{-0ecjVTYTJa=VP>0kVhjQi!H&{*X0|T z9B`Ks{t%)Q)(B$n7x^?e{J!kCyKy6r8{is9*0mL%tl?(o8+Y-CD~+BfxQ$e1bzYP8 zz=`!&XD9b&&95?#wT(cQK4Bry+W7!Guyk@EY|kOqi`Xa;d7JzzlKBMF(x-g6i5#W6 zQj)w48$gypr^Y@DoQYp(fqI(CFWrDZOI`PVm#l5`D(t{~4M>_Q6e1hSUfy)yOGqO0 z3g~N4VpYgJD{`#A1(#2{-^I6%d~~SZdK5={TP&k>>xs$j<``;aw=QMkNnOu%J^|W9 zksJi!3sn0`eAZdfxAN|J;AZ7Vww+%qN}38#Jnec>EmhQ&Rh9zc_j+edP0bC#H`sxL z^Ec42Fa7a^@P`*BDsU;maB1oDRB_+HHB5sYU{SiNvY2Rf~&K{3TtgHx`*>`nssABGa{KHx%B+ z+7R29{bbc+Z_NTHVla!AhR~sag`b9Ao~|reTwQ)_2b*Q0--hkm7vbY8p8>F+c?o-w zr));cZ@qMsX({Jt+{;Wxa_kYjCBRJSOv>whiq-Bq2NT(z* zcw;_q!ib?%ub_HXOsiJoHB?iRpM@ISLS`jTyasX{RyxqI;!dKo@9Mqf1NQej^DV3J+NpoUnRFGkT^Wli-)21O9Q? zNLRf<0|Y;YvxGLwlOL4p1z^FtBMjgByRUC+MRYf5<{Hih2&mgF?w8C4WG14j?zya0 z_H;MdYa7neA8!c5iHP905|rd*p1v?|*w0nL`s$W>cm1tv%xbzr@!=*or+oAKoeTlH zbSsQ`sNClE*Jsq?eD1@0yeCn~@s?3C{lO(Hk7gsRWK6!upRrfCZ=0f2QPj7ER!B@?3h_Ra*&EYxCfg&{G7TcEGP6w3_kYrcHWO1=*AqeZkgXUHohw7+ci(79r%T4u^H-yZwHGn;#N(o+pC zTW(7wIAQraUSQU(W2VruPo9t!_Vwv0ZRSX+lHmSZY>9(K*oO6rR+g0ALomoVS+xGK6Y{7^1way5KG} zs50Zk@qu;)CM*Q*u9SyplkUY;7_m}SdJYg4E|+Q-;(!rW(tDJ1x$(KRuoAt<_kg30 zyQVbS;#h1JpDzms^^4+JEO;J15(N8cr(;J*M#SZ^n0wqGt+7&XT+UC?VedE)iw(=c zeM&3jQX;Fi#3&Og?j2wCm!DL$a6ZYwZppU*XYyyuO0e=2rYI50eYDoc0ZF?}8CLV* z1_8~h6s|RH!Szx@AfMATOKMCh5aT<04@edrseQYHQWV$!LJ;ABUtopT2 zhXdk`2M-9xvdbPEW^^ydZG|rc-}3mFFQ4+At-#ha&e#yY&hLI3$hk*1E-O_%airhh zoSS3DElQYw3$>Sdj+y;X>qC=TMuy-V@K!Ztw6C>p74Ze8=dMf_pRMYYuviPW!XQ~% z$rks_YuJoK<3WYeyo+7aViDE6KkXPDEW3?5XH?2E&Z8P5?`u9aHk11p1Z zS#{b*oa7XJX?l|Q%T~v^x|qeqDd`ITOL*3>wKyyx=8n9Bw|tzRV$=IV&lE0~UHC0w zkNOQNOUc9%6Rir*cH}lep-E=`B;-T|I+n;0_M=3zT4zMp9;tz$wL@K;u`59K;gGpw zyU`G48=gF%u$WigO}%p!kM%mkdzfo+lB?M?alq9Rw#cFs(}q614s4f(Aix8fEUSMgz%VTn(~R{jaxPnn7hYF>%Wp4mAJ zV%ll37@F&9Jj!P(oH7R>pN@)1%k?PyP;muT{$T@0dyt!?75A5Ek;ABZacJE_5r(g` z5sWs~EkTmm0}-U8peoJvE#d0mCQYOyRuAk^9o&Fbx*R*eFn;ZB5qJP{XGJj7AJuWK z$Cz2^m+t7NRuFY=B&(yEw|vfIeirj3FO&oZEEL(hy}FTlZ$ji}-54m2=YFT0bX{Jq z$1O^|XYh)4`^oLll~8usdK{4}Cw`-_i^|Wzm}%TWXi4>hxeM%Wn&_sIsM* zpZ)ZiJ`AjWvEr0sZ(98S%VNTJC_bW(m*xK zh_Eu0fZ%mR5^59crx>1KEM1_#|EjbA^w9@uF|4#a=V|_)Y&Bi;POA;Oc?bvy9&08% z3Gxj>S()h%Ro39M+tQ>z-BH~d9o*u!QHxX2v%CYVUlC$4$iRxuM-;z+4K}odRk7Cg z9c)XKv1)dz8yk0!+kVbC3wpZtz<_J-zSUh;7;X6tuN>!ZR zy~?&Vonij|G{bFOwbHPwDe-bM6%+ldx0j|a+AIE(tEa~TtTh&knvndfXW_5Ey#oCs zq)yo5B7bjoKsQWK(}QjZtXH48{?Q0QPx2mEugu$Rd;e&=Ao+L+gtgQp#xCd0O@B>* zZtTg#krKnciBWse+4`SKffz1mZ2Ra(UmgCpPK)zy4Qx&@`x0J-_dnNrUkrp5mP5i`$!0i7wERoUDpKlL}kBKpAXlMZCWOj6Q zx5N?wJ0@QQ*1Z~sfN~sEt4jGh;^#r(#?pe=GVHNdo3@qGwO>T5b zP2%17oj)>%jn>&qPW++gQX!U)w>Hy=s&N~#rUJp9k+^v%3;kQcl|s2JebGM?xuGZW z#k+J-qx18{1j2_d#DcgqQ9eUt2BubONkf&5?r>;ywL9LocF$Fv&ZflD7(Bz*JF4Ix{V@*O+iM@ zp*Pqc>#`OgXcGQy=+BbN7Zi&^x1?F*yVmRmdRK5M;bZklhE@2fIg3nG~M#fyR4?KU}&wT6kRqBQAz@e+uf7+Dk6$ODc zj2)=2QooQ{Y{pD`Iv^Bgt~1UFw5wUx6IEQjLk*5*LQ9zezxv*T8`DxAIp(8VLt;yo zM)JZ})4`OHO{kw(f7l+&1-O{Up1wwS@qz!9wTbE;`Nl^e>yO&a2F)zCYnLN1-nBY+ zfM|Igv>H@3IR4NArQjNyDI};MQkfr4Jp41*0Vwshi)tg^{KJr4!W0F@`itV<`riIn z;5o3LgBPE1|I>l@=anBecF(^?!Mo_FH01wBSs1)!qJQ}BLE2ow=31^*|HDRuxAYhv zc>X;|3t-aEv_!(c8{vyZ;N!jh_aHx^e%OJVo!0;GCg81<*vqT`8pLlCtkR3c6zh+D z{LlPin6HffJ%|^`f6fc>(EgihWF$BKJ;;~FOZa$mp%ee68X*1ke-H9s(xWN$U(#d# jSL)HK{r}nE9QX$Azhv#XjemFv{76f_kSG(^fAxO=X@_Q} literal 0 HcmV?d00001 From dac42607e68c1bd7c4cef2a34c597f5c87d72218 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Wed, 30 Oct 2024 08:47:01 -0400 Subject: [PATCH 31/32] added ademamix example to gallery --- docs/gallery.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/docs/gallery.rst b/docs/gallery.rst index 5f7a3b134..e05742c02 100644 --- a/docs/gallery.rst +++ b/docs/gallery.rst @@ -284,6 +284,22 @@ Examples that make use of the :doc:`api/contrib` module.
Sharpness-Aware Minimization (SAM).
+.. raw:: html + +
+ +.. only:: html + + .. image:: /images/examples/contrib/ademamix_rosenbrock.png + :alt: + + :doc:`_collections/examples/contrib/ademamix_rosenbrock` + +.. raw:: html + +
AdEMAMix.
+
+ .. raw:: html From 7c57abd5245b8c7dee0dbf78bd76a811925a28d3 Mon Sep 17 00:00:00 2001 From: Daniel Marthaler Date: Wed, 30 Oct 2024 09:35:06 -0400 Subject: [PATCH 32/32] reran notebook with colab link --- examples/contrib/rosenbrock_ademamix.ipynb | 67 ++++++---------------- 1 file changed, 17 insertions(+), 50 deletions(-) diff --git a/examples/contrib/rosenbrock_ademamix.ipynb b/examples/contrib/rosenbrock_ademamix.ipynb index 588fcdb52..a6ebd3bec 100644 --- a/examples/contrib/rosenbrock_ademamix.ipynb +++ b/examples/contrib/rosenbrock_ademamix.ipynb @@ -6,6 +6,9 @@ "metadata": {}, "source": [ "# Recreate AdeMAMix Rosenbrock Plot from Paper\n", + "\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/contrib/rosenbrock_ademamix.ipynb)\n", + "\n", "This notebook attempts to recreate Figure 2 from the [AdeMAMix paper](https://arxiv.org/pdf/2409.03137)" ] }, @@ -105,48 +108,24 @@ "text": [ "Objective function: 1616.0\n", "Objective function for b1=0.9 at iteration 0 = 1599.2254638671875\n", - "Objective function for b1=0.9 at iteration 10000 = 10.89592456817627\n", - "Objective function for b1=0.9 at iteration 20000 = 9.620516777038574\n", - "Objective function for b1=0.9 at iteration 30000 = 7.285767555236816\n", - "Objective function for b1=0.9 at iteration 40000 = 3.306288242340088\n", + "Objective function for b1=0.9 at iteration 25000 = 8.618563652038574\n", "Objective function for b1=0.9 at iteration 50000 = 0.26169437170028687\n", - "Objective function for b1=0.9 at iteration 60000 = 0.009876935742795467\n", - "Objective function for b1=0.9 at iteration 70000 = 9.95625596260652e-05\n", - "Objective function for b1=0.9 at iteration 80000 = 6.432726706862013e-08\n", - "Objective function for b1=0.9 at iteration 90000 = 5.157154703283595e-10\n", + "Objective function for b1=0.9 at iteration 75000 = 3.934588676202111e-06\n", "Objective function: 1616.0\n", "Objective function for b1=0.99 at iteration 0 = 1599.2254638671875\n", - "Objective function for b1=0.99 at iteration 10000 = 10.799932479858398\n", - "Objective function for b1=0.99 at iteration 20000 = 9.439836502075195\n", - "Objective function for b1=0.99 at iteration 30000 = 6.946890830993652\n", - "Objective function for b1=0.99 at iteration 40000 = 2.7601280212402344\n", + "Objective function for b1=0.99 at iteration 25000 = 8.370291709899902\n", "Objective function for b1=0.99 at iteration 50000 = 0.17759834229946136\n", - "Objective function for b1=0.99 at iteration 60000 = 0.005802110303193331\n", - "Objective function for b1=0.99 at iteration 70000 = 4.045083551318385e-05\n", - "Objective function for b1=0.99 at iteration 80000 = 1.2925656989182244e-08\n", - "Objective function for b1=0.99 at iteration 90000 = 6.390479256879189e-10\n", + "Objective function for b1=0.99 at iteration 75000 = 1.1651114846245036e-06\n", "Objective function: 1616.0\n", "Objective function for b1=0.999 at iteration 0 = 1599.2254638671875\n", - "Objective function for b1=0.999 at iteration 10000 = 10.194862365722656\n", - "Objective function for b1=0.999 at iteration 20000 = 9.375121116638184\n", - "Objective function for b1=0.999 at iteration 30000 = 7.936856746673584\n", - "Objective function for b1=0.999 at iteration 40000 = 5.422780513763428\n", + "Objective function for b1=0.999 at iteration 25000 = 8.757580757141113\n", "Objective function for b1=0.999 at iteration 50000 = 1.4608842134475708\n", - "Objective function for b1=0.999 at iteration 60000 = 0.057731419801712036\n", - "Objective function for b1=0.999 at iteration 70000 = 0.0010820545721799135\n", - "Objective function for b1=0.999 at iteration 80000 = 6.941367587387504e-07\n", - "Objective function for b1=0.999 at iteration 90000 = 3.984723662142642e-11\n", + "Objective function for b1=0.999 at iteration 75000 = 5.853441689396277e-05\n", "Objective function: 1616.0\n", "Objective function for b1=0.9999 at iteration 0 = 1599.2281494140625\n", - "Objective function for b1=0.9999 at iteration 10000 = 29.86247444152832\n", - "Objective function for b1=0.9999 at iteration 20000 = 9.297667503356934\n", - "Objective function for b1=0.9999 at iteration 30000 = 7.363901138305664\n", - "Objective function for b1=0.9999 at iteration 40000 = 3.581587553024292\n", + "Objective function for b1=0.9999 at iteration 25000 = 7.632230758666992\n", "Objective function for b1=0.9999 at iteration 50000 = 0.872508704662323\n", - "Objective function for b1=0.9999 at iteration 60000 = 1.0354793071746826\n", - "Objective function for b1=0.9999 at iteration 70000 = 0.3354209363460541\n", - "Objective function for b1=0.9999 at iteration 80000 = 0.09372159093618393\n", - "Objective function for b1=0.9999 at iteration 90000 = 0.09824670851230621\n" + "Objective function for b1=0.9999 at iteration 75000 = 0.16873982548713684\n" ] } ], @@ -167,7 +146,7 @@ " updates, opt_state = solver.update(grad, opt_state, params)\n", " params = optax.apply_updates(params, updates)\n", " all_params.append(params)\n", - " if i%10000 == 0:\n", + " if i%25000 == 0:\n", " print(f\"Objective function for b1={b1} at iteration {i} = {rosenbrock(params)}\")\n", " all_b1_params.append(all_params)\n", "all_b1_params_array = jnp.array(all_b1_params)" @@ -244,26 +223,14 @@ "text": [ "Objective function: 1616.0\n", "Objective function for b3=0.0 at iteration 0 = 1599.227294921875\n", - "Objective function for b3=0.9900450110435486 at iteration 10000 = 4.631922721862793\n", - "Objective function for b3=0.9950100779533386 at iteration 20000 = 0.0\n", - "Objective function for b3=0.9966706037521362 at iteration 30000 = 0.0\n", - "Objective function for b3=0.9975019097328186 at iteration 40000 = 0.0\n", + "Objective function for b3=0.9960060715675354 at iteration 25000 = 0.0\n", "Objective function for b3=0.9980010390281677 at iteration 50000 = 0.0\n", - "Objective function for b3=0.9983339309692383 at iteration 60000 = 0.0\n", - "Objective function for b3=0.9985717535018921 at iteration 70000 = 0.0\n", - "Objective function for b3=0.9987501502037048 at iteration 80000 = 0.0\n", - "Objective function for b3=0.9988889694213867 at iteration 90000 = 0.0\n", + "Objective function for b3=0.9986668825149536 at iteration 75000 = 0.0\n", "Objective function: 1616.0\n", "Objective function for b3=0.0 at iteration 0 = 1599.227294921875\n", - "Objective function for b3=0.9990003108978271 at iteration 10000 = 4.634438514709473\n", - "Objective function for b3=0.999500036239624 at iteration 20000 = 1.222501424535949e-07\n", - "Objective function for b3=0.9996666312217712 at iteration 30000 = 5.6290506478262614e-08\n", - "Objective function for b3=0.9997499585151672 at iteration 40000 = 5.5214698591044e-08\n", + "Objective function for b3=0.9995999932289124 at iteration 25000 = 1.852578179750708e-08\n", "Objective function for b3=0.9997999668121338 at iteration 50000 = 4.14928891245836e-09\n", - "Objective function for b3=0.9998332858085632 at iteration 60000 = 5.135669667311049e-08\n", - "Objective function for b3=0.9998571276664734 at iteration 70000 = 1.139520122706017e-08\n", - "Objective function for b3=0.999875009059906 at iteration 80000 = 1.3248055097392353e-08\n", - "Objective function for b3=0.9998888969421387 at iteration 90000 = 1.028013230097713e-10\n" + "Objective function for b3=0.9998666644096375 at iteration 75000 = 8.126028205879265e-08\n" ] } ], @@ -287,7 +254,7 @@ " updates, opt_state = solver.update(grad, opt_state, params)\n", " params = optax.apply_updates(params, updates)\n", " all_params.append(params)\n", - " if i%10000 == 0:\n", + " if i%25000 == 0:\n", " print(f\"Objective function for b3={b3(i)} at iteration {i} = {rosenbrock(params)}\")\n", " all_ademamix_params.append(all_params)\n", "all_ademamix_params_array = jnp.array(all_ademamix_params)"