from __future__ import annotations
import asyncio
import contextlib
import json
import logging
import ssl
import struct
import time
import traceback
from base64 import urlsafe_b64decode
from contextlib import suppress as contextlib_suppress
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable
from aiohttp import ClientSession
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_private_key
from google.protobuf.json_format import MessageToJson
from google.protobuf.message import Message
from http_ece import decrypt as http_decrypt # type: ignore[import-untyped]
from .const import (
MCS_HOST,
MCS_PORT,
MCS_SELECTIVE_ACK_ID,
MCS_VERSION,
)
from .fcmregister import FcmRegister, FcmRegisterConfig
from .proto.mcs_pb2 import ( # pylint: disable=no-name-in-module
Close,
DataMessageStanza,
HeartbeatAck,
HeartbeatPing,
IqStanza,
LoginRequest,
LoginResponse,
SelectiveAck,
StreamErrorStanza,
)
_logger = logging.getLogger(__name__)
OnNotificationCallable = Callable[[dict[str, Any], str, Any], None]
CredentialsUpdatedCallable = Callable[[dict[str, Any]], None]
# MCS Message Types and Tags
MCS_MESSAGE_TAG = {
HeartbeatPing: 0,
HeartbeatAck: 1,
LoginRequest: 2,
LoginResponse: 3,
Close: 4,
"MessageStanza": 5,
"PresenceStanza": 6,
IqStanza: 7,
DataMessageStanza: 8,
"BatchPresenceStanza": 9,
StreamErrorStanza: 10,
"HttpRequest": 11,
"HttpResponse": 12,
"BindAccountRequest": 13,
"BindAccountResponse": 14,
"TalkMetadata": 15,
}
class ErrorType(Enum):
CONNECTION = 1
READ = 2
LOGIN = 3
NOTIFY = 4
class FcmPushClientRunState(Enum):
CREATED = (1,)
STARTING_TASKS = (2,)
STARTING_CONNECTION = (3,)
STARTING_LOGIN = (4,)
STARTED = (5,)
RESETTING = (6,)
STOPPING = (7,)
STOPPED = (8,)
[docs]@dataclass
class FcmPushClientConfig: # pylint:disable=too-many-instance-attributes
"""Class to provide configuration to
:class:`firebase_messaging.FcmPushClientConfig`.FcmPushClient."""
server_heartbeat_interval: int | None = 10
"""Time in seconds to request the server to send heartbeats"""
client_heartbeat_interval: int | None = 20
"""Time in seconds to send heartbeats to the server"""
send_selective_acknowledgements: bool = True
"""True to send selective acknowledgements for each message received.
Currently if false the client does not send any acknowledgements."""
connection_retry_count: int = 5
"""Number of times to retry the connection before giving up."""
start_seconds_before_retry_connect: float = 3
"""Time in seconds to wait before attempting to retry
the connection after failure."""
reset_interval: float = 3
"""Time in seconds to wait between resets after errors or disconnection."""
heartbeat_ack_timeout: float = 5
"""Time in seconds to wait for a heartbeat ack before resetting."""
abort_on_sequential_error_count: int | None = 3
"""Number of sequential errors of the same time to wait before aborting.
If set to None the client will not abort."""
monitor_interval: float = 1
"""Time in seconds for the monitor task to fire and check for heartbeats,
stale connections and shut down of the main event loop."""
log_warn_limit: int | None = 5
"""Number of times to log specific warning messages before going silent for
a specific warning type."""
log_debug_verbose: bool = False
"""Set to True to log all message info including tokens."""
[docs]class FcmPushClient: # pylint:disable=too-many-instance-attributes
"""Client that connects to Firebase Cloud Messaging and receives messages.
:param credentials: credentials object returned by register()
:param credentials_updated_callback: callback when new credentials are
created to allow client to store them
:param received_persistent_ids: any persistent id's you already received.
:param config: configuration class of
:class:`firebase_messaging.FcmPushClientConfig`
"""
def __init__(
self,
callback: Callable[[dict, str, Any | None], None],
fcm_config: FcmRegisterConfig,
credentials: dict | None = None,
credentials_updated_callback: CredentialsUpdatedCallable | None = None,
*,
callback_context: object | None = None,
received_persistent_ids: list[str] | None = None,
config: FcmPushClientConfig | None = None,
http_client_session: ClientSession | None = None,
):
"""Initializes the receiver."""
self.callback = callback
self.callback_context = callback_context
self.fcm_config = fcm_config
self.credentials = credentials
self.credentials_updated_callback = credentials_updated_callback
self.persistent_ids = received_persistent_ids if received_persistent_ids else []
self.config = config if config else FcmPushClientConfig()
if self.config.log_debug_verbose:
_logger.setLevel(logging.DEBUG)
self._http_client_session = http_client_session
self.reader: asyncio.StreamReader | None = None
self.writer: asyncio.StreamWriter | None = None
self.do_listen = False
self.sequential_error_counters: dict[ErrorType, int] = {}
self.log_warn_counters: dict[str, int] = {}
# reset variables
self.input_stream_id = 0
self.last_input_stream_id_reported = -1
self.first_message = True
self.last_login_time: float | None = None
self.last_message_time: float | None = None
self.run_state: FcmPushClientRunState = FcmPushClientRunState.CREATED
self.tasks: list[asyncio.Task] = []
self.reset_lock: asyncio.Lock | None = None
self.stopping_lock: asyncio.Lock | None = None
def _msg_str(self, msg: Message) -> str:
if self.config.log_debug_verbose:
return type(msg).__name__ + "\n" + MessageToJson(msg, indent=4)
return type(msg).__name__
def _log_verbose(self, msg: str, *args: object) -> None:
if self.config.log_debug_verbose:
_logger.debug(msg, *args)
def _log_warn_with_limit(self, msg: str, *args: object) -> None:
if msg not in self.log_warn_counters:
self.log_warn_counters[msg] = 0
if (
self.config.log_warn_limit
and self.config.log_warn_limit > self.log_warn_counters[msg]
):
self.log_warn_counters[msg] += 1
_logger.warning(msg, *args)
async def _do_writer_close(self) -> None:
writer = self.writer
self.writer = None
if writer:
writer.close()
with contextlib.suppress(Exception):
await writer.wait_closed()
async def _reset(self) -> None:
if (
(self.reset_lock and self.reset_lock.locked())
or (self.stopping_lock and self.stopping_lock.locked())
or not self.do_listen
):
return
async with self.reset_lock: # type: ignore[union-attr]
_logger.debug("Resetting connection")
self.run_state = FcmPushClientRunState.RESETTING
await self._do_writer_close()
now = time.time()
time_since_last_login = now - self.last_login_time # type: ignore[operator]
if time_since_last_login < self.config.reset_interval:
_logger.debug("%ss since last reset attempt.", time_since_last_login)
await asyncio.sleep(self.config.reset_interval - time_since_last_login)
_logger.debug("Reestablishing connection")
if not await self._connect_with_retry():
_logger.error(
"Unable to connect to MCS endpoint "
+ "after %s tries, shutting down",
self.config.connection_retry_count,
)
self._terminate()
return
_logger.debug("Re-connected to ssl socket")
await self._login()
# protobuf variable length integers are encoded in base 128
# each byte contains 7 bits of the integer and the msb is set if there's
# more. pretty simple to implement
async def _read_varint32(self) -> int:
res = 0
shift = 0
while True:
r = await self.reader.readexactly(1) # type: ignore[union-attr]
(b,) = struct.unpack("B", r)
res |= (b & 0x7F) << shift
if (b & 0x80) == 0:
break
shift += 7
return res
@staticmethod
def _encode_varint32(x: int) -> bytes:
if x == 0:
return bytes(bytearray([0]))
res = bytearray([])
while x != 0:
b = x & 0x7F
x >>= 7
if x != 0:
b |= 0x80
res.append(b)
return bytes(res)
@staticmethod
def _make_packet(msg: Message, include_version: bool) -> bytes:
tag = MCS_MESSAGE_TAG[type(msg)]
header = bytearray([MCS_VERSION, tag]) if include_version else bytearray([tag])
payload = msg.SerializeToString()
buf = bytes(header) + FcmPushClient._encode_varint32(len(payload)) + payload
return buf
async def _send_msg(self, msg: Message) -> None:
self._log_verbose("Sending packet to server: %s", self._msg_str(msg))
buf = FcmPushClient._make_packet(msg, self.first_message)
self.writer.write(buf) # type: ignore[union-attr]
await self.writer.drain() # type: ignore[union-attr]
async def _receive_msg(self) -> Message | None:
if self.first_message:
r = await self.reader.readexactly(2) # type: ignore[union-attr]
version, tag = struct.unpack("BB", r)
if version < MCS_VERSION and version != 38:
raise RuntimeError(f"protocol version {version} unsupported")
self.first_message = False
else:
r = await self.reader.readexactly(1) # type: ignore[union-attr]
(tag,) = struct.unpack("B", r)
size = await self._read_varint32()
self._log_verbose(
"Received message with tag %s and size %s",
tag,
size,
)
if not size >= 0:
self._log_warn_with_limit("Unexpected message size %s", size)
return None
buf = await self.reader.readexactly(size) # type: ignore[union-attr]
msg_class = next(iter([c for c, t in MCS_MESSAGE_TAG.items() if t == tag]))
if not msg_class:
self._log_warn_with_limit("Unexpected message tag %s", tag)
return None
if isinstance(msg_class, str):
self._log_warn_with_limit("Unconfigured message class %s", msg_class)
return None
payload = msg_class() # type: ignore[operator]
payload.ParseFromString(buf)
self._log_verbose("Received payload: %s", self._msg_str(payload))
return payload
async def _login(self) -> None:
self.run_state = FcmPushClientRunState.STARTING_LOGIN
now = time.time()
self.input_stream_id = 0
self.last_input_stream_id_reported = -1
self.first_message = True
self.last_login_time = now
try:
android_id = self.credentials["gcm"]["android_id"] # type: ignore[index]
req = LoginRequest()
req.adaptive_heartbeat = False
req.auth_service = LoginRequest.ANDROID_ID # 2
req.auth_token = self.credentials["gcm"]["security_token"] # type: ignore[index]
req.id = self.fcm_config.chrome_version
req.domain = "mcs.android.com"
req.device_id = f"android-{int(android_id):x}"
req.network_type = 1
req.resource = android_id
req.user = android_id
req.use_rmq2 = True
req.setting.add(name="new_vc", value="1")
req.received_persistent_id.extend(self.persistent_ids)
if (
self.config.server_heartbeat_interval
and self.config.server_heartbeat_interval > 0
):
req.heartbeat_stat.ip = ""
req.heartbeat_stat.timeout = True
req.heartbeat_stat.interval_ms = (
1000 * self.config.server_heartbeat_interval
)
await self._send_msg(req)
_logger.debug("Sent login request")
except Exception as ex:
_logger.error("Received an exception logging in: %s", ex)
if self._try_increment_error_count(ErrorType.LOGIN):
await self._reset()
@staticmethod
def _decrypt_raw_data(
credentials: dict[str, dict[str, str]],
crypto_key_str: str,
salt_str: str,
raw_data: bytes,
) -> bytes:
crypto_key = urlsafe_b64decode(crypto_key_str.encode("ascii"))
salt = urlsafe_b64decode(salt_str.encode("ascii"))
der_data_str = credentials["keys"]["private"]
der_data = urlsafe_b64decode(der_data_str.encode("ascii") + b"========")
secret_str = credentials["keys"]["secret"]
secret = urlsafe_b64decode(secret_str.encode("ascii") + b"========")
privkey = load_der_private_key(
der_data, password=None, backend=default_backend()
)
decrypted = http_decrypt(
raw_data,
salt=salt,
private_key=privkey,
dh=crypto_key,
version="aesgcm",
auth_secret=secret,
)
return decrypted
def _app_data_by_key(
self, p: DataMessageStanza, key: str, do_not_raise: bool = False
) -> str:
for x in p.app_data:
if x.key == key:
return x.value
if do_not_raise:
return ""
raise RuntimeError(f"couldn't find in app_data {key}")
def _handle_data_message(
self,
msg: DataMessageStanza,
) -> None:
_logger.debug(
"Received data message Stream ID: %s, Last: %s, Status: %s",
msg.stream_id,
msg.last_stream_id_received,
msg.status,
)
if (
self._app_data_by_key(msg, "message_type", do_not_raise=True)
== "deleted_messages"
):
# The deleted_messages message does not contain data.
return
crypto_key = self._app_data_by_key(msg, "crypto-key")[3:] # strip dh=
salt = self._app_data_by_key(msg, "encryption")[5:] # strip salt=
subtype = self._app_data_by_key(msg, "subtype")
if TYPE_CHECKING:
assert self.credentials
if subtype != self.credentials["gcm"]["app_id"]:
self._log_warn_with_limit(
"Subtype %s in data message does not match"
+ "app id client was registered with %s",
subtype,
self.credentials["gcm"]["app_id"],
)
if not self.credentials:
return
decrypted = self._decrypt_raw_data(
self.credentials, crypto_key, salt, msg.raw_data
)
decrypted_json = None
with contextlib_suppress(json.JSONDecodeError, ValueError):
decrypted_json = json.loads(decrypted.decode("utf-8"))
if not decrypted_json:
self._log_warn_with_limit(
"Failed to decrypt data for message %s", msg.persistent_id
)
ret_val = decrypted_json if decrypted_json else decrypted
self._log_verbose("Data for message %s is: %s", msg.persistent_id, ret_val)
try:
if not isinstance(ret_val, dict):
ret_val = {"message": ret_val}
self.callback(ret_val, msg.persistent_id, self.callback_context)
self._reset_error_count(ErrorType.NOTIFY)
except Exception:
_logger.exception("Unexpected exception calling notification callback\n")
self._try_increment_error_count(ErrorType.NOTIFY)
def _new_input_stream_id_available(self) -> bool:
return self.last_input_stream_id_reported != self.input_stream_id
def _get_input_stream_id(self) -> int:
self.last_input_stream_id_reported = self.input_stream_id
return self.input_stream_id
async def _handle_ping(self, p: HeartbeatPing) -> None:
_logger.debug(
"Received heartbeat ping, sending ack: Stream ID: %s, Last: %s, Status: %s",
p.stream_id,
p.last_stream_id_received,
p.status,
)
req = HeartbeatAck()
if self._new_input_stream_id_available():
req.last_stream_id_received = self._get_input_stream_id()
await self._send_msg(req)
async def _handle_iq(self, p: IqStanza) -> None:
if not p.extension:
self._log_warn_with_limit(
"Unexpected IqStanza id received with no extension", str(p)
)
return
if p.extension.id not in (12, 13):
self._log_warn_with_limit(
"Unexpected extension id received: %s", p.extension.id
)
return
async def _send_selective_ack(self, persistent_id: str) -> None:
iqs = IqStanza()
iqs.type = IqStanza.IqType.SET
iqs.id = ""
iqs.extension.id = MCS_SELECTIVE_ACK_ID
sa = SelectiveAck()
sa.id.extend([persistent_id])
iqs.extension.data = sa.SerializeToString()
_logger.debug("Sending selective ack for message id %s", persistent_id)
await self._send_msg(iqs)
async def _send_heartbeat(self) -> None:
req = HeartbeatPing()
if self._new_input_stream_id_available():
req.last_stream_id_received = self._get_input_stream_id()
await self._send_msg(req)
_logger.debug("Sent heartbeat ping")
def _terminate(self) -> None:
self.run_state = FcmPushClientRunState.STOPPING
self.do_listen = False
current_task = asyncio.current_task()
for task in self.tasks:
if (
current_task != task and not task.done()
): # cancel return if task is done so no need to check
task.cancel()
async def _do_monitor(self) -> None:
while self.do_listen:
await asyncio.sleep(self.config.monitor_interval)
if self.run_state == FcmPushClientRunState.STARTED:
# if server_heartbeat_interval is set and less than
# client_heartbeat_interval then the last_message_time
# will be within the client window if connected
if self.config.client_heartbeat_interval:
now = time.time()
if (
self.last_message_time + self.config.client_heartbeat_interval # type: ignore[operator]
< now
):
await self._send_heartbeat()
await asyncio.sleep(self.config.heartbeat_ack_timeout)
now = time.time()
if ( # Check state hasn't changed during sleep
self.last_message_time # type: ignore[operator]
+ self.config.client_heartbeat_interval
< now
and self.do_listen
and self.run_state == FcmPushClientRunState.STARTED
):
await self._reset()
elif self.config.server_heartbeat_interval:
now = time.time()
if ( # We give the server 2 extra seconds
self.last_message_time + self.config.server_heartbeat_interval # type: ignore[operator]
< now - 2
):
await self._reset()
def _reset_error_count(self, error_type: ErrorType) -> None:
self.sequential_error_counters[error_type] = 0
def _try_increment_error_count(self, error_type: ErrorType) -> bool:
if error_type not in self.sequential_error_counters:
self.sequential_error_counters[error_type] = 0
self.sequential_error_counters[error_type] += 1
if (
self.config.abort_on_sequential_error_count
and self.sequential_error_counters[error_type]
>= self.config.abort_on_sequential_error_count
):
_logger.error(
"Shutting down push receiver due to "
+ f"{self.sequential_error_counters[error_type]} sequential"
+ f" errors of type {error_type}"
)
self._terminate()
return False
return True
async def _handle_message(self, msg: Message) -> None:
self.last_message_time = time.time()
self.input_stream_id += 1
if isinstance(msg, Close):
self._log_warn_with_limit("Server sent Close message, resetting")
if self._try_increment_error_count(ErrorType.CONNECTION):
await self._reset()
return
if isinstance(msg, LoginResponse):
if str(msg.error):
_logger.error("Received login error response: %s", msg)
if self._try_increment_error_count(ErrorType.LOGIN):
await self._reset()
else:
_logger.info("Successfully logged in to MCS endpoint")
self._reset_error_count(ErrorType.LOGIN)
self.run_state = FcmPushClientRunState.STARTED
self.persistent_ids = []
return
if isinstance(msg, DataMessageStanza):
self._handle_data_message(msg)
self.persistent_ids.append(msg.persistent_id)
if self.config.send_selective_acknowledgements:
await self._send_selective_ack(msg.persistent_id)
elif isinstance(msg, HeartbeatPing):
await self._handle_ping(msg)
elif isinstance(msg, HeartbeatAck):
_logger.debug("Received heartbeat ack: %s", msg)
elif isinstance(msg, IqStanza):
pass
else:
self._log_warn_with_limit("Unexpected message type %s.", type(msg).__name__)
# Reset error count if a read has been successful
self._reset_error_count(ErrorType.READ)
self._reset_error_count(ErrorType.CONNECTION)
@staticmethod
async def _open_connection(
host: str, port: int, ssl_context: ssl.SSLContext
) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
return await asyncio.open_connection(host=host, port=port, ssl=ssl_context)
async def _connect(self) -> bool:
try:
loop = asyncio.get_running_loop()
# create_default_context() blocks the event loop
ssl_context = await loop.run_in_executor(None, ssl.create_default_context)
self.reader, self.writer = await self._open_connection(
host=MCS_HOST, port=MCS_PORT, ssl_context=ssl_context
)
_logger.debug("Connected to MCS endpoint (%s,%s)", MCS_HOST, MCS_PORT)
return True
except OSError as oex:
_logger.error(
"Could not connected to MCS endpoint (%s,%s): %s",
MCS_HOST,
MCS_PORT,
oex,
)
return False
async def _connect_with_retry(self) -> bool:
self.run_state = FcmPushClientRunState.STARTING_CONNECTION
trycount = 0
connected = False
while (
trycount < self.config.connection_retry_count
and not connected
and self.do_listen
):
trycount += 1
connected = await self._connect()
if not connected:
sleep_time = (
self.config.start_seconds_before_retry_connect * trycount * trycount
)
_logger.info(
"Could not connect to MCS Endpoint on "
+ "try %s, sleeping for %s seconds",
trycount,
sleep_time,
)
await asyncio.sleep(sleep_time)
if not connected:
_logger.error(
"Unable to connect to MCS endpoint after %s tries, aborting", trycount
)
return connected
async def _listen(self) -> None:
"""listens for push notifications."""
if not await self._connect_with_retry():
return
try:
await self._login()
while self.do_listen:
try:
if self.run_state == FcmPushClientRunState.RESETTING:
await asyncio.sleep(1)
elif msg := await self._receive_msg():
await self._handle_message(msg)
except (OSError, EOFError) as osex:
if (
isinstance(
osex,
(
ConnectionResetError,
TimeoutError,
asyncio.IncompleteReadError,
ssl.SSLError,
),
)
and self.run_state == FcmPushClientRunState.RESETTING
):
if (
isinstance(osex, ssl.SSLError) # pylint: disable=no-member
and osex.reason != "APPLICATION_DATA_AFTER_CLOSE_NOTIFY"
):
self._log_warn_with_limit(
"Unexpected SSLError reason during reset of %s",
osex.reason,
)
else:
self._log_verbose(
"Expected read error during reset: %s",
type(osex).__name__,
)
else:
_logger.exception("Unexpected exception during read\n")
if self._try_increment_error_count(ErrorType.CONNECTION):
await self._reset()
except Exception as ex:
_logger.error(
"Unknown error: %s, shutting down FcmPushClient.\n%s",
ex,
traceback.format_exc(),
)
self._terminate()
finally:
await self._do_writer_close()
[docs] async def checkin_or_register(self) -> str:
"""Check in if you have credentials otherwise register as a new client.
:param sender_id: sender id identifying push service you are connecting to.
:param app_id: identifier for your application.
:return: The FCM token which is used to identify you with the push end
point application.
"""
self.register = FcmRegister(
self.fcm_config,
self.credentials,
self.credentials_updated_callback,
http_client_session=self._http_client_session,
)
self.credentials = await self.register.checkin_or_register()
# await self.register.fcm_refresh_install()
await self.register.close()
return self.credentials["fcm"]["registration"]["token"]
[docs] async def start(self) -> None:
"""Connect to FCM and start listening for push notifications."""
self.reset_lock = asyncio.Lock()
self.stopping_lock = asyncio.Lock()
self.do_listen = True
self.run_state = FcmPushClientRunState.STARTING_TASKS
try:
self.tasks = [
asyncio.create_task(self._listen()),
asyncio.create_task(self._do_monitor()),
]
except Exception as ex:
_logger.error("Unexpected error running FcmPushClient: %s", ex)
[docs] async def stop(self) -> None:
if (
self.stopping_lock
and self.stopping_lock.locked()
or self.run_state
in (
FcmPushClientRunState.STOPPING,
FcmPushClientRunState.STOPPED,
)
):
return
async with self.stopping_lock: # type: ignore[union-attr]
try:
self.run_state = FcmPushClientRunState.STOPPING
self.do_listen = False
for task in self.tasks:
if not task.done():
task.cancel()
finally:
self.run_state = FcmPushClientRunState.STOPPED
self.fcm_thread = None
self.listen_event_loop = None
[docs] def is_started(self) -> bool:
return self.run_state == FcmPushClientRunState.STARTED
[docs] async def send_message(self, raw_data: bytes, persistent_id: str) -> None:
"""Not implemented, does nothing atm."""
dms = DataMessageStanza()
dms.persistent_id = persistent_id
# Not supported yet