From a591992b80c8dc52ce25bdf0278ec1f8ae5ff671 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Tue, 7 Nov 2023 03:22:54 +0800 Subject: [PATCH 1/5] Fix config format (#4594) Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/activation_checkpointing/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/activation_checkpointing/config.py b/deepspeed/runtime/activation_checkpointing/config.py index 13df4b981298..dc07388a95da 100755 --- a/deepspeed/runtime/activation_checkpointing/config.py +++ b/deepspeed/runtime/activation_checkpointing/config.py @@ -17,7 +17,7 @@ "partitioned_activations": [true|false], "number_checkpoints": 100, "contiguous_memory_optimization": [true|false], - "cpu_checkpointing": [true|false] + "cpu_checkpointing": [true|false], "profile": [true|false], "synchronize_checkpoint_boundary": [true|false], } From b1cb0dfc4623fc44e9931d1481c246bd8a59b778 Mon Sep 17 00:00:00 2001 From: Guanhua Wang Date: Mon, 6 Nov 2023 14:15:16 -0800 Subject: [PATCH 2/5] Guanhua/partial offload rebase v2 (#590) (#4636) This PR introduces Twin-Flow feature of ZeRO-Offload++, which improves e2e training iteration time by up to 6x on DGX-H100s. This PR includes: * Twin-Flow implementation inside ZeRO optimizer * json config tutorial * example using deepspeed * unit tests cc @jeffra @awan-10 @tjruwase @mrwyattii Co-authored-by: Olatunji Ruwase Co-authored-by: Jeff Rasley --- deepspeed/runtime/engine.py | 4 + deepspeed/runtime/zero/config.py | 9 +- deepspeed/runtime/zero/offload_config.py | 3 + deepspeed/runtime/zero/stage3.py | 60 ++++++-- docs/_pages/config-json.md | 7 + .../partial_offload_test.py | 128 ++++++++++++++++++ tests/unit/ops/adam/test_hybrid_adam.py | 78 +++++++++++ .../unit/runtime/zero/test_zero_offloadpp.py | 75 ++++++++++ 8 files changed, 355 insertions(+), 9 deletions(-) create mode 100644 tests/small_model_debugging/partial_offload_test.py create mode 100644 tests/unit/ops/adam/test_hybrid_adam.py create mode 100644 tests/unit/runtime/zero/test_zero_offloadpp.py diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 04611671fe06..3d75022ccd86 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -700,6 +700,9 @@ def zero_cpu_offload(self): return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu return False + def zero_partial_offload(self): + return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0) + def zero_sub_group_size(self): return self._config.zero_config.sub_group_size @@ -1565,6 +1568,7 @@ def _configure_zero_optimizer(self, optimizer): offload_optimizer_config=self.zero_offload_optimizer(), offload_param_config=self.zero_offload_param(), sub_group_size=self.zero_sub_group_size(), + offload_ratio=self.zero_partial_offload(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor(), diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 35d60b5b3290..f16dfd7ac4c0 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -6,7 +6,7 @@ import sys from typing import Optional from enum import Enum -from deepspeed.pydantic_v1 import Field, validator +from deepspeed.pydantic_v1 import Field, validator, root_validator from deepspeed.runtime.config_utils import get_scalar_param, pp_int, DeepSpeedConfigModel from deepspeed.utils import logger from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum @@ -300,3 +300,10 @@ def overlap_comm_valid(cls, field_value, values): assert ("stage" in values), "DeepSpeedZeroConfig: 'stage' must be defined before 'overlap_comm'" field_value = values["stage"] == ZeroStageEnum.weights return field_value + + @root_validator + def offload_ratio_check(cls, values): + offload_config = getattr(values, "offload_optimizer", {}) + if offload_config and offload_config.ratio < 1.0: + assert values.get("stage") == ZeroStageEnum.weights, "Partial offloading only supported for ZeRO Stage 3." + return values diff --git a/deepspeed/runtime/zero/offload_config.py b/deepspeed/runtime/zero/offload_config.py index 1bd79412d39f..b7adc13a0ea2 100644 --- a/deepspeed/runtime/zero/offload_config.py +++ b/deepspeed/runtime/zero/offload_config.py @@ -92,3 +92,6 @@ class DeepSpeedZeroOffloadOptimizerConfig(DeepSpeedConfigModel): def set_pipeline(cls, field_value, values): values["pipeline"] = field_value or values.get("pipeline", False) return field_value + + ratio: float = Field(1.0, ge=0.0, le=1.0) + """ Percentage of offloaded optimizer states to CPU Adam. Only valid with ZeRO Stage 3.""" diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 1030ec3e1120..7fdf4e4d9998 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -20,7 +20,7 @@ from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload -from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper @@ -104,6 +104,7 @@ def __init__( offload_optimizer_config=None, offload_param_config=None, sub_group_size=1000000000000, + offload_ratio=0.0, mpu=None, clip_grad=0.0, gradient_accumulation_dtype=torch.float32, @@ -159,6 +160,7 @@ def __init__( self.offload_param_pin_memory = False self.params_in_nvme_and_cpu = False self.max_params_in_cpu = 0 + self.partial_offload = offload_ratio #num of ranks in a ZeRO param partitioning group self.zero_hpz_partition_size = zero_hpz_partition_size @@ -191,6 +193,23 @@ def __init__( self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) + # backup fused_adam optimizer init + if self.offload_optimizer and self.partial_offload != 1.0: + backup_gpu_tensor = torch.randn(1, device='cuda').to(self.dtype) + backup_gpu_param = torch.nn.Parameter(backup_gpu_tensor) + assert type(init_optimizer) == DeepSpeedCPUAdam, 'Hybrid Optimizer Only Supports DeepSpeedCPUAdam' + self.backup_optimizer = FusedAdam([backup_gpu_param], + lr=self.optimizer.param_groups[0]["lr"], + bias_correction=self.optimizer.param_groups[0]["bias_correction"], + betas=self.optimizer.param_groups[0]["betas"], + eps=self.optimizer.param_groups[0]["eps"], + weight_decay=self.optimizer.param_groups[0]["weight_decay"], + amsgrad=self.optimizer.param_groups[0]["amsgrad"]) + # Multiple param_groups configs for back-up optimizer + if len(self.optimizer.param_groups) > 1: + for i in range(1, len(self.optimizer.param_groups)): + self.backup_optimizer.add_param_group(self.optimizer.param_groups[i]) + self.module = module self.elastic_checkpoint = elastic_checkpoint @@ -780,6 +799,17 @@ def _create_fp32_partitions(self): nvme_fp32_dest_tensors = [] fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() + # Assign portion of subgroup to cpu, the other to gpu. + if self.offload_optimizer: + self.subgroup_to_device = {} + sub_group_size = len(self.fp16_partitioned_groups_flat) + # print(f"Partial offload sub_group_size is {sub_group_size}, ratio is {self.partial_offload}\n") + for i in range(sub_group_size): + if i < int(self.partial_offload * sub_group_size): + self.subgroup_to_device[i] = 'cpu' + else: + self.subgroup_to_device[i] = get_accelerator()._name + for i, tensor in enumerate(self.fp16_partitioned_groups_flat): num_elements = self.fp16_partitioned_groups_flat_numel[i] @@ -816,8 +846,12 @@ def _create_fp32_partitions(self): self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) else: - self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( - self.device).clone().float().detach()) + if self.offload_optimizer: + self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( + self.subgroup_to_device[i]).clone().float().detach()) + else: + self.fp32_partitioned_groups_flat.append(self.fp16_partitioned_groups_flat[i].to( + self.device).clone().float().detach()) self.fp32_partitioned_groups_flat[i].requires_grad = True # keep this in case internal optimizer uses it @@ -885,10 +919,20 @@ def _release_ipg_buffers(self): def _optimizer_step(self, sub_group_id): param_group_id = self.sub_group_to_group_id[sub_group_id] fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] - - self.optimizer.step() - self.optimizer.param_groups[param_group_id]['params'] = [] + if self.offload_optimizer: + cur_device = self.subgroup_to_device[sub_group_id] + if cur_device == 'cpu': + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + cpu_loss = self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] + else: + self.backup_optimizer.param_groups[param_group_id]['params'] = [fp32_param] + gpu_loss = self.backup_optimizer.step() + self.backup_optimizer.param_groups[param_group_id]['params'] = [] + else: + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] def _swappable_optimizer_subgroup(self, sub_group_id): if not self.swap_optimizer: @@ -955,7 +999,7 @@ def initialize_optimizer_states(self): if self.offload_optimizer_pin_memory: subgroup_gradient_buffer = get_accelerator().pin_memory(subgroup_gradient_buffer) - self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer + self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer.to(self.subgroup_to_device[i]) else: self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index cdd18e62a29e..e9d7166b05b3 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -576,6 +576,7 @@ Note that if the value of "device" is not specified or not supported, an asserti "device": "[cpu|nvme]", "nvme_path": "/local_nvme", "pin_memory": [true|false], + "ratio": 0.3, "buffer_count": 4, "fast_init": false } @@ -598,6 +599,12 @@ Note that if the value of "device" is not specified or not supported, an asserti | ---------------------------------------------------------------------------------------------------- | ------- | | Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. | `false` | +***ratio***: [float] + +| Description | Default | +| ------------------------------------------------------------------- | ------- | +| the ratio of parameters updating (i.e. optimizer step) on CPU side. | 1 | + ***buffer_count***: [integer] | Description | Default | diff --git a/tests/small_model_debugging/partial_offload_test.py b/tests/small_model_debugging/partial_offload_test.py new file mode 100644 index 000000000000..2094448d534d --- /dev/null +++ b/tests/small_model_debugging/partial_offload_test.py @@ -0,0 +1,128 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import json +import argparse +import torch +import deepspeed +from torch.utils.data.distributed import DistributedSampler +import deepspeed.comm as dist + + +class SimpleModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False): + super(SimpleModel, self).__init__() + self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear4 = torch.nn.Linear(hidden_dim, hidden_dim) + if empty_grad: + self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + hidden = x + hidden = self.linear(hidden) + hidden = self.linear2(hidden) + hidden = self.linear3(hidden) + hidden = self.linear4(hidden) + return self.cross_entropy_loss(hidden, y) + + +def create_config_from_dict(tmpdir, config_dict): + config_path = os.path.join(tmpdir, 'temp_config.json') + with open(config_path, 'w') as fd: + json.dump(config_dict, fd) + return config_path + + +def get_data_loader(model, total_samples, hidden_dim, device): + batch_size = model.train_micro_batch_size_per_gpu() + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half) + train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + sampler = DistributedSampler(train_dataset) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) + return train_loader + + +def get_args(tmpdir, config_dict): + parser = argparse.ArgumentParser() + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument('--zero', type=int, default=0) + args = parser.parse_args() #args='' + + config_dict["zero_optimization"]["stage"] = args.zero + print('config_dict["zero_optimization"]', config_dict["zero_optimization"]) + config_path = create_config_from_dict(tmpdir, config_dict) + + args.deepspeed_config = config_path + return args + + +def print0(msg): + if dist.get_rank() == 0: + print(msg, flush=True) + + +rank = int(os.environ['RANK']) +print('seed:', 2222 + rank) +torch.random.manual_seed(2222 + rank) + +config_dict = { + "train_batch_size": 256, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 15 + }, + "zero_optimization": { + "stage": 0, + "sub_group_size": 8, + "reduce_bucket_size": 20, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + "ratio": 0.3 + } + } +} +# "initial_scale_power": 15 +args = get_args('/tmp/', config_dict) +hidden_dim = 4 * 1024 + +model = SimpleModel(hidden_dim, empty_grad=False) + +model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters(), + dist_init_required=True) + + +def print_params(tag, model): + if dist.get_rank() == 0: + for n, p in model.named_parameters(): + print0("{} {}:{}".format(tag, n, p)) + + +data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device) +#print_params('pre-train', model) +#while True: +for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0: + print("LOSS:", loss.item()) + model.backward(loss) + model.step() + #print_params('step={}'.format(n), model) + if n == 2: break diff --git a/tests/unit/ops/adam/test_hybrid_adam.py b/tests/unit/ops/adam/test_hybrid_adam.py new file mode 100644 index 000000000000..c7ef4890b322 --- /dev/null +++ b/tests/unit/ops/adam/test_hybrid_adam.py @@ -0,0 +1,78 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import numpy as np +import pytest + +from cpuinfo import get_cpu_info + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam +from deepspeed.ops.op_builder import CPUAdamBuilder +from unit.common import DistributedTest + +if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("hybrid-adam is not compatible", allow_module_level=True) + +pytest.cpu_vendor = get_cpu_info()["vendor_id_raw"].lower() + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().numpy() + y = second.detach().numpy() + print("ATOL", atol) + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + + +@pytest.mark.parametrize('dtype', [torch.half, torch.float], ids=["fp16", "fp32"]) +@pytest.mark.parametrize('model_size', [8, 16]) +class TestHybridAdam(DistributedTest): + world_size = 1 + reuse_dist_env = True + requires_cuda_env = False + if not get_accelerator().is_available(): + init_distributed = False + set_dist_env = False + + @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + def test_hybrid_adam_equal(self, dtype, model_size): + if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): + pytest.skip("cpu-adam with half precision not supported on AMD CPUs") + + ref_data = torch.randn(model_size).to(dtype) + total_data = ref_data.clone().detach() + + ref_param = torch.nn.Parameter(ref_data) + ref_optimizer = DeepSpeedCPUAdam([ref_param]) + + cpu_data, cuda_data = total_data.chunk(2) + cpu_param = torch.nn.Parameter(cpu_data) + cuda_param = torch.nn.Parameter(cuda_data.to(get_accelerator().device_name())) + + cpu_optimizer = DeepSpeedCPUAdam([cpu_param]) + cuda_optimizer = FusedAdam([cuda_param]) + + ref_grad = torch.randn(model_size).to(dtype) + cpu_grad, cuda_grad = ref_grad.clone().detach().chunk(2) + + ref_param.grad = ref_grad + cpu_param.grad = cpu_grad + cuda_param.grad = cuda_grad.to(get_accelerator().device_name()) + + ref_optimizer.step() + cpu_optimizer.step() + cuda_optimizer.step() + + cuda_param_copy = cuda_param.cpu() + + total_param = torch.cat((cpu_param, cuda_param_copy)) + + check_equal(ref_param, total_param) diff --git a/tests/unit/runtime/zero/test_zero_offloadpp.py b/tests/unit/runtime/zero/test_zero_offloadpp.py new file mode 100644 index 000000000000..c376686f8052 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_offloadpp.py @@ -0,0 +1,75 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import pytest +import deepspeed.comm as dist +from unit.common import DistributedTest +from unit.simple_model import random_dataloader + +import deepspeed + +from deepspeed.runtime.zero.offload_config import DeepSpeedZeroOffloadOptimizerConfig + +import torch.nn as nn + + +class NNModel(nn.Module): + + def __init__(self, h_dim=1024, n_layers=2): + super(NNModel, self).__init__() + self.layers = nn.ModuleList([nn.Linear(h_dim, h_dim) for i in range(n_layers)]) + self.cross_entropy_loss = nn.CrossEntropyLoss() + + def forward(self, x, y): + for layer in self.layers: + x = layer(x) + return self.cross_entropy_loss(x, y) + + +def test_zero_partial_offload_config(): + config = DeepSpeedZeroOffloadOptimizerConfig(**{"ratio": 0.3}) + assert config.ratio == 0.3 + + +#Large sweep along hidden dim, num_layers of different sizes +@pytest.mark.parametrize("h_dim", [1024]) +@pytest.mark.parametrize("n_layers", [4, 8]) +class TestZeroPartialOffloadConfigSweep(DistributedTest): + world_size = 4 + + def test(self, h_dim: int, n_layers: int) -> None: + config_dict = { + "train_batch_size": 256, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 15 + }, + "zero_optimization": { + "stage": 3, + "sub_group_size": 8, + "reduce_bucket_size": 20, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True, + "ratio": 0.3 + } + } + } + + model = NNModel(h_dim, n_layers) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=h_dim, device=model.device) + dist.barrier() + + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() From 7e480ea02fa1837aad0e9d3eb8056da3055a392c Mon Sep 17 00:00:00 2001 From: Guanhua Wang Date: Mon, 6 Nov 2023 14:17:15 -0800 Subject: [PATCH 3/5] offload++ blog (#623) (#4637) This PR is the blog for ZeRO-Offload++, it describes the details of how our new Twin-Flow feature works and its performance numbers on both DGX-A100 and DGX-H100 machines. Corresponding code PR is https://github.com/microsoft/DeepSpeed/pull/4636 cc @jeffra @awan-10 @tjruwase @mrwyattii --------- Co-authored-by: Jeff Rasley --- blogs/deepspeed-offloadpp/README.md | 52 ++++++++++++++++++ blogs/deepspeed-offloadpp/images/a100-8.png | Bin 0 -> 22879 bytes .../images/cpu-offload.png | Bin 0 -> 42294 bytes blogs/deepspeed-offloadpp/images/h100-8.png | Bin 0 -> 20621 bytes .../images/twin-offload.png | Bin 0 -> 59949 bytes 5 files changed, 52 insertions(+) create mode 100644 blogs/deepspeed-offloadpp/README.md create mode 100644 blogs/deepspeed-offloadpp/images/a100-8.png create mode 100644 blogs/deepspeed-offloadpp/images/cpu-offload.png create mode 100644 blogs/deepspeed-offloadpp/images/h100-8.png create mode 100644 blogs/deepspeed-offloadpp/images/twin-offload.png diff --git a/blogs/deepspeed-offloadpp/README.md b/blogs/deepspeed-offloadpp/README.md new file mode 100644 index 000000000000..22910ffef7b1 --- /dev/null +++ b/blogs/deepspeed-offloadpp/README.md @@ -0,0 +1,52 @@ +# DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow + +Deep learning has been successfully adopted in a wide range of applications such as speech recognition, chatbot, text and image generation, etc. To achieve better model serving accuracy, model size grows significantly. Take language models as example, from BERT with 110 million parameters to Megatron-Turing NLG with 530 billion parameters, the model size grows almost 5000x. Given limited GPU memory size, we need to efficiently utilize GPU memory to achieve good system throughput. + +ZeRO offers memory efficient data parallel training scheme. For training large models like LLMs using ZeRO, GPU memory size is still often insufficient to hold all the model parameters. Thus, ZeRO-Offload is introduced to solve this insufficient GPU memory issue. ZeRO-Offload releases GPU memory pressure by offloading data and compute to the CPU side while minimizing CPU-GPU data copy overhead. Given CPU memory is often orders-of-magnitude larger than GPU memory, ZeRO-Offload was the first piece of work that enables billion-level parameter training even with very limited GPU memory resources (e.g., to an extreme: single GPU). ZeRO-Offload provides excellent performance when model size is multiple times larger than total GPU memory size. + +However, system efficiency is still far from optimal when adopting ZeRO-Offload in some scenarios. Especially in the cases like small batch training, model that could not fit into GPU memory but not orders-of-magnitude bigger than GPU memory capacity, CPU offload not only introduce long end-to-end latency, but also underutilize GPU computation resources. To reduce memory copy latency as well as inefficient utilization of GPU introduced in these offload cases, we propose ZeRO-Offload++, which leverages both CPU and GPU coherently. ZeRO-Offload++ mainly includes 3 new features as _Twin-Flow_, MemCpy reduction, CPUAdam optimization. Now we release our __Twin-Flow__ feature. + +The key benefits are: +* With _Twin-Flow_, ZeRO-Offload++ achieves up to **6x** training speedup compared with ZeRO-Offload. +* High-level API provided in DeepSpeed config JSON makes it easy to use and fine-tune. + +![h100-img](./images/h100-8.png) + +## Twin-Flow + +In DeepSpeed, when training using popular optimizer like Adam, optimizer offloading follows an all-or-nothing policy. For simplifed example shown as Figure below, without offloading, all the parameters will be updated using GPU adam as FusedAdam optimizer. On the other hand, if offloading is enabled, all model weights use CPUAdam to update. + +![cpu-offload-img](./images/cpu-offload.png) + +The major downside of this all-or-nothing offloading is, when offload all optimizer states to CPU side, both GPU memory and compute resources remain under-utilized. Although increasing batch size improves GPU utilization rate, each training iteration time is still super long compared with no-offloading case. To improve GPU compute and memory utilization rate as well as decrease training iteration time, we introduce a new feature in our DeepSpeed training engine called _Twin-Flow_. + +In comparison, _Twin-Flow_ allows a portion of optimizer states to be held in CPU memory and the other portion of optimizer states remaining in GPU memory. When optimization step is triggered, both CPU and GPU can do parameter updates simultaneously. Once offloading is enabled, we provide an offload ratio configuration which allows users to adjust how many percentages of model weights are updated on CPU side and the rest are happened on GPU side. "_Twin_" comes from the idea that both CPU and GPU are using the same optimizer function here. "_Flow_" means parameters are not only hold in both host and device memory, but also computed using both CPU and GPU cores. + +As shown in Figure below, with ZeRO-Offload enabled and we set _Twin-Flow_ ratio of 0.4 (40%). DeepSpeed Training engine will automatically assign first 40% (i.e. 0-40%) of weights step procedure on the CPU side using CPUAdam, and use GPU side FusedAdam to update the rest 60% (i.e., 40-100%) model parameters jointly. Therefore, with _Twin-Flow_, we can achieve decent GPU memory and core utilization rate, at the same time reduce training iteation time in optimizer offloading cases. + +![_Twin-Flow_-img](./images/twin-offload.png) + +Note that this _Twin-Flow_ ratio can be adjusted based on how much GPU idle memory is available. The smaller this ratio is, the more GPU memory and cores are used and the shorter training iteration time it achieves. The ideal case is to be as near as GPU memory upper bound in order to minimize training iteration time. +Note that _Twin-Flow_ is not limited to Adam optimizer only, it can be applied to any optimizer (e.g., AdaGrad) from the user side. + +## Performance Evaluation + +We conduct our performance evaluations over both A100 and H100 DGX machine and test for OPT model with 13B and 30B parameters. We run 13B OPT model training on a 8 A100 DGX machine, and run OPT-30B model training using a 8 H100 DGX machine. With some tuning on offload ratio in ZeRO-Offload++, we achieve 6x and 3x training speedup of Meta OPT models on single DGX-H100-80GB and DGX-A100-40GB, respectively (top-most figure and bottom figure here). + +![a100-img](./images/a100-8.png) + +## On-going Optimizations + +* Reduce uncessary D2H/H2D memcpy + +* On-the-fly fp16 to fp32 casting for CPUAdam + +## Tutorials + +Examples and Tutorials are [here](https://github.com/microsoft/Megatron-DeepSpeed/blob/guanhua/partial-offload/examples_deepspeed/offload_pp/README.md) + +## Contributors: + +This project was made possible by the contributions of the following people from DeepSpeed Team: + +[Guanhua Wang](https://www.microsoft.com/en-us/research/people/guanhuawang/), Masahiro Tanaka, Xiaoxia Wu, Lok Chand Koppaka, Samyam Rajbhandari, [Olatunji Ruwase](https://www.microsoft.com/en-us/research/people/olruwase/), [Yuxiong He](https://www.microsoft.com/en-us/research/people/yuxhe/) (team lead) diff --git a/blogs/deepspeed-offloadpp/images/a100-8.png b/blogs/deepspeed-offloadpp/images/a100-8.png new file mode 100644 index 0000000000000000000000000000000000000000..22b787f69e1e0c7b566ceda0df6591895e804b18 GIT binary patch literal 22879 zcmeFZcT`hd*Ds19Aku7rfD{$PP?g@hR4IlUxE95L&1rAP{;D zy-8@HNC_C~iN4?Se9wE%e`lO~&lvZI5s|@&h%VD#B_{l0W0fF5_`2i)R8b%*>%Y5BXk4+A*OVtBs)(gHeL+HK*Xumf zdpa>OfkvYzCnu+-rU;+^+>{{gAP~sh+}zsQ+JUe|bvwoMxVpMRl7>A9Tk?pHk55og zP*_;l+qZ9{qoWfN5>iuBv$C>2efm^TP*74*Qc+P+^>wtSrlz(H-O$kR_3PK~-@kYE zqI(G2Ai8&AbC9sj{2U#h7#kZSJet76+}zy4!orUqKQI`~#>U3x=H^!5lkM&8ot>SZ zKY!wIxZ~sFv$L~{iwm`WWh;Ud*5s?vtnv%SpkNH-UlOk3k!xFD$;dB%DRPcJ4$Kf?$fi}02#jR`M&vS$QM;t z4ci+H>CYb8t?aax6m-o4U-7Llk_t$riUgNYeJL{#K|uCfowt$%7Cy9~*dzD2SmjZQ zMsM~yF*%U}F^P#iPoJT|(3btX21kbbco%JoqYW>amZ+H}Y(cRv12VFw)AGW4epf8# zNOa#9njuC>vA^Y7mXmHFTf;Jq%SBRXn72$r>w}uJHF{9%r^gOpuPNv3UhO0wYhqF< z_E}HU?E6&9EaDBmGpT06XKHRHK~QuM;A#l$2&K$YFL!In*lV>siT>Sc?Ck)KWw~q4 zD5)o?cQ&_FeXkBDtSvWWkiE0)m_$sLH)@u1B=RmYo~_wso}1aXy~1O6V2uf zL65G-O@2kp;>-8e+5LPCyuBDRk1IB`UmZb?k|sl>2@JlU`M^Pe$X(JsTq#ddA#g}J z6BkDe_@H*{=^hbWipT-uwzn$EpTVylIlc(IT#(o6`kuJsToWJsd6+}^=c}0MrQF;8 z%MZ$v#Ha*!M2};z|J5CiA-nzZ_r;+!i~K{Zfp!`T)l+ z#bNo=z0+M6S?_O*30E+Mnu@}~6IAdSWpE^k@jtd+^FZ_=5$64YaXIrG4 ztpuKie0tANagQ%jiCISS%Zz0EL<$=8k`<5n-6XHPlC@JmQe;UG+wHNoaq)OTO|Fe3%*W)b?4ofalC1u(nrI&<#lR21q0;1`}rZ0P7YD@BfS z8kEg>Hzuw`WH~)%>#cT3;7=WR`K+L55FR4aN56j=%OX8n53ECqSY%}myrd6a(NwG% z`9ODjVcEFHikKzg4s%N9ZAYyY@5l}??v@ukoIO;iluX_F%L7&4a-)GH$%92B_$jWr zuJI)8l0z6!Zm;+do?W2Qmjx!h3fpS798<(%h4k2(u*L2PxV1%Ox6}rl1@L? zMoja%BFODd?iUxvWY6gL*Oo;JbT@To&eFOC=>v zIybxU6{o#(awz|&J1>wPtF|c+A>g?qRz9)2E{^hxXzOcYw;MdCRx3ylgvRh&>6qCT29RNn~xnr0;Fy%p`OA ztWNe1J*pQtRi@rA3q9J<+8iwQ&XWdy6%}4Re}XsO-wSk+nll_l@SpSjnh~@7xyqzLs9)X29 zNE{ckU0-J2yD{=;FKS>q)*v`lqMerqt*;y9z@+Ac2~(#OKAzrMidz45*qJmbTOf=s zoGf}L#eUC^FE*BE+jiBoEkQ2(a_!J1FEYRmtp`nQF-nyI5FB<1m|TJOlGmA8_Ucy$ z9}hb}($-(ySY{U>H)Q0=>1CeTGpg{fk9)XhS;qgd0j+B{n-)UQ$P3b%YL?Qu$Iv&K z=a}YcZi3(W+?@4q6EIEUm|ees!iuOOHXk1 zRDYu!KYycVdMJ#4U~*4Ci971(%j>gNf9R&~xdb%8lB^pS*>j~AAf|64bQXWrhHvp^GM1Y0(P z2|st|)_6i^XO-|G2)tJ)e?u}UrU93mrniLqfoosPJrzg#73}rw`-)L>bXq|6fN8zdk(I+vc)@PRUyVN3>Zdml_ zdF~ifK3!VDu3+KSJ8It;#6(PW_QN^Po%E5}i%h?eqLUc3`~w;_(Ka}b_uDaPDIRjT z?qt)LrFYz(dExe=fxb99MiklP)CfI7b>=<_821o8m=yNUz!1G(RXo9)XQ;h8La z1k^2~#Cc+0+C&T!emZxFEhFadi=Y}D4jBX&G}~X18>aan&ez4TmwY8&q{Fau_;zT*u=$FDu^2CIp2Pa)JU^S9{ zC$6$E!$_=+XH0+K2E+h+MkJW@DxJ`O8I_Ie=DA0EDo`k0c_>b zac(OD+T0k=H|0@Gc9#c6yVt*Uk7b}LIk}WtCs1v#WqGVb>tBQ0-UWu*XbVwq-MU(H zYa`3f!ie|I)GbH&*Wu+liJK*l_-JLH*ZJ@UR-Lmwboe!<9(OmErjiPC+}cw(5T5Bq z2%3OOaZHoJReo)w8c*czxja~x`|RSUjqAM6nQPeP{QAmg=HZ~0jsZ20%2!aa2i3F7 zNQB?X+KW4cUMqC;WhTGx`d|$#upnuICf&$(R!dpYFL(sBH1uoxBVt>h0NC)>K=rBD z^MeXHE3U_8x7zM|{jjL$E*vwyC;7?Uo<8ao*6|(G)Vk?qp?Z&Iq#{AUr$GHL&m}AIPv#xnhmY-K(=MPoGQaIwm9v2Qk^UQqM92QAnFO>Y! zK=w{^v$LuN$@PH*UZSbLVqT1Y0>C!8=FER(kGnM~qI~@lU?*vXWn93{jN3_VQZHmB z7W}Z*G&OVH?-V8N@nn9zZ^_RQm5-V8@EzpOPQjt)*juusc~WS;Cuu@;Du>tmBBhQs z%uKub(>UHfwA#);n-jr1&bq^QW&M*1Mucvd2)hY(K>QG=0l^g;@dnqftS403pI3=> zcNPhb-sJ_z4r!Im=_R{xvL57>aE+AP)3?46Hg&RWhkto&!eh9@ue>!bj`9+s&WO0b z#UZ@v$NY^uK(%Jgz@yJ3LaXX7c-*fl10l;VoP>Jx&S`#n=P+^2Q3at@4`F0?)#9BW{adW>IlFsh0g&Av1nG0q*vOieQN*_`UhRaMZAMQtaB zEU)CLk)#9Z!_X@-@JwQ2!jRTcAHl&|ssRh?f3fianglKNt*-tWyUfcGtOQ-r%paR8 zTmg-tQNg`Mul53r_H`q_9d3T??yRo2?n{SVZ{jJPbdC!Z#t|2DP911r%SD9UO7rw_ zW#0Op@{epXRD*Om!yQCS9HnQjm7c%OGCeEgnSOqcPSp%-TuoyiEqZ`ByD^dm&^P5~K&EXHLV z^K?yiSul6z!t4acg!Wl^CpWL^I^{I4Dg06%PH}a?DPLJds+X|eKv5PQhH_C zo+9vk@EhVFT%YpUWli0a$S;z=L${+Z=b*`f0rBNI5nI}49&Y=e_0y0U;%QA#4flQO z@%rNeip4S6Y9wUdw5(bU@tp0R$c=pwCb*t#vEF^;8?#G;ZDafnd5*@b2T*Hog6OV* z6y^`_ZM_UarBi>4o#9INpkc#D{X|aJJ8r((_zJ1YCcGkKTqS*@-r1|;#akGLiI+Bw zHHYsY|4ZVxPmM;457Ltiw=h%3w}!kQp)KCDu=)2O)+Je6NkiO@~WL_=EbWGuLuc9AdS=c z%WVt()xpg^M`hkYpu1V90KU)=Kh!j?`*EOj*1Uih>N-K^NcBfdd`m4TH`Xi(s==n! znIPh%zcQ47D3rc7<->Sw@}n-w%a*OsiHeq*10AB5tX;=YwOO=Om`8zC$~&p+4h&cB z-7X%>KYJJqqw#G2_`wiW^vM9XquanuIj25AE3hnbv+lOjMthuXNGjE3Wb@qnQYSwD z9@b4s7R5QS=Sf6OD#fDDTyEA-6oI{!qM`Q=9J~q(D5w=#U#q_B)b-LpRY1QUabco8 zA~R>iIevXqTR@91ND^X4D_O&SKhHU{V)0sk52p%bat$(mCvTXlTi9nT|Aw;S;)HcENJ|FWpziiW%TXlbW*2x%FZ&5 zY-n_6D_86zNKZ07a>s`fi7qNL+8On48%>PMde7jaGsN!kRn<;0y@~h0i-OZtKjd@& zXup(tuO#=B*@S;mCEajRnrK*%+fDnlFF2P8azYG~xvE}_GVX_ZP&1z|`s!lu;U0WW zLt50vH)*w+IIs66#Pbw4jk~_NGGueBa&mcvAYv6>6HBT1X~Q-E$obidzjOA{nmAe2 z;vffjwW|8@QpdDr*}Xy@*7fV3Y(QQ$j}%^wwz<^B+XujP1;|9FWB&9_Pa1;*6QQ#_ZUK5}LHuxQ$uA-vhHl2r-1)*m=#F9oNi}O5sY;@& zgR%Hhxdh_z2UYC!6-*%{lUS)A37ulvpRWovZ^TIr_wE09GjSez(~O}dCOR|PZ1T}{ z-qcH}el_N2a)%w21@HL)4f9Fg-Wt}pB|52c^{C7^YkTRpa6lOvbhU4y6CuIvjU&mi zqjn%y&yL^j*&l7UyM1TXs>fV5Yr{EEt>%ICI~q=~ctO2tT`3GReHbZG^>fD3OC?G# z^MZ~zy*{r*7*jOyS=ReG-_%3Vw>q&;s@@W66z%X5q+Z**+;M_OuF(ACB?MLO`^w!H ziiv={dW>2rN{j2=Yk~bi^X~F2PQ7pBQbDA?@LE`9Ktb9bTgyUUNE{uHht^aYR=TdT zoVHShy@kBJ=tM|89A4>MCtkN%q>wOGBfVh6xRMdRkaa&FUE*be9HNbQWe#u(oG=_x z-uc?y7$p0{3!g94G3}iYzMA3zlPfDfk}n2;7`aiukwm$9?AiDse1<6>o+teIThyuJ zwS*Mp#ym{pzwvLZn$@}E3)XXO!UE`t8oDzt3?umkSHyMG++B2I-ogUk8kJAb}!&~NnZEb@}-<^ zg|UV_q-4E z+kSnUr~ZgLoJbmx@Pn4G4G5dQPI*Tz%iapFB$r8)8JPMsJkz9@D7PJX=t0r&yX53e zBf694<`MJ7&r|rZ|q_ab;B6@Sc4v2 zfdpoSEX4AUe5?mw!!e@~MK?yeNGc*pvVWVvdOoah!8|1{{+VypR^wGhrCtL@BU(qs;pVJ>$VVngm>zW)+<9h>Csjs>eA#?I|K+zB0+uh$w?Ro4R_+2F9Ha}ke=Yx3 z`!qKg(fmn6k1;f7Y`9t4nfoE8(3c0}0|C&efhX5)e=6E^t}x}>saEM@iE7eqZ|t=Y zX$F8xo5eyu%Wlc2^|?2RTQPT~y}r8A9BOJoA=iA79={aWiE}TCUwMwq*g80X&(P;f z>JNcKV9@I9an(0>jPG1MHwu+W(+?Co@Q_!aQF~N)H(YIUW*fI9DMSumn}UXYb0hTh z-JQkyqPw2~wNp2$UsOBNT?4OtY-$3ESPSOwPj@Izx<4FS;e|IQd$*cKO}VPBQ_9nOw6UG^EFky3iSTuf*IC1Pdc7=mFShJ5j(!h_7LeaVNtR zu0VnOM(?yDz1h-&b{e`e6mB)Yd$bo`8o%;>k6TwTN`BB5zrUxgR?++z+KYZS_ET3- z4y6olt$B5Q?`ON|qa)?n+c!>;*IOgl5{b8 zFtk0FuXC|o{~6?wAkEfRVx>Cfpkg~`BU$-`49+|CBtD!W$1{((Z4<=FfoVP@_r0>8 zVQF>H@R9(z;nUI&8#dw&e$dfx$&9-^$RCS8zV=^Ty0)NEDJa+cNN-3&t4^aXT9#|_ znjPeGp{;!r8zf! zODJyk$pMPpv3CBJH|-C{XAXWc6YStH-v;F8-BT9V!GJf!5zilh6tWoJ5KwoH1%S28 z5_Ib7N(PSWLQd9}3!|FTJ~;jqVWRzSFya$!EEjKpTOTS68~vnTm;hLryMLTnd#@;8 z4y)$)zHN1#>6dFl5v12!W-I=x@Vy7GA%PBgyZoUxCsQ1g-i^`O_E!L81jaqc6`m&E z1CGbMYCRtkaQ}Qor@OQ{cQ3AT#u2wl`cVGlFl<&H<)0TwQ=G#x<7|K+DcA$CK{M*3 z7*V~7{{CbBysw`MF1S5PItM-PGh<2+6(I7}+V}X4wbqK%^93~p-Y75*hR7?;& z`*j>pmoc#FeN<}ON*EjJi%sHBmUITW!CQdZeKiMBTmeThlyv!o#P0^HroVl3yY%QU zCKOIzZF|RoxG?+K{H}--z()D{5P|nVbuBA~a6$%AMBrVIQbbFrO)g1_02|2dhd8gB zRZNb5DhA{;ZNnI0j|lqp#0}aU97CzU$8YF=84Js+Ut1=hRs-E#QEif&zAn9Pyyh-y z0o7&uPJl+Sq^#JwT^>Z`xSm}kA=ub=%aufoF-2TPc4Y;&S}{+zVhAywar#~l;6>Hh zw6l=USVOo@S@}TNH2O1+r^)v?hRlagGS($s2BU!o2B%Kt%^QF7N`+xU89!--)rh`M zZxCkG-0}Xos&M)lmUrE}R+kX?H5D_5b1%w5|Q zQ@?diVzK(u02q{#T4Oxl?p*OCkL1(Uqf^gD(BCQ-ueX+D?f|~g`(B^BA7I};LqE9T zQ#?f;=>90~rKfEhq53-^DefEmIQ#ARj^SH2be-At!DQ6eY8uS3ANkA`7}GeV*qWrw z6>myR++$7sNniIbkCjmOOFv2zksL2ipQ(g<`>K}g0zs;vp0HWsrYSWxj%waAhCi&Y zvC#24is-D~O;Ga*^M|X>j-oD42Fm*rbj}f6ix|{?htdIPz^9^L7SW2Hi}(e3=__fV zw{GoC6u=_jmFxqfs3*AIjW|@S%G?Kb%Ba*tK=G|~rL9QCP#eDrR|ENGYl}zS;%Dm3 z?WM(WLmz4(AHH$DxU!;8;t|NaY7_A(F|o+BW#<=B0D$JIE=E#(pF~U!wJ{O1_c^xg z5;~j2GI^mgywmO;xanwP^ls{vJvekO=TDfRaP-9>>CLe@I#BCOErTRO!l6sxW0N#6 z*|OAbXqkjKk7vGSp2T@j2Hs;|UMXS+dMJ?XkQ4Vi3Sw983N&9MMB)vd=`OO_2}zoK zEXkK7lPIH%7JDSWIAH{#q=Dh~|6C1RyqV#T3q-QE-7k8^AjurFvXV?kZr2R%3 zroVlX7}^NI2nibE! zIT>YD@$_O*>fx{P9WkTztik2^z|Io~HD9?@iGYfQZ;mARI^E~#eP2gT9c<=CZ(^9t z!lhBXm zX|fjGJ>Sxf9}pH6=5^QRQy(dzgH$dOoUt5sqB~jDp-Ik<5rpa zrN(rgss7gEC7pQ%nx`8rGgaJ$9DC#qUoDq|%(qO`+rtfD&{u(90_6ri)OQ;mt@Ot& zX!j~Jd+U}<@V>*<&>r>AHVzt*c#R4`fMkE|gTDdrHfJeq&*6nLuX!K`Bc}M=s8nBF<}IjZ z8nxv7tqEDfe51|fyXosmnPd2{{U}o2TNBd{FkGFoO*|}s?h^KWhKJz9=xMz+_muQq z+Sa58mc^8PcX>Wz{b{{#mX4GA5=x_TR$ZL_@R=Nx!MpA+*QrGBr7gWtOyr=7N-Z=O z5;Kcxa<@;f()bqW4GHpcss_a0hVASZK}KR^$jEk;i|4rQ4a|{EUTGR z9Up!d01dfd-LI0tM`}4MbTx4{D*kWldn6K+=Pci_#zR&ygKuqNpNNpv3Qm$Xo56oc$I@St2MoZAM>N*r0fudvB`x%hO?(MT zG5VrLc<+#W4!H(O9<5m}{rO!P!rb?5+ezlTUm4G6MTS{TR4WUluMb;~aKVyHeYUR4 z-UpK-qRq24RVO`8PD$D2LvJ&fC{9|c5UlV{$}nHL&t<_k(XM!L>dbgTV2coKnNN#< z>khNEYJ<-dZPSgW`QdEG!jXYNTFQFO#nSI!CqU#v_nsMpw#r-Mn!x}^!kj#7pPFGk z(_VuHdENOoFv^d0QuYw~A+B9&8f>`i-7)RpQS5$mf z4`j1my?aiblOY0f7elH`=pvVMkH4n1Bb>|26cZBO=$4Psy(;P*VXaV~hMtQykt-p% zsoyh{>5cZhvU3RPVt8Y472C>3VM?W6smvy^tCG%Iq`}}yT4Qm5uXeO@hhQfW1eX)t z`vyzPc@*|-rg6sD``6}ip(Ck))YQ>MM?eskqdaX4yF*pkh~VpR;D*04S{b^(7k6d7 zNowFwF#x}O5%BWnZPkjd>=Cm0%YXtfUXS{NQMhS^wn)!v$5IgXv@9nUEpQ$if6It> zhd#1EH4qzekhiD2#306Foo?sql}nr63W~FFOy7IiQC?-ZA*A$U@aG{gQ%R{ipX$kD zXWu5C<{Bd@l@_0?GAM)zPp&s%s>JLL-AX3}uM@BdMMmnum1d^>%yvOCZ!pCP!N>0N z3EF`&=HAiCLdAoAO`>rEa<^|}jZi*k)5pkH=ibyn$wuC$`=*Fkup4Z>A5m8adQk0d zTz*6fBz=2d8Wp5z0pFyekMwv9z-CTfGE4Zv5C1fSWXC^x#Cnx*8QJg4lySy5Ps^c& z{N^?_&~Xcym+&{1xc8MK{xa$|r{Xl4n#I2Mh{Q_IkddH3G59B` z$sJtcnK0p+u<5bI4ArM;-dv0^LLFo?Vt2I#j=4z#K!7v?Dp|F9j0gi zZ{-qixGu8}>gKeI`Ijs6a$G1O^PXP{SRB_+7lKwY2L=^hmmdA}a#scP#4&~YF^3;@ zFTcL_Y0l%kME)1GD#M4c=;oSktC}nvH2@6S`Lg=-710Si6W$x}x#<<`k_1TFOxRpz z)~ZnbDIJ&0=uNGcqWj{)U@bCW+u)zUQNa%mo8U0^e zT;&TpJv*9TDgUt?OgOj@2xj5vWVtB6z#q*s?XNzEZ0@dJhOdwzbl_6zF=q)E(-y(Z zxCeooyBEhh!DrPXf36SalXDDuOz{+zdj(KH0^lV5y?4t0r}_VwHsE#37qUchEU8iJ zK60op7K8*80l?&9w)(`WZ8bNzz|t4zQ8y=$xSW|4d4OyUQMGpE1!DTx5Yax%Z4 zsLF@w6X+7zYFy#wXv*C-r;F^t2LvH}CY%&^I;##oLKr9C5PX_Cioj9CGO(HF>CFXC zu*_g5_8mSPa@e^wRy`qQ2KtoQ>zcgs^(ZZaZS|>56wIW7z`@thCj=(rOi=}SstQu+ zGzacrnfb>isfv93TRHXbuLVfTuvhI@W}okVXqlV)=VP*y?vvjTamt7rTitRzEGR_! zQJy8ap2Ot(GHdn^`3UP3-7KGd_mpB!flMTk9IX%sUMZduZ0YMUJM^L_TmQ#f(+q<(?ObqyhM)DuchSS5uiCb+Cv;A{+1%YCqN}~{ zfr&zhtp)Ob+*1T-Fnb92Xu&wax%u6_hvl9>weoS5EGTvTtwWNX*RS^-L}34rW0@1v zIm}~H5-^sqhep1(q$&wYWZvT1Ef`N~X+n(`T=`Ka{e?RVCOQyFR@=2O>>a@UmxUuX z;f*(+OWu< zELbfQMt?B5QtQUKbf3J}mjCo3JmRia;$I6$7IN9851-UPSD~#4?ZU^6% zy5gPmI`1w;trap=qkD1l_+Hib$rCfA)x@Uy6bV6+klzj;(MzblRV⻨%D7(&v zf?NRQ%hsBjAg1=d@Cecw$x-ZB_LNkbNp4Z!_3d!5y`h>aYS-;etO;j{lWg|OKJ>uc zZ((HeZRP#_|J&J>HDb^)}Y3O3~Cz=8tIpn&B*TgM7lUd&`~semM7wnBsY z6ma`mtz~IELJW#liofo3{(CT^|1G(%nlEVTg@C9Dss9r`u%ii*7vqoDkuFf56S%uy zCX$jtx^gw>$S%2iQERMsM06>N(m8oT`IK5y~c-%fGuHx-Z?2>mOy9^p~-)F0Ap z6y%ajBk>O3?l|wN=9{OvP(54aFDighvD2tN`NC_EujAJ218U|g>x<0iI^tc~M=BTn z%3bigHN}y`b7X|tq^^})mTBUnrtcJ%Yp^ptX^VWxLr0;pLs-T8;oEPwH2{tm<vu3yqE%KnOrN}kBz(^ z$iA;0*r@!joE21tNd(U}L=hl>>vxA+X2%>29;X|#R&l+ZT9}E_JA|L+Q`acO(39stQ5sfYwU;A;h+z$Y2c!*n_ z0(zMQLdPKoB+soM5+8H3r4*R5SP z^F%b@CFTkmw)<_TH^BZW;|T#Pso&Iw*)O>eQJKW)lL`IGTsn z5-gB-h)tj-beh#nv0i~Its1F>5*WmuD;Eg)1LMtxFr}rT<^vL?6Cb{flueJL9_%SX zkB+Db8C6*XFW3ckw7ISdlDV$jYeC$|Fel@dj>-%Cyu>x41JLsQqb8)FEM#|j%D*F)^^2>GZ&C(oL2b2YMG-f*c$ zGv(#<%n`OP;3X7({!Ifhu=!^V3ZY`Nr)D42(B1)tf5WI0yrftcZyisu1T2%qs!!@6 zw{GZ3oQVs-LBVo%bZOsruHnNe%ODuzrL~vz02&Iz+-*zv-4cg}dDpnJm7VK>owb+$ z8dbX#Q38{kLSNNptsULy#v9&91F$!lbe%p`mIZD#TNEmDsyO+(;8~VQUDRVo?ei|1 zXd|X9G$+ZAa=AD%dso4=_0kK*)3_!X2ZG%f{%Di}fIUwCSYfo0ZGbM0CJZbfjh`DW zsu?9(?U3;?$RJv-=KjL@8o@Bv%bK&TAqM=xyqs3>5^Ge$F^rcWXVmRKFr*4P7Ap9- zj@+mjX9F`?pp5lv{F16qgrzvu?H)djur<_&zSK%)kTTHDnUg1!XvS8EvAxa!EN6qV zg$cp6?*RTdOw?O0SbYP~#L-%8>`ShI-g z&vVG=(%ud)Vv?!M;OAC-)dTQ);GsE!;9^8~e?yTC8|LS6Iw?aQrq^xHL2a@z6v;4| z$2>A3H&C@PRTf&*Le2*Zf$) zFl;+rd^))@sI!Ox+xuxhc?xPiEF|#QWa!;M|8#0EjZpMPgDO(gboR>+9^9Hb>NwC} z|7Q~X%HIHFisCM0uHT`p8EUhApk#ttk&GVj^sGHm=)4TAbV5*%#JG$m+;1+_FM1h2lHV1tf{0?f}E^n=sRbhI3m?i`SIr^T=uCi|3=%w zyplZ${&~*D7edaq*leNqXFiWMUPpHj1IOifr(K`%LLldsMalf=jk<1!55<%6hD;gH za-NAVsI}O~4o#_{nid^%8yt`p4vQTLmCeC^0x6!xvq>hEgcO!Y?@uhr3*LH8ZRL#r zxR=)JYV)yF0}A@F76R*=cv8IXUFr{`pi&UqVe~C+OuPDQwKlS2rm=cy#`W6#Z`vDr zE8QJu5PSf44(|nl*F{=W)A_Y)le!E{ba8toXT!yh|E-4^q4dGO`uzP92ZfX(;4_RH zxaud-V=Vjj7Bz3YRwTr<=AIL%8jP3v@XSuD;<3iwhR-Z8NsN~OthfKh(rTC{^s$l| z!+8l~4ilf`39MMZsW*A=)7FAW!ETvzgmO!@vNq<85<20Qx+ZgQNwJ{ zMf_jZ8U>1&eZN;yVa6YmY*%}jz<4u6_LZ@Um#^kXtQXUvjUR3zwJBD`BnQ3y8NqK> z)*dKsW0}%xzn32bjX1&OZw-@-$i6A>nQBIu+UpLAB{i>BYTgq;ZtDCUeHG@JzU`#k z+^gWEpiS@xnf6h*(hbaYm5hy8E%F=UakzRTsk(bG=l8SVM}{BEJ@$_;G=ozc}W`6?x7tP5c-j zQKbTgpS^;Oyn@{d`@R2Q!#w{_^FP}7|B^ERP4_OasrZ!O^VJKhb1C(6OwPqic;-!b z<{kK5GC*v?Qo=<}T*5^-O!trB|J4U{!(fSjHA^RWa_Ca_8zN{BW((WKvOIK0-1z;7 z2`PR${`tP4c!*;h1Q-Xgih~f=22sT#-YvFn?t1g0T3j(*H0~zn-TXyKWR@XXu) z_0X5CIl~bsUz-$FyfXFKbx{mIPw)3KHq0m+2F`|QXTyj^#P6Nl?`kU$pBoqs$XkX^ z28vtktn53yC|f@rcC|Yf(axi$Du_2zWgrdxi_^Qxf9e91TY@1wo4-(a$3M7MHMcMdf1j*$LeSR$+AnH6A=pw^Ia zrfdeyWyoykidy@RHYN7hA%gc!lSo(Ue@K>au({NTT=Dv+IEm`0DJIDY^F4=R<1;qN z8TE|_gD5+#Gko8Ha&U_UUiX{(_eipTzeN?r4T4yU{HlcIcumz;8%R}q)KE-97vCU< zxbU98HjA%1T5=wIC6q}1Td13Apk4u=O{g5oVzm-G>r88p!B+`)JC;H*P6(vn%bQV? zI+qEB63WW}&Oar;PPGRZG>|7z10hU=Z+rlDxu~fs>oKcQu?cSK30?Q5hnuTEwpAqG zR6sqx`j?+@zJrMh=y1ck`zqPKw463_gG=+ypO{!!@K*+cYwngU>ogfdMQ?@;ZtZCT zRN90alR9%rBB1Vrl+`?)01spW-nk0wpHj|60pqbIeK?R#TCXq_SJKK6e)I219#(lYt z%qLg3xE$1wf-|p@6KBJ{1apfQO&M>;?N$5@4tLE_IaY;&qT0jvbVuG)T?)1jD>X|= zX0CzZDEjL7TuunN+3{k(+VkW1?#;5hQMf}Jby9tgUs$V(L10L;RORP*GOwgxz->bD zp*ANO?AJJ3+**3A#i2na*qkrLTQiFZz$&RK4p0?;qAK3Mm?uSFT<6j)ao6AA=s}D6 z*-+7wmw_wrhHV_tZ@it%*yZdIl*l zVL20MJ)K%^=PV&O-ezS0n&n1dzlE30&c3!wrhkf3q4^tJ>;YtG^h}pl!GK}7sIUN6 z?L)&Uoq$rPi@yvEFXt(I^Z8-n|8?k82x<|TnL zUVnjWrlJO@I4H@J638#(4FOC-ZAGFy%<^m6p^rTuxYIZ&ZO2toW4x@ZZ{8D8n0<;; z{DzC`-`yBB?n!XgpL+mEW+q{`>*Pn(Qxt5HEp2 zXA*_sph0(|>_YIKPLs&nfN!m|U)04V*438pSlH>dY&Aw8<#P2!6L#1z@{4{Gm+U=4 z1*Mq)7dQWACrlkg|Irm%kj|jNw=^_mx<_6yGMxz0YAcj0YPMpoK75{KV)~J4oeNE< zI}GDZurJ8~G}mhL1J>IifL*~$SMp+D!6bh>+`YWUA)K&D0?_--=h+7)`boQLueS{9 zz}kY3d!J$+0uNIu10AEZ4~cl&eW+&t2@`jLsEW`w z34r+P=>W6ZAI*`inPm=>w@jpzFaW{Nr}S!F(f1I++NhShHaUhlq#dPTXRzPSWGM36GMBSubq` zm2oQlA-17@f1JpPGi2{&@KY`LDffSm!w>wN3z#8)Ct~o-u`gg>-$_AvofBkU=fKtXNkxyFNZCp>&VV>d8{A4 z4{^LYyR-EW)9eP_8tU#IvvTKla08xC2gr3oX?vet%j%Khh6il^%{Uxxk~PzM!yay zZhSSGwAh!$ujhi$w~64y@H=cm=a0S=Wqjez_a-hsQ|APYWlBoVNNONOXFBWo%-pII zNL?gDIJwR8!K^?Xy^iXq2Il>lc*Pecw1nq-RBZ1Z)mMKvp`>;voZFD_fjA_!UDIIqvk&J6pjkYWJ zvndGY8d1rtDlYsty1sx*!N&GpR2a?h2cFS2l5Ey{&9UnvKzu#&HnATq{d|G}E%=+~ zOC~4isfPDrB(636Zjo)^*A{&YWhp^HO&}RQIvU*e5r&Q<`td&qYncPXi9=!k%wY&= z{r{cj@AT^^6<>aFTz(OY-+%P`=;9s*_OCQY)faa6gyW*N$iA zCT$_ce^RPHS(TLff13Z>+BiQydAW~v-;*UA1gFWyCRoQ_pf65Oc5?aj@A{u2i3;(p zJh%Vz_L~29U)i(w2yevKXW_e#G<7As9OuN^m!0P8G=CtGh)TkFM>OQufeOmndf!w0 zIUft!)*@NTBEu(Jx}>-44z{xgIglux%#qOnia+16<+f(&=SeC~~fbd_ET-JNgRJzj|X;92%y z0hxzz_SsVvAm3xL9d|%!(&@8L&Qocoo_Um5m(%x$1*SQ*dhUPVAZvSS(Q!T=FldXW z2!W||6s%NI(dVqFnwJz$1$2N`+pxjLK7wQjYcuJuX&Sz|vWq=8Vf@5;V4Z z;~RZEV5X4Stkw)vi$*y8LiclsS~mT{*mFqtC4M3x%dMPJ6BkIjb)fBRbBA`l>rW&; zPSGO6n=5@)`XwM^BS5|d7Dmt3B8i=B@69i#PJfj*l9@M?YW>3aP4rCgfN215dEx_x70AmIlWw_n~O4i>9Y*~@H|nqehtBTdzbg_ zh_XXncZCb6x+X)Pudrj0Lkh+5FIy3X zyM?f7wftgoHX2W%!_vuXp|BTZaJ`La$DyaJd+A^K^M-X1FMJYwlr=^>mtLpTc_hCQAQO-F5 zO*N!)X*zbr<$n_pJ%s>zW^$i#4or4gzApRL7OxJSs+W zQYl!jCA=T_kEV&Oi(WN|DmEP)$g?hIb`s5_JmSRN^^ea&@HCR(>vpiu0V0qjL! zwDms$+G8yP^-nJ$j=`j6as`}S*ixo7UcnpNNgyLQ2{y-ym>R8UFLJKIN;28WH_)y@ zKFEZoe&J#wDRAI?h5~lp_@8DMvF^1I8)spGqxd-_Il=?YI;x>k?PH|8TDx|2>RQx7 zIpa_6zHf|nSeAZ6sD!D{)vIS)hg1fV5EPZfsx=8rOo@6Fp6ZfmyH()G`=oOKbbT6i zc0b@onF?C1d232YDA9qmp3tli28tuJw2aM6P|EebPlg3*-rS;lOl#qZzbJ{u%!iUx-vS>rVd0}C;of5L??GiM#nCuNhf*QL)d zn57=3DK8@KPBC%*#9+h0tgKQn9x41R-(AnfFX zTupUX8i#_zfX7cs>S%|Sr4E`b5KjIW%AvSGOSm@H#zcb#>j6mdlb#tM7_XGtQVwH&H1JJxeR zBOPShZp1pZSleHkJ=)jI3Op{bE*n{ey#7T9+-TB?Lmw6{JJd0a*tx`Y474Dl$Im_? zbWDx@n_c!)tK?`ADF-+2c*ng74)=b>g$&_(uf}QI@9<*%v#LkuW}64uhaR(^@NLnc z2|i04kW90|MHlQkY}W0hyx0qKPZ;rk~xzbVyxW zYxJe($`q9xtcD?m)%@c|c0I73A0l$U3NHIj8|o^)$crUR!Q8agiNy=S7_pvx?dcC! zwz3SY55!H{z3og8k%f{N=E@W2Jt$Y!tHP)nnQj{X%q(qs9_Og8e0Ju)S{bcd(f6B9 z3Zb8@dC&7=@1*B50Ka}A-mG!O@^{gXL&JULclLW_zGe|dDE3}|RPAN!2PeNO_7R{> z{ciEN3L2yiBn#!T;8$y|D0q>i`Y`MFa(<%>s%+)Qj=R=fn2o~`JT7B<)M>;S1OR8B zcyaztzCDgTgkQ>;E#&J*ws)tQ5-@^izib}#AULO7|C-6{oxbgUiS^|STP4VA?7sTA z8r?2o>DEaLv}9CLDfi_lynLI-s$TZAFK@ZnUM|FCfa-$3|2=)*q|M1k&-#z}nn*(L zoNi9si3`DiL#1N-?%Poeoy{$1~q7=UiYJapUR=eV~B{}_%Gp&ig0BhF5G z4`ZTF?(yQKqG2y4hv=(Ds|BppYrF=UEgCuKb1t%9=0nX!8o4jU1dK`vkEVEd1=hih zP@I%Xm)iVpD5Ynl6|+8=^o(($6A1onZObA^Em`0eJ4UW>fg4l%!q=`^v2Rti^*S3E zkQ(4)g*aLpKxnvL;fkHAnndyHoRz{PACPZ@+KC}<$ktuy4VzO7uqsxJf9tz2>Y=}= z)^m;-s#t#~y*a^`lozee|Te`tb$=8rgIWM$|JV`#;UI~X)NFdqaI>SqQ8kLU%$dJM6wt>C27o&sWEu2o=oqRa;7UeMFRNH2?||wi zTgIQ#6R`eetmj}^AEprXtLKh1*l!fnIg7L?IL;CM{an>MW5tQBiE@v4+13}WjYQw8 zDwRsp5g6Za`oWU7IH2*;1?{fiS5?GNT8A3woL?+X|K@gMIr?bnoyiXL+J;&lc?Q0thWKsi(e}gT)UtFmwk7OV!-*oF>*fy+y{6y?WD7{AXDQITb`7dK ztmH!}^*xzD5dL-ee$(74S}I4>kG$fQm8E5>T)XN@Csm>T%ZH!vjg#`afp&e{9{a zYqS9MD_^#=zu=FhB}m@#L$=j%i5Mv1@*+gsjxPeE#%p3DR=&C90pMuoY+Gw{^!y9akWi+u0@ zz31GVi*xoE`(`r+4J>-iHD}GLr>dTs)gM&e%3`9CqJcmlOnEseH4q4n1_S~-pdbNP zj~z#jB|sKKD>nt_0J;V@73Pt!IC9I zjy}$oetWtM2^8Jp@d@IiX{ALSnr>DX`@d*?${v%`Adwo(NeTsOC2n$PH2?gj^#AF% zq8aU_$FQG%&7-c$N5@jT=kl?6+;LX7i&xmlRKrmfYU4lbkX*1jEQ+Eu&eNQ$C@v|n z?Bw_Lh={m4Q}%OvZt1d?;=cpWCYDltc0*+XhBX4=Dm90gM;Z929^mRv-&i=A ziru(V)?u;vbQYZ&d+FM#f3?tW@a_&&-;^UV!judIf4{P=FIsAq)>Ko~710ojKy$-DxGf;MO z@$RaQ!A~xUaCpjc*HM-dcemxjJT_bzU~BO0X|ns!s3ywc9CyW`g=x{%E*`>Z6l zG~NA8vT3hxX3S?sy6J)fFteup6(JXnf;wVAvJrnJ3m0y35wL#~@hpeO9+!L@O#R$P z`7@Cb^plYfk67EWmPGyQvlDw~^%?&uCXXLNSRVZ2TtjQ7x@WMj>x5^x$TwPN&52)e zHrP}mW!_>E3ceQSnfnt*oE_1~)=>N|k6j~2&Q=e_d8S1pWKV%cwsztiaolCpx7<~5uU74((e4nPKn>pmEyC`rgOlsGoHdKtU3 zolD6e&T?X};`gMhKO^=}+0&Po zvC73tq0Ji2uSN8>8$w!}lmc-O#(Tjn7Pn_7RRX6ccXbpUk1q@N;urlzF#2uv145FW zAJsW~jRgiu_AOO#QRmnN$4bz>Wm{ zss_O$gAEhcq(WAu)eXcFXxqg)-l-5Wu2WaVrlVG79>hrBW^>Y$x8EHPS<|TDqI&&9 zC;mVHM}(kw&AfnIdM}uAkbjnP@~W32AuzIbmSK{rBk+AfJoeytK;kQrhi~fB*0W}M zRj2yw32fffZ(Jrxe*NV5({)X}(Y0h0hisc1Rwf)lOuj!9;wve%;G6Aeu<6M==bo1z zqm$sPO%WLOE508J4nktP;SSe`zx({tIH49s?T3%7li5!ZvN`FF0Zq8}?Q>{m8j7{D zJ)#iu8lt&4^`B7`mH1G!p;P5s-!H>-IB%J;p`&V)NN00xE98{O;P6BGS0C*e&ovRw z2S*o-Mx^@XsFj99w_`5EZ`nD}H$Id45cCAu??vR>&MxB}G_Z5I824?8o080;0*uh_ z>g9Y($bFEW_Baa5E=Uj8@Z}BYt&Gg%pQ`L9_iidq_e!I0W}6HD|v=t9)5Y zbS{yxe**nzM3D`aj@**{cQ}&ffrJK>wzFM*jWQ=6tl=WYNNPkdu|(eMABAVX$qrr8 zyWxar3r*w@-4Wt#zk3@)FWXM0?YCbu8#Tl7AscnA)2%%&CtHC84IBKJqpiF(TLC@F)6hBZ!R_!KGMJw;69%_@>v94T28^Q^Mfs2Jxw#zyk?9ToJYnGygCz|Ol<03S-5)_I;CfZ-;qKzOTlx;~B zC_gsDjS9`+p%oUIdf)rW1|S6@9l1<_Q+1C@&_AKtw%zlBLNGa`xZWVjDIYrKDdZm+ z>SZ#9J0lHC1ux$;Cm>6H5P+A7i0;`rxH}wZcPS^U+Z~g*e{T>%qWE+01jcrA+;kI6 z$oW3$i%LoZ2~+>W#jC>sJ{H3&ijggom{aRS)Bsw_L`sM^cj-6h`*`pU;Z{^GUSk(I z)D1ZzZWSgrcGJ@;NW*iykz*~c2a4LI99v|PQi6-1_AnPe*LGKb<{vi4Gkk0IJRZRT zgM2;YDJ^53R;+4(&!qY5za6NmQ)p%kmQA4SQ!sIoRBza>d!?*PR*gdhf`YjS+$I&) zxU$$O947GrQ`b8b(atu$*j)f@c+>f!N8XJdkiRPj5z`@-E0I1k8HT zS>7ikzX)v$S@q;OI3sy&e#NZ($v3c1MBN(Ucn~x`U4&z43u7ysnCh{85z_q)f?BN% zTGK$n74MuHYnNNI>UT@4n5Y~GI>UJUGi}|RI3(gwK^`zrUB{FfGxh4m9jt<(&4zic zM1?v7&fyA(F~pK=~^?@4(;CD8t=1O`r!Ii|wP6kADgY=Z})@L{9+H295;>GRyX z%Wse&IytZ9&h6)Z?+1Oy1MLKMW#abA)hEj-5m?6);wn0C$(6>q8?RK19(?9YiS^oU zi4=Zc!#5!Y&?LOU?Zg=ep`+Neg|Tmne`7N-f-S3;SSd2Amh)ftJufQ86tFSvl$|7% z?eyKR>Mu5`S)A`)l9y z15?eSBY!dR%z3Q;{k`^=2{rm&q|+r9(Yv^iSI8>tESul2zL9UZ$I3ZxuqlWjjS;zk z4MNM>l{D=dEJ2JMmT1O4=jHC`=vI~DBx(}x;h~H@`d7bMY8jW!p(zd0E{4lt6iIvW zTJ>uIdHnP1;tzCmf3Bi^&12LT?|KVlk#-6>bQwe#8=V-j>t#jx$UKt$;Idr&$G@uA zyV0JY%yHaGFEYLwb>G6DHY>%R3;FbVjMhM*H3K2D zm9NB7h0H*dZ2T0M5{V-u{OQ*IgopCIl-dUIMP~9(E^&+GEqc_KeNcoN)hU|orY+-K zC-_Yl^b2#;KuVWonL`O>Bp^XjV(i5kd5u+$f`3H@E}a>djxRKfkw2>_$|~q{L>dqU zvv9O$h84PeCF??wz?3LAnp!-&$w0IzAA6@u*a{aDxR^VDu5~1;FNckwx0A5R@GB5@ z3fb9cu;Ffa03*qW0znx;xK1aV6M{gX(q4H$Qjy=-z}fWFn0k*EJc7bWhkpDZw}~gE zp!Uy8^47o*H}oCM)V(%CsJ)T@xg&`YoD>3m7M@D?5t81)sqed^rtXuLF=rKp%KdpS zSf*sBE(#TFZ144BOtDIw8#iwV1$yGbELJ~kk(TLjIdr;eMuWtuq$?-Lb86w|peHC5 zi~xQOilcW#I)$DMo1m=H@Yzf#LxTyq##e-ComnqGh{$+moF#Tfufhmg_vJdFB8n=V-lp8I?vdB4J-~=Focm z&NOXKwN>fOH4)Wh?E>dXq%f-lMW;&R74n(?FTLuAd+xeogN;7&?~<0;@crKCxwcVJbo^mR?nA!_r?uo%cuq=|N@Ip? z#aT$%P8hO$f7c4z6wu-jc{kLOcQ}4JNL5IxV7Yb07VGSCsY)c0MD_}i>(_3LOJSxcX?wY%bD{x{XhVQuO?vB6AFs#j zK!e+y{8Ke8StS5Mng4;%uC?R`46L*|aV)6+aS>gPDfhBc%p3gSt}d~a$h!zn%U{># zuQyk7l!X{6EpQ+y@O!Kuf87yCJ3qvK=3gwla$HjFj;0hOHeS=(7u?bELSIxI@_LYq zBXd$cD}0q$ySId2^+1nEZ^o(k)9c0CCG*^W^x02z7!O*L$4dIg=~IxtzO; z4f*+VJmfx(@3gC&4ZDgeLvQmobj+?BKl9HQHaad%x{1;!LcUGAfGfe~e(PSHZb)%a z#wrIg3zjxk<(F7!D{FE-{_ui2NG>ji)7D5LLsZ;WRv3t`B$8T&h1dR+TsEBH8nz0H zT>T_^Aeya#ut#>?daNz|AmMIo`Lq!27w$MtuYc@$eXgvVvDsa%b9+7>f zCsT%`#uylW8F9DEXNhAPuDkqLbvinsa$N3~=bJ$Edo$d_ z5=SCy7HgwarwnlkFY4+p+e}B7rN;$;{yb95=4#Is`xd29%u$hY`F^9Rf=m0^d=nwB z5H)d&yZMcHq*$S)OUWi9Oue(MwWi(~hcAKa-GL&P20ZIQx9We;Tez+Jj?SWh42U5A zDFbqRM{RHKRS>)`_fm3!Rh7q;~7Dw37XzaZquY3{r(rIbad zh1+aPoJYPR%cq#0W*}&M!ycDK)w+mFrr7=7&Q^)VHMI7{e$G0N=P`7k2-2@gUh6zQ zD};LXI|H#;VhG%N__h9|R2(8~%bY^CQPKl`o}kXaW-xMlogT8N%{yn0Oxci;=$E|p z&S@v55RLqOoT$^D7q5p}IEDn6c+>mPth3Q=?D|~aKkM#F=Vf8wcSfW*&?tvZuO$qb zEIiR1N0-pkV6;N4=rUZ8rtNhB^K?ROm4Q)aM2ZnaQ5<2*`=q<>?4=~a;Pd(m`jm4H z9Es~#?31rORBG=fEV7pReTN6s=vgl#U-ni047pTKoTn8nN-gt!E&#{ek$Cs|O88A? zhZ`69!7D4{`82vLezz{*8fUVZ+jdX*X<0+pWU^Zpqp48B9ok2jHLRFp_bN1lNn zch2BK`o}+X#b0)E`JedQ?Qi%xB@`IAw>eU&RX~i9qBv~kGdIrkvVu?aM+LG&sNROz z!)_5~z6mT=`A`skbD4ej@xSRYSZcgl^0EUP$i7I$zhy||^CQzKA+Z@&fstEAHS+w4 z9$KiAHmH7ivhmzY=3#WGQf)H? zrcq-mpbO{{NBT-tl75#qqEwdug(YhI2D3n}-T6y-m4_TFdp+{jZ1^3-U6J`ui!8c^ zgX(QY)=r)cDL=h>eufBnQ{)1f;Dfm@T_+K{QCdh|oHFE#P96c9N%8A6zeZt^F!rRA z)vL%8Dja3wW?hh3`C-3CXv~`9C}n@h+82#HOwavS87mR+6!1DR)J|ucbNu3+*Cl!l zv;?erMlfF&0|{RP+p!8TjbroM{Eqv3N?u<(DYS3{Q}5Z{&iB;idKgIjI1;hng!U%` z%7^8`wJUPd=LGWKYOexK-v^`*z!jz~2()-cyhuvOrAms5kajo%=~eeqp% z8Qr|`{|M3zSM8?qHBJ5KCzdB)mNEd7ESpbIZ-78~jml zsi3*F=bJwN8tdp-JKjCQ58gm^Q(%#O_=qndcL%BN4}Tm*z=f$drcUU3)*#ouLF0lxLk)IOe%9>xDe?(F_5Gf#C-0Q#%Z6X((>?6R3Ca;R)5GU;;0vERQy$pElV2N>s2=-ndG-wX zH&@N7(9uY&FHR+4*nSFC$Z`yb&(6xg4pK?!V9nA)FuDbI{(NKPdc9yGlbtWS9M1-~ zM+wxBqiV&3joc77m8U*5ReLXoHQG`jEEiHs5!YY*=2E_zQs;1mqMDKth9f2pm*TU% zUODZ~>$*Z3^r35h(s&q5a@(|ZSBTIqKmYM?-Z>hCJxT|%l%F--9bvoc!V4Z?*-Rv_ zIUvtYXoX@LzQ$~9on0zUtMV;>ww`IW(B0WK$-?U%Ze%@)kZuob9#Ly$w738 zcOA8EvgBMVz=kQtx*_`);`7%TW(-C@9Hg28cJORmAHOxG8Br*!-axD^t0q^__yQH7Bpuu#Whut{u+?&y5D)>ev6e zdIYIP%digxMpf42JgDEa86ECFwKy&iJqv6zQ(mgUqyp_l)k=IVU#O!9f*X1m+~Y^) z8AT6DhR4+F9qTwlirPx{Tkoq797(~q%dZ-ujSZsb!>(WcA@D+F?!**_WT*GKReo)= zywO=SAgQ07BVXR#p*@qb{@&hr#`4wN&D7oHua;=0K>gDPj5j*;mfS+hFzV+geUKWq zS8Q|dOPi?Y9_uJGAUtwZIdd21p&;B*M2tBaov_ljw41qPqQ}iW8)Gbo6zJ%F@=z0B z%cev9npNPc-}~}5I!l52x=;P~)f}w$^Y0qZJTo+p`<} z&|;ebJB?D0H)&G()xLh|%^v6>Gjvl2?!WJd&weXu0R@SUr}6~{ywn}Bl<=>v=3p?* z=GC9vT=ma3rsERVj>+sIHDlHi?U?F zOGSp$76zg1U8HJCj~ z5(y@V5EtX2a@wJGqg9Q0$pTTlMjCmbjf;VY`9n;=q3`MOaG*&dG^l>R%bl{_y` znf=1!q38*lm^`~tvn*NL$bQ`P<)-Wa$|kQiz3AV9NY!Sk$^3^eT>>3hOI1zi;V!&4<2;koS?7O#BH`~O>2VovJ0S-p6} ztafF1a~yW4mRzjCthotHL#>4T=SYS?vCRa{*@D+nA**VM#+PA$kSI_SDkT#D(Ma#i zIRhKPoMaA&t3(4!@#$Zdnn1>yG-wKdtHzg7K$yV0S%C>EBxycfVh#hKSFNO)`RUR| zB_IxFO|8F|$N|z{p;HMk|7LOBoX;4ScptAabpSiW zeEN4k(k6+FeI;tF#|mB7`X+KDLbiK}jPqA|!ltf{*K9{Kg%UF}Gvn2OLg^$>x&!t% zvHju3o>{-1RpesU{8FdNtk-O%7X5ev{0?@6lE749UUd&1NGG; zuD=SUwAb-?&%x(3f5nbdqMCSDe#Pl?se{|4xtuN|qJwX}meU8@)Zf6q=}Xe{H0U|zvHvHCU*wE zo1LQS{qJ$K9(Px+*F+tce->EwxZ(&@+^&vHn<1z4_qR)r$#+j{F8llFFXIu%J=kL7 zb@np|to#nMR4|GdzuOB`9ni~Y3K50==;wGfk7PjIBJ6CN3!K=AMIp?1?f-axSxpJd zKNuL}ivjx?)c15mz++>OAP89;eiDi$#d`Lz~N8|$B=)( z^+rl!d~v2kz{53zVTZqAIT>trKh&tqeQOj|W!Hd8^BsT^Nw%NF0*soMr;u=3*U`$1 zO0|U=I5kc-Qj8}H6jLm^nU*^PW8jgo=HV&(oMl-ZDZW#Rup%Xzq;9SDMx6c7Hr_#> z$Q0kk-TPhlRXBN}K#}O?xSvw}=WIKVMze=)T27>oT)`j~uzB2Nq3A@DkM~QDs%mO# zgU>Q>z8QJ1AiBA`E9!oBBg{>T0J+~@INDCEG7$+w(TT6F)?yBWV`Ob-E4X~&8Gp6; z2z(Op(b_D?X=QBKoj_}Q?kzUF^!VTfJ^!I?d%pXF$=-oVtx$<%ccvmcIGlM70LQ^v z&g)8z5;cvMcaD(vc&{%-^^`Y;QaXxWX(#c##_2zoFcE^B<;rK*2TX^*(KdA)>vb=oeZ+99~`iCl<9CaZXOcre!zzjR$ifIuJvP(hAQovEIt^AuYo*{l8ER6)p4QJi_)buN$Ogyp!b z1#E>Q0hpJIL8|uWQ#^_+5Y3g-cpwA2F+?(5rgI-%xzlCPcfu$wmMa|s7Oq@inC3tstoMAAf%osA10OP9jrSQlq zGwcLT)TLrs&_(!amjmWji6yBd=oA{ip9H1>ZcoS+o?anj)03-UfDK{bNEr66|Ja!- zQqO%Ix-QD|9){S7aF`58E&=e+rnv?8n;FiN^G)tMQ38upVq#*t(E<+jwo@=teuvbX zKjE{pI<87A+U1E$ZP4?0R~y=JkcpX@*%|6b3^Wcj$yo6UyiMKx-*t+cKFbTwQGTN= zqO6u9QQBhm2L(*MMuj6cLGO-U4gb-_OET{k7&0d9ifkIOL5c%t`1nzR2 z)RRUwx|q*jYCOPbefG*47)3>m{J%~T<^1W3B0GWO61IvU3R^LtWh6Ta%EG1%)c@h) z+jkTNrVFfm{_5R#NY`~2WD>NUx-WrX5_V0U#G;*iZ5-Hc_v6bzB__-e)OFZ;8jaER z8(k#((<`rJBh}j?Ck-mEfcZnC04)Bw_a=7%1St11-A7A06lL-36_eK zDR&ShsJbdKL{njl303C(cFy4+#_yv#11Z45o6UZ=?;zb4bM=mJP?WwVu;?G_3zcAq zLf>^)hoG^bvv!5ik0!LVSdnmI$%0hcV57Ja^&*9s?|@fX&xBJt*%)GIqrZUjHTkXy zzzgo*84b&JRz~`bAptu1B2OkNUKT>cVZ@5377`r%!i88_%*09F+gmV-Qp{*k7t|kW z`e~kBFu(({R$y#%Oj~T`kMJN}DLu{yGdL0mmT(e)S8y>p5}?8qdC>>qp-}CENGoUU zK5)b1PV2VnBlK_{tI z9x$7n=10voO7zfo-2?fCgEh)B$}!3_+VFay4bDKL!pM*m2`OAWu&EVtI{El;_=pys zBa6_MK`R@b-yMoRmBwSOruOVjmawibh?JT6pBWL;begwZ7}lf~!#)&mC->@zAQXiX z<(jfmnvbd%D#fvTO4eA7#Q!ABdjFlRTkvI;hXe<*-XTi7_aj^unnaEi&ItPzCtOY2 z0*A2y(OL`FwrMAKey@cpD-V4XX-7?-?+(dekYDb+)gGkhcqG-OKduoRGMXanlL9$YYa(?-(;T$!vS}s z`z-TXk)!Q}B9sFlP)-Bty1E)%S9i!$iDSi z3K8h{>SuHca@fT56z2Ze!9pk1OU~~|7X?CuZ(04aQ-&q#|!qg`@hA%F`5*RR5}Ht2XHEkb@KQ z;&EUup`h1YM1wggce`SUCkm3wgBE_@PuF>uVU*eJ<4Wfbd7;r%wg~;TT|8f5zNX%L zY!ultL@yr@Z5mrzijb`Gdp*BRVF$kU7=q_N3g;h;_L2rMPi#;S5YS1G9r}0O(sI6k zPZ3soJIXGZBDP7kqUs7_9kfkmud%|WYs6}#=;I4sao$H)4%yK+)cg|006;Y}S!iG= zyMs6}>>21e$~1~YY7vNvXSoN2t)$>b3|qtuM1+1CxBRoW$D%A`N|L44oiEv1i)M=( zI{_!-jIs%)TFx7K9z+~F`$dywr(CvzpnWpJaa=NJ5tp12DpiTmkHE+^yv@WlXpxL` zTW_4rj|to3%q1~=G&MDKGny(K`|`^e+kn!5f&8oM{rGQLWsF6ykiD-apVch0p8^wB zG{fob(Y`q=T<_z2*l{=;)Rz+(!I(uRmS>oi*s6%S*sR0b3%xx#JvEha$q}s6S?NFD z=z4ZC#4#jgdM6Xp*%<_HdBU}Mnp1>^6yD8A3L8I?!Vl)959GIr*Ge#*^9~d3c#i&X^6RYb+OJ2R$>&f$f!Z zsH&G#(*t501e-pipx;};wLmb9j97b>EdKhec93VOSII zP3Q-U$C#!fVUo&S9}oKF_opNsU-2!&pNb#Bh8eQFgDC>;T%orI4T$2g?UqByXAifB z=&V$nYu?tQnM{(5h-mn+M24CrYSK`X$SmApj!sVHy3-uYD!AV0qeJ&;2_B^1){b@yzNXyl45E2zJd9$twngMzW>?MaVJv8k>EiZC}o z-X`5louC*kTJO0ci7_CuZKFXwtjx`RR2mJvFry4|@ zaf;(DBRU^R#h}zzi6v)p`%ndeOI?hY9|BED{*VciW?p2e4W;M4T}7q%DrT9YpHF3m zOGigHF{LcVhKKi~dWMrXNQe-nK9$l06z*3c`!VxA71_Q~p37L7dot#|wRA9JipGcs4=xz5l=4%d| zT{@_T%b&?bZM$f)KXKHOu=(>BS%aPhz_hz0Xp4)hjlHLdl|NE_H(`f|V}dL7s|W|v zSp_jOU@Qj>6cdk%A0l;r^A{0EP(B9E=IZIdIu9Acy?uNXl?Qw z>JA78UP_(fARP6I1Q$(dlye~BMxYq}y!xKGyMmU0d#u9(7ki(8k0QQpD&9~9et{w- zWVyWn3S3?bp_Pfm;p7UYAB34iGpKM15%kcb+;GB6fe63D_XkrW!%H0*;j$w05>o}6 z%Np4fnAsrjRZ;H4@5CXwOM@7_4Q96sy2df6;5nq>!%f7tr9m#l*EyxB`fxa1>)L@I z;{x8Mf~guu8y5xN83c%G!SGZj;c3wMJDkhFuu0yk1265J=}}-kCDy~RI_UY2k=N3IqiP8gAL-G&a_li zRGvDhoWYmjfVfjzOaoCs;6og+ILM@OJ`hhFDH;#5tT;sM1is5jK}@jZQYE;RIs}i7 zM3V3zFpf(GbNLIyVH|pTys+;& zF&$OIIx)iIJwX9c9#=#|`%{3l4S%~?_Eu67Q^AV}H&Qh6a~d!N62PjC^t}8ETHNI& zzUlpEn++tr=Wz~`#Z`J^k%KApZc&1-91&z`b!fJ~SjH3};=qf#JWSjMonmiyrC4^{ zU(6-46uQYUh3hDy(GbQ3hwq_gr}vw8Z3l^~NQr+V$548ZiZC#JaXT!S@)Qn!)hdfW zjiO-63{c7(8OeU|;`{ejdvxc%-UK3IH1W=F8YrTyx|4Q3_7l4@Lfp~9826_iGc6g`)h|W8`hm~qei)ZmT=rp8BAayo9 zGm{k!R8vAhY~1I`hc|TxKCLa(F~g&skkxz zSITpsXxDvAaulRpx^{K!#+8YkNtARBATAp)2Prg~CT;IJK^VoK`6k-K8*vwzSeY19 z3gzBdM7T!W*#r2tnsn8-iDaiDF64dh)x}Yw@%kgyerkW+68(1xXx;k2k%Ewv;*|o+ zKE>up>{C-l_@D*R*;f%H8xq$lZ-uSUw>x9eANq!-+OIdsNZckGGB%z59e}y)h+pN5 z@D8B1PP$&*9xmm0fhEbe*|=aXCRewrY4@V4Sc%A$7M;k05pHXAE`nHoEhU3Mr@%jZ zS3HRa#XZ_)rZAVTfQJ(eM0ot%4{pQh!x`-fen#uOk9;c`gL&ahJm~pw2MJ<0I9FLv z{;z3&9xUbTfQxA`lS2@}4SZ&fg4~At3z6>PD9AUiqlyxQa}ncJxJ_ICB~szPs;!}m zyO{)1IHFfOQv*HYY(QwkzlzF2gJ%Wd@1(^M&Y6BkTn_Tep)y_%!`?6r=(sykR6y1E z5~=v_2*2or;Ub(XmP1VdEEE60WsC1(dKRA_iVbjp`S89buN-2saC|oi55*^B9W;G9S1&mW3f2p%8q8mg4!f54Q*k0hwL3mshNH!Ol1SeNjoF^MIh=T>>r$#YAsy^6p6}92rMLfymTTB69$3KcCbm}`~8Lf6W^?+_64AK?f0HI$S)oGwKlj> zq0Hks1W@G_@15Xt?VM}EC?*idlmI&Zs~zOK4eY(ZV`39i4 z{uzjci!08^=Ya3nHAd_~@P3W*F{dcoU+>Pa=Jmi!&zW-zo*if%0z`fc{7pvM8NC{XfDJc!=0km^<^id*4XU~yWz~k+b zE#NDe>@~}Q5&*i*^~r`BwJMi>y#o_o10aUMk7(%Vc?{XUYIlaGn|L<9pq z$a%Wi(;oV#Y0C!S+nE$Kfj2{+Na7Np7BP9%{oXHWk7bCe+I{P-UbO`(n?a)rTiEc( z$Q$}HGy3a%`GoZe{u&QE?4=`7qJy~VasaTkn>}`(eG-Pm1FTj-GW@&EVUlzKX9~~l z@t}7Hzb(^c4kl&+61Lhh-X&Rl{cdtjI)=KaxjLN?w8z$x6IX@IPrI=74I zGx9sv{9S1(seq}f`(~!XIE8^1=pd8K_C5KahxH3^LT?$~X6Cf{T&9<6`&?bDgi(I1 zfJe5_)P;+CyxS5xTl6{HA^xP0MB_oSHes8@=Xty;Z5#SkqscJ!S*8Cy^yV<&F{ApM zt0jrlV!)#xq~()B0S!RCib~7M7Tu}oz9CNTa;a5FnhOa zaAHZgUVjXlU|&iB zt?hp4cfuRoF;_4lKrgmY0`bgeZ6k;@Gsn-ByBKnzcq89q z>5m^j#`_+Tt^-t@`s2g>VNeuM5K(RSz0U7u3Ad>DXPN|e&vSY}2(8@Yu$i#io14Nc z`zsOOX1I!X{d11TKU`a_D8;#bF7`EuyBe&QPlwHbn#MCs3Spg#lEG#mT4eU%yEvHB z7TpT~DietX{JB2>Pb?-E^L;RhrNpEw7$_SZC)T|X54N9Za<`TbSpw<+`MHu|i|}MR z+O;;S*X&Wjau@&E_W{wBgoh)EE14;QCEtYVGmd+JI5EQ!M-XWY#Pe}nY}NIpH@d{i zoq-p{jw-8t4UdJ%AEo_^7KmohB-x$YujxksHsA6?6Nm<&C2~TLu>saIC=sAURRKs7x{E861xM z*aTP#jeyxhn`q5C+mGHQ_#>umjNB;Ti2!XKpy)aT+F7F55qMF+2Y|F!XK>WVSt5}L zEds+kIy&z7-$P!2aIbjfe55EoJ&8%|Czm@Zki};tWRq z(nD@?Ih`?F82TU6_(eOslPV11MnpzdBoE>g4uMdmrp651EVOu~#C(we;bLQBr*oOZ zt^|Gkh)jZnXpRD8C;+%p{bvCn8gI3}zz?^!fC$cj;1k^js1t>L(JEt3(2Yt)e6J<_ z;TatT3Jo5;k_yg%)jto-%l5yK7e{ax1(Dc>4qV*doZ}fIQ)9v`K@31OhX|++rh`7T zaaylSP{E0PXw}qxVm~%^cH}ua?yAq+gys4yG>1t$vX8FxfND*_YJuN7ian67^ONNBVfB7khBW7!!9=59j!2^28WsH4$PE4_j0u@v<}KxiX_agv9?h@e3{+!dH}eSLjs)9>HEm+MjVd0nmj{{z=Q zs|l6-)f9l!{7Y=g0~+A?L=Eyky}Sh+G~m*u3aXz@fw-ytuW%**ua6XdegDv(?Xf&8 zQ2S6bXQx|Pm%aG;Aneb0%F;c<#;~3`paNP>pg4|>Tg_eO-BMjct)WKhcT_MPRWzCy zm8HisGSAeen5SOmg{^&U_i^Z`i1@keVIiiEb5W+Q{SyifVp3H{gYzk7G^cfInoE

uyUTrV-JK+ zI{Hmusa3D}*yfKf?X^;VObUGaIE$I`;=faA{P1$*x3Z}{yUJamBL0y0=VXZF#y;O6 zh|bT0_W!n$4a&tFBX?}MPkdj8>~_e^q_=m`pS{AD0?1Cem)znnEPfzT&u^dGMe9r^ z4^95~3Y=rK+#tk>YaaCcS<3#kz&V_ti-kZ6I871fVG$^IGdnM9b(P3@Tq4^Hx3J20{(v=Wil!BQO_rc=BLHt`sVsa+^Wsb3vTn3 z{DO@PWdkR*$#J}}UWZ1SYEHRMo~$n>QN%>uzdcZfUH6 zWauEJpzhYu-bxOy<(N4{u6pN2)-;>|{9y}tA|tT47gXH>>cF|bj|T2p-P+cR>wQWk z_+D)KG38_HEt8+06}<6@DOeBtXxqFv`k6%h218%2Eo?Y1@G1{Z%63XhZlm3M>}0?8 z3X`bh*5UuTz|rt7fwY_Y^OOza+l~6g-^jPzR}2}(zAkC8eJpNbb-TR5h3p3Gw+efT z$8!#lccf;ak0&)~G$e_@9(BKYpD~av@gmr{oU?s^+YSVCDHrTvKvpS1s~cduqAhVz z#XtrcQ&ronM2a07s?>zy_^qb+(z}z2w0cH)JyoCeIyL(%XT=#z0A0hXpZW|-JhxO! zZsjsYD^sIpisdtcrYkKef3e4}^s_GX)aw9qM@-BRM|j$QvF|j$Qm*^R3t?p#juwJ@%kR^w8RqE1wn zKePCKEQwLizu%!*VrpJ%;_)`%F$Rr*NpevYxvBew`rXgD>CbDNK+`EH8`>=04$X`! zI!$l4mT?ZsgO~0%e`eMCLJ)d#XxS#(2&Y*#Az}OGGim2W8MaqgMTG}Q3C4F^qo>V= zMg!|(LwX*>G6&5uMK-QjDGnQ(jt8gRqLUQ*Ru~Bl%>Ygr7&IYDsc5SLN`4Z8@xT`O z=s0f&YSXUf2HM$h#qb(72tx)y(5GHDznBlJ$}}^^rr92eXhbZ^j?K{p%LLYzX)Er_ zcyZSIn;Eow~= zwHsewPbEA;VIT1y1A(sm6~b;zG($+xmmjvWe7jx>3yTT4Jt7VqJ^kSJF2`})L=PAt zvcHnzYf!rbzZi;zank5S| z^~W({>^RS2UE+niFXa6xR&Lh?E}RotYG4EU3qU|pY@pwSB-L2{c{KgU#u;n6ABnGC z)ud*@ap*E->ngcs$q!f-^ak=}B8Rpq%%6E@^L9U{^*Umrc*k@g#D!8O7K@!|&eM6?I{M zYO>C3#}Rp4wLiNuNybWV>91MNa{Wg+B>_`VbW)B)53{>J_m@LlegW! z<=W-*pl(omtCH|D=20aO$ zthvZKZokFK$0@(Nmsth)P=Q|!e)3cEKUE4y*-BjZf=&*38-C!o?a%{F!1BEcnwH+P zZ=8Z!N`{fodk;I8G>d<0bkaiC-t1`BcS+;)I$pmfmDh>!sI<=V`*5<(HoT)XF&V2| z;x?i+$&&E6pmw%PcBsO1@il8cHKW=Eaf9Vb^nNW{c^sXi?QOx}Osctpr z?5j>2L}Bgl^A1;`bvVHyIS&BJ^nX20xYs-h+6oyfBa=JOKKR}@NJa@Y4y`7Yg-j1j zH#AewPMfS&bJ4SX&&|hV2j4&=IICU$Bi#z(@OjkFdfbQZs~jSd0|v#RZ12!?4r2Q& z5~*nQbLjJLfQ_cKF0CsFx|Cv9yppIR^UUKWARCX59b<&{yg`bqI@@>WP*3r%`Ob4k z1wV`fd~^n+{4P6I|ETCpB8pHWA?RBKi=PQ_PJG6XrtQ`UG5xS50xl)n)Y|3!%6k6J zvE?+-Z=F-l0~}~AFEfCrG5CugZcxJJ-DlSfnemH;7;)d6nrPdLUjMM-b$Z8on|w0u zGi!#nC8Jq`)l(u}q*^<^^pi)w144N$M6rUre)#)MGf(*ZXIJ-CVh85QK&~OiEHcL* zd!81QXyLPcn2paTdnPRCjp#K!ruVl8_NYRBEzR7GJzVx zZi1<(5dSWkjjK(sqGJOR*R>Y+P!irzS>{<^q)5U~GAkDK|6ynou$m8>+UG{QggYW=lW;MF^40*SsW*-kfIw&SIIQ z_1j}?!9W`r0q$+Ud6e(G-7lr~s=)MUjS7e}(7l@!v`vqhwXA%Np7fdh-dbK7rjDZT zo))t)F9O(cgpQw4<{TQ}#L!s;f>55KfVWGd@;zVXI-4*Ml5E?T*zIQRKq2A*wuzeM zKV%->p+880Ott?d?jz-8f? zHo3%i{50pX94M5&>?7dvR!?m|ylVIo~ZORG48=Jeki9TIT!x5>t9DLhnt8 zyh=>YhKW~VS}o;I?hhiL^qN9<`-JzIbUSv&KhL?m8ntF~eiQyw*EC57%f|+MN9z-< zgwp|YMNRb&@|Vr`dmeh?NuNGC5LBIK5U3UFb}yq+XgzknpfZ)`^8M|-jY4p#2g0!L zz1(!8R$UUM)z%)Oj9~5McwtiiUMb;=jN|tgvGST1K_?=I0iC#op`!?LzCX%70!0m9 z5=|n(rMA>DL6Rf9fc4*8V{>YUy-L|HNq<7pD%@atLtY*jS#XR*Y~V>Tpym=i z+m5FW*vG;;v{-?cPYFU=aXEAwu(T^?;y8bzzh5{Pmvg@QXi@(ltQ_3_yP;C~TY!2= z>z3N0b|HpOc$=g+szrm*Wprmh5`L z)w-L;`}L|#aig>vxIjA-4MEWSpM}<@!7B$EuCb`tnrN+I2z-`O!@$L;UdXs71xoo=(wz$e@usAm(_{G~IH^oKeiOC$Cn0 z&S=-FXL|0PwaiG0=G%K|n%f{jJ92HcqxN@Xe%9qXmqhuiOuSuF&Isqs-ki95-dl5K3wg)!}A~CC8K@$zb5(z1WW05imP@FZH+5b zf(!3(YJ_gs9y&_CfAwfv<>cnv#sB#qBlRNgxu=P*z7;!#khm5;PdaTYD<7jPrvmG6 zaUOG-qN+kEHg3bHZ%2XO`|r>cRbvdU#hMM<@hC5jzC>(pR4qgv^C#?aX!?U+m$Ny| zj7{VgPrUyskTLO$jx(o|J@*-*I*57scP!bdP=Howz1J|=2WqU;ME3G)OP&Lw&ZACo zmC)O$_Uyg5`=qPY`0&}Z-FGnX-x4XM(1?1~|1Gd^u5&Ko8lJI^OZ$?iXFB7ll!jnB ztsMH%+I+}xO{V&(r?2>j^O{*eoL?XFkv$R=LDjyq4F|Jl`CVdB0i+d3+=Y^LHQ_+LSdVF!D2+u$>1myV<*(`M-q9pik! zg7u`+qxGpN`Xb=?ChqYlDeKI}N*}g$*7v%4zifP+pD^>Ltd&@K$qN+qK zHs^QQ=-dSVMMeN}!op*q5F>d+9G#UY<-gnhm0IfOq5d5wPb6(f zOS^Ld>&rdv^X|#QD7%ZDWH0tJ5i0%;><%b;T$G|T6E!(qRf;nQZ*xPG(h79Fvu3~T zOE_MWje5{et5TClWUZDpx~m*(%A)I6LS7nVx@Jy$1gQ1RqL ziwS(-b2YTurDKXWCYjc^cIumiPF}_W)9WLY_bbg?^f>6hn-!2eC2lyu ztrhB>-WBrh$YTAgUB~7!kg0XD)cf9JeddWzyPo}Z>9uVkc(+3n?WJCBoBU@P-PjWn zZsn=YzWzD5;U}H7i!m@J=_B)5sviUnDIjnlbYhr&e)3f17j7++#sJ~Tdu;%SHMX(j-+C@Mh0`*9_hg;NaQ0riP%o$6b99@s26g3W@RuvNuB;i~yQ zDxvfAEu;Hb{@;IJ9ilqA+%s0g_A_-a>X*OWa<)#EMuNdQ@<4f74pk05(A~g54qoA{tgCVnQe`Ii2dW#KDqyy>sr<&4k^ zJj%g^!gR?I)s~wqH|sB;=Ak{Oa9|Jk+2!A%10({s50uv;h{uR0HfL}6G9)|j{-gZK zdL;YbZHaBqe`ADk8^x+niDgu2B@*3hv2Lj0B0oORiZg+`JV*`*$ga{lD*=hI@5*AH zyEy?jE40Ui_Zi5M>;QxKR%7XZGW8&Kvm?kzErVM)HCW96<5qeFRP~a#7yBCgYn;iQ zkz;`N(46>0^X;Ejp8@kV1ay?ftuYXtJ1{=or%|*@|9rD26$eSSQW()A{~NEX+c)+@ zoh~~6lvq(#gb?DKeU1d7Vz4>SH=h2nulSEM(Xt3}eV=5OizKnDc!u|185r`m zF<1yAXH*;%oKHpFE20Kg6x!#%c2yAW>M5lv{e1u3#Ti< z{%B~F(Io!Y4HEj2`%-L^_)_^&`_kOO@!#=|D_-4S#izH#Fqj-j3M*e0_Ok!Gw?X>* zXQVv{lxL5ar!q37it#J}L=y$hn)@uXD4Hooc5$yL_}=${L&5#ub-DsPXtfcH+XeV7 z3S139%N7X(rZ$uNlK`g})jx4PLk62c6eJ;j1F`-7nEMRrAU%3q21v4O)BHOIL*;%D%8?jK7)v;0H3F&1o+ zy0GCXp;K$0!R-_@1FuYJ3P_LEex?R#w0?(7*zV>?z6L)&&65;2?aX_9K8X_u$C%sL z+fn+)Rl)jm3vC3|1V1tWCYnmjvksCgTYwaNy*tI@66ovBD6MO zeE*-@T3mseI%x-lt;qC)D+n~glGXXFM=!h=$EYxFG_1xyf(tb9Ud0B0MwI6aru=4p zxFOuZ(ZL1gb5J;QPdyb(edGY(I%j(>%|zt(kf06y4iJ04jfRJfOTm$L4G+Ks{NKlc zqk4Z!p^;8GXs9{yi?H29>Yo*>W;ffXR#O^a)^7o|K$Kkbe$gBm12=v!i^1r2qt%D5an=dV1 zM|rQ)Ees2!4IEIF4drj-JH_1BqqZJ%DP#fkDxqofc$oE7=7P7r(qf1Vnw{3dflpK~I*#lRE_#gL!vf zI4xYY#?)UDwFDgvE%wI`vAyZ)?xmA6KB&LGG*d|L0G$O=$K}b!nsz&2fQkUS^d2O7 zO(*h~ou;co0`zcBdr(#KHo+SB)f3mM zha{AZ!}iYi=aiL`xN@aKAnICL30mVAqSy{0Hz3I?3k<*?Z!$`c#e2Y)I!K51xGWgc z7l(qQxw(0~B_LpXv@jA;c8?i#c6BKoKAJf9z0M}%@8q9d&CM|(yIcDUb&a z`@;2sbWJeHme>qsX_e`4QwUl|q22cMJo8zOU@dW5>qTMoB2)!w$C8)7K8|t*^?n0r zCx=O+2ry)+iX@JWl^y9RcLx8strIrl|9ACCA52t3`J&5#eYm&h-N=|Y?zt&gV^4q(Y# zo{EEPL|G6f51-T2XNqHF%p75cfV&$M>YSa`)tBFWFRRe=&_`A4;# z3x|}?EV)7!S_R$)n18Nb=+bkB@>1^X`Qi9vEg2#+*mmUBA??IM*1C@J+8|#8TjY2j z9;>F)-h7Fph4Kisj09zGyO{=t-q1O2$$3K`2s64ffGfC_O zL3@9`MW#KpXA$R@mX_A4cQH>BdsDtU$;LWrj)2U4W@9ygJ~M3g_1+?s7FA1)h45*A zd!=0oyc~U;Z~#p~UjI7ZnN&sDEm6;BdIA27osUm#)T~7ho}~5cS`+Yjgl&}tIL zK*|gtgPL|zIl7(AJnaDtszQN_j1$1DvVHwk@-(Zp#f@XhjjraauX91Gi2OE zc9;vCOC$^GO+Fr5#_}e5F_x_hw!`Msqh<&Sx?p$$GQ%?r%EzxuPM_!RYnP}6U7hc7 zuC+c+pboIi_Tp}gqdkVxb}e9*9EOUR^yh}~KJumTDX*l0XA$k4JL$({F+C zEh$hNI8;gz=i;Cg#_K5>VMY$lL01EUi16^(`B*^tDp6?hKo{sF^WE|kAeN1`(ow+a zV}_b|de4eG&3*qhJPMp*3@*rg(pG7}O#~%G7DAG?t<3o;@%QFrbaZp>N1c|nRx1GL zu%pb4Jb?jW9Qd=??JX?wb|v>Nd*0Jaq7!PX^d(9$ReKQlW85J%L+C8fY!LMAi*JY7 z+JHfs#|+tSG3MUzb3}U>c36kGY|DmBN6?7O%+JriEXhyhi1rqN2eof+p{uu~WxLK0 zbuS1)&kpf+v-QelCBMUx5uD;3UGiz|rl+QG3xqjfsQvUI>yKfDXOL^k#gXEQl=5nks>xj{RVBx^b9qC{aTA+X+ZIU%PDrW`F1ssBYdMc5ig z%7H9}MtnP>K&wdM5R*>W#9V?&<}b~Q$ssX;%z?x}_gW;jEd-;_|EpxI*m>x$iV~;` z+NOb=+|y2P0D^0VS@W?(wPJM#?)L&>c6hv0bE~vC>*HSKT$NQ{!E&2)uCGhi=4IcsWz+ii_fceeA``kaJ!emQeIhWbf2_Sv+3U+y@k%?9$5r}Mc{_$ z1_yeFX@>+m#0bRa3R8VL7bf|%0le!i|S#GQmpuoyOs6Fg7QL3I=Dt40HO-)T}GBQ+AlOC$(|j$VwPrQ~5g z5*k#neKI$Tgk3|HLKCxOilh~}C#XWULCso7WTeaiiKG@8K}S8>`es>gQFlZ+s?iln zR$F0MS1!b(j#usOFBx*6^>2+CR_aFY;*-_E6@Fw(ElgpZvI2Re9M%h?q2WnHsfp>u zP3zL+oSIMZeIFT-g&l~QWDX3(_ahGpZ$)Mh_mPSvLBorc>Ud*u8|0272_-5 z$_oQA2a}Ptdq{VpPizs&t05vNf1@KB5z+2&6Q z3#?863olf_L0ZXyJU5#*bZ%mvi_(KPR^oDSbCit81bNtMw_+_iqrOv%SoGp^Mju8~ z2q6g-4>$9|*s7F#eT(vm?{kP!*jkEiunzP)D+v8`w49I#p~YkbNPk0hBz&Y?XGEEklEngAb8XCTd2&MNlgs8Hy7;sG~M2VEwDC9(1YJamy`VF1W z(^%dN?C<3KqPTRyOLVIt0y#Uxum|D{Js5q+Pg3~cm#ifP(;m&{Y+~mTO8fu5e`HxT zh5wlW&M4Y3F{?V?!-y8lo^WGihn@9|ArZD>pm1$V?0cj6pQc%m_XY-xn6*7dkfFU6 zTKtMH6iZOTKZLe}{LnUs90|<0aF`k7vOO6I4Lp1#md&`XSWJpgIIrvh<_(CtbCc1< zX#?pJ@SB=MNLDcdqOx#L!*ep&cyG9FxYs+)X%A<@T~lPjP9foD`h+DBnQ!`Jl$$N|GK1#3Y5J+30@dXx150EKz3`F;0$dsW7GmV0!HFYX5 zrpAgd2c7uC`>+WNGxW<=WHTqM!1j_wm*CAasYwpWS8I zDQhKXB^%nM$IelJ%)my=!R__ctIhwPjn1`r;m=^O?~eh9XFVkmeYXoUsliGp*nv`}VuQ}jP5 zKHvN!OUm?JG`iIxiPQkC#H)lX5LFAQMuaZ{CMnl!G`MB-gJ0F+C}zC%*Au2nmHS5@sOUao|B9V0nnq zTr(`6M58RcgA*RN84;Yi0V01(T|#IyN|z9cXA$R!w-;&C8&r<55`hE;uKbkMl)|cY zoGi~wAeSJtdKVx+#OD)*{4`}I_*d1v=euwTLK|sy_!fzw6@8GLBI|IbtCo5}jHDxvW`z1Hp}Je4a+`Xb#HCx?D~dZUhKTPzbJpgunlAo3l(2GoNduhlh~HK|3*zA_=_Zx9)^ul-zH* z^~q&pX`0AV96ls9@;|r8hImjkLu6PE$q&5`Q8zt0`7nHxzXmwqm_o8qWOz|(`6<|X zP}>E_9?R44op#WDHK+FFyrDOP1kvs+CG&Nl=(Qc?2EFrsMdrkR{IECNGH582x8sJO z6~FhL0L5k78}k&(>%fDj4ajVV$)LC)1;Q{851x-Cq8GE$DUd*Tjr=<24b`IMslOSy zKWfoeuIfpY?i}!?0%^b(Y;xSxU4saPs{3MTIdCy)=HMx~c`cXJUF39`KH$(}L+bsshLR>6XElr&mRafdNB^djv+FhK$=2eti z7lrU%2}sB))D2q_G&`4QpGvOGY+B$SH6!ouL7gGX-nN|K_wqV|D{9}JOm~cr zvm$@Es|vJ%|0pJp*t7?}L17f_fLQvJkx&K{p8x1O8t(ow_rg$)5w2`MPkov_nE~d1 zZE`X1^&bs)-2d8!^9jz7@oKq?%iu`gZ)x0^3I4PHdsPk!$AZNjXl8%%1%N9+@V`Wx zf8nhD^s92a%}*Wq?*H$3|NSqn>>*}qnmxt;_pGmN<$?TzHvkG3Jr>G8F9G^b2-taE zf&wdR9$Q}G&MK5nJnj7xVn1?cZ5_?!>uxK_tZQF*k{x%Yll_m-lNhE?EBNZXxsAMc zCA@8NDe>4YrjxJQv8w(5%33Ppzsq#lMMW>;ra`dYS#H$emTdgPGrHWg<*uhM-WXnS zwHzHHLNa^>)RoPNFNUvIx^P7wado?Lc-R&x zP0`W`%!?H#2dTp7FF;*4fq=1m;^K?HQy=@r6YVk#h@zwuz|zlX>hXHr8MB}m6y)l| zC^%Hyw;L-`j6{aP5(Vq-1tT|f0@p{#PZv)Jc#YI;eV1m$K*hIY=fTCuP2gyctGN2U z&2I$+QzI#|>`sN1!{5K|5nr3UBYtMzC*?rmr~yKELP9yKujol-Hw_UeO?7MPbJR(? z^`&n9Bd_6d-;W0JXa6k4?)SyI+!lZ!qb^Z0-R$@FGLXfWTH^U+)7v-54|Iurr0%Ql zY@giMKe*iHmBKm9k)Sm2S{Ow{92&V-ACLWQzJBoSx z0VkUG$dQ?F?1wG=u|N!m^Z02F1# zr9VnDoX6vus)`qDz>Z#s9%uN^WIbz2_Hz?JLZ4Iv*9OvceTY8O_2cmS$SwAndauFK zMBR&&FPv-uVyD8>@70C-o<3zTQ*2Cd=n11R?;yKRRt(tzp@yOL9a`1{As;z|Iwx-NncM4& zVjyh{gH(WqS7J9`Vy_)!}=y!FR3s{&wQIJ*OzP&vNNM3?+ZiUhM=Y03r z>7C+33k#W04T0dU{-V&$%t(U$6Vu<|M3;kdOntJxVDn`cxVzAjofUBB&7oh#1VUa8 z=E8^XE+Q87_%R569u1}otJ27~{xEdYyg&J`cSupmm>_Ho1T>J`(b3Tx;|u^>jR5&` z>&HQ}vc3gwE_$G4yPp|vqcD;x_Cvc-Sr#|*=I2iLq#_tvA~@sWl*ImpNt4gWuXt8a zK@5$}71Rd1w4z@6hU1knX>u5nF6^8H*#yw{Y7;{P$k?bbHNLh69EBSZfE$jRW6IZy z%%Dr_r$)6yFE3uAJdVf?a*ygyyd`t8+k z4&7Ns@P~BI<&Eo{GW#*RBb{oU7qY74t$0tv&HBm2nYv?WRrOyP-;Wen(}%|*x?Nmj zVkjOenVQ)$2#Fm16zorIzn06mgzW*e^cfxwGAv%5UlQeD00A;HniD1cuK{A`{Y(8` zBfyJJd<#BUkUhg>*qolKu_HR$pEHCT@7;8^T>HFoY3Fd#ZbS!~s$vwdwokDB#PU6iG9qSZD~ z->y6(xQCJ0DFY}v*nv?r;%VSbPOg$)E8$zAStyWNdskgOss`#L{KqP=@ard0Ah-s) zw3&sZ4C<#cB;t%-KQRxeDnac!nm9gk*a6bkXd%uQlcvX3r5r!*>-O-XqAcpKWrBg9 zNJ^Am1%PdSgZoDBEF7>W>)#s!@D06Ab^&Zsn2g06^4fzv{K-pWm1bF@?wBGP<3DlP zlMk-eif1pe7Zh^)NL9xWBQmAR-H&SQi+3cw@8=bLnCwJ_p^W$sj{@_MtKkpaI~qx! zXyA`ND;B8i84edR*;B3GSJ1FJP0l85hVt z#l2=62Hzpk8Psik4N?3X`-X~MKX%SVdo@o@J^ zt96Mqve$pRkM7^}{=gpRmcP=*8#D!NPGNQo+s0{E8Do$D;fWKZGiq9SM(fV8#-0-) zZp3H1KN(f*{@o>K2}3>cA?DVv*Uno?*}Ky!O$IlV7rSl3O!G=j-il11jfSI(c~4`F z)rdOAU)~5)Tx0pYJ}pZSyiPVEk@eLv&0E^f9xIBe80~dT&K#>X&HqL_m-}0+BE4cr z+OTN_Tdut6nejSRLAL)ds_>|DQI(T+mC-9>T|4xRHHTWaYi)y*CDNM%kFm2sGm$s$ zPgO9-|AV>TKf0EdTj{xV>bynm4LjUP>HB#Q+t&Az~YqC$l7O@XP7qnd8i@{ayk-#_U={eIeSI~zWmbML^oIGXJEx67dk z)oJ|93s=J5vYTV_CTEWM%Z5}eRg9XMAPs|p7wLEkKOTkG^M@a{bskH9vKm|#<_~AG zLY56y89p9JE_C~3HSe-tgw7VeYAuS!fCXhRdNZ(H^Ctwi(xTwU3f+TmgCuGL+tmhF zkZc=8f;O)$?3&(7y59meu9WfYNOa4s!%(T0kh*oEN5;#&-s%I$GJL7q0Z}X*;j0GV z507lvvv$h$Gj-m6rPqn-K?yfw6Z2{E_R~>bSUh<=Fk4LbpjskzS=cmx&Csd9;HIdI zWlO-Rl9ERgWrMtq|3r%T2fe7(3}?{}ZIKmBw#eOH!heOiG~-pOFj_o2DAZsOgCe__ z{GisTrEjD2@nYF7t^|=k&fU1fZ3u8j%>2Aa{SuIpq)iZ^XNgl{eWuUt*X}oC2k(^X z2;dIsS(Ztt5lzjzTz@u%1a(a>j}6$Rcf-GX*3Y^&_w%4FJK`Mz$)Jc6$;z{tB-l_roNA%>X zY4zFQOIP20zDQGPri)%ef};;&?H4LtzUUOjarT2_Rfnn?e)Cmd#fiTq9Y{h@svSou z0+BBlO)bfNt)*r+h&#QN>wXwtVUXV^@NcX-wyR@iH<;Z%lQo}yZKr$COd-+LD6x?s zS7B2!>-GC}W;OY~p>l8!9gyve?0y7}8VKp;>6Lp4CVwtCPNdYx0u8-{r%n9CX|=!1 zr9Ekvh&v-g6`Gm1#&p{FCA>~|>CPOrz7R3vG&V(Uo**2+Pa)Km`O$Nhe|m8CLA{m0 zkUgkruL1vaL8j1GHz)Ao;w#iL)1RI0*jYj9o#6@p4zJ7FlWcQTShv%6dDnF30)z4I zFUM*==(&utA7jqRNJp@)b;V4(Bv#zQ?mzoV-+h*A8f}7Koo;BdjbyUJ;e2Id=Olf& zMTgopC6kK{34iB4Cb05eH`ZZ~^i{0q&dY3QvI*x9boAFggvJv7#N|3D@gsf@C-LF0 zm!AA8=geX9iNjmYmnGJs_DgD8x5vJicCX-!j9Ls!Vk9K}!V3wv}bp%Rv%v+MVrkE%}=UcYf1 z>T_suK00{$%Q@~Z-lv~*End!eaHTxFmo4y z=00i+!hYxQbu$XMd+LYK^2Z7z&-V~-%O_bRY$AlL)EQnFC&>D{D%3ACEmXg-<_?bW zCK0QvJBF<{#v|ScaAVEC*oUTCZwcD?W}&}w=#Hg^gr|cbrfdx}oGVhCQ5?(KV}@KED*DeN^(j-oeUXqu(Vm01fRb&4AVm-qsQx)(eqjI$G5|ECx(Nm?J2qJ z&&!O|-y3hEDW4k2XQ#*r4LBAe#HKQ=C`Iv|S}RkXlldd{-YVX_D&t`k_19jz;rSSF z!w=eU#`{`I7Q0B8RPPDA%-}taAuLf$x=ktSE9lgtlyyM9^oOUv)<+a;6}}U zR)6v0V-CJm_KJRk4ZBw|6Skq<&&;&)RUnyy}R> zx*sIeNx3=|!hmxaGFsJxYUHX`*?-a#ZhpO`tTSfsylUC1|0~cv>V4wpr?gtmT~>?l zJ7B*Zew^;1pjU_}N1{r;yzY6~WJxlx;q);k_?eT=y430rF@Y#I-rQ#uN$7*>Pz0Pf*yVY1H-x(Mhum znexO|c|z~+0>Q9D)I_-FPYahP&_i7T5pZ8GFQ@P%e*rR6ltEzYoZ@%}N}wSaQ6e*2+b z_ZA6STo9BYg)h~Ut9RZaRjHUuk4Lz@z}r=WKH8&xDMWNQ1Y!Shm+4EmPNsrJ6U!XA z(YE5SZz!)rOj~tB9`&-0l&;TqD(PkesY!0$L4XuvF&h6@hmTIZ{)}@4ooeZZ*mBBR zc}fq{`o8u_(gu9wJQDdd_f>Eqsk0+%}UtNH-p8{8s9|@O1 z(%s!ffE(QxrFr#h?*)ufrGX^S>*J!tY-x0K;IXHCeGQwn%VB3Fh^w>a*|k1y3m^P? z>HCh;;n;YC{tL3!AzMy9U9fOSo?E_~M9)**lNXIw3=dyD49p5xdPpHIZp87T<2$F( z;EEOiy3$0l&7JqGoTO)8qO>SgrZ-r1kbHl?6jT>I{W@~F)}<_*`_=73c#C_6aFp{t z?xdCX&kq%LwHxw2;;P8-ZYD_0*m%_#+xG06$02CNYf7|9EEUfNadyn6crz4EyC_R{ z+9#Qm8-hO$YB)L>PdNSwS+4W5<3v56-HNXHm^IehJeq+j*e-G_iymjhtsvDAFO7Ep zvPQb{$*W3>H|zFkzhdKdt8s|-Amd1>wGBOe|2%_Kd7u7mN1f2C>16CC-4DlAG16w~ z8%+`P=HC$MptYCMwsxYy7c}DU_!-O+W1McZ7@ymhlfRbBPN$v9X9a&$h>>;&ExZhl zf4po+o7MXxrzqA^&}NxDa^HEgztA4A)jB8-&yi`*r~cmHPfptH>hg?i+R-;Q62mM^ z!#!YSkIc>fD06!Hzlb)xAk8*xRHnyLZ+FXlZ&B|Mfzwuu9MMV{+ z9I7C^rQwN^Qa$>8$F}*au+?oZeoG;G^9NS(GZ$T|Ah;P|J(P_i{? ze&4aM&m9ByfZ2VG^dWGqn2b(mc)4{>$gGAHm#jJ}bC%yJD06Lx^tpH#SB^yc9KuW3 zb?rNccNjn0bJHrFqzjtbaQF-C9N`6@WR?NbL2na_yaj3Z##+p0Ic}&bfLTp!^uage z+5ydZlW&Jk=H(+%B$jtb6q?{XtG90f0%{J=IuY`9cR=?Qs= zU#Fn^&oqo^JeBHT+||7AZL9U&m%5ymifY+N(w1}cOnrPu&g;1W*$i$66(TuvkaCp?kMmAo8EUuB=!6WD_~U59Gopp}ONZ-wVZ)lLlB)Sd zkUr`j@pUHaZmX`>^cdaBXLXfMk9LP|3ZZz6S{6u-QY-s;RAUF;AujCTbb%WHAm<%gKExuN! zUfoFl_##4-;}Y~8jT&>?wH*Gsww9GfT8nS$o_e#a`jvKuPQrPSjQ1#vXuZpQn)k#* zyvOwwlgvxxPo20sgu4OajCeh+Z2>Hr{M&eWws5VmP7-;`W0fdu;4KrMlV@DdXM)t@z4^OF(QlKqz-s=aQ zt{KZo?3`9mY4eLo{>FWHAjAL*oWuN&Gn1u1=;kbfq>~`vxhEN7$XH$aB3XTuQ_#ti z_=RDk#&>=Gg{Lgmd85A^)Pnlpo<1=?Q~8TT+t!HIbpZHR}zWkYB7dp9tHN*fY_@^;dB{? zQwo6=o+3}4lkn)5KT2E`8WR7DJ$Cb|+@<*P@#`G?S2e3jE8OD-WK6vs|vamP9bf(<05Ce^M zV*X0Q_4vo5hs-rtq*sCR!i_Q9Dr+e9uxtq4Hi_uWUaiu}!)q!&eqzz&K^>K$k8UTx28M4OePS5(!dX^I?h9C zuim|0;5Cd_IuZ5W$+TEDWSlT2jb+!Wd2$bwU!1^Xy1IAs{~VsNGd2_A+rHix*6iIP ztTjCjGvyw~ZQgfF$=-fIi=mY|#(PCnYGKnpn>(w{Bo#ROM7F|t@QV9^lot_^!f4;yI z%03F^7-)0j(*0RcqW`09ym>>mRBCfnwn7fXn+@7jbYE~zpXV@u36(}I7fPTH$V|EK z>?6$Ymy+?CSQMF$p$nXc0=Azq)_z^^g!F7=TsUo+poQmfRXQM%MzLBvW$oRjJZaF9 zg8o3V?e$kP{cdFcXtO`ai!r1J8*jKX)!I^Y)3kX5abQEZ+uulyM(9gx#Zb52UJttJ zc}mI@o|^g_I8A9#-85N#P{i*k*&-YdMh!?WU#Yif;CH0`=&{-Bn_Lm^Jc}g(>NGx( zSnm~8_IDP4z6$eucTzeuY4cvI=4znfcUz(bJF3d)Ani*8Mw5&t$+T2~=gPiTv^mpj z?W$cOx)2)tRao5_r{@RbQhMF1OY%e?I zK8+Qy{Lq6w-qsY9YXRnL-PpB|Zqld&%bn8$CWapeRED*kN^sfO|X^*JsC?2yv%YiigP?ovCvaQ*Bl=VR5 zb~LsLN*YQ|Tvln?Bo8|Hy0({4oKY|FNPP0l-ye06Nv)+~aKpLNy4gtxEwj~MrMGiw z#%|LZJk@HrIXbZfpjfgCi>p|4!3Z#YjP*8Kl+jaT7l; zX&jvO3G}<^g7?2(>vg0l81GnCtFd&^t4Uil|Lw||;BC+bC( z#%y%g(2GcgAz1~Mt=`>wJWUbZRtY0@(Ccv^F^@s~@6R3I54y3cyI7?!o5-G{BnvKb zZRqs>+3}sn(*!jR3KP{rmUg0_gqHYS{DM;0OzG=b<&ewKm+>flMpP3s~soD5%an`V(s7N#E z_R$kWSP(3Ix7R7gm2ojy{E>+|g`xGsoD|#CFV1gFby?E-dxd3yuEx4`C{q@p8oSOX z?CYAWcinRObA{A zWCHXK!EhHBB7*5bI)M+;1x!*eVHJOaA6Er%UHLPYZ(P?$Nop*gS2S4B@-}50HZ33E z-LOmGONjaJI3Bf>(*4AxtadlE4u|1|^o%!YR1n?^z2*R{xBGB`mx?{yFa*V3mhZ3X zVVqqgZ>`KHZ@Mus;Q)PVWtOzh6_qYG%23MEwq#(ClMg&aA_t4Ylv zjSJ}$JJ&^4XwDbgjs(6VL*!j|$~qdRShD7GNOU@N_uV&aT(XscL+|*Y?FVNx%?Xy4 zaz~*9Gcj~^bWr5!Ug}un?65A~w-PZ)sBw2VcpA_JBg_f>Dom9hLMf`gXY2PJw+kMC zJz%_ZxPb{K zt@&ki!4|()*7NOBEh5Y6GjszB;rUom))L$O-0aGbdU?Mv?mSWx-$YPNDBiGDL1PJ7 zvLxZ#l-!*f=c^_;Djv}lGmPn+}zEiW)#Yw-c7Zbm}Dbp%c!GNTdL)kJ* zovJ0OW|PNix1W5V;M)bgDq9oxvBl=spbVFG+j$mn%?VD(N(g`QNW zjSP60PmtUe$C8c?CdsDwZIJH9xHl~aD!+6RxxdFFU42p5BCfXcm47KYT9U$(P;N?5Y z6UteXkr8dda*r=~Q*eNZ|BLo|NVkOex5HHQtuARLoEu|gHMGpmyp@$o)nGWVvKhei z4=%g1&0BYP+xwU7om`%n6D&=Mm_=0tV%-=ZzNcqg2n<&IfST(r2@D?3vUEpPrDx*5 zZI^){#o=>fG1xa~93#&UtzzNSlEG0HQW(d-pOB?KL$jZ7!W_7mm&ckxuq?Gw{dygZ zv{2qH3#wa@Jal>(I2QAuB)pY~xGerxW>AsG#>jZV09{%Yo#F&8&4!G7lH*6es8m8} z#^QH9E1x!_rB(ZevF?sV>=4_^LPPdAXXa#B&qdhR$R*g-0wFLfA+ZmSjJtb>@p~tm zuhHzxxi*f?Zz|XiG<)Z(XuoXek91*Z%tIG5gVe zHWGBlWrt*FyE!~0r}`w|g~r4(HmqnqT1D1QqJ)zjeectp}pk$ZC5!X zD&IiPbr?lD3PU`(>=VWq&CcyFQ=&D6!{6$r?2PgQnIO6naP4f6WjGQ{wAu*n-TF1W zob2`M)M$!waL)D(Bo6Ssvm9YkOvl7>i>}Y&A3KuWP~TwPbQLrlmENKG=iC3)ufmIF z5HcdU+AHArHq#?}tdlA4s@y1vOS3cZHNO82O*NfqtX_X)ne+ejcHZA`MeQDsXc3(R zLG%0L^paHbwrEaqTM6!eb>5w!Ts&5 zS#!?peP*A%pZ%Qg^Z5#ihRn-N2jZWJ%I!M83lEy9QK1w*ho97SzZZol#=Kx~rC z?y{N#l*N{Cu44;uqUg-J%K*HUEa^b?WN%!3_f$AIs1*>>JXprkbB6~i5B@5f@nC#* z4NH)}{6-(zB~8X*Fe9!%SjwAm<;-L8g(j(NIFUmlNO|Lk7VxG{pX)OJ;NZdOS~1Xh zaQm4xc#c0{zmW`2`ZqaI9{^H_(qWW%6d5H9J7C21U+z*!G&IAkLn_U_jpOz|gZ{?+ zn230?bCn3>;C<8qdXmmw$aQ#sO}TnHhzmdr+_jSr?6>N~kx^8jn*S!r zMo5#_V)yME&!Z_gQy&YczIZr1FQjb_(hNIqtbW20^MjVY8t$%EK%p+fSx$C?#6gJD zx;G@bGebX!^FB-9j{pd>3t}tYXgTin70>jG&lGVPw$i=t8m#HjE3wP(Uqwwr%`?7R zcEV$GWwK0f#0$-&t8Cn8;h_`_CkQ6OyP*nDCsKc75yDA~25V8`W-1#tS=%gTjGllM z=roIgo%GC!!L0bu*-TB>mHWMljJ{$nFr}o^R1jB~r!v#e5vwXELZ0&=`vA1X{c98I z0Kl1nKwJi)qqj1(1^I^CWIi!LXzRn>4eYWfynM@6+Ysh__*>E&dxemXSE~+Y>BL=$Prv@f>73mQ2%I?oTZ z2YWIuhpK2^O&#T($hH-+-)Fh{VBL19_+kpmC{K*u9k~~TU~VWH=hwL0aeYi2ptac( z5%>~qLX3Vh{1{=Au3X`MTy;}! ztft354Q?ED@#x%yy8oNU?ynjDW_5&PM|9=pFY>QDOe6RE7iQIuYf{*m(v3}&9@7m}%HrH(CJ5w8Z?ep@8!V@IXO}Wu^ z-#k+*R`n+2J<;C&u$I;#2$)-yPcum_e~tT*(OPSTlXIL4l`2AlXt)d3mq9JDNtmx+jJB4{r ziOCYKJWpsw{W~^@Oq_*`Gw?a2=nt#7!<#%qrDds?t&CoAGW};Do&6>b3I_Tv`#FI# zGv)wTGo>fuO^f58Q$=FebQiuj)GRTPe!5WhjYwM)pmCk`x~0USO>!f3bPaHnw?hrQ zfzY^3EZf|EDJiJoc=y?t?;3*l38{66s0j>f|7MKMqeCWEW8jVY=LSsfHmxd$%sB06 zMGB8498I4^g8_6d?^Eb!78L#^aS(-bH6Z*dE6}TZ%rI?X$NDDZVvuH6>Z0>+?Z8>; zcxQLR&fB)a=R=y6nTxeI0qM>f%XAa9*rU-^-`^q{!|bdl{gWWpNUUPo!d19VG!E-% zZ=FC_nezG@xmoY2+birEC*=Foak@e+@}8oiV;&-ekTJb8}4s@pFV)u zw1=-$>mtW*-g%uhc1S+k$TC#KOcbxjtQ5F3puW0rwn?)dP8z8L44p-w77wuPcHNZDJJTlHZPP>^EBlM;k-igS!9SA{&$WB@ zz8lRfs;g1%YasdL!)5g7Hff#ps_`e|C-#1T-IdiR!QJAj;B3kx*Xgm#)NQb&A7?Tk z@^G+cQKArrW!M3Q%>RreE7mJ7!@LSfE0gxs%J-_-?aQVHjcyJou26=WEI$j{C%`IqM zAe;2c6+f<4U;M-lC=dy5Kz_uEV@c&;K0D~=Ai>RPRXBC%KrI*Sz@DnLw%eJN)@!C; zl^0LI7cVuns5+c*f451c;9XZ`8ZmVpZs`7Mf0DtIeIO2Xi@eVDakAA&p6CC|VQ>;L zRJV9Uhvsz$@B%cZIa%PfcOC9Vv*`NA?=aK<2v^K>8aQCR(rf~w!fy>#w_52E{k0#W zq;7LK^@Wz|NsS7Ft)PFxn*y}~Y^zr}GhV#%%SQ#&FO7>vS=>@{vhPxv(sbjE=JS`? z_s;L%Q5yME^ChIq!@ciqM@G`Gl4OW^$EtgAh-yW0NEf31O9;%IfY>nM8kq}pkr_rH z2s=o7`|gMFgcVysvQ*m=LW6_(x?8h|fnnQ*N%vS!ACv3ceoDqM^I1jESEQuJ<(|kvPGkDPz0j@~+>S>A zO){_#T$w48Un$3TYUofTp~neKEG@!X2_0YcnMYsl#TT7Cmekq-72@74%=l(iTj%@(E4cHCjbq?!D1Mp3;I34$Onz zW2`l_wzGZI9Eg{|$P6NZKW2C8QhtohYR{^$tvcs-Vvo)!Ej*=sEqB&kKg*#$yJkO> zCX2O=4`o}d-Q)1h3VatJAoC@euubZ~+M9dj%N*}DQw30iWh&UQgx(4;0fbtJaEai{ z6>m2XpU;*S5clZK(aE@}ODey<42SOi^*Vy=L>z%`1Cl6LwF&l8vm#lsEy2$m;xpAP z)gnJtI4oT*k#cQC3*{O9yv5c0<#N5+ZH=ZoG`+ncU76tEg7+G+x&wp}nj6(-9pz^; zp+1|7$N&B;e%O~+RB5%=9U7jsju8vxsa27r#+a%SY70D0r*q`po9wR~9u2-ND8-zu ztPgjs35oW&78!!#x?b&#Nf{dRS+_sAiPWdBTq+-`l5WXF6C5(HKQ~Bm;}CEc%JI@R zZ#Lay0g35fc9<*a>76HACqk>4H&ptoR#rpD!qp(PwkHkWe^e!mgaO(@g?7L|qh`Sx z7>JiCvDX@S6?r@NE)sX846$-dY^uPsn>TFR>x%DW-%+TYB!(#6N4$|O(#t|Bm#BSP z5Op4U>P0Hc6;me#V(lh9pXc1Ad$kMCpbtWuZ{qWqH$_4l@YZh7{n_N~?5{rUqQ)Sk zvO^b%tM5Y>O^8@FnZ!L8T+Haaqk579JDmV!*v=pBxcq2tM=|f_B8oEm{vGsLSAjzQ zqtIK;+!&Pl82b_v>lH-?6<-p!b*p}2qmA0>;1pk_rQtU<@b&Bboqrrv9gY{`c^iTQjBI3&6(|YJZd&0mV+#+p5bn8*25Ucc~$lC?WapM<@ zOnjKx#%@>N-8TbXM&?~X`_*^3^_1tx)>Be@ArAAt>uG|L6dxrV36%^f0fzY)@Z1VN zI{6@K*1PWGKn~t-e?*>M$J&Z2b}GfLi94}tnS6bIKC`4Fqbz{Viu+~gqG*uf{!;by zlZ)u@hiji$HlBOnyLomglv)(wUu=}Du1&651D3^S@BYT_J;Z*99x3fn+I6hu!lYyi{41=4lCsz;ps-b=Xg14Nk%< zAA9Zl?LE}1-E)XwH$Kp#@>FjUziy41d=8H3Nm<8T+PhW0$J~2_O3=fqjql2Jk54`! zZ~MIzqz!%zRYXZk--If%ddpT|EYLc|23RS51*SSiO)$zbHGpD&IqHP4enadT$V~Lr zBtOLW&x90jzVbWB!}|dPdHncvx(}dJS&eB+`$7Qti<1q264=g;JTPDOU7Sd9POG&! z@uz)=0{A`u%r}7TDNKQj|LysEopWVpu7x(I0$y=CR&hHoNeLvskkU*}t^vnN=|fr9(x7OCVUspLRmM6-CGg7IxT!bU z>5Ka~%m%onC#C{aC%Pyp<}9Hf5H{pG>|HD^v91kc73T#TlUFMOAflCFw&73(f&~7) zIoIvtQxS{0Eul>H#7doQq{PFY+v8d8``v!Q-&? zMrKWsd&Y`mvkU4hPkQEWsK0R0L^wVY=2~D7*|$tNAh%3PJt<^cL|!jXKww}>#sLsF zfWah&te8n*dI~TDct%|YP^IE6J>#rITo_%VSdIs!!Y~34#6~!jCU4b)(wXn> zQr%AcNxE17ggz`tdmsNu2v$cEj7#w`LPZZ(9qI=B2^(K7E8l@^*ADOk^$xzQHHiw# zAM#y?gSb|W;;yqEF;ip+NJd=WYFf%P0D9J?^3-k;h50yAwIb?^bxp;Pq*xIgGvSwJ zH9hm-!Vi>J%Te{&Spi26)KD-xR-Q#+X7ir*onba0w}(H&66%iF zh6cO!_R{{cw8?UN9KLpGO|-6ZQYK0pcwB=CSb6spEdeQGf+taOD`buK~ zj#`cK%YxEWc40|aFS>R2#99o#@@-dgHqvN!2h@5=yEw7;8c}(=Tvu>VyjPtkS!)ry zeh`<=Sgm&HxqU9i7o!&^=*6~0v#cN&_lcr8MeWoRd+j%*AAZVfywD`QCtgM2L0x}{ zMQ+lUPF6*nd<|pI##qQfpQ1pc?o)9WvyF@Ra9`%mjUG%$PH5r*NhCV7H8`L~=UMvi zEwsMO8IokC>g}HOz_9LIWgpi5LUIeh$5ldsdj|p{7#m!@UP+BD3mN-9gG{;~i_Yk} zUjcyG>~0_OGJ<0g3#Od_pr^~!%P*S7<;G~HM!G!b!Teh##i?CJMW%8*rn&60`!hYj z=8I-;H41oO@-=Oqn4<7}*U|DkuAM>q+`;jChV@qTr5$JAQBRMubu>71gunA8U`=8V zQ>}-=ZaInM4#w^(grf1}5Sn-M;57Kjn{T>0U2K1;@)wIGjS&W2R$p zw@2fMbYgM8yt2$m752V8Z?ht!nHvsZKo2(beeS}HKysK?G1*hcX}sAxM;}|-a~89j zDNR=2N84LgQlo;OfM+89#*xN*HSGuW{>r%9)(LkY&P?<(eGHQ<`rmM&`HJL-*z~Fc zGyE7@MZ;F}aCCLV#uejDB#{~mLx)l_35$iV3tNafNT#jTN4#I6s{bu)dTx-3!-oX& zojqZl!###{SlaoY!WipV9)iHX6HRWL91`}KG{6@e$X2nuzSgf+D-lv;9Pe0OnO+z_ zb58G!r~yN{`o7GV3bjXNc{3;(1MPuZFe)_*!}wl#b(~$*wk1E$dvTzTM~R;vH}F z^=K;Q>SuTt((`r9`!@xlNN3}m5ZHfkKm_l=Xt7-k!l(TC0!R6|^R@Ui(|{;xATCzrXZX+I2CL zJ&{t-)shZl-N7v}_(__xB~S%1{OcJ z$X2)d{m{dHa+fbaLtnf=*K=PQ9`)~U*_UvAaNh<{tI;P3mUE@7!gSkTTgCw1Ro6`h z%*BiV3f}+mlIdQPUj_e#$r0 Ml(m&A6m3HO2j#`^2mk;8 literal 0 HcmV?d00001 diff --git a/blogs/deepspeed-offloadpp/images/h100-8.png b/blogs/deepspeed-offloadpp/images/h100-8.png new file mode 100644 index 0000000000000000000000000000000000000000..938625d52aaf1f92345be37d20e00235fc85e7aa GIT binary patch literal 20621 zcmeFZcT`i`-Zl!d5d{$y0qIB&Rp}k+p=;>fLXh5D=!$@f^Z?QoLT?gE2vwRuC`#`j zCA3ha8|q!`ea=4ny?2c7zjxd*zHu^^i)4+JHGk!qzw*o(-#yk*BPXRJB_JRmR|hLY z2na~v1O!)Tt`PxOte?kA0e=WRAZm&PW&I3mz=uCxDrhSZ5LCpFojxN5K36@~(pMQD zA0HbVo0yoGoSX#C|9mOCx0{%lSXx@Xc=5sxczx)0i0*NAcJ=^XNkhKAzCl4jp`oE~ z-@c8Cii(eq|M>A^Mn*<4p_cy!L%(#d#v(G_gsK|mn%_3|Gfx3%dM0fDZz zy0U`4ujOV2P7x=CGH#yLr5seeKdl6MtdXWIX?=uHhg}>M$>uw1A{o-YlIr<>V0M_- zyxUzN>Uku8zT6eGd5Z#g_uJ&W5WY)tdF=t1mVn^v zwf~;~4;C2YqlxUOKk{(@415sL7TJ9vcbNc`R(pf(?*fXC9isox0=Npr$p^jS-n z-}uJrZS`$=!s&*MNf0S$3R`#jYIgeSIMLuhpHFg|1PiNU3YR{VCi2~~g z|G0S-BG!jH2$0jQ8g!=aI440pdk5zIj>L17OoOmaux=@nqt9aZ!=JL6R0hciG%XJ}Da%+Tm3|@aj7*%O`r#9B zZEB?$Y0jF?4|!s4x#ncxT3t)FB>PitW9e*bIB@e+fPa~>+h3r((!Ii1hP`R6b}tU_OJ%2X6|ph3H$rddskO2!s-F-E7+8FItk+r|@qoUzr>X|v$CCYn*NrJ`Z7qOHHbiGJ?Wmsk{ zoKhYXrvt*}t|=ly;1MiF68COZaeLA)h`&y1JgpSEkO!?Nb_5HCNuC~$bL9q%8L;qU zuJT)iPp(PJz3Fq}n|?i1;PxpM7w$5zE)S3rXl17!)0=??#Q3XnZx^jW@0r# z&8;5GAjbp8^0*~ENDyzDaaau;bx@I7IM+D?xMI6=E8+IDwb~%LKyxb3uYn&I8V!Dd zFjrcJ<`NpL)1%wI*~9pxkBNVFDqxcAn7$McOVEINKE|(Q-CLUgVGq!_$GmX~I@7A} zdd2A);kiOJ^=5l7%tr6V0GAwMGURCvkbS$e)t|&hn|?69#I^Q|Z}&5ts87&Z*lL<= z%A~3bBXD;F0Qpe8a$lpKRHO{Oj|8f;lqs4}`HCGjYocmhwFE3GWk^!F2Qpyt*4T4(%0CQd%9rF{;CR5a&+!y zFz9b+B{sV*Ef-Q$=Qq%PVd7ugRu9q|+)H%Wai(yD&x97`L1Vcn5??Tju$~XiB{!@J z$)gvPjY+wEbmWb6UWL8_+(j||PGBmXqz_psI?B!JXrQbmJaGZEGO74$;ji{I5d;3GE3g%MH z*j-JShUD8i#r1LpxhDRjc0C8fvP}P}qLiD7$muEZ+`$A$62?{{2OO{G6@0|Ft(Jw# z``q;Hph(=$flJ**vs~LMt}ztBki(J1DcjiX(5W7gPTZW(o0ckcKFJB55ZWRfnw4(; z84ZuyrMbrfbt{%HQg^&XXZ-rQeOURdV@7d79{xHgYfViLGOBs^93jQ0B|qd7KyX}J zZC_6;l~hXWmI;nGPZ-@T$M&Bg_%tg-sJC=2rFD?)rXs|OLE4t|&nXd(v_~POh}Er_ zq>cDpdze+f_XJvQ&b^{m)4HpiBC!J*0Dbq+zaEKtdgm{0pG5s|_er?n`o z-=O94a}+~LZ|H4ZrQC)hqn}`PlkHR2<+8b%8J`B`l#aB_Z^aY~PXf8)*hT$*pIOBr zgQ}c4ahI}BX5Y3{duQWm*K!Ro(FjqVZe4N*M{Yx%(Zj}v0}%% z>Qz<|_#mEc2EJ&4nG-c6l&GaN21LTgMnsb1S!i^u%3jYHnXppXyxsIQqzt+u^6}Me zURKIbspIosjK2F6`|`V3S_v zEPhnEcgUqmWYTUHMgvt9Pl zmhQNn=CqU#yrh|9!RF@UU!V6FY|I5FMt~*VZj)XC9jR>WEON-SrQP3yfw*vQ;&(qa zoe79{yCuXpG|f&YK-wDkoHK~m! zS=Bd@n(sCoXCjn-a5V=}TRXhmopc+FR~}wqJ&BKE|1|{U>|u(JRqB3{&owc9$dK7d zMN5iLrr*e!`EQf0T9TK{=T&!oGEys-F({+e4~$DzmnhYLljhY}na6Z{Z@pEj$+&kW z`GivOK;yUV%q-)kvW`78oUS4Ul zkE4%r?t7Nr;8hu9WXmv!jU#Mkfu60cABGbG0SGjq!=e)Nd1k})aV!#=N}?5LD6d=r z_d@4(4>Ve*q>&h9IIV-?Q>@;D%b1y5bVR0N8+Ut?vdm=**YZp$>VwY48*I%ojSxwk zt9xpFgrlj%u+ZoFyv{k*=>6F3qIHJB2pnhY-4=VVevj9q7|xf_*OrZjuCkL+>BIN2 z@oRm!PH*(%PS|U8tFqbRX+bEzMTUJ|Bb-3sedN5B_=D$uFoB*1{K(m(zUuXNZ8Z&t zscgx?{K;zy-MjwmQpJ(XX4}O4DtLqa7*rT%@p$>9ADJcZ#|IyOfj|y=CCrZRx zAGGvBq_3BL1Xoh#I684wF=uVCEjvaz)2UG;{(bdrMxMi$@upoN;wbRYm7@D~QKNaC zkqm{H(}fAUp-qnKM@bW7z8UpPOJJz!Dn-id4e+m7W$^ z3d!HDNLbgSE3{F`KbMfeemmf2W?aGMH-0j`1>9OOFvaTX9;W>8Uic?q9*cr@#_Wll zQa&{;p4sX>9w(aml>1|AP0iHrM&|@o*Q~B44EODcNX%Hsl8T#q&)bN$WirG}8}H^T zM!?MtBwEFokwsc{lgih^f?U(9_c|W1-D%ecuunDV`bhAX_+Y` z^@NEjz!o){_^bxNuK@OmP zK@JncPQxOG>(a9^-qGxl^u(W$W{}agP04{bp7^LGa7J0PT^>WjW=e(99$<(RB^>tQ2*+xAwP zjc~fPC?$n*xOC+PO(~4(_YwGOUBQ_j6+ZP1KV7(n$rZb(oKaSDy!7P+qAxHr^;!)p zRNT$_{@Om%bm|q-fBYzdW}UAALlnOfJn6O^9`f7Pd(hx8B{+S&{okJNrC( zjXu%qMy08#DI|#lS7vZ8KS>#|#>V;Ptc||^vGd96)}O90w5@J-u1VD5*nd0O-(g9` zWaGs@z#f-in(|C~VvBcL_1K-f7SpH*XC7{@%MLM(QDUR!(XGE*Olsi5KfT6r;X_$tdyfYS3QEH^*H{v2mQ^>f-=a&pXebGw`?0FC-F&Vf2rqn|C#qQElh?ebn|`bV1`!wT=nDo zB;kR<20Ko_EbpVs;gMHtjAK2*8ZC|pPVWMuKASD5v_ce7YnbfDf$huuYmZ%dgmt$P zPo=u#*xA4~sfw`u1jwitck25j_FgO2oJU2$IP zZCA!HWf?D5mnK@O4**tiE2i`o|HQs}D^Q~`y8Sk-=5$dj&HzL? z45Ux=ME0(6C{H;lh~w*FQ5CT}dlSXn%u~LWi`K*qn6Y`iq)*fe3AX>8+l;I3SYHrV0CMutk+U?Eg|bSHmv1aTb$^#1k~Al? zwc)Cn1W)J|7Z)hr7AUE{=CT~Ft8U^Vj(U=Fo$pWK6WvAVOH^YF1MeL>$*rsj#=RH_ zEJ)l+U1BeJ$td;YqLwIrtghD~%0qvBgud{=HiQs}q1hxQg**`xCBtT>iAkTtGJ1RU z{3F0HukDytemoq`A-+o@W<>RxT41fvvHlVdgW~Fdv`OI2yV-Ejb@iM{)NJql8Nug% zRyUj#1;;fpQoBdvXU-1!k)O+9;PKWSCy0Q`Bw*W&v8j;<<>EmH9;>L4^cg>KuMmdl zEhX`R&#d;Kq924NuiSjN1l>{&m=0bF>Wklj)La94#{H1Bv_%Xs{rWJfIJ{rnj5AwG zO8()YCYlF!5}=By6N1lG@XV#OidyyTa1YtP<9wICuMJ7Es4qt<+vnf%8WMFYQ+Awn z&puJ%)3*&~2-lFLdOIdsc(`tElcUl~-I|Xm;Rro;ml!t3dTjs!PJn3;%FF1qCMKeO zkp2`UtOB!SuwloRWfB-s4@%oS+Jz*WsI3r(HtwrVQoZ~Q0*tsoQj!@$X;X^CcIDFm|T;as6OGL zxRqVw-Rn&KJ(8?3K6B(Sa-2EVUHau1hiYhapi}Sx-J;JPKV-CraSYaBkd3)am)9)b z-^h>h;V1+I4A@rg8_Nh?uDVeY0l<@ri;@4HC&8L5(B!^`^Ha({J>X?%b2Jn)?X%h~ zFg#2AQ$Sse6u8*8mBx9~dgAyEoKllXBdGbsYi*t>=Rd8O0htX<pdx)PP50O-FUb7dV!B~B}dAER5Z_c=ZbMw1>c=6C zcnF=X_RQ8#ptfz>%(hY(GaaH{0sPrDGw$Y9rHO75AW7a0S>jkkoLM%#p6U?zas`+M zCxi*3gWVn<``)jb!a|4biWR+tgmpJJ=uUuro?)Y!7!!kd_PI658=}NvRb3?CTXE zue>ACf+G5!I4Acg=2U~@sp*k1G4)haqbU^RU$W{#+QbL!)+$D#e;PjG!T@mA8Y8?CJ)11IXc;K z)*-&*Pc9`q;{7eCJ9x+5pmI~%@>i32P1Mv+&U;lKa#3}2*LFw{<`;uh4sOnkyZ$af z7w?`U9Gb^@z50y(^<}z%P&C%4kUHxoTBQ3@?58cYM=8W7sbAvIdc`ts3aU?(%O|Ok zy=Oe+Ig+B+)U}hhm0if|I+f9-4zMlf#T!Chbv*&O=x3Dihnp6Bmeyh^QxgwFo`)sk zwufKq>>YOkfb0TEPYGaeG(R)_gLb6M%j&m#i!*?e$q*d*$q>_EM!#XbI?Z%RnfJn^tVSC-cWVdn0NJMd9AQ8UmIDDqBF)=dMvF$LTAdpB>3r51K1C;D~15YtuD7eX>GdTeYasf2j+PI zv7X!3H5EC;A=G;ihrpD>LDk5*Q!~w%MfEXv4&OC=%Wj2omUX!LElysw+QxiJJGJB& zl7yKD@ns?O-Pz|hZ~8zo_naVG%@~8=rKRPyvrUt9cRqBHw5h=hBAX*GIeC%itsOm6 zervPy^ddz^9iQvM@_c%*JqLVrq~)u$miN_@*92Pq;90iyo-Xf6O(8x)#+aC`S>vuV z_jR%LqD@USs@4(JGvbH-aa}7&$#Kq&v68y)3t6F`+WR5u$j;=B1xPTSRqRz~lNIK8 ziqV4KGHD7sq*$i?>WKb;Lw7;MtK2Kai11CIPT2=h_GMRJ0F8qbjR^{|eRVcV1oA}c zW}Fo$H4I(uYrUfeT@JH*?ATJrH~O<)E-X9}sbgOCs+cKGk0;Ty?510;T<%*ab@UO> z2?rmNsI`B36`1Dx zO68jg#Li~6C5E@HS?qe)4-eHR4%`g1td#Xb5$PBwTswc7O!GHf)6^UEQg@na&n263 zE(bcYz`p)5#l_obx%KqP_c&{6(T^;$23rym&!4QZXgt$pAPsQIC$U#I=_zjxYN)?z zo|U7!uL01w8kO3CV`DK%QpUCQBOL1Y^prsNO67&>pT9<61^ z#80=O60CW5_%0G4}ZqILVHrC5cag#Fny>6l`;k5>WNfS|=K{$|5p zXOQ#XzHh4oCe~TTQ)FCNY&&3`ls_~mJ`#Y;mJdpNmnlV4P(9gYR1`y}zV(HfD>=nn zR<<0;Y6&^s2^Rj_7SH|3FI$sTLVyVqMS zhWoPBRt#J`+Hi?d^I)3UAh{HVgPP`P1xltM>vkhdYXQP2D}uf6;RI z%=Kb#73s~z8Ews8yIxc?#Wj{0IM2N+p)FzwGrUGK?v=z67=?F-m%azXKVtNbj@goWHbf<psXi(fS26E*$LY(xyr_9F z-j|vvnH4MLwc)xW5|LsahXR3#l)8KPj@6e})c1Pl^DWyR1L6_Fii{lGtBjsZcPQlY z4Vet#QWcuu%{<%L4{kuNN^ozF?b(yBCG;%uyqAu>`pb1-Prhp^I|$8rN|#9;=^t)M ze-$uBW@%U$58+p)|HIPvf42twZ&?8L;q>fiZmT>1pWJ+L9=!G45l?RSyE*ux(JR=7 z%DYz7RsFV|VWZ~zYj zn>@n2_yaUW0GbLB87Bad0VA_O>478mDsUzQ#oR_X5&_5HQz8T(%m3Ab#x2lX$P*Cc zx#0~AS)k)2h#LgU|DX3*U8J9hdH)Ix1ux3t)|M$hb!3G6|ZSzceFB#24I2SMq_PHF=aLlmX^%jhf9h+`sVAiIXI1kw9XI5A}f;`X9kJ~NZwN*?#zVjbNP`X{*nI_PrG;cuI6N|&+xZ>>C;TcS z(2RM{9FHUw`P{vUM9f^I#Uh6xjXKbS0F4MUT0JpZL(Pr%Ek?Q1jAh8U#f;E_xiktN zb@X(jW?f*XJkh^_7o>_YvzTYASpv!ivs5lMqn2xB2sks;NEYJA%dA~~EAQO9*kF5M zQ+4+6v8(4=v;K1rxvcLh;c6Ys#Hd0GMF_@XUP_E2+0>ca{y6ywM$@v)vubjq1}+8+ zaSx~zkP9ynCGj6#f@C9|c=|&pny-T#ay+u{#J)_}Dh5L`Ml28_30_hdLYKQ8y?wzG^jj*Uh37Zwo> zjN)AFV(1F8Rp?p$6y3xvZvSnG+!Lj zmv{8N$L>IYW%a&JqwpkTbZ1K2qxugi5Bi66zqfRt)qY**z1qW9{=Ja;H%V}xFtcA< z-c)vlo4D`m?!t9+0*G9C5D~$~H>{D#4kcu?|0kG2$lBh$gulLg@~O0g<&W{C-O>%a zWyyB~tq+aJ>5fCEG_v!xmSeEVO|$j!3L^~XUyS*bQJu%TU1LjS@y3|?HZPELT~OTQ z&LEX`!$PS4%E*ZOLDT+eG@OvYdvUk8f0)um&46b(}LVE^Mi`;2iab+q_2E4H#TZ=qz(K8a%Fqmeeg z__!K=X%q23Z4!$-BvyM7t0Hz@#|!~E`n_TUHP=_@YNIMm8y!Nbi=30;P?_JuF5aUW z#i+VNOUM?vkN?gOZuBF>Ov$_K93-x|y3{f>?BBvlG{8q<2I|Zv6>vxRvfOYJafxEM z3nJrp&AONO_CxCZH)SJW{R!z0@Z;L|&4mNoY}otr;U!qx5$j*9X6Gx_6HI`C_kRL< zOVq->?#I|6*b6VhA`LYS(6hPjWGQ%FH(Rx8E4G7yZ<~y&@yB*|i);NFW8e{Q=Hbh2KpFe!?C`b(b!Tc)nd z(^>ZH=}7*JYE+HzSlo;fCbTu`7m;tAaHy!BBFN#f9VX_0agOHSQHH@7!&>3(Q`CFn zrU@SFDojM{W3uD0IcYR^jZF?lMxNjpXIAT`sOm=Z+l#Ar-x}g@oIA1z8Sd%#eeNdx zfz&H?1p4a=(z2J6`pxB}{sI|25KZgupjuzQn~0VH(X8?g1tIE>)QGh5&QKl5_oyqPpth}*EHFO&22?4>)OW)kj-54Fa zL1IFNWKpfk1cUjYPDuw=-?%M0ECdn1D4Ob1v$GFWQck9X>Y^vg4p0#Q7>luy-2?dN zZbsk03+q>6$Mm-Bn+`!fY(UP8xU^?1#PGzo*mdOXl=q?AH;6a?tmRYFspyx?pviYr z&Z=c`6sYBEb_u?+NA`ed*Nm*l0tfkmH%UHiM+`SR066?g@MoTeBfXb7`?xLIGbnvD zO(BNg{i+=j&HR2>GfnNpbH_RGUZGo@O!5b#<;^Wfm&P6Bei8k!R_h4>%7_rcKa)2X zHfo$JAUIfMP?nG%RPX0gsm3p>Np9t?>CZXY$6KXJr}oHvRwQowKK9)8;D-FOAcE8a z@Yl-_LU0*E+SE-VgmWE=>s$@$)=ca)@hXSDxL0p`P04(dY*k>%G5SGnqGo79T}MW55y( z*>Hp<{3y=drB=viO1RmKxQg`AP~bFsgXv44%ls>6xUqoh?K)-#N5m!)KAGVU^47Gp z!T3e1)H#E+0`Z*8r#Mn??!A|IU4f4RV~B>oyQ#N=%NF|sb*Lf+dT}7I zcF1^f)^kzG$O|KG7>k<0K=n(f^)P|+&S!PM1~LxM zjtPI$xETIs0Nixr>m?FdK&UXY`DVqnO%4Af&AdJ7;B&ZVUybTxO%CB8-0)Os(|yQC z3wi&(sX#tKoayF9ok3$#_^K-b?75MN5ykAj`4{Qw7y|R)3UfU>eaG!Yfx8AQ1G9O2f3}**lK9n4)i}heFZubr`wKT{j-Tu9B#I<_;XfroJ z4L)UCDKNBmx*0dg{BK>26NHusfe`&yPa(ANcbOOazw^G$<#)%$kyUWu^XrCbr{N-} zdse|6xD^|cKU-5GNO&XAM*g>x{BO_ve+N6zA$V@(z9N5-fAI=V9&*F+5#5In_%r7B zR;TjI--FdKS07ykNE6jZn241?bbwE=A-ng))qk=jcwxsfT`rEwG%3L-Zwu;Rrn4tX zbG@8Hk=0YK zGjer23{F7;sUZ;hLTL9ukrOh@@J`*dx9HeLkV5zHHfRpf!@q|OY9y^YbD^9Yw!LTk5=9bV_Pbnw}4 z-{!1lPk=;-AOe{Y(2j?Q+`I1s(ly-rzyu*=D5N2p?VoN5seCmz*rrZ9ik^BrSs%MQ z{|H0kB$k}4_t-5xsu#$81rw6=n6i4b;JWcAQxG5#izJSH1+Zbak%_yeB~tXqWXg~- z`l6Y3c9oNam!HdP#F0MkD>X-I?{9mcKC^q@$F>3M7k@z@wiMnIR&Lt&eWnU=y2IzM zGh_HJ`i*mWUapnsnLAO?5TBs?Gb#0;rJUUFb%vx!=0P4JheyE zi>VPfTc9`-T|BO$nP~7fyDv3Z-&}F3vRTR?U)M-{e8BV);0ZJ^8S^~F^ubifz0p)| zLFYeUPnNkOio31AXkc;WUs%!vDu->6BFd(b^(U{xSKq*^Yb;*cp^z!wBdXCEwWR`v zEAFOk>Q1L5sQmlYnzouSwOS4b#<$D2Z{D5L@M>6a41+uT$Cq&Ae)~5Mcv5Lc?P=5G z8+gY?AH%bUR%4LKz1^SIN8iszO-20O^$WiT)4(5SZr(A4%u+eO@-C0~!wXA>8hSW1 zG7#$y@S^{kRMC)8Q8}Hq-c~tTc}5VSCTU} zZqdx8`%95n^ZPzn+9XFiWR}+XRkUJSZ#j+b5gEVCAC!$%#GfSQ*C8?X^TZ&UzkRpF z4YgQY9B~4fxlL5j9Ul~ZbLh8Ihq1jBS*){e>FIZVQJRL5@{kp|lJJ77elv%c-PdH< zGN;_XjsK)RX1mz4!Ja-y_>cZlx*(9cm1wmM!W`HS|1Y3^hF1%VrPi_UrA4PTjYf&+ zh*QlOV>B71M&4v|dG@J{ZckSiBp*J`5(`DmNKGC&S3bS0*iM+DPX|Xg@`_40J|4^J z>y{No!8QN3@@@jMtQ^K*qvGHe)!dX^vp!nd&;%Fn;WBz5+DfyhjF{Hu1cN=U$2)K8ed+7zK+^I&^nfveuH=6eoa4673EJ#;(rWsPG%%Pi61S-P*f#jS z<~j@Z3c``l`YAwYznMmqELPe)NpLcj)^C@s?l*$~-kWn6l^~n8Gr8KzLw4TfjQ#gH7l5ybj@_C*;=3Fy@tSWE%U=-xkw6&!- zs^00GqZ$!0lCu2QgR}zNjlcw3{s+!&KIcUzA`K-#sK1)Oet*s%40~LAy9r)h@Hu#_>wid? z3`ZRgzDgle)5L$?vvR?G+rjSt@ao;ArU_7sj*s282_RvGtF~dKUPy38EE*tXd(G3Q^ z-;hyjHB{k-5{CHisE-Gmdw` zKHjaS>ho$U$yQFT{Ge)G*pp`bCriLF+JXRCVWxJse?4%p0!pEO=Q=j1flcx~D16jW z?Ql2=iRme}#DU%=iX9Z|dfNmIR^npld28W;eKddSmit8hnnIb1fn<{M)wVH9uc!C^ zX(knneQ*EH_sQP+oTAH|8R4IUQuf{i866+CA2({+-C=9N3OM?-uz$v=Ihi~%i5P171EIXxP>oeoDDeh3v$3#H6;4YP4e_{4y8#&=;53CCT}$jE0+T=s`KdzFTJtML}XIg z-b`q>|34N@TNQvxu&6xHUXQ%{j$zE$9^`5MUSxd%kFQnwR4TTCPsz5?%l;%lIE#jA zbS)rgXI%29boa@b3O*^%WH&U;FhH*!BQeuSyV>s=xzM%XE|-y>Dc?noA13{I<{boK zUr`kNdRQi*a1QkT|7jZ-R54G<{_*+$uLW4h>B5Du&-qmFnH6|~;IG~SOnW#le!9@C zVEf;tPqxwC3x2IflptrOT7FJ;g+ruM0v~~KErxrTfSM~sa>|-=)QuheH)rUEU(%x9xitt zo5#nFRhm`SrK<0TO;{w!w@qQ=G#|BWv~2dlrIVxSQ5oc9G;UdXK77Kn0_LY-cuTpH z`vMk-sSUj)RZPxv${E>88Xp$%t-RC_zTAeU!LG6NhI%&qmz(ZCUJu=HNDeXwPivFBaxm?wgvo=wtqVPX0zs7i+@Zzff+*)bAC~X!(@p`T2pJ~CqT@1EPBJ#S_nW=F)5YqnikuoKt8XiSsJyx? zedy^>OiXbkOOxqr4<9+zE)hfVz1UsdPsE!aR>V1d<3u(^H}8UCuJ%cU7V?fEd5a09 z;;&oKBvu9FHoW8MkYeO2zZ!XupBI~Qm6q(HWM&DBCmGj>*xfGTa-gzpu39%ricJ%g{JmqJ9CX1DFML+^4M#z{ zEIQ-T#Ry>{Gi8o6^o>2l*(E`Tt=ORhExA)&l(f85a~38nBUQ+hE1-*8gZgnVIfvif zygN~OKn98l%ik49QARZ_fnm|#@hXs{52D-44x4)Yt^-)@39b9u`CxoRY*2MW4}#w_ zy(`n<5ev401r;PSXi8cLD_cK^MDXRVU&t-Kxtvb4L z3k>4}!>V$vV0o2vK^8my2tJ(ja`hX{uFAWc?vAhED&-k(g_Cwi0@2U+7hoxiwjezD z5+mId_PHCZ$yf&KY-^zzxE&5Ska|hZj;b&h903oGFYmP#`FcFjh~#aE(xf z$enggovTMf-=C~?F1MHs&6+Lm=#@?=&7wK>DV~S*JHuPeNw73<^KT2=XZO5 zfE!h$SVHjv&=f5{Q%OEu&{%GEcx3|<=}~yf$Hf3AF;U_DK6%c<{YI z1EXkJOc3Md?{!e~5%Kz)H+A7z`S4t1FhtB09?Z>Uaz>5ZuXUGdZc6R>)bhHLLAp!R zb3Qh3MTrU+9oWUjY63xHUU=7*Exot%ro=7XwBPNQ$=#^%E5WE+-D;gLLPzynP!kNO-nW};(f(*1U>2Ys_q;N#0ta^sp)oiuuGe-MofS}rVI z>bKX!TC|nGh>Dqqd1T%rnwxlysmt%1p}!9iH6jCOTFYa1i>o{{(52=uPo7_MBTmfS z_(qzl)2pE163+&GoesxBkXlU0DK)H7uSQk#u2o7Tu*vTAG>;wVWeVllRZuyiaYy3F zE`^y{aMCq(D2hD;fBxyn>OSGA}9 zZt?n^-JmagW;5}bVXsWI>O-nFbig54YGWJN3KKWuAJ)mtvXkk98hpB+IU)07w>?=3 zaf#KbhQt-HWh(aV`_tTpfi*r_Y?idK&FC#bCF{`D^E)8igPadcb}9J`nkK5ZZA2|g z6;%rj$Obn^fT`=A))}ZR_<(emBMXnc(r0ES!rtk+Z2Jct~(I$L5DOQ?`+qM0n0}ZxWHe~v# zWmTS42_NY;5kQP~)Qc$ENWSX;G%eAE8Fv4Oj$wImoJ6`p@rjHg!0RCk5XC)jwD#7& z*eC4w%yb0y4CwUN8p1pnEq<3aWwNf^hDd6_sLH2vovczy)z}!jxMmkw_C|!v>chz7~1P zV(EK&Pxh-*eV5D2HNt@bvwEM=<;NQ5OQYov{eEe2{Au(S;pftA%N3s%ZrqnC_jq#q znP2m~`RD#J-K#v;ouFxS(Q30|(9ydkFaPsI+Fd%d?UjJO z#o5oUr-D{}W}o<;J4jA;efQC`H5)6=>`h~~KYROcQlaJk*gL7V2Ub5bUC-ET`rg0$ z=+{4K8zY|CEZD!}pT4URaB->5a?5Et(@*DTR=EJz126#RqDDzj2owNk#TgprKcD~a xqx_wCpvZ^Ce=ghq+xnXu$gZ2T|KHjF><{*+W?FBZHXA6y;OXk;vd$@?2>@+3rfNFzvtLrD!V5(1Jc9U|SKbmK76NQWRD0@9&0N_R+?LxXhZ z+4%XK-|u_Q`Qxnhu6MnEy|Z+=_c;66Pu%x)-}iMr&j+>V@4S5+*=>Y8paD!E+7zb``ur( zPKSawAkapeB3xR_-EjL8mis-OlM@TK+l8I&I{qGrP~i7}KC%qj=!}dz>0sOdFfBqF z(K=^f4w;xso;v(=>iWu321AjJ@zc(sukY zQ26I+^#9S-irU5D*W2OU+w(Br%fSrWvoCph3CHF|Mgk)}#V`CMqYR#)>8O97xtpEE*Q15f~ zlns}(*TRN*dpTj&*49(T1xA&1>6sW=8ls{1kXpp7pkcyY)V)~ms^xQtFW^1a_-2p!QB+)0By3PA@jH%1AWC+l$n18$V(2Fj!>@A6q?uH1A$?)NnwK?xWcPbMFOgob3fhS&^aNLA7XR@KSHdE;nUS!<0CVrn z%6QDXN2OHLDEm6)V{w%6xX-cb(p6+Uiz_jcY(9-_cO6h!lem>}n{pzqmMbiMq?%H0BGa@liwo4U8URPYb%b992htvO1>hZSDjAZux^8NeA zbiZW=8}|gNclp!(7P|JzE%GycGg18dz;eU>8Vn}#r2v@K8C9P0sg~$|!wDJp;cc2^ z*Q=xh3NAV3{Ij8>r(djAKUR;`Y(YFxe@F^ts`EyMLQ2*;#$V601RuVkUWGZ(MeGNy zD5PRHkCqDe=^8SANVLLgSNqg42f6=jWKVHkzeW)HrU6ILUINxb^U-~!siEvo_K8)F z;E2LEE(P+QHH*xNmxwf(v5HaLgD^9O-z#!Y7A9kz4fDB>q+3I#=+~&*o2X5}KZ>!d zPlv{^%ApaA9b%hFAod#R^U<9d=M1ZfR!0e^#)j<3)L)DlnB-hK_Nnj0+8qZv2Z?hj zjXWbb#w|@aX47=RtEBf6JY)bX&E8uk!8TxoLfQ#G_`f867dY_P}#$)y0a&6Hi0y zQID!Z+edE8x-DflH{LN(qPt_>yHfVOhb*V>&mj@l%&Kuzjs%kx4Z<3S9q%fdt4@;( zY2}{u`sn0Uj;l526>9+)neZ=gv0!I`-%a{%*V&cBu(NW_88}`v@o<~A;Qp4w+X6{3 z&idDGLq8gNG4m2Z{>Wv{8k^Qz{NT49%8DX9Z;^cT@j@upmA3{e^!M5AIbIEiLalW{SiSt2txfEHXr9Pr zhW^WQe9cZn%^c+$OLx%-DdEy~w#%G6kKK0F0>QPYx0AKa^@OJ8*rVU+y}7JH6N=?z z*-Ek+cKuk^`a%M5t+)n-2`~L4QdU(5o6qpy{*emveod8rrhWB$U}OmIOMy0kP3Cf7 zqgc<6ZmFfFR4=1?%TFNUQBhx$xa7}u`K#RJXw3@v41{yI4KgMBofa6Pg-q-*Uwj!& zH|J6LZSyNJpufrh^$X*?!6)Y1CliscC@SOuAGms*FD>!Yir4Rojk@|-g&^lX-HJ%- zu`LCZMTlm;K@W}7<Yj%4 zp1*;rGmZ*yo8tcnK_3X=tu_NOyuvv(5b85l-we$g8P>u#_1~(X#$TcDIxfnHYAw(1 z(R=RaiW9Q5v&Yd;wWu9?K48D?%&lwHeGVXi6NX`|r1@D=rFS;e-?Cv#=+9I7WnQ^! zIi6g_L(>96VRH~Kv~Q`~%`#v@$ezV!#AUFVrQI1=vXOPFR{h8)4F)ENEcejaBSDpw zX3KRj>!y@-GVCLokY$;gp}&6owB7oKpOn&1rHLGAQwhuMVznZTFA;>75pT{wdmxbu zLOr;{3OB2O6~jE|@|yt`FMJo2>#LHNEU;G7-T+LF`Ext>=A3&R@HS2jGI) zJ543mCZ$#>88dzvbPPc!FbDjERy+<3il#n4<|YwkKP53<54w6<;9T+xYNw8VOe>IK z=qXBq;q0Lk;&W=w-cfHkVJH4{Uohi4(|buC4MYEO-^zLP`fE9EQof(Fc~6_ZKXRXY zp-bWOaqt#wc=PC(x+eNz{SS@Yl#+FydHNlimPRN28eQ-NudDc83PF62`rojN2QCqj zao_A@J1JIKZN6aQ#s>2jfRG7iySe^?iJB8QC5=51lhpQS_!`yU{+P`zI&nVQy?kFl zAlw%7$nBBmVeN>>>I$7y7O{HjjsX081aQwb2{+}B*P)}{o#(~5Bn0=fpG0`^Y0u@o zpZ{K<&~1HWo$kgWL9Mzh?CBW|CjlaKPsF_70ZPvI$2}>uIExIg7|BKPCnxF8T0czG zT~X#JD`A{!vf4Xlc-@!*89+RZ4B_Ey?nW5-%4E^RF3ZN+*GZk5 z(5?M*dKZjFjteOF!Hx0&Ru&g1zCdLmTKfwue-A7E%Scjae33Mri99>g zIj6n&ApdM%9wmfM&lS?`@rkUiLVSTQCkPMh6ts}Lz%|9&6Lc&Q@@;5T_3>nUp-OLH z6WATI&*W3snvnC#;r$<-WHATvwx`lJSw~jG&zv7%Oz3Co5`FKfPpNPe%xK<%h7bNx zB82X{ep=X>Dd}xyH)gwP`B1{ zBv2a-1~0$=`17JS$}|4!bR537ev0{bR?V?hhXVC4XwJ9FNi>T5dd+Kg4$mZ!90Nx6 zFLom4WhwFf->wTj-AVMIOD$Z?H#?h-&(cXFIUrVT;@&%7ojJD@$FO|UBMg0Kj%>NV z6VE59+GP`YDaUbSrBkH)y@ZxvpxVi{2C_-{gbp9hA^-8;JiF|Q535+RmuQwMwU75E z@TlgHf`c(8`KPZ}R|XBKsSm8pET!JH@u+3E{S;S9PO#48y2{;5cd^=lS`J4386-<* zY229gTq9`Jr*IK))18C3F=#p|TY>G0oK_v2;^4J5to_ESya$~Q9gmPCfKa%UkOgA+ zNWUSp1>YaXOlQ41R@yS9rL8&L7xB!|qjNCwb97E>)tDdsi{g11>fiCCohx-ikk^+* z)iZcdu9G=gjLy<%Id2h+`^JoFh%g=n*{f-q-4Sv7&+^Jdi*l#3)IY4y==`yy5b^0j zYf4Z|qHEgIdmo4y$=;yT@`$ai{H{01QbiWKKTVH&+ffhgv$RC>K!$?28nZyEl0^#ak~Yng}Orweg1G$ zi=&a2eOYcdllk^d#Z!ZE-oE1_OTA==B*@&<;(hn1vq(NNn9*6-TK98Hx%{NM=r6@X zD&(0sW9a+qvmfc%(Qkzdh-CM0utEhUu{Dnpb}@fA^7iT0NPoWPKT&tgP@n%$KtWz0 z?u2gB8a*pL*6ezACA9zqt1-lxr%X&fBwpE?^^2ov?43ELDo;YYAO;d?Ed+@LmrkN* zMoH)smGR9g)E}kpW1KkZPV=x#Uln~P?Yko8Z&sU*3C_%vZ9l;WI-z?K+fn5?7kRX6 zkuE%lbq*)Iy*)I#%{v%R=~lyAEizuml{Ju~hC|*ce6i)m{*sLj_ouVPaF<3!H6yi+ zwR2kJb6Gdh)Y^Q}_`YNfDja?G2w6HB(^OS-eeJPs@?rAx{!N|2Q{A6x0oeD5#HfF4 zz7^{l8W4}-h8kCWk0U3l=Z{*MJC(|#G#X%z?X;-3`@9r;y}oek8#(X36*LA#`l2BW zJy(cQHdA0I$PdToAdLCpd@OIr`LZaWx`?x4XalpZ4kV4F9>5IG=-V|RWIeqr2*#+W zu$Q0T&x_Hc!5v(DIn5*nSWm{wRzN+2Akp!+5ic0fKUZyw>-w|Uw%bYGChbKieB97h zfyJo~NP}QuiOd z?8Kka)g%Ea+;7BS6k}Tk&9GxivvVqRr<(d5`Ovy3ggh?->N_$tu+%|(`F4L+Is+Y9 z1YMCM03#S1lxE9Pnc$GfgXIQgn7=9JdRij3t4Eg!$E#|p#uT=%aFN--CMbqF*1@vt zJip};??^=phE_ztdVAncsIA60@5?}s48nVK&_GXdR!6AfV!#>0ob4^*U}~P? zk4p+s#TLt`J1pE{C8FPgKL#)9U$E)NI5TCulXy6Kc^=xe|J3dx@`BJvW>EYPUABk7 zc>x6NePJ(CnZT+f8c6P-_LZ^JnS182rmm^7ABLIqvnkpQyx?^$&z*mIcIHL+E;Bku zZ>($_)FgbUSsQS;^LA)N8&;-ll^BB5CE_ z4W_(9fiJ#qv}-X;D>cnpdOpoqLEy^5#5S^-D5iQ`n5&azel28D6Ma)4@tf~?aIgN{ z3;(qxzOGx?gwInn8;58q;Wi>&;^+>C1ijFB=tTW+W8Yqxfj=(MMQEAn?7AV#t>A=Z z$H<%H&jyW#Y&^ZJI>fTc#_Sn-xT{C5IkkuP-{8(DDp7fJ@PSP-!mdx1Jez~R5aMO1 zZ-f#RzIlJ)EH3--T!RifxfS(FUHiZQe^qW=t5odj)^5riDx3%5*?j!^(RX?+yyEq!3=V8mxDsm8CKV!9M}k z!SuYldN2G*f5dk@ZKzs_{?1S9Bhy{VN`K;*(uc7>(MSB;2 z^buY#sFQC)+8`r&9}LA1v7AUTfKS#UTW3w~(+hCe^R3Zg5A9 zA?SXOHEm)E;^Mt$-ha_TpsaHHknRw#7Eb2D$v0nKLy!#+B>ug8vLvksj?w634o;?T zV^_U)vVh{ETT>aZ?2#gEScjnfq)bk0iB_c`H^Gx1L_>DYB%*TtV&Qo_7^~{wp!Pb^ zq&L@v(BQCZ5``3sCALD9+7Dw*6R!5f_5$EP#<*WN4+E@)cN969hag6I9d?K=@eK7@ zb}zY4D{>B8<4ea)v+EbE`ia<8gSf(o4~ZV%4VQZ(_L!QlwW^=vT=?DY8fn!$(|nY; zyZL3Te%IEMbx>@~E$4gO-213%t@j8460^%Fn%?mHzDM8XbGCv@Y)U6SA7bcv%uNm49H;B&y%?q}Fg$E( zFtAE#aOB45A0InddJ>gFC-=4K1c<3pfqq5|QrMjIh({1#)NY+<`B8@+y~Y<6l;+{=T@*ij9`UR`y;-OZe;7XKz|m*CUrHP?Kj~w-wl5EOy$zfnllNP zsfRtU3RsLrPxM!nJ_?u0F2wY17&LwTcHncr114f6jqamehDl*oV7Oi>Kkz~Rht=k! z*TX~`bPURx>PFudwPt`UhBSXG*sv}$#vQX0L#!1~fCW^x%hwd&Bvt!IPgm(nGv#g@mt&NSvHMC}^} zUD}&cIJ9gS)p3g3jdJW}-?7(U__TT$pgJ7sF?98&nta160{>bBF+SEsfmHs!yK8nn(zrKfZC9!C+xkS^(}ErH9*& znEa~cvHAMrTpX?B8mcz%`}}4{?>f17o7(J;E1Sb?KZ}O8DP|zEJem~Pvf{6$>;j&(Pf#@mPsby)~{`+5kg4wVMF3Vp1N zG2oby87@e3{F$m(1Cs?!u%mqHb%Bv71;Bm(Ppj~pVoY8u3-tO~1EFow>*C+B`kKoa zdsTRw6hG!r6#}7#zS?QM{Mz`4=wmsNKU&0U+*8H0dCd`GNNwn<*Pn+D_K+Cth#iid zPpU@WVlJWt^Oq1*Ez6Y`p~b0<-ZVb?#RBmblTQ0LvvKy}I1#Bdso&^9_$TkHs}rR2nOas{YwH?y^TSSxrC{oEpJq8x=@SE>-{3hxUbwxMeEe(PYxpXUrY-}H8@4?(bR+n;6;nIcxNn0rc`?q$WD)snf8Qd`bId3{jKHy zlBU>TkEfrOI!>$lYnTGPT4&wQoJp>HzYiF@1dOIteCy}fnbONURgV1?w|a1GRmTW@ zD$6s><%Z3Koisv0(|vOdX^6(Y#dxN1+8ObYG6-oVDf#8Alc~GmFE2TxH+)k{W?zKz z(KF&!uG3kY1#2}%Jsh)>Qxc~@j}y)_(@1ibQ02;~{?0J$gC?lAc!%fID;DZ?&XGmE zyh?L+se$s1hB>{+1xe=}l}-yQk^IT1^1&D7gu&kSWvaBoBa5}__tR{)J1qp)EF_H( zjB|}HKDUA$qLs5zxh5GRB$LEe;WpOOCvjKU;h&Ws&(;iy?fNL2aR@V>4#j#K*lV z$4ZmSZ)A=5Z!QnG!DpvKvVsXm;Z1&9su_(<;f^qU?~~Y9voF5POvT=55dSS9fwD-I zh-2h`Q0dH%cmTwGTp8jj}RxB99Lr+7xxQ__A@t^$rv|#wwf35_gjxHycF~3`IFD@lT zSLlDIA*lOcRY?y{(BI7hGMe8=S=zJGOiE~kuexKto+$=`}zFLkZTA$dbHeaP=38&|{ z5ya;fM6fN_R=ZliGuLlc8&q4+;*gI)2L#q0_3N{_2R*k--;dgBYx|V|xplMqugbbo zU_alWcz(E+Pa|xflvtp}F&+!3n|F@@-mS&q75)G1&GpscpjMf&^uHwPhK4s1BS3e@V*$$OS*~Xm`nvbc`P$LR zjKiqQpak2=Y4h!;be|KQ^V^cz`H#ndMa6F_KD$$#C7eu>ip_cx?#h1&V{YB5=wpuf zXU8nLupEuT^tv11$>CCi@9Xa=4eLk{>3%mJ)AcU-#jgsrir5|je)TM#`QrL?q0nxo z=BT#$TE!UGZ>Gvhw)y%bRrEceZA*yWO=J-*X%-BLO{LBR4|%senK{gO_aH!Q6l#g) z-CnIpjeUD-ldLsY^n&vPH87AD&jH0cS7W6BOu?a55PkImXd>YvBgFL#?0P58VK7><=zqYncfz@kX;p7>tX!9yeIBsk2k znjID-KR`;1nhk4a8og>^qkyq(%s4a)?n(o4f5r2=rQlEkGC!b$N^1k3$lcL?*v{6e ze>T57`6D21jwb=i3Q3L`aMmU|XJKROC3p~XyD^$?=6CUd1eimLzjG+GH7IqPBj&b) zn+DPVancCcf*&(<3E;mW@};xDre)0FZ90=##68D|<~zl`l<+uUnQQXVdua6q@C5F^ zEl`oAV}AEI%!Aht+%`3ehte1Vgp(~-3@{`s=91g<6joXjGU0SBC1vG%3T+O>?fb?RwZQu1#sh%K zgT;^{<=5nm>Y1DSbaRFxIi$|gtF>2e4a6$+y}meT=GZ!#b6Xhx8AgdkwT7y#-9uthS9DPQm#y=k?FfZqy87 zwJ%xLq4~N7YG2fcr)s$*j&gbjLq~k#)1lfkZeSl7 zG&Duh=gCtzTzn4?7mDEF;rVFZ5l)2;8sLD>?7{CdfXa-OL$O&v_xFey(s-GtUlR#9 zS5o;0i^{o^5qCy1YNz28capmx#ecU~0`K(zQke=HM$drE@{wGnUU28{pJA*|f;&qY z#T7A!9vL>&d2=`3;ls}N-H#a=M%FH{r*d`Jj zmaTmtO_Vu8(l8Rq9=#$K$dwIq_v{<)CyV5wU&Em_xgi|^MV1;h^JLHfD<8=Xxa}N4 z)Qjeui>)fhB^fG7j9KGSf-28rYtOY$Ze0@-uY6^h}d~NgD{Op1QAo80qD~R zK*wT#@c0e~0aE1S%PQgG6)Za``2sm;WnM7DruC zM;$Vw7iN&E7dn3hCZb-(@`vQvB1VdUStuOA08(G+i7!;klo`uWB)6V#sE4eQ9j_0| zQuCls|Qb9>vp@VXIbFwe6{#S0i8~a2SIqlNMmDrFcU-$BVA{R z57q&qZ3jV<+&>9PjQnJT~|wiT7A62X30!@Y%vhG#pm=J{^7{O!=w9q9_T*1 z!}L|fNVrQt7vsaNnF>{mNXImh!47?{G(TFlQk(&jhqro795T5YQ+Fu=(09vuzgr(_ zDPQkR%vXIY?5|4GxO$O843Yr89Gv;uZw#Yo4BVwzsr6d09?CF33qU|31Xk-l#|&_R zsLbKd!Hv@Mn)%Oj0f5h^_RS_lRzWt<@R6UxX@mhf8%0#4ksI3bC|$qZo5%KWHPr$C z$LoT}4<90cz{Jp=3JGm<(n(g_B~1wb1OAy zcEeU3TI%M57FqW3VSIMkQL0vN@=Q7G2w1Svn?BHe0^7hhhgz}B*( zB(u!>Q`EJ!wXJgi{IbgMy(sM4LRZe*){e}d0x==`Om2v>^hB`z@gBBdJO|zX)${QM z!W2DeY3?NvH8eU?E3RL6Iy}XR$WfwpQ-hcEM zL`$tgOG&NkkDi_*iwc#yA5ntFX+^VA4227W;{8Ls&ktgno13BQd~w|D>`2M0J=p+t zf2>jxx(#`q{b>w;PG1_AV2~3?R~!v#)`b&OQU!ZMOJCFKOZd12Ys=(}-Z;z$Vg>A+ zF$_OG3-fB}A{Z9jW)|~jFUh18A--}DA0QEFdyCnXHG3vB8zpsP*>nH|;Cujmdnnpa z1CXK*;o;tk?>#^{>e+I@pNImJb_Cg{7IGK@{D>NA8C@K7ygfBqx=M1lFp-vZ`)`tl z-LqOd8tyS)_;-wh0=J>iWyl4!z#H7glW7})1?)YOZF(~_1$R%+o*~F8<`TbeACC1* z4Y$PUtP7c9+xPU@*;$qchVh`!C0d0aVWbI+&VhLQAST+!i$DFDOr6C`w}SD=S$aa= zVmlh&x}&kR5Lodkhgim=F?C^_o1~L41#}TuMUyZY)RBNpKvJl>%3d`?)P)HhMY=YI z2)c;6Bx>s_iT3MXhri1>#>wxmaxY0d^B0FAELr%+FmYu}KP?T-BaMj*3SO%4yE)7_n)ll6egan|FV1c!z>9oYH>vf85xt6njwt(Kqi=A zYtLD=&r0Ua7bjGQDj{+t#mSDpkY<8r;=vlt%F{T4Q!F6v4k8P*RA2t>IXGRUP5dDc zrs2MP^tch3`|b^9*&j^e(*$7XS=>p&!TG1)Y8edsWb9$eI2s4iPBq9PuB=IGYwOW3 zJqJ9@Pa;k-gf0khf2$1%-|`2PJd|)ok2o3aHF#fI3tu(cM33<860OMqC7=YUvJ)6q z@_C)u2GA3}N9zKYv}q#zIa(;Pl-)DgMyR8vm;;Y@{cgnsWrAH3I*(cU$)X*vBRbEW zj21;mOdB#nJt@O1%e~(7hTw~!6sCQe$YdRJ%)mb9z-%Xkgbm|qxa0Zi4|<(rCq3z zi`fRkm4o75Vk?+PxwnfPU{smV5x3y`p&7_k;kCbJCs++Nu}>z&YkM~-%%pqf9Yz{o zAzqpW1(4Nf@utOzfdg?EY!hf=C~j4tT3o0##f-(o23~z=yhF4aSYX!d;pL@-ij*Tf zd^3dVe4`8J8JeVWL^Cdv#}EgJ^VzTug~=e`Pl8bW z0zzJ#{7JpIp|=gHmy-)-7Z=yhQb^i;Zz7w<2001Y9dEqY{6cmttpk>5-J4wQj$QdA zhvY)b+sk?m597NEjsl?+MNp4;nRy8-87tn~1UhTeB{n7jk)2)@m04UD zmG@6N{hUGMRYN$KTzkA?0UX}~H-^4W5)VVr6Nys`U?4`&5fW5c56jET1wwFJOyflD zsNvds&l%w!AO;;Y_!-Xw#7sCN9HnpK{4QPG0~1+BE9;U3$ZoAy3@u>9bQ$5=1Je6_ zoakVxGINleJsJ~a;DImdrj?(u4LLycnIxkCNNaT1?k14>3)9okB)`J+V7}r2Lve&9 zpB6%n3KU?8wHjhvFE~_&gCYJoEy)4{hE@z&G=|rKuMhV2q6Z6&e?ucNVshd+Urhts z7vvjqNL}8Ce^{BMykSBQK8KxcD+FCr znMvQnll3{lg^NUqn5T)&vUT_bH3Ktc$fsNowe>|yhxe|3Y3%7Y>MP-UflvExVay?N zL@qKz5Ysb^c1sYZ5}(YsG-ws`6IUWqFGoHZnzw+IcOLt z3p6uH5Tc6M;3Z$^X98F-+Ykt@J)iZI;KTNhpFi=ITfjuE=@x_Onmo{wXqn*gbXrIY}OVPlUWIuBwO|%I|On4{tP+oKb)?6 znx4fsYA`sXKLSrh?Gdwzzl=c8@&%iKp$8yzW_16KVNgyL(R!=^+)GWeZ>^Qw&~Caa zhSKo&25;X~D<4%B{~0k=}53^MIeOFZ!Rd;ySQ z8aFpLG!_tnyiUq4T@>ror){5|1#ROA<-T>P5cQ2&Ah==?E8^mYF`mOB z2eyoNbavNN|-f(-y&qZdyX?Pf^?{l+dfb zQ`K88;ZI)HV4t z(27akr07Ws6l;i9+(V|p@Zt5U?m5(lZ-=6@_J4TOU_=f@GxfY~QGX}A&1Wi=V9xzu z+hnVl+fUoD;~=C&HMPSz*^Js$XWDhjq(jotz|jX(j*Yo=K?M$eg{|#ts|}=zA^$e4 zclMIzAu;L+I+9SN&4{qtA)iN16JLV~K3#!6zdUT|C{pvRCa z1avta0VVrgotVogXKK?zN^RVyB#p9}(q`7kOsF>Z@QpZ!iQHo6pd&$w?iG89M!x(; zqVdnnc6-Kw?E-v@aeSvT$BvJ{qIr8hoTv|igq@!mg%)wk-t>tH^y#8X6C^9eqdI7L zZNRQ8Jro`v)d228(u@X-rm?87$Np%(6u*6xUf5wCn%~v1H+^G~?je^>Tkdtek1;H) z=>3{(@)z0l8P@RajfG+I^XuJaxYPrN3wzFChr~VtdcfH`CvMLSW#-!6nDmOuCpf;D zD$^4>vHsP6pigxe;3&yWAJxg?)WEOH_VkNZz0njl-2GT;{&kq}-0$~ReBLy8yNsK`r8WjUY9_;p#*i| zE0W0{qjqOq)^grS&vGmLfN5A#sdP`ZRa1-hz^iNtz_`ml)Cbcg{_-fiJ`H<69ii)^ z`KrRa-A^yJUS+v3Q~5&zeVa&nSjcXPe#r}P8QA@`>R=HeJ7uOZhZ4B7(*;TctQtbwBGVIYTrpTV89cuqz} z(nn8CvUY1yDU*&D96mgMXTr|$^k|0Suq|XgKN}umbI|R~c#(W7rS%jjfvrpqcoCWE z!h~SHC!`%>E?kW0;CQ|M0}(E^r!+AFEsO&g>5HRm?OGgb!0eJX=@V&F30lXA&>ppQ zM9{D;P_$v%$Sy4{EpKk7T;xDHR>SUdKl?LNt7?Ey4D+XHa9+?j6K?^iINr-ZJ}&=> zLsTXD84G%_0+}!B_&VwP5g+4gGbp|lkwfkE&E@V-9bG+=_T^rj5VTu?{$sSGEANqp zzT)N^Syx|{;ijuo#7ObR*99a`zYy9my_Az=Nb&^-C%CXe%kOGG#8qSOM~agwjbm{4 z4}WH*G6e$K6H7TA6ZU5QN7G0oTHxLdgLq-SDgxjvZ84Z&Buq~{4p-xMaZQ{Bn*xa( z-P10PHkj3d6*IwE3P?hgRG~jj?1NY5hbuO}u#Huc1)es&`}B_9(Z7p?KC!l92Y{Bn3i8*n_|&;Z@))PPadi4d2ID^EkYfs$d3M z9AD;(aAM6!c=!vB_&JW5kiNa2R=7)+-k-k7N6`tmnm_7e3?!Gmhxe=6<`tF9WB+x4 zG(v};T~x1xkPtEGet%bTW&f)<+50DT46V|qCvTrIDW63zlV5v6vtSlDsp;&MtTIQP z!tgbBFXcR;qb;i}*>NP#Ae7e~fCjB(Ql^Zv1ggzEH2DrSazr{GrdbOvY(p`%+##+# ziR29^4NaO|>MC3&8X&oo(${mW`zSmCuKW=PM^8iWIzZl#k+#id0%;er5g?QV9B|dC z*c)NNU#(F5-beP_?@ud@3rYcRAi-klQ><9EHF57d0SMr@p>Ng4d=2+O?bamgBhb-h zt3Hg1dx*9HA~OaK%M{8Z|4F9>Vy5Ziz>NOcrLpZIa7s+TTGy8b5U0%?GfyhGt^m^NWj=#H1GN z*iSfw5C#{oJ%P%b>&srpgcZj-X1nlNJO?@xm`M_n&_>JDOZhn$i%I|4_~2aAZ$j8lW%5cBB>Nw4g!Cg4hXsG>Eq{-o5I`7-f6%sU>Ya{m z9B8SmDK)ha-&yFBNKix>ur6Xgf@f#?;IAcBaNI*Qo>30&c^%maiE5uyB~- z2u9ld=vAryW;7MkWB6p^;EHaV^!?^g+7N}|9U7V#>iTpsvB*gylWQ5F#W2)^(pc2 z@G3l`xz*|)k33QQ(KKQwJVZ!DN}9Iiv@&2p^jkB7=h3^D1MvLPcM$WPPz z+R9#y=Bcm}Jyd0eM#`U{!EiHZ+LWkCO#}{txOFO)o9g9N_5HJK2xYw_`6Zlc68|V0 zAsE8WiIQMGzED|ppRA0iR7%}pUN)6o-fuhhhI{USWaIp-;L<%F`w7244UDMC%Cd!Y zlLyg4h}-7C{D~EgR?0^?G;*hxy?Dtnl+k zdw16w61w>C=L~oH5MUN`0OQQ|SJeZWi25hJ`+M-RNdEtL4LX>DM}P84>F>GB7rJ-t z?~UIA${PTBcSjz_#%#*~-EfHB@5; zhJuDvO9sN)GqcptdyleU%ZtF5usdC~>h0Fc&G|;LD}xzEH@v!4?B{1^L%YquVTclE zOu$$Qh-pP&yA6lL*C_hi&%p7Q-M zt7J=fb5gfgqYYXsmOW$)DI!i;aQbHnT!rQ#=D}W40Nwj#Yux!4$ID`cVE{v?a{!RM zcvYm;y?LFPaEOrKU6LiWtF1BQwjjFVH)+GXSO99CQuUfgqp%|&kOQK{d&~XbKdbEL zo(PIn=H#$US6Rh2sBy8f{+8*j9GdHRGw+sSdBRk@Co0iJpfvt(9gEj%nQK9V%IKX05Y1M1~?%9!U5qT2Vh=O{0h_5 z>uvx5etZke1l4fzV~b!}fFr$4N}SB&Ld%MhyilN*)he_F#FW?qKxkey;;^_^dV98y z>Td?%Q!bd7*IyQZts;FM0HOf--t7Awh~eDrTmUri3WGa^IOt9e2M}=T?l(QscVGw@ zz3tsrg%t+i@=jx~x!cj;3i z>(>I;9IpSUz6_qup%D?e7sp%q7xI(-Xc+5v%62}(e%M|>LeqW8@zGJut71mz4+*c6 zoF6}a;CQn(A)~$lcedZWPD?!sIe?v1%WvA=%bERf`wCdO%FWHq>q_(fEFrsTT-;0< z9{sxLbl(droycQfUtd7uWAg}O^ZmQKV0j-9==B=y^=fQ0DGlAF=hJL-&R^V7m+|Q8 zCtCnSk7df@aHwZtasvxD9?~XXV?7>EJsQm@ ztAS?mXS(`myK*peL4^Q*SM_&D$FM8bUlB0Rs-mq;Yf{E6wrEuX(tJu@Hl+HI6S zTDida5FEGnvm02c>!)yyxiX+atrn;}#I$7k>Lc zDK)bXy1Tm%_HEtVPI3J2wjksK;)@z!gd4!BuI3hP0|Aa4uaQ(l`N%dL-$upn=GI|MnD#BZ7tnEjg*JeBy@I;Pmp$^h|b7lelXVT=hu_SOAjc)KnNZ9Yn(ljndCa( z-`Q~h%kypRK%5ZS*4Fl>NZQ@K?rOfyDH}cJ0Wg;zf}N~EG5k2|Cg@;p!-lf}#>dK^ zeg9NXSne&hqKhCQchyip*YkgyWaNkxja$F>I_ohey$Aq9MX$c5LlC2@1D})taqQT&q-HG&u`ovXh6dJ^=LT!yO)ocDiqFsiT8Sw9924p1&y6c$%IA8!S+W zVS@VIT#bG3K|=1x#ee1`-?TGQ9T0w(ui8$fEo&E56M2oE1660%Df?>%h#*)OknrC`Uus;3y10qbVBP=X@9?_}H zZ6;29A;Qpx%EE_F0qo43!EOLfL?*a9TmSg_3>CGSd)n(>Os9#@04>{#dLZ`ho)mt& zW>c-hLc%fP9xGHH`I@S>wO!-&a~%J#`@-9pvw@}0*2sM(F(g}l&%bLr z$J!<|EG)J_)7!hjqg!H0F*|aj-~(C!-_^BEgXn2cVX$>T@)f_y0hp1YZS|( z9ebZD6?=QpTP1#jUZPxAknn!B{DSGHZS!-f>rL^6Tj0=@C=EE%x6@hTwCp#=n1F?4 z;*jL0=KacZTd}5lNN2Sry;Ai?_3xHLS-gd}GS$TuXtq9-w4dcr??!*7HYW|%pn!v96Pu$jR>jMtfw-i6c3`A$X1A_)W}&9y z##dH1qo5t?J%X=WWgW{E_CKgP>wu`bsND}BJ#=?VH^IL{4(mzuL_N6?R(4UPZ94&f* z;XfXl3&A?2IlV8E99UE>6WJWQ6p%Xo3!xv7wRt3d7JLwP&~VOccZ=QS_~eFAc4T)Y zcr}uRAw2=lCebh$Knd^*9dKDVIabpyo;f-cq@fJ?xm|B`t;YP%$XCfVCc7ZK0!VBp zm@k!1V`Wqn2l%W!)^TiJ%OwvR3`_kW#4BidIq(^$l;ehXRLq3UOkLJJCX+|`!$GO4 z{!xdD=_;mu_d|nh0jY;vuXR&IcA>HAt?H`n%{tNb_w0A>_!!+?cr_Eb%Kw|)dDO|9 z&}{g&o+pRrr-1ZxBPPcxN_$|<&OsOgGk|if1-m;JOPvIx2F4+mxJ_dR9jzKdN&oqF1ab$yg^oc=90 zwt%jN4WHR)1|+9oC>+|P}zR_t+f|b_v=iUA`?6+&Sem3 zCig`T>h~5GT2vi`Cc%Q7b zg>mUSsg+b*PGr+9Yp>M+M8Z*TkiwkASI5P8E0)30`&oVII9}u+t z3atM)7kY6VO~6;7GqZ?TJ!^p{Je@c)jt?fbPk!vVeCaCiO}N7MSWQ@zI4q&>@!wq{ zkS)Jqd5Z40)*O`gqibc9nUXPyfBNw&IyIKNzAD!A2k6VTyB>J47P8I|En^I3SVhVi zf_bl4$2m$iT=u6}07{^43s0+G`NTkAm+D;K)JC)?2LR8gU1*^*1IOGt43!GQ@6+W= z%hbmZ03j-2XxJ4UGbB}h5Mxk2-f>rpr;pH2AkR;uH~^}s=h$3Kk7ldApCLv1?kY{U zskrR5wq`HYO<=YkUcMq;_Iariub+pF@sONXn7^385dWF@M)Fxe*hOzD+HmC-wOGWKh%wcIrE?6QZgOheWXsxTd>+6mq$17zQuni!tA- z^Jzp+2~>tubCo18?)8iV_stahkEo08N5{RD6BRa)7ki4ZyBa0b$~*Mk~Uk5QbYZIpI4^YYn$ln z+a<=`Mz~P8$g|1F4<~sl(wJjgB7s##3!#nOEJjB;;$m?znN^Q!J8lv@dD`416ITL> z-~SMPE^}`KpmU3<7$w`yTdCI;JynQ&8p^S*%~&swLD+u~9~t8ii=w-TCR^d2 zPNN<1I-WDZpZmC#SpUsJ#=HOmVy-#7-W2q}Ikc@xeSO>nTN$U+b4MZ{CrCT{*}eUb z3!+7UyHvQkm=pcoGdAPuXXzB#o_72d9Y=~kN4|xbc1QWrE|NT<)Zx9xK^Xelw3*IV z;42UA?BlYLnR3grH*PjgT8Er)+Rc z0+l3_ePuadqhT_c%e<1%viW#Ptu7+m-7jlo{hyhgD1J4fJk%G6iz}d0i-eN6NvgLbJcm#d?g`$6QXZJ>%h<8z1pVae&{B zw*!QIS$L-0PcRl)_o?q53qRo9g&m=)5#Rhg4DTntr`SAsCjwiiz%;?P^o|t*``o%(jnEZwkDL$aMJRDLa3%d22sv zQT$D+Gu7&4Sc7-_u3-3UG?H=PAmyr$W5Fu7L*IOU)X}^}K=0@{8T|p)Fe}@6Gwm

ERt<_ zjG6OHAZ@y}q5Sq|v8{0E`}Y<}(UuHRJ3(bX#w2Qei6`;Rg43y+_F1|``q^TR5n1JY zQRQN-Rua~^Qfx;qYKahNtDjT(sef>4{WiB@+%vI#wZnzZWTCg$2#R5D`!R{V!5nqP zw(zq@|9vQs;$IEtI;6`lAP{4oBwnsf(fAxPSOICZT;E+k_e2Y(^euWNEQ6!rW0orM zu+eadqObn5s5NiP?Df79{gARG>NuLu(d{zQcs6AVdA~0Z$EEK$GWxyPNOdO3c!fUe z?NUf6M?jw!wXks1_v-_a>?V$UVyY;U`(j^|BSz~|u{K70^2mK%Dq{3(hCqyM^oid^+LM z4P%=a!uXfv!zaN_h%P`_g!JMPg)g$+m2z)Tc_*ppRVwmI;;)adQ(v_hqGzw?v)j;P*BV66kUY)*By&yvz~v@pM9&#k-Ij3IiR-;FZXwUsNZZ$7F}*DWl980Wn3}cTnQt1vmZ)pL5pe=& z6>QBb9jIb8-xn$a(g#GIr?8{DKY*GS?#lXMcchCk@Zv72z9~)wGJXwP*G&TT_K%w_ zy(^xfPrPZzZS2(x$ZJ317h&=qEisL1WmPjAm+&Uhl@x%weeBJ*avOtk;eIY-45fn* z0+~8?n{4^R`g@f#6PLOl6r=%H&J3YBn4fCx2SM1P{LWA8Y14g%%?NALCA zOD=sEzRpJP5J#C_{^f$@n-dOj8B@kePYx)AH)WF zyRo@E0?U1Ql2@j~M}U*MUV3@d#%HjhpQoH`ut9{Q(b=g?hCkJEV%jW=~^6zMdV1I~ZoDP0$_+W09aEjNEY9h)?f)ptOW)7JxY(S=yAs%1vy z8kNDeN!MM4an0`+MX)|{)09N72>VzNIB+oheZK)BkzVDnAWz&6@1njpW)?8V-p@61 zzd%oGa7g+^mzvJbgd=hF5=<}0|huCdF z`ezSWbP`5dU^qRy(#FL|?v{Wpzx!GjLqy9%bd;=&yxGKhf(NWY(aX+D8hcCBw<>z-1U*pcdy!>OII*nF!2W_@c@YlxY!H9rr zdKXFPPoUFs`sYrqJ|$}L6}GBujHR;F6qG#4h2>i)zdAa+xFiGKd=B@jD@N9?c62g* zx|7W;)3ad#-6r=E-M+IB3WA#Uz{pmimbV-EUwpGgWV8K6dCQu0I#oy>8Atg`;TfN? zEzf>wWWG@SanD+2Y-7Y}e$Y2s#zOr0gP}~)#xtST%ilJXKOS$sxCYNT`8C2#x7GH| z*Fb}sZ!y3>AXT zY|#r2Wq}Y)0jFB&k1}5oj$TgLl^`s{q$LbXR@Byy(ThRA7-HwpP5&z#Kg~kO6S1Nk zn^wu3?C*)MNL;H@;r&%t_9*85z(T|4jX)-9SsQtt>!VTaapymFqX#B^4b`7dUv!{LD}i! zk@42ac}B?$QOBO*m!m5wRFR!l(u+#gewXV|Q94$Y#Xk%``lZvh^Zh5k z(yWb8cdnLKO}T#`LKA@hOg4BX(GoDLA-)~osFl;%GaDHh0mwDrO*8_VmbM<{ zA{L6yRr2eqTAk3|Mnvv~7*8Zas?F9WxSZjIKq`aqRJedZPX#l-nt@#7pv|T&)pwQ4 zl$M4oD(UNeT&QskDr_~Ef)RJK`)ucHZ)7h&^=8z{k%g=weiv?+b8B$g;6dv{mp32` zi~NX58xOg>@>(^%$HnB#(Bn-qx%|3U)JuL<`ZhcO0N_*~FW?thvc1t_`6XBlmF*=P z3Q9of-X`&fKj8v?mfAQDOh(YI6uF{HI76UkC0{7mG>Oi#?(O*yGt7TDw&yagB1*T) zc`p*p^dqJ8eYP?_<$RfQ!x^|d$QILBLx{c?ghs)osk@nz{KdEx2(DAdnm(6GR|<}v+eib(Xi@m6@vHxTKM2S7FGSUXBlomW#$G9g_Tq8t z`THf~Q!zu_K*gN>?exqxUu%1a_vo%K%(VW)cYQ@0(_wsVWMIx+pYaKm+mpZR%k+I_ z{uKWWyu-}tX{n67j57#Eu%mRm*Z;jG5OcY4+rF?xXPkBBZguZtwwlpV2;e~LF*5o^ zVWE`~;IW61gVgiS9iaqT z{Vqz%j>zfiGq&HAEa#WBZ^LTCmXD&^&Fhwo&U|Kc{^U7ytP(Idjkbnmi^}#bveZ7= z@Qn;)o61#x>M}p=E*-!*U9fe_e^|dzPkU)dk9oF9?wfaON-EoC>xI=ns4TRROStS* z>YrQmG=&I?UJL~p3=EVg$;ZpeL3P^~N)I9$mlml+URruj)s^y{Y<@s#yuXIJVDp5` zD{Z>=tn_KLj{WxhZWu(XtSY0c+rw_PH<+X06@P63dg zPo&b1a{(?@qN(S@s}}M3Eb;iJ8n+^dx6EuNrol%t6N#{r<}cY7=8q@{J927V-opGz z4ylqz>gh*?9VWHQc?A87;;bz#!Fg9Z?20c@yU?O{5`;&nm8HFDlZL zD?ZMYvGR~CZ&;XiGOscB_doc8e^$^&TrstE#~`&!=vu)Mqo6rWDxOKu*rD$S*%m@A zi2m|}vdQYWBeRp4uyMd~(>;%&%s`$1(JAvl*gl4l^Ap(|M9JC6(WXg$;j5>OzVY1G zbGxK;ncWM8X#V<$PM@z9zepopl|3+dal0O2re00;Z%C4o=X^3BJh0|lUk6^BopS+* zVvh%odrkH-kLQc$= zpJ|p*9zl~jhycOpZ`#_fhkI)VL_c4|1kE3jM-09db|P_0@HhDG>2s{?nQpN)P+osB zwuMNe{LQ_mXeBxovGKu^rHm)=eU(69r?-ef4GHvl<}BxavG88b*g}n0&4g6&&*<3_~_U*}& zUI6xy=_*2TQZr;uIuqzBn4aJ(3IyfmEnp`D%+LBs#|B+rdJ2uJ@tAb3KM*?D-_G|e z9y_T0veJX#G3+svQS{McnEpq_q7rK(6ee9=5c0$XEaWaUya$=!<<(s zUfeQ4mCA;t+-0}@;oOIi4=qKazRI*KLIo7Kv8KvCVu7%_I->owm5|w@QcfKuJjNaR zppiQ;F=Vijr2>t)X1Q|vmqo?j#YfTOC1nH#6Z1W;+tHH&T=a4+W6d<*o%cOz=Y6}=MTwBue+ecqTAWjAlbt?Kl2o7J1|s_AMZIOwwHk2{hX0bfmD zja-@=(r1BaVPXK+3x>ZwVqLoi&@Br_jPyrVhIeO7^xz@MK2?it5Zb}_%k$K0x9{nmfoN- z;49pqZ~G>EF)L20JoWUj;`BzX$qi;a*BqG|zenPGyY!J0CM@uD%@oRY^fF26d9#DV z4tmu-cWdrL55_wF+u&E3dH#iJRQGhZZ1x$A)fN@EL8+EHs~sBiR~N`+eUgqg2v?%b z)=wNm<*#=Pwf!Y?XX8=UD=)Ikp!jck&!@6ZAKNP~TEgRA{1OOKd4mFLU;5BOcJimZ zos>MC!?)YcS1DkaP&?YYmg-cq0KIF6+AplP#Ndlh&i9INO4&1C3N|M-HSz6BzUvje z6#PyqZ*N7#E1|l~i{Um-^w#^Tb9QDKUL4!Qr@XsEAJq3D=!tNn>drgve{TB_g!%qW z^oZS?)+4-xkqCMv%iKsvvc^Z}nvHnYJNxv0DcsBvhK{?jWB`b z>i+RCUv+NRP--BKvRFfvDk>Y4y8m^b!GEi1+3hujOPL1|El^o=3B%&HFm79T6j>!C zf{N7849Pkp6-uU6PS|Dc%1t{xmCg&%?qx;zu}~!v?i`dT~kL9H;KaNo-3ByXMF9AonY3Gwh8scg=uQ38`QCNI84h1IF}7nji%A%5Hc z(w@lO1wXzF8!{vi0#5%BC7EV-VeI#pZAYbd%M!yfFEbxJCk}4TMzIQ-pgSU1R`W^0 z)v>lFZN)INklh`TI}t^1xP9sJj*kNZElbA zR7f?}(9mf-8qf)r5iS==sZf?Mbdd~rZsCI-I|)GJCRvwn79+P6z$qU)6yEVup{gT!G{(s#nb!Gm zzpq<6iu;nd=2h&DrlrAE(6Llwm_K8YmfK=Z_Gh^hBB_WDldR8^y(_AuBk4b#Rf6E` zE;TK7v&p&JfAXxvJt|a(_2!13Z?v22?qB(7*SLC%&ZMfQ8qD(X-9D=MNV2mA|FJ>U zl+!F#uyrBiT=^;Xe>`IxSnNynztMj4njVVbOp9%O6|MPnf%$e?qC(F7<)eE3TuvF6 z^K6Tg^{r;$=X8<|-w525OzK=ZLGjtq&-pX&`wRN3Oln_ItnM~0_N)k(=26D~1TnNjE3>sc|y5f~N%pQ=l&p@{p0uti+QUx5sP4dc~Bn&g{g$txq zWJ;N0x?2gqi5>VSoDK6C9O2&Bbl(4Y0N8|}24b|Y6_YSv-qr)VeMCB203MTV5>1YC z!GqQQaMjOU-Y9T@8kixF!mY&igy=&e`JJalm!UKLTX(*%*O)q-0<)RE)U^Q+{|d(OVcxO4I?xd|4ezF>FDh~ z5eg|6S$cFCe$#fkyvWJ(uz+14H=KQHEdGIm&Z63(mj`O6B?IM1_fhmkFZ%xHt*k%} z=ylHxcvRa0zZ|JBn%T&pp}cOav&`jpE}d%J3TMKIv`qx2@K9b^*3GiE+=btVei`ql z`^4_JG76c~24(!yXT|;U>~lWB%AFR*J4)`HD$zM?3+7#_tqA4m_uHNw!_pf`7ncot zd)kyf;@jvZ6}WqjlT8k#=32d6k1sTitUJEvb3 zk8M`gVyIu__t>M+na#ZPeoNmsI^K{fKG;pJFe93=>G{QBIlB`oPiIeMfY|K0rrLNt zVR0E~I8`Ct8Ar`^tl*slYK`nt8&D^EL;kr6l!zGrA^Ffqok#IhlK6tUxmLkGj}S}k zAvF8P&dI^El2?42fiW@AHJh0`5523Y_>#-w4;$uu_{xR%XoPg}NQ2RX?JhJ`&8P|U zKi*8{d9y<|@i)IT7OY})8ALd(i`cb#NW>oV^vbxug~h19bsP_Rk6UeR-xoA9$ZIag z#jJz7X4Sqk`(fIXwJy(-d38AxgQ4u{8|8NePMkqOe3`)yhGZ~0@1EFLibb2|sXY5Q zK@My9(^dR=8fvq`-EU6Ld08PXI}QE+SE(bs$OECaC-q?!UYOTFYq+a`h3{ z%!)$}*)%WGQTt|G2`YtvIUSeNjtRBK86#-n*{j}+*u0Cjm|JoXr z<5rZhstC#C&sb5>PvQP!Zt9~n#UeyPx2ZZvoYT-JZtpF*eR->RB60pH=>~B{Z4xc! zZb~sM<=Fr|t+qW-;3v2=pdpb_>`}s3PVJ}RZq;h9X17_mm**1ckSJ2jcR)6!H zTV2$a&Z2v=wZ4?YVm_KlQ5`)KOzup zCO@g_A2$h4*$u3enxLv9VlkirNQsP4GH7~qu?kss=VYw@Q_f~o9vRicE$bZ57~=Kt(6 zfSjg4B!^}wM;|hk@T$YiaaNtW6F>yLrPDZR8V|7Wy#Bn@_Zu%I{4I~T=D@2@;S&~( z8868Mv2jl6<%(p` z9KKr0Xy7IqJ+>P|hAwbs3l(jtK;Y@3UuZ_d8E7cVukNI&b|j!YW- z6c3pBSKj#A6L`tex^HatK6tLa^^G0@bQTuRDG8agX>q+6xZqUT za?o|UQXHlH?eE(^1qi)45E}#avuAIjqFnth%Q7o1`*A$~T-`g?PL~Ef*3$V_m3)E9 zWVfm!zZlSaCR;tWKRN22{VM|v{4rw|Tn75x#oUGiX4wP~vpAT(siy!JqJLu+bQuBA zWvNfH0Hnx)TH3GCw|d0%Z>@K|#)AhBCb?_dQ+_Bqeyxm zK+&7wNX`Q{07B2f3X%1u9e(GR8n2Y1|I&$+e{2^7mt&LMy|0;dd-*wd#H1L=Z-Ted zUgrYImrCf+4JieMGSKxLp~*A?q@pK~^>}clO9Pq3=|dp$h(T(5ZZ8I9caV6OKR;c9 zQ*RHOfx5)=?5&KOE&t!D+quk%8oxr_Ts`Mm)(CizF z*TsRwkGTeRJW7EOa2+;9JJ$~#{^#WYDl`eC0zd(TJPy1zfxB(LJ5&9z89*I*@JNmr z|F>E*vzy;vDhMvGufHHcMe5Zf4#)}O?I!AN2bJKW{91FV|K^7G6B|-`dzm?c93z!x zd^$-ODCsZD5p&7TtlWR}w?OxrJcNUZsRyBI;{tBK%(%tqwQiMynDe5*|DLECkli>4 z=jy;|cc**8FslI1%5yX3IAEV&?B7@~1NBqreXnwj?fHtoM;)#YRylkr7xFBxmj$;+ z8nthnC*@P#-7Obb|4ALFP*j@L^Y1_6HEt>mygtkNe{P4q=lt|^%rorG2>svZLT5(e z7Qp0mf;nCOdyn9~`)_zaqzU{PoD6w^f6oIB4n&Gook8#e{2BR@f`gIg;93Gk!{xub z`S)kN5C5I?-@(Xpy!QO>r2h{7dyd`zPWtcQzvq~$2512tpw)1=z4^1n#^nzZC&-)p z*5U7iQv6rU1hdfqq08`tt5f7V4MbFSkz%y$GINvVk7XCoCRS5u+EYVJU)$m*7#IjB@k0HwHR}hLk|f z84grA)#Yv?vGMWg>zzPDa~ZRMeD+pgCz=mVeZJmv^qQD_K38wEvA)hSNW;M3axcDw zYW`sWrm)z$M=ocr>C9OLu^}K-~hU(;Gf4 z65(}FpKKs+=|L+|l;)u3%-e7B%EER>1pp)$30A8U3+w9 zJ!frabY>VaIQIRb6GRR&nJ?}C$zT!8c^Wp5HZ+0BbO*&?U-ySWfLPZVy%=b_9K;1h z-)Ih-ne#4s_?xJGQs zTxv8RiSK6P!K@YnA+?sL zA9Cvfb|%jI_xHz`__Qy~^W30RyuwF{od@ryrx(JC#^fgCLOH^;89yv*3Tnz9WW|9G zCd%o1dltZ@GNt3BV+FgF0A-f+Jv^vGQlu7NX&+CpF1;fe@inA<#bw@WplOI?Fvq*& zp5o>}+XJNe0t%LN$7gs=`%B2jeLRimn#TC#P-O1dyftdL>$N|t zHG+O?pM|mK^K0=^^EP+dwI)d!ypomjVHi^w7Bva)8zr1T3K|S(G77tqO$)~miP$8d z*X;nFY%u_UO#$nrTBefD3?we0Z2Q4LzjHG0SUV~8yPd`1;QWh*Y#BN(d8t0Fcq}kZ zn&ACp3vfuI*gL!h1Xm<$5UIhrX2e-=o^Ol>bU=mL_9GksTGMF?HOdgHwvhUSE9Z-W zN6lJr^7D^ z=cn!f`MzB#F%Z@z0nDRvbaXTUWD<`5S*gUXfkp5ld`JHkb%)jOT^fX--+gc z4Vx94(g(1>kew{M{N!=Irn)AxlIOGZovnb{r(32%_N2k~F&Ci`cY%$mlFY2wC!weF z;M$I{On{m|lV%)>mHxRh*T|xU@g4ZVBzr-C0?_7Z%bU0CYPABw>jp-jCONZge+*fw z=6*B@`v%%uEiEl~K5o6&tYWOpFaW$r=SZ^gtz$sQ0wJIuFp3Ut2DUv-zIrIv@)_qpz~OmanP8~;X`PEazBk~fst+`uAZ50$|}#s)R3T0h=nwv zivxEA-d-mS-ClNIcq_W`KQjf%@l$&KEMzaT$?9CRh|S>ZR;GguoQlaYo!xkDVn|0t1Mg zl@8tx$Y|LQG=0Vn~yB z2&alH3n{c~O{Gz|5Q#ZbC34{bkbB@sDgc{QWdpl|7FO_%eu3Jzka#>tOr!WAmO&=a zL)nh!W>GWs`lV5&6EKJc{O&ksrMtnW5mn=2nO^@5q?*?4bkc151iKUd3UjVJlnRKI zFao5kETdi1mp4yYpE^MYtgSg4adIa>XgMlIUiV3e;R5|4Qaw*{I6w&&zw-MXZM+C0 z`pumn+8fMwFF6tf_c3m<&3i=QU44;8xTG9OKQbX~uk=HH;&o9H5;86@;{}j0AF!!r zlmZ!UDJs!aFYpZS-Vk?^oAnOfiZj?_#;teX)Y11@&9Wc$yY*=>f3U`ez2eepwAacq zwMZjau#|=dy$8%)A!^8Nv?nD>cL1wrV@F3vl}W36z0G1%|Nl$+=URo!o0)SOZi^Q zv;?)JcWXFRnG5@<-yQK_Qz1fL7wh_E!vpllRDmY+eMn^zegfsHv4MdB3VC=4%KeA2 z^dws(c(mf0Batd`ClJBsiS3#W5R>5Rg=K5m%3O$`?lf>n^N~2wXv3nHR=+|HX%=iY z?+HN9i-Shh(sR;iLb1G{USm(_0$%l0bf23KLFK{|%TYxjQegLlc@aC#D4B6U2Cr43 zzF_bi@;m>6H;F#6j<9SP?}aZZvE97Q9DXFT=2;)4U% z!!AG@Nh#|6@$F5lhOF0*v`^^hh#oVv8-WiYJ2C;XXR(uLrJ=RK0I*To5VJ|?{IRv$ z9acBNlo%D2- zgc>CfP3_?gl~d%$#D^4yA>yJZANAG0gtV-BlVkJ&CgbMdqyQp1>j>K`0(_&7UJCkfZC}OrWhP45-92V^${$S|3 zGJ5dYoLzriPc-BziGWKAj-JdvMSVTYr5K?P+_Jly`iu%s8}u<5Y!3tvd?3u8^h8Af zu86UWehV>!SBeBMNIpN%JMcNM34y}=AY!v()PW&tfZ50v48b@#qUgR8jK*;1 z7X3C-ots${peIvm03ZEm1y|DjK<7Z!5HmR8=V~k-BkD(#kQIYu3Irvys@IRN$$x;M zCXmtySGVNJrxq^GM{=D}vaaA0Y7il`*VK#Avz>xzF!G$pEs)sX(^pldru)(w2;$ zen%Qp4&P5}XOfXl;q}fvJO&w)sUFSM!!)W~NC+0_LaRx>Pl1oSfOdD)d-Jk}0N9 zLmW^x%Z3obN-?R?GVdr~8~?>jRi_=AA^nt*$(kEiF}b^qalt zO?Q|VbCrCxI1hv8v4}olRa*|>rZb3%_%Txs!aaIAHR}Yki26(cd{$+QWVG(PE92}m z2d1CnP8}mboDR!T*a+4(a#j>)=fyD)1TDQE6-yOO#l1Js*~{guk#$`_{}(R`4;J`j zeIa{++sBXQDe(1#gHh-%2%f>?Oal%d96q*BoDhWhZA6dXWxhK!D$tD!KL*IJ+-!FE z1&Zzv$tLXPRqY){1yAg5&`k$E3X!A?>Z@an1sKpCEAU~{HXFRyrf%lBkWet zf(jqfL8psx!gtV!=4Tqhtjj|whmp>Zcv5oM-#O9N<8I|%tYx&Ct4S1pPS``VG8x>i z1W6|$VhwjA8bc*qHH>?-24RLpXo^0Mb@PA~L7s*PgT6r6Np>gCTt5v+jBxcGu86NR zdW0d}X0j+~aJp$&fxKb@6Sv3#Wzfst;awsDjEFmkVw-TRn>%|em?-!MbvxEQGux=ST>>oQX|h`Q4rlIfg+apWOUe69kei7+A{>3mMCS>@t zD>7C0YlU?7Ri&~+w)){YmDWv;g1mV zZrmpMMof~F$xC5)<2z{tEqpi5Hz|Z1G(|d)e_@GxgX`g$gh7bG5D{o0j(>^jx8gBL z6u}T?77-kQa$f*}G1T|Pj)TAW5r=F(W8c{T%8u1wS4u8897bK`as(C%YPOv zvZIUk8Uh)_~ zZ`*7{Z*tHmj1An>G04#un{c2d;H_`$Jg}gfvY-)1p=CJ=Sq+B0N}A?Cn<_b4L&*M` zb-K3!B-Vf9Db|jE(9OHeAog#ol{*D4*r5_V=iQN(bdv7yV3obTBz!Ffp2*({)M)e* z@_l}#!o}xG)TNPKLQ{;go@*;@D+>p*2>Jx7ps%xh)hN!cCJ)if5`VtKI|{du7{Oh} zE{GQ@3^<5}2S*YN64GU$;An~FvTww$i;m-+Wlm=nn}|1yW6pl=8@UwklS$@p`V4kU zBjNQVQBXx2f!->C2`OhjG{$%Iqj5CXZg3&v(a(m70j_xO7Tu*(75x^)XPd#E#f(K` zBW%;5S&`bxzb{VYf?k@IQO%gN_4=#N(dbLye2&0}Z7Q zjhpo+76htRY1n&*6bqC=_zv-Z%v(3T;Y@f(_m1GT%1;oYXR(2d3@jQ}# z{{;u= zgebLq*dikc`~kCF_ZS#~1rW=%)c`Aci94t#7qW^o7^{PlQG49^-|*Gn;2H>++0#ra zaS_^R;=rDr*&>5?pdHRVe<846M>^fYms(kZQfA0sKl$rxErcLBHQN7TbANvh?hY_x zvO7ROD6VD@=^e{g3msD0v}IJV8E9j4pLYd)0rQQ=Ds`u)rLp@u{|*>f9QXTuHXA<= zKO5a$mtTnb8#WnB#q z_I{#HPjMD2s=ch3Ccmk~*sJ>HKwtB|XpHH2s?a_GTr6Sj{ZFh%Yhs;D@a8(1-6QMB zo=0mhUWSEjI#mhKawp7*2RSa-ytlY*c3u)PbNd_oZAO^{M74f9qc{F@P8l+!@uJWh zT3g;gXJJ}~1ay}s(S*ZC*P12J7opb9AQ67|1lcnP#H{mQ1)hF0K(?Iu0&pdYNcIn+ zpXX*cm3>rf_<&A8oU*i`37iGE9(hfm14c@7b8~>j6%xM#dcS1+`rg64&D@qqcHP~P zMFZw;#7QpvLxiMl^)H)9DjCunw4jYhhM=U{-gSR#WKgMgq(rN+$qC%n9_U!K!BT&L zu?50Jr5)XMG(QJZ4~pJ_bjjSmaBQn4*wm-sEY{noUTRtZGWm{6@O4hWmh&Zo-AO8Z6+{~nPld-%ho!K*Eh3%NF<{? z1@4P;vNjm8&Tvl0-T5k!Uu#|UTU1n)+TXm0^b-iv1)eAC;R8R_t|R37bcB8-dvnp3 zJ1-cX=(;|N$QX<{Lo=&)`C6I$iB~%dq$gv;TY3Jr6B!VF3mO5O?A3N<>qsWJhLuOi zQsz>?Cxfnkig~CY$s(aUlL}Jnzkc0>$cA+5(Ul`-&+A z99XN`i#2Qu)6pDiMHQWLHNgs39+V?Mzw%x_>^vNx5h5ZvkNc4iz9k%3B88+@et!PR zS}gzfZ3BQ_Om^J*V66*)N~zwx;`&g61WG}LEt0BVKmbee3^cfL*RlFecw^auM1p-X3w|b*!ga{-S9^Ydu4C#X(I|MFC;82y#Igy5 z{xh+*F$FX$)aj}wCO|^L=4D}tl-&K((2=sPCQ9gzDAMm4}>+BX>xpkh`8v@KzNJ^3h73wPGk>LBYR~BMo3T z$Kq$vrL5-*fxz)V;$rWsCeUqGC@3aY&_fhKko+FMp0wJ<8P!1F;SxQxPNL-Rb{0K z6opv)mp?%=51C+T{|cHeS=$LQE8-|NVQYq0d_LHSYLNd6d>dkg!3Fr0EI8ItN$=TD z$kBiH5#<2+DVEj4BM?;{ahyKw4{O5J$YatDH#m_>6$X_{oP=4OOA+U362^Z%tDwXn zqvsudr|w+H-;toeCkj{DV*Std1^M;nUW2R&_yjWU>#gx5BU7{uXn-@ z|1kB=ae4p$A8)Q@+ge<<>q^VEy==SHRjpdKy@lmkTD5H3wypEt=leV7cK)uetM~hb zr|*xUBE>9TiuEzKl1-q+Ljc_v!N?sR`H=WW#Wa9nhK*u{iF_WO1eVXmk{k-zWK7xE zIL|whp^o48j>-LOCH=R@>1QB>_aLh#EkT?C&fbRx2l10WWcI-$go=IPm_5(t~97k%hbsnw`XGb<7Y( zG4f@C4?ZBVhK0*fN$^n#c16q+ZV>tb=Vh-rmg3+0g8+IT&I!6R4>goDA7t!+|K9@P z^kTE;zn>R;q#GWXczY!z7@~)a0`z0}-RS;9G8VT0*dYch2ukSi3+r|UE4SMtjbKW* zPb4-;62KQ8?vr(t9IExH^?NWM++2nNuOw?~un@k?JMlNC%LDq;-5K-s)XKlos2TKR zE;bpFq5{!u{z94YHM@r5VT7aO^nl50xlZC!z2msmUU7?F3n7^MtW)8-XKW(t`!pxZ z&9~L1mkU?UZg(tyU0*sbq8+)j4?Gnmb?}))-MhQ486*31xUVnAxxf$};;5vd%~d1O z_O|P`?S!8->n^==%#%Fg_g&!Qgwy5}k?$Y2%3R$}wt6~2|I+Ss=-F*ARLqgWR`0@` zC&UW(LNZ+ZVWTF4+~u4e%UGjYaKf=Mei}$O*SR(teLDDHX1wO@l4&Y^hdbC%`n% z)&>2koP)<x7C4#C0N3x6mqCn&bsRn2KcG=wc|R3Y@R46 zIhozVsEh92{Y)bAF@2}Bowuo@R!?HUhGRzrQP69ZPo}N91nRV&pY#?ve0ojjV_KF| zYg77_%m0NdTL^!M|J$q4_GRngRHe~~6~7r_#|wgcM$P!ya%1Grv8iU7Wt)kSSI?v!N9MSsp2vgaPc z1NR27GaenA#oPD%DehAno)ww5ak~Oo%vWI#@pC#&h|p2o1;p!KVv)3UmLqigQ0$jO zHWFvXXT_*MjSwmePe=fph`)06mYL@4O?cL{X{9zpxE(j;@1sM9n61ZJvGo{PXJxT( z8ii0}&)vYs!ey&upxR#wEaz0=M@Z0xLOfBma(8U&{o_O!i4h0#!dg%d4=@vN-)_$} zy3NG?$X>bU^4GD|%>6Zy73d`3+=rkHge!O!4aVK6kNw++7JPY(>-;DuYLqE|{ zBdi!ID#*3vvM)IN5>jNyNqnnuC4Uo3sLZ@YSN*3L{FiA9>V^~EZ$nx(AXt%408xdx zaN@5kb{EdDE@`-=eVZbZrbC`#=>EmJd6cDBU@6#T; z+s0~}KXh7)9a#IJh6xRBi}sU#pPUsrkU+u6ASdLxyZjXA=xw;4)ll(AY{0gGO$$c& zw`Is&hXQ@*MVX8L#OUK>=cfwa-4zW@>n0R-&WcI}hB<%Ibq0}+681bH)AYGM6SK@M42%sHUboIe{}Y2!G5?;E=) z<#Vv@W$9iwA$pkS5;)(*2tL4k?>*-_fxjdY9{UHUN+SVr;P$5bQ}|Wd2=si%i-mve zzZQCEO}kyd_Fp-uCCwzn~{G_GVWn4^x2!*>sf6O{4 za4me0&mbCF_NPg%RBu_jxnitV+cZrs)5#33AF$=9+(C>*m?_ZEYwiwI*tUeZI8Wv@ zP58J!8GgihB+uck$MbG5Q)WT-L#J8@g``uT!ID#eC~l)Y!Q<`fRyExD1pk-gP0A71 zp%jVFw?Y6kRSW$PD=5*KH*a|} z-;)cz{8(JaXAuzgAi5JQ^Wrv-51_tUA530Q!2)6iK1#Uj^NqtYi#+_1!9SHQg(a-J zD~##qF>_DVo1C>S3maOfz9Gzc!Igk-wTco~0{@C%R%@?UoQfiT=B0YRjpmc`4<9L9 z9G+9%)_g6(I^C)f^W!>M^Y)$qeEI1tDDd6`$v*Thv|`(Zat{h|mQzX;KubtQP#!Vq3==X!EG7gMPW>NfJZ*vv`*;%{1fT!GyocEaHh z(Y@Sk{6oEdj-5Yw5ox`?6#S(SEVTs9vlo7oLjok5vU_5Ja@2Q{@8vk7`9I z!uTD)A{s)JOnfXZx9zl9(fcizxc}AJ;zUkC_r8#4{qeh@-yv6Uy~ZXy!I_WpCT23V zzCmcVcqBU)HGe~juY^Q@>omu>uu?JOT=MoIn=$Pph=g0iJo?&n)al=S1`Vo&1bx`2 z5AoIg-<@Wr+`B#b6BAflv{Yu%87$E-#p2}WG1kKAxg3FS;G)ce1+Zr3Tu(&7_F~Dg z^=0OoinKMu^XC_%5>uf|?*<}|ly!qA#bg=fVN{93%C<+|PRyBtSK|zRk%X03kwVqo zV^(r^6S4=CeGJ)mpCy=d-sw=Mqu z;W!BbwO-*dh`HmJ>DCj;SFId!aaG0fX`S2yyK!6)-J7fsC<<4e{A#5NjZ(=HTVMIY1}?#XSmFqP)oIcY`tEgWCf!) z$>A4Im71y9W_S*80r%SXTPI2|=u}9q{_syOlu$9_Oukb%-^uw`JfmV%Ze7o0)tE>ZJ} zh```;F{`s_Qyc;>)JeB_-XX1Dyj=VCrrhf#VW;b=rE7@?r}v16PF~Las2C2SGd_2z zX;49Yq7Z~`FaU*a)cY?W7`UG=Rr!Qb!a%<62OAv;8BoBpx^pT$`F!$|O&?CgD4L?& z@)Z@key>7QZ)#s;n3aG(;vZ@*Ip1jD2F^|SImWG+c}4h)SFwJ;HxnWgi7eAxzRJy^Q#@aBN`>S_$UL4TTrg#e^xkkXv5CnI zpiHdw)yYBNtQdfaEvo~dy7GtH%ve2Rf?B zdnPvFgAmX|YwyWgz-wO5{u#enKff^K?D#W}FRNF5mN@)wDpQTvvp792C>PrQd-Cd! z5VACl&Bj6T-nN`b8qkGX$_I-+dMBjlmGQx>5sib7K!1_YnMMLuzbgo{^^wug(~!wy zSc^T9<=pqHxgWnt|NTtkGJaKAC_w-|*lQWMnac@7KlEk3ZYxa#AUt->*s?hK4A*F> z^k8c`=*i#XfyrRSOLi`s$aB;vs|^AX0g!lSS=Qx>-c>9=!>*X*xAi9!tvO#&?511B zMaO$c4)Zkol@U>AbSH?@X1jj?EdV9-j-!B-go{2iemQF} zXg}8%9Q$_>0+XuMz^GBvQ(M^df9E6o(h)cv+vZ5bH{F0Y|#O8KWOElVPt1d2>Yv zFoqpak{0s3V8<)>+;{QF*1!@A$lMQQ*eq%Wq=nKJXA8EN?WyN#gu;VrQxAQ_rBGz4 z8=t2Wf$JdOBvJFTlFbgbqi*N92zJxyrAYiJ7wYx1aGu|)9|lZuc{e*%43nK z&yr2x`wcq z+@q|K>X7VS*VZjw^J<1iM6p@=?N9v#NB%QoXrTQa8(!Pz0UIO>=_p&|erx-7A;LY6 z+w^yUka84XivqqIeziF5sN^+`?>2xON8RRS#9U{7e*WwJvIIxww|waL0`wH3ifEppQ3{4 zx{*O*m}1@fqd(krGlLO+IpqXfR{@&Wl%Pplt`j;@?Cxx7rk{pQ*(3DgRxZrX8g7aP z3!Zz7+DT79={kw!y}({tLD+3tEMbiO4mgxMr~uylm;KW9C2P<28{~)mS_p9LXZ4;X zVV5QG2cJR`7z8~%$2i=0x8UG|W+2JJlxr52>H}NJ_b$`(p-U3Pdvn_MSPohwgyOl) za8J#ztxY}}o;Evh2Nka*$@(Cq7c_-BG0^Vdd{T+09YfhB+GNz-NT=*ke?-Jd^dw-u ze*Y$=bf0Jln{)B`=|$_ue7yk9Q$1%(?KT8N%;qfw0VsdV(~mK^!Rzp-E$%nac*={G z_(M)*K46+YmIRTO(iD?CeD%mxhYe2HOhwS!g#{rIXuwxBOeyqZ<%Jt@JmTq227Gtu z^ce*5S`T4TOWsu8(sU?F515Y2eRI{q27h2}V|Ql8j&rC<20TURsD-g7 zli6{O{{b3n&vgcD#$~es7N{hdzpkbqariwXds^Xi%eXngP{@0zlJDj2{DPK8A#j|{ z2P(`Tm$t?sG{8SPGxqHXxmHst7kQdtPb`gAMGHVB<>>q6Nb?pWqaZhgX#xm!2=XMn zcz0l58HlQ{(i-eW>US9^TAWVP*P= zZ$@K6Kay($J~)a5F&qfdL3$fSeLeI?o#9cKcO1QFg|2u+?n|=yAj+YZ+@TfMYQ;iw6m@` zZ2@@g>W3;GQ1K?tdS_Fnxu;uk$Yh9QTs98P4_cgQL8;Dl;$>2+9-L0!h&(r*WV3r7S1C+bRwLE|KtLqLk9L z*}Eyj^$5*}naX$MD~rG*B3JR3y@7IJ?B*w?o@JlQa|jl}-Ej6Bk~4+?;cx!CL6dS8 zeBp~55|AK93AFD@H_r6O$YVqjmNL0_f`TJ1$l!%}@~aq=egL&t@}&F?EVJCZFREu+9pOJ|GN6|dnVk3%kfyb+y)aq>1W5elz??IHUYATxM#APW({X~yJ_ zhWs{k!Sg^inSG@nX!S`)!q(%J+6LVKFA(gvpnAQjNP0v`X!js7DCbA}%y(xI9>`tU zNpm~1kcyFwz7CN)bZnw+Rhczl6#I3$TSTuadig0P2@L{<>`Rna+7QEXQI88E*lC1! zzxT6N>6V3UI{)j5#80-Qlpo5{1s1$y>9_gTrRb79YK!0(08p3XgXWwM=$a}P1g6{lb2oW$XGkNq>$vP zui|$@CDtBMY=N0yDD%BKy#+iBu?Z9@01|=9C%m&u--ay~4XeL@*<|&IMCAMMkOCB_ zht_+jgLVX^P;7g}2KFAg2F)%M5SmBB9WplCiE!WaVhD$v(Iy+S=;pWM@Op7!c#yTp zsAUbGJAE`53SpP3x#pVcuiZ|59?G&>x=xs5T{Hlae!PoNpjDt>`QPPg$2B7jnut!W z5=-2Z{4CUvGKRtY{WZrC@W0jJDE$P+mW0nZ?#K?*fBQI=zL2HnnHK+u_ni>dt$Vh* z%)jvWmu4zpLk3+q#v>~>3qK$kVb@%S0}U!+iXK0d%*d<`bHAewV|ly{43DX%f4(i} zZ;Ho2F1;v!mHJfB!pNOg#WnOd(5!J~813-e%P&+52pFOO+&7gM>eTsJ-^?I`%qlbM z)wT7%`(^?UI#eQRIXr6Dej9B?Z1wiBXrMJ@(D6f1Okx6Tc&6=3mNFq~_fVz2*eUCWtYtT@2 za?iJDB8>pDTr_NLKM9s?rf5`@0P&B2GVbFEI=Z+wgz~{Hrr%AAApoD?Jq({1(3NO1@CO8^|#IW8l@wgw%QOb-P5Lm=;1Ia zeyoe-^$B_9J}B>)JWK%G&oYC!OZC+2&~p;5QKg|^UuX!f0FSbr#WHu1FAKGC4;BxP z#c!=NLWbWWV=e0LLT$RUSQpV=DmM2wgLdyX18oh0C^n+JY&UDN9SXZAieg>bN!(s{&)e>_rmGH8`wAzXT(Km3y< zFvk@1!MAUXNxbJC)yp>?;Y0g`><%a4Il?WNBhYIuED&iOpI56}O4 zi%A!e(7PhwJ{qQkW4N!R`yC|0Q6j^4MI@OA!u-03Q0I-u#IKBxh0wQ1+@RR{>8*CA)9+@BJo;-vMi%!(+JrD}%SsFYWRp^<;n3%PFoy z|1sa28wlOKCcPIb{^W(+5tdAhf$WI}_V#?i^H$x{#x?_@BY|#>AOs_+UmgxjquxOX z5*%etU(&g=h5>akZ8*vmu;!hKwlP%Jg*Tx9V;T@@N}REU$x0gHxXT;*PJHhw)M=$) z|1c%+gzqgyyoac&zDM94Mp85@5SSD-5%G_0-(M53o}1+IJ$o=LSs!C~x@~uG2hTTy z(Eue~8~@jnAOJUy)O9D7eGjRkL9F|{$if-Yg^J_%R4`lbFdhiw z!wOmBH(P?qz2+0fyIS4$!thwPUGL#AMcCK*e@W24|5q;xkIiB6M9@L4f1CP0wl8?% zC09>{@!zJ4N9MjOPPWATFLQu-D zaVbXW?)@*D1Oz&&HP~(-Dl7!tG*npIU4Hf_t6;%8xOh265%FpG%A$nyuHny=lzqUiFq;wU@*Q;oS z-(X&*OAwGIK5vwY_+%1HITK9r$!B^V$Z4}P=(iziJsA&@ZG@UJs@DFv8QQ3)9gQoQ zpq9_+96%j;daU7WyVd)DjUE2~j2%!Y1WS>N=Maz+_K9+S4MsTUtOX1=k||m3Okf=7 z2ivzD64!*+DU>t1->S*Fm5X>9CHpynlKd$(%N>c<{DyERq#^WBTGVlA%wB9xSr4my zmGGajKLpf0|5=2>@|uO@kGyDF=o5GfqgU-YgkZ!P5uLFuL-Krbh$fT5Mw~`reVMDB zMtCZ+mZYZ|;{Wm^7wEy@hafw|@>LD*LQE}_H_^Z2J0mLN1Q^IpSW>p=tGXL4IzuQT zQ7(6+zeShO>j@4o^R1UJrcCrRr zG`dUiGvUa;I(>!K1{M1$05oSgN!^`MQ6ML%URHF2?N^{Q)X3O_sv>>%nm1J#qFFVW zz^H-qL~evpu70g$mtevT4GzI%nHI0MhXNgleRcc=L3wX}1&8jhn*wCYqZphGA=08| zGt?^KvI+782jj%5y<78<<)DkTaEIDSl2Am$TnI$NF4(&V2G&K+h*>)2s7fDHXLaq( z%TU7CT}U-u&CHz*D+q)Y7i?qy3;RBRE!og7|AqK!!_AQ7=oI6a+Xh3+1*8(?L@;RH*yTsZ#ShnGf_Kk5+@w#wK09^UM!**?kB;P!ce9YN! z%Ab-GxQTPLRlhFUq-;eZ`T*bhGpZfez`k_uFx2ZG3pf>_SlY{xOg~-7z;M zTWa!|EQ1nch|zzriY8nPxSX=RQ43xXx}K!Wu3&V4wpYN13_5VeW>ixv5W~=~!N^$s zwQjDGd0rd(TU=oY*#`!FHQ5@6QU_}ydKpFn?*qwMh_K^X3)NE1y>@#X;(1c6Abp>ttghAsG_bG?G6{n10X3V*d+ zvL6EzxA*T3kg1!z+F-GfT%hnd&VE5wG8+&$3cJxlPkf_^gpw~Td6Ek&)m^<)t;uxB z`L(!W`$6tE#V_?zS+w#K3Vjtw+v?8896`#c)6(C4!!TlkF5lR^?0AgpGk1bFYfJ1S zlm~0wi+H~JLSj(b?)Oez7&PYq0a*WlCVBV2b2N~M)T0yzLp4Xzdne4$8X`z&K?eKFZ`!eJL5gj|ABzZFaILPsk42(ao2c^K7= zi~Slb+54Sq#;u?T-cABHpf8iZ@+@E_WjyX9Djl(!l0VNpziqTk#Z`= zTLwQM+LRpX+wOZaHK0_HoqsG4oW#(uJn;N|BM^lwicWw(Bebu|7~-b_;7%c-drOk0 z%1))Fkz{#kyzqxVulYSoL}EfV(~ifb2xiov_143JvVq!?0%NoNpLUbeHf_SCMjcc% zkfGl+^#vpm4s|994+(84$}UoL3mpWB?I9D&GIvS1J> zM=ZLVQK>o%+0&N${jY6SXOq4h8XTv-+gc*oqlSeq=v+DeRbD@L(EikIYt2XHs%{-e35W<=rp=--p z*_0hK_QX2%5K&Jeb@D#EeB8h1Fum|yW3$2s-x0SVW(Ko@6G8ADhydd zZkuF5{18D%ZmNeTlh4psFc6U8v}4gpH=9Hbvd`MpHr3C5*7I*1G#RjDE4Bk6>SU** z9q5IS7d=6EJCNET?-69Af@2clx2pbh`;p$ChoSCDT7@K+xrd+AwxK4C9qLP+w_{4v zC(6u|DOL{g#@jJE%Hrvbq9yHKq8v%vW{zwlnXn(4v9%5$bgvg3YgYvJIAO0OLHuT! z#vLlkY!~ldv@0RGDft0~DU%tio!63WOv^AAy&(0j1vnwv6SBQw9AU53+Zp#?#tz>t z+LK+ZzHaJoWkGSkF2LWh6@Ae0rweNap$GNlaWM+s^hU4q) zuhZ#SPoXT*3KFmonMEY*HH&+-+dMyy_w*p=yZ&Fv!rkJ3N4ReG+cV+Y&{aq3??HDH z3{T+;eD3hv?Z$WWJ74|OW>nodjTyDUh$h#4IO55e43ytaJ#p1Y#_Uxpo#&5VX84AG zu9lhoQ54xMy^5&dMF!1?`LO2Ot8KI7GMoh8H+Nu?p{1FW zA`9<|bttUu7m6P+>t>mbS5+r+&eJ}RVZSUd4NobU0Um7alO1VGXTZ0+ct{Ro8^N`+ zZ&knG@rJz_`#;1P{cgnRE~$IfE1yYx6!P~-q?9wOwyl>6xa5bL zELwGAeoq}cwks8nO^y(Qr?!`#lY-3Ibd>e@r zBhbBis;TQ~wrh^8NEW&dca3z-&>_G2#vUB}U-J$6*(sZ^h&M1XL7I;Yt-=mHIpY6c zTb><<8Z`X-R}5bu^2{|2#ZOH7}!#`pF3dy}O@{h-T4*2wj ze{A9(&Qg+$(QhtrygQ!Nb)GM=z@8||Ik=6mDQoawDKy`{=S3SHp0dM6b2~^uJwxYu z7b{{wp)+2t+$%j}CAtes?AE-^WvF`=^YF3An#Sj?SBu^#AM=WCYou6~bKMMy%Wg6E zb5bRhiL~Hd%*Uh@xRXMgD%&#c!;Bj!$IuoUKyIgW@KenWhEE!h0AVz10#ri^#eoR3 z9PxHBvcinLlOf;4Z}_Knx~^)tDh{Et(I-5HmP1NCk?;}xF`6-(z2E~*0l4FTgOBb2yE-=#sjCfv_freRxS?R_Gvkeyv1sg@9SDCV!+$564|~ zZ*+qKu1%!E>rgkT+Sr+pwa*|3UV{BL^S&f5S~H1#gPvhTqHD*P$~0jXu5%y9pC2jc zz$vBCe$2s3=ir-$Gzl3PYQ(Hj$3$2k@Zyy>udWu{Wu_FC{zyhGK7|G$PzoQUW7F$o z`wDbncWzr9E+P}0OTJ3C%07)s1A-u8xz(z0e5^L1l7V%Tmex4 zKl@J;z7ch||1;-<|7XtYdM9*{WSW`YPH%bEOKIMo>1@bq)}NC@u&UM2pNAzo zT3Zqq#VNVj^g2tG?|!x@7( zu?)7}it5o>YUFD*aR3xY;S_RWw}w`%9xx+eX1~nD@Uzdz8`Y;!>_({{1e*ySB$_7M zTzDxFtaW>oWpSj?syve>mKoefLUEjTyt6fIugB|qODiBHq}~Dsb|}sssQJpC2V6vk z1ulCKDP=#8<#p5iWT#Fp=?(C;u9XwSVNsJSP{WTRV(VSUCW#VYNlalL&ej0tm~u-1 zfK%KFB)DynCtrfc-h4oLHDXO2c3^;si7a?Nj?ap^sR3}7*`4U%gSA$j$X4$UhQTY3i*1^G zIaY$?f)5LA_q&tOd&Xn^twQC;O%yWO^jRafb`W?ScVdwK_r4-YtKNGPi{s2<%6z$R zw*v22!M5}kNhc-CCF3-2`5JqR-qjWPf<2H7E*@X*{W0@kCKn^=DZmf1es7y}xCn`G z!x-Qcta@XWk`?yt=am=pX8}fJ^`q0m)b+vCb@L4aiZzU_^%Za5mf))90RAgO>D0oJ z-qcl@_2)OG-3}gUj%aOk1L2rG7I_+jU1yv9vxvRn_cqhXo>u9$h7p)<6*ZSGT>2*J zLHg=QBbPR5KSBzo2$s|!7IcV-C0RwCnA$YtvpQeAo_jMUk*j+6 zWR{ULd?oieN;F{atlu-J(1sUNMs?6nCn@>q_}l0W1!lc5<9_N3LWw3q!54(==fgK? zqhd@R5%992?G<+jR&J`EVXA?9#KmUk`?=8s7mpAf6kV|skpY^RTfhG}c_6%fP?^p{ zm(LT~`;)+K8Ns`%Fe?z+d)SIPKnhUdkhE&voHi6g?}J$t)ha6yr3hT3s^ z_K~K{+|aO1URvbs@N2!b4M9uOOY(}|I|K0$vEI%RrZ6{*QiAQZ(f7N+!yeGb*QWP3 zmK(T-G}uWVil{e67z|gl*j>6^J^str3;M%-T-0H2%S?E4`-w4EhLSLH-Gq6%cg5gr zCw66H`@)JVI3H){PeN}am>I7psy-)&Xw;_OS6`FSDYNUT2POBZqBdpB-;NKHz8gt% zDMin7cX>f2JfsmDd_&6V&^dtFieZMh`jHcrK&%>ASVXn@VWdv%!&nh~aw>((Z^ z#9)-pd@*a=;F+BhSI5@7SN>`XjiYnF<(o-7&UUX-a`=}#wnuD<>v%$(`!7!FFo$Oh ze~#^0%nz(XCw^KBxcyy^2QOe4pfh}Ew}-xN6oY;1&D&si$3=I?Q{=(dji@H%eKZmi zG&E*jgGiZUq{KmV2{Wr`w2Z8l_3-bCY~7)k|_X`9+C<)g)TrjcW!+ zAb~%EqQ7l=&RFM$%kSN)kAET<57*LJ2GJ@EeRfaF>XK`Rjx690eb?VVyxtw&AGT}Y z4$&<}{JGNb4ij5~Qeuw~zSl*xI3~1+h+4K{>+^3u?%F>XIzGoa+YP9aXxH6GM@##6 z94<}#LJnTD!ksJaMWTT;pC7x$aL~S}R5u-49&G*XkOxO^NC3 znW>&H)C;hnZ- z{8}D2pR*kTHd5HZiebvREc`R4ST5Omcot!l7tNAMDX9q&h^e7D$L7E4%!U$^k`aJcFnh_kb$=0FUjxh~b6yRaze-*o7SY zFsoL{-4`v0fyJzTtfN5E;C)htDVJS^xe^pgb*%5cqf;rZRq`1I*>$1$XX1;V{B&8W4pgKle z?$}-&yrM{QWf#tYrtIJAcXRLXT@7K^SL&8R^Wt4%+X!1TL}u}j30Lff8U|yWrP3P@ zDNVvP2L5nyURC;XzH)kB)wD}J$WE}(~{!cUw?Unr^2vE&&899z#7Z_7yR)8H@;WAG8o_L0di4hl%Ye6{`!ytJWxT21-ah z>dH*Jjh|l}3B86x&6vAAv!r=ZRtEjF_VpDV3Pst1#S>kV(ZevJZMOCSBY8xmSI6X1GK}=GP*^@s%3FZh3||CRxQ~mLaej zmmxrhBHLGbBVEPaCmJJjZ(Eo~v$9RCcSQVW=`Ilq?0QVn!*uPx;k7u0VGb)Ep{{;T zu4;%kD^X-$es!xqBINz+Sl0A~&|l*k`o}d>;=B6z7)N6c#jj}`9gpcjASQKRFW|p- zA&#KR{9V5CJDp#`bE$}&N!z16E2t5 zh}7DrdaMt~^XX_)pVow;Q*|Wh{O|&36nJRwXmzg2yiSWsF_|F#43djY@wpuN_=>nFf2d^Je$0 zFHR@pu9C@B(;&R2vH;_m{jDq3so+SJeF@6xn*&19s~hhHD@iaNJu8oD#-%5=N9m;1 ztjNcRRpkAH?=HHxmmdY}l+Z)qDZdm7K=z|bq&2FV^$PN?RH92|(k))#cxUaBMv_9bXc?0sU1rX&|fY@7)`11 zyju+RHqA6rLpkbH_H)T|8pY688475o7b5bwji!3e%s+F!V00aR`hvD55y3!KYzor> ztkWc^lV!%H+M$u&h42t3a$XKIrW3ciVCdcNV#gN(IC$4wxr*if-WoM2^vlS-%`luT z`1g}P<_HB$2$XA|XTO*U8Uyq0WChG?H0Z26CR2{MV&;v$hu0T@-s@s=bxr>H1?i9H%L=N<{4|Qm6+TIoU)o&e{pZKY zQ@6y~vm|yg&qkQyQJlK(ah^{6yD58JQT^pyPPr?O8~hwG?U+Fx5hW@gKt!p#204Dy zU7I(|YiQ(q4%Drylu))2j%+lCfpObpA+Dv<{AwJhSDl|uFRRn)E+U=m$oR5-vhHs& zzgfzW3;Njgh0Hg)>Y-b3PdxL*T>UZlMCtkUp$m-wux(UeOHp)-qen=AQ|>-!j~Vwp zXfWd5WY^HTYQ6>8Uzv}zfsn5qPyOZ%i(>rQPGK?~XA`4N9bX?>4P#@fqH{666qiJO z3?m+5=5i{swNBi}Dr!oAjeV1RD8+W#%ZqUs%!Gxv4aOoZ^Gf*Eh}fbqr_Gky__&sz zGDHBb_;%$V_l{#jCHu}toaLC`rK2K^Fo`KpGwM)?%p2eqXxinooJW4eC{}cG%&aN) z_gc<&jb5I3#FrQyu*z$LUwF5d`t<7tTaml>JrHo=n@f~pt?-}Bn@UM33rA`rW7UP} z0{6w5lXQ;{yk?8*r=P=}u&xb$U(%hzrCz$q_}I#LIl1Yf^KHy>ZaY7ITqYUl51|`g zpjimYtH(0kM{j1sTBmZqQ>(P1xfMnC&Z3dYl(k2<)ai@ey!kGm@i%MQ0knpLX*cJS z8hTCEbH9u2W)WF7t<$ryz(Th=!ScFsJ-7J7+`jDEkICQojAV6xa-+dsc8!k(Oc+jP zss<+e&uSnQ;G&0m5N(S!mZQWEIkadr%)eu1{ZfgwcdITEQGXAz<0gW^5YQx|jI~5a zU7R_{m4n^$Pbv_REe|(?@))x-6KEmcy;_Oy-OCf7HY7xNl>&afWL!nmNaGzngien* zI?g$LwIdaSO(^8RvJPbACs$d0#DB#<7yawY)p+DpTP42!|m)`yhw9mhTk=;uR_89vAjzvno1tyG2ds|J(tP~8o;l@v&ca9m{ zHZ2$^Bv@49m^nlYN--^d3nKn9A*Ut;%1>5}EeCZM z3Aec@uGZm@1%vqMWK2?sqrQbODpsH@?Vn{LZZsr6(jv_;x_Y)My3;?l9>09Ef)7)$ z-EH5{?AXPznck*LSpFis?{i^C!lfmw=)Y-8F=q&0?EvvEEJI;M_;donq<-mUft zYUw->wV+fb7B>439rUh<`;(+-nqsW?eQ8H#BLJZkmK+LALwq4@^#c?%IR8(5gf0VF zc$zM%^OQtOp9H$JmdZjh$Og=N!c7I7C~5V>fqZ!ho+GO;n8EdBK~`67o!)i+u|wsY%U1QL+N*7N^o}-=<_^B{$C#H$NrOtTWcC{ zQudfB21x1ucVYd{I5z>A0C(p~Zc1LLoTDO5Fkp0I)E!VH+HQh6W5F^@+rF|_#Rsus z;1y}#OCg28q#-5~euB4RgaVm2O(#7v``ZraOBX!86Cal)_a5RSm!wHuwe}q|)K8+K zix|yVkef_4BG6enrfVRfE+P=R#fEm#SRuAcvAFpvZTC0F(i-w1&U&;aoz0PA!;7N+ z8pxFu#xApvH#wL_Ul8)?_2|c+qo3Qmh{{F$iUm3Tx;w9l!;JW4#&~ugOOa?%orxtXs zFe=lh>f5>s^iO8IB@u<|GLTldRcbFG;HELapYWiK*7%_u|OgoDoFTHb#3$%+2_k2H*0r7raA8kI>$_}?`&)5w*#B{BE zKaW?E2O`#pq17P$Nm17JM;i19W`|}7LI`j;wm%L?QuKkDm@N6iCCnR_^4xf~3XGkc z1))Q83XV(7^%39#$GPjjG5n0d1^@g+pICAp{}d5xbs!)$n_hKJs<`c;#>V!@8#wfr zpd#^IQNpk)^*TA@=b*#186`JpmQ2Kp$y-wiTs&QT-dp;~VF% zE@!=9P^{K##?cDd&;reoVnXki@E#*%wlF&Ac{&{Glpg|!CQ~dH4B6>mzq6u34XvK8 zYV6CK)R^@WmY;#*@GnZrEqrv8R@LX;Hogjvc^e z0?))*)crQB)8z@n+?krehNJ|y{_>$@7%>m77V-&IC+7Z~bx>F5WC7q&bLhy29BgO* zOJS{|rik$=+R5G=)H?Nh!HIqL70d^1VM=2YcP6=i6L4> zm#ERp9pCre=l%tEemisad7iz`KIdI)z3)0}0jaz0BN*)`$8qCjuF6q!H9@m)QMa4B zWA?yLXh=N89f(OJ9Dcg33KV@%#q!cvsv7YPd4ITIy{?V$A9JaX!Lk) zd&ULAJb%Ar6D*uLUD*~hV-e|o{~j{zM(8hDR23yrT(?$iJ)>{vz1C^0{*F}C7&o$1 zK0dUCOtT4&5#lsSkUR#hpYMGV_MkngVq5&IIK(SeSCU_QWal5&y2XTBGgooO$Q361 zxdrqd$2aVl;aA=s1Xv>Fb_24(Pq^o+69{F`R!B!j6kRH~Sc@mD#rqeReepJ8`v^Uk z2d-&bP(uo&FJp^4o&8ubH?yqt>%oV3zncojL7t*CeLE6no4$2z!Py)5=y$wiwQ1L% zp<BTWS&uH9QB)0T|9@t{82$fJI?A_Q?u_H8OBVvG>k zKcyt#Jh9vg2u2~(RQ{`B1{(d*bx~;QTcpQle-5Wm(iJ0~UmhfyEj@=sr4N5ID929F zae7?Hpp^Ik$mIBn{+PPOT8w0aTylktJw+o{<0sT4S~^d09JDkiw;3ym#x-}tu>%=f_1}n7zRC86?GX*e znH30{!x5{Mh49Ee&4~7OI<|pt9bjZmV2XTZIq{H$hh2jQ&#Bak)$H>jiD{R}HAgNd zv3*y*X`a9gMdBf!T^aBEIl9CIq_Bi*dgVavhDPJ3M*+mHheBvG#nk+yKP8eGhhw~| zz#Xo#@85tKtr^5GBe7Q5>e0>t@kgoA6nm`R)`r`R#fK2#Uv1VSzl=nNbtWSl%@V06 z!iTxjXdiEa^NwCmvoRei>`CH<(N2c~fj1mEeCj{P1rxe}rVc`X=CW{xv!w(#hKNKo zZl&zlYf3y9-lUo#)NNyRp1SW$;t1O6D zV~nV##zr1vUrV=bJVu>{#4E0NShmtM)ijrrjAph+_Du)MG?zcGnBk@AUyMo$u?Uev z9z5KSNsvd9cdLOHgK@q)QNMgKQwYEsdm z;5q#Xs@_M!=N2oar2xZ>qe!lV`6JB1n{fj**ED|{biNguN;|?c{Zk+#QBlLu?HSH& z2pmD)5Pg>D5-XrK@O{yTPFF@~lw4Vk_*uH1cXytGg$CMYgQ)|sR*iW$vz4XHb_&kt zY*0dJ2nn^+mD()yi%uSD7u8cFVyvsFb*x#tMySkUtU()Z8v>H&FLs8BJS+u3Osm|= z#E}|SSw)tO51Q{Hd+?N;Yxa7n*FCjs+AgzF0}+|?G6yIp*q;IaP5I=u zWdBN2e!NC9`{FPAgQLL5n%K;c%C zI&a5r}1=*C&W8^UfN#hw03O zpUh81RcS7~d2$KbUJoV@7aRFFoGe(I$8~)QyB8{Pkqk!+_dWO2v)pDbCb@)l4FIu9 zM~TGF%QehcisP)~cv7X_Og`vz#kL)nl#68Nyc@3Q_x4G*Ok~quml>9&A(21`j9shq zYFzO(770A;brfMX>%5=pv)q#lPF?bQ4?V%GDJvT4jbz5}%NalNj(=L!ch32H=YHGK zPN~%F;06(G%-Zv%yfK63Oa4sSKHy6^fC(ckSjNr4T#A)IU@san96P=Y9h70r|D|yA zPHle_K*n%!alhlZ@dH@?Oi{y5*KucPppWo226Y>h8EzP0#5 z?xD5LKqsk&z*b>6w=lYxHB;9>-@fX95y)X57l=3dg)%awt z^OoHbt?D_IJ;4g6@=`Or_<9Oir_jigg@`oZ&NKP`v!F$G!eZprxI64d9#uLOMJ6>N zgN)+YB%|z!_Sg9Dtk{|fwpg4UV$ErVsnWE@edEFJUsoiQrB2HiP^pMEP+{AT>p_=; zc9YD0dMssOKk606#L&Av6c_glF0MoV^*C?AWC=EpHFCwLriEN#L3)bQ?vGX2yu<-h zUbkwDARBOD)HqPE^pRVH{C@Ixr#DLOUr)_V%nz90kq~-9&lFo5E4F^)NF zxy+tPXH(C0t0wqrX?zeyrW5zADLq?~t{2aLv&@&g~yz|NbsOW9ts-BYOhLhMKChDNP~mF_XCab>D)1+&iW z)QQ!LE+RH{w($T`9SRa19t5-`VFy0^&~y&ae$@1aEy)!(Usq^Dg%+*Z%>3|dS;NohZuH#iDK~Q4S|x-g^5+a z04w3}J5viGX*Gef`6k#H`lw~Dd7Dn%AF4*LqTp-WkI2tlZv}YkeGTP}$fmQ;-nh?&_qHw@Ul=;~-L|iq-dvVKV zPwB{Uvto4S7V5wkF~puwJ}lmFmN^!4ax+B8D>$Dj$FhEkw=pfaQOK*s{*QdQ>lSJQ z?jYKxm0qYQ`VEfXbt_hYjYagnQtEqGX@j`i;ek^JPBkU?%&IyR)ww0NUs-u5Sa5Z@*;*`9ZS-K<4%4W878^Do&sj*PwZGrM;wW@WoD8b zahXW@^UOLX>z@I;rA3mh?Z$8m}8VqT86T z!$pmeXPxhB9X_Pus(s4w?m%xB1quHO|4hB+^!5+9A)I|1U{dUiB+hE;i;ft-g{U zj=d}kpDnx;kRVF)S6=tF6B6}s);acW>XCC3^>Yk6NMb0Vr?v0kJl-ER^Toa5gEiZ- zsmFQl++(2vjt^!$wIh?vHH!OHvGR;r#rVL?58}`+J&`C{?+u#a68){0DO4{SNgf3E z%Ur)MT7G?o4VzQojz-O=rrU zV;eWh5DVY;QtX9s;^}`!o*-_|4vuD-19^FqYc1bI$;y&bQkmX;jU!QBK1*onQ6g|| zG+b!%@yFX@PTP=(CRC*FLGs*CYgb^&<}NZ*fZ$Q{vQl0$Q?7HpQO*UhwL%n6mOb&R zH@mvp@VZXB>R~n9Vm9=n;MUkh7kZeYN_W)CP6FWWZ&`bWj3}nmDN6!EC-k(?LAzw5 zITHM}jYb|2%C$fE^7zL}$c+yB#BFob24p>&1B2ayb4;l!3w+e`pMNqjO1~iEfgE+Y z)j|nQiE@sSLU!Lem99LwYS&wov#N0(6FjL-!guP+I1lD7BEv7Ts6vjZfxw8GPza@ z*Zz~s`p2->MN&oI8-6L{vCRy--rYN$Bx{-B-z$a*t0p24Q87vW#Ox+ zQkEKY*SejnWBVDs)SbgUth-*^6d{Jm;``c^Ov8D{7K(fwl(9wMAJk_ps_no1r1(pr zuISOE9dQg%)7NA&&NP#6jh~Gs`f1@|+yu(6*78VL9m@faA=OEAOPc5l?)&RQ;+AWI zCF>X+_<7vms!N+L_`ro(bH)B8eSbsr%hc_@+NFk(9^KS*=Jubwovj~+J605Bo1?uQ z=;di}4EV{p8#dHYnmrji(zmWCSj|@)?|N|u!(GX#5wY*x-@eaDf?!GM7~MiggsZ8D zC9FUZ-u^myTcYcqni5`90b%z-kcHEl%8`a6FR)zs+{^+t&u2Z zBb|nnr;OUd@Zt&Y1Yan$t@b&ZE2GlJIjYJS==mqyqqZB^7neLIruYe)wpLhL_Na*0 zL%8&jexxlp+F=YsGj@MC4yfhrQ0gZBn<@S|y?T(N49VE_egCI@Isy_Cx?kOIjOC9f>KI6q=s1$CJLU@X>_VYcl)g3OVH+BLpvG{}YUJEo zo?+n9MvV$&py=K=3ec%^eW|pUW?#W?sadFjI@5_iB#r6kSh?grXJN1ilZFA{Lx}@K zN``dB^0yHW)Hxz_{JDJ9H%F5HrOQ=Hl5AY{?!;-Z%Cnb{j5@Z%3y1znvKG~)t0l;J z;v_zuGG2yE;KOd}C_qAV&q(b$b-)5{N>RI9*>`mlQOWJ2H)KW^UZ3JQA)y6(uMsGJLq0{X?iL+`BFMed|O3?|uMyJjT(|VL08r?*$e2eBx3@x6O*uZ$0$P~g!CXuwya>NPX;~k zF8Fga;K+?`;Y%f6%JqT1sHc@5>9xrXKU;G7HHQ`bO?9-Dr&nkoIo#Y2)`xjs+ZGnQHsgIExjQ9v!Gy;q*J(u;QK-&eP<5;4B)MEK>UJD^TtT)0 zGW&6hMe!HjJd`?}&!oz840^!LT&qMakz6b2^VF_}L+q#^SM0{ZMJ<_IK{zCKtUrs- z!BQzkfUV=0Ete|2x)7`*>A5Doa{J;I#lzz6m5vsMG3m5l0^R9|L$ z2&sgB*};i*sg^F&OKGzMKCsZBRL`HKj~X1Z=)2V@0-8A#qVnkcI&7te06fzx<-Yrkh+xwL^KMOH^-~!}IA{S7e z%~NWeyktTCkE1?AGt2lcKqOIGV4_a?RKvQ``bD}E`R|u;Gc{&n8G}IH!AkM(_?+vW ze&R<}*};-6u07@1t%4zauN)k=hf}7#44g$VFCw)H#~BZ^jBc9d4{z*OSIwOwXF}B@ z=2Zn;dGn05AnneaWx7BPk+&Ra{<&fp77rSKLd^x2pB6}WdAs0m@IVnydXmu%AZ+Ke zD@9&=`AW_ne09(#v;g8YEXjBmS*2Ih&R<&^T%!+QXps-SlhcgQ4p2Yj0@Ga-xa7Xm zSW0;x?^D`P2jZqn9L&iAQoH&JM}2Y(kE|NQEXn~&sbW9LDWO2$0~SWi>=jg5cE69) z9L81sJg0};657zApgH${ZbgzP}Sso#TLi7Hee^WhrjP%B!_25toPoCeAZ;Z@4WWaP~ zEQ^h>?WEy%M7g3dmHqyh!)M*i01n~#PyJDfNrvXIYE8iHiJm2iY5utxisj^95d;vc`=MDGk&pL}J8mu73Yq||m*vC}PX7b?U`;wf z=^L|D_lP}g)d8Cq9W^G|zG`!}Iii9(3vmG#D)sVs#h;4nKWRIhf?mrqcjxd5RB71@ zgkc#3|BHM}I1wpODic>e4G@TbNSA=+OhW&uhoTv#8DZCu^oZ?~R!XEDB#db9UcWf7d$pRnB7YcFw`zzgvbIsqL`&TC%)u--vMqXY}qn)~-8 z93-zd$29QLnMW1a{k4Wa{>B^zfBO}Z157NnSmY&?7E=;OC1+wj&?vqs1v=lJ+y?SY zYUZDqYSt4)1pM#&pofrY|9qQAQWP}G3YKef+H zjqIHE&Cg24Y^H@?2M2F#m$ZbgmQSW$KkEK(jkV;HYD|1pe}t!rF|n(6>)j-NTjX!L z``bXOje3iM$dh?%1kHC3z zi~A(#Z(JabEX`oENQi%q`@-Jf_*Q;r1X&+-*-}i&@l)y2CL7Uxs!>Hx06Kb?N(Ud! zc?CC#xiQ7JvrFKJv{}7=zj2ui*|Qh=vPMYH@Yw-}4&HV8Twg}wbC$KO>%V11;bDal zJw8y(MZ3=Xj0%nh`yGQojyGFar>m^ermT?gkFC{d6iw!l+#$sT;?xBV|8#dm(O(27 zaHSkmD@w&&qPp^x~wS0HN>P99ck1u~S9 zTVX~^`&rAL${x8}84$CcZVpFsLqz-ie8OWq%_Y_i5Y)17hwpCY+3IclHzSoyF+1*$ z1oWzh2zD7cKdkATKThHf5TqpCCM?LebmUTjQx?T7z2wrOQ%G)Ui>ZhnGIagF*eTp zBtc3V9c%SxhL9rh)brPdy!{Yp@HqFw2aljtAMhw=6O zhj( Date: Mon, 6 Nov 2023 14:52:29 -0800 Subject: [PATCH 4/5] Update README in offloadpp blog (#4641) update tutorial link --- blogs/deepspeed-offloadpp/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blogs/deepspeed-offloadpp/README.md b/blogs/deepspeed-offloadpp/README.md index 22910ffef7b1..1441da5a35c0 100644 --- a/blogs/deepspeed-offloadpp/README.md +++ b/blogs/deepspeed-offloadpp/README.md @@ -43,7 +43,7 @@ We conduct our performance evaluations over both A100 and H100 DGX machine and t ## Tutorials -Examples and Tutorials are [here](https://github.com/microsoft/Megatron-DeepSpeed/blob/guanhua/partial-offload/examples_deepspeed/offload_pp/README.md) +Examples and Tutorials are [here](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/examples_deepspeed/offload_pp/README.md) ## Contributors: From cbec96b00ec16bd7c724dcb7ae00ceb1d319cefe Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 6 Nov 2023 15:55:25 -0800 Subject: [PATCH 5/5] [docs] update news items (#4640) Co-authored-by: Guanhua Wang --- README.md | 18 ++++++++++++++---- docs/index.md | 21 ++++++++++++++++----- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 721ba62cee37..e0463beb1a77 100755 --- a/README.md +++ b/README.md @@ -15,13 +15,23 @@ ## Latest News DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). +* [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp) * [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) * [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)] * [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)] -* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) -* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md) -* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md)] -* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)] +* [2023/08] [DeepSpeed ZeRO-Inference: 20x faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) + + +

--- diff --git a/docs/index.md b/docs/index.md index 60bcf19b84da..a4027ca1e52e 100755 --- a/docs/index.md +++ b/docs/index.md @@ -7,13 +7,24 @@ title: "Latest News" --- DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat). +* [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-offloadpp) * [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-fastgen) * [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)] -* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)] -* [2023/08] [DeepSpeed ZeRO-Inference: 20X faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) -* [2023/08] [DeepSpeed-Chat: Llama/Llama-2 system support, efficiency boost, and training stability improvements](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat/ds-chat-release-8-31/README.md) -* [2023/08] [DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-ulysses) [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/japanese/README.md)] -* [2023/06] [ZeRO++: A leap in speed for LLM and chat model training with 4X less communication](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)[[English](https://www.microsoft.com/en-us/research/blog/deepspeed-zero-a-leap-in-speed-for-llm-and-chat-model-training-with-4x-less-communication/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/zeropp/japanese/README.md)] +* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[DeepSpeed4Science website](https://deepspeed4science.ai/)] [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)] +* [2023/08] [DeepSpeed ZeRO-Inference: 20x faster inference through weight quantization and KV cache offloading](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) + + + +
+ More news + +
# Extreme Speed and Scale for DL Training and Inference
+ More news + +