@encryption_commands.command()
@click.option("--clean-first", type=bool, is_flag=True)
@do_not_run_in_prod
def generate_initial_keys(clean_first: bool) -> None:
"""
During the migration we will need to create User and Device Keypairs proactively from the backend for all existing members with conversations.
For dev/testing, this sets up new keys for all users.
"""
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from jwcrypto import jwe, jwk
from sqlalchemy import and_
### Initial cleanup
from components.encryption.internal.models.encryption_group import (
EncryptionGroup,
)
from components.encryption.internal.models.encryption_group_user import (
EncryptionGroupUser,
)
from components.encryption.internal.models.encryption_user import (
EncryptionUser,
)
if clean_first:
current_session.query(EncryptionGroupUser).delete() # noqa: ALN085
current_session.query(EncryptionUser).delete() # noqa: ALN085
current_session.query(EncryptionGroup).delete() # noqa: ALN085
### Bootstrap: create key pair for doctor group
# This must be done before creating user keys, because the doctor group serves as an escrow group
# for recovery, so its public key must be known when creating users, and the user keys must be
# known when setting up the group, which is the time where the group key encryption key is available.
from components.encryption.internal.business_logic import (
InitialGroupMember,
create_encryption_group,
get_doctor_group,
)
doctor_group = get_doctor_group()
if doctor_group:
click.echo("Doctor group already exists")
doctors_public_key_pem = doctor_group.public_key
else:
current_logger.info("Generating doctor group RSA keypair")
doctors_key_pair = rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
)
doctors_public_key_pem = (
doctors_key_pair.public_key()
.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode("utf-8")
)
doctors_group_public_key_jwk = jwk.JWK.from_pem(
doctors_public_key_pem.encode("utf-8")
)
current_logger.info(
f"Doctors group key thumbprint: {doctors_group_public_key_jwk.thumbprint()}"
)
### Generate keys for users (members and medical admins)
from components.clinic.internal.models.clinic_user import ( # noqa: ALN039,ALN069
ClinicUser,
)
from components.clinic.internal.models.medical_admin import ( # noqa: ALN039,ALN069
MedicalAdmin,
)
from components.clinic.internal.models.medical_conversation import ( # noqa: ALN039,ALN069
MedicalConversation,
)
from components.encryption.constants import (
DOCTORS_GROUP_NAME,
NULL_PASSWORD_DIGEST,
)
from components.encryption.internal.models.encryption_user import (
EncryptionUser,
)
conversations_clinic_user_ids = {
clinic_user_id
for (clinic_user_id,) in current_session.query( # noqa: ALN085
MedicalConversation.member_clinic_user_id.distinct()
).filter(MedicalConversation.member_clinic_user_id.isnot(None))
}
# Remove all clinic members who already have keys
clinic_users_with_keys = {
clinic_user_id
for (clinic_user_id,) in current_session.query(ClinicUser.id) # noqa: ALN085
.join(
EncryptionUser,
and_(
EncryptionUser.app_id == ClinicUser.app_id,
EncryptionUser.app_user_id == ClinicUser.app_user_id,
),
)
.filter(
EncryptionUser.user_public_key.isnot(None),
EncryptionUser.user_secret_key_as_jwe.isnot(None),
EncryptionUser.device_public_key.isnot(None),
EncryptionUser.device_secret_key_as_jwe.isnot(None),
)
}
medical_admin_clinic_user_ids = {
clinic_user_id
for (clinic_user_id,) in current_session.query( # noqa: ALN085
MedicalAdmin.clinic_user_id.distinct()
).filter(MedicalAdmin.clinic_user_id.isnot(None), MedicalAdmin.is_active)
}
clinic_user_ids = (
(conversations_clinic_user_ids - clinic_users_with_keys)
| medical_admin_clinic_user_ids
) # We always want the medical admins to be sure we add them to the doctor group
current_logger.info(
f"Generating keys for {len(clinic_user_ids)} clinic users and admins"
)
null_password_hashed_jwk = jwk.JWK.from_password(NULL_PASSWORD_DIGEST)
doctors_group_initial_members: list[InitialGroupMember] = []
for clinic_user_id in clinic_user_ids:
clinic_user = current_session.get(ClinicUser, clinic_user_id)
encryption_user = (
current_session.query(EncryptionUser) # noqa: ALN085
.filter(
EncryptionUser.app_id == clinic_user.app_id, # type: ignore[arg-type,union-attr]
EncryptionUser.app_user_id == clinic_user.app_user_id, # type: ignore[arg-type,union-attr]
)
.first()
)
if not encryption_user or not any(
[
encryption_user.device_public_key,
encryption_user.device_secret_key_as_jwe,
encryption_user.user_public_key,
encryption_user.user_secret_key_as_jwe,
]
):
# Generate a RSA keypair for the "device"
device_rsa_key = rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
)
# Serialize both keys as PEM
device_private_key_pem = device_rsa_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
device_public_key_pem = device_rsa_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
# Encrypt the device private key with the null password and create a JWE token
device_secret_key_as_jwe = jwe.JWE(
plaintext=device_private_key_pem,
recipient=null_password_hashed_jwk,
header={"alg": "dir", "enc": "A256GCM"},
)
# Generate a RSA keypair for the "user"
user_rsa_key = rsa.generate_private_key(
public_exponent=65537,
key_size=4096,
)
# Serialize both keys as PEM
user_private_key_pem = user_rsa_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
user_public_key_pem = user_rsa_key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
device_public_key_jwk = jwk.JWK()
device_public_key_jwk.import_from_pem(device_public_key_pem)
user_secret_key_as_jwe = jwe.JWE(plaintext=user_private_key_pem)
doctors_group_key_id = f"group:{doctors_group_public_key_jwk.thumbprint()}"
current_logger.info(
f"Encrypting user private key with doctors group public key ({doctors_group_key_id})"
)
user_secret_key_as_jwe.add_recipient(
key=doctors_group_public_key_jwk,
header={
"alg": "RSA-OAEP-256",
"enc": "A256GCM",
"kid": doctors_group_key_id,
},
)
device_key_id = f"device:{device_public_key_jwk.thumbprint()}"
current_logger.info(
f"Encrypting device private key with device public key ({device_key_id})"
)
user_secret_key_as_jwe.add_recipient(
key=device_public_key_jwk,
header={"alg": "RSA-OAEP-256", "enc": "A256GCM", "kid": device_key_id},
)
if encryption_user:
click.echo(
f"Saving new keys for {clinic_user.app_id}:{clinic_user.app_user_id}" # type: ignore[union-attr]
)
encryption_user.device_public_key = device_public_key_pem.decode("utf8")
encryption_user.device_secret_key_as_jwe = (
device_secret_key_as_jwe.serialize()
)
encryption_user.is_device_secret_encrypted = False
encryption_user.user_public_key = user_public_key_pem.decode("utf8")
encryption_user.user_secret_key_as_jwe = (
user_secret_key_as_jwe.serialize()
)
else:
click.echo(
f"Saving new EncryptionUser for {clinic_user.app_id}:{clinic_user.app_user_id}" # type: ignore[union-attr]
)
encryption_user = EncryptionUser(
app_id=clinic_user.app_id, # type: ignore[union-attr]
app_user_id=clinic_user.app_user_id, # type: ignore[union-attr]
device_public_key=device_public_key_pem.decode("utf8"),
device_secret_key_as_jwe=device_secret_key_as_jwe.serialize(),
is_device_secret_encrypted=False,
user_public_key=user_public_key_pem.decode("utf8"),
user_secret_key_as_jwe=user_secret_key_as_jwe.serialize(),
)
current_session.add(encryption_user)
current_session.commit()
# Add Medical Admins to Doctor Group
if (
clinic_user_id in medical_admin_clinic_user_ids
and not current_session.query(EncryptionGroupUser) # noqa: ALN085
.filter(
EncryptionGroupUser.encryption_group == doctor_group,
EncryptionGroupUser.encryption_user == encryption_user,
)
.first()
):
doctors_group_initial_members.append(
InitialGroupMember(
full_name=clinic_user.medical_admin.full_name, # type: ignore[union-attr]
encryption_user=encryption_user,
)
)
### Generate doctors group
if doctor_group is None:
create_encryption_group(
DOCTORS_GROUP_NAME,
doctors_group_initial_members,
key_pair=doctors_key_pair,
)
current_session.commit()