Skip to content

Commit

Permalink
more input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jan 5, 2025
1 parent 716c520 commit 53fd619
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 34 deletions.
110 changes: 83 additions & 27 deletions kraken/ketos/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,55 @@
import click

from pathlib import Path
from difflib import get_close_matches

from .util import message

logging.captureWarnings(True)
logger = logging.getLogger('kraken')


def _get_field_list(name):
def _validate_script(script: str) -> str:
from htrmopo.util import _iso15924
if script not in _iso15924:
return get_close_matches(script, _iso15924.keys())
return script


def _validate_language(language: str) -> str:
from htrmopo.util import _iso639_3
if language not in _iso639_3:
return get_close_matches(language, _iso639_3.keys())
return language


def _validate_license(license: str) -> str:
from htrmopo.util import _licenses
if license not in _licenses:
return get_close_matches(license, _licenses.keys())
return license


def _get_field_list(name,
validation_fn=lambda x: x,
required: bool = False):
values = []
while True:
value = click.prompt(name, default=None)
if value is not None:
values.append(value)
value = click.prompt(name, default='')
if value:
if (cand := validation_fn(value)) == value:
values.append(value)
else:
message(f'Not a valid {name} value. Did you mean {cand}?')
else:
break
if click.confirm(f'All `{name}` values added?'):
if required and not values:
message(f'`{name}` is a required field.')
continue
else:
break
else:
continue
return values


Expand All @@ -46,7 +81,7 @@ def _get_field_list(name):
@click.option('-i', '--metadata', show_default=True,
type=click.File(mode='r', lazy=True), help='Model card file for the model.')
@click.option('-a', '--access-token', prompt=True, help='Zenodo access token')
@click.option('-d', '--doi', prompt=True, help='DOI of an existing record to update')
@click.option('-d', '--doi', help='DOI of an existing record to update')
@click.option('-p', '--private/--public', default=False, help='Disables Zenodo '
'community inclusion request. Allows upload of models that will not show '
'up on `kraken list` output')
Expand All @@ -56,15 +91,16 @@ def publish(ctx, metadata, access_token, doi, private, model):
Publishes a model on the zenodo model repository.
"""
import json
import yaml
import tempfile

from htrmopo import publish_model, update_model

pub_fn = publish_model

from kraken.lib.vgsl import TorchVGSLModel
from kraken.lib.progress import KrakenDownloadProgressBar

pub_fn = publish_model

_yaml_delim = r'(?:---|\+\+\+)'
_yaml = r'(.*?)'
_content = r'\s*(.+)$'
Expand All @@ -77,27 +113,44 @@ def publish(ctx, metadata, access_token, doi, private, model):
# construct metadata if none is given
if metadata:
frontmatter, content = _yaml_regex.match(metadata.read()).groups()
frontmatter = yaml.safe_load(frontmatter)
else:
frontmatter['summary'] = click.prompt('summary')
content = click.edit('Write long form description (training data, transcription standards) of the model in markdown format here')

creators = []
message('To stop adding authors, leave the author name field empty.')
while True:
author = click.prompt('author', default=None)
affiliation = click.prompt('affiliation', default=None)
orcid = click.prompt('orcid', default=None)
if author is not None:
creators.append({'author': author})
author = click.prompt('author name', default='')
if author:
creators.append({'name': author})
else:
break
if click.confirm('All authors added?'):
break
else:
continue
affiliation = click.prompt('affiliation', default='')
orcid = click.prompt('orcid', default='')
if affiliation is not None:
creators[-1]['affiliation'] = affiliation
if orcid is not None:
creators[-1]['orcid'] = orcid
if not creators:
raise click.UsageError('The `authors` field is obligatory. Aborting')

frontmatter['authors'] = creators
frontmatter['license'] = click.prompt('license')
frontmatter['language'] = _get_field_list('language')
frontmatter['script'] = _get_field_list('script')
while True:
license = click.prompt('license')
if (lic := _validate_license(license)) == license:
frontmatter['license'] = license
break
else:
message(f'Not a valid license identifer. Did you mean {lic}?')

message('To stop adding values to the following fields, enter an empty field.')

frontmatter['language'] = _get_field_list('language', _validate_language, required=True)
frontmatter['script'] = _get_field_list('script', _validate_script, required=True)

if len(tags := _get_field_list('tag')):
frontmatter['tags'] = tags + ['kraken_pytorch']
Expand All @@ -108,30 +161,33 @@ def publish(ctx, metadata, access_token, doi, private, model):

# take last metrics field, falling back to accuracy field in model metadata
metrics = {}
if 'metrics' in nn.user_metadata and nn.user_metadata['metrics']:
metrics['cer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_accuracy']
metrics['wer'] = 100 - nn.user_metadata['metrics'][-1][1]['val_word_accuracy']
elif 'accuracy' in nn.user_metadata and nn.user_metadata['accuracy']:
metrics['cer'] = 100 - nn.user_metadata['accuracy']
if nn.user_metadata.get('metrics', None) is not None:
if (val_accuracy := nn.user_metadata['metrics'][-1][1].get('val_accuracy', None)) is not None:
metrics['cer'] = 100 - (val_accuracy * 100)
if (val_word_accuracy := nn.user_metadata['metrics'][-1][1].get('val_word_accuracy', None)) is not None:
metrics['wer'] = 100 - (val_word_accuracy * 100)
elif (accuracy := nn.user_metadata.get('accuracy', None)) is not None:
metrics['cer'] = 100 - accuracy
frontmatter['metrics'] = metrics
software_hints = ['kind=vgsl']

# some recognition-specific software hints
if nn.model_type == 'recognition':
software_hints.append([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', 'legacy_polygons={nn.user_metadata["legacy_polygons"]}'])
software_hints.extend([f'seg_type={nn.seg_type}', f'one_channel_mode={nn.one_channel_mode}', f'legacy_polygons={nn.user_metadata["legacy_polygons"]}'])
frontmatter['software_hints'] = software_hints

frontmatter['software_name'] = 'kraken'
frontmatter['model_type'] = [nn.model_type]

# build temporary directory
with tempfile.TemporaryDirectory() as tmpdir, KrakenDownloadProgressBar() as progress:
upload_task = progress.add_task('Uploading', total=0, visible=True if not ctx.meta['verbose'] else False)

model = Path(model)
model = Path(model).resolve()
tmpdir = Path(tmpdir)
(tmpdir / model.name).symlink_to(model)
# v0 metadata only supports recognition models
(tmpdir / model.name).resolve().symlink_to(model)
if nn.model_type == 'recognition':
# v0 metadata only supports recognition models
v0_metadata = {
'summary': frontmatter['summary'],
'description': content,
Expand All @@ -145,7 +201,7 @@ def publish(ctx, metadata, access_token, doi, private, model):
with open(tmpdir / 'metadata.json', 'w') as fo:
json.dump(v0_metadata, fo)
kwargs = {'model': tmpdir,
'model_card': f'---\n{frontmatter}---\n{content}',
'model_card': f'---\n{yaml.dump(frontmatter)}---\n{content}',
'access_token': access_token,
'callback': lambda total, advance: progress.update(upload_task, total=total, advance=advance),
'private': private}
Expand Down
11 changes: 4 additions & 7 deletions kraken/repo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright 2015 Benjamin Kiessling
# Copyright 2025 Benjamin Kiessling
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,15 +17,12 @@
~~~~~~~~~~~
Wrappers around the htrmopo reference implementation implementing
kraken-specific filtering.
kraken-specific filtering for repository querying operations.
"""
import logging
import warnings
from pathlib import Path
from collections import defaultdict
from typing import IO, Any, Dict, List, Union, cast, Optional, TypeVar, Iterable, Literal

from collections.abc import Callable
from typing import Any, Dict, Optional, TypeVar, Literal


from htrmopo import get_description as mopo_get_description
from htrmopo import get_listing as mopo_get_listing
Expand Down

0 comments on commit 53fd619

Please sign in to comment.